Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 20 additions & 12 deletions vllm/v1/attention/backends/rocm_aiter_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,12 +214,14 @@ class AiterFlashAttentionMetadata:
# |-- query_len ---|

num_actual_tokens: int # Number of tokens excluding padding.
num_actual_kv_tokens: int
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
seq_lens: torch.Tensor
slot_mapping: torch.Tensor
block_table: torch.Tensor
cu_seq_lens: Optional[torch.Tensor]

# For cascade attention.
use_cascade: bool
Expand Down Expand Up @@ -272,6 +274,20 @@ def build(self,
seq_lens = common_attn_metadata.seq_lens
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
if max_query_len > 1:
# We pre-compute cumulative seq len needed for prefill attention
# here to avoid recomputing it for every layer
cu_seq_lens = torch.zeros(seq_lens.shape[0] + 1,
dtype=torch.int32,
device=seq_lens.device)
torch.cumsum(seq_lens,
dim=0,
dtype=cu_seq_lens.dtype,
out=cu_seq_lens[1:])
num_actual_kv_tokens = int(cu_seq_lens[-1].item())
else:
cu_seq_lens = None
num_actual_kv_tokens = 0

def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
max_seq_len, causal):
Expand All @@ -281,12 +297,14 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,

attn_metadata = AiterFlashAttentionMetadata(
num_actual_tokens=num_actual_tokens,
num_actual_kv_tokens=num_actual_kv_tokens,
max_query_len=max_query_len,
query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table=block_table_tensor,
slot_mapping=slot_mapping,
cu_seq_lens=cu_seq_lens,
use_cascade=use_cascade,
common_prefix_len=common_prefix_len,
total_tokens=self.total_tokens,
Expand Down Expand Up @@ -475,16 +493,6 @@ def forward(
block_table = attn_metadata.block_table

if max_seqlen_q > 1:

cu_seq_lens = torch.zeros(seqused_k.shape[0] + 1,
dtype=torch.int32,
device=query.device)

torch.cumsum(seqused_k,
dim=0,
dtype=cu_seq_lens.dtype,
out=cu_seq_lens[1:])

torch.ops.vllm.flash_attn_varlen_func(
query[:num_actual_tokens],
key_cache,
Expand All @@ -497,10 +505,10 @@ def forward(
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=block_table,
cu_seqlens_k=cu_seq_lens,
cu_seqlens_k=attn_metadata.cu_seq_lens,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
total_tokens=attn_metadata.total_tokens,
total_tokens=attn_metadata.num_actual_kv_tokens,
)

_, num_heads, head_size = query.shape
Expand Down