| 
23 | 23 | from smdebug.core.utils import SagemakerSimulator, ScriptSimulator  | 
24 | 24 | 
 
  | 
25 | 25 | 
 
  | 
 | 26 | +class CustomCrossEntropyLoss(nn.modules.loss._WeightedLoss):  | 
 | 27 | +    __constants__ = ["weight", "ignore_index", "reduction"]  | 
 | 28 | + | 
 | 29 | +    def __init__(  | 
 | 30 | +        self, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction="mean"  | 
 | 31 | +    ):  | 
 | 32 | +        super(CustomCrossEntropyLoss, self).__init__(weight, size_average, reduce, reduction)  | 
 | 33 | +        self.ignore_index = ignore_index  | 
 | 34 | + | 
 | 35 | +    def forward(self, input, target):  | 
 | 36 | +        return F.cross_entropy(  | 
 | 37 | +            input,  | 
 | 38 | +            target,  | 
 | 39 | +            weight=self.weight,  | 
 | 40 | +            ignore_index=self.ignore_index,  | 
 | 41 | +            reduction=self.reduction,  | 
 | 42 | +        )  | 
 | 43 | + | 
 | 44 | + | 
26 | 45 | @pytest.mark.skipif(  | 
27 | 46 |     torch.__version__ == "1.7.0",  | 
28 | 47 |     reason="Disabling the test temporarily until we root cause the version incompatibility",  | 
29 | 48 | )  | 
30 | 49 | @pytest.mark.parametrize("script_mode", [False])  | 
31 | 50 | @pytest.mark.parametrize("use_loss_module", [True, False])  | 
32 |  | -def test_pytorch(script_mode, use_loss_module):  | 
 | 51 | +@pytest.mark.parametrize("use_custom_loss_module", [True, False])  | 
 | 52 | +def test_pytorch(script_mode, use_loss_module, use_custom_loss_module):  | 
33 | 53 |     smd.del_hook()  | 
34 | 54 | 
 
  | 
35 | 55 |     sim_class = ScriptSimulator if script_mode else SagemakerSimulator  | 
36 | 56 |     with sim_class() as sim:  | 
37 | 57 |         trainloader, testloader = get_dataloaders()  | 
38 | 58 |         net = Net()  | 
39 |  | -        criterion = nn.CrossEntropyLoss()  | 
 | 59 | +        if use_custom_loss_module:  | 
 | 60 | +            criterion = CustomCrossEntropyLoss()  | 
 | 61 | +        else:  | 
 | 62 | +            criterion = nn.CrossEntropyLoss()  | 
40 | 63 |         optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)  | 
41 | 64 | 
 
  | 
42 | 65 |         if script_mode:  | 
 | 
0 commit comments