Skip to content

Conversation

WoosukKwon
Copy link
Collaborator

This PR implements the "simple CUDA graph" option, which enables using cuda graphs without torch.compile. This is not the default option, but users can use this option when the startup time is critical.

@github-actions
Copy link

github-actions bot commented Jul 1, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @WoosukKwon, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a new, simplified approach to leveraging CUDA graphs for performance optimization, specifically designed to bypass torch.compile. The primary goal is to significantly reduce application startup times by enabling direct CUDA graph capture and replay for model forward passes, offering users an alternative compilation strategy for critical scenarios.

Highlights

  • New CUDA Graph Option: Introduced a simple_cuda_graph option within CompilationConfig that enables the use of full CUDA graphs without relying on torch.compile. This is intended to reduce startup time.
  • CUDA Graph Wrapper Implementation: Implemented a new CudaGraphWrapper class responsible for managing CUDA graph capture, warmups, and replay. This wrapper handles different batch sizes and ensures proper execution flow for graphed operations.
  • Conditional Compilation Logic: Modified the _support_torch_compile decorator and its associated __init__ and __call__ methods to conditionally initialize and invoke the CudaGraphWrapper when the simple_cuda_graph option is enabled and torch.compile is not active.
  • Configuration and Integration: Updated CompilationConfig and VllmConfig to correctly set compilation levels and CUDA graph usage based on the simple_cuda_graph flag. The GpuModelRunner's CUDA graph activation logic was also refined to incorporate this new option.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@mergify mergify bot added the v1 label Jul 1, 2025
Signed-off-by: Woosuk Kwon <[email protected]>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This PR implements the simple CUDA graph option, which enables using cuda graphs without torch.compile. The changes look promising and cover the necessary areas from configuration to execution. There are a few critical issues related to incorrect configuration parameter access that would lead to runtime errors. Additionally, I've provided some suggestions to improve code clarity and maintainability.

not self.model_config.enforce_eager:
# By default, V1 uses piecewise CUDA graphs. If full_cuda_graph
# is set to True, full CUDA graphs will be used.
self.use_cudagraph = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The use_cudagraph attribute belongs to compilation_config, not VllmConfig directly. This line incorrectly creates a new attribute on the VllmConfig instance. It should be set on self.compilation_config.

Suggested change
self.use_cudagraph = True
self.compilation_config.use_cudagraph = True

Comment on lines +144 to +146
if CudaGraphWrapper in cls.__bases__:
# support decorating multiple times
return cls
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The checks for TorchCompileWrapperWithCustomDispatcher and CudaGraphWrapper in the base classes are separate but have identical logic. This can be combined into a single if statement using an or operator to improve readability and reduce code duplication.

Suggested change
if CudaGraphWrapper in cls.__bases__:
# support decorating multiple times
return cls
if (TorchCompileWrapperWithCustomDispatcher in cls.__bases__
or CudaGraphWrapper in cls.__bases__):
# support decorating multiple times
return cls

Comment on lines +217 to +224
self.use_cuda_graph = True
if self.model_config.enforce_eager:
self.use_cuda_graph = False
if not self.compilation_config.use_cudagraph:
self.use_cuda_graph = False
if (self.compilation_config.level != CompilationLevel.PIECEWISE
and not self.compilation_config.simple_cuda_graph):
self.use_cuda_graph = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for determining self.use_cuda_graph is implemented with a series of if statements that conditionally set it to False. This can be simplified into a single, more readable boolean expression.

Suggested change
self.use_cuda_graph = True
if self.model_config.enforce_eager:
self.use_cuda_graph = False
if not self.compilation_config.use_cudagraph:
self.use_cuda_graph = False
if (self.compilation_config.level != CompilationLevel.PIECEWISE
and not self.compilation_config.simple_cuda_graph):
self.use_cuda_graph = False
self.use_cuda_graph = (
not self.model_config.enforce_eager
and self.compilation_config.use_cudagraph and
(self.compilation_config.level == CompilationLevel.PIECEWISE
or self.compilation_config.simple_cuda_graph))

@WoosukKwon
Copy link
Collaborator Author

cc @SageMoore @zou3519 @youkaichao

@zou3519
Copy link
Collaborator

zou3519 commented Jul 1, 2025

cc @BoyuanFeng

@zou3519
Copy link
Collaborator

zou3519 commented Jul 1, 2025

@WoosukKwon do you want to ship this or is this more of a proof-of-concept? On the torch.compile side Boyuan and I will dig into why simple cudagraphs are faster and see if there's anything we can improve about the existing situation.

@WoosukKwon
Copy link
Collaborator Author

@zou3519 I was thinking that this would be a useful feature for some users, as long as there's a gap in the startup time.

@lionelvillard
Copy link
Contributor

@WoosukKwon how much faster (to build) is simple cudagraph compare to the existing solution?

@BoyuanFeng
Copy link
Contributor

BoyuanFeng commented Jul 2, 2025

@lionelvillard I did a benchmark on start time of LLAMA-3.1-70B. Simple CUDAGraph gives slightly faster CUDAGraph capture time than Full CUDAGraph. Most saving comes from skipping compile.

image

@zou3519
Copy link
Collaborator

zou3519 commented Jul 2, 2025

Btw, @WoosukKwon, I think #20059 is also implementing a full cudagraph wrapper that can be used without torch.compile

@yinghai
Copy link
Contributor

yinghai commented Jul 6, 2025

@BoyuanFeng what's the TP in the experiments that you ran?

@yinghai
Copy link
Contributor

yinghai commented Jul 6, 2025

I think #20059 is also implementing a full cudagraph wrapper that can be used without torch.compile

I don't think that's the case. It relies on torch dynamo to do the graph capturing. Then it decides whether to work on the whole graph module or the split pieces graph module for different cuda graph capturing. I think it's a quite specific case for kernels that support cudagraph well for decode but not well for prefill. Anyway, it adds up the startup time not reducing it.

In a sense #20059 is going an opposite direction (make graph compilation more complicated) of this PR (make graph compilation simpler).

@yinghai
Copy link
Contributor

yinghai commented Jul 6, 2025

Actually after some revision I think the work in 20059 now is somewhat orthogonal to the work here. Might be good to see if the cudagraph wrapper can be reused. But the work there is mainly to tell at forward metadata preparation phase, whether this is a decode only work load or not and hit the forward runner/attention backend about it. The forward runner here can be a torch compile wrapper or a plain cudagraph wrapper. I think the work here is to allow a plain cudagraph wrapper without involving the torch dynamo part.

@ProExpertProg
Copy link
Collaborator

@yinghai I've been helping shepherd #20059 and I agree that was the initial approach but I have asked the author to re-architect the PR to include "simple capture" that this PR is doing. Even if simple capture is not in that PR, the infrastructure from the PR can be reused for it.

@zou3519
Copy link
Collaborator

zou3519 commented Jul 7, 2025

@BoyuanFeng what's the TP in the experiments that you ran?

@yinghai --tensor-parallel-size 8

@zou3519
Copy link
Collaborator

zou3519 commented Jul 7, 2025

My question for this group is: let's assume that we get torch.compile warm start time down to close to zero. Do we still want a simple cudagraphs feature?

@ProExpertProg
Copy link
Collaborator

ProExpertProg commented Jul 7, 2025

@zou3519 I believe yes, currently DBO requires turning off torch.compile. In the future, there might be other cases where torch.compile is not supported. Obviously we want to support torch.compile in all cases long-term but I think it's good to have this as an orthogonal feature. #20059 should give us that in a simple way though

@yinghai
Copy link
Contributor

yinghai commented Jul 7, 2025

let's assume that we get torch.compile warm start time down to close to zero. Do we still want a simple cudagraphs feature?

yes please. give folks a choice would be good.

@yinghai
Copy link
Contributor

yinghai commented Jul 7, 2025

@yinghai I've been helping shepherd #20059 and I agree that was the initial approach but I have asked the author to re-architect the PR to include "simple capture" that this PR is doing. Even if simple capture is not in that PR, the infrastructure from the PR can be reused for it.

yeah reading through the changes and discussion in that PR i realized it, thanks a lot for that!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants