diff --git a/smdebug/pytorch/hook.py b/smdebug/pytorch/hook.py index 1eeb0636f..64dfa888c 100644 --- a/smdebug/pytorch/hook.py +++ b/smdebug/pytorch/hook.py @@ -3,6 +3,7 @@ # Third Party import torch import torch.distributed as dist +from torch.nn.modules.loss import _Loss # First Party from smdebug.core.collection import DEFAULT_PYTORCH_COLLECTIONS, CollectionKeys @@ -154,6 +155,9 @@ def forward_hook(self, module, inputs, outputs): if not self._get_collections_to_save_for_step(): return + if isinstance(module, _Loss): + module._module_name = module._get_name() + module_name = module._module_name # This overwhelms the logs; turn back on if you really need it # logger.debug("Processing the global step {0} for module {1}".format(self.step, module_name)) diff --git a/tests/zero_code_change/test_pytorch_integration.py b/tests/zero_code_change/test_pytorch_integration.py index 21e7759f8..e562269cc 100644 --- a/tests/zero_code_change/test_pytorch_integration.py +++ b/tests/zero_code_change/test_pytorch_integration.py @@ -23,20 +23,43 @@ from smdebug.core.utils import SagemakerSimulator, ScriptSimulator +class CustomCrossEntropyLoss(nn.modules.loss._WeightedLoss): + __constants__ = ["weight", "ignore_index", "reduction"] + + def __init__( + self, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction="mean" + ): + super(CustomCrossEntropyLoss, self).__init__(weight, size_average, reduce, reduction) + self.ignore_index = ignore_index + + def forward(self, input, target): + return F.cross_entropy( + input, + target, + weight=self.weight, + ignore_index=self.ignore_index, + reduction=self.reduction, + ) + + @pytest.mark.skipif( torch.__version__ == "1.7.0", reason="Disabling the test temporarily until we root cause the version incompatibility", ) @pytest.mark.parametrize("script_mode", [False]) @pytest.mark.parametrize("use_loss_module", [True, False]) -def test_pytorch(script_mode, use_loss_module): +@pytest.mark.parametrize("use_custom_loss_module", [True, False]) +def test_pytorch(script_mode, use_loss_module, use_custom_loss_module): smd.del_hook() sim_class = ScriptSimulator if script_mode else SagemakerSimulator with sim_class() as sim: trainloader, testloader = get_dataloaders() net = Net() - criterion = nn.CrossEntropyLoss() + if use_custom_loss_module: + criterion = CustomCrossEntropyLoss() + else: + criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) if script_mode: