diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py index 29a3b40d2d86..72819f31de20 100644 --- a/tests/entrypoints/openai/test_vision.py +++ b/tests/entrypoints/openai/test_vision.py @@ -34,11 +34,11 @@ ], [ "The image shows a Venn diagram with three over", - "The image shows a Venn diagram with three intersect", + "This image shows a Venn diagram with three over", ], [ "This image displays a gradient of colors ranging from", - "The image displays a gradient of colors ranging from", + "This image displays a gradient of colors forming a spectrum", ], ] diff --git a/tests/kernels/attention/test_mha_attn.py b/tests/kernels/attention/test_mha_attn.py index c01ea32994da..d37b968ed979 100644 --- a/tests/kernels/attention/test_mha_attn.py +++ b/tests/kernels/attention/test_mha_attn.py @@ -36,31 +36,52 @@ def test_mha_attn_platform(device: str): torch.set_default_dtype(torch.float16) if device == "cpu": - with patch("vllm.attention.selector.current_platform", - CpuPlatform()), \ - patch("vllm.platforms.current_platform", CpuPlatform()): + with patch("vllm.attention.layer.current_platform", CpuPlatform()), \ + patch("vllm.model_executor.models.vision.current_platform", + CpuPlatform()): attn = MultiHeadAttention(16, 64, scale=1) - assert attn.attn_backend == _Backend.TORCH_SDPA_VLLM_V1 + assert attn.attn_backend == _Backend.TORCH_SDPA elif device == "hip": - with patch("vllm.attention.selector.current_platform", - RocmPlatform()), \ - patch("vllm.platforms.current_platform", RocmPlatform()), \ - patch("vllm.attention.layer.current_platform", RocmPlatform()): + with patch("vllm.attention.layer.current_platform", RocmPlatform()), \ + patch("vllm.model_executor.models.vision.current_platform", + RocmPlatform()): attn = MultiHeadAttention(16, 64, scale=1) assert attn.attn_backend == _Backend.TORCH_SDPA else: - with patch("vllm.attention.selector.current_platform", - CudaPlatform()), \ - patch("vllm.platforms.current_platform", CudaPlatform()): + # Test CUDA with head_size=64 (divisible by 32) + # - should use vLLM's FlashAttention + with patch("vllm.attention.layer.current_platform", CudaPlatform()), \ + patch("vllm.model_executor.models.vision.current_platform", + CudaPlatform()): attn = MultiHeadAttention(16, 64, scale=1) - assert attn.attn_backend == _Backend.XFORMERS + assert attn.attn_backend == _Backend.FLASH_ATTN - with patch("vllm.attention.selector.current_platform", + # Test CUDA with head_size=72 (not divisible by 32) + # - with upstream FA not available + # - should use xformers + with patch("vllm.attention.layer.current_platform", CudaPlatform()), \ + patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()), \ - patch("vllm.platforms.current_platform", CudaPlatform()): + patch("vllm.attention.layer.check_upstream_fa_availability", + return_value=False): attn = MultiHeadAttention(16, 72, scale=1) assert attn.attn_backend == _Backend.XFORMERS + # Test CUDA with head_size=72 (not divisible by 32) + # - with upstream FA available + # - should use upstream FA + with patch("vllm.attention.layer.current_platform", CudaPlatform()), \ + patch("vllm.model_executor.models.vision.current_platform", + CudaPlatform()), \ + patch("vllm.attention.layer.check_upstream_fa_availability", + return_value=True), \ + patch.dict('sys.modules', {'flash_attn': type('MockFlashAttn', (), + { + 'flash_attn_varlen_func': lambda *args, **kwargs: None + })()}): + attn = MultiHeadAttention(16, 72, scale=1) + assert attn.attn_backend == _Backend.FLASH_ATTN + def ref_attention( query: torch.Tensor, diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index bb05b468fd10..44cb2c7c6b64 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -23,6 +23,7 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.platforms import _Backend, current_platform from vllm.utils import direct_register_custom_op @@ -55,6 +56,14 @@ def check_xformers_availability(): return USE_XFORMERS_OPS +def check_upstream_fa_availability(dtype: torch.dtype): + if dtype in (torch.float16, torch.bfloat16) and current_platform.is_cuda( + ) and current_platform.has_device_capability(80): + from transformers.utils import is_flash_attn_2_available + return is_flash_attn_2_available() + return False + + class Attention(nn.Module, AttentionLayerBase): """Attention layer. @@ -349,29 +358,55 @@ def __init__( f"divisible by num_kv_heads ({self.num_kv_heads})" self.num_queries_per_kv = self.num_heads // self.num_kv_heads + # During model initialization, the default dtype is set as the model + # weight and activation dtype. dtype = torch.get_default_dtype() - attn_backend = get_attn_backend(head_size, - dtype, - kv_cache_dtype=None, - block_size=16, - is_attention_free=False) - backend = backend_name_to_enum(attn_backend.get_name()) + + # Determine the attention backend + backend = get_vit_attn_backend(head_size=head_size, dtype=dtype) + + # Some auto-selected backends can be upgraded + # to upstream flash attention if available. + # If vllm native fa is selected, we use it directly. + use_upstream_fa = False + if backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( + dtype): + backend = _Backend.FLASH_ATTN + use_upstream_fa = True + if current_platform.is_rocm(): # currently, only torch_sdpa is supported on rocm self.attn_backend = _Backend.TORCH_SDPA else: + self.attn_backend = backend if backend in { _Backend.TORCH_SDPA, _Backend.TORCH_SDPA_VLLM_V1, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1, _Backend.ROCM_AITER_FA, - } else current_platform.get_vit_attn_backend() + _Backend.FLASH_ATTN, + _Backend.FLASH_ATTN_VLLM_V1, + } else _Backend.TORCH_SDPA if (self.attn_backend == _Backend.XFORMERS and not check_xformers_availability()): self.attn_backend = _Backend.TORCH_SDPA + if self.attn_backend in { + _Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1 + }: + if use_upstream_fa: + from flash_attn import flash_attn_varlen_func + self._flash_attn_varlen_func = flash_attn_varlen_func + else: + from vllm.vllm_flash_attn import flash_attn_varlen_func + self._flash_attn_varlen_func = flash_attn_varlen_func + + logger.info_once( + f"MultiHeadAttention attn_backend: {self.attn_backend}, " + f"use_upstream_fa: {use_upstream_fa}") + def forward( self, query: torch.Tensor, @@ -392,7 +427,31 @@ def forward( key = torch.repeat_interleave(key, num_repeat, dim=2) value = torch.repeat_interleave(value, num_repeat, dim=2) - if self.attn_backend == _Backend.XFORMERS: + if self.attn_backend in { + _Backend.FLASH_ATTN, + _Backend.FLASH_ATTN_VLLM_V1, + }: + + cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len, + step=q_len, + dtype=torch.int32, + device=query.device) + cu_seqlens_k = torch.arange(0, (bsz + 1) * kv_len, + step=kv_len, + dtype=torch.int32, + device=key.device) + + out = self._flash_attn_varlen_func( + query.flatten(0, 1), + key.flatten(0, 1), + value.flatten(0, 1), + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=q_len, + max_seqlen_k=kv_len, + softmax_scale=self.scale, + ) + elif self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops out = xops.memory_efficient_attention_forward(query, diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index 97aace5a20c3..bcff65a717ab 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -34,6 +34,7 @@ from einops import rearrange, repeat from transformers import BatchFeature +from vllm.attention.layer import check_upstream_fa_availability from vllm.config import VllmConfig from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils @@ -170,7 +171,16 @@ def __init__( prefix=f"{prefix}.proj") # Detect attention implementation. - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend = get_vit_attn_backend( + head_size=self.hidden_size_per_attention_head, + dtype=torch.get_default_dtype()) + + self.use_upstream_fa = False + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability(torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN + self.use_upstream_fa = True + if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.ROCM_AITER_FA @@ -233,7 +243,10 @@ def forward( if self.attn_backend == _Backend.ROCM_AITER_FA: from aiter import flash_attn_varlen_func else: - from flash_attn import flash_attn_varlen_func + if self.use_upstream_fa: + from flash_attn import flash_attn_varlen_func + else: + from vllm.vllm_flash_attn import flash_attn_varlen_func q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) @@ -457,7 +470,11 @@ def __init__( ), "vit's config.hidden must be equal to config.embed_dim" self.ln = nn.LayerNorm(hidden_size, eps=1e-6) - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend = get_vit_attn_backend( + head_size=head_dim, dtype=torch.get_default_dtype()) + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability(torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN @property def dtype(self) -> torch.dtype: diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 539381b61800..279f458dfa6c 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -44,6 +44,7 @@ Glm4vVideoProcessor) from transformers.video_utils import VideoMetadata +from vllm.attention.layer import check_upstream_fa_availability from vllm.config import VllmConfig from vllm.distributed import (get_tensor_model_parallel_world_size, parallel_state) @@ -260,7 +261,15 @@ def __init__( ) # Detect attention implementation. - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend = get_vit_attn_backend( + head_size=self.hidden_size_per_attention_head, + dtype=torch.get_default_dtype()) + self.use_upstream_fa = False + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability(torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN + self.use_upstream_fa = True + if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, @@ -310,7 +319,10 @@ def forward( if self.attn_backend == _Backend.FLASH_ATTN: # from vllm_flash_attn.flash_attn_interface import ( # flash_attn_varlen_func) - from flash_attn import flash_attn_varlen_func + if self.use_upstream_fa: + from flash_attn import flash_attn_varlen_func + else: + from vllm.vllm_flash_attn import flash_attn_varlen_func q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) @@ -715,7 +727,11 @@ def __init__( self.post_layernorm = RMSNorm(vision_config.hidden_size, eps=vision_config.rms_norm_eps) - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend = get_vit_attn_backend( + head_size=head_dim, dtype=torch.get_default_dtype()) + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability(torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN @property def dtype(self) -> torch.dtype: diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index 710b805acb3e..04824db1b6dd 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -17,6 +17,7 @@ BaseModelOutputWithPooling) from transformers.utils import torch_int +from vllm.attention.layer import check_upstream_fa_availability from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger @@ -374,7 +375,16 @@ def __init__( ) # Detect attention implementation. - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend = get_vit_attn_backend( + head_size=self.head_dim, dtype=torch.get_default_dtype()) + + self.use_upstream_fa = False + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability( + torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN + self.use_upstream_fa = True + if self.attn_backend not in {_Backend.FLASH_ATTN, _Backend.XFORMERS}: raise RuntimeError( f"Keye-VL does not support {self.attn_backend} backend now.") @@ -428,7 +438,10 @@ def forward( ) if self.attn_backend == _Backend.FLASH_ATTN: - from flash_attn import flash_attn_varlen_func + if self.use_upstream_fa: + from flash_attn import flash_attn_varlen_func + else: + from vllm.vllm_flash_attn import flash_attn_varlen_func q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 8aa777557029..98f9c0cf4c16 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -38,6 +38,7 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig) +from vllm.attention.layer import check_upstream_fa_availability from vllm.config import VllmConfig from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils @@ -298,7 +299,16 @@ def __init__( disable_tp=use_data_parallel) # Detect attention implementation. - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend = get_vit_attn_backend( + head_size=self.hidden_size_per_attention_head, + dtype=torch.get_default_dtype()) + self.use_upstream_fa = False + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability( + torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN + self.use_upstream_fa = True + if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.ROCM_AITER_FA @@ -359,7 +369,10 @@ def forward( if self.attn_backend == _Backend.ROCM_AITER_FA: from aiter import flash_attn_varlen_func else: - from flash_attn import flash_attn_varlen_func + if self.use_upstream_fa: + from flash_attn import flash_attn_varlen_func + else: + from vllm.vllm_flash_attn import flash_attn_varlen_func q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) @@ -628,7 +641,12 @@ def __init__( prefix=f"{prefix}.merger", use_data_parallel=use_data_parallel, ) - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend = get_vit_attn_backend( + head_size=head_dim, dtype=torch.get_default_dtype()) + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability( + torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN @property def dtype(self) -> torch.dtype: diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 90a1ad2a658a..89af79c3b5fd 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -41,6 +41,7 @@ from transformers.models.qwen2_vl.video_processing_qwen2_vl import ( Qwen2VLVideoProcessor) +from vllm.attention.layer import check_upstream_fa_availability from vllm.config import VllmConfig from vllm.distributed import parallel_state, tensor_model_parallel_all_gather from vllm.distributed import utils as dist_utils @@ -314,7 +315,16 @@ def __init__( prefix=f"{prefix}.proj") # Detect attention implementation. - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend = get_vit_attn_backend( + head_size=self.hidden_size_per_attention_head, + dtype=torch.get_default_dtype()) + self.use_upstream_fa = False + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability( + torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN + self.use_upstream_fa = True + if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.ROCM_AITER_FA @@ -374,7 +384,10 @@ def forward( if self.attn_backend == _Backend.ROCM_AITER_FA: from aiter import flash_attn_varlen_func else: - from flash_attn import flash_attn_varlen_func + if self.use_upstream_fa: + from flash_attn import flash_attn_varlen_func + else: + from vllm.vllm_flash_attn import flash_attn_varlen_func q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) @@ -628,7 +641,12 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.merger", ) - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend = get_vit_attn_backend( + head_size=head_dim, dtype=torch.get_default_dtype()) + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability( + torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN @property def dtype(self) -> torch.dtype: diff --git a/vllm/model_executor/models/siglip2navit.py b/vllm/model_executor/models/siglip2navit.py index c6244fb3b3e6..a86700fe68dd 100644 --- a/vllm/model_executor/models/siglip2navit.py +++ b/vllm/model_executor/models/siglip2navit.py @@ -13,6 +13,7 @@ from transformers import Siglip2VisionConfig from transformers.configuration_utils import PretrainedConfig +from vllm.attention.layer import check_upstream_fa_availability from vllm.config import QuantizationConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn @@ -236,7 +237,15 @@ def __init__( self.use_rope = config.use_rope # Detect attention implementation. - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend = get_vit_attn_backend( + head_size=self.head_dim, dtype=torch.get_default_dtype()) + self.use_upstream_fa = False + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability( + torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN + self.use_upstream_fa = True + if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.ROCM_AITER_FA @@ -280,7 +289,10 @@ def forward( if self.attn_backend == _Backend.ROCM_AITER_FA: from aiter import flash_attn_varlen_func else: - from flash_attn import flash_attn_varlen_func + if self.use_upstream_fa: + from flash_attn import flash_attn_varlen_func + else: + from vllm.vllm_flash_attn import flash_attn_varlen_func attn_output = flash_attn_varlen_func( queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(seq_length, -1) diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index c16aa5ac608f..81f86db7e187 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -7,7 +7,6 @@ import torch from transformers import PretrainedConfig -from vllm.attention.selector import get_env_variable_attn_backend from vllm.logger import init_logger from vllm.platforms import _Backend, current_platform @@ -68,17 +67,18 @@ def get_vision_encoder_info( raise NotImplementedError(msg) -def get_vit_attn_backend(support_fa: bool = False) -> _Backend: +def get_vit_attn_backend(head_size: int, dtype: torch.dtype) -> _Backend: """ Get the available attention backend for Vision Transformer. """ - # TODO(Isotr0py): Remove `support_fa` after support FA for all ViTs attn. + # Lazy import to avoid circular dependency + from vllm.attention.selector import get_env_variable_attn_backend selected_backend: Optional[_Backend] = get_env_variable_attn_backend() if selected_backend is not None: return selected_backend - return current_platform.get_vit_attn_backend(support_fa) + return current_platform.get_vit_attn_backend(head_size, dtype) def resolve_visual_encoder_outputs( @@ -122,4 +122,4 @@ def resolve_visual_encoder_outputs( uses_last_layer = feature_sample_layers[-1] in (len(hs_pool) - 1, -1) if post_layer_norm is not None and uses_last_layer: hs_pool[-1] = post_layer_norm(encoder_outputs) - return torch.cat(hs_pool, dim=-1) \ No newline at end of file + return torch.cat(hs_pool, dim=-1) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index e40b6eb2b5a4..77c9a012b2d3 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -209,18 +209,24 @@ def get_current_memory_usage(cls, return torch.cuda.max_memory_allocated(device) @classmethod - def get_vit_attn_backend(cls, support_fa: bool = False) -> _Backend: - if cls.has_device_capability(80) and support_fa: - from transformers.utils import is_flash_attn_2_available - if is_flash_attn_2_available(): + def get_vit_attn_backend(cls, head_size: int, + dtype: torch.dtype) -> _Backend: + if dtype not in (torch.float16, torch.bfloat16): + return _Backend.XFORMERS + + if cls.has_device_capability(80): + FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 + from vllm.attention.selector import is_attn_backend_supported + is_default_fa_supported = is_attn_backend_supported( + FLASH_ATTN_V1, head_size, dtype, allow_import_error=False) + if is_default_fa_supported: return _Backend.FLASH_ATTN - logger.warning_once( - "Current `vllm-flash-attn` has a bug inside vision " - "module, so we use xformers backend instead. You can " - "run `pip install flash-attn` to use flash-attention " - "backend.") - # Fallback for Volta/Turing GPUs or FA not supported - return _Backend.XFORMERS + else: + # Fallback to XFORMERS + return _Backend.XFORMERS + else: + # Fallback for Volta/Turing GPUs or FA not supported + return _Backend.XFORMERS @classmethod def get_attn_backend_cls(cls, selected_backend, head_size, dtype, diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 59aa46818569..054d08c3a85b 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -192,7 +192,8 @@ def device_id_to_physical_device_id(cls, device_id: int): return device_id @classmethod - def get_vit_attn_backend(cls, support_fa: bool = False) -> _Backend: + def get_vit_attn_backend(cls, head_size: int, + dtype: torch.dtype) -> _Backend: return _Backend.TORCH_SDPA @classmethod diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index f4d136c5e0aa..bb8bff48c7b9 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -175,15 +175,15 @@ class RocmPlatform(Platform): ] @classmethod - def get_vit_attn_backend(cls, support_fa: bool = False) -> _Backend: - if support_fa: - if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA - and on_gfx9()): - # Note: AITER FA is only supported for Qwen-VL models. - # TODO: Add support for other VL models in their model class. - return _Backend.ROCM_AITER_FA - if on_gfx9(): - return _Backend.FLASH_ATTN + def get_vit_attn_backend(cls, head_size: int, + dtype: torch.dtype) -> _Backend: + if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA + and on_gfx9()): + # Note: AITER FA is only supported for Qwen-VL models. + # TODO: Add support for other VL models in their model class. + return _Backend.ROCM_AITER_FA + if on_gfx9(): + return _Backend.FLASH_ATTN return _Backend.TORCH_SDPA @classmethod