Skip to content

Conversation

@youzhedian
Copy link
Contributor

@youzhedian youzhedian commented Aug 27, 2025

This PR adds Decode Context Parallel (DCP) support for MLA inference, fully compatible with chunked prefill and APC.

You can enable DCP with --decode-context-parallel-size/-dcp xxx (only support flashmla backend now), and tp_size needs to be divisible by dcp_size, because the world size does not change by dcp, it simply reuse the GPUs of TP group, and split one TP group into tp_size//dcp_size DCP groups. e.g.

with -tp 8 -dcp 8 , we use 8 GPUs
with -tp 8 -dcp 4 , we use 8 GPUs
with -tp 4 -dcp 4 -pp 2 , we use 8 GPUs

and kvcache token budget always increased by `dcp` times.

This DCP implement store kvcache with an interleave style, the kvcache for the token whose token_idx is i is always stored on the GPU whose dcp_rank equals to i % dcp_world_size:

e.g. DCP2, req with prompt_len=5, generation_len=4:
kvcache store in dcp_rank0: 0, 2, 4, 6, 8 
kvcache store in dcp_rank1: 1, 3, 5, 7,

deepseek-ai/DeepSeek-V2-Lite-Chat gsm8k eval:

- TP4PP2DCP4
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.6687|±  | 0.013|
|     |       |strict-match    |     5|exact_match|↑  |0.6657|±  | 0.013|

-  TP4PP2DCP1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.6657|±  | 0.013|
|     |       |strict-match    |     5|exact_match|↑  |0.6611|±  | 0.013|

more info pls ref introduce Doc

Future work (These items will be tackled in follow-up PRs; community contributions are warmly welcomed.):

  • DCP support MLA fullgraph
  • Extend cp_gather_cache to handle scaled KV-cache and supersede gather_and_maybe_dequant_cache
  • DCP support triton_mla/cutlass_mla, DCP only support flashmla backend now
  • KV-cache deduplication via DCP in GQA (e.g., GQA8-TP16 with DCP2)
  • DCP support for MTP
  • Ring-CP style prefill context parallel (PCP) implement to optimize TTFT

@mergify
Copy link

mergify bot commented Aug 27, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @youzhedian.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Aug 27, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces Context Parallelism (CP) support for MLA inference, which is a significant feature enhancement. The changes are extensive, touching configuration, parallel state management, scheduling, KV cache, and attention backends. The implementation seems well-thought-out, with new CUDA kernels for CP-specific operations and corresponding Python wrappers and tests. The end-to-end tests comparing CP with TP are a good validation strategy.

My review found one critical bug fix in the cuda_communicator.py file, where a reduce_scatter operation was using a potentially non-contiguous tensor, which could lead to incorrect results. The provided patch correctly fixes this issue. The rest of the implementation for context parallelism appears solid.

Copy link
Member

@youkaichao youkaichao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the great work!

as discussed, there can be two types of cp, cp for prefill (where the world size is enlarged by cp) and cp for decode (where the world size does not change by cp). if possible, let's denote the current pr as decode-context-parallel-size and dcp_size to leave room for prefill cp in the future.

@youkaichao youkaichao changed the title [Feature] Support Context Parallel for MLA [Feature] Support Decode Context Parallel for MLA Aug 27, 2025
@youkaichao
Copy link
Member

@youzhedian to accelerate the review and merge (especially ci testing), maybe we can split the kernel side changes to a separate PR and get it merged first. then follow-up PRs can use pre-compiled wheels from that PR, with much faster ci testing.

@hmellor
Copy link
Member

hmellor commented Aug 27, 2025

I've just come across this PR adding cp to non-MLA attention #23703

@youzhedian
Copy link
Contributor Author

@youzhedian to accelerate the review and merge (especially ci testing), maybe we can split the kernel side changes to a separate PR and get it merged first. then follow-up PRs can use pre-compiled wheels from that PR, with much faster ci testing.

#23791 as suggested

@LucasWilkinson
Copy link
Collaborator

Cool thanks for taking this on! I think this can be done without any GPU model runner changes; I was working on a prototype but got unfortunately it got backburned for few months 😞 anyways just sharing here for an alternative solution that doesn't require as much more of the core code but potentially more susceptible to imbalance (its not fully functional yet)

#22789

@youzhedian youzhedian reopened this Aug 28, 2025
@facebook-github-bot
Copy link

@wushidonguc has imported this pull request. If you are a Meta employee, you can view this in D81728831.

eicherseiji pushed a commit to eicherseiji/vllm that referenced this pull request Sep 9, 2025
…#23734)

Signed-off-by: hongchao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Co-authored-by: hongchao <[email protected]>
Co-authored-by: youkaichao <[email protected]>
skyloevil pushed a commit to skyloevil/vllm that referenced this pull request Sep 13, 2025
…#23734)

Signed-off-by: hongchao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Co-authored-by: hongchao <[email protected]>
Co-authored-by: youkaichao <[email protected]>
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
…#23734)

Signed-off-by: hongchao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Co-authored-by: hongchao <[email protected]>
Co-authored-by: youkaichao <[email protected]>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
…#23734)

Signed-off-by: hongchao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Co-authored-by: hongchao <[email protected]>
Co-authored-by: youkaichao <[email protected]>
Signed-off-by: xuebwang-amd <[email protected]>
@gary-wjc
Copy link

Is this PR compatible with #22668 ? @youzhedian

xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
…#23734)

Signed-off-by: hongchao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Co-authored-by: hongchao <[email protected]>
Co-authored-by: youkaichao <[email protected]>
Signed-off-by: xuebwang-amd <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants