Skip to content

Conversation

charlifu
Copy link
Contributor

@charlifu charlifu commented Jun 4, 2025

This PR adds full graph capture for TritonAttentionBackend.

  • add exemption for TritonAttentionBackend in model runner.
  • Avoid requirement for aot_scheduling in metadata build function by overwirte the build function and init function.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @charlifu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

Summary of Changes

Hello team,

Gemini or gemini-code-assist here to provide a summary of this pull request. This PR aims to enable full CUDA graph capture support specifically for the TritonAttentionBackend. This is achieved by making necessary modifications to the attention metadata builder for Triton and updating the model runner to allow Triton when full graph capture is enabled. The changes involve overriding the __init__ and build methods in the TritonAttentionMetadataBuilder to handle the specific requirements for graph capture, such as slot mapping initialization and incorporating logic for local and cascade attention within the metadata structure. Additionally, the validation in the GPUModelRunner is updated to explicitly permit TritonAttentionBackend alongside FlashAttention v3 when full_cuda_graph is configured.

Highlights

  • Full CUDA Graph Support: Adds support for full CUDA graph capture when using the TritonAttentionBackend, which can improve performance by reducing CPU overhead.
  • TritonAttentionMetadataBuilder Overhaul: The __init__ and build methods of TritonAttentionMetadataBuilder are significantly modified to correctly generate attention metadata compatible with full graph capture, including handling slot mapping, local attention, and cascade attention.
  • Model Runner Update: The GPUModelRunner is updated to recognize TritonAttentionBackend as a valid option when the full_cuda_graph compilation flag is enabled.

Changelog

  • vllm/v1/attention/backends/triton_attn.py
    • Overrode __init__ method in TritonAttentionMetadataBuilder to store necessary configuration directly.
    • Overrode build method in TritonAttentionMetadataBuilder to construct FlashAttentionMetadata for full graph capture, including initializing slot mapping with -1 for unused entries (lines 66-68), and incorporating logic for local attention (lines 72-97) and cascade attention (lines 99-116).
  • vllm/v1/worker/gpu_model_runner.py
    • Modified the initialize_attn_backend method to allow TritonAttentionBackend when full_cuda_graph is enabled, in addition to FlashAttention v3 (lines 2048-2050).
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.


Tokens flow like waves,
Graph captures the kernel's dance,
Speed on ROCm wakes.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

github-actions bot commented Jun 4, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Jun 4, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request successfully adds full graph capture support for TritonAttentionBackend. The changes in gpu_model_runner.py correctly exempt TritonAttentionBackend from the FA3-only restriction for full CUDA graphs. The core logic in triton_attn.py achieves its goal by overriding __init__ and build in TritonAttentionMetadataBuilder to avoid dependencies on AOT scheduling, effectively passing None for AOT-related scheduler metadata.

The main area for potential improvement is the significant code duplication introduced in TritonAttentionMetadataBuilder.build, which mirrors much of FlashAttentionMetadataBuilder.build. Addressing this could enhance long-term maintainability.

Summary of Findings

  • Code Duplication in TritonAttentionMetadataBuilder: The __init__ and build methods in TritonAttentionMetadataBuilder are largely duplicated from FlashAttentionMetadataBuilder to avoid AOT scheduling dependencies. This could lead to maintenance issues if the base class logic changes. Consider refactoring to reduce duplication or adding a TODO for future improvement.

Merge Readiness

The pull request effectively enables full graph capture for TritonAttentionBackend. The changes are functional and address the described objectives. However, there's a medium-severity concern regarding code duplication in TritonAttentionMetadataBuilder, which could impact future maintainability. It would be beneficial to discuss potential refactoring strategies for this duplication. I am unable to approve pull requests, but I recommend addressing or acknowledging the code duplication concern before merging. Other reviewers should make the final decision.

Comment on lines 32 to 134

def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
block_table: BlockTable):
super().__init__(runner, kv_cache_spec, block_table)
model_config = runner.model_config
compilation_config = runner.vllm_config.compilation_config

self.runner = runner
self.use_full_cuda_graph = compilation_config.full_cuda_graph
self.num_heads_q = model_config.get_num_attention_heads(
runner.parallel_config)
self.num_heads_kv = model_config.get_num_kv_heads(
runner.parallel_config)
self.headdim = model_config.get_head_size()
self.block_size = kv_cache_spec.block_size
self.kv_cache_spec = kv_cache_spec
self.block_table = block_table

# Sliding window size to be used with the AOT scheduler will be
# populated on first build() call.
self.aot_sliding_window: Optional[tuple[int, int]] = None
self.aot_schedule = False

def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata):
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
block_table = self.block_table
block_table_tensor = block_table.get_device_tensor()[:num_reqs]

block_table.slot_mapping[:num_actual_tokens].copy_(
block_table.slot_mapping_cpu[:num_actual_tokens],
non_blocking=True)
# Fill unused with -1. Needed for reshape_and_cache in full cuda graph
# mode.
block_table.slot_mapping[num_actual_tokens:].fill_(-1)

slot_mapping = block_table.slot_mapping[:num_actual_tokens]

# for local attention
local_attn_metadata = None
if self.runner.attention_chunk_size is not None:
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
virt_block_table_tensor = make_local_attention_virtual_batches(
self.runner.attention_chunk_size,
self.runner.query_start_loc_np[:num_reqs + 1],
self.runner.seq_lens_np[:num_reqs],
block_table_tensor,
self.block_size,
)
local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
self.runner.device, non_blocking=True)
local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
self.runner.device, non_blocking=True)
local_max_query_len = seqlens_q_local_np.max()
local_max_seq_len = virt_k_seqlens_np.max()

local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
local_query_start_loc=local_query_start_loc,
local_seqused_k=local_seqused_k,
local_block_table=virt_block_table_tensor,
local_max_query_len=local_max_query_len,
local_max_seq_len=local_max_seq_len,
local_scheduler_metadata=None,
)

use_cascade = common_prefix_len > 0

if use_cascade:
cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
dtype=torch.int32,
device=self.runner.device)
prefix_kv_lens = torch.tensor([common_prefix_len],
dtype=torch.int32,
device=self.runner.device)
suffix_kv_lens = (self.runner.seq_lens_np[:num_reqs] -
common_prefix_len)
suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(
self.runner.device)
else:
cu_prefix_query_lens = None
prefix_kv_lens = None
suffix_kv_lens = None
prefix_scheduler_metadata = None

attn_metadata = FlashAttentionMetadata(
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table=block_table_tensor,
slot_mapping=slot_mapping,
use_cascade=use_cascade,
common_prefix_len=common_prefix_len,
cu_prefix_query_lens=cu_prefix_query_lens,
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
local_attn_metadata=local_attn_metadata,
prefix_scheduler_metadata=prefix_scheduler_metadata,
)
return attn_metadata
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The __init__ and build methods in TritonAttentionMetadataBuilder are substantially similar to those in the parent FlashAttentionMetadataBuilder. While this achieves the goal of avoiding the AOT scheduling dependency from FlashAttention for Triton, it introduces significant code duplication. For instance, the logic for handling local_attn_metadata and cascade attention appears to be largely identical.

This duplication could pose a maintenance challenge:

  • If common logic within FlashAttentionMetadataBuilder.build (e.g., handling of local attention, cascade attention, slot mapping, block table) is updated or bug-fixed, these changes would need to be manually mirrored in TritonAttentionMetadataBuilder.build.
  • It increases the overall codebase size with redundant logic.

Could we explore ways to reduce this duplication for better long-term maintainability? For example:

  1. Could FlashAttentionMetadataBuilder.build be refactored to make the AOT scheduling part more modular or optional? Perhaps by passing a scheduler_fn or by having its internal schedule helper function return None if self.aot_schedule is False, and then the main build method handles None scheduler metadata appropriately.
  2. Could common sections (like local attention setup, cascade setup) be extracted into protected helper methods in FlashAttentionMetadataBuilder that TritonAttentionMetadataBuilder could then call, overriding only the parts related to AOT scheduling?

This would allow TritonAttentionMetadataBuilder to inherit more of the common logic while still achieving its specific goal. If this level of refactoring is out of scope for this PR, adding a TODO to track this potential future improvement would be valuable.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think honestly we should further decouple the builders so this is fine.

gshtras added a commit to ROCm/vllm that referenced this pull request Jun 5, 2025
@houseroad houseroad added the rocm Related to AMD ROCm label Jun 8, 2025
@gshtras gshtras added the ready ONLY add when PR is ready to merge/full CI is needed label Jun 9, 2025
@gshtras gshtras requested a review from ProExpertProg June 9, 2025 16:37
Copy link
Collaborator

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the support.

Could you list the test plan and results? Also could you add some unittest to cover this case?

Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI: we will have to align this PR with #18581 to make sure we have a consistent AttentionMetadataBuilder API for building for full CUDA graphs. Personally I like using a for_cudagraph_capture: bool = False flag to reduce the amount of different build functions

@ProExpertProg
Copy link
Collaborator

I'll respond in more detail tomorrow on my PR but I think a big downside is touching all of the build functions. I also think this is fundamentally different from the old way with direct met data passthrough construction - here the intention of the two different build methods is clearly different, and one can call the other.

Happy to discuss more tomorrow, sorry for the late response. But yeah let's figure this out before we merge either PR.

@mergify mergify bot added the needs-rebase label Jun 13, 2025
@ProExpertProg
Copy link
Collaborator

@charlifu #18581 merged, lmk if you need any help merging in the changes with the refactor, hopefully not bad at all

auto-merge was automatically disabled June 13, 2025 19:23

Head branch was pushed to by a user without write access

@mergify mergify bot removed the needs-rebase label Jun 13, 2025
Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

# Sliding window size to be used with the AOT scheduler will be
# populated on first build() call.
self.aot_sliding_window: Optional[tuple[int, int]] = None
self.aot_schedule = False
Copy link
Collaborator

@LucasWilkinson LucasWilkinson Jun 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Seems like aot_schedule and aot_sliding_window are not used? can these be removed?

runner.parallel_config)
self.num_heads_kv = model_config.get_num_kv_heads(
runner.parallel_config)
self.headdim = model_config.get_head_size()
Copy link
Collaborator

@LucasWilkinson LucasWilkinson Jun 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: seems like headdim is not used, can this be removed?

self.use_full_cuda_graph = compilation_config.full_cuda_graph
self.num_heads_q = model_config.get_num_attention_heads(
runner.parallel_config)
self.num_heads_kv = model_config.get_num_kv_heads(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: seems like num_heads_q and num_heads_kv are not used can these be removed?

compilation_config = runner.vllm_config.compilation_config

self.runner = runner
self.use_full_cuda_graph = compilation_config.full_cuda_graph
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: seems like use_full_cuda_graph is unused can this be removed?

Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall this seems very close to ready; There seems to be a few unused variables/attributes added, you please audit for unused variables please

Signed-off-by: charlifu <[email protected]>
@charlifu
Copy link
Contributor Author

overall this seems very close to ready; There seems to be a few unused variables/attributes added, you please audit for unused variables please

Unused vars removed. Not sure if we might need them in the future. But we can always add them back.

@charlifu charlifu requested a review from LucasWilkinson June 17, 2025 12:53
gshtras added a commit to ROCm/vllm that referenced this pull request Jun 17, 2025
@LucasWilkinson LucasWilkinson merged commit a44b1c9 into vllm-project:main Jun 17, 2025
68 checks passed
eliotwang pushed a commit to eliotwang/vllm that referenced this pull request Sep 1, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants