File tree Expand file tree Collapse file tree 1 file changed +3
-0
lines changed Expand file tree Collapse file tree 1 file changed +3
-0
lines changed Original file line number Diff line number Diff line change 8282from .utils import _get_per_token_block_size
8383
8484logger = logging .getLogger (__name__ )
85+ is_cuda_8_9 = torch .cuda .is_available () and torch .cuda .get_device_capability () >= (8 , 9 )
8586
8687__all__ = [
8788 "swap_conv2d_1x1_to_linear" ,
@@ -939,6 +940,7 @@ def float8_dynamic_activation_float8_weight(
939940 mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
940941
941942 """
943+ assert is_cuda_8_9 , "Float8 dynamic activation quantization is only supported on CUDA 8.9 and above"
942944 if mm_config is None :
943945 mm_config = Float8MMConfig (use_fast_accum = True )
944946
@@ -993,6 +995,7 @@ def float8_static_activation_float8_weight(
993995 weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m
994996 mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
995997 """
998+ assert is_cuda_8_9 , "Float8 static activation quantization is only supported on CUDA 8.9 and above"
996999 if mm_config is None :
9971000 mm_config = Float8MMConfig (use_fast_accum = True )
9981001
You can’t perform that action at this time.
0 commit comments