Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 7 additions & 18 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,12 +158,13 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,

self.aot_schedule = (get_flash_attn_version() == 3)
self.use_full_cuda_graph = compilation_config.full_cuda_graph
if self.use_full_cuda_graph and not self.aot_schedule:
raise ValueError("Full CUDA graph mode requires AOT scheduling, "
"which requires FlashAttention 3.")
self.scheduler_metadata = torch.zeros(self.runner.max_num_reqs + 1,
dtype=torch.int32,
device=self.runner.device)
if self.use_full_cuda_graph:
# NOTE(lucas): AOT scheduling not supported in full cuda graph mode
# yet. This is because the scheduler and kernel need to always use
# the same num_splits (which acts as an upper bound with the
# dynamic split scheduler) which is currently heuristically decided
# by the kernel launching code.
self.aot_schedule = False
Comment on lines +161 to +167
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The previous logic (lines 161-163 removed) stated that "Full CUDA graph mode requires AOT scheduling". However, the new code (lines 161-167 added) explicitly sets self.aot_schedule = False when self.use_full_cuda_graph is enabled. This is a direct contradiction of the previous requirement.

While the new comment clarifies that "AOT scheduling not supported in full cuda graph mode yet", it's important to ensure that any other parts of the codebase or documentation that refer to full_cuda_graph requiring aot_schedule are updated to reflect this change. The previous ValueError message was misleading given the current state of support.


# Sliding window size to be used with the AOT scheduler will be
# populated on first build() call.
Expand Down Expand Up @@ -299,18 +300,6 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
max_seq_len=max_seq_len,
causal=True)

if self.use_full_cuda_graph:
assert scheduler_metadata is not None
n = scheduler_metadata.shape[0]
self.scheduler_metadata[:n].copy_(scheduler_metadata,
non_blocking=True)
# NOTE(woosuk): We should zero out the rest of the scheduler
# metadata to guarantee the correctness. Otherwise, some thread
# blocks may use the invalid scheduler metadata and overwrite the
# output buffer.
self.scheduler_metadata[n:] = 0
scheduler_metadata = self.scheduler_metadata[:n]

attn_metadata = FlashAttentionMetadata(
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
Expand Down