From dddb650c10f6f2015cf84bd2983ff2a5b12a417a Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 11 Sep 2025 05:45:32 -0400 Subject: [PATCH 01/37] working changes Signed-off-by: Thomas Parnell --- .../layers/mamba/ops/ssd_chunk_state.py | 83 +++++++++++++++++-- .../layers/mamba/ops/ssd_combined.py | 10 ++- .../layers/mamba/ops/ssd_state_passing.py | 23 +++-- vllm/v1/attention/backends/mamba2_attn.py | 4 + vllm/v1/core/sched/scheduler.py | 2 +- 5 files changed, 106 insertions(+), 16 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index ad58a9918f03..a9b07660589a 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -14,6 +14,22 @@ from .mamba_ssm import softplus +@triton.jit +def find_seq_idx(query_start_len_ptr, target_idx, num_seqs, + BLOCK_Q: tl.constexpr, use_q_block_mode: tl.constexpr): + left: tl.int32 = 0 + right = num_seqs + while left < right: + mid = (left + right) // 2 + val = tl.load(query_start_len_ptr + mid) + mid_val = val // BLOCK_Q + mid if use_q_block_mode else val + + if mid_val <= target_idx: + left = mid + 1 + else: + right = mid + + return left - 1 @triton.autotune( configs=[ @@ -35,6 +51,7 @@ def _chunk_cumsum_fwd_kernel( dt_bias_ptr, dt_out_ptr, dA_cumsum_ptr, + cu_seqlens_ptr, # Matrix dimension batch, seqlen, @@ -42,6 +59,7 @@ def _chunk_cumsum_fwd_kernel( chunk_size, dt_min, dt_max, + num_seqs, # Strides stride_dt_batch, stride_dt_seqlen, @@ -68,7 +86,23 @@ def _chunk_cumsum_fwd_kernel( # https://github.com/triton-lang/triton/issues/1058 pid_c = tl.program_id(axis=1).to(tl.int64) pid_h = tl.program_id(axis=2) - dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen + + + seq_idx = find_seq_idx(cu_seqlens_ptr, pid_c, num_seqs, chunk_size, True) + + chunk_start_idx = tl.load(cu_seqlens_ptr + seq_idx) // chunk_size + seq_idx + + chunk_local_idx = pid_c - chunk_start_idx + + cur_batch_in_all_start_idx = tl.load(cu_seqlens_ptr + seq_idx) + cur_batch_in_all_stop_idx = tl.load(cu_seqlens_ptr + seq_idx + 1) + cur_batch_query_len = cur_batch_in_all_stop_idx - cur_batch_in_all_start_idx + + # skip any unncessary work + if chunk_local_idx * chunk_size >= cur_batch_query_len: + return + + dt_ptr += pid_b * stride_dt_batch + (cur_batch_in_all_start_idx + chunk_local_idx * chunk_size) * stride_dt_seqlen dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk @@ -81,7 +115,7 @@ def _chunk_cumsum_fwd_kernel( offs_c[None, :] * stride_dt_out_csize) dA_cs_ptrs = dA_cumsum_ptr + (offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize) - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + chunk_size_limit = min(chunk_size, cur_batch_query_len - chunk_local_idx * chunk_size) dt = tl.load(dt_ptrs, mask=(offs_h[:, None] < nheads) & @@ -102,13 +136,13 @@ def _chunk_cumsum_fwd_kernel( 0.0) tl.store(dt_out_ptrs, dt, - mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit)) A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32) dA = dt * A[:, None] dA_cs = tl.cumsum(dA, axis=1) tl.store(dA_cs_ptrs, dA_cs, - mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit)) @triton.autotune( @@ -196,6 +230,7 @@ def _chunk_state_fwd_kernel( states_ptr, dt_ptr, dA_cumsum_ptr, + cu_seqlens_ptr, seq_idx_ptr, # Matrix dimensions hdim, @@ -204,6 +239,7 @@ def _chunk_state_fwd_kernel( batch, seqlen, nheads_ngroups_ratio, + num_seqs, # Strides stride_x_batch, stride_x_seqlen, @@ -241,13 +277,26 @@ def _chunk_state_fwd_kernel( num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) pid_m = tl.program_id(axis=0) // num_pid_n pid_n = tl.program_id(axis=0) % num_pid_n - b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + ( + + seq_idx = find_seq_idx(cu_seqlens_ptr, pid_c, num_seqs, chunk_size, True) + chunk_start_idx = tl.load(cu_seqlens_ptr + seq_idx) // chunk_size + seq_idx + chunk_local_idx = pid_c - chunk_start_idx + cur_batch_in_all_start_idx = tl.load(cu_seqlens_ptr + seq_idx) + cur_batch_in_all_stop_idx = tl.load(cu_seqlens_ptr + seq_idx + 1) + cur_batch_query_len = cur_batch_in_all_stop_idx - cur_batch_in_all_start_idx + + # skip any unncessary work + if chunk_local_idx * chunk_size >= cur_batch_query_len: + return + + seqlen_offset = cur_batch_in_all_start_idx + chunk_local_idx*chunk_size + b_ptr += pid_b * stride_b_batch + seqlen_offset * stride_b_seqlen + ( pid_h // nheads_ngroups_ratio) * stride_b_head - x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + x_ptr += pid_b * stride_x_batch + seqlen_offset * stride_x_seqlen + pid_h * stride_x_head dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head if HAS_SEQ_IDX: - seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + seq_idx_ptr += pid_b * stride_seq_idx_batch + seqlen_offset * stride_seq_idx_seqlen offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) @@ -263,7 +312,8 @@ def _chunk_state_fwd_kernel( if HAS_SEQ_IDX: seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + chunk_size_limit = min(chunk_size, cur_batch_query_len - chunk_local_idx * chunk_size) + if HAS_SEQ_IDX: seq_idx_last = tl.load(seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen) @@ -556,26 +606,35 @@ def _chunk_state_varlen_kernel( def _chunk_cumsum_fwd(dt, A, chunk_size, + cu_seqlens, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))): + batch, seqlen, nheads = dt.shape assert A.shape == (nheads, ) if dt_bias is not None: assert dt_bias.shape == (nheads, ) - nchunks = math.ceil(seqlen / chunk_size) + num_seqs = len(cu_seqlens)-1 + nchunks = seqlen // chunk_size + num_seqs + print("dt.shape: ", dt.shape) + print("A.shape: ", A.shape) + print("nchunks: ", nchunks) + print("type(cu_seqlens): ", type(cu_seqlens)) dt_out = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32) + print("dt_out.shape: ", dt_out.shape) dA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32) + print("dA_cumsum.shape: ", dA_cumsum.shape) grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H'])) with torch.cuda.device(dt.device.index): @@ -585,12 +644,14 @@ def _chunk_cumsum_fwd(dt, dt_bias, dt_out, dA_cumsum, + cu_seqlens, batch, seqlen, nheads, chunk_size, dt_limit[0], dt_limit[1], + num_seqs, dt.stride(0), dt.stride(1), dt.stride(2), @@ -615,6 +676,7 @@ def _chunk_state_fwd(B, x, dt, dA_cumsum, + cu_seqlens, seq_idx=None, states=None, states_in_fp32=True): @@ -634,6 +696,7 @@ def _chunk_state_fwd(B, states = torch.empty((batch, nchunks, nheads, headdim, dstate), device=x.device, dtype=states_dtype) + print("[_chunk_state_fwd] states.shape: ", states.shape) grid = lambda META: ( triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv( dstate, META['BLOCK_SIZE_N']), batch * nchunks, nheads) @@ -644,6 +707,7 @@ def _chunk_state_fwd(B, states, dt, dA_cumsum, + cu_seqlens, seq_idx, headdim, dstate, @@ -651,6 +715,7 @@ def _chunk_state_fwd(B, batch, seqlen, nheads // ngroups, + len(cu_seqlens)-1, x.stride(0), x.stride(1), x.stride(2), diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index fcc5c905bf77..c11f1a4c4b24 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -85,21 +85,27 @@ def _mamba_chunk_scan_combined_fwd(x, # - see the blog and paper for a visualization of the submatrices # which we refer to in the comments below + num_seqs = len(cu_seqlens) - 1 # 1. Compute chunked cumsum of A * dt # - here dt may go through a softplus activation dA_cumsum, dt = _chunk_cumsum_fwd(dt, A, chunk_size, + cu_seqlens, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit) + print("dA_cumsum.shape: ", dA_cumsum.shape) + print("dt.shape: ", dt.shape) + # 2. Compute the state for each intra-chunk # (right term of low-rank factorization of off-diagonal blocks; B terms) states = _chunk_state_fwd(B, x, dt, dA_cumsum, + cu_seqlens, seq_idx=seq_idx, states_in_fp32=True) @@ -117,6 +123,7 @@ def _mamba_chunk_scan_combined_fwd(x, states, final_states = _state_passing_fwd( rearrange(states, "... p n -> ... (p n)"), dA_cumsum, + cu_seqlens, initial_states=rearrange(initial_states, "... p n -> ... (p n)") if initial_states is not None else None, seq_idx=seq_idx, @@ -213,7 +220,8 @@ def mamba_chunk_scan_combined(x, out: Preallocated output tensor state_dtype: The data type of the ssm state """ - + print("-------------------------") + print("cu_seqlens: ", cu_seqlens) if not return_varlen_states: cu_seqlens = None else: diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py index d61c3a8cdbe9..56c61ad3cdd8 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -10,6 +10,7 @@ from vllm.triton_utils import tl, triton +from .ssd_chunk_state import find_seq_idx @triton.autotune( configs=[ @@ -33,11 +34,13 @@ def _state_passing_fwd_kernel( seq_idx_ptr, chunk_offsets_ptr, chunk_meta_num, + cu_seqlens_ptr, # Matrix dimensions dim, nchunks, seqlen, chunk_size, + num_seqs, # Strides stride_states_batch, stride_states_chunk, @@ -102,16 +105,21 @@ def _state_passing_fwd_kernel( prev_seq_idx_chunk_end = 0 logical_chunk_idx = 0 for c in range(nchunks): + + # now a chunk can only contain one sequence + seq_idx_chunk = find_seq_idx(cu_seqlens_ptr, c, num_seqs, chunk_size, True) + new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) dA_cs = tl.load(dA_cs_ptr).to(tl.float32) scale_mask = True if HAS_SEQ_IDX: + # - the seq to pass forward is the one that is flushed to the right # boundary. # - that is given by seq_idx_chunk_end below: the sequence index at the end of the chunk. - seq_idx_chunk_end = tl.load(seq_idx_ptr + (min( - (c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen) + seq_idx_chunk_end = seq_idx_chunk + if HAS_INITSTATES: if IS_CONT_BATCHED and prev_seq_idx_chunk_end != seq_idx_chunk_end: # this means in the current chunk the rightmost flushed seq @@ -130,9 +138,8 @@ def _state_passing_fwd_kernel( # - and subtract the cumsum just before that position from the total cumsum # - first, update the logical chunk index (add the number of sequences in the current physical chunk): # sequence index at the start of the current chunk - seq_idx_chunk_start = tl.load(seq_idx_ptr + - min(c * chunk_size, seqlen) * - stride_seq_idx_seqlen) + seq_idx_chunk_start = seq_idx_chunk + logical_chunk_idx += seq_idx_chunk_end - seq_idx_chunk_start # - load the chunk offset: c_off = tl.load(chunk_offsets_ptr + logical_chunk_idx, @@ -168,6 +175,7 @@ def _state_passing_fwd_kernel( def _state_passing_fwd( states, dA_cumsum, + cu_seqlens, initial_states=None, seq_idx=None, chunk_size=None, @@ -207,6 +215,9 @@ def _state_passing_fwd( device=states.device, dtype=torch.float32) grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads) + + print("[_state_passing_fwd] seq_idx.shape: ", seq_idx.shape) + print("[_state_passing_fwd] chunk_offsets: ", chunk_offsets) with torch.cuda.device(states.device.index): _state_passing_fwd_kernel[grid]( states, @@ -217,10 +228,12 @@ def _state_passing_fwd( seq_idx, chunk_offsets, len(chunk_offsets) if chunk_offsets is not None else 0, + cu_seqlens, dim, nchunks, seqlen if seq_idx is not None else 0, chunk_size, + len(cu_seqlens)-1, states.stride(0), states.stride(1), states.stride(2), diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index 359bad1ea9de..971f32848460 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -127,6 +127,10 @@ class Mamba2AttentionMetadata: chunk_indices_p: Optional[torch.Tensor] chunk_offsets_p: Optional[torch.Tensor] + # tpa + chunk_seqlen_start_p: Optional[torch.Tensor] + chunk_seqlen_end_p: Optional[torch.Tensor] + state_indices_tensor: torch.Tensor # shape: [batch,] # The following attributes are for triton implementation of causal_conv1d diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index ed7c16dc520f..d66fa8967978 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -598,7 +598,7 @@ def schedule(self) -> SchedulerOutput: structured_output_request_ids=structured_output_request_ids, grammar_bitmask=grammar_bitmask, ) - + print(scheduler_output) # NOTE(Kuntai): this function is designed for multiple purposes: # 1. Plan the KV cache store # 2. Wrap up all the KV cache load / save ops into an opaque object From 2a7b2166c223f61e819f8a8302a004ed2a961c65 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 11 Sep 2025 11:34:10 -0400 Subject: [PATCH 02/37] working changes Signed-off-by: Thomas Parnell --- .../layers/mamba/mamba_mixer2.py | 4 + .../layers/mamba/ops/ssd_bmm.py | 25 ++-- .../layers/mamba/ops/ssd_chunk_scan.py | 123 +++++------------- .../layers/mamba/ops/ssd_chunk_state.py | 114 ++++------------ .../layers/mamba/ops/ssd_combined.py | 37 +++--- .../layers/mamba/ops/ssd_state_passing.py | 59 ++------- vllm/v1/attention/backends/mamba2_attn.py | 65 ++++++++- 7 files changed, 168 insertions(+), 259 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 04ebdbca85e5..222f89f2c35b 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -487,6 +487,8 @@ def forward_cuda( seq_idx_p = attn_metadata.seq_idx_p chunk_indices_p = attn_metadata.chunk_indices_p chunk_offsets_p = attn_metadata.chunk_offsets_p + cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p + last_chunk_p = attn_metadata.last_chunk_p else: conv_state = mamba_cache_params.conv_state ssm_state = mamba_cache_params.ssm_state @@ -671,6 +673,8 @@ def forward_cuda( chunk_indices=chunk_indices_p, chunk_offsets=chunk_offsets_p, cu_seqlens=query_start_loc_p, + cu_chunk_seqlens=cu_chunk_seqlen_p, + last_chunk=last_chunk_p, initial_states=initial_states, return_varlen_states=True, return_final_states=False, diff --git a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py index 11ca1255ebfb..e7980af51a74 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py @@ -97,6 +97,7 @@ def _bmm_chunk_fwd_kernel( b_ptr, out_ptr, seq_idx_ptr, + cu_chunk_seqlens_ptr, # Matrix dimensions seqlen, chunk_size, @@ -135,8 +136,12 @@ def _bmm_chunk_fwd_kernel( if IS_CAUSAL: if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M: return - a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head - b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head + + chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c) + chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1) + + a_ptr += pid_b * stride_a_batch + chunk_seqlen_start * stride_a_seqlen + pid_h * stride_a_head + b_ptr += pid_b * stride_b_batch + chunk_seqlen_start * chunk_size * stride_b_seqlen + pid_h * stride_b_head if HAS_SEQ_IDX: seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen @@ -147,7 +152,7 @@ def _bmm_chunk_fwd_kernel( offs_k[None, :] * stride_ak) b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen) - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + chunk_size_limit = chunk_seqlen_end acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): @@ -165,15 +170,7 @@ def _bmm_chunk_fwd_kernel( offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - if HAS_SEQ_IDX: - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, - mask=offs_m < chunk_size_limit, - other=-1) - seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, - mask=offs_n < chunk_size_limit, - other=-2) - acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0) + out = acc.to(out_ptr.dtype.element_ty) out_ptr += pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head @@ -188,6 +185,7 @@ def _bmm_chunk_fwd_kernel( def _bmm_chunk_fwd(a, b, chunk_size, + cu_chunk_seqlens, seq_idx=None, causal=False, output_dtype=None): @@ -214,7 +212,7 @@ def _bmm_chunk_fwd(a, a = a.contiguous() if b.stride(-1) != 1 and b.stride(1) != 1: b = b.contiguous() - nchunks = math.ceil(seqlen / chunk_size) + nchunks = len(cu_chunk_seqlens)-1 # Allocates output. out_dtype = a.dtype if output_dtype is None else output_dtype out = torch.empty( @@ -236,6 +234,7 @@ def _bmm_chunk_fwd(a, b, out, seq_idx, + cu_chunk_seqlens, seqlen, chunk_size, k, diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index fb8350e191c9..9f23b4103a97 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -125,6 +125,7 @@ def _chunk_scan_fwd_kernel( chunk_indices_ptr, chunk_offsets_ptr, chunk_meta_num, + cu_chunk_seqlens_ptr, # Matrix dimensions chunk_size, hdim, @@ -190,12 +191,11 @@ def _chunk_scan_fwd_kernel( pid_bc = tl.program_id(axis=1).to(tl.int64) pid_c = pid_bc // batch pid_b = pid_bc - pid_c * batch - if not HAS_INITSTATES: - c_idx = pid_c - c_off = 0 - else: - c_idx = tl.load(chunk_indices_ptr + pid_c, mask=pid_c > -1, other=0) - c_off = tl.load(chunk_offsets_ptr + pid_c, mask=pid_c > -1, other=0) + + # logical chunks = physical chunks + # always start from beginning + c_idx = pid_c + c_off = 0 pid_h = tl.program_id(axis=2) num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) @@ -203,10 +203,14 @@ def _chunk_scan_fwd_kernel( pid_n = tl.program_id(axis=0) % num_pid_n cb_ptr += pid_b * stride_cb_batch + c_idx * stride_cb_chunk + ( pid_h // nheads_ngroups_ratio) * stride_cb_head - x_ptr += pid_b * stride_x_batch + c_idx * chunk_size * stride_x_seqlen + pid_h * stride_x_head + + chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c) + chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1) + + x_ptr += pid_b * stride_x_batch + chunk_seqlen_start * stride_x_seqlen + pid_h * stride_x_head dt_ptr += pid_b * stride_dt_batch + c_idx * stride_dt_chunk + pid_h * stride_dt_head dA_cumsum_ptr += pid_b * stride_dA_cs_batch + c_idx * stride_dA_cs_chunk + pid_h * stride_dA_cs_head - C_ptr += pid_b * stride_C_batch + c_idx * chunk_size * stride_C_seqlen + ( + C_ptr += pid_b * stride_C_batch + chunk_seqlen_start * stride_C_seqlen + ( pid_h // nheads_ngroups_ratio) * stride_C_head # M-block offsets and prev states @@ -216,94 +220,30 @@ def _chunk_scan_fwd_kernel( prev_states_hdim = stride_states_hdim prev_states_dstate = stride_states_dstate - chunk_size_limit = min(chunk_size, seqlen - c_idx * chunk_size) + chunk_size_limit = chunk_seqlen_end + if HAS_SEQ_IDX: - seq_idx_ptr += pid_b * stride_seq_idx_batch + c_idx * chunk_size * stride_seq_idx_seqlen + seq_idx_ptr += pid_b * stride_seq_idx_batch + chunk_seqlen_start * stride_seq_idx_seqlen - # - we only need seq_idx_prev to be aligned to chunk boundary + # current sequence index + seq_idx = tl.load(seq_idx_ptr) + + # previous sequence index seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=c_idx >= 1, other=0) - if HAS_INITSTATES: - # if there are init states, we only need seq_idx_m to point - # what is the current seq_idx - - # get current seq idx - if (pid_m * BLOCK_SIZE_M + c_off) < chunk_size_limit: - seq_idx_m = tl.load( - seq_idx_ptr + - (pid_m * BLOCK_SIZE_M + c_off) * stride_seq_idx_seqlen, ) - - # - recall that in ssd_state_passing, for the case c_off == 0 - # i.e., the very first sequence, we made states_ptr hold its initial state - # so this edge case is taken care of - if ((c_off == 0) and - (seq_idx_prev != seq_idx_m - ) # if a seq is changed exactly on boundary - or (c_off > 0) # implies a new example (pseudo chunk) - ): - - # - replace prev_states_ptr with init_states - prev_states_ptr = initstates_ptr + seq_idx_m * stride_init_states_batch + pid_h * stride_init_states_head - prev_states_hdim = stride_init_states_hdim # override strides - prev_states_dstate = stride_init_states_dstate + if HAS_INITSTATES and (seq_idx != seq_idx_prev): + # - replace prev_states_ptr with init_states + prev_states_ptr = initstates_ptr + seq_idx * stride_init_states_batch + pid_h * stride_init_states_head + prev_states_hdim = stride_init_states_hdim # override strides + prev_states_dstate = stride_init_states_dstate offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) - # - handle chunk state limit - if HAS_INITSTATES: - - # have to split this if otherwise compilation will have problems - dA_cs_m_boundary = 0.0 - - # get the c_idx for the next (logica) chunk - c_idx_n = tl.load( - chunk_indices_ptr + (pid_c + 1), - mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num, - other=-1 # to trigger different chunk - ) - - # - there are things to consider - # A. if c_off > 0 then we need to move the dA_cs boundary to ensure correct - # contribution of past states - # B. if c_off_n < chunk_size_limit, then we need to adjust this so as not to - # encroach into the next sequence, where c_off_n is the offset of the next - # (logical) chunk. - # An equivalent check for B is c_idx == c_idx_n, where there is repetition in - # (logical) chunk indices. - - if (c_idx == c_idx_n) or c_off > 0: - - # get the next offset - c_off_n = tl.load(chunk_offsets_ptr + (pid_c + 1), - mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num, - other=chunk_size) - - # in this case, adjust down the chunk_size_limit - if c_idx == c_idx_n: - chunk_size_limit = min(c_off_n, chunk_size_limit) - - # get the cs at the offset boundary - # - c_off == 0 is a passthrough - # - We need dA_cs at the boundary, defined by c_off - no need - # to increase pointer by pid_m (it is a constant offset, - # i.e. the same for all blocks) - dA_cs_m_boundary = tl.load( - dA_cumsum_ptr + (c_off - 1) * stride_dA_cs_csize, - mask=(((c_off - 1) > -1) and ((c_off) < chunk_size)), - other=0.0).to(tl.float32) - - if HAS_SEQ_IDX: - # - handle seq idx when HAS_INITSTATES==False - if not HAS_INITSTATES: - seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, - mask=offs_m < chunk_size_limit, - other=-1) - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) # Without the if (pid_c > -1), with Triton 2.1.0, I get @@ -323,12 +263,12 @@ def _chunk_scan_fwd_kernel( if not HAS_INITSTATES: # - this is for continuous batching where there is no init states - scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), + scale_m = tl.where(seq_idx == seq_idx_prev, tl.exp(dA_cs_m), 0.0) else: # - if there is initstates, we will rely on prev_states, no zeroing # required. - scale_m = tl.exp(dA_cs_m - dA_cs_m_boundary) + scale_m = tl.exp(dA_cs_m) else: scale_m = tl.exp(dA_cs_m) if BLOCK_SIZE_DSTATE <= 128: @@ -416,7 +356,7 @@ def _chunk_scan_fwd_kernel( acc += x_residual * D if HAS_Z: - out_x_ptr += pid_b * stride_out_batch + c_idx * chunk_size * stride_out_seqlen + pid_h * stride_out_head + out_x_ptr += pid_b * stride_out_batch + chunk_seqlen_start * stride_out_seqlen + pid_h * stride_out_head out_x_ptrs = out_x_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :]) tl.store(out_x_ptrs, @@ -424,7 +364,7 @@ def _chunk_scan_fwd_kernel( mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim)) - z_ptr += pid_b * stride_z_batch + c_idx * chunk_size * stride_z_seqlen + pid_h * stride_z_head + z_ptr += pid_b * stride_z_batch + chunk_seqlen_start * stride_z_seqlen + pid_h * stride_z_head z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :]) z = tl.load(z_ptrs, @@ -433,7 +373,7 @@ def _chunk_scan_fwd_kernel( other=0.0).to(tl.float32) acc *= z * tl.sigmoid(z) - out_ptr += pid_b * stride_out_batch + c_idx * chunk_size * stride_out_seqlen + pid_h * stride_out_head + out_ptr += pid_b * stride_out_batch + chunk_seqlen_start * stride_out_seqlen + pid_h * stride_out_head out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] * stride_out_hdim) tl.store(out_ptrs, @@ -449,6 +389,7 @@ def _chunk_scan_fwd( dA_cumsum, C, states, + cu_chunk_seqlens, D=None, z=None, seq_idx=None, @@ -495,8 +436,7 @@ def _chunk_scan_fwd( grid = lambda META: ( triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv( - headdim, META['BLOCK_SIZE_N']), batch * nchunks - if chunk_offsets is None else len(chunk_offsets), nheads) + headdim, META['BLOCK_SIZE_N']), batch * nchunks, nheads) z_strides = ((z.stride(0), z.stride(1), z.stride(2), z.stride(3)) if z is not None else (0, 0, 0, 0)) _chunk_scan_fwd_kernel[grid]( @@ -515,6 +455,7 @@ def _chunk_scan_fwd( chunk_indices, chunk_offsets, len(chunk_indices) if chunk_indices is not None else 0, + cu_chunk_seqlens, chunk_size, headdim, dstate, diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index 3379054c17b9..f68c8ca7d5e2 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -14,22 +14,6 @@ from .mamba_ssm import softplus -@triton.jit -def find_seq_idx(query_start_len_ptr, target_idx, num_seqs, - BLOCK_Q: tl.constexpr, use_q_block_mode: tl.constexpr): - left: tl.int32 = 0 - right = num_seqs - while left < right: - mid = (left + right) // 2 - val = tl.load(query_start_len_ptr + mid) - mid_val = val // BLOCK_Q + mid if use_q_block_mode else val - - if mid_val <= target_idx: - left = mid + 1 - else: - right = mid - - return left - 1 @triton.autotune( configs=[ @@ -51,7 +35,7 @@ def _chunk_cumsum_fwd_kernel( dt_bias_ptr, dt_out_ptr, dA_cumsum_ptr, - cu_seqlens_ptr, + cu_chunk_seqlens_ptr, # Matrix dimension batch, seqlen, @@ -59,7 +43,6 @@ def _chunk_cumsum_fwd_kernel( chunk_size, dt_min, dt_max, - num_seqs, # Strides stride_dt_batch, stride_dt_seqlen, @@ -87,22 +70,10 @@ def _chunk_cumsum_fwd_kernel( pid_c = tl.program_id(axis=1).to(tl.int64) pid_h = tl.program_id(axis=2) + chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c) + chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1) - seq_idx = find_seq_idx(cu_seqlens_ptr, pid_c, num_seqs, chunk_size, True) - - chunk_start_idx = tl.load(cu_seqlens_ptr + seq_idx) // chunk_size + seq_idx - - chunk_local_idx = pid_c - chunk_start_idx - - cur_batch_in_all_start_idx = tl.load(cu_seqlens_ptr + seq_idx) - cur_batch_in_all_stop_idx = tl.load(cu_seqlens_ptr + seq_idx + 1) - cur_batch_query_len = cur_batch_in_all_stop_idx - cur_batch_in_all_start_idx - - # skip any unncessary work - if chunk_local_idx * chunk_size >= cur_batch_query_len: - return - - dt_ptr += pid_b * stride_dt_batch + (cur_batch_in_all_start_idx + chunk_local_idx * chunk_size) * stride_dt_seqlen + dt_ptr += pid_b * stride_dt_batch + chunk_seqlen_start * stride_dt_seqlen dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk @@ -115,11 +86,10 @@ def _chunk_cumsum_fwd_kernel( offs_c[None, :] * stride_dt_out_csize) dA_cs_ptrs = dA_cumsum_ptr + (offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize) - chunk_size_limit = min(chunk_size, cur_batch_query_len - chunk_local_idx * chunk_size) dt = tl.load(dt_ptrs, mask=(offs_h[:, None] < nheads) & - (offs_c[None, :] < chunk_size_limit), + (offs_c[None, :] < chunk_seqlen_end), other=0.0).to(tl.float32) if HAS_DT_BIAS: dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, @@ -132,17 +102,17 @@ def _chunk_cumsum_fwd_kernel( # dt = tl.clamp(dt, dt_min, dt_max) dt = tl.minimum(tl.maximum(dt, dt_min), dt_max) dt = tl.where( - (offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, + (offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_seqlen_end), dt, 0.0) tl.store(dt_out_ptrs, dt, - mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit)) + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32) dA = dt * A[:, None] dA_cs = tl.cumsum(dA, axis=1) tl.store(dA_cs_ptrs, dA_cs, - mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit)) + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) @triton.autotune( @@ -230,8 +200,8 @@ def _chunk_state_fwd_kernel( states_ptr, dt_ptr, dA_cumsum_ptr, - cu_seqlens_ptr, seq_idx_ptr, + cu_chunk_seqlens_ptr, # Matrix dimensions hdim, dstate, @@ -239,7 +209,6 @@ def _chunk_state_fwd_kernel( batch, seqlen, nheads_ngroups_ratio, - num_seqs, # Strides stride_x_batch, stride_x_seqlen, @@ -278,25 +247,15 @@ def _chunk_state_fwd_kernel( pid_m = tl.program_id(axis=0) // num_pid_n pid_n = tl.program_id(axis=0) % num_pid_n - seq_idx = find_seq_idx(cu_seqlens_ptr, pid_c, num_seqs, chunk_size, True) - chunk_start_idx = tl.load(cu_seqlens_ptr + seq_idx) // chunk_size + seq_idx - chunk_local_idx = pid_c - chunk_start_idx - cur_batch_in_all_start_idx = tl.load(cu_seqlens_ptr + seq_idx) - cur_batch_in_all_stop_idx = tl.load(cu_seqlens_ptr + seq_idx + 1) - cur_batch_query_len = cur_batch_in_all_stop_idx - cur_batch_in_all_start_idx - # skip any unncessary work - if chunk_local_idx * chunk_size >= cur_batch_query_len: - return + chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c) + chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1) - seqlen_offset = cur_batch_in_all_start_idx + chunk_local_idx*chunk_size - b_ptr += pid_b * stride_b_batch + seqlen_offset * stride_b_seqlen + ( + b_ptr += pid_b * stride_b_batch + chunk_seqlen_start * stride_b_seqlen + ( pid_h // nheads_ngroups_ratio) * stride_b_head - x_ptr += pid_b * stride_x_batch + seqlen_offset * stride_x_seqlen + pid_h * stride_x_head + x_ptr += pid_b * stride_x_batch + chunk_seqlen_start * stride_x_seqlen + pid_h * stride_x_head dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head - if HAS_SEQ_IDX: - seq_idx_ptr += pid_b * stride_seq_idx_batch + seqlen_offset * stride_seq_idx_seqlen offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) @@ -306,17 +265,13 @@ def _chunk_state_fwd_kernel( b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen) dt_ptrs = dt_ptr + offs_k * stride_dt_csize + dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) - dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize - if HAS_SEQ_IDX: - seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen - chunk_size_limit = min(chunk_size, cur_batch_query_len - chunk_local_idx * chunk_size) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize - if HAS_SEQ_IDX: - seq_idx_last = tl.load(seq_idx_ptr + - (chunk_size_limit - 1) * stride_seq_idx_seqlen) + chunk_size_limit = chunk_seqlen_end acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, chunk_size_limit, BLOCK_SIZE_K): @@ -331,17 +286,11 @@ def _chunk_state_fwd_kernel( dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) - if HAS_SEQ_IDX: - seq_idx_k = tl.load(seq_idx_ptrs, - mask=offs_k < chunk_size_limit - k, - other=-1) dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) - if not HAS_SEQ_IDX: - scale = tl.exp(dA_cs_last - dA_cs_k) * dt_k - else: - scale = tl.where(seq_idx_k == seq_idx_last, - tl.exp(dA_cs_last - dA_cs_k) * dt_k, 0.0) + + scale = tl.exp(dA_cs_last - dA_cs_k) * dt_k + b *= scale[:, None] b = b.to(x_ptr.dtype.element_ty) acc += tl.dot(x, b) @@ -349,8 +298,7 @@ def _chunk_state_fwd_kernel( b_ptrs += BLOCK_SIZE_K * stride_b_seqlen dt_ptrs += BLOCK_SIZE_K * stride_dt_csize dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize - if HAS_SEQ_IDX: - seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen + states = acc.to(states_ptr.dtype.element_ty) states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head @@ -606,35 +554,27 @@ def _chunk_state_varlen_kernel( def _chunk_cumsum_fwd(dt, A, chunk_size, - cu_seqlens, + cu_chunk_seqlens, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))): - batch, seqlen, nheads = dt.shape assert A.shape == (nheads, ) if dt_bias is not None: assert dt_bias.shape == (nheads, ) - num_seqs = len(cu_seqlens)-1 - nchunks = seqlen // chunk_size + num_seqs - print("dt.shape: ", dt.shape) - print("A.shape: ", A.shape) - print("nchunks: ", nchunks) - print("type(cu_seqlens): ", type(cu_seqlens)) + nchunks = cu_chunk_seqlens.shape[0]-1 dt_out = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32) - print("dt_out.shape: ", dt_out.shape) dA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32) - print("dA_cumsum.shape: ", dA_cumsum.shape) grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H'])) with torch.cuda.device(dt.device.index): @@ -644,14 +584,13 @@ def _chunk_cumsum_fwd(dt, dt_bias, dt_out, dA_cumsum, - cu_seqlens, + cu_chunk_seqlens, batch, seqlen, nheads, chunk_size, dt_limit[0], dt_limit[1], - num_seqs, dt.stride(0), dt.stride(1), dt.stride(2), @@ -676,7 +615,7 @@ def _chunk_state_fwd(B, x, dt, dA_cumsum, - cu_seqlens, + cu_chunk_seqlens, seq_idx=None, states=None, states_in_fp32=True): @@ -696,7 +635,9 @@ def _chunk_state_fwd(B, states = torch.empty((batch, nchunks, nheads, headdim, dstate), device=x.device, dtype=states_dtype) + print("[_chunk_state_fwd] states.shape: ", states.shape) + grid = lambda META: ( triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv( dstate, META['BLOCK_SIZE_N']), batch * nchunks, nheads) @@ -707,7 +648,7 @@ def _chunk_state_fwd(B, states, dt, dA_cumsum, - cu_seqlens, + cu_chunk_seqlens, seq_idx, headdim, dstate, @@ -715,7 +656,6 @@ def _chunk_state_fwd(B, batch, seqlen, nheads // ngroups, - len(cu_seqlens)-1, x.stride(0), x.stride(1), x.stride(2), diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index c11f1a4c4b24..c3f8db39e7ae 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -39,6 +39,8 @@ def _mamba_chunk_scan_combined_fwd(x, chunk_indices=None, chunk_offsets=None, cu_seqlens=None, + cu_chunk_seqlens=None, + last_chunk=None, dt_softplus=False, dt_limit=(0.0, float("inf")), state_dtype=None, @@ -85,27 +87,23 @@ def _mamba_chunk_scan_combined_fwd(x, # - see the blog and paper for a visualization of the submatrices # which we refer to in the comments below - num_seqs = len(cu_seqlens) - 1 # 1. Compute chunked cumsum of A * dt # - here dt may go through a softplus activation dA_cumsum, dt = _chunk_cumsum_fwd(dt, A, chunk_size, - cu_seqlens, + cu_chunk_seqlens, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit) - print("dA_cumsum.shape: ", dA_cumsum.shape) - print("dt.shape: ", dt.shape) - # 2. Compute the state for each intra-chunk # (right term of low-rank factorization of off-diagonal blocks; B terms) states = _chunk_state_fwd(B, x, dt, dA_cumsum, - cu_seqlens, + cu_chunk_seqlens, seq_idx=seq_idx, states_in_fp32=True) @@ -123,7 +121,7 @@ def _mamba_chunk_scan_combined_fwd(x, states, final_states = _state_passing_fwd( rearrange(states, "... p n -> ... (p n)"), dA_cumsum, - cu_seqlens, + cu_chunk_seqlens, initial_states=rearrange(initial_states, "... p n -> ... (p n)") if initial_states is not None else None, seq_idx=seq_idx, @@ -138,6 +136,7 @@ def _mamba_chunk_scan_combined_fwd(x, CB = _bmm_chunk_fwd(C, B, chunk_size, + cu_chunk_seqlens, seq_idx=seq_idx, output_dtype=torch.float32) @@ -158,6 +157,7 @@ def _mamba_chunk_scan_combined_fwd(x, dA_cumsum, C, states, + cu_chunk_seqlens, D=D, z=z, seq_idx=seq_idx, @@ -170,16 +170,11 @@ def _mamba_chunk_scan_combined_fwd(x, return out_x, dt, dA_cumsum, states, final_states else: assert batch == 1, "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1" - varlen_states = chunk_state_varlen( - B.squeeze(0), - x.squeeze(0), - dt.squeeze(0), - dA_cumsum.squeeze(0), - cu_seqlens, - states.squeeze(0), - initial_states=initial_states, - ) - return out_x, dt, dA_cumsum, states, final_states, varlen_states + print("last_chunk: ", last_chunk) + print(states.shape) + varlen_states = states[last_chunk] + print(varlen_states.shape) + return out_x, dt, dA_cumsum, states, final_states, states def mamba_chunk_scan_combined(x, @@ -196,6 +191,8 @@ def mamba_chunk_scan_combined(x, chunk_indices=None, chunk_offsets=None, cu_seqlens=None, + cu_chunk_seqlens=None, + last_chunk=None, dt_softplus=False, dt_limit=(0.0, float("inf")), out=None, @@ -216,12 +213,12 @@ def mamba_chunk_scan_combined(x, initial_states: (batch, nheads, headdim, dstate) seq_idx: (batch, seqlen) cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True + cu_chunk_seqlens: (num_chunks + 1) dt_softplus: Whether to apply softplus to dt out: Preallocated output tensor state_dtype: The data type of the ssm state """ - print("-------------------------") - print("cu_seqlens: ", cu_seqlens) + if not return_varlen_states: cu_seqlens = None else: @@ -241,6 +238,8 @@ def mamba_chunk_scan_combined(x, chunk_indices=chunk_indices, chunk_offsets=chunk_offsets, cu_seqlens=cu_seqlens, + cu_chunk_seqlens=cu_chunk_seqlens, + last_chunk=last_chunk, dt_softplus=dt_softplus, dt_limit=dt_limit, out=out, diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py index 56c61ad3cdd8..4d2bea947be5 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -10,7 +10,6 @@ from vllm.triton_utils import tl, triton -from .ssd_chunk_state import find_seq_idx @triton.autotune( configs=[ @@ -34,13 +33,12 @@ def _state_passing_fwd_kernel( seq_idx_ptr, chunk_offsets_ptr, chunk_meta_num, - cu_seqlens_ptr, + cu_chunk_seqlens_ptr, # Matrix dimensions dim, nchunks, seqlen, chunk_size, - num_seqs, # Strides stride_states_batch, stride_states_chunk, @@ -102,12 +100,12 @@ def _state_passing_fwd_kernel( tl.store(out_ptrs, states, mask=offs_m < dim) out_ptrs += stride_out_chunk - prev_seq_idx_chunk_end = 0 - logical_chunk_idx = 0 + + prev_seq_idx = 0 for c in range(nchunks): - # now a chunk can only contain one sequence - seq_idx_chunk = find_seq_idx(cu_seqlens_ptr, c, num_seqs, chunk_size, True) + + chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + c) new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) @@ -115,51 +113,24 @@ def _state_passing_fwd_kernel( scale_mask = True if HAS_SEQ_IDX: - # - the seq to pass forward is the one that is flushed to the right - # boundary. - # - that is given by seq_idx_chunk_end below: the sequence index at the end of the chunk. - seq_idx_chunk_end = seq_idx_chunk + seq_idx = tl.load(seq_idx_ptr + chunk_seqlen_start * stride_seq_idx_seqlen) if HAS_INITSTATES: - if IS_CONT_BATCHED and prev_seq_idx_chunk_end != seq_idx_chunk_end: + if IS_CONT_BATCHED and prev_seq_idx != seq_idx: # this means in the current chunk the rightmost flushed seq # has changed. # - so we do not propagate the state from previous chunk # - but rather we load that sequence's init state - initstates_ptrs = initstates_ptr + seq_idx_chunk_end * stride_initstates_batch + initstates_ptrs = initstates_ptr + seq_idx * stride_initstates_batch # - update state with seq_idx_new's init state states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) - - # - we need to consider the cumsum only of the last sequence in the chunk - # - find its starting position (given by c_off of the logical chunk index) - # - and subtract the cumsum just before that position from the total cumsum - # - first, update the logical chunk index (add the number of sequences in the current physical chunk): - # sequence index at the start of the current chunk - seq_idx_chunk_start = seq_idx_chunk - - logical_chunk_idx += seq_idx_chunk_end - seq_idx_chunk_start - # - load the chunk offset: - c_off = tl.load(chunk_offsets_ptr + logical_chunk_idx, - mask=logical_chunk_idx < chunk_meta_num, - other=0) - # - if offset is 0, then the sequence starts at the beginning of the chunk, and we don't need to subtract anything - if c_off > 0: - # - dA_cs_ptr currently points to the cumsum at the end of the chunk - subtract the chunk size and add the offset - dA_cs_boundary = tl.load( - dA_cs_ptr - (chunk_size - 1) * stride_dA_cs_csize + - (c_off - 1) * stride_dA_cs_csize, - mask=(c_off - 1) > -1 and c_off < chunk_size, - other=0.0) - dA_cs -= dA_cs_boundary - - # - increment logical chunk index for every physical chunk - logical_chunk_idx += 1 else: - scale_mask = seq_idx_chunk_end == prev_seq_idx_chunk_end - prev_seq_idx_chunk_end = seq_idx_chunk_end + scale_mask = seq_idx == prev_seq_idx + + prev_seq_idx = seq_idx scale = tl.where(scale_mask, tl.exp(dA_cs), 0.0) states = scale * states + new_states @@ -175,7 +146,7 @@ def _state_passing_fwd_kernel( def _state_passing_fwd( states, dA_cumsum, - cu_seqlens, + cu_chunk_seqlens, initial_states=None, seq_idx=None, chunk_size=None, @@ -215,9 +186,6 @@ def _state_passing_fwd( device=states.device, dtype=torch.float32) grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads) - - print("[_state_passing_fwd] seq_idx.shape: ", seq_idx.shape) - print("[_state_passing_fwd] chunk_offsets: ", chunk_offsets) with torch.cuda.device(states.device.index): _state_passing_fwd_kernel[grid]( states, @@ -228,12 +196,11 @@ def _state_passing_fwd( seq_idx, chunk_offsets, len(chunk_offsets) if chunk_offsets is not None else 0, - cu_seqlens, + cu_chunk_seqlens, dim, nchunks, seqlen if seq_idx is not None else 0, chunk_size, - len(cu_seqlens)-1, states.stride(0), states.stride(1), states.stride(2), diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index 971f32848460..91fb63dc486c 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -14,7 +14,7 @@ from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, split_decodes_and_prefills) from vllm.v1.kv_cache_interface import AttentionSpec - +from vllm.utils import cdiv def _query_start_loc_to_chunk_indices_offsets( query_start_loc: torch.Tensor, chunk_size: int, @@ -128,8 +128,8 @@ class Mamba2AttentionMetadata: chunk_offsets_p: Optional[torch.Tensor] # tpa - chunk_seqlen_start_p: Optional[torch.Tensor] - chunk_seqlen_end_p: Optional[torch.Tensor] + cu_chunk_seqlen_p: Optional[torch.Tensor] + last_chunk_p: Optional[torch.Tensor] state_indices_tensor: torch.Tensor # shape: [batch,] @@ -165,6 +165,10 @@ def build(self, has_initial_states_p = None prep_initial_states = False + + cu_chunk_seqlen_p = None + last_chunk_p = None + state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( @@ -172,6 +176,11 @@ def build(self, common_attn_metadata, decode_threshold=self.reorder_batch_threshold)) + print("num_decodes: ", num_decodes) + print("num_prefills: ", num_prefills) + print("num_decode_tokens: ", num_decode_tokens) + print("num_prefill_tokens: ", num_prefill_tokens) + # Compute seq_idx, chunk_indices and chunk_offsets for prefill only if num_prefills > 0: #[batch,] @@ -182,9 +191,11 @@ def build(self, has_initial_states_p = has_initial_states_cpu.to( query_start_loc.device) + query_start_loc_p = common_attn_metadata.query_start_loc[ -num_prefills - 1:] - num_decode_tokens + seq_idx_p = torch.repeat_interleave(torch.arange( num_prefills, dtype=torch.int32, @@ -193,6 +204,52 @@ def build(self, output_size=num_prefill_tokens) seq_idx_p.unsqueeze_(0) + + num_computed_tokens_p = common_attn_metadata.num_computed_tokens_cpu[num_reqs - num_prefills:num_reqs] + query_start_loc_p_cpu = common_attn_metadata.query_start_loc_cpu[ + -num_prefills - 1:] - num_decode_tokens + + print("num_computed_tokens_p: ", num_computed_tokens_p) + print("query_start_loc_p: ", query_start_loc_p) + + cu_chunk_seqlen = [] + last_chunk = [] + seqlen_pos = 0 + for req_idx in range(num_prefills): + this_num_computed = num_computed_tokens_p[req_idx].item() + this_new_tokens = query_start_loc_p_cpu[req_idx+1].item() - query_start_loc_p_cpu[req_idx].item() + print(req_idx, this_num_computed, this_new_tokens) + + # if computed tokens are not chunk-aligned, use the first + # chunk to finish it off + # TODO(tdoublep): I guess we need block size actually? + if this_num_computed % self.chunk_size != 0: + cu_chunk_seqlen.append(seqlen_pos) + # how many tokens to finish the chunk? + chunk_len = cdiv(this_num_computed, self.chunk_size)*self.chunk_size - this_num_computed + # we can only use at most this_new_tokens + chunk_len = min(chunk_len, this_new_tokens) + seqlen_pos += chunk_len + this_new_tokens -= chunk_len + + n_chunks = cdiv(this_new_tokens, self.chunk_size) + for chunk in range(n_chunks): + cu_chunk_seqlen.append(seqlen_pos) + chunk_len = min(self.chunk_size, this_new_tokens) + seqlen_pos += chunk_len + this_new_tokens -= chunk_len + + assert this_new_tokens == 0 + last_chunk.append(len(cu_chunk_seqlen)-1) + + cu_chunk_seqlen.append(seqlen_pos) + + cu_chunk_seqlen_p = torch.as_tensor(cu_chunk_seqlen, device=query_start_loc.device, dtype=torch.int32) + last_chunk_p = torch.as_tensor(last_chunk, device=query_start_loc.device, dtype=torch.int32) + + print("cu_chunk_seqlen: ", cu_chunk_seqlen) + print("cu_chunk_seqlen_p: ", cu_chunk_seqlen_p) + # We compute metadata for chunked prefill once at the top level # model forward and reuse them in mamba layers. If not needed, # they will be ignored inside mamba kernels. @@ -224,5 +281,7 @@ def build(self, chunk_indices_p=chunk_indices_p, chunk_offsets_p=chunk_offsets_p, state_indices_tensor=state_indices_tensor, + cu_chunk_seqlen_p=cu_chunk_seqlen_p, + last_chunk_p=last_chunk_p, ) return attn_metadata From 664a21a31c56c457bcba2fc785dc4bfa0bc2f9ab Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 11 Sep 2025 14:06:48 -0400 Subject: [PATCH 03/37] fix bug Signed-off-by: Thomas Parnell --- vllm/model_executor/layers/mamba/ops/ssd_bmm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py index e7980af51a74..4debe5375f78 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py @@ -141,7 +141,7 @@ def _bmm_chunk_fwd_kernel( chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1) a_ptr += pid_b * stride_a_batch + chunk_seqlen_start * stride_a_seqlen + pid_h * stride_a_head - b_ptr += pid_b * stride_b_batch + chunk_seqlen_start * chunk_size * stride_b_seqlen + pid_h * stride_b_head + b_ptr += pid_b * stride_b_batch + chunk_seqlen_start * stride_b_seqlen + pid_h * stride_b_head if HAS_SEQ_IDX: seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen From 6c475d62e6f4e93b7f5bcbc091301ba362bce697 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 11 Sep 2025 14:14:08 -0400 Subject: [PATCH 04/37] fix bug Signed-off-by: Thomas Parnell --- vllm/model_executor/layers/mamba/ops/ssd_combined.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index c3f8db39e7ae..174ba9c0ae26 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -174,7 +174,7 @@ def _mamba_chunk_scan_combined_fwd(x, print(states.shape) varlen_states = states[last_chunk] print(varlen_states.shape) - return out_x, dt, dA_cumsum, states, final_states, states + return out_x, dt, dA_cumsum, states, final_states, varlen_states def mamba_chunk_scan_combined(x, From 0d00c69acd045cd3f5800927841c837d54eaf13d Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 11 Sep 2025 14:19:11 -0400 Subject: [PATCH 05/37] fix bug Signed-off-by: Thomas Parnell --- vllm/model_executor/layers/mamba/ops/ssd_combined.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index 174ba9c0ae26..14f4903ac7df 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -172,7 +172,7 @@ def _mamba_chunk_scan_combined_fwd(x, assert batch == 1, "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1" print("last_chunk: ", last_chunk) print(states.shape) - varlen_states = states[last_chunk] + varlen_states = states[:, last_chunk, ...] print(varlen_states.shape) return out_x, dt, dA_cumsum, states, final_states, varlen_states From b7ae698358db8290f8e04471544bd48579e82d60 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 11 Sep 2025 17:34:37 -0400 Subject: [PATCH 06/37] working changes Signed-off-by: Thomas Parnell --- .../layers/mamba/mamba_mixer2.py | 3 + .../layers/mamba/ops/ssd_bmm.py | 4 +- .../layers/mamba/ops/ssd_chunk_scan.py | 103 +++++++++++------- .../layers/mamba/ops/ssd_chunk_state.py | 32 +++--- .../layers/mamba/ops/ssd_combined.py | 19 +++- .../layers/mamba/ops/ssd_state_passing.py | 70 ++++-------- 6 files changed, 118 insertions(+), 113 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 222f89f2c35b..34c5a9280752 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -684,6 +684,9 @@ def forward_cuda( self.head_dim), state_dtype=ssm_state.dtype) + print("preallocated_ssm_out_p: ", preallocated_ssm_out_p[0,:10]) + print("varlen_state: ", varlen_state[0,0,0,:10]) + # update ssm states # - varlen state is a (num_prefills, nheads, headdim, dstate) tensor ssm_state[state_indices_tensor_p] = varlen_state diff --git a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py index 4debe5375f78..3a245b127f01 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py @@ -142,8 +142,6 @@ def _bmm_chunk_fwd_kernel( a_ptr += pid_b * stride_a_batch + chunk_seqlen_start * stride_a_seqlen + pid_h * stride_a_head b_ptr += pid_b * stride_b_batch + chunk_seqlen_start * stride_b_seqlen + pid_h * stride_b_head - if HAS_SEQ_IDX: - seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) @@ -152,7 +150,7 @@ def _bmm_chunk_fwd_kernel( offs_k[None, :] * stride_ak) b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen) - chunk_size_limit = chunk_seqlen_end + chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index 9f23b4103a97..65777d4e0789 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -13,7 +13,7 @@ TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') - +''' @triton.autotune( configs=[ triton.Config( @@ -107,6 +107,7 @@ ], key=['chunk_size', 'hdim', 'dstate', 'IS_CAUSAL'], ) +''' @triton.jit def _chunk_scan_fwd_kernel( # Pointers to matrices @@ -216,28 +217,19 @@ def _chunk_scan_fwd_kernel( # M-block offsets and prev states # - logic in next block may override these if there is an active offset offs_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M) - prev_states_ptr = states_ptr + pid_b * stride_states_batch + c_idx * stride_states_chunk + pid_h * stride_states_head - prev_states_hdim = stride_states_hdim - prev_states_dstate = stride_states_dstate - - chunk_size_limit = chunk_seqlen_end + #prev_states_ptr = states_ptr + pid_b * stride_states_batch + c_idx * stride_states_chunk + pid_h * stride_states_head + #prev_states_hdim = stride_states_hdim + #prev_states_dstate = stride_states_dstate - if HAS_SEQ_IDX: - seq_idx_ptr += pid_b * stride_seq_idx_batch + chunk_seqlen_start * stride_seq_idx_seqlen + chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start - # current sequence index - seq_idx = tl.load(seq_idx_ptr) - # previous sequence index - seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, + seq_idx_ptr += pid_b * stride_seq_idx_batch + chunk_seqlen_start * stride_seq_idx_seqlen + seq_idx = tl.load(seq_idx_ptr) + seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=c_idx >= 1, - other=0) + other=-1) - if HAS_INITSTATES and (seq_idx != seq_idx_prev): - # - replace prev_states_ptr with init_states - prev_states_ptr = initstates_ptr + seq_idx * stride_init_states_batch + pid_h * stride_init_states_head - prev_states_hdim = stride_init_states_hdim # override strides - prev_states_dstate = stride_init_states_dstate offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, @@ -256,31 +248,39 @@ def _chunk_scan_fwd_kernel( C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate) - prev_states_ptrs = prev_states_ptr + ( - offs_n[None, :] * prev_states_hdim + - offs_k_dstate[:, None] * prev_states_dstate) - if HAS_SEQ_IDX: - - if not HAS_INITSTATES: - # - this is for continuous batching where there is no init states - scale_m = tl.where(seq_idx == seq_idx_prev, tl.exp(dA_cs_m), - 0.0) - else: - # - if there is initstates, we will rely on prev_states, no zeroing - # required. - scale_m = tl.exp(dA_cs_m) - else: - scale_m = tl.exp(dA_cs_m) + scale_m = tl.exp(dA_cs_m) if BLOCK_SIZE_DSTATE <= 128: C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate), other=0.0) - prev_states = tl.load(prev_states_ptrs, + + if seq_idx != seq_idx_prev: + if HAS_INITSTATES: + # load from init states + init_states_ptrs = initstates_ptr + seq_idx * stride_init_states_batch \ + + pid_h * stride_init_states_head \ + + offs_n[None, :] * stride_init_states_hdim \ + + offs_k_dstate[:, None] * stride_init_states_dstate + prev_states = tl.load(init_states_ptrs, mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0) + else: + # Set to zero + prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_N), dtype=tl.float32) + else: + # Load from previous chunk + states_ptrs = states_ptr + seq_idx_prev * stride_states_batch \ + + pid_h * stride_states_head \ + + offs_n[None, :] * stride_states_hdim \ + + offs_k_dstate[:, None] * stride_states_dstate + prev_states = tl.load(states_ptrs, + mask=(offs_k_dstate[:, None] < dstate) & + (offs_n[None, :] < hdim), + other=0.0) + prev_states = prev_states.to(C_ptr.dtype.element_ty) acc = tl.dot(C, prev_states) * scale_m[:, None] else: @@ -290,11 +290,31 @@ def _chunk_scan_fwd_kernel( (offs_k_dstate[None, :] < dstate - k), other=0.0) # C = (C * scale_m[:, None]).to(C_ptr.dtype.element_ty) - prev_states = tl.load( - prev_states_ptrs, - mask=(offs_k_dstate[:, None] < dstate - k) & - (offs_n[None, :] < hdim), - other=0.0) + if seq_idx != seq_idx_prev: + if HAS_INITSTATES: + # load from init states + init_states_ptrs = initstates_ptr + seq_idx * stride_init_states_batch \ + + pid_h * stride_init_states_head \ + + offs_n[None, :] * stride_init_states_hdim \ + + offs_k_dstate[:, None] * stride_init_states_dstate + prev_states = tl.load(init_states_ptrs, + mask=(offs_k_dstate[:, None] < dstate) & + (offs_n[None, :] < hdim), + other=0.0) + else: + # Set to zero + prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_K), dtype=tl.float32) + else: + # Load from previous chunk + states_ptrs = states_ptr + seq_idx_prev * stride_states_batch \ + + pid_h * stride_states_head \ + + offs_n[None, :] * stride_states_hdim \ + + offs_k_dstate[:, None] * stride_states_dstate + prev_states = tl.load(states_ptrs, + mask=(offs_k_dstate[:, None] < dstate) & + (offs_n[None, :] < hdim), + other=0.0) + prev_states = prev_states.to(C_ptr.dtype.element_ty) acc += tl.dot(C, prev_states) C_ptrs += BLOCK_SIZE_K @@ -412,6 +432,8 @@ def _chunk_scan_fwd( assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) assert states.shape == (batch, nchunks, nheads, headdim, dstate) + print("out.shape: ", out.shape) + if seq_idx is not None: assert seq_idx.shape == (batch, seqlen) @@ -511,5 +533,8 @@ def _chunk_scan_fwd( HAS_SEQ_IDX=seq_idx is not None, IS_TRITON_22=TRITON_22, HAS_INITSTATES=initial_states is not None, + BLOCK_SIZE_M=64, + BLOCK_SIZE_N=64, + BLOCK_SIZE_K=32, ) return out_x diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index f68c8ca7d5e2..49c4678b4a87 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -15,18 +15,6 @@ from .mamba_ssm import softplus -@triton.autotune( - configs=[ - triton.Config({'BLOCK_SIZE_H': 1}), - triton.Config({'BLOCK_SIZE_H': 2}), - triton.Config({'BLOCK_SIZE_H': 4}), - triton.Config({'BLOCK_SIZE_H': 8}), - triton.Config({'BLOCK_SIZE_H': 16}), - triton.Config({'BLOCK_SIZE_H': 32}), - triton.Config({'BLOCK_SIZE_H': 64}), - ], - key=['chunk_size', 'nheads'], -) @triton.jit def _chunk_cumsum_fwd_kernel( # Pointers to matrices @@ -73,6 +61,7 @@ def _chunk_cumsum_fwd_kernel( chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c) chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1) + dt_ptr += pid_b * stride_dt_batch + chunk_seqlen_start * stride_dt_seqlen dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk @@ -86,10 +75,11 @@ def _chunk_cumsum_fwd_kernel( offs_c[None, :] * stride_dt_out_csize) dA_cs_ptrs = dA_cumsum_ptr + (offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize) + chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start dt = tl.load(dt_ptrs, mask=(offs_h[:, None] < nheads) & - (offs_c[None, :] < chunk_seqlen_end), + (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32) if HAS_DT_BIAS: dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, @@ -102,17 +92,17 @@ def _chunk_cumsum_fwd_kernel( # dt = tl.clamp(dt, dt_min, dt_max) dt = tl.minimum(tl.maximum(dt, dt_min), dt_max) dt = tl.where( - (offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_seqlen_end), dt, + (offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0) tl.store(dt_out_ptrs, dt, - mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit)) A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32) dA = dt * A[:, None] dA_cs = tl.cumsum(dA, axis=1) tl.store(dA_cs_ptrs, dA_cs, - mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit)) @triton.autotune( @@ -200,8 +190,8 @@ def _chunk_state_fwd_kernel( states_ptr, dt_ptr, dA_cumsum_ptr, - seq_idx_ptr, cu_chunk_seqlens_ptr, + seq_idx_ptr, # Matrix dimensions hdim, dstate, @@ -271,7 +261,7 @@ def _chunk_state_fwd_kernel( dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize - chunk_size_limit = chunk_seqlen_end + chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, chunk_size_limit, BLOCK_SIZE_K): @@ -577,6 +567,10 @@ def _chunk_cumsum_fwd(dt, dtype=torch.float32) grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H'])) + + print("dt_out.shape: ", dt_out.shape) + print("dA_cumsum.shape: ", dA_cumsum.shape) + with torch.cuda.device(dt.device.index): _chunk_cumsum_fwd_kernel[grid_chunk_cs]( dt, @@ -606,6 +600,7 @@ def _chunk_cumsum_fwd(dt, dA_cumsum.stride(3), dt_softplus, HAS_DT_BIAS=dt_bias is not None, + BLOCK_SIZE_H=1, BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size), ) return dA_cumsum, dt_out @@ -637,6 +632,7 @@ def _chunk_state_fwd(B, dtype=states_dtype) print("[_chunk_state_fwd] states.shape: ", states.shape) + print("[_chunk_state_fwd] cu_chunk_seqlens: ", cu_chunk_seqlens) grid = lambda META: ( triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv( diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index 14f4903ac7df..2c158f04b8cf 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -107,6 +107,10 @@ def _mamba_chunk_scan_combined_fwd(x, seq_idx=seq_idx, states_in_fp32=True) + print("after chunk_state_fwd: ") + print("states.shape: ", states.shape) + print("states: ", states[0,0,0,0,:10]) + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries # (middle term of factorization of off-diag blocks; A terms) # - for handling chunked prefill, this requires i) initial_states @@ -118,7 +122,7 @@ def _mamba_chunk_scan_combined_fwd(x, # - this will ensure that states will be updated with the rightmost flushed seq_idx # of the previous chunk. This implies that the first chunk of states is either 0 # or equal to init_states of the first example. - states, final_states = _state_passing_fwd( + states = _state_passing_fwd( rearrange(states, "... p n -> ... (p n)"), dA_cumsum, cu_chunk_seqlens, @@ -129,8 +133,13 @@ def _mamba_chunk_scan_combined_fwd(x, out_dtype=state_dtype if state_dtype is not None else C.dtype, is_cont_batched=cu_seqlens is not None, chunk_offsets=chunk_offsets) - states, final_states = (rearrange(t, "... (p n) -> ... p n", n=dstate) - for t in [states, final_states]) + + print("after state passing: ") + print("states.shape: ", states.shape) + + print("states: ", states[0 ,0, 0,:10]) + + states = rearrange(states, "... (p n) -> ... p n", n=dstate) # 4. Compute batched matrix multiply for C_j^T B_i terms CB = _bmm_chunk_fwd(C, @@ -172,8 +181,10 @@ def _mamba_chunk_scan_combined_fwd(x, assert batch == 1, "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1" print("last_chunk: ", last_chunk) print(states.shape) - varlen_states = states[:, last_chunk, ...] + varlen_states = states[:, last_chunk, ...].clone() print(varlen_states.shape) + print("varlen_states: ", varlen_states[0,0,0,:10]) + final_states = states[:, -1, ...] return out_x, dt, dA_cumsum, states, final_states, varlen_states diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py index 4d2bea947be5..b084ed317ee6 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -27,7 +27,6 @@ def _state_passing_fwd_kernel( # Pointers to matrices states_ptr, out_ptr, - final_states_ptr, dA_cs_ptr, initstates_ptr, seq_idx_ptr, @@ -48,9 +47,6 @@ def _state_passing_fwd_kernel( stride_out_chunk, stride_out_head, stride_out_dim, - stride_final_states_batch, - stride_final_states_head, - stride_final_states_dim, stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, @@ -73,11 +69,11 @@ def _state_passing_fwd_kernel( dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + ( chunk_size - 1) * stride_dA_cs_csize out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head - final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head if HAS_INITSTATES: initstates_ptr += pid_h * stride_initstates_head if not IS_CONT_BATCHED: initstates_ptr += pid_b * stride_initstates_batch + initstates_ptr += offs_m * stride_initstates_dim if HAS_SEQ_IDX: seq_idx_ptr += pid_b * stride_seq_idx_batch @@ -85,59 +81,40 @@ def _state_passing_fwd_kernel( offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) states_ptrs = states_ptr + offs_m * stride_states_dim out_ptrs = out_ptr + offs_m * stride_out_dim - final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim - # - states will be the past state of the sequence that continues on the current check - if not HAS_INITSTATES: - states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) - else: - initstates_ptr += offs_m * stride_initstates_dim - initstates_ptrs = initstates_ptr - # - for cont batches, for the first chunk mean it will be the first batch's - # init state - states = tl.load(initstates_ptrs, mask=offs_m < dim, + if HAS_INITSTATES: + initstates_ptrs = initstates_ptr + 0 * stride_initstates_batch + states = tl.load(initstates_ptrs, + mask=offs_m < dim, other=0.0).to(tl.float32) - - tl.store(out_ptrs, states, mask=offs_m < dim) - out_ptrs += stride_out_chunk + else: + states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) prev_seq_idx = 0 for c in range(nchunks): - chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + c) new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) dA_cs = tl.load(dA_cs_ptr).to(tl.float32) - scale_mask = True - if HAS_SEQ_IDX: - seq_idx = tl.load(seq_idx_ptr + chunk_seqlen_start * stride_seq_idx_seqlen) + seq_idx = tl.load(seq_idx_ptr + chunk_seqlen_start * stride_seq_idx_seqlen) + # we are started a new sequence + if prev_seq_idx != seq_idx: if HAS_INITSTATES: - if IS_CONT_BATCHED and prev_seq_idx != seq_idx: - # this means in the current chunk the rightmost flushed seq - # has changed. - # - so we do not propagate the state from previous chunk - # - but rather we load that sequence's init state - initstates_ptrs = initstates_ptr + seq_idx * stride_initstates_batch - - # - update state with seq_idx_new's init state - states = tl.load(initstates_ptrs, - mask=offs_m < dim, - other=0.0).to(tl.float32) + initstates_ptrs = initstates_ptr + seq_idx * stride_initstates_batch + states = tl.load(initstates_ptrs, + mask=offs_m < dim, + other=0.0).to(tl.float32) else: - scale_mask = seq_idx == prev_seq_idx + states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) - prev_seq_idx = seq_idx + prev_seq_idx = seq_idx - scale = tl.where(scale_mask, tl.exp(dA_cs), 0.0) - states = scale * states + new_states - if c < nchunks - 1: - tl.store(out_ptrs, states, mask=offs_m < dim) - else: - tl.store(final_states_ptrs, states, mask=offs_m < dim) + states = tl.exp(dA_cs) * states + new_states + tl.store(out_ptrs, states, mask=offs_m < dim) states_ptrs += stride_states_chunk dA_cs_ptr += stride_dA_cs_chunk out_ptrs += stride_out_chunk @@ -155,6 +132,7 @@ def _state_passing_fwd( chunk_offsets=None, ): batch, nchunks, nheads, dim = states.shape + assert batch == 1 if chunk_size is None: chunk_size = dA_cumsum.shape[-1] else: @@ -182,15 +160,12 @@ def _state_passing_fwd( out = torch.empty((batch, nchunks, nheads, dim), device=states.device, dtype=out_dtype) - final_states = torch.empty((batch, nheads, dim), - device=states.device, - dtype=torch.float32) + grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads) with torch.cuda.device(states.device.index): _state_passing_fwd_kernel[grid]( states, out, - final_states, dA_cumsum, initial_states, seq_idx, @@ -209,9 +184,6 @@ def _state_passing_fwd( out.stride(1), out.stride(2), out.stride(3), - final_states.stride(0), - final_states.stride(1), - final_states.stride(2), dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), @@ -225,4 +197,4 @@ def _state_passing_fwd( HAS_SEQ_IDX=seq_idx is not None, IS_CONT_BATCHED=is_cont_batched, ) - return out, final_states + return out From 9b24bce7b1a328aa4e056939f97a5a6efe9942b9 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 11 Sep 2025 17:56:57 -0400 Subject: [PATCH 07/37] Fix bugs Signed-off-by: Thomas Parnell --- vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py | 6 +++--- vllm/model_executor/layers/mamba/ops/ssd_combined.py | 5 +++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index 49c4678b4a87..edfa4b12bd1a 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -256,13 +256,13 @@ def _chunk_state_fwd_kernel( offs_k[:, None] * stride_b_seqlen) dt_ptrs = dt_ptr + offs_k * stride_dt_csize + chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start + dA_cs_last = tl.load(dA_cumsum_ptr + - (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) + (chunk_size_limit - 1) * stride_dA_cs_csize).to(tl.float32) dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize - chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, chunk_size_limit, BLOCK_SIZE_K): x = tl.load(x_ptrs, diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index 2c158f04b8cf..5c62e01f6dd7 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -97,6 +97,9 @@ def _mamba_chunk_scan_combined_fwd(x, dt_softplus=dt_softplus, dt_limit=dt_limit) + print("dA_cumsum: ", dA_cumsum[0,0,0,:10]) + print("dt: ", dt[0,0,0,:10]) + # 2. Compute the state for each intra-chunk # (right term of low-rank factorization of off-diagonal blocks; B terms) states = _chunk_state_fwd(B, @@ -107,8 +110,6 @@ def _mamba_chunk_scan_combined_fwd(x, seq_idx=seq_idx, states_in_fp32=True) - print("after chunk_state_fwd: ") - print("states.shape: ", states.shape) print("states: ", states[0,0,0,0,:10]) # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries From e850661da32c7c8c86f95a4f929e1d5ae3c89ee6 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 12 Sep 2025 00:32:10 -0400 Subject: [PATCH 08/37] working changes Signed-off-by: Thomas Parnell --- .../layers/mamba/ops/ssd_chunk_scan.py | 27 +++++++++++-------- .../layers/mamba/ops/ssd_combined.py | 6 +---- .../layers/mamba/ops/ssd_state_passing.py | 14 +++++----- 3 files changed, 24 insertions(+), 23 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index 65777d4e0789..0aa7f2b28159 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -233,7 +233,7 @@ def _chunk_scan_fwd_kernel( offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, - mask=offs_m < chunk_size, + mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) @@ -248,7 +248,8 @@ def _chunk_scan_fwd_kernel( C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate) - scale_m = tl.exp(dA_cs_m) + scale_m = tl.where(seq_idx == seq_idx_prev, tl.exp(dA_cs_m), 0.0) + if BLOCK_SIZE_DSTATE <= 128: C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & @@ -284,6 +285,7 @@ def _chunk_scan_fwd_kernel( prev_states = prev_states.to(C_ptr.dtype.element_ty) acc = tl.dot(C, prev_states) * scale_m[:, None] else: + offset_tpa = 0 for k in range(0, dstate, BLOCK_SIZE_K): C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & @@ -296,9 +298,10 @@ def _chunk_scan_fwd_kernel( init_states_ptrs = initstates_ptr + seq_idx * stride_init_states_batch \ + pid_h * stride_init_states_head \ + offs_n[None, :] * stride_init_states_hdim \ - + offs_k_dstate[:, None] * stride_init_states_dstate + + offs_k_dstate[:, None] * stride_init_states_dstate \ + + offset_tpa prev_states = tl.load(init_states_ptrs, - mask=(offs_k_dstate[:, None] < dstate) & + mask=(offs_k_dstate[:, None] < dstate-k) & (offs_n[None, :] < hdim), other=0.0) else: @@ -309,16 +312,18 @@ def _chunk_scan_fwd_kernel( states_ptrs = states_ptr + seq_idx_prev * stride_states_batch \ + pid_h * stride_states_head \ + offs_n[None, :] * stride_states_hdim \ - + offs_k_dstate[:, None] * stride_states_dstate + + offs_k_dstate[:, None] * stride_states_dstate \ + + offset_tpa prev_states = tl.load(states_ptrs, - mask=(offs_k_dstate[:, None] < dstate) & + mask=(offs_k_dstate[:, None] < dstate-k) & (offs_n[None, :] < hdim), other=0.0) prev_states = prev_states.to(C_ptr.dtype.element_ty) acc += tl.dot(C, prev_states) C_ptrs += BLOCK_SIZE_K - prev_states_ptrs += BLOCK_SIZE_K + offset_tpa += BLOCK_SIZE_K + acc *= scale_m[:, None] offs_k = tl.arange(0, BLOCK_SIZE_K) + c_off @@ -332,16 +337,16 @@ def _chunk_scan_fwd_kernel( (pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit) for k in range(0, K_MAX, BLOCK_SIZE_K): cb = tl.load(cb_ptrs, - mask=(offs_m[:, None] < chunk_size) & - (offs_k[None, :] < chunk_size - k), + mask=(offs_m[:, None] < chunk_size_limit) & + (offs_k[None, :] < chunk_size_limit - k), other=0.0).to(tl.float32) dA_cs_k = tl.load(dA_cumsum_ptrs, - mask=offs_k < chunk_size - k, + mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) # If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j]. # So we don't need masking wrt seq_idx here. cb *= tl.exp(dA_cs_m[:, None] - dA_cs_k[None, :]) - dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) cb *= dt_k if IS_CAUSAL: diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index 5c62e01f6dd7..84ac31542506 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -136,8 +136,6 @@ def _mamba_chunk_scan_combined_fwd(x, chunk_offsets=chunk_offsets) print("after state passing: ") - print("states.shape: ", states.shape) - print("states: ", states[0 ,0, 0,:10]) states = rearrange(states, "... (p n) -> ... p n", n=dstate) @@ -181,9 +179,7 @@ def _mamba_chunk_scan_combined_fwd(x, else: assert batch == 1, "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1" print("last_chunk: ", last_chunk) - print(states.shape) - varlen_states = states[:, last_chunk, ...].clone() - print(varlen_states.shape) + varlen_states = states[:, last_chunk, ...].clone().squeeze(0) print("varlen_states: ", varlen_states[0,0,0,:10]) final_states = states[:, -1, ...] return out_x, dt, dA_cumsum, states, final_states, varlen_states diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py index b084ed317ee6..e7d00a8fdd89 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -69,11 +69,6 @@ def _state_passing_fwd_kernel( dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + ( chunk_size - 1) * stride_dA_cs_csize out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head - if HAS_INITSTATES: - initstates_ptr += pid_h * stride_initstates_head - if not IS_CONT_BATCHED: - initstates_ptr += pid_b * stride_initstates_batch - initstates_ptr += offs_m * stride_initstates_dim if HAS_SEQ_IDX: seq_idx_ptr += pid_b * stride_seq_idx_batch @@ -83,7 +78,10 @@ def _state_passing_fwd_kernel( out_ptrs = out_ptr + offs_m * stride_out_dim if HAS_INITSTATES: - initstates_ptrs = initstates_ptr + 0 * stride_initstates_batch + initstates_ptrs = initstates_ptr + 0 * stride_initstates_batch \ + + pid_h * stride_initstates_head \ + + offs_m * stride_initstates_dim + states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) @@ -104,7 +102,9 @@ def _state_passing_fwd_kernel( # we are started a new sequence if prev_seq_idx != seq_idx: if HAS_INITSTATES: - initstates_ptrs = initstates_ptr + seq_idx * stride_initstates_batch + initstates_ptrs = initstates_ptr + seq_idx * stride_initstates_batch \ + + pid_h * stride_initstates_head \ + + offs_m * stride_initstates_dim states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) From 0d5c3ae9559716882f44f4c8aa7ad7739df5f1cf Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 12 Sep 2025 01:02:40 -0400 Subject: [PATCH 09/37] revert some changes Signed-off-by: Thomas Parnell --- vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index edfa4b12bd1a..47077872356e 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -96,13 +96,13 @@ def _chunk_cumsum_fwd_kernel( 0.0) tl.store(dt_out_ptrs, dt, - mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit)) + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32) dA = dt * A[:, None] dA_cs = tl.cumsum(dA, axis=1) tl.store(dA_cs_ptrs, dA_cs, - mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit)) + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) @triton.autotune( From 31e05fae6baef01fda20df7577d581e48aeade9c Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 12 Sep 2025 01:03:04 -0400 Subject: [PATCH 10/37] fmt Signed-off-by: Thomas Parnell --- vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index 47077872356e..12196806e272 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -6,7 +6,6 @@ # ruff: noqa: E501 -import math import torch @@ -61,7 +60,6 @@ def _chunk_cumsum_fwd_kernel( chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c) chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1) - dt_ptr += pid_b * stride_dt_batch + chunk_seqlen_start * stride_dt_seqlen dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk @@ -237,7 +235,6 @@ def _chunk_state_fwd_kernel( pid_m = tl.program_id(axis=0) // num_pid_n pid_n = tl.program_id(axis=0) % num_pid_n - chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c) chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1) @@ -259,7 +256,8 @@ def _chunk_state_fwd_kernel( chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start dA_cs_last = tl.load(dA_cumsum_ptr + - (chunk_size_limit - 1) * stride_dA_cs_csize).to(tl.float32) + (chunk_size_limit - 1) * stride_dA_cs_csize).to( + tl.float32) dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize @@ -552,7 +550,7 @@ def _chunk_cumsum_fwd(dt, assert A.shape == (nheads, ) if dt_bias is not None: assert dt_bias.shape == (nheads, ) - nchunks = cu_chunk_seqlens.shape[0]-1 + nchunks = cu_chunk_seqlens.shape[0] - 1 dt_out = torch.empty(batch, nheads, nchunks, From a8aff97cc0180fd9547b24f9d1036dc0e74b40ce Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 12 Sep 2025 03:07:00 -0400 Subject: [PATCH 11/37] workign Signed-off-by: Thomas Parnell --- .../layers/mamba/ops/ssd_chunk_scan.py | 15 +++++---------- .../layers/mamba/ops/ssd_chunk_state.py | 15 +++++++++++++-- .../layers/mamba/ops/ssd_combined.py | 2 ++ 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index 0aa7f2b28159..0637a7cda4dd 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -13,7 +13,6 @@ TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') -''' @triton.autotune( configs=[ triton.Config( @@ -107,7 +106,6 @@ ], key=['chunk_size', 'hdim', 'dstate', 'IS_CAUSAL'], ) -''' @triton.jit def _chunk_scan_fwd_kernel( # Pointers to matrices @@ -233,7 +231,7 @@ def _chunk_scan_fwd_kernel( offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, - mask=offs_m < chunk_size_limit, + mask=offs_m < chunk_size, other=0.0).to(tl.float32) acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) @@ -337,16 +335,16 @@ def _chunk_scan_fwd_kernel( (pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit) for k in range(0, K_MAX, BLOCK_SIZE_K): cb = tl.load(cb_ptrs, - mask=(offs_m[:, None] < chunk_size_limit) & - (offs_k[None, :] < chunk_size_limit - k), + mask=(offs_m[:, None] < chunk_size) & + (offs_k[None, :] < chunk_size - k), other=0.0).to(tl.float32) dA_cs_k = tl.load(dA_cumsum_ptrs, - mask=offs_k < chunk_size_limit - k, + mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) # If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j]. # So we don't need masking wrt seq_idx here. cb *= tl.exp(dA_cs_m[:, None] - dA_cs_k[None, :]) - dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) cb *= dt_k if IS_CAUSAL: @@ -538,8 +536,5 @@ def _chunk_scan_fwd( HAS_SEQ_IDX=seq_idx is not None, IS_TRITON_22=TRITON_22, HAS_INITSTATES=initial_states is not None, - BLOCK_SIZE_M=64, - BLOCK_SIZE_N=64, - BLOCK_SIZE_K=32, ) return out_x diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index 12196806e272..7715a5107467 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -13,7 +13,18 @@ from .mamba_ssm import softplus - +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_H': 1}), + triton.Config({'BLOCK_SIZE_H': 2}), + triton.Config({'BLOCK_SIZE_H': 4}), + triton.Config({'BLOCK_SIZE_H': 8}), + triton.Config({'BLOCK_SIZE_H': 16}), + triton.Config({'BLOCK_SIZE_H': 32}), + triton.Config({'BLOCK_SIZE_H': 64}), + ], + key=['chunk_size', 'nheads'], +) @triton.jit def _chunk_cumsum_fwd_kernel( # Pointers to matrices @@ -255,6 +266,7 @@ def _chunk_state_fwd_kernel( chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start + # should this be limit or not? dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size_limit - 1) * stride_dA_cs_csize).to( tl.float32) @@ -598,7 +610,6 @@ def _chunk_cumsum_fwd(dt, dA_cumsum.stride(3), dt_softplus, HAS_DT_BIAS=dt_bias is not None, - BLOCK_SIZE_H=1, BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size), ) return dA_cumsum, dt_out diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index 84ac31542506..a89bcb14d90b 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -148,6 +148,8 @@ def _mamba_chunk_scan_combined_fwd(x, seq_idx=seq_idx, output_dtype=torch.float32) + print("CB: ", CB[0,0,0,0,:10]) + # 5. Scan and compute the diagonal blocks, taking into # account past causal states. # - if initial states are provided, then states information will be From 67db9b4175b0e861792914e580706f20ecce4cac Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 12 Sep 2025 03:45:51 -0400 Subject: [PATCH 12/37] working changes Signed-off-by: Thomas Parnell --- .../layers/mamba/mamba_mixer2.py | 5 ++++- .../layers/mamba/ops/ssd_combined.py | 18 +++++++++++++++--- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 34c5a9280752..0ae15cc1aa2a 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -682,7 +682,10 @@ def forward_cuda( dt_limit=(0.0, float("inf")), out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1, self.head_dim), - state_dtype=ssm_state.dtype) + state_dtype=ssm_state.dtype, + layer=self.prefix, + ) + print("preallocated_ssm_out_p: ", preallocated_ssm_out_p[0,:10]) print("varlen_state: ", varlen_state[0,0,0,:10]) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index a89bcb14d90b..71f43b445e73 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -44,7 +44,8 @@ def _mamba_chunk_scan_combined_fwd(x, dt_softplus=False, dt_limit=(0.0, float("inf")), state_dtype=None, - out=None): + out=None, + layer=None): assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2" batch, seqlen, nheads, headdim = x.shape _, _, ngroups, dstate = B.shape @@ -97,7 +98,16 @@ def _mamba_chunk_scan_combined_fwd(x, dt_softplus=dt_softplus, dt_limit=dt_limit) + + print("layer: ", layer) + + + dA_cumsum_ref = torch.load("dump/dA_cumsum_%s_main" % (layer)) + print("dA_cumsum: ", dA_cumsum[0,0,0,:10]) + print("dA_cumsum_ref: ", dA_cumsum_ref[0,0,0,:10]) + torch.testing.assert_close(dA_cumsum, dA_cumsum_ref, atol=0.0, rtol=0.0) + print("dt: ", dt[0,0,0,:10]) # 2. Compute the state for each intra-chunk @@ -208,7 +218,8 @@ def mamba_chunk_scan_combined(x, out=None, return_final_states=False, return_varlen_states=False, - state_dtype=None): + state_dtype=None, + layer=None): """ Argument: x: (batch, seqlen, nheads, headdim) @@ -253,7 +264,8 @@ def mamba_chunk_scan_combined(x, dt_softplus=dt_softplus, dt_limit=dt_limit, out=out, - state_dtype=state_dtype) + state_dtype=state_dtype, + layer=layer) if not return_varlen_states: if not return_final_states: return From d841e827592644152266387da2b52abd79cb2a0a Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 12 Sep 2025 03:58:03 -0400 Subject: [PATCH 13/37] working changes Signed-off-by: Thomas Parnell --- vllm/model_executor/layers/mamba/ops/ssd_combined.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index 71f43b445e73..73a6507c4cb9 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -103,12 +103,11 @@ def _mamba_chunk_scan_combined_fwd(x, dA_cumsum_ref = torch.load("dump/dA_cumsum_%s_main" % (layer)) - - print("dA_cumsum: ", dA_cumsum[0,0,0,:10]) - print("dA_cumsum_ref: ", dA_cumsum_ref[0,0,0,:10]) torch.testing.assert_close(dA_cumsum, dA_cumsum_ref, atol=0.0, rtol=0.0) - print("dt: ", dt[0,0,0,:10]) + dt_ref = torch.load("dump/dt_%s_main" % (layer)) + torch.testing.assert_close(dt, dt_ref, atol=0.0, rtol=0.0) + # 2. Compute the state for each intra-chunk # (right term of low-rank factorization of off-diagonal blocks; B terms) @@ -120,7 +119,9 @@ def _mamba_chunk_scan_combined_fwd(x, seq_idx=seq_idx, states_in_fp32=True) - print("states: ", states[0,0,0,0,:10]) + states_ref = torch.load("dump/states_%s_main" % (layer)) + torch.testing.assert_close(states, states_ref, atol=0.0, rtol=0.0) + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries # (middle term of factorization of off-diag blocks; A terms) From 908aecb646b084ad71d70f816a1f02d8adbd18ee Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 12 Sep 2025 04:30:25 -0400 Subject: [PATCH 14/37] working changes Signed-off-by: Thomas Parnell --- .../layers/mamba/ops/ssd_chunk_state.py | 6 +++-- .../layers/mamba/ops/ssd_combined.py | 23 +++++++++++++------ 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index 7715a5107467..d15147dd5e29 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -13,6 +13,7 @@ from .mamba_ssm import softplus +''' @triton.autotune( configs=[ triton.Config({'BLOCK_SIZE_H': 1}), @@ -25,6 +26,7 @@ ], key=['chunk_size', 'nheads'], ) +''' @triton.jit def _chunk_cumsum_fwd_kernel( # Pointers to matrices @@ -266,9 +268,8 @@ def _chunk_state_fwd_kernel( chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start - # should this be limit or not? dA_cs_last = tl.load(dA_cumsum_ptr + - (chunk_size_limit - 1) * stride_dA_cs_csize).to( + (chunk_size - 1) * stride_dA_cs_csize).to( tl.float32) dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize @@ -610,6 +611,7 @@ def _chunk_cumsum_fwd(dt, dA_cumsum.stride(3), dt_softplus, HAS_DT_BIAS=dt_bias is not None, + BLOCK_SIZE_H=1, BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size), ) return dA_cumsum, dt_out diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index 73a6507c4cb9..6ef63e07b417 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -100,12 +100,15 @@ def _mamba_chunk_scan_combined_fwd(x, print("layer: ", layer) + has_init = initial_states is not None + print("has_init: ", has_init) - - dA_cumsum_ref = torch.load("dump/dA_cumsum_%s_main" % (layer)) + dA_cumsum_ref = torch.load("dump/dA_cumsum_%s_main_%d" % (layer, has_init)) + torch.cuda.synchronize() torch.testing.assert_close(dA_cumsum, dA_cumsum_ref, atol=0.0, rtol=0.0) - dt_ref = torch.load("dump/dt_%s_main" % (layer)) + dt_ref = torch.load("dump/dt_%s_main_%d" % (layer, has_init)) + torch.cuda.synchronize() torch.testing.assert_close(dt, dt_ref, atol=0.0, rtol=0.0) @@ -119,7 +122,8 @@ def _mamba_chunk_scan_combined_fwd(x, seq_idx=seq_idx, states_in_fp32=True) - states_ref = torch.load("dump/states_%s_main" % (layer)) + states_ref = torch.load("dump/states_%s_main_%d" % (layer, has_init)) + torch.cuda.synchronize() torch.testing.assert_close(states, states_ref, atol=0.0, rtol=0.0) @@ -146,11 +150,16 @@ def _mamba_chunk_scan_combined_fwd(x, is_cont_batched=cu_seqlens is not None, chunk_offsets=chunk_offsets) - print("after state passing: ") - print("states: ", states[0 ,0, 0,:10]) - states = rearrange(states, "... (p n) -> ... p n", n=dstate) + ''' + print("after state passing: ") + states_ref = torch.load("dump/final_states_%s_main_%d" % (layer, has_init)).unsqueeze(0) + print("states.shape: ", states.shape) + print("states_ref.shape: ", states_ref.shape) + torch.testing.assert_close(states, states_ref, atol=0.0, rtol=0.0) + ''' + # 4. Compute batched matrix multiply for C_j^T B_i terms CB = _bmm_chunk_fwd(C, B, From af7a2465d5b6559fe2e67f3a40386f7e68edc353 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 12 Sep 2025 04:57:14 -0400 Subject: [PATCH 15/37] working changes Signed-off-by: Thomas Parnell --- .../layers/mamba/ops/ssd_chunk_scan.py | 84 +------------------ .../layers/mamba/ops/ssd_chunk_state.py | 73 ---------------- .../layers/mamba/ops/ssd_combined.py | 13 ++- .../layers/mamba/ops/ssd_state_passing.py | 5 -- 4 files changed, 13 insertions(+), 162 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index 0637a7cda4dd..81814b5ce8bc 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -15,86 +15,6 @@ @triton.autotune( configs=[ - triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 256, - 'BLOCK_SIZE_K': 64 - }, - num_stages=3, - num_warps=8), - triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 256, - 'BLOCK_SIZE_K': 32 - }, - num_stages=4, - num_warps=4), - triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 32 - }, - num_stages=4, - num_warps=4), - triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32 - }, - num_stages=4, - num_warps=4), - triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 32 - }, - num_stages=4, - num_warps=4), - triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 64 - }, - num_stages=4, - num_warps=4), - triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 64 - }, - num_stages=4, - num_warps=4), - triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 32 - }, - num_stages=4, - num_warps=4), - triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 32 - }, - num_stages=5, - num_warps=2), - triton.Config( - { - 'BLOCK_SIZE_M': 32, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32 - }, - num_stages=5, - num_warps=2), triton.Config( { 'BLOCK_SIZE_M': 64, @@ -246,9 +166,11 @@ def _chunk_scan_fwd_kernel( C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate) - scale_m = tl.where(seq_idx == seq_idx_prev, tl.exp(dA_cs_m), 0.0) + #scale_m = tl.where(seq_idx == seq_idx_prev, tl.exp(dA_cs_m), 0.0) + scale_m = tl.exp(dA_cs_m) if BLOCK_SIZE_DSTATE <= 128: + C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate), diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index d15147dd5e29..e11dab8c5c34 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -13,20 +13,12 @@ from .mamba_ssm import softplus -''' @triton.autotune( configs=[ - triton.Config({'BLOCK_SIZE_H': 1}), - triton.Config({'BLOCK_SIZE_H': 2}), - triton.Config({'BLOCK_SIZE_H': 4}), triton.Config({'BLOCK_SIZE_H': 8}), - triton.Config({'BLOCK_SIZE_H': 16}), - triton.Config({'BLOCK_SIZE_H': 32}), - triton.Config({'BLOCK_SIZE_H': 64}), ], key=['chunk_size', 'nheads'], ) -''' @triton.jit def _chunk_cumsum_fwd_kernel( # Pointers to matrices @@ -118,70 +110,6 @@ def _chunk_cumsum_fwd_kernel( @triton.autotune( configs=[ - triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 256, - 'BLOCK_SIZE_K': 64 - }, - num_stages=3, - num_warps=8), - triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 256, - 'BLOCK_SIZE_K': 32 - }, - num_stages=4, - num_warps=4), - triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 32 - }, - num_stages=4, - num_warps=4), - triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32 - }, - num_stages=4, - num_warps=4), - triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 32 - }, - num_stages=4, - num_warps=4), - triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 32 - }, - num_stages=4, - num_warps=4), - triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 32 - }, - num_stages=5, - num_warps=2), - triton.Config( - { - 'BLOCK_SIZE_M': 32, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32 - }, - num_stages=5, - num_warps=2), triton.Config( { 'BLOCK_SIZE_M': 64, @@ -611,7 +539,6 @@ def _chunk_cumsum_fwd(dt, dA_cumsum.stride(3), dt_softplus, HAS_DT_BIAS=dt_bias is not None, - BLOCK_SIZE_H=1, BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size), ) return dA_cumsum, dt_out diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index 6ef63e07b417..fcf439b7a496 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -152,13 +152,11 @@ def _mamba_chunk_scan_combined_fwd(x, states = rearrange(states, "... (p n) -> ... p n", n=dstate) - ''' print("after state passing: ") states_ref = torch.load("dump/final_states_%s_main_%d" % (layer, has_init)).unsqueeze(0) print("states.shape: ", states.shape) print("states_ref.shape: ", states_ref.shape) torch.testing.assert_close(states, states_ref, atol=0.0, rtol=0.0) - ''' # 4. Compute batched matrix multiply for C_j^T B_i terms CB = _bmm_chunk_fwd(C, @@ -168,7 +166,8 @@ def _mamba_chunk_scan_combined_fwd(x, seq_idx=seq_idx, output_dtype=torch.float32) - print("CB: ", CB[0,0,0,0,:10]) + CB_ref = torch.load("dump/CB_%s_main_%d" % (layer, has_init)) + torch.testing.assert_close(CB, CB_ref, atol=0.0, rtol=0.0) # 5. Scan and compute the diagonal blocks, taking into # account past causal states. @@ -196,6 +195,14 @@ def _mamba_chunk_scan_combined_fwd(x, initial_states=initial_states, out=out, ) + + out_x_ref = torch.load("dump/out_x_%s_main_%d" % (layer, has_init)) + torch.testing.assert_close(out_x, out_x_ref, atol=0.0, rtol=0.0) + + out_ref = torch.load("dump/out_%s_main_%d" % (layer, has_init)) + torch.testing.assert_close(out, out_ref, atol=0.0, rtol=0.0) + + if cu_seqlens is None: return out_x, dt, dA_cumsum, states, final_states else: diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py index e7d00a8fdd89..a345fad6795c 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -14,11 +14,6 @@ @triton.autotune( configs=[ triton.Config({'BLOCK_SIZE': 64}), - triton.Config({'BLOCK_SIZE': 128}), - triton.Config({'BLOCK_SIZE': 256}), - triton.Config({'BLOCK_SIZE': 512}), - triton.Config({'BLOCK_SIZE': 1024}), - triton.Config({'BLOCK_SIZE': 2048}), ], key=['dim'], ) From 7ce2b5972568447eddc95110f63ec316f05576d9 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 12 Sep 2025 05:42:59 -0400 Subject: [PATCH 16/37] Some test cases working Signed-off-by: Thomas Parnell --- .../layers/mamba/mamba_mixer2.py | 4 ++-- .../layers/mamba/ops/ssd_chunk_scan.py | 10 +++++++--- .../layers/mamba/ops/ssd_chunk_state.py | 8 ++++---- .../layers/mamba/ops/ssd_combined.py | 18 ++++++++++++------ vllm/v1/attention/backends/mamba2_attn.py | 18 +++++++++--------- vllm/v1/core/sched/scheduler.py | 2 +- 6 files changed, 35 insertions(+), 25 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 0ae15cc1aa2a..25ac56b72740 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -687,8 +687,8 @@ def forward_cuda( ) - print("preallocated_ssm_out_p: ", preallocated_ssm_out_p[0,:10]) - print("varlen_state: ", varlen_state[0,0,0,:10]) + #print("preallocated_ssm_out_p: ", preallocated_ssm_out_p[0,:10]) + #print("varlen_state: ", varlen_state[0,0,0,:10]) # update ssm states # - varlen state is a (num_prefills, nheads, headdim, dstate) tensor diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index 81814b5ce8bc..dc573bd01e68 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -188,9 +188,11 @@ def _chunk_scan_fwd_kernel( mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0) + prev_states = prev_states.to(C_ptr.dtype.element_ty) else: # Set to zero prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_N), dtype=tl.float32) + prev_states = prev_states.to(C_ptr.dtype.element_ty) else: # Load from previous chunk states_ptrs = states_ptr + seq_idx_prev * stride_states_batch \ @@ -201,8 +203,8 @@ def _chunk_scan_fwd_kernel( mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0) + prev_states = prev_states.to(C_ptr.dtype.element_ty) - prev_states = prev_states.to(C_ptr.dtype.element_ty) acc = tl.dot(C, prev_states) * scale_m[:, None] else: offset_tpa = 0 @@ -224,9 +226,11 @@ def _chunk_scan_fwd_kernel( mask=(offs_k_dstate[:, None] < dstate-k) & (offs_n[None, :] < hdim), other=0.0) + prev_states = prev_states.to(C_ptr.dtype.element_ty) else: # Set to zero prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_K), dtype=tl.float32) + prev_states = prev_states.to(C_ptr.dtype.element_ty) else: # Load from previous chunk states_ptrs = states_ptr + seq_idx_prev * stride_states_batch \ @@ -238,8 +242,8 @@ def _chunk_scan_fwd_kernel( mask=(offs_k_dstate[:, None] < dstate-k) & (offs_n[None, :] < hdim), other=0.0) + prev_states = prev_states.to(C_ptr.dtype.element_ty) - prev_states = prev_states.to(C_ptr.dtype.element_ty) acc += tl.dot(C, prev_states) C_ptrs += BLOCK_SIZE_K offset_tpa += BLOCK_SIZE_K @@ -357,7 +361,7 @@ def _chunk_scan_fwd( assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) assert states.shape == (batch, nchunks, nheads, headdim, dstate) - print("out.shape: ", out.shape) + #print("out.shape: ", out.shape) if seq_idx is not None: assert seq_idx.shape == (batch, seqlen) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index e11dab8c5c34..0e029b4de199 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -507,8 +507,8 @@ def _chunk_cumsum_fwd(dt, grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H'])) - print("dt_out.shape: ", dt_out.shape) - print("dA_cumsum.shape: ", dA_cumsum.shape) + #print("dt_out.shape: ", dt_out.shape) + #print("dA_cumsum.shape: ", dA_cumsum.shape) with torch.cuda.device(dt.device.index): _chunk_cumsum_fwd_kernel[grid_chunk_cs]( @@ -569,8 +569,8 @@ def _chunk_state_fwd(B, device=x.device, dtype=states_dtype) - print("[_chunk_state_fwd] states.shape: ", states.shape) - print("[_chunk_state_fwd] cu_chunk_seqlens: ", cu_chunk_seqlens) + #print("[_chunk_state_fwd] states.shape: ", states.shape) + #print("[_chunk_state_fwd] cu_chunk_seqlens: ", cu_chunk_seqlens) grid = lambda META: ( triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv( diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index fcf439b7a496..71afe952788b 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -98,7 +98,7 @@ def _mamba_chunk_scan_combined_fwd(x, dt_softplus=dt_softplus, dt_limit=dt_limit) - + ''' print("layer: ", layer) has_init = initial_states is not None print("has_init: ", has_init) @@ -110,6 +110,7 @@ def _mamba_chunk_scan_combined_fwd(x, dt_ref = torch.load("dump/dt_%s_main_%d" % (layer, has_init)) torch.cuda.synchronize() torch.testing.assert_close(dt, dt_ref, atol=0.0, rtol=0.0) + ''' # 2. Compute the state for each intra-chunk @@ -122,10 +123,11 @@ def _mamba_chunk_scan_combined_fwd(x, seq_idx=seq_idx, states_in_fp32=True) + ''' states_ref = torch.load("dump/states_%s_main_%d" % (layer, has_init)) torch.cuda.synchronize() torch.testing.assert_close(states, states_ref, atol=0.0, rtol=0.0) - + ''' # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries # (middle term of factorization of off-diag blocks; A terms) @@ -152,11 +154,13 @@ def _mamba_chunk_scan_combined_fwd(x, states = rearrange(states, "... (p n) -> ... p n", n=dstate) + ''' print("after state passing: ") states_ref = torch.load("dump/final_states_%s_main_%d" % (layer, has_init)).unsqueeze(0) print("states.shape: ", states.shape) print("states_ref.shape: ", states_ref.shape) torch.testing.assert_close(states, states_ref, atol=0.0, rtol=0.0) + ''' # 4. Compute batched matrix multiply for C_j^T B_i terms CB = _bmm_chunk_fwd(C, @@ -166,8 +170,10 @@ def _mamba_chunk_scan_combined_fwd(x, seq_idx=seq_idx, output_dtype=torch.float32) + ''' CB_ref = torch.load("dump/CB_%s_main_%d" % (layer, has_init)) torch.testing.assert_close(CB, CB_ref, atol=0.0, rtol=0.0) + ''' # 5. Scan and compute the diagonal blocks, taking into # account past causal states. @@ -195,21 +201,21 @@ def _mamba_chunk_scan_combined_fwd(x, initial_states=initial_states, out=out, ) - + ''' out_x_ref = torch.load("dump/out_x_%s_main_%d" % (layer, has_init)) torch.testing.assert_close(out_x, out_x_ref, atol=0.0, rtol=0.0) out_ref = torch.load("dump/out_%s_main_%d" % (layer, has_init)) torch.testing.assert_close(out, out_ref, atol=0.0, rtol=0.0) - + ''' if cu_seqlens is None: return out_x, dt, dA_cumsum, states, final_states else: assert batch == 1, "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1" - print("last_chunk: ", last_chunk) + #print("last_chunk: ", last_chunk) varlen_states = states[:, last_chunk, ...].clone().squeeze(0) - print("varlen_states: ", varlen_states[0,0,0,:10]) + #print("varlen_states: ", varlen_states[0,0,0,:10]) final_states = states[:, -1, ...] return out_x, dt, dA_cumsum, states, final_states, varlen_states diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index 91fb63dc486c..dd5961bc0553 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -176,10 +176,10 @@ def build(self, common_attn_metadata, decode_threshold=self.reorder_batch_threshold)) - print("num_decodes: ", num_decodes) - print("num_prefills: ", num_prefills) - print("num_decode_tokens: ", num_decode_tokens) - print("num_prefill_tokens: ", num_prefill_tokens) + #print("num_decodes: ", num_decodes) + #print("num_prefills: ", num_prefills) + #print("num_decode_tokens: ", num_decode_tokens) + #print("num_prefill_tokens: ", num_prefill_tokens) # Compute seq_idx, chunk_indices and chunk_offsets for prefill only if num_prefills > 0: @@ -209,8 +209,8 @@ def build(self, query_start_loc_p_cpu = common_attn_metadata.query_start_loc_cpu[ -num_prefills - 1:] - num_decode_tokens - print("num_computed_tokens_p: ", num_computed_tokens_p) - print("query_start_loc_p: ", query_start_loc_p) + #print("num_computed_tokens_p: ", num_computed_tokens_p) + #print("query_start_loc_p: ", query_start_loc_p) cu_chunk_seqlen = [] last_chunk = [] @@ -218,7 +218,7 @@ def build(self, for req_idx in range(num_prefills): this_num_computed = num_computed_tokens_p[req_idx].item() this_new_tokens = query_start_loc_p_cpu[req_idx+1].item() - query_start_loc_p_cpu[req_idx].item() - print(req_idx, this_num_computed, this_new_tokens) + #print(req_idx, this_num_computed, this_new_tokens) # if computed tokens are not chunk-aligned, use the first # chunk to finish it off @@ -247,8 +247,8 @@ def build(self, cu_chunk_seqlen_p = torch.as_tensor(cu_chunk_seqlen, device=query_start_loc.device, dtype=torch.int32) last_chunk_p = torch.as_tensor(last_chunk, device=query_start_loc.device, dtype=torch.int32) - print("cu_chunk_seqlen: ", cu_chunk_seqlen) - print("cu_chunk_seqlen_p: ", cu_chunk_seqlen_p) + #print("cu_chunk_seqlen: ", cu_chunk_seqlen) + #print("cu_chunk_seqlen_p: ", cu_chunk_seqlen_p) # We compute metadata for chunked prefill once at the top level # model forward and reuse them in mamba layers. If not needed, diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 49b6d99e4ab1..101867f5cfc5 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -598,7 +598,7 @@ def schedule(self) -> SchedulerOutput: structured_output_request_ids=structured_output_request_ids, grammar_bitmask=grammar_bitmask, ) - print(scheduler_output) + #print(scheduler_output) # NOTE(Kuntai): this function is designed for multiple purposes: # 1. Plan the KV cache store # 2. Wrap up all the KV cache load / save ops into an opaque object From f950f2eb7d4eb849878aa80b0b24a588689fa427 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 12 Sep 2025 07:58:45 -0400 Subject: [PATCH 17/37] Fix IMA Signed-off-by: Thomas Parnell --- .../layers/mamba/ops/ssd_chunk_scan.py | 141 ++++++++++-------- 1 file changed, 78 insertions(+), 63 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index dc573bd01e68..49daec377d2e 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -52,6 +52,7 @@ def _chunk_scan_fwd_kernel( batch, seqlen, nheads_ngroups_ratio, + nchunks, # Strides stride_cb_batch, stride_cb_chunk, @@ -107,6 +108,7 @@ def _chunk_scan_fwd_kernel( IS_TRITON_22: tl.constexpr, HAS_INITSTATES: tl.constexpr, ): + pid_bc = tl.program_id(axis=1).to(tl.int64) pid_c = pid_bc // batch pid_b = pid_bc - pid_c * batch @@ -156,84 +158,92 @@ def _chunk_scan_fwd_kernel( acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - # Without the if (pid_c > -1), with Triton 2.1.0, I get - # Assertion `!(srcMmaLayout && dstMmaLayout) && "Unexpected mma -> mm a layout conversion"' failed. - # With Triton 2.2.0, this works - if IS_TRITON_22 or c_idx > -1: - # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 - offs_k_dstate = tl.arange( - 0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) - C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + - offs_k_dstate[None, :] * stride_C_dstate) + offs_out_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + + + # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 + offs_k_dstate = tl.arange( + 0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) + C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + + offs_k_dstate[None, :] * stride_C_dstate) + + #scale_m = tl.where(seq_idx == seq_idx_prev, tl.exp(dA_cs_m), 0.0) + scale_m = tl.exp(dA_cs_m) - #scale_m = tl.where(seq_idx == seq_idx_prev, tl.exp(dA_cs_m), 0.0) - scale_m = tl.exp(dA_cs_m) + if BLOCK_SIZE_DSTATE <= 128: + + C = tl.load(C_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) & + (offs_k_dstate[None, :] < dstate), + other=0.0) + + if seq_idx != seq_idx_prev: + if HAS_INITSTATES: + # load from init states + init_states_ptrs = initstates_ptr + seq_idx * stride_init_states_batch \ + + pid_h * stride_init_states_head \ + + offs_n[None, :] * stride_init_states_hdim \ + + offs_k_dstate[:, None] * stride_init_states_dstate + prev_states = tl.load(init_states_ptrs, + mask=(offs_k_dstate[:, None] < dstate) & + (offs_n[None, :] < hdim), + other=0.0) + prev_states = prev_states.to(C_ptr.dtype.element_ty) + else: + # Set to zero + prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_N), dtype=tl.float32) + prev_states = prev_states.to(C_ptr.dtype.element_ty) + else: + if c_idx > 0: + tl.device_assert(c_idx < nchunks) + # Load from praevious chunk + states_ptrs = states_ptr + (c_idx-1) * stride_states_chunk \ + + pid_h * stride_states_head \ + + offs_n[None, :] * stride_states_hdim \ + + offs_k_dstate[:, None] * stride_states_dstate + prev_states = tl.load(states_ptrs, + mask=(offs_k_dstate[:, None] < dstate) & + (offs_n[None, :] < hdim), + other=0.0) + prev_states = prev_states.to(C_ptr.dtype.element_ty) + else: + # Set to zero + prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_N), dtype=tl.float32) + prev_states = prev_states.to(C_ptr.dtype.element_ty) - if BLOCK_SIZE_DSTATE <= 128: + acc = tl.dot(C, prev_states) * scale_m[:, None] + else: + offset_tpa = 0 + for k in range(0, dstate, BLOCK_SIZE_K): C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & - (offs_k_dstate[None, :] < dstate), + (offs_k_dstate[None, :] < dstate - k), other=0.0) - - + # C = (C * scale_m[:, None]).to(C_ptr.dtype.element_ty) if seq_idx != seq_idx_prev: if HAS_INITSTATES: # load from init states init_states_ptrs = initstates_ptr + seq_idx * stride_init_states_batch \ + pid_h * stride_init_states_head \ + offs_n[None, :] * stride_init_states_hdim \ - + offs_k_dstate[:, None] * stride_init_states_dstate + + offs_k_dstate[:, None] * stride_init_states_dstate \ + + offset_tpa prev_states = tl.load(init_states_ptrs, - mask=(offs_k_dstate[:, None] < dstate) & + mask=(offs_k_dstate[:, None] < dstate-k) & (offs_n[None, :] < hdim), other=0.0) prev_states = prev_states.to(C_ptr.dtype.element_ty) else: # Set to zero - prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_N), dtype=tl.float32) + prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_K), dtype=tl.float32) prev_states = prev_states.to(C_ptr.dtype.element_ty) else: - # Load from previous chunk - states_ptrs = states_ptr + seq_idx_prev * stride_states_batch \ - + pid_h * stride_states_head \ - + offs_n[None, :] * stride_states_hdim \ - + offs_k_dstate[:, None] * stride_states_dstate - prev_states = tl.load(states_ptrs, - mask=(offs_k_dstate[:, None] < dstate) & - (offs_n[None, :] < hdim), - other=0.0) - prev_states = prev_states.to(C_ptr.dtype.element_ty) - - acc = tl.dot(C, prev_states) * scale_m[:, None] - else: - offset_tpa = 0 - for k in range(0, dstate, BLOCK_SIZE_K): - C = tl.load(C_ptrs, - mask=(offs_m[:, None] < chunk_size_limit) & - (offs_k_dstate[None, :] < dstate - k), - other=0.0) - # C = (C * scale_m[:, None]).to(C_ptr.dtype.element_ty) - if seq_idx != seq_idx_prev: - if HAS_INITSTATES: - # load from init states - init_states_ptrs = initstates_ptr + seq_idx * stride_init_states_batch \ - + pid_h * stride_init_states_head \ - + offs_n[None, :] * stride_init_states_hdim \ - + offs_k_dstate[:, None] * stride_init_states_dstate \ - + offset_tpa - prev_states = tl.load(init_states_ptrs, - mask=(offs_k_dstate[:, None] < dstate-k) & - (offs_n[None, :] < hdim), - other=0.0) - prev_states = prev_states.to(C_ptr.dtype.element_ty) - else: - # Set to zero - prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_K), dtype=tl.float32) - prev_states = prev_states.to(C_ptr.dtype.element_ty) - else: + if c_idx > 0: # Load from previous chunk - states_ptrs = states_ptr + seq_idx_prev * stride_states_batch \ + states_ptrs = states_ptr + (c_idx-1) * stride_states_chunk \ + pid_h * stride_states_head \ + offs_n[None, :] * stride_states_hdim \ + offs_k_dstate[:, None] * stride_states_dstate \ @@ -243,12 +253,16 @@ def _chunk_scan_fwd_kernel( (offs_n[None, :] < hdim), other=0.0) prev_states = prev_states.to(C_ptr.dtype.element_ty) + else: + # Set to zero + prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_K), dtype=tl.float32) + prev_states = prev_states.to(C_ptr.dtype.element_ty) - acc += tl.dot(C, prev_states) - C_ptrs += BLOCK_SIZE_K - offset_tpa += BLOCK_SIZE_K + acc += tl.dot(C, prev_states) + C_ptrs += BLOCK_SIZE_K + offset_tpa += BLOCK_SIZE_K - acc *= scale_m[:, None] + acc *= scale_m[:, None] offs_k = tl.arange(0, BLOCK_SIZE_K) + c_off cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + @@ -322,9 +336,11 @@ def _chunk_scan_fwd_kernel( other=0.0).to(tl.float32) acc *= z * tl.sigmoid(z) + out_ptr += pid_b * stride_out_batch + chunk_seqlen_start * stride_out_seqlen + pid_h * stride_out_head out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] * stride_out_hdim) + tl.store(out_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & @@ -361,8 +377,6 @@ def _chunk_scan_fwd( assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) assert states.shape == (batch, nchunks, nheads, headdim, dstate) - #print("out.shape: ", out.shape) - if seq_idx is not None: assert seq_idx.shape == (batch, seqlen) @@ -413,6 +427,7 @@ def _chunk_scan_fwd( batch, seqlen, nheads // ngroups, + nchunks, cb.stride(0), cb.stride(1), cb.stride(2), From 75e01c87e1871bb3a5e94c8f55268c8b6953c3fc Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 12 Sep 2025 13:56:41 -0400 Subject: [PATCH 18/37] Add back autotune config Signed-off-by: Thomas Parnell --- .../layers/mamba/ops/ssd_chunk_scan.py | 80 +++++++++++++++++++ .../layers/mamba/ops/ssd_chunk_state.py | 70 ++++++++++++++++ .../layers/mamba/ops/ssd_state_passing.py | 5 ++ 3 files changed, 155 insertions(+) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index 49daec377d2e..4e116ae9fc8b 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -15,6 +15,86 @@ @triton.autotune( configs=[ + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 64 + }, + num_stages=3, + num_warps=8), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 64 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 64 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), triton.Config( { 'BLOCK_SIZE_M': 64, diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index 0e029b4de199..3f4a5b4e6006 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -15,7 +15,13 @@ @triton.autotune( configs=[ + triton.Config({'BLOCK_SIZE_H': 1}), + triton.Config({'BLOCK_SIZE_H': 2}), + triton.Config({'BLOCK_SIZE_H': 4}), triton.Config({'BLOCK_SIZE_H': 8}), + triton.Config({'BLOCK_SIZE_H': 16}), + triton.Config({'BLOCK_SIZE_H': 32}), + triton.Config({'BLOCK_SIZE_H': 64}), ], key=['chunk_size', 'nheads'], ) @@ -110,6 +116,70 @@ def _chunk_cumsum_fwd_kernel( @triton.autotune( configs=[ + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 64 + }, + num_stages=3, + num_warps=8), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), + triton.Config( + { + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32 + }, + num_stages=5, + num_warps=2), triton.Config( { 'BLOCK_SIZE_M': 64, diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py index a345fad6795c..e7d00a8fdd89 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -14,6 +14,11 @@ @triton.autotune( configs=[ triton.Config({'BLOCK_SIZE': 64}), + triton.Config({'BLOCK_SIZE': 128}), + triton.Config({'BLOCK_SIZE': 256}), + triton.Config({'BLOCK_SIZE': 512}), + triton.Config({'BLOCK_SIZE': 1024}), + triton.Config({'BLOCK_SIZE': 2048}), ], key=['dim'], ) From 2698f2eca296767ccab8c6def3428dc4bef6511c Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 12 Sep 2025 14:47:49 -0400 Subject: [PATCH 19/37] cleanup Signed-off-by: Thomas Parnell --- vllm/model_executor/layers/mamba/mamba_mixer2.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 25ac56b72740..222f89f2c35b 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -682,13 +682,7 @@ def forward_cuda( dt_limit=(0.0, float("inf")), out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1, self.head_dim), - state_dtype=ssm_state.dtype, - layer=self.prefix, - ) - - - #print("preallocated_ssm_out_p: ", preallocated_ssm_out_p[0,:10]) - #print("varlen_state: ", varlen_state[0,0,0,:10]) + state_dtype=ssm_state.dtype) # update ssm states # - varlen state is a (num_prefills, nheads, headdim, dstate) tensor From d3f05b7a8682c5e21bc22416f83b4a6e6c84f0e7 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 12 Sep 2025 14:49:45 -0400 Subject: [PATCH 20/37] cleanup Signed-off-by: Thomas Parnell --- vllm/model_executor/layers/mamba/ops/ssd_bmm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py index 3a245b127f01..786721733af7 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py @@ -168,7 +168,6 @@ def _bmm_chunk_fwd_kernel( offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - out = acc.to(out_ptr.dtype.element_ty) out_ptr += pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head From df635038e9b724c894ac4fdb20fb0bc44cc7fc8f Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 12 Sep 2025 14:55:24 -0400 Subject: [PATCH 21/37] cleanup Signed-off-by: Thomas Parnell --- .../layers/mamba/ops/ssd_chunk_state.py | 20 ++++--------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index 3f4a5b4e6006..6f710c76f5f3 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -13,6 +13,7 @@ from .mamba_ssm import softplus + @triton.autotune( configs=[ triton.Config({'BLOCK_SIZE_H': 1}), @@ -263,15 +264,12 @@ def _chunk_state_fwd_kernel( b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen) dt_ptrs = dt_ptr + offs_k * stride_dt_csize - - chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start - dA_cs_last = tl.load(dA_cumsum_ptr + - (chunk_size - 1) * stride_dA_cs_csize).to( - tl.float32) - + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, chunk_size_limit, BLOCK_SIZE_K): x = tl.load(x_ptrs, @@ -287,9 +285,7 @@ def _chunk_state_fwd_kernel( other=0.0).to(tl.float32) dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) - scale = tl.exp(dA_cs_last - dA_cs_k) * dt_k - b *= scale[:, None] b = b.to(x_ptr.dtype.element_ty) acc += tl.dot(x, b) @@ -297,7 +293,6 @@ def _chunk_state_fwd_kernel( b_ptrs += BLOCK_SIZE_K * stride_b_seqlen dt_ptrs += BLOCK_SIZE_K * stride_dt_csize dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize - states = acc.to(states_ptr.dtype.element_ty) states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head @@ -577,9 +572,6 @@ def _chunk_cumsum_fwd(dt, grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H'])) - #print("dt_out.shape: ", dt_out.shape) - #print("dA_cumsum.shape: ", dA_cumsum.shape) - with torch.cuda.device(dt.device.index): _chunk_cumsum_fwd_kernel[grid_chunk_cs]( dt, @@ -638,10 +630,6 @@ def _chunk_state_fwd(B, states = torch.empty((batch, nchunks, nheads, headdim, dstate), device=x.device, dtype=states_dtype) - - #print("[_chunk_state_fwd] states.shape: ", states.shape) - #print("[_chunk_state_fwd] cu_chunk_seqlens: ", cu_chunk_seqlens) - grid = lambda META: ( triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv( dstate, META['BLOCK_SIZE_N']), batch * nchunks, nheads) From c5edccdae6fbb740c7c40259ca4ddc26f595cdeb Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 12 Sep 2025 15:00:17 -0400 Subject: [PATCH 22/37] cleanup Signed-off-by: Thomas Parnell --- .../layers/mamba/ops/ssd_chunk_state.py | 2 - .../layers/mamba/ops/ssd_combined.py | 54 ++----------------- 2 files changed, 3 insertions(+), 53 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index 6f710c76f5f3..4d4e593b21c4 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -246,10 +246,8 @@ def _chunk_state_fwd_kernel( num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) pid_m = tl.program_id(axis=0) // num_pid_n pid_n = tl.program_id(axis=0) % num_pid_n - chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c) chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1) - b_ptr += pid_b * stride_b_batch + chunk_seqlen_start * stride_b_seqlen + ( pid_h // nheads_ngroups_ratio) * stride_b_head x_ptr += pid_b * stride_x_batch + chunk_seqlen_start * stride_x_seqlen + pid_h * stride_x_head diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index 71afe952788b..48d4c7e6da09 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -44,8 +44,7 @@ def _mamba_chunk_scan_combined_fwd(x, dt_softplus=False, dt_limit=(0.0, float("inf")), state_dtype=None, - out=None, - layer=None): + out=None): assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2" batch, seqlen, nheads, headdim = x.shape _, _, ngroups, dstate = B.shape @@ -98,21 +97,6 @@ def _mamba_chunk_scan_combined_fwd(x, dt_softplus=dt_softplus, dt_limit=dt_limit) - ''' - print("layer: ", layer) - has_init = initial_states is not None - print("has_init: ", has_init) - - dA_cumsum_ref = torch.load("dump/dA_cumsum_%s_main_%d" % (layer, has_init)) - torch.cuda.synchronize() - torch.testing.assert_close(dA_cumsum, dA_cumsum_ref, atol=0.0, rtol=0.0) - - dt_ref = torch.load("dump/dt_%s_main_%d" % (layer, has_init)) - torch.cuda.synchronize() - torch.testing.assert_close(dt, dt_ref, atol=0.0, rtol=0.0) - ''' - - # 2. Compute the state for each intra-chunk # (right term of low-rank factorization of off-diagonal blocks; B terms) states = _chunk_state_fwd(B, @@ -123,12 +107,6 @@ def _mamba_chunk_scan_combined_fwd(x, seq_idx=seq_idx, states_in_fp32=True) - ''' - states_ref = torch.load("dump/states_%s_main_%d" % (layer, has_init)) - torch.cuda.synchronize() - torch.testing.assert_close(states, states_ref, atol=0.0, rtol=0.0) - ''' - # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries # (middle term of factorization of off-diag blocks; A terms) # - for handling chunked prefill, this requires i) initial_states @@ -151,17 +129,8 @@ def _mamba_chunk_scan_combined_fwd(x, out_dtype=state_dtype if state_dtype is not None else C.dtype, is_cont_batched=cu_seqlens is not None, chunk_offsets=chunk_offsets) - states = rearrange(states, "... (p n) -> ... p n", n=dstate) - ''' - print("after state passing: ") - states_ref = torch.load("dump/final_states_%s_main_%d" % (layer, has_init)).unsqueeze(0) - print("states.shape: ", states.shape) - print("states_ref.shape: ", states_ref.shape) - torch.testing.assert_close(states, states_ref, atol=0.0, rtol=0.0) - ''' - # 4. Compute batched matrix multiply for C_j^T B_i terms CB = _bmm_chunk_fwd(C, B, @@ -170,11 +139,6 @@ def _mamba_chunk_scan_combined_fwd(x, seq_idx=seq_idx, output_dtype=torch.float32) - ''' - CB_ref = torch.load("dump/CB_%s_main_%d" % (layer, has_init)) - torch.testing.assert_close(CB, CB_ref, atol=0.0, rtol=0.0) - ''' - # 5. Scan and compute the diagonal blocks, taking into # account past causal states. # - if initial states are provided, then states information will be @@ -201,21 +165,11 @@ def _mamba_chunk_scan_combined_fwd(x, initial_states=initial_states, out=out, ) - ''' - out_x_ref = torch.load("dump/out_x_%s_main_%d" % (layer, has_init)) - torch.testing.assert_close(out_x, out_x_ref, atol=0.0, rtol=0.0) - - out_ref = torch.load("dump/out_%s_main_%d" % (layer, has_init)) - torch.testing.assert_close(out, out_ref, atol=0.0, rtol=0.0) - ''' - if cu_seqlens is None: return out_x, dt, dA_cumsum, states, final_states else: assert batch == 1, "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1" - #print("last_chunk: ", last_chunk) varlen_states = states[:, last_chunk, ...].clone().squeeze(0) - #print("varlen_states: ", varlen_states[0,0,0,:10]) final_states = states[:, -1, ...] return out_x, dt, dA_cumsum, states, final_states, varlen_states @@ -241,8 +195,7 @@ def mamba_chunk_scan_combined(x, out=None, return_final_states=False, return_varlen_states=False, - state_dtype=None, - layer=None): + state_dtype=None): """ Argument: x: (batch, seqlen, nheads, headdim) @@ -287,8 +240,7 @@ def mamba_chunk_scan_combined(x, dt_softplus=dt_softplus, dt_limit=dt_limit, out=out, - state_dtype=state_dtype, - layer=layer) + state_dtype=state_dtype) if not return_varlen_states: if not return_final_states: return From 712ced11f6ae00dce9433c76a25db343d0942f0c Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 12 Sep 2025 15:06:08 -0400 Subject: [PATCH 23/37] cleanup Signed-off-by: Thomas Parnell --- .../layers/mamba/ops/ssd_state_passing.py | 11 ++-------- vllm/v1/attention/backends/mamba2_attn.py | 21 +------------------ 2 files changed, 3 insertions(+), 29 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py index e7d00a8fdd89..c1207424a9a1 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -78,7 +78,7 @@ def _state_passing_fwd_kernel( out_ptrs = out_ptr + offs_m * stride_out_dim if HAS_INITSTATES: - initstates_ptrs = initstates_ptr + 0 * stride_initstates_batch \ + initstates_ptrs = initstates_ptr + stride_initstates_batch \ + pid_h * stride_initstates_head \ + offs_m * stride_initstates_dim @@ -90,16 +90,12 @@ def _state_passing_fwd_kernel( prev_seq_idx = 0 for c in range(nchunks): - chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + c) - new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) dA_cs = tl.load(dA_cs_ptr).to(tl.float32) - seq_idx = tl.load(seq_idx_ptr + chunk_seqlen_start * stride_seq_idx_seqlen) - - # we are started a new sequence + # we have started a new sequence if prev_seq_idx != seq_idx: if HAS_INITSTATES: initstates_ptrs = initstates_ptr + seq_idx * stride_initstates_batch \ @@ -112,7 +108,6 @@ def _state_passing_fwd_kernel( states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) prev_seq_idx = seq_idx - states = tl.exp(dA_cs) * states + new_states tl.store(out_ptrs, states, mask=offs_m < dim) states_ptrs += stride_states_chunk @@ -132,7 +127,6 @@ def _state_passing_fwd( chunk_offsets=None, ): batch, nchunks, nheads, dim = states.shape - assert batch == 1 if chunk_size is None: chunk_size = dA_cumsum.shape[-1] else: @@ -160,7 +154,6 @@ def _state_passing_fwd( out = torch.empty((batch, nchunks, nheads, dim), device=states.device, dtype=out_dtype) - grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads) with torch.cuda.device(states.device.index): _state_passing_fwd_kernel[grid]( diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index dd5961bc0553..010c8f25946c 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -126,8 +126,6 @@ class Mamba2AttentionMetadata: seq_idx_p: Optional[torch.Tensor] chunk_indices_p: Optional[torch.Tensor] chunk_offsets_p: Optional[torch.Tensor] - - # tpa cu_chunk_seqlen_p: Optional[torch.Tensor] last_chunk_p: Optional[torch.Tensor] @@ -164,8 +162,6 @@ def build(self, # currently we really only support the FlashAttention backend has_initial_states_p = None prep_initial_states = False - - cu_chunk_seqlen_p = None last_chunk_p = None @@ -176,11 +172,6 @@ def build(self, common_attn_metadata, decode_threshold=self.reorder_batch_threshold)) - #print("num_decodes: ", num_decodes) - #print("num_prefills: ", num_prefills) - #print("num_decode_tokens: ", num_decode_tokens) - #print("num_prefill_tokens: ", num_prefill_tokens) - # Compute seq_idx, chunk_indices and chunk_offsets for prefill only if num_prefills > 0: #[batch,] @@ -191,11 +182,9 @@ def build(self, has_initial_states_p = has_initial_states_cpu.to( query_start_loc.device) - query_start_loc_p = common_attn_metadata.query_start_loc[ -num_prefills - 1:] - num_decode_tokens - seq_idx_p = torch.repeat_interleave(torch.arange( num_prefills, dtype=torch.int32, @@ -204,25 +193,20 @@ def build(self, output_size=num_prefill_tokens) seq_idx_p.unsqueeze_(0) - num_computed_tokens_p = common_attn_metadata.num_computed_tokens_cpu[num_reqs - num_prefills:num_reqs] query_start_loc_p_cpu = common_attn_metadata.query_start_loc_cpu[ -num_prefills - 1:] - num_decode_tokens - #print("num_computed_tokens_p: ", num_computed_tokens_p) - #print("query_start_loc_p: ", query_start_loc_p) - + # TODO (tdoublep): Optimize the code cu_chunk_seqlen = [] last_chunk = [] seqlen_pos = 0 for req_idx in range(num_prefills): this_num_computed = num_computed_tokens_p[req_idx].item() this_new_tokens = query_start_loc_p_cpu[req_idx+1].item() - query_start_loc_p_cpu[req_idx].item() - #print(req_idx, this_num_computed, this_new_tokens) # if computed tokens are not chunk-aligned, use the first # chunk to finish it off - # TODO(tdoublep): I guess we need block size actually? if this_num_computed % self.chunk_size != 0: cu_chunk_seqlen.append(seqlen_pos) # how many tokens to finish the chunk? @@ -247,9 +231,6 @@ def build(self, cu_chunk_seqlen_p = torch.as_tensor(cu_chunk_seqlen, device=query_start_loc.device, dtype=torch.int32) last_chunk_p = torch.as_tensor(last_chunk, device=query_start_loc.device, dtype=torch.int32) - #print("cu_chunk_seqlen: ", cu_chunk_seqlen) - #print("cu_chunk_seqlen_p: ", cu_chunk_seqlen_p) - # We compute metadata for chunked prefill once at the top level # model forward and reuse them in mamba layers. If not needed, # they will be ignored inside mamba kernels. From dc85f7ea1748be77d1fb599aeb50d9237160224f Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 12 Sep 2025 15:06:47 -0400 Subject: [PATCH 24/37] cleanup Signed-off-by: Thomas Parnell --- vllm/v1/core/sched/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 101867f5cfc5..d1a6dd73e85c 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -598,7 +598,7 @@ def schedule(self) -> SchedulerOutput: structured_output_request_ids=structured_output_request_ids, grammar_bitmask=grammar_bitmask, ) - #print(scheduler_output) + # NOTE(Kuntai): this function is designed for multiple purposes: # 1. Plan the KV cache store # 2. Wrap up all the KV cache load / save ops into an opaque object From 5e827a635cebf277986a88f5c4555bd68b0ec2f6 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 12 Sep 2025 15:15:14 -0400 Subject: [PATCH 25/37] cleanup Signed-off-by: Thomas Parnell --- vllm/model_executor/layers/mamba/ops/ssd_state_passing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py index c1207424a9a1..93cb1b485c53 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -78,7 +78,7 @@ def _state_passing_fwd_kernel( out_ptrs = out_ptr + offs_m * stride_out_dim if HAS_INITSTATES: - initstates_ptrs = initstates_ptr + stride_initstates_batch \ + initstates_ptrs = initstates_ptr \ + pid_h * stride_initstates_head \ + offs_m * stride_initstates_dim From 42e4b27d9e1754808a8f7355d4dba7a75a6b9452 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 12 Sep 2025 15:36:25 -0400 Subject: [PATCH 26/37] cleanup Signed-off-by: Thomas Parnell --- .../layers/mamba/ops/ssd_chunk_scan.py | 138 +++++------------- 1 file changed, 39 insertions(+), 99 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index 4e116ae9fc8b..0aa12f7f138b 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -13,6 +13,7 @@ TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') + @triton.autotune( configs=[ triton.Config( @@ -188,48 +189,43 @@ def _chunk_scan_fwd_kernel( IS_TRITON_22: tl.constexpr, HAS_INITSTATES: tl.constexpr, ): - pid_bc = tl.program_id(axis=1).to(tl.int64) pid_c = pid_bc // batch pid_b = pid_bc - pid_c * batch - - # logical chunks = physical chunks - # always start from beginning - c_idx = pid_c - c_off = 0 - pid_h = tl.program_id(axis=2) num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) pid_m = tl.program_id(axis=0) // num_pid_n pid_n = tl.program_id(axis=0) % num_pid_n - cb_ptr += pid_b * stride_cb_batch + c_idx * stride_cb_chunk + ( + cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + ( pid_h // nheads_ngroups_ratio) * stride_cb_head - chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c) chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1) - x_ptr += pid_b * stride_x_batch + chunk_seqlen_start * stride_x_seqlen + pid_h * stride_x_head - dt_ptr += pid_b * stride_dt_batch + c_idx * stride_dt_chunk + pid_h * stride_dt_head - dA_cumsum_ptr += pid_b * stride_dA_cs_batch + c_idx * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head C_ptr += pid_b * stride_C_batch + chunk_seqlen_start * stride_C_seqlen + ( pid_h // nheads_ngroups_ratio) * stride_C_head # M-block offsets and prev states # - logic in next block may override these if there is an active offset - offs_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M) - #prev_states_ptr = states_ptr + pid_b * stride_states_batch + c_idx * stride_states_chunk + pid_h * stride_states_head - #prev_states_hdim = stride_states_hdim - #prev_states_dstate = stride_states_dstate - - chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start - + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) seq_idx_ptr += pid_b * stride_seq_idx_batch + chunk_seqlen_start * stride_seq_idx_seqlen seq_idx = tl.load(seq_idx_ptr) seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, - mask=c_idx >= 1, + mask=pid_c >= 1, other=-1) + if HAS_INITSTATES: + prev_states_ptr = initstates_ptr + seq_idx * stride_init_states_batch + pid_h * stride_init_states_head + prev_states_hdim = stride_init_states_hdim + prev_states_dstate = stride_init_states_dstate + else: + prev_states_ptr = states_ptr + (pid_c-1) * stride_states_chunk + pid_h * stride_states_head + prev_states_hdim = stride_states_hdim + prev_states_dstate = stride_states_dstate + + chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, @@ -241,110 +237,56 @@ def _chunk_scan_fwd_kernel( offs_out_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - - # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 offs_k_dstate = tl.arange( 0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate) - #scale_m = tl.where(seq_idx == seq_idx_prev, tl.exp(dA_cs_m), 0.0) scale_m = tl.exp(dA_cs_m) - if BLOCK_SIZE_DSTATE <= 128: - C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate), other=0.0) - if seq_idx != seq_idx_prev: - if HAS_INITSTATES: - # load from init states - init_states_ptrs = initstates_ptr + seq_idx * stride_init_states_batch \ - + pid_h * stride_init_states_head \ - + offs_n[None, :] * stride_init_states_hdim \ - + offs_k_dstate[:, None] * stride_init_states_dstate - prev_states = tl.load(init_states_ptrs, - mask=(offs_k_dstate[:, None] < dstate) & - (offs_n[None, :] < hdim), - other=0.0) - prev_states = prev_states.to(C_ptr.dtype.element_ty) - else: - # Set to zero - prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_N), dtype=tl.float32) - prev_states = prev_states.to(C_ptr.dtype.element_ty) + if (seq_idx != seq_idx_prev and HAS_INITSTATES) or pid_c > 0: + prev_states_ptrs = prev_states_ptr \ + + offs_n[None, :] * prev_states_hdim \ + + offs_k_dstate[:, None] * prev_states_dstate + prev_states = tl.load(prev_states_ptrs, + mask=(offs_k_dstate[:, None] < dstate) & + (offs_n[None, :] < hdim), + other=0.0) + prev_states = prev_states.to(C_ptr.dtype.element_ty) else: - if c_idx > 0: - tl.device_assert(c_idx < nchunks) - # Load from praevious chunk - states_ptrs = states_ptr + (c_idx-1) * stride_states_chunk \ - + pid_h * stride_states_head \ - + offs_n[None, :] * stride_states_hdim \ - + offs_k_dstate[:, None] * stride_states_dstate - prev_states = tl.load(states_ptrs, - mask=(offs_k_dstate[:, None] < dstate) & - (offs_n[None, :] < hdim), - other=0.0) - prev_states = prev_states.to(C_ptr.dtype.element_ty) - else: - # Set to zero - prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_N), dtype=tl.float32) - prev_states = prev_states.to(C_ptr.dtype.element_ty) + prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_N), dtype=C_ptr.dtype.element_ty) acc = tl.dot(C, prev_states) * scale_m[:, None] else: - offset_tpa = 0 + prev_states_ptrs = prev_states_ptr \ + + offs_n[None, :] * prev_states_hdim \ + + offs_k_dstate[:, None] * prev_states_dstate for k in range(0, dstate, BLOCK_SIZE_K): C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate - k), other=0.0) - # C = (C * scale_m[:, None]).to(C_ptr.dtype.element_ty) - if seq_idx != seq_idx_prev: - if HAS_INITSTATES: - # load from init states - init_states_ptrs = initstates_ptr + seq_idx * stride_init_states_batch \ - + pid_h * stride_init_states_head \ - + offs_n[None, :] * stride_init_states_hdim \ - + offs_k_dstate[:, None] * stride_init_states_dstate \ - + offset_tpa - prev_states = tl.load(init_states_ptrs, - mask=(offs_k_dstate[:, None] < dstate-k) & - (offs_n[None, :] < hdim), - other=0.0) - prev_states = prev_states.to(C_ptr.dtype.element_ty) - else: - # Set to zero - prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_K), dtype=tl.float32) - prev_states = prev_states.to(C_ptr.dtype.element_ty) + if (seq_idx != seq_idx_prev and HAS_INITSTATES) or pid_c > 0: + prev_states = tl.load(prev_states_ptrs, + mask=(offs_k_dstate[:, None] < dstate - k) & + (offs_n[None, :] < hdim), + other=0.0) + prev_states = prev_states.to(C_ptr.dtype.element_ty) else: - if c_idx > 0: - # Load from previous chunk - states_ptrs = states_ptr + (c_idx-1) * stride_states_chunk \ - + pid_h * stride_states_head \ - + offs_n[None, :] * stride_states_hdim \ - + offs_k_dstate[:, None] * stride_states_dstate \ - + offset_tpa - prev_states = tl.load(states_ptrs, - mask=(offs_k_dstate[:, None] < dstate-k) & - (offs_n[None, :] < hdim), - other=0.0) - prev_states = prev_states.to(C_ptr.dtype.element_ty) - else: - # Set to zero - prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_K), dtype=tl.float32) - prev_states = prev_states.to(C_ptr.dtype.element_ty) - + prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_K), dtype=C_ptr.dtype.element_ty) acc += tl.dot(C, prev_states) C_ptrs += BLOCK_SIZE_K - offset_tpa += BLOCK_SIZE_K - + prev_states_ptrs += BLOCK_SIZE_K acc *= scale_m[:, None] - offs_k = tl.arange(0, BLOCK_SIZE_K) + c_off + offs_k = tl.arange(0, BLOCK_SIZE_K) cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k) x_ptrs = x_ptr + (offs_k[:, None] * stride_x_seqlen + @@ -381,7 +323,7 @@ def _chunk_scan_fwd_kernel( dt_ptrs += BLOCK_SIZE_K * stride_dt_csize dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize - offs_out_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M) + offs_out_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) if HAS_D: @@ -416,11 +358,9 @@ def _chunk_scan_fwd_kernel( other=0.0).to(tl.float32) acc *= z * tl.sigmoid(z) - out_ptr += pid_b * stride_out_batch + chunk_seqlen_start * stride_out_seqlen + pid_h * stride_out_head out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] * stride_out_hdim) - tl.store(out_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & From d8591820e0e1ea842e3b90688b490829b860c4d3 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 12 Sep 2025 15:39:59 -0400 Subject: [PATCH 27/37] cleanup Signed-off-by: Thomas Parnell --- vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index 0aa12f7f138b..de920f59ec2d 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -133,7 +133,6 @@ def _chunk_scan_fwd_kernel( batch, seqlen, nheads_ngroups_ratio, - nchunks, # Strides stride_cb_batch, stride_cb_chunk, @@ -447,7 +446,6 @@ def _chunk_scan_fwd( batch, seqlen, nheads // ngroups, - nchunks, cb.stride(0), cb.stride(1), cb.stride(2), From 56b37c22e2608314ef61be789655b5578df81298 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 12 Sep 2025 15:41:32 -0400 Subject: [PATCH 28/37] cleanup Signed-off-by: Thomas Parnell --- vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index 4d4e593b21c4..eca98ff73e8b 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -569,7 +569,6 @@ def _chunk_cumsum_fwd(dt, dtype=torch.float32) grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H'])) - with torch.cuda.device(dt.device.index): _chunk_cumsum_fwd_kernel[grid_chunk_cs]( dt, From e21b4e633e03865b9b6c7469455d7005a9e23af9 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 12 Sep 2025 15:58:20 -0400 Subject: [PATCH 29/37] lint Signed-off-by: Thomas Parnell --- .../layers/mamba/ops/ssd_bmm.py | 4 +-- .../layers/mamba/ops/ssd_chunk_scan.py | 32 +++++++++++-------- .../layers/mamba/ops/ssd_chunk_state.py | 1 - .../layers/mamba/ops/ssd_combined.py | 5 ++- .../layers/mamba/ops/ssd_state_passing.py | 9 +++--- vllm/v1/attention/backends/mamba2_attn.py | 23 +++++++++---- 6 files changed, 41 insertions(+), 33 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py index 786721733af7..260f1e5239af 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py @@ -6,8 +6,6 @@ # ruff: noqa: E501,SIM102 -import math - import torch from vllm.triton_utils import tl, triton @@ -209,7 +207,7 @@ def _bmm_chunk_fwd(a, a = a.contiguous() if b.stride(-1) != 1 and b.stride(1) != 1: b = b.contiguous() - nchunks = len(cu_chunk_seqlens)-1 + nchunks = len(cu_chunk_seqlens) - 1 # Allocates output. out_dtype = a.dtype if output_dtype is None else output_dtype out = torch.empty( diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index de920f59ec2d..207c440b0ff6 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -212,15 +212,16 @@ def _chunk_scan_fwd_kernel( seq_idx_ptr += pid_b * stride_seq_idx_batch + chunk_seqlen_start * stride_seq_idx_seqlen seq_idx = tl.load(seq_idx_ptr) seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, - mask=pid_c >= 1, - other=-1) + mask=pid_c >= 1, + other=-1) if HAS_INITSTATES: prev_states_ptr = initstates_ptr + seq_idx * stride_init_states_batch + pid_h * stride_init_states_head prev_states_hdim = stride_init_states_hdim prev_states_dstate = stride_init_states_dstate else: - prev_states_ptr = states_ptr + (pid_c-1) * stride_states_chunk + pid_h * stride_states_head + prev_states_ptr = states_ptr + ( + pid_c - 1) * stride_states_chunk + pid_h * stride_states_head prev_states_hdim = stride_states_hdim prev_states_dstate = stride_states_dstate @@ -254,12 +255,13 @@ def _chunk_scan_fwd_kernel( + offs_n[None, :] * prev_states_hdim \ + offs_k_dstate[:, None] * prev_states_dstate prev_states = tl.load(prev_states_ptrs, - mask=(offs_k_dstate[:, None] < dstate) & - (offs_n[None, :] < hdim), - other=0.0) + mask=(offs_k_dstate[:, None] < dstate) & + (offs_n[None, :] < hdim), + other=0.0) prev_states = prev_states.to(C_ptr.dtype.element_ty) else: - prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_N), dtype=C_ptr.dtype.element_ty) + prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_N), + dtype=C_ptr.dtype.element_ty) acc = tl.dot(C, prev_states) * scale_m[:, None] @@ -273,13 +275,15 @@ def _chunk_scan_fwd_kernel( (offs_k_dstate[None, :] < dstate - k), other=0.0) if (seq_idx != seq_idx_prev and HAS_INITSTATES) or pid_c > 0: - prev_states = tl.load(prev_states_ptrs, - mask=(offs_k_dstate[:, None] < dstate - k) & - (offs_n[None, :] < hdim), - other=0.0) + prev_states = tl.load( + prev_states_ptrs, + mask=(offs_k_dstate[:, None] < dstate - k) & + (offs_n[None, :] < hdim), + other=0.0) prev_states = prev_states.to(C_ptr.dtype.element_ty) else: - prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_K), dtype=C_ptr.dtype.element_ty) + prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_K), + dtype=C_ptr.dtype.element_ty) acc += tl.dot(C, prev_states) C_ptrs += BLOCK_SIZE_K prev_states_ptrs += BLOCK_SIZE_K @@ -418,8 +422,8 @@ def _chunk_scan_fwd( else: out_x = None - grid = lambda META: ( - triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv( + grid = lambda META: (triton.cdiv( + chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv( headdim, META['BLOCK_SIZE_N']), batch * nchunks, nheads) z_strides = ((z.stride(0), z.stride(1), z.stride(2), z.stride(3)) if z is not None else (0, 0, 0, 0)) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index eca98ff73e8b..448c7970b64b 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -6,7 +6,6 @@ # ruff: noqa: E501 - import torch from vllm.triton_utils import tl, triton diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index 48d4c7e6da09..e04ff3da991d 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -14,8 +14,7 @@ from .ssd_bmm import _bmm_chunk_fwd from .ssd_chunk_scan import _chunk_scan_fwd -from .ssd_chunk_state import (_chunk_cumsum_fwd, _chunk_state_fwd, - chunk_state_varlen) +from .ssd_chunk_state import _chunk_cumsum_fwd, _chunk_state_fwd from .ssd_state_passing import _state_passing_fwd TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') @@ -165,12 +164,12 @@ def _mamba_chunk_scan_combined_fwd(x, initial_states=initial_states, out=out, ) + final_states = states[:, -1, ...] if cu_seqlens is None: return out_x, dt, dA_cumsum, states, final_states else: assert batch == 1, "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1" varlen_states = states[:, last_chunk, ...].clone().squeeze(0) - final_states = states[:, -1, ...] return out_x, dt, dA_cumsum, states, final_states, varlen_states diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py index 93cb1b485c53..3a3de30ba2f5 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -82,8 +82,7 @@ def _state_passing_fwd_kernel( + pid_h * stride_initstates_head \ + offs_m * stride_initstates_dim - states = tl.load(initstates_ptrs, - mask=offs_m < dim, + states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) else: states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) @@ -94,15 +93,15 @@ def _state_passing_fwd_kernel( new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) dA_cs = tl.load(dA_cs_ptr).to(tl.float32) - seq_idx = tl.load(seq_idx_ptr + chunk_seqlen_start * stride_seq_idx_seqlen) + seq_idx = tl.load(seq_idx_ptr + + chunk_seqlen_start * stride_seq_idx_seqlen) # we have started a new sequence if prev_seq_idx != seq_idx: if HAS_INITSTATES: initstates_ptrs = initstates_ptr + seq_idx * stride_initstates_batch \ + pid_h * stride_initstates_head \ + offs_m * stride_initstates_dim - states = tl.load(initstates_ptrs, - mask=offs_m < dim, + states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) else: states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index 010c8f25946c..9a060bff6d1f 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -9,12 +9,13 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.config import VllmConfig +from vllm.utils import cdiv from vllm.v1.attention.backends.mamba_attn import ( BaseMambaAttentionMetadataBuilder) from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, split_decodes_and_prefills) from vllm.v1.kv_cache_interface import AttentionSpec -from vllm.utils import cdiv + def _query_start_loc_to_chunk_indices_offsets( query_start_loc: torch.Tensor, chunk_size: int, @@ -193,7 +194,9 @@ def build(self, output_size=num_prefill_tokens) seq_idx_p.unsqueeze_(0) - num_computed_tokens_p = common_attn_metadata.num_computed_tokens_cpu[num_reqs - num_prefills:num_reqs] + num_computed_tokens_p = \ + common_attn_metadata.num_computed_tokens_cpu[ + num_reqs - num_prefills:num_reqs] query_start_loc_p_cpu = common_attn_metadata.query_start_loc_cpu[ -num_prefills - 1:] - num_decode_tokens @@ -203,14 +206,16 @@ def build(self, seqlen_pos = 0 for req_idx in range(num_prefills): this_num_computed = num_computed_tokens_p[req_idx].item() - this_new_tokens = query_start_loc_p_cpu[req_idx+1].item() - query_start_loc_p_cpu[req_idx].item() + this_new_tokens = query_start_loc_p_cpu[req_idx + 1].item( + ) - query_start_loc_p_cpu[req_idx].item() # if computed tokens are not chunk-aligned, use the first # chunk to finish it off if this_num_computed % self.chunk_size != 0: cu_chunk_seqlen.append(seqlen_pos) # how many tokens to finish the chunk? - chunk_len = cdiv(this_num_computed, self.chunk_size)*self.chunk_size - this_num_computed + chunk_len = cdiv(this_num_computed, self.chunk_size + ) * self.chunk_size - this_num_computed # we can only use at most this_new_tokens chunk_len = min(chunk_len, this_new_tokens) seqlen_pos += chunk_len @@ -224,12 +229,16 @@ def build(self, this_new_tokens -= chunk_len assert this_new_tokens == 0 - last_chunk.append(len(cu_chunk_seqlen)-1) + last_chunk.append(len(cu_chunk_seqlen) - 1) cu_chunk_seqlen.append(seqlen_pos) - cu_chunk_seqlen_p = torch.as_tensor(cu_chunk_seqlen, device=query_start_loc.device, dtype=torch.int32) - last_chunk_p = torch.as_tensor(last_chunk, device=query_start_loc.device, dtype=torch.int32) + cu_chunk_seqlen_p = torch.as_tensor(cu_chunk_seqlen, + device=query_start_loc.device, + dtype=torch.int32) + last_chunk_p = torch.as_tensor(last_chunk, + device=query_start_loc.device, + dtype=torch.int32) # We compute metadata for chunked prefill once at the top level # model forward and reuse them in mamba layers. If not needed, From 1b0f793a21f4f6fdaa285857e79a30518d2e55eb Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 18 Sep 2025 05:05:29 -0400 Subject: [PATCH 30/37] Fix bug in scan kernel when to reading previous state. Signed-off-by: Thomas Parnell --- .../layers/mamba/ops/ssd_chunk_scan.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index 207c440b0ff6..186f771a1018 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -215,7 +215,7 @@ def _chunk_scan_fwd_kernel( mask=pid_c >= 1, other=-1) - if HAS_INITSTATES: + if HAS_INITSTATES and (seq_idx != seq_idx_prev): prev_states_ptr = initstates_ptr + seq_idx * stride_init_states_batch + pid_h * stride_init_states_head prev_states_hdim = stride_init_states_hdim prev_states_dstate = stride_init_states_dstate @@ -250,7 +250,12 @@ def _chunk_scan_fwd_kernel( (offs_k_dstate[None, :] < dstate), other=0.0) - if (seq_idx != seq_idx_prev and HAS_INITSTATES) or pid_c > 0: + if not HAS_INITSTATES and (seq_idx != seq_idx_prev): + # if no init states AND starting a new sequence, we need zeros + prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_N), + dtype=C_ptr.dtype.element_ty) + else: + # otherwise read the previous state prev_states_ptrs = prev_states_ptr \ + offs_n[None, :] * prev_states_hdim \ + offs_k_dstate[:, None] * prev_states_dstate @@ -259,9 +264,6 @@ def _chunk_scan_fwd_kernel( (offs_n[None, :] < hdim), other=0.0) prev_states = prev_states.to(C_ptr.dtype.element_ty) - else: - prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_N), - dtype=C_ptr.dtype.element_ty) acc = tl.dot(C, prev_states) * scale_m[:, None] @@ -274,16 +276,16 @@ def _chunk_scan_fwd_kernel( mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate - k), other=0.0) - if (seq_idx != seq_idx_prev and HAS_INITSTATES) or pid_c > 0: + if not HAS_INITSTATES and (seq_idx != seq_idx_prev): + prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_K), + dtype=C_ptr.dtype.element_ty) + else: prev_states = tl.load( prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0) prev_states = prev_states.to(C_ptr.dtype.element_ty) - else: - prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_K), - dtype=C_ptr.dtype.element_ty) acc += tl.dot(C, prev_states) C_ptrs += BLOCK_SIZE_K prev_states_ptrs += BLOCK_SIZE_K From df8f0464d7f2908924d580b4018841ef2468e1a0 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Thu, 18 Sep 2025 14:49:25 -0400 Subject: [PATCH 31/37] Remove BLOCK_H=1 from list of tuneable configurations. Co-authored-by: Chih-Chieh-Yang Signed-off-by: Thomas Parnell --- vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index 448c7970b64b..4ad3b348658d 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -15,7 +15,6 @@ @triton.autotune( configs=[ - triton.Config({'BLOCK_SIZE_H': 1}), triton.Config({'BLOCK_SIZE_H': 2}), triton.Config({'BLOCK_SIZE_H': 4}), triton.Config({'BLOCK_SIZE_H': 8}), From 2cb4252d4c20a1973b95145d0b0c057d4d05181e Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 26 Sep 2025 16:41:58 -0400 Subject: [PATCH 32/37] Fix a few merge errors Signed-off-by: Thomas Parnell --- vllm/model_executor/layers/mamba/ops/ssd_bmm.py | 3 +-- vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py | 2 +- vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py | 4 ++-- vllm/model_executor/layers/mamba/ops/ssd_combined.py | 4 +--- vllm/model_executor/layers/mamba/ops/ssd_state_passing.py | 2 +- 5 files changed, 6 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py index cf0496e57007..1e607e36bc57 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py @@ -200,7 +200,6 @@ def _bmm_chunk_fwd(a, b = b.contiguous() nchunks = len(cu_chunk_seqlens) - 1 - # Allocates output. out_dtype = a.dtype if output_dtype is None else output_dtype out = torch.empty((nchunks, ngroups, chunk_size, chunk_size), @@ -219,7 +218,7 @@ def _bmm_chunk_fwd(a, b_ptr=b, out_ptr=out, seq_idx_ptr=seq_idx, - cu_chunk_seqlens=cu_chunk_seqlens, + cu_chunk_seqlens_ptr=cu_chunk_seqlens, seqlen=seqlen, chunk_size=chunk_size, K=k, diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index 9fc4043b4caa..be0e8bcfb37f 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -416,7 +416,7 @@ def _chunk_scan_fwd( chunk_indices_ptr=chunk_indices, chunk_offsets_ptr=chunk_offsets, chunk_meta_num=len(chunk_indices) if chunk_indices is not None else 0, - cu_chunk_seqlens=cu_chunk_seqlens, + cu_chunk_seqlens_ptr=cu_chunk_seqlens, chunk_size=chunk_size, hdim=headdim, dstate=dstate, diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index db0d266d34df..71b664f0bfee 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -556,7 +556,7 @@ def _chunk_cumsum_fwd(dt, dt_bias_ptr=dt_bias, dt_out_ptr=dt_out, dA_cumsum_ptr=dA_cumsum, - cu_chunk_seqlens=cu_chunk_seqlens, + cu_chunk_seqlens_ptr=cu_chunk_seqlens, seqlen=seqlen, nheads=nheads, chunk_size=chunk_size, @@ -616,7 +616,7 @@ def _chunk_state_fwd(B, states_ptr=states, dt_ptr=dt, dA_cumsum_ptr=dA_cumsum, - cu_chunk_seqlens=cu_chunk_seqlens, + cu_chunk_seqlens_ptr=cu_chunk_seqlens, seq_idx_ptr=seq_idx, hdim=headdim, dstate=dstate, diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index d8412f7101ba..60a1ec0110bd 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -163,9 +163,7 @@ def _mamba_chunk_scan_combined_fwd(x, initial_states=initial_states, ) - varlen_states = states[:, last_chunk, ...].clone().squeeze(0) - - return varlen_states + return states[last_chunk] def mamba_chunk_scan_combined_varlen( diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py index 8311c7380bf5..c776f15e3a8e 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -140,7 +140,7 @@ def _state_passing_fwd( chunk_offsets_ptr=chunk_offsets, chunk_meta_num=len(chunk_offsets) if chunk_offsets is not None else 0, - cu_chunk_seqlens=cu_chunk_seqlens, + cu_chunk_seqlens_ptr=cu_chunk_seqlens, dim=dim, nchunks=nchunks, seqlen=seqlen if seq_idx is not None else 0, From c2a4f8e4bb68de523dbcde25bc45ddf8a855b349 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 26 Sep 2025 16:49:03 -0400 Subject: [PATCH 33/37] Remove chunk_offsets and chunk_indices Signed-off-by: Thomas Parnell --- .../layers/mamba/mamba_mixer2.py | 4 - .../layers/mamba/ops/ssd_chunk_scan.py | 16 --- .../layers/mamba/ops/ssd_combined.py | 13 +-- .../layers/mamba/ops/ssd_state_passing.py | 6 -- vllm/v1/attention/backends/mamba2_attn.py | 102 +----------------- 5 files changed, 3 insertions(+), 138 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 31a4fe784ae7..3de89067b6a1 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -502,8 +502,6 @@ def forward_cuda( prep_initial_states = attn_metadata.prep_initial_states chunk_size = attn_metadata.chunk_size seq_idx_p = attn_metadata.seq_idx_p - chunk_indices_p = attn_metadata.chunk_indices_p - chunk_offsets_p = attn_metadata.chunk_offsets_p query_start_loc_p = attn_metadata.query_start_loc_p cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p last_chunk_p = attn_metadata.last_chunk_p @@ -636,8 +634,6 @@ def forward_cuda( z=None, dt_bias=self.dt_bias, seq_idx=seq_idx_p, - chunk_indices=chunk_indices_p, - chunk_offsets=chunk_offsets_p, cu_seqlens=query_start_loc_p, cu_chunk_seqlens=cu_chunk_seqlen_p, last_chunk=last_chunk_p, diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index be0e8bcfb37f..a32568c53b04 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -120,9 +120,6 @@ def _chunk_scan_fwd_kernel( states_ptr, D_ptr, initstates_ptr, - chunk_indices_ptr, - chunk_offsets_ptr, - chunk_meta_num, cu_chunk_seqlens_ptr, # Matrix dimensions chunk_size: tl.constexpr, @@ -361,8 +358,6 @@ def _chunk_scan_fwd( seq_idx, D=None, z=None, - chunk_indices=None, - chunk_offsets=None, initial_states=None, ): assert seq_idx is not None, "this implementation requires seq_idx" @@ -382,14 +377,6 @@ def _chunk_scan_fwd( assert states.shape == (nchunks, nheads, headdim, dstate) assert seq_idx.shape == (seqlen, ) - if initial_states is not None: - # with initial states, we need to take care of how - # seq_idx crosses the boundaries - assert chunk_indices is not None and chunk_offsets is not None, \ - "chunk_indices and chunk_offsets should have been set" - else: - chunk_indices, chunk_offsets = None, None - grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton .cdiv(headdim, META['BLOCK_SIZE_N']), nchunks, nheads) @@ -413,9 +400,6 @@ def _chunk_scan_fwd( states_ptr=states, D_ptr=D, initstates_ptr=initial_states, - chunk_indices_ptr=chunk_indices, - chunk_offsets_ptr=chunk_offsets, - chunk_meta_num=len(chunk_indices) if chunk_indices is not None else 0, cu_chunk_seqlens_ptr=cu_chunk_seqlens, chunk_size=chunk_size, hdim=headdim, diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index 60a1ec0110bd..0b880bda3046 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -36,8 +36,6 @@ def _mamba_chunk_scan_combined_fwd(x, dt_bias=None, initial_states=None, seq_idx=None, - chunk_indices=None, - chunk_offsets=None, cu_seqlens=None, cu_chunk_seqlens=None, last_chunk=None, @@ -108,7 +106,7 @@ def _mamba_chunk_scan_combined_fwd(x, # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries # (middle term of factorization of off-diag blocks; A terms) # - for handling chunked prefill, this requires i) initial_states - # ii) seq_idx iii) is_cont_batched and (iv) chunk_offsets to be all specified. + # ii) seq_idx to be all specified. # - When a new seq_idx is detected, we will stop passing the prev_state # and switch accordingly to the init_state corresponding to the new seq_idx. # - We will also make sure that the dA_cumsum is taken only from the start of the @@ -124,8 +122,7 @@ def _mamba_chunk_scan_combined_fwd(x, if initial_states is not None else None, # (batch, nheads, headdim*dstate) seq_idx=seq_idx, - out_dtype=state_dtype if state_dtype is not None else C.dtype, - chunk_offsets=chunk_offsets) + out_dtype=state_dtype if state_dtype is not None else C.dtype) states = rearrange(states, "... (p n) -> ... p n", n=dstate) # 4. Compute batched matrix multiply for C_j^T B_i terms @@ -158,8 +155,6 @@ def _mamba_chunk_scan_combined_fwd(x, seq_idx, D=D, z=z, - chunk_indices=chunk_indices, - chunk_offsets=chunk_offsets, initial_states=initial_states, ) @@ -182,8 +177,6 @@ def mamba_chunk_scan_combined_varlen( z=None, dt_bias=None, initial_states=None, - chunk_indices=None, - chunk_offsets=None, dt_softplus=False, dt_limit=(0.0, float("inf")), state_dtype=None, @@ -226,8 +219,6 @@ def mamba_chunk_scan_combined_varlen( dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, - chunk_indices=chunk_indices, - chunk_offsets=chunk_offsets, cu_seqlens=cu_seqlens, cu_chunk_seqlens=cu_chunk_seqlens, last_chunk=last_chunk, diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py index c776f15e3a8e..69560ab695a1 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -30,8 +30,6 @@ def _state_passing_fwd_kernel( dA_cs_ptr, initstates_ptr, seq_idx_ptr, - chunk_offsets_ptr, - chunk_meta_num, cu_chunk_seqlens_ptr, # Matrix dimensions dim: tl.constexpr, @@ -111,7 +109,6 @@ def _state_passing_fwd( dA_cumsum, cu_chunk_seqlens, seq_idx, - chunk_offsets, initial_states=None, out_dtype=None, ): @@ -137,9 +134,6 @@ def _state_passing_fwd( dA_cs_ptr=dA_cumsum, initstates_ptr=initial_states, seq_idx_ptr=seq_idx, - chunk_offsets_ptr=chunk_offsets, - chunk_meta_num=len(chunk_offsets) - if chunk_offsets is not None else 0, cu_chunk_seqlens_ptr=cu_chunk_seqlens, dim=dim, nchunks=nchunks, diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index a07e522762df..4e6998194db4 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import math from dataclasses import dataclass from typing import Optional @@ -18,91 +17,6 @@ from vllm.v1.kv_cache_interface import AttentionSpec -def _query_start_loc_to_chunk_indices_offsets( - query_start_loc: torch.Tensor, chunk_size: int, - total_seqlens: int) -> tuple[torch.Tensor, torch.Tensor]: - """ - Args: - query_start_loc (torch.Tensor): 1D tensor of cumulative sequence - lengths, shape (num_seqs + 1,). - The first element should be 0. Each entry represents the starting - index of a sequence in the flattened token array. - chunk_size (int): The size of each physical mamba chunk - (number of tokens per chunk). - total_seqlens (int): The total number of tokens in the batch. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - chunk_indices (torch.Tensor): 1D tensor of indices - indicating the physical chunk for each logical chunk. - - chunk_offsets (torch.Tensor): 1D tensor of offsets - indicating the starting index of each logical chunk within - its physical chunk. - - This function computes the chunk indices and offsets for the given - query_start_loc and chunk_size. Both are tensors of integers with length N, - where N is the number of logical (pseudo) chunks. - A logical chunk is a sequence of tokens that are all part of the same - sequence and are all in the same physical mamba chunk. - In other words, a logical chunk changes every time we cross a sequence - boundary or a physical mamba chunk boundary. - Logical chunks are needed to handle batched requests with initial states - (see _state_passing_fwd and _chunk_scan_fwd). - The chunk_indices tensor contains the index of the physical chunk for each - logical chunk. - The chunk_offsets tensor contains the offset (AKA starting index) of the - logical chunk in the physical chunk. - - Example: - query_start_loc = [0, 5, 10] - chunk_size = 8 - total_seqlens = 10 - -> chunk_indices = [0, 0, 1] - -> chunk_offsets = [0, 5, 0] - - In this example, we have 2 sequences, each with 5 tokens. The physical - chunk size is 8 tokens. - We have three logical chunks: - - the first logical chunk starts at token 0 in the first physical chunk - and contains all 5 tokens from the first sequence - - the second logical chunk starts at token 5 in the first physical chunk - and contains first 3 tokens from the second sequence - - the third logical chunk starts at token 0 in the second physical chunk - and contains the remaining 2 tokens from the second sequence - """ - - cu_seqlens = query_start_loc[1:] # remove prepended 0 - - # outputs will have length expansion of chunks that do not divide - # chunk_size - N = math.ceil(total_seqlens / chunk_size) + (cu_seqlens[:-1] % chunk_size - > 0).sum() - chunk_indices = torch.arange(N, - dtype=torch.int, - device=query_start_loc.device) - chunk_offsets = torch.zeros((N, ), - dtype=torch.int, - device=query_start_loc.device) - - p = 0 # num of insertions - for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]): - - # if does not divide chunk_size, then there is one chunk insertion - p += (s % chunk_size > 0) - - # get the dimensions - # - the + 1 for _e is to shift the boundary by one chunk - # - this shifting is not needed if chunk_size divides e - _s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size - > 0) - - # adjust indices and offsets - chunk_indices[_s:_e] -= p - chunk_offsets[_s] = s % chunk_size - - return chunk_indices, chunk_offsets - - class Mamba2AttentionBackend(AttentionBackend): @staticmethod @@ -126,8 +40,6 @@ class Mamba2AttentionMetadata: # the batch has no prefill request. has_initial_states_p: Optional[torch.Tensor] seq_idx_p: Optional[torch.Tensor] - chunk_indices_p: Optional[torch.Tensor] - chunk_offsets_p: Optional[torch.Tensor] cu_chunk_seqlen_p: Optional[torch.Tensor] last_chunk_p: Optional[torch.Tensor] @@ -158,7 +70,6 @@ def build(self, seq_lens = common_attn_metadata.seq_lens seq_idx_p = None - chunk_indices_p, chunk_offsets_p = None, None # Need flags to indicate if there are initial states # currently we really only support the FlashAttention backend has_initial_states_p = None @@ -176,7 +87,7 @@ def build(self, common_attn_metadata, decode_threshold=self.reorder_batch_threshold)) - # Compute seq_idx, chunk_indices and chunk_offsets for prefill only + # Compute seq_idx for prefill only if num_prefills > 0: #[batch,] has_initial_states_cpu = ( @@ -243,15 +154,6 @@ def build(self, device=query_start_loc_p.device, dtype=torch.int32) - # We compute metadata for chunked prefill once at the top level - # model forward and reuse them in mamba layers. If not needed, - # they will be ignored inside mamba kernels. - if prep_initial_states: - chunk_indices_p, chunk_offsets_p = ( - _query_start_loc_to_chunk_indices_offsets( - query_start_loc_p, self.chunk_size, - num_prefill_tokens)) - nums_dict, batch_ptr, token_chunk_offset_ptr = \ compute_causal_conv1d_metadata(query_start_loc_p) @@ -274,8 +176,6 @@ def build(self, chunk_size=self.chunk_size, has_initial_states_p=has_initial_states_p, seq_idx_p=seq_idx_p, - chunk_indices_p=chunk_indices_p, - chunk_offsets_p=chunk_offsets_p, state_indices_tensor=state_indices_tensor, cu_chunk_seqlen_p=cu_chunk_seqlen_p, last_chunk_p=last_chunk_p, From 76ce99c0e0701f06f77e64f3569cfd773cd75fc9 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 26 Sep 2025 17:03:57 -0400 Subject: [PATCH 34/37] clean up seq_idx Signed-off-by: Thomas Parnell --- vllm/model_executor/layers/mamba/ops/ssd_bmm.py | 8 -------- .../layers/mamba/ops/ssd_chunk_scan.py | 10 +++++----- .../layers/mamba/ops/ssd_chunk_state.py | 8 -------- .../model_executor/layers/mamba/ops/ssd_combined.py | 13 +++---------- .../layers/mamba/ops/ssd_state_passing.py | 8 +++----- vllm/v1/attention/backends/mamba2_attn.py | 13 ++++++------- 6 files changed, 17 insertions(+), 43 deletions(-) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py index 1e607e36bc57..41d3eba96e72 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py @@ -94,7 +94,6 @@ def _bmm_chunk_fwd_kernel( a_ptr, b_ptr, out_ptr, - seq_idx_ptr, cu_chunk_seqlens_ptr, # Matrix dimensions seqlen, @@ -111,7 +110,6 @@ def _bmm_chunk_fwd_kernel( stride_out_head: tl.int64, stride_outm: tl.int64, stride_outn: tl.constexpr, - stride_seq_idx_seqlen: tl.constexpr, # Meta-parameters IS_CAUSAL: tl.constexpr, dot_dtype: tl.constexpr, @@ -177,14 +175,12 @@ def _bmm_chunk_fwd(a, b, chunk_size, cu_chunk_seqlens, - seq_idx, causal=False, output_dtype=None): """ Argument: a: (seqlen, ngroups, k) b: (seqlen, ngroups, k) - seq_idx: (seqlen,). out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out. causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are guaranteed to be correct. Return: @@ -192,8 +188,6 @@ def _bmm_chunk_fwd(a, """ seqlen, ngroups, k = a.shape assert b.shape == a.shape - assert seq_idx is not None - assert seq_idx.shape == (seqlen, ) if a.stride(-1) != 1 and a.stride(0) != 1: a = a.contiguous() if b.stride(-1) != 1 and b.stride(0) != 1: @@ -217,7 +211,6 @@ def _bmm_chunk_fwd(a, a_ptr=a, b_ptr=b, out_ptr=out, - seq_idx_ptr=seq_idx, cu_chunk_seqlens_ptr=cu_chunk_seqlens, seqlen=seqlen, chunk_size=chunk_size, @@ -233,7 +226,6 @@ def _bmm_chunk_fwd(a, stride_out_head=out.stride(1), stride_outm=out.stride(-2), stride_outn=out.stride(-1), - stride_seq_idx_seqlen=seq_idx.stride(0), IS_CAUSAL=causal, dot_dtype=dot_dtype, ) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index a32568c53b04..e1e77e14f69d 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -147,7 +147,7 @@ def _chunk_scan_fwd_kernel( stride_dA_cs_chunk: tl.int64, stride_dA_cs_head: tl.int64, stride_dA_cs_csize: tl.constexpr, - stride_seq_idx_seqlen: tl.constexpr, + stride_seq_idx_chunk: tl.constexpr, stride_C_seqlen: tl.int64, stride_C_head: tl.int64, stride_C_dstate: tl.constexpr, @@ -191,9 +191,9 @@ def _chunk_scan_fwd_kernel( # - logic in next block may override these if there is an active offset offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - seq_idx_ptr += chunk_seqlen_start * stride_seq_idx_seqlen + seq_idx_ptr += pid_c * stride_seq_idx_chunk seq_idx = tl.load(seq_idx_ptr) - seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, + seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_chunk, mask=pid_c >= 1, other=-1) @@ -375,7 +375,7 @@ def _chunk_scan_fwd( assert dt.shape == (nheads, nchunks, chunk_size) assert dA_cumsum.shape == (nheads, nchunks, chunk_size) assert states.shape == (nchunks, nheads, headdim, dstate) - assert seq_idx.shape == (seqlen, ) + assert seq_idx.shape == (nchunks, ) grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton .cdiv(headdim, META['BLOCK_SIZE_N']), nchunks, nheads) @@ -425,7 +425,7 @@ def _chunk_scan_fwd( stride_dA_cs_chunk=dA_cumsum.stride(1), stride_dA_cs_head=dA_cumsum.stride(0), stride_dA_cs_csize=dA_cumsum.stride(2), - stride_seq_idx_seqlen=seq_idx.stride(0), + stride_seq_idx_chunk=seq_idx.stride(0), stride_C_seqlen=C.stride(0), stride_C_head=C.stride(1), stride_C_dstate=C.stride(2), diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index 71b664f0bfee..3a3e0f293459 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -192,7 +192,6 @@ def _chunk_state_fwd_kernel( dt_ptr, dA_cumsum_ptr, cu_chunk_seqlens_ptr, - seq_idx_ptr, # Matrix dimensions hdim: tl.constexpr, dstate: tl.constexpr, @@ -216,7 +215,6 @@ def _chunk_state_fwd_kernel( stride_dA_cs_head: tl.int64, stride_dA_cs_chunk: tl.int64, stride_dA_cs_csize: tl.constexpr, - stride_seq_idx_seqlen: tl.constexpr, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, @@ -585,7 +583,6 @@ def _chunk_state_fwd(B, dt, dA_cumsum, cu_chunk_seqlens, - seq_idx=None, states=None, states_in_fp32=True): seqlen, nheads, headdim = x.shape @@ -596,9 +593,6 @@ def _chunk_state_fwd(B, assert dt.shape == (nheads, nchunks, chunk_size) assert dA_cumsum.shape == dt.shape - assert seq_idx is not None - assert seq_idx.shape == (seqlen, ) - if states is not None: assert states.shape == (nchunks, nheads, headdim, dstate) else: @@ -617,7 +611,6 @@ def _chunk_state_fwd(B, dt_ptr=dt, dA_cumsum_ptr=dA_cumsum, cu_chunk_seqlens_ptr=cu_chunk_seqlens, - seq_idx_ptr=seq_idx, hdim=headdim, dstate=dstate, chunk_size=chunk_size, @@ -639,7 +632,6 @@ def _chunk_state_fwd(B, stride_dA_cs_head=dA_cumsum.stride(0), stride_dA_cs_chunk=dA_cumsum.stride(1), stride_dA_cs_csize=dA_cumsum.stride(2), - stride_seq_idx_seqlen=seq_idx.stride(0), ) return states diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index 0b880bda3046..e4353757480e 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -55,7 +55,7 @@ def _mamba_chunk_scan_combined_fwd(x, if D is not None: assert D.shape == (nheads, headdim) or D.shape == (nheads, ) if seq_idx is not None: - assert seq_idx.shape == (seqlen, ) + assert seq_idx.shape == (cu_chunk_seqlens.shape[0] - 1, ) if B.stride(-1) != 1: B = B.contiguous() if C.stride(-1) != 1: @@ -100,20 +100,14 @@ def _mamba_chunk_scan_combined_fwd(x, dt, dA_cumsum, cu_chunk_seqlens, - seq_idx=seq_idx, states_in_fp32=True) # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries # (middle term of factorization of off-diag blocks; A terms) - # - for handling chunked prefill, this requires i) initial_states + # - for handling chunked prefill, this requires i) initial_states and # ii) seq_idx to be all specified. # - When a new seq_idx is detected, we will stop passing the prev_state # and switch accordingly to the init_state corresponding to the new seq_idx. - # - We will also make sure that the dA_cumsum is taken only from the start of the - # sequence (hence we need the full dA_cumsum tensor and not just the values at chunk boundaries) - # - this will ensure that states will be updated with the rightmost flushed seq_idx - # of the previous chunk. This implies that the first chunk of states is either 0 - # or equal to init_states of the first example. states = _state_passing_fwd( rearrange(states, "... p n -> ... (p n)"), dA_cumsum, # (nheads, nchunks, chunk_size) @@ -130,7 +124,6 @@ def _mamba_chunk_scan_combined_fwd(x, B, chunk_size, cu_chunk_seqlens, - seq_idx=seq_idx, output_dtype=torch.float32) # 5. Scan and compute the diagonal blocks, taking into @@ -189,7 +182,7 @@ def mamba_chunk_scan_combined_varlen( B: (seqlen, ngroups, dstate) C: (seqlen, ngroups, dstate) chunk_size: int - seq_idx: (seqlen) + seq_idx: (nchunks) cu_seqlens: (batch + 1) out: (seqlen, nheads, headdim) preallocated output tensor D: (nheads, headdim) or (nheads,) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py index 69560ab695a1..f09af262cfc2 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -49,7 +49,7 @@ def _state_passing_fwd_kernel( stride_initstates_batch: tl.int64, stride_initstates_head: tl.int64, stride_initstates_dim: tl.constexpr, - stride_seq_idx_seqlen: tl.constexpr, + stride_seq_idx_chunk: tl.constexpr, # Meta-parameters HAS_INITSTATES: tl.constexpr, BLOCK_SIZE: tl.constexpr, @@ -78,12 +78,10 @@ def _state_passing_fwd_kernel( prev_seq_idx = 0 for c in range(nchunks): - chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + c) new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) dA_cs = tl.load(dA_cs_ptr).to(tl.float32) - seq_idx = tl.load(seq_idx_ptr + - chunk_seqlen_start * stride_seq_idx_seqlen) + seq_idx = tl.load(seq_idx_ptr + c * stride_seq_idx_chunk) # we have started a new sequence if prev_seq_idx != seq_idx: if HAS_INITSTATES: @@ -151,7 +149,7 @@ def _state_passing_fwd( stride_initstates_batch=initial_states_strides[0], stride_initstates_head=initial_states_strides[1], stride_initstates_dim=initial_states_strides[2], - stride_seq_idx_seqlen=seq_idx.stride(0), + stride_seq_idx_chunk=seq_idx.stride(0), HAS_INITSTATES=initial_states is not None, ) return out diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index 4e6998194db4..9e73ae7e2473 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -100,13 +100,6 @@ def build(self, query_start_loc_p = common_attn_metadata.query_start_loc[ -num_prefills - 1:] - num_decode_tokens - seq_idx_p = torch.repeat_interleave(torch.arange( - num_prefills, - dtype=torch.int32, - device=query_start_loc_p.device), - query_start_loc_p.diff(), - output_size=num_prefill_tokens) - num_computed_tokens_p = \ common_attn_metadata.num_computed_tokens_cpu[ num_reqs - num_prefills:num_reqs] @@ -115,6 +108,7 @@ def build(self, # TODO (tdoublep): Optimize the code cu_chunk_seqlen = [] + seq_idx = [] last_chunk = [] seqlen_pos = 0 for req_idx in range(num_prefills): @@ -125,6 +119,7 @@ def build(self, # if computed tokens are not chunk-aligned, use the first # chunk to finish it off if this_num_computed % self.chunk_size != 0: + seq_idx.append(req_idx) cu_chunk_seqlen.append(seqlen_pos) # how many tokens to finish the chunk? chunk_len = cdiv(this_num_computed, self.chunk_size @@ -136,6 +131,7 @@ def build(self, n_chunks = cdiv(this_new_tokens, self.chunk_size) for chunk in range(n_chunks): + seq_idx.append(req_idx) cu_chunk_seqlen.append(seqlen_pos) chunk_len = min(self.chunk_size, this_new_tokens) seqlen_pos += chunk_len @@ -146,6 +142,9 @@ def build(self, cu_chunk_seqlen.append(seqlen_pos) + seq_idx_p = torch.as_tensor(seq_idx, + device=query_start_loc_p.device, + dtype=torch.int32) cu_chunk_seqlen_p = torch.as_tensor( cu_chunk_seqlen, device=query_start_loc_p.device, From 51b756b33c11e6af91c67edeae18f9676d8f11bc Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Mon, 29 Sep 2025 14:24:39 -0400 Subject: [PATCH 35/37] Fix plamo2 Signed-off-by: Thomas Parnell --- vllm/model_executor/models/plamo2.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 03265b13de50..f31c5952e09b 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -260,9 +260,9 @@ def forward_cuda( prep_initial_states = attn_metadata.prep_initial_states chunk_size = attn_metadata.chunk_size seq_idx_p = attn_metadata.seq_idx_p - chunk_indices_p = attn_metadata.chunk_indices_p - chunk_offsets_p = attn_metadata.chunk_offsets_p query_start_loc_p = attn_metadata.query_start_loc_p + cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p + last_chunk_p = attn_metadata.last_chunk_p # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states) @@ -368,9 +368,9 @@ def forward_cuda( self.num_heads // self.tp_size, self.head_dim), dt_bias=self.dt_bias, seq_idx=seq_idx_p, - chunk_indices=chunk_indices_p, - chunk_offsets=chunk_offsets_p, cu_seqlens=query_start_loc_p, + cu_chunk_seqlens=cu_chunk_seqlen_p, + last_chunk=last_chunk_p, initial_states=initial_states, dt_softplus=True, dt_limit=(0.0, float("inf")), From 37ffa9250fcce1d3273d5cfd7c87e595a6670693 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Mon, 29 Sep 2025 14:43:54 -0400 Subject: [PATCH 36/37] Review comments Signed-off-by: Thomas Parnell --- .../layers/mamba/mamba_mixer2.py | 4 +- .../layers/mamba/ops/ssd_bmm.py | 2 + .../layers/mamba/ops/ssd_combined.py | 14 ++++--- vllm/v1/attention/backends/mamba2_attn.py | 42 +++++++++++++------ 4 files changed, 42 insertions(+), 20 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 3de89067b6a1..bfb0666d361f 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -504,7 +504,7 @@ def forward_cuda( seq_idx_p = attn_metadata.seq_idx_p query_start_loc_p = attn_metadata.query_start_loc_p cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p - last_chunk_p = attn_metadata.last_chunk_p + last_chunk_indices_p = attn_metadata.last_chunk_indices_p # 1. Gated MLP's linear projection projected_states, _ = self.in_proj(hidden_states) @@ -636,7 +636,7 @@ def forward_cuda( seq_idx=seq_idx_p, cu_seqlens=query_start_loc_p, cu_chunk_seqlens=cu_chunk_seqlen_p, - last_chunk=last_chunk_p, + last_chunk_indices=last_chunk_indices_p, initial_states=initial_states, dt_softplus=True, dt_limit=(0.0, float("inf")), diff --git a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py index 41d3eba96e72..15a72fc61261 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py @@ -181,6 +181,8 @@ def _bmm_chunk_fwd(a, Argument: a: (seqlen, ngroups, k) b: (seqlen, ngroups, k) + chunk_size: int + cu_chunk_seq_lens: (nchunks+1,) causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are guaranteed to be correct. Return: diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index e4353757480e..f3eb61d5840e 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -38,7 +38,7 @@ def _mamba_chunk_scan_combined_fwd(x, seq_idx=None, cu_seqlens=None, cu_chunk_seqlens=None, - last_chunk=None, + last_chunk_indices=None, dt_softplus=False, dt_limit=(0.0, float("inf")), state_dtype=None): @@ -151,7 +151,7 @@ def _mamba_chunk_scan_combined_fwd(x, initial_states=initial_states, ) - return states[last_chunk] + return states[last_chunk_indices] def mamba_chunk_scan_combined_varlen( @@ -163,7 +163,7 @@ def mamba_chunk_scan_combined_varlen( chunk_size, cu_seqlens, cu_chunk_seqlens, - last_chunk, + last_chunk_indices, seq_idx, out, D=None, @@ -182,8 +182,10 @@ def mamba_chunk_scan_combined_varlen( B: (seqlen, ngroups, dstate) C: (seqlen, ngroups, dstate) chunk_size: int - seq_idx: (nchunks) - cu_seqlens: (batch + 1) + cu_seqlens: (batch + 1,) + cu_chunk_seqlens: (nchunks + 1,) + last_chunk_indices: (batch,) + seq_idx: (nchunks,) out: (seqlen, nheads, headdim) preallocated output tensor D: (nheads, headdim) or (nheads,) z: (seqlen, nheads, headdim) @@ -214,7 +216,7 @@ def mamba_chunk_scan_combined_varlen( seq_idx=seq_idx, cu_seqlens=cu_seqlens, cu_chunk_seqlens=cu_chunk_seqlens, - last_chunk=last_chunk, + last_chunk_indices=last_chunk_indices, dt_softplus=dt_softplus, dt_limit=dt_limit, state_dtype=state_dtype) diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index 9e73ae7e2473..e4f16f37a430 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -40,8 +40,16 @@ class Mamba2AttentionMetadata: # the batch has no prefill request. has_initial_states_p: Optional[torch.Tensor] seq_idx_p: Optional[torch.Tensor] + + # cu_chunk_seqlen_p is a tensor of shape (nchunks+1,) that contains, for + # each chunk, its offests into the varlen sequence dimension. It is defined + # such that the i-th chunk contains tokens from cu_chunk_seqlen_p[i] to + # cu_chunk_seqlen_p[i+1]. cu_chunk_seqlen_p: Optional[torch.Tensor] - last_chunk_p: Optional[torch.Tensor] + + # last_chunk_indices_p is a tensor of shape (batch,) that contains the + # index of the last chunk for every sequence in the (prefill) batch. + last_chunk_indices_p: Optional[torch.Tensor] state_indices_tensor: torch.Tensor # shape: [batch,] @@ -66,16 +74,16 @@ def build(self, common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False) -> Mamba2AttentionMetadata: num_reqs = common_attn_metadata.num_reqs - query_start_loc_p = None seq_lens = common_attn_metadata.seq_lens + query_start_loc_p = None seq_idx_p = None + cu_chunk_seqlen_p = None + last_chunk_indices_p = None + # Need flags to indicate if there are initial states - # currently we really only support the FlashAttention backend has_initial_states_p = None prep_initial_states = False - cu_chunk_seqlen_p = None - last_chunk_p = None # for causal_conv1d nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None @@ -106,10 +114,19 @@ def build(self, query_start_loc_p_cpu = common_attn_metadata.query_start_loc_cpu[ -num_prefills - 1:] - num_decode_tokens - # TODO (tdoublep): Optimize the code + # The code below carefully constructs the chunks such that: + # 1. Chunks contain tokens from a *single* sequence only. + # 2. For every sequence, we are guaranteed that we can + # retrieve the mamba state *every* chunk_size tokens. + # Constraint (1) dramatically simplifies the mamba2 kernels. + # Constraint (2) dramatically simplifies the implementation + # of prefix caching for mamba2 (wip). We need to take care + # of the interaction with chunked prefill in order to + # satisfy constraint (2). + # TODO (tdoublep): This code could probably be optimized. cu_chunk_seqlen = [] seq_idx = [] - last_chunk = [] + last_chunk_indices = [] seqlen_pos = 0 for req_idx in range(num_prefills): this_num_computed = num_computed_tokens_p[req_idx].item() @@ -138,7 +155,7 @@ def build(self, this_new_tokens -= chunk_len assert this_new_tokens == 0 - last_chunk.append(len(cu_chunk_seqlen) - 1) + last_chunk_indices.append(len(cu_chunk_seqlen) - 1) cu_chunk_seqlen.append(seqlen_pos) @@ -149,9 +166,10 @@ def build(self, cu_chunk_seqlen, device=query_start_loc_p.device, dtype=torch.int32) - last_chunk_p = torch.as_tensor(last_chunk, - device=query_start_loc_p.device, - dtype=torch.int32) + last_chunk_indices_p = torch.as_tensor( + last_chunk_indices, + device=query_start_loc_p.device, + dtype=torch.int32) nums_dict, batch_ptr, token_chunk_offset_ptr = \ compute_causal_conv1d_metadata(query_start_loc_p) @@ -177,7 +195,7 @@ def build(self, seq_idx_p=seq_idx_p, state_indices_tensor=state_indices_tensor, cu_chunk_seqlen_p=cu_chunk_seqlen_p, - last_chunk_p=last_chunk_p, + last_chunk_indices_p=last_chunk_indices_p, nums_dict=nums_dict, batch_ptr=batch_ptr, token_chunk_offset_ptr=token_chunk_offset_ptr, From 29b42ccb770e2d8cc0a0ad1958a4d407e84c7475 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Mon, 29 Sep 2025 15:31:25 -0400 Subject: [PATCH 37/37] Fix plamo2 again Signed-off-by: Thomas Parnell --- vllm/model_executor/models/plamo2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index f31c5952e09b..8234d40e94ab 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -262,7 +262,7 @@ def forward_cuda( seq_idx_p = attn_metadata.seq_idx_p query_start_loc_p = attn_metadata.query_start_loc_p cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p - last_chunk_p = attn_metadata.last_chunk_p + last_chunk_indices_p = attn_metadata.last_chunk_indices_p # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states) @@ -370,7 +370,7 @@ def forward_cuda( seq_idx=seq_idx_p, cu_seqlens=query_start_loc_p, cu_chunk_seqlens=cu_chunk_seqlen_p, - last_chunk=last_chunk_p, + last_chunk_indices=last_chunk_indices_p, initial_states=initial_states, dt_softplus=True, dt_limit=(0.0, float("inf")),