-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
[Bug] [ROCm] Fix Llama 4 Enablement Bug on ROCm: V0 ROCmFlashAttentionImpl and Triton Fused MoE bugs #16198
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Co-authored-by: Hongxia Yang <[email protected]> Signed-off-by: tjtanaa <[email protected]>
Signed-off-by: tjtanaa <[email protected]>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
cc @houseroad |
Signed-off-by: tjtanaa <[email protected]>
@simon-mo @houseroad @SageMoore Can you help to merge this? This will unblock us from aiter integration for performance improvement. Thanks! |
# with torch.compile Dynamo. | ||
# V1 Engine on ROCm with eager mode is fine. | ||
# V0 Engine on ROCm with HIPGraph is fine. | ||
topk_weights = topk_weights.view(-1).reshape(topk_weights.shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible that this is related to Inductor always putting matrices in row-major order? And we should add a modifier to the custom op?
See comment in torch_bindings.cpp:
// The default behavior in PyTorch 2.6 is "requires_contiguous", so we need
// to override this for many GEMMs with the following tag. Otherwise,
// torch.compile will force all input tensors to be contiguous(), which
// will break many custom ops that require column-major weight matrices.
// TODO: remove this for PyTorch 2.8, when the default is planned to switch
// to match exact eager-mode strides.
at::Tag stride_tag = at::Tag::needs_fixed_stride_order;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the issue might not be related to inductor as it does not happen on CUDA.
the topk_weights.stride()
on CUDA returns (1,1)
but on ROCm returns (1,1024)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we create an issue to track this hack if @ProExpertProg's suggestion doesn't work
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is an issue created: ROCm/pytorch#2020
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ProExpertProg @houseroad
The topk_weights
is generated using Llama4MoE.custom_routing_function
which is just a series of native PyTorch operator. So, there is no custom ops involved in generating topk_weights
.
vllm/vllm/model_executor/models/llama4.py
Line 44 in 027b204
class Llama4MoE(nn.Module): |
class Llama4MoE(nn.Module):
@staticmethod
def custom_routing_function(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
router_scores, router_indices = torch.topk(gating_output, topk, dim=-1)
router_scores = torch.sigmoid(router_scores.float()).to(
hidden_states.dtype)
return (router_scores, router_indices.to(torch.int32))
The router_scores
is the topk_weights
of the fused_moe input.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not about the custom op generating but about consuming a tensor, so the wna16 op consumes this tensor and it might get transposed, unless that IP is marked with the tag
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ProExpertProg
Thank you for the leads. It seems there is a way to add the tag through the PyTorch Python API as well. We have expose the tags interface through direct_register_custom_op
from vllm/utils.py
which is a functioned proposed by Kaichao to register custom ops that are not traceable by torch.compile
. Adding the tags tags=(torch.Tag.needs_fixed_stride_order,),
does resolve the issue.
direct_register_custom_op(
op_name="inplace_fused_experts",
op_func=inplace_fused_experts,
mutates_args=["hidden_states"],
fake_impl=inplace_fused_experts_fake,
+ tags=(torch.Tag.needs_fixed_stride_order,),
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep this looks right, great work!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks reasonable. Let's try @ProExpertProg's suggestion for fixing the topk_weights issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks fine for unblocking now. We need to create 2 follow ups.
# with torch.compile Dynamo. | ||
# V1 Engine on ROCm with eager mode is fine. | ||
# V0 Engine on ROCm with HIPGraph is fine. | ||
topk_weights = topk_weights.view(-1).reshape(topk_weights.shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we create an issue to track this hack if @ProExpertProg's suggestion doesn't work
raise ValueError( | ||
"ROCmFlashAttention does not support blocksparse attention.") | ||
|
||
if use_irope: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Create an issue to trace this progress?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think the output will be incorrect with global attention
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I remember it seems reasonable, but we should definitely have the right approach here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed about tracking this issue if we want to fully support V0. We will create one internally. Does that sound good to you?
Signed-off-by: kliuae <[email protected]>
Signed-off-by: tjtanaa <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding support for tags! LGTM assuming this tag fixed the original issue!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Stamp
…nImpl and Triton Fused MoE bugs (vllm-project#16198) Signed-off-by: tjtanaa <[email protected]> Signed-off-by: kliuae <[email protected]> Co-authored-by: Hongxia Yang <[email protected]> Co-authored-by: kliuae <[email protected]> Signed-off-by: Yang Wang <[email protected]>
…nImpl and Triton Fused MoE bugs (vllm-project#16198) Signed-off-by: tjtanaa <[email protected]> Signed-off-by: kliuae <[email protected]> Co-authored-by: Hongxia Yang <[email protected]> Co-authored-by: kliuae <[email protected]>
…nImpl and Triton Fused MoE bugs (vllm-project#16198) Signed-off-by: tjtanaa <[email protected]> Signed-off-by: kliuae <[email protected]> Co-authored-by: Hongxia Yang <[email protected]> Co-authored-by: kliuae <[email protected]>
…nImpl and Triton Fused MoE bugs (vllm-project#16198) Signed-off-by: tjtanaa <[email protected]> Signed-off-by: kliuae <[email protected]> Co-authored-by: Hongxia Yang <[email protected]> Co-authored-by: kliuae <[email protected]> Signed-off-by: Mu Huai <[email protected]>
Description
This PR fixes two bugs:
TypeError: ROCmFlashAttentionImpl.__init__() got an unexpected keyword argument 'use_irope'
topk_weights
ininvoke_fused_moe_kernel
being not contiguous under V1 + ROCm + torch.compile + Dynamo + hipgraph mode.