@@ -214,12 +214,14 @@ class AiterFlashAttentionMetadata:
214214 # |-- query_len ---|
215215
216216 num_actual_tokens : int # Number of tokens excluding padding.
217+ num_actual_kv_tokens : int
217218 max_query_len : int
218219 query_start_loc : torch .Tensor
219220 max_seq_len : int
220221 seq_lens : torch .Tensor
221222 slot_mapping : torch .Tensor
222223 block_table : torch .Tensor
224+ cu_seq_lens : Optional [torch .Tensor ]
223225
224226 # For cascade attention.
225227 use_cascade : bool
@@ -272,6 +274,20 @@ def build(self,
272274 seq_lens = common_attn_metadata .seq_lens
273275 block_table_tensor = common_attn_metadata .block_table_tensor
274276 slot_mapping = common_attn_metadata .slot_mapping
277+ if max_query_len > 1 :
278+ # We pre-compute cumulative seq len needed for prefill attention
279+ # here to avoid recomputing it for every layer
280+ cu_seq_lens = torch .zeros (seq_lens .shape [0 ] + 1 ,
281+ dtype = torch .int32 ,
282+ device = seq_lens .device )
283+ torch .cumsum (seq_lens ,
284+ dim = 0 ,
285+ dtype = cu_seq_lens .dtype ,
286+ out = cu_seq_lens [1 :])
287+ num_actual_kv_tokens = int (cu_seq_lens [- 1 ].item ())
288+ else :
289+ cu_seq_lens = None
290+ num_actual_kv_tokens = 0
275291
276292 def schedule (batch_size , cu_query_lens , max_query_len , seqlens ,
277293 max_seq_len , causal ):
@@ -281,12 +297,14 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
281297
282298 attn_metadata = AiterFlashAttentionMetadata (
283299 num_actual_tokens = num_actual_tokens ,
300+ num_actual_kv_tokens = num_actual_kv_tokens ,
284301 max_query_len = max_query_len ,
285302 query_start_loc = query_start_loc ,
286303 max_seq_len = max_seq_len ,
287304 seq_lens = seq_lens ,
288305 block_table = block_table_tensor ,
289306 slot_mapping = slot_mapping ,
307+ cu_seq_lens = cu_seq_lens ,
290308 use_cascade = use_cascade ,
291309 common_prefix_len = common_prefix_len ,
292310 total_tokens = self .total_tokens ,
@@ -475,16 +493,6 @@ def forward(
475493 block_table = attn_metadata .block_table
476494
477495 if max_seqlen_q > 1 :
478-
479- cu_seq_lens = torch .zeros (seqused_k .shape [0 ] + 1 ,
480- dtype = torch .int32 ,
481- device = query .device )
482-
483- torch .cumsum (seqused_k ,
484- dim = 0 ,
485- dtype = cu_seq_lens .dtype ,
486- out = cu_seq_lens [1 :])
487-
488496 torch .ops .vllm .flash_attn_varlen_func (
489497 query [:num_actual_tokens ],
490498 key_cache ,
@@ -497,10 +505,10 @@ def forward(
497505 alibi_slopes = self .alibi_slopes ,
498506 window_size = self .sliding_window ,
499507 block_table = block_table ,
500- cu_seqlens_k = cu_seq_lens ,
508+ cu_seqlens_k = attn_metadata . cu_seq_lens ,
501509 k_scale = layer ._k_scale ,
502510 v_scale = layer ._v_scale ,
503- total_tokens = attn_metadata .total_tokens ,
511+ total_tokens = attn_metadata .num_actual_kv_tokens ,
504512 )
505513
506514 _ , num_heads , head_size = query .shape
0 commit comments