-
-
Notifications
You must be signed in to change notification settings - Fork 11.4k
[torch.compile][ROCm][V1] Enable attention output FP8 fusion for V1 attention backends #19767
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Gregory Shtrasberg <[email protected]>
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @gshtras, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces a significant optimization by enabling fused 8-bit floating-point (FP8) output quantization for attention operations within the V1 attention backends, specifically targeting ROCm platforms. The changes integrate FP8 conversion directly into the Triton kernels, allowing for potential improvements in memory efficiency and computational speed. The overall attention pipeline has been updated to support and utilize this new lower-precision output format.
Highlights
- FP8 Output Fusion: Introduced the capability for fused 8-bit floating-point (FP8) output quantization directly within the Triton attention kernels (
kernel_paged_attention_2d,_fwd_kernel,kernel_unified_attention_2d,reduce_segments). This involves addingout_scaleandUSE_FP8parameters to these kernels and implementing the scaling and clamping logic in their epilogues. - API Integration: Updated the Python attention wrappers (
chunked_prefill_paged_decode,context_attention_fwd,unified_attention) to accept anoutput_scaleparameter, which is then passed down to the Triton kernels to enable or disable FP8 output based on its presence. - Backend Support: Modified
TritonAttentionImplto declare support for fused output quantization via a newfused_output_quant_supportedmethod and removed theNotImplementedErrorcheck, allowingoutput_scaleto be passed through the attention pipeline.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces FP8 output fusion for V1 attention backends, which is a valuable feature for performance and memory optimization. The changes primarily involve adding out_scale parameters and FP8-specific logic (scaling and clamping) to various Triton kernels and their calling functions.
Key observations:
- The core FP8 logic seems correctly implemented in the kernels.
- The use of
query.dtypefor intermediate buffers liketmp_outputinchunked_prefill_paged_decode.pyis a good choice for maintaining precision. - The
fused_output_quant_supportedmethod inTritonAttentionImplis currently broad; it might need refinement later if specific FP8 configurations are not universally supported by this backend.
Suggestions for improvement:
- PR Description: The pull request description is currently a template. Please fill it out with the purpose, test plan, and test results to provide context for reviewers and future reference. This is especially important for a feature like FP8 fusion which can have numerical implications.
- Code Duplication: The FP8 output scaling and clamping logic (
acc = acc / tl.load(out_scale); acc = tl.clamp(acc, FP8_MIN, FP8_MAX)) is repeated in several Triton kernels (kernel_paged_attention_2d,_fwd_kernel,kernel_unified_attention_2d,reduce_segments). For better maintainability, consider refactoring this common logic into a shared Triton JIT utility function if feasible. For example:This is a medium-severity suggestion for future maintainability; the current approach is acceptable for this PR.@triton.jit def scale_and_clamp_fp8(acc, out_scale_ptr, fp8_min, fp8_max): scaled_acc = acc / tl.load(out_scale_ptr) return tl.clamp(scaled_acc, fp8_min, fp8_max)
Overall, the changes look reasonable for enabling FP8 output fusion. Thorough testing will be crucial to validate correctness and performance.
Signed-off-by: Gregory Shtrasberg <[email protected]>
c27a54b to
9417465
Compare
Signed-off-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: Gregory Shtrasberg <[email protected]>
…ized with .zeros from vllm-project#19784 Signed-off-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: Gregory Shtrasberg <[email protected]>
ProExpertProg
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think non-LLM tests (with custom TestModels) need to use the check function, this is only for LLM-based tests that have to do the check during compilation.
Signed-off-by: Luka Govedič <[email protected]> Signed-off-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: Gregory Shtrasberg <[email protected]>
ProExpertProg
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A couple of minor notes! Glad we don't have to do the complicated model loading anymore & great job reusing existing test. And we will remove the V0 test soon as well!
| "model_name, model_class", | ||
| CUDA_MODELS if current_platform.is_cuda() else ROCM_MODELS) | ||
| @pytest.mark.parametrize("backend", [_Backend.FLASHINFER] if | ||
| current_platform.is_cuda() else [_Backend.ROCM_FLASH]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't we want to test the triton backend as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is tested with the split_attention parameter. The 2 approaches share the same backend class
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, this actually dispatches to the triton backend. We should cleanup the attention backend selection logic on rocm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, I just remembered: we could test the Triton backend on CUDA as well, it would run in CI automatically which would be nice. Could you add the triton backend to the list of cuda backends?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can add Triton backend here as a follow-up
| self.attn._k_scale = self.attn._k_scale.to(device) | ||
| self.attn._v_scale = self.attn._v_scale.to(device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this necessary? Where would ROCm actually do this? Because I think it might break the Blackwell FI?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using on-cpu tensors in the reshape_and_cache kernel causes a crash. In production the default device is set to CUDA before the tensors are created, but not in the test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not just set the default device at the start of the test then?
… empty tensor creation on the fusion pattern Signed-off-by: Gregory Shtrasberg <[email protected]>
ProExpertProg
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good apart from the remaining comments. Is it possible to see unit tests running in AMD CI somewhere?
Signed-off-by: Gregory Shtrasberg <[email protected]>
|
This pull request has merge conflicts that must be resolved before it can be |
Verified this locally, until we have AMD tests running again for PRs |
Signed-off-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: Luka Govedič <[email protected]>
…ttention backends (vllm-project#19767) Signed-off-by: Gregory Shtrasberg <[email protected]> Signed-off-by: Luka Govedič <[email protected]> Co-authored-by: Luka Govedič <[email protected]> Co-authored-by: Luka Govedič <[email protected]>
…ttention backends (vllm-project#19767) Signed-off-by: Gregory Shtrasberg <[email protected]> Signed-off-by: Luka Govedič <[email protected]> Co-authored-by: Luka Govedič <[email protected]> Co-authored-by: Luka Govedič <[email protected]>
…ttention backends (vllm-project#19767) Signed-off-by: Gregory Shtrasberg <[email protected]> Signed-off-by: Luka Govedič <[email protected]> Co-authored-by: Luka Govedič <[email protected]> Co-authored-by: Luka Govedič <[email protected]> Signed-off-by: xuebwang-amd <[email protected]>
…ttention backends (vllm-project#19767) Signed-off-by: Gregory Shtrasberg <[email protected]> Signed-off-by: Luka Govedič <[email protected]> Co-authored-by: Luka Govedič <[email protected]> Co-authored-by: Luka Govedič <[email protected]> Signed-off-by: xuebwang-amd <[email protected]>
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.Purpose
An extension of #16756 for V1 unified attention (and its fallback split attention) backend.
Requires #19158 (full graph capture for this backend) to actually perform the fusion.
Fixes the fusion path to support torch.zeros initialized output tensor (used to be torch.empty before #19784)
Test Plan
To enable the feature in V1, the full cuda graph capture is required:
-O '{"pass_config":{"enable_attn_fusion":true,"enable_noop":true},"full_cuda_graph":true}'Test Result
Graph before fusion:
Graph after fusion
Performance
On MI300X
Without "enable_attn_fusion":true -> 17.5s
With "enable_attn_fusion":true -> 16.5s
Co-authored-by: Luka Govedič [email protected]