Skip to content

Commit cf6b734

Browse files
bradleyhdfacebook-github-bot
authored andcommitted
make flash_attn ViT upgrade opt-in
Summary: In #26104, some changes were made in layer.py that resulted in always trying to switch to FA backend for ViT, even when `VLLM_ATTENTION_BACKEND` is set. This broke Meta's internal AMD pipelines as it is not desired nor expected behavior. With this change, the models that were changed in the offending PR can explicitly opt-in to this behavior. Differential Revision: D84946967
1 parent acedc74 commit cf6b734

File tree

7 files changed

+38
-43
lines changed

7 files changed

+38
-43
lines changed

vllm/attention/layer.py

Lines changed: 26 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def check_xformers_availability():
6565
return USE_XFORMERS_OPS
6666

6767

68-
def check_upstream_fa_availability(dtype: torch.dtype):
68+
def check_upstream_fa_availability(dtype: torch.dtype) -> bool:
6969
if (
7070
dtype in (torch.float16, torch.bfloat16)
7171
and current_platform.is_cuda()
@@ -80,26 +80,31 @@ def check_upstream_fa_availability(dtype: torch.dtype):
8080
return find_spec("flash_attn") is not None
8181
return False
8282

83+
def is_fa_backend(backend: _Backend) -> bool:
84+
return backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}
8385

8486
def maybe_get_vit_flash_attn_backend(
85-
attn_backend: _Backend, use_upstream_fa: bool
86-
) -> tuple[_Backend, Callable]:
87-
if (
88-
attn_backend != _Backend.FLASH_ATTN
89-
and attn_backend != _Backend.ROCM_AITER_FA
90-
and check_upstream_fa_availability(torch.get_default_dtype())
91-
):
92-
attn_backend = _Backend.FLASH_ATTN
93-
use_upstream_fa = True
87+
attn_backend: _Backend,
88+
try_switch_to_fa: bool = False,
89+
try_use_upstream_fa: bool = False) -> tuple[_Backend, Callable]:
9490

95-
if current_platform.is_rocm() and attn_backend == _Backend.FLASH_ATTN:
96-
use_upstream_fa = True
97-
98-
if attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
91+
if try_switch_to_fa and not is_fa_backend(attn_backend):
92+
attn_backend = _Backend.FLASH_ATTN
93+
94+
if current_platform.is_rocm() and \
95+
attn_backend == _Backend.FLASH_ATTN:
96+
# Always try upstream on ROCM.
97+
logger.info_once("maybe_get_vit_flash_attn_backend: forcing upstream FlashAttn on ROCM.")
98+
try_use_upstream_fa = True
99+
100+
if is_fa_backend(attn_backend):
99101
if attn_backend == _Backend.ROCM_AITER_FA:
100102
from aiter import flash_attn_varlen_func
101103
else:
102-
if use_upstream_fa:
104+
if try_use_upstream_fa:
105+
assert check_upstream_fa_availability( \
106+
torch.get_default_dtype()), \
107+
"Upstream FlashAttn is not available."
103108
from flash_attn import flash_attn_varlen_func
104109
else:
105110
from vllm.vllm_flash_attn import flash_attn_varlen_func
@@ -108,7 +113,6 @@ def maybe_get_vit_flash_attn_backend(
108113

109114
return attn_backend, flash_attn_varlen_func
110115

111-
112116
class Attention(nn.Module, AttentionLayerBase):
113117
"""Attention layer.
114118
@@ -428,11 +432,6 @@ def __init__(
428432
# Determine the attention backend
429433
backend = get_vit_attn_backend(head_size=head_size, dtype=dtype)
430434

431-
# Some auto-selected backends can be upgraded
432-
# to upstream flash attention if available.
433-
# If vllm native fa is selected, we use it directly.
434-
use_upstream_fa = False
435-
436435
if current_platform.is_xpu():
437436
# currently, only torch_sdpa is supported on xpu
438437
self.attn_backend = _Backend.TORCH_SDPA
@@ -450,30 +449,20 @@ def __init__(
450449
else _Backend.TORCH_SDPA
451450
)
452451

453-
self.attn_backend, self._flash_attn_varlen_func = (
454-
maybe_get_vit_flash_attn_backend(
452+
self.attn_backend, self._flash_attn_varlen_func \
453+
= maybe_get_vit_flash_attn_backend(
455454
self.attn_backend,
456-
use_upstream_fa,
455+
try_switch_to_fa=False,
456+
try_use_upstream_fa=False,
457457
)
458-
)
459458

460459
if self.attn_backend == _Backend.XFORMERS and not check_xformers_availability():
461460
self.attn_backend = _Backend.TORCH_SDPA
462461

463-
self.is_flash_attn_backend = self.attn_backend in {
464-
_Backend.FLASH_ATTN,
465-
_Backend.ROCM_AITER_FA,
466-
}
467-
468-
# this condition is just to make sure that the
469-
# use_upstream_fa in the log is correct
470-
if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN:
471-
use_upstream_fa = True
462+
self.is_flash_attn_backend = is_fa_backend(self.attn_backend)
472463

473464
logger.info_once(
474-
f"MultiHeadAttention attn_backend: {self.attn_backend}, "
475-
f"use_upstream_fa: {use_upstream_fa}"
476-
)
465+
f"MultiHeadAttention attn_backend: {self.attn_backend}")
477466

478467
def forward(
479468
self,

vllm/model_executor/models/dots_ocr.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,8 @@ def __init__(
295295
self.attn_backend, self.flash_attn_varlen_func = (
296296
maybe_get_vit_flash_attn_backend(
297297
self.attn_backend,
298-
self.use_upstream_fa,
298+
try_switch_to_fa=True,
299+
try_use_upstream_fa=self.use_upstream_fa,
299300
)
300301
)
301302
if self.attn_backend not in {

vllm/model_executor/models/ernie45_vl.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,8 @@ def __init__(
203203
self.attn_backend, self.flash_attn_varlen_func = (
204204
maybe_get_vit_flash_attn_backend(
205205
self.attn_backend,
206-
self.use_upstream_fa,
206+
try_switch_to_fa=True,
207+
try_use_upstream_fa=self.use_upstream_fa,
207208
)
208209
)
209210

vllm/model_executor/models/glm4_1v.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,8 @@ def __init__(
293293
self.attn_backend, self.flash_attn_varlen_func = (
294294
maybe_get_vit_flash_attn_backend(
295295
self.attn_backend,
296-
self.use_upstream_fa,
296+
try_switch_to_fa=True,
297+
try_use_upstream_fa=self.use_upstream_fa,
297298
)
298299
)
299300

vllm/model_executor/models/qwen2_5_vl.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,8 @@ def __init__(
345345
self.attn_backend, self.flash_attn_varlen_func = (
346346
maybe_get_vit_flash_attn_backend(
347347
self.attn_backend,
348-
self.use_upstream_fa,
348+
try_switch_to_fa=True,
349+
try_use_upstream_fa=self.use_upstream_fa,
349350
)
350351
)
351352
self.is_flash_attn_backend = self.attn_backend in {

vllm/model_executor/models/qwen2_vl.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,8 @@ def __init__(
361361
self.attn_backend, self.flash_attn_varlen_func = (
362362
maybe_get_vit_flash_attn_backend(
363363
self.attn_backend,
364-
self.use_upstream_fa,
364+
try_switch_to_fa=True,
365+
try_use_upstream_fa=self.use_upstream_fa,
365366
)
366367
)
367368

vllm/model_executor/models/siglip2navit.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,8 @@ def __init__(
255255
self.attn_backend, self.flash_attn_varlen_func = (
256256
maybe_get_vit_flash_attn_backend(
257257
self.attn_backend,
258-
self.use_upstream_fa,
258+
try_switch_to_fa=True,
259+
try_use_upstream_fa=self.use_upstream_fa,
259260
)
260261
)
261262

0 commit comments

Comments
 (0)