Skip to content

Conversation

choudhary-devang
Copy link
Collaborator

@choudhary-devang choudhary-devang commented Jun 17, 2025

This Pr is an extension of #2139 pr,

Major changes:
1)Introduced lowering pattern for "per_tensor" quantized weights.
2) Modified the original api get_default_arm_inductor_quantization_config to add user choice of using "per_tensor" and "per_channel" granularity in model weight's quantization.

supported shapes:

  1. s8:s8:f32 - (per_tensor / per_channel) input : s8, weight : s8, output : f32
  2. u8:s8:f32 - (per_tensor / per_channel ) input : u8, weight : s8, output : f32

Tested and verified for different models:

  • Bert model
  • Resnet model
  • Vit model
  • Custum models

Example script for refence:

import torch
from transformers import BertModel
from torchao.quantization.pt2e.quantize_pt2e import prepare_pt2e, convert_pt2e
import torchao.quantization.pt2e.quantizer.arm_inductor_quantizer as aiq
from torchao.quantization.pt2e.quantizer.arm_inductor_quantizer import ArmInductorQuantizer
import torch._inductor.config as config
# Enable C++ wrapper for Inductor
config.cpp_wrapper = True
config.freezing=True

model_name = "bert-base-uncased"
model = BertModel.from_pretrained(model_name)

# Set the model to eval mode
model = model.eval()

# Create the data, using dummy data here as an example
traced_bs = 32
seq_length = 128
x = torch.randint(0, 10000, (traced_bs, seq_length))
attention_mask = torch.ones((traced_bs, seq_length))
example_inputs = (x, attention_mask)

# Capture the FX Graph to be quantized
with torch.no_grad():
    exported_model = torch.export.export_for_training(model, example_inputs).module()
    # Set up the quantizer and prepare the model for post-training quantization
    quantizer = ArmInductorQuantizer()
    quantizer.set_global(aiq.get_default_arm_inductor_quantization_config(is_dynamic=True, is_per_channel=True))
    prepared_model = prepare_pt2e(exported_model, quantizer)
    converted_model = convert_pt2e(prepared_model)
    converted_model = torch.compile(converted_model)
    with torch.profiler.profile( record_shapes=True) as prof:
        for _ in range(200):
            converted_model(*example_inputs)
print(prof.key_averages(group_by_input_shape=True).table(sort_by="self_cpu_time_total"))

Results

Model FP32 quant (int8) Speedup
resnet 62.967 44.482 1.415561
bert 103.879 71.953 1.443706
vit 69.031 59.973 1.151035

All time in sec, Taken on Aws Graviton 3E 32 core Instance

Pip list

image

cc: @jerryzh168, @fadara01, @Xia-Weiwen

Copy link

pytorch-bot bot commented Jun 17, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2391

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 2 New Failures

As of commit 9f01b51 with merge base 8c6d754 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 17, 2025
@choudhary-devang
Copy link
Collaborator Author

Hi @jerryzh168, @fadara01, @Xia-Weiwen can you please review this pr
thankyou

@jerryzh168
Copy link
Contributor

Thanks, can you add some tests in https://github.com/pytorch/ao/tree/main/test/quantization/pt2e

@jerryzh168 jerryzh168 added the topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) label Jun 26, 2025
@choudhary-devang
Copy link
Collaborator Author

Hi @jerryzh168,
I have added the testcase specific for the changes and to keep them separate i have added the file like : -ao/test/quantization/pt2e/test_arm_inductor_quantizer_per_tensor.py
can you please review this,
thankyou

@choudhary-devang choudhary-devang force-pushed the Per_tensor_lowering branch 2 times, most recently from 2caf61d to e51e9ec Compare July 20, 2025 07:42
@fadara01
Copy link

Thanks for your PR!
Do we see any speedups (against fp32) for e.g. bert / resnet50 as a result of this lowering?
Do we need to do any work in pytorch - qconv and qlinear to support such lowerings?

@choudhary-devang
Copy link
Collaborator Author

Thanks for your PR! Do we see any speedups (against fp32) for e.g. bert / resnet50 as a result of this lowering? Do we need to do any work in pytorch - qconv and qlinear to support such lowerings?

Hi @fadara01, Thanks for the response.
I have updated the description to include some of the details, we don't need any changes in pytorch.
for my experimentation i have used pip install torch torchvision.

to recreate the experiment
Fp32 script

import torch
from transformers import BertModel

# model loading
model_name = "bert-base-uncased"
model = BertModel.from_pretrained(model_name)
# Create the data, using dummy data here as an example
traced_bs = 32
seq_length = 128
x = torch.randint(0, 10000, (traced_bs, seq_length))
attention_mask = torch.ones((traced_bs, seq_length))
example_inputs = (x, attention_mask)

# Inference 
with torch.no_grad():
    model = torch.compile(model)
    with torch.profiler.profile( record_shapes=True) as prof:
        for _ in range(200):
                model(x)
print(prof.key_averages(group_by_input_shape=True).table(sort_by="self_cpu_time_total"))

quant script

import torch
from transformers import BertModel
from torchao.quantization.pt2e.quantize_pt2e import prepare_pt2e, convert_pt2e
import torchao.quantization.pt2e.quantizer.arm_inductor_quantizer as aiq
from torchao.quantization.pt2e.quantizer.arm_inductor_quantizer import ArmInductorQuantizer
import torch._inductor.config as config
# Enable C++ wrapper for Inductor
config.cpp_wrapper = True
config.freezing=True

model_name = "bert-base-uncased"
model = BertModel.from_pretrained(model_name)

# Set the model to eval mode
model = model.eval()

# Create the data, using dummy data here as an example
traced_bs = 32
seq_length = 128
x = torch.randint(0, 10000, (traced_bs, seq_length))
attention_mask = torch.ones((traced_bs, seq_length))
example_inputs = (x, attention_mask)

# Capture the FX Graph to be quantized
with torch.no_grad():
    exported_model = torch.export.export_for_training(model, example_inputs).module()
    # Set up the quantizer and prepare the model for post-training quantization
    quantizer = ArmInductorQuantizer()
    quantizer.set_global(aiq.get_default_arm_inductor_quantization_config(is_dynamic=True, is_per_channel=True))
    prepared_model = prepare_pt2e(exported_model, quantizer)
    converted_model = convert_pt2e(prepared_model)
    converted_model = torch.compile(converted_model)
    with torch.profiler.profile( record_shapes=True) as prof:
        for _ in range(200):
            converted_model(*example_inputs)
print(prof.key_averages(group_by_input_shape=True).table(sort_by="self_cpu_time_total"))

current setup
**kernel **
onednn_verbose,v1,primitive,exec,cpu,matmul,lowp_gemm:acl,undef,src:s8:a:blocked:ab::f0 wei:s8::blocked:ab::f0 bia:f32:a:blocked:ab::f0_mask2 dst:f32:a:blocked:ab::f0,attr-scratchpad:user attr-scales:src0:0:f32+wei:0:f32 attr-zero-points:src0:0:s32,,50x512:512x1000,0.224854

@fadara01
Copy link

fadara01 commented Jul 21, 2025

Ahhh that's amazing! I remember doing a PoC for this exact thing back in the day and I had to tweak qlinear/qconv, hence my question.

@choudhary-devang
Copy link
Collaborator Author

Hi @jerryzh168, @fadara01, can you please approve and merge this change.
thankyou

@choudhary-devang
Copy link
Collaborator Author

@pytorchbot rebase

1 similar comment
@choudhary-devang
Copy link
Collaborator Author

@pytorchbot rebase

@choudhary-devang choudhary-devang force-pushed the Per_tensor_lowering branch 2 times, most recently from b5a6358 to ab75a9b Compare July 31, 2025 05:09
@choudhary-devang
Copy link
Collaborator Author

Hi @jerryzh168, @fadara01, can you please approve and merge this change.
thankyou

@choudhary-devang
Copy link
Collaborator Author

Hi @jerryzh168. @fadara01, can you please approve and merge this change.
Thankyou.

X86InductorQuantizer,
)

if TORCH_VERSION_AT_LEAST_2_7:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is deprecated btw, please use

if torch_version_at_least("2.8.0"):

)

if TORCH_VERSION_AT_LEAST_2_7:
torch._inductor.config.pre_grad_custom_pass = quant_lift_up
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens when multiple backend set this one?

torch._inductor.config.pre_grad_custom_pass = quant_lift_up

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens when multiple backend set this one?

torch._inductor.config.pre_grad_custom_pass = quant_lift_up

previously, the last writer won, so we introduced chain instead of overwriting, so multiple backend can safely coexist.

details:
previously :-

torch._inductor.config.pre_grad_custom_pass = quant_lift_up
A single global callable meant the last assignment silently overwrite any prior pass.

Change:-
Chain ARM’s pass after any existing pass, instead of overwriting it.
This guarantees both passes run in a deterministic order.

added an helper function to chain rather overwrite

def _chain_pregrad_pass(new_pass):
    prev = getattr(torch._inductor.config, "pre_grad_custom_pass", None)
    if prev is None or prev is new_pass:
        return new_pass
    def _chained(gm):
        # run previous first, then ours (conservative ordering)
        prev(gm)
        new_pass(gm)
    return _chained

replacing direct pass with chaining:

if torch_version_at_least("2.8.0"):
    torch._inductor.config.pre_grad_custom_pass = _chain_pregrad_pass(quant_lift_up)

Now both pass (prev -> Arm) will execute.

Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sory missed this one

@choudhary-devang
Copy link
Collaborator Author

Hi @jerryzh168, can you check this once
#2391 (comment)
and if everything looks okay to you, please merge this change

from torchao.quantization.pt2e.inductor_passes.arm import (
_register_quantization_weight_pack_pass,
)
from torchao.quantization.pt2e.inductor_passes.x86 import (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems to be introducing dependency between arm and x86, is it possible to remove?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you are really reusing this, might be better to refactor this to a separate file and have both x86 and arm depend on it I think

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’ve removed the ARM→x86 import and refactored quant_lift_up into a shared file(utils.py) so both backends depend on a neutral module instead of each other.
path:
ao/torchao/quantization/pt2e/inductor_passes/utils.py

_register_quantization_weight_pack_pass,
)
from torchao.quantization.pt2e.inductor_passes.x86 import (
quant_lift_up,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought this is a prev_pass?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the chaining helper,prev is the existing torch._inductor.config.pre_grad_custom_pass (if any). We now set:

torch._inductor.config.pre_grad_custom_pass = _chain_pregrad_pass(quant_lift_up)

which composes prev (if present) with quant_lift_up (new). If prev is quant_lift_up, we skip wrapping to avoid double running.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants