diff --git a/tests/config/test_multimodal_config.py b/tests/config/test_multimodal_config.py new file mode 100644 index 000000000000..b1a09d88ed9d --- /dev/null +++ b/tests/config/test_multimodal_config.py @@ -0,0 +1,25 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from vllm.attention.backends.registry import _Backend +from vllm.config.multimodal import MultiModalConfig + + +def test_mm_encoder_attn_backend_str_conversion(): + config = MultiModalConfig(mm_encoder_attn_backend="FLASH_ATTN") + assert config.mm_encoder_attn_backend == _Backend.FLASH_ATTN + + +def test_mm_encoder_attn_backend_invalid(): + with pytest.raises(ValueError): + MultiModalConfig(mm_encoder_attn_backend="not_a_backend") + + +def test_mm_encoder_attn_backend_hash_updates(): + base_hash = MultiModalConfig().compute_hash() + overridden_hash = MultiModalConfig( + mm_encoder_attn_backend=_Backend.FLASH_ATTN + ).compute_hash() + assert base_hash != overridden_hash diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index a028be6ce7f8..a3444c1ac82c 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -16,6 +16,7 @@ from vllm.attention.selector import get_attn_backend from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target from vllm.config import CacheConfig, get_current_vllm_config +from vllm.config.multimodal import MultiModalConfig from vllm.config.vllm import VllmConfig from vllm.distributed.kv_transfer import ( get_kv_transfer_group, @@ -443,6 +444,7 @@ def __init__( # This has no effect, it is only here to make it easier to swap # between Attention and MultiHeadAttention prefix: str = "", + multimodal_config: MultiModalConfig | None = None, ) -> None: super().__init__() self.num_heads = num_heads @@ -462,7 +464,14 @@ def __init__( dtype = torch.get_default_dtype() # Determine the attention backend - backend = get_vit_attn_backend(head_size=head_size, dtype=dtype) + attn_backend_override = None + if multimodal_config is not None: + attn_backend_override = multimodal_config.mm_encoder_attn_backend + backend = get_vit_attn_backend( + head_size=head_size, + dtype=dtype, + attn_backend_override=attn_backend_override, + ) # Some auto-selected backends can be upgraded # to upstream flash attention if available. diff --git a/vllm/config/model.py b/vllm/config/model.py index c99451aa2a1b..7bf8b4bfc15a 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -50,6 +50,7 @@ import vllm.model_executor.layers.quantization as me_quant import vllm.model_executor.models as me_models + from vllm.attention.backends.registry import _Backend from vllm.config.load import LoadConfig from vllm.config.parallel import ParallelConfig from vllm.model_executor.layers.quantization import QuantizationMethods @@ -57,6 +58,7 @@ else: PretrainedConfig = Any + _Backend = Any me_quant = LazyLoader( "model_executor", globals(), "vllm.model_executor.layers.quantization" ) @@ -307,6 +309,7 @@ class ModelConfig: mm_processor_cache_type: InitVar[MMCacheType | None] = None mm_shm_cache_max_object_size_mb: InitVar[int | None] = None mm_encoder_tp_mode: InitVar[MMEncoderTPMode | None] = None + mm_encoder_attn_backend: InitVar[_Backend | str | None] = None interleave_mm_strings: InitVar[bool | None] = None skip_mm_profiling: InitVar[bool | None] = None video_pruning_rate: InitVar[float | None] = None @@ -424,6 +427,7 @@ def __post_init__( mm_processor_cache_type: MMCacheType | None, mm_shm_cache_max_object_size_mb: int | None, mm_encoder_tp_mode: MMEncoderTPMode | None, + mm_encoder_attn_backend: _Backend | str | None, interleave_mm_strings: bool | None, skip_mm_profiling: bool | None, video_pruning_rate: float | None, @@ -733,6 +737,7 @@ def _task_to_convert(task: TaskOption) -> ConvertType: mm_processor_cache_type=mm_processor_cache_type, mm_shm_cache_max_object_size_mb=mm_shm_cache_max_object_size_mb, mm_encoder_tp_mode=mm_encoder_tp_mode, + mm_encoder_attn_backend=mm_encoder_attn_backend, interleave_mm_strings=interleave_mm_strings, skip_mm_profiling=skip_mm_profiling, video_pruning_rate=video_pruning_rate, diff --git a/vllm/config/multimodal.py b/vllm/config/multimodal.py index 6c3e2b9b867f..e80d072dab45 100644 --- a/vllm/config/multimodal.py +++ b/vllm/config/multimodal.py @@ -3,13 +3,18 @@ import hashlib from collections.abc import Mapping -from typing import Any, Literal, TypeAlias +from typing import TYPE_CHECKING, Any, Literal, TypeAlias from pydantic import ConfigDict, Field, field_validator, model_validator from pydantic.dataclasses import dataclass from vllm.config.utils import config +if TYPE_CHECKING: + from vllm.attention.backends.registry import _Backend +else: + _Backend = Any + @dataclass class BaseDummyOptions: @@ -112,6 +117,10 @@ class MultiModalConfig: DP (which is controlled by `--data-parallel-size`). This is only supported on a per-model basis and falls back to `"weights"` if the encoder does not support DP.""" + mm_encoder_attn_backend: _Backend | None = None + """Optional override for the multi-modal encoder attention backend when + using vision transformers. Accepts any value from + `vllm.attention.backends.registry._Backend` (e.g. `FLASH_ATTN`).""" interleave_mm_strings: bool = False """Enable fully interleaved support for multimodal prompts, while using --chat-template-content-format=string.""" @@ -148,6 +157,29 @@ def _validate_limit_per_prompt( value[k] = BaseDummyOptions(**v) return value + @field_validator("mm_encoder_attn_backend", mode="before") + @classmethod + def _validate_mm_encoder_attn_backend(cls, value: object) -> _Backend | None: + from vllm.attention.backends.registry import ( + _Backend as BackendEnum, + ) + from vllm.attention.backends.registry import ( + backend_name_to_enum, + ) + + if value is None or isinstance(value, BackendEnum): + return value + + if isinstance(value, str): + candidate = backend_name_to_enum(value.upper()) + if candidate is not None: + return candidate + + valid_backends = ", ".join(sorted(BackendEnum.__members__.keys())) + raise ValueError( + f"Invalid mm encoder attention backend. Expected one of: {valid_backends}." + ) + @model_validator(mode="after") def _validate_multimodal_config(self): if self.mm_processor_cache_type != "shm" and ( @@ -172,9 +204,11 @@ def compute_hash(self) -> str: excluding anything before input ids/embeddings and after the final hidden states. """ - # no factors to consider. - # this config will not affect the computation graph. - factors: list[Any] = [] + factors: list[Any] = [ + self.mm_encoder_attn_backend.name + if self.mm_encoder_attn_backend is not None + else None + ] hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 70791375be39..b50fbe130b1f 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -32,6 +32,7 @@ from typing_extensions import TypeIs, deprecated import vllm.envs as envs +from vllm.attention.backends.registry import _Backend from vllm.config import ( CacheConfig, CompilationConfig, @@ -451,6 +452,9 @@ class EngineArgs: MultiModalConfig.mm_shm_cache_max_object_size_mb ) mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode + mm_encoder_attn_backend: _Backend | str | None = ( + MultiModalConfig.mm_encoder_attn_backend + ) io_processor_plugin: str | None = None skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling video_pruning_rate: float = MultiModalConfig.video_pruning_rate @@ -914,6 +918,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: multimodal_group.add_argument( "--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"] ) + multimodal_group.add_argument( + "--mm-encoder-attn-backend", + **multimodal_kwargs["mm_encoder_attn_backend"], + ) multimodal_group.add_argument( "--interleave-mm-strings", **multimodal_kwargs["interleave_mm_strings"] ) @@ -1160,6 +1168,7 @@ def create_model_config(self) -> ModelConfig: mm_processor_cache_type=self.mm_processor_cache_type, mm_shm_cache_max_object_size_mb=self.mm_shm_cache_max_object_size_mb, mm_encoder_tp_mode=self.mm_encoder_tp_mode, + mm_encoder_attn_backend=self.mm_encoder_attn_backend, pooler_config=self.pooler_config, override_pooler_config=self.override_pooler_config, logits_processor_pattern=self.logits_processor_pattern, diff --git a/vllm/model_executor/models/dots_ocr.py b/vllm/model_executor/models/dots_ocr.py index bd7f37b07de3..4557ef71e3c2 100644 --- a/vllm/model_executor/models/dots_ocr.py +++ b/vllm/model_executor/models/dots_ocr.py @@ -256,6 +256,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, + attn_backend_override: _Backend | None = None, ) -> None: super().__init__() @@ -288,7 +289,9 @@ def __init__( ) # Select attention backend self.attn_backend = get_vit_attn_backend( - self.hidden_size_per_attention_head, torch.get_default_dtype() + self.hidden_size_per_attention_head, + torch.get_default_dtype(), + attn_backend_override=attn_backend_override, ) self.use_upstream_fa = False @@ -510,6 +513,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, + attn_backend_override: _Backend | None = None, ): super().__init__() @@ -521,6 +525,7 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.attn", use_data_parallel=use_data_parallel, + attn_backend_override=attn_backend_override, ) self.norm1 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps) self.mlp = DotsSwiGLUFFN( @@ -561,6 +566,7 @@ def __init__( require_post_norm: bool | None = None, prefix: str = "", use_data_parallel: bool = False, + attn_backend_override: _Backend | None = None, ) -> None: super().__init__() self.config = config @@ -571,7 +577,9 @@ def __init__( head_dim = config.embed_dim // config.num_attention_heads self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) self.attn_backend = get_vit_attn_backend( - head_size=head_dim, dtype=torch.get_default_dtype() + head_size=head_dim, + dtype=torch.get_default_dtype(), + attn_backend_override=attn_backend_override, ) if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( torch.get_default_dtype() @@ -591,6 +599,7 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.blocks.{i}", use_data_parallel=use_data_parallel, + attn_backend_override=attn_backend_override, ) for i in range(num_layers) ] @@ -750,11 +759,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config.vision_config = vision_config else: vision_config = self.config.vision_config + attn_backend_override = ( + multimodal_config.mm_encoder_attn_backend + if multimodal_config is not None + else None + ) self.vision_tower = DotsVisionTransformer( vision_config, quant_config=self.quant_config, prefix=maybe_prefix(prefix, "vision_tower"), use_data_parallel=self.use_data_parallel, + attn_backend_override=attn_backend_override, ) self.language_model: Qwen2ForCausalLM = init_vllm_registered_model( vllm_config=vllm_config, diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index e5badc0a28f6..e6ac0e6a0b99 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -164,6 +164,7 @@ def __init__( projection_size: int, quant_config: QuantizationConfig | None = None, prefix: str = "", + attn_backend_override: _Backend | None = None, ) -> None: super().__init__() # Per attention head and per partition values. @@ -196,6 +197,7 @@ def __init__( self.attn_backend = get_vit_attn_backend( head_size=self.hidden_size_per_attention_head, dtype=torch.get_default_dtype(), + attn_backend_override=attn_backend_override, ) self.use_upstream_fa = False @@ -367,6 +369,7 @@ def __init__( norm_layer: Callable[[int], nn.Module] | None = None, quant_config: QuantizationConfig | None = None, prefix: str = "", + attn_backend_override: _Backend | None = None, ) -> None: super().__init__() @@ -382,6 +385,7 @@ def __init__( projection_size=dim, quant_config=quant_config, prefix=f"{prefix}.attn", + attn_backend_override=attn_backend_override, ) self.mlp = Ernie4_5_VisionMLP( @@ -458,6 +462,7 @@ def __init__( norm_eps: float = 1e-6, quant_config: QuantizationConfig | None = None, prefix: str = "", + attn_backend_override: _Backend | None = None, ) -> None: super().__init__() patch_size = vision_config.patch_size @@ -493,6 +498,7 @@ def __init__( norm_layer=norm_layer, quant_config=quant_config, prefix=f"{prefix}.blocks.{layer_idx}", + attn_backend_override=attn_backend_override, ) for layer_idx in range(depth) ] @@ -504,7 +510,9 @@ def __init__( self.ln = nn.LayerNorm(hidden_size, eps=1e-6) self.attn_backend = get_vit_attn_backend( - head_size=head_dim, dtype=torch.get_default_dtype() + head_size=head_dim, + dtype=torch.get_default_dtype(), + attn_backend_override=attn_backend_override, ) if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( torch.get_default_dtype() @@ -1327,11 +1335,17 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: self.config = config self.multimodal_config = multimodal_config + attn_backend_override = ( + multimodal_config.mm_encoder_attn_backend + if multimodal_config is not None + else None + ) self.vision_model = Ernie4_5_VisionTransformer( config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), quant_config=quant_config, prefix=maybe_prefix(prefix, "vision_model"), + attn_backend_override=attn_backend_override, ) self.language_model = Ernie4_5_VLMoeForCausalLM( diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 132f26253b36..38512f22ba8a 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -247,6 +247,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, + attn_backend_override: _Backend | None = None, ) -> None: super().__init__() # Per attention head and per partition values. @@ -287,6 +288,7 @@ def __init__( self.attn_backend = get_vit_attn_backend( head_size=self.hidden_size_per_attention_head, dtype=torch.get_default_dtype(), + attn_backend_override=attn_backend_override, ) self.use_upstream_fa = False @@ -417,6 +419,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, + attn_backend_override: _Backend | None = None, ) -> None: super().__init__() if norm_layer is None: @@ -430,6 +433,7 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.attn", use_data_parallel=use_data_parallel, + attn_backend_override=attn_backend_override, ) self.mlp = Glm4vVisionMLP( dim, @@ -696,6 +700,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, + attn_backend_override: _Backend | None = None, ) -> None: super().__init__() @@ -731,6 +736,7 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.blocks.{layer_idx}", use_data_parallel=self.use_data_parallel, + attn_backend_override=attn_backend_override, ) for layer_idx in range(depth) ] @@ -759,7 +765,9 @@ def __init__( ) self.attn_backend = get_vit_attn_backend( - head_size=head_dim, dtype=torch.get_default_dtype() + head_size=head_dim, + dtype=torch.get_default_dtype(), + attn_backend_override=attn_backend_override, ) if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( torch.get_default_dtype() @@ -1437,12 +1445,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.multimodal_config = multimodal_config self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" + attn_backend_override = ( + multimodal_config.mm_encoder_attn_backend + if multimodal_config is not None + else None + ) self.visual = Glm4vVisionTransformer( config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-5), quant_config=quant_config, prefix=maybe_prefix(prefix, "visual"), use_data_parallel=self.use_data_parallel, + attn_backend_override=attn_backend_override, ) if config.model_type == "glm4v": diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index 292a07c00d07..acfd51a6d0cc 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -353,6 +353,7 @@ def __init__( config: PretrainedConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", + attn_backend_override: _Backend | None = None, ): super().__init__() self.config = config @@ -392,7 +393,9 @@ def __init__( # Detect attention implementation. self.attn_backend = get_vit_attn_backend( - head_size=self.head_dim, dtype=torch.get_default_dtype() + head_size=self.head_dim, + dtype=torch.get_default_dtype(), + attn_backend_override=attn_backend_override, ) self.use_upstream_fa = False @@ -521,6 +524,7 @@ def __init__( config: PretrainedConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", + attn_backend_override: _Backend | None = None, ): super().__init__() self.embed_dim = config.hidden_size @@ -529,6 +533,7 @@ def __init__( config, quant_config=quant_config, prefix=f"{prefix}.self_attn", + attn_backend_override=attn_backend_override, ) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = SiglipMLP( @@ -573,6 +578,7 @@ def __init__( config: PretrainedConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", + attn_backend_override: _Backend | None = None, ): super().__init__() self.config = config @@ -585,6 +591,7 @@ def __init__( config, quant_config=quant_config, prefix=f"{prefix}.layers.{layer_idx}", + attn_backend_override=attn_backend_override, ) for layer_idx in range(config.num_hidden_layers) ] @@ -666,6 +673,7 @@ def __init__( config: PretrainedConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", + attn_backend_override: _Backend | None = None, ): super().__init__() self.config = config @@ -676,6 +684,7 @@ def __init__( config, quant_config=quant_config, prefix=f"{prefix}.encoder", + attn_backend_override=attn_backend_override, ) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) @@ -747,6 +756,7 @@ def __init__( config: PretrainedConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", + attn_backend_override: _Backend | None = None, ): super().__init__() @@ -754,6 +764,7 @@ def __init__( config, quant_config=quant_config, prefix=f"{prefix}.vision_model", + attn_backend_override=attn_backend_override, ) self.quant_config = quant_config @@ -1296,10 +1307,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.multimodal_config = multimodal_config + attn_backend_override = ( + multimodal_config.mm_encoder_attn_backend + if multimodal_config is not None + else None + ) self.visual = KeyeSiglipVisionModel( config.vision_config, quant_config=quant_config, prefix=maybe_prefix(prefix, "visual"), + attn_backend_override=attn_backend_override, ) self.mlp_AR = self._build_projector( diff --git a/vllm/model_executor/models/ovis2_5.py b/vllm/model_executor/models/ovis2_5.py index 758611afb9a4..f6461ae9a412 100644 --- a/vllm/model_executor/models/ovis2_5.py +++ b/vllm/model_executor/models/ovis2_5.py @@ -10,6 +10,7 @@ import torch.nn as nn from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig +from vllm.attention.backends.registry import _Backend from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.model_executor.layers.linear import ReplicatedLinear @@ -105,6 +106,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, + attn_backend_override: _Backend | None = None, ): super().__init__() self.config = config @@ -113,6 +115,7 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.vit", use_data_parallel=use_data_parallel, + attn_backend_override=attn_backend_override, ) # reserved tokens for INDICATOR_IDS head_dim = visual_vocab_size - len(INDICATOR_IDS) @@ -132,6 +135,7 @@ def _init_backbone( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, + attn_backend_override: _Backend | None = None, ): model_type = config.model_type if model_type == "siglip2_navit": @@ -140,6 +144,7 @@ def _init_backbone( quant_config=quant_config, prefix=prefix, use_data_parallel=use_data_parallel, + attn_backend_override=attn_backend_override, ) raise ValueError(f"Unsupported visual tokenizer model_type: {model_type}") @@ -457,6 +462,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config self.config: PretrainedConfig = config self.llm = init_vllm_registered_model( @@ -464,11 +470,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=maybe_prefix(prefix, "llm"), ) + attn_backend_override = ( + multimodal_config.mm_encoder_attn_backend + if multimodal_config is not None + else None + ) self.visual_tokenizer = VisualTokenizer( config=config.vit_config, visual_vocab_size=config.visual_vocab_size, quant_config=quant_config, prefix=f"{prefix}.visual_tokenizer", + attn_backend_override=attn_backend_override, ) self.vte = VisualEmbedding(config.visual_vocab_size, config.hidden_size) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index b0dcb55898b7..e49387648ae3 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -637,6 +637,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, + attn_backend_override: _Backend | None = None, ) -> None: super().__init__() @@ -669,7 +670,9 @@ def __init__( use_upstream_fa = False self.attn_backend = get_vit_attn_backend( - head_size=head_dim, dtype=torch.get_default_dtype() + head_size=head_dim, + dtype=torch.get_default_dtype(), + attn_backend_override=attn_backend_override, ) if ( self.attn_backend != _Backend.FLASH_ATTN @@ -1226,12 +1229,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): if multimodal_config.get_limit_per_prompt( "image" ) or multimodal_config.get_limit_per_prompt("video"): + attn_backend_override = ( + multimodal_config.mm_encoder_attn_backend + if multimodal_config is not None + else None + ) self.visual = Qwen2_5_VisionTransformer( config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), quant_config=self.quant_config, prefix=maybe_prefix(prefix, "visual"), use_data_parallel=self.use_data_parallel, + attn_backend_override=attn_backend_override, ) else: self.visual = None diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 2f74d4489cee..827b7f4aa26f 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -320,6 +320,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, + attn_backend_override: _Backend | None = None, ) -> None: super().__init__() # Per attention head and per partition values. @@ -355,6 +356,7 @@ def __init__( self.attn_backend = get_vit_attn_backend( head_size=self.hidden_size_per_attention_head, dtype=torch.get_default_dtype(), + attn_backend_override=attn_backend_override, ) self.use_upstream_fa = False @@ -497,6 +499,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, + attn_backend_override: _Backend | None = None, ) -> None: super().__init__() if norm_layer is None: @@ -512,6 +515,7 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.attn", use_data_parallel=use_data_parallel, + attn_backend_override=attn_backend_override, ) self.mlp = Qwen2VisionMLP( dim, @@ -662,6 +666,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, + attn_backend_override: _Backend | None = None, ) -> None: super().__init__() @@ -703,6 +708,7 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.blocks.{layer_idx}", use_data_parallel=use_data_parallel, + attn_backend_override=attn_backend_override, ) for layer_idx in range(depth) ] @@ -716,7 +722,9 @@ def __init__( use_data_parallel=use_data_parallel, ) self.attn_backend = get_vit_attn_backend( - head_size=head_dim, dtype=torch.get_default_dtype() + head_size=head_dim, + dtype=torch.get_default_dtype(), + attn_backend_override=attn_backend_override, ) if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( torch.get_default_dtype() @@ -1356,12 +1364,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): if multimodal_config.get_limit_per_prompt( "image" ) or multimodal_config.get_limit_per_prompt("video"): + attn_backend_override = ( + multimodal_config.mm_encoder_attn_backend + if multimodal_config is not None + else None + ) self.visual = Qwen2VisionTransformer( config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), quant_config=quant_config, prefix=maybe_prefix(prefix, "visual"), use_data_parallel=self.use_data_parallel, + attn_backend_override=attn_backend_override, ) else: self.visual = None diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py index 08bccee9e2d1..1176c559bffe 100755 --- a/vllm/model_executor/models/qwen3_omni_moe_thinker.py +++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py @@ -296,6 +296,7 @@ def __init__( norm_eps: float = 1e-6, quant_config: QuantizationConfig | None = None, prefix: str = "", + attn_backend_override: _Backend | None = None, ) -> None: super().__init__() self.hidden_size = vision_config.hidden_size @@ -367,7 +368,9 @@ def __init__( ) self.attn_backend = get_vit_attn_backend( - head_size=head_dim, dtype=torch.get_default_dtype() + head_size=head_dim, + dtype=torch.get_default_dtype(), + attn_backend_override=attn_backend_override, ) if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( torch.get_default_dtype() @@ -1144,11 +1147,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.audio_tower = Qwen3OmniMoeAudioEncoder(thinker_config.audio_config) + attn_backend_override = ( + multimodal_config.mm_encoder_attn_backend + if multimodal_config is not None + else None + ) self.visual = Qwen3Omni_VisionTransformer( vision_config=thinker_config.vision_config, norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6), quant_config=quant_config, prefix=maybe_prefix(prefix, "visual"), + attn_backend_override=attn_backend_override, ) self.quant_config = quant_config diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 6955fc80af6e..0ece93791954 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -300,6 +300,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, + attn_backend_override: _Backend | None = None, ) -> None: super().__init__() self.hidden_size = vision_config.hidden_size @@ -359,7 +360,9 @@ def __init__( ) self.attn_backend = get_vit_attn_backend( - head_size=head_dim, dtype=torch.get_default_dtype() + head_size=head_dim, + dtype=torch.get_default_dtype(), + attn_backend_override=attn_backend_override, ) use_upstream_fa = False if ( @@ -379,7 +382,6 @@ def __init__( raise RuntimeError( f"Qwen3-VL does not support {self.attn_backend} backend now." ) - self.blocks = nn.ModuleList( [ Qwen3_VisionBlock( @@ -1214,12 +1216,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"): ) and not multimodal_config.get_limit_per_prompt("video"): self.visual = None else: + attn_backend_override = ( + multimodal_config.mm_encoder_attn_backend + if multimodal_config is not None + else None + ) self.visual = Qwen3_VisionTransformer( config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), quant_config=quant_config, prefix=maybe_prefix(prefix, "visual"), use_data_parallel=self.use_data_parallel, + attn_backend_override=attn_backend_override, ) self.language_model = Qwen3LLMForCausalLM( diff --git a/vllm/model_executor/models/siglip2navit.py b/vllm/model_executor/models/siglip2navit.py index e7af0e7a7ae4..0e8dbcd61522 100644 --- a/vllm/model_executor/models/siglip2navit.py +++ b/vllm/model_executor/models/siglip2navit.py @@ -208,6 +208,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, + attn_backend_override: _Backend | None = None, ): super().__init__() self.config = config @@ -248,7 +249,9 @@ def __init__( # Detect attention implementation. self.attn_backend = get_vit_attn_backend( - head_size=self.head_dim, dtype=torch.get_default_dtype() + head_size=self.head_dim, + dtype=torch.get_default_dtype(), + attn_backend_override=attn_backend_override, ) self.use_upstream_fa = False @@ -372,6 +375,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, + attn_backend_override: _Backend | None = None, ): super().__init__() self.embed_dim = config.hidden_size @@ -381,6 +385,7 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.self_attn", use_data_parallel=use_data_parallel, + attn_backend_override=attn_backend_override, ) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = Siglip2MLP( @@ -434,6 +439,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, + attn_backend_override: _Backend | None = None, ): super().__init__() self.config = config @@ -444,6 +450,7 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.layers.{idx}", use_data_parallel=use_data_parallel, + attn_backend_override=attn_backend_override, ) for idx in range(config.num_hidden_layers) ] @@ -618,6 +625,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, + attn_backend_override: _Backend | None = None, ): super().__init__() self.config = config @@ -629,6 +637,7 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.encoder", use_data_parallel=use_data_parallel, + attn_backend_override=attn_backend_override, ) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) @@ -657,6 +666,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, + attn_backend_override: _Backend | None = None, ): super().__init__() @@ -665,6 +675,7 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.vision_model", use_data_parallel=use_data_parallel, + attn_backend_override=attn_backend_override, ) def forward( diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index bd5a6cf018d2..8bbb06f72772 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -78,10 +78,18 @@ def get_vision_encoder_info(hf_config: VisionLanguageConfig) -> VisionEncoderInf raise NotImplementedError(msg) -def get_vit_attn_backend(head_size: int, dtype: torch.dtype) -> _Backend: +def get_vit_attn_backend( + head_size: int, + dtype: torch.dtype, + *, + attn_backend_override: _Backend | None = None, +) -> _Backend: """ Get the available attention backend for Vision Transformer. """ + if attn_backend_override is not None: + return attn_backend_override + # Lazy import to avoid circular dependency from vllm.attention.selector import get_env_variable_attn_backend