-
-
Notifications
You must be signed in to change notification settings - Fork 10.8k
[ROCm][torch.compile] Adding ROCm-specific fusion pass for integrating aiter act/rms MXFP4 operators #25860
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
xytpai
wants to merge
2
commits into
vllm-project:main
Choose a base branch
from
ROCm:xyt/fused_act_quant4
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+290
−18
Open
[ROCm][torch.compile] Adding ROCm-specific fusion pass for integrating aiter act/rms MXFP4 operators #25860
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,175 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| from typing import Any, NamedTuple | ||
|
|
||
| import torch | ||
| import torch._inductor.pattern_matcher as pm | ||
| from torch import fx | ||
| from torch._higher_order_ops.auto_functionalize import auto_functionalized | ||
| from torch._inductor.pattern_matcher import PatternMatcherPass | ||
| from torch._ops import OpOverload | ||
|
|
||
| from vllm.config import VllmConfig | ||
| from vllm.logger import init_logger | ||
| from vllm.model_executor.layers.quantization.utils.quant_utils import ( | ||
| GroupShape, QuantKey, ScaleDesc, kFp8DynamicTensorSym, kFp8DynamicTokenSym, | ||
| kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale) | ||
| from vllm.platforms import current_platform | ||
|
|
||
| from .inductor_pass import enable_fake_mode | ||
| from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
|
|
||
| def empty_bf16(*args, **kwargs): | ||
| return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda") | ||
|
|
||
|
|
||
| def empty_fp32(*args, **kwargs): | ||
| return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda") | ||
|
|
||
|
|
||
| def empty_i32(*args, **kwargs): | ||
| return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda") | ||
|
|
||
|
|
||
| def empty_fp4(*args, **kwargs): | ||
| return torch.empty(*args, **kwargs, dtype=torch.uint8, device="cuda") | ||
|
|
||
|
|
||
| class SiluMulMXFP4GemmPattern: | ||
| def __init__(self): | ||
| pass | ||
|
|
||
| def register(self, pm_pass: PatternMatcherPass): | ||
|
|
||
| def pattern(result: torch.Tensor, | ||
| result_silu_mul: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, | ||
| scale: torch.Tensor): | ||
| at1 = auto_functionalized(torch.ops._C.silu_and_mul.default, | ||
| result=result_silu_mul, | ||
| input=input) | ||
| at2 = auto_functionalized(torch.ops.vllm.gemm_with_dynamic_quant.default, | ||
| result=result, | ||
| x=at1[1], | ||
| weight=weight, | ||
| weight_scale=scale, | ||
| x_scales=None) | ||
| return at2[1] | ||
|
|
||
|
|
||
| def replacement(result: torch.Tensor, | ||
| result_silu_mul: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, | ||
| scale: torch.Tensor): | ||
| at = auto_functionalized(torch.ops.vllm.silu_and_mul_mxfp4_gemm.default, | ||
| result=result, | ||
| x=input, | ||
| weight=weight, | ||
| weight_scale=scale) | ||
| return at[1] | ||
|
|
||
| inputs = [ | ||
| empty_bf16(5, 4), # result | ||
| empty_bf16(5, 4), # result_silu_mul | ||
| empty_bf16(5, 4), # input | ||
| empty_fp4(5, 4), # weight | ||
| empty_fp4(1, 1), # scale | ||
| ] | ||
|
|
||
| pm.register_replacement( | ||
| pattern, | ||
| replacement, | ||
| inputs, | ||
| pm.fwd_only, | ||
| pm_pass, | ||
| ) | ||
|
|
||
|
|
||
| ADD_RMS_OP = torch.ops._C.fused_add_rms_norm.default | ||
|
|
||
|
|
||
| class AddRMSNormMXFP4GemmPattern: | ||
| def __init__(self, epsilon: float): | ||
| self.epsilon = epsilon | ||
| self.FUSED_OP = torch.ops.vllm.add_rmsnorm_mxfp4_gemm.default | ||
| self.QUANT_F4GEMM_OP = torch.ops.vllm.gemm_with_dynamic_quant.default | ||
|
|
||
| def register(self, pm_pass: PatternMatcherPass): | ||
|
|
||
| def pattern( | ||
| result: torch.Tensor, input: torch.Tensor, | ||
| residual: torch.Tensor, weight_rms: torch.Tensor, | ||
| weight_gemm: torch.Tensor, scale: torch.Tensor): | ||
| at1 = auto_functionalized(ADD_RMS_OP, | ||
| input=input, | ||
| residual=residual, | ||
| weight=weight_rms, | ||
| epsilon=self.epsilon) | ||
| at2 = auto_functionalized(self.QUANT_F4GEMM_OP, | ||
| result=result, | ||
| x=at1[1], | ||
| weight=weight_gemm, | ||
| weight_scale=scale, | ||
| x_scales=None) | ||
| return at2[1], at1[2] | ||
|
|
||
| def replacement( | ||
| result: torch.Tensor, input: torch.Tensor, | ||
| residual: torch.Tensor, weight_rms: torch.Tensor, | ||
| weight_gemm: torch.Tensor, scale: torch.Tensor): | ||
| at = auto_functionalized(self.FUSED_OP, | ||
| result=result, | ||
| input=input, | ||
| residual=residual, | ||
| residual_out=residual, | ||
| weight_rms=weight_rms, | ||
| weight_gemm=weight_gemm, | ||
| scale=scale, | ||
| epsilon=self.epsilon) | ||
| return at[1], at[2] | ||
|
|
||
| inputs = [ | ||
| empty_bf16(4, 4), # result | ||
| empty_bf16(4, 4), # input | ||
| empty_bf16(4, 4), # residual | ||
| empty_bf16(1, 4), # weight_rms | ||
| empty_fp4(4, 4), # weight_gemm | ||
| empty_fp4(1, 1), # scale | ||
| ] | ||
|
|
||
| pm.register_replacement( | ||
| pattern, | ||
| replacement, | ||
| inputs, | ||
| pm.fwd_only, | ||
| pm_pass) | ||
|
|
||
|
|
||
| class ROCmFusionPass(VllmPatternMatcherPass): | ||
| """ | ||
| This pass fuses a pre-defined set of custom ops into fused ops. | ||
| It uses the torch pattern matcher to find the patterns and replace them. | ||
| """ | ||
|
|
||
| @enable_fake_mode | ||
| def __init__(self, config: VllmConfig): | ||
| super().__init__(config) | ||
|
|
||
| self.patterns: PatternMatcherPass = PatternMatcherPass( | ||
| pass_name="rocm_fusion_pass") | ||
|
|
||
| SiluMulMXFP4GemmPattern().register(self.patterns) | ||
|
|
||
| for epsilon in [1e-5, 1e-6]: | ||
| AddRMSNormMXFP4GemmPattern(epsilon).register(self.patterns) | ||
|
|
||
| self.dump_patterns(config, self.patterns) | ||
|
|
||
| @VllmInductorPass.time_and_log | ||
| def __call__(self, graph: fx.Graph): | ||
| self.matched_count = self.patterns.apply(graph) | ||
| logger.debug("Replaced %s patterns", self.matched_count) | ||
|
|
||
| def uuid(self) -> Any: | ||
| return self.hash_source(self, SiluMulMXFP4GemmPattern) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type hints for
gemm_with_dynamic_quant_fakeare incorrect and do not match the function's implementation or its real counterpartgemm_with_dynamic_quant.x_scalesparameter is type-hinted astorch.Tensorbut has a default value ofNone. The type hint should beOptional[torch.Tensor]to reflect this.torch.Tensor, but it implicitly returnsNone. The return type hint should beNone.These inconsistencies can cause issues with static type checkers and
torch.compile's fake tensor propagation, which relies on correct function signatures for its analysis.