Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
117 commits
Select commit Hold shift + click to select a range
e6a41ba
Initial APC for mamba
s3woz Aug 29, 2025
87dd0a0
Conv kernel state handling
bohnstingl Aug 29, 2025
de7d2d6
Resolve conflicts
tdoublep Sep 4, 2025
224c9e1
Get things working with latest code
tdoublep Sep 4, 2025
f9e0de1
Merge branch 'main' into mamba2_prefix_caching_clean
largraf Sep 9, 2025
dddb650
working changes
tdoublep Sep 11, 2025
1dc7a04
Merge branch 'main' into tpa-aligned-mamba
tdoublep Sep 11, 2025
2a7b216
working changes
tdoublep Sep 11, 2025
a30bb5e
Initial varlen for APC
s3woz Sep 11, 2025
664a21a
fix bug
tdoublep Sep 11, 2025
6c475d6
fix bug
tdoublep Sep 11, 2025
0d00c69
fix bug
tdoublep Sep 11, 2025
b7ae698
working changes
tdoublep Sep 11, 2025
9b24bce
Fix bugs
tdoublep Sep 11, 2025
e850661
working changes
tdoublep Sep 12, 2025
0d5c3ae
revert some changes
tdoublep Sep 12, 2025
31e05fa
fmt
tdoublep Sep 12, 2025
a8aff97
workign
tdoublep Sep 12, 2025
67db9b4
working changes
tdoublep Sep 12, 2025
d841e82
working changes
tdoublep Sep 12, 2025
908aecb
working changes
tdoublep Sep 12, 2025
af7a246
working changes
tdoublep Sep 12, 2025
7ce2b59
Some test cases working
tdoublep Sep 12, 2025
f950f2e
Fix IMA
tdoublep Sep 12, 2025
039267d
Small cleanup
s3woz Sep 12, 2025
fe095c4
State fix
bohnstingl Sep 12, 2025
8a6336e
Merge with tdoublep/tpa-aligned-mamba
s3woz Sep 12, 2025
75e01c8
Add back autotune config
tdoublep Sep 12, 2025
2698f2e
cleanup
tdoublep Sep 12, 2025
d3f05b7
cleanup
tdoublep Sep 12, 2025
df63503
cleanup
tdoublep Sep 12, 2025
c5edccd
cleanup
tdoublep Sep 12, 2025
712ced1
cleanup
tdoublep Sep 12, 2025
dc85f7e
cleanup
tdoublep Sep 12, 2025
5e827a6
cleanup
tdoublep Sep 12, 2025
42e4b27
cleanup
tdoublep Sep 12, 2025
d859182
cleanup
tdoublep Sep 12, 2025
56b37c2
cleanup
tdoublep Sep 12, 2025
e21b4e6
lint
tdoublep Sep 12, 2025
88c157a
Merge with tdoublep/tpa-aligned-mamba
s3woz Sep 12, 2025
45bdc9c
Merge with tdoublep/tpa-aligned-mamba
s3woz Sep 12, 2025
90dd7f5
Initial conv state fixes
s3woz Sep 12, 2025
d8e00e3
Conv1D fix
bohnstingl Sep 15, 2025
5785a85
Conv and SSD state storing fixes
s3woz Sep 15, 2025
d5cab4c
Corrected decode. APC should work OK.
s3woz Sep 15, 2025
e711fc0
Cleanup.
s3woz Sep 16, 2025
120fbb7
CUDA graphs fixes.
s3woz Sep 16, 2025
bc3c122
Merge with main.
s3woz Sep 17, 2025
22c47f9
Precommit fixes.
s3woz Sep 17, 2025
f56fe5c
Precommit fixes.
s3woz Sep 17, 2025
15bb921
Precommit fixes.
s3woz Sep 17, 2025
5890677
Fix CUDA graph issue
tdoublep Sep 17, 2025
df23ee4
pre-commit
tdoublep Sep 17, 2025
4ef4023
Tests
bohnstingl Sep 18, 2025
ff49343
Precommit fixes.
s3woz Sep 18, 2025
c4255bb
Precommit fixes.
s3woz Sep 18, 2025
7649489
Fix bug in scan kernel when to reading previous state.
tdoublep Sep 18, 2025
1bb59d7
Remove BLOCK_H=1 from list of tuneable configurations.
tdoublep Sep 18, 2025
25f8a27
Reworked testcase test_multiple_prompts_partial_cached_output_logprobs
bohnstingl Sep 19, 2025
8515ee2
Fixed indexing for SSM state storing when bs>1
s3woz Sep 19, 2025
ebba273
Fixed indexing for SSM state storing when bs>1
s3woz Sep 19, 2025
48faed8
Fix test_specific_prompts_output_logprobs
tdoublep Sep 19, 2025
0ce539e
Fix tests
tdoublep Sep 19, 2025
c06246b
Fused causal_conv1d.
bohnstingl Sep 22, 2025
77ab2fe
Merge
s3woz Sep 23, 2025
4bb28c0
Precommit
s3woz Sep 23, 2025
1c7e947
Support for disabling prefix caching
s3woz Sep 23, 2025
7cdae60
Metadata optimization for apc=off
s3woz Sep 25, 2025
f2beb4d
Evaluating other models. Lightweight model for testing.
s3woz Sep 29, 2025
1f40794
Cleanup conv1D kernel and stripped APC testcases
bohnstingl Sep 30, 2025
7119d48
Merge remote-tracking branch 'origin/main' into mamba2_prefix_caching…
s3woz Sep 30, 2025
e5d5519
Pre-commit fixes.
s3woz Sep 30, 2025
148ea61
Addressing test failures.
s3woz Oct 1, 2025
ce5144b
Fixed issue with conv1D
bohnstingl Oct 1, 2025
e3b8cfb
Reintegrated conv1D update changes
bohnstingl Oct 1, 2025
0bb5197
Precommit fixes
s3woz Oct 1, 2025
c911c88
Moved APC tests: test_hybrid.py; pre-commit clean
bohnstingl Oct 1, 2025
cf5c4c7
Merge: Deleted old test file
bohnstingl Oct 1, 2025
5345d66
Merge branch 'mamba2_prefix_caching_clean' of https://github.com/s3wo…
largraf Oct 1, 2025
618fe53
Precommit fixes
s3woz Oct 2, 2025
9dd6b81
Precommit fixes
s3woz Oct 2, 2025
63e9217
Precommit fixes
s3woz Oct 2, 2025
c0eed4a
Precommit fixes
s3woz Oct 2, 2025
1425b73
Precommit fixes and documentation
bohnstingl Oct 2, 2025
6e8faf9
Precommit fixes
s3woz Oct 2, 2025
0141f15
Fixed test_hybrid.py for models w/o mamba_block-size
bohnstingl Oct 3, 2025
8d0077a
Addressing feedback and cleanup.
s3woz Oct 3, 2025
37dacff
Pre-commit fixes.
s3woz Oct 3, 2025
17986f8
Pre-commit fixes.
s3woz Oct 3, 2025
f71ad6d
Integrated code review comments.
bohnstingl Oct 3, 2025
a22c8ab
Pre-commit fixes.
s3woz Oct 3, 2025
b49e33f
Adjusted mamba_mixer2.py to new conv1D naming
bohnstingl Oct 3, 2025
3333e77
Merge branch 'mamba2_prefix_caching_clean' of https://github.com/s3wo…
bohnstingl Oct 3, 2025
0e574a7
Fix assertion for block_size_to_align
bohnstingl Oct 3, 2025
ae21a8c
Integrated code review comments
bohnstingl Oct 3, 2025
60cc1cd
cache_enabled -> enable_prefix_caching
tdoublep Oct 3, 2025
0d1b054
Reduce diff
tdoublep Oct 3, 2025
e9f2257
Removed unused code
tdoublep Oct 3, 2025
e733552
Update comment
tdoublep Oct 3, 2025
6ebc97f
Remove duplicate code
tdoublep Oct 3, 2025
19b33a0
reduce diff
tdoublep Oct 3, 2025
cd57200
Merge branch 'main' into mamba2_prefix_caching_clean
tdoublep Oct 3, 2025
9ab2dfc
Remove cache spec from mamba metadata
tdoublep Oct 3, 2025
f6293da
remove enable_prefix_caching from MambaSpec
tdoublep Oct 3, 2025
0806c06
Disable prefix caching by default for hybrid models
tdoublep Oct 3, 2025
ac31d48
Add logging about disabling cascade attn
tdoublep Oct 3, 2025
e57cf9c
rename seq_lens_completed -> context_lens
tdoublep Oct 3, 2025
a49f94d
Consistent naming between mamba_mixer2 and mamba2 metadata
tdoublep Oct 3, 2025
785ef4a
Remove FCG-handling for prefill-only tensors
tdoublep Oct 3, 2025
d68e491
minor cleanup
tdoublep Oct 3, 2025
5e8a63a
Add comment
tdoublep Oct 3, 2025
36cc1ce
Enable other mamba2 models
tdoublep Oct 3, 2025
b006aba
Fix computation of prefill-only tensors
tdoublep Oct 3, 2025
b256275
Ensure that mamba_block_size always get set
tdoublep Oct 3, 2025
745af73
Fixing argument description of conv1D
bohnstingl Oct 3, 2025
dcfc5ad
Remove assert for models that support prefix caching
tdoublep Oct 3, 2025
a0a4c40
remove debug print
tdoublep Oct 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
414 changes: 413 additions & 1 deletion tests/models/language/generation/test_hybrid.py

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion vllm/config/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ class CacheConfig:
mamba_page_size_padded: Optional[int] = None
""" Optional override for mamba page size; used by hybrid mamba/attention
models to ensure exact alignment with attention page size."""

mamba_block_size: Optional[int] = None
"""Size of a contiguous cache block in number of tokens for mamba cache."""
mamba_cache_dtype: MambaDType = "auto"
"""The data type to use for the Mamba cache (both the conv as well as the
ssm state). If set to 'auto', the data type will be inferred from the model
Expand Down
7 changes: 6 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1563,7 +1563,12 @@ def _set_default_args(self, usage_context: UsageContext,
self.enable_prefix_caching = False

if self.enable_prefix_caching is None:
self.enable_prefix_caching = True
# Disable prefix caching default for hybrid models
# since the feature is still experimental.
if model_config.is_hybrid:
self.enable_prefix_caching = False
else:
self.enable_prefix_caching = True
else:

pooling_type = model_config.pooler_config.pooling_type
Expand Down
113 changes: 105 additions & 8 deletions vllm/model_executor/layers/mamba/mamba_mixer2.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,9 @@ def forward_cuda(
# stay the same and reused for all mamba layers in the same iteration
attn_metadata: AttentionMetadata = forward_context.attn_metadata

assert self.cache_config is not None
mamba_block_size = self.cache_config.mamba_block_size
prefix_caching_enabled = self.cache_config.enable_prefix_caching
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
Expand Down Expand Up @@ -573,6 +576,25 @@ def forward_cuda(
dim=0,
)

if prefix_caching_enabled:
# If prefix caching is enabled, retrieve the relevant variables
# for prefill and decode
last_state_idx_d, last_state_idx_p = torch.split(
attn_metadata.last_state_idx, [num_decodes, num_prefills],
dim=0)
current_last_idx_d, current_last_idx_p = torch.split(
attn_metadata.current_last_idx, [num_decodes, num_prefills],
dim=0)
# Prefill-only variables:
current_first_idx_p = attn_metadata.current_first_idx_p
context_lens_p = attn_metadata.context_lens_p
last_computed_offset_p = attn_metadata.last_computed_offset_p
else:
last_state_idx_d, last_state_idx_p = None, None
current_last_idx_d, current_last_idx_p = None, None
current_first_idx_p = None
context_lens_p = None

# Preallocate output tensor to avoid memcpy cost for merging prefill
# and decode outputs
preallocated_ssm_out = torch.empty(
Expand All @@ -592,8 +614,17 @@ def forward_cuda(
# Process prefill requests
if has_prefill:
# 2. Convolution sequence transformation
# - "cache_indices" updates the conv_state cache in positions
# pointed to by "state_indices_tensor"
# - It will read the initial states for every sequence,
# that has "has_initial_states_p" == True,
# from "cache_indices", using "state_indices_tensor_p".
# - It updates the "conv_state" cache in positions pointed
# to by "state_indices_tensor_p".
# In particular, it will always write the state at the
# sequence end.
# In addition, "current_first_idx_p" and "current_last_idx_p"
# are provided (which are pointers into
# "state_indices_tensor_p"), it will write additional cache
# states aligned at "block_size_to_align".
x = hidden_states_B_C_p.transpose(
0, 1) # this is the form that causal-conv see
hidden_states_B_C_p = causal_conv1d_fn(
Expand All @@ -604,6 +635,11 @@ def forward_cuda(
conv_states=conv_state,
has_initial_state=has_initial_states_p,
cache_indices=state_indices_tensor_p,
current_first_idx=current_first_idx_p,
current_last_idx=current_last_idx_p,
initial_state_idx=last_state_idx_p,
context_lens=context_lens_p,
block_size_to_align=mamba_block_size,
metadata=attn_metadata,
query_start_loc=query_start_loc_p).transpose(
0, 1)[:num_prefill_tokens]
Expand All @@ -614,9 +650,13 @@ def forward_cuda(
# 3. State Space Model sequence transformation
initial_states = None
if (has_initial_states_p is not None and prep_initial_states):
kernel_ssm_indices = state_indices_tensor_p
if prefix_caching_enabled:
kernel_ssm_indices = state_indices_tensor_p.gather(
1, last_state_idx_p.unsqueeze(1)).squeeze(1)
initial_states = torch.where(
has_initial_states_p[:, None, None, None],
ssm_state[state_indices_tensor_p], 0)
ssm_state[kernel_ssm_indices], 0)

# NOTE: final output is an in-place update of out tensor
varlen_states = mamba_chunk_scan_combined_varlen(
Expand All @@ -638,26 +678,82 @@ def forward_cuda(
cu_chunk_seqlens=cu_chunk_seqlen_p,
last_chunk_indices=last_chunk_indices_p,
initial_states=initial_states,
return_intermediate_states=prefix_caching_enabled,
dt_softplus=True,
dt_limit=(0.0, float("inf")),
out=preallocated_ssm_out_p.view(num_prefill_tokens, -1,
self.head_dim),
state_dtype=ssm_state.dtype)

# update ssm states
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
ssm_state[state_indices_tensor_p] = varlen_states
if prefix_caching_enabled:
# Save states for sequences with more than just the final state:
n_blocks_to_fill = current_last_idx_p - current_first_idx_p
for seq_idx in (n_blocks_to_fill > 0).nonzero().squeeze(1):
cache_blocks_to_fill = state_indices_tensor_p[
seq_idx, current_first_idx_p[seq_idx]:
current_first_idx_p[seq_idx] +
n_blocks_to_fill[seq_idx]]
# chunks = [0 1 2 3 4 5 6 ...]
# First aligned chunk would typically be:
# mamba_block_size = 1024, chunk_size = 256
# 1024 // 256 - 1 --> chunks[3]
# But when last chunk wasn't block aligned:
# - last_computed_offset_p[seq_idx] // chunk_size
# e.g. 1000 // 256 -> 3 completed --> store chunk[0]
# e.g. 513 // 256 -> 2 completed --> store chunk[1] (skip 1)
# e.g. 256 // 256 -> 1 completed --> store chunk[2] (skip 2)
# e.g. 10 // 256 -> 0 completed --> store chunk[3] (skip 3)
chunk_stride = mamba_block_size // chunk_size
first_aligned_chunk = \
torch.concat([torch.zeros(1, \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you need an additional torch.zeros?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re-written this

dtype=last_chunk_indices_p.dtype, \
device=last_chunk_indices_p.device), \
last_chunk_indices_p + 1])[seq_idx] \
+ chunk_stride - 1 \
- last_computed_offset_p[seq_idx] // chunk_size
Comment on lines +697 to +713
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I can't understand this part (And also the comments in 702-705...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've re-written the code in a way that I believe is easier to understand what is happening. We plan to fuse this code into the kernels at a later stage but I would like to make sure the code on main right now is easy to follow.

from_where = varlen_states[
first_aligned_chunk:first_aligned_chunk +
n_blocks_to_fill[seq_idx] * chunk_stride:chunk_stride]
ssm_state[cache_blocks_to_fill] = from_where

#For all seqs, store the last state (Note: might be partial):
ssm_state[state_indices_tensor_p.gather(1,
current_last_idx_p.unsqueeze(1)).squeeze(1)] = \
varlen_states[last_chunk_indices_p]
else:
# update ssm states
# - varlen state is a (num_prefills, nheads, headdim, dstate)
# tensor
ssm_state[state_indices_tensor_p] = varlen_states

# Process decode requests
if has_decode:
if prefix_caching_enabled:
state_indices_tensor_d_input = \
state_indices_tensor_d.gather(1,
last_state_idx_d.unsqueeze(1)).squeeze(1)
state_indices_tensor_d_output = \
state_indices_tensor_d.gather(1,
current_last_idx_d.unsqueeze(1)).squeeze(1)
#Note:
# for decode always: current_first_idx_d == current_last_idx_d
# at block boundaries: current_first_idx_d > last_state_idx_d
else:
# Without caching, read and write in-place to the same blocks:
state_indices_tensor_d_input = state_indices_tensor_d
state_indices_tensor_d_output = state_indices_tensor_d

# 2. Convolution sequence transformation
hidden_states_B_C_d = causal_conv1d_update(
hidden_states_B_C_d,
conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=state_indices_tensor_d)
conv_state_indices=state_indices_tensor_d,
current_last_idx=current_last_idx_d,
initial_state_idx=last_state_idx_d,
)

hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn(
hidden_states_B_C_d)
Expand Down Expand Up @@ -689,7 +785,8 @@ def forward_cuda(
z=None,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=state_indices_tensor_d,
state_batch_indices=state_indices_tensor_d_input,
dst_state_batch_indices=state_indices_tensor_d_output,
out=preallocated_ssm_out_d.view(num_decodes, -1,
self.head_dim),
)
Expand Down
Loading