diff --git a/vllm/model_executor/layers/mamba/mamba2_metadata.py b/vllm/model_executor/layers/mamba/mamba2_metadata.py index b1c46190403d..5dc9c0f9fcb7 100644 --- a/vllm/model_executor/layers/mamba/mamba2_metadata.py +++ b/vllm/model_executor/layers/mamba/mamba2_metadata.py @@ -24,21 +24,23 @@ class Mamba2Metadata: chunk_offsets: torch.Tensor -def _seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int): +def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor, + chunk_size: int, + total_seqlens: int): - # convert seq_idx to chunk indices and offsets - # - derive the cu_seqlens - _, cu_seqlens = torch.where(seq_idx.diff()) - cu_seqlens += 1 + cu_seqlens = query_start_loc[1:] # remove prepended 0 # outputs will have length expansion of chunks that do not divide # chunk_size - N = math.ceil(seq_idx.shape[-1] / chunk_size) + (cu_seqlens % chunk_size - > 0).sum() - chunk_indices = torch.arange(N, dtype=torch.int, device=seq_idx.device) - chunk_offsets = torch.zeros((N, ), dtype=torch.int, device=seq_idx.device) + N = math.ceil(total_seqlens / chunk_size) + (cu_seqlens[:-1] % chunk_size + > 0).sum() + chunk_indices = torch.arange(N, + dtype=torch.int, + device=query_start_loc.device) + chunk_offsets = torch.zeros((N, ), + dtype=torch.int, + device=query_start_loc.device) - cu_seqlens = cu_seqlens.tolist() + [seq_idx.shape[-1]] p = 0 # num of insertions for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]): @@ -80,13 +82,15 @@ def prepare_mamba2_metadata( seq_idx = None chunk_indices, chunk_offsets = None, None if has_prefill: - seq_idx = torch.zeros_like(input_ids, dtype=torch.int32) - for i, (srt, end) in enumerate( - zip( - attn_metadata.query_start_loc, - attn_metadata.query_start_loc[1:], - )): - seq_idx[srt:end] = i + seqlens = attn_metadata.query_start_loc.diff() + total_seqlens = len(input_ids) + + seq_idx = torch.repeat_interleave(torch.arange( + len(attn_metadata.query_start_loc) - 1, + dtype=torch.int32, + device=attn_metadata.query_start_loc.device), + seqlens, + output_size=total_seqlens) seq_idx.unsqueeze_(0) # compute metadata for chunked prefill. @@ -97,8 +101,11 @@ def prepare_mamba2_metadata( # compute them once at the top level model forward and reuse # them in mamba layers. If not needed, they will be ignored # inside mamba kernels. - chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets( - seq_idx, chunk_size) + chunk_indices, chunk_offsets = \ + _query_start_loc_to_chunk_indices_offsets( + attn_metadata.query_start_loc, + chunk_size=chunk_size, + total_seqlens=total_seqlens) return Mamba2Metadata(has_prefill=has_prefill, has_initial_states=has_initial_states,