Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 88 additions & 2 deletions tests/kernels/mamba/test_mamba_ssm_ssd.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import math

import pytest
import torch
import torch.nn.functional as F
Expand All @@ -9,14 +11,98 @@
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
mamba_chunk_scan_combined_varlen)
from vllm.platforms import current_platform
from vllm.v1.attention.backends.mamba2_attn import (
_query_start_loc_to_chunk_indices_offsets)

# Added by the IBM Team, 2024

# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/modules/ssd_minimal.py


# helper function that got removed from https://github.com/vllm-project/vllm/pull/24683
def _query_start_loc_to_chunk_indices_offsets(
query_start_loc: torch.Tensor, chunk_size: int,
total_seqlens: int) -> tuple[torch.Tensor, torch.Tensor]:
"""
Args:
query_start_loc (torch.Tensor): 1D tensor of cumulative sequence
lengths, shape (num_seqs + 1,).
The first element should be 0. Each entry represents the starting
index of a sequence in the flattened token array.
chunk_size (int): The size of each physical mamba chunk
(number of tokens per chunk).
total_seqlens (int): The total number of tokens in the batch.

Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- chunk_indices (torch.Tensor): 1D tensor of indices
indicating the physical chunk for each logical chunk.
- chunk_offsets (torch.Tensor): 1D tensor of offsets
indicating the starting index of each logical chunk within
its physical chunk.

This function computes the chunk indices and offsets for the given
query_start_loc and chunk_size. Both are tensors of integers with length N,
where N is the number of logical (pseudo) chunks.
A logical chunk is a sequence of tokens that are all part of the same
sequence and are all in the same physical mamba chunk.
In other words, a logical chunk changes every time we cross a sequence
boundary or a physical mamba chunk boundary.
Logical chunks are needed to handle batched requests with initial states
(see _state_passing_fwd and _chunk_scan_fwd).
The chunk_indices tensor contains the index of the physical chunk for each
logical chunk.
The chunk_offsets tensor contains the offset (AKA starting index) of the
logical chunk in the physical chunk.

Example:
query_start_loc = [0, 5, 10]
chunk_size = 8
total_seqlens = 10
-> chunk_indices = [0, 0, 1]
-> chunk_offsets = [0, 5, 0]

In this example, we have 2 sequences, each with 5 tokens. The physical
chunk size is 8 tokens.
We have three logical chunks:
- the first logical chunk starts at token 0 in the first physical chunk
and contains all 5 tokens from the first sequence
- the second logical chunk starts at token 5 in the first physical chunk
and contains first 3 tokens from the second sequence
- the third logical chunk starts at token 0 in the second physical chunk
and contains the remaining 2 tokens from the second sequence
"""

cu_seqlens = query_start_loc[1:] # remove prepended 0

# outputs will have length expansion of chunks that do not divide
# chunk_size
N = math.ceil(total_seqlens / chunk_size) + (cu_seqlens[:-1] % chunk_size
> 0).sum()
chunk_indices = torch.arange(N,
dtype=torch.int,
device=query_start_loc.device)
chunk_offsets = torch.zeros((N, ),
dtype=torch.int,
device=query_start_loc.device)

p = 0 # num of insertions
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):

# if does not divide chunk_size, then there is one chunk insertion
p += (s % chunk_size > 0)

# get the dimensions
# - the + 1 for _e is to shift the boundary by one chunk
# - this shifting is not needed if chunk_size divides e
_s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size
> 0)

# adjust indices and offsets
chunk_indices[_s:_e] -= p
chunk_offsets[_s] = s % chunk_size

return chunk_indices, chunk_offsets


# this is the segsum implementation taken from above
def segsum(x):
"""Calculates segment sum."""
Expand Down