-
Notifications
You must be signed in to change notification settings - Fork 345
Enables the per_tensor lowering patterns for weight per_packing #2391
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Enables the per_tensor lowering patterns for weight per_packing #2391
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2391
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 2 New FailuresAs of commit 9f01b51 with merge base 8c6d754 ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
c698531
to
67d4a79
Compare
Hi @jerryzh168, @fadara01, @Xia-Weiwen can you please review this pr |
Thanks, can you add some tests in https://github.com/pytorch/ao/tree/main/test/quantization/pt2e |
67d4a79
to
d863085
Compare
Hi @jerryzh168, |
2caf61d
to
e51e9ec
Compare
Thanks for your PR! |
Hi @fadara01, Thanks for the response. to recreate the experiment
quant script
current setup |
Ahhh that's amazing! I remember doing a PoC for this exact thing back in the day and I had to tweak qlinear/qconv, hence my question. |
Hi @jerryzh168, @fadara01, can you please approve and merge this change. |
@pytorchbot rebase |
1 similar comment
@pytorchbot rebase |
b5a6358
to
ab75a9b
Compare
Hi @jerryzh168, @fadara01, can you please approve and merge this change. |
ab75a9b
to
ad1ff8d
Compare
Hi @jerryzh168. @fadara01, can you please approve and merge this change. |
X86InductorQuantizer, | ||
) | ||
|
||
if TORCH_VERSION_AT_LEAST_2_7: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is deprecated btw, please use
if torch_version_at_least("2.8.0"): |
) | ||
|
||
if TORCH_VERSION_AT_LEAST_2_7: | ||
torch._inductor.config.pre_grad_custom_pass = quant_lift_up |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what happens when multiple backend set this one?
torch._inductor.config.pre_grad_custom_pass = quant_lift_up |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what happens when multiple backend set this one?
torch._inductor.config.pre_grad_custom_pass = quant_lift_up
previously, the last writer won, so we introduced chain instead of overwriting, so multiple backend can safely coexist.
details:
previously :-
torch._inductor.config.pre_grad_custom_pass = quant_lift_up
A single global callable meant the last assignment silently overwrite any prior pass.
Change:-
Chain ARM’s pass after any existing pass, instead of overwriting it.
This guarantees both passes run in a deterministic order.
added an helper function to chain rather overwrite
def _chain_pregrad_pass(new_pass):
prev = getattr(torch._inductor.config, "pre_grad_custom_pass", None)
if prev is None or prev is new_pass:
return new_pass
def _chained(gm):
# run previous first, then ours (conservative ordering)
prev(gm)
new_pass(gm)
return _chained
replacing direct pass with chaining:
if torch_version_at_least("2.8.0"):
torch._inductor.config.pre_grad_custom_pass = _chain_pregrad_pass(quant_lift_up)
Now both pass (prev -> Arm) will execute.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sory missed this one
ad1ff8d
to
c9417fa
Compare
Hi @jerryzh168, can you check this once |
from torchao.quantization.pt2e.inductor_passes.arm import ( | ||
_register_quantization_weight_pack_pass, | ||
) | ||
from torchao.quantization.pt2e.inductor_passes.x86 import ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems to be introducing dependency between arm and x86, is it possible to remove?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you are really reusing this, might be better to refactor this to a separate file and have both x86 and arm depend on it I think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I’ve removed the ARM→x86 import and refactored quant_lift_up
into a shared file(utils.py) so both backends depend on a neutral module instead of each other.
path:
ao/torchao/quantization/pt2e/inductor_passes/utils.py
_register_quantization_weight_pack_pass, | ||
) | ||
from torchao.quantization.pt2e.inductor_passes.x86 import ( | ||
quant_lift_up, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought this is a prev_pass
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the chaining helper,prev
is the existing torch._inductor.config.pre_grad_custom_pass
(if any). We now set:
torch._inductor.config.pre_grad_custom_pass = _chain_pregrad_pass(quant_lift_up)
which composes prev
(if present) with quant_lift_up (new)
. If prev
is quant_lift_up
, we skip wrapping to avoid double running.
c9417fa
to
9f01b51
Compare
This Pr is an extension of #2139 pr,
Major changes:
1)Introduced lowering pattern for "per_tensor" quantized weights.
2) Modified the original api
get_default_arm_inductor_quantization_config
to add user choice of using "per_tensor" and "per_channel" granularity in model weight's quantization.supported shapes:
Tested and verified for different models:
Example script for refence:
Results
All time in sec, Taken on Aws Graviton 3E 32 core Instance
Pip list
cc: @jerryzh168, @fadara01, @Xia-Weiwen