Skip to content

Conversation

@tjtanaa
Copy link
Collaborator

@tjtanaa tjtanaa commented Oct 2, 2025

Purpose

The refactoring of code has causes the vit flash attn dispatcher logic to enter the wrong code path to import
from vllm.vllm_flash_attn import flash_attn_varlen_func on ROCm platform.

Fix incorrect usage of aiter.flash_attn_varlen_func in MultiHeadAttention class introduced in #23978

Test Plan

Evaluate accuracy of all of the models that uses this vit flash attn dispatcher logic on chartqa dataset.
NOTE: The accuracy by no means indicates the actual model performance on benchmark and the accuracy is not evaluate through the same procedure used in the official release.

Bugfix of MultiHeadAttention class is validated through OpenGVLab/InternVL3_5-8B.

Test Result

Flash Attention Backend Comparison: AIter vs Non-AIter

Model Backend Explicit Prompt Relaxed Correctness Anywhere in Answer Relaxed Correctness
Qwen/Qwen2.5-VL-72B-Instruct AIter 0.8672 0.8860
Qwen/Qwen2.5-VL-72B-Instruct No AIter 0.8624 0.8848
Qwen/Qwen3-VL-235B-A22B-Instruct No AIter 0.8648 0.8656
Qwen/Qwen3-VL-235B-A22B-Instruct AIter 0.8656 0.8680
zai-org/GLM-4.5V-FP8 No AIter 0.5088 0.5716
zai-org/GLM-4.5V-FP8 AIter 0.4952 0.5580
baidu/ERNIE-4.5-VL-28B-A3B-PT No AIter 0.8424 0.8828
baidu/ERNIE-4.5-VL-28B-A3B-PT AIter 0.8444 0.8768
AIDC-AI/Ovis2.5-9B No AIter 0.8656 0.8764
AIDC-AI/Ovis2.5-9B AIter 0.8652 0.8784
OpenGVLab/InternVL3_5-8B No AIter 0.892 0.892
OpenGVLab/InternVL3_5-8B AIter 0.8964 0.8964

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify mergify bot added qwen Related to Qwen models rocm Related to AMD ROCm labels Oct 2, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the Vision Transformer (ViT) flash attention dispatcher logic to centralize it and fix a bug on the ROCm platform. The changes are consistent across multiple model files, replacing duplicated logic with a call to a new utility function maybe_get_vit_flash_attn_backend. This is a good improvement for maintainability. However, I've found a critical issue in the implementation of check_upstream_fa_availability which could lead to runtime errors.

Signed-off-by: tjtanaa <[email protected]>
@DarkLight1337 DarkLight1337 enabled auto-merge (squash) October 2, 2025 16:32
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 2, 2025
@DarkLight1337
Copy link
Member

cc @Isotr0py

auto-merge was automatically disabled October 2, 2025 17:17

Head branch was pushed to by a user without write access

@DarkLight1337 DarkLight1337 enabled auto-merge (squash) October 2, 2025 17:23
Signed-off-by: tjtanaa <[email protected]>
auto-merge was automatically disabled October 3, 2025 04:00

Head branch was pushed to by a user without write access

@vllm-bot vllm-bot merged commit 9c5ee91 into vllm-project:main Oct 3, 2025
50 of 55 checks passed
yewentao256 pushed a commit that referenced this pull request Oct 3, 2025
tomeras91 pushed a commit to tomeras91/vllm that referenced this pull request Oct 6, 2025
karan pushed a commit to karan/vllm that referenced this pull request Oct 6, 2025
southfreebird pushed a commit to southfreebird/vllm that referenced this pull request Oct 7, 2025
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
bradleyhd added a commit to bradleyhd/vllm that referenced this pull request Oct 17, 2025
Summary:
In vllm-project#26104, some changes were made in layer.py that resulted in always trying to switch to FA backend for ViT, even when  `VLLM_ATTENTION_BACKEND` is set.

This broke Meta's internal AMD pipelines as it is not desired nor expected behavior. With this change, the models that were changed in the offending PR can explicitly opt-in to this behavior.

Differential Revision: D84946967
bradleyhd added a commit to bradleyhd/vllm that referenced this pull request Oct 17, 2025
Summary:

In vllm-project#26104, some changes were made in layer.py that resulted in always trying to switch to FA backend for ViT, even when  `VLLM_ATTENTION_BACKEND` is set.

This broke Meta's internal AMD pipelines as it is not desired nor expected behavior. With this change, the models that were changed in the offending PR can explicitly opt-in to this behavior.

Differential Revision: D84946967
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
bradleyhd added a commit to bradleyhd/vllm that referenced this pull request Oct 22, 2025
Summary:
Pull Request resolved: vllm-project#27124

In vllm-project#26104, some changes were made in layer.py that resulted in always trying to switch to FA backend for ViT, even when  `VLLM_ATTENTION_BACKEND` is set.

This broke Meta's internal AMD pipelines as it is not desired nor expected behavior. With this change, the models that were changed in the offending PR can explicitly opt-in to this behavior.

Reviewed By: Prowindy

Differential Revision: D84946967
bradleyhd added a commit to bradleyhd/vllm that referenced this pull request Oct 22, 2025
Summary:
Pull Request resolved: vllm-project#27124

In vllm-project#26104, some changes were made in layer.py that resulted in always trying to switch to FA backend for ViT, even when  `VLLM_ATTENTION_BACKEND` is set.

This broke Meta's internal AMD pipelines as it is not desired nor expected behavior. With this change, the models that were changed in the offending PR can explicitly opt-in to this behavior.

Reviewed By: Prowindy

Differential Revision: D84946967
bradleyhd added a commit to bradleyhd/vllm that referenced this pull request Oct 22, 2025
Summary:
Pull Request resolved: vllm-project#27124

In vllm-project#26104, some changes were made in layer.py that resulted in always trying to switch to FA backend for ViT, even when  `VLLM_ATTENTION_BACKEND` is set.

This broke Meta's internal AMD pipelines as it is not desired nor expected behavior. With this change, the models that were changed in the offending PR can explicitly opt-in to this behavior.

Reviewed By: Prowindy

Differential Revision: D84946967
bradleyhd added a commit to bradleyhd/vllm that referenced this pull request Oct 22, 2025
Summary:
Pull Request resolved: vllm-project#27124

In vllm-project#26104, some changes were made in layer.py that resulted in always trying to switch to FA backend for ViT, even when  `VLLM_ATTENTION_BACKEND` is set.

This broke Meta's internal AMD pipelines as it is not desired nor expected behavior. With this change, the models that were changed in the offending PR can explicitly opt-in to this behavior.

Reviewed By: Prowindy

Differential Revision: D84946967
bradleyhd added a commit to bradleyhd/vllm that referenced this pull request Oct 22, 2025
Summary:
Pull Request resolved: vllm-project#27124

In vllm-project#26104, some changes were made in layer.py that resulted in always trying to switch to FA backend for ViT, even when  `VLLM_ATTENTION_BACKEND` is set.

This broke Meta's internal AMD pipelines as it is not desired nor expected behavior. With this change, the models that were changed in the offending PR can explicitly opt-in to this behavior.

Reviewed By: Prowindy

Differential Revision: D84946967
bradleyhd added a commit to bradleyhd/vllm that referenced this pull request Oct 23, 2025
Summary:
Pull Request resolved: vllm-project#27124

In vllm-project#26104, some changes were made in layer.py that resulted in always trying to switch to FA backend for ViT, even when  `VLLM_ATTENTION_BACKEND` is set.

This broke Meta's internal AMD pipelines as it is not desired nor expected behavior. With this change, the models that were changed in the offending PR can explicitly opt-in to this behavior.

Reviewed By: Prowindy

Differential Revision: D84946967
alhridoy pushed a commit to alhridoy/vllm that referenced this pull request Oct 24, 2025
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
rtourgeman pushed a commit to rtourgeman/vllm that referenced this pull request Nov 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants