diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 6dd09fad7a90..bfb0666d361f 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -502,9 +502,9 @@ def forward_cuda( prep_initial_states = attn_metadata.prep_initial_states chunk_size = attn_metadata.chunk_size seq_idx_p = attn_metadata.seq_idx_p - chunk_indices_p = attn_metadata.chunk_indices_p - chunk_offsets_p = attn_metadata.chunk_offsets_p query_start_loc_p = attn_metadata.query_start_loc_p + cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p + last_chunk_indices_p = attn_metadata.last_chunk_indices_p # 1. Gated MLP's linear projection projected_states, _ = self.in_proj(hidden_states) @@ -634,9 +634,9 @@ def forward_cuda( z=None, dt_bias=self.dt_bias, seq_idx=seq_idx_p, - chunk_indices=chunk_indices_p, - chunk_offsets=chunk_offsets_p, cu_seqlens=query_start_loc_p, + cu_chunk_seqlens=cu_chunk_seqlen_p, + last_chunk_indices=last_chunk_indices_p, initial_states=initial_states, dt_softplus=True, dt_limit=(0.0, float("inf")), diff --git a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py index 601b71ab2a51..15a72fc61261 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_bmm.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_bmm.py @@ -6,8 +6,6 @@ # ruff: noqa: E501,SIM102 -import math - import torch from vllm.triton_utils import tl, triton @@ -96,7 +94,7 @@ def _bmm_chunk_fwd_kernel( a_ptr, b_ptr, out_ptr, - seq_idx_ptr, + cu_chunk_seqlens_ptr, # Matrix dimensions seqlen, chunk_size: tl.constexpr, @@ -112,7 +110,6 @@ def _bmm_chunk_fwd_kernel( stride_out_head: tl.int64, stride_outm: tl.int64, stride_outn: tl.constexpr, - stride_seq_idx_seqlen: tl.constexpr, # Meta-parameters IS_CAUSAL: tl.constexpr, dot_dtype: tl.constexpr, @@ -129,10 +126,12 @@ def _bmm_chunk_fwd_kernel( if IS_CAUSAL: if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M: return - a_ptr += pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head - b_ptr += pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head - seq_idx_ptr += pid_c * chunk_size * stride_seq_idx_seqlen + chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c) + chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1) + + a_ptr += chunk_seqlen_start * stride_a_seqlen + pid_h * stride_a_head + b_ptr += chunk_seqlen_start * stride_b_seqlen + pid_h * stride_b_head offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) @@ -141,7 +140,7 @@ def _bmm_chunk_fwd_kernel( offs_k[None, :] * stride_ak) b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen) - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) @@ -162,16 +161,6 @@ def _bmm_chunk_fwd_kernel( offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - # Zero out the results that are not from the same request - # in the varlen batch - seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, - mask=offs_m < chunk_size_limit, - other=-1) - seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, - mask=offs_n < chunk_size_limit, - other=-2) - acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0) - out = acc.to(out_ptr.dtype.element_ty) out_ptr += pid_c * stride_out_chunk + pid_h * stride_out_head out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + @@ -182,12 +171,18 @@ def _bmm_chunk_fwd_kernel( (offs_n[None, :] < chunk_size)) -def _bmm_chunk_fwd(a, b, chunk_size, seq_idx, causal=False, output_dtype=None): +def _bmm_chunk_fwd(a, + b, + chunk_size, + cu_chunk_seqlens, + causal=False, + output_dtype=None): """ Argument: a: (seqlen, ngroups, k) b: (seqlen, ngroups, k) - seq_idx: (seqlen,). out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out. + chunk_size: int + cu_chunk_seq_lens: (nchunks+1,) causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are guaranteed to be correct. Return: @@ -195,14 +190,12 @@ def _bmm_chunk_fwd(a, b, chunk_size, seq_idx, causal=False, output_dtype=None): """ seqlen, ngroups, k = a.shape assert b.shape == a.shape - assert seq_idx is not None - assert seq_idx.shape == (seqlen, ) if a.stride(-1) != 1 and a.stride(0) != 1: a = a.contiguous() if b.stride(-1) != 1 and b.stride(0) != 1: b = b.contiguous() - nchunks = math.ceil(seqlen / chunk_size) + nchunks = len(cu_chunk_seqlens) - 1 # Allocates output. out_dtype = a.dtype if output_dtype is None else output_dtype out = torch.empty((nchunks, ngroups, chunk_size, chunk_size), @@ -220,7 +213,7 @@ def _bmm_chunk_fwd(a, b, chunk_size, seq_idx, causal=False, output_dtype=None): a_ptr=a, b_ptr=b, out_ptr=out, - seq_idx_ptr=seq_idx, + cu_chunk_seqlens_ptr=cu_chunk_seqlens, seqlen=seqlen, chunk_size=chunk_size, K=k, @@ -235,7 +228,6 @@ def _bmm_chunk_fwd(a, b, chunk_size, seq_idx, causal=False, output_dtype=None): stride_out_head=out.stride(1), stride_outm=out.stride(-2), stride_outn=out.stride(-1), - stride_seq_idx_seqlen=seq_idx.stride(0), IS_CAUSAL=causal, dot_dtype=dot_dtype, ) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py index add72617fcea..e1e77e14f69d 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py @@ -120,9 +120,7 @@ def _chunk_scan_fwd_kernel( states_ptr, D_ptr, initstates_ptr, - chunk_indices_ptr, - chunk_offsets_ptr, - chunk_meta_num, + cu_chunk_seqlens_ptr, # Matrix dimensions chunk_size: tl.constexpr, hdim: tl.constexpr, @@ -149,7 +147,7 @@ def _chunk_scan_fwd_kernel( stride_dA_cs_chunk: tl.int64, stride_dA_cs_head: tl.int64, stride_dA_cs_csize: tl.constexpr, - stride_seq_idx_seqlen: tl.constexpr, + stride_seq_idx_chunk: tl.constexpr, stride_C_seqlen: tl.int64, stride_C_head: tl.int64, stride_C_dstate: tl.constexpr, @@ -175,170 +173,107 @@ def _chunk_scan_fwd_kernel( HAS_INITSTATES: tl.constexpr, ): pid_c = tl.program_id(axis=1).to(tl.int64) - if not HAS_INITSTATES: - c_idx = pid_c - c_off = 0 - else: - c_idx = tl.load(chunk_indices_ptr + pid_c, mask=pid_c > -1, other=0) - c_off = tl.load(chunk_offsets_ptr + pid_c, mask=pid_c > -1, other=0) - pid_h = tl.program_id(axis=2) num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) pid_m = tl.program_id(axis=0) // num_pid_n pid_n = tl.program_id(axis=0) % num_pid_n - cb_ptr += c_idx * stride_cb_chunk + (pid_h // + cb_ptr += pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head - x_ptr += c_idx * chunk_size * stride_x_seqlen + pid_h * stride_x_head - dt_ptr += c_idx * stride_dt_chunk + pid_h * stride_dt_head - dA_cumsum_ptr += c_idx * stride_dA_cs_chunk + pid_h * stride_dA_cs_head - C_ptr += c_idx * chunk_size * stride_C_seqlen + ( + chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c) + chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1) + x_ptr += chunk_seqlen_start * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + C_ptr += chunk_seqlen_start * stride_C_seqlen + ( pid_h // nheads_ngroups_ratio) * stride_C_head # M-block offsets and prev states # - logic in next block may override these if there is an active offset - offs_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M) - prev_states_ptr = states_ptr + c_idx * stride_states_chunk + pid_h * stride_states_head - prev_states_hdim = stride_states_hdim - prev_states_dstate = stride_states_dstate - - chunk_size_limit = min(chunk_size, seqlen - c_idx * chunk_size) - - seq_idx_ptr += c_idx * chunk_size * stride_seq_idx_seqlen - # - we only need seq_idx_prev to be aligned to chunk boundary - seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, - mask=c_idx >= 1, - other=0) - - if HAS_INITSTATES: - # if there are init states, we only need seq_idx_m to point - # what is the current seq_idx - - # get current seq idx - if (pid_m * BLOCK_SIZE_M + c_off) < chunk_size_limit: - seq_idx_m = tl.load( - seq_idx_ptr + - (pid_m * BLOCK_SIZE_M + c_off) * stride_seq_idx_seqlen, ) - - # - recall that in ssd_state_passing, for the case c_off == 0 - # i.e., the very first sequence, we made states_ptr hold its initial state - # so this edge case is taken care of - if ((c_off == 0) and (seq_idx_prev != seq_idx_m - ) # if a seq is changed exactly on boundary - or (c_off > 0) # implies a new example (pseudo chunk) - ): + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + + seq_idx_ptr += pid_c * stride_seq_idx_chunk + seq_idx = tl.load(seq_idx_ptr) + seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_chunk, + mask=pid_c >= 1, + other=-1) + + if HAS_INITSTATES and (seq_idx != seq_idx_prev): + prev_states_ptr = initstates_ptr + seq_idx * stride_init_states_batch + pid_h * stride_init_states_head + prev_states_hdim = stride_init_states_hdim + prev_states_dstate = stride_init_states_dstate + else: + prev_states_ptr = states_ptr + ( + pid_c - 1) * stride_states_chunk + pid_h * stride_states_head + prev_states_hdim = stride_states_hdim + prev_states_dstate = stride_states_dstate - # - replace prev_states_ptr with init_states - prev_states_ptr = initstates_ptr + seq_idx_m * stride_init_states_batch + pid_h * stride_init_states_head - prev_states_hdim = stride_init_states_hdim # override strides - prev_states_dstate = stride_init_states_dstate + chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) - # - handle chunk state limit - if HAS_INITSTATES: - # have to split this if otherwise compilation will have problems - dA_cs_m_boundary = 0.0 - - # get the c_idx for the next (logica) chunk - c_idx_n = tl.load( - chunk_indices_ptr + (pid_c + 1), - mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num, - other=-1 # to trigger different chunk - ) - - # - there are things to consider - # A. if c_off > 0 then we need to move the dA_cs boundary to ensure correct - # contribution of past states - # B. if c_off_n < chunk_size_limit, then we need to adjust this so as not to - # encroach into the next sequence, where c_off_n is the offset of the next - # (logical) chunk. - # An equivalent check for B is c_idx == c_idx_n, where there is repetition in - # (logical) chunk indices. - - if (c_idx == c_idx_n) or c_off > 0: - - # get the next offset - c_off_n = tl.load(chunk_offsets_ptr + (pid_c + 1), - mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num, - other=chunk_size) - - # in this case, adjust down the chunk_size_limit - if c_idx == c_idx_n: - chunk_size_limit = min(c_off_n, chunk_size_limit) - - # get the cs at the offset boundary - # - c_off == 0 is a passthrough - # - We need dA_cs at the boundary, defined by c_off - no need - # to increase pointer by pid_m (it is a constant offset, - # i.e. the same for all blocks) - dA_cs_m_boundary = tl.load( - dA_cumsum_ptr + (c_off - 1) * stride_dA_cs_csize, - mask=(((c_off - 1) > -1) and ((c_off) < chunk_size)), - other=0.0).to(tl.float32) - else: - # - handle seq idx when HAS_INITSTATES==False - seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, - mask=offs_m < chunk_size_limit, - other=-1) - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - # Without the if (pid_c > -1), with Triton 2.1.0, I get - # Assertion `!(srcMmaLayout && dstMmaLayout) && "Unexpected mma -> mm a layout conversion"' failed. - # With Triton 2.2.0, this works - if IS_TRITON_22 or c_idx > -1: - # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 - offs_k_dstate = tl.arange( - 0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) - C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + - offs_k_dstate[None, :] * stride_C_dstate) + offs_out_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - prev_states_ptrs = prev_states_ptr + ( - offs_n[None, :] * prev_states_hdim + - offs_k_dstate[:, None] * prev_states_dstate) + # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 + offs_k_dstate = tl.arange( + 0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) + C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + + offs_k_dstate[None, :] * stride_C_dstate) + + scale_m = tl.exp(dA_cs_m) + if BLOCK_SIZE_DSTATE <= 128: + C = tl.load(C_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) & + (offs_k_dstate[None, :] < dstate), + other=0.0) - if not HAS_INITSTATES: - # - this is for continuous batching where there is no init states - scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0) + if not HAS_INITSTATES and (seq_idx != seq_idx_prev): + # if no init states AND starting a new sequence, we need zeros + prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_N), + dtype=C_ptr.dtype.element_ty) else: - # - if there is initstates, we will rely on prev_states, no zeroing - # required. - scale_m = tl.exp(dA_cs_m - dA_cs_m_boundary) - - if BLOCK_SIZE_DSTATE <= 128: - C = tl.load(C_ptrs, - mask=(offs_m[:, None] < chunk_size_limit) & - (offs_k_dstate[None, :] < dstate), - other=0.0) - + # otherwise read the previous state + prev_states_ptrs = prev_states_ptr \ + + offs_n[None, :] * prev_states_hdim \ + + offs_k_dstate[:, None] * prev_states_dstate prev_states = tl.load(prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0) prev_states = prev_states.to(C_ptr.dtype.element_ty) - acc = tl.dot(C, prev_states) * scale_m[:, None] - else: - for k in range(0, dstate, BLOCK_SIZE_K): - C = tl.load(C_ptrs, - mask=(offs_m[:, None] < chunk_size_limit) & - (offs_k_dstate[None, :] < dstate - k), - other=0.0) - # C = (C * scale_m[:, None]).to(C_ptr.dtype.element_ty) + + acc = tl.dot(C, prev_states) * scale_m[:, None] + + else: + prev_states_ptrs = prev_states_ptr \ + + offs_n[None, :] * prev_states_hdim \ + + offs_k_dstate[:, None] * prev_states_dstate + for k in range(0, dstate, BLOCK_SIZE_K): + C = tl.load(C_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) & + (offs_k_dstate[None, :] < dstate - k), + other=0.0) + if not HAS_INITSTATES and (seq_idx != seq_idx_prev): + prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_K), + dtype=C_ptr.dtype.element_ty) + else: prev_states = tl.load( prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0) prev_states = prev_states.to(C_ptr.dtype.element_ty) - acc += tl.dot(C, prev_states) - C_ptrs += BLOCK_SIZE_K - prev_states_ptrs += BLOCK_SIZE_K - acc *= scale_m[:, None] + acc += tl.dot(C, prev_states) + C_ptrs += BLOCK_SIZE_K + prev_states_ptrs += BLOCK_SIZE_K + acc *= scale_m[:, None] - offs_k = tl.arange(0, BLOCK_SIZE_K) + c_off + offs_k = tl.arange(0, BLOCK_SIZE_K) cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k) x_ptrs = x_ptr + (offs_k[:, None] * stride_x_seqlen + @@ -375,7 +310,7 @@ def _chunk_scan_fwd_kernel( dt_ptrs += BLOCK_SIZE_K * stride_dt_csize dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize - offs_out_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M) + offs_out_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) if HAS_D: @@ -393,7 +328,7 @@ def _chunk_scan_fwd_kernel( acc += x_residual * D if HAS_Z: - z_ptr += c_idx * chunk_size * stride_z_seqlen + pid_h * stride_z_head + z_ptr += chunk_seqlen_start * stride_z_seqlen + pid_h * stride_z_head z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :]) z = tl.load(z_ptrs, @@ -402,7 +337,7 @@ def _chunk_scan_fwd_kernel( other=0.0).to(tl.float32) acc *= z * tl.sigmoid(z) - out_ptr += c_idx * chunk_size * stride_out_seqlen + pid_h * stride_out_head + out_ptr += chunk_seqlen_start * stride_out_seqlen + pid_h * stride_out_head out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] * stride_out_hdim) tl.store(out_ptrs, @@ -418,12 +353,11 @@ def _chunk_scan_fwd( dA_cumsum, C, states, + cu_chunk_seqlens, out, seq_idx, D=None, z=None, - chunk_indices=None, - chunk_offsets=None, initial_states=None, ): assert seq_idx is not None, "this implementation requires seq_idx" @@ -441,20 +375,10 @@ def _chunk_scan_fwd( assert dt.shape == (nheads, nchunks, chunk_size) assert dA_cumsum.shape == (nheads, nchunks, chunk_size) assert states.shape == (nchunks, nheads, headdim, dstate) - assert seq_idx.shape == (seqlen, ) - - if initial_states is not None: - # with initial states, we need to take care of how - # seq_idx crosses the boundaries - assert chunk_indices is not None and chunk_offsets is not None, \ - "chunk_indices and chunk_offsets should have been set" - else: - chunk_indices, chunk_offsets = None, None + assert seq_idx.shape == (nchunks, ) - grid = lambda META: ( - triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv( - headdim, META['BLOCK_SIZE_N']), nchunks - if chunk_offsets is None else len(chunk_offsets), nheads) + grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton + .cdiv(headdim, META['BLOCK_SIZE_N']), nchunks, nheads) z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0)) @@ -476,9 +400,7 @@ def _chunk_scan_fwd( states_ptr=states, D_ptr=D, initstates_ptr=initial_states, - chunk_indices_ptr=chunk_indices, - chunk_offsets_ptr=chunk_offsets, - chunk_meta_num=len(chunk_indices) if chunk_indices is not None else 0, + cu_chunk_seqlens_ptr=cu_chunk_seqlens, chunk_size=chunk_size, hdim=headdim, dstate=dstate, @@ -503,7 +425,7 @@ def _chunk_scan_fwd( stride_dA_cs_chunk=dA_cumsum.stride(1), stride_dA_cs_head=dA_cumsum.stride(0), stride_dA_cs_csize=dA_cumsum.stride(2), - stride_seq_idx_seqlen=seq_idx.stride(0), + stride_seq_idx_chunk=seq_idx.stride(0), stride_C_seqlen=C.stride(0), stride_C_head=C.stride(1), stride_C_dstate=C.stride(2), diff --git a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py index 8ee41f2cbc1b..3a3e0f293459 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py @@ -6,8 +6,6 @@ # ruff: noqa: E501 -import math - import torch from vllm.triton_utils import tl, triton @@ -34,6 +32,7 @@ def _chunk_cumsum_fwd_kernel( dt_bias_ptr, dt_out_ptr, dA_cumsum_ptr, + cu_chunk_seqlens_ptr, # Matrix dimension seqlen, nheads: tl.constexpr, @@ -61,7 +60,11 @@ def _chunk_cumsum_fwd_kernel( # https://github.com/triton-lang/triton/issues/1058 pid_c = tl.program_id(axis=0).to(tl.int64) pid_h = tl.program_id(axis=1) - dt_ptr += pid_c * chunk_size * stride_dt_seqlen + + chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c) + chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1) + + dt_ptr += chunk_seqlen_start * stride_dt_seqlen dt_out_ptr += pid_c * stride_dt_out_chunk dA_cumsum_ptr += pid_c * stride_dA_cs_chunk @@ -74,7 +77,7 @@ def _chunk_cumsum_fwd_kernel( offs_c[None, :] * stride_dt_out_csize) dA_cs_ptrs = dA_cumsum_ptr + (offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize) - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start dt = tl.load(dt_ptrs, mask=(offs_h[:, None] < nheads) & @@ -188,7 +191,7 @@ def _chunk_state_fwd_kernel( states_ptr, dt_ptr, dA_cumsum_ptr, - seq_idx_ptr, + cu_chunk_seqlens_ptr, # Matrix dimensions hdim: tl.constexpr, dstate: tl.constexpr, @@ -212,7 +215,6 @@ def _chunk_state_fwd_kernel( stride_dA_cs_head: tl.int64, stride_dA_cs_chunk: tl.int64, stride_dA_cs_csize: tl.constexpr, - stride_seq_idx_seqlen: tl.constexpr, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, @@ -223,14 +225,14 @@ def _chunk_state_fwd_kernel( num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) pid_m = tl.program_id(axis=0) // num_pid_n pid_n = tl.program_id(axis=0) % num_pid_n - b_ptr += pid_c * chunk_size * stride_b_seqlen + ( + chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c) + chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1) + b_ptr += chunk_seqlen_start * stride_b_seqlen + ( pid_h // nheads_ngroups_ratio) * stride_b_head - x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + x_ptr += chunk_seqlen_start * stride_x_seqlen + pid_h * stride_x_head dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head - seq_idx_ptr += pid_c * chunk_size * stride_seq_idx_seqlen - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) @@ -243,10 +245,7 @@ def _chunk_state_fwd_kernel( (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize - seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen - chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) - seq_idx_last = tl.load(seq_idx_ptr + - (chunk_size_limit - 1) * stride_seq_idx_seqlen) + chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, chunk_size_limit, BLOCK_SIZE_K): @@ -261,15 +260,9 @@ def _chunk_state_fwd_kernel( dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) - - seq_idx_k = tl.load(seq_idx_ptrs, - mask=offs_k < chunk_size_limit - k, - other=-1) dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) - - scale = tl.where(seq_idx_k == seq_idx_last, - tl.exp(dA_cs_last - dA_cs_k) * dt_k, 0.0) + scale = tl.exp(dA_cs_last - dA_cs_k) * dt_k b *= scale[:, None] b = b.to(x_ptr.dtype.element_ty) acc += tl.dot(x, b) @@ -278,7 +271,6 @@ def _chunk_state_fwd_kernel( b_ptrs += BLOCK_SIZE_K * stride_b_seqlen dt_ptrs += BLOCK_SIZE_K * stride_dt_csize dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize - seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen states = acc.to(states_ptr.dtype.element_ty) @@ -534,6 +526,7 @@ def _chunk_state_varlen_kernel( def _chunk_cumsum_fwd(dt, A, chunk_size, + cu_chunk_seqlens, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))): @@ -541,7 +534,7 @@ def _chunk_cumsum_fwd(dt, assert A.shape == (nheads, ) if dt_bias is not None: assert dt_bias.shape == (nheads, ) - nchunks = math.ceil(seqlen / chunk_size) + nchunks = cu_chunk_seqlens.shape[0] - 1 dt_out = torch.empty(nheads, nchunks, chunk_size, @@ -561,6 +554,7 @@ def _chunk_cumsum_fwd(dt, dt_bias_ptr=dt_bias, dt_out_ptr=dt_out, dA_cumsum_ptr=dA_cumsum, + cu_chunk_seqlens_ptr=cu_chunk_seqlens, seqlen=seqlen, nheads=nheads, chunk_size=chunk_size, @@ -588,7 +582,7 @@ def _chunk_state_fwd(B, x, dt, dA_cumsum, - seq_idx=None, + cu_chunk_seqlens, states=None, states_in_fp32=True): seqlen, nheads, headdim = x.shape @@ -599,9 +593,6 @@ def _chunk_state_fwd(B, assert dt.shape == (nheads, nchunks, chunk_size) assert dA_cumsum.shape == dt.shape - assert seq_idx is not None - assert seq_idx.shape == (seqlen, ) - if states is not None: assert states.shape == (nchunks, nheads, headdim, dstate) else: @@ -619,7 +610,7 @@ def _chunk_state_fwd(B, states_ptr=states, dt_ptr=dt, dA_cumsum_ptr=dA_cumsum, - seq_idx_ptr=seq_idx, + cu_chunk_seqlens_ptr=cu_chunk_seqlens, hdim=headdim, dstate=dstate, chunk_size=chunk_size, @@ -641,7 +632,6 @@ def _chunk_state_fwd(B, stride_dA_cs_head=dA_cumsum.stride(0), stride_dA_cs_chunk=dA_cumsum.stride(1), stride_dA_cs_csize=dA_cumsum.stride(2), - stride_seq_idx_seqlen=seq_idx.stride(0), ) return states diff --git a/vllm/model_executor/layers/mamba/ops/ssd_combined.py b/vllm/model_executor/layers/mamba/ops/ssd_combined.py index 37d6c2870812..f3eb61d5840e 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_combined.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_combined.py @@ -14,8 +14,7 @@ from .ssd_bmm import _bmm_chunk_fwd from .ssd_chunk_scan import _chunk_scan_fwd -from .ssd_chunk_state import (_chunk_cumsum_fwd, _chunk_state_fwd, - chunk_state_varlen) +from .ssd_chunk_state import _chunk_cumsum_fwd, _chunk_state_fwd from .ssd_state_passing import _state_passing_fwd TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') @@ -37,9 +36,9 @@ def _mamba_chunk_scan_combined_fwd(x, dt_bias=None, initial_states=None, seq_idx=None, - chunk_indices=None, - chunk_offsets=None, cu_seqlens=None, + cu_chunk_seqlens=None, + last_chunk_indices=None, dt_softplus=False, dt_limit=(0.0, float("inf")), state_dtype=None): @@ -56,7 +55,7 @@ def _mamba_chunk_scan_combined_fwd(x, if D is not None: assert D.shape == (nheads, headdim) or D.shape == (nheads, ) if seq_idx is not None: - assert seq_idx.shape == (seqlen, ) + assert seq_idx.shape == (cu_chunk_seqlens.shape[0] - 1, ) if B.stride(-1) != 1: B = B.contiguous() if C.stride(-1) != 1: @@ -89,6 +88,7 @@ def _mamba_chunk_scan_combined_fwd(x, dA_cumsum, dt = _chunk_cumsum_fwd(dt, A, chunk_size, + cu_chunk_seqlens, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit) @@ -99,36 +99,31 @@ def _mamba_chunk_scan_combined_fwd(x, x, dt, dA_cumsum, - seq_idx=seq_idx, + cu_chunk_seqlens, states_in_fp32=True) # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries # (middle term of factorization of off-diag blocks; A terms) - # - for handling chunked prefill, this requires i) initial_states - # ii) seq_idx iii) is_cont_batched and (iv) chunk_offsets to be all specified. + # - for handling chunked prefill, this requires i) initial_states and + # ii) seq_idx to be all specified. # - When a new seq_idx is detected, we will stop passing the prev_state # and switch accordingly to the init_state corresponding to the new seq_idx. - # - We will also make sure that the dA_cumsum is taken only from the start of the - # sequence (hence we need the full dA_cumsum tensor and not just the values at chunk boundaries) - # - this will ensure that states will be updated with the rightmost flushed seq_idx - # of the previous chunk. This implies that the first chunk of states is either 0 - # or equal to init_states of the first example. states = _state_passing_fwd( rearrange(states, "... p n -> ... (p n)"), dA_cumsum, # (nheads, nchunks, chunk_size) + cu_chunk_seqlens, initial_states=rearrange(initial_states, "... p n -> ... (p n)") if initial_states is not None else None, # (batch, nheads, headdim*dstate) seq_idx=seq_idx, - out_dtype=state_dtype if state_dtype is not None else C.dtype, - chunk_offsets=chunk_offsets) + out_dtype=state_dtype if state_dtype is not None else C.dtype) states = rearrange(states, "... (p n) -> ... p n", n=dstate) # 4. Compute batched matrix multiply for C_j^T B_i terms CB = _bmm_chunk_fwd(C, B, chunk_size, - seq_idx=seq_idx, + cu_chunk_seqlens, output_dtype=torch.float32) # 5. Scan and compute the diagonal blocks, taking into @@ -148,26 +143,15 @@ def _mamba_chunk_scan_combined_fwd(x, dA_cumsum, C, states, + cu_chunk_seqlens, out, # in-place update seq_idx, D=D, z=z, - chunk_indices=chunk_indices, - chunk_offsets=chunk_offsets, initial_states=initial_states, ) - varlen_states = chunk_state_varlen( - B, - x, - dt, - dA_cumsum, - cu_seqlens, - states, - initial_states=initial_states, - ) - - return varlen_states + return states[last_chunk_indices] def mamba_chunk_scan_combined_varlen( @@ -178,14 +162,14 @@ def mamba_chunk_scan_combined_varlen( C, chunk_size, cu_seqlens, + cu_chunk_seqlens, + last_chunk_indices, seq_idx, out, D=None, z=None, dt_bias=None, initial_states=None, - chunk_indices=None, - chunk_offsets=None, dt_softplus=False, dt_limit=(0.0, float("inf")), state_dtype=None, @@ -198,8 +182,10 @@ def mamba_chunk_scan_combined_varlen( B: (seqlen, ngroups, dstate) C: (seqlen, ngroups, dstate) chunk_size: int - seq_idx: (seqlen) - cu_seqlens: (batch + 1) + cu_seqlens: (batch + 1,) + cu_chunk_seqlens: (nchunks + 1,) + last_chunk_indices: (batch,) + seq_idx: (nchunks,) out: (seqlen, nheads, headdim) preallocated output tensor D: (nheads, headdim) or (nheads,) z: (seqlen, nheads, headdim) @@ -228,9 +214,9 @@ def mamba_chunk_scan_combined_varlen( dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, - chunk_indices=chunk_indices, - chunk_offsets=chunk_offsets, cu_seqlens=cu_seqlens, + cu_chunk_seqlens=cu_chunk_seqlens, + last_chunk_indices=last_chunk_indices, dt_softplus=dt_softplus, dt_limit=dt_limit, state_dtype=state_dtype) diff --git a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py index 71a8a4b0a1c8..f09af262cfc2 100644 --- a/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py +++ b/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py @@ -30,8 +30,7 @@ def _state_passing_fwd_kernel( dA_cs_ptr, initstates_ptr, seq_idx_ptr, - chunk_offsets_ptr, - chunk_meta_num, + cu_chunk_seqlens_ptr, # Matrix dimensions dim: tl.constexpr, nchunks, @@ -50,94 +49,52 @@ def _state_passing_fwd_kernel( stride_initstates_batch: tl.int64, stride_initstates_head: tl.int64, stride_initstates_dim: tl.constexpr, - stride_seq_idx_seqlen: tl.constexpr, + stride_seq_idx_chunk: tl.constexpr, # Meta-parameters HAS_INITSTATES: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): pid_h = tl.program_id(axis=1) pid_m = tl.program_id(axis=0) + states_ptr += pid_h * stride_states_head dA_cs_ptr += pid_h * stride_dA_cs_head + (chunk_size - 1) * stride_dA_cs_csize out_ptr += pid_h * stride_out_head - if HAS_INITSTATES: - initstates_ptr += pid_h * stride_initstates_head offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) states_ptrs = states_ptr + offs_m * stride_states_dim out_ptrs = out_ptr + offs_m * stride_out_dim - # - states will be the past state of the sequence that continues on the current check - if not HAS_INITSTATES: - states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) - else: - initstates_ptr += offs_m * stride_initstates_dim - initstates_ptrs = initstates_ptr - # - for cont batches, for the first chunk mean it will be the first batch's - # init state + if HAS_INITSTATES: + initstates_ptrs = initstates_ptr \ + + pid_h * stride_initstates_head \ + + offs_m * stride_initstates_dim + states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + else: + states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) - tl.store(out_ptrs, states, mask=offs_m < dim) - out_ptrs += stride_out_chunk - prev_seq_idx_chunk_end = 0 - logical_chunk_idx = 0 - for c in range(nchunks - 1): + prev_seq_idx = 0 + for c in range(nchunks): new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) dA_cs = tl.load(dA_cs_ptr).to(tl.float32) - scale_mask = True - # - the seq to pass forward is the one that is flushed to the right - # boundary. - # - that is given by seq_idx_chunk_end below: the sequence index at the end of the chunk. - seq_idx_chunk_end = tl.load(seq_idx_ptr + - (min((c + 1) * chunk_size, seqlen) - 1) * - stride_seq_idx_seqlen) - - if HAS_INITSTATES: - if prev_seq_idx_chunk_end != seq_idx_chunk_end: - # this means in the current chunk the rightmost flushed seq - # has changed. - # - so we do not propagate the state from previous chunk - # - but rather we load that sequence's init state - initstates_ptrs = initstates_ptr + seq_idx_chunk_end * stride_initstates_batch - - # - update state with seq_idx_new's init state + seq_idx = tl.load(seq_idx_ptr + c * stride_seq_idx_chunk) + # we have started a new sequence + if prev_seq_idx != seq_idx: + if HAS_INITSTATES: + initstates_ptrs = initstates_ptr + seq_idx * stride_initstates_batch \ + + pid_h * stride_initstates_head \ + + offs_m * stride_initstates_dim states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + else: + states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32) - # - we need to consider the cumsum only of the last sequence in the chunk - # - find its starting position (given by c_off of the logical chunk index) - # - and subtract the cumsum just before that position from the total cumsum - # - first, update the logical chunk index (add the number of sequences in the current physical chunk): - # sequence index at the start of the current chunk - seq_idx_chunk_start = tl.load(seq_idx_ptr + - min(c * chunk_size, seqlen) * - stride_seq_idx_seqlen) - logical_chunk_idx += seq_idx_chunk_end - seq_idx_chunk_start - # - load the chunk offset: - c_off = tl.load(chunk_offsets_ptr + logical_chunk_idx, - mask=logical_chunk_idx < chunk_meta_num, - other=0) - # - if offset is 0, then the sequence starts at the beginning of the chunk, and we don't need to subtract anything - if c_off > 0: - # - dA_cs_ptr currently points to the cumsum at the end of the chunk - subtract the chunk size and add the offset - dA_cs_boundary = tl.load( - dA_cs_ptr - (chunk_size - 1) * stride_dA_cs_csize + - (c_off - 1) * stride_dA_cs_csize, - mask=(c_off - 1) > -1 and c_off < chunk_size, - other=0.0) - dA_cs -= dA_cs_boundary - - # - increment logical chunk index for every physical chunk - logical_chunk_idx += 1 - else: - scale_mask = seq_idx_chunk_end == prev_seq_idx_chunk_end - prev_seq_idx_chunk_end = seq_idx_chunk_end - - scale = tl.where(scale_mask, tl.exp(dA_cs), 0.0) - states = scale * states + new_states + prev_seq_idx = seq_idx + states = tl.exp(dA_cs) * states + new_states tl.store(out_ptrs, states, mask=offs_m < dim) states_ptrs += stride_states_chunk @@ -148,8 +105,8 @@ def _state_passing_fwd_kernel( def _state_passing_fwd( states, dA_cumsum, + cu_chunk_seqlens, seq_idx, - chunk_offsets, initial_states=None, out_dtype=None, ): @@ -175,9 +132,7 @@ def _state_passing_fwd( dA_cs_ptr=dA_cumsum, initstates_ptr=initial_states, seq_idx_ptr=seq_idx, - chunk_offsets_ptr=chunk_offsets, - chunk_meta_num=len(chunk_offsets) - if chunk_offsets is not None else 0, + cu_chunk_seqlens_ptr=cu_chunk_seqlens, dim=dim, nchunks=nchunks, seqlen=seqlen if seq_idx is not None else 0, @@ -194,7 +149,7 @@ def _state_passing_fwd( stride_initstates_batch=initial_states_strides[0], stride_initstates_head=initial_states_strides[1], stride_initstates_dim=initial_states_strides[2], - stride_seq_idx_seqlen=seq_idx.stride(0), + stride_seq_idx_chunk=seq_idx.stride(0), HAS_INITSTATES=initial_states is not None, ) return out diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index 03265b13de50..8234d40e94ab 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -260,9 +260,9 @@ def forward_cuda( prep_initial_states = attn_metadata.prep_initial_states chunk_size = attn_metadata.chunk_size seq_idx_p = attn_metadata.seq_idx_p - chunk_indices_p = attn_metadata.chunk_indices_p - chunk_offsets_p = attn_metadata.chunk_offsets_p query_start_loc_p = attn_metadata.query_start_loc_p + cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p + last_chunk_indices_p = attn_metadata.last_chunk_indices_p # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states) @@ -368,9 +368,9 @@ def forward_cuda( self.num_heads // self.tp_size, self.head_dim), dt_bias=self.dt_bias, seq_idx=seq_idx_p, - chunk_indices=chunk_indices_p, - chunk_offsets=chunk_offsets_p, cu_seqlens=query_start_loc_p, + cu_chunk_seqlens=cu_chunk_seqlen_p, + last_chunk_indices=last_chunk_indices_p, initial_states=initial_states, dt_softplus=True, dt_limit=(0.0, float("inf")), diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index 6f16fda962ae..e4f16f37a430 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import math from dataclasses import dataclass from typing import Optional @@ -8,6 +7,7 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.config import VllmConfig +from vllm.utils import cdiv from vllm.v1.attention.backends.mamba_attn import ( BaseMambaAttentionMetadataBuilder) from vllm.v1.attention.backends.utils import (PAD_SLOT_ID, @@ -17,91 +17,6 @@ from vllm.v1.kv_cache_interface import AttentionSpec -def _query_start_loc_to_chunk_indices_offsets( - query_start_loc: torch.Tensor, chunk_size: int, - total_seqlens: int) -> tuple[torch.Tensor, torch.Tensor]: - """ - Args: - query_start_loc (torch.Tensor): 1D tensor of cumulative sequence - lengths, shape (num_seqs + 1,). - The first element should be 0. Each entry represents the starting - index of a sequence in the flattened token array. - chunk_size (int): The size of each physical mamba chunk - (number of tokens per chunk). - total_seqlens (int): The total number of tokens in the batch. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - chunk_indices (torch.Tensor): 1D tensor of indices - indicating the physical chunk for each logical chunk. - - chunk_offsets (torch.Tensor): 1D tensor of offsets - indicating the starting index of each logical chunk within - its physical chunk. - - This function computes the chunk indices and offsets for the given - query_start_loc and chunk_size. Both are tensors of integers with length N, - where N is the number of logical (pseudo) chunks. - A logical chunk is a sequence of tokens that are all part of the same - sequence and are all in the same physical mamba chunk. - In other words, a logical chunk changes every time we cross a sequence - boundary or a physical mamba chunk boundary. - Logical chunks are needed to handle batched requests with initial states - (see _state_passing_fwd and _chunk_scan_fwd). - The chunk_indices tensor contains the index of the physical chunk for each - logical chunk. - The chunk_offsets tensor contains the offset (AKA starting index) of the - logical chunk in the physical chunk. - - Example: - query_start_loc = [0, 5, 10] - chunk_size = 8 - total_seqlens = 10 - -> chunk_indices = [0, 0, 1] - -> chunk_offsets = [0, 5, 0] - - In this example, we have 2 sequences, each with 5 tokens. The physical - chunk size is 8 tokens. - We have three logical chunks: - - the first logical chunk starts at token 0 in the first physical chunk - and contains all 5 tokens from the first sequence - - the second logical chunk starts at token 5 in the first physical chunk - and contains first 3 tokens from the second sequence - - the third logical chunk starts at token 0 in the second physical chunk - and contains the remaining 2 tokens from the second sequence - """ - - cu_seqlens = query_start_loc[1:] # remove prepended 0 - - # outputs will have length expansion of chunks that do not divide - # chunk_size - N = math.ceil(total_seqlens / chunk_size) + (cu_seqlens[:-1] % chunk_size - > 0).sum() - chunk_indices = torch.arange(N, - dtype=torch.int, - device=query_start_loc.device) - chunk_offsets = torch.zeros((N, ), - dtype=torch.int, - device=query_start_loc.device) - - p = 0 # num of insertions - for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]): - - # if does not divide chunk_size, then there is one chunk insertion - p += (s % chunk_size > 0) - - # get the dimensions - # - the + 1 for _e is to shift the boundary by one chunk - # - this shifting is not needed if chunk_size divides e - _s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size - > 0) - - # adjust indices and offsets - chunk_indices[_s:_e] -= p - chunk_offsets[_s] = s % chunk_size - - return chunk_indices, chunk_offsets - - class Mamba2AttentionBackend(AttentionBackend): @staticmethod @@ -125,8 +40,16 @@ class Mamba2AttentionMetadata: # the batch has no prefill request. has_initial_states_p: Optional[torch.Tensor] seq_idx_p: Optional[torch.Tensor] - chunk_indices_p: Optional[torch.Tensor] - chunk_offsets_p: Optional[torch.Tensor] + + # cu_chunk_seqlen_p is a tensor of shape (nchunks+1,) that contains, for + # each chunk, its offests into the varlen sequence dimension. It is defined + # such that the i-th chunk contains tokens from cu_chunk_seqlen_p[i] to + # cu_chunk_seqlen_p[i+1]. + cu_chunk_seqlen_p: Optional[torch.Tensor] + + # last_chunk_indices_p is a tensor of shape (batch,) that contains the + # index of the last chunk for every sequence in the (prefill) batch. + last_chunk_indices_p: Optional[torch.Tensor] state_indices_tensor: torch.Tensor # shape: [batch,] @@ -151,13 +74,14 @@ def build(self, common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False) -> Mamba2AttentionMetadata: num_reqs = common_attn_metadata.num_reqs - query_start_loc_p = None seq_lens = common_attn_metadata.seq_lens + query_start_loc_p = None seq_idx_p = None - chunk_indices_p, chunk_offsets_p = None, None + cu_chunk_seqlen_p = None + last_chunk_indices_p = None + # Need flags to indicate if there are initial states - # currently we really only support the FlashAttention backend has_initial_states_p = None prep_initial_states = False @@ -171,7 +95,7 @@ def build(self, common_attn_metadata, decode_threshold=self.reorder_batch_threshold)) - # Compute seq_idx, chunk_indices and chunk_offsets for prefill only + # Compute seq_idx for prefill only if num_prefills > 0: #[batch,] has_initial_states_cpu = ( @@ -184,21 +108,68 @@ def build(self, query_start_loc_p = common_attn_metadata.query_start_loc[ -num_prefills - 1:] - num_decode_tokens - seq_idx_p = torch.repeat_interleave(torch.arange( - num_prefills, - dtype=torch.int32, - device=query_start_loc_p.device), - query_start_loc_p.diff(), - output_size=num_prefill_tokens) - - # We compute metadata for chunked prefill once at the top level - # model forward and reuse them in mamba layers. If not needed, - # they will be ignored inside mamba kernels. - if prep_initial_states: - chunk_indices_p, chunk_offsets_p = ( - _query_start_loc_to_chunk_indices_offsets( - query_start_loc_p, self.chunk_size, - num_prefill_tokens)) + num_computed_tokens_p = \ + common_attn_metadata.num_computed_tokens_cpu[ + num_reqs - num_prefills:num_reqs] + query_start_loc_p_cpu = common_attn_metadata.query_start_loc_cpu[ + -num_prefills - 1:] - num_decode_tokens + + # The code below carefully constructs the chunks such that: + # 1. Chunks contain tokens from a *single* sequence only. + # 2. For every sequence, we are guaranteed that we can + # retrieve the mamba state *every* chunk_size tokens. + # Constraint (1) dramatically simplifies the mamba2 kernels. + # Constraint (2) dramatically simplifies the implementation + # of prefix caching for mamba2 (wip). We need to take care + # of the interaction with chunked prefill in order to + # satisfy constraint (2). + # TODO (tdoublep): This code could probably be optimized. + cu_chunk_seqlen = [] + seq_idx = [] + last_chunk_indices = [] + seqlen_pos = 0 + for req_idx in range(num_prefills): + this_num_computed = num_computed_tokens_p[req_idx].item() + this_new_tokens = query_start_loc_p_cpu[req_idx + 1].item( + ) - query_start_loc_p_cpu[req_idx].item() + + # if computed tokens are not chunk-aligned, use the first + # chunk to finish it off + if this_num_computed % self.chunk_size != 0: + seq_idx.append(req_idx) + cu_chunk_seqlen.append(seqlen_pos) + # how many tokens to finish the chunk? + chunk_len = cdiv(this_num_computed, self.chunk_size + ) * self.chunk_size - this_num_computed + # we can only use at most this_new_tokens + chunk_len = min(chunk_len, this_new_tokens) + seqlen_pos += chunk_len + this_new_tokens -= chunk_len + + n_chunks = cdiv(this_new_tokens, self.chunk_size) + for chunk in range(n_chunks): + seq_idx.append(req_idx) + cu_chunk_seqlen.append(seqlen_pos) + chunk_len = min(self.chunk_size, this_new_tokens) + seqlen_pos += chunk_len + this_new_tokens -= chunk_len + + assert this_new_tokens == 0 + last_chunk_indices.append(len(cu_chunk_seqlen) - 1) + + cu_chunk_seqlen.append(seqlen_pos) + + seq_idx_p = torch.as_tensor(seq_idx, + device=query_start_loc_p.device, + dtype=torch.int32) + cu_chunk_seqlen_p = torch.as_tensor( + cu_chunk_seqlen, + device=query_start_loc_p.device, + dtype=torch.int32) + last_chunk_indices_p = torch.as_tensor( + last_chunk_indices, + device=query_start_loc_p.device, + dtype=torch.int32) nums_dict, batch_ptr, token_chunk_offset_ptr = \ compute_causal_conv1d_metadata(query_start_loc_p) @@ -222,9 +193,9 @@ def build(self, chunk_size=self.chunk_size, has_initial_states_p=has_initial_states_p, seq_idx_p=seq_idx_p, - chunk_indices_p=chunk_indices_p, - chunk_offsets_p=chunk_offsets_p, state_indices_tensor=state_indices_tensor, + cu_chunk_seqlen_p=cu_chunk_seqlen_p, + last_chunk_indices_p=last_chunk_indices_p, nums_dict=nums_dict, batch_ptr=batch_ptr, token_chunk_offset_ptr=token_chunk_offset_ptr,