1818 fake_quantize_per_channel_group ,
1919 fake_quantize_per_token ,
2020)
21- from torchao .quantization .utils import get_group_qparams_symmetric
21+ from torchao .quantization .quant_primitives import (
22+ fake_quantize_affine ,
23+ ZeroPointDomain ,
24+ )
25+ from torchao .quantization .utils import (
26+ get_group_qparams_symmetric ,
27+ get_groupwise_affine_qparams ,
28+ groupwise_affine_quantize_tensor ,
29+ )
2230from torchao .utils import TORCH_VERSION_AFTER_2_4
2331
2432
2533# TODO: put this in a common test utils file
34+ _CUDA_IS_AVAILABLE = torch .cuda .is_available ()
35+
2636class Sub (torch .nn .Module ):
2737 def __init__ (self ):
2838 super ().__init__ ()
29- self .linear = torch .nn .Linear (32 , 32 , bias = False ).to (torch .float )
39+ self .linear = torch .nn .Linear (256 , 256 , bias = False ).to (torch .float )
3040
3141 def example_inputs (self ):
32- return (torch .randn (1 , 32 ).to (torch .float ),)
42+ return (torch .randn (1 , 256 ).to (torch .float ),)
3343
3444 def forward (self , x ):
3545 return self .linear (x )
3646
3747class M (torch .nn .Module ):
3848 def __init__ (self ):
3949 super ().__init__ ()
40- self .linear1 = torch .nn .Linear (64 , 32 , bias = False ).to (torch .float )
50+ self .linear1 = torch .nn .Linear (512 , 256 , bias = False ).to (torch .float )
4151 self .sub = Sub ()
42- self .linear2 = torch .nn .Linear (32 , 64 , bias = False ).to (torch .float )
52+ self .linear2 = torch .nn .Linear (256 , 512 , bias = False ).to (torch .float )
4353
4454 def example_inputs (self ):
45- return (torch .randn (1 , 64 ).to (torch .float ),)
55+ return (torch .randn (1 , 512 ).to (torch .float ),)
4656
4757 def forward (self , x ):
4858 x = self .linear1 (x )
@@ -111,23 +121,46 @@ def test_fake_quantize_per_token(self):
111121
112122 def _set_ptq_weight (
113123 self ,
114- ptq_linear : "Int8DynActInt4WeightLinear" ,
115- fp32_weight : torch .Tensor ,
116- group_size : int ,
124+ ptq_linear : torch .nn .Module ,
125+ qat_linear : torch .nn .Module ,
117126 ):
118127 """
119128 Set the weight to the quantized version of the given fp32 weights,
120129 for making linear outputs comparable with QAT.
121130 """
131+ from torchao .quantization .GPTQ import (
132+ Int8DynActInt4WeightLinear ,
133+ WeightOnlyInt4Linear ,
134+ )
135+ from torchao .quantization .prototype .qat import (
136+ Int8DynActInt4WeightQATLinear ,
137+ Int4WeightOnlyQATLinear ,
138+ )
122139 n_bit = 4
123140 (qmin , qmax ) = self ._get_qmin_qmax (n_bit )
124- (s , zp ) = get_group_qparams_symmetric (fp32_weight , n_bit , group_size )
125- q_weight = torch .ops .quantized_decomposed .quantize_per_channel_group (
126- fp32_weight , s , zp , qmin , qmax , torch .int8 , group_size ,
127- )
128- ptq_linear .weight = q_weight
129- ptq_linear .scales = s
130- ptq_linear .zeros = zp
141+ if isinstance (ptq_linear , Int8DynActInt4WeightLinear ):
142+ assert isinstance (qat_linear , Int8DynActInt4WeightQATLinear )
143+ fp32_weight = qat_linear .weight
144+ group_size = qat_linear .groupsize
145+ (s , zp ) = get_group_qparams_symmetric (fp32_weight , n_bit , group_size )
146+ q_weight = torch .ops .quantized_decomposed .quantize_per_channel_group (
147+ fp32_weight , s , zp , qmin , qmax , torch .int8 , group_size ,
148+ )
149+ ptq_linear .weight = q_weight
150+ ptq_linear .scales = s
151+ ptq_linear .zeros = zp
152+ elif isinstance (ptq_linear , WeightOnlyInt4Linear ):
153+ assert isinstance (qat_linear , Int4WeightOnlyQATLinear )
154+ (q_weight , scales_and_zeros ) = groupwise_affine_quantize_tensor (
155+ qat_linear .weight , n_bit , qat_linear .groupsize ,
156+ )
157+ q_weight = torch .ops .aten ._convert_weight_to_int4pack (
158+ q_weight .to ("cuda" ), qat_linear .inner_k_tiles ,
159+ )
160+ ptq_linear .weight = q_weight
161+ ptq_linear .scales_and_zeros = scales_and_zeros
162+ else :
163+ raise ValueError ("Unknown ptq_linear type: %s" % type (ptq_linear ))
131164
132165 @unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "skipping when torch version is 2.4 or lower" )
133166 def test_qat_8da4w_linear (self ):
@@ -144,7 +177,7 @@ def test_qat_8da4w_linear(self):
144177 )
145178
146179 # Force the weights to be the same
147- self ._set_ptq_weight (ptq_linear , qat_linear . weight , group_size )
180+ self ._set_ptq_weight (ptq_linear , qat_linear )
148181
149182 # Compare linear values
150183 torch .manual_seed (self .SEED )
@@ -280,7 +313,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
280313 loss_fn1 = torch .nn .CrossEntropyLoss ()
281314 loss_fn2 = torch .nn .CrossEntropyLoss ()
282315 example_inputs = nn_model .example_inputs ()
283- target = torch .randn (1 , 64 ).float ()
316+ target = torch .randn (1 , 512 ).float ()
284317 output1 = nn_model (* example_inputs )
285318 output2 = qat_model (* example_inputs )
286319 torch .testing .assert_close (output1 , output2 , atol = 0 , rtol = 0 )
@@ -322,6 +355,130 @@ def test_qat_generic_fake_quantize(self):
322355 torch .testing .assert_close (py_out , ao_out , atol = 0 , rtol = 0 )
323356 torch .testing .assert_close (py_input .grad , ao_input .grad , atol = 0 , rtol = 0 )
324357
358+ def _assert_close_4w (self , val , ref ):
359+ # Note: for int4 weight-only quantization, we do not expect exact match
360+ # because torch._weight_int4pack_mm and torch.mm do not match exactly.
361+ # Here we use the same error bar as PyTorch core to determine closeness:
362+ # https://github.com/pytorch/pytorch/blob/6079c5091091d872b8dafbaa4e31a5b6194647ad/test/test_linalg.py#L6079
363+ mean_err = ((val - ref ) / ref ).mean ().abs ()
364+ print (mean_err )
365+ self .assertTrue (mean_err < 0.05 )
366+
367+ @unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "skipping when torch version is 2.4 or lower" )
368+ @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
369+ def test_qat_4w_primitives (self ):
370+ n_bit = 4
371+ group_size = 32
372+ inner_k_tiles = 8
373+ scales_precision = torch .bfloat16
374+ device = torch .device ("cuda" )
375+ dtype = torch .bfloat16
376+ torch .manual_seed (self .SEED )
377+ x = torch .randn (100 , 256 , dtype = dtype , device = device )
378+ weight = torch .randn (512 , 256 , dtype = dtype , device = device )
379+
380+ # PTQ
381+ (q_weight , scales_and_zeros ) = groupwise_affine_quantize_tensor (
382+ weight , n_bit , group_size , scales_precision ,
383+ )
384+ q_weight = torch .ops .aten ._convert_weight_to_int4pack (
385+ q_weight .to (device ), inner_k_tiles ,
386+ )
387+ ptq_out = torch .ops .aten ._weight_int4pack_mm (
388+ x , q_weight , group_size , scales_and_zeros
389+ )
390+
391+ # QAT
392+ block_size = (1 , group_size )
393+ quant_min = 0
394+ quant_max = 2 ** n_bit - 1
395+ scales , zero_points = get_groupwise_affine_qparams (
396+ weight , n_bit , group_size , scales_precision ,
397+ )
398+ w_fq = fake_quantize_affine (
399+ weight ,
400+ block_size ,
401+ scales ,
402+ zero_points ,
403+ torch .int32 ,
404+ quant_min ,
405+ quant_max ,
406+ zero_point_domain = ZeroPointDomain .FLOAT ,
407+ )
408+ qat_out = torch .nn .functional .linear (x , w_fq )
409+
410+ self ._assert_close_4w (qat_out , ptq_out )
411+
412+ @unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "skipping when torch version is 2.4 or lower" )
413+ @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
414+ def test_qat_4w_linear (self ):
415+ from torchao .quantization .prototype .qat import Int4WeightOnlyQATLinear
416+ from torchao .quantization .GPTQ import WeightOnlyInt4Linear
417+
418+ group_size = 128
419+ device = torch .device ("cuda" )
420+ dtype = torch .bfloat16
421+ torch .manual_seed (self .SEED )
422+ qat_linear = Int4WeightOnlyQATLinear (
423+ 256 , 688 , bias = False , groupsize = group_size , device = device ,
424+ )
425+ ptq_linear = WeightOnlyInt4Linear (
426+ 256 , 688 , bias = False , groupsize = group_size , device = device ,
427+ )
428+
429+ # Force the weights to be the same
430+ self ._set_ptq_weight (ptq_linear , qat_linear )
431+
432+ # Compare linear values
433+ torch .manual_seed (self .SEED )
434+ x = torch .randn (100 , 256 , dtype = dtype , device = device )
435+ x2 = copy .deepcopy (x )
436+ qat_out = qat_linear (x )
437+ ptq_out = ptq_linear (x2 )
438+ self ._assert_close_4w (qat_out , ptq_out )
439+
440+ @unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "skipping when torch version is 2.4 or lower" )
441+ @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
442+ def test_qat_4w_quantizer (self ):
443+ from torchao .quantization .prototype .qat import Int4WeightOnlyQATQuantizer
444+ from torchao .quantization .GPTQ import Int4WeightOnlyQuantizer
445+
446+ group_size = 32
447+ inner_k_tiles = 8
448+ device = torch .device ("cuda" )
449+ dtype = torch .bfloat16
450+ torch .manual_seed (self .SEED )
451+ m = M ().to (device ).to (dtype )
452+ m2 = copy .deepcopy (m )
453+ qat_quantizer = Int4WeightOnlyQATQuantizer (
454+ groupsize = group_size , inner_k_tiles = inner_k_tiles ,
455+ )
456+ ptq_quantizer = Int4WeightOnlyQuantizer (
457+ groupsize = group_size , inner_k_tiles = inner_k_tiles ,
458+ )
459+ qat_model = qat_quantizer .prepare (m )
460+ ptq_model = ptq_quantizer .quantize (m2 )
461+
462+ # Compare model values
463+ torch .manual_seed (self .SEED )
464+ x = [i .to (device ).to (dtype ) for i in m .example_inputs ()]
465+ x2 = copy .deepcopy (x )
466+ qat_out = qat_model (* x )
467+ ptq_out = ptq_model (* x2 )
468+ self ._assert_close_4w (qat_out , ptq_out )
469+
470+ # Convert QAT model and compare model values
471+ converted_model = qat_quantizer .convert (qat_model )
472+ converted_out = converted_model (* x )
473+ torch .testing .assert_close (converted_out , ptq_out , atol = 0 , rtol = 0 )
474+
475+ # Compare converted state dict
476+ ptq_state_dict = ptq_model .state_dict ()
477+ converted_state_dict = converted_model .state_dict ()
478+ self .assertEqual (ptq_state_dict .keys (), converted_state_dict .keys ())
479+ for k in ptq_state_dict .keys ():
480+ torch .testing .assert_close (ptq_state_dict [k ], converted_state_dict [k ], atol = 0 , rtol = 0 )
481+
325482
326483if __name__ == "__main__" :
327484 unittest .main ()
0 commit comments