From 48ce1aadfa05b110910108e06f66934683435258 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 9 Jul 2025 17:50:10 +0000 Subject: [PATCH 1/4] [Misc] Log the reason for falling back to FlexAttention Signed-off-by: DarkLight1337 --- vllm/attention/selector.py | 49 +++++++++++++---- vllm/platforms/cuda.py | 57 ++++++++++++-------- vllm/v1/attention/backends/cpu_attn.py | 4 ++ vllm/v1/attention/backends/flash_attn.py | 4 ++ vllm/v1/attention/backends/flashinfer.py | 4 ++ vllm/v1/attention/backends/flex_attention.py | 4 ++ vllm/v1/attention/backends/mla/common.py | 4 ++ vllm/v1/attention/backends/rocm_aiter_fa.py | 4 ++ vllm/v1/attention/backends/triton_attn.py | 4 ++ 9 files changed, 102 insertions(+), 32 deletions(-) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index df14aea729f3..4c4c9b0154b9 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -3,6 +3,7 @@ import os from contextlib import contextmanager +from dataclasses import dataclass from functools import cache from typing import Generator, Optional, Union @@ -79,31 +80,61 @@ def get_global_forced_attn_backend() -> Optional[_Backend]: return forced_attn_backend -def supports_head_size( +@dataclass(frozen=True) +class _IsSupported: + can_import: bool + head_size: bool + dtype: bool + + def __bool__(self) -> bool: + return self.head_size and self.dtype + + +def is_attn_backend_supported( attn_backend: Union[str, type[AttentionBackend]], head_size: int, -) -> bool: + dtype: torch.dtype, + *, + allow_import_error: bool = True, +) -> _IsSupported: if isinstance(attn_backend, str): try: attn_backend = resolve_obj_by_qualname(attn_backend) except ImportError: - return False + if not allow_import_error: + raise + + return _IsSupported(can_import=False, head_size=False, dtype=False) assert isinstance(attn_backend, type) # TODO: Update the interface once V0 is removed if get_supported_head_sizes := getattr(attn_backend, "get_supported_head_sizes", None): - return head_size in get_supported_head_sizes() - if validate_head_size := getattr(attn_backend, "validate_head_size", None): + is_head_size_supported = head_size in get_supported_head_sizes() + elif validate_head_size := getattr(attn_backend, "validate_head_size", + None): try: validate_head_size(head_size) - return True + is_head_size_supported = True except Exception: - return False + is_head_size_supported = False + else: + raise NotImplementedError(f"{attn_backend.__name__} does not support " + "head size validation") + + if get_supported_dtypes := getattr(attn_backend, "get_supported_dtypes", + None): + is_dtype_supported = dtype in get_supported_dtypes() + else: + raise NotImplementedError(f"{attn_backend.__name__} does not support " + "dtype validation") - raise NotImplementedError(f"{attn_backend.__name__} does not support " - "head size validation") + return _IsSupported( + can_import=True, + head_size=is_head_size_supported, + dtype=is_dtype_supported, + ) def get_attn_backend( diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index b53d7e71a03e..a661229b2399 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -253,39 +253,50 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, logger.info_once("Using Flash Attention backend on V1 engine.") return FLASH_ATTN_V1 - from vllm.attention.selector import supports_head_size + from vllm.attention.selector import is_attn_backend_supported # Default backends for V1 engine - # FP32 is only supported by FlexAttention - if dtype not in (torch.float16, torch.bfloat16): - logger.info_once( - "Using FlexAttention backend for %s on V1 engine.", - dtype, - ) - return FLEX_ATTENTION_V1 - # Prefer FlashInfer for Blackwell GPUs if installed - if cls.is_device_capability(100) and \ - supports_head_size(FLASHINFER_V1, head_size): - try: - import flashinfer # noqa: F401 + if cls.is_device_capability(100): + if is_default_backend_supported := is_attn_backend_supported( + FLASHINFER_V1, head_size, dtype): logger.info_once( - "Using FlashInfer backend on V1 engine by default for " - "Blackwell (SM 10.0) GPUs.") + "Using FlashInfer backend on V1 engine by default " + "for Blackwell (SM 10.0) GPUs.") return FLASHINFER_V1 - except ImportError: - logger.info_once( + + if not is_default_backend_supported.can_import: + logger.warning_once( "FlashInfer failed to import for V1 engine on " "Blackwell (SM 10.0) GPUs; it is recommended to " "install FlashInfer for better performance.") - pass + # FlashAttention is the default for SM 8.0+ GPUs - if cls.has_device_capability(80) and \ - supports_head_size(FLASH_ATTN_V1, head_size): - logger.info_once("Using Flash Attention backend on V1 engine.") - return FLASH_ATTN_V1 + if cls.has_device_capability(80): + if is_default_backend_supported := is_attn_backend_supported( + FLASH_ATTN_V1, head_size, dtype, + allow_import_error=False): + logger.info_once("Using Flash Attention backend on " + "V1 engine.") + return FLASH_ATTN_V1 + + # FlexAttention is the default for older GPUs + else: + logger.info_once("Using FlexAttention backend on V1 engine.") + return FLEX_ATTENTION_V1 + + assert not is_default_backend_supported + + default_not_supported_reason = {} + if not is_default_backend_supported.head_size: + default_not_supported_reason["head_size"] = head_size + if not is_default_backend_supported.dtype: + default_not_supported_reason["dtype"] = dtype - logger.info_once("Using FlexAttention backend on V1 engine.") + logger.info_once( + "Using FlexAttention backend for %s on V1 engine.", + str(default_not_supported_reason), + ) return FLEX_ATTENTION_V1 # Backends for V0 engine diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index d6270fbf3196..f1c6bdfc1c94 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -37,6 +37,10 @@ class TorchSDPABackend(AttentionBackend): accept_output_buffer: bool = False + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16, torch.float32] + @classmethod def validate_head_size(cls, head_size: int) -> None: attn_impl = _get_paged_attn_impl() diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index fbc13c06c65a..552c2caf2fa8 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -44,6 +44,10 @@ class FlashAttentionBackend(AttentionBackend): accept_output_buffer: bool = True + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16] + @classmethod def get_supported_head_sizes(cls) -> list[int]: return [32, 64, 96, 128, 160, 192, 224, 256] diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 860309faa905..4c8a734d867a 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -38,6 +38,10 @@ class FlashInferBackend(AttentionBackend): accept_output_buffer: bool = True + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16] + @classmethod def get_supported_head_sizes(cls) -> list[int]: # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157 diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index a8c5f464aa32..f0f54c28831f 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -42,6 +42,10 @@ def _offsets_to_doc_ids_tensor(offsets: torch.Tensor) -> torch.Tensor: class FlexAttentionBackend(AttentionBackend): accept_output_buffer: bool = True + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16, torch.float32] + @classmethod def validate_head_size(cls, head_size: int) -> None: return # FlexAttention supports any head size diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index f2aaf59a40f8..e66286d0dde8 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -254,6 +254,10 @@ def get_kv_cache_shape( ) -> tuple[int, ...]: return (num_blocks, block_size, head_size) + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16] + @classmethod def get_supported_head_sizes(cls) -> list[int]: return [576] diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 6a78b03dce86..dd86e56885ed 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -314,6 +314,10 @@ class AiterFlashAttentionBackend(AttentionBackend): accept_output_buffer: bool = True + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16] + @classmethod def get_supported_head_sizes(cls) -> list[int]: return [32, 64, 96, 128, 160, 192, 224, 256] diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index cdaff2f6a40f..7dc90a6a97e7 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -190,6 +190,10 @@ class TritonAttentionBackend(AttentionBackend): accept_output_buffer: bool = True + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16] + @classmethod def get_supported_head_sizes(cls) -> list[int]: return [32, 64, 96, 128, 160, 192, 224, 256] From aee14797f606f043fab962c7ba079cb6cfe82c00 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 9 Jul 2025 17:54:27 +0000 Subject: [PATCH 2/4] Rename Signed-off-by: DarkLight1337 --- vllm/platforms/cuda.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index a661229b2399..80503d33eb69 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -287,15 +287,15 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, assert not is_default_backend_supported - default_not_supported_reason = {} + use_flex_attention_reason = {} if not is_default_backend_supported.head_size: - default_not_supported_reason["head_size"] = head_size + use_flex_attention_reason["head_size"] = head_size if not is_default_backend_supported.dtype: - default_not_supported_reason["dtype"] = dtype + use_flex_attention_reason["dtype"] = dtype logger.info_once( "Using FlexAttention backend for %s on V1 engine.", - str(default_not_supported_reason), + str(use_flex_attention_reason), ) return FLEX_ATTENTION_V1 From 2111a24907f0c7ce4761f105a12cc0179fbaabf1 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Wed, 9 Jul 2025 17:55:42 +0000 Subject: [PATCH 3/4] Format Signed-off-by: DarkLight1337 --- vllm/platforms/cuda.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 80503d33eb69..2155a524d282 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -295,7 +295,8 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, logger.info_once( "Using FlexAttention backend for %s on V1 engine.", - str(use_flex_attention_reason), + ", ".join(f"{k}={v}" + for k, v in use_flex_attention_reason.items()), ) return FLEX_ATTENTION_V1 From e84a7a66c9d990e312ed5ebc1d2039f7f90e8135 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 10 Jul 2025 02:05:55 +0800 Subject: [PATCH 4/4] Update selector.py --- vllm/attention/selector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 4c4c9b0154b9..4d4886d02b78 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -87,7 +87,7 @@ class _IsSupported: dtype: bool def __bool__(self) -> bool: - return self.head_size and self.dtype + return self.can_import and self.head_size and self.dtype def is_attn_backend_supported(