Skip to content

Conversation

xytpai
Copy link

@xytpai xytpai commented Sep 29, 2025

This PR adds few fusion passes for silu_mul_quant_mxfp4 and add_rmsnorm_quant_mxfp4
~ 2% Perf gain

image

@mergify mergify bot added the rocm Related to AMD ROCm label Sep 29, 2025
@xytpai xytpai marked this pull request as draft September 29, 2025 06:58
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces fusion passes for ROCm to optimize silu_mul operations with mxfp4 quantization. The changes include adding a new ROCmFusionPass, refactoring gemm_with_dynamic_quant to be compatible with torch.compile, and implementing new fused custom ops. The overall approach is sound and follows existing patterns for fusion passes in vLLM. I've found one issue related to incorrect type hints in a fake implementation function, which could cause problems with torch.compile.

Comment on lines 80 to +87
def gemm_with_dynamic_quant_fake(
result: torch.Tensor,
x: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
x_scales: torch.Tensor = None,
rocm_use_aiter_fp4_asm_gemm: bool = False,
out_dtype: Optional[torch.dtype] = torch.bfloat16,
) -> torch.Tensor:
return torch.empty((*x.shape[:-1], weight.shape[0]),
dtype=out_dtype,
device=x.device)
return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The type hints for gemm_with_dynamic_quant_fake are incorrect and do not match the function's implementation or its real counterpart gemm_with_dynamic_quant.

  1. The x_scales parameter is type-hinted as torch.Tensor but has a default value of None. The type hint should be Optional[torch.Tensor] to reflect this.
  2. The function is type-hinted to return torch.Tensor, but it implicitly returns None. The return type hint should be None.

These inconsistencies can cause issues with static type checkers and torch.compile's fake tensor propagation, which relies on correct function signatures for its analysis.

Suggested change
def gemm_with_dynamic_quant_fake(
result: torch.Tensor,
x: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
x_scales: torch.Tensor = None,
rocm_use_aiter_fp4_asm_gemm: bool = False,
out_dtype: Optional[torch.dtype] = torch.bfloat16,
) -> torch.Tensor:
return torch.empty((*x.shape[:-1], weight.shape[0]),
dtype=out_dtype,
device=x.device)
return
def gemm_with_dynamic_quant_fake(
result: torch.Tensor,
x: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
x_scales: Optional[torch.Tensor] = None,
rocm_use_aiter_fp4_asm_gemm: bool = False,
out_dtype: Optional[torch.dtype] = torch.bfloat16,
) -> None:
return

@xytpai xytpai changed the title [ROCm][torch.compile] Add act_mxfp4 fusion pass [ROCm][torch.compile] Adding ROCm-specific fusion pass for integrating aiter act/rms MXFP4 operators Sep 29, 2025
@xytpai xytpai marked this pull request as ready for review September 29, 2025 07:50
@mergify
Copy link

mergify bot commented Oct 3, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @xytpai.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Oct 3, 2025
Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of creating a new pass, can we add these patterns to the existing passes? Also please wait for #24604 which will add better pattern matching utilities

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

needs-rebase rocm Related to AMD ROCm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants