Skip to content

Conversation

shiyang-weng
Copy link
Contributor

@shiyang-weng shiyang-weng commented Jul 17, 2025

For float8_e4m3fn, support

register_qlinear_weight_prepack
_register_qlinear_unary_fusion
_register_qlinear_binary_fusion
quant_lift_up

on inductor.

For FP8, there are following issues

  1. q/dq switch to use quantize_affine_float8/dequantize_affine_float8
  2. The q/dq API change. The fp8 q/dq requires type(scale) is tensor.
  3. pt2e not support float8.

Based on these issues,

  1. Need to handle fp8 q/dq pattern separately.
  2. Handle scale separately.
  3. We implement the function(fp8_convert_), which can add q/dq before the linear in the model. We add the function to test/quantization/pt2e/test_x86inductor_fusion.py

Copy link

pytorch-bot bot commented Jul 17, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2565

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 4fb5f7a with merge base 8e2ca35 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@shiyang-weng shiyang-weng marked this pull request as draft July 17, 2025 02:59
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 17, 2025
@shiyang-weng shiyang-weng marked this pull request as ready for review August 1, 2025 01:29
@shiyang-weng
Copy link
Contributor Author

@jerryzh168 Could you help review this pr

@shiyang-weng
Copy link
Contributor Author

@jerryzh168 Could you help review this pr?

@jerryzh168
Copy link
Contributor

jerryzh168 commented Aug 5, 2025

@shiyang-weng the registration PR is reverted in #2672, we'd need to land that again without breaking BC

@Xia-Weiwen
Copy link
Collaborator

@shiyang-weng the registration PR is reverted in #2672, we'd need to land that again without breaking BC

Hi @jerryzh168 Could you please provide a reproducer so that we can fix that? Thanks.

@jerryzh168
Copy link
Contributor

yeah this is the test:

def test_expected_kernels_on_gpu(self, granularity, torch_compile_mode):
that's added in the PR, need H100 GPU to run

@Xia-Weiwen Xia-Weiwen requested a review from Copilot September 23, 2025 03:05
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds float8_e4m3fn support to PyTorch Inductor for qlinear operations, implementing quantization patterns specifically for FP8 data types. The implementation handles differences in FP8 quantization API requirements, including tensor-based scales and modified quantize/dequantize operations.

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 6 comments.

File Description
torchao/quantization/pt2e/inductor_passes/x86.py Adds FP8 quantization support with new patterns, updates existing functions to handle FP8 operations, and modifies view operation handling
test/quantization/pt2e/test_x86inductor_fusion.py Adds comprehensive test coverage for FP8 quantization patterns and refactors test helpers to support FP8

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

Comment on lines +1107 to +1108
x_zp = kwargs["x_zp"] if "x_zp" in kwargs else None
w_zp = kwargs["w_zp"] if "w_zp" in kwargs else None
Copy link

Copilot AI Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The extraction of qparams has inconsistent patterns. The first two use tuple unpacking while x_zp and w_zp use conditional extraction. For better maintainability and consistency, consider using the same pattern for all parameters.

Suggested change
x_zp = kwargs["x_zp"] if "x_zp" in kwargs else None
w_zp = kwargs["w_zp"] if "w_zp" in kwargs else None
x_zp = kwargs.get("x_zp")
w_zp = kwargs.get("w_zp")

Copilot uses AI. Check for mistakes.

is_tensor_overload,
is_fp8,
) in linear_weight_prepack_cases:
if is_fp8 and not is_tensor_overload:
Copy link

Copilot AI Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] This skip condition appears in multiple places (lines 1429 and 1506). Consider extracting this logic into a helper function or constant to avoid code duplication and improve maintainability.

Suggested change
if is_fp8 and not is_tensor_overload:
if _should_skip_fp8_case(is_fp8, is_tensor_overload):

Copilot uses AI. Check for mistakes.

if output_dtype == torch.float8_e4m3fn:
# For float8, torchao.quantize_affine_float8 requires tensor as scale
# Support scale node is full firstly
assert kwargs["o_inv_scale"].target is torch.ops.aten.full.default
Copy link

Copilot AI Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The assertion assumes kwargs[\"o_inv_scale\"] is always a node object, but there's no validation that it has a target attribute. This could cause an AttributeError if the object doesn't have this attribute.

Suggested change
assert kwargs["o_inv_scale"].target is torch.ops.aten.full.default
assert hasattr(kwargs["o_inv_scale"], "target") and kwargs["o_inv_scale"].target is torch.ops.aten.full.default, (
"Expected kwargs['o_inv_scale'] to be a node object with 'target' attribute set to torch.ops.aten.full.default"
)

Copilot uses AI. Check for mistakes.

# check if scale created by torch.tensor
return (
len(node.all_input_nodes) == 2
and node.all_input_nodes[1].target == torch.tensor
Copy link

Copilot AI Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Using torch.tensor as a target comparison might be fragile since it's comparing against a function object. Consider using a more robust method to identify tensor creation nodes, such as checking the function name or using a more specific target.

Suggested change
and node.all_input_nodes[1].target == torch.tensor
and torch.fx.node._qualified_name(node.all_input_nodes[1].target) == "torch.tensor"

Copilot uses AI. Check for mistakes.

Comment on lines +104 to +113
class FP8QDQLinear(torch.nn.Module):
def __init__(self, in_features, out_features, has_bias):
super().__init__()
self.qtype = torch.float8_e4m3fn
self.weight = torch.randn((out_features, in_features)).to(self.qtype)
self.weight_scale = 2.0
self.scale = 2.0
self.bias = None
if has_bias:
self.bias = torch.randn((out_features,))
Copy link

Copilot AI Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The hardcoded scale values (2.0) should be configurable parameters or documented constants to improve maintainability and make the test more flexible.

Copilot uses AI. Check for mistakes.

Comment on lines +1954 to +1957
if is_fp8:
# fp8_convert_ not support dynamic and qat yet
assert not is_dynamic
assert not is_qat
Copy link

Copilot AI Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] This assertion pattern appears multiple times in the test file (lines 206-208 and 1954-1957). Consider extracting this validation into a helper function to reduce code duplication.

Copilot uses AI. Check for mistakes.

@shiyang-weng
Copy link
Contributor Author

This PR used for support fp8 on PT.
But it is not in PT2.8. So I add version check on UT

@shiyang-weng shiyang-weng marked this pull request as draft September 25, 2025 06:09
@shiyang-weng
Copy link
Contributor Author

CC @mingfeima for review

@shiyang-weng shiyang-weng marked this pull request as ready for review September 29, 2025 07:22
Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you need to change the quant flow code to produce this op?

I'd recommend to do this by defining a new observer, use this API:

def test_observer_callback(self):

@jerryzh168 jerryzh168 merged commit a52a64a into pytorch:main Oct 8, 2025
18 checks passed
jainapurva pushed a commit that referenced this pull request Oct 9, 2025
* quantize_affine_float8/dequantize_affine_float8 not decomposed on inductor

* remove redundant unittest.skipIf

* fix rebase issue

* change dispatch key to a flag decomposed

* support scaled_mm on inductor

* fix rebase issue

* support dequant promtion for fp8

* add ut

* remove redundant codes

* fix lint

* resolve conflict

* change to use qlinear

* add ut

* fix lint

* support fp8 quant_lift_up

* add reshape into _VIEW_METHOD_OPS

* add quant_input_check

* fix lint

* refine ut

* remove fp8 dynamic quant ut

* fix output_scale issue

* add float8_e4m3fn to dtype_list

* refine code

* refine code

* fix bugs

* add comment

* merge main

* change to use non-decomposed q/dq

* fix lint

* add version check

* change version

* fix attention bug; update ut

* add liftup oplist
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants