|
24 | 24 | from lightning_lite.utilities.imports import ( |
25 | 25 | _TORCH_GREATER_EQUAL_1_12, |
26 | 26 | _TORCH_GREATER_EQUAL_1_13, |
27 | | - _TORCH_GREATER_EQUAL_1_14, |
| 27 | + _TORCH_GREATER_EQUAL_2_0, |
28 | 28 | ) |
29 | 29 |
|
30 | 30 | _log = logging.getLogger(__name__) |
@@ -86,12 +86,12 @@ def _get_all_available_cuda_gpus() -> List[int]: |
86 | 86 | return list(range(num_cuda_devices())) |
87 | 87 |
|
88 | 88 |
|
89 | | -# TODO: Remove once minimum supported PyTorch version is 1.14 |
| 89 | +# TODO: Remove once minimum supported PyTorch version is 2.0 |
90 | 90 | @contextmanager |
91 | 91 | def _patch_cuda_is_available() -> Generator: |
92 | 92 | """Context manager that safely patches :func:`torch.cuda.is_available` with its NVML-based version if |
93 | 93 | possible.""" |
94 | | - if hasattr(torch._C, "_cuda_getDeviceCount") and _device_count_nvml() >= 0 and not _TORCH_GREATER_EQUAL_1_14: |
| 94 | + if hasattr(torch._C, "_cuda_getDeviceCount") and _device_count_nvml() >= 0 and not _TORCH_GREATER_EQUAL_2_0: |
95 | 95 | # we can safely patch is_available if both torch has CUDA compiled and the NVML count is succeeding |
96 | 96 | # otherwise, patching is_available could lead to attribute errors or infinite recursion |
97 | 97 | orig_check = torch.cuda.is_available |
@@ -127,7 +127,7 @@ def is_cuda_available() -> bool: |
127 | 127 | if the platform allows it. |
128 | 128 | """ |
129 | 129 | # We set `PYTORCH_NVML_BASED_CUDA_CHECK=1` in lightning_lite.__init__.py |
130 | | - return torch.cuda.is_available() if _TORCH_GREATER_EQUAL_1_14 else num_cuda_devices() > 0 |
| 130 | + return torch.cuda.is_available() if _TORCH_GREATER_EQUAL_2_0 else num_cuda_devices() > 0 |
131 | 131 |
|
132 | 132 |
|
133 | 133 | # TODO: Remove once minimum supported PyTorch version is 1.13 |
|
0 commit comments