diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index a3444c1ac82c..bd38f3679ece 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -93,12 +93,15 @@ def check_upstream_fa_availability(dtype: torch.dtype): def maybe_get_vit_flash_attn_backend( - attn_backend: _Backend, use_upstream_fa: bool + attn_backend: _Backend, + use_upstream_fa: bool, + attn_backend_override: _Backend | None = None, ) -> tuple[_Backend, Callable]: if ( attn_backend != _Backend.FLASH_ATTN and attn_backend != _Backend.ROCM_AITER_FA and check_upstream_fa_availability(torch.get_default_dtype()) + and attn_backend_override is None ): attn_backend = _Backend.FLASH_ATTN use_upstream_fa = True @@ -499,6 +502,7 @@ def __init__( maybe_get_vit_flash_attn_backend( self.attn_backend, use_upstream_fa, + attn_backend_override=attn_backend_override, ) ) diff --git a/vllm/model_executor/models/dots_ocr.py b/vllm/model_executor/models/dots_ocr.py index 4557ef71e3c2..6d462ad8ae62 100644 --- a/vllm/model_executor/models/dots_ocr.py +++ b/vllm/model_executor/models/dots_ocr.py @@ -299,6 +299,7 @@ def __init__( maybe_get_vit_flash_attn_backend( self.attn_backend, self.use_upstream_fa, + attn_backend_override=attn_backend_override, ) ) if self.attn_backend not in { diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index 372675178ccc..86536b21c33f 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -206,6 +206,7 @@ def __init__( maybe_get_vit_flash_attn_backend( self.attn_backend, self.use_upstream_fa, + attn_backend_override=attn_backend_override, ) ) diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 38512f22ba8a..bed7c81335e0 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -296,6 +296,7 @@ def __init__( maybe_get_vit_flash_attn_backend( self.attn_backend, self.use_upstream_fa, + attn_backend_override=attn_backend_override, ) ) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 827b7f4aa26f..94436fe009f1 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -364,6 +364,7 @@ def __init__( maybe_get_vit_flash_attn_backend( self.attn_backend, self.use_upstream_fa, + attn_backend_override=attn_backend_override, ) ) diff --git a/vllm/model_executor/models/siglip2navit.py b/vllm/model_executor/models/siglip2navit.py index 0e8dbcd61522..bab5c1d82ded 100644 --- a/vllm/model_executor/models/siglip2navit.py +++ b/vllm/model_executor/models/siglip2navit.py @@ -259,6 +259,7 @@ def __init__( maybe_get_vit_flash_attn_backend( self.attn_backend, self.use_upstream_fa, + attn_backend_override=attn_backend_override, ) )