4949    to_affine_quantized_floatx_static ,
5050    to_affine_quantized_intx ,
5151    to_fbgemm_fp8 ,
52-     to_fbgemm_int4 ,
5352    to_marlinqqq_quantized_intx ,
5453)
5554from  torchao .dtypes .uintx .packed_linear_int8_dynamic_activation_intx_weight_layout  import  (
7170from  torchao .quantization .observer  import  AffineQuantizedObserverBase , get_block_size 
7271from  torchao .quantization .quantize_ .common  import  (
7372    KernelPreference ,
73+     PackingFormat ,
7474)
7575from  torchao .quantization .quantize_ .workflows  import  (
7676    Float8Tensor ,
7777    Int4PreshuffledTensor ,
78+     Int4Tensor ,
7879    QuantizeTensorToFloat8Kwargs ,
7980)
8081from  torchao .quantization .transform_module  import  (
@@ -1119,6 +1120,7 @@ class Int4WeightOnlyConfig(AOBaseConfig):
11191120        `zero_point_domain`: data type of zeros points, choices are [ZeroPointDomain.FLOAT, ZeroPointDomain.INT, ZeroPointDomain.NONE] 
11201121        `set_inductor_config`: if True, adjusts `torchinductor` settings to recommended values. 
11211122        `preserve_zero`: whether to preserve zero, default is None. Will be set to True if zero_point_domain is ZeroPointDomain.INT 
1123+         `packing_format`: the packing format for int4 tensor, available from VERSION 2 and above 
11221124    """ 
11231125
11241126    group_size : int  =  128 
@@ -1127,6 +1129,9 @@ class Int4WeightOnlyConfig(AOBaseConfig):
11271129    zero_point_domain : Optional [ZeroPointDomain ] =  ZeroPointDomain .NONE 
11281130    set_inductor_config : bool  =  True 
11291131    preserve_zero : Optional [bool ] =  None 
1132+     # only used in VERSION >= 2 
1133+     packing_format : PackingFormat  =  PackingFormat .PLAIN 
1134+     VERSION : int  =  1 
11301135
11311136
11321137# for BC 
@@ -1144,15 +1149,36 @@ def _int4_weight_only_quantize_tensor(weight, config):
11441149    layout  =  config .layout 
11451150    use_hqq  =  config .use_hqq 
11461151    zero_point_domain  =  config .zero_point_domain 
1152+     packing_format  =  config .packing_format 
11471153
11481154    if  weight .shape [- 1 ] %  group_size  !=  0 :
11491155        logger .info (
11501156            f"Skipping quantizing weight with int4 weight only quantization because the shape of weight { weight .shape }   is not compatible with group_size { group_size }  " 
11511157        )
11521158        return  weight 
11531159
1160+     block_size  =  tuple ([1  for  _  in  range (weight .ndim  -  1 )] +  [group_size ])
1161+ 
1162+     if  config .VERSION  ==  2 :
1163+         if  packing_format  ==  PackingFormat .PRESHUFFLED :
1164+             new_weight  =  Int4PreshuffledTensor .from_float (
1165+                 weight ,
1166+                 block_size ,
1167+                 activation_dtype = torch .bfloat16 ,
1168+             )
1169+             return  new_weight 
1170+         elif  packing_format  ==  PackingFormat .PLAIN :
1171+             new_weight  =  Int4Tensor .from_float (
1172+                 weight ,
1173+                 block_size ,
1174+             )
1175+             return  new_weight 
1176+         else :
1177+             raise  ValueError (f"Unsupported packing format: { packing_format }  " )
1178+ 
1179+     assert  config .VERSION  ==  1 
1180+ 
11541181    mapping_type  =  MappingType .ASYMMETRIC 
1155-     block_size  =  tuple ([1  for  _  in  range (weight .dim () -  1 )] +  [group_size ])
11561182    target_dtype  =  torch .int32 
11571183    quant_min  =  0 
11581184    quant_max  =  15 
@@ -1224,6 +1250,46 @@ def _int4_weight_only_transform(
12241250    return  module 
12251251
12261252
1253+ @dataclass  
1254+ class  Float8ActivationInt4WeightConfig (AOBaseConfig ):
1255+     """Configuration for apply float8 dynamic per row quantization and int4 
1256+     per group weight quantization to linear 
1257+ 
1258+     Args: 
1259+         `group_size`: group size for groupwise quantization for weight 
1260+         `packing_format`: how the weight is packed, only preshuffled is supported 
1261+     """ 
1262+ 
1263+     group_size : int  =  128 
1264+     packing_format : PackingFormat  =  "preshuffled" 
1265+ 
1266+ 
1267+ @register_quantize_module_handler (Float8ActivationInt4WeightConfig ) 
1268+ def  _float8_activation_int4_weight_transform (
1269+     module : torch .nn .Module , config : Float8ActivationInt4WeightConfig 
1270+ ) ->  torch .nn .Module :
1271+     assert  hasattr (module , "weight" ), (
1272+         "applying int8 weight only quant requires module to have weight attribute" 
1273+         +  " but {module} does not have one" 
1274+     )
1275+     group_size  =  config .group_size 
1276+     packing_format  =  config .packing_format 
1277+ 
1278+     assert  packing_format  ==  "preshuffled" , (
1279+         f"only preshuffled packing_format supported right now, got: { packing_format }  " 
1280+     )
1281+     weight  =  module .weight 
1282+     block_size  =  tuple ([1  for  _  in  range (weight .ndim  -  1 )] +  [group_size ])
1283+     new_weight  =  Int4PreshuffledTensor .from_float (
1284+         module .weight ,
1285+         block_size ,
1286+         activation_dtype = torch .float8_e4m3fn ,
1287+     )
1288+     module .weight  =  torch .nn .Parameter (new_weight , requires_grad = False )
1289+     module .extra_repr  =  types .MethodType (_linear_extra_repr , module )
1290+     return  module 
1291+ 
1292+ 
12271293@dataclass  
12281294class  Int8WeightOnlyConfig (AOBaseConfig ):
12291295    """ 
@@ -1677,6 +1743,7 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
16771743        # TODO(future PR): this should really throw an exception instead of silently 
16781744        # not doing what the user asked 
16791745        return  weight 
1746+ 
16801747    if  isinstance (weight_granularity , PerRow ):
16811748        assert  weight .dtype  ==  torch .bfloat16 , (
16821749            "PerRow quantization only works for bfloat16 precision input weight" 
@@ -2145,7 +2212,7 @@ def _(module: torch.nn.Module, config: FbgemmConfig) -> torch.nn.Module:
21452212                activation_dtype = torch .bfloat16 ,
21462213            )
21472214        else :
2148-             weight  =  to_fbgemm_int4 (
2215+             weight  =  Int4Tensor . from_float (
21492216                module .weight ,
21502217                config .block_size ,
21512218            )
0 commit comments