Skip to content

Conversation

@s3woz
Copy link
Contributor

@s3woz s3woz commented Sep 26, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model

Purpose

This PR implements Automatic Prefix Caching (APC) for Mamba2 hybrid models.
Logic before this PR:

  • Mamba2 implementation uses single cache block for the "current state" updated in-place

This PR introduces APC logic for Mamba2 by storing states at input block boundaries, and resuming the computations from them when the cache is hit. The chart below shows timing results from vllm bench latency --model ibm-granite/granite-4.0-tiny-preview --num-iters 10 --num-iters-warmup 2 for three cases:

  1. vLLM main
  2. This PR with APC off --no-enable-prefix-caching
  3. This PR with APC on --enable-prefix-caching
image

As prefill length increases, the APC-on mode provides clear benefits for long prefill-dominated cases (e.g. see decode 1). For decode-heavy cases (e.g. decode 1024), the performance is suboptimal due to how vLLM currently implements the metadata building for KVCache groups. Our internal evaluations show that the decode speed overhead can be eliminated if the additional metadata produced in the APC-on mode is cached between the groups. A general implementation of such functionality is pending in #22788 .

Technical considerations:

  • Current vLLM logic: Cache Manager assumes that page sizes should be equal for various attention implementations. block_size is determined as the smallest attention block size for which attention page size >= Mamba2 page size, and Mamba2 state blocks are padded.
  • In this PR we introduce additional condition: Due to Mamba2 kernel specifics, obtaining intermediate states efficiently during prefill is possible every mamba_chunk_size (typically 256). Thus, we assume that the block_size should be set to a multiple of 256.
  • To obtain the intermediate mamba states at mamba_chunk_size boundaries and at the same time ensure high mamba kernel performance, the kernels need to process the sequences in a chunk-aligned manner, with chunk boundaries aligned to the absolute sequence length. The kernels are modified accordingly in this PR, pulling changes from [Kernel] Chunk-aligned mamba2 #24683 .

Pending enhancements:

  1. Ensure that Mamba Metadata is properly cached when [Attention] Cache attention metadata builds across hybrid KV-cache groups #22788 is merged.
  2. Potential early memory freeing: For running requests free up all blocks that aren't needed by anymore (adjust remove_skipped_blocks in class MambaManager, and max_memory_usage_bytes).
  3. Currently padding is applied only to mamba cache pages. Depending on scenarios, it might be more memory-efficient to pad attention cache pages instead.
  4. Different strategies could be implemented to choose which states to store. Currently all states are stored: as requests arrive, allocate a number of blocks proportional to the sequence length in order to store all states at block boundaries (places where cache hits may occur), and the current state. Potential future enhancements include sparser strategies (reducing memory usage and I/O), such as just store last state strategy: to limit block allocations, allocate only two blocks - the last block-aligned intermediate state (to allow for cache hits), and the current state.
  5. Support speculative decoding.
  6. Remove overhead of mamba2 state caching from PyTorch level and move into Triton kernel.
  7. Extend testcases from E2E tests to specific "cache logic" tests

@tdoublep @bohnstingl

Test Plan

from vllm import LLM, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory
import time
MODEL = "ibm-granite/granite-4.0-tiny-preview"
PROMPT_MULTIPLE = 310
sampling_params = SamplingParams(temperature=0.0)
prefix = ( # examples/offline_inference/prefix_caching.py
    "You are an expert school principal, skilled in effectively managing "
    "faculty and staff. Draft 10-15 questions for a potential first grade "
    "Head Teacher for my K-12, all-girls', independent school that emphasizes "
    "community, joyful discovery, and life-long learning. The candidate is "
    "coming in for a first-round panel interview for a 8th grade Math "
    "teaching role. They have 5 years of previous teaching experience "
    "as an assistant teacher at a co-ed, public school with experience "
    "in middle school math teaching. ")
prefix2 = ("Based on these information, fulfill "
            "the following paragraph: ")
prompt = PROMPT_MULTIPLE * prefix + prefix2 + "Hello, my name is"
print('Prompt length:', len(prompt))
for APC in [False, True]:
    engine = LLM(model=MODEL, enable_prefix_caching=APC, 
        gpu_memory_utilization=0.4, disable_log_stats=False)
    for i in range(3):
        if i == 0:
            print('Warm-up')
        if i == 1:
            print('Measuring')
            start_time = time.time()
        outputs = engine.generate(prompt, sampling_params)
        print('APC:', APC, i, f"Generated text: {outputs[0].outputs[0].text!r}")
        for m in engine.llm_engine.get_metrics():
            if 'vllm:prefix_cache_hits' in m.name:
                print(m.name, m.value)
    print('APC:', APC, "loop took --- %s seconds ---" % (time.time() - start_time))
    del engine
    cleanup_dist_env_and_memory()

Test Result

Warm-up
Adding requests: 100%|---------| 1/1 [00:00<00:00,  9.89it/s]
Processed prompts: 100%|------| 1/1 [00:08<00:00,  8.13s/it, est. speed input: 4540.29 toks/s, output: 1.97 toks/s]
APC: False 0 Generated text: ' NAME_1. I am an expert school principal, skilled in effectively managing'
vllm:prefix_cache_hits 0
Measuring
Adding requests: 100%|---------| 1/1 [00:00<00:00, 12.45it/s]
Processed prompts: 100%|----| 1/1 [00:00<00:00,  1.79it/s, est. speed input: 66076.01 toks/s, output: 28.64 toks/s]
APC: False 1 Generated text: ' NAME_1. I am an expert school principal, skilled in effectively managing'
vllm:prefix_cache_hits 0
Adding requests: 100%|------| 1/1 [00:00<00:00, 11.90it/s]
Processed prompts: 100%|----| 1/1 [00:00<00:00,  1.78it/s, est. speed input: 65945.06 toks/s, output: 28.59 toks/s]
APC: False 2 Generated text: ' NAME_1. I am an expert school principal, skilled in effectively managing'
vllm:prefix_cache_hits 0
APC: False loop took --- 1.2919602394104004 seconds ---

Warm-up
Adding requests: 100%|-------------| 1/1 [00:00<00:00,  9.78it/s]
Processed prompts: 100%|----------| 1/1 [00:08<00:00,  8.19s/it, est. speed input: 4505.37 toks/s, output: 1.95 toks/s]
APC: True 0 Generated text: ' NAME_1. I am the candidate for the position of a Math teacher.'
vllm:prefix_cache_hits 0
Measuring
Adding requests: 100%|-------------| 1/1 [00:00<00:00, 11.40it/s]
Processed prompts: 100%|----------| 1/1 [00:00<00:00,  4.72it/s, est. speed input: 174700.34 toks/s, output: 75.73 toks/s]
APC: True 1 Generated text: ' NAME_1, and I am an expert school principal, skilled in effectively'
vllm:prefix_cache_hits 36864
Adding requests: 100%|-------------| 1/1 [00:00<00:00, 11.25it/s]
Processed prompts: 100%|----------| 1/1 [00:00<00:00,  5.86it/s, est. speed input: 217302.08 toks/s, output: 94.19 toks/s]
APC: True 2 Generated text: ' NAME_1, and I am an expert school principal, skilled in effectively'
vllm:prefix_cache_hits 73728
APC: True loop took --- 0.5674521923065186 seconds ---

s3woz and others added 30 commits August 29, 2025 12:35
Signed-off-by: Stanislaw Wozniak <[email protected]>
Signed-off-by: Thomas Ortner <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Stanislaw Wozniak <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Stanislaw Wozniak <[email protected]>
Signed-off-by: Thomas Ortner <[email protected]>
Signed-off-by: Stanislaw Wozniak <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Copy link
Member

@tdoublep tdoublep left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the big effort + perseverance. I think this is now in a shape that we can merge it. It should already give good speedups for prefill-dominated latency benchmarks.

@s3woz @bohnstingl -- could you please create a new Issue to track the remaining work items?

  • Implement policy for freeing mamba blocks to fix performance in throughput benchmarks
  • Relax constraint that mamba block size must be multiple of chunk size
  • Give user flexibility to set mamba caching granularity
  • Support mamba prefix caching and spec decode
  • Fuse logic for SSM state writing into kernels
  • Test TP>1 behaviour
  • Cache meta-data builds across KV cache groups (#22788)
  • Additional cleanup in causal_conv1d kernels (e.g., strip out unused logic)
  • Enable prefix caching for Mamba1
  • Enable prefix caching for ShortConv
  • Enable prefix caching for LinearAttention
  • Enable prefix caching for GDN

There is quite a bit of interest from the community in helping with these follow-ups so it would be good to merge this so we can start parallelizing up the work.


mask = (idx_tokens_conv < state_len)[:, None] & \
(idx_feats < dim)[None, :]
tl.debug_barrier() # NOTE: use this due to bug in Triton compiler
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there an issue we could link to here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also just duplicated from the equivalent code on main: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py#L185

I strongly suspect we can remove that, but prefer to do a big cleanup of this kernel as a follow-up.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the debugging statements and I haven't seen any negative side-effects yet. However, I haven't changed it in this PR, but this could be something for a follow-up PR.

stride_istate_dim = 0
stride_istate_token = 0
num_cache_lines = 0
BLOCK_M = 8
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where does this number come from?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines 1059 to 1060
initial_state_idx: (batch,), dtype int32
The pointer into cache_indices, which signifies the cache block containing the initial state.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this right?

Suggested change
initial_state_idx: (batch,), dtype int32
The pointer into cache_indices, which signifies the cache block containing the initial state.
initial_state_idx: (batch,), dtype int32
The pointer into initial_states, which signifies the cache block containing the initial state.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the catch. I think the description and the naming was a bit off. The tensor initial_state_idx indexes into the conv_state_indices and with it points to the location of the initial states. I updated the description there. Please let me know if it makes more sense now.

Copy link
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty clean -- a lot less invasive than I thought it would be!

KERNEL_WIDTH=width,
SILU_ACTIVATION=activation in ["silu", "swish"],
IS_VARLEN=query_start_loc is not None,
IS_CONTINUOUS_BATCHING=conv_state_indices is not None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we assume this is always true now?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. We've massively cleaned up the mamba2 kernels to remove these unused logic but the causal_conv1d could actually use another pass through it imo. We can do it as follow-up.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(e.g., stuff like IS_VARLEN can also be stripped out)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As @tdoublep mentioned the kernels have been cleaned up altogether quite a lot, but they are still not perfect. Especially the conv1D kernel. They can be simplified quite a bit, I believe.

Signed-off-by: Thomas Parnell <[email protected]>
@tdoublep tdoublep merged commit ea507c3 into vllm-project:main Oct 4, 2025
57 checks passed
@tdoublep
Copy link
Member

tdoublep commented Oct 4, 2025

Tracking issue for follow-ups: #26201

Copy link
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for my late review. I've added some small comments. Can you update them in a future PR?

# Additional cache-related varaiables:
mamba_block_size = self.kv_cache_spec.block_size
seq_lens_pending = (
torch.roll(common_attn_metadata.query_start_loc, -1, -1) -
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we usually do query_start_loc[:-1] - query_start_loc[1:] like

query_lens = query_start_loc[1:] - query_start_loc[:-1]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We actually don't even need to compute query_lens

# current_first == current_last if no block crossing occurs, and
# only one state will be stored
# 0th based indexing leads to "-1" -> e.g. 16 computed -> state[15]:
current_last_idx = cdiv(context_lens + seq_lens_pending,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
current_last_idx = cdiv(context_lens + seq_lens_pending,
current_last_idx = cdiv(common_attn_metadata.seq_lens,

Is this simplification correct?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

Comment on lines +213 to +214
last_state_idx = \
last_state_idx.clamp(min=0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't need line break here?

state_indices_tensor: torch.Tensor # shape: [batch,]
current_last_idx: torch.Tensor
current_first_idx_p: torch.Tensor
last_state_idx: torch.Tensor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Can you mark the shape of these tensors? And I think it's better to add some comments. I guess:
last_state_idx -> the chunk id of the last computed token
current_first_idx -> the chunk id of the first scheduled token
current_last_idx -> the chunk id of the last scheduled token
And I prefer the following names (as we are using variable names like num_computed_tokens_cpu, num_scheduled_tokens )
chunk_id_last_computed_token
chunk_id_first_scheduled_token
chunk_id_last_scheduled_token

Copy link
Member

@tdoublep tdoublep Oct 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this suggestion - will do (although they are block indexes rather than chunk indices)

seq_lens_pending = (
torch.roll(common_attn_metadata.query_start_loc, -1, -1) -
common_attn_metadata.query_start_loc)[:-1]
context_lens = common_attn_metadata.seq_lens - \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it common_attn_metadata.num_computed_tokens_cpu?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes (except we want it on device)

common_attn_metadata.query_start_loc)[:-1]
context_lens = common_attn_metadata.seq_lens - \
seq_lens_pending
last_computed_offset = \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you need line break here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No

Comment on lines +697 to +713
# First aligned chunk would typically be:
# mamba_block_size = 1024, chunk_size = 256
# 1024 // 256 - 1 --> chunks[3]
# But when last chunk wasn't block aligned:
# - last_computed_offset_p[seq_idx] // chunk_size
# e.g. 1000 // 256 -> 3 completed --> store chunk[0]
# e.g. 513 // 256 -> 2 completed --> store chunk[1] (skip 1)
# e.g. 256 // 256 -> 1 completed --> store chunk[2] (skip 2)
# e.g. 10 // 256 -> 0 completed --> store chunk[3] (skip 3)
chunk_stride = mamba_block_size // chunk_size
first_aligned_chunk = \
torch.concat([torch.zeros(1, \
dtype=last_chunk_indices_p.dtype, \
device=last_chunk_indices_p.device), \
last_chunk_indices_p + 1])[seq_idx] \
+ chunk_stride - 1 \
- last_computed_offset_p[seq_idx] // chunk_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I can't understand this part (And also the comments in 702-705...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've re-written the code in a way that I believe is easier to understand what is happening. We plan to fuse this code into the kernels at a later stage but I would like to make sure the code on main right now is easy to follow.

# e.g. 10 // 256 -> 0 completed --> store chunk[3] (skip 3)
chunk_stride = mamba_block_size // chunk_size
first_aligned_chunk = \
torch.concat([torch.zeros(1, \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you need an additional torch.zeros?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re-written this

# hit_length = len(hit_blocks_other_attn[0])
# * self.other_block_size
# so we insert dummy blocks at the beginning:
if i > 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need if i>0

@tdoublep
Copy link
Member

tdoublep commented Oct 4, 2025

@heheda12345 Opened this PR to address your review comments: #26222

@ZJY0516
Copy link
Contributor

ZJY0516 commented Oct 5, 2025

This pr makes tests/kernels/mamba/test_causal_conv1d.py::test_causal_conv1d_update fail. Fix in #26250

@ZJY0516 ZJY0516 mentioned this pull request Oct 5, 2025
5 tasks
tomeras91 pushed a commit to tomeras91/vllm that referenced this pull request Oct 6, 2025
Signed-off-by: Stanislaw Wozniak <[email protected]>
Signed-off-by: Thomas Ortner <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Co-authored-by: Thomas Ortner <[email protected]>
Co-authored-by: Thomas Parnell <[email protected]>
Signed-off-by: Tomer Asida <[email protected]>
karan pushed a commit to karan/vllm that referenced this pull request Oct 6, 2025
Signed-off-by: Stanislaw Wozniak <[email protected]>
Signed-off-by: Thomas Ortner <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Co-authored-by: Thomas Ortner <[email protected]>
Co-authored-by: Thomas Parnell <[email protected]>
Signed-off-by: Karan Goel <[email protected]>
southfreebird pushed a commit to southfreebird/vllm that referenced this pull request Oct 7, 2025
Signed-off-by: Stanislaw Wozniak <[email protected]>
Signed-off-by: Thomas Ortner <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Co-authored-by: Thomas Ortner <[email protected]>
Co-authored-by: Thomas Parnell <[email protected]>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
Signed-off-by: Stanislaw Wozniak <[email protected]>
Signed-off-by: Thomas Ortner <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Co-authored-by: Thomas Ortner <[email protected]>
Co-authored-by: Thomas Parnell <[email protected]>
Signed-off-by: xuebwang-amd <[email protected]>
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
Signed-off-by: Stanislaw Wozniak <[email protected]>
Signed-off-by: Thomas Ortner <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Co-authored-by: Thomas Ortner <[email protected]>
Co-authored-by: Thomas Parnell <[email protected]>
alhridoy pushed a commit to alhridoy/vllm that referenced this pull request Oct 24, 2025
Signed-off-by: Stanislaw Wozniak <[email protected]>
Signed-off-by: Thomas Ortner <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Co-authored-by: Thomas Ortner <[email protected]>
Co-authored-by: Thomas Parnell <[email protected]>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
Signed-off-by: Stanislaw Wozniak <[email protected]>
Signed-off-by: Thomas Ortner <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Co-authored-by: Thomas Ortner <[email protected]>
Co-authored-by: Thomas Parnell <[email protected]>
Signed-off-by: xuebwang-amd <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants