Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 40 additions & 9 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import os
from contextlib import contextmanager
from dataclasses import dataclass
from functools import cache
from typing import Generator, Optional, Union

Expand Down Expand Up @@ -79,31 +80,61 @@ def get_global_forced_attn_backend() -> Optional[_Backend]:
return forced_attn_backend


def supports_head_size(
@dataclass(frozen=True)
class _IsSupported:
can_import: bool
head_size: bool
dtype: bool

def __bool__(self) -> bool:
return self.can_import and self.head_size and self.dtype


def is_attn_backend_supported(
attn_backend: Union[str, type[AttentionBackend]],
head_size: int,
) -> bool:
dtype: torch.dtype,
*,
allow_import_error: bool = True,
) -> _IsSupported:
if isinstance(attn_backend, str):
try:
attn_backend = resolve_obj_by_qualname(attn_backend)
except ImportError:
return False
if not allow_import_error:
raise

return _IsSupported(can_import=False, head_size=False, dtype=False)

assert isinstance(attn_backend, type)

# TODO: Update the interface once V0 is removed
if get_supported_head_sizes := getattr(attn_backend,
"get_supported_head_sizes", None):
return head_size in get_supported_head_sizes()
if validate_head_size := getattr(attn_backend, "validate_head_size", None):
is_head_size_supported = head_size in get_supported_head_sizes()
elif validate_head_size := getattr(attn_backend, "validate_head_size",
None):
try:
validate_head_size(head_size)
return True
is_head_size_supported = True
except Exception:
return False
is_head_size_supported = False
else:
raise NotImplementedError(f"{attn_backend.__name__} does not support "
"head size validation")

if get_supported_dtypes := getattr(attn_backend, "get_supported_dtypes",
None):
is_dtype_supported = dtype in get_supported_dtypes()
else:
raise NotImplementedError(f"{attn_backend.__name__} does not support "
"dtype validation")

raise NotImplementedError(f"{attn_backend.__name__} does not support "
"head size validation")
return _IsSupported(
can_import=True,
head_size=is_head_size_supported,
dtype=is_dtype_supported,
)


def get_attn_backend(
Expand Down
57 changes: 35 additions & 22 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,43 +259,56 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
logger.info_once("Using Flash Attention backend on V1 engine.")
return FLASH_ATTN_V1

from vllm.attention.selector import supports_head_size
from vllm.attention.selector import is_attn_backend_supported

# Default backends for V1 engine
# FP32 is only supported by FlexAttention
if dtype not in (torch.float16, torch.bfloat16):
logger.info_once(
"Using FlexAttention backend for %s on V1 engine.",
dtype,
)
return FLEX_ATTENTION_V1

# Prefer FlashInfer for Blackwell GPUs if installed
if cls.is_device_capability(100) and \
supports_head_size(FLASHINFER_V1, head_size):
try:
import flashinfer # noqa: F401

if cls.is_device_capability(100):
if is_default_backend_supported := is_attn_backend_supported(
FLASHINFER_V1, head_size, dtype):
from vllm.v1.attention.backends.utils import (
set_kv_cache_layout)

logger.info_once(
"Using FlashInfer backend with HND KV cache layout on "
"V1 engine by default for Blackwell (SM 10.0) GPUs.")
set_kv_cache_layout("HND")

return FLASHINFER_V1
except ImportError:
logger.info_once(

if not is_default_backend_supported.can_import:
logger.warning_once(
"FlashInfer failed to import for V1 engine on "
"Blackwell (SM 10.0) GPUs; it is recommended to "
"install FlashInfer for better performance.")
pass

# FlashAttention is the default for SM 8.0+ GPUs
if cls.has_device_capability(80) and \
supports_head_size(FLASH_ATTN_V1, head_size):
logger.info_once("Using Flash Attention backend on V1 engine.")
return FLASH_ATTN_V1
if cls.has_device_capability(80):
if is_default_backend_supported := is_attn_backend_supported(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable is_default_backend_supported is reused for both FlashInfer (line 261) and FlashAttention (line 276) checks. This reassignment can reduce clarity regarding which backend's support is being evaluated. Consider using a more specific variable name, such as is_flash_attn_supported, for the FlashAttention check to improve readability and explicitly indicate the context of the support check.

FLASH_ATTN_V1, head_size, dtype,
allow_import_error=False):
logger.info_once("Using Flash Attention backend on "
"V1 engine.")
return FLASH_ATTN_V1

# FlexAttention is the default for older GPUs
else:
logger.info_once("Using FlexAttention backend on V1 engine.")
return FLEX_ATTENTION_V1

assert not is_default_backend_supported

use_flex_attention_reason = {}
if not is_default_backend_supported.head_size:
use_flex_attention_reason["head_size"] = head_size
if not is_default_backend_supported.dtype:
use_flex_attention_reason["dtype"] = dtype

logger.info_once("Using FlexAttention backend on V1 engine.")
logger.info_once(
"Using FlexAttention backend for %s on V1 engine.",
", ".join(f"{k}={v}"
for k, v in use_flex_attention_reason.items()),
)
return FLEX_ATTENTION_V1

# Backends for V0 engine
Expand Down
2 changes: 1 addition & 1 deletion vllm/reasoning/hunyuan_a13b_reasoning_parser.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import re
from collections.abc import Sequence
from typing import Optional, Union

import regex as re
from transformers import PreTrainedTokenizerBase

from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
Expand Down
4 changes: 4 additions & 0 deletions vllm/v1/attention/backends/cpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@
class TorchSDPABackend(AttentionBackend):
accept_output_buffer: bool = False

@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16, torch.float32]

@classmethod
def validate_head_size(cls, head_size: int) -> None:
attn_impl = _get_paged_attn_impl()
Expand Down
4 changes: 4 additions & 0 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ class FlashAttentionBackend(AttentionBackend):

accept_output_buffer: bool = True

@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]

@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
Expand Down
4 changes: 4 additions & 0 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ class FlashInferBackend(AttentionBackend):
accept_output_buffer: bool = True
cached_sm100a_supported: Optional[bool] = None

@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]

@classmethod
def get_supported_head_sizes(cls) -> list[int]:
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
Expand Down
4 changes: 4 additions & 0 deletions vllm/v1/attention/backends/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ def _offsets_to_doc_ids_tensor(offsets: torch.Tensor) -> torch.Tensor:
class FlexAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True

@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16, torch.float32]

@classmethod
def validate_head_size(cls, head_size: int) -> None:
return # FlexAttention supports any head size
Expand Down
4 changes: 4 additions & 0 deletions vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,10 @@ def get_kv_cache_shape(
) -> tuple[int, ...]:
return (num_blocks, block_size, head_size)

@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]

@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [576]
Expand Down
4 changes: 4 additions & 0 deletions vllm/v1/attention/backends/rocm_aiter_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,10 @@ class AiterFlashAttentionBackend(AttentionBackend):

accept_output_buffer: bool = True

@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]

@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
Expand Down
4 changes: 4 additions & 0 deletions vllm/v1/attention/backends/triton_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,10 @@ class TritonAttentionBackend(AttentionBackend):

accept_output_buffer: bool = True

@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]

@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
Expand Down