Skip to content

Conversation

tjtanaa
Copy link
Contributor

@tjtanaa tjtanaa commented Apr 7, 2025

Description

This PR fixes two bugs:

  1. TypeError: ROCmFlashAttentionImpl.__init__() got an unexpected keyword argument 'use_irope'
  2. Fix the topk_weights in invoke_fused_moe_kernel being not contiguous under V1 + ROCm + torch.compile + Dynamo + hipgraph mode.

tjtanaa and others added 2 commits April 7, 2025 15:55
Co-authored-by: Hongxia Yang <[email protected]>
Signed-off-by: tjtanaa <[email protected]>
Copy link

github-actions bot commented Apr 7, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@hongxiayang
Copy link
Collaborator

cc @houseroad

@tjtanaa tjtanaa marked this pull request as ready for review April 7, 2025 16:43
@hongxiayang
Copy link
Collaborator

@simon-mo @houseroad @SageMoore Can you help to merge this? This will unblock us from aiter integration for performance improvement. Thanks!

# with torch.compile Dynamo.
# V1 Engine on ROCm with eager mode is fine.
# V0 Engine on ROCm with HIPGraph is fine.
topk_weights = topk_weights.view(-1).reshape(topk_weights.shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible that this is related to Inductor always putting matrices in row-major order? And we should add a modifier to the custom op?

See comment in torch_bindings.cpp:

  // The default behavior in PyTorch 2.6 is "requires_contiguous", so we need
  // to override this for many GEMMs with the following tag. Otherwise,
  // torch.compile will force all input tensors to be contiguous(), which
  // will break many custom ops that require column-major weight matrices.
  // TODO: remove this for PyTorch 2.8, when the default is planned to switch
  // to match exact eager-mode strides.
  at::Tag stride_tag = at::Tag::needs_fixed_stride_order;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the issue might not be related to inductor as it does not happen on CUDA.
the topk_weights.stride() on CUDA returns (1,1) but on ROCm returns (1,1024)

Copy link
Collaborator

@houseroad houseroad Apr 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we create an issue to track this hack if @ProExpertProg's suggestion doesn't work

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is an issue created: ROCm/pytorch#2020

Copy link
Contributor Author

@tjtanaa tjtanaa Apr 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ProExpertProg @houseroad
The topk_weights is generated using Llama4MoE.custom_routing_function which is just a series of native PyTorch operator. So, there is no custom ops involved in generating topk_weights.

class Llama4MoE(nn.Module):

class Llama4MoE(nn.Module):

    @staticmethod
    def custom_routing_function(
        hidden_states: torch.Tensor,
        gating_output: torch.Tensor,
        topk: int,
        renormalize: bool,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        router_scores, router_indices = torch.topk(gating_output, topk, dim=-1)
        router_scores = torch.sigmoid(router_scores.float()).to(
            hidden_states.dtype)
        return (router_scores, router_indices.to(torch.int32))

The router_scores is the topk_weights of the fused_moe input.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not about the custom op generating but about consuming a tensor, so the wna16 op consumes this tensor and it might get transposed, unless that IP is marked with the tag

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ProExpertProg
Thank you for the leads. It seems there is a way to add the tag through the PyTorch Python API as well. We have expose the tags interface through direct_register_custom_op from vllm/utils.py which is a functioned proposed by Kaichao to register custom ops that are not traceable by torch.compile. Adding the tags tags=(torch.Tag.needs_fixed_stride_order,), does resolve the issue.

direct_register_custom_op(
    op_name="inplace_fused_experts",
    op_func=inplace_fused_experts,
    mutates_args=["hidden_states"],
    fake_impl=inplace_fused_experts_fake,
+    tags=(torch.Tag.needs_fixed_stride_order,),
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep this looks right, great work!

Copy link
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks reasonable. Let's try @ProExpertProg's suggestion for fixing the topk_weights issue.

Copy link
Collaborator

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks fine for unblocking now. We need to create 2 follow ups.

# with torch.compile Dynamo.
# V1 Engine on ROCm with eager mode is fine.
# V0 Engine on ROCm with HIPGraph is fine.
topk_weights = topk_weights.view(-1).reshape(topk_weights.shape)
Copy link
Collaborator

@houseroad houseroad Apr 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we create an issue to track this hack if @ProExpertProg's suggestion doesn't work

raise ValueError(
"ROCmFlashAttention does not support blocksparse attention.")

if use_irope:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Create an issue to trace this progress?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think the output will be incorrect with global attention

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember it seems reasonable, but we should definitely have the right approach here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed about tracking this issue if we want to fully support V0. We will create one internally. Does that sound good to you?

Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding support for tags! LGTM assuming this tag fixed the original issue!

@DarkLight1337 DarkLight1337 added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 8, 2025
Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stamp

@DarkLight1337 DarkLight1337 enabled auto-merge (squash) April 8, 2025 15:51
@vllm-bot vllm-bot merged commit 2976dc2 into vllm-project:main Apr 9, 2025
59 of 63 checks passed
yangw-dev pushed a commit to yangw-dev/vllm that referenced this pull request Apr 21, 2025
…nImpl and Triton Fused MoE bugs (vllm-project#16198)

Signed-off-by: tjtanaa <[email protected]>
Signed-off-by: kliuae <[email protected]>
Co-authored-by: Hongxia Yang <[email protected]>
Co-authored-by: kliuae <[email protected]>
Signed-off-by: Yang Wang <[email protected]>
jikunshang pushed a commit to jikunshang/vllm that referenced this pull request Apr 29, 2025
…nImpl and Triton Fused MoE bugs (vllm-project#16198)

Signed-off-by: tjtanaa <[email protected]>
Signed-off-by: kliuae <[email protected]>
Co-authored-by: Hongxia Yang <[email protected]>
Co-authored-by: kliuae <[email protected]>
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
…nImpl and Triton Fused MoE bugs (vllm-project#16198)

Signed-off-by: tjtanaa <[email protected]>
Signed-off-by: kliuae <[email protected]>
Co-authored-by: Hongxia Yang <[email protected]>
Co-authored-by: kliuae <[email protected]>
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
…nImpl and Triton Fused MoE bugs (vllm-project#16198)

Signed-off-by: tjtanaa <[email protected]>
Signed-off-by: kliuae <[email protected]>
Co-authored-by: Hongxia Yang <[email protected]>
Co-authored-by: kliuae <[email protected]>
Signed-off-by: Mu Huai <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants