-
-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[ROCm][Perf] New design on ROCm AITER MHA backend Implementation #25763
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[ROCm][Perf] New design on ROCm AITER MHA backend Implementation #25763
Conversation
Signed-off-by: ganyi <[email protected]>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run You ask your reviewers to trigger select CI tests on top of Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new, more performant MHA backend implementation for ROCm. The changes include removing redundant KV fetches, introducing phase-aware execution (decode, pure prefill, chunk prefill), reordering inputs for better memory access, and rewriting the Triton kernel for fetching KV cache. The performance improvements demonstrated are significant. However, I have identified a few critical bugs in the implementation that need to be addressed. These include incorrect scaling in a Triton kernel, an invalid tensor view operation that will lead to a runtime error, and a logical error in the batch reordering logic. Addressing these issues is crucial for the correctness and stability of the new backend.
Signed-off-by: ganyi <[email protected]>
Signed-off-by: ganyi <[email protected]>
Could you please help clarify, which Qwen3 model and datatype are you using? Could you please also append the accuracy results? |
Signed-off-by: ganyi <[email protected]>
Signed-off-by: ganyi <[email protected]>
Signed-off-by: ganyi <[email protected]>
Thanks for the suggestion, just update the PR description with model and accuracy verification. |
hi @gshtras , can you please take a look on this PR. |
Purpose
The current
AiterFlashAttentionImpl
fetches K/V every run, which creates unnecessary memory pressure and non-trivial latency—especially with long prompts. This PR:Design and implementation
Phase-aware path:
Input reordering to [decode:cp:pp] ensures tokens are contiguous in memory, improving kernel locality and occupancy. The reorder occurs in both Scheduler's scheduling phase and ModelRunner's state updating phase. We add this
split_prefill_from_chunk
to theSchedulerConfig
to control this behavior, which will be turned on if bothVLLM_ROCM_USE_AITER
andVLLM_ROCM_USE_AITER_MHA
are set.Compared with the old one, this solution is more memory efficient and fast, especially on the long prompt scenario. Here is the Performance Measured on Qwen3, Mi308:
Long prompt, short output (
2k prompt, 16 output
): ~4.x throughput improvement.Short prompt, long output (
128 prompt, 1k output
): ~2.x throughput improvement.Extremely long prompt (
192k prompt, 2k output
): ~5.x throughput improvement.Test Plan
acc : lm_eval test for accuracy verification
perf : vllm bench test
Test Result
2k prompt 16 output case:
old impl

new impl

128 prompt 1k output case:
old impl

new impl

acc verification
We test this PR on Qwen3-30B-A3B-FP8 on gsm8k with lm_eval, and here is the result:
Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.