Skip to content

Conversation

@tdoublep
Copy link
Member

@tdoublep tdoublep commented Sep 11, 2025

Purpose

This PR changes the way that the mamba2 kernels split the batch into "chunks". The change ensures that (a) no chunk ever contains more than one sequence, and (b) all intermediate states are computed at the chunk boundaries within each sequence.

This change is useful for three reasons:

  1. It dramatically simplifies the kernels due to (a).
  2. It enables much easier implementation of prefix caching for mamba due to (b)
  3. It can improve performance, even without prefix caching, because we can entirely skip the final call to the "varlen" kernel that is used to align the final states for each sequence.

The downside is that it introduces some "virtual" padding inside the chunks. We don't actually pad anything in GPU memory, we just potentially need to use a larger grid when launching kernels and may do some redundant compute. However, this padding is bounded to at most one chunk per sequence, and my initial experiments suggest it really doesn't hurt a lot. In fact, we actually see a significant speedup because we skip the call to the final "varlen" kernel. We follow a very similar approach for working with varlen batches in the Triton attention kernels, so this kind of technique is not without precedent.

TODO:


Simple example for two sequences A and B is shown below. A0 and B0 represent the chunks that were prefilled at the previous step, and A1 and B1 are the new chunks we want to prefill in this iteration.

image

The idea is that for sequence A, we first take enough tokens from the new part (A1) to ensure that, when taking together with the precomputed part (A0), the state is chunked-aligned. Then we fill chunks with new tokens (from A1) until we run out, at which we pad to the chunk boundary. Then repeat for B.

Test Plan

See correctness + benchmarking below.

Test Result

See correctness + benchmarking below.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
@mergify mergify bot added the v1 label Sep 11, 2025
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
@tdoublep
Copy link
Member Author

Server:

vllm serve ibm-granite/granite-4.0-tiny-preview --enforce-eager

Client:

lm_eval --model local-completions --tasks gsm8k --num_fewshot 5 --batch_size auto --limit 500     \
    --model_args model=ibm-granite/granite-4.0-tiny-preview,base_url=http://0.0.0.0:8000/v1/completions,num_concurrent=50,max_retries=3,tokenized_requests=False

Results (main):

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.608|±  |0.0219|
|     |       |strict-match    |     5|exact_match|↑  |0.584|±  |0.0221|

Results (tpa-mamba-aligned):

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.616|±  |0.0218|
|     |       |strict-match    |     5|exact_match|↑  |0.590|±  |0.0220|


Signed-off-by: Thomas Parnell <[email protected]>
@tdoublep
Copy link
Member Author

Server:

vllm serve ibm-granite/granite-4.0-tiny-preview

Benchmark:

vllm bench serve \
        --model ibm-granite/granite-4.0-tiny-preview \
        --dataset-name sharegpt \
        --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json \
        --ignore_eos

Branch main (second run):

============ Serving Benchmark Result ============
Successful requests:                     983       
Benchmark duration (s):                  32.64     
Total input tokens:                      235252    
Total generated tokens:                  222931    
Request throughput (req/s):              30.12     
Output token throughput (tok/s):         6830.50   
Total Token throughput (tok/s):          14038.50  
---------------Time to First Token----------------
Mean TTFT (ms):                          5419.54   
Median TTFT (ms):                        5404.06   
P99 TTFT (ms):                           9049.08   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          111.52    
Median TPOT (ms):                        72.42     
P99 TPOT (ms):                           303.04    
---------------Inter-token Latency----------------
Mean ITL (ms):                           53.58     
Median ITL (ms):                         35.74     
P99 ITL (ms):                            245.94    
==================================================

Branch tpa-aligned-mamba (Second run):

============ Serving Benchmark Result ============
Successful requests:                     983       
Benchmark duration (s):                  32.34     
Total input tokens:                      233074    
Total generated tokens:                  223781    
Request throughput (req/s):              30.39     
Output token throughput (tok/s):         6918.85   
Total Token throughput (tok/s):          14125.03  
---------------Time to First Token----------------
Mean TTFT (ms):                          4103.87   
Median TTFT (ms):                        4083.23   
P99 TTFT (ms):                           7084.88   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          105.28    
Median TPOT (ms):                        83.67     
P99 TPOT (ms):                           248.41    
---------------Inter-token Latency----------------
Mean ITL (ms):                           56.46     
Median ITL (ms):                         36.00     
P99 ITL (ms):                            320.39    
==================================================

Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
@tdoublep tdoublep added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 26, 2025
@tdoublep
Copy link
Member Author

More benchmarking data, this time for NVIDIA-Nemotron-Nano-12B-v2.

Server

vllm serve nvidia/NVIDIA-Nemotron-Nano-12B-v2 

Client

vllm bench serve \
    --model nvidia/NVIDIA-Nemotron-Nano-12B-v2 \
    --dataset-name sharegpt \
    --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json \
    --ignore_eos

Results from main:

============ Serving Benchmark Result ============
Successful requests:                     983       
Benchmark duration (s):                  54.60     
Total input tokens:                      218758    
Total generated tokens:                  201157    
Request throughput (req/s):              18.00     
Output token throughput (tok/s):         3684.31   
Peak output token throughput (tok/s):    7049.00   
Peak concurrent requests:                983.00    
Total Token throughput (tok/s):          7691.00   
---------------Time to First Token----------------
Mean TTFT (ms):                          13463.72  
Median TTFT (ms):                        9844.10   
P99 TTFT (ms):                           35302.82  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          145.05    
Median TPOT (ms):                        109.78    
P99 TPOT (ms):                           698.06    
---------------Inter-token Latency----------------
Mean ITL (ms):                           88.24     
Median ITL (ms):                         70.37     
P99 ITL (ms):                            467.66    
==================================================

Results from tpa-aligned-mamba:

============ Serving Benchmark Result ============
Successful requests:                     983       
Benchmark duration (s):                  47.08     
Total input tokens:                      219876    
Total generated tokens:                  200385    
Request throughput (req/s):              20.88     
Output token throughput (tok/s):         4256.43   
Peak output token throughput (tok/s):    6920.00   
Peak concurrent requests:                983.00    
Total Token throughput (tok/s):          8926.88   
---------------Time to First Token----------------
Mean TTFT (ms):                          9829.15   
Median TTFT (ms):                        6547.68   
P99 TTFT (ms):                           27941.63  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          113.15    
Median TPOT (ms):                        92.88     
P99 TPOT (ms):                           444.97    
---------------Inter-token Latency----------------
Mean ITL (ms):                           75.63     
Median ITL (ms):                         66.86     
P99 ITL (ms):                            399.50    
==================================================

Copy link
Contributor

@tomeras91 tomeras91 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks great overall! Really simplifies the code

I added a few nit comments about the need for comments and documenting expected shapes

I also think it's worth adding a general comment somewhere that in this implementation we're assuming each chunk has only a single sequence, since this is a significant change which is different from the original implementation

Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
@tlrmchlsmth tlrmchlsmth self-assigned this Sep 29, 2025
Copy link
Contributor

@tomeras91 tomeras91 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @tdoublep. LGTM now. The descriptions and shapes really help

Copy link
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR looks great at first pass. Love to see more red than green.

@tlrmchlsmth
Copy link
Member

tlrmchlsmth commented Sep 29, 2025

In the figure in the PR description, why does A1.a fall at the beginning of the chunk rather than the end? I thought A0 should be ahead of it rather than behind

@tdoublep
Copy link
Member Author

tdoublep commented Sep 29, 2025

In the figure in the PR description, why does A1.a fall at the beginning of the chunk rather than the end? I thought A0 should be ahead of it rather than behind

@tlrmchlsmth A0 isn't actually added to the chunk, it has already been prefilled and doesn't need to be computed again. We just need to partition A1 in such a way that len(A0)+len(A1.a)=chunk_size so that the intermediate states we get at the output of the first chunk correspond to the actual chunk boundaries within the sequence. That's why the part of the chunk after A1.a is grey to indicate that it gets padded (not actually padding in memory, only compute).

@tlrmchlsmth
Copy link
Member

Do the padded regions get loaded at all?

In the figure in the PR description, why does A1.a fall at the beginning of the chunk rather than the end? I thought A0 should be ahead of it rather than behind

@tlrmchlsmth A0 isn't actually added to the chunk, it has already been prefilled and doesn't need to be computed again. We just need to partition A1 in such a way that len(A0)+len(A1.a)=chunk_size so that the intermediate states we get at the output of the first chunk correspond to the actual chunk boundaries within the sequence. That's why the part of the chunk after A1.a is grey to indicate that it gets padded (not actually padding in memory, only compute).

makes sense. So then the A0-sized padded region could overlap with another chunk, or it could fall off the end of the KV cache tensor, right? Do we mask off the loads of the padded region as well?

Signed-off-by: Thomas Parnell <[email protected]>
@tdoublep
Copy link
Member Author

tdoublep commented Sep 29, 2025

Do the padded regions get loaded at all?

No, padding is maybe the wrong word. There isn't any actual padding of tensors in memory here.

Masking would probably be a better word. If we have 5 chunks like in the above example, we would launch a Triton kernel with a grid size of (5,..) and in the first chunk we mask out the last len(A0) slots, the second chunk we mask out nothing, third chunk we mask out chunk_size-len(A1.c)

We are basically trading off a bit of extra compute in order to get intermediate states at exactly where we want them within each sequence. It turns out it isn't really a trade-off since it strips out so much complexity, it is a net-win.

@tdoublep
Copy link
Member Author

So then the A0-sized padded region could overlap with another chunk, or it could fall off the end of the KV cache tensor, right?

Yes, if we don't introduce the padding/masking it will lead to (a) having multiple sequences within the same chunk and (b) needing this whole mapping between "logical" and "physical" chunks to track where everything is.

Do we mask off the loads of the padded region as well?

Yes, we mask off the loads exactly (example: https://github.com/tdoublep/vllm/blob/tpa-aligned-mamba/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py#L231)

@tdoublep tdoublep merged commit fea3e47 into vllm-project:main Sep 29, 2025
52 checks passed
pdasigi pushed a commit to pdasigi/vllm that referenced this pull request Oct 2, 2025
yewentao256 pushed a commit that referenced this pull request Oct 3, 2025
tomeras91 pushed a commit to tomeras91/vllm that referenced this pull request Oct 6, 2025
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
alhridoy pushed a commit to alhridoy/vllm that referenced this pull request Oct 24, 2025
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants