Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added a flag `verbose` to the `seed_everything()` function ([#20108](https://github.com/Lightning-AI/pytorch-lightning/pull/20108))

-

### Changed

Expand Down Expand Up @@ -44,6 +45,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed an issue that would cause too many printouts of the seed info when using `seed_everything()` ([#20108](https://github.com/Lightning-AI/pytorch-lightning/pull/20108))

- Fixed `_LoggerConnector`'s `_ResultMetric` to move all registered keys to the device of the logged value if needed ([#19814](https://github.com/Lightning-AI/pytorch-lightning/issues/19814))



## [2.3.0] - 2024-06-13
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -400,26 +400,19 @@ def log(

# register logged value if it doesn't exist
if key not in self:
self.register_key(key, meta, value)
metric = _ResultMetric(meta, isinstance(value, Tensor))
self[key] = metric

# check the stored metadata and the current one match
elif meta != self[key].meta:
raise MisconfigurationException(
f"You called `self.log({name}, ...)` twice in `{fx}` with different arguments. This is not allowed"
)
self[key].to(value.device)

batch_size = self._extract_batch_size(self[key], batch_size, meta)
self.update_metrics(key, value, batch_size)

def register_key(self, key: str, meta: _Metadata, value: _VALUE) -> None:
"""Create one _ResultMetric object per value.

Value can be provided as a nested collection

"""
metric = _ResultMetric(meta, isinstance(value, Tensor)).to(value.device)
self[key] = metric

def update_metrics(self, key: str, value: _VALUE, batch_size: int) -> None:
result_metric = self[key]
# performance: avoid calling `__call__` to avoid the checks in `torch.nn.Module._call_impl`
Expand Down
21 changes: 21 additions & 0 deletions tests/tests_pytorch/trainer/logging_/test_logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from torchmetrics import Accuracy, MeanAbsoluteError, MeanSquaredError, MetricCollection
from torchmetrics import AveragePrecision as AvgPre

from tests_pytorch.helpers.runif import RunIf
from tests_pytorch.models.test_hooks import get_members


Expand Down Expand Up @@ -639,3 +640,23 @@ def test_result_collection_no_batch_size_extraction():
assert results["training_step.epoch_log_val"].value == log_val * batch_size
assert results["training_step.epoch_log_val"].cumulated_batch_size == batch_size
assert results["training_step.epoch_sum_log_val"].value == log_val


@RunIf(min_cuda_gpus=1)
def test_result_collection_changes_device():
"""Test that the keys in the ResultCollection are moved to the device together with the collection."""
results = _ResultCollection(training=True)
fx, name = "training_step", "step_log_val"
log_val = torch.tensor(7.0, device="cuda:0")

# same device as the original tensor
results.log(fx, name, log_val, on_step=True, on_epoch=False, reduce_fx="mean")
assert results[f"{fx}.{name}"].cumulated_batch_size.device == log_val.device

# moved to cpu
results.cpu()
assert results[f"{fx}.{name}"].cumulated_batch_size.device == torch.device("cpu")

# same device as the new tensor
results.log(fx, name, log_val, on_step=True, on_epoch=False, reduce_fx="mean")
assert results[f"{fx}.{name}"].cumulated_batch_size.device == log_val.device