Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
a012257
renaming to piecewise_backend.py
fhl2000 Aug 16, 2025
c37583c
dispatch cascade attention to NONE or PIECEWISE runtime mode;clean up…
fhl2000 Aug 17, 2025
5e998c2
Merge remote-tracking branch 'origin/main' into post_issues_20059
fhl2000 Aug 17, 2025
648fbb3
apply suggestion from bot
fhl2000 Aug 17, 2025
3c88284
Merge branch 'main' into post_issues_20059
fhl2000 Aug 17, 2025
868c85f
fix bug when attn_metadata have no use_cascade
fhl2000 Aug 17, 2025
3bef9e4
simple dispatching test
fhl2000 Aug 17, 2025
05cf012
minor comment tweak
fhl2000 Aug 17, 2025
20d8afb
address comments part1
fhl2000 Aug 22, 2025
9648a84
Merge branch 'main' into post_issues_20059
fhl2000 Aug 23, 2025
da79494
fix comments part2
fhl2000 Aug 23, 2025
01083c3
fix double validation of deprecated flag
fhl2000 Aug 27, 2025
ca3c2c9
Merge branch 'main' into post_issues_20059
fhl2000 Aug 28, 2025
59db6b1
Merge branch 'main' into post_issues_20059
fhl2000 Aug 29, 2025
11421de
Merge branch 'main' into post_issues_20059
fhl2000 Sep 4, 2025
2bf5569
fix pre-commit
fhl2000 Sep 4, 2025
bd1762a
Merge remote-tracking branch 'origin/main' into post_issues_20059
fhl2000 Sep 9, 2025
2a50ecc
Merge branch 'main' into post_issues_20059
fhl2000 Sep 10, 2025
4aef453
Merge branch 'main' into post_issues_20059
fhl2000 Sep 14, 2025
ff3c671
Merge remote-tracking branch 'origin/main' into post_issues_20059
fhl2000 Sep 19, 2025
27eecc2
disable cascade_attn when DBO
fhl2000 Sep 19, 2025
8d3ecc8
Merge branch 'main' into post_issues_20059
fhl2000 Sep 20, 2025
48a8c7f
pre-commit
fhl2000 Sep 20, 2025
f3e08f3
remove piecewise cudagraph wrapper when no needed
fhl2000 Sep 23, 2025
3faff97
simplify set_splitting_ops_for_v1
fhl2000 Sep 23, 2025
f09e47f
resolve merged from main
fhl2000 Sep 23, 2025
b8894f2
Merge branch 'main' into post_issues_20059
fhl2000 Sep 24, 2025
92cbd4f
pre-commit
fhl2000 Sep 24, 2025
df90576
modify comments for full_cuda_graph
fhl2000 Sep 24, 2025
6176761
Merge branch 'main' into post_issues_20059
ProExpertProg Sep 24, 2025
4679802
address comment;add test for splitting_ops
fhl2000 Sep 24, 2025
5475e9e
fix profile_run log
fhl2000 Sep 25, 2025
413079b
temporary disable cascade attention
fhl2000 Sep 25, 2025
254bfd3
Merge branch 'main' into post_issues_20059
fhl2000 Sep 25, 2025
f584663
Merge branch 'main' into post_issues_20059
fhl2000 Sep 25, 2025
d8a1ad7
recover
fhl2000 Sep 25, 2025
891723a
Merge remote-tracking branch 'origin/main' into post_issues_20059
fhl2000 Sep 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 1 addition & 85 deletions tests/compile/piecewise/test_full_cudagraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
import contextlib
import os
import weakref
from dataclasses import dataclass
from typing import Optional

import pytest

from tests.utils import wait_for_gpu_memory_to_clear
from tests.v1.attention.utils import full_cg_backend_configs as backend_configs
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig
from vllm.platforms import current_platform
Expand All @@ -33,89 +32,6 @@ def temporary_environ(env_vars):
os.environ[k] = v


@dataclass
class BackendConfig:
name: str
env_vars: dict
comp_config: dict
specific_gpu_arch: Optional[tuple] = None


# Define all backend configurations of full cudagraph to be tested
backend_configs = {
# FA3 on Hopper
"FA3":
BackendConfig(name="FA3",
env_vars={
"VLLM_FLASH_ATTN_VERSION": "3",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={
"cudagraph_mode": "FULL",
},
specific_gpu_arch=(9, 0)),
# FlashMLA on Hopper
"FlashMLA":
BackendConfig(name="FlashMLA",
env_vars={
"VLLM_ATTENTION_BACKEND": "FLASHMLA",
},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
},
specific_gpu_arch=(9, 0)),
# FlashAttention MLA on Hopper
"FlashAttentionMLA":
BackendConfig(name="FlashAttentionMLA",
env_vars={
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={
"cudagraph_mode": "FULL_DECODE_ONLY",
},
specific_gpu_arch=(9, 0)),
# Cutlass MLA on Blackwell
"CutlassMLA":
BackendConfig(
name="CutlassMLA",
env_vars={
"VLLM_USE_V1": "1",
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
"FORCE_NUM_KV_SPLITS":
"1", # TODO: remove this when hang issue is fixed
},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
"cudagraph_capture_sizes": [16, 32, 64, 128, 256, 512],
},
specific_gpu_arch=(10, 0)),
# FA2
"FA2":
BackendConfig(name="FA2",
env_vars={
"VLLM_FLASH_ATTN_VERSION": "2",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={
"cudagraph_mode": "FULL",
}),
# Triton Attention
"TritonAttn":
BackendConfig(name="TritonAttn",
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN"},
comp_config={
"cudagraph_mode": "FULL",
}),
# FlashInfer
"FlashInfer":
BackendConfig(name="FlashInfer",
env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
}),
}

test_params_full_cudagraph = []

# deepseek-ai/DeepSeek-V2-Lite with MLA
Expand Down
67 changes: 65 additions & 2 deletions tests/compile/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import vllm
from vllm.compilation.counter import compilation_counter
from vllm.config import CompilationConfig, VllmConfig
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
from vllm.utils import _is_torch_equal_or_newer


Expand Down Expand Up @@ -106,7 +106,6 @@ def test_dynamo_as_is(vllm_runner, monkeypatch):
def test_no_compilation(vllm_runner, monkeypatch):
# Disable multiprocessing so that the counter is in the same process
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')

with (
compilation_counter.expect(num_graphs_seen=0,
dynamo_as_is_count=0),
Expand All @@ -131,3 +130,67 @@ def test_enforce_eager(vllm_runner, monkeypatch):
enforce_eager=True,
gpu_memory_utilization=0.4) as _):
pass


def test_splitting_ops_dynamic():
# Default config
config = VllmConfig()
assert config.compilation_config.cudagraph_mode == \
CUDAGraphMode.FULL_AND_PIECEWISE
assert config.compilation_config.splitting_ops_contain_attention()

# When use_inductor_graph_partition=True
if _is_torch_equal_or_newer('2.9.0.dev'):
# inductor graph partition is only available in PyTorch 2.9+.
# this is a fast config check so we are not using pytest.skip.
config = VllmConfig(compilation_config=CompilationConfig(
use_inductor_graph_partition=True,
splitting_ops=["silly_attention"]))
# should ignore splitting_ops
assert config.compilation_config.splitting_ops == []

# When attn_fusion pass enabled.
config = VllmConfig(compilation_config=CompilationConfig(
pass_config={
"enable_attn_fusion": True,
"enable_noop": True
},
custom_ops=["+quant_fp8"],
cudagraph_mode=CUDAGraphMode.PIECEWISE,
))
assert config.compilation_config.splitting_ops == []
# cudagraph mode also fall back to FULL
assert config.compilation_config.cudagraph_mode == \
CUDAGraphMode.FULL

# splitting_ops can not contain attention ops when attn_fusion
# pass enabled.
with pytest.raises(AssertionError):
config = VllmConfig(compilation_config=CompilationConfig(
pass_config={
"enable_attn_fusion": True,
"enable_noop": True
},
custom_ops=["+quant_fp8"],
cudagraph_mode=CUDAGraphMode.PIECEWISE,
# work around for accessing all attntion ops
splitting_ops=CompilationConfig()._attention_ops,
))

# When both use_inductor_graph_partition and attn_fusion pass enabled.
if _is_torch_equal_or_newer('2.9.0.dev'):
config = VllmConfig(compilation_config=CompilationConfig(
use_inductor_graph_partition=True,
pass_config={
"enable_attn_fusion": True,
"enable_noop": True
},
custom_ops=["+quant_fp8"],
cudagraph_mode=CUDAGraphMode.PIECEWISE,
))
assert config.compilation_config.splitting_ops == []
# enable_attn_fusion is directly support under
# use_inductor_graph_partition=True, and cudagraph_mode
# is unchanged.
assert config.compilation_config.cudagraph_mode == \
CUDAGraphMode.PIECEWISE
87 changes: 86 additions & 1 deletion tests/v1/attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""Utility functions for attention-related v1 tests."""

from dataclasses import dataclass
from typing import Union
from typing import Optional, Union

import pytest
import torch
Expand Down Expand Up @@ -260,3 +260,88 @@ def create_dummy_kv_cache(block_size: int,
dtype=dtype,
device=device)
return kv_cache


@dataclass
class BackendConfig:
name: str
env_vars: dict
comp_config: dict # compilation config
specific_gpu_arch: Optional[tuple] = None


# Define all backend configurations of full cudagraph to be tested
full_cg_backend_configs = {
# FA3 on Hopper
"FA3":
BackendConfig(name="FA3",
env_vars={
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN",
"VLLM_FLASH_ATTN_VERSION": "3",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={
"cudagraph_mode": "FULL",
},
specific_gpu_arch=(9, 0)),
# FlashMLA on Hopper
"FlashMLA":
BackendConfig(name="FlashMLA",
env_vars={
"VLLM_ATTENTION_BACKEND": "FLASHMLA",
},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
},
specific_gpu_arch=(9, 0)),
# Cutlass MLA on Blackwell
"CutlassMLA":
BackendConfig(
name="CutlassMLA",
env_vars={
"VLLM_USE_V1": "1",
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
"FORCE_NUM_KV_SPLITS":
"1", # TODO: remove this when hang issue is fixed
},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
},
specific_gpu_arch=(10, 0)),
# FlashAttention MLA on Hopper
"FlashAttentionMLA":
BackendConfig(name="FlashAttentionMLA",
env_vars={
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={
"cudagraph_mode": "FULL_DECODE_ONLY",
},
specific_gpu_arch=(9, 0)),
# FA2
"FA2":
BackendConfig(name="FA2",
env_vars={
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN",
"VLLM_FLASH_ATTN_VERSION": "2",
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
}),
# Triton Attention
"TritonAttn":
BackendConfig(name="TritonAttn",
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN"},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
}),
# FlashInfer
"FlashInfer":
BackendConfig(name="FlashInfer",
env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
}),
}
61 changes: 26 additions & 35 deletions tests/v1/cudagraph/test_cudagraph_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,39 +45,22 @@ def _create_vllm_config(compilation_config: CompilationConfig,
class TestCudagraphDispatcher:

@pytest.mark.parametrize(
"params",
"case_id,cudagraph_mode_str,compilation_level",
[
# Test case 0: Full CG for mixed batches, no separate routine
{
"case_id": 0,
"cudagraph_mode": "FULL",
"compilation_level": CompilationLevel.NO_COMPILATION,
},
(0, "FULL", CompilationLevel.NO_COMPILATION),
# Test case 1: Full CG for uniform batches, piecewise for mixed
{
"case_id": 1,
"cudagraph_mode": "FULL_AND_PIECEWISE",
"compilation_level": CompilationLevel.PIECEWISE,
},
(1, "FULL_AND_PIECEWISE", CompilationLevel.NO_COMPILATION),
# Test case 2: Full CG for uniform batches, no CG for mixed
{
"case_id": 2,
"cudagraph_mode": "FULL_DECODE_ONLY",
"compilation_level": CompilationLevel.NO_COMPILATION,
},
(2, "FULL_DECODE_ONLY", CompilationLevel.NO_COMPILATION),
# Test case 3: Piecewise for all
{
"case_id": 3,
"cudagraph_mode": "PIECEWISE",
"compilation_level": CompilationLevel.PIECEWISE,
},
(3, "PIECEWISE", CompilationLevel.PIECEWISE),
])
def test_dispatcher(self, params):
def test_dispatcher(self, cudagraph_mode_str, compilation_level):
# Setup dispatcher
comp_config = CompilationConfig(
cudagraph_mode=params["cudagraph_mode"],
level=params["compilation_level"],
cudagraph_capture_sizes=[1, 8])
comp_config = CompilationConfig(cudagraph_mode=cudagraph_mode_str,
level=compilation_level,
cudagraph_capture_sizes=[1, 8])

config = _create_vllm_config(comp_config, max_num_seqs=8)
dispatcher = CudagraphDispatcher(config)
Expand All @@ -86,11 +69,11 @@ def test_dispatcher(self, params):
uniform_decode_query_len=1)

# Verify the key is initialized correctly
if params["cudagraph_mode"] in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
if cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 2
else:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 0
if params["cudagraph_mode"] not in ["NONE", "PIECEWISE"]:
if cudagraph_mode_str not in ["NONE", "PIECEWISE"]:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 2
else:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 0
Expand All @@ -99,10 +82,10 @@ def test_dispatcher(self, params):
# 1. non-uniform batch, size in cudagraph size list
desc_full_exact = BatchDescriptor(num_tokens=8, uniform_decode=False)
rt_mode, key = dispatcher.dispatch(desc_full_exact)
if params["cudagraph_mode"] == "FULL":
if cudagraph_mode_str == "FULL":
assert rt_mode == CUDAGraphMode.FULL
assert key == desc_full_exact
elif params["cudagraph_mode"] in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
elif cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
assert rt_mode == CUDAGraphMode.PIECEWISE
assert key == desc_full_exact
else:
Expand All @@ -111,15 +94,13 @@ def test_dispatcher(self, params):
# 2. uniform decode batch, size in cudagraph size list
desc_uniform_exact = BatchDescriptor(num_tokens=8, uniform_decode=True)
rt_mode, key = dispatcher.dispatch(desc_uniform_exact)
if params["cudagraph_mode"] == "FULL":
if cudagraph_mode_str == "FULL":
assert rt_mode == CUDAGraphMode.FULL
assert key == desc_uniform_exact.non_uniform
elif params["cudagraph_mode"] in [
"FULL_DECODE_ONLY", "FULL_AND_PIECEWISE"
]:
elif cudagraph_mode_str in ["FULL_DECODE_ONLY", "FULL_AND_PIECEWISE"]:
assert rt_mode == CUDAGraphMode.FULL
assert key == desc_uniform_exact
elif params["cudagraph_mode"] == "PIECEWISE":
elif cudagraph_mode_str == "PIECEWISE":
assert rt_mode == CUDAGraphMode.PIECEWISE
assert key == desc_uniform_exact.non_uniform
else:
Expand All @@ -131,6 +112,16 @@ def test_dispatcher(self, params):
assert rt_mode == CUDAGraphMode.NONE
assert key is None

# 4. Cascade attention should have a fall back mode
desc_full_exact = BatchDescriptor(num_tokens=8, uniform_decode=False)
rt_mode, key = dispatcher.dispatch(desc_full_exact,
use_cascade_attn=True)
if "PIECEWISE" in cudagraph_mode_str: # string contains check
assert rt_mode == CUDAGraphMode.PIECEWISE
assert key == desc_full_exact.non_uniform
else:
assert rt_mode == CUDAGraphMode.NONE


@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
class TestCUDAGraphWrapper:
Expand Down
Loading