diff --git a/tests/test_common/test_utils/test_checkpoints.py b/tests/test_common/test_utils/test_checkpoints.py index 512ad86..8f94a36 100644 --- a/tests/test_common/test_utils/test_checkpoints.py +++ b/tests/test_common/test_utils/test_checkpoints.py @@ -34,13 +34,14 @@ def test_save(self): loss = 0.5 epoch = 10 Checkpoints.save( - "./input/test_model.pth", + "tests/test_common/test_utils/test_model.pth", model_save, optimizer, epoch, loss, ) - self.assertTrue(os.path.isfile("./input/test_model.pth")) + self.assertTrue(os.path.isfile("tests/test_common/test_utils/test_model.pth")) + os.remove("tests/test_common/test_utils/test_model.pth") def test_load(self): """ @@ -51,11 +52,12 @@ def test_load(self): loss = 0.5 epoch = 10 Checkpoints.save( - "./input/test_model.pth", + "tests/test_common/test_utils/test_model.pth", model, optimizer, epoch, loss, ) - model_load = Checkpoints.load("./input/test_model.pth") + model_load = Checkpoints.load("tests/test_common/test_utils/test_model.pth") self.assertEqual(list(model.state_dict()), list(model_load.state_dict())) + os.remove("tests/test_common/test_utils/test_model.pth")