From cee34810b16d8841a02949fadf658095e65171f9 Mon Sep 17 00:00:00 2001 From: Luuk Kempen Date: Tue, 23 Sep 2025 14:17:31 +0200 Subject: [PATCH 1/2] Fix pkg_resources deprecation warning --- requirements.txt | 1 + torchtnt/utils/version.py | 10 +++++++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/requirements.txt b/requirements.txt index 67f474827e..cbf69048ff 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ numpy==1.24.4 fsspec tensorboard packaging +importlib-metadata;python_version<'3.8' psutil pyre_extensions typing_extensions diff --git a/torchtnt/utils/version.py b/torchtnt/utils/version.py index 0cb7e7408f..0219112174 100644 --- a/torchtnt/utils/version.py +++ b/torchtnt/utils/version.py @@ -9,10 +9,14 @@ import platform -import pkg_resources import torch from packaging.version import Version +try: + import importlib.metadata as importlib_metadata +except ImportError: + import importlib_metadata + def is_windows() -> bool: """ @@ -48,8 +52,8 @@ def get_torch_version() -> Version: if hasattr(torch, "__version__"): pkg_version = Version(torch.__version__) else: - # try pkg_resources to infer version - pkg_version = Version(pkg_resources.get_distribution("torch").version) + # try importlib.metadata to infer version + pkg_version = Version(importlib_metadata.version("torch")) except TypeError as e: raise TypeError("PyTorch version could not be detected automatically.") from e From 615a9e764cb1d5cc4d7ec0b3408cb60ac0bce0ea Mon Sep 17 00:00:00 2001 From: Luuk Kempen Date: Tue, 23 Sep 2025 15:07:49 +0200 Subject: [PATCH 2/2] Add tests for getting Torch version via importlib --- tests/utils/test_version.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/tests/utils/test_version.py b/tests/utils/test_version.py index 5cc2bef1bd..860fde22ed 100644 --- a/tests/utils/test_version.py +++ b/tests/utils/test_version.py @@ -8,11 +8,16 @@ # pyre-strict import unittest -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, PropertyMock, patch import torch from packaging.version import Version +try: + import importlib.metadata as importlib_metadata +except ImportError: + import importlib_metadata + from torchtnt.utils import version @@ -47,6 +52,18 @@ def test_get_torch_version(self) -> None: self.assertNotEqual(version.get_torch_version(), Version("1.8.3")) self.assertEqual(version.get_torch_version(), Version("1.12.0")) + def test_get_torch_version_importlib(self) -> None: + with patch.object(torch, "__version__", new_callable=PropertyMock): + delattr(torch, "__version__") + + with patch.object(importlib_metadata, "version", return_value="1.8.3"): + self.assertEqual(version.get_torch_version(), Version("1.8.3")) + self.assertNotEqual(version.get_torch_version(), Version("1.12.0")) + + with patch.object(importlib_metadata, "version", return_value="1.12.0"): + self.assertNotEqual(version.get_torch_version(), Version("1.8.3")) + self.assertEqual(version.get_torch_version(), Version("1.12.0")) + def test_torch_version_comparators(self) -> None: with patch.object(torch, "__version__", "2.0.0a0"): self.assertFalse(version.is_torch_version_geq("2.1.0"))