Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.utils import cdiv
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
get_kv_cache_layout)
from vllm.v1.kv_cache_interface import AttentionSpec
Expand Down Expand Up @@ -153,7 +154,9 @@ def _get_sliding_window_configs(

class FlashAttentionMetadataBuilder(
AttentionMetadataBuilder[FlashAttentionMetadata]):
full_cudagraph_supported: ClassVar[bool] = get_flash_attn_version() == 3
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.NEVER if get_flash_attn_version() == 2 \
else AttentionCGSupport.ALWAYS

def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
Expand Down
357 changes: 323 additions & 34 deletions vllm/v1/attention/backends/flashinfer.py

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion vllm/v1/attention/backends/mla/flashmla.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder)
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec

logger = init_logger(__name__)
Expand Down Expand Up @@ -54,7 +55,8 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):


class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
full_cudagraph_supported: ClassVar[bool] = True # Decode-only
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.PURE_DECODE_ONLY

def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
Expand Down
4 changes: 3 additions & 1 deletion vllm/v1/attention/backends/mla/rocm_aiter_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder)
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec

# yapf: enable
Expand Down Expand Up @@ -64,7 +65,8 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):


class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
full_cudagraph_supported: ClassVar[bool] = True # decode only
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.PURE_DECODE_ONLY

def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
Expand Down
6 changes: 4 additions & 2 deletions vllm/v1/attention/backends/triton_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec

Expand Down Expand Up @@ -57,7 +58,8 @@ class TritonAttentionMetadata:

class TritonAttentionMetadataBuilder(
AttentionMetadataBuilder[TritonAttentionMetadata]):
full_cudagraph_supported: ClassVar[bool] = True
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.ALWAYS

def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
Expand Down
18 changes: 17 additions & 1 deletion vllm/v1/attention/backends/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import abc
import enum
import functools
from abc import abstractmethod
from dataclasses import dataclass, make_dataclass
Expand Down Expand Up @@ -65,9 +66,24 @@ class CommonAttentionMetadata:
M = TypeVar("M")


class AttentionCGSupport(enum.Enum):
""" Constants for the cudagraph support of the attention backend
Here we do not consider the cascade attention, as currently
it is never cudagraph supported."""

NEVER = 0
"""NO cudagraph support"""
PURE_DECODE_ONLY = 1
"""Cudagraph supported for pure decode, need to run without
cudagraph for mixed prefill-decode batches"""
ALWAYS = 2
"""Cudagraph always supported"""


class AttentionMetadataBuilder(abc.ABC, Generic[M]):
# Does this backend/builder support CUDA Graphs for attention.
full_cudagraph_supported: ClassVar[bool] = False
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.NEVER

@abstractmethod
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
Expand Down
24 changes: 17 additions & 7 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
is_pin_memory_available, round_up, supports_dynamo)
from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata,
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
make_kv_sharing_fast_prefill_attention_metadata,
make_local_attention_virtual_batches)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
Expand Down Expand Up @@ -2619,12 +2619,22 @@ def _initialize_single_attn_backend(
self.device,
)

if (self.full_cuda_graph
and not attn_metadata_builder_i.full_cudagraph_supported):
raise ValueError(
f"Full CUDAGraph not supported for "
f"{attn_backend_i.__name__}. Turn off CompilationConfig."
f"full_cuda_graph or use a different attention backend.")
if self.full_cuda_graph:
if attn_metadata_builder_i.attn_cudagraph_support == \
AttentionCGSupport.NEVER:
raise ValueError(f"Full CUDAGraph not supported for "
f"{attn_backend_i.__name__}. Turn off "
f"CompilationConfig.full_cuda_graph or use a "
f" different attention backend.")
if attn_metadata_builder_i.attn_cudagraph_support == \
AttentionCGSupport.PURE_DECODE_ONLY:
# Limit the max cudagraph size to the max number of
# sequences for pure decode only cudagraph backend,
# whose max_query_len is 1.
self.cudagraph_batch_sizes = [
size for size in self.cudagraph_batch_sizes
if size <= self.scheduler_config.max_num_seqs
]
return attn_backend_i, attn_metadata_builder_i

def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
Expand Down
5 changes: 5 additions & 0 deletions vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,11 +321,16 @@ def compile_or_warm_up_model(self) -> None:
if get_pp_group().is_last_rank:
max_num_reqs = min(self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens)
# activate building attn_metadata for this dummy run to avoid
# potential illegal memory access for full cudagraph relay.
attn_cudagraph = self.compilation_config.full_cuda_graph and\
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I understand why you need this. AIUI, this code is specifically warming up shapes that are not in the cudagraph capture list? Is this required because you modified the list in the GPUModelRunner?

I see there's some discussion about a hang when you don't pass an attention metadata into the dummy_run?

Copy link
Contributor Author

@fhl2000 fhl2000 Jul 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! @SageMoore, Thank you for the questions!

Is this required because you modified the list in the GPUModelRunner?

I think they are not related.

I'd like to try explaining more here. This line of code is actually located after capturing all shapes of cudagraphs for the modified list in gpu_model_runner. This dummy_run with num_tokens= max_num_reqs is actually <= the max captured size of that modified list. And recall that dummy_run for attention_cg_support=PURE_DECODE_ONLY would only try to run pure decode batches. So here it would only run into cudagraph replay of decode only if it hits the size of list, otherwise no cudagraph. However, when it hits the replay, FlashInfer may be trapped in an infinite loop if the content in the persistent buffers is incorrect.

Copy link
Contributor

@SageMoore SageMoore Jul 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK please let me know if I'm understanding correctly. You are saying that, if max_num_reqs is a shape that has already been full cudagraph captured, we need to make sure that the _dummy_run goes through the process of creating an AttentionMetadata because, even though the persistent buffers are guaranteed to exist, they can contain incorrect data which can cause the graph replay to hang when running with the flash infer backend?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ok makes sense to me I think; basically for all dummy runs after capture we need build the metadata since it will result in a graph replay

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are saying that, if max_num_reqs is a shape that has already been full cudagraph captured, we need to make sure that the _dummy_run goes through the process of creating an AttentionMetadata because, even though the persistent buffers are guaranteed to exist, they can contain incorrect data which can cause the graph replay to hang when running with the flash infer backend

Exactly.

not self.model_config.enforce_eager

# We skip EPLB here since we don't want to record dummy metrics
hidden_states, last_hidden_states = \
self.model_runner._dummy_run(
num_tokens=max_num_reqs,
capture_attn_cudagraph=attn_cudagraph,
skip_eplb=True,
)
if self.model_runner.is_pooling_model:
Expand Down