diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index 37d8ae0c08bf..8679d5c3019b 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -5,6 +5,15 @@ import vllm from vllm.compilation.counter import compilation_counter from vllm.config import VllmConfig +from vllm.utils import _is_torch_equal_or_newer + + +def test_version(): + assert _is_torch_equal_or_newer('2.8.0.dev20250624+cu128', '2.8.0.dev') + assert _is_torch_equal_or_newer('2.8.0a0+gitc82a174', '2.8.0.dev') + assert _is_torch_equal_or_newer('2.8.0', '2.8.0.dev') + assert _is_torch_equal_or_newer('2.8.1', '2.8.0.dev') + assert not _is_torch_equal_or_newer('2.7.1', '2.8.0.dev') def test_use_cudagraphs_dynamic(monkeypatch): diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 8bb8c3a2a2e4..a2bb053cec4a 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -32,7 +32,7 @@ def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface: if compilation_config.use_inductor: if envs.VLLM_USE_STANDALONE_COMPILE and is_torch_equal_or_newer( - "2.8.0a"): + "2.8.0.dev"): logger.debug("Using InductorStandaloneAdaptor") return InductorStandaloneAdaptor() else: diff --git a/vllm/model_executor/layers/quantization/torchao.py b/vllm/model_executor/layers/quantization/torchao.py index 9c909a3a430c..a4e0356c0268 100644 --- a/vllm/model_executor/layers/quantization/torchao.py +++ b/vllm/model_executor/layers/quantization/torchao.py @@ -44,14 +44,14 @@ def __init__(self, """ # TorchAO quantization relies on tensor subclasses. In order, # to enable proper caching this needs standalone compile - if is_torch_equal_or_newer("2.8.0a"): + if is_torch_equal_or_newer("2.8.0.dev"): os.environ["VLLM_TEST_STANDALONE_COMPILE"] = "1" logger.info( "Using TorchAO: Setting VLLM_TEST_STANDALONE_COMPILE=1") # TODO: remove after the torch dependency is updated to 2.8 if is_torch_equal_or_newer( - "2.7.0") and not is_torch_equal_or_newer("2.8.0a"): + "2.7.0") and not is_torch_equal_or_newer("2.8.0.dev"): os.environ["VLLM_DISABLE_COMPILE_CACHE"] = "1" logger.info("Using TorchAO: Setting VLLM_DISABLE_COMPILE_CACHE=1") """ diff --git a/vllm/utils.py b/vllm/utils.py index 34be4d52c483..fdefda901c4d 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -2919,8 +2919,13 @@ def is_torch_equal_or_newer(target: str) -> bool: Whether the condition meets. """ try: - torch_version = version.parse(str(torch.__version__)) - return torch_version >= version.parse(target) + return _is_torch_equal_or_newer(str(torch.__version__), target) except Exception: # Fallback to PKG-INFO to load the package info, needed by the doc gen. return Version(importlib.metadata.version('torch')) >= Version(target) + + +# Helper function used in testing. +def _is_torch_equal_or_newer(torch_version: str, target: str) -> bool: + torch_version = version.parse(torch_version) + return torch_version >= version.parse(target)