Skip to content

Conversation

cboss6
Copy link
Contributor

@cboss6 cboss6 commented Aug 27, 2025

Purpose

This PR introduces a novel static expert load balancing placement strategy (called Zigzag) designed for MoE models with multiple expert groups, such as the DeepSeek series.

Clipboard_Screenshot_1756349711

Through our heatmap analysis, we observed that in multi-expert-group MoE models such as DeepSeek, experts within the same group tend to be selected together in practical scenarios. Therefore, distributing them across different devices can bring performance benefits.

The zigzag expert placement feature has been validated on DeepSeek-R1, demonstrating ~8% improvement in QPM (Queries Per Minute) compared to the default configuration during our online serving benchmarking on a single node with h20*8.

The zigzag strategy optimizes how experts are distributed across parallel ranks by implementing a staggered placement pattern, which helps achieve better load balancing across expert parallel groups. This is particularly beneficial for models that use grouped top-k routing, where experts are organized into logical groups and the routing decisions are made within these groups. The implementation ensures that experts are distributed more evenly across ranks, reducing load imbalance and improving overall throughput performance in production environments.

Performance

Test Platform:
Vllm version: vllm/vllm-openai:v0.10.1.1
Model: DeepSeek-V2-Chat-0628,
GPU: H20 * 8
Parallel config 1: tp=8, enable_expert_parallel=True
Benchmark config: input_len=1024, output_len=512, request_rate=8, max_concurrency=8, num_prompts=32:
python3 ./bench_serving.py
--backend vllm
--dataset-name random
--model ${MODEL_PATH}
--random-input-len 1024
--random-output-len 128
--random-range-ratio 0.5
--tokenizer ./tokenizer
--dataset-path ./ShareGPT_V3_unfiltered_cleaned_split.json
--request-rate 8
--max-concurrency 8
--num-prompts 32
--base-url http://127.0.0.1:8000
--port 8000

Clipboard_Screenshot_1756730426

Conclusion: With only expert parallelism enabled, Zigzag improves throughput and end-to-end latency by approximately 3%.

Accuracy Test

Tested with Deepseek-v2-chat-0628 on h20*8 with following serving cmd:

python3 -u -m vllm.entrypoints.openai.api_server \
            --model ${model_path} \
            --trust-remote-code \
            --gpu-memory-utilization 0.85 \
            -tp 8 \ 
            --enable-expert-parallel \
            --enable_round_robin_expert_placement

Note: Deepseek-v2 has a bad behavior on our chosen dataset, just to make sure zigzag has no impact on accuracy.

Dataset vllm v0.10.1.1 This PR
Aime24 13.33% 10.00%
Gpqa 41.91% 45.96%
Math500 72.20% 71.20%

Tested with Deepseek-R1-0528 on h20*8 and verified zigzag has no impact on accuracy.

Dataset Accuracy Baseline Accuracy (zigzag-MR)
Aime24 79.80% 80.00%
Gpqa 71.50% 71.21%
Math500 97.30% 95.20%

Usage

To try out Zigzag static EPLB strategy, enable it with the following options:

from vllm import LLM, SamplingParams
model_path = '/model/path/to/DeepSeek/series'
model = LLM(model=model_path,
            enable_expert_parallel=True,
            enable_round_robin_expert_placement=True,
)

Compatibility

The zigzag pattern is designed for MoE models with multiple expert groups, such as the DeepSeek series. Note that this method cannot benefit from MoE models without expert groups.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a novel "Zigzag" static expert placement strategy for MoE models, which is a welcome performance optimization. The implementation is mostly sound, with the necessary configuration options and logic added. My review includes two main points of feedback. Firstly, there's some redundant code in the new zigzag placement logic that can be removed to improve clarity and correctness. Secondly, the assertion for validating the zigzag placement configuration could be improved by splitting it into multiple assertions with more specific error messages, which would enhance the developer experience when debugging configuration issues.

@cboss6 cboss6 force-pushed the cboss/zigzag-vllm branch from f2add2d to a89dcd3 Compare August 27, 2025 13:27
@DarkLight1337
Copy link
Member

You can fix the pre-commit about .md files by merging from main

@cboss6
Copy link
Contributor Author

cboss6 commented Aug 28, 2025

You can fix the pre-commit about .md files by merging from main

Done. Could you please have a further review? Thanks

@cboss6 cboss6 changed the title [Feat] A novel static EPLB placement strategy for MoE models. [EPLB] A novel static EPLB placement strategy for MoE models. Sep 1, 2025
@cboss6 cboss6 changed the title [EPLB] A novel static EPLB placement strategy for MoE models. [Feat][EPLB] A novel static EPLB placement strategy for MoE models. Sep 1, 2025
@hmellor
Copy link
Member

hmellor commented Sep 1, 2025

Should Zigzag always be used on grouped expert models?

I'm wondering if we should bother to make this configurable and just use Zigzag as the way we place experts for grouped expert models.

@cboss6
Copy link
Contributor Author

cboss6 commented Sep 2, 2025

Should Zigzag always be used on grouped expert models?

I'm wondering if we should bother to make this configurable and just use Zigzag as the way we place experts for grouped expert models.

Thanks for the reply! Yes, I believe Zigzag could be a more suitable default placement strategy for grouped expert models.
To avoid adding another user-facing option, I’ve added a lightweight check to determine whether to apply the Zigzag pattern automatically.
BTW, this approach has also been working well as the default in our internal vllm repo.

@abmfy
Copy link
Member

abmfy commented Sep 4, 2025

QQ: are you using random dataset for benchmarking here?

@abmfy
Copy link
Member

abmfy commented Sep 5, 2025

Just curious why the accuracy is good since this PR doesn't seem to be modifying the weight loader; the weights are loaded onto GPUs assuming say experts [0, 1, 2, 3] to GPU 0 and [4, 5, 6, 7] to GPU 1, etc.
But later the GPU 0 will assume it has [0, 4, 8, 12] according to the expert_map.
So this should cause misalignment between the expert router and the called experts.

Could you please provide the scripts you used for benchmarking and accuracy tests? Thank you!

@abmfy
Copy link
Member

abmfy commented Sep 5, 2025

Also, however, with EPLB enabled, I don’t think this holds, as it directly breaks the assumption about the physical experts’ locations in the algorithm; moreover, the EPLB algorithm already accounts for expert groups.

@cboss6
Copy link
Contributor Author

cboss6 commented Sep 5, 2025

@abmfy
This is a static EPLB method, and it does not affect dynamic EPLB methods in the current vLLM. It is only applied during the model’s initialization. Since the model’s weight loader comes afterward, there is no need to adjust the experts’ weights here.
Additionally, this feature has been running in our internal services for several months, and we have established mature validation tests to ensure its reliability.

@cboss6
Copy link
Contributor Author

cboss6 commented Sep 5, 2025

Just curious why the accuracy is good since this PR doesn't seem to be modifying the weight loader; the weights are loaded onto GPUs assuming say experts [0, 1, 2, 3] to GPU 0 and [4, 5, 6, 7] to GPU 1, etc. But later the GPU 0 will assume it has [0, 4, 8, 12] according to the expert_map. So this should cause misalignment between the expert router and the called experts.

Could you please provide the scripts you used for benchmarking and accuracy tests? Thank you!

As for the weight loading, in my memory, the calling order is roughly:
gpu_model_runner.py::load_model -> model_loader::initialization_model -> model.load_weights

My change (static expert placement) is applied during model initialization, before load_weights works. In other words, I finalize the expert_map first, and that mapping is then used for weight loader.

@cboss6
Copy link
Contributor Author

cboss6 commented Sep 5, 2025

Also, however, with EPLB enabled, I don’t think this holds, as it directly breaks the assumption about the physical experts’ locations in the algorithm; moreover, the EPLB algorithm already accounts for expert groups.

Static placement just decides the initial expert-to-device mapping. Dynamic EPLB still has full control at runtime to rebalance traffic or remap according to its own strategy. They don’t interfere; in fact, they can complement each other, because a well-chosen static placement can provide a good starting point. while dynamic EPLB continues to adapt to load patterns.

@abmfy
Copy link
Member

abmfy commented Sep 5, 2025

Thanks for your quick response and the contribution!

Just curious why the accuracy is good since this PR doesn't seem to be modifying the weight loader; the weights are loaded onto GPUs assuming say experts [0, 1, 2, 3] to GPU 0 and [4, 5, 6, 7] to GPU 1, etc. But later the GPU 0 will assume it has [0, 4, 8, 12] according to the expert_map. So this should cause misalignment between the expert router and the called experts.
Could you please provide the scripts you used for benchmarking and accuracy tests? Thank you!

As for the weight loading, in my memory, the calling order is roughly: gpu_model_runner.py::load_model -> model_loader::initialization_model -> model.load_weights

My change (static expert placement) is applied during model initialization, before load_weights works. In other words, I finalize the expert_map first, and that mapping is then used for weight loader.

No, actually the weight loader relies on the mapping in fused_moe.py#L1739-L1768. But this PR doesn’t seem to be doing that (though it’s certainly doable), which is why I’m asking.

@abmfy This is a static EPLB method, and it does not affect dynamic EPLB methods in the current vLLM. It is only applied during the model’s initialization. Since the model’s weight loader comes afterward, there is no need to adjust the experts’ weights here. Additionally, this feature has been running in our internal services for several months, and we have established mature validation tests to ensure its reliability.

That’s good to hear! I’m not questioning the reliability — just a bit concerned that some modifications (e.g., to the weight loader since it doesn’t use the expert_map) might be missing in this PR when upstreaming.

Also, however, with EPLB enabled, I don’t think this holds, as it directly breaks the assumption about the physical experts’ locations in the algorithm; moreover, the EPLB algorithm already accounts for expert groups.

Static placement just decides the initial expert-to-device mapping. Dynamic EPLB still has full control at runtime to rebalance traffic or remap according to its own strategy. They don’t interfere; in fact, they can complement each other, because a well-chosen static placement can provide a good starting point. while dynamic EPLB continues to adapt to load patterns.

I think we can add an option to enable the round-robin arrangement, but it shouldn’t be the default, since I’m concerned that other components (e.g., the EP kernels) may rely on the current linear expert arrangement.

In the current EPLB implementation, it also assumes a linear arrangement of physical experts, so the two cannot work together. That said, I agree we could support the round-robin arrangement only when EPLB is disabled, since EPLB is typically used with redundant experts and adapting round-robin in that case would be difficult.

@roywei
Copy link

roywei commented Sep 6, 2025

cc @wpc @charlotte12l @luccafong

Copy link

mergify bot commented Sep 8, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @cboss6.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 8, 2025
cboss6 and others added 17 commits September 16, 2025 15:43
Co-authored-by: Harry Mellor <[email protected]>
Signed-off-by: Chen Bruce <[email protected]>
Signed-off-by: bruceszchen <[email protected]>
Signed-off-by: bruceszchen <[email protected]>
Signed-off-by: bruceszchen <[email protected]>
Signed-off-by: bruceszchen <[email protected]>
Signed-off-by: bruceszchen <[email protected]>
Signed-off-by: Harry Mellor <[email protected]>
Signed-off-by: bruceszchen <[email protected]>
Signed-off-by: Harry Mellor <[email protected]>
Signed-off-by: bruceszchen <[email protected]>
Signed-off-by: bruceszchen <[email protected]>
@hmellor hmellor enabled auto-merge (squash) September 16, 2025 08:36
@hmellor hmellor merged commit 7ea5c73 into vllm-project:main Sep 16, 2025
49 checks passed
@github-project-automation github-project-automation bot moved this from To Triage to Done in gpt-oss Issues & Enhancements Sep 16, 2025
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
…llm-project#23745)

Signed-off-by: bruceszchen <[email protected]>
Signed-off-by: Chen Bruce <[email protected]>
Signed-off-by: Harry Mellor <[email protected]>
Signed-off-by: Chen Bruce <[email protected]>
Co-authored-by: lemon412 <[email protected]>
Co-authored-by: Harry Mellor <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build deepseek Related to DeepSeek models documentation Improvements or additions to documentation frontend gpt-oss Related to GPT-OSS models llama Related to Llama models multi-modality Related to multi-modality (#4194) new-model Requests to new models performance Performance-related issues qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm speculative-decoding structured-output tool-calling v1
Projects
Status: Done
Status: Done
Status: Done
Development

Successfully merging this pull request may close these issues.

6 participants