Skip to content

Conversation

@minosfuture
Copy link

@minosfuture minosfuture commented Sep 17, 2025

This PR implements the causal mask for interleave context parallelism to allow query length > 1.

The solution follows the discussion between @LucasWilkinson , @youkaichao , and @youzhedian on slack.

key illustration made by @LucasWilkinson :

Normal:

k_toks >   0 1 2 3 4 5
q_toks v  _____________
       2 | 1 1 1
       3 | 1 1 1 1
       4 | 1 1 1 1 1
       5 | 1 1 1 1 1 1


DCP Rank 0:

k_toks >   0 2 4
q_toks v  _______
       2 | 1 1
       3 | 1 1
       4 | 1 1 1
       5 | 1 1 1 


DCP Rank 1:

k_toks >   1 3 5
q_toks v   ______
       2 | 1
       3 | 1 1
       4 | 1 1
       5 | 1 1 1

In the DCP case, the k/v tokens are distributed in an interleaved fashion, see vllm-project/vllm#23734.
Therefore we have 0,2,4 kv on rank0 and 1,3,5 kv on rank1 in the example above. The mask shape is no longer a bottom right triangle.
This requires FA to be aware of cp world size and cp rank, in order to determine the causal mask.
The block tiling implementation also needs to be updated. As illustrated below, we now needs to process block tile (0,1) in CP case, while it can be skipped previously in normal case.
image

Tests

Added and passed unit tests for CP.

bool const packgqa_override = params.arch >= 90 && (params.h / params.h_k) == 8 &&
params.is_local &&
bool const packgqa_override = params.arch >= 90 && (params.h / params.h_k) == 8 &&
params.is_local &&
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mind removing the unrelated formatting changes? trying to stay as close to upstream as possible when possible

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

: std::max(n_block_min,
cute::ceil_div(m_idx_max + seqlen_k - seqlen_q - params.window_size_left, kBlockN));
cute::ceil_div(m_idx_max +
params.cp_world_size * seqlen_k -
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use cp_tot_seqlen_k to skip the mul here? should branch in the non-cp case to save the mul?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could make cp_tot_seqlen_k == seqlen_k in the params.cp_world_size == 1 case

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

hopper/seqlen.h Outdated
, cp_world_size(cp_world_size)
, cp_tot_seqlen_k(cp_tot_seqused_k == nullptr
? 0
: cp_tot_seqused_k[bidb])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work! left a few comments but its looking really good!

Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@LucasWilkinson LucasWilkinson merged commit 8f468e7 into vllm-project:main Oct 5, 2025
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants