Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
dddb650
working changes
tdoublep Sep 11, 2025
1dc7a04
Merge branch 'main' into tpa-aligned-mamba
tdoublep Sep 11, 2025
2a7b216
working changes
tdoublep Sep 11, 2025
664a21a
fix bug
tdoublep Sep 11, 2025
6c475d6
fix bug
tdoublep Sep 11, 2025
0d00c69
fix bug
tdoublep Sep 11, 2025
b7ae698
working changes
tdoublep Sep 11, 2025
9b24bce
Fix bugs
tdoublep Sep 11, 2025
e850661
working changes
tdoublep Sep 12, 2025
0d5c3ae
revert some changes
tdoublep Sep 12, 2025
31e05fa
fmt
tdoublep Sep 12, 2025
a8aff97
workign
tdoublep Sep 12, 2025
67db9b4
working changes
tdoublep Sep 12, 2025
d841e82
working changes
tdoublep Sep 12, 2025
908aecb
working changes
tdoublep Sep 12, 2025
af7a246
working changes
tdoublep Sep 12, 2025
7ce2b59
Some test cases working
tdoublep Sep 12, 2025
f950f2e
Fix IMA
tdoublep Sep 12, 2025
75e01c8
Add back autotune config
tdoublep Sep 12, 2025
2698f2e
cleanup
tdoublep Sep 12, 2025
d3f05b7
cleanup
tdoublep Sep 12, 2025
df63503
cleanup
tdoublep Sep 12, 2025
c5edccd
cleanup
tdoublep Sep 12, 2025
712ced1
cleanup
tdoublep Sep 12, 2025
dc85f7e
cleanup
tdoublep Sep 12, 2025
5e827a6
cleanup
tdoublep Sep 12, 2025
42e4b27
cleanup
tdoublep Sep 12, 2025
d859182
cleanup
tdoublep Sep 12, 2025
56b37c2
cleanup
tdoublep Sep 12, 2025
e21b4e6
lint
tdoublep Sep 12, 2025
1b0f793
Fix bug in scan kernel when to reading previous state.
tdoublep Sep 18, 2025
df8f046
Remove BLOCK_H=1 from list of tuneable configurations.
tdoublep Sep 18, 2025
878190e
merge main
tdoublep Sep 19, 2025
a0d277d
Resolve (many) conflicts
tdoublep Sep 26, 2025
2cb4252
Fix a few merge errors
tdoublep Sep 26, 2025
c2a4f8e
Remove chunk_offsets and chunk_indices
tdoublep Sep 26, 2025
76ce99c
clean up seq_idx
tdoublep Sep 26, 2025
51b756b
Fix plamo2
tdoublep Sep 29, 2025
37ffa92
Review comments
tdoublep Sep 29, 2025
29b42cc
Fix plamo2 again
tdoublep Sep 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions vllm/model_executor/layers/mamba/mamba_mixer2.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,9 +502,9 @@ def forward_cuda(
prep_initial_states = attn_metadata.prep_initial_states
chunk_size = attn_metadata.chunk_size
seq_idx_p = attn_metadata.seq_idx_p
chunk_indices_p = attn_metadata.chunk_indices_p
chunk_offsets_p = attn_metadata.chunk_offsets_p
query_start_loc_p = attn_metadata.query_start_loc_p
cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p
last_chunk_indices_p = attn_metadata.last_chunk_indices_p

# 1. Gated MLP's linear projection
projected_states, _ = self.in_proj(hidden_states)
Expand Down Expand Up @@ -634,9 +634,9 @@ def forward_cuda(
z=None,
dt_bias=self.dt_bias,
seq_idx=seq_idx_p,
chunk_indices=chunk_indices_p,
chunk_offsets=chunk_offsets_p,
cu_seqlens=query_start_loc_p,
cu_chunk_seqlens=cu_chunk_seqlen_p,
last_chunk_indices=last_chunk_indices_p,
initial_states=initial_states,
dt_softplus=True,
dt_limit=(0.0, float("inf")),
Expand Down
42 changes: 17 additions & 25 deletions vllm/model_executor/layers/mamba/ops/ssd_bmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

# ruff: noqa: E501,SIM102

import math

import torch

from vllm.triton_utils import tl, triton
Expand Down Expand Up @@ -96,7 +94,7 @@ def _bmm_chunk_fwd_kernel(
a_ptr,
b_ptr,
out_ptr,
seq_idx_ptr,
cu_chunk_seqlens_ptr,
# Matrix dimensions
seqlen,
chunk_size: tl.constexpr,
Expand All @@ -112,7 +110,6 @@ def _bmm_chunk_fwd_kernel(
stride_out_head: tl.int64,
stride_outm: tl.int64,
stride_outn: tl.constexpr,
stride_seq_idx_seqlen: tl.constexpr,
# Meta-parameters
IS_CAUSAL: tl.constexpr,
dot_dtype: tl.constexpr,
Expand All @@ -129,10 +126,12 @@ def _bmm_chunk_fwd_kernel(
if IS_CAUSAL:
if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M:
return
a_ptr += pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head
b_ptr += pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head

seq_idx_ptr += pid_c * chunk_size * stride_seq_idx_seqlen
chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c)
chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1)

a_ptr += chunk_seqlen_start * stride_a_seqlen + pid_h * stride_a_head
b_ptr += chunk_seqlen_start * stride_b_seqlen + pid_h * stride_b_head

offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
Expand All @@ -141,7 +140,7 @@ def _bmm_chunk_fwd_kernel(
offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk +
offs_n[None, :] * stride_b_seqlen)
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start

acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

Expand All @@ -162,16 +161,6 @@ def _bmm_chunk_fwd_kernel(
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)

# Zero out the results that are not from the same request
# in the varlen batch
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
mask=offs_m < chunk_size_limit,
other=-1)
seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen,
mask=offs_n < chunk_size_limit,
other=-2)
acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0)

out = acc.to(out_ptr.dtype.element_ty)
out_ptr += pid_c * stride_out_chunk + pid_h * stride_out_head
out_ptrs = out_ptr + (stride_outm * offs_m[:, None] +
Expand All @@ -182,27 +171,31 @@ def _bmm_chunk_fwd_kernel(
(offs_n[None, :] < chunk_size))


def _bmm_chunk_fwd(a, b, chunk_size, seq_idx, causal=False, output_dtype=None):
def _bmm_chunk_fwd(a,
b,
chunk_size,
cu_chunk_seqlens,
causal=False,
output_dtype=None):
"""
Argument:
a: (seqlen, ngroups, k)
b: (seqlen, ngroups, k)
seq_idx: (seqlen,). out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out.
chunk_size: int
cu_chunk_seq_lens: (nchunks+1,)
causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are
guaranteed to be correct.
Return:
out: (nchunks, ngroups, chunk_size, chunk_size)
"""
seqlen, ngroups, k = a.shape
assert b.shape == a.shape
assert seq_idx is not None
assert seq_idx.shape == (seqlen, )
if a.stride(-1) != 1 and a.stride(0) != 1:
a = a.contiguous()
if b.stride(-1) != 1 and b.stride(0) != 1:
b = b.contiguous()

nchunks = math.ceil(seqlen / chunk_size)
nchunks = len(cu_chunk_seqlens) - 1
# Allocates output.
out_dtype = a.dtype if output_dtype is None else output_dtype
out = torch.empty((nchunks, ngroups, chunk_size, chunk_size),
Expand All @@ -220,7 +213,7 @@ def _bmm_chunk_fwd(a, b, chunk_size, seq_idx, causal=False, output_dtype=None):
a_ptr=a,
b_ptr=b,
out_ptr=out,
seq_idx_ptr=seq_idx,
cu_chunk_seqlens_ptr=cu_chunk_seqlens,
seqlen=seqlen,
chunk_size=chunk_size,
K=k,
Expand All @@ -235,7 +228,6 @@ def _bmm_chunk_fwd(a, b, chunk_size, seq_idx, causal=False, output_dtype=None):
stride_out_head=out.stride(1),
stride_outm=out.stride(-2),
stride_outn=out.stride(-1),
stride_seq_idx_seqlen=seq_idx.stride(0),
IS_CAUSAL=causal,
dot_dtype=dot_dtype,
)
Expand Down
Loading