|
15 | 15 | import os |
16 | 16 | from copy import deepcopy |
17 | 17 | from unittest import mock |
18 | | -from unittest.mock import Mock, patch |
| 18 | +from unittest.mock import MagicMock, patch |
19 | 19 |
|
20 | 20 | import pytest |
21 | 21 | import torch |
@@ -267,16 +267,28 @@ def test_xla_checkpoint_plugin_being_default(tpu_available): |
267 | 267 |
|
268 | 268 |
|
269 | 269 | @RunIf(tpu=True) |
270 | | -@patch("pytorch.lightning.strategies.tpu_spawn.TPUSpawnStrategy.root_device") |
| 270 | +@patch("lightning.pytorch.strategies.tpu_spawn.TPUSpawnStrategy.root_device") |
271 | 271 | def test_xla_mp_device_dataloader_attribute(_, monkeypatch): |
| 272 | + dataset = RandomDataset(32, 64) |
| 273 | + dataloader = DataLoader(dataset) |
| 274 | + strategy = TPUSpawnStrategy() |
| 275 | + isinstance_return = True |
| 276 | + |
272 | 277 | import torch_xla.distributed.parallel_loader as parallel_loader |
273 | 278 |
|
274 | | - mp_loader_mock = Mock() |
| 279 | + class MpDeviceLoaderMock(MagicMock): |
| 280 | + def __instancecheck__(self, instance): |
| 281 | + # to make `isinstance(dataloader, MpDeviceLoader)` pass with a mock as class |
| 282 | + return isinstance_return |
| 283 | + |
| 284 | + mp_loader_mock = MpDeviceLoaderMock() |
275 | 285 | monkeypatch.setattr(parallel_loader, "MpDeviceLoader", mp_loader_mock) |
276 | 286 |
|
277 | | - dataset = RandomDataset(32, 64) |
278 | | - dataloader = DataLoader(dataset) |
279 | | - strategy = TPUSpawnStrategy() |
| 287 | + processed_dataloader = strategy.process_dataloader(dataloader) |
| 288 | + assert processed_dataloader is dataloader |
| 289 | + mp_loader_mock.assert_not_called() # no-op |
| 290 | + |
| 291 | + isinstance_return = False |
280 | 292 | processed_dataloader = strategy.process_dataloader(dataloader) |
281 | 293 | mp_loader_mock.assert_called_with(dataloader, strategy.root_device) |
282 | 294 | assert processed_dataloader.dataset == processed_dataloader._loader.dataset |
|
0 commit comments