@@ -272,6 +272,52 @@ def forward(self, x):
272272 prog = res [1 ]._mil_program
273273 assert get_op_types_in_program (prog ) == ["constexpr_blockwise_shift_scale" , "linear" ]
274274
275+ @pytest .mark .skipif (not _HAS_TORCHAO , reason = MSG_TORCHAO_NOT_FOUND )
276+ @pytest .mark .parametrize (
277+ "compute_unit, has_zeros" ,
278+ itertools .product (compute_units , [True , False ], [ct .target .IOS16 , ct .target .IOS17 ]),
279+ )
280+ def test_dequantize_affine_before_ios18 (self , compute_unit , has_zeros , minimum_deployment_target ):
281+
282+ quant_min = - 128
283+ quant_max = 127
284+
285+ n = 4
286+ k = 128
287+ input_dtype = torch .int8
288+ int_data = torch .randint (low = quant_min , high = quant_max , size = (n , k )).to (input_dtype )
289+ scale = torch .rand (n , 1 )
290+
291+ zero_point = None
292+ if has_zeros :
293+ zero_point = torch .randint (low = quant_min , high = quant_max , size = (n , 1 )).to (input_dtype )
294+
295+ class Model (torch .nn .Module ):
296+ def __init__ (self ):
297+ super ().__init__ ()
298+ self .register_buffer ("int_data" , int_data )
299+ self .register_buffer ("scale" , scale )
300+ self .register_buffer ("zero_point" , zero_point )
301+
302+ def forward (self , x ):
303+ w = torchao_quant .dequantize_affine (self .int_data , [1 , k ], self .scale , self .zero_point , input_dtype , quant_min , quant_max )
304+ return torch .nn .functional .linear (x , w )
305+
306+
307+ model = Model ()
308+ model = model .to (torch .device ("cpu" ))
309+
310+ input_shape = [(3 , k )]
311+ res = self .run_compare_torch (
312+ input_shape ,
313+ model ,
314+ minimum_deployment_target = minimum_deployment_target ,
315+ compute_unit = compute_unit ,
316+ rtol = 0.1 ,
317+ frontend = TorchFrontend .TORCHEXPORT ,
318+ )
319+ prog = res [1 ]._mil_program
320+ assert get_op_types_in_program (prog ) == ["constexpr_affine_dequantize" , "linear" ]
275321
276322
277323# TODO(rdar://108463675): refactor torch op tests later to parametrize quantized vs standard ops
0 commit comments