-
Notifications
You must be signed in to change notification settings - Fork 344
[Inductor][float8] Support qlinear for float8 in inductor #2565
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Add fp8 dequant promotion
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2565
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit 4fb5f7a with merge base 8e2ca35 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@jerryzh168 Could you help review this pr |
@jerryzh168 Could you help review this pr? |
@shiyang-weng the registration PR is reverted in #2672, we'd need to land that again without breaking BC |
Hi @jerryzh168 Could you please provide a reproducer so that we can fix that? Thanks. |
yeah this is the test: ao/test/dtypes/test_affine_quantized_float.py Line 735 in 418593c
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds float8_e4m3fn support to PyTorch Inductor for qlinear operations, implementing quantization patterns specifically for FP8 data types. The implementation handles differences in FP8 quantization API requirements, including tensor-based scales and modified quantize/dequantize operations.
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 6 comments.
File | Description |
---|---|
torchao/quantization/pt2e/inductor_passes/x86.py | Adds FP8 quantization support with new patterns, updates existing functions to handle FP8 operations, and modifies view operation handling |
test/quantization/pt2e/test_x86inductor_fusion.py | Adds comprehensive test coverage for FP8 quantization patterns and refactors test helpers to support FP8 |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
x_zp = kwargs["x_zp"] if "x_zp" in kwargs else None | ||
w_zp = kwargs["w_zp"] if "w_zp" in kwargs else None |
Copilot
AI
Sep 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The extraction of qparams has inconsistent patterns. The first two use tuple unpacking while x_zp and w_zp use conditional extraction. For better maintainability and consistency, consider using the same pattern for all parameters.
x_zp = kwargs["x_zp"] if "x_zp" in kwargs else None | |
w_zp = kwargs["w_zp"] if "w_zp" in kwargs else None | |
x_zp = kwargs.get("x_zp") | |
w_zp = kwargs.get("w_zp") |
Copilot uses AI. Check for mistakes.
is_tensor_overload, | ||
is_fp8, | ||
) in linear_weight_prepack_cases: | ||
if is_fp8 and not is_tensor_overload: |
Copilot
AI
Sep 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] This skip condition appears in multiple places (lines 1429 and 1506). Consider extracting this logic into a helper function or constant to avoid code duplication and improve maintainability.
if is_fp8 and not is_tensor_overload: | |
if _should_skip_fp8_case(is_fp8, is_tensor_overload): |
Copilot uses AI. Check for mistakes.
if output_dtype == torch.float8_e4m3fn: | ||
# For float8, torchao.quantize_affine_float8 requires tensor as scale | ||
# Support scale node is full firstly | ||
assert kwargs["o_inv_scale"].target is torch.ops.aten.full.default |
Copilot
AI
Sep 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assertion assumes kwargs[\"o_inv_scale\"]
is always a node object, but there's no validation that it has a target
attribute. This could cause an AttributeError if the object doesn't have this attribute.
assert kwargs["o_inv_scale"].target is torch.ops.aten.full.default | |
assert hasattr(kwargs["o_inv_scale"], "target") and kwargs["o_inv_scale"].target is torch.ops.aten.full.default, ( | |
"Expected kwargs['o_inv_scale'] to be a node object with 'target' attribute set to torch.ops.aten.full.default" | |
) |
Copilot uses AI. Check for mistakes.
# check if scale created by torch.tensor | ||
return ( | ||
len(node.all_input_nodes) == 2 | ||
and node.all_input_nodes[1].target == torch.tensor |
Copilot
AI
Sep 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Using torch.tensor
as a target comparison might be fragile since it's comparing against a function object. Consider using a more robust method to identify tensor creation nodes, such as checking the function name or using a more specific target.
and node.all_input_nodes[1].target == torch.tensor | |
and torch.fx.node._qualified_name(node.all_input_nodes[1].target) == "torch.tensor" |
Copilot uses AI. Check for mistakes.
class FP8QDQLinear(torch.nn.Module): | ||
def __init__(self, in_features, out_features, has_bias): | ||
super().__init__() | ||
self.qtype = torch.float8_e4m3fn | ||
self.weight = torch.randn((out_features, in_features)).to(self.qtype) | ||
self.weight_scale = 2.0 | ||
self.scale = 2.0 | ||
self.bias = None | ||
if has_bias: | ||
self.bias = torch.randn((out_features,)) |
Copilot
AI
Sep 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The hardcoded scale values (2.0) should be configurable parameters or documented constants to improve maintainability and make the test more flexible.
Copilot uses AI. Check for mistakes.
if is_fp8: | ||
# fp8_convert_ not support dynamic and qat yet | ||
assert not is_dynamic | ||
assert not is_qat |
Copilot
AI
Sep 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] This assertion pattern appears multiple times in the test file (lines 206-208 and 1954-1957). Consider extracting this validation into a helper function to reduce code duplication.
Copilot uses AI. Check for mistakes.
This PR used for support fp8 on PT. |
CC @mingfeima for review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you need to change the quant flow code to produce this op?
I'd recommend to do this by defining a new observer, use this API:
ao/test/quantization/pt2e/test_quantize_pt2e.py
Line 2315 in c96f2dd
def test_observer_callback(self): |
* quantize_affine_float8/dequantize_affine_float8 not decomposed on inductor * remove redundant unittest.skipIf * fix rebase issue * change dispatch key to a flag decomposed * support scaled_mm on inductor * fix rebase issue * support dequant promtion for fp8 * add ut * remove redundant codes * fix lint * resolve conflict * change to use qlinear * add ut * fix lint * support fp8 quant_lift_up * add reshape into _VIEW_METHOD_OPS * add quant_input_check * fix lint * refine ut * remove fp8 dynamic quant ut * fix output_scale issue * add float8_e4m3fn to dtype_list * refine code * refine code * fix bugs * add comment * merge main * change to use non-decomposed q/dq * fix lint * add version check * change version * fix attention bug; update ut * add liftup oplist
For float8_e4m3fn, support
on inductor.
For FP8, there are following issues
Based on these issues,