-
-
Notifications
You must be signed in to change notification settings - Fork 10.7k
[ROCm][Aiter] Add triton fp8 bmm kernel for mla #22759
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a Triton FP8 BMM kernel for MLA on ROCm to enhance performance. The changes include adding an environment variable to toggle this feature and updating the MLA attention implementation to leverage the new kernel for matrix multiplications. The implementation also involves quantizing weights to FP8 and includes a warmup loop for the Triton kernel. My review has identified two high-severity issues: the use of a hardcoded FP8 dtype instead of a platform-specific one, and a hardcoded warmup range for the Triton kernel that is insufficient for default configurations, potentially leading to performance degradation.
Signed-off-by: Divakar Verma <[email protected]>
Signed-off-by: Divakar Verma <[email protected]>
Signed-off-by: Divakar Verma <[email protected]>
Signed-off-by: Divakar Verma <[email protected]>
Signed-off-by: Divakar Verma <[email protected]>
Signed-off-by: Divakar Verma <[email protected]>
f881a40
to
c219220
Compare
Signed-off-by: Divakar Verma <[email protected]>
Signed-off-by: Xiaozhu <[email protected]> Signed-off-by: Michael Goin <[email protected]> Co-authored-by: Michael Goin <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: NickLucche <[email protected]>
Signed-off-by: Zifei Tong <[email protected]> Signed-off-by: mgoin <[email protected]> Co-authored-by: mgoin <[email protected]>
Signed-off-by: RUTHLESS-BOT <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
…t#22606) Signed-off-by: frankwang28 <[email protected]>
…t#22749) Signed-off-by: Harry Mellor <[email protected]>
…project#22673) Signed-off-by: Harry Mellor <[email protected]>
Signed-off-by: Jee Jee Li <[email protected]>
Signed-off-by: Michael Goin <[email protected]> Signed-off-by: Harry Mellor <[email protected]> Co-authored-by: Harry Mellor <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]> Signed-off-by: Woosuk Kwon <[email protected]>
This pull request has merge conflicts that must be resolved before it can be |
Re-created another PR with some updates: #23264 |
Purpose
Replace
torch.bmm
with aiterbatched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant
kernel for MLATest Plan
Test Result
Correctness Result
Performance test on DeepSeek-R1 with full-cudagraph capture mode
(Optional) Documentation Update