Skip to content

Commit fb4eb52

Browse files
committed
Fix dequantize_affine before iOS18
1 parent ea1d2de commit fb4eb52

File tree

2 files changed

+46
-1
lines changed

2 files changed

+46
-1
lines changed

coremltools/converters/mil/frontend/torch/quantization_ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -803,7 +803,6 @@ def dequantize_affine(context, node):
803803
int_data.astype(quantized_np_dtype),
804804
zero_point,
805805
scale,
806-
axis=-1,
807806
name=node.name,
808807
)
809808
context.add(output, node.name)

coremltools/converters/mil/frontend/torch/test/test_torch_quantization_ops.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,52 @@ def forward(self, x):
272272
prog = res[1]._mil_program
273273
assert get_op_types_in_program(prog) == ["constexpr_blockwise_shift_scale", "linear"]
274274

275+
@pytest.mark.skipif(not _HAS_TORCHAO, reason=MSG_TORCHAO_NOT_FOUND)
276+
@pytest.mark.parametrize(
277+
"compute_unit, has_zeros",
278+
itertools.product(compute_units, [True, False], [ct.target.IOS16, ct.target.IOS17]),
279+
)
280+
def test_dequantize_affine_before_ios18(self, compute_unit, has_zeros, minimum_deployment_target):
281+
282+
quant_min = -128
283+
quant_max = 127
284+
285+
n = 4
286+
k = 128
287+
input_dtype = torch.int8
288+
int_data = torch.randint(low=quant_min, high=quant_max, size=(n, k)).to(input_dtype)
289+
scale = torch.rand(n, 1)
290+
291+
zero_point = None
292+
if has_zeros:
293+
zero_point = torch.randint(low=quant_min, high=quant_max, size=(n, 1)).to(input_dtype)
294+
295+
class Model(torch.nn.Module):
296+
def __init__(self):
297+
super().__init__()
298+
self.register_buffer("int_data", int_data)
299+
self.register_buffer("scale", scale)
300+
self.register_buffer("zero_point", zero_point)
301+
302+
def forward(self, x):
303+
w = torchao_quant.dequantize_affine(self.int_data, [1, k], self.scale, self.zero_point, input_dtype, quant_min, quant_max)
304+
return torch.nn.functional.linear(x, w)
305+
306+
307+
model = Model()
308+
model = model.to(torch.device("cpu"))
309+
310+
input_shape = [(3, k)]
311+
res = self.run_compare_torch(
312+
input_shape,
313+
model,
314+
minimum_deployment_target=minimum_deployment_target,
315+
compute_unit=compute_unit,
316+
rtol=0.1,
317+
frontend=TorchFrontend.TORCHEXPORT,
318+
)
319+
prog = res[1]._mil_program
320+
assert get_op_types_in_program(prog) == ["constexpr_affine_dequantize", "linear"]
275321

276322

277323
# TODO(rdar://108463675): refactor torch op tests later to parametrize quantized vs standard ops

0 commit comments

Comments
 (0)