diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 12b2c78f1a74..7589905ac927 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -595,21 +595,32 @@ def forward_cuda( if prefix_caching_enabled: # If prefix caching is enabled, retrieve the relevant variables # for prefill and decode - last_state_idx_d, last_state_idx_p = torch.split( - attn_metadata.last_state_idx, [num_decodes, num_prefills], dim=0 + block_idx_last_computed_token_d, block_idx_last_computed_token_p = ( + torch.split( + attn_metadata.block_idx_last_computed_token, + [num_decodes, num_prefills], + dim=0, + ) ) - current_last_idx_d, current_last_idx_p = torch.split( - attn_metadata.current_last_idx, [num_decodes, num_prefills], dim=0 + block_idx_last_scheduled_token_d, block_idx_last_scheduled_token_p = ( + torch.split( + attn_metadata.block_idx_last_scheduled_token, + [num_decodes, num_prefills], + dim=0, + ) ) # Prefill-only variables: - current_first_idx_p = attn_metadata.current_first_idx_p - context_lens_p = attn_metadata.context_lens_p - last_computed_offset_p = attn_metadata.last_computed_offset_p + block_idx_first_scheduled_token_p = ( + attn_metadata.block_idx_first_scheduled_token_p + ) + num_computed_tokens_p = attn_metadata.num_computed_tokens_p else: - last_state_idx_d, last_state_idx_p = None, None - current_last_idx_d, current_last_idx_p = None, None - current_first_idx_p = None - context_lens_p = None + block_idx_last_computed_token_d = None + block_idx_last_computed_token_p = None + block_idx_last_scheduled_token_d = None + block_idx_last_scheduled_token_p = None + block_idx_first_scheduled_token_p = None + num_computed_tokens_p = None # Preallocate output tensor to avoid memcpy cost for merging prefill # and decode outputs @@ -637,7 +648,8 @@ def forward_cuda( # to by "state_indices_tensor_p". # In particular, it will always write the state at the # sequence end. - # In addition, "current_first_idx_p" and "current_last_idx_p" + # In addition, "block_idx_first_scheduled_token_p" and + # "block_idx_last_scheduled_token_p" # are provided (which are pointers into # "state_indices_tensor_p"), it will write additional cache # states aligned at "block_size_to_align". @@ -652,10 +664,10 @@ def forward_cuda( conv_states=conv_state, has_initial_state=has_initial_states_p, cache_indices=state_indices_tensor_p, - current_first_idx=current_first_idx_p, - current_last_idx=current_last_idx_p, - initial_state_idx=last_state_idx_p, - context_lens=context_lens_p, + block_idx_first_scheduled_token=block_idx_first_scheduled_token_p, + block_idx_last_scheduled_token=block_idx_last_scheduled_token_p, + initial_state_idx=block_idx_last_computed_token_p, + num_computed_tokens=num_computed_tokens_p, block_size_to_align=mamba_block_size, metadata=attn_metadata, query_start_loc=query_start_loc_p, @@ -669,7 +681,7 @@ def forward_cuda( kernel_ssm_indices = state_indices_tensor_p if prefix_caching_enabled: kernel_ssm_indices = state_indices_tensor_p.gather( - 1, last_state_idx_p.unsqueeze(1) + 1, block_idx_last_computed_token_p.unsqueeze(1) ).squeeze(1) initial_states = torch.where( has_initial_states_p[:, None, None, None], @@ -703,52 +715,76 @@ def forward_cuda( ) if prefix_caching_enabled: - # Save states for sequences with more than just the final state: - n_blocks_to_fill = current_last_idx_p - current_first_idx_p - for seq_idx in (n_blocks_to_fill > 0).nonzero().squeeze(1): + # The chunk_stride is the number of chunks per mamba block + # e.g., if mamba_block_size = 512 and chunk_size = 256, + # then chunk_stride = 2 + chunk_stride = mamba_block_size // chunk_size + + # Save state for sequences with more than just final state + for seq_idx in range(num_prefills): + # Block index for the first scheduled token + block_idx_first_scheduled_token = block_idx_first_scheduled_token_p[ + seq_idx + ] + + # Block index for the last scheduled token + block_idx_last_scheduled_token = block_idx_last_scheduled_token_p[ + seq_idx + ] + + # Number of blocks that need to be written + n_blocks_to_fill = ( + block_idx_last_scheduled_token - block_idx_first_scheduled_token + ) + + # Skip sequences that don't have any blocks to fill + if n_blocks_to_fill == 0: + continue + + # Look up the state indices cache_blocks_to_fill = state_indices_tensor_p[ seq_idx, - current_first_idx_p[seq_idx] : current_first_idx_p[seq_idx] - + n_blocks_to_fill[seq_idx], + block_idx_first_scheduled_token:block_idx_last_scheduled_token, ] - # chunks = [0 1 2 3 4 5 6 ...] - # First aligned chunk would typically be: - # mamba_block_size = 1024, chunk_size = 256 - # 1024 // 256 - 1 --> chunks[3] - # But when last chunk wasn't block aligned: - # - last_computed_offset_p[seq_idx] // chunk_size - # e.g. 1000 // 256 -> 3 completed --> store chunk[0] - # e.g. 513 // 256 -> 2 completed --> store chunk[1] (skip 1) - # e.g. 256 // 256 -> 1 completed --> store chunk[2] (skip 2) - # e.g. 10 // 256 -> 0 completed --> store chunk[3] (skip 3) - chunk_stride = mamba_block_size // chunk_size - first_aligned_chunk = ( - torch.concat( - [ - torch.zeros( - 1, - dtype=last_chunk_indices_p.dtype, - device=last_chunk_indices_p.device, - ), - last_chunk_indices_p + 1, - ] - )[seq_idx] - + chunk_stride - - 1 - - last_computed_offset_p[seq_idx] // chunk_size + + # First chunk index for this sequence + if seq_idx == 0: + first_chunk = 0 + else: + first_chunk = 1 + last_chunk_indices_p[seq_idx - 1] + + # First chunk that is aligned on the mamba block boundary + first_aligned_chunk = first_chunk + chunk_stride - 1 + + # Calculate the number of computed tokens that were not + # already cached + num_unaligned_computed_tokens = ( + num_computed_tokens_p[seq_idx] % mamba_block_size ) + + if num_unaligned_computed_tokens > 0: + # If the number of computed tokens is not block aligned, + # then we need to shift the index accordingly + first_aligned_chunk -= ( + num_unaligned_computed_tokens // chunk_size + ) + + # Get states to write from_where = varlen_states[ first_aligned_chunk : first_aligned_chunk - + n_blocks_to_fill[seq_idx] * chunk_stride : chunk_stride + + n_blocks_to_fill * chunk_stride : chunk_stride ] + + # Write the states ssm_state[cache_blocks_to_fill] = from_where - # For all seqs, store the last state (Note: might be partial): + # For all seqs, store the last state (note: might be partial): ssm_state[ state_indices_tensor_p.gather( - 1, current_last_idx_p.unsqueeze(1) + 1, block_idx_last_scheduled_token_p.unsqueeze(1) ).squeeze(1) ] = varlen_states[last_chunk_indices_p] + else: # update ssm states # - varlen state is a (num_prefills, nheads, headdim, dstate) @@ -759,14 +795,17 @@ def forward_cuda( if has_decode: if prefix_caching_enabled: state_indices_tensor_d_input = state_indices_tensor_d.gather( - 1, last_state_idx_d.unsqueeze(1) + 1, block_idx_last_computed_token_d.unsqueeze(1) ).squeeze(1) state_indices_tensor_d_output = state_indices_tensor_d.gather( - 1, current_last_idx_d.unsqueeze(1) + 1, block_idx_last_scheduled_token_d.unsqueeze(1) ).squeeze(1) - # Note: - # for decode always: current_first_idx_d == current_last_idx_d - # at block boundaries: current_first_idx_d > last_state_idx_d + # for decode: + # block_idx_first_scheduled_token_d == + # block_idx_last_scheduled_token_d + # at block boundaries: + # block_idx_first_scheduled_token_d > + # block_idx_last_computed_token_d else: # Without caching, read and write in-place to the same blocks: state_indices_tensor_d_input = state_indices_tensor_d @@ -780,8 +819,8 @@ def forward_cuda( self.conv1d.bias, self.activation, conv_state_indices=state_indices_tensor_d, - current_last_idx=current_last_idx_d, - initial_state_idx=last_state_idx_d, + block_idx_last_scheduled_token=block_idx_last_scheduled_token_d, + initial_state_idx=block_idx_last_computed_token_d, ) hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn(hidden_states_B_C_d) diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index a6d5d4d17970..ec486d3b9267 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -27,10 +27,10 @@ def _causal_conv1d_fwd_kernel( # continuous batching query_start_loc_ptr, batch_ptr, token_chunk_offset_ptr, - current_first_idx, # (batch,) - current_last_idx, # (batch,) + block_idx_first_scheduled_token, # (batch,) + block_idx_last_scheduled_token, # (batch,) initial_state_idx, # (batch,) - context_lens, # (batch,) + num_computed_tokens, # (batch,) o_ptr, # (dim, seqlen) - actually pointing to x_ptr # Matrix dimensions dim: tl.constexpr, @@ -94,9 +94,9 @@ def _causal_conv1d_fwd_kernel( # continuous batching # In particular, if prefix caching is enabled, the program write additional cache states to "cache_indices_ptr" # Get the length of the completed sequence so far and compute the offset. - current_first_index = tl.load(current_first_idx + idx_seq) - current_last_index = tl.load(current_last_idx + idx_seq) - sequence_completed_index = tl.load(context_lens + idx_seq) + current_first_index = tl.load(block_idx_first_scheduled_token + idx_seq) + current_last_index = tl.load(block_idx_last_scheduled_token + idx_seq) + sequence_completed_index = tl.load(num_computed_tokens + idx_seq) # Compute the offset where the first stride_block_m-aligned first full block is # Value in "token-space" @@ -476,10 +476,10 @@ def causal_conv1d_fn( has_initial_state: Optional[torch.Tensor] = None, activation: Optional[str] = "silu", pad_slot_id: int = PAD_SLOT_ID, - current_first_idx: Optional[torch.Tensor] = None, - current_last_idx: Optional[torch.Tensor] = None, + block_idx_first_scheduled_token: Optional[torch.Tensor] = None, + block_idx_last_scheduled_token: Optional[torch.Tensor] = None, initial_state_idx: Optional[torch.Tensor] = None, - context_lens: Optional[torch.Tensor] = None, + num_computed_tokens: Optional[torch.Tensor] = None, block_size_to_align=0, metadata=None, validate_data=False, @@ -523,13 +523,13 @@ def causal_conv1d_fn( for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] in this case, the kernel will not process entries at indices 0 and 3 - current_first_idx: (batch,), dtype int32 + block_idx_first_scheduled_token: (batch,), dtype int32 The pointer into cache_indices, where the first cache block to be filled is located. - current_last_idx: (batch,), dtype int32 + block_idx_last_scheduled_token: (batch,), dtype int32 The pointer into cache_indices, where the last cache block to be filled is located. initial_state_idx: (batch,), dtype int32 The pointer into cache_indices, where the cache block containing the initial state is located. - context_lens: (batch,), dtype int32 + num_computed_tokens: (batch,), dtype int32 The number of tokens already completed for each sequence block_size_to_align: int The block size to align the cached states to @@ -708,10 +708,10 @@ def grid(META): query_start_loc, batch_ptr, token_chunk_offset_ptr, - current_first_idx, - current_last_idx, + block_idx_first_scheduled_token, + block_idx_last_scheduled_token, initial_state_idx, - context_lens, + num_computed_tokens, out, # Matrix dimensions dim, @@ -735,7 +735,7 @@ def grid(META): HAS_BIAS=bias is not None, KERNEL_WIDTH=width, SILU_ACTIVATION=activation in ["silu", "swish"], - IS_APC_ENABLED=current_last_idx is not None, + IS_APC_ENABLED=block_idx_last_scheduled_token is not None, USE_PAD_SLOT=pad_slot_id is not None, NP2_STATELEN=np2_statelen, # launch_cooperative_grid=True @@ -756,7 +756,7 @@ def _causal_conv1d_update_kernel( conv_state_indices_ptr, num_accepted_tokens_ptr, query_start_loc_ptr, # (batch + 1) - current_last_idx, # (batch,) + block_idx_last_scheduled_token, # (batch,) initial_state_idx, # (batch,) o_ptr, # (batch, dim, seqlen) # Matrix dimensions @@ -802,7 +802,7 @@ def _causal_conv1d_update_kernel( if IS_APC_ENABLED: # Get the state from the initial_state_idx conv_state_init = tl.load(initial_state_idx + idx_seq) - current_last_index = tl.load(current_last_idx + idx_seq) + current_last_index = tl.load(block_idx_last_scheduled_token + idx_seq) else: conv_state_init = 0 current_last_index = 0 @@ -1078,7 +1078,7 @@ def causal_conv1d_update( query_start_loc: Optional[torch.Tensor] = None, max_query_len: int = -1, pad_slot_id: int = PAD_SLOT_ID, - current_last_idx: Optional[torch.Tensor] = None, + block_idx_last_scheduled_token: Optional[torch.Tensor] = None, initial_state_idx: Optional[torch.Tensor] = None, validate_data=False, ): @@ -1097,7 +1097,7 @@ def causal_conv1d_update( If not None, the conv_state is a larger tensor along the batch dim, and we are selecting the batch coords specified by conv_state_indices. Useful for a continuous batching scenario. - current_last_idx: (batch,), dtype int32 + block_idx_last_scheduled_token: (batch,), dtype int32 The pointer into conv_state_indices, where the last cache block to be filled is located. initial_state_idx: (batch,), dtype int32 The pointer into conv_state_indices, where the cache block containing the initial state is located. @@ -1201,7 +1201,7 @@ def grid(META): conv_state_indices, num_accepted_tokens, query_start_loc, - current_last_idx, + block_idx_last_scheduled_token, initial_state_idx, out, # Matrix dimensions @@ -1230,7 +1230,7 @@ def grid(META): KERNEL_WIDTH=width, SILU_ACTIVATION=activation in ["silu", "swish"], IS_VARLEN=query_start_loc is not None, - IS_APC_ENABLED=current_last_idx is not None, + IS_APC_ENABLED=block_idx_last_scheduled_token is not None, IS_SPEC_DECODING=num_accepted_tokens is not None, NP2_STATELEN=np2_statelen, USE_PAD_SLOT=pad_slot_id is not None, diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index ae8a0e92daf4..10f09442d82e 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -122,11 +122,10 @@ class Mamba2AttentionMetadata: last_chunk_indices_p: Optional[torch.Tensor] state_indices_tensor: torch.Tensor # shape: [batch,] - current_last_idx: torch.Tensor - current_first_idx_p: torch.Tensor - last_state_idx: torch.Tensor - context_lens_p: torch.Tensor - last_computed_offset_p: torch.Tensor + block_idx_last_scheduled_token: torch.Tensor # shape: [batch,] + block_idx_first_scheduled_token_p: torch.Tensor # shape: [batch,] + block_idx_last_computed_token: torch.Tensor # shape: [batch,] + num_computed_tokens_p: torch.Tensor # shape: [batch,] # The following attributes are for triton implementation of causal_conv1d nums_dict: Optional[dict] = None @@ -160,12 +159,12 @@ def __init__( dtype=torch.int32, device=device, ) - self.current_last_idx = torch.empty( + self.block_idx_last_scheduled_token = torch.empty( (self.decode_cudagraph_max_bs,), dtype=torch.int32, device=device, ) - self.last_state_idx = torch.empty( + self.block_idx_last_computed_token = torch.empty( (self.decode_cudagraph_max_bs,), dtype=torch.int32, device=device, @@ -192,43 +191,38 @@ def build( # for causal_conv1d nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None - context_lens, context_lens_p = None, None - current_first_idx, current_first_idx_p = None, None - last_computed_offset, last_computed_offset_p = None, None + num_computed_tokens, num_computed_tokens_p = None, None + block_idx_first_scheduled_token = None + block_idx_first_scheduled_token_p = None if self.vllm_config.cache_config.enable_prefix_caching: # Return a tensor of shape (#requests, #max blocks) state_indices_tensor = common_attn_metadata.block_table_tensor - # Additional cache-related varaiables: mamba_block_size = self.kv_cache_spec.block_size - seq_lens_pending = ( - torch.roll(common_attn_metadata.query_start_loc, -1, -1) - - common_attn_metadata.query_start_loc - )[:-1] - context_lens = common_attn_metadata.seq_lens - seq_lens_pending - last_computed_offset = context_lens % mamba_block_size - # Indices: last_computed <= current_first <= current_last - # Cases: - # last_computed == current_first if last state was partially - # computed and needs to be updated - # current_first == current_last if no block crossing occurs, and - # only one state will be stored - # 0th based indexing leads to "-1" -> e.g. 16 computed -> state[15]: - current_last_idx = ( - cdiv(context_lens + seq_lens_pending, mamba_block_size) - 1 + num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to( + self.device + ) + # Block index of the last computed token + block_idx_last_computed_token = ( + cdiv(num_computed_tokens, mamba_block_size) - 1 + ) + # which is <= block index for the first scheduled token + block_idx_first_scheduled_token = ( + cdiv(num_computed_tokens + 1, mamba_block_size) - 1 + ) + # which is <= block index of the last scheduled token + block_idx_last_scheduled_token = ( + cdiv(common_attn_metadata.seq_lens, mamba_block_size) - 1 ) - current_first_idx = cdiv(context_lens + 1, mamba_block_size) - 1 - last_state_idx = cdiv(context_lens, mamba_block_size) - 1 # -1 in case it's non-computed and causes later issues with indexing - last_state_idx = last_state_idx.clamp(min=0) - + block_idx_last_computed_token = block_idx_last_computed_token.clamp(min=0) else: # Always return just a single block per each request: state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] # Additional cache-related varaiables: - current_last_idx = None - last_state_idx = None + block_idx_last_scheduled_token = None + block_idx_last_computed_token = None num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( @@ -256,18 +250,15 @@ def build( ) if self.vllm_config.cache_config.enable_prefix_caching: - assert context_lens is not None - context_lens_p = context_lens[num_reqs - num_prefills : num_reqs] - assert last_computed_offset is not None - last_computed_offset_p = last_computed_offset[ + assert num_computed_tokens is not None + num_computed_tokens_p = num_computed_tokens[ num_reqs - num_prefills : num_reqs ] - assert current_first_idx is not None - current_first_idx_p = current_first_idx[ + assert block_idx_first_scheduled_token is not None + block_idx_first_scheduled_token_p = block_idx_first_scheduled_token[ num_reqs - num_prefills : num_reqs ] - - num_computed_tokens_p = common_attn_metadata.num_computed_tokens_cpu[ + num_computed_tokens_p_cpu = common_attn_metadata.num_computed_tokens_cpu[ num_reqs - num_prefills : num_reqs ] query_start_loc_p_cpu = ( @@ -290,7 +281,7 @@ def build( last_chunk_indices = [] seqlen_pos = 0 for req_idx in range(num_prefills): - this_num_computed = num_computed_tokens_p[req_idx].item() + this_num_computed = num_computed_tokens_p_cpu[req_idx].item() this_new_tokens = ( query_start_loc_p_cpu[req_idx + 1].item() - query_start_loc_p_cpu[req_idx].item() @@ -338,7 +329,10 @@ def build( compute_causal_conv1d_metadata(query_start_loc_p) ) - elif num_decodes <= self.decode_cudagraph_max_bs: + elif ( + num_decodes <= self.decode_cudagraph_max_bs + and self.compilation_config.full_cuda_graph + ): # Pad state tensor for CUDA graph num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes) self.state_indices_tensor[:num_decodes].copy_( @@ -348,17 +342,21 @@ def build( state_indices_tensor[num_decodes:] = PAD_SLOT_ID if self.vllm_config.cache_config.enable_prefix_caching: - self.current_last_idx[:num_decodes].copy_( - current_last_idx, non_blocking=True + self.block_idx_last_scheduled_token[:num_decodes].copy_( + block_idx_last_scheduled_token, non_blocking=True ) - current_last_idx = self.current_last_idx[:num_input_tokens] - current_last_idx[num_decodes:] = 0 + block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[ + :num_input_tokens + ] + block_idx_last_scheduled_token[num_decodes:] = 0 - self.last_state_idx[:num_decodes].copy_( - last_state_idx, non_blocking=True + self.block_idx_last_computed_token[:num_decodes].copy_( + block_idx_last_computed_token, non_blocking=True ) - last_state_idx = self.last_state_idx[:num_input_tokens] - last_state_idx[num_decodes:] = 0 + block_idx_last_computed_token = self.block_idx_last_computed_token[ + :num_input_tokens + ] + block_idx_last_computed_token[num_decodes:] = 0 attn_metadata = Mamba2AttentionMetadata( num_prefills=num_prefills, @@ -377,10 +375,9 @@ def build( nums_dict=nums_dict, batch_ptr=batch_ptr, token_chunk_offset_ptr=token_chunk_offset_ptr, - current_last_idx=current_last_idx, - current_first_idx_p=current_first_idx_p, - last_state_idx=last_state_idx, - context_lens_p=context_lens_p, - last_computed_offset_p=last_computed_offset_p, + block_idx_last_scheduled_token=block_idx_last_scheduled_token, + block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p, + block_idx_last_computed_token=block_idx_last_computed_token, + num_computed_tokens_p=num_computed_tokens_p, ) return attn_metadata diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 0f71796014db..d624ff1b3dcc 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -584,8 +584,7 @@ def find_longest_cache_hit( # hit_length = len(hit_blocks_other_attn[0]) # * self.other_block_size # so we insert dummy blocks at the beginning: - if i > 0: - computed.extend([block_pool.null_block] * i) + computed.extend([block_pool.null_block] * i) computed.append(cached) break # we just need the last match - early stopping