66
77# ruff: noqa: E501,SIM102
88
9- import math
10-
119import torch
1210
1311from vllm .triton_utils import tl , triton
@@ -96,7 +94,7 @@ def _bmm_chunk_fwd_kernel(
9694 a_ptr ,
9795 b_ptr ,
9896 out_ptr ,
99- seq_idx_ptr ,
97+ cu_chunk_seqlens_ptr ,
10098 # Matrix dimensions
10199 seqlen ,
102100 chunk_size : tl .constexpr ,
@@ -112,7 +110,6 @@ def _bmm_chunk_fwd_kernel(
112110 stride_out_head : tl .int64 ,
113111 stride_outm : tl .int64 ,
114112 stride_outn : tl .constexpr ,
115- stride_seq_idx_seqlen : tl .constexpr ,
116113 # Meta-parameters
117114 IS_CAUSAL : tl .constexpr ,
118115 dot_dtype : tl .constexpr ,
@@ -129,10 +126,12 @@ def _bmm_chunk_fwd_kernel(
129126 if IS_CAUSAL :
130127 if pid_n * BLOCK_SIZE_N >= (pid_m + 1 ) * BLOCK_SIZE_M :
131128 return
132- a_ptr += pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head
133- b_ptr += pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head
134129
135- seq_idx_ptr += pid_c * chunk_size * stride_seq_idx_seqlen
130+ chunk_seqlen_start = tl .load (cu_chunk_seqlens_ptr + pid_c )
131+ chunk_seqlen_end = tl .load (cu_chunk_seqlens_ptr + pid_c + 1 )
132+
133+ a_ptr += chunk_seqlen_start * stride_a_seqlen + pid_h * stride_a_head
134+ b_ptr += chunk_seqlen_start * stride_b_seqlen + pid_h * stride_b_head
136135
137136 offs_m = pid_m * BLOCK_SIZE_M + tl .arange (0 , BLOCK_SIZE_M )
138137 offs_n = pid_n * BLOCK_SIZE_N + tl .arange (0 , BLOCK_SIZE_N )
@@ -141,7 +140,7 @@ def _bmm_chunk_fwd_kernel(
141140 offs_k [None , :] * stride_ak )
142141 b_ptrs = b_ptr + (offs_k [:, None ] * stride_bk +
143142 offs_n [None , :] * stride_b_seqlen )
144- chunk_size_limit = min ( chunk_size , seqlen - pid_c * chunk_size )
143+ chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start
145144
146145 acc = tl .zeros ((BLOCK_SIZE_M , BLOCK_SIZE_N ), dtype = tl .float32 )
147146
@@ -162,16 +161,6 @@ def _bmm_chunk_fwd_kernel(
162161 offs_m = pid_m * BLOCK_SIZE_M + tl .arange (0 , BLOCK_SIZE_M )
163162 offs_n = pid_n * BLOCK_SIZE_N + tl .arange (0 , BLOCK_SIZE_N )
164163
165- # Zero out the results that are not from the same request
166- # in the varlen batch
167- seq_idx_m = tl .load (seq_idx_ptr + offs_m * stride_seq_idx_seqlen ,
168- mask = offs_m < chunk_size_limit ,
169- other = - 1 )
170- seq_idx_n = tl .load (seq_idx_ptr + offs_n * stride_seq_idx_seqlen ,
171- mask = offs_n < chunk_size_limit ,
172- other = - 2 )
173- acc = tl .where (seq_idx_m [:, None ] == seq_idx_n [None , :], acc , 0.0 )
174-
175164 out = acc .to (out_ptr .dtype .element_ty )
176165 out_ptr += pid_c * stride_out_chunk + pid_h * stride_out_head
177166 out_ptrs = out_ptr + (stride_outm * offs_m [:, None ] +
@@ -182,27 +171,31 @@ def _bmm_chunk_fwd_kernel(
182171 (offs_n [None , :] < chunk_size ))
183172
184173
185- def _bmm_chunk_fwd (a , b , chunk_size , seq_idx , causal = False , output_dtype = None ):
174+ def _bmm_chunk_fwd (a ,
175+ b ,
176+ chunk_size ,
177+ cu_chunk_seqlens ,
178+ causal = False ,
179+ output_dtype = None ):
186180 """
187181 Argument:
188182 a: (seqlen, ngroups, k)
189183 b: (seqlen, ngroups, k)
190- seq_idx: (seqlen,). out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out.
184+ chunk_size: int
185+ cu_chunk_seq_lens: (nchunks+1,)
191186 causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are
192187 guaranteed to be correct.
193188 Return:
194189 out: (nchunks, ngroups, chunk_size, chunk_size)
195190 """
196191 seqlen , ngroups , k = a .shape
197192 assert b .shape == a .shape
198- assert seq_idx is not None
199- assert seq_idx .shape == (seqlen , )
200193 if a .stride (- 1 ) != 1 and a .stride (0 ) != 1 :
201194 a = a .contiguous ()
202195 if b .stride (- 1 ) != 1 and b .stride (0 ) != 1 :
203196 b = b .contiguous ()
204197
205- nchunks = math . ceil ( seqlen / chunk_size )
198+ nchunks = len ( cu_chunk_seqlens ) - 1
206199 # Allocates output.
207200 out_dtype = a .dtype if output_dtype is None else output_dtype
208201 out = torch .empty ((nchunks , ngroups , chunk_size , chunk_size ),
@@ -220,7 +213,7 @@ def _bmm_chunk_fwd(a, b, chunk_size, seq_idx, causal=False, output_dtype=None):
220213 a_ptr = a ,
221214 b_ptr = b ,
222215 out_ptr = out ,
223- seq_idx_ptr = seq_idx ,
216+ cu_chunk_seqlens_ptr = cu_chunk_seqlens ,
224217 seqlen = seqlen ,
225218 chunk_size = chunk_size ,
226219 K = k ,
@@ -235,7 +228,6 @@ def _bmm_chunk_fwd(a, b, chunk_size, seq_idx, causal=False, output_dtype=None):
235228 stride_out_head = out .stride (1 ),
236229 stride_outm = out .stride (- 2 ),
237230 stride_outn = out .stride (- 1 ),
238- stride_seq_idx_seqlen = seq_idx .stride (0 ),
239231 IS_CAUSAL = causal ,
240232 dot_dtype = dot_dtype ,
241233 )
0 commit comments