Skip to content

Commit 99dcf93

Browse files
committed
fix: move results' keys to device (#19813)
1 parent b9680a3 commit 99dcf93

File tree

1 file changed

+3
-10
lines changed
  • src/lightning/pytorch/trainer/connectors/logger_connector

1 file changed

+3
-10
lines changed

src/lightning/pytorch/trainer/connectors/logger_connector/result.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -403,26 +403,19 @@ def log(
403403

404404
# register logged value if it doesn't exist
405405
if key not in self:
406-
self.register_key(key, meta, value)
406+
metric = _ResultMetric(meta, isinstance(value, Tensor))
407+
self[key] = metric
407408

408409
# check the stored metadata and the current one match
409410
elif meta != self[key].meta:
410411
raise MisconfigurationException(
411412
f"You called `self.log({name}, ...)` twice in `{fx}` with different arguments. This is not allowed"
412413
)
414+
self[key].to(value.device)
413415

414416
batch_size = self._extract_batch_size(self[key], batch_size, meta)
415417
self.update_metrics(key, value, batch_size)
416418

417-
def register_key(self, key: str, meta: _Metadata, value: _VALUE) -> None:
418-
"""Create one _ResultMetric object per value.
419-
420-
Value can be provided as a nested collection
421-
422-
"""
423-
metric = _ResultMetric(meta, isinstance(value, Tensor)).to(value.device)
424-
self[key] = metric
425-
426419
def update_metrics(self, key: str, value: _VALUE, batch_size: int) -> None:
427420
result_metric = self[key]
428421
# performance: avoid calling `__call__` to avoid the checks in `torch.nn.Module._call_impl`

0 commit comments

Comments
 (0)