Skip to content

Commit 1dfbe7d

Browse files
committed
fix: move results' keys to device (#19813)
1 parent 6c70dd7 commit 1dfbe7d

File tree

3 files changed

+48
-10
lines changed

3 files changed

+48
-10
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1515

1616
- Added a flag `verbose` to the `seed_everything()` function ([#20108](https://github.com/Lightning-AI/pytorch-lightning/pull/20108))
1717

18+
- Fixed `_LoggerConnector`'s `_ResultMetric` to move all registered keys to the device of the logged value if needed ([#19813](https://github.com/Lightning-AI/pytorch-lightning/issues/19813))
19+
20+
-
1821

1922
### Changed
2023

src/lightning/pytorch/trainer/connectors/logger_connector/result.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -400,26 +400,19 @@ def log(
400400

401401
# register logged value if it doesn't exist
402402
if key not in self:
403-
self.register_key(key, meta, value)
403+
metric = _ResultMetric(meta, isinstance(value, Tensor))
404+
self[key] = metric
404405

405406
# check the stored metadata and the current one match
406407
elif meta != self[key].meta:
407408
raise MisconfigurationException(
408409
f"You called `self.log({name}, ...)` twice in `{fx}` with different arguments. This is not allowed"
409410
)
411+
self[key].to(value.device)
410412

411413
batch_size = self._extract_batch_size(self[key], batch_size, meta)
412414
self.update_metrics(key, value, batch_size)
413415

414-
def register_key(self, key: str, meta: _Metadata, value: _VALUE) -> None:
415-
"""Create one _ResultMetric object per value.
416-
417-
Value can be provided as a nested collection
418-
419-
"""
420-
metric = _ResultMetric(meta, isinstance(value, Tensor)).to(value.device)
421-
self[key] = metric
422-
423416
def update_metrics(self, key: str, value: _VALUE, batch_size: int) -> None:
424417
result_metric = self[key]
425418
# performance: avoid calling `__call__` to avoid the checks in `torch.nn.Module._call_impl`

tests/tests_pytorch/trainer/logging_/test_logger_connector.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535

3636
from tests_pytorch.models.test_hooks import get_members
3737

38+
from parity_pytorch import RunIf
39+
3840

3941
def test_fx_validator():
4042
funcs_name = get_members(Callback)
@@ -639,3 +641,43 @@ def test_result_collection_no_batch_size_extraction():
639641
assert results["training_step.epoch_log_val"].value == log_val * batch_size
640642
assert results["training_step.epoch_log_val"].cumulated_batch_size == batch_size
641643
assert results["training_step.epoch_sum_log_val"].value == log_val
644+
645+
646+
def test_result_collection_changes_device(): # mock_torch):
647+
results = _ResultCollection(training=True)
648+
fx, name = "training_step", "step_log_val"
649+
log_val = torch.tensor(7.0)
650+
651+
# same device as the original tensor
652+
results.log(fx, name, log_val, on_step=True, on_epoch=False, reduce_fx="mean")
653+
assert results[f"{fx}.{name}"].cumulated_batch_size.device == log_val.device
654+
655+
# moved to cpu
656+
cumulated_batch_size = results[f"{fx}.{name}"].cumulated_batch_size = Mock(spec=torch.Tensor)
657+
cumulated_batch_size.to.return_value = Mock(spec=torch.Tensor)
658+
results.cpu()
659+
cumulated_batch_size.to.assert_called_once_with(log_val.device)
660+
661+
# same device as the new tensor
662+
results.log(fx, name, log_val, on_step=True, on_epoch=False, reduce_fx="mean")
663+
cumulated_batch_size.to.return_value.to.assert_called_once_with(log_val.device)
664+
665+
666+
@RunIf(min_gpus=1)
667+
def test_gpu_result_collection_changes_device():
668+
results = _ResultCollection(training=True)
669+
fx, name = "training_step", "step_log_val"
670+
log_val = torch.tensor(7.0, device="cuda:0")
671+
672+
# same device as the original tensor
673+
results.log(fx, name, log_val, on_step=True, on_epoch=False, reduce_fx="mean")
674+
assert results[f"{fx}.{name}"].cumulated_batch_size.device == log_val.device
675+
676+
# moved to cpu
677+
cumulated_batch_size = results[f"{fx}.{name}"].cumulated_batch_size
678+
results.cpu()
679+
assert cumulated_batch_size.device == "cpu"
680+
681+
# same device as the new tensor
682+
results.log(fx, name, log_val, on_step=True, on_epoch=False, reduce_fx="mean")
683+
assert results[f"{fx}.{name}"].cumulated_batch_size.device == log_val.device

0 commit comments

Comments
 (0)