Skip to content

Commit f9c071d

Browse files
committed
[Quant][Inductor] Bug fix: mutation nodes not handled correctly for QLinearPointwiseBinaryPT2E (pytorch#127592)
Fixes pytorch#127402 - Revert some changes to `ir.MutationOutput` and inductor/test_flex_attention.py - Add checks of mutation for QLinearPointwiseBinaryPT2E Pull Request resolved: pytorch#127592 Approved by: https://github.com/leslie-fang-intel, https://github.com/Chillee
1 parent 1cd4199 commit f9c071d

File tree

3 files changed

+40
-10
lines changed

3 files changed

+40
-10
lines changed

test/inductor/test_flex_attention.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -776,11 +776,13 @@ def f(q, k, v):
776776
metrics.reset()
777777
f(q, k, v)
778778
accessed_bytes = 1 * 8 * 1024 * 64 * torch.float32.itemsize
779-
logsumexp_bytes = 1 * 8 * 1024 * torch.float32.itemsize
780779
num_accesses = 4 # q, k, v reads, one output.
781-
self.assertEqual(
782-
metrics.num_bytes_accessed, accessed_bytes * num_accesses + logsumexp_bytes
783-
)
780+
# TODO: Get rid of this fudge factor
781+
# We need this fudge factor for now, since
782+
# 1. For some reason we materialize the output of the attention unnecessarily (it's related to the mutation somehow)
783+
# 2. We also write the extraneous logsumexp
784+
num_accesses += 2
785+
self.assertLess(metrics.num_bytes_accessed, accessed_bytes * num_accesses)
784786

785787
@supported_platform
786788
@skip("Triton bug ") # https://github.com/pytorch/pytorch/issues/124571

test/inductor/test_mkldnn_pattern_matcher.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ def _test_code_common(
233233
rtol=1.3e-6,
234234
check_quantization=False,
235235
check_dynamic=None,
236+
num_include_ops=None,
236237
):
237238
with torch.no_grad():
238239
clone_inputs = self._clone_inputs(inputs)
@@ -245,6 +246,12 @@ def _test_code_common(
245246
)
246247
for op in include_ops:
247248
self.assertIn(op, source_code)
249+
if num_include_ops is not None:
250+
assert len(include_ops) == len(num_include_ops)
251+
for i in range(len(include_ops)):
252+
self.assertEqual(
253+
source_code.count(include_ops[i]), num_include_ops[i]
254+
)
248255
for op in exclude_ops:
249256
self.assertNotIn(op, source_code)
250257
if check_dynamic is not None:
@@ -1808,6 +1815,32 @@ def matcher_check_fn():
18081815
matcher_check_fn=matcher_check_fn,
18091816
is_qat=is_qat,
18101817
)
1818+
if torch._inductor.config.cpp_wrapper:
1819+
# For CPP wrapper
1820+
self._test_code_common(
1821+
mod,
1822+
(v,),
1823+
[
1824+
"op_qlinear_pointwise.call",
1825+
"op_qlinear_pointwise_binary.call",
1826+
],
1827+
[],
1828+
check_quantization=True,
1829+
num_include_ops=[2, 2],
1830+
)
1831+
else:
1832+
# For python wrapper
1833+
self._test_code_common(
1834+
mod,
1835+
(v,),
1836+
[
1837+
"torch.ops.onednn.qlinear_pointwise.default",
1838+
"torch.ops.onednn.qlinear_pointwise.binary",
1839+
],
1840+
[],
1841+
check_quantization=True,
1842+
num_include_ops=[2, 2],
1843+
)
18111844

18121845
@skipIfNoDynamoSupport
18131846
@skipIfNoONEDNN

torch/_inductor/ir.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4775,15 +4775,10 @@ def get_mutation_names(self):
47754775

47764776
def __init__(self, layout, mutated_node, node_doing_mutating):
47774777
# NB: Do not directly construct this - use `mark_node_as_mutating`
4778-
super().__init__(None, layout, [mutated_node], ())
4778+
super().__init__(None, layout, [mutated_node, node_doing_mutating], ())
47794779
self.node_doing_mutating = node_doing_mutating
47804780
self.name = V.graph.register_buffer(self)
47814781

4782-
def get_read_writes(self):
4783-
read_writes = super().get_read_writes()
4784-
read_writes.reads.add(dependencies.WeakDep(self.node_doing_mutating.get_name()))
4785-
return read_writes
4786-
47874782
def should_allocate(self):
47884783
return False
47894784

0 commit comments

Comments
 (0)