diff --git a/tests/compile/piecewise/test_full_cudagraph.py b/tests/compile/piecewise/test_full_cudagraph.py index b02c1b565671..9906e49bb110 100644 --- a/tests/compile/piecewise/test_full_cudagraph.py +++ b/tests/compile/piecewise/test_full_cudagraph.py @@ -3,12 +3,11 @@ import contextlib import os import weakref -from dataclasses import dataclass -from typing import Optional import pytest from tests.utils import wait_for_gpu_memory_to_clear +from tests.v1.attention.utils import full_cg_backend_configs as backend_configs from vllm import LLM, SamplingParams from vllm.config import CompilationConfig from vllm.platforms import current_platform @@ -33,89 +32,6 @@ def temporary_environ(env_vars): os.environ[k] = v -@dataclass -class BackendConfig: - name: str - env_vars: dict - comp_config: dict - specific_gpu_arch: Optional[tuple] = None - - -# Define all backend configurations of full cudagraph to be tested -backend_configs = { - # FA3 on Hopper - "FA3": - BackendConfig(name="FA3", - env_vars={ - "VLLM_FLASH_ATTN_VERSION": "3", - "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16", - }, - comp_config={ - "cudagraph_mode": "FULL", - }, - specific_gpu_arch=(9, 0)), - # FlashMLA on Hopper - "FlashMLA": - BackendConfig(name="FlashMLA", - env_vars={ - "VLLM_ATTENTION_BACKEND": "FLASHMLA", - }, - comp_config={ - "cudagraph_mode": "FULL_AND_PIECEWISE", - }, - specific_gpu_arch=(9, 0)), - # FlashAttention MLA on Hopper - "FlashAttentionMLA": - BackendConfig(name="FlashAttentionMLA", - env_vars={ - "VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA", - "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16", - }, - comp_config={ - "cudagraph_mode": "FULL_DECODE_ONLY", - }, - specific_gpu_arch=(9, 0)), - # Cutlass MLA on Blackwell - "CutlassMLA": - BackendConfig( - name="CutlassMLA", - env_vars={ - "VLLM_USE_V1": "1", - "VLLM_ATTENTION_BACKEND": "CUTLASS_MLA", - "FORCE_NUM_KV_SPLITS": - "1", # TODO: remove this when hang issue is fixed - }, - comp_config={ - "cudagraph_mode": "FULL_AND_PIECEWISE", - "cudagraph_capture_sizes": [16, 32, 64, 128, 256, 512], - }, - specific_gpu_arch=(10, 0)), - # FA2 - "FA2": - BackendConfig(name="FA2", - env_vars={ - "VLLM_FLASH_ATTN_VERSION": "2", - "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16", - }, - comp_config={ - "cudagraph_mode": "FULL", - }), - # Triton Attention - "TritonAttn": - BackendConfig(name="TritonAttn", - env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN"}, - comp_config={ - "cudagraph_mode": "FULL", - }), - # FlashInfer - "FlashInfer": - BackendConfig(name="FlashInfer", - env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"}, - comp_config={ - "cudagraph_mode": "FULL_AND_PIECEWISE", - }), -} - test_params_full_cudagraph = [] # deepseek-ai/DeepSeek-V2-Lite with MLA diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index 7afd6251bbbd..17d3f0b37768 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -4,7 +4,7 @@ import vllm from vllm.compilation.counter import compilation_counter -from vllm.config import CompilationConfig, VllmConfig +from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig from vllm.utils import _is_torch_equal_or_newer @@ -106,7 +106,6 @@ def test_dynamo_as_is(vllm_runner, monkeypatch): def test_no_compilation(vllm_runner, monkeypatch): # Disable multiprocessing so that the counter is in the same process monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0') - with ( compilation_counter.expect(num_graphs_seen=0, dynamo_as_is_count=0), @@ -131,3 +130,67 @@ def test_enforce_eager(vllm_runner, monkeypatch): enforce_eager=True, gpu_memory_utilization=0.4) as _): pass + + +def test_splitting_ops_dynamic(): + # Default config + config = VllmConfig() + assert config.compilation_config.cudagraph_mode == \ + CUDAGraphMode.FULL_AND_PIECEWISE + assert config.compilation_config.splitting_ops_contain_attention() + + # When use_inductor_graph_partition=True + if _is_torch_equal_or_newer('2.9.0.dev'): + # inductor graph partition is only available in PyTorch 2.9+. + # this is a fast config check so we are not using pytest.skip. + config = VllmConfig(compilation_config=CompilationConfig( + use_inductor_graph_partition=True, + splitting_ops=["silly_attention"])) + # should ignore splitting_ops + assert config.compilation_config.splitting_ops == [] + + # When attn_fusion pass enabled. + config = VllmConfig(compilation_config=CompilationConfig( + pass_config={ + "enable_attn_fusion": True, + "enable_noop": True + }, + custom_ops=["+quant_fp8"], + cudagraph_mode=CUDAGraphMode.PIECEWISE, + )) + assert config.compilation_config.splitting_ops == [] + # cudagraph mode also fall back to FULL + assert config.compilation_config.cudagraph_mode == \ + CUDAGraphMode.FULL + + # splitting_ops can not contain attention ops when attn_fusion + # pass enabled. + with pytest.raises(AssertionError): + config = VllmConfig(compilation_config=CompilationConfig( + pass_config={ + "enable_attn_fusion": True, + "enable_noop": True + }, + custom_ops=["+quant_fp8"], + cudagraph_mode=CUDAGraphMode.PIECEWISE, + # work around for accessing all attntion ops + splitting_ops=CompilationConfig()._attention_ops, + )) + + # When both use_inductor_graph_partition and attn_fusion pass enabled. + if _is_torch_equal_or_newer('2.9.0.dev'): + config = VllmConfig(compilation_config=CompilationConfig( + use_inductor_graph_partition=True, + pass_config={ + "enable_attn_fusion": True, + "enable_noop": True + }, + custom_ops=["+quant_fp8"], + cudagraph_mode=CUDAGraphMode.PIECEWISE, + )) + assert config.compilation_config.splitting_ops == [] + # enable_attn_fusion is directly support under + # use_inductor_graph_partition=True, and cudagraph_mode + # is unchanged. + assert config.compilation_config.cudagraph_mode == \ + CUDAGraphMode.PIECEWISE diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index 6f8c5ea50ef0..01b5de83a59a 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -3,7 +3,7 @@ """Utility functions for attention-related v1 tests.""" from dataclasses import dataclass -from typing import Union +from typing import Optional, Union import pytest import torch @@ -260,3 +260,88 @@ def create_dummy_kv_cache(block_size: int, dtype=dtype, device=device) return kv_cache + + +@dataclass +class BackendConfig: + name: str + env_vars: dict + comp_config: dict # compilation config + specific_gpu_arch: Optional[tuple] = None + + +# Define all backend configurations of full cudagraph to be tested +full_cg_backend_configs = { + # FA3 on Hopper + "FA3": + BackendConfig(name="FA3", + env_vars={ + "VLLM_ATTENTION_BACKEND": "FLASH_ATTN", + "VLLM_FLASH_ATTN_VERSION": "3", + "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16", + }, + comp_config={ + "cudagraph_mode": "FULL", + }, + specific_gpu_arch=(9, 0)), + # FlashMLA on Hopper + "FlashMLA": + BackendConfig(name="FlashMLA", + env_vars={ + "VLLM_ATTENTION_BACKEND": "FLASHMLA", + }, + comp_config={ + "cudagraph_mode": "FULL_AND_PIECEWISE", + }, + specific_gpu_arch=(9, 0)), + # Cutlass MLA on Blackwell + "CutlassMLA": + BackendConfig( + name="CutlassMLA", + env_vars={ + "VLLM_USE_V1": "1", + "VLLM_ATTENTION_BACKEND": "CUTLASS_MLA", + "FORCE_NUM_KV_SPLITS": + "1", # TODO: remove this when hang issue is fixed + }, + comp_config={ + "cudagraph_mode": "FULL_AND_PIECEWISE", + }, + specific_gpu_arch=(10, 0)), + # FlashAttention MLA on Hopper + "FlashAttentionMLA": + BackendConfig(name="FlashAttentionMLA", + env_vars={ + "VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA", + "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16", + }, + comp_config={ + "cudagraph_mode": "FULL_DECODE_ONLY", + }, + specific_gpu_arch=(9, 0)), + # FA2 + "FA2": + BackendConfig(name="FA2", + env_vars={ + "VLLM_ATTENTION_BACKEND": "FLASH_ATTN", + "VLLM_FLASH_ATTN_VERSION": "2", + "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16", + }, + comp_config={ + "cudagraph_mode": "FULL_AND_PIECEWISE", + }), + # Triton Attention + "TritonAttn": + BackendConfig(name="TritonAttn", + env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN"}, + comp_config={ + "cudagraph_mode": "FULL_AND_PIECEWISE", + }), + # FlashInfer + "FlashInfer": + BackendConfig(name="FlashInfer", + env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"}, + comp_config={ + "cudagraph_mode": "FULL_AND_PIECEWISE", + }), +} diff --git a/tests/v1/cudagraph/test_cudagraph_dispatch.py b/tests/v1/cudagraph/test_cudagraph_dispatch.py index 64f2fa462802..b6b85e4440d0 100644 --- a/tests/v1/cudagraph/test_cudagraph_dispatch.py +++ b/tests/v1/cudagraph/test_cudagraph_dispatch.py @@ -45,39 +45,22 @@ def _create_vllm_config(compilation_config: CompilationConfig, class TestCudagraphDispatcher: @pytest.mark.parametrize( - "params", + "case_id,cudagraph_mode_str,compilation_level", [ # Test case 0: Full CG for mixed batches, no separate routine - { - "case_id": 0, - "cudagraph_mode": "FULL", - "compilation_level": CompilationLevel.NO_COMPILATION, - }, + (0, "FULL", CompilationLevel.NO_COMPILATION), # Test case 1: Full CG for uniform batches, piecewise for mixed - { - "case_id": 1, - "cudagraph_mode": "FULL_AND_PIECEWISE", - "compilation_level": CompilationLevel.PIECEWISE, - }, + (1, "FULL_AND_PIECEWISE", CompilationLevel.NO_COMPILATION), # Test case 2: Full CG for uniform batches, no CG for mixed - { - "case_id": 2, - "cudagraph_mode": "FULL_DECODE_ONLY", - "compilation_level": CompilationLevel.NO_COMPILATION, - }, + (2, "FULL_DECODE_ONLY", CompilationLevel.NO_COMPILATION), # Test case 3: Piecewise for all - { - "case_id": 3, - "cudagraph_mode": "PIECEWISE", - "compilation_level": CompilationLevel.PIECEWISE, - }, + (3, "PIECEWISE", CompilationLevel.PIECEWISE), ]) - def test_dispatcher(self, params): + def test_dispatcher(self, cudagraph_mode_str, compilation_level): # Setup dispatcher - comp_config = CompilationConfig( - cudagraph_mode=params["cudagraph_mode"], - level=params["compilation_level"], - cudagraph_capture_sizes=[1, 8]) + comp_config = CompilationConfig(cudagraph_mode=cudagraph_mode_str, + level=compilation_level, + cudagraph_capture_sizes=[1, 8]) config = _create_vllm_config(comp_config, max_num_seqs=8) dispatcher = CudagraphDispatcher(config) @@ -86,11 +69,11 @@ def test_dispatcher(self, params): uniform_decode_query_len=1) # Verify the key is initialized correctly - if params["cudagraph_mode"] in ["FULL_AND_PIECEWISE", "PIECEWISE"]: + if cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]: assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 2 else: assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 0 - if params["cudagraph_mode"] not in ["NONE", "PIECEWISE"]: + if cudagraph_mode_str not in ["NONE", "PIECEWISE"]: assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 2 else: assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 0 @@ -99,10 +82,10 @@ def test_dispatcher(self, params): # 1. non-uniform batch, size in cudagraph size list desc_full_exact = BatchDescriptor(num_tokens=8, uniform_decode=False) rt_mode, key = dispatcher.dispatch(desc_full_exact) - if params["cudagraph_mode"] == "FULL": + if cudagraph_mode_str == "FULL": assert rt_mode == CUDAGraphMode.FULL assert key == desc_full_exact - elif params["cudagraph_mode"] in ["FULL_AND_PIECEWISE", "PIECEWISE"]: + elif cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]: assert rt_mode == CUDAGraphMode.PIECEWISE assert key == desc_full_exact else: @@ -111,15 +94,13 @@ def test_dispatcher(self, params): # 2. uniform decode batch, size in cudagraph size list desc_uniform_exact = BatchDescriptor(num_tokens=8, uniform_decode=True) rt_mode, key = dispatcher.dispatch(desc_uniform_exact) - if params["cudagraph_mode"] == "FULL": + if cudagraph_mode_str == "FULL": assert rt_mode == CUDAGraphMode.FULL assert key == desc_uniform_exact.non_uniform - elif params["cudagraph_mode"] in [ - "FULL_DECODE_ONLY", "FULL_AND_PIECEWISE" - ]: + elif cudagraph_mode_str in ["FULL_DECODE_ONLY", "FULL_AND_PIECEWISE"]: assert rt_mode == CUDAGraphMode.FULL assert key == desc_uniform_exact - elif params["cudagraph_mode"] == "PIECEWISE": + elif cudagraph_mode_str == "PIECEWISE": assert rt_mode == CUDAGraphMode.PIECEWISE assert key == desc_uniform_exact.non_uniform else: @@ -131,6 +112,16 @@ def test_dispatcher(self, params): assert rt_mode == CUDAGraphMode.NONE assert key is None + # 4. Cascade attention should have a fall back mode + desc_full_exact = BatchDescriptor(num_tokens=8, uniform_decode=False) + rt_mode, key = dispatcher.dispatch(desc_full_exact, + use_cascade_attn=True) + if "PIECEWISE" in cudagraph_mode_str: # string contains check + assert rt_mode == CUDAGraphMode.PIECEWISE + assert key == desc_full_exact.non_uniform + else: + assert rt_mode == CUDAGraphMode.NONE + @pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda") class TestCUDAGraphWrapper: diff --git a/tests/v1/cudagraph/test_cudagraph_mode.py b/tests/v1/cudagraph/test_cudagraph_mode.py index 41a9493cbe58..c4116247bb7c 100644 --- a/tests/v1/cudagraph/test_cudagraph_mode.py +++ b/tests/v1/cudagraph/test_cudagraph_mode.py @@ -4,12 +4,11 @@ import os import weakref from contextlib import ExitStack -from dataclasses import dataclass -from typing import Optional import pytest from tests.utils import wait_for_gpu_memory_to_clear +from tests.v1.attention.utils import full_cg_backend_configs as backend_configs from vllm import LLM from vllm.config import CompilationConfig from vllm.platforms import current_platform @@ -34,74 +33,6 @@ def temporary_environ(env_vars): os.environ[k] = v -@dataclass -class BackendConfig: - name: str - env_vars: dict - comp_config: dict - specific_gpu_arch: Optional[tuple] = None - - -# Define all backend configurations of full cudagraph to be tested -backend_configs = { - # FA3 on Hopper - "FA3": - BackendConfig(name="FA3", - env_vars={ - "VLLM_FLASH_ATTN_VERSION": "3", - "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16", - }, - comp_config={ - "cudagraph_mode": "FULL", - }, - specific_gpu_arch=(9, 0)), - # FlashMLA on Hopper - "FlashMLA": - BackendConfig(name="FlashMLA", - env_vars={ - "VLLM_ATTENTION_BACKEND": "FLASHMLA", - }, - comp_config={ - "cudagraph_mode": "FULL_AND_PIECEWISE", - }, - specific_gpu_arch=(9, 0)), - # FlashAttention MLA on Hopper - "FlashAttentionMLA": - BackendConfig(name="FlashAttentionMLA", - env_vars={ - "VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA", - "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16", - }, - comp_config={ - "cudagraph_mode": "FULL_DECODE_ONLY", - }, - specific_gpu_arch=(9, 0)), - # FA2 - "FA2": - BackendConfig(name="FA2", - env_vars={ - "VLLM_FLASH_ATTN_VERSION": "2", - "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16", - }, - comp_config={ - "cudagraph_mode": "FULL_AND_PIECEWISE", - }), - # Triton Attention - "TritonAttn": - BackendConfig(name="TritonAttn", - env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN"}, - comp_config={ - "cudagraph_mode": "FULL_AND_PIECEWISE", - }), - # FlashInfer - "FlashInfer": - BackendConfig(name="FlashInfer", - env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"}, - comp_config={ - "cudagraph_mode": "FULL_AND_PIECEWISE", - }), -} - # test attention backend and cudagraph_mode combo # (backend_name, cudagraph_mode, supported) combo_cases_1 = [ @@ -114,9 +45,10 @@ class BackendConfig: ] -@pytest.mark.parametrize("combo_case", combo_cases_1) -def test_backend_and_cudagraph_mode_combo(combo_case): - backend_name, cudagraph_mode, supported = combo_case +@pytest.mark.parametrize("backend_name, cudagraph_mode, supported", + combo_cases_1) +def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, + supported): if backend_name == "FlashInfer": try: import flashinfer # noqa: F401 @@ -142,7 +74,7 @@ def test_backend_and_cudagraph_mode_combo(combo_case): compilation_config=CompilationConfig( level=3, cudagraph_mode=cudagraph_mode)) llm.generate(["Hello, my name is"] * 10) - + # when above code raises, `llm` may be undefined, so we need to catch that try: llm = weakref.proxy(llm) del llm @@ -173,7 +105,8 @@ def test_backend_and_cudagraph_mode_combo(combo_case): ] -@pytest.mark.parametrize("combo_case", combo_cases_2) +@pytest.mark.parametrize("backend_name,cudagraph_mode,compilation_level,"\ + "supported", combo_cases_2) def test_cudagraph_compilation_combo(combo_case): backend_name, cudagraph_mode, compilation_level, supported\ = combo_case @@ -192,6 +125,7 @@ def test_cudagraph_compilation_combo(combo_case): compilation_config=CompilationConfig( level=compilation_level, cudagraph_mode=cudagraph_mode)) llm.generate(["Hello, my name is"] * 10) + # when above code raises, `llm` may be undefined, so we need to catch that try: llm = weakref.proxy(llm) del llm diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 17fc727b8fc7..335bbda5e4eb 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -340,15 +340,15 @@ def call_module(self, target: torch.fx.node.Target, num_graphs=len(self.compile_submod_names), runtime_shape=None) # Lazy import here to avoid circular import - from .cuda_piecewise_backend import PiecewiseBackend + from .piecewise_backend import PiecewiseBackend piecewise_backend = PiecewiseBackend( submod, self.vllm_config, index, len(self.compile_submod_names), sym_shape_indices, compiled_graph_for_dynamic_shape, self.vllm_backend) - if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE - and + if (self.compilation_config.cudagraph_mode.\ + has_piecewise_cudagraphs() and not self.compilation_config.use_inductor_graph_partition): # We're using Dynamo-based piecewise splitting, so we wrap # the whole subgraph with a static graph wrapper. diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 6e9a36a2b0b9..fa38cfe49a91 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -336,7 +336,7 @@ def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig): from vllm.config import CUDAGraphMode compilation_config = vllm_config.compilation_config - if (compilation_config.cudagraph_mode != CUDAGraphMode.NONE + if (compilation_config.cudagraph_mode.has_piecewise_cudagraphs() and compilation_config.use_inductor_graph_partition): from torch._inductor.utils import CUDAGraphWrapperMetadata @@ -365,7 +365,7 @@ def customized_cudagraph_wrapper(f, yield - if (compilation_config.cudagraph_mode != CUDAGraphMode.NONE + if (compilation_config.cudagraph_mode.has_piecewise_cudagraphs() and compilation_config.use_inductor_graph_partition): torch._inductor.utils.set_customized_partition_wrappers(None) diff --git a/vllm/compilation/cuda_piecewise_backend.py b/vllm/compilation/piecewise_backend.py similarity index 100% rename from vllm/compilation/cuda_piecewise_backend.py rename to vllm/compilation/piecewise_backend.py diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 2da9d8f4f3ea..3f19796fd102 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -458,15 +458,22 @@ def __post_init__(self): "to True to enable.") current_platform.check_and_update_config(self) - # final check of cudagraph mode after platform-specific update + # Do this after all the updates to compilation_config.level + if envs.VLLM_USE_V1 and \ + self.compilation_config.level == CompilationLevel.PIECEWISE: + self.compilation_config.set_splitting_ops_for_v1() + + # final check of cudagraph mode after all possible updates if envs.VLLM_USE_V1 and current_platform.is_cuda_alike(): - if self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL \ + if self.compilation_config.cudagraph_mode.has_full_cudagraphs()\ and self.model_config is not None and \ - not self.model_config.disable_cascade_attn: - logger.info("CUDAGraphMode.FULL is not supported with " - "cascade attention currently. Disabling cascade" - "attention.") - self.model_config.disable_cascade_attn = True + not self.model_config.disable_cascade_attn and\ + not self.compilation_config.cudagraph_mode.\ + has_piecewise_cudagraphs(): + logger.warning_once( + "No piecewise cudagraph for executing cascade attention." + " Will fall back to eager execution if a batch runs " + "into cascade attentions") if self.compilation_config.cudagraph_mode\ .requires_piecewise_compilation(): @@ -476,6 +483,12 @@ def __post_init__(self): "when cudagraph_mode piecewise cudagraphs is used, "\ f"cudagraph_mode={self.compilation_config.cudagraph_mode}" + # final migrate the deprecated flags + self.compilation_config.use_cudagraph = self.compilation_config.\ + cudagraph_mode!= CUDAGraphMode.NONE + self.compilation_config.full_cuda_graph = self.compilation_config.\ + cudagraph_mode.has_full_cudagraphs() + if self.parallel_config.enable_dbo: a2a_backend = envs.VLLM_ALL2ALL_BACKEND assert a2a_backend in \ @@ -486,14 +499,14 @@ def __post_init__(self): "variable to deepep_low_latency or deepep_high_throughput and "\ "install the DeepEP kernels." + if not self.model_config.disable_cascade_attn: + self.model_config.disable_cascade_attn = True + logger.warning_once( + "Disabling cascade attention when DBO is enabled.") + if not self.instance_id: self.instance_id = random_uuid()[:5] - # Do this after all the updates to compilation_config.level - if envs.VLLM_USE_V1 and \ - self.compilation_config.level == CompilationLevel.PIECEWISE: - self.compilation_config.set_splitting_ops_for_v1() - if (envs.VLLM_USE_V1 and not self.scheduler_config.disable_hybrid_kv_cache_manager): # logger should only print warning message for hybrid models. As we diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 50fde9461a13..9735db98567d 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -61,9 +61,17 @@ def max_cudagraph_mode(self) -> 'CUDAGraphMode': def has_full_cudagraphs(self) -> bool: return self.max_cudagraph_mode() == CUDAGraphMode.FULL + def has_piecewise_cudagraphs(self) -> bool: + return self.requires_piecewise_compilation() + def separate_routine(self) -> bool: return isinstance(self.value, tuple) + def valid_runtime_modes(self) -> bool: + return self in [ + CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL + ] + @config @dataclass @@ -269,7 +277,8 @@ class CompilationConfig: Note that this is orthogonal to the cudagraph capture logic outside of compilation. Warning: This flag is deprecated and will be removed in the next major or - minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode instead. + minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode=PIECEWISE + instead. """ cudagraph_num_of_warmups: int = 0 """Number of warmup runs for cudagraph. @@ -294,7 +303,8 @@ class CompilationConfig: flag cannot be used together with splitting_ops. This may provide performance benefits for smaller models. Warning: This flag is deprecated and will be removed in the next major or - minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode instead. + minor release, i.e. v0.11.0 or v1.0.0. Please use cudagraph_mode= + FULL_AND_PIECEWISE instead. """ use_inductor_graph_partition: bool = False @@ -464,7 +474,8 @@ def __post_init__(self) -> None: if not self.use_cudagraph: logger.warning("use_cudagraph is deprecated, use " "cudagraph_mode=NONE instead.") - if self.cudagraph_mode is not None: + if self.cudagraph_mode is not None and \ + self.cudagraph_mode != CUDAGraphMode.NONE: raise ValueError( "use_cudagraph and cudagraph_mode are mutually" " exclusive, prefer cudagraph_mode since " @@ -473,7 +484,8 @@ def __post_init__(self) -> None: if self.full_cuda_graph: logger.warning("full_cuda_graph is deprecated, use " "cudagraph_mode=FULL instead.") - if self.cudagraph_mode is not None: + if self.cudagraph_mode is not None and \ + not self.cudagraph_mode.has_full_cudagraphs(): raise ValueError("full_cuda_graph and cudagraph_mode are " "mutually exclusive, prefer cudagraph_mode " "since full_cuda_graph is deprecated.") @@ -570,48 +582,75 @@ def set_splitting_ops_for_v1(self): "set_splitting_ops_for_v1 should only be called when " "level is CompilationLevel.PIECEWISE") - use_inductor_graph_partition_msg = ( - "When use_inductor_graph_partition=True, splitting_ops " - "are ignored and set to an empty list. Instead, " - "\"tags=(torch._C.Tag.cudagraph_unsafe, ),\" is " - "used to annotate custom ops for graph partition.") + if self.use_inductor_graph_partition: + self.set_splitting_ops_for_inductor_graph_partition() + return + + if self.pass_config.enable_attn_fusion: + # here use_inductor_graph_partition is False + self.set_splitting_ops_for_attn_fusion() + return if self.splitting_ops is None: - if self.use_inductor_graph_partition: - # When using inductor graph partition, we set splitting_ops - # to be empty and rely on torch._C.Tag.cudagraph_unsafe to - # annotate custom ops as splitting ops. - logger.warning_once(use_inductor_graph_partition_msg) - self.splitting_ops = [] - else: - # NOTE: When using full cudagraph, instead of setting an empty - # list and capture the full cudagraph inside the flattened fx - # graph, we keep the piecewise fx graph structure but capture - # the full cudagraph outside the fx graph. This reduces some - # cpu overhead when the runtime batch_size is not cudagraph - # captured. see https://github.com/vllm-project/vllm/pull/20059 - # for details. make a copy to avoid mutating the class-level - # list via reference. - self.splitting_ops = list(self._attention_ops) + # NOTE: When using full cudagraph, instead of setting an empty + # list and capture the full cudagraph inside the flattened fx + # graph, we keep the piecewise fx graph structure but capture + # the full cudagraph outside the fx graph. This reduces some + # cpu overhead when the runtime batch_size is not cudagraph + # captured. see https://github.com/vllm-project/vllm/pull/20059 + # for details. Make a copy to avoid mutating the class-level + # list via reference. + self.splitting_ops = list(self._attention_ops) elif len(self.splitting_ops) == 0: logger.warning_once( - "Using piecewise compilation with empty " - "splitting_ops and use_inductor_graph_partition" - f"={self.use_inductor_graph_partition}.") - if (self.cudagraph_mode == CUDAGraphMode.PIECEWISE - and not self.use_inductor_graph_partition): + "Using piecewise compilation with empty splitting_ops") + if self.cudagraph_mode == CUDAGraphMode.PIECEWISE: + logger.warning_once( + "Piecewise compilation with empty splitting_ops do not" \ + "contains piecewise cudagraph. Setting cudagraph_" + "mode to NONE. Hint: If you are using attention backends " + "that support cudagraph, consider manually setting " + "cudagraph_mode to FULL or FULL_DECODE_ONLY to enable " + "full cudagraphs.") + self.cudagraph_mode = CUDAGraphMode.NONE + elif self.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE: logger.warning_once( - "When compilation level is piecewise with empty " - "splitting_ops, PIECEWISE cudagraph_mode will be " - "treated as FULL cudagraph_mode. Please ensure you are " - "using attention backends that support cudagraph or set " - "cudagraph_mode to NONE explicitly if encountering " - "any problems.") + "Piecewise compilation with empty splitting_ops do not " + "contains piecewise cudagraph. Setting cudagraph_mode " + "to FULL.") self.cudagraph_mode = CUDAGraphMode.FULL self.splitting_ops = [] - elif self.use_inductor_graph_partition: + + def set_splitting_ops_for_inductor_graph_partition(self): + assert self.use_inductor_graph_partition + use_inductor_graph_partition_msg = ( + "When use_inductor_graph_partition=True, splitting_ops " + "are ignored and set to an empty list. Instead, " + "\"tags=(torch._C.Tag.cudagraph_unsafe, ),\" is " + "used to annotate custom ops for graph partition.") + if self.splitting_ops is not None and \ + len(self.splitting_ops) > 0: logger.warning_once(use_inductor_graph_partition_msg) + self.splitting_ops = [] + + def set_splitting_ops_for_attn_fusion(self): + assert self.pass_config.enable_attn_fusion + if self.splitting_ops is None: self.splitting_ops = [] + if self.cudagraph_mode.has_piecewise_cudagraphs(): + logger.warning_once( + "enable_attn_fusion is incompatible with piecewise " + "cudagraph when use_inductor_graph_partition is off." + "In this case, splitting_ops will be set to empty " + "list, and cudagraph_mode will be set to FULL. " + "Please ensure you are using attention backends that " + "support cudagraph or set cudagraph_mode to NONE " + "explicitly if encountering any problems.") + self.cudagraph_mode = CUDAGraphMode.FULL + + assert not self.splitting_ops_contain_attention(), ( + "attention ops should not be in splitting_ops " + "when enable_attn_fusion is True") def splitting_ops_contain_attention(self) -> bool: return self.splitting_ops is not None and all( diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 3b535423f7bc..2bf4e1804521 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -246,8 +246,7 @@ class ForwardContext: ubatch_slices: Optional[UBatchSlices] = None def __post_init__(self): - assert self.cudagraph_runtime_mode in [ - CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \ + assert self.cudagraph_runtime_mode.valid_runtime_modes(), \ f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}" diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index ea4fba8eeea6..2dbe2bfb8082 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -22,10 +22,10 @@ class CudagraphDispatcher: At runtime, the dispatch method generates the runtime cudagraph mode (FULL, PIECEWISE, or NONE for no cudagraph) and the valid key (batch descriptor) - based on the input key. After dispatching (communicate via forward context), - the cudagraph wrappers will trust the dispatch key to do either capturing - or replaying (if mode matched), or pass through to the underlying runnable - without cudagraph (if mode no match or mode is NONE). + based on the input key. After dispatching (communicated via forward + context), the cudagraph wrappers will trust the dispatch key to either + capture or replay (if the mode matches), or pass through to the underlying + runnable without cudagraph (if the mode does not match or mode is NONE). """ def __init__(self, vllm_config: VllmConfig): @@ -57,19 +57,15 @@ def __init__(self, vllm_config: VllmConfig): def add_cudagraph_key(self, runtime_mode: CUDAGraphMode, batch_descriptor: BatchDescriptor): assert runtime_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \ - f"Invalid cudagraph runtime mode: {runtime_mode}" + f"Invalid cudagraph runtime mode for keys: {runtime_mode}" self.cudagraph_keys[runtime_mode].add(batch_descriptor) def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode, uniform_decode_query_len: int): # This should be called only after attention backend is initialized. - # Note: we create all valid keys possible for cudagraph but do not - # guarantee all keys would be used. For example, we create keys for - # piecewise cudagraphs when it is piecewise compilation, which is always - # valid, but for attention backend support unified routine, we may not - # trigger capturing/replaying the piecewise cudagraphs depending on - # CompilationConfig.cudagraph_mode. In addition, if we allow lazy + # Note: we create all valid keys for cudagraph here but do not + # guarantee all keys would be used. For example, if we allow lazy # capturing in future PR, some keys may never be triggered. if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: for bs in self.compilation_config.cudagraph_capture_sizes: @@ -94,10 +90,13 @@ def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode, self.keys_initialized = True def dispatch( - self, batch_descriptor: BatchDescriptor + self, + batch_descriptor: BatchDescriptor, + use_cascade_attn: bool = False ) -> tuple[CUDAGraphMode, Optional[BatchDescriptor]]: """ - Given a batch descriptor, dispatch to a cudagraph mode. + Given conditions(e.g.,batch descriptor and if using cascade attention), + dispatch to a cudagraph runtime mode and the valid batch descriptor. A new batch descriptor is returned as we might dispatch a uniform batch to a graph that supports a more general batch (uniform to non-uniform). """ @@ -107,14 +106,16 @@ def dispatch( "initialized. No cudagraph will be used.") return CUDAGraphMode.NONE, None - # check if key exists for full cudagraph - if batch_descriptor in self.cudagraph_keys[CUDAGraphMode.FULL]: - return CUDAGraphMode.FULL, batch_descriptor - - # otherwise, check if non-uniform key exists non_uniform_key = batch_descriptor.non_uniform - if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.FULL]: - return CUDAGraphMode.FULL, non_uniform_key + # if a batch use cascade attention, bypass checking full cudagraphs + if not use_cascade_attn: + # check if key exists for full cudagraph + if batch_descriptor in self.cudagraph_keys[CUDAGraphMode.FULL]: + return CUDAGraphMode.FULL, batch_descriptor + + # otherwise, check if non-uniform key exists + if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.FULL]: + return CUDAGraphMode.FULL, non_uniform_key # also check if non-uniform key exists for more "general" # piecewise cudagraph diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index cbf439aa697b..5ca01051452e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -923,11 +923,13 @@ def _prepare_inputs( ) -> tuple[PerLayerAttnMetadata, torch.Tensor, Optional[SpecDecodeMetadata], np.ndarray, Optional[CommonAttentionMetadata], int, Optional[UBatchSlices], - Optional[torch.Tensor]]: + Optional[torch.Tensor], bool]: """ :return: tuple[ attn_metadata: layer-to-attention_metadata mapping, - logits_indices, spec_decode_metadata + logits_indices, spec_decode_metadata, + num_scheduled_tokens, spec_decode_common_attn_metadata, + max_num_scheduled_tokens, use_cascade_attn ] """ total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens @@ -1131,6 +1133,7 @@ def _prepare_inputs( attn_metadata: PerLayerAttnMetadata = {} if ubatch_slices is not None: attn_metadata = [dict() for _ in range(len(ubatch_slices))] + use_cascade_attn = False # Used in the below loop. query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1] @@ -1247,9 +1250,15 @@ def _prepare_inputs( common_prefix_len=common_prefix_len, common_attn_metadata=common_attn_metadata, **extra_attn_metadata_args) + use_cascade_attn |= getattr(attn_metadata_i, "use_cascade", + False) for layer_name in attn_group.layer_names: attn_metadata[layer_name] = attn_metadata_i + # disable cascade attention when DBO + if ubatch_slices is not None: + use_cascade_attn = False + # Hot-Swap lora model if self.lora_config: self.set_active_loras(self.input_batch, num_scheduled_tokens) @@ -1257,7 +1266,7 @@ def _prepare_inputs( return (attn_metadata, logits_indices, spec_decode_metadata, num_scheduled_tokens, spec_decode_common_attn_metadata, max_num_scheduled_tokens, ubatch_slices, - num_tokens_after_padding) + num_tokens_after_padding, use_cascade_attn) def _compute_cascade_attn_prefix_len( self, @@ -2247,8 +2256,8 @@ def execute_model( # Prepare the decoder inputs. (attn_metadata, logits_indices, spec_decode_metadata, num_scheduled_tokens_np, spec_decode_common_attn_metadata, - max_query_len, ubatch_slices, num_tokens_after_padding - ) = self._prepare_inputs(scheduler_output) + max_query_len, ubatch_slices, num_tokens_after_padding, + use_cascade_attn) = self._prepare_inputs(scheduler_output) ( num_scheduled_tokens, @@ -2269,7 +2278,8 @@ def execute_model( batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens, uniform_decode=uniform_decode) cudagraph_runtime_mode, batch_descriptor = \ - self.cudagraph_dispatcher.dispatch(batch_descriptor) + self.cudagraph_dispatcher.dispatch(batch_descriptor, + use_cascade_attn) # This is currently to get around the assert in the DPMetadata # where it wants `num_tokens_across_dp` to align with `num_tokens` @@ -2697,16 +2707,15 @@ def reload_weights(self) -> None: "Cannot reload weights before model is loaded." model_loader = get_model_loader(self.load_config) logger.info("Reloading weights inplace...") - model = self.get_model() - model_loader.load_weights(model, model_config=self.model_config) + model_loader.load_weights(self.get_model(), + model_config=self.model_config) def save_tensorized_model( self, tensorizer_config: "TensorizerConfig", ) -> None: - model = self.get_model() TensorizerLoader.save_model( - model, + self.get_model(), tensorizer_config=tensorizer_config, model_config=self.model_config, ) @@ -2922,9 +2931,8 @@ def _dummy_run( (1 token) and prefill (multiple tokens) requests. remove_lora: If False, dummy LoRAs are not destroyed after the run """ - assert cudagraph_runtime_mode is None or cudagraph_runtime_mode in { - CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL - } + assert cudagraph_runtime_mode is None or \ + cudagraph_runtime_mode.valid_runtime_modes() # If cudagraph_mode.decode_mode() == FULL and # cudagraph_mode.separate_routine(). This means that we are using @@ -3108,7 +3116,8 @@ def _dummy_run( # filter out the valid batch descriptor _cg_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch( BatchDescriptor(num_tokens=num_tokens, - uniform_decode=uniform_decode)) + uniform_decode=uniform_decode)) \ + if not is_profile else (CUDAGraphMode.NONE, None) if cudagraph_runtime_mode is not None: # we allow forcing NONE when the dispatcher disagrees to support # warm ups for cudagraph capture @@ -3448,8 +3457,8 @@ def _capture_cudagraphs(self, compilation_cases: list[int], cudagraph_runtime_mode: CUDAGraphMode, uniform_decode: bool): assert cudagraph_runtime_mode != CUDAGraphMode.NONE and \ - cudagraph_runtime_mode in [CUDAGraphMode.FULL, - CUDAGraphMode.PIECEWISE] + cudagraph_runtime_mode.valid_runtime_modes(), \ + f"Invalid cudagraph runtime mode: {cudagraph_runtime_mode}" # Only rank 0 should print progress bar during capture if is_global_first_rank(): @@ -3580,6 +3589,12 @@ def create_attn_groups( self.calculate_reorder_batch_threshold() def initialize_cudagraph_capture(self) -> None: + """ + Resolve the cudagraph_mode when there are multiple attention + backends with potential conflicting CUDA graph support. + Then initialize the cudagraph_dispatcher based on the resolved + cudagraph_mode. + """ min_cg_support = AttentionCGSupport.ALWAYS min_cg_builder_name = None