Skip to content

Commit c9417fa

Browse files
Replace global overwrite with chained pre_grad_custom_pass
1 parent 50736ef commit c9417fa

File tree

1 file changed

+22
-3
lines changed

1 file changed

+22
-3
lines changed

torchao/quantization/pt2e/quantizer/arm_inductor_quantizer.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,33 @@
4545
QuantizationAnnotation,
4646
QuantizationSpec,
4747
)
48-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_7
48+
49+
50+
def _chain_pregrad_pass(new_pass):
51+
"""
52+
Chain `new_pass` after any existing torch._inductor.config.pre_grad_custom_pass.
53+
If none exists or it's already the same callable, return `new_pass` as-is.
54+
"""
55+
prev = getattr(torch._inductor.config, "pre_grad_custom_pass", None)
56+
if prev is None or prev is new_pass:
57+
return new_pass
58+
59+
def _chained(graph_module):
60+
# Run previous pass first, then ours (order chosen to be conservative).
61+
prev(graph_module)
62+
new_pass(graph_module)
63+
64+
return _chained
65+
66+
67+
from torchao.utils import torch_version_at_least
4968

5069
from .x86_inductor_quantizer import (
5170
X86InductorQuantizer,
5271
)
5372

54-
if TORCH_VERSION_AT_LEAST_2_7:
55-
torch._inductor.config.pre_grad_custom_pass = quant_lift_up
73+
if torch_version_at_least("2.8.0"):
74+
torch._inductor.config.pre_grad_custom_pass = _chain_pregrad_pass(quant_lift_up)
5675
_register_quantization_weight_pack_pass()
5776

5877
FilterFn: TypeAlias = Callable[[List[Node]], bool]

0 commit comments

Comments
 (0)