Skip to content

Conversation

ganyi1996ppo
Copy link

@ganyi1996ppo ganyi1996ppo commented Sep 26, 2025

Purpose

The current AiterFlashAttentionImpl fetches K/V every run, which creates unnecessary memory pressure and non-trivial latency—especially with long prompts. This PR:

  • Removes redundant KV fetches from the attention backend
  • Introduces a phase-aware execution (decode, pure prefill, chunk prefill) and reorders inputs to [decode:chunk_prefill:pure_prefill] for token-contiguous memory access.
  • Rewrites the “fetch KV” Triton kernel for better occupancy in chunked prefill and similar scenarios.

Design and implementation

image

Phase-aware path:

  • decode
  • chunk prefill (cp)
  • pure prefill (pp)

Input reordering to [decode:cp:pp] ensures tokens are contiguous in memory, improving kernel locality and occupancy. The reorder occurs in both Scheduler's scheduling phase and ModelRunner's state updating phase. We add this split_prefill_from_chunk to the SchedulerConfig to control this behavior, which will be turned on if both VLLM_ROCM_USE_AITER and VLLM_ROCM_USE_AITER_MHA are set.

Compared with the old one, this solution is more memory efficient and fast, especially on the long prompt scenario. Here is the Performance Measured on Qwen3, Mi308:

Long prompt, short output (2k prompt, 16 output): ~4.x throughput improvement.
Short prompt, long output (128 prompt, 1k output): ~2.x throughput improvement.
Extremely long prompt (192k prompt, 2k output): ~5.x throughput improvement.

Test Plan

acc : lm_eval test for accuracy verification
perf : vllm bench test

Test Result

2k prompt 16 output case:

old impl
image

new impl
image

128 prompt 1k output case:

old impl
image

new impl
image

acc verification

We test this PR on Qwen3-30B-A3B-FP8 on gsm8k with lm_eval, and here is the result:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8097|±  |0.0108|
|     |       |strict-match    |     5|exact_match|↑  |0.8901|±  |0.0086|

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new, more performant MHA backend implementation for ROCm. The changes include removing redundant KV fetches, introducing phase-aware execution (decode, pure prefill, chunk prefill), reordering inputs for better memory access, and rewriting the Triton kernel for fetching KV cache. The performance improvements demonstrated are significant. However, I have identified a few critical bugs in the implementation that need to be addressed. These include incorrect scaling in a Triton kernel, an invalid tensor view operation that will lead to a runtime error, and a logical error in the batch reordering logic. Addressing these issues is crucial for the correctness and stability of the new backend.

@ganyi1996ppo ganyi1996ppo changed the title [ROCm][Perf] New design on MHA backend Implementation [ROCm][Perf] New design on ROCm AITER MHA backend Implementation Sep 26, 2025
@wuhuikx
Copy link

wuhuikx commented Sep 27, 2025

Could you please help clarify, which Qwen3 model and datatype are you using? Could you please also append the accuracy results?

@wuhuikx
Copy link

wuhuikx commented Sep 28, 2025

cc @wuhuikx @sunway513

Signed-off-by: ganyi <[email protected]>
Signed-off-by: ganyi <[email protected]>
Signed-off-by: ganyi <[email protected]>
@ganyi1996ppo
Copy link
Author

ganyi1996ppo commented Sep 29, 2025

Could you please help clarify, which Qwen3 model and datatype are you using? Could you please also append the accuracy results?

Thanks for the suggestion, just update the PR description with model and accuracy verification.

@ganyi1996ppo
Copy link
Author

hi @gshtras , can you please take a look on this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
rocm Related to AMD ROCm v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants