Skip to content

Commit ddce558

Browse files
authored
[Quant][PT2E][X86] Enable annotation of aten.mul.tensor with X86InductorQuantizer (#2075)
1 parent 31d17c0 commit ddce558

File tree

2 files changed

+117
-0
lines changed

2 files changed

+117
-0
lines changed

test/quantization/pt2e/test_x86inductor_quantizer.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2876,6 +2876,73 @@ def test_lowering_to_x86(self):
28762876
lower=True,
28772877
)
28782878

2879+
@skipIfNoX86
2880+
def test_annotate_mul_tensor(self):
2881+
class M1(torch.nn.Module):
2882+
def __init__(self):
2883+
super().__init__()
2884+
2885+
def forward(self, x, y):
2886+
return x * y
2887+
2888+
class M2(torch.nn.Module):
2889+
def __init__(self):
2890+
super().__init__()
2891+
2892+
def forward(self, x, y):
2893+
return x * y.sum(-1)
2894+
2895+
class M3(torch.nn.Module):
2896+
def __init__(self):
2897+
super().__init__()
2898+
2899+
def forward(self, x, y):
2900+
return x * y.sum()
2901+
2902+
class M4(torch.nn.Module):
2903+
def __init__(self):
2904+
super().__init__()
2905+
2906+
def forward(self, x, y):
2907+
return x * y.sum().item()
2908+
2909+
for Mod in [M1, M2, M3, M4]:
2910+
with override_quantized_engine("x86"), torch.no_grad():
2911+
m = Mod().eval()
2912+
example_inputs = (torch.randn(64, 64), torch.randn(64, 64))
2913+
quantizer = X86InductorQuantizer().set_global(
2914+
xiq.get_default_x86_inductor_quantization_config()
2915+
)
2916+
quantizer.set_function_type_qconfig(
2917+
torch.mul, quantizer.get_global_quantization_config()
2918+
)
2919+
node_occurrence = {
2920+
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2
2921+
if isinstance(m, M1)
2922+
else 0,
2923+
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2
2924+
if isinstance(m, M1)
2925+
else 0,
2926+
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
2927+
torch.ops.quantized_decomposed.dequantize_per_channel.default: 0,
2928+
}
2929+
node_list = [
2930+
torch.ops.aten.mul.Tensor,
2931+
]
2932+
if isinstance(m, M1):
2933+
node_list = [
2934+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
2935+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
2936+
] + node_list
2937+
2938+
self._test_quantizer(
2939+
m,
2940+
example_inputs,
2941+
quantizer,
2942+
node_occurrence,
2943+
node_list,
2944+
)
2945+
28792946

28802947
if __name__ == "__main__":
28812948
run_tests()

torchao/quantization/pt2e/quantizer/x86_inductor_quantizer.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ class _X86InductorQuantizationAnnotation(QuantizationAnnotation):
9393
torch.ops.aten.conv1d.default,
9494
torch.ops.aten.conv2d.default,
9595
torch.ops.aten.linear.default,
96+
torch.ops.aten.mul.Tensor,
9697
}
9798

9899
# A superset of default_quantizable_ops includes operators support the int8 data type
@@ -219,6 +220,12 @@ def _map_module_function_to_aten_operator_type():
219220
],
220221
torch.ops.aten.matmul.default,
221222
),
223+
(
224+
[
225+
torch.mul,
226+
],
227+
torch.ops.aten.mul.Tensor,
228+
),
222229
)
223230
for map_item in map_list:
224231
module_function_to_aten_operator.update(dict.fromkeys(map_item[0], map_item[1])) # type: ignore[arg-type, call-overload]
@@ -735,6 +742,7 @@ def _annotate_with_config(
735742
self._annotate_conv2d_fusion_pattern(model, quantization_config, filter_fn)
736743
self._annotate_linear_fusion_pattern(model, quantization_config, filter_fn)
737744
self._annotate_matmul(model, quantization_config, filter_fn)
745+
self._annotate_mul_tensor(model, quantization_config, filter_fn)
738746

739747
# Step2: Recipe to propagate annotation for patterns beside conv/linear.
740748
# Go through all the nodes from start to end.
@@ -1577,5 +1585,47 @@ def _annotate_linear_binary_unary(
15771585
)
15781586
)
15791587

1588+
def _annotate_mul_tensor(
1589+
self,
1590+
model: torch.fx.GraphModule,
1591+
quantization_config: Optional[QuantizationConfig],
1592+
filter_fn: Optional[FilterFn] = None,
1593+
):
1594+
def _is_tensor(n: Node):
1595+
return isinstance(n, Node) and isinstance(
1596+
n.meta["val"], torch._subclasses.fake_tensor.FakeTensor
1597+
)
1598+
1599+
def _same_shape(n1: Node, n2: Node):
1600+
return n1.meta["val"].shape == n2.meta["val"].shape
1601+
1602+
for node in model.graph.nodes:
1603+
if node.target != torch.ops.aten.mul.Tensor:
1604+
continue
1605+
1606+
if _skip_annotate([node], filter_fn):
1607+
continue
1608+
1609+
if quantization_config is None:
1610+
_annotate_nodes_not_quantize(node)
1611+
continue
1612+
1613+
assert len(node.args) == 2
1614+
if not (_is_tensor(node.args[0]) and _is_tensor(node.args[1])):
1615+
continue
1616+
1617+
if not _same_shape(node.args[0], node.args[1]):
1618+
continue
1619+
1620+
input_qspec_map = {}
1621+
mul_node = node
1622+
for input_node in mul_node.args:
1623+
input_qspec_map[input_node] = get_input_act_qspec(quantization_config)
1624+
mul_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
1625+
input_qspec_map=input_qspec_map,
1626+
_annotated=True,
1627+
_is_output_of_quantized_pattern=True,
1628+
)
1629+
15801630
def validate(self, model: torch.fx.GraphModule) -> None:
15811631
pass

0 commit comments

Comments
 (0)