Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
4f6e1b4
init
BoyuanFeng Sep 4, 2025
1c1b600
nit
BoyuanFeng Sep 4, 2025
50d1dda
nit
BoyuanFeng Sep 4, 2025
7218e2b
nit
BoyuanFeng Sep 5, 2025
71209e2
cleanup
BoyuanFeng Sep 5, 2025
202b6f3
add doc
BoyuanFeng Sep 5, 2025
0b1e18a
improve warn/error msg
BoyuanFeng Sep 5, 2025
b66568b
match new torch api
BoyuanFeng Sep 5, 2025
87c74dd
skip cudagraph for get_input_embedding
BoyuanFeng Sep 9, 2025
c0bd3fb
Update vllm/compilation/backends.py
BoyuanFeng Sep 11, 2025
e16e23a
Apply suggestions from code review
BoyuanFeng Sep 11, 2025
892ab46
more docs
BoyuanFeng Sep 11, 2025
eabb1b6
nit
BoyuanFeng Sep 11, 2025
04e9801
Update vllm/v1/cudagraph_dispatcher.py
BoyuanFeng Sep 12, 2025
6cf5bd5
add piecewise test
BoyuanFeng Sep 14, 2025
70f45da
lint
BoyuanFeng Sep 14, 2025
7eb5d57
Merge branch 'main' into bf/cg-partition
BoyuanFeng Sep 15, 2025
3a6abd8
Merge branch 'main' into bf/cg-partition
BoyuanFeng Sep 15, 2025
4cce30c
add custom compile config test
BoyuanFeng Sep 15, 2025
d3809fb
more tests for splitting_ops
BoyuanFeng Sep 16, 2025
d7a73db
add tests for attention_quant_pattern
BoyuanFeng Sep 17, 2025
289a60e
rearch is_attention_compiled_piecewise
BoyuanFeng Sep 17, 2025
29ae5f0
Merge branch 'main' into bf/cg-partition
BoyuanFeng Sep 18, 2025
b5972fa
move set/unset wrapper to support_torch_compile for frame-specific
BoyuanFeng Sep 18, 2025
7570f4b
update test_attention_quant_pattern
BoyuanFeng Sep 18, 2025
c7ff7c4
Update vllm/config/compilation.py
BoyuanFeng Sep 18, 2025
4a38b36
more tests
BoyuanFeng Sep 18, 2025
d4269d9
move wrapper set/unset to context manager
BoyuanFeng Sep 18, 2025
20b9ef1
nit
BoyuanFeng Sep 18, 2025
e055458
update test
BoyuanFeng Sep 19, 2025
45b7588
Merge branch 'main' into bf/cg-partition
BoyuanFeng Sep 19, 2025
91c03a4
move maybe_use_cudagraph_partition_wrapper to decorators.py
BoyuanFeng Sep 19, 2025
19787d3
test inductor graph partition only when >= torch2.9
BoyuanFeng Sep 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 60 additions & 11 deletions tests/compile/piecewise/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
VllmConfig, set_current_vllm_config)
from vllm.envs import VLLM_USE_V1
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils import is_torch_equal_or_newer

# This import automatically registers `torch.ops.silly.attention`
from ..silly_attention import get_global_counter, reset_global_counter
Expand Down Expand Up @@ -50,16 +51,21 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


@pytest.mark.parametrize("use_inductor", [True, False])
@torch.inference_mode()
def test_simple_piecewise_compile(use_inductor):
assert VLLM_USE_V1

def _run_simple_model(
splitting_ops,
use_inductor_graph_partition,
use_inductor,
expected_num_piecewise_graphs_seen,
expected_num_piecewise_capturable_graphs_seen,
expected_num_backend_compilations,
expected_num_cudagraph_captured,
):
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
use_inductor=use_inductor,
splitting_ops=["silly.attention"],
splitting_ops=splitting_ops,
use_inductor_graph_partition=use_inductor_graph_partition,
cudagraph_copy_inputs=True,
cudagraph_capture_sizes=[1, 2],
))
Expand All @@ -70,11 +76,11 @@ def test_simple_piecewise_compile(use_inductor):

with compilation_counter.expect(
num_graphs_seen=1, # one graph for the model
num_piecewise_graphs_seen=5, # 2 * num_layers + 1
num_piecewise_capturable_graphs_seen=3, # 1 + num_layers
num_backend_compilations=3, # num_piecewise_capturable_graphs_seen
num_cudagraph_captured=
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
num_piecewise_graphs_seen=expected_num_piecewise_graphs_seen,
num_piecewise_capturable_graphs_seen=
expected_num_piecewise_capturable_graphs_seen,
num_backend_compilations=expected_num_backend_compilations,
num_cudagraph_captured=expected_num_cudagraph_captured,
), set_forward_context(None,
vllm_config=vllm_config): # background context
# warm up with background context
Expand Down Expand Up @@ -104,3 +110,46 @@ def test_simple_piecewise_compile(use_inductor):
output = model(input)
assert get_global_counter() == 2
assert torch.allclose(output.cpu(), torch.tensor([19.0, 19.0]))


@pytest.mark.parametrize("use_inductor", [True, False])
@torch.inference_mode()
def test_simple_piecewise_compile(use_inductor):
assert VLLM_USE_V1
_run_simple_model(
splitting_ops=["silly.attention"],
use_inductor_graph_partition=False,
use_inductor=use_inductor,
expected_num_piecewise_graphs_seen=5, # 2 * num_layers + 1
expected_num_piecewise_capturable_graphs_seen=3, # 1 + num_layers
expected_num_backend_compilations=
3, # num_piecewise_capturable_graphs_seen
expected_num_cudagraph_captured=
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
)


@torch.inference_mode()
@pytest.mark.parametrize("splitting_ops", [["silly.attention"], []])
def test_simple_inductor_graph_partition(splitting_ops):
assert VLLM_USE_V1
if not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("inductor graph partition is only available "
"in PyTorch 2.9+")

_run_simple_model(
# inductor graph partition automatically resets splitting_ops
# to be an empty list
splitting_ops=splitting_ops,
use_inductor_graph_partition=True,
use_inductor=True,
expected_num_piecewise_graphs_seen=
1, # since not splitting at fx graph level
expected_num_piecewise_capturable_graphs_seen=
1, # since not splitting at fx graph level
expected_num_backend_compilations=
1, # since not splitting at fx graph level
expected_num_cudagraph_captured=
6, # inductor graph partition still captures 6
# graph, same as fx graph partition.
)
1 change: 1 addition & 0 deletions tests/compile/silly_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,5 @@ def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
mutates_args=["out"],
fake_impl=silly_attention_fake,
target_lib=silly_lib,
tags=(torch._C.Tag.cudagraph_unsafe, ),
)
59 changes: 58 additions & 1 deletion tests/compile/test_full_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,21 @@

from __future__ import annotations

import logging
import tempfile
from typing import Any, Optional, Union

import pytest
import torch

from tests.quantization.utils import is_quant_method_supported
from tests.v1.attention.utils import _Backend
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig, CompilationLevel, PassConfig
from vllm.attention.selector import global_force_attn_backend_context_manager
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
PassConfig)
from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer

from ..utils import create_new_process_for_each_test

Expand Down Expand Up @@ -107,18 +112,70 @@ def test_full_graph(
(CompilationConfig(level=CompilationLevel.PIECEWISE,
debug_dump_path=tempfile.gettempdir()),
("facebook/opt-125m", {})),
] + [
# graph inductor partition
(
CompilationConfig(
level=CompilationLevel.PIECEWISE,
# inductor graph partition uses
# torch._C.Tag.cudagraph_unsafe to specify splitting ops
use_inductor_graph_partition=True,
cudagraph_mode=CUDAGraphMode.PIECEWISE,
compile_sizes=[1, 2]),
model) for model in models_list(all=False)
if is_torch_equal_or_newer("2.9.0.dev")
])
# only test some of the models
@create_new_process_for_each_test()
def test_custom_compile_config(
compilation_config: CompilationConfig,
model_info: tuple[str, dict[str, Any]],
):
if (compilation_config.use_inductor_graph_partition
and not is_torch_equal_or_newer("2.9.0.dev")):
pytest.skip("inductor graph partition is only available "
"in PyTorch 2.9+")

model, model_kwargs = model_info
print(f"MODEL={model}")
run_model(compilation_config, model, model_kwargs)


def test_inductor_graph_partition_attn_fusion(caplog_vllm):
if not is_torch_equal_or_newer("2.9.0.dev"):
pytest.skip("inductor graph partition is only available "
"in PyTorch 2.9+")

model = "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8"
compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_inductor_graph_partition=True,
cudagraph_mode=CUDAGraphMode.PIECEWISE,
custom_ops=["+quant_fp8"],
pass_config=PassConfig(enable_attn_fusion=True, enable_noop=True),
)
model_kwargs = {
"kv_cache_dtype": "fp8",
"max_model_len": 1024,
}
with caplog_vllm.at_level(
logging.DEBUG), global_force_attn_backend_context_manager(
_Backend.FLASHINFER):
run_model(compilation_config, model, model_kwargs)

try:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just disable the cache here instead?

assert ("Fused quantization onto 48 attention nodes"
in caplog_vllm.text), caplog_vllm.text
except AssertionError:
# Note: this message is only triggered when the compilation goes
# through the custom pass. Due to multiple layers of cache on
# PyTorch side, the compilation of a graph may be cached such
# that custom pass directly goes through cache. In this case,
# we go through this branch and assert that the pass is not
# triggered.
assert "Fused quantization" not in caplog_vllm.text


def run_model(compile_config: Union[int, CompilationConfig], model: str,
model_kwargs: dict[str, Any]):
prompts = [
Expand Down
16 changes: 15 additions & 1 deletion tests/compile/test_fusion_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp)
from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer
from vllm.v1.kv_cache_interface import AttentionSpec

FP8_DTYPE = current_platform.fp8_dtype()
Expand Down Expand Up @@ -339,6 +340,10 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
@pytest.mark.parametrize(
"split_attention",
[False, True] if current_platform.is_rocm() else [False])
# TODO(boyuan): test inductor graph partition on rocm
@pytest.mark.parametrize(
"use_inductor_graph_partition",
[False] if current_platform.is_rocm() else [False, True])
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
reason="Only test ROCm or CUDA")
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
Expand All @@ -352,9 +357,15 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
dtype: torch.dtype, model_name: str,
model_class: type[AttentionQuantPatternModel],
backend: _Backend, split_attention: bool,
monkeypatch, dist_init):
use_inductor_graph_partition: bool,
monkeypatch, dist_init, caplog_vllm):
"""Test AttentionStaticQuantPattern fusion pass"""

if use_inductor_graph_partition and not is_torch_equal_or_newer(
"2.9.0.dev"):
pytest.skip("inductor graph partition is only available "
"in PyTorch 2.9+")

monkeypatch.setenv("VLLM_USE_V1", "1")
if split_attention:
monkeypatch.setenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "1")
Expand All @@ -372,6 +383,7 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
custom_ops=["+quant_fp8"],
use_inductor_graph_partition=use_inductor_graph_partition,
),
cache_config=CacheConfig(cache_dtype="fp8"))

Expand Down Expand Up @@ -444,6 +456,7 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
backend=test_backend,
fullgraph=True)
assert model_compiled.attn._o_scale_float is None

result_fused_1 = model_compiled(q, k, v)

if backend == _Backend.FLASHINFER:
Expand All @@ -453,6 +466,7 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
# _o_scale_float
assert model_compiled.attn._o_scale_float is not None
result_fused_2 = model_compiled(q, k, v)

assert model_compiled.attn._o_scale_float is not None

torch.testing.assert_close(result_unfused,
Expand Down
2 changes: 2 additions & 0 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,7 @@ def unified_attention_fake(
mutates_args=[],
fake_impl=unified_attention_fake,
dispatch_key=current_platform.dispatch_key,
tags=(torch._C.Tag.cudagraph_unsafe, ),
)


Expand Down Expand Up @@ -625,4 +626,5 @@ def unified_attention_with_output_fake(
mutates_args=["output", "output_block_scale"],
fake_impl=unified_attention_with_output_fake,
dispatch_key=current_platform.dispatch_key,
tags=(torch._C.Tag.cudagraph_unsafe, ),
)
10 changes: 8 additions & 2 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ def call_module(self, target: torch.fx.node.Target,
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
]
global compilation_start_time

compiled_graph_for_dynamic_shape = self.vllm_backend.\
compiler_manager.compile(
submod,
Expand All @@ -336,15 +337,20 @@ def call_module(self, target: torch.fx.node.Target,
num_graphs=len(self.compile_submod_names),
runtime_shape=None)
# Lazy import here to avoid circular import
from .cuda_graph import CUDAGraphOptions
from .cuda_piecewise_backend import PiecewiseBackend

piecewise_backend = PiecewiseBackend(
submod, self.vllm_config, index,
len(self.compile_submod_names), sym_shape_indices,
compiled_graph_for_dynamic_shape, self.vllm_backend)

if self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE:
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
and
not self.compilation_config.use_inductor_graph_partition):
# We're using Dynamo-based piecewise splitting, so we wrap
# the whole subgraph with a static graph wrapper.
from .cuda_graph import CUDAGraphOptions

# resolve the static graph wrapper class (e.g. CUDAGraphWrapper
# class) as platform dependent.
static_graph_wrapper_class = resolve_obj_by_qualname(
Expand Down
57 changes: 55 additions & 2 deletions vllm/compilation/decorators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import contextlib
import inspect
from typing import Callable, Optional, TypeVar, Union, overload
from unittest.mock import patch
Expand All @@ -14,7 +15,7 @@
from vllm.config import CompilationLevel, VllmConfig
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
from vllm.utils import supports_dynamo
from vllm.utils import resolve_obj_by_qualname, supports_dynamo

from .monitor import start_monitoring_torch_compile

Expand Down Expand Up @@ -301,8 +302,11 @@ def patched_inline_call(parent, func, args, kwargs):

with patch.object(InliningInstructionTranslator, 'inline_call',
patched_inline_call), torch._dynamo.config.patch(
**dynamo_config_patches):
**dynamo_config_patches
), maybe_use_cudagraph_partition_wrapper(
self.vllm_config):
output = self.compiled_callable(*args, **kwargs)

return output

# usually, capturing the model once is enough, and then we can
Expand All @@ -314,3 +318,52 @@ def patched_inline_call(parent, func, args, kwargs):

cls.__call__ = __call__
return cls


@contextlib.contextmanager
def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
"""
Context manager to set/unset customized cudagraph partition wrappers.

If we're using Inductor-based graph partitioning, we currently have the
whole `fx.Graph` before Inductor lowering and and the piecewise
splitting happens after all graph passes and fusions. Here, we add
a custom hook for Inductor to wrap each partition with our static
graph wrapper class to maintain more control over static graph
capture and replay.
"""
from vllm.config import CUDAGraphMode

compilation_config = vllm_config.compilation_config
if (compilation_config.cudagraph_mode != CUDAGraphMode.NONE
and compilation_config.use_inductor_graph_partition):
from torch._inductor.utils import CUDAGraphWrapperMetadata

from vllm.compilation.cuda_graph import CUDAGraphOptions
from vllm.platforms import current_platform

static_graph_wrapper_class = resolve_obj_by_qualname(
current_platform.get_static_graph_wrapper_cls())

def customized_cudagraph_wrapper(f,
metadata: CUDAGraphWrapperMetadata):
partition_id = metadata.partition_index
num_partitions = metadata.num_partitions
return static_graph_wrapper_class(
runnable=f,
vllm_config=vllm_config,
runtime_mode=CUDAGraphMode.PIECEWISE,
cudagraph_options=CUDAGraphOptions(
debug_log_enable=partition_id == 0,
gc_disable=partition_id != 0,
weak_ref_output=partition_id == num_partitions - 1,
))

torch._inductor.utils.set_customized_partition_wrappers(
customized_cudagraph_wrapper)

yield

if (compilation_config.cudagraph_mode != CUDAGraphMode.NONE
and compilation_config.use_inductor_graph_partition):
torch._inductor.utils.set_customized_partition_wrappers(None)
Loading