diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index e31a78ba33ba..9bb8087a511d 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -503,7 +503,7 @@ def __post_init__(self): if self.compilation_config.pass_config.enable_sequence_parallelism: self.compilation_config.custom_ops.append("+rms_norm") - if current_platform.is_cuda_alike() or current_platform.is_xpu(): + if current_platform.support_static_graph_mode(): # if cudagraph_mode is not explicitly set by users, set default # value if self.compilation_config.cudagraph_mode is None: diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 05f129f513a0..7baa5a9742f4 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -498,6 +498,10 @@ def check_if_supports_dtype(cls, torch_dtype: torch.dtype): def support_hybrid_kv_cache(cls) -> bool: return True + @classmethod + def support_static_graph_mode(cls) -> bool: + return True + # NVML utils # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index c43580ac5da1..8a05c84d4242 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -587,6 +587,13 @@ def support_hybrid_kv_cache(cls) -> bool: """ return False + @classmethod + def support_static_graph_mode(cls) -> bool: + """ + Returns if the graph mode is supported by the current platform. + """ + return False + @classmethod def use_sync_weight_loader(cls) -> bool: """ diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 9470434aa428..0c7b9c2a4abf 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -477,3 +477,7 @@ def check_if_supports_dtype(cls, torch_dtype: torch.dtype): @classmethod def support_hybrid_kv_cache(cls) -> bool: return True + + @classmethod + def support_static_graph_mode(cls) -> bool: + return True diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 4d3bef4b4294..eb591ae4454e 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -113,12 +113,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: # lazy import to avoid circular import from vllm.config import CompilationLevel, CUDAGraphMode compilation_config = vllm_config.compilation_config - if compilation_config.cudagraph_mode is None or \ - compilation_config.cudagraph_mode.max_cudagraph_mode() \ - != CUDAGraphMode.NONE: - logger.info("[XPU] CUDA graph is not supported on XPU, disabling " - "cudagraphs. Fallback to cudagraph_mode=NONE") - compilation_config.cudagraph_mode = CUDAGraphMode.NONE + + assert compilation_config.cudagraph_mode == CUDAGraphMode.NONE, \ + "CUDA graph mode should be NONE on XPU" if vllm_config.lora_config is not None: compilation_config.level = CompilationLevel.NO_COMPILATION @@ -169,6 +166,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: def support_hybrid_kv_cache(cls) -> bool: return True + @classmethod + def support_static_graph_mode(cls) -> bool: + return False + @classmethod def is_pin_memory_available(cls): return True