-
-
Notifications
You must be signed in to change notification settings - Fork 10.7k
[Bug Fix][NOT READY FOR MERGE] Change the order where we apply router weights on input for MoE #16744
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Zijing Liu <[email protected]>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
if apply_router_weight_on_input: | ||
assert topk_ids.shape[1] == 1, "Can only apply router weight \ | ||
on input when topk is 1!" | ||
qcurr_hidden_states = qcurr_hidden_states * curr_topk_weights |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this won't work for fp8 because qcurr_hidden_states
would be fp8 dtype like torch.float8_e4m3fn
while curr_topk_weights
is in fp16 for scout, and mult is not implemented for fp8 so we can't do the casting here.
I think we need to do this in the fused moe trition kernel (fused_moe_kernel
) itself where it is accumulated in float32, maybe passing in a flag to multiply by topk weights before the matmul (edit: ah I see this is marked fixme)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. If we want to properly fix this, we probably need to modify all the existing MoE kernels to make sure they are applying the router weights in the proper order in the right order. Meanwhile, let me do this before moe_kernel_prepare_input
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we move this to somewhere on top where quantization not applied yet? e.g. 1267 line
Signed-off-by: Zijing Liu <[email protected]>
expert_ids, | ||
num_tokens_post_padded, | ||
apply_router_weight_on_input, | ||
# FIXME: Always False here because fused_moe_kernel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably this comment not needed, instead we add the explanation to where apply_router_weight_on_input actually used
Close the PR, refer to #16801 for the root cause fix. |
What does this PR do?
Issue
Firstly, we are working on Llama4 Scout INT4 checkpoint, and we noticed a huge eval score drop on TP=1.
Secondly, we also notice that INT4 Llama4 model returns different results for the same prompt in the same batch. This applies to all TPs=[1,2,4,8]
There is a user report for a similar issue when running on on-the-fly Scout checkpoint: #16337
Issue details and re-produce: https://docs.google.com/document/d/10k3yuyZ4OmN278hChwPGDcxgqdfXCYqUH0_rsDVMxYo/edit?usp=sharing
Root Cause
For Llama4, we need to apply router weights on the MoE inputs (hidden states) before any matrix multiplications (e.g. W13). However, currently, a list of our MoE kernels (e.g. moe_wna16_gemm, tritons, etc) apply router weights after matrix multiplications. The cutlass_fp8_moe_kernel is applying the router weights in the right order.
Test Plan
TP=1, INT4 checkpoint, we are seeing score goes up to normal range. Full eval is still in-progres.
TP=8, BF16 Scout HF public checkpoint
Ref TP=8 BF16 Scout HF public checkpoint, before the change: