Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 96 additions & 57 deletions vllm/model_executor/layers/mamba/mamba_mixer2.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,21 +595,32 @@ def forward_cuda(
if prefix_caching_enabled:
# If prefix caching is enabled, retrieve the relevant variables
# for prefill and decode
last_state_idx_d, last_state_idx_p = torch.split(
attn_metadata.last_state_idx, [num_decodes, num_prefills], dim=0
block_idx_last_computed_token_d, block_idx_last_computed_token_p = (
torch.split(
attn_metadata.block_idx_last_computed_token,
[num_decodes, num_prefills],
dim=0,
)
)
current_last_idx_d, current_last_idx_p = torch.split(
attn_metadata.current_last_idx, [num_decodes, num_prefills], dim=0
block_idx_last_scheduled_token_d, block_idx_last_scheduled_token_p = (
torch.split(
attn_metadata.block_idx_last_scheduled_token,
[num_decodes, num_prefills],
dim=0,
)
)
# Prefill-only variables:
current_first_idx_p = attn_metadata.current_first_idx_p
context_lens_p = attn_metadata.context_lens_p
last_computed_offset_p = attn_metadata.last_computed_offset_p
block_idx_first_scheduled_token_p = (
attn_metadata.block_idx_first_scheduled_token_p
)
num_computed_tokens_p = attn_metadata.num_computed_tokens_p
else:
last_state_idx_d, last_state_idx_p = None, None
current_last_idx_d, current_last_idx_p = None, None
current_first_idx_p = None
context_lens_p = None
block_idx_last_computed_token_d = None
block_idx_last_computed_token_p = None
block_idx_last_scheduled_token_d = None
block_idx_last_scheduled_token_p = None
block_idx_first_scheduled_token_p = None
num_computed_tokens_p = None

# Preallocate output tensor to avoid memcpy cost for merging prefill
# and decode outputs
Expand Down Expand Up @@ -637,7 +648,8 @@ def forward_cuda(
# to by "state_indices_tensor_p".
# In particular, it will always write the state at the
# sequence end.
# In addition, "current_first_idx_p" and "current_last_idx_p"
# In addition, "block_idx_first_scheduled_token_p" and
# "block_idx_last_scheduled_token_p"
# are provided (which are pointers into
# "state_indices_tensor_p"), it will write additional cache
# states aligned at "block_size_to_align".
Expand All @@ -652,10 +664,10 @@ def forward_cuda(
conv_states=conv_state,
has_initial_state=has_initial_states_p,
cache_indices=state_indices_tensor_p,
current_first_idx=current_first_idx_p,
current_last_idx=current_last_idx_p,
initial_state_idx=last_state_idx_p,
context_lens=context_lens_p,
block_idx_first_scheduled_token=block_idx_first_scheduled_token_p,
block_idx_last_scheduled_token=block_idx_last_scheduled_token_p,
initial_state_idx=block_idx_last_computed_token_p,
num_computed_tokens=num_computed_tokens_p,
block_size_to_align=mamba_block_size,
metadata=attn_metadata,
query_start_loc=query_start_loc_p,
Expand All @@ -669,7 +681,7 @@ def forward_cuda(
kernel_ssm_indices = state_indices_tensor_p
if prefix_caching_enabled:
kernel_ssm_indices = state_indices_tensor_p.gather(
1, last_state_idx_p.unsqueeze(1)
1, block_idx_last_computed_token_p.unsqueeze(1)
).squeeze(1)
initial_states = torch.where(
has_initial_states_p[:, None, None, None],
Expand Down Expand Up @@ -703,52 +715,76 @@ def forward_cuda(
)

if prefix_caching_enabled:
# Save states for sequences with more than just the final state:
n_blocks_to_fill = current_last_idx_p - current_first_idx_p
for seq_idx in (n_blocks_to_fill > 0).nonzero().squeeze(1):
# The chunk_stride is the number of chunks per mamba block
# e.g., if mamba_block_size = 512 and chunk_size = 256,
# then chunk_stride = 2
chunk_stride = mamba_block_size // chunk_size

# Save state for sequences with more than just final state
for seq_idx in range(num_prefills):
# Block index for the first scheduled token
block_idx_first_scheduled_token = block_idx_first_scheduled_token_p[
seq_idx
]

# Block index for the last scheduled token
block_idx_last_scheduled_token = block_idx_last_scheduled_token_p[
seq_idx
]

# Number of blocks that need to be written
n_blocks_to_fill = (
block_idx_last_scheduled_token - block_idx_first_scheduled_token
)

# Skip sequences that don't have any blocks to fill
if n_blocks_to_fill == 0:
continue

# Look up the state indices
cache_blocks_to_fill = state_indices_tensor_p[
seq_idx,
current_first_idx_p[seq_idx] : current_first_idx_p[seq_idx]
+ n_blocks_to_fill[seq_idx],
block_idx_first_scheduled_token:block_idx_last_scheduled_token,
]
# chunks = [0 1 2 3 4 5 6 ...]
# First aligned chunk would typically be:
# mamba_block_size = 1024, chunk_size = 256
# 1024 // 256 - 1 --> chunks[3]
# But when last chunk wasn't block aligned:
# - last_computed_offset_p[seq_idx] // chunk_size
# e.g. 1000 // 256 -> 3 completed --> store chunk[0]
# e.g. 513 // 256 -> 2 completed --> store chunk[1] (skip 1)
# e.g. 256 // 256 -> 1 completed --> store chunk[2] (skip 2)
# e.g. 10 // 256 -> 0 completed --> store chunk[3] (skip 3)
chunk_stride = mamba_block_size // chunk_size
first_aligned_chunk = (
torch.concat(
[
torch.zeros(
1,
dtype=last_chunk_indices_p.dtype,
device=last_chunk_indices_p.device,
),
last_chunk_indices_p + 1,
]
)[seq_idx]
+ chunk_stride
- 1
- last_computed_offset_p[seq_idx] // chunk_size

# First chunk index for this sequence
if seq_idx == 0:
first_chunk = 0
else:
first_chunk = 1 + last_chunk_indices_p[seq_idx - 1]

# First chunk that is aligned on the mamba block boundary
first_aligned_chunk = first_chunk + chunk_stride - 1

# Calculate the number of computed tokens that were not
# already cached
num_unaligned_computed_tokens = (
num_computed_tokens_p[seq_idx] % mamba_block_size
)

if num_unaligned_computed_tokens > 0:
# If the number of computed tokens is not block aligned,
# then we need to shift the index accordingly
first_aligned_chunk -= (
num_unaligned_computed_tokens // chunk_size
)

# Get states to write
from_where = varlen_states[
first_aligned_chunk : first_aligned_chunk
+ n_blocks_to_fill[seq_idx] * chunk_stride : chunk_stride
+ n_blocks_to_fill * chunk_stride : chunk_stride
]

# Write the states
ssm_state[cache_blocks_to_fill] = from_where

# For all seqs, store the last state (Note: might be partial):
# For all seqs, store the last state (note: might be partial):
ssm_state[
state_indices_tensor_p.gather(
1, current_last_idx_p.unsqueeze(1)
1, block_idx_last_scheduled_token_p.unsqueeze(1)
).squeeze(1)
] = varlen_states[last_chunk_indices_p]

else:
# update ssm states
# - varlen state is a (num_prefills, nheads, headdim, dstate)
Expand All @@ -759,14 +795,17 @@ def forward_cuda(
if has_decode:
if prefix_caching_enabled:
state_indices_tensor_d_input = state_indices_tensor_d.gather(
1, last_state_idx_d.unsqueeze(1)
1, block_idx_last_computed_token_d.unsqueeze(1)
).squeeze(1)
state_indices_tensor_d_output = state_indices_tensor_d.gather(
1, current_last_idx_d.unsqueeze(1)
1, block_idx_last_scheduled_token_d.unsqueeze(1)
).squeeze(1)
# Note:
# for decode always: current_first_idx_d == current_last_idx_d
# at block boundaries: current_first_idx_d > last_state_idx_d
# for decode:
# block_idx_first_scheduled_token_d ==
# block_idx_last_scheduled_token_d
# at block boundaries:
# block_idx_first_scheduled_token_d >
# block_idx_last_computed_token_d
else:
# Without caching, read and write in-place to the same blocks:
state_indices_tensor_d_input = state_indices_tensor_d
Expand All @@ -780,8 +819,8 @@ def forward_cuda(
self.conv1d.bias,
self.activation,
conv_state_indices=state_indices_tensor_d,
current_last_idx=current_last_idx_d,
initial_state_idx=last_state_idx_d,
block_idx_last_scheduled_token=block_idx_last_scheduled_token_d,
initial_state_idx=block_idx_last_computed_token_d,
)

hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn(hidden_states_B_C_d)
Expand Down
44 changes: 22 additions & 22 deletions vllm/model_executor/layers/mamba/ops/causal_conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ def _causal_conv1d_fwd_kernel( # continuous batching
query_start_loc_ptr,
batch_ptr,
token_chunk_offset_ptr,
current_first_idx, # (batch,)
current_last_idx, # (batch,)
block_idx_first_scheduled_token, # (batch,)
block_idx_last_scheduled_token, # (batch,)
initial_state_idx, # (batch,)
context_lens, # (batch,)
num_computed_tokens, # (batch,)
o_ptr, # (dim, seqlen) - actually pointing to x_ptr
# Matrix dimensions
dim: tl.constexpr,
Expand Down Expand Up @@ -94,9 +94,9 @@ def _causal_conv1d_fwd_kernel( # continuous batching
# In particular, if prefix caching is enabled, the program write additional cache states to "cache_indices_ptr"

# Get the length of the completed sequence so far and compute the offset.
current_first_index = tl.load(current_first_idx + idx_seq)
current_last_index = tl.load(current_last_idx + idx_seq)
sequence_completed_index = tl.load(context_lens + idx_seq)
current_first_index = tl.load(block_idx_first_scheduled_token + idx_seq)
current_last_index = tl.load(block_idx_last_scheduled_token + idx_seq)
sequence_completed_index = tl.load(num_computed_tokens + idx_seq)

# Compute the offset where the first stride_block_m-aligned first full block is
# Value in "token-space"
Expand Down Expand Up @@ -476,10 +476,10 @@ def causal_conv1d_fn(
has_initial_state: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu",
pad_slot_id: int = PAD_SLOT_ID,
current_first_idx: Optional[torch.Tensor] = None,
current_last_idx: Optional[torch.Tensor] = None,
block_idx_first_scheduled_token: Optional[torch.Tensor] = None,
block_idx_last_scheduled_token: Optional[torch.Tensor] = None,
initial_state_idx: Optional[torch.Tensor] = None,
context_lens: Optional[torch.Tensor] = None,
num_computed_tokens: Optional[torch.Tensor] = None,
block_size_to_align=0,
metadata=None,
validate_data=False,
Expand Down Expand Up @@ -523,13 +523,13 @@ def causal_conv1d_fn(
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
current_first_idx: (batch,), dtype int32
block_idx_first_scheduled_token: (batch,), dtype int32
The pointer into cache_indices, where the first cache block to be filled is located.
current_last_idx: (batch,), dtype int32
block_idx_last_scheduled_token: (batch,), dtype int32
The pointer into cache_indices, where the last cache block to be filled is located.
initial_state_idx: (batch,), dtype int32
The pointer into cache_indices, where the cache block containing the initial state is located.
context_lens: (batch,), dtype int32
num_computed_tokens: (batch,), dtype int32
The number of tokens already completed for each sequence
block_size_to_align: int
The block size to align the cached states to
Expand Down Expand Up @@ -708,10 +708,10 @@ def grid(META):
query_start_loc,
batch_ptr,
token_chunk_offset_ptr,
current_first_idx,
current_last_idx,
block_idx_first_scheduled_token,
block_idx_last_scheduled_token,
initial_state_idx,
context_lens,
num_computed_tokens,
out,
# Matrix dimensions
dim,
Expand All @@ -735,7 +735,7 @@ def grid(META):
HAS_BIAS=bias is not None,
KERNEL_WIDTH=width,
SILU_ACTIVATION=activation in ["silu", "swish"],
IS_APC_ENABLED=current_last_idx is not None,
IS_APC_ENABLED=block_idx_last_scheduled_token is not None,
USE_PAD_SLOT=pad_slot_id is not None,
NP2_STATELEN=np2_statelen,
# launch_cooperative_grid=True
Expand All @@ -756,7 +756,7 @@ def _causal_conv1d_update_kernel(
conv_state_indices_ptr,
num_accepted_tokens_ptr,
query_start_loc_ptr, # (batch + 1)
current_last_idx, # (batch,)
block_idx_last_scheduled_token, # (batch,)
initial_state_idx, # (batch,)
o_ptr, # (batch, dim, seqlen)
# Matrix dimensions
Expand Down Expand Up @@ -802,7 +802,7 @@ def _causal_conv1d_update_kernel(
if IS_APC_ENABLED:
# Get the state from the initial_state_idx
conv_state_init = tl.load(initial_state_idx + idx_seq)
current_last_index = tl.load(current_last_idx + idx_seq)
current_last_index = tl.load(block_idx_last_scheduled_token + idx_seq)
else:
conv_state_init = 0
current_last_index = 0
Expand Down Expand Up @@ -1078,7 +1078,7 @@ def causal_conv1d_update(
query_start_loc: Optional[torch.Tensor] = None,
max_query_len: int = -1,
pad_slot_id: int = PAD_SLOT_ID,
current_last_idx: Optional[torch.Tensor] = None,
block_idx_last_scheduled_token: Optional[torch.Tensor] = None,
initial_state_idx: Optional[torch.Tensor] = None,
validate_data=False,
):
Expand All @@ -1097,7 +1097,7 @@ def causal_conv1d_update(
If not None, the conv_state is a larger tensor along the batch dim,
and we are selecting the batch coords specified by conv_state_indices.
Useful for a continuous batching scenario.
current_last_idx: (batch,), dtype int32
block_idx_last_scheduled_token: (batch,), dtype int32
The pointer into conv_state_indices, where the last cache block to be filled is located.
initial_state_idx: (batch,), dtype int32
The pointer into conv_state_indices, where the cache block containing the initial state is located.
Expand Down Expand Up @@ -1201,7 +1201,7 @@ def grid(META):
conv_state_indices,
num_accepted_tokens,
query_start_loc,
current_last_idx,
block_idx_last_scheduled_token,
initial_state_idx,
out,
# Matrix dimensions
Expand Down Expand Up @@ -1230,7 +1230,7 @@ def grid(META):
KERNEL_WIDTH=width,
SILU_ACTIVATION=activation in ["silu", "swish"],
IS_VARLEN=query_start_loc is not None,
IS_APC_ENABLED=current_last_idx is not None,
IS_APC_ENABLED=block_idx_last_scheduled_token is not None,
IS_SPEC_DECODING=num_accepted_tokens is not None,
NP2_STATELEN=np2_statelen,
USE_PAD_SLOT=pad_slot_id is not None,
Expand Down
Loading