Skip to content

Commit 1b1241c

Browse files
authored
Fix TPU tests (#16628)
1 parent bf51844 commit 1b1241c

File tree

3 files changed

+35
-148
lines changed

3 files changed

+35
-148
lines changed

tests/tests_fabric/strategies/launchers/test_xla.py

Lines changed: 0 additions & 137 deletions
This file was deleted.

tests/tests_fabric/strategies/test_xla.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import os
1515
from functools import partial
1616
from unittest import mock
17-
from unittest.mock import Mock
17+
from unittest.mock import MagicMock, Mock
1818

1919
import pytest
2020
import torch
@@ -84,14 +84,26 @@ def test_tpu_reduce():
8484
@RunIf(tpu=True)
8585
@mock.patch("lightning.fabric.strategies.xla.XLAStrategy.root_device")
8686
def test_xla_mp_device_dataloader_attribute(_, monkeypatch):
87+
dataset = RandomDataset(32, 64)
88+
dataloader = DataLoader(dataset)
89+
strategy = XLAStrategy()
90+
isinstance_return = True
91+
8792
import torch_xla.distributed.parallel_loader as parallel_loader
8893

89-
mp_loader_mock = Mock()
94+
class MpDeviceLoaderMock(MagicMock):
95+
def __instancecheck__(self, instance):
96+
# to make `isinstance(dataloader, MpDeviceLoader)` pass with a mock as class
97+
return isinstance_return
98+
99+
mp_loader_mock = MpDeviceLoaderMock()
90100
monkeypatch.setattr(parallel_loader, "MpDeviceLoader", mp_loader_mock)
91101

92-
dataset = RandomDataset(32, 64)
93-
dataloader = DataLoader(dataset)
94-
strategy = XLAStrategy()
102+
processed_dataloader = strategy.process_dataloader(dataloader)
103+
assert processed_dataloader is dataloader
104+
mp_loader_mock.assert_not_called() # no-op
105+
106+
isinstance_return = False
95107
processed_dataloader = strategy.process_dataloader(dataloader)
96108
mp_loader_mock.assert_called_with(dataloader, strategy.root_device)
97109
assert processed_dataloader.dataset == processed_dataloader._loader.dataset

tests/tests_pytorch/accelerators/test_tpu.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import os
1616
from copy import deepcopy
1717
from unittest import mock
18-
from unittest.mock import Mock, patch
18+
from unittest.mock import MagicMock, patch
1919

2020
import pytest
2121
import torch
@@ -267,16 +267,28 @@ def test_xla_checkpoint_plugin_being_default(tpu_available):
267267

268268

269269
@RunIf(tpu=True)
270-
@patch("pytorch.lightning.strategies.tpu_spawn.TPUSpawnStrategy.root_device")
270+
@patch("lightning.pytorch.strategies.tpu_spawn.TPUSpawnStrategy.root_device")
271271
def test_xla_mp_device_dataloader_attribute(_, monkeypatch):
272+
dataset = RandomDataset(32, 64)
273+
dataloader = DataLoader(dataset)
274+
strategy = TPUSpawnStrategy()
275+
isinstance_return = True
276+
272277
import torch_xla.distributed.parallel_loader as parallel_loader
273278

274-
mp_loader_mock = Mock()
279+
class MpDeviceLoaderMock(MagicMock):
280+
def __instancecheck__(self, instance):
281+
# to make `isinstance(dataloader, MpDeviceLoader)` pass with a mock as class
282+
return isinstance_return
283+
284+
mp_loader_mock = MpDeviceLoaderMock()
275285
monkeypatch.setattr(parallel_loader, "MpDeviceLoader", mp_loader_mock)
276286

277-
dataset = RandomDataset(32, 64)
278-
dataloader = DataLoader(dataset)
279-
strategy = TPUSpawnStrategy()
287+
processed_dataloader = strategy.process_dataloader(dataloader)
288+
assert processed_dataloader is dataloader
289+
mp_loader_mock.assert_not_called() # no-op
290+
291+
isinstance_return = False
280292
processed_dataloader = strategy.process_dataloader(dataloader)
281293
mp_loader_mock.assert_called_with(dataloader, strategy.root_device)
282294
assert processed_dataloader.dataset == processed_dataloader._loader.dataset

0 commit comments

Comments
 (0)