File tree Expand file tree Collapse file tree 1 file changed +22
-3
lines changed Expand file tree Collapse file tree 1 file changed +22
-3
lines changed Original file line number Diff line number Diff line change @@ -251,9 +251,28 @@ def forward(
251251 _Backend .FLASH_ATTN ,
252252 _Backend .FLASH_ATTN_VLLM_V1 ,
253253 }:
254- from vllm .vllm_flash_attn import flash_attn_func
255-
256- out = flash_attn_func (query , key , value , softmax_scale = self .scale )
254+ from vllm .vllm_flash_attn import flash_attn_varlen_func
255+
256+ cu_seqlens_q = torch .arange (0 , (bsz + 1 ) * q_len ,
257+ step = q_len ,
258+ dtype = torch .int32 ,
259+ device = query .device )
260+ cu_seqlens_k = torch .arange (0 , (bsz + 1 ) * kv_len ,
261+ step = kv_len ,
262+ dtype = torch .int32 ,
263+ device = key .device )
264+
265+ out = flash_attn_varlen_func (
266+ query .flatten (0 , 1 ),
267+ key .flatten (0 , 1 ),
268+ value .flatten (0 , 1 ),
269+ cu_seqlens_q = cu_seqlens_q ,
270+ cu_seqlens_k = cu_seqlens_k ,
271+ max_seqlen_q = q_len ,
272+ max_seqlen_k = kv_len ,
273+ softmax_scale = self .scale ,
274+ )
275+ out = out .reshape (bsz , q_len , - 1 )
257276 elif self .attn_backend == _Backend .XFORMERS :
258277 from xformers import ops as xops
259278
You can’t perform that action at this time.
0 commit comments