Skip to content

Commit bf25167

Browse files
authored
Add testing for PyTorch 2.4 (Trainer) (#20010)
1 parent 96b75df commit bf25167

File tree

23 files changed

+157
-58
lines changed

23 files changed

+157
-58
lines changed

.azure/gpu-tests-pytorch.yml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ jobs:
5555
"Lightning | latest":
5656
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.11-torch2.3-cuda12.1.0"
5757
PACKAGE_NAME: "lightning"
58+
"Lightning | future":
59+
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.11-torch2.4-cuda12.1.0"
60+
PACKAGE_NAME: "lightning"
5861
pool: lit-rtx-3090
5962
variables:
6063
DEVICES: $( python -c 'print("$(Agent.Name)".split("_")[-1])' )
@@ -76,9 +79,12 @@ jobs:
7679
echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/cu${cuda_ver}/torch_stable.html"
7780
scope=$(python -c 'n = "$(PACKAGE_NAME)" ; print(dict(pytorch="pytorch_lightning").get(n, n))')
7881
echo "##vso[task.setvariable variable=COVERAGE_SOURCE]$scope"
82+
python_ver=$(python -c "import sys; print(f'{sys.version_info.major}{sys.version_info.minor}')")
83+
echo "##vso[task.setvariable variable=PYTHON_VERSION_MM]$python_ver"
7984
displayName: "set env. vars"
8085
- bash: |
81-
echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM}/torch_test.html"
86+
echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM}"
87+
echo "##vso[task.setvariable variable=TORCHVISION_URL]https://download.pytorch.org/whl/test/cu124/torchvision-0.19.0%2Bcu124-cp${PYTHON_VERSION_MM}-cp${PYTHON_VERSION_MM}-linux_x86_64.whl"
8288
condition: endsWith(variables['Agent.JobName'], 'future')
8389
displayName: "set env. vars 4 future"
8490
@@ -107,7 +113,7 @@ jobs:
107113
108114
- bash: |
109115
extra=$(python -c "print({'lightning': 'pytorch-'}.get('$(PACKAGE_NAME)', ''))")
110-
pip install -e ".[${extra}dev]" pytest-timeout -U --find-links="${TORCH_URL}"
116+
pip install -e ".[${extra}dev]" pytest-timeout -U --find-links="${TORCH_URL}" --find-links="${TORCHVISION_URL}"
111117
displayName: "Install package & dependencies"
112118
113119
- bash: pip uninstall -y lightning

.github/checkgroup.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,11 +142,15 @@ subprojects:
142142
- "build-cuda (3.10, 2.2, 12.1.0)"
143143
- "build-cuda (3.11, 2.1, 12.1.0)"
144144
- "build-cuda (3.11, 2.2, 12.1.0)"
145+
- "build-cuda (3.11, 2.3, 12.1.0)"
146+
- "build-cuda (3.11, 2.4, 12.1.0)"
145147
#- "build-NGC"
146148
- "build-pl (3.10, 2.1, 12.1.0)"
147149
- "build-pl (3.10, 2.2, 12.1.0)"
148150
- "build-pl (3.11, 2.1, 12.1.0)"
149151
- "build-pl (3.11, 2.2, 12.1.0)"
152+
- "build-pl (3.11, 2.3, 12.1.0)"
153+
- "build-pl (3.11, 2.4, 12.1.0)"
150154

151155
# SECTION: lightning_fabric
152156

.github/workflows/ci-tests-pytorch.yml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ jobs:
5353
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.3" }
5454
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.3" }
5555
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.3" }
56+
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.4" }
57+
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.4" }
58+
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.4" }
5659
# only run PyTorch latest with Python latest, use PyTorch scope to limit dependency issues
5760
- { os: "macOS-12", pkg-name: "pytorch", python-version: "3.11", pytorch-version: "2.1" }
5861
- { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.11", pytorch-version: "2.1" }
@@ -82,7 +85,7 @@ jobs:
8285
PACKAGE_NAME: ${{ matrix.pkg-name }}
8386
TORCH_URL: "https://download.pytorch.org/whl/cpu/torch_stable.html"
8487
TORCH_URL_STABLE: "https://download.pytorch.org/whl/cpu/torch_stable.html"
85-
TORCH_URL_TEST: "https://download.pytorch.org/whl/test/cpu/torch_test.html"
88+
TORCH_URL_TEST: "https://download.pytorch.org/whl/test/cpu/torch"
8689
FREEZE_REQUIREMENTS: ${{ ! (github.ref == 'refs/heads/master' || startsWith(github.ref, 'refs/heads/release/')) }}
8790
PYPI_CACHE_DIR: "_pip-wheels"
8891
# TODO: Remove this - Enable running MPS tests on this platform
@@ -124,11 +127,13 @@ jobs:
124127
- name: Env. variables
125128
run: |
126129
# Switch PyTorch URL
127-
python -c "print('TORCH_URL=' + str('${{env.TORCH_URL_TEST}}' if '${{ matrix.pytorch-version }}' == '2.3' else '${{env.TORCH_URL_STABLE}}'))" >> $GITHUB_ENV
130+
python -c "print('TORCH_URL=' + str('${{env.TORCH_URL_TEST}}' if '${{ matrix.pytorch-version }}' == '2.4' else '${{env.TORCH_URL_STABLE}}'))" >> $GITHUB_ENV
128131
# Switch coverage scope
129132
python -c "print('COVERAGE_SCOPE=' + str('lightning' if '${{matrix.pkg-name}}' == 'lightning' else 'pytorch_lightning'))" >> $GITHUB_ENV
130133
# if you install mono-package set dependency only for this subpackage
131134
python -c "print('EXTRA_PREFIX=' + str('' if '${{matrix.pkg-name}}' != 'lightning' else 'pytorch-'))" >> $GITHUB_ENV
135+
# Avoid issue on Windows with PyTorch 2.4: "RuntimeError: use_libuv was requested but PyTorch was build without libuv support"
136+
python -c "print('USE_LIBUV=0' if '${{matrix.os}}' == 'windows-2022' and '${{matrix.pytorch-version}}' == '2.4' else '')" >> $GITHUB_ENV
132137
133138
- name: Install package & dependencies
134139
timeout-minutes: 20

.github/workflows/docker-build.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ jobs:
4747
- { python_version: "3.10", pytorch_version: "2.2", cuda_version: "12.1.0" }
4848
- { python_version: "3.11", pytorch_version: "2.1", cuda_version: "12.1.0" }
4949
- { python_version: "3.11", pytorch_version: "2.2", cuda_version: "12.1.0" }
50+
- { python_version: "3.11", pytorch_version: "2.3", cuda_version: "12.1.0" }
51+
- { python_version: "3.11", pytorch_version: "2.4", cuda_version: "12.1.0" }
5052
steps:
5153
- uses: actions/checkout@v4
5254
with:
@@ -74,7 +76,7 @@ jobs:
7476
tags = [f"latest-py{py_ver}-torch{pt_ver}-cuda{cuda_ver}"]
7577
if ver:
7678
tags += [f"{ver}-py{py_ver}-torch{pt_ver}-cuda{cuda_ver}"]
77-
if py_ver == '3.10' and pt_ver == '2.1' and cuda_ver == '12.1.0':
79+
if py_ver == '3.11' and pt_ver == '2.3' and cuda_ver == '12.1.0':
7880
tags += ["latest"]
7981
8082
tags = [f"{repo}:{tag}" for tag in tags]
@@ -108,6 +110,7 @@ jobs:
108110
- { python_version: "3.11", pytorch_version: "2.1", cuda_version: "12.1.0" }
109111
- { python_version: "3.11", pytorch_version: "2.2", cuda_version: "12.1.0" }
110112
- { python_version: "3.11", pytorch_version: "2.3", cuda_version: "12.1.0" }
113+
- { python_version: "3.11", pytorch_version: "2.4", cuda_version: "12.1.0" }
111114
# - { python_version: "3.12", pytorch_version: "2.2", cuda_version: "12.1.0" } # todo: pending on `onnxruntime`
112115
steps:
113116
- uses: actions/checkout@v4

dockers/base-cuda/Dockerfile

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ FROM nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu${UBUNTU_VERSION}
2020

2121
ARG PYTHON_VERSION=3.10
2222
ARG PYTORCH_VERSION=2.1
23-
ARG MAX_ALLOWED_NCCL=2.17.1
23+
ARG MAX_ALLOWED_NCCL=2.22.3
2424

2525
SHELL ["/bin/bash", "-c"]
2626
# https://techoverflow.net/2019/05/18/how-to-fix-configuring-tzdata-interactive-input-when-building-docker-images/
@@ -92,7 +92,8 @@ RUN \
9292
-r requirements/pytorch/test.txt \
9393
-r requirements/pytorch/strategies.txt \
9494
--find-links="https://download.pytorch.org/whl/cu${CUDA_VERSION_MM//'.'/''}/torch_stable.html" \
95-
--find-links="https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM//'.'/''}/torch_test.html"
95+
--find-links="https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM//'.'/''}/torch" \
96+
--find-links="https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM//'.'/''}/pytorch-triton"
9697

9798
RUN \
9899
# Show what we have

docs/source-pytorch/versioning.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,18 @@ The table below indicates the coverage of tested versions in our CI. Versions ou
7979
- ``torch``
8080
- ``torchmetrics``
8181
- Python
82+
* - 2.4
83+
- 2.4
84+
- 2.4
85+
- ≥2.1, ≤2.4
86+
- ≥0.7.0
87+
- ≥3.9, ≤3.12
88+
* - 2.3
89+
- 2.3
90+
- 2.3
91+
- ≥2.0, ≤2.3
92+
- ≥0.7.0
93+
- ≥3.8, ≤3.11
8294
* - 2.2
8395
- 2.2
8496
- 2.2

requirements/pytorch/base.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment
33

44
numpy >=1.21.0, <1.27.0
5-
torch >=2.1.0, <2.4.0
5+
torch >=2.1.0, <2.5.0
66
tqdm >=4.57.0, <4.67.0
77
PyYAML >=5.4, <6.1.0
88
fsspec[http] >=2022.5.0, <2024.4.0

requirements/pytorch/examples.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment
33

44
requests <2.32.0
5-
torchvision >=0.16.0, <0.19.0
5+
torchvision >=0.16.0, <0.20.0
66
ipython[all] <8.15.0
77
torchmetrics >=0.10.0, <1.3.0
88
lightning-utilities >=0.8.0, <0.12.0

src/lightning/fabric/utilities/cloud_io.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
def _load(
3434
path_or_url: Union[IO, _PATH],
3535
map_location: _MAP_LOCATION_TYPE = None,
36+
weights_only: bool = False,
3637
) -> Any:
3738
"""Loads a checkpoint.
3839
@@ -46,15 +47,21 @@ def _load(
4647
return torch.load(
4748
path_or_url,
4849
map_location=map_location, # type: ignore[arg-type] # upstream annotation is not correct
50+
weights_only=weights_only,
4951
)
5052
if str(path_or_url).startswith("http"):
5153
return torch.hub.load_state_dict_from_url(
5254
str(path_or_url),
5355
map_location=map_location, # type: ignore[arg-type]
56+
weights_only=weights_only,
5457
)
5558
fs = get_filesystem(path_or_url)
5659
with fs.open(path_or_url, "rb") as f:
57-
return torch.load(f, map_location=map_location) # type: ignore[arg-type]
60+
return torch.load(
61+
f,
62+
map_location=map_location, # type: ignore[arg-type]
63+
weights_only=weights_only,
64+
)
5865

5966

6067
def get_filesystem(path: _PATH, **kwargs: Any) -> AbstractFileSystem:

src/lightning/pytorch/plugins/precision/amp.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import lightning.pytorch as pl
2121
from lightning.fabric.plugins.precision.amp import _optimizer_handles_unscaling
22+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
2223
from lightning.fabric.utilities.types import Optimizable
2324
from lightning.pytorch.plugins.precision.precision import Precision
2425
from lightning.pytorch.utilities import GradClipAlgorithmType
@@ -39,7 +40,7 @@ def __init__(
3940
self,
4041
precision: Literal["16-mixed", "bf16-mixed"],
4142
device: str,
42-
scaler: Optional[torch.cuda.amp.GradScaler] = None,
43+
scaler: Optional["torch.cuda.amp.GradScaler"] = None,
4344
) -> None:
4445
if precision not in ("16-mixed", "bf16-mixed"):
4546
raise ValueError(
@@ -49,7 +50,7 @@ def __init__(
4950

5051
self.precision = precision
5152
if scaler is None and self.precision == "16-mixed":
52-
scaler = torch.cuda.amp.GradScaler()
53+
scaler = torch.amp.GradScaler(device=device) if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler()
5354
if scaler is not None and self.precision == "bf16-mixed":
5455
raise MisconfigurationException(f"`precision='bf16-mixed'` does not use a scaler, found {scaler}.")
5556
self.device = device

0 commit comments

Comments
 (0)