Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions vllm/compilation/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from torch._dynamo.symbolic_convert import InliningInstructionTranslator

from vllm.compilation.counter import compilation_counter
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.compilation.wrapper import (CudaGraphWrapper,
TorchCompileWrapperWithCustomDispatcher)
from vllm.config import CompilationLevel, VllmConfig
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
Expand Down Expand Up @@ -140,11 +141,15 @@ def _support_torch_compile(
if TorchCompileWrapperWithCustomDispatcher in cls.__bases__:
# support decorating multiple times
return cls
if CudaGraphWrapper in cls.__bases__:
# support decorating multiple times
return cls
Comment on lines +144 to +146
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The checks for TorchCompileWrapperWithCustomDispatcher and CudaGraphWrapper in the base classes are separate but have identical logic. This can be combined into a single if statement using an or operator to improve readability and reduce code duplication.

Suggested change
if CudaGraphWrapper in cls.__bases__:
# support decorating multiple times
return cls
if (TorchCompileWrapperWithCustomDispatcher in cls.__bases__
or CudaGraphWrapper in cls.__bases__):
# support decorating multiple times
return cls


# take care of method resolution order
# make sure super().__init__ is called on the base class
# other than TorchCompileWrapperWithCustomDispatcher
cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, )
cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher,
CudaGraphWrapper)

old_init = cls.__init__

Expand All @@ -158,7 +163,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
] or not supports_dynamo()
if self.do_not_compile:
if vllm_config.compilation_config.simple_cuda_graph:
CudaGraphWrapper.__init__(self)
return

compilation_counter.num_models_seen += 1
TorchCompileWrapperWithCustomDispatcher.__init__(
self, compilation_level=vllm_config.compilation_config.level)
Expand All @@ -169,9 +177,14 @@ def __call__(self, *args, **kwargs):
# torch.compiler.is_compiling() means we are inside the compilation
# e.g. TPU has the compilation logic in model runner, so we don't
# need to compile the model inside.
if self.do_not_compile or torch.compiler.is_compiling():
if torch.compiler.is_compiling():
return self.forward(*args, **kwargs)

if self.do_not_compile:
if not self.vllm_config.compilation_config.simple_cuda_graph:
return self.forward(*args, **kwargs)
return self.forward_graph(*args, **kwargs)

# the first compilation needs to have dynamic shapes marked
if len(self.compiled_codes) < 1:
sig = inspect.signature(self.__class__.forward)
Expand Down
89 changes: 89 additions & 0 deletions vllm/compilation/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import sys
from abc import abstractmethod
from collections import defaultdict
from contextlib import contextmanager
from types import CodeType
from typing import Callable, Optional
Expand Down Expand Up @@ -133,3 +134,91 @@ def dispatch_to_code(self, index: int):
self.__class__.forward.__code__ = self.compiled_codes[index]
yield
self.__class__.forward.__code__ = self.original_code_object


class CudaGraphWrapper:

def __init__(self):
vllm_config = get_current_vllm_config()
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config

# configs
self.cudagraph_capture_sizes = set(
self.compilation_config.cudagraph_capture_sizes)
self.cudagraph_num_of_warmups = (
self.compilation_config.cudagraph_num_of_warmups)
assert self.compilation_config.simple_cuda_graph
assert self.compilation_config.full_cuda_graph

# states
# batch size -> graph
self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
self.pool = torch.cuda.graph_pool_handle()
# batch size -> hidden states
self.hidden_states: dict[int, torch.Tensor] = {}
# batch size -> number of warmups
self.num_warmups: dict[int, int] = defaultdict(int)
# Special flag to handle the first memory profiling run.
self.first_run_finished = False

def capture_graph(self, *args, **kwargs) -> None:
batch_size = self._get_batch_size(*args, **kwargs)
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, self.pool):
hidden_states = self.forward(*args, **kwargs)
self.hidden_states[batch_size] = hidden_states
self.graphs[batch_size] = graph

def forward_graph(self, *args, **kwargs) -> torch.Tensor:
if not self.first_run_finished:
# Memory profiling run.
self.first_run_finished = True
return self.forward(*args, **kwargs)

batch_size = self._get_batch_size(*args, **kwargs)
if batch_size not in self.cudagraph_capture_sizes:
# Run in eager mode.
return self.forward(*args, **kwargs)

if self.num_warmups[batch_size] < self.cudagraph_num_of_warmups:
# Warmup mode. Run in eager mode.
self.num_warmups[batch_size] += 1
return self.forward(*args, **kwargs)

if batch_size not in self.graphs:
# Capture the graph.
self.capture_graph(*args, **kwargs)
return self.hidden_states[batch_size]

# Run the graph and return the hidden states.
graph = self.graphs[batch_size]
graph.replay()
hidden_states = self.hidden_states[batch_size]
return hidden_states

@abstractmethod
def forward(self, *args, **kwargs):
...

def _get_batch_size(self, *args, **kwargs) -> int:
# NOTE(woosuk): Ensure that the keyword arguments here match those
# in the model's forward method signature.
input_ids = kwargs.get("input_ids")
if input_ids is not None:
return input_ids.shape[0]
input_embeds = kwargs.get("inputs_embeds")
if input_embeds is not None:
return input_embeds.shape[0]
intermediate_tensors = kwargs.get("intermediate_tensors")
if intermediate_tensors is not None:
return intermediate_tensors.shape[0]
# NOTE(woosuk): We don't use the `positions` tensor for batch size
# because its first dimension may not be the batch dimension for some
# models such as Qwen2.5-VL.
if len(args) > 0:
# For LoRA models, kwargs could be empty.
# FIXME(woosuk): This is a hack. We should find a more robust way
# to get the batch size.
return args[0].shape[0]
raise ValueError("No batch size found in arguments")
19 changes: 17 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4040,6 +4040,9 @@ class CompilationConfig:
splitting certain operations such as attention into subgraphs. Thus this
flag cannot be used together with splitting_ops. This may provide
performance benefits for smaller models."""
simple_cuda_graph: bool = False
"""Whether to use simple CUDA graph, which uses full CUDA graphs without
torch.compile. This can speed up the startup time of vLLM."""

pass_config: PassConfig = field(default_factory=PassConfig)
"""Custom inductor passes, see PassConfig for more details"""
Expand Down Expand Up @@ -4164,6 +4167,13 @@ def __post_init__(self) -> None:
if isinstance(self.pass_config, dict):
self.pass_config = PassConfig(**self.pass_config)

if self.simple_cuda_graph:
# When using simple CUDA graph, we skip torch.compile.
self.level = CompilationLevel.NO_COMPILATION
# Simple CUDA graph does not support piecewise CUDA graphs.
self.full_cuda_graph = True
self.use_cudagraph = True

def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
if self.level == CompilationLevel.NO_COMPILATION:
raise ValueError("No compilation level is set.")
Expand Down Expand Up @@ -4518,9 +4528,14 @@ def __post_init__(self):
not self.model_config.enforce_eager:
# By default, V1 uses piecewise CUDA graphs. If full_cuda_graph
# is set to True, full CUDA graphs will be used.
self.use_cudagraph = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The use_cudagraph attribute belongs to compilation_config, not VllmConfig directly. This line incorrectly creates a new attribute on the VllmConfig instance. It should be set on self.compilation_config.

Suggested change
self.use_cudagraph = True
self.compilation_config.use_cudagraph = True

self.compilation_config.cudagraph_num_of_warmups = 1
self.compilation_config.level = CompilationLevel.PIECEWISE
self.compilation_config.set_splitting_ops_for_v1()
if self.compilation_config.simple_cuda_graph:
self.compilation_config.full_cuda_graph = True
self.compilation_config.level = CompilationLevel.NO_COMPILATION
else:
self.compilation_config.level = CompilationLevel.PIECEWISE
self.compilation_config.set_splitting_ops_for_v1()

self._set_cudagraph_sizes()

Expand Down
14 changes: 9 additions & 5 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,15 @@ def __init__(
block_sizes=[self.cache_config.block_size],
)

self.use_cuda_graph = (
self.vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE
and self.vllm_config.compilation_config.use_cudagraph
and not self.model_config.enforce_eager)
self.use_cuda_graph = True
if self.model_config.enforce_eager:
self.use_cuda_graph = False
if not self.compilation_config.use_cudagraph:
self.use_cuda_graph = False
if (self.compilation_config.level != CompilationLevel.PIECEWISE
and not self.compilation_config.simple_cuda_graph):
self.use_cuda_graph = False
Comment on lines +217 to +224
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for determining self.use_cuda_graph is implemented with a series of if statements that conditionally set it to False. This can be simplified into a single, more readable boolean expression.

Suggested change
self.use_cuda_graph = True
if self.model_config.enforce_eager:
self.use_cuda_graph = False
if not self.compilation_config.use_cudagraph:
self.use_cuda_graph = False
if (self.compilation_config.level != CompilationLevel.PIECEWISE
and not self.compilation_config.simple_cuda_graph):
self.use_cuda_graph = False
self.use_cuda_graph = (
not self.model_config.enforce_eager
and self.compilation_config.use_cudagraph and
(self.compilation_config.level == CompilationLevel.PIECEWISE
or self.compilation_config.simple_cuda_graph))


# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
# The convention is different.
# self.cudagraph_batch_sizes sorts in ascending order.
Expand Down