-
-
Notifications
You must be signed in to change notification settings - Fork 10.7k
[ROCm][Aiter] Add triton fp8 bmm kernel for mla #23264
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Divakar Verma <[email protected]>
Signed-off-by: Divakar Verma <[email protected]>
Signed-off-by: Divakar Verma <[email protected]>
Signed-off-by: Divakar Verma <[email protected]>
Signed-off-by: Divakar Verma <[email protected]>
Signed-off-by: Divakar Verma <[email protected]>
Signed-off-by: Divakar Verma <[email protected]>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new Triton kernel for FP8 batched matrix multiplication (bmm
) within the MLA backend on ROCm, aimed at improving performance. The changes include adding a new environment variable VLLM_ROCM_USE_AITER_FP8BMM
to control this feature, and conditionally using the new kernel in place of torch.bmm
for specific operations in the MLA implementation.
My review has identified one critical issue: the new environment variable is not included in the computation graph hash, which can lead to incorrect caching behavior. This must be addressed to ensure correctness.
Signed-off-by: Divakar Verma <[email protected]>
Signed-off-by: Divakar Verma <[email protected]>
# triton kernel to avoid runtime compilation for unseen batch sizes | ||
# Pre-compile for batch sizes 1 to 1024 to cover most use-cases. | ||
# On DS-R1, this step adds roughly 50s to the model loading time. | ||
max_batch_size = 1024 # [ToDo] Find the optimal upper limit |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not crazy about adding this much overhead to the model loading time. CC @mgoin @LucasWilkinson I don't know how much pre-compilation we consider "acceptable". @divakar-amd how much does this improve performance?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Without pre-compilation, the performance is worse than torch.bmm for the first run. I added a plot above showing the performance difference if pre-compilation is not used. @SageMoore
Also, pre-compilation will add to the model loading time only if AITER is enabled.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is generally fine. As you pointed out offline we do precompilation for other AITER kernels already and that takes less time than torch.compile, which seems like a reasonable upper bound. I do agree that the TTFT improvements are nice.
Signed-off-by: Divakar Verma <[email protected]> Co-authored-by: ShaoChunLee <[email protected]>
Signed-off-by: Divakar Verma <[email protected]> Co-authored-by: ShaoChunLee <[email protected]>
Purpose
Replace
torch.bmm
with aiterbatched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant
kernel for MLATest Plan
Test Result
Correctness Result
Performance test on DeepSeek-R1 with full-cudagraph capture mode
Performance with and without this kernel

Pre-compilation for the kernel
If the kernel is not pre-compiled, the graph below shows the difference in performance when the kernel is run for the first time (aiter_BMM_run1) -vs- when the subsequent run (aiter_BMM_run2). Adding a pre-compilation step during weight loading resolves this issue.
