diff --git a/tests/zero_code_change/pt_utils.py b/tests/zero_code_change/pt_utils.py index 9bb1c073e..3d6cb78de 100644 --- a/tests/zero_code_change/pt_utils.py +++ b/tests/zero_code_change/pt_utils.py @@ -7,7 +7,6 @@ import torch.nn.functional as F import torchvision import torchvision.transforms as transforms -from packaging import version def get_dataloaders() -> Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]: @@ -15,26 +14,15 @@ def get_dataloaders() -> Tuple[torch.utils.data.DataLoader, torch.utils.data.Dat [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] ) - # Temporary Change to allow the test to run with pytorch 1.7 RC3 - # Smdebug breaks when num_workers>0 for Pytorch 1.7.0 - if version.parse(torch.__version__) >= version.parse("1.7.0"): - num_workers = 0 - else: - num_workers = 2 - trainset = torchvision.datasets.CIFAR10( root="./data", train=True, download=True, transform=transform ) - trainloader = torch.utils.data.DataLoader( - trainset, batch_size=4, shuffle=True, num_workers=num_workers - ) + trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2) testset = torchvision.datasets.CIFAR10( root="./data", train=False, download=True, transform=transform ) - testloader = torch.utils.data.DataLoader( - testset, batch_size=4, shuffle=False, num_workers=num_workers - ) + testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2) classes = ("plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck") return trainloader, testloader diff --git a/tests/zero_code_change/test_pytorch_integration.py b/tests/zero_code_change/test_pytorch_integration.py index 21e7759f8..eb6d06536 100644 --- a/tests/zero_code_change/test_pytorch_integration.py +++ b/tests/zero_code_change/test_pytorch_integration.py @@ -12,7 +12,6 @@ # Third Party import pytest -import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim @@ -23,10 +22,6 @@ from smdebug.core.utils import SagemakerSimulator, ScriptSimulator -@pytest.mark.skipif( - torch.__version__ == "1.7.0", - reason="Disabling the test temporarily until we root cause the version incompatibility", -) @pytest.mark.parametrize("script_mode", [False]) @pytest.mark.parametrize("use_loss_module", [True, False]) def test_pytorch(script_mode, use_loss_module):