Skip to content

Conversation

@gshtras
Copy link
Collaborator

@gshtras gshtras commented Jun 17, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

An extension of #16756 for V1 unified attention (and its fallback split attention) backend.
Requires #19158 (full graph capture for this backend) to actually perform the fusion.

Fixes the fusion path to support torch.zeros initialized output tensor (used to be torch.empty before #19784)

Test Plan

To enable the feature in V1, the full cuda graph capture is required:
-O '{"pass_config":{"enable_attn_fusion":true,"enable_noop":true},"full_cuda_graph":true}'

Test Result

Graph before fusion:

     # File: /projects/ROCm/vllm_upstream/vllm/attention/layer.py:228 in forward, code: value = value.view(-1, self.num_kv_heads, self.head_size)
    view_9: "bf16[s0, 8, 128]" = torch.ops.aten.reshape.default(getitem_4, [-1, 8, 128]);  getitem_4 = None
    
     # File: /projects/ROCm/vllm_upstream/vllm/attention/layer.py:224 in forward, code: output = output.view(-1, self.num_heads, self.head_size)
    full_default: "bf16[s0, 32, 128]" = torch.ops.aten.full.default([arg1_1, 32, 128], 0.0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
    
     # File: /projects/ROCm/vllm_upstream/vllm/attention/layer.py:243 in forward, code: torch.ops.vllm.unified_attention_with_output(
    auto_functionalized_1 = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.unified_attention_with_output.default, query = cat, key = cat_1, value = view_9, output = full_default, layer_name = 'model.layers.0.self_attn.attn', output_scale = None);  cat = cat_1 = view_9 = full_default = None
    getitem_12: "bf16[s0, 32, 128]" = auto_functionalized_1[1];  auto_functionalized_1 = None
    
     # File: /projects/ROCm/vllm_upstream/vllm/_custom_ops.py:1261 in scaled_fp8_quant, code: output = torch.empty(shape, device=input.device, dtype=out_dtype)
    empty_1: "f8e4m3fnuz[s0, 4096]" = torch.ops.aten.empty.memory_format([arg1_1, 4096], dtype = torch.float8_e4m3fnuz, device = device(type='cuda', index=0), pin_memory = False)
    
     # File: /projects/ROCm/vllm_upstream/vllm/_custom_ops.py:1280 in scaled_fp8_quant, code: torch.ops._C.static_scaled_fp8_quant(output, input, scale)
    view_16: "bf16[s0, 4096]" = torch.ops.aten.reshape.default(getitem_12, [-1, 4096]);  getitem_12 = None
    auto_functionalized_2 = torch.ops.higher_order.auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default, result = empty_1, input = view_16, scale = arg9_1);  empty_1 = view_16 = None
    getitem_14: "f8e4m3fnuz[s0, 4096]" = auto_functionalized_2[1];  auto_functionalized_2 = None
    
     # File: /projects/ROCm/vllm_upstream/vllm/_custom_ops.py:1261 in scaled_fp8_quant, code: output = torch.empty(shape, device=input.device, dtype=out_dtype)
    empty_2: "f8e4m3fnuz[s0, 4096]" = torch.ops.aten.empty.memory_format([arg1_1, 4096], dtype = torch.float8_e4m3fnuz, device = device(type='cuda', index=0), pin_memory = False)
    
     # File: /projects/ROCm/vllm_upstream/vllm/model_executor/layers/quantization/utils/w8a8_utils.py:165 in rocm_per_tensor_w8a8_scaled_mm, code: output = torch._scaled_mm(qinput,
    _scaled_mm_1: "bf16[s0, 4096]" = torch.ops.aten._scaled_mm.default(getitem_14, arg10_1, arg9_1, arg11_1, None, None, torch.bfloat16);  getitem_14 = arg10_1 = arg9_1 = arg11_1 = None

Graph after fusion

     # File: /projects/ROCm/vllm_upstream/vllm/attention/layer.py:228 in forward, code: value = value.view(-1, self.num_kv_heads, self.head_size)
    view_9: "bf16[s0, 8, 128]" = torch.ops.aten.reshape.default(getitem_4, [-1, 8, 128]);  getitem_4 = None
    
     # File: /projects/ROCm/vllm_upstream/vllm/attention/layer.py:224 in forward, code: output = output.view(-1, self.num_heads, self.head_size)
    full_default: "bf16[s0, 32, 128]" = torch.ops.aten.full.default([arg1_1, 32, 128], 0.0, dtype = torch.bfloat16, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False);  full_default = None
    
     # File: /projects/ROCm/vllm_upstream/vllm/_custom_ops.py:1261 in scaled_fp8_quant, code: output = torch.empty(shape, device=input.device, dtype=out_dtype)
    empty_1: "f8e4m3fnuz[s0, 4096]" = torch.ops.aten.empty.memory_format([arg1_1, 4096], dtype = torch.float8_e4m3fnuz, device = device(type='cuda', index=0), pin_memory = False)
    
    # No stacktrace found for following nodes
    reshape_default_62: "f8e4m3fnuz[s0, 32, 128]" = torch.ops.aten.reshape.default(empty_1, [-1, 32, 128]);  empty_1 = None
    auto_functionalized_191 = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.unified_attention_with_output.default, query = cat, key = cat_1, value = view_9, output = reshape_default_62, layer_name = 'model.layers.0.self_attn.attn', output_scale = arg9_1);  cat = cat_1 = view_9 = reshape_default_62 = None
    getitem_639: "f8e4m3fnuz[s0, 32, 128]" = auto_functionalized_191[1];  auto_functionalized_191 = None
    reshape_default_63: "f8e4m3fnuz[s0, 4096]" = torch.ops.aten.reshape.default(getitem_639, [-1, 4096]);  getitem_639 = None
    
     # File: /projects/ROCm/vllm_upstream/vllm/_custom_ops.py:1261 in scaled_fp8_quant, code: output = torch.empty(shape, device=input.device, dtype=out_dtype)
    empty_2: "f8e4m3fnuz[s0, 4096]" = torch.ops.aten.empty.memory_format([arg1_1, 4096], dtype = torch.float8_e4m3fnuz, device = device(type='cuda', index=0), pin_memory = False)
    
     # File: /projects/ROCm/vllm_upstream/vllm/model_executor/layers/quantization/utils/w8a8_utils.py:165 in rocm_per_tensor_w8a8_scaled_mm, code: output = torch._scaled_mm(qinput,
    _scaled_mm_1: "bf16[s0, 4096]" = torch.ops.aten._scaled_mm.default(reshape_default_63, arg10_1, arg9_1, arg11_1, None, None, torch.bfloat16);  reshape_default_63 = arg10_1 = arg9_1 = arg11_1 = None

Performance

On MI300X

VLLM_V1_USE_PREFILL_DECODE_ATTENTION=1 python3 benchmarks/benchmark_latency.py --model amd/Llama-3.1-70B-Instruct-FP8-KV --batch-size 1 --input-len 128 --output-len 2048 --tensor-parallel-size 8 --compilation-config '{"full_cuda_graph": true,"custom_ops":["+rms_norm","+silu_and_mul"],"pass_config":{"enable_noop":true,"enable_fusion":true,"enable_attn_fusion":true}}' --dtype float16 --num-iters-warmup 3 --num-iters 5 --trust-remote-code

Without "enable_attn_fusion":true -> 17.5s
With "enable_attn_fusion":true -> 16.5s

Co-authored-by: Luka Govedič [email protected]

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @gshtras, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant optimization by enabling fused 8-bit floating-point (FP8) output quantization for attention operations within the V1 attention backends, specifically targeting ROCm platforms. The changes integrate FP8 conversion directly into the Triton kernels, allowing for potential improvements in memory efficiency and computational speed. The overall attention pipeline has been updated to support and utilize this new lower-precision output format.

Highlights

  • FP8 Output Fusion: Introduced the capability for fused 8-bit floating-point (FP8) output quantization directly within the Triton attention kernels (kernel_paged_attention_2d, _fwd_kernel, kernel_unified_attention_2d, reduce_segments). This involves adding out_scale and USE_FP8 parameters to these kernels and implementing the scaling and clamping logic in their epilogues.
  • API Integration: Updated the Python attention wrappers (chunked_prefill_paged_decode, context_attention_fwd, unified_attention) to accept an output_scale parameter, which is then passed down to the Triton kernels to enable or disable FP8 output based on its presence.
  • Backend Support: Modified TritonAttentionImpl to declare support for fused output quantization via a new fused_output_quant_supported method and removed the NotImplementedError check, allowing output_scale to be passed through the attention pipeline.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@mergify mergify bot added rocm Related to AMD ROCm v1 labels Jun 17, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces FP8 output fusion for V1 attention backends, which is a valuable feature for performance and memory optimization. The changes primarily involve adding out_scale parameters and FP8-specific logic (scaling and clamping) to various Triton kernels and their calling functions.

Key observations:

  • The core FP8 logic seems correctly implemented in the kernels.
  • The use of query.dtype for intermediate buffers like tmp_output in chunked_prefill_paged_decode.py is a good choice for maintaining precision.
  • The fused_output_quant_supported method in TritonAttentionImpl is currently broad; it might need refinement later if specific FP8 configurations are not universally supported by this backend.

Suggestions for improvement:

  • PR Description: The pull request description is currently a template. Please fill it out with the purpose, test plan, and test results to provide context for reviewers and future reference. This is especially important for a feature like FP8 fusion which can have numerical implications.
  • Code Duplication: The FP8 output scaling and clamping logic (acc = acc / tl.load(out_scale); acc = tl.clamp(acc, FP8_MIN, FP8_MAX)) is repeated in several Triton kernels (kernel_paged_attention_2d, _fwd_kernel, kernel_unified_attention_2d, reduce_segments). For better maintainability, consider refactoring this common logic into a shared Triton JIT utility function if feasible. For example:
    @triton.jit
    def scale_and_clamp_fp8(acc, out_scale_ptr, fp8_min, fp8_max):
        scaled_acc = acc / tl.load(out_scale_ptr)
        return tl.clamp(scaled_acc, fp8_min, fp8_max)
    This is a medium-severity suggestion for future maintainability; the current approach is acceptable for this PR.

Overall, the changes look reasonable for enabling FP8 output fusion. Thorough testing will be crucial to validate correctness and performance.

Signed-off-by: Gregory Shtrasberg <[email protected]>
@gshtras gshtras force-pushed the attention_fusion_v1 branch from c27a54b to 9417465 Compare June 17, 2025 20:38
@gshtras gshtras marked this pull request as ready for review June 30, 2025 17:45
Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think non-LLM tests (with custom TestModels) need to use the check function, this is only for LLM-based tests that have to do the check during compilation.

ProExpertProg and others added 4 commits June 30, 2025 22:22
Signed-off-by: Luka Govedič <[email protected]>

Signed-off-by: Gregory Shtrasberg <[email protected]>
@gshtras gshtras changed the title [Draft][torch.compile][ROCm][V1] Enable attention output FP8 fusion for V1 attention backends [torch.compile][ROCm][V1] Enable attention output FP8 fusion for V1 attention backends Jul 2, 2025
@gshtras gshtras added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 2, 2025
@gshtras gshtras requested a review from yewentao256 as a code owner September 3, 2025 22:11
@mergify mergify bot removed the needs-rebase label Sep 3, 2025
Signed-off-by: Gregory Shtrasberg <[email protected]>
Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple of minor notes! Glad we don't have to do the complicated model loading anymore & great job reusing existing test. And we will remove the V0 test soon as well!

"model_name, model_class",
CUDA_MODELS if current_platform.is_cuda() else ROCM_MODELS)
@pytest.mark.parametrize("backend", [_Backend.FLASHINFER] if
current_platform.is_cuda() else [_Backend.ROCM_FLASH])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we want to test the triton backend as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is tested with the split_attention parameter. The 2 approaches share the same backend class

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, this actually dispatches to the triton backend. We should cleanup the attention backend selection logic on rocm

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I just remembered: we could test the Triton backend on CUDA as well, it would run in CI automatically which would be nice. Could you add the triton backend to the list of cuda backends?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can add Triton backend here as a follow-up

Comment on lines +182 to +183
self.attn._k_scale = self.attn._k_scale.to(device)
self.attn._v_scale = self.attn._v_scale.to(device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this necessary? Where would ROCm actually do this? Because I think it might break the Blackwell FI?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using on-cpu tensors in the reshape_and_cache kernel causes a crash. In production the default device is set to CUDA before the tensors are created, but not in the test

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just set the default device at the start of the test then?

Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good apart from the remaining comments. Is it possible to see unit tests running in AMD CI somewhere?

@mergify
Copy link

mergify bot commented Sep 9, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @gshtras.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 9, 2025
@gshtras
Copy link
Collaborator Author

gshtras commented Sep 9, 2025

Looks good apart from the remaining comments. Is it possible to see unit tests running in AMD CI somewhere?

Verified this locally, until we have AMD tests running again for PRs

@mergify mergify bot removed the needs-rebase label Sep 9, 2025
@ProExpertProg ProExpertProg moved this from In progress to In review in torch.compile integration Sep 9, 2025
Signed-off-by: Luka Govedič <[email protected]>
@ProExpertProg ProExpertProg enabled auto-merge (squash) September 10, 2025 16:02
@simon-mo simon-mo disabled auto-merge September 10, 2025 20:59
@simon-mo simon-mo merged commit 9a16130 into vllm-project:main Sep 10, 2025
47 of 49 checks passed
@github-project-automation github-project-automation bot moved this from In review to Done in torch.compile integration Sep 10, 2025
skyloevil pushed a commit to skyloevil/vllm that referenced this pull request Sep 13, 2025
…ttention backends (vllm-project#19767)

Signed-off-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: Luka Govedič <[email protected]>
Co-authored-by: Luka Govedič <[email protected]>
Co-authored-by: Luka Govedič <[email protected]>
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
…ttention backends (vllm-project#19767)

Signed-off-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: Luka Govedič <[email protected]>
Co-authored-by: Luka Govedič <[email protected]>
Co-authored-by: Luka Govedič <[email protected]>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
…ttention backends (vllm-project#19767)

Signed-off-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: Luka Govedič <[email protected]>
Co-authored-by: Luka Govedič <[email protected]>
Co-authored-by: Luka Govedič <[email protected]>
Signed-off-by: xuebwang-amd <[email protected]>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
…ttention backends (vllm-project#19767)

Signed-off-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: Luka Govedič <[email protected]>
Co-authored-by: Luka Govedič <[email protected]>
Co-authored-by: Luka Govedič <[email protected]>
Signed-off-by: xuebwang-amd <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm torch.compile v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

3 participants