Skip to content

Conversation

liuzijing2014
Copy link
Collaborator

@liuzijing2014 liuzijing2014 commented Apr 16, 2025

What does this PR do?

Issue

Firstly, we are working on Llama4 Scout INT4 checkpoint, and we noticed a huge eval score drop on TP=1.

  • Mmlu_pro, full eval, tp=1: 0.013
  • Mmlu_pro, full eval, tp=2: 0.7143
  • Mmlu_pro, full eval, tp=4: 0.7119
  • Mmlu_pro, full eval, tp=8: 0.7117

Secondly, we also notice that INT4 Llama4 model returns different results for the same prompt in the same batch. This applies to all TPs=[1,2,4,8]

There is a user report for a similar issue when running on on-the-fly Scout checkpoint: #16337

========== Prompt ==============
Prompt: 'The capital of France is'
================================
========== Batch2 ==============
Generated text: ' Paris. The capital of Germany is Berlin. The capital of Italy is Rome. The capital of Spain is Madrid. The capital of England is London. The capital of America is Washington. The capital of China is Beijing. The capital of Japan is Tokyo. The capital of Australia is Canberra. The capital of India is New'

Generated text: ' niezaners why or五个 inalfora Hangアイ Rick in catching yourselves FCND\n Power,� :adev and niezane - need \nainn alkcusk. -size -----------------------------------------------------------------------------\nSYS124reiba trick and,, , hosverあたり For ... \n�idać sanctuary of the\n indeed selfstar cuns涉及,'
================================

Issue details and re-produce: https://docs.google.com/document/d/10k3yuyZ4OmN278hChwPGDcxgqdfXCYqUH0_rsDVMxYo/edit?usp=sharing

Root Cause

For Llama4, we need to apply router weights on the MoE inputs (hidden states) before any matrix multiplications (e.g. W13). However, currently, a list of our MoE kernels (e.g. moe_wna16_gemm, tritons, etc) apply router weights after matrix multiplications. The cutlass_fp8_moe_kernel is applying the router weights in the right order.

Test Plan

TP=1, INT4 checkpoint, we are seeing score goes up to normal range. Full eval is still in-progres.

|       Tasks       |Version|    Filter    |n-shot|  Metric   |   |Value |   |Stderr|
|-------------------|------:|--------------|-----:|-----------|---|-----:|---|-----:|
|mmlu_pro           |    2.0|custom-extract|      |exact_match|↑  |0.7188|±  |0.0040|
| - biology         |    2.1|custom-extract|     5|exact_match|↑  |0.8271|±  |0.0141|
| - business        |    2.1|custom-extract|     5|exact_match|↑  |0.7630|±  |0.0151|
| - chemistry       |    2.1|custom-extract|     5|exact_match|↑  |0.7818|±  |0.0123|
| - computer_science|    2.1|custom-extract|     5|exact_match|↑  |0.6756|±  |0.0231|
| - economics       |    2.1|custom-extract|     5|exact_match|↑  |0.8033|±  |0.0137|
| - engineering     |    2.1|custom-extract|     5|exact_match|↑  |0.6233|±  |0.0156|
| - health          |    2.1|custom-extract|     5|exact_match|↑  |0.7200|±  |0.0157|
| - history         |    2.1|custom-extract|     5|exact_match|↑  |0.6220|±  |0.0249|
| - law             |    2.1|custom-extract|     5|exact_match|↑  |0.5014|±  |0.0151|
| - math            |    2.1|custom-extract|     5|exact_match|↑  |0.8105|±  |0.0107|
| - other           |    2.1|custom-extract|     5|exact_match|↑  |0.6580|±  |0.0156|
| - philosophy      |    2.1|custom-extract|     5|exact_match|↑  |0.5812|±  |0.0221|
| - physics         |    2.1|custom-extract|     5|exact_match|↑  |0.7860|±  |0.0114|
| - psychology      |    2.1|custom-extract|     5|exact_match|↑  |0.7744|±  |0.0148|
| Tasks |Version|Filter|n-shot|     Metric      |   |Value |   |Stderr|
|-------|------:|------|-----:|-----------------|---|-----:|---|-----:|
|chartqa|      0|none  |     0|anywhere_accuracy|↑  |0.8868|±  |0.0063|
|       |       |none  |     0|exact_match      |↑  |0.6600|±  |0.0095|
|       |       |none  |     0|relaxed_accuracy |↑  |0.8844|±  |0.0064|

TP=8, BF16 Scout HF public checkpoint

|       Tasks       |Version|    Filter    |n-shot|  Metric   |   |Value |   |Stderr|
|-------------------|------:|--------------|-----:|-----------|---|-----:|---|-----:|
|mmlu_pro           |    2.0|custom-extract|      |exact_match|↑  |0.7153|±  |0.0040|
| - biology         |    2.1|custom-extract|     5|exact_match|↑  |0.7015|±  |0.0171|
| - business        |    2.1|custom-extract|     5|exact_match|↑  |0.7820|±  |0.0147|
| - chemistry       |    2.1|custom-extract|     5|exact_match|↑  |0.7898|±  |0.0121|
| - computer_science|    2.1|custom-extract|     5|exact_match|↑  |0.6683|±  |0.0233|
| - economics       |    2.1|custom-extract|     5|exact_match|↑  |0.8152|±  |0.0134|
| - engineering     |    2.1|custom-extract|     5|exact_match|↑  |0.6316|±  |0.0155|
| - health          |    2.1|custom-extract|     5|exact_match|↑  |0.7323|±  |0.0155|
| - history         |    2.1|custom-extract|     5|exact_match|↑  |0.6037|±  |0.0251|
| - law             |    2.1|custom-extract|     5|exact_match|↑  |0.4741|±  |0.0151|
| - math            |    2.1|custom-extract|     5|exact_match|↑  |0.8157|±  |0.0106|
| - other           |    2.1|custom-extract|     5|exact_match|↑  |0.6851|±  |0.0153|
| - philosophy      |    2.1|custom-extract|     5|exact_match|↑  |0.5932|±  |0.0220|
| - physics         |    2.1|custom-extract|     5|exact_match|↑  |0.7875|±  |0.0114|
| - psychology      |    2.1|custom-extract|     5|exact_match|↑  |0.7682|±  |0.0149|
| Tasks |Version|Filter|n-shot|     Metric      |   |Value |   |Stderr|
|-------|------:|------|-----:|-----------------|---|-----:|---|-----:|
|chartqa|      0|none  |     0|anywhere_accuracy|↑  |0.8896|±  |0.0063|
|       |       |none  |     0|exact_match      |↑  |0.6540|±  |0.0095|
|       |       |none  |     0|relaxed_accuracy |↑  |0.8864|±  |0.0063|

Ref TP=8 BF16 Scout HF public checkpoint, before the change:

|       Tasks       |Version|    Filter    |n-shot|  Metric   |   |Value |   |Stderr|
|-------------------|------:|--------------|-----:|-----------|---|-----:|---|-----:|
|mmlu_pro           |    2.0|custom-extract|      |exact_match|↑  |0.7149|±  |0.0040|
| - biology         |    2.1|custom-extract|     5|exact_match|↑  |0.7057|±  |0.0170|
| - business        |    2.1|custom-extract|     5|exact_match|↑  |0.7681|±  |0.0150|
| - chemistry       |    2.1|custom-extract|     5|exact_match|↑  |0.7818|±  |0.0123|
| - computer_science|    2.1|custom-extract|     5|exact_match|↑  |0.6829|±  |0.0230|
| - economics       |    2.1|custom-extract|     5|exact_match|↑  |0.8152|±  |0.0134|
| - engineering     |    2.1|custom-extract|     5|exact_match|↑  |0.6171|±  |0.0156|
| - health          |    2.1|custom-extract|     5|exact_match|↑  |0.7286|±  |0.0156|
| - history         |    2.1|custom-extract|     5|exact_match|↑  |0.5958|±  |0.0252|
| - law             |    2.1|custom-extract|     5|exact_match|↑  |0.4705|±  |0.0150|
| - math            |    2.1|custom-extract|     5|exact_match|↑  |0.8238|±  |0.0104|
| - other           |    2.1|custom-extract|     5|exact_match|↑  |0.6948|±  |0.0152|
| - philosophy      |    2.1|custom-extract|     5|exact_match|↑  |0.5872|±  |0.0221|
| - physics         |    2.1|custom-extract|     5|exact_match|↑  |0.7891|±  |0.0113|
| - psychology      |    2.1|custom-extract|     5|exact_match|↑  |0.7832|±  |0.0146|
| Tasks |Version|Filter|n-shot|     Metric      |   |Value |   |Stderr|
|-------|------:|------|-----:|-----------------|---|-----:|---|-----:|
|chartqa|      0|none  |     0|anywhere_accuracy|↑  |0.8900|±  |0.0063|
|       |       |none  |     0|exact_match      |↑  |0.6564|±  |0.0095|
|       |       |none  |     0|relaxed_accuracy |↑  |0.8872|±  |0.0063|

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

if apply_router_weight_on_input:
assert topk_ids.shape[1] == 1, "Can only apply router weight \
on input when topk is 1!"
qcurr_hidden_states = qcurr_hidden_states * curr_topk_weights
Copy link
Collaborator

@sarckk sarckk Apr 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this won't work for fp8 because qcurr_hidden_states would be fp8 dtype like torch.float8_e4m3fn while curr_topk_weights is in fp16 for scout, and mult is not implemented for fp8 so we can't do the casting here.

I think we need to do this in the fused moe trition kernel (fused_moe_kernel) itself where it is accumulated in float32, maybe passing in a flag to multiply by topk weights before the matmul (edit: ah I see this is marked fixme)

Copy link
Collaborator Author

@liuzijing2014 liuzijing2014 Apr 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. If we want to properly fix this, we probably need to modify all the existing MoE kernels to make sure they are applying the router weights in the proper order in the right order. Meanwhile, let me do this before moe_kernel_prepare_input.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we move this to somewhere on top where quantization not applied yet? e.g. 1267 line

Signed-off-by: Zijing Liu <[email protected]>
expert_ids,
num_tokens_post_padded,
apply_router_weight_on_input,
# FIXME: Always False here because fused_moe_kernel
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably this comment not needed, instead we add the explanation to where apply_router_weight_on_input actually used

@liuzijing2014
Copy link
Collaborator Author

Close the PR, refer to #16801 for the root cause fix.

@liuzijing2014 liuzijing2014 deleted the in4-fix branch April 17, 2025 22:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants