1212
1313import torch
1414import torch .nn .functional as F
15+ from parameterized import parameterized
1516from torch .ao .quantization .fx ._decomposed import quantized_decomposed_lib # noqa: F401
1617
1718from torchao import quantize_
4041 Int8DynActInt4WeightQATLinear ,
4142)
4243from torchao .quantization .qat .utils import (
43- _choose_qparams_per_token_asymmetric ,
4444 _fake_quantize_per_channel_group ,
4545 _fake_quantize_per_token ,
4646 _GenericFakeQuantize ,
5353 MappingType ,
5454 TorchAODType ,
5555 ZeroPointDomain ,
56+ choose_qparams_affine ,
57+ dequantize_affine ,
5658 fake_quantize_affine ,
59+ quantize_affine ,
5760)
5861from torchao .quantization .unified import (
5962 TwoStepQuantizer ,
6063)
6164from torchao .quantization .utils import (
65+ _get_per_token_block_size ,
6266 get_group_qparams_symmetric ,
6367 get_groupwise_affine_qparams ,
6468 groupwise_affine_quantize_tensor ,
@@ -134,12 +138,13 @@ def forward(self, x):
134138
135139
136140class M4 (torch .nn .Module ):
137- def __init__ (self ):
141+ def __init__ (self , dtype : torch . dtype = torch . float32 ):
138142 super ().__init__ ()
139- self .linear = torch .nn .Linear (512 , 256 , bias = False ).to (torch .float )
143+ self .dtype = dtype
144+ self .linear = torch .nn .Linear (512 , 256 , bias = False ).to (dtype )
140145
141146 def example_inputs (self ):
142- return (torch .randn (1 , 512 ).to (torch . float ),)
147+ return (torch .randn (1 , 512 ).to (self . dtype ),)
143148
144149 def forward (self , x ):
145150 return self .linear (x )
@@ -219,30 +224,41 @@ def test_fake_quantize_per_token(self):
219224 torch .manual_seed (self .SEED )
220225 x = torch .randn (100 , 256 ).requires_grad_ ()
221226 x2 = copy .deepcopy (x )
222- # TODO: use torch.ops.aten.quantized_decomposed version instead
223- (s , zp ) = _choose_qparams_per_token_asymmetric (x , torch .float32 , torch .int32 )
227+ block_size = _get_per_token_block_size (x )
228+ (s , zp ) = choose_qparams_affine (
229+ x ,
230+ mapping_type = MappingType .ASYMMETRIC ,
231+ block_size = block_size ,
232+ target_dtype = torch .int8 ,
233+ quant_min = - 128 ,
234+ quant_max = 127 ,
235+ scale_dtype = torch .float32 ,
236+ zero_point_dtype = torch .int32 ,
237+ )
224238
225239 # fake quant op
226240 out = _fake_quantize_per_token (x , s , zp , qmin , qmax )
227241 out .sum ().backward ()
228242
229243 # compare against PTQ ops
230- out_ptq = torch . ops . quantized_decomposed . quantize_per_token (
244+ out_ptq = quantize_affine (
231245 x2 ,
246+ block_size ,
232247 s ,
233248 zp ,
249+ torch .int8 ,
234250 qmin ,
235251 qmax ,
236- torch .int8 ,
237252 )
238- out_ptq = torch . ops . quantized_decomposed . dequantize_per_token (
253+ out_ptq = dequantize_affine (
239254 out_ptq ,
255+ block_size ,
240256 s ,
241257 zp ,
258+ torch .int8 ,
242259 qmin ,
243260 qmax ,
244- torch .int8 ,
245- torch .float32 ,
261+ output_dtype = torch .float32 ,
246262 )
247263 torch .testing .assert_close (out , out_ptq , atol = 0 , rtol = 0 )
248264
@@ -1004,8 +1020,15 @@ def linear_forward_8da4w(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
10041020 Baseline for int8 dynamic per token asymmetric + int4 per group symmetric quant.
10051021 """
10061022 # activations
1007- (s , zp ) = _choose_qparams_per_token_asymmetric (
1008- x , torch .float32 , torch .int32
1023+ (s , zp ) = choose_qparams_affine (
1024+ x ,
1025+ mapping_type = MappingType .ASYMMETRIC ,
1026+ block_size = _get_per_token_block_size (x ),
1027+ target_dtype = torch .int8 ,
1028+ quant_min = - 128 ,
1029+ quant_max = 127 ,
1030+ scale_dtype = torch .float32 ,
1031+ zero_point_dtype = torch .int32 ,
10091032 )
10101033 (qmin , qmax ) = _get_qmin_qmax (8 )
10111034 x_fq = _fake_quantize_per_token (x , s , zp , qmin , qmax )
@@ -1427,10 +1450,11 @@ def test_qat_linear_bias(self):
14271450 example_inputs = m .example_inputs ()
14281451 m (* example_inputs )
14291452
1453+ @parameterized .expand ([torch .float32 , torch .bfloat16 , torch .float16 ])
14301454 @unittest .skipIf (
14311455 not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower"
14321456 )
1433- def test_fake_quantize_per_token_vs_convert (self ):
1457+ def test_fake_quantize_per_token_vs_convert (self , dtype : torch . dtype ):
14341458 """
14351459 Test that the following produce the exact same numerics:
14361460 1. FakeQuantizer with asymmetric per_token config
@@ -1439,17 +1463,18 @@ def test_fake_quantize_per_token_vs_convert(self):
14391463 from torchao .quantization .utils import per_token_dynamic_quant
14401464
14411465 torch .manual_seed (self .SEED )
1442- x = torch .randn (1 , 235 , 2048 )
1466+ x = torch .randn (1 , 235 , 2048 ). to ( dtype )
14431467 config = FakeQuantizeConfig (torch .int8 , "per_token" , is_symmetric = False )
14441468 fake_quantizer = FakeQuantizer (config )
14451469 fake_quantizer_out = fake_quantizer (x )
14461470 baseline_out = per_token_dynamic_quant (x )
14471471 torch .testing .assert_close (fake_quantizer_out , baseline_out , atol = 0 , rtol = 0 )
14481472
1473+ @parameterized .expand ([torch .float32 , torch .bfloat16 , torch .float16 ])
14491474 @unittest .skipIf (
14501475 not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower"
14511476 )
1452- def test_qat_8da4w_prepare_vs_convert (self ):
1477+ def test_qat_8da4w_prepare_vs_convert (self , dtype : torch . dtype ):
14531478 """
14541479 Test that the prepare and convert steps of Int8DynActInt4QATQuantizer produces
14551480 numerics that match exactly over N trials.
@@ -1463,7 +1488,7 @@ def test_qat_8da4w_prepare_vs_convert(self):
14631488
14641489 for seed in range (self .SEED , self .SEED + num_trials ):
14651490 torch .manual_seed (seed )
1466- m = M4 ()
1491+ m = M4 (dtype )
14671492 torch .manual_seed (seed )
14681493 x = m .example_inputs ()
14691494
0 commit comments