|
35 | 35 |
|
36 | 36 | from tests_pytorch.models.test_hooks import get_members |
37 | 37 |
|
| 38 | +from parity_pytorch import RunIf |
| 39 | + |
38 | 40 |
|
39 | 41 | def test_fx_validator(): |
40 | 42 | funcs_name = get_members(Callback) |
@@ -639,3 +641,43 @@ def test_result_collection_no_batch_size_extraction(): |
639 | 641 | assert results["training_step.epoch_log_val"].value == log_val * batch_size |
640 | 642 | assert results["training_step.epoch_log_val"].cumulated_batch_size == batch_size |
641 | 643 | assert results["training_step.epoch_sum_log_val"].value == log_val |
| 644 | + |
| 645 | + |
| 646 | +def test_result_collection_changes_device(): # mock_torch): |
| 647 | + results = _ResultCollection(training=True) |
| 648 | + fx, name = "training_step", "step_log_val" |
| 649 | + log_val = torch.tensor(7.0) |
| 650 | + |
| 651 | + # same device as the original tensor |
| 652 | + results.log(fx, name, log_val, on_step=True, on_epoch=False, reduce_fx="mean") |
| 653 | + assert results[f"{fx}.{name}"].cumulated_batch_size.device == log_val.device |
| 654 | + |
| 655 | + # moved to cpu |
| 656 | + cumulated_batch_size = results[f"{fx}.{name}"].cumulated_batch_size = Mock(spec=torch.Tensor) |
| 657 | + cumulated_batch_size.to.return_value = Mock(spec=torch.Tensor) |
| 658 | + results.cpu() |
| 659 | + cumulated_batch_size.to.assert_called_once_with(log_val.device) |
| 660 | + |
| 661 | + # same device as the new tensor |
| 662 | + results.log(fx, name, log_val, on_step=True, on_epoch=False, reduce_fx="mean") |
| 663 | + cumulated_batch_size.to.return_value.to.assert_called_once_with(log_val.device) |
| 664 | + |
| 665 | + |
| 666 | +@RunIf(min_gpus=1) |
| 667 | +def test_gpu_result_collection_changes_device(): |
| 668 | + results = _ResultCollection(training=True) |
| 669 | + fx, name = "training_step", "step_log_val" |
| 670 | + log_val = torch.tensor(7.0, device="cuda:0") |
| 671 | + |
| 672 | + # same device as the original tensor |
| 673 | + results.log(fx, name, log_val, on_step=True, on_epoch=False, reduce_fx="mean") |
| 674 | + assert results[f"{fx}.{name}"].cumulated_batch_size.device == log_val.device |
| 675 | + |
| 676 | + # moved to cpu |
| 677 | + cumulated_batch_size = results[f"{fx}.{name}"].cumulated_batch_size |
| 678 | + results.cpu() |
| 679 | + assert cumulated_batch_size.device == "cpu" |
| 680 | + |
| 681 | + # same device as the new tensor |
| 682 | + results.log(fx, name, log_val, on_step=True, on_epoch=False, reduce_fx="mean") |
| 683 | + assert results[f"{fx}.{name}"].cumulated_batch_size.device == log_val.device |
0 commit comments