Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions code_soup/common/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from code_soup.common.utils.checkpoints import Checkpoints
from code_soup.common.utils.seeding import Seeding
38 changes: 31 additions & 7 deletions tests/test_common/test_utils/test_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,59 @@

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models

from code_soup.common.utils.checkpoints import Checkpoints
from code_soup.common.utils import Checkpoints


class TheModelClass(nn.Module):
"""
Model class for tests
"""

def __init__(self):
super(TheModelClass, self).__init__()
self.dense = nn.Linear(2, 1)
self.activation = nn.Sigmoid()

def forward(self, x):
return self.activation(self.dense(x))


class TestCheckpoints(unittest.TestCase):
def test_save(self):
"""
Test that the model is saved
"""
model_save = models.resnet18(pretrained=True)
model_save = TheModelClass()
optimizer = optim.SGD(model_save.parameters(), lr=0.01, momentum=0.9)
loss = 0.5
epoch = 10
Checkpoints.save(
"tests/test_common/test_utils/test_model.pth",
"./input/test_model.pth",
model_save,
optimizer,
epoch,
loss,
)
self.assert_(os.path.isfile("tests/test_common/test_utils/test_model.pth"))
self.assertTrue(os.path.isfile("./input/test_model.pth"))

def test_load(self):
"""
Test that the model is loaded
"""
model = models.resnet18()
model = Checkpoints.load("tests/test_common/test_utils/test_model.pth")
model_load = models.resnet18(pretrained=True)
model = TheModelClass()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
loss = 0.5
epoch = 10
Checkpoints.save(
"./input/test_model.pth",
model,
optimizer,
epoch,
loss,
)
model_load = Checkpoints.load("./input/test_model.pth")
self.assertEqual(list(model.state_dict()), list(model_load.state_dict()))
Binary file removed tests/test_common/test_utils/test_model.pth
Binary file not shown.
8 changes: 5 additions & 3 deletions tests/test_common/test_utils/test_seeding.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,19 @@

import numpy as np
import torch
import torchvision.transforms as transforms
from torchvision.transforms.transforms import ToTensor

from code_soup.common.utils.seeding import Seeding
from code_soup.common.utils import Seeding


class TestSeeding(unittest.TestCase):
"""Test the seed function."""

def test_seed(self):
"""Test that the seed is set."""
random.seed(42)
initial_state = random.getstate()
Seeding.seed(42)
final_state = random.getstate()
self.assertEqual(initial_state, final_state)
self.assertEqual(np.random.get_state()[1][0], 42)
self.assertEqual(torch.get_rng_state().tolist()[0], 42)