-
-
Notifications
You must be signed in to change notification settings - Fork 11.1k
[V1] [Hybrid] Mamba2 Automatic Prefix Caching #25752
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[V1] [Hybrid] Mamba2 Automatic Prefix Caching #25752
Conversation
Signed-off-by: Stanislaw Wozniak <[email protected]>
Signed-off-by: Thomas Ortner <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Stanislaw Wozniak <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Stanislaw Wozniak <[email protected]>
Signed-off-by: Thomas Ortner <[email protected]>
Signed-off-by: Stanislaw Wozniak <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the big effort + perseverance. I think this is now in a shape that we can merge it. It should already give good speedups for prefill-dominated latency benchmarks.
@s3woz @bohnstingl -- could you please create a new Issue to track the remaining work items?
- Implement policy for freeing mamba blocks to fix performance in throughput benchmarks
- Relax constraint that mamba block size must be multiple of chunk size
- Give user flexibility to set mamba caching granularity
- Support mamba prefix caching and spec decode
- Fuse logic for SSM state writing into kernels
- Test TP>1 behaviour
- Cache meta-data builds across KV cache groups (#22788)
- Additional cleanup in causal_conv1d kernels (e.g., strip out unused logic)
- Enable prefix caching for Mamba1
- Enable prefix caching for ShortConv
- Enable prefix caching for LinearAttention
- Enable prefix caching for GDN
There is quite a bit of interest from the community in helping with these follow-ups so it would be good to merge this so we can start parallelizing up the work.
Signed-off-by: Thomas Parnell <[email protected]>
|
|
||
| mask = (idx_tokens_conv < state_len)[:, None] & \ | ||
| (idx_feats < dim)[None, :] | ||
| tl.debug_barrier() # NOTE: use this due to bug in Triton compiler |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there an issue we could link to here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is also just duplicated from the equivalent code on main: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py#L185
I strongly suspect we can remove that, but prefer to do a big cleanup of this kernel as a follow-up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed the debugging statements and I haven't seen any negative side-effects yet. However, I haven't changed it in this PR, but this could be something for a follow-up PR.
| stride_istate_dim = 0 | ||
| stride_istate_token = 0 | ||
| num_cache_lines = 0 | ||
| BLOCK_M = 8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where does this number come from?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is hard-coded on main, we just moved the definition earlier: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py#L618
| initial_state_idx: (batch,), dtype int32 | ||
| The pointer into cache_indices, which signifies the cache block containing the initial state. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this right?
| initial_state_idx: (batch,), dtype int32 | |
| The pointer into cache_indices, which signifies the cache block containing the initial state. | |
| initial_state_idx: (batch,), dtype int32 | |
| The pointer into initial_states, which signifies the cache block containing the initial state. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the catch. I think the description and the naming was a bit off. The tensor initial_state_idx indexes into the conv_state_indices and with it points to the location of the initial states. I updated the description there. Please let me know if it makes more sense now.
Signed-off-by: Thomas Parnell <[email protected]>
tlrmchlsmth
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks pretty clean -- a lot less invasive than I thought it would be!
| KERNEL_WIDTH=width, | ||
| SILU_ACTIVATION=activation in ["silu", "swish"], | ||
| IS_VARLEN=query_start_loc is not None, | ||
| IS_CONTINUOUS_BATCHING=conv_state_indices is not None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we assume this is always true now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. We've massively cleaned up the mamba2 kernels to remove these unused logic but the causal_conv1d could actually use another pass through it imo. We can do it as follow-up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(e.g., stuff like IS_VARLEN can also be stripped out)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As @tdoublep mentioned the kernels have been cleaned up altogether quite a lot, but they are still not perfect. Especially the conv1D kernel. They can be simplified quite a bit, I believe.
Signed-off-by: Thomas Ortner <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
|
Tracking issue for follow-ups: #26201 |
heheda12345
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for my late review. I've added some small comments. Can you update them in a future PR?
| # Additional cache-related varaiables: | ||
| mamba_block_size = self.kv_cache_spec.block_size | ||
| seq_lens_pending = ( | ||
| torch.roll(common_attn_metadata.query_start_loc, -1, -1) - |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we usually do query_start_loc[:-1] - query_start_loc[1:] like
vllm/vllm/v1/attention/backends/utils.py
Line 752 in d3d649e
| query_lens = query_start_loc[1:] - query_start_loc[:-1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We actually don't even need to compute query_lens
| # current_first == current_last if no block crossing occurs, and | ||
| # only one state will be stored | ||
| # 0th based indexing leads to "-1" -> e.g. 16 computed -> state[15]: | ||
| current_last_idx = cdiv(context_lens + seq_lens_pending, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| current_last_idx = cdiv(context_lens + seq_lens_pending, | |
| current_last_idx = cdiv(common_attn_metadata.seq_lens, |
Is this simplification correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes
| last_state_idx = \ | ||
| last_state_idx.clamp(min=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't need line break here?
| state_indices_tensor: torch.Tensor # shape: [batch,] | ||
| current_last_idx: torch.Tensor | ||
| current_first_idx_p: torch.Tensor | ||
| last_state_idx: torch.Tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think Can you mark the shape of these tensors? And I think it's better to add some comments. I guess:
last_state_idx -> the chunk id of the last computed token
current_first_idx -> the chunk id of the first scheduled token
current_last_idx -> the chunk id of the last scheduled token
And I prefer the following names (as we are using variable names like num_computed_tokens_cpu, num_scheduled_tokens )
chunk_id_last_computed_token
chunk_id_first_scheduled_token
chunk_id_last_scheduled_token
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like this suggestion - will do (although they are block indexes rather than chunk indices)
| seq_lens_pending = ( | ||
| torch.roll(common_attn_metadata.query_start_loc, -1, -1) - | ||
| common_attn_metadata.query_start_loc)[:-1] | ||
| context_lens = common_attn_metadata.seq_lens - \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it common_attn_metadata.num_computed_tokens_cpu?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes (except we want it on device)
| common_attn_metadata.query_start_loc)[:-1] | ||
| context_lens = common_attn_metadata.seq_lens - \ | ||
| seq_lens_pending | ||
| last_computed_offset = \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you need line break here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No
| # First aligned chunk would typically be: | ||
| # mamba_block_size = 1024, chunk_size = 256 | ||
| # 1024 // 256 - 1 --> chunks[3] | ||
| # But when last chunk wasn't block aligned: | ||
| # - last_computed_offset_p[seq_idx] // chunk_size | ||
| # e.g. 1000 // 256 -> 3 completed --> store chunk[0] | ||
| # e.g. 513 // 256 -> 2 completed --> store chunk[1] (skip 1) | ||
| # e.g. 256 // 256 -> 1 completed --> store chunk[2] (skip 2) | ||
| # e.g. 10 // 256 -> 0 completed --> store chunk[3] (skip 3) | ||
| chunk_stride = mamba_block_size // chunk_size | ||
| first_aligned_chunk = \ | ||
| torch.concat([torch.zeros(1, \ | ||
| dtype=last_chunk_indices_p.dtype, \ | ||
| device=last_chunk_indices_p.device), \ | ||
| last_chunk_indices_p + 1])[seq_idx] \ | ||
| + chunk_stride - 1 \ | ||
| - last_computed_offset_p[seq_idx] // chunk_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually I can't understand this part (And also the comments in 702-705...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've re-written the code in a way that I believe is easier to understand what is happening. We plan to fuse this code into the kernels at a later stage but I would like to make sure the code on main right now is easy to follow.
| # e.g. 10 // 256 -> 0 completed --> store chunk[3] (skip 3) | ||
| chunk_stride = mamba_block_size // chunk_size | ||
| first_aligned_chunk = \ | ||
| torch.concat([torch.zeros(1, \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do you need an additional torch.zeros?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Re-written this
| # hit_length = len(hit_blocks_other_attn[0]) | ||
| # * self.other_block_size | ||
| # so we insert dummy blocks at the beginning: | ||
| if i > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we don't need if i>0
|
@heheda12345 Opened this PR to address your review comments: #26222 |
|
This pr makes |
Signed-off-by: Stanislaw Wozniak <[email protected]> Signed-off-by: Thomas Ortner <[email protected]> Signed-off-by: Thomas Parnell <[email protected]> Co-authored-by: Thomas Ortner <[email protected]> Co-authored-by: Thomas Parnell <[email protected]> Signed-off-by: Tomer Asida <[email protected]>
Signed-off-by: Stanislaw Wozniak <[email protected]> Signed-off-by: Thomas Ortner <[email protected]> Signed-off-by: Thomas Parnell <[email protected]> Co-authored-by: Thomas Ortner <[email protected]> Co-authored-by: Thomas Parnell <[email protected]> Signed-off-by: Karan Goel <[email protected]>
Signed-off-by: Stanislaw Wozniak <[email protected]> Signed-off-by: Thomas Ortner <[email protected]> Signed-off-by: Thomas Parnell <[email protected]> Co-authored-by: Thomas Ortner <[email protected]> Co-authored-by: Thomas Parnell <[email protected]>
Signed-off-by: Stanislaw Wozniak <[email protected]> Signed-off-by: Thomas Ortner <[email protected]> Signed-off-by: Thomas Parnell <[email protected]> Co-authored-by: Thomas Ortner <[email protected]> Co-authored-by: Thomas Parnell <[email protected]> Signed-off-by: xuebwang-amd <[email protected]>
Signed-off-by: Stanislaw Wozniak <[email protected]> Signed-off-by: Thomas Ortner <[email protected]> Signed-off-by: Thomas Parnell <[email protected]> Co-authored-by: Thomas Ortner <[email protected]> Co-authored-by: Thomas Parnell <[email protected]>
Signed-off-by: Stanislaw Wozniak <[email protected]> Signed-off-by: Thomas Ortner <[email protected]> Signed-off-by: Thomas Parnell <[email protected]> Co-authored-by: Thomas Ortner <[email protected]> Co-authored-by: Thomas Parnell <[email protected]>
Signed-off-by: Stanislaw Wozniak <[email protected]> Signed-off-by: Thomas Ortner <[email protected]> Signed-off-by: Thomas Parnell <[email protected]> Co-authored-by: Thomas Ortner <[email protected]> Co-authored-by: Thomas Parnell <[email protected]> Signed-off-by: xuebwang-amd <[email protected]>
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new modelPurpose
This PR implements Automatic Prefix Caching (APC) for Mamba2 hybrid models.
Logic before this PR:
This PR introduces APC logic for Mamba2 by storing states at input block boundaries, and resuming the computations from them when the cache is hit. The chart below shows timing results from
vllm bench latency --model ibm-granite/granite-4.0-tiny-preview --num-iters 10 --num-iters-warmup 2for three cases:--no-enable-prefix-caching--enable-prefix-cachingAs prefill length increases, the APC-on mode provides clear benefits for long prefill-dominated cases (e.g. see decode 1). For decode-heavy cases (e.g. decode 1024), the performance is suboptimal due to how vLLM currently implements the metadata building for KVCache groups. Our internal evaluations show that the decode speed overhead can be eliminated if the additional metadata produced in the APC-on mode is cached between the groups. A general implementation of such functionality is pending in #22788 .
Technical considerations:
Pending enhancements:
remove_skipped_blocksinclass MambaManager, andmax_memory_usage_bytes).@tdoublep @bohnstingl
Test Plan
Test Result