Skip to content

Commit 0bd1563

Browse files
committed
Fix test assertion
The assertion was referencing the variable before the object was moved to device
1 parent 46dfe5b commit 0bd1563

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

tests/tests_pytorch/trainer/logging_/test_logger_connector.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -643,7 +643,7 @@ def test_result_collection_no_batch_size_extraction():
643643

644644

645645
@RunIf(min_cuda_gpus=1)
646-
def test_gpu_result_collection_changes_device():
646+
def test_result_collection_changes_device():
647647
"""Test that the keys in the ResultCollection are moved to the device together with the collection."""
648648
results = _ResultCollection(training=True)
649649
fx, name = "training_step", "step_log_val"
@@ -654,9 +654,8 @@ def test_gpu_result_collection_changes_device():
654654
assert results[f"{fx}.{name}"].cumulated_batch_size.device == log_val.device
655655

656656
# moved to cpu
657-
cumulated_batch_size = results[f"{fx}.{name}"].cumulated_batch_size
658657
results.cpu()
659-
assert cumulated_batch_size.device == "cpu"
658+
assert results[f"{fx}.{name}"].cumulated_batch_size.device == torch.device("cpu")
660659

661660
# same device as the new tensor
662661
results.log(fx, name, log_val, on_step=True, on_epoch=False, reduce_fx="mean")

0 commit comments

Comments
 (0)