1111import unittest
1212
1313import torch
14+ import torch .nn .functional as F
1415from torch .ao .quantization .fx ._decomposed import quantized_decomposed_lib # noqa: F401
1516from torchao .dtypes import (
1617 TensorCoreTiledLayoutType ,
1718)
1819from torchao .quantization .prototype .qat .api import (
1920 ComposableQATQuantizer ,
21+ FakeQuantizeConfig ,
22+ QuantizationGranularity ,
23+ )
24+ from torchao .quantization .prototype .qat .fake_quantizer import (
25+ FakeQuantizer ,
26+ )
27+ from torchao .quantization .prototype .qat .linear import (
28+ FakeQuantizedLinear ,
2029)
2130from torchao .quantization .prototype .qat .utils import (
2231 _choose_qparams_per_token_asymmetric ,
2332 _fake_quantize_per_channel_group ,
2433 _fake_quantize_per_token ,
34+ _get_qmin_qmax ,
2535 _GenericFakeQuantize ,
2636)
2737from torchao .quantization .quant_api import (
@@ -92,15 +102,10 @@ def forward(self, x):
92102class TestQAT (unittest .TestCase ):
93103 SEED = 123
94104
95- def _get_qmin_qmax (self , n_bit : int ):
96- qmin = - (2 ** (n_bit - 1 ))
97- qmax = 2 ** (n_bit - 1 ) - 1
98- return (qmin , qmax )
99-
100105 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
101106 def test_fake_quantize_per_channel_group (self ):
102107 n_bit = 4
103- (qmin , qmax ) = self . _get_qmin_qmax (n_bit )
108+ (qmin , qmax ) = _get_qmin_qmax (n_bit )
104109 group_size = 128
105110
106111 torch .manual_seed (self .SEED )
@@ -126,7 +131,7 @@ def test_fake_quantize_per_channel_group(self):
126131
127132 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
128133 def test_fake_quantize_per_token (self ):
129- (qmin , qmax ) = self . _get_qmin_qmax (8 )
134+ (qmin , qmax ) = _get_qmin_qmax (8 )
130135
131136 torch .manual_seed (self .SEED )
132137 x = torch .randn (100 , 256 ).requires_grad_ ()
@@ -165,11 +170,11 @@ def _set_ptq_weight(
165170 Int4WeightOnlyQATLinear ,
166171 )
167172 n_bit = 4
168- (qmin , qmax ) = self ._get_qmin_qmax (n_bit )
173+ (qmin , qmax ) = _get_qmin_qmax (n_bit )
174+ group_size = qat_linear .weight_fake_quantizer .config .group_size
169175 if isinstance (ptq_linear , Int8DynActInt4WeightLinear ):
170176 assert isinstance (qat_linear , Int8DynActInt4WeightQATLinear )
171177 fp32_weight = qat_linear .weight
172- group_size = qat_linear .groupsize
173178 (s , zp ) = get_group_qparams_symmetric (fp32_weight , n_bit , group_size )
174179 q_weight = torch .ops .quantized_decomposed .quantize_per_channel_group (
175180 fp32_weight , s , zp , qmin , qmax , torch .int8 , group_size ,
@@ -180,7 +185,7 @@ def _set_ptq_weight(
180185 elif isinstance (ptq_linear , WeightOnlyInt4Linear ):
181186 assert isinstance (qat_linear , Int4WeightOnlyQATLinear )
182187 (q_weight , scales_and_zeros ) = groupwise_affine_quantize_tensor (
183- qat_linear .weight , n_bit , qat_linear . groupsize ,
188+ qat_linear .weight , n_bit , group_size ,
184189 )
185190 q_weight = torch .ops .aten ._convert_weight_to_int4pack (
186191 q_weight .to ("cuda" ), qat_linear .inner_k_tiles ,
@@ -218,31 +223,36 @@ def test_qat_8da4w_linear(self):
218223 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
219224 def test_qat_8da4w_quantizer (self ):
220225 from torchao .quantization .prototype .qat import Int8DynActInt4WeightQATQuantizer
221- from torchao .quantization .prototype . qat . linear import Int8DynActInt4WeightQATQuantizer
226+ from torchao .quantization .GPTQ import Int8DynActInt4WeightQuantizer
222227
223228 group_size = 16
224229 torch .manual_seed (self .SEED )
225230 m = M ()
226231 m2 = copy .deepcopy (m )
227- subclass_quantizer = Int8DynActInt4WeightQATQuantizer (groupsize = group_size )
228- module_swap_quantizer = Int8DynActInt4WeightQATQuantizer (groupsize = group_size )
229- subclass_model = subclass_quantizer .prepare (m )
230- module_swap_model = module_swap_quantizer . prepare (m2 )
232+ qat_quantizer = Int8DynActInt4WeightQATQuantizer (groupsize = group_size )
233+ ptq_quantizer = Int8DynActInt4WeightQuantizer (groupsize = group_size )
234+ qat_model = qat_quantizer .prepare (m )
235+ ptq_model = ptq_quantizer . quantize (m2 )
231236
232237 # Compare model values
233238 torch .manual_seed (self .SEED )
234239 x = m .example_inputs ()
235240 x2 = copy .deepcopy (x )
236- subclass_out = subclass_model (* x )
237- module_swap_out = module_swap_model (* x2 )
238- torch .testing .assert_close (subclass_out , module_swap_out , atol = 0 , rtol = 0 )
241+ qat_out = qat_model (* x )
242+ ptq_out = ptq_model (* x2 )
243+ torch .testing .assert_close (ptq_out , qat_out , atol = 0 , rtol = 0 )
239244
240245 # Convert QAT model and compare model values
241- subclass_model = subclass_quantizer .convert (subclass_model )
242- module_swap_model = module_swap_quantizer .convert (module_swap_model )
243- subclass_out = subclass_model (* x )
244- module_swap_out = module_swap_model (* x2 )
245- torch .testing .assert_close (subclass_out , module_swap_out , atol = 0 , rtol = 0 )
246+ converted_model = qat_quantizer .convert (qat_model )
247+ converted_out = converted_model (* x )
248+ torch .testing .assert_close (ptq_out , converted_out , atol = 0 , rtol = 0 )
249+
250+ # Compare converted state dict
251+ ptq_state_dict = ptq_model .state_dict ()
252+ converted_state_dict = converted_model .state_dict ()
253+ self .assertEqual (ptq_state_dict .keys (), converted_state_dict .keys ())
254+ for k in ptq_state_dict .keys ():
255+ torch .testing .assert_close (ptq_state_dict [k ], converted_state_dict [k ], atol = 0 , rtol = 0 )
246256
247257 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
248258 def test_qat_8da4w_quantizer_meta_weights (self ):
@@ -275,9 +285,12 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):
275285 quantizer = Int8DynActInt4WeightQATQuantizer (groupsize = group_size )
276286 qat_model = quantizer .prepare (m )
277287 qat_model .apply (disable_8da4w_fake_quant )
278- self .assertFalse (qat_model .linear1 ._fake_quant_enabled )
279- self .assertFalse (qat_model .linear2 ._fake_quant_enabled )
280- self .assertFalse (qat_model .sub .linear ._fake_quant_enabled )
288+ self .assertFalse (qat_model .linear1 .activation_fake_quantizer .enabled )
289+ self .assertFalse (qat_model .linear1 .weight_fake_quantizer .enabled )
290+ self .assertFalse (qat_model .linear2 .activation_fake_quantizer .enabled )
291+ self .assertFalse (qat_model .linear2 .weight_fake_quantizer .enabled )
292+ self .assertFalse (qat_model .sub .linear .activation_fake_quantizer .enabled )
293+ self .assertFalse (qat_model .sub .linear .weight_fake_quantizer .enabled )
281294
282295 # Disabled fake quant is just a normal linear
283296 m2 .linear1 .weight = torch .nn .Parameter (qat_model .linear1 .weight )
@@ -292,9 +305,12 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):
292305
293306 # Renable fake quant
294307 qat_model .apply (enable_8da4w_fake_quant )
295- self .assertTrue (qat_model .linear1 ._fake_quant_enabled )
296- self .assertTrue (qat_model .linear2 ._fake_quant_enabled )
297- self .assertTrue (qat_model .sub .linear ._fake_quant_enabled )
308+ self .assertTrue (qat_model .linear1 .activation_fake_quantizer .enabled )
309+ self .assertTrue (qat_model .linear1 .weight_fake_quantizer .enabled )
310+ self .assertTrue (qat_model .linear2 .activation_fake_quantizer .enabled )
311+ self .assertTrue (qat_model .linear2 .weight_fake_quantizer .enabled )
312+ self .assertTrue (qat_model .sub .linear .activation_fake_quantizer .enabled )
313+ self .assertTrue (qat_model .sub .linear .weight_fake_quantizer .enabled )
298314
299315 # Fake quant should be applied as normal
300316 quantizer2 = Int8DynActInt4WeightQATQuantizer (groupsize = group_size )
@@ -407,7 +423,7 @@ def test_qat_generic_fake_quantize(self):
407423 the numerics of existing fake quantize ops in Pytorch in both
408424 the forward and the backward passes.
409425 """
410- (qmin , qmax ) = self . _get_qmin_qmax (4 )
426+ (qmin , qmax ) = _get_qmin_qmax (4 )
411427 py_input = torch .randn (16 , 64 ).float ().requires_grad_ ()
412428 py_s = torch .randn (16 ).float ()
413429 py_zp = torch .randint (qmax , size = (16 ,), dtype = torch .int32 )
@@ -521,7 +537,7 @@ def test_qat_4w_quantizer_gradients(self):
521537 @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
522538 def test_qat_4w_quantizer (self ):
523539 from torchao .quantization .prototype .qat import Int4WeightOnlyQATQuantizer
524- from torchao .quantization .prototype . qat . linear import Int4WeightOnlyQATQuantizer
540+ from torchao .quantization .GPTQ import Int4WeightOnlyQuantizer
525541
526542 group_size = 32
527543 inner_k_tiles = 8
@@ -530,29 +546,34 @@ def test_qat_4w_quantizer(self):
530546 torch .manual_seed (self .SEED )
531547 m = M ().to (device ).to (dtype )
532548 m2 = copy .deepcopy (m )
533- subclass_quantizer = Int4WeightOnlyQATQuantizer (
549+ qat_quantizer = Int4WeightOnlyQATQuantizer (
534550 groupsize = group_size , inner_k_tiles = inner_k_tiles ,
535551 )
536- module_swap_quantizer = Int4WeightOnlyQATQuantizer (
552+ ptq_quantizer = Int4WeightOnlyQuantizer (
537553 groupsize = group_size , inner_k_tiles = inner_k_tiles ,
538554 )
539- subclass_model = subclass_quantizer .prepare (m )
540- module_swap_model = module_swap_quantizer . prepare (m2 )
555+ qat_model = qat_quantizer .prepare (m )
556+ ptq_model = ptq_quantizer . quantize (m2 )
541557
542558 # Compare model values
543559 torch .manual_seed (self .SEED )
544560 x = [i .to (device ).to (dtype ) for i in m .example_inputs ()]
545561 x2 = copy .deepcopy (x )
546- subclass_out = subclass_model (* x )
547- module_swap_out = module_swap_model (* x2 )
548- torch . testing . assert_close ( subclass_out , module_swap_out , atol = 0 , rtol = 0 )
562+ qat_out = qat_model (* x )
563+ ptq_out = ptq_model (* x2 )
564+ self . _assert_close_4w ( qat_out , ptq_out )
549565
550566 # Convert QAT model and compare model values
551- subclass_model = subclass_quantizer .convert (subclass_model )
552- module_swap_model = module_swap_quantizer .convert (module_swap_model )
553- subclass_out = subclass_model (* x )
554- module_swap_out = module_swap_model (* x2 )
555- torch .testing .assert_close (subclass_out , module_swap_out , atol = 0 , rtol = 0 )
567+ converted_model = qat_quantizer .convert (qat_model )
568+ converted_out = converted_model (* x )
569+ torch .testing .assert_close (converted_out , ptq_out , atol = 0 , rtol = 0 )
570+
571+ # Compare converted state dict
572+ ptq_state_dict = ptq_model .state_dict ()
573+ converted_state_dict = converted_model .state_dict ()
574+ self .assertEqual (ptq_state_dict .keys (), converted_state_dict .keys ())
575+ for k in ptq_state_dict .keys ():
576+ torch .testing .assert_close (ptq_state_dict [k ], converted_state_dict [k ], atol = 0 , rtol = 0 )
556577
557578 class _MyQATQuantizer (TwoStepQuantizer ):
558579 """
@@ -603,5 +624,127 @@ def test_qat_4w_embedding(self):
603624 converted = quantizer .convert (model )
604625 converted_out = converted (* x )
605626
627+ def test_fake_quantize_config (self ):
628+ """
629+ Test initialization and property setting of `FakeQuantizeConfig`.
630+ """
631+ # basic configs
632+ per_token_config = FakeQuantizeConfig (8 , "per_token" )
633+ self .assertEqual (per_token_config .bit_width , 8 )
634+ self .assertEqual (per_token_config .granularity , QuantizationGranularity .PER_TOKEN )
635+ self .assertIsNone (per_token_config .group_size )
636+ per_channel_config = FakeQuantizeConfig (4 , "per_channel" )
637+ self .assertEqual (per_channel_config .bit_width , 4 )
638+ self .assertEqual (per_channel_config .granularity , QuantizationGranularity .PER_CHANNEL )
639+ self .assertIsNone (per_channel_config .group_size )
640+
641+ # initialize per_group config using only group size
642+ per_group_config = FakeQuantizeConfig (4 , group_size = 32 )
643+ self .assertEqual (per_group_config .bit_width , 4 )
644+ self .assertEqual (per_group_config .granularity , QuantizationGranularity .PER_GROUP )
645+ self .assertEqual (per_group_config .group_size , 32 )
646+
647+ # set granularity after initialization, should accept str as before
648+ per_group_config .granularity = "per_token"
649+ self .assertEqual (per_token_config .granularity , QuantizationGranularity .PER_TOKEN )
650+
651+ # set group_size after initialization, should also update granularity
652+ per_group_config .group_size = 16
653+ self .assertEqual (per_group_config .granularity , QuantizationGranularity .PER_GROUP )
654+ self .assertEqual (per_group_config .group_size , 16 )
655+
656+ # bad config1: no granularity or group size provided
657+ with self .assertRaisesRegex (ValueError , "group_size or granularity must be set" ):
658+ FakeQuantizeConfig (8 )
659+
660+ # bad config2: 'per_group' but no group size
661+ with self .assertRaisesRegex (ValueError , "no group_size was set" ):
662+ FakeQuantizeConfig (8 , "per_group" )
663+
664+ # bad config3: group size was set but granularity was not 'per_group'
665+ with self .assertRaisesRegex (ValueError , "group_size was set" ):
666+ FakeQuantizeConfig (8 , "per_token" , group_size = 16 )
667+
668+ @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
669+ def test_fake_quantized_linear_8da4w (self ):
670+ """
671+ Test that we can express int8 dynamic activations + int4 weights with `FakeQuantizedLinear`.
672+ """
673+ group_size = 128
674+ torch .manual_seed (self .SEED )
675+ fq_linear = FakeQuantizedLinear (
676+ 256 ,
677+ 688 ,
678+ bias = False ,
679+ activation_config = FakeQuantizeConfig (8 , "per_token" , symmetric = False ),
680+ weight_config = FakeQuantizeConfig (4 , group_size = group_size ),
681+ )
682+
683+ def linear_forward_8da4w (x : torch .Tensor , weight : torch .Tensor ) -> torch .Tensor :
684+ """
685+ Baseline for int8 dynamic per token asymmetric + int4 per group symmetric quant.
686+ """
687+ # activations
688+ (s , zp ) = _choose_qparams_per_token_asymmetric (x , torch .float32 , torch .int32 )
689+ (qmin , qmax ) = _get_qmin_qmax (8 )
690+ x_fq = _fake_quantize_per_token (x , s , zp , qmin , qmax )
691+
692+ # weights
693+ (s , zp ) = get_group_qparams_symmetric (weight , 4 , group_size , torch .float32 )
694+ zp = zp .to (torch .int32 )
695+ (qmin , qmax ) = _get_qmin_qmax (4 )
696+ w_fq = _fake_quantize_per_channel_group (weight , s , zp , qmin , qmax , group_size )
697+ return F .linear (x_fq , w_fq )
698+
699+ # Compare linear values
700+ torch .manual_seed (self .SEED )
701+ x = torch .randn (100 , 256 )
702+ x2 = copy .deepcopy (x )
703+ fq_out = fq_linear (x )
704+ baseline_out = linear_forward_8da4w (x2 , fq_linear .weight )
705+ torch .testing .assert_close (baseline_out , fq_out , atol = 0 , rtol = 0 )
706+
707+ @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
708+ def test_fake_quantized_linear_4w (self ):
709+ """
710+ Test that we can express int4 weight only (tinygemm) with `FakeQuantizedLinear`.
711+ """
712+ group_size = 128
713+ weight_config = FakeQuantizeConfig (
714+ bit_width = 4 ,
715+ group_size = group_size ,
716+ symmetric = False ,
717+ zero_point_domain = ZeroPointDomain .FLOAT ,
718+ )
719+ torch .manual_seed (self .SEED )
720+ fq_linear = FakeQuantizedLinear (
721+ 256 ,
722+ 688 ,
723+ bias = False ,
724+ activation_config = None ,
725+ weight_config = weight_config ,
726+ )
727+
728+ def linear_forward_4w (x : torch .Tensor , weight : torch .Tensor ) -> torch .Tensor :
729+ """
730+ Baseline for int4 weight only fake quantization that simulates the tinygemm kernel.
731+ """
732+ (qmin , qmax ) = _get_qmin_qmax (4 , symmetric = False )
733+ (s , zp ) = get_groupwise_affine_qparams (weight , 4 , group_size , torch .float32 )
734+ zp = zp .to (torch .int32 )
735+ w_fq = _fake_quantize_per_channel_group (
736+ weight , s , zp , qmin , qmax , group_size , zero_point_domain = ZeroPointDomain .FLOAT ,
737+ )
738+ return F .linear (x , w_fq )
739+
740+ # Compare linear values
741+ torch .manual_seed (self .SEED )
742+ x = torch .randn (100 , 256 )
743+ x2 = copy .deepcopy (x )
744+ fq_out = fq_linear (x )
745+ baseline_out = linear_forward_4w (x2 , fq_linear .weight )
746+ torch .testing .assert_close (baseline_out , fq_out , atol = 0 , rtol = 0 )
747+
748+
606749if __name__ == "__main__" :
607750 unittest .main ()
0 commit comments