Skip to content

Commit 50736ef

Browse files
added the testcases
1 parent 6cd82d4 commit 50736ef

File tree

1 file changed

+265
-0
lines changed

1 file changed

+265
-0
lines changed
Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# BSD-3-Clause
3+
4+
# Owner(s): ["oncall: quantization"]
5+
6+
import functools
7+
import platform
8+
import unittest
9+
from typing import Dict
10+
11+
import torch
12+
import torch.nn as nn
13+
from torch.testing._internal.common_quantization import (
14+
NodeSpec as ns,
15+
)
16+
from torch.testing._internal.common_quantization import (
17+
QuantizationTestCase,
18+
skipIfNoInductorSupport,
19+
)
20+
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo
21+
22+
import torchao.quantization.pt2e.quantizer.arm_inductor_quantizer as armiq
23+
from torchao.quantization.pt2e.inductor_passes.arm import (
24+
_register_quantization_weight_pack_pass,
25+
)
26+
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
27+
from torchao.quantization.pt2e.quantizer.arm_inductor_quantizer import (
28+
ArmInductorQuantizer,
29+
)
30+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_7
31+
32+
33+
# ----------------------------------------------------------------------------- #
34+
# Helper decorators #
35+
# ----------------------------------------------------------------------------- #
36+
def skipIfNoArm(fn):
37+
reason = "Quantized operations require Arm."
38+
if isinstance(fn, type):
39+
if platform.processor() != "aarch64":
40+
fn.__unittest_skip__ = True
41+
fn.__unittest_skip_why__ = reason
42+
return fn
43+
44+
@functools.wraps(fn)
45+
def wrapper(*args, **kwargs):
46+
if platform.processor() != "aarch64":
47+
raise unittest.SkipTest(reason)
48+
return fn(*args, **kwargs)
49+
50+
return wrapper
51+
52+
53+
# ----------------------------------------------------------------------------- #
54+
# Mini-models #
55+
# ----------------------------------------------------------------------------- #
56+
class _SingleConv2d(nn.Module):
57+
def __init__(self):
58+
super().__init__()
59+
self.conv = nn.Conv2d(3, 6, kernel_size=3, stride=1, padding=1)
60+
61+
def forward(self, x):
62+
return self.conv(x)
63+
64+
65+
class _SingleLinear(nn.Module):
66+
def __init__(self, bias: bool = False):
67+
super().__init__()
68+
self.linear = nn.Linear(16, 16, bias=bias)
69+
70+
def forward(self, x):
71+
return self.linear(x)
72+
73+
74+
if TORCH_VERSION_AT_LEAST_2_5:
75+
from torch.export import export_for_training
76+
77+
78+
# ----------------------------------------------------------------------------- #
79+
# Base harness #
80+
# ----------------------------------------------------------------------------- #
81+
class _ArmInductorPerTensorTestCase(QuantizationTestCase):
82+
def _test_quantizer(
83+
self,
84+
model: torch.nn.Module,
85+
example_inputs: tuple[torch.Tensor, ...],
86+
quantizer: ArmInductorQuantizer,
87+
expected_node_occurrence: Dict[torch._ops.OpOverload, int],
88+
expected_node_list=None,
89+
*,
90+
is_qat: bool = False,
91+
lower: bool = False,
92+
):
93+
gm = export_for_training(model.eval(), example_inputs).module()
94+
95+
gm = prepare_pt2e(gm, quantizer)
96+
gm(*example_inputs)
97+
gm = convert_pt2e(gm)
98+
99+
if lower:
100+
# Register weight-pack pass (only affects per-tensor path; harmless otherwise)
101+
_register_quantization_weight_pack_pass(per_channel=False)
102+
from torch._inductor.constant_folding import constant_fold
103+
from torch._inductor.fx_passes.freezing_patterns import freezing_passes
104+
105+
gm.recompile()
106+
freezing_passes(gm, example_inputs)
107+
constant_fold(gm)
108+
gm(*example_inputs)
109+
110+
self.checkGraphModuleNodes(
111+
gm,
112+
expected_node_occurrence={
113+
ns.call_function(k): v for k, v in expected_node_occurrence.items()
114+
},
115+
expected_node_list=[
116+
ns.call_function(n) for n in (expected_node_list or [])
117+
],
118+
)
119+
120+
121+
# ----------------------------------------------------------------------------- #
122+
# Test-suite #
123+
# ----------------------------------------------------------------------------- #
124+
@skipIfNoInductorSupport
125+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+")
126+
class TestQuantizePT2EArmInductorPerTensor(_ArmInductorPerTensorTestCase):
127+
# ------------------------------------------------------------------ #
128+
# 1. Conv2d - per-tensor static PTQ #
129+
# ------------------------------------------------------------------ #
130+
@skipIfNoArm
131+
def test_conv2d_per_tensor_weight(self):
132+
example_inputs = (torch.randn(2, 3, 16, 16),)
133+
q = ArmInductorQuantizer().set_global(
134+
armiq.get_default_arm_inductor_quantization_config(is_per_channel=False)
135+
)
136+
expected = {
137+
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
138+
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
139+
torch.ops.quantized_decomposed.dequantize_per_channel.default: 0,
140+
}
141+
self._test_quantizer(_SingleConv2d(), example_inputs, q, expected, lower=True)
142+
143+
# ------------------------------------------------------------------ #
144+
# 2. Linear - per-tensor static PTQ #
145+
# ------------------------------------------------------------------ #
146+
@skipIfNoArm
147+
def test_linear_per_tensor_weight(self):
148+
example_inputs = (torch.randn(4, 16),)
149+
q = ArmInductorQuantizer().set_global(
150+
armiq.get_default_arm_inductor_quantization_config(is_per_channel=False)
151+
)
152+
expected = {
153+
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
154+
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
155+
torch.ops.quantized_decomposed.dequantize_per_channel.default: 0,
156+
}
157+
self._test_quantizer(_SingleLinear(), example_inputs, q, expected, lower=True)
158+
159+
# ------------------------------------------------------------------ #
160+
# 3. Linear - per-tensor **dynamic** #
161+
# ------------------------------------------------------------------ #
162+
@skipIfNoArm
163+
def test_linear_dynamic_per_tensor_weight(self):
164+
example_inputs = (torch.randn(8, 16),)
165+
q = ArmInductorQuantizer().set_global(
166+
armiq.get_default_arm_inductor_quantization_config(
167+
is_dynamic=True, is_per_channel=False
168+
)
169+
)
170+
expected = {
171+
torch.ops.quantized_decomposed.choose_qparams.tensor: 1,
172+
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1,
173+
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1,
174+
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1,
175+
torch.ops.quantized_decomposed.dequantize_per_channel.default: 0,
176+
}
177+
self._test_quantizer(_SingleLinear(), example_inputs, q, expected, lower=True)
178+
179+
# ------------------------------------------------------------------ #
180+
# 4. Conv2d - **per-channel** static PTQ #
181+
# ------------------------------------------------------------------ #
182+
@skipIfNoArm
183+
def test_conv2d_per_channel_weight(self):
184+
example_inputs = (torch.randn(2, 3, 16, 16),)
185+
q = ArmInductorQuantizer().set_global(
186+
armiq.get_default_arm_inductor_quantization_config(is_per_channel=True)
187+
)
188+
expected = {
189+
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
190+
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1,
191+
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
192+
}
193+
self._test_quantizer(_SingleConv2d(), example_inputs, q, expected, lower=True)
194+
195+
# ------------------------------------------------------------------ #
196+
# 5. Linear - **per-channel** static PTQ #
197+
# ------------------------------------------------------------------ #
198+
@skipIfNoArm
199+
def test_linear_per_channel_weight(self):
200+
example_inputs = (torch.randn(4, 16),)
201+
q = ArmInductorQuantizer().set_global(
202+
armiq.get_default_arm_inductor_quantization_config(is_per_channel=True)
203+
)
204+
expected = {
205+
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
206+
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1,
207+
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
208+
}
209+
self._test_quantizer(_SingleLinear(), example_inputs, q, expected, lower=True)
210+
211+
# ------------------------------------------------------------------ #
212+
# 6. Conv2d - **QAT** per-tensor #
213+
# ------------------------------------------------------------------ #
214+
@skipIfTorchDynamo("slow under Dynamo")
215+
@skipIfNoArm
216+
def test_conv2d_qat_per_tensor_weight(self):
217+
example_inputs = (torch.randn(2, 3, 16, 16),)
218+
q = ArmInductorQuantizer().set_global(
219+
armiq.get_default_arm_inductor_quantization_config(is_qat=True)
220+
)
221+
expected = {
222+
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
223+
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
224+
torch.ops.quantized_decomposed.dequantize_per_channel.default: 0,
225+
}
226+
self._test_quantizer(
227+
_SingleConv2d(),
228+
example_inputs,
229+
q,
230+
expected,
231+
is_qat=True,
232+
lower=True,
233+
)
234+
235+
# ------------------------------------------------------------------ #
236+
# 7. Linear - **dynamic + QAT** per-tensor #
237+
# ------------------------------------------------------------------ #
238+
@skipIfTorchDynamo("slow under Dynamo")
239+
@skipIfNoArm
240+
def test_linear_dynamic_qat_per_tensor_weight(self):
241+
example_inputs = (torch.randn(8, 16),)
242+
q = ArmInductorQuantizer().set_global(
243+
armiq.get_default_arm_inductor_quantization_config(
244+
is_dynamic=True, is_qat=True, is_per_channel=False
245+
)
246+
)
247+
expected = {
248+
torch.ops.quantized_decomposed.choose_qparams.tensor: 1,
249+
torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1,
250+
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1,
251+
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1,
252+
torch.ops.quantized_decomposed.dequantize_per_channel.default: 0,
253+
}
254+
self._test_quantizer(
255+
_SingleLinear(),
256+
example_inputs,
257+
q,
258+
expected,
259+
is_qat=True,
260+
lower=True,
261+
)
262+
263+
264+
if __name__ == "__main__":
265+
run_tests()

0 commit comments

Comments
 (0)