From e15d2afa6f2e98fcf6f0d466621bd9e56bf83b98 Mon Sep 17 00:00:00 2001 From: Huamin Li <3ericli@gmail.com> Date: Tue, 30 Sep 2025 12:38:09 -0700 Subject: [PATCH] bring _query_start_loc_to_chunk_indices_offsets back Signed-off-by: Huamin Li <3ericli@gmail.com> --- tests/kernels/mamba/test_mamba_ssm_ssd.py | 90 ++++++++++++++++++++++- 1 file changed, 88 insertions(+), 2 deletions(-) diff --git a/tests/kernels/mamba/test_mamba_ssm_ssd.py b/tests/kernels/mamba/test_mamba_ssm_ssd.py index 927af32588e6..58fdf03e3f6f 100644 --- a/tests/kernels/mamba/test_mamba_ssm_ssd.py +++ b/tests/kernels/mamba/test_mamba_ssm_ssd.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math + import pytest import torch import torch.nn.functional as F @@ -9,14 +11,98 @@ from vllm.model_executor.layers.mamba.ops.ssd_combined import ( mamba_chunk_scan_combined_varlen) from vllm.platforms import current_platform -from vllm.v1.attention.backends.mamba2_attn import ( - _query_start_loc_to_chunk_indices_offsets) # Added by the IBM Team, 2024 # Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/modules/ssd_minimal.py +# helper function that got removed from https://github.com/vllm-project/vllm/pull/24683 +def _query_start_loc_to_chunk_indices_offsets( + query_start_loc: torch.Tensor, chunk_size: int, + total_seqlens: int) -> tuple[torch.Tensor, torch.Tensor]: + """ + Args: + query_start_loc (torch.Tensor): 1D tensor of cumulative sequence + lengths, shape (num_seqs + 1,). + The first element should be 0. Each entry represents the starting + index of a sequence in the flattened token array. + chunk_size (int): The size of each physical mamba chunk + (number of tokens per chunk). + total_seqlens (int): The total number of tokens in the batch. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - chunk_indices (torch.Tensor): 1D tensor of indices + indicating the physical chunk for each logical chunk. + - chunk_offsets (torch.Tensor): 1D tensor of offsets + indicating the starting index of each logical chunk within + its physical chunk. + + This function computes the chunk indices and offsets for the given + query_start_loc and chunk_size. Both are tensors of integers with length N, + where N is the number of logical (pseudo) chunks. + A logical chunk is a sequence of tokens that are all part of the same + sequence and are all in the same physical mamba chunk. + In other words, a logical chunk changes every time we cross a sequence + boundary or a physical mamba chunk boundary. + Logical chunks are needed to handle batched requests with initial states + (see _state_passing_fwd and _chunk_scan_fwd). + The chunk_indices tensor contains the index of the physical chunk for each + logical chunk. + The chunk_offsets tensor contains the offset (AKA starting index) of the + logical chunk in the physical chunk. + + Example: + query_start_loc = [0, 5, 10] + chunk_size = 8 + total_seqlens = 10 + -> chunk_indices = [0, 0, 1] + -> chunk_offsets = [0, 5, 0] + + In this example, we have 2 sequences, each with 5 tokens. The physical + chunk size is 8 tokens. + We have three logical chunks: + - the first logical chunk starts at token 0 in the first physical chunk + and contains all 5 tokens from the first sequence + - the second logical chunk starts at token 5 in the first physical chunk + and contains first 3 tokens from the second sequence + - the third logical chunk starts at token 0 in the second physical chunk + and contains the remaining 2 tokens from the second sequence + """ + + cu_seqlens = query_start_loc[1:] # remove prepended 0 + + # outputs will have length expansion of chunks that do not divide + # chunk_size + N = math.ceil(total_seqlens / chunk_size) + (cu_seqlens[:-1] % chunk_size + > 0).sum() + chunk_indices = torch.arange(N, + dtype=torch.int, + device=query_start_loc.device) + chunk_offsets = torch.zeros((N, ), + dtype=torch.int, + device=query_start_loc.device) + + p = 0 # num of insertions + for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]): + + # if does not divide chunk_size, then there is one chunk insertion + p += (s % chunk_size > 0) + + # get the dimensions + # - the + 1 for _e is to shift the boundary by one chunk + # - this shifting is not needed if chunk_size divides e + _s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size + > 0) + + # adjust indices and offsets + chunk_indices[_s:_e] -= p + chunk_offsets[_s] = s % chunk_size + + return chunk_indices, chunk_offsets + + # this is the segsum implementation taken from above def segsum(x): """Calculates segment sum."""