Skip to content

Commit a58639c

Browse files
authored
Nightly PyTorch version is now 2.0 (#16017)
1 parent 5375982 commit a58639c

File tree

22 files changed

+33
-40
lines changed

22 files changed

+33
-40
lines changed

src/lightning_lite/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
_logger.propagate = False
1515

1616

17-
# In PyTorch 1.14+, setting this variable will force `torch.cuda.is_available()` and `torch.cuda.device_count()`
17+
# In PyTorch 2.0+, setting this variable will force `torch.cuda.is_available()` and `torch.cuda.device_count()`
1818
# to use an NVML-based implementation that doesn't poison forks.
1919
# https://github.com/pytorch/pytorch/issues/83973
2020
os.environ["PYTORCH_NVML_BASED_CUDA_CHECK"] = "1"

src/lightning_lite/accelerators/cuda.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from lightning_lite.utilities.imports import (
2525
_TORCH_GREATER_EQUAL_1_12,
2626
_TORCH_GREATER_EQUAL_1_13,
27-
_TORCH_GREATER_EQUAL_1_14,
27+
_TORCH_GREATER_EQUAL_2_0,
2828
)
2929

3030
_log = logging.getLogger(__name__)
@@ -86,12 +86,12 @@ def _get_all_available_cuda_gpus() -> List[int]:
8686
return list(range(num_cuda_devices()))
8787

8888

89-
# TODO: Remove once minimum supported PyTorch version is 1.14
89+
# TODO: Remove once minimum supported PyTorch version is 2.0
9090
@contextmanager
9191
def _patch_cuda_is_available() -> Generator:
9292
"""Context manager that safely patches :func:`torch.cuda.is_available` with its NVML-based version if
9393
possible."""
94-
if hasattr(torch._C, "_cuda_getDeviceCount") and _device_count_nvml() >= 0 and not _TORCH_GREATER_EQUAL_1_14:
94+
if hasattr(torch._C, "_cuda_getDeviceCount") and _device_count_nvml() >= 0 and not _TORCH_GREATER_EQUAL_2_0:
9595
# we can safely patch is_available if both torch has CUDA compiled and the NVML count is succeeding
9696
# otherwise, patching is_available could lead to attribute errors or infinite recursion
9797
orig_check = torch.cuda.is_available
@@ -127,7 +127,7 @@ def is_cuda_available() -> bool:
127127
if the platform allows it.
128128
"""
129129
# We set `PYTORCH_NVML_BASED_CUDA_CHECK=1` in lightning_lite.__init__.py
130-
return torch.cuda.is_available() if _TORCH_GREATER_EQUAL_1_14 else num_cuda_devices() > 0
130+
return torch.cuda.is_available() if _TORCH_GREATER_EQUAL_2_0 else num_cuda_devices() > 0
131131

132132

133133
# TODO: Remove once minimum supported PyTorch version is 1.13

src/lightning_lite/utilities/imports.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,7 @@
2525
# 2. The inspection mode via `python -i`: https://stackoverflow.com/a/6879085/1162383
2626
_IS_INTERACTIVE = hasattr(sys, "ps1") or bool(sys.flags.interactive)
2727

28-
_PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8)
29-
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)
3028
_TORCH_GREATER_EQUAL_1_11 = compare_version("torch", operator.ge, "1.11.0")
3129
_TORCH_GREATER_EQUAL_1_12 = compare_version("torch", operator.ge, "1.12.0")
3230
_TORCH_GREATER_EQUAL_1_13 = compare_version("torch", operator.ge, "1.13.0")
33-
_TORCH_GREATER_EQUAL_1_14 = compare_version("torch", operator.ge, "1.14.0", use_base_version=True)
31+
_TORCH_GREATER_EQUAL_2_0 = compare_version("torch", operator.ge, "2.0.0", use_base_version=True)

src/lightning_lite/utilities/types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from torch.optim import Optimizer
2020
from typing_extensions import Protocol, runtime_checkable
2121

22-
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_13, _TORCH_GREATER_EQUAL_1_14
22+
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_13, _TORCH_GREATER_EQUAL_2_0
2323

2424
_PATH = Union[str, Path]
2525
_DEVICE = Union[torch.device, str, int]
@@ -75,7 +75,7 @@ def step(self, epoch: Optional[int] = None) -> None:
7575

7676

7777
_TORCH_LRSCHEDULER = (
78-
torch.optim.lr_scheduler.LRScheduler if _TORCH_GREATER_EQUAL_1_14 else torch.optim.lr_scheduler._LRScheduler
78+
torch.optim.lr_scheduler.LRScheduler if _TORCH_GREATER_EQUAL_2_0 else torch.optim.lr_scheduler._LRScheduler
7979
)
8080

8181

src/pytorch_lightning/CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
210210
- HPC checkpoints are now loaded automatically only in slurm environment when no specific value for `ckpt_path` has been set ([#14911](https://github.com/Lightning-AI/lightning/pull/14911))
211211
- The `Callback.on_load_checkpoint` now gets the full checkpoint dictionary and the `callback_state` argument was renamed `checkpoint` ([#14835](https://github.com/Lightning-AI/lightning/pull/14835))
212212
- Moved the warning about saving nn.Module in `save_hyperparameters()` to before the deepcopy ([#15132](https://github.com/Lightning-AI/lightning/pull/15132))
213-
- To avoid issues with forking processes, from PyTorch 1.13 and higher, Lightning will directly use the PyTorch NVML-based check for `torch.cuda.device_count` and from PyTorch 1.14 and higher, Lightning will configure PyTorch to use a NVML-based check for `torch.cuda.is_available`. ([#15110](https://github.com/Lightning-AI/lightning/pull/15110), [#15133](https://github.com/Lightning-AI/lightning/pull/15133))
213+
- To avoid issues with forking processes, from PyTorch 1.13 and higher, Lightning will directly use the PyTorch NVML-based check for `torch.cuda.device_count` and from PyTorch 2.0 and higher, Lightning will configure PyTorch to use a NVML-based check for `torch.cuda.is_available`. ([#15110](https://github.com/Lightning-AI/lightning/pull/15110), [#15133](https://github.com/Lightning-AI/lightning/pull/15133))
214214
- The `NeptuneLogger` now uses `neptune.init_run` instead of the deprecated `neptune.init` to initialize a run ([#15393](https://github.com/Lightning-AI/lightning/pull/15393))
215215

216216
### Deprecated

src/pytorch_lightning/callbacks/quantization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626
from torch.quantization import FakeQuantizeBase
2727

2828
import pytorch_lightning as pl
29+
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_11, _TORCH_GREATER_EQUAL_1_12
2930
from pytorch_lightning.callbacks.callback import Callback
30-
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_11, _TORCH_GREATER_EQUAL_1_12
3131
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3232

3333
if _TORCH_GREATER_EQUAL_1_11:

src/pytorch_lightning/core/module.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from lightning_lite.utilities.cloud_io import get_filesystem
3838
from lightning_lite.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
3939
from lightning_lite.utilities.distributed import _distributed_available, _sync_ddp
40+
from lightning_lite.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_11
4041
from lightning_lite.utilities.types import Steppable
4142
from pytorch_lightning.callbacks.callback import Callback
4243
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks
@@ -45,9 +46,9 @@
4546
from pytorch_lightning.core.saving import ModelIO
4647
from pytorch_lightning.loggers import Logger
4748
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator
48-
from pytorch_lightning.utilities import _IS_WINDOWS, GradClipAlgorithmType
49+
from pytorch_lightning.utilities import GradClipAlgorithmType
4950
from pytorch_lightning.utilities.exceptions import MisconfigurationException
50-
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_11, _TORCH_GREATER_EQUAL_1_13
51+
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_13
5152
from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_warn
5253
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
5354
from pytorch_lightning.utilities.types import (

src/pytorch_lightning/demos/mnist_datamodule.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@
2424
from torch import Tensor
2525
from torch.utils.data import DataLoader, Dataset, random_split
2626

27+
from lightning_lite.utilities.imports import _IS_WINDOWS
2728
from pytorch_lightning import LightningDataModule
28-
from pytorch_lightning.utilities.imports import _IS_WINDOWS, _TORCHVISION_AVAILABLE
29+
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
2930

3031
if _TORCHVISION_AVAILABLE:
3132
from torchvision import transforms as transform_lib

src/pytorch_lightning/plugins/precision/fsdp_native_native_amp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
import torch
1717

1818
from lightning_lite.utilities.enums import PrecisionType
19+
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12
1920
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin
2021
from pytorch_lightning.utilities.exceptions import MisconfigurationException
21-
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12
2222

2323
if _TORCH_GREATER_EQUAL_1_12 and torch.distributed.is_available():
2424
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision

src/pytorch_lightning/strategies/ddp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
_sync_ddp_if_available,
4040
)
4141
from lightning_lite.utilities.distributed import group as _group
42+
from lightning_lite.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_11
4243
from lightning_lite.utilities.optimizer import _optimizers_to_device
4344
from lightning_lite.utilities.seed import reset_seed
4445
from lightning_lite.utilities.types import ReduceOp
@@ -53,7 +54,6 @@
5354
from pytorch_lightning.trainer.states import TrainerFn
5455
from pytorch_lightning.utilities.distributed import register_ddp_comm_hook
5556
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException
56-
from pytorch_lightning.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_11
5757
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn
5858
from pytorch_lightning.utilities.types import PredictStep, STEP_OUTPUT, TestStep, ValidationStep
5959

0 commit comments

Comments
 (0)