@@ -1030,30 +1030,43 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
10301030    return  int8_dynamic_activation_int8_weight (layout = SemiSparseLayout ())
10311031
10321032
1033- def  float8_weight_only (weight_dtype : torch .dtype  =  torch .float8_e4m3fn ):
1033+ @dataclass  
1034+ class  Float8WeightOnlyConfig (AOBaseConfig ):
10341035    """ 
1035-     Applies  float8 weight-only symmetric per-channel quantization to linear layers. 
1036+     Configuration for applying  float8 weight-only symmetric per-channel quantization to linear layers. 
10361037
10371038    Args: 
10381039        weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn. 
10391040
10401041    Note: 
10411042        The actual matmul will be computed in original precision of the weight tensor. 
1042- 
10431043    """ 
1044-     from  torchao .dtypes  import  to_affine_quantized_floatx 
10451044
1046-     def  apply_float8wo_quant (weight ):
1047-         block_size  =  (1 , weight .shape [1 ])
1048-         return  to_affine_quantized_floatx (
1049-             input_float = weight ,
1050-             block_size = block_size ,
1051-             target_dtype = weight_dtype ,
1052-             scale_dtype = None ,
1053-             _layout = Float8Layout (mm_config = None ),
1054-         )
1045+     weight_dtype : torch .dtype  =  torch .float8_e4m3fn 
1046+ 
1047+ 
1048+ # for BC 
1049+ float8_weight_only  =  Float8WeightOnlyConfig 
1050+ 
1051+ 
1052+ @register_quantize_module_handler (Float8WeightOnlyConfig ) 
1053+ def  _float8_weight_only_transform (
1054+     module : torch .nn .Module , config : Float8WeightOnlyConfig 
1055+ ) ->  torch .nn .Module :
1056+     from  torchao .dtypes  import  to_affine_quantized_floatx 
10551057
1056-     return  _get_linear_subclass_inserter (apply_float8wo_quant )
1058+     weight  =  module .weight 
1059+     block_size  =  (1 , weight .shape [1 ])
1060+     new_weight  =  to_affine_quantized_floatx (
1061+         input_float = weight ,
1062+         block_size = block_size ,
1063+         target_dtype = config .weight_dtype ,
1064+         scale_dtype = None ,
1065+         _layout = Float8Layout (mm_config = None ),
1066+     )
1067+     module .weight  =  torch .nn .Parameter (new_weight , requires_grad = False )
1068+     module .extra_repr  =  types .MethodType (_linear_extra_repr , module )
1069+     return  module 
10571070
10581071
10591072_fp8_granularities  =  Union [PerTensor , PerRow ]
@@ -1170,16 +1183,10 @@ def _fp8_mm_compat(weight: torch.Tensor) -> bool:
11701183    return  is_compatible 
11711184
11721185
1173- def  float8_dynamic_activation_float8_weight (
1174-     activation_dtype : torch .dtype  =  torch .float8_e4m3fn ,
1175-     weight_dtype : torch .dtype  =  torch .float8_e4m3fn ,
1176-     granularity : Optional [
1177-         Union [_fp8_granularities , Tuple [_fp8_granularities , _fp8_granularities ]]
1178-     ] =  None ,
1179-     mm_config : Optional [Float8MMConfig ] =  None ,
1180- ):
1186+ @dataclass  
1187+ class  Float8DynamicActivationFloat8WeightConfig (AOBaseConfig ):
11811188    """ 
1182-     Applies  float8 dynamic symmetric quantization to both activations and weights of linear layers. 
1189+     Configuration for applying  float8 dynamic symmetric quantization to both activations and weights of linear layers. 
11831190
11841191    Args: 
11851192        activation_dtype (torch.dtype): The target data type for activation quantization. Default is torch.float8_e4m3fn. 
@@ -1192,104 +1199,149 @@ def float8_dynamic_activation_float8_weight(
11921199        mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation. 
11931200
11941201    """ 
1195-     assert  (
1196-         is_sm_at_least_89 () or  is_MI300 ()
1197-     ), "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+" 
1198-     if  mm_config  is  None :
1199-         mm_config  =  Float8MMConfig (use_fast_accum = True )
12001202
1201-     activation_granularity , weight_granularity  =  _normalize_granularity (granularity )
1203+     activation_dtype : torch .dtype  =  torch .float8_e4m3fn 
1204+     weight_dtype : torch .dtype  =  torch .float8_e4m3fn 
1205+     granularity : Optional [
1206+         Union [_fp8_granularities , Tuple [_fp8_granularities , _fp8_granularities ]]
1207+     ] =  None 
1208+     mm_config : Optional [Float8MMConfig ] =  None 
12021209
1203-     def  apply_float8_dynamic_activation_quant (weight : torch .Tensor ):
1204-         if  not  _fp8_mm_compat (weight ):
1205-             return  weight 
1206-         if  isinstance (weight_granularity , PerRow ):
1207-             assert  (
1208-                 weight .dtype  ==  torch .bfloat16 
1209-             ), "PerRow quantization only works for bfloat16 precision input weight" 
1210+     def  __post_init__ (self ):
1211+         assert  (
1212+             is_sm_at_least_89 () or  is_MI300 ()
1213+         ), "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+" 
1214+         if  self .mm_config  is  None :
1215+             self .mm_config  =  Float8MMConfig (use_fast_accum = True )
12101216
1211-         block_size  =  get_block_size (weight .shape , weight_granularity )
1212-         quantized_weight  =  to_affine_quantized_floatx (
1213-             input_float = weight ,
1214-             block_size = block_size ,
1215-             target_dtype = weight_dtype ,
1216-             scale_dtype = torch .float32 ,
1217-             _layout = Float8Layout (mm_config = mm_config ),
1218-         )
12191217
1220-         input_quant_func  =  _input_activation_quant_func_fp8 
1221-         input_quant_kwargs  =  {
1222-             "activation_granularity" : activation_granularity ,
1223-             "activation_dtype" : activation_dtype ,
1224-         }
1218+ # for bc 
1219+ float8_dynamic_activation_float8_weight  =  Float8DynamicActivationFloat8WeightConfig 
12251220
1226-         quantized_weight  =  to_linear_activation_quantized (
1227-             quantized_weight , input_quant_func , quant_kwargs = input_quant_kwargs 
1228-         )
1229-         return  quantized_weight 
12301221
1231-     return  _get_linear_subclass_inserter (apply_float8_dynamic_activation_quant )
1222+ @register_quantize_module_handler (Float8DynamicActivationFloat8WeightConfig ) 
1223+ def  _float8_dynamic_activation_float8_weight_transform (
1224+     module : torch .nn .Module , config : Float8DynamicActivationFloat8WeightConfig 
1225+ ):
1226+     activation_dtype  =  config .activation_dtype 
1227+     weight_dtype  =  config .weight_dtype 
1228+     granularity  =  config .granularity 
1229+     mm_config  =  config .mm_config 
1230+     weight  =  module .weight 
12321231
1232+     activation_granularity , weight_granularity  =  _normalize_granularity (granularity )
12331233
1234- def  float8_static_activation_float8_weight (
1235-     scale : torch .Tensor ,
1236-     activation_dtype : torch .dtype  =  torch .float8_e4m3fn ,
1237-     weight_dtype : torch .dtype  =  torch .float8_e4m3fn ,
1238-     granularity : Optional [
1239-         Union [_fp8_granularities , Tuple [_fp8_granularities , _fp8_granularities ]]
1240-     ] =  None ,
1241-     mm_config : Optional [Float8MMConfig ] =  None ,
1242- ):
1234+     if  not  _fp8_mm_compat (weight ):
1235+         # TODO(future PR): this should really throw an exception instead of silently 
1236+         # not doing what the user asked 
1237+         return  module 
1238+     if  isinstance (weight_granularity , PerRow ):
1239+         assert  (
1240+             weight .dtype  ==  torch .bfloat16 
1241+         ), "PerRow quantization only works for bfloat16 precision input weight" 
1242+ 
1243+     block_size  =  get_block_size (weight .shape , weight_granularity )
1244+     quantized_weight  =  to_affine_quantized_floatx (
1245+         input_float = weight ,
1246+         block_size = block_size ,
1247+         target_dtype = weight_dtype ,
1248+         scale_dtype = torch .float32 ,
1249+         _layout = Float8Layout (mm_config = mm_config ),
1250+     )
1251+ 
1252+     input_quant_func  =  _input_activation_quant_func_fp8 
1253+     input_quant_kwargs  =  {
1254+         "activation_granularity" : activation_granularity ,
1255+         "activation_dtype" : activation_dtype ,
1256+     }
1257+ 
1258+     quantized_weight  =  to_linear_activation_quantized (
1259+         quantized_weight , input_quant_func , quant_kwargs = input_quant_kwargs 
1260+     )
1261+ 
1262+     module .weight  =  torch .nn .Parameter (quantized_weight , requires_grad = False )
1263+     module .extra_repr  =  types .MethodType (_linear_extra_repr , module )
1264+     return  module 
1265+ 
1266+ 
1267+ @dataclass  
1268+ class  Float8StaticActivationFloat8WeightConfig (AOBaseConfig ):
12431269    """ 
1244-     Applies  float8 static symmetric quantization to 
1270+     Configuration for applying  float8 static symmetric quantization to 
12451271
12461272    Args: 
12471273        scale (torch.Tensor): The scale tensor for activation quantization. 
12481274        activation_dtype (torch.dtype): The target data type for activation quantization. Default is torch.float8_e4m 
12491275        weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m 
12501276        mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation. 
12511277    """ 
1252-     assert  (
1253-         is_sm_at_least_89 () or  is_MI300 ()
1254-     ), "Float8 static activation quantization is only supported on CUDA 8.9 and above" 
1255-     if  mm_config  is  None :
1256-         mm_config  =  Float8MMConfig (use_fast_accum = True )
12571278
1279+     scale : torch .Tensor 
1280+     activation_dtype : torch .dtype  =  torch .float8_e4m3fn 
1281+     weight_dtype : torch .dtype  =  torch .float8_e4m3fn 
1282+     granularity : Optional [
1283+         Union [_fp8_granularities , Tuple [_fp8_granularities , _fp8_granularities ]]
1284+     ] =  None 
1285+     mm_config : Optional [Float8MMConfig ] =  None 
1286+ 
1287+     def  __post_init__ (self ):
1288+         assert  (
1289+             is_sm_at_least_89 () or  is_MI300 ()
1290+         ), "Float8 static activation quantization is only supported on CUDA 8.9 and above" 
1291+         if  self .mm_config  is  None :
1292+             self .mm_config  =  Float8MMConfig (use_fast_accum = True )
1293+ 
1294+ 
1295+ # for bc 
1296+ float8_static_activation_float8_weight  =  Float8StaticActivationFloat8WeightConfig 
1297+ 
1298+ 
1299+ @register_quantize_module_handler (Float8StaticActivationFloat8WeightConfig ) 
1300+ def  _float8_static_activation_float8_weight_transform (
1301+     module : torch .nn .Module , config : Float8StaticActivationFloat8WeightConfig 
1302+ ):
1303+     scale  =  config .scale 
1304+     activation_dtype  =  config .activation_dtype 
1305+     weight_dtype  =  config .weight_dtype 
1306+     granularity  =  config .granularity 
1307+     mm_config  =  config .mm_config 
1308+ 
1309+     weight  =  module .weight 
12581310    activation_granularity , weight_granularity  =  _normalize_granularity (granularity )
12591311    assert  isinstance (
12601312        activation_granularity , PerTensor 
12611313    ), "Static quantization only supports PerTensor granularity" 
12621314
1263-     def  apply_float8_static_activation_quant (weight : torch .Tensor ):
1264-         if  not  _fp8_mm_compat (weight ):
1265-             return  weight 
1266-         block_size  =  get_block_size (weight .shape , weight_granularity )
1267-         quantized_weight  =  to_affine_quantized_floatx (
1268-             input_float = weight ,
1269-             block_size = block_size ,
1270-             target_dtype = weight_dtype ,
1271-             scale_dtype = torch .float32 ,
1272-             _layout = Float8Layout (mm_config = mm_config ),
1273-         )
1315+     if  not  _fp8_mm_compat (weight ):
1316+         # TODO(future PR): this should really throw an exception instead of silently 
1317+         # not doing what the user asked 
1318+         return  module 
1319+     block_size  =  get_block_size (weight .shape , weight_granularity )
1320+     quantized_weight  =  to_affine_quantized_floatx (
1321+         input_float = weight ,
1322+         block_size = block_size ,
1323+         target_dtype = weight_dtype ,
1324+         scale_dtype = torch .float32 ,
1325+         _layout = Float8Layout (mm_config = mm_config ),
1326+     )
12741327
1275-         input_quant_func  =  _input_activation_quant_func_fp8 
1276-         input_quant_kwargs  =  {
1277-             "activation_granularity" : activation_granularity ,
1278-             "activation_dtype" : activation_dtype ,
1279-         }
1280- 
1281-         quantized_weight  =  (
1282-             to_weight_tensor_with_linear_activation_quantization_metadata (
1283-                 quantized_weight ,
1284-                 input_quant_func ,
1285-                 scale = scale ,
1286-                 zero_point = None ,
1287-                 quant_kwargs = input_quant_kwargs ,
1288-             )
1289-         )
1290-         return  quantized_weight 
1328+     input_quant_func  =  _input_activation_quant_func_fp8 
1329+     input_quant_kwargs  =  {
1330+         "activation_granularity" : activation_granularity ,
1331+         "activation_dtype" : activation_dtype ,
1332+     }
12911333
1292-     return  _get_linear_subclass_inserter (apply_float8_static_activation_quant )
1334+     quantized_weight  =  to_weight_tensor_with_linear_activation_quantization_metadata (
1335+         quantized_weight ,
1336+         input_quant_func ,
1337+         scale = scale ,
1338+         zero_point = None ,
1339+         quant_kwargs = input_quant_kwargs ,
1340+     )
1341+ 
1342+     module .weight  =  torch .nn .Parameter (quantized_weight , requires_grad = False )
1343+     module .extra_repr  =  types .MethodType (_linear_extra_repr , module )
1344+     return  module 
12931345
12941346
12951347def  uintx_weight_only (dtype , group_size = 64 , pack_dim = - 1 , use_hqq = False ):
0 commit comments