From 042858de42fd097a500899b635872234dc269cc3 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 22 Apr 2025 17:46:33 +0000 Subject: [PATCH 1/2] llama4 fa3 fix Signed-off-by: Lucas Wilkinson --- vllm/v1/attention/backends/flash_attn.py | 68 +++++++++++++++--------- 1 file changed, 42 insertions(+), 26 deletions(-) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index c039cd8067f3..ce4b0a19d8ff 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -109,6 +109,7 @@ class LocalAttentionMetadata: local_block_table: torch.Tensor local_max_query_len: int local_max_seq_len: int + local_scheduler_metadata: Optional[torch.Tensor] local_attn_metadata: Optional[LocalAttentionMetadata] = None @@ -286,7 +287,9 @@ def __init__(self, runner: "GPUModelRunner"): self.runner = runner self.aot_schedule = (get_flash_attn_version() == 3) - self.num_heads = model_config.get_num_attention_heads( + self.num_heads_q = model_config.get_num_attention_heads( + runner.parallel_config) + self.num_heads_kv = model_config.get_num_kv_heads( runner.parallel_config) self.headdim = model_config.get_head_size() self.page_size = self.runner.block_size @@ -308,6 +311,23 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( self.runner.device, non_blocking=True).long() + def schedule(cu_query_lens, max_query_len, seqlens, max_seq_len, + causal): + if self.aot_schedule: + return get_scheduler_metadata( + batch_size=num_reqs, + max_seqlen_q=max_query_len, + max_seqlen_k=max_seq_len, + cache_seqlens=seqlens, + num_heads_q=self.num_heads_q, + num_heads_kv=self.num_heads_kv, + headdim=self.headdim, + page_size=self.page_size, + cu_seqlens_q=cu_query_lens, + causal=causal, + ) + return None + # for local attention local_attn_metadata = None if self.runner.attention_chunk_size is not None: @@ -319,36 +339,30 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, block_table, self.runner.block_size, ) + local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to( + self.runner.device, non_blocking=True) + local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to( + self.runner.device, non_blocking=True) + local_max_query_len = seqlens_q_local_np.max() + local_max_seq_len = virt_k_seqlens_np.max() + local_scheduler_metadata = schedule( + cu_query_lens=local_query_start_loc, + max_query_len=local_max_query_len, + seqlens=local_seqused_k, + max_seq_len=local_max_seq_len, + causal=True) + local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata( - local_query_start_loc=torch.from_numpy( - virt_q_cu_seqlens_np).to(self.runner.device, - non_blocking=True), - local_seqused_k=torch.from_numpy(virt_k_seqlens_np).to( - self.runner.device, non_blocking=True), + local_query_start_loc=local_query_start_loc, + local_seqused_k=local_seqused_k, local_block_table=virt_block_table, - local_max_query_len=seqlens_q_local_np.max(), - local_max_seq_len=virt_k_seqlens_np.max(), + local_max_query_len=local_max_query_len, + local_max_seq_len=local_max_seq_len, + local_scheduler_metadata=local_scheduler_metadata, ) use_cascade = common_prefix_len > 0 - def schedule(cu_query_lens, max_query_len, seqlens, max_seq_len, - causal): - if self.aot_schedule: - return get_scheduler_metadata( - batch_size=num_reqs, - max_seqlen_q=max_query_len, - max_seqlen_k=max_seq_len, - cache_seqlens=seqlens, - num_heads_q=self.num_heads, - num_heads_kv=self.num_heads, - headdim=self.headdim, - page_size=self.page_size, - cu_seqlens_q=cu_query_lens, - causal=causal, - ) - return None - if use_cascade: cu_prefix_query_lens = torch.tensor([0, num_actual_tokens], dtype=torch.int32, @@ -541,12 +555,14 @@ def forward( max_seqlen_q = local_metadata.local_max_query_len max_seqlen_k = local_metadata.local_max_seq_len block_table = local_metadata.local_block_table + scheduler_metadata = local_metadata.local_scheduler_metadata else: cu_seqlens_q = attn_metadata.query_start_loc seqused_k = attn_metadata.seq_lens max_seqlen_q = attn_metadata.max_query_len max_seqlen_k = attn_metadata.max_seq_len block_table = attn_metadata.block_table + scheduler_metadata = attn_metadata.scheduler_metadata descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) @@ -565,7 +581,7 @@ def forward( window_size=self.sliding_window, block_table=block_table, softcap=self.logits_soft_cap, - scheduler_metadata=attn_metadata.scheduler_metadata, + scheduler_metadata=scheduler_metadata, fa_version=self.vllm_flash_attn_version, q_descale=layer._q_scale.expand(descale_shape), k_descale=layer._k_scale.expand(descale_shape), From 0e8e6340c558ef2c468d06bc43877591283d0044 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 22 Apr 2025 18:05:06 +0000 Subject: [PATCH 2/2] fix Signed-off-by: Lucas Wilkinson --- vllm/v1/attention/backends/flash_attn.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index ce4b0a19d8ff..d4277086552d 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -311,11 +311,11 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( self.runner.device, non_blocking=True).long() - def schedule(cu_query_lens, max_query_len, seqlens, max_seq_len, - causal): + def schedule(batch_size, cu_query_lens, max_query_len, seqlens, + max_seq_len, causal): if self.aot_schedule: return get_scheduler_metadata( - batch_size=num_reqs, + batch_size=batch_size, max_seqlen_q=max_query_len, max_seqlen_k=max_seq_len, cache_seqlens=seqlens, @@ -346,6 +346,7 @@ def schedule(cu_query_lens, max_query_len, seqlens, max_seq_len, local_max_query_len = seqlens_q_local_np.max() local_max_seq_len = virt_k_seqlens_np.max() local_scheduler_metadata = schedule( + batch_size=local_query_start_loc.shape[0] - 1, cu_query_lens=local_query_start_loc, max_query_len=local_max_query_len, seqlens=local_seqused_k, @@ -375,12 +376,14 @@ def schedule(cu_query_lens, max_query_len, seqlens, max_seq_len, suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to( self.runner.device) prefix_scheduler_metadata = schedule( + batch_size=num_reqs, cu_query_lens=cu_prefix_query_lens, max_query_len=num_actual_tokens, seqlens=prefix_kv_lens, max_seq_len=common_prefix_len, causal=False) - scheduler_metadata = schedule(cu_query_lens=query_start_loc, + scheduler_metadata = schedule(batch_size=num_reqs, + cu_query_lens=query_start_loc, max_query_len=max_query_len, seqlens=suffix_kv_lens, max_seq_len=max_seq_len - @@ -391,7 +394,8 @@ def schedule(cu_query_lens, max_query_len, seqlens, max_seq_len, prefix_kv_lens = None suffix_kv_lens = None prefix_scheduler_metadata = None - scheduler_metadata = schedule(cu_query_lens=query_start_loc, + scheduler_metadata = schedule(batch_size=num_reqs, + cu_query_lens=query_start_loc, max_query_len=max_query_len, seqlens=seq_lens, max_seq_len=max_seq_len,