Skip to content

Conversation

@wenxindongwork
Copy link
Collaborator

@wenxindongwork wenxindongwork commented Oct 14, 2025

This PR introduces support for Data Parallelism in the vLLM TPU backend.

Data Parallelism is a sharding strategy that is intended to be applied to the following scanerios:

  1. Model replication. Replicating the model across a large slice increases the overall throughput of the system.
  2. KV cache de-duplication for large models with small number of kv heads, e.g. DeepSeekV3, Qwen3 235b, and models that use fp8 kv cache quantization. By default we shard the kv heads by TP, and if there is not enough heads to shard, we end up replicating the heads and kv cache, which results in wasted memory. Using attention DP will eliminate this waste by replicate the attention layer and sending different data to each attention replica.

Note that the vLLM TPU DP design (SPMD) is very different from vLLM GPU DP design (MPMD). vLLM GPU DP launches multiple vLLM EngineCore instances (one for DP rank) and communicate between processes, whereas vLLM TPU DP launches a single vLLM EngineCore and does the data sharding within one instance. SPMD is a more TPU and JAX native approach to DP.

DP scheduler
We introduce a new DP scheduler class that extends the base vLLM Scheduler to support data-parallel (DP) execution. It manages request distribution and KV cache allocation across multiple DP ranks, where each rank has its own logical KV cache shard and processes a subset of the total requests.

Input preparation
See changes in tpu_jax_runner._prepare_inputs() . When DP is enabled, we assign each request to a DP rank. Input tokens should be sorted by same DP ranks (e.g. inputs tokens from DP rank 0 should come before input tokens from DP rank 1, and so on). We do this sorting and padding in the _prepare_inputs_dp function. The input tokens will then be sharded by the DP axis, such that each attention layer replica processes one shard of the global input.

Sharding and Attention Axis Updates:

We introduced a new mesh axis: attn_dp to enable attention-only DP. Additionally, we modified sharding annotations for Llama3 to enable both model-wise and attention-only DP. All models should follow the same rough idea: attention weights (e.g. qkv projs, out_proj) should be replicated across the data and attn_dp dimension, and data should be sharded across the data and attn_dp dimension. MLP and MoE layers should be replicated across the data dimension but not the attn_dp dimension.

Usage:
It is recommended to use DP with async scheduling

Model wise DP: python examples/offline_inference.py --tensor_parallel_size=4 --data_parallel_size=2 --async-scheduling
Attention DP: python examples/offline_inference.py --tensor_parallel_size=8 --kv-cache-dtype=fp8 --additional_config='{"sharding":{"sharding_strategy": {"enable_dp_attention":1}}}' --async-scheduling

Attention DP will be automatically triggered when the enable_dp_attention flag is passed, and the exact dp_size will be automatically determined based on the number of kv heads and TP size.

Complementary vLLM PR

We have to make some minor changes to vLLM upstream.
https://github.com/vllm-project/vllm/pull/27365/files

Tests

  • New e2e BuildKite test on v6e-8(test_data_parallel.py) to verify 1. model parallelism, 2. attention data parallelism, 3. output correctness check. Performance tests not added in this PR.
  • Unit tests for the DPScheduler class and _prepare_input_dp function.

Buildkite: https://buildkite.com/tpu-commons/tpu-inference-ci/builds/4993

Performance

This PR introduces the functional DP implementation.

Model wise DP

  • Achieves ~72-80% of expected throughput.

Attention DP

  • 1.5x gain in effective KV cache size.

Future PRs

  • Support DP for speculative decoding, sequence paralellism, LoRA, and structured decoding.
  • Support DP for Torchax backend.

Checklist

Before submitting this PR, please make sure:

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have made or will make corresponding changes to any relevant documentation.

@wenxindongwork wenxindongwork force-pushed the dp_attention branch 4 times, most recently from 6d38e1d to da4519e Compare October 22, 2025 19:04
@wenxindongwork wenxindongwork self-assigned this Oct 27, 2025
@wenxindongwork wenxindongwork marked this pull request as ready for review October 27, 2025 18:17
@wenxindongwork wenxindongwork force-pushed the dp_attention branch 3 times, most recently from a1ea387 to f257867 Compare October 27, 2025 22:13
@kyuyeunk
Copy link
Collaborator

kyuyeunk commented Nov 5, 2025

are you planning to create a separate PR for sharding annotation related changes? it seems like that change makes the PR significantly larger than just DP related changes and make it hard to review.

Signed-off-by: wenxindongwork <[email protected]>
Signed-off-by: wenxindongwork <[email protected]>
Signed-off-by: wenxindongwork <[email protected]>
Signed-off-by: wenxindongwork <[email protected]>
Signed-off-by: wenxindongwork <[email protected]>
Signed-off-by: wenxindongwork <[email protected]>
Signed-off-by: wenxindongwork <[email protected]>
Signed-off-by: wenxindongwork <[email protected]>
Signed-off-by: wenxindongwork <[email protected]>
Signed-off-by: wenxindongwork <[email protected]>
Signed-off-by: wenxindongwork <[email protected]>
Signed-off-by: wenxindongwork <[email protected]>
Signed-off-by: wenxindongwork <[email protected]>
Signed-off-by: wenxindongwork <[email protected]>
Signed-off-by: wenxindongwork <[email protected]>
Signed-off-by: wenxindongwork <[email protected]>
Signed-off-by: wenxindongwork <[email protected]>
Signed-off-by: wenxindongwork <[email protected]>
Signed-off-by: wenxindongwork <[email protected]>
Signed-off-by: wenxindongwork <[email protected]>
Signed-off-by: wenxindongwork <[email protected]>
@wenxindongwork wenxindongwork merged commit a27922a into main Nov 6, 2025
3 checks passed
wenxindongwork added a commit that referenced this pull request Nov 6, 2025
sixiang-google pushed a commit that referenced this pull request Nov 6, 2025
sierraisland pushed a commit that referenced this pull request Nov 7, 2025
sierraisland pushed a commit that referenced this pull request Nov 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants