1818 fake_quantize_per_channel_group ,
1919 fake_quantize_per_token ,
2020)
21- from torchao .quantization .utils import get_group_qparams_symmetric
21+ from torchao .quantization .utils import (
22+ get_group_qparams_symmetric ,
23+ get_groupwise_affine_qparams ,
24+ groupwise_affine_dequantize_tensor_from_qparams ,
25+ groupwise_affine_quantize_tensor ,
26+ groupwise_affine_quantize_tensor_from_qparams ,
27+ )
2228from torchao .utils import TORCH_VERSION_AFTER_2_4
2329
2430
2531# TODO: put this in a common test utils file
32+ _CUDA_IS_AVAILABLE = torch .cuda .is_available ()
33+
2634class Sub (torch .nn .Module ):
2735 def __init__ (self ):
2836 super ().__init__ ()
29- self .linear = torch .nn .Linear (32 , 32 , bias = False ).to (torch .float )
37+ self .linear = torch .nn .Linear (256 , 256 , bias = False ).to (torch .float )
3038
3139 def example_inputs (self ):
32- return (torch .randn (1 , 32 ).to (torch .float ),)
40+ return (torch .randn (1 , 256 ).to (torch .float ),)
3341
3442 def forward (self , x ):
3543 return self .linear (x )
3644
3745class M (torch .nn .Module ):
3846 def __init__ (self ):
3947 super ().__init__ ()
40- self .linear1 = torch .nn .Linear (64 , 32 , bias = False ).to (torch .float )
48+ self .linear1 = torch .nn .Linear (512 , 256 , bias = False ).to (torch .float )
4149 self .sub = Sub ()
42- self .linear2 = torch .nn .Linear (32 , 64 , bias = False ).to (torch .float )
50+ self .linear2 = torch .nn .Linear (256 , 512 , bias = False ).to (torch .float )
4351
4452 def example_inputs (self ):
45- return (torch .randn (1 , 64 ).to (torch .float ),)
53+ return (torch .randn (1 , 512 ).to (torch .float ),)
4654
4755 def forward (self , x ):
4856 x = self .linear1 (x )
@@ -111,23 +119,46 @@ def test_fake_quantize_per_token(self):
111119
112120 def _set_ptq_weight (
113121 self ,
114- ptq_linear : "Int8DynActInt4WeightLinear" ,
115- fp32_weight : torch .Tensor ,
116- group_size : int ,
122+ ptq_linear : torch .nn .Module ,
123+ qat_linear : torch .nn .Module ,
117124 ):
118125 """
119126 Set the weight to the quantized version of the given fp32 weights,
120127 for making linear outputs comparable with QAT.
121128 """
129+ from torchao .quantization .GPTQ import (
130+ Int8DynActInt4WeightLinear ,
131+ WeightOnlyInt4Linear ,
132+ )
133+ from torchao .quantization .prototype .qat import (
134+ Int8DynActInt4WeightQATLinear ,
135+ Int4WeightOnlyQATLinear ,
136+ )
122137 n_bit = 4
123138 (qmin , qmax ) = self ._get_qmin_qmax (n_bit )
124- (s , zp ) = get_group_qparams_symmetric (fp32_weight , n_bit , group_size )
125- q_weight = torch .ops .quantized_decomposed .quantize_per_channel_group (
126- fp32_weight , s , zp , qmin , qmax , torch .int8 , group_size ,
127- )
128- ptq_linear .weight = q_weight
129- ptq_linear .scales = s
130- ptq_linear .zeros = zp
139+ if isinstance (ptq_linear , Int8DynActInt4WeightLinear ):
140+ assert isinstance (qat_linear , Int8DynActInt4WeightQATLinear )
141+ fp32_weight = qat_linear .weight
142+ group_size = qat_linear .groupsize
143+ (s , zp ) = get_group_qparams_symmetric (fp32_weight , n_bit , group_size )
144+ q_weight = torch .ops .quantized_decomposed .quantize_per_channel_group (
145+ fp32_weight , s , zp , qmin , qmax , torch .int8 , group_size ,
146+ )
147+ ptq_linear .weight = q_weight
148+ ptq_linear .scales = s
149+ ptq_linear .zeros = zp
150+ elif isinstance (ptq_linear , WeightOnlyInt4Linear ):
151+ assert isinstance (qat_linear , Int4WeightOnlyQATLinear )
152+ (q_weight , scales_and_zeros ) = groupwise_affine_quantize_tensor (
153+ qat_linear .weight , n_bit , qat_linear .groupsize ,
154+ )
155+ q_weight = torch .ops .aten ._convert_weight_to_int4pack (
156+ q_weight .to ("cuda" ), qat_linear .inner_k_tiles ,
157+ )
158+ ptq_linear .weight = q_weight
159+ ptq_linear .scales_and_zeros = scales_and_zeros
160+ else :
161+ raise ValueError ("Unknown ptq_linear type: %s" % type (ptq_linear ))
131162
132163 @unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "skipping when torch version is 2.4 or lower" )
133164 def test_qat_8da4w_linear (self ):
@@ -144,7 +175,7 @@ def test_qat_8da4w_linear(self):
144175 )
145176
146177 # Force the weights to be the same
147- self ._set_ptq_weight (ptq_linear , qat_linear . weight , group_size )
178+ self ._set_ptq_weight (ptq_linear , qat_linear )
148179
149180 # Compare linear values
150181 torch .manual_seed (self .SEED )
@@ -280,7 +311,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
280311 loss_fn1 = torch .nn .CrossEntropyLoss ()
281312 loss_fn2 = torch .nn .CrossEntropyLoss ()
282313 example_inputs = nn_model .example_inputs ()
283- target = torch .randn (1 , 64 ).float ()
314+ target = torch .randn (1 , 512 ).float ()
284315 output1 = nn_model (* example_inputs )
285316 output2 = qat_model (* example_inputs )
286317 torch .testing .assert_close (output1 , output2 , atol = 0 , rtol = 0 )
@@ -322,6 +353,123 @@ def test_qat_generic_fake_quantize(self):
322353 torch .testing .assert_close (py_out , ao_out , atol = 0 , rtol = 0 )
323354 torch .testing .assert_close (py_input .grad , ao_input .grad , atol = 0 , rtol = 0 )
324355
356+ def _assert_close_4w (self , val , ref ):
357+ # Note: for int4 weight-only quantization, we do not expect exact match
358+ # because torch._weight_int4pack_mm and torch.mm do not match exactly.
359+ # Here we use the same error bar as PyTorch core to determine closeness:
360+ # https://github.com/pytorch/pytorch/blob/6079c5091091d872b8dafbaa4e31a5b6194647ad/test/test_linalg.py#L6079
361+ mean_err = ((val - ref ) / ref ).mean ().abs ()
362+ print (mean_err )
363+ self .assertTrue (mean_err < 0.05 )
364+
365+ @unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "skipping when torch version is 2.4 or lower" )
366+ @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
367+ def test_qat_4w_primitives (self ):
368+ n_bit = 4
369+ group_size = 32
370+ inner_k_tiles = 8
371+ scales_precision = torch .bfloat16
372+ device = torch .device ("cuda" )
373+ dtype = torch .bfloat16
374+ torch .manual_seed (self .SEED )
375+ x = torch .randn (100 , 256 , dtype = dtype , device = device )
376+ weight = torch .randn (512 , 256 , dtype = dtype , device = device )
377+
378+ # PTQ
379+ (q_weight , scales_and_zeros ) = groupwise_affine_quantize_tensor (
380+ weight , n_bit , group_size , scales_precision ,
381+ )
382+ q_weight = torch .ops .aten ._convert_weight_to_int4pack (
383+ q_weight .to (device ), inner_k_tiles ,
384+ )
385+ ptq_out = torch .ops .aten ._weight_int4pack_mm (
386+ x , q_weight , group_size , scales_and_zeros
387+ )
388+
389+ # QAT
390+ scales , zero_points = get_groupwise_affine_qparams (
391+ weight , n_bit , group_size , scales_precision ,
392+ )
393+ w_q = groupwise_affine_quantize_tensor_from_qparams (
394+ weight , scales , zero_points , n_bit , group_size , cast_dtypes = False ,
395+ )
396+ w_dq = groupwise_affine_dequantize_tensor_from_qparams (
397+ w_q , scales , zero_points , n_bit , group_size , cast_dtypes = False ,
398+ )
399+ qat_out = torch .nn .functional .linear (x , w_dq )
400+
401+ self ._assert_close_4w (qat_out , ptq_out )
402+
403+ @unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "skipping when torch version is 2.4 or lower" )
404+ @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
405+ def test_qat_4w_linear (self ):
406+ from torchao .quantization .prototype .qat import Int4WeightOnlyQATLinear
407+ from torchao .quantization .GPTQ import WeightOnlyInt4Linear
408+
409+ group_size = 128
410+ device = torch .device ("cuda" )
411+ dtype = torch .bfloat16
412+ torch .manual_seed (self .SEED )
413+ qat_linear = Int4WeightOnlyQATLinear (
414+ 256 , 688 , bias = False , groupsize = group_size , device = device ,
415+ )
416+ ptq_linear = WeightOnlyInt4Linear (
417+ 256 , 688 , bias = False , groupsize = group_size , device = device ,
418+ )
419+
420+ # Force the weights to be the same
421+ self ._set_ptq_weight (ptq_linear , qat_linear )
422+
423+ # Compare linear values
424+ torch .manual_seed (self .SEED )
425+ x = torch .randn (100 , 256 , dtype = dtype , device = device )
426+ x2 = copy .deepcopy (x )
427+ qat_out = qat_linear (x )
428+ ptq_out = ptq_linear (x2 )
429+ self ._assert_close_4w (qat_out , ptq_out )
430+
431+ @unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "skipping when torch version is 2.4 or lower" )
432+ @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
433+ def test_qat_4w_quantizer (self ):
434+ from torchao .quantization .prototype .qat import Int4WeightOnlyQATQuantizer
435+ from torchao .quantization .GPTQ import Int4WeightOnlyQuantizer
436+
437+ group_size = 32
438+ inner_k_tiles = 8
439+ device = torch .device ("cuda" )
440+ dtype = torch .bfloat16
441+ torch .manual_seed (self .SEED )
442+ m = M ().to (device ).to (dtype )
443+ m2 = copy .deepcopy (m )
444+ qat_quantizer = Int4WeightOnlyQATQuantizer (
445+ groupsize = group_size , inner_k_tiles = inner_k_tiles ,
446+ )
447+ ptq_quantizer = Int4WeightOnlyQuantizer (
448+ groupsize = group_size , inner_k_tiles = inner_k_tiles ,
449+ )
450+ qat_model = qat_quantizer .prepare (m )
451+ ptq_model = ptq_quantizer .quantize (m2 )
452+
453+ # Compare model values
454+ torch .manual_seed (self .SEED )
455+ x = [i .to (device ).to (dtype ) for i in m .example_inputs ()]
456+ x2 = copy .deepcopy (x )
457+ qat_out = qat_model (* x )
458+ ptq_out = ptq_model (* x2 )
459+ self ._assert_close_4w (qat_out , ptq_out )
460+
461+ # Convert QAT model and compare model values
462+ converted_model = qat_quantizer .convert (qat_model )
463+ converted_out = converted_model (* x )
464+ torch .testing .assert_close (converted_out , ptq_out , atol = 0 , rtol = 0 )
465+
466+ # Compare converted state dict
467+ ptq_state_dict = ptq_model .state_dict ()
468+ converted_state_dict = converted_model .state_dict ()
469+ self .assertEqual (ptq_state_dict .keys (), converted_state_dict .keys ())
470+ for k in ptq_state_dict .keys ():
471+ torch .testing .assert_close (ptq_state_dict [k ], converted_state_dict [k ], atol = 0 , rtol = 0 )
472+
325473
326474if __name__ == "__main__" :
327475 unittest .main ()
0 commit comments