diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index abe05174507f..e8bffbef4415 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -214,12 +214,14 @@ class AiterFlashAttentionMetadata: # |-- query_len ---| num_actual_tokens: int # Number of tokens excluding padding. + num_actual_kv_tokens: int max_query_len: int query_start_loc: torch.Tensor max_seq_len: int seq_lens: torch.Tensor slot_mapping: torch.Tensor block_table: torch.Tensor + cu_seq_lens: Optional[torch.Tensor] # For cascade attention. use_cascade: bool @@ -272,6 +274,20 @@ def build(self, seq_lens = common_attn_metadata.seq_lens block_table_tensor = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping + if max_query_len > 1: + # We pre-compute cumulative seq len needed for prefill attention + # here to avoid recomputing it for every layer + cu_seq_lens = torch.zeros(seq_lens.shape[0] + 1, + dtype=torch.int32, + device=seq_lens.device) + torch.cumsum(seq_lens, + dim=0, + dtype=cu_seq_lens.dtype, + out=cu_seq_lens[1:]) + num_actual_kv_tokens = int(cu_seq_lens[-1].item()) + else: + cu_seq_lens = None + num_actual_kv_tokens = 0 def schedule(batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal): @@ -281,12 +297,14 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, attn_metadata = AiterFlashAttentionMetadata( num_actual_tokens=num_actual_tokens, + num_actual_kv_tokens=num_actual_kv_tokens, max_query_len=max_query_len, query_start_loc=query_start_loc, max_seq_len=max_seq_len, seq_lens=seq_lens, block_table=block_table_tensor, slot_mapping=slot_mapping, + cu_seq_lens=cu_seq_lens, use_cascade=use_cascade, common_prefix_len=common_prefix_len, total_tokens=self.total_tokens, @@ -475,16 +493,6 @@ def forward( block_table = attn_metadata.block_table if max_seqlen_q > 1: - - cu_seq_lens = torch.zeros(seqused_k.shape[0] + 1, - dtype=torch.int32, - device=query.device) - - torch.cumsum(seqused_k, - dim=0, - dtype=cu_seq_lens.dtype, - out=cu_seq_lens[1:]) - torch.ops.vllm.flash_attn_varlen_func( query[:num_actual_tokens], key_cache, @@ -497,10 +505,10 @@ def forward( alibi_slopes=self.alibi_slopes, window_size=self.sliding_window, block_table=block_table, - cu_seqlens_k=cu_seq_lens, + cu_seqlens_k=attn_metadata.cu_seq_lens, k_scale=layer._k_scale, v_scale=layer._v_scale, - total_tokens=attn_metadata.total_tokens, + total_tokens=attn_metadata.num_actual_kv_tokens, ) _, num_heads, head_size = query.shape