From d4eba739a44c218dfbd8b456b2c9b84dfc98092b Mon Sep 17 00:00:00 2001 From: wwl2755 Date: Fri, 5 Sep 2025 22:51:58 +0000 Subject: [PATCH 01/11] add fa3 in vit Signed-off-by: wwl2755 --- vllm/attention/layer.py | 35 ++++++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index bb05b468fd10..2d063c72ee80 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -360,12 +360,15 @@ def __init__( # currently, only torch_sdpa is supported on rocm self.attn_backend = _Backend.TORCH_SDPA else: + self.attn_backend = backend if backend in { _Backend.TORCH_SDPA, _Backend.TORCH_SDPA_VLLM_V1, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1, _Backend.ROCM_AITER_FA, + _Backend.FLASH_ATTN, + _Backend.FLASH_ATTN_VLLM_V1, } else current_platform.get_vit_attn_backend() if (self.attn_backend == _Backend.XFORMERS @@ -392,7 +395,37 @@ def forward( key = torch.repeat_interleave(key, num_repeat, dim=2) value = torch.repeat_interleave(value, num_repeat, dim=2) - if self.attn_backend == _Backend.XFORMERS: + if self.attn_backend in { + _Backend.FLASH_ATTN, + _Backend.FLASH_ATTN_VLLM_V1, + }: + if self.head_size % 32 != 0: + # import from upstream flash_attn + from flash_attn import flash_attn_varlen_func + else: + from vllm.vllm_flash_attn import flash_attn_varlen_func + + cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len, + step=q_len, + dtype=torch.int32, + device=query.device) + cu_seqlens_k = torch.arange(0, (bsz + 1) * kv_len, + step=kv_len, + dtype=torch.int32, + device=key.device) + + out = flash_attn_varlen_func( + query.flatten(0, 1), + key.flatten(0, 1), + value.flatten(0, 1), + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=q_len, + max_seqlen_k=kv_len, + softmax_scale=self.scale, + ) + out = out.reshape(bsz, q_len, -1) + elif self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops out = xops.memory_efficient_attention_forward(query, From 18c07175cdb771616106fd29a2a190ae09bd3d1f Mon Sep 17 00:00:00 2001 From: wwl2755 Date: Sun, 7 Sep 2025 22:08:56 +0000 Subject: [PATCH 02/11] refactor get_vit_attention_backend() Signed-off-by: wwl2755 --- vllm/attention/layer.py | 34 +++++++++++++--------- vllm/model_executor/models/ernie45_vl.py | 10 +++++-- vllm/model_executor/models/glm4_1v.py | 10 +++++-- vllm/model_executor/models/keye.py | 8 +++-- vllm/model_executor/models/qwen2_5_vl.py | 10 +++++-- vllm/model_executor/models/qwen2_vl.py | 10 +++++-- vllm/model_executor/models/siglip2navit.py | 8 +++-- vllm/model_executor/models/vision.py | 11 +++---- vllm/platforms/cuda.py | 29 +++++++++++------- vllm/platforms/interface.py | 4 +-- vllm/platforms/rocm.py | 19 ++++++------ 11 files changed, 96 insertions(+), 57 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 2d063c72ee80..d473520c93b5 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -23,6 +23,7 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.platforms import _Backend, current_platform from vllm.utils import direct_register_custom_op @@ -349,13 +350,11 @@ def __init__( f"divisible by num_kv_heads ({self.num_kv_heads})" self.num_queries_per_kv = self.num_heads // self.num_kv_heads - dtype = torch.get_default_dtype() - attn_backend = get_attn_backend(head_size, - dtype, - kv_cache_dtype=None, - block_size=16, - is_attention_free=False) - backend = backend_name_to_enum(attn_backend.get_name()) + # dtype = torch.get_default_dtype() + + # Determine the attention backend + backend, use_upstream_fa = get_vit_attn_backend(head_size=head_size) + if current_platform.is_rocm(): # currently, only torch_sdpa is supported on rocm self.attn_backend = _Backend.TORCH_SDPA @@ -375,6 +374,20 @@ def __init__( and not check_xformers_availability()): self.attn_backend = _Backend.TORCH_SDPA + if self.attn_backend in { + _Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1 + }: + if use_upstream_fa: + from flash_attn import flash_attn_varlen_func + self._flash_attn_varlen_func = flash_attn_varlen_func + else: + from vllm.vllm_flash_attn import flash_attn_varlen_func + self._flash_attn_varlen_func = flash_attn_varlen_func + + logger.info_once( + f"MultiHeadAttention attn_backend: {self.attn_backend}, " + f"use_upstream_fa: {use_upstream_fa}") + def forward( self, query: torch.Tensor, @@ -399,11 +412,6 @@ def forward( _Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1, }: - if self.head_size % 32 != 0: - # import from upstream flash_attn - from flash_attn import flash_attn_varlen_func - else: - from vllm.vllm_flash_attn import flash_attn_varlen_func cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len, step=q_len, @@ -414,7 +422,7 @@ def forward( dtype=torch.int32, device=key.device) - out = flash_attn_varlen_func( + out = self._flash_attn_varlen_func( query.flatten(0, 1), key.flatten(0, 1), value.flatten(0, 1), diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index 97aace5a20c3..45071cda653f 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -170,7 +170,8 @@ def __init__( prefix=f"{prefix}.proj") # Detect attention implementation. - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend, self.use_upstream_fa = get_vit_attn_backend( + head_size=self.hidden_size_per_attention_head) if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.ROCM_AITER_FA @@ -233,7 +234,10 @@ def forward( if self.attn_backend == _Backend.ROCM_AITER_FA: from aiter import flash_attn_varlen_func else: - from flash_attn import flash_attn_varlen_func + if self.use_upstream_fa: + from flash_attn import flash_attn_varlen_func + else: + from vllm.vllm_flash_attn import flash_attn_varlen_func q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) @@ -457,7 +461,7 @@ def __init__( ), "vit's config.hidden must be equal to config.embed_dim" self.ln = nn.LayerNorm(hidden_size, eps=1e-6) - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend, _ = get_vit_attn_backend(head_size=head_dim) @property def dtype(self) -> torch.dtype: diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 539381b61800..3f6c91951ff9 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -260,7 +260,8 @@ def __init__( ) # Detect attention implementation. - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend, self.use_upstream_fa = get_vit_attn_backend( + head_size=self.hidden_size_per_attention_head) if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, @@ -310,7 +311,10 @@ def forward( if self.attn_backend == _Backend.FLASH_ATTN: # from vllm_flash_attn.flash_attn_interface import ( # flash_attn_varlen_func) - from flash_attn import flash_attn_varlen_func + if self.use_upstream_fa: + from flash_attn import flash_attn_varlen_func + else: + from vllm.vllm_flash_attn import flash_attn_varlen_func q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) @@ -715,7 +719,7 @@ def __init__( self.post_layernorm = RMSNorm(vision_config.hidden_size, eps=vision_config.rms_norm_eps) - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend, _ = get_vit_attn_backend(head_size=head_dim) @property def dtype(self) -> torch.dtype: diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index 710b805acb3e..089a32aeba57 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -374,7 +374,8 @@ def __init__( ) # Detect attention implementation. - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend, self.use_upstream_fa = get_vit_attn_backend( + head_size=self.head_dim) if self.attn_backend not in {_Backend.FLASH_ATTN, _Backend.XFORMERS}: raise RuntimeError( f"Keye-VL does not support {self.attn_backend} backend now.") @@ -428,7 +429,10 @@ def forward( ) if self.attn_backend == _Backend.FLASH_ATTN: - from flash_attn import flash_attn_varlen_func + if self.use_upstream_fa: + from flash_attn import flash_attn_varlen_func + else: + from vllm.vllm_flash_attn import flash_attn_varlen_func q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 8aa777557029..2ebc9826ad36 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -298,7 +298,8 @@ def __init__( disable_tp=use_data_parallel) # Detect attention implementation. - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend, self.use_upstream_fa = get_vit_attn_backend( + head_size=self.hidden_size_per_attention_head) if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.ROCM_AITER_FA @@ -359,7 +360,10 @@ def forward( if self.attn_backend == _Backend.ROCM_AITER_FA: from aiter import flash_attn_varlen_func else: - from flash_attn import flash_attn_varlen_func + if self.use_upstream_fa: + from flash_attn import flash_attn_varlen_func + else: + from vllm.vllm_flash_attn import flash_attn_varlen_func q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) @@ -628,7 +632,7 @@ def __init__( prefix=f"{prefix}.merger", use_data_parallel=use_data_parallel, ) - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend, _ = get_vit_attn_backend(head_size=head_dim) @property def dtype(self) -> torch.dtype: diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 90a1ad2a658a..fbb52721bb52 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -314,7 +314,8 @@ def __init__( prefix=f"{prefix}.proj") # Detect attention implementation. - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend, self.use_upstream_fa = get_vit_attn_backend( + head_size=self.hidden_size_per_attention_head) if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.ROCM_AITER_FA @@ -374,7 +375,10 @@ def forward( if self.attn_backend == _Backend.ROCM_AITER_FA: from aiter import flash_attn_varlen_func else: - from flash_attn import flash_attn_varlen_func + if self.use_upstream_fa: + from flash_attn import flash_attn_varlen_func + else: + from vllm.vllm_flash_attn import flash_attn_varlen_func q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) @@ -628,7 +632,7 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.merger", ) - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend, _ = get_vit_attn_backend(head_size=head_dim) @property def dtype(self) -> torch.dtype: diff --git a/vllm/model_executor/models/siglip2navit.py b/vllm/model_executor/models/siglip2navit.py index c6244fb3b3e6..2c8b965d27a4 100644 --- a/vllm/model_executor/models/siglip2navit.py +++ b/vllm/model_executor/models/siglip2navit.py @@ -236,7 +236,8 @@ def __init__( self.use_rope = config.use_rope # Detect attention implementation. - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + self.attn_backend, self.use_upstream_fa = get_vit_attn_backend( + head_size=self.head_dim) if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.ROCM_AITER_FA @@ -280,7 +281,10 @@ def forward( if self.attn_backend == _Backend.ROCM_AITER_FA: from aiter import flash_attn_varlen_func else: - from flash_attn import flash_attn_varlen_func + if self.use_upstream_fa: + from flash_attn import flash_attn_varlen_func + else: + from vllm.vllm_flash_attn import flash_attn_varlen_func attn_output = flash_attn_varlen_func( queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(seq_length, -1) diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index c16aa5ac608f..447d3bc31519 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -68,17 +68,18 @@ def get_vision_encoder_info( raise NotImplementedError(msg) -def get_vit_attn_backend(support_fa: bool = False) -> _Backend: +def get_vit_attn_backend(head_size: int) -> tuple[_Backend, bool]: """ Get the available attention backend for Vision Transformer. + + Returns: + Tuple of (backend, use_upstream_fa) """ - # TODO(Isotr0py): Remove `support_fa` after support FA for all ViTs attn. - selected_backend: Optional[_Backend] = get_env_variable_attn_backend() if selected_backend is not None: - return selected_backend + return selected_backend, False - return current_platform.get_vit_attn_backend(support_fa) + return current_platform.get_vit_attn_backend(head_size) def resolve_visual_encoder_outputs( diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index e40b6eb2b5a4..a2bf5d0e5871 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -209,18 +209,25 @@ def get_current_memory_usage(cls, return torch.cuda.max_memory_allocated(device) @classmethod - def get_vit_attn_backend(cls, support_fa: bool = False) -> _Backend: - if cls.has_device_capability(80) and support_fa: - from transformers.utils import is_flash_attn_2_available - if is_flash_attn_2_available(): - return _Backend.FLASH_ATTN - logger.warning_once( - "Current `vllm-flash-attn` has a bug inside vision " - "module, so we use xformers backend instead. You can " - "run `pip install flash-attn` to use flash-attention " - "backend.") + def get_vit_attn_backend(cls, head_size: int) -> tuple[_Backend, bool]: + if cls.has_device_capability(80): + if head_size % 32 == 0: + # Use vllm-flash-attn + return _Backend.FLASH_ATTN, False + if head_size % 32 != 0: + from transformers.utils import is_flash_attn_2_available + if is_flash_attn_2_available(): + # Use upstream FA + return _Backend.FLASH_ATTN, True + else: + # Fallback to XFORMERS + logger.warning_once( + "Using xformers for ViT attention backend. " + "To use flash attention for ViT" + "please install flash_attn") + return _Backend.XFORMERS, False # Fallback for Volta/Turing GPUs or FA not supported - return _Backend.XFORMERS + return _Backend.XFORMERS, False @classmethod def get_attn_backend_cls(cls, selected_backend, head_size, dtype, diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 59aa46818569..5c4cade7f69d 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -192,8 +192,8 @@ def device_id_to_physical_device_id(cls, device_id: int): return device_id @classmethod - def get_vit_attn_backend(cls, support_fa: bool = False) -> _Backend: - return _Backend.TORCH_SDPA + def get_vit_attn_backend(cls, head_size: int) -> tuple[_Backend, bool]: + return _Backend.TORCH_SDPA, False @classmethod def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index f4d136c5e0aa..78080cce221c 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -175,16 +175,15 @@ class RocmPlatform(Platform): ] @classmethod - def get_vit_attn_backend(cls, support_fa: bool = False) -> _Backend: - if support_fa: - if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA - and on_gfx9()): - # Note: AITER FA is only supported for Qwen-VL models. - # TODO: Add support for other VL models in their model class. - return _Backend.ROCM_AITER_FA - if on_gfx9(): - return _Backend.FLASH_ATTN - return _Backend.TORCH_SDPA + def get_vit_attn_backend(cls, head_size: int) -> tuple[_Backend, bool]: + if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA + and on_gfx9()): + # Note: AITER FA is only supported for Qwen-VL models. + # TODO: Add support for other VL models in their model class. + return _Backend.ROCM_AITER_FA, False + if on_gfx9(): + return _Backend.FLASH_ATTN, False + return _Backend.TORCH_SDPA, False @classmethod def get_attn_backend_cls(cls, selected_backend, head_size, dtype, From 125b83f22044cc3452f2d334373e29d0f2880a23 Mon Sep 17 00:00:00 2001 From: wwl2755 Date: Mon, 8 Sep 2025 02:41:50 +0000 Subject: [PATCH 03/11] add dtype checking and fix tests Signed-off-by: wwl2755 --- tests/kernels/attention/test_mha_attn.py | 35 ++++++++++++++++++++++ vllm/attention/layer.py | 7 +++-- vllm/model_executor/models/ernie45_vl.py | 6 ++-- vllm/model_executor/models/glm4_1v.py | 6 ++-- vllm/model_executor/models/keye.py | 2 +- vllm/model_executor/models/qwen2_5_vl.py | 6 ++-- vllm/model_executor/models/qwen2_vl.py | 6 ++-- vllm/model_executor/models/siglip2navit.py | 2 +- vllm/model_executor/models/vision.py | 5 ++-- vllm/platforms/cuda.py | 6 +++- vllm/platforms/interface.py | 3 +- vllm/platforms/rocm.py | 3 +- 12 files changed, 70 insertions(+), 17 deletions(-) diff --git a/tests/kernels/attention/test_mha_attn.py b/tests/kernels/attention/test_mha_attn.py index c01ea32994da..cb5a72878a71 100644 --- a/tests/kernels/attention/test_mha_attn.py +++ b/tests/kernels/attention/test_mha_attn.py @@ -36,12 +36,18 @@ def test_mha_attn_platform(device: str): torch.set_default_dtype(torch.float16) if device == "cpu": +<<<<<<< HEAD with patch("vllm.attention.selector.current_platform", CpuPlatform()), \ patch("vllm.platforms.current_platform", CpuPlatform()): +======= + with patch("vllm.model_executor.models.vision.current_platform", + CpuPlatform()): +>>>>>>> add dtype checking and fix tests attn = MultiHeadAttention(16, 64, scale=1) assert attn.attn_backend == _Backend.TORCH_SDPA_VLLM_V1 elif device == "hip": +<<<<<<< HEAD with patch("vllm.attention.selector.current_platform", RocmPlatform()), \ patch("vllm.platforms.current_platform", RocmPlatform()), \ @@ -58,6 +64,35 @@ def test_mha_attn_platform(device: str): with patch("vllm.attention.selector.current_platform", CudaPlatform()), \ patch("vllm.platforms.current_platform", CudaPlatform()): +======= + with patch("vllm.model_executor.models.vision.current_platform", + RocmPlatform()): + attn = MultiHeadAttention(16, 64, scale=1) + assert attn.attn_backend == _Backend.TORCH_SDPA + else: + # Test CUDA with head_size=64 (divisible by 32) + # - should use vLLM FlashAttention + with patch("vllm.model_executor.models.vision.current_platform", + CudaPlatform()): + attn = MultiHeadAttention(16, 64, scale=1) + assert attn.attn_backend == _Backend.FLASH_ATTN + + # Test CUDA with head_size=72 (not divisible by 32) + # - upstream FA available + with patch("vllm.model_executor.models.vision.current_platform", + CudaPlatform()), \ + patch("transformers.utils.is_flash_attn_2_available", + return_value=True): + attn = MultiHeadAttention(16, 72, scale=1) + assert attn.attn_backend == _Backend.FLASH_ATTN + + # Test CUDA with head_size=72 (not divisible by 32) + # - upstream FA not available + with patch("vllm.model_executor.models.vision.current_platform", + CudaPlatform()), \ + patch("transformers.utils.is_flash_attn_2_available", + return_value=False): +>>>>>>> add dtype checking and fix tests attn = MultiHeadAttention(16, 72, scale=1) assert attn.attn_backend == _Backend.XFORMERS diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index d473520c93b5..1d375edcf482 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -350,10 +350,13 @@ def __init__( f"divisible by num_kv_heads ({self.num_kv_heads})" self.num_queries_per_kv = self.num_heads // self.num_kv_heads - # dtype = torch.get_default_dtype() + # During model initialization, the default dtype is set as the model + # weight and activation dtype. + dtype = torch.get_default_dtype() # Determine the attention backend - backend, use_upstream_fa = get_vit_attn_backend(head_size=head_size) + backend, use_upstream_fa = get_vit_attn_backend(head_size=head_size, + dtype=dtype) if current_platform.is_rocm(): # currently, only torch_sdpa is supported on rocm diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index 45071cda653f..e9cf31e1ddca 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -171,7 +171,8 @@ def __init__( # Detect attention implementation. self.attn_backend, self.use_upstream_fa = get_vit_attn_backend( - head_size=self.hidden_size_per_attention_head) + head_size=self.hidden_size_per_attention_head, + dtype=torch.get_default_dtype()) if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.ROCM_AITER_FA @@ -461,7 +462,8 @@ def __init__( ), "vit's config.hidden must be equal to config.embed_dim" self.ln = nn.LayerNorm(hidden_size, eps=1e-6) - self.attn_backend, _ = get_vit_attn_backend(head_size=head_dim) + self.attn_backend, _ = get_vit_attn_backend( + head_size=head_dim, dtype=torch.get_default_dtype()) @property def dtype(self) -> torch.dtype: diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 3f6c91951ff9..f04f918c389f 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -261,7 +261,8 @@ def __init__( # Detect attention implementation. self.attn_backend, self.use_upstream_fa = get_vit_attn_backend( - head_size=self.hidden_size_per_attention_head) + head_size=self.hidden_size_per_attention_head, + dtype=torch.get_default_dtype()) if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, @@ -719,7 +720,8 @@ def __init__( self.post_layernorm = RMSNorm(vision_config.hidden_size, eps=vision_config.rms_norm_eps) - self.attn_backend, _ = get_vit_attn_backend(head_size=head_dim) + self.attn_backend, _ = get_vit_attn_backend( + head_size=head_dim, dtype=torch.get_default_dtype()) @property def dtype(self) -> torch.dtype: diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index 089a32aeba57..b7095f8ca7aa 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -375,7 +375,7 @@ def __init__( # Detect attention implementation. self.attn_backend, self.use_upstream_fa = get_vit_attn_backend( - head_size=self.head_dim) + head_size=self.head_dim, dtype=torch.get_default_dtype()) if self.attn_backend not in {_Backend.FLASH_ATTN, _Backend.XFORMERS}: raise RuntimeError( f"Keye-VL does not support {self.attn_backend} backend now.") diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 2ebc9826ad36..801103a032a4 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -299,7 +299,8 @@ def __init__( # Detect attention implementation. self.attn_backend, self.use_upstream_fa = get_vit_attn_backend( - head_size=self.hidden_size_per_attention_head) + head_size=self.hidden_size_per_attention_head, + dtype=torch.get_default_dtype()) if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.ROCM_AITER_FA @@ -632,7 +633,8 @@ def __init__( prefix=f"{prefix}.merger", use_data_parallel=use_data_parallel, ) - self.attn_backend, _ = get_vit_attn_backend(head_size=head_dim) + self.attn_backend, _ = get_vit_attn_backend( + head_size=head_dim, dtype=torch.get_default_dtype()) @property def dtype(self) -> torch.dtype: diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index fbb52721bb52..a9ac8c909f1f 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -315,7 +315,8 @@ def __init__( # Detect attention implementation. self.attn_backend, self.use_upstream_fa = get_vit_attn_backend( - head_size=self.hidden_size_per_attention_head) + head_size=self.hidden_size_per_attention_head, + dtype=torch.get_default_dtype()) if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.ROCM_AITER_FA @@ -632,7 +633,8 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.merger", ) - self.attn_backend, _ = get_vit_attn_backend(head_size=head_dim) + self.attn_backend, _ = get_vit_attn_backend( + head_size=head_dim, dtype=torch.get_default_dtype()) @property def dtype(self) -> torch.dtype: diff --git a/vllm/model_executor/models/siglip2navit.py b/vllm/model_executor/models/siglip2navit.py index 2c8b965d27a4..aa2aa2e2b8d3 100644 --- a/vllm/model_executor/models/siglip2navit.py +++ b/vllm/model_executor/models/siglip2navit.py @@ -237,7 +237,7 @@ def __init__( # Detect attention implementation. self.attn_backend, self.use_upstream_fa = get_vit_attn_backend( - head_size=self.head_dim) + head_size=self.head_dim, dtype=torch.get_default_dtype()) if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.ROCM_AITER_FA diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index 447d3bc31519..82b14c65f262 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -68,7 +68,8 @@ def get_vision_encoder_info( raise NotImplementedError(msg) -def get_vit_attn_backend(head_size: int) -> tuple[_Backend, bool]: +def get_vit_attn_backend(head_size: int, + dtype: torch.dtype) -> tuple[_Backend, bool]: """ Get the available attention backend for Vision Transformer. @@ -79,7 +80,7 @@ def get_vit_attn_backend(head_size: int) -> tuple[_Backend, bool]: if selected_backend is not None: return selected_backend, False - return current_platform.get_vit_attn_backend(head_size) + return current_platform.get_vit_attn_backend(head_size, dtype) def resolve_visual_encoder_outputs( diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index a2bf5d0e5871..e006bbcc5605 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -209,7 +209,11 @@ def get_current_memory_usage(cls, return torch.cuda.max_memory_allocated(device) @classmethod - def get_vit_attn_backend(cls, head_size: int) -> tuple[_Backend, bool]: + def get_vit_attn_backend(cls, head_size: int, + dtype: torch.dtype) -> tuple[_Backend, bool]: + if dtype not in (torch.float16, torch.bfloat16): + return _Backend.XFORMERS, False + if cls.has_device_capability(80): if head_size % 32 == 0: # Use vllm-flash-attn diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 5c4cade7f69d..1d5074928b74 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -192,7 +192,8 @@ def device_id_to_physical_device_id(cls, device_id: int): return device_id @classmethod - def get_vit_attn_backend(cls, head_size: int) -> tuple[_Backend, bool]: + def get_vit_attn_backend(cls, head_size: int, + dtype: torch.dtype) -> tuple[_Backend, bool]: return _Backend.TORCH_SDPA, False @classmethod diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 78080cce221c..6203ca749a65 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -175,7 +175,8 @@ class RocmPlatform(Platform): ] @classmethod - def get_vit_attn_backend(cls, head_size: int) -> tuple[_Backend, bool]: + def get_vit_attn_backend(cls, head_size: int, + dtype: torch.dtype) -> tuple[_Backend, bool]: if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9()): # Note: AITER FA is only supported for Qwen-VL models. From 22205bd0d8f854a3915cb566712a76638ae8c5df Mon Sep 17 00:00:00 2001 From: wwl2755 Date: Thu, 11 Sep 2025 20:16:31 +0000 Subject: [PATCH 04/11] more robust check Signed-off-by: wwl2755 --- tests/kernels/attention/test_mha_attn.py | 25 ----------------- vllm/platforms/cuda.py | 35 +++++++++++++----------- 2 files changed, 19 insertions(+), 41 deletions(-) diff --git a/tests/kernels/attention/test_mha_attn.py b/tests/kernels/attention/test_mha_attn.py index cb5a72878a71..706ffb4630d3 100644 --- a/tests/kernels/attention/test_mha_attn.py +++ b/tests/kernels/attention/test_mha_attn.py @@ -36,35 +36,11 @@ def test_mha_attn_platform(device: str): torch.set_default_dtype(torch.float16) if device == "cpu": -<<<<<<< HEAD - with patch("vllm.attention.selector.current_platform", - CpuPlatform()), \ - patch("vllm.platforms.current_platform", CpuPlatform()): -======= with patch("vllm.model_executor.models.vision.current_platform", CpuPlatform()): ->>>>>>> add dtype checking and fix tests attn = MultiHeadAttention(16, 64, scale=1) assert attn.attn_backend == _Backend.TORCH_SDPA_VLLM_V1 elif device == "hip": -<<<<<<< HEAD - with patch("vllm.attention.selector.current_platform", - RocmPlatform()), \ - patch("vllm.platforms.current_platform", RocmPlatform()), \ - patch("vllm.attention.layer.current_platform", RocmPlatform()): - attn = MultiHeadAttention(16, 64, scale=1) - assert attn.attn_backend == _Backend.TORCH_SDPA - else: - with patch("vllm.attention.selector.current_platform", - CudaPlatform()), \ - patch("vllm.platforms.current_platform", CudaPlatform()): - attn = MultiHeadAttention(16, 64, scale=1) - assert attn.attn_backend == _Backend.XFORMERS - - with patch("vllm.attention.selector.current_platform", - CudaPlatform()), \ - patch("vllm.platforms.current_platform", CudaPlatform()): -======= with patch("vllm.model_executor.models.vision.current_platform", RocmPlatform()): attn = MultiHeadAttention(16, 64, scale=1) @@ -92,7 +68,6 @@ def test_mha_attn_platform(device: str): CudaPlatform()), \ patch("transformers.utils.is_flash_attn_2_available", return_value=False): ->>>>>>> add dtype checking and fix tests attn = MultiHeadAttention(16, 72, scale=1) assert attn.attn_backend == _Backend.XFORMERS diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index e006bbcc5605..652be4f5861a 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -215,23 +215,26 @@ def get_vit_attn_backend(cls, head_size: int, return _Backend.XFORMERS, False if cls.has_device_capability(80): - if head_size % 32 == 0: - # Use vllm-flash-attn + FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 + from vllm.attention.selector import is_attn_backend_supported + is_default_fa_supported = is_attn_backend_supported( + FLASH_ATTN_V1, head_size, dtype, allow_import_error=False) + from transformers.utils import is_flash_attn_2_available + is_upstream_fa_supported = is_flash_attn_2_available() + if is_default_fa_supported: return _Backend.FLASH_ATTN, False - if head_size % 32 != 0: - from transformers.utils import is_flash_attn_2_available - if is_flash_attn_2_available(): - # Use upstream FA - return _Backend.FLASH_ATTN, True - else: - # Fallback to XFORMERS - logger.warning_once( - "Using xformers for ViT attention backend. " - "To use flash attention for ViT" - "please install flash_attn") - return _Backend.XFORMERS, False - # Fallback for Volta/Turing GPUs or FA not supported - return _Backend.XFORMERS, False + elif is_upstream_fa_supported: + return _Backend.FLASH_ATTN, True + else: + # Fallback to XFORMERS + logger.warning_once( + "Using xformers for ViT attention backend. " + "To use flash attention for ViT" + "please install flash_attn") + return _Backend.XFORMERS, False + else: + # Fallback for Volta/Turing GPUs or FA not supported + return _Backend.XFORMERS, False @classmethod def get_attn_backend_cls(cls, selected_backend, head_size, dtype, From 7a4121ca3aaeda7f80809e144b1bf07f8f7740b9 Mon Sep 17 00:00:00 2001 From: wwl2755 Date: Thu, 11 Sep 2025 22:38:45 +0000 Subject: [PATCH 05/11] fix comment(1) Signed-off-by: wwl2755 --- tests/kernels/attention/test_mha_attn.py | 15 ++------------- vllm/attention/layer.py | 24 ++++++++++++++++++++---- vllm/model_executor/models/vision.py | 7 +++---- vllm/platforms/cuda.py | 18 +++++------------- vllm/platforms/interface.py | 4 ++-- vllm/platforms/rocm.py | 8 ++++---- 6 files changed, 36 insertions(+), 40 deletions(-) diff --git a/tests/kernels/attention/test_mha_attn.py b/tests/kernels/attention/test_mha_attn.py index 706ffb4630d3..ed7a7a764789 100644 --- a/tests/kernels/attention/test_mha_attn.py +++ b/tests/kernels/attention/test_mha_attn.py @@ -54,20 +54,9 @@ def test_mha_attn_platform(device: str): assert attn.attn_backend == _Backend.FLASH_ATTN # Test CUDA with head_size=72 (not divisible by 32) - # - upstream FA available + # - should use xformers with patch("vllm.model_executor.models.vision.current_platform", - CudaPlatform()), \ - patch("transformers.utils.is_flash_attn_2_available", - return_value=True): - attn = MultiHeadAttention(16, 72, scale=1) - assert attn.attn_backend == _Backend.FLASH_ATTN - - # Test CUDA with head_size=72 (not divisible by 32) - # - upstream FA not available - with patch("vllm.model_executor.models.vision.current_platform", - CudaPlatform()), \ - patch("transformers.utils.is_flash_attn_2_available", - return_value=False): + CudaPlatform()): attn = MultiHeadAttention(16, 72, scale=1) assert attn.attn_backend == _Backend.XFORMERS diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 1d375edcf482..77a2a88287d3 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -56,6 +56,14 @@ def check_xformers_availability(): return USE_XFORMERS_OPS +def check_upstream_fa_availability(dtype: torch.dtype): + if dtype in (torch.float16, torch.bfloat16) and current_platform.is_cuda( + ) and current_platform.has_device_capability(80): + from transformers.utils import is_flash_attn_2_available + return is_flash_attn_2_available() + return False + + class Attention(nn.Module, AttentionLayerBase): """Attention layer. @@ -355,14 +363,22 @@ def __init__( dtype = torch.get_default_dtype() # Determine the attention backend - backend, use_upstream_fa = get_vit_attn_backend(head_size=head_size, - dtype=dtype) + backend = get_vit_attn_backend(head_size=head_size, dtype=dtype) + + # Some auto-selected backends can be upgraded + # to upstream flash attention if available. + # If vllm native fa is selected, we use it directly. + use_upstream_fa = False + if check_upstream_fa_availability( + dtype) and backend != _Backend.FLASH_ATTN: + backend = _Backend.FLASH_ATTN + use_upstream_fa = True if current_platform.is_rocm(): # currently, only torch_sdpa is supported on rocm self.attn_backend = _Backend.TORCH_SDPA else: - + self.attn_backend = backend if backend in { _Backend.TORCH_SDPA, _Backend.TORCH_SDPA_VLLM_V1, @@ -371,7 +387,7 @@ def __init__( _Backend.ROCM_AITER_FA, _Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1, - } else current_platform.get_vit_attn_backend() + } else _Backend.TORCH_SDPA if (self.attn_backend == _Backend.XFORMERS and not check_xformers_availability()): diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index 82b14c65f262..1a2138a3df19 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -68,8 +68,7 @@ def get_vision_encoder_info( raise NotImplementedError(msg) -def get_vit_attn_backend(head_size: int, - dtype: torch.dtype) -> tuple[_Backend, bool]: +def get_vit_attn_backend(head_size: int, dtype: torch.dtype) -> _Backend: """ Get the available attention backend for Vision Transformer. @@ -78,7 +77,7 @@ def get_vit_attn_backend(head_size: int, """ selected_backend: Optional[_Backend] = get_env_variable_attn_backend() if selected_backend is not None: - return selected_backend, False + return selected_backend return current_platform.get_vit_attn_backend(head_size, dtype) @@ -124,4 +123,4 @@ def resolve_visual_encoder_outputs( uses_last_layer = feature_sample_layers[-1] in (len(hs_pool) - 1, -1) if post_layer_norm is not None and uses_last_layer: hs_pool[-1] = post_layer_norm(encoder_outputs) - return torch.cat(hs_pool, dim=-1) \ No newline at end of file + return torch.cat(hs_pool, dim=-1) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 652be4f5861a..77c9a012b2d3 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -210,31 +210,23 @@ def get_current_memory_usage(cls, @classmethod def get_vit_attn_backend(cls, head_size: int, - dtype: torch.dtype) -> tuple[_Backend, bool]: + dtype: torch.dtype) -> _Backend: if dtype not in (torch.float16, torch.bfloat16): - return _Backend.XFORMERS, False + return _Backend.XFORMERS if cls.has_device_capability(80): FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501 from vllm.attention.selector import is_attn_backend_supported is_default_fa_supported = is_attn_backend_supported( FLASH_ATTN_V1, head_size, dtype, allow_import_error=False) - from transformers.utils import is_flash_attn_2_available - is_upstream_fa_supported = is_flash_attn_2_available() if is_default_fa_supported: - return _Backend.FLASH_ATTN, False - elif is_upstream_fa_supported: - return _Backend.FLASH_ATTN, True + return _Backend.FLASH_ATTN else: # Fallback to XFORMERS - logger.warning_once( - "Using xformers for ViT attention backend. " - "To use flash attention for ViT" - "please install flash_attn") - return _Backend.XFORMERS, False + return _Backend.XFORMERS else: # Fallback for Volta/Turing GPUs or FA not supported - return _Backend.XFORMERS, False + return _Backend.XFORMERS @classmethod def get_attn_backend_cls(cls, selected_backend, head_size, dtype, diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 1d5074928b74..054d08c3a85b 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -193,8 +193,8 @@ def device_id_to_physical_device_id(cls, device_id: int): @classmethod def get_vit_attn_backend(cls, head_size: int, - dtype: torch.dtype) -> tuple[_Backend, bool]: - return _Backend.TORCH_SDPA, False + dtype: torch.dtype) -> _Backend: + return _Backend.TORCH_SDPA @classmethod def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 6203ca749a65..bb8bff48c7b9 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -176,15 +176,15 @@ class RocmPlatform(Platform): @classmethod def get_vit_attn_backend(cls, head_size: int, - dtype: torch.dtype) -> tuple[_Backend, bool]: + dtype: torch.dtype) -> _Backend: if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9()): # Note: AITER FA is only supported for Qwen-VL models. # TODO: Add support for other VL models in their model class. - return _Backend.ROCM_AITER_FA, False + return _Backend.ROCM_AITER_FA if on_gfx9(): - return _Backend.FLASH_ATTN, False - return _Backend.TORCH_SDPA, False + return _Backend.FLASH_ATTN + return _Backend.TORCH_SDPA @classmethod def get_attn_backend_cls(cls, selected_backend, head_size, dtype, From b32f431a02207388da8fbf6d94ae43ec4d680d81 Mon Sep 17 00:00:00 2001 From: wwl2755 Date: Thu, 11 Sep 2025 23:37:46 +0000 Subject: [PATCH 06/11] fix test Signed-off-by: wwl2755 --- tests/kernels/attention/test_mha_attn.py | 32 ++++++++++++++++++------ 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/tests/kernels/attention/test_mha_attn.py b/tests/kernels/attention/test_mha_attn.py index ed7a7a764789..04160ea592d4 100644 --- a/tests/kernels/attention/test_mha_attn.py +++ b/tests/kernels/attention/test_mha_attn.py @@ -36,30 +36,48 @@ def test_mha_attn_platform(device: str): torch.set_default_dtype(torch.float16) if device == "cpu": - with patch("vllm.model_executor.models.vision.current_platform", + with patch("vllm.attention.layer.current_platform", CpuPlatform()), \ + patch("vllm.model_executor.models.vision.current_platform", CpuPlatform()): attn = MultiHeadAttention(16, 64, scale=1) - assert attn.attn_backend == _Backend.TORCH_SDPA_VLLM_V1 + assert attn.attn_backend == _Backend.TORCH_SDPA elif device == "hip": - with patch("vllm.model_executor.models.vision.current_platform", + with patch("vllm.attention.layer.current_platform", RocmPlatform()), \ + patch("vllm.model_executor.models.vision.current_platform", RocmPlatform()): attn = MultiHeadAttention(16, 64, scale=1) assert attn.attn_backend == _Backend.TORCH_SDPA else: # Test CUDA with head_size=64 (divisible by 32) - # - should use vLLM FlashAttention - with patch("vllm.model_executor.models.vision.current_platform", + # - should upstream FlashAttention + with patch("vllm.attention.layer.current_platform", CudaPlatform()), \ + patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()): attn = MultiHeadAttention(16, 64, scale=1) assert attn.attn_backend == _Backend.FLASH_ATTN # Test CUDA with head_size=72 (not divisible by 32) + # - with upstream FA not available # - should use xformers - with patch("vllm.model_executor.models.vision.current_platform", - CudaPlatform()): + with patch("vllm.attention.layer.current_platform", CudaPlatform()), \ + patch("vllm.model_executor.models.vision.current_platform", + CudaPlatform()), \ + patch("vllm.attention.layer.check_upstream_fa_availability", + return_value=False): attn = MultiHeadAttention(16, 72, scale=1) assert attn.attn_backend == _Backend.XFORMERS + # Test CUDA with head_size=72 (not divisible by 32) + # - with upstream FA available + # - should use upstream FA + with patch("vllm.attention.layer.current_platform", CudaPlatform()), \ + patch("vllm.model_executor.models.vision.current_platform", + CudaPlatform()), \ + patch("vllm.attention.layer.check_upstream_fa_availability", + return_value=True): + attn = MultiHeadAttention(16, 72, scale=1) + assert attn.attn_backend == _Backend.FLASH_ATTN + def ref_attention( query: torch.Tensor, From e93c7f60413f8923f6107543f55bebce68850458 Mon Sep 17 00:00:00 2001 From: wwl2755 Date: Thu, 11 Sep 2025 23:46:02 +0000 Subject: [PATCH 07/11] fix comment (2) Signed-off-by: wwl2755 --- vllm/attention/layer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 77a2a88287d3..1e94ad0782a4 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -451,7 +451,6 @@ def forward( max_seqlen_k=kv_len, softmax_scale=self.scale, ) - out = out.reshape(bsz, q_len, -1) elif self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops From 762cf3d6b1e98ab80e1b6bbdd64fb1fc4fd9c4f3 Mon Sep 17 00:00:00 2001 From: wwl2755 Date: Fri, 12 Sep 2025 00:48:06 +0000 Subject: [PATCH 08/11] fix interface Signed-off-by: wwl2755 --- vllm/attention/layer.py | 4 ++-- vllm/model_executor/models/ernie45_vl.py | 15 +++++++++++++-- vllm/model_executor/models/glm4_1v.py | 14 ++++++++++++-- vllm/model_executor/models/keye.py | 11 ++++++++++- vllm/model_executor/models/qwen2_5_vl.py | 16 ++++++++++++++-- vllm/model_executor/models/qwen2_vl.py | 16 ++++++++++++++-- vllm/model_executor/models/siglip2navit.py | 10 +++++++++- 7 files changed, 74 insertions(+), 12 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 1e94ad0782a4..44cb2c7c6b64 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -369,8 +369,8 @@ def __init__( # to upstream flash attention if available. # If vllm native fa is selected, we use it directly. use_upstream_fa = False - if check_upstream_fa_availability( - dtype) and backend != _Backend.FLASH_ATTN: + if backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( + dtype): backend = _Backend.FLASH_ATTN use_upstream_fa = True diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index e9cf31e1ddca..bcff65a717ab 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -34,6 +34,7 @@ from einops import rearrange, repeat from transformers import BatchFeature +from vllm.attention.layer import check_upstream_fa_availability from vllm.config import VllmConfig from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils @@ -170,9 +171,16 @@ def __init__( prefix=f"{prefix}.proj") # Detect attention implementation. - self.attn_backend, self.use_upstream_fa = get_vit_attn_backend( + self.attn_backend = get_vit_attn_backend( head_size=self.hidden_size_per_attention_head, dtype=torch.get_default_dtype()) + + self.use_upstream_fa = False + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability(torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN + self.use_upstream_fa = True + if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.ROCM_AITER_FA @@ -462,8 +470,11 @@ def __init__( ), "vit's config.hidden must be equal to config.embed_dim" self.ln = nn.LayerNorm(hidden_size, eps=1e-6) - self.attn_backend, _ = get_vit_attn_backend( + self.attn_backend = get_vit_attn_backend( head_size=head_dim, dtype=torch.get_default_dtype()) + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability(torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN @property def dtype(self) -> torch.dtype: diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index f04f918c389f..279f458dfa6c 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -44,6 +44,7 @@ Glm4vVideoProcessor) from transformers.video_utils import VideoMetadata +from vllm.attention.layer import check_upstream_fa_availability from vllm.config import VllmConfig from vllm.distributed import (get_tensor_model_parallel_world_size, parallel_state) @@ -260,9 +261,15 @@ def __init__( ) # Detect attention implementation. - self.attn_backend, self.use_upstream_fa = get_vit_attn_backend( + self.attn_backend = get_vit_attn_backend( head_size=self.hidden_size_per_attention_head, dtype=torch.get_default_dtype()) + self.use_upstream_fa = False + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability(torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN + self.use_upstream_fa = True + if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, @@ -720,8 +727,11 @@ def __init__( self.post_layernorm = RMSNorm(vision_config.hidden_size, eps=vision_config.rms_norm_eps) - self.attn_backend, _ = get_vit_attn_backend( + self.attn_backend = get_vit_attn_backend( head_size=head_dim, dtype=torch.get_default_dtype()) + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability(torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN @property def dtype(self) -> torch.dtype: diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index b7095f8ca7aa..04824db1b6dd 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -17,6 +17,7 @@ BaseModelOutputWithPooling) from transformers.utils import torch_int +from vllm.attention.layer import check_upstream_fa_availability from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger @@ -374,8 +375,16 @@ def __init__( ) # Detect attention implementation. - self.attn_backend, self.use_upstream_fa = get_vit_attn_backend( + self.attn_backend = get_vit_attn_backend( head_size=self.head_dim, dtype=torch.get_default_dtype()) + + self.use_upstream_fa = False + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability( + torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN + self.use_upstream_fa = True + if self.attn_backend not in {_Backend.FLASH_ATTN, _Backend.XFORMERS}: raise RuntimeError( f"Keye-VL does not support {self.attn_backend} backend now.") diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 801103a032a4..98f9c0cf4c16 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -38,6 +38,7 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig) +from vllm.attention.layer import check_upstream_fa_availability from vllm.config import VllmConfig from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils @@ -298,9 +299,16 @@ def __init__( disable_tp=use_data_parallel) # Detect attention implementation. - self.attn_backend, self.use_upstream_fa = get_vit_attn_backend( + self.attn_backend = get_vit_attn_backend( head_size=self.hidden_size_per_attention_head, dtype=torch.get_default_dtype()) + self.use_upstream_fa = False + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability( + torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN + self.use_upstream_fa = True + if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.ROCM_AITER_FA @@ -633,8 +641,12 @@ def __init__( prefix=f"{prefix}.merger", use_data_parallel=use_data_parallel, ) - self.attn_backend, _ = get_vit_attn_backend( + self.attn_backend = get_vit_attn_backend( head_size=head_dim, dtype=torch.get_default_dtype()) + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability( + torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN @property def dtype(self) -> torch.dtype: diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index a9ac8c909f1f..89af79c3b5fd 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -41,6 +41,7 @@ from transformers.models.qwen2_vl.video_processing_qwen2_vl import ( Qwen2VLVideoProcessor) +from vllm.attention.layer import check_upstream_fa_availability from vllm.config import VllmConfig from vllm.distributed import parallel_state, tensor_model_parallel_all_gather from vllm.distributed import utils as dist_utils @@ -314,9 +315,16 @@ def __init__( prefix=f"{prefix}.proj") # Detect attention implementation. - self.attn_backend, self.use_upstream_fa = get_vit_attn_backend( + self.attn_backend = get_vit_attn_backend( head_size=self.hidden_size_per_attention_head, dtype=torch.get_default_dtype()) + self.use_upstream_fa = False + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability( + torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN + self.use_upstream_fa = True + if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.ROCM_AITER_FA @@ -633,8 +641,12 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.merger", ) - self.attn_backend, _ = get_vit_attn_backend( + self.attn_backend = get_vit_attn_backend( head_size=head_dim, dtype=torch.get_default_dtype()) + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability( + torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN @property def dtype(self) -> torch.dtype: diff --git a/vllm/model_executor/models/siglip2navit.py b/vllm/model_executor/models/siglip2navit.py index aa2aa2e2b8d3..a86700fe68dd 100644 --- a/vllm/model_executor/models/siglip2navit.py +++ b/vllm/model_executor/models/siglip2navit.py @@ -13,6 +13,7 @@ from transformers import Siglip2VisionConfig from transformers.configuration_utils import PretrainedConfig +from vllm.attention.layer import check_upstream_fa_availability from vllm.config import QuantizationConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn @@ -236,8 +237,15 @@ def __init__( self.use_rope = config.use_rope # Detect attention implementation. - self.attn_backend, self.use_upstream_fa = get_vit_attn_backend( + self.attn_backend = get_vit_attn_backend( head_size=self.head_dim, dtype=torch.get_default_dtype()) + self.use_upstream_fa = False + if self.attn_backend != _Backend.FLASH_ATTN and \ + check_upstream_fa_availability( + torch.get_default_dtype()): + self.attn_backend = _Backend.FLASH_ATTN + self.use_upstream_fa = True + if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.ROCM_AITER_FA From 27fb88fe0f1c1d99b4ccfbe3c084116267ceac04 Mon Sep 17 00:00:00 2001 From: wwl2755 Date: Fri, 12 Sep 2025 04:19:41 +0000 Subject: [PATCH 09/11] fix CI failures Signed-off-by: wwl2755 --- tests/entrypoints/openai/test_vision.py | 2 +- tests/kernels/attention/test_mha_attn.py | 8 ++++++-- vllm/model_executor/models/vision.py | 4 +++- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py index 29a3b40d2d86..3e4a689151a0 100644 --- a/tests/entrypoints/openai/test_vision.py +++ b/tests/entrypoints/openai/test_vision.py @@ -34,7 +34,7 @@ ], [ "The image shows a Venn diagram with three over", - "The image shows a Venn diagram with three intersect", + "This image shows a Venn diagram with three over", ], [ "This image displays a gradient of colors ranging from", diff --git a/tests/kernels/attention/test_mha_attn.py b/tests/kernels/attention/test_mha_attn.py index 04160ea592d4..d37b968ed979 100644 --- a/tests/kernels/attention/test_mha_attn.py +++ b/tests/kernels/attention/test_mha_attn.py @@ -49,7 +49,7 @@ def test_mha_attn_platform(device: str): assert attn.attn_backend == _Backend.TORCH_SDPA else: # Test CUDA with head_size=64 (divisible by 32) - # - should upstream FlashAttention + # - should use vLLM's FlashAttention with patch("vllm.attention.layer.current_platform", CudaPlatform()), \ patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()): @@ -74,7 +74,11 @@ def test_mha_attn_platform(device: str): patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()), \ patch("vllm.attention.layer.check_upstream_fa_availability", - return_value=True): + return_value=True), \ + patch.dict('sys.modules', {'flash_attn': type('MockFlashAttn', (), + { + 'flash_attn_varlen_func': lambda *args, **kwargs: None + })()}): attn = MultiHeadAttention(16, 72, scale=1) assert attn.attn_backend == _Backend.FLASH_ATTN diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index 1a2138a3df19..d6a450ac1413 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -7,7 +7,6 @@ import torch from transformers import PretrainedConfig -from vllm.attention.selector import get_env_variable_attn_backend from vllm.logger import init_logger from vllm.platforms import _Backend, current_platform @@ -75,6 +74,9 @@ def get_vit_attn_backend(head_size: int, dtype: torch.dtype) -> _Backend: Returns: Tuple of (backend, use_upstream_fa) """ + # Lazy import to avoid circular dependency + from vllm.attention.selector import get_env_variable_attn_backend + selected_backend: Optional[_Backend] = get_env_variable_attn_backend() if selected_backend is not None: return selected_backend From 6cb38d184f3124bba534488cb18a4ecc564a9fc5 Mon Sep 17 00:00:00 2001 From: wwl2755 Date: Fri, 12 Sep 2025 04:23:46 +0000 Subject: [PATCH 10/11] fix wrong comment Signed-off-by: wwl2755 --- vllm/model_executor/models/vision.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index d6a450ac1413..81f86db7e187 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -70,9 +70,6 @@ def get_vision_encoder_info( def get_vit_attn_backend(head_size: int, dtype: torch.dtype) -> _Backend: """ Get the available attention backend for Vision Transformer. - - Returns: - Tuple of (backend, use_upstream_fa) """ # Lazy import to avoid circular dependency from vllm.attention.selector import get_env_variable_attn_backend From db8acbe0473142cac3398848380e52b376530a70 Mon Sep 17 00:00:00 2001 From: wwl2755 Date: Fri, 12 Sep 2025 06:36:40 +0000 Subject: [PATCH 11/11] #suppress-bc-linter Signed-off-by: wwl2755 --- tests/entrypoints/openai/test_vision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py index 3e4a689151a0..72819f31de20 100644 --- a/tests/entrypoints/openai/test_vision.py +++ b/tests/entrypoints/openai/test_vision.py @@ -38,7 +38,7 @@ ], [ "This image displays a gradient of colors ranging from", - "The image displays a gradient of colors ranging from", + "This image displays a gradient of colors forming a spectrum", ], ]