Skip to content

[RFC]: Support Context Parallelism with Fully Sharded KV Cache and Ring Attention #26133

@qiruiyangmeta

Description

@qiruiyangmeta

Motivation.

Context parallelism introduces an additional degree of parallelism to LLM inference. While tensor parallelism and pipeline parallelism focus on distributing model weights and layers across devices, context parallelism specifically targets the parallel processing of multiple input contexts or sequences. By combining context parallelism with parallelisms, systems can achieve more scalable and efficient inference, leveraging all forms of parallelism to maximize hardware utilization and reduce latency.
Image
Context parallelism improves performance as the context length grows by distributing both computation and the KV cache across multiple GPUs. This approach effectively lowers processing latency and can also decrease the memory required per GPU potentially, especially when dealing with extremely large KV caches (such as sequence lengths on the order of 1 million tokens).

Proposed Change.

Within the model, attention is the only component that has dependency on the sequence dimension, since each token must attend to all previous tokens in the same sequence. In contrast, FFN and element-wise operations are performed independently for each token. For a more in-depth understanding of context parallelism in LLM inference, including partial attention, read the MLSys paper available at https://arxiv.org/pdf/2411.01783.
To implement context parallelism in vLLM, the design needs to:

  • Be aware of these dependencies to minimize synchronization overhead,
  • Remain flexible to support various backends, and
  • Avoid major changes to core components to ensure system stability.

Partition Sequence Dimension

Causal attention imposes a varying computational load for each token, as shown in the following figure. To ensure an even workload distribution, tokens should be partitioned across different context parallelism (CP) ranks. Specifically, the sequence is divided into 2 × cp_world_size chunks. Each CP rank i is assigned both the i-th chunk and the (2 × cp_world_size - i - 1)-th chunk. This approach helps balance the compute load among all CP ranks.

Image

Prefill

During the prefill phase, both the query (Q) and key-value (KV) tensors are sharded across GPUs. To ensure that each Q token can attend to all preceding KV tokens, it is necessary to exchange the relevant Q or KV shards among GPUs. To reduce synchronization overhead, data transfers are overlapped with partial attention computations, with the goal of fully hiding data transfer latency. The following figure shows an example of prefill with cp2.

Image

The choice between passing KV or Q shards depends on the relative sizes of the Q and KV tensors. For full prefill, passing KV shards is generally preferred, as the number of queries per KV head typically exceeds two in most models. Conversely, for chunked prefill, passing Q shards may be more efficient if the KV cache length is significantly greater than the number of Q tokens.

Decode

During decoding, new Key-Value (KV) pairs are distributed across CP ranks in a round-robin fashion. This prevents overlap and enables correct computation of partial attention, which is then merged. Once each CP rank completes its local partial attention, these partial results are all-gathered and merged to achieve the same attention across all CP ranks. The provided figure demonstrates the decoding process with cp=2.

Image

Block Table

When tokens are distributed across context parallel (CP) ranks, gaps may appear in the block table. After compaction, tokens that are stored physically next to each other may not be logically consecutive. This is correct for CP because we only need to maintain the correct relative order of tokens for mapping purposes, rather than tracking their absolute positions in the block table. The figure below shows how KVs are stored in the CP case.

Image

PRs

#26057
#26058
#26059

Feedback Period.

No response

CC List.

@luccafong @houseroad @minosfuture

Any Other Things.

Related RFCs
#25749
#22693

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions