Skip to content

Commit fea3e47

Browse files
authored
[Kernel] Chunk-aligned mamba2 (#24683)
1 parent 61a3431 commit fea3e47

File tree

8 files changed

+250
-434
lines changed

8 files changed

+250
-434
lines changed

vllm/model_executor/layers/mamba/mamba_mixer2.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -502,9 +502,9 @@ def forward_cuda(
502502
prep_initial_states = attn_metadata.prep_initial_states
503503
chunk_size = attn_metadata.chunk_size
504504
seq_idx_p = attn_metadata.seq_idx_p
505-
chunk_indices_p = attn_metadata.chunk_indices_p
506-
chunk_offsets_p = attn_metadata.chunk_offsets_p
507505
query_start_loc_p = attn_metadata.query_start_loc_p
506+
cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p
507+
last_chunk_indices_p = attn_metadata.last_chunk_indices_p
508508

509509
# 1. Gated MLP's linear projection
510510
projected_states, _ = self.in_proj(hidden_states)
@@ -634,9 +634,9 @@ def forward_cuda(
634634
z=None,
635635
dt_bias=self.dt_bias,
636636
seq_idx=seq_idx_p,
637-
chunk_indices=chunk_indices_p,
638-
chunk_offsets=chunk_offsets_p,
639637
cu_seqlens=query_start_loc_p,
638+
cu_chunk_seqlens=cu_chunk_seqlen_p,
639+
last_chunk_indices=last_chunk_indices_p,
640640
initial_states=initial_states,
641641
dt_softplus=True,
642642
dt_limit=(0.0, float("inf")),

vllm/model_executor/layers/mamba/ops/ssd_bmm.py

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66

77
# ruff: noqa: E501,SIM102
88

9-
import math
10-
119
import torch
1210

1311
from vllm.triton_utils import tl, triton
@@ -96,7 +94,7 @@ def _bmm_chunk_fwd_kernel(
9694
a_ptr,
9795
b_ptr,
9896
out_ptr,
99-
seq_idx_ptr,
97+
cu_chunk_seqlens_ptr,
10098
# Matrix dimensions
10199
seqlen,
102100
chunk_size: tl.constexpr,
@@ -112,7 +110,6 @@ def _bmm_chunk_fwd_kernel(
112110
stride_out_head: tl.int64,
113111
stride_outm: tl.int64,
114112
stride_outn: tl.constexpr,
115-
stride_seq_idx_seqlen: tl.constexpr,
116113
# Meta-parameters
117114
IS_CAUSAL: tl.constexpr,
118115
dot_dtype: tl.constexpr,
@@ -129,10 +126,12 @@ def _bmm_chunk_fwd_kernel(
129126
if IS_CAUSAL:
130127
if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M:
131128
return
132-
a_ptr += pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head
133-
b_ptr += pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head
134129

135-
seq_idx_ptr += pid_c * chunk_size * stride_seq_idx_seqlen
130+
chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c)
131+
chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1)
132+
133+
a_ptr += chunk_seqlen_start * stride_a_seqlen + pid_h * stride_a_head
134+
b_ptr += chunk_seqlen_start * stride_b_seqlen + pid_h * stride_b_head
136135

137136
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
138137
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
@@ -141,7 +140,7 @@ def _bmm_chunk_fwd_kernel(
141140
offs_k[None, :] * stride_ak)
142141
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk +
143142
offs_n[None, :] * stride_b_seqlen)
144-
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
143+
chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start
145144

146145
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
147146

@@ -162,16 +161,6 @@ def _bmm_chunk_fwd_kernel(
162161
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
163162
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
164163

165-
# Zero out the results that are not from the same request
166-
# in the varlen batch
167-
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
168-
mask=offs_m < chunk_size_limit,
169-
other=-1)
170-
seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen,
171-
mask=offs_n < chunk_size_limit,
172-
other=-2)
173-
acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0)
174-
175164
out = acc.to(out_ptr.dtype.element_ty)
176165
out_ptr += pid_c * stride_out_chunk + pid_h * stride_out_head
177166
out_ptrs = out_ptr + (stride_outm * offs_m[:, None] +
@@ -182,27 +171,31 @@ def _bmm_chunk_fwd_kernel(
182171
(offs_n[None, :] < chunk_size))
183172

184173

185-
def _bmm_chunk_fwd(a, b, chunk_size, seq_idx, causal=False, output_dtype=None):
174+
def _bmm_chunk_fwd(a,
175+
b,
176+
chunk_size,
177+
cu_chunk_seqlens,
178+
causal=False,
179+
output_dtype=None):
186180
"""
187181
Argument:
188182
a: (seqlen, ngroups, k)
189183
b: (seqlen, ngroups, k)
190-
seq_idx: (seqlen,). out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out.
184+
chunk_size: int
185+
cu_chunk_seq_lens: (nchunks+1,)
191186
causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are
192187
guaranteed to be correct.
193188
Return:
194189
out: (nchunks, ngroups, chunk_size, chunk_size)
195190
"""
196191
seqlen, ngroups, k = a.shape
197192
assert b.shape == a.shape
198-
assert seq_idx is not None
199-
assert seq_idx.shape == (seqlen, )
200193
if a.stride(-1) != 1 and a.stride(0) != 1:
201194
a = a.contiguous()
202195
if b.stride(-1) != 1 and b.stride(0) != 1:
203196
b = b.contiguous()
204197

205-
nchunks = math.ceil(seqlen / chunk_size)
198+
nchunks = len(cu_chunk_seqlens) - 1
206199
# Allocates output.
207200
out_dtype = a.dtype if output_dtype is None else output_dtype
208201
out = torch.empty((nchunks, ngroups, chunk_size, chunk_size),
@@ -220,7 +213,7 @@ def _bmm_chunk_fwd(a, b, chunk_size, seq_idx, causal=False, output_dtype=None):
220213
a_ptr=a,
221214
b_ptr=b,
222215
out_ptr=out,
223-
seq_idx_ptr=seq_idx,
216+
cu_chunk_seqlens_ptr=cu_chunk_seqlens,
224217
seqlen=seqlen,
225218
chunk_size=chunk_size,
226219
K=k,
@@ -235,7 +228,6 @@ def _bmm_chunk_fwd(a, b, chunk_size, seq_idx, causal=False, output_dtype=None):
235228
stride_out_head=out.stride(1),
236229
stride_outm=out.stride(-2),
237230
stride_outn=out.stride(-1),
238-
stride_seq_idx_seqlen=seq_idx.stride(0),
239231
IS_CAUSAL=causal,
240232
dot_dtype=dot_dtype,
241233
)

0 commit comments

Comments
 (0)