-
-
Notifications
You must be signed in to change notification settings - Fork 10.7k
[ROCm][torch.compile] Adding ROCm-specific fusion pass for integrating aiter act/rms MXFP4 operators #25860
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces fusion passes for ROCm to optimize silu_mul
operations with mxfp4
quantization. The changes include adding a new ROCmFusionPass
, refactoring gemm_with_dynamic_quant
to be compatible with torch.compile
, and implementing new fused custom ops. The overall approach is sound and follows existing patterns for fusion passes in vLLM. I've found one issue related to incorrect type hints in a fake implementation function, which could cause problems with torch.compile
.
def gemm_with_dynamic_quant_fake( | ||
result: torch.Tensor, | ||
x: torch.Tensor, | ||
weight: torch.Tensor, | ||
weight_scale: torch.Tensor, | ||
x_scales: torch.Tensor = None, | ||
rocm_use_aiter_fp4_asm_gemm: bool = False, | ||
out_dtype: Optional[torch.dtype] = torch.bfloat16, | ||
) -> torch.Tensor: | ||
return torch.empty((*x.shape[:-1], weight.shape[0]), | ||
dtype=out_dtype, | ||
device=x.device) | ||
return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type hints for gemm_with_dynamic_quant_fake
are incorrect and do not match the function's implementation or its real counterpart gemm_with_dynamic_quant
.
- The
x_scales
parameter is type-hinted astorch.Tensor
but has a default value ofNone
. The type hint should beOptional[torch.Tensor]
to reflect this. - The function is type-hinted to return
torch.Tensor
, but it implicitly returnsNone
. The return type hint should beNone
.
These inconsistencies can cause issues with static type checkers and torch.compile
's fake tensor propagation, which relies on correct function signatures for its analysis.
def gemm_with_dynamic_quant_fake( | |
result: torch.Tensor, | |
x: torch.Tensor, | |
weight: torch.Tensor, | |
weight_scale: torch.Tensor, | |
x_scales: torch.Tensor = None, | |
rocm_use_aiter_fp4_asm_gemm: bool = False, | |
out_dtype: Optional[torch.dtype] = torch.bfloat16, | |
) -> torch.Tensor: | |
return torch.empty((*x.shape[:-1], weight.shape[0]), | |
dtype=out_dtype, | |
device=x.device) | |
return | |
def gemm_with_dynamic_quant_fake( | |
result: torch.Tensor, | |
x: torch.Tensor, | |
weight: torch.Tensor, | |
weight_scale: torch.Tensor, | |
x_scales: Optional[torch.Tensor] = None, | |
rocm_use_aiter_fp4_asm_gemm: bool = False, | |
out_dtype: Optional[torch.dtype] = torch.bfloat16, | |
) -> None: | |
return |
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of creating a new pass, can we add these patterns to the existing passes? Also please wait for #24604 which will add better pattern matching utilities
This PR adds few fusion passes for
silu_mul_quant_mxfp4
andadd_rmsnorm_quant_mxfp4
~ 2% Perf gain