File tree Expand file tree Collapse file tree 1 file changed +7
-0
lines changed Expand file tree Collapse file tree 1 file changed +7
-0
lines changed Original file line number Diff line number Diff line change 8282from .utils import _get_per_token_block_size
8383
8484logger = logging .getLogger (__name__ )
85+ is_cuda_8_9 = torch .cuda .is_available () and torch .cuda .get_device_capability () >= (8 , 9 )
8586
8687__all__ = [
8788 "swap_conv2d_1x1_to_linear" ,
@@ -939,6 +940,9 @@ def float8_dynamic_activation_float8_weight(
939940 mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
940941
941942 """
943+ assert (
944+ is_cuda_8_9
945+ ), "Float8 dynamic activation quantization is only supported on CUDA 8.9 and above"
942946 if mm_config is None :
943947 mm_config = Float8MMConfig (use_fast_accum = True )
944948
@@ -993,6 +997,9 @@ def float8_static_activation_float8_weight(
993997 weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m
994998 mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
995999 """
1000+ assert (
1001+ is_cuda_8_9
1002+ ), "Float8 static activation quantization is only supported on CUDA 8.9 and above"
9961003 if mm_config is None :
9971004 mm_config = Float8MMConfig (use_fast_accum = True )
9981005
You can’t perform that action at this time.
0 commit comments