Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ numpy==1.24.4
fsspec
tensorboard
packaging
importlib-metadata;python_version<'3.8'
psutil
pyre_extensions
typing_extensions
Expand Down
19 changes: 18 additions & 1 deletion tests/utils/test_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,16 @@
# pyre-strict

import unittest
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, PropertyMock, patch

import torch
from packaging.version import Version

try:
import importlib.metadata as importlib_metadata
except ImportError:
import importlib_metadata

from torchtnt.utils import version


Expand Down Expand Up @@ -47,6 +52,18 @@ def test_get_torch_version(self) -> None:
self.assertNotEqual(version.get_torch_version(), Version("1.8.3"))
self.assertEqual(version.get_torch_version(), Version("1.12.0"))

def test_get_torch_version_importlib(self) -> None:
with patch.object(torch, "__version__", new_callable=PropertyMock):
delattr(torch, "__version__")

with patch.object(importlib_metadata, "version", return_value="1.8.3"):
self.assertEqual(version.get_torch_version(), Version("1.8.3"))
self.assertNotEqual(version.get_torch_version(), Version("1.12.0"))

with patch.object(importlib_metadata, "version", return_value="1.12.0"):
self.assertNotEqual(version.get_torch_version(), Version("1.8.3"))
self.assertEqual(version.get_torch_version(), Version("1.12.0"))

def test_torch_version_comparators(self) -> None:
with patch.object(torch, "__version__", "2.0.0a0"):
self.assertFalse(version.is_torch_version_geq("2.1.0"))
10 changes: 7 additions & 3 deletions torchtnt/utils/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@

import platform

import pkg_resources
import torch
from packaging.version import Version

try:
import importlib.metadata as importlib_metadata
except ImportError:
import importlib_metadata


def is_windows() -> bool:
"""
Expand Down Expand Up @@ -48,8 +52,8 @@ def get_torch_version() -> Version:
if hasattr(torch, "__version__"):
pkg_version = Version(torch.__version__)
else:
# try pkg_resources to infer version
pkg_version = Version(pkg_resources.get_distribution("torch").version)
# try importlib.metadata to infer version
pkg_version = Version(importlib_metadata.version("torch"))
except TypeError as e:
raise TypeError("PyTorch version could not be detected automatically.") from e

Expand Down