From d168bac75411def79d7dbbd226cafe7a01af1cf4 Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Thu, 23 Oct 2025 16:13:36 +0000 Subject: [PATCH 1/4] move resolving cudagraph_mode before metadata_builder init Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/v1/worker/gpu_model_runner.py | 38 ++++++++++++++++++++---------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ebc8cfe92deb..c8578f1b5995 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3751,8 +3751,6 @@ def capture_model(self) -> int: "ensure `cudagraph_mode` was not manually set to `NONE`" ) return 0 - else: - self.initialize_cudagraph_capture() compilation_counter.num_gpu_runner_capture_triggers += 1 @@ -3926,7 +3924,7 @@ class AttentionGroupKey(NamedTuple): def get_attn_backends_for_group( kv_cache_group_spec: KVCacheGroupSpec, - ) -> dict[AttentionGroupKey, list[str]]: + ) -> tuple[dict[AttentionGroupKey, list[str]], set[type[AttentionBackend]]]: layers = get_layers_from_vllm_config( self.vllm_config, AttentionLayerBase, kv_cache_group_spec.layer_names ) @@ -3955,7 +3953,10 @@ def get_attn_backends_for_group( attn_backend, layer_kv_cache_spec ) attn_backend_layers[key].append(layer_name) - return {attn_backends[k]: v for k, v in attn_backend_layers.items()} + return ( + {attn_backends[k]: v for k, v in attn_backend_layers.items()}, + set(group_key.attn_backend for group_key in attn_backends.values()), + ) def create_attn_groups( attn_backends_map: dict[AttentionGroupKey, list[str]], @@ -3976,14 +3977,25 @@ def create_attn_groups( attn_groups.append(attn_group) return attn_groups + attention_backend_maps = [] + attention_backend_set: set[type[AttentionBackend]] = set() for kv_cache_group_spec in kv_cache_config.kv_cache_groups: attn_backends = get_attn_backends_for_group(kv_cache_group_spec) - self.attn_groups.append(create_attn_groups(attn_backends)) + attention_backend_maps.append(attn_backends[0]) + attention_backend_set.union(attn_backends[1]) + + # Resolve cudagraph_mode before actually initialize metadata_builders + self._check_and_update_cudagraph_mode(attention_backend_set) + + for attn_backends_map in attention_backend_maps: + self.attn_groups.append(create_attn_groups(attn_backends_map)) # Calculate reorder batch threshold (if needed) self.calculate_reorder_batch_threshold() - def initialize_cudagraph_capture(self) -> None: + def _check_and_update_cudagraph_mode( + self, attention_backends: set[type[AttentionBackend]] + ) -> None: """ Resolve the cudagraph_mode when there are multiple attention backends with potential conflicting CUDA graph support. @@ -3993,11 +4005,11 @@ def initialize_cudagraph_capture(self) -> None: min_cg_support = AttentionCGSupport.ALWAYS min_cg_builder_name = None - for attn_group in self._attn_group_iterator(): - builder = attn_group.get_metadata_builder() - if builder.cudagraph_support.value < min_cg_support.value: - min_cg_support = builder.cudagraph_support - min_cg_builder_name = builder.__class__.__name__ + for attn_backend in attention_backends: + builder_cls = attn_backend.get_builder_cls() + if builder_cls.cudagraph_support.value < min_cg_support.value: + min_cg_support = builder_cls.cudagraph_support + min_cg_builder_name = builder_cls.__name__ # Flexible resolve the cudagraph mode cudagraph_mode = self.compilation_config.cudagraph_mode # check cudagraph for mixed batch is supported @@ -4100,8 +4112,8 @@ def initialize_cudagraph_capture(self) -> None: "and make sure compilation mode is VLLM_COMPILE" ) - # Trigger cudagraph dispatching keys initialization here (after - # initializing attn backends). + # Trigger cudagraph dispatching keys initialization after + # resolved cudagraph mode. self.cudagraph_dispatcher.initialize_cudagraph_keys( self.compilation_config.cudagraph_mode, self.uniform_decode_query_len ) From 38f73df1283ff269a7c538eaa07a2133d9942df7 Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Thu, 23 Oct 2025 16:37:36 +0000 Subject: [PATCH 2/4] fix Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- vllm/v1/worker/gpu_model_runner.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c8578f1b5995..6d124f0f4fbc 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3982,7 +3982,7 @@ def create_attn_groups( for kv_cache_group_spec in kv_cache_config.kv_cache_groups: attn_backends = get_attn_backends_for_group(kv_cache_group_spec) attention_backend_maps.append(attn_backends[0]) - attention_backend_set.union(attn_backends[1]) + attention_backend_set.update(attn_backends[1]) # Resolve cudagraph_mode before actually initialize metadata_builders self._check_and_update_cudagraph_mode(attention_backend_set) @@ -4003,13 +4003,13 @@ def _check_and_update_cudagraph_mode( cudagraph_mode. """ min_cg_support = AttentionCGSupport.ALWAYS - min_cg_builder_name = None + min_cg_backend_name = None for attn_backend in attention_backends: builder_cls = attn_backend.get_builder_cls() if builder_cls.cudagraph_support.value < min_cg_support.value: min_cg_support = builder_cls.cudagraph_support - min_cg_builder_name = builder_cls.__name__ + min_cg_backend_name = attn_backend.__name__ # Flexible resolve the cudagraph mode cudagraph_mode = self.compilation_config.cudagraph_mode # check cudagraph for mixed batch is supported @@ -4019,7 +4019,7 @@ def _check_and_update_cudagraph_mode( ): msg = ( f"CUDAGraphMode.{cudagraph_mode.name} is not supported " - f"with {min_cg_builder_name} backend (support: " + f"with {min_cg_backend_name} backend (support: " f"{min_cg_support})" ) if min_cg_support == AttentionCGSupport.NEVER: @@ -4050,7 +4050,7 @@ def _check_and_update_cudagraph_mode( ): msg = ( f"CUDAGraphMode.{cudagraph_mode.name} is not supported " - f"with {min_cg_builder_name} backend (support: " + f"with {min_cg_backend_name} backend (support: " f"{min_cg_support})" ) if self.compilation_config.mode == CompilationMode.VLLM_COMPILE and ( @@ -4084,7 +4084,7 @@ def _check_and_update_cudagraph_mode( msg = ( f"CUDAGraphMode.{cudagraph_mode.name} is not supported" f" with spec-decode for attention backend " - f"{min_cg_builder_name} (support: {min_cg_support})" + f"{min_cg_backend_name} (support: {min_cg_support})" ) if self.compilation_config.splitting_ops_contain_attention(): msg += "; setting cudagraph_mode=PIECEWISE" @@ -4106,7 +4106,7 @@ def _check_and_update_cudagraph_mode( ): raise ValueError( f"CUDAGraphMode.{cudagraph_mode.name} is not " - f"supported with {min_cg_builder_name} backend (" + f"supported with {min_cg_backend_name} backend (" f"support:{min_cg_support}) " "; please try cudagraph_mode=PIECEWISE, " "and make sure compilation mode is VLLM_COMPILE" From 16ac4c1995c530bedcbecfe9c1b5e097d91642be Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Thu, 23 Oct 2025 16:52:46 +0000 Subject: [PATCH 3/4] add comments Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- tests/compile/test_fusions_e2e.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/compile/test_fusions_e2e.py b/tests/compile/test_fusions_e2e.py index 50271e2a4d70..d66c60ccb5b2 100644 --- a/tests/compile/test_fusions_e2e.py +++ b/tests/compile/test_fusions_e2e.py @@ -132,6 +132,9 @@ def test_attn_quant( mode = CUDAGraphMode.FULL_AND_PIECEWISE splitting_ops: list[str] | None = None else: + # FIXME: Llama-4-Scout-17B-16E-Instruct-FP8 + FlashInfer + Blackwell end at + # CUDAGraphMode.NONE here because it derives an attention backend that + # does not support full cudagraphs mode = CUDAGraphMode.FULL_DECODE_ONLY splitting_ops = [] From a15bcf4ae6c1f204aab2e8170fe8d91c25949164 Mon Sep 17 00:00:00 2001 From: fhl2000 <63384265+fhl2000@users.noreply.github.com> Date: Thu, 23 Oct 2025 17:34:02 +0000 Subject: [PATCH 4/4] fix doc Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> --- docs/design/cuda_graphs.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/design/cuda_graphs.md b/docs/design/cuda_graphs.md index e511eb25cb7a..b56cf61e782c 100644 --- a/docs/design/cuda_graphs.md +++ b/docs/design/cuda_graphs.md @@ -167,7 +167,7 @@ class AttentionCGSupport(enum.Enum): """NO CUDA Graphs support""" ``` -Suppose we have hybrid attention backends (e.g., in mamba mixer models). In that case, we seek the minimum capability of all backends to determine the final capability of the model, and we might resolve the incompatible CUDA Graphs mode by downgrading the mode to the best fit one. For example, downgrading `FULL` mode to `FULL_AND_PIECEWISE` mode if the minimum capability is `UNIFORM_BATCH`, or `PIECEWISE` mode if the minimum capability is `NEVER` for -O3 compilation mode. For the complete fallback policy, please see the code of [initialize_cudagraph_capture][vllm.v1.worker.gpu_model_runner.GPUModelRunner.initialize_cudagraph_capture]. +Suppose we have hybrid attention backends (e.g., in mamba mixer models). In that case, we seek the minimum capability of all backends to determine the final capability of the model, and we might resolve the incompatible CUDA Graphs mode by downgrading the mode to the best fit one. For example, downgrading `FULL` mode to `FULL_AND_PIECEWISE` mode if the minimum capability is `UNIFORM_BATCH`, or `PIECEWISE` mode if the minimum capability is `NEVER` for -O3 compilation mode. For the complete fallback policy, please see the code for [this][vllm.v1.worker.gpu_model_runner.GPUModelRunner._check_and_update_cudagraph_mode]. The following table lists backends that support full CUDA Graphs at the time of writing.