Skip to content

Conversation

@bradleyhd
Copy link
Contributor

Summary:
In #26104, some changes were made in layer.py that resulted in always trying to switch to FA backend for ViT, even when VLLM_ATTENTION_BACKEND is set.

This broke Meta's internal AMD pipelines as it is not desired nor expected behavior. With this change, the models that were changed in the offending PR can explicitly opt-in to this behavior.

Differential Revision: D84946967

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@mergify mergify bot added the qwen Related to Qwen models label Oct 17, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request makes the FlashAttention backend upgrade for Vision Transformer (ViT) models an opt-in behavior, addressing an issue where it was unconditionally attempted, causing problems on AMD platforms. The change is implemented by introducing a try_switch_to_fa flag in maybe_get_vit_flash_attn_backend and updating the call sites in various models.

The overall approach is sound and correctly addresses the reported issue. However, I've identified a critical bug in the new implementation that could lead to crashes on platforms not supporting FlashAttention, like XPU. I've also pointed out a high-severity maintainability issue regarding the modification of function parameters, which could make the code harder to reason about. Addressing these points will improve the robustness and clarity of the code.

Comment on lines 91 to 92
if try_switch_to_fa and not is_fa_backend(attn_backend):
attn_backend = _Backend.FLASH_ATTN
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current logic unconditionally switches the backend to FLASH_ATTN if try_switch_to_fa is true. This can cause a crash on platforms that do not support FlashAttention, such as XPU, because the subsequent import of vllm.vllm_flash_attn will fail. The switch should be guarded to only occur on supported platforms (CUDA and ROCm).

Suggested change
if try_switch_to_fa and not is_fa_backend(attn_backend):
attn_backend = _Backend.FLASH_ATTN
if try_switch_to_fa and not is_fa_backend(attn_backend) and (
current_platform.is_cuda() or current_platform.is_rocm()):
attn_backend = _Backend.FLASH_ATTN

attn_backend == _Backend.FLASH_ATTN:
# Always try upstream on ROCM.
logger.info_once("maybe_get_vit_flash_attn_backend: forcing upstream FlashAttn on ROCM.")
try_use_upstream_fa = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Modifying an input parameter try_use_upstream_fa directly is confusing and can lead to unexpected side effects. It's better to use a local variable to track the state within the function. For example, you could introduce use_upstream_fa = try_use_upstream_fa at the beginning of the function and then modify and use use_upstream_fa.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

bradleyhd added a commit to bradleyhd/vllm that referenced this pull request Oct 17, 2025
Summary:

In vllm-project#26104, some changes were made in layer.py that resulted in always trying to switch to FA backend for ViT, even when  `VLLM_ATTENTION_BACKEND` is set.

This broke Meta's internal AMD pipelines as it is not desired nor expected behavior. With this change, the models that were changed in the offending PR can explicitly opt-in to this behavior.

Differential Revision: D84946967
@bradleyhd
Copy link
Contributor Author

Updated to try and mimic #26104 as closely as possible to make this an equivalent change. Not sure the behavior in the original PR's is good / should be preserved, though.

@zhewenl zhewenl added rocm Related to AMD ROCm ci/build ci-failure Issue about an unexpected test failure in CI labels Oct 17, 2025
@zhewenl
Copy link
Collaborator

zhewenl commented Oct 17, 2025

This PR also fix existing AMD failures(example):

(EngineCore_DP0 pid=50574)   File "/usr/local/lib/python3.12/dist-packages/vllm/attention/layers/cross_attention.py", line 168, in __init__
--
  | (EngineCore_DP0 pid=50574)     super().__init__(
  | (EngineCore_DP0 pid=50574)   File "/usr/local/lib/python3.12/dist-packages/vllm/attention/layer.py", line 236, in __init__
  | (EngineCore_DP0 pid=50574)     self.impl = impl_cls(
  | (EngineCore_DP0 pid=50574)                 ^^^^^^^^^
  | (EngineCore_DP0 pid=50574)   File "/usr/local/lib/python3.12/dist-packages/vllm/v1/attention/backends/triton_attn.py", line 248, in __init__
  | (EngineCore_DP0 pid=50574)     raise NotImplementedError(
  | (EngineCore_DP0 pid=50574) NotImplementedError: Encoder self-attention and encoder/decoder cross-attention are not implemented for TritonAttentionImpl
  | [rank0]:[W1017 04:54:49.728888986 ProcessGroupNCCL.cpp:1538] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())

cc @Alexei-V-Ivanov-AMD

@DarkLight1337
Copy link
Member

cc @tjtanaa

@LucasWilkinson
Copy link
Collaborator

LucasWilkinson commented Oct 20, 2025

This logic is very confusing now; would be good to get more context here and try to refactor this a bit more aggressively,

cc @wwl2755 @tjtanaa

Seems like the original intention of using upstream FA is: #24347 , i.e. use it for models with head dim that is not supported by vllm-FA but is supported by upstream-FA; anything thats a multiple of 8 but not a multiple of 32

Can we just make all this logic; if on cuda and if head dim is not a multiple of 32 use upstream FA? otherwise just use get_attn_backend (like pre: #24347)?

@tjtanaa
Copy link
Collaborator

tjtanaa commented Oct 21, 2025

This logic is very confusing now; would be good to get more context here and try to refactor this a bit more aggressively,

cc @wwl2755 @tjtanaa

Seems like the original intention of using upstream FA is: #24347 , i.e. use it for models with head dim that is not supported by vllm-FA but is supported by upstream-FA; anything thats a multiple of 8 but not a multiple of 32

Can we just make all this logic; if on cuda and if head dim is not a multiple of 32 use upstream FA? otherwise just use get_attn_backend (like pre: #24347)?

To add on to @LucasWilkinson 's feedback

The logic that we should update should be

vllm/vllm/platforms/rocm.py

Lines 204 to 211 in c3a2c6a

def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend":
from vllm.attention.backends.registry import _Backend
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
return _Backend.ROCM_AITER_FA
if on_gfx9():
return _Backend.FLASH_ATTN
return _Backend.TORCH_SDPA

Right now, as long as on AMD instinct, we assume that ck-flash-attention library is installed. If we want to enable torch.sdpa , we can start from modifying this part first. If ck-flash-attention is not installed, or the head_dim is not supported by the specified backend, we fall back to torch.sdpa

Another thing that I notice that VLLM_ATTENTION_BACKEND semantics should be meant for Text Model backbone.

The set of ATTENTION BACKEND supported by ViT are TORCH_SDPA, FLASH_ATTN and ROCM_AITER_FA only.

However, the ATTENTION_BACKEND for LLM Backbone are TRITON_ATTN, ROCM_ATTN, ROCM_AITER_FA or ROCM_AITER_UNIFIED_ATTN.

So, I would suggest reserving the VLLM_ATTENTION_BACKEND environment variable for LLM Attention Backend Selection.

Moreover, on MI300 series, flash attention/ aiter flash attention is recommended to be used for ViT as it is the fastest. When torch.sdpa is selection, it is extremely slow as it does for loop to compute the attention output in majority of the vision models.

@DarkLight1337
Copy link
Member

Heads up that we have decoupled the two backends in #27061

@bradleyhd
Copy link
Contributor Author

Heads up that we have decoupled the two backends in #27061

@DarkLight1337 thanks. is maybe_get_vit_flash_attn_backend needed now in light of this PR?

@DarkLight1337
Copy link
Member

cc @ywang96 @Isotr0py

@ywang96
Copy link
Member

ywang96 commented Oct 22, 2025

@LucasWilkinson @tjtanaa @bradleyhd FYI on parallel to this PR, I've also made #27061 which decouples ViT attn backend from LM attn backend (which should probably be something we should've done from the get go).

@bradleyhd
Copy link
Contributor Author

@tjtanaa curious, is upstream FA in this case expected to be FAv3? (looking at #24347)

@bradleyhd
Copy link
Contributor Author

@ywang96 #27061 only works if we override to ROCM_AITER_FA, because it is exempt from the logic in maybe_get_vit_flash_attn_backend. If we set to TORCH_SDPA, it just gets overwritten with FA because we have a module named flash_attn installed

bradleyhd added a commit to bradleyhd/vllm that referenced this pull request Oct 22, 2025
Summary:
Pull Request resolved: vllm-project#27124

In vllm-project#26104, some changes were made in layer.py that resulted in always trying to switch to FA backend for ViT, even when  `VLLM_ATTENTION_BACKEND` is set.

This broke Meta's internal AMD pipelines as it is not desired nor expected behavior. With this change, the models that were changed in the offending PR can explicitly opt-in to this behavior.

Reviewed By: Prowindy

Differential Revision: D84946967
bradleyhd added a commit to bradleyhd/vllm that referenced this pull request Oct 22, 2025
Summary:
Pull Request resolved: vllm-project#27124

In vllm-project#26104, some changes were made in layer.py that resulted in always trying to switch to FA backend for ViT, even when  `VLLM_ATTENTION_BACKEND` is set.

This broke Meta's internal AMD pipelines as it is not desired nor expected behavior. With this change, the models that were changed in the offending PR can explicitly opt-in to this behavior.

Reviewed By: Prowindy

Differential Revision: D84946967
@bradleyhd bradleyhd changed the title make flash_attn ViT upgrade opt-in honor --mm_encoder_attn_backend when used Oct 22, 2025
@bradleyhd
Copy link
Contributor Author

alright folks, I've updated this to make use of the new --mm_encoder_attn_backend. When supplied, it won't auto-upgrade to FA. We need this asap to unblock as it allows us to specify torch_sdpa usage. We can and should circle back for a more comprehensive refactor here

bradleyhd added a commit to bradleyhd/vllm that referenced this pull request Oct 22, 2025
Summary:
Pull Request resolved: vllm-project#27124

In vllm-project#26104, some changes were made in layer.py that resulted in always trying to switch to FA backend for ViT, even when  `VLLM_ATTENTION_BACKEND` is set.

This broke Meta's internal AMD pipelines as it is not desired nor expected behavior. With this change, the models that were changed in the offending PR can explicitly opt-in to this behavior.

Reviewed By: Prowindy

Differential Revision: D84946967
bradleyhd added a commit to bradleyhd/vllm that referenced this pull request Oct 22, 2025
Summary:
Pull Request resolved: vllm-project#27124

In vllm-project#26104, some changes were made in layer.py that resulted in always trying to switch to FA backend for ViT, even when  `VLLM_ATTENTION_BACKEND` is set.

This broke Meta's internal AMD pipelines as it is not desired nor expected behavior. With this change, the models that were changed in the offending PR can explicitly opt-in to this behavior.

Reviewed By: Prowindy

Differential Revision: D84946967
bradleyhd added a commit to bradleyhd/vllm that referenced this pull request Oct 22, 2025
Summary:
Pull Request resolved: vllm-project#27124

In vllm-project#26104, some changes were made in layer.py that resulted in always trying to switch to FA backend for ViT, even when  `VLLM_ATTENTION_BACKEND` is set.

This broke Meta's internal AMD pipelines as it is not desired nor expected behavior. With this change, the models that were changed in the offending PR can explicitly opt-in to this behavior.

Reviewed By: Prowindy

Differential Revision: D84946967
Summary:
Pull Request resolved: vllm-project#27124

In vllm-project#26104, some changes were made in layer.py that resulted in always trying to switch to FA backend for ViT, even when  `VLLM_ATTENTION_BACKEND` is set.

This broke Meta's internal AMD pipelines as it is not desired nor expected behavior. With this change, the models that were changed in the offending PR can explicitly opt-in to this behavior.

Reviewed By: Prowindy

Differential Revision: D84946967
Copy link
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - I think we do need to think about how to deal with override in a better way (whether we should honor it truly with the risk of failure or handle the fallback automatically)

@ywang96 ywang96 added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 23, 2025
@DarkLight1337 DarkLight1337 changed the title honor --mm_encoder_attn_backend when used [Bugfix] Honor --mm_encoder_attn_backend when used Oct 23, 2025
@DarkLight1337 DarkLight1337 merged commit 570c3e1 into vllm-project:main Oct 23, 2025
56 checks passed
albertoperdomo2 pushed a commit to albertoperdomo2/vllm that referenced this pull request Oct 23, 2025
Co-authored-by: Bradley D <[email protected]>
Co-authored-by: Roger Wang <[email protected]>
Signed-off-by: Alberto Perdomo <[email protected]>
kingsmad pushed a commit to kingsmad/vllm that referenced this pull request Oct 25, 2025
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
ilmarkov pushed a commit to neuralmagic/vllm that referenced this pull request Nov 7, 2025
rtourgeman pushed a commit to rtourgeman/vllm that referenced this pull request Nov 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build ci-failure Issue about an unexpected test failure in CI qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

6 participants