Skip to content

Commit b62700e

Browse files
Drop the DataLoader iterator when pickling (#17130)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent a26e54f commit b62700e

File tree

3 files changed

+35
-2
lines changed

3 files changed

+35
-2
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1212

1313
### Fixed
1414

15+
- Fixed issue where pickling the module instance would fail with a DataLoader error ([#17130](https://github.com/Lightning-AI/lightning/pull/17130))
1516
- Fixed WandbLogger not showing "best" aliases for model checkpoints when `ModelCheckpoint(save_top_k>0)` is used ([#17121](https://github.com/Lightning-AI/lightning/pull/17121))
1617

1718

src/lightning/pytorch/utilities/combined_loader.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from collections.abc import Iterable
15-
from typing import Any, Callable, Iterator, List, Literal, Optional, Tuple, Type, TypeVar, Union
15+
from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Tuple, Type, TypeVar, Union
1616

17-
from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter
17+
from torch.utils.data.dataloader import _BaseDataLoaderIter, _MultiProcessingDataLoaderIter
1818
from typing_extensions import Self, TypedDict
1919

2020
from lightning.fabric.utilities.data import sized_len
@@ -38,6 +38,18 @@ def __iter__(self) -> Self:
3838
def reset(self) -> None:
3939
self.iterators = []
4040

41+
def __getstate__(self) -> Dict[str, Any]:
42+
state = self.__dict__.copy()
43+
44+
# workaround an inconvenient `NotImplementedError`:
45+
# https://github.com/pytorch/pytorch/blob/v2.0.0/torch/utils/data/dataloader.py#L652-L658
46+
state["iterators"] = [
47+
None if isinstance(iterator, _BaseDataLoaderIter) else iterator_state
48+
for iterator, iterator_state in zip(self.iterators, state["iterators"])
49+
]
50+
51+
return state
52+
4153

4254
class _MaxSizeCycle(_ModeIterator[List]):
4355
def __init__(self, iterables: List[Iterable]) -> None:

tests/tests_pytorch/utilities/test_combined_loader.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import math
15+
import pickle
1516
from typing import Any, get_args, NamedTuple, Sequence
1617

1718
import pytest
@@ -536,3 +537,22 @@ def test_combined_dataloader_for_training_with_ddp(use_distributed_sampler, mode
536537

537538
def test_supported_modes():
538539
assert set(_SUPPORTED_MODES) == set(get_args(_LITERAL_SUPPORTED_MODES))
540+
541+
542+
def test_combined_loader_can_be_pickled():
543+
dataloader = DataLoader([0, 1, 2, 3])
544+
545+
# sanity check that and error would be raised. if this ever changes, `_ModeIterator.__getstate__` should be updated
546+
iterator = iter(dataloader)
547+
with pytest.raises(NotImplementedError, match="cannot be pickled"):
548+
pickle.dumps(iterator)
549+
550+
numbers = list(range(10))
551+
cl = CombinedLoader([dataloader, numbers])
552+
iter(cl)
553+
554+
iterator = cl._iterator
555+
assert iterator.__getstate__() == {"iterables": [dataloader, numbers], "iterators": [None, iterator.iterators[1]]}
556+
557+
# no error
558+
pickle.dumps(cl)

0 commit comments

Comments
 (0)