Skip to content

Commit c4c2a8a

Browse files
ChaiBapchyaChoiByungWook
authored andcommitted
feature: add data parallelism support (#11) (#12) (#13)
1 parent 925fcd1 commit c4c2a8a

File tree

9 files changed

+365
-1
lines changed

9 files changed

+365
-1
lines changed

buildspec.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ version: 0.2
22

33
env:
44
variables:
5-
FRAMEWORK_VERSION: '1.4.0'
5+
FRAMEWORK_VERSION: '1.6.0'
66
CPU_INSTANCE_TYPE: 'ml.c4.xlarge'
77
ECR_REPO: 'sagemaker-test'
88

src/sagemaker_pytorch_container/training.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from sagemaker_training import entry_point, environment, errors, runner
2121

2222
MASTER_PORT = '7777'
23+
LAUNCH_SMDATAPARALLEL_ENV_NAME = 'sagemaker_distributed_dataparallel_enabled'
2324

2425
logger = logging.getLogger(__name__)
2526

@@ -50,8 +51,15 @@ def train(training_environment):
5051

5152
mpi_enabled = training_environment.additional_framework_parameters.get('sagemaker_mpi_enabled')
5253

54+
smdataparallel_enabled = training_environment.additional_framework_parameters.get(
55+
LAUNCH_SMDATAPARALLEL_ENV_NAME, False
56+
)
57+
5358
if mpi_enabled:
5459
runner_type = runner.MPIRunnerType
60+
elif smdataparallel_enabled:
61+
runner_type = runner.SMDataParallelRunnerType
62+
logger.info('Invoking SMDataParallel')
5563
else:
5664
runner_type = runner.ProcessRunnerType
5765

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
ARG region
2+
from 763104351884.dkr.ecr.$region.amazonaws.com/pytorch-training:1.6.0-cpu-py36-ubuntu16.04
3+
4+
COPY lib/changehostname.c /
5+
COPY lib/start_with_right_hostname.sh /usr/local/bin/start_with_right_hostname.sh
6+
RUN chmod +x /usr/local/bin/start_with_right_hostname.sh
7+
8+
COPY dist/sagemaker_pytorch_training-*.tar.gz /sagemaker_pytorch_training.tar.gz
9+
RUN pip install --upgrade --no-cache-dir /sagemaker_pytorch_training.tar.gz && \
10+
rm /sagemaker_pytorch_training.tar.gz
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
ARG region
2+
from 763104351884.dkr.ecr.$region.amazonaws.com/pytorch-training:1.6.0-gpu-py36-cu101-ubuntu16.04
3+
4+
# TODO: Remove once the 1.6.0-gpu-py3 DLC image installs mpi4py
5+
RUN pip3 install mpi4py==3.0.3
6+
7+
# TODO: Remove once the 1.6.0-gpu-py3 DLC image fixes OpenSSH config
8+
# Configure OpenSSH so that nodes can communicate with each other
9+
RUN mkdir -p /var/run/sshd && \
10+
sed 's@session\s*required\s*pam_loginuid.so@session optional pam_loginuid.so@g' -i /etc/pam.d/sshd
11+
12+
RUN rm -rf /root/.ssh/ && \
13+
mkdir -p /root/.ssh/ && \
14+
ssh-keygen -q -t rsa -N '' -f /root/.ssh/id_rsa && \
15+
cp /root/.ssh/id_rsa.pub /root/.ssh/authorized_keys \
16+
&& printf "Host *\n StrictHostKeyChecking no\n" >> /root/.ssh/config
17+
18+
# TODO: Remove once the 1.6.0-gpu-py3 DLC image fixes MPI config
19+
# Comment line in MPI config to prevent mutually exclusive MCA settings
20+
RUN sed -i '62,62 s/^/#/' /home/.openmpi/etc/openmpi-mca-params.conf
21+
22+
COPY lib/changehostname.c /
23+
COPY lib/start_with_right_hostname.sh /usr/local/bin/start_with_right_hostname.sh
24+
RUN chmod +x /usr/local/bin/start_with_right_hostname.sh
25+
26+
COPY dist/sagemaker_pytorch_training-*.tar.gz /sagemaker_pytorch_training.tar.gz
27+
RUN pip install --upgrade --no-cache-dir /sagemaker_pytorch_training.tar.gz && \
28+
rm /sagemaker_pytorch_training.tar.gz
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from pytorch/pytorch:1.6.0-cuda10.1-cudnn7-runtime
2+
3+
RUN apt-get update && apt-get install -y --no-install-recommends \
4+
jq \
5+
build-essential \
6+
cmake \
7+
gcc
8+
RUN rm -rf /var/lib/apt/lists/*
9+
10+
COPY lib/changehostname.c /
11+
COPY lib/start_with_right_hostname.sh /usr/local/bin/start_with_right_hostname.sh
12+
RUN chmod +x /usr/local/bin/start_with_right_hostname.sh
13+
14+
COPY dist/sagemaker_pytorch_training-*.tar.gz /sagemaker_pytorch_training.tar.gz
15+
RUN pip install --no-cache-dir /sagemaker_pytorch_training.tar.gz && \
16+
rm /sagemaker_pytorch_training.tar.gz
17+
18+
ENV SAGEMAKER_TRAINING_MODULE=sagemaker_pytorch_container.training:main
19+
20+
WORKDIR /
21+
22+
# Starts framework
23+
ENTRYPOINT ["bash", "-m", "start_with_right_hostname.sh"]
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import json
16+
import os
17+
import tarfile
18+
19+
import pytest
20+
from sagemaker.pytorch import PyTorch
21+
22+
from integration import resources_path
23+
from utils.local_mode_utils import assert_files_exist
24+
25+
# TODO: Enable the test once SMDataParallel DLC is publicly accessibly
26+
@pytest.mark.skip(reason="SMDataParallel DLC is not publicly accessible")
27+
@pytest.mark.skip_cpu
28+
@pytest.mark.skip_generic
29+
def test_smdataparallel_training(sagemaker_local_session, image_uri, framework_version, tmpdir):
30+
output_path = 'file://' + str(tmpdir)
31+
32+
estimator = PyTorch(
33+
entry_point=os.path.join(resources_path, 'mnist', 'smdataparallel_mnist.py'),
34+
role='SageMakerRole',
35+
train_instance_type="local_gpu",
36+
sagemaker_session=sagemaker_local_session,
37+
train_instance_count=1,
38+
image_name=image_uri,
39+
output_path=output_path,
40+
framework_version=framework_version,
41+
hyperparameters={'sagemaker_distributed_dataparallel_enabled': True, "save-model": ""})
42+
success_files = {
43+
'model': ['mnist_cnn.pt'],
44+
'output': ['success'],
45+
}
46+
_train_and_assert_success(estimator, str(tmpdir), success_files)
47+
48+
49+
def _train_and_assert_success(estimator, output_path, output_files):
50+
estimator.fit()
51+
assert_files_exist(output_path, output_files)
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import json
16+
import os
17+
import tarfile
18+
19+
import pytest
20+
from sagemaker.pytorch import PyTorch
21+
22+
from integration import resources_path, DEFAULT_TIMEOUT
23+
from integration.sagemaker.timeout import timeout
24+
25+
26+
@pytest.mark.skip(reason="SMDataParallel DLC is not publicly accessible")
27+
@pytest.mark.skip_cpu
28+
@pytest.mark.skip_generic
29+
@pytest.mark.parametrize(
30+
"instances, train_instance_type",
31+
[(1, "ml.p3.16xlarge"), (2, "ml.p3.16xlarge"), (1, "ml.p3dn.24xlarge"), (2, "ml.p3dn.24xlarge")],
32+
)
33+
def test_smdataparallel_training(instances, train_instance_type, sagemaker_session, image_uri, framework_version, tmpdir):
34+
default_bucket = sagemaker_session.default_bucket()
35+
output_path = "s3://" + os.path.join(default_bucket, "pytorch/smdataparallel")
36+
37+
estimator = PyTorch(
38+
entry_point=os.path.join(resources_path, "mnist", "smdataparallel_mnist.py"),
39+
role="SageMakerRole",
40+
train_instance_type=train_instance_type,
41+
sagemaker_session=sagemaker_session,
42+
train_instance_count=instances,
43+
image_name=image_uri,
44+
output_path=output_path,
45+
framework_version=framework_version,
46+
hyperparameters={
47+
"sagemaker_distributed_dataparallel_enabled": True
48+
}
49+
)
50+
51+
with timeout(minutes=DEFAULT_TIMEOUT):
52+
estimator.fit()
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance with the License. A copy of the License is located at
4+
#
5+
# http://aws.amazon.com/apache2.0/
6+
#
7+
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions and limitations under the License.
8+
9+
from __future__ import print_function
10+
import argparse
11+
import time
12+
import torch
13+
import torch.nn as nn
14+
import torch.nn.functional as F
15+
import torch.optim as optim
16+
from torchvision import datasets, transforms
17+
from torch.optim.lr_scheduler import StepLR
18+
from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP
19+
import smdistributed.dataparallel.torch.distributed as dist
20+
dist.init_process_group()
21+
22+
class Net(nn.Module):
23+
def __init__(self):
24+
super(Net, self).__init__()
25+
self.conv1 = nn.Conv2d(1, 32, 3, 1)
26+
self.conv2 = nn.Conv2d(32, 64, 3, 1)
27+
self.dropout1 = nn.Dropout2d(0.25)
28+
self.dropout2 = nn.Dropout2d(0.5)
29+
self.fc1 = nn.Linear(9216, 128)
30+
self.fc2 = nn.Linear(128, 10)
31+
32+
def forward(self, x):
33+
x = self.conv1(x)
34+
x = F.relu(x)
35+
x = self.conv2(x)
36+
x = F.relu(x)
37+
x = F.max_pool2d(x, 2)
38+
x = self.dropout1(x)
39+
x = torch.flatten(x, 1)
40+
x = self.fc1(x)
41+
x = F.relu(x)
42+
x = self.dropout2(x)
43+
x = self.fc2(x)
44+
output = F.log_softmax(x, dim=1)
45+
return output
46+
47+
48+
def train(args, model, device, train_loader, optimizer, epoch):
49+
model.train()
50+
for batch_idx, (data, target) in enumerate(train_loader):
51+
data, target = data.to(device), target.to(device)
52+
optimizer.zero_grad()
53+
output = model(data)
54+
loss = F.nll_loss(output, target)
55+
loss.backward()
56+
optimizer.step()
57+
if batch_idx % args.log_interval == 0 and args.rank == 0:
58+
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
59+
epoch, batch_idx * len(data) * args.world_size, len(train_loader.dataset),
60+
100. * batch_idx / len(train_loader), loss.item()))
61+
if args.verbose:
62+
print('Batch', batch_idx, "from rank", args.rank)
63+
64+
65+
def test(model, device, test_loader):
66+
model.eval()
67+
test_loss = 0
68+
correct = 0
69+
with torch.no_grad():
70+
for data, target in test_loader:
71+
data, target = data.to(device), target.to(device)
72+
output = model(data)
73+
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
74+
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
75+
correct += pred.eq(target.view_as(pred)).sum().item()
76+
77+
test_loss /= len(test_loader.dataset)
78+
79+
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
80+
test_loss, correct, len(test_loader.dataset),
81+
100. * correct / len(test_loader.dataset)))
82+
83+
84+
def main():
85+
# Training settings
86+
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
87+
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
88+
help='input batch size for training (default: 64)')
89+
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
90+
help='input batch size for testing (default: 1000)')
91+
parser.add_argument('--epochs', type=int, default=14, metavar='N',
92+
help='number of epochs to train (default: 14)')
93+
parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
94+
help='learning rate (default: 1.0)')
95+
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
96+
help='Learning rate step gamma (default: 0.7)')
97+
parser.add_argument('--seed', type=int, default=1, metavar='S',
98+
help='random seed (default: 1)')
99+
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
100+
help='how many batches to wait before logging training status')
101+
parser.add_argument('--save-model', action='store_true', default=False,
102+
help='For Saving the current Model')
103+
parser.add_argument('--verbose', action='store_true', default=False,
104+
help='For displaying SMDataParallel-specific logs')
105+
parser.add_argument('--data-path', type=str, default='/tmp/data', help='Path for downloading '
106+
'the MNIST dataset')
107+
108+
args = parser.parse_args()
109+
args.world_size = dist.get_world_size()
110+
args.rank = rank = dist.get_rank()
111+
args.local_rank = local_rank = dist.get_local_rank()
112+
args.lr = 1.0
113+
args.batch_size //= args.world_size // 8
114+
args.batch_size = max(args.batch_size, 1)
115+
data_path = args.data_path
116+
117+
if args.verbose:
118+
print('Hello from rank', rank, 'of local_rank',
119+
local_rank, 'in world size of', args.world_size)
120+
121+
if not torch.cuda.is_available():
122+
raise Exception("Must run SMDataParallel MNIST example on CUDA-capable devices.")
123+
124+
torch.manual_seed(args.seed)
125+
126+
device = torch.device("cuda")
127+
128+
if local_rank == 0:
129+
train_dataset = datasets.MNIST(data_path, train=True, download=True,
130+
transform=transforms.Compose([
131+
transforms.ToTensor(),
132+
transforms.Normalize((0.1307,), (0.3081,))
133+
]))
134+
else:
135+
time.sleep(8)
136+
train_dataset = datasets.MNIST(data_path, train=True, download=False,
137+
transform=transforms.Compose([
138+
transforms.ToTensor(),
139+
transforms.Normalize((0.1307,), (0.3081,))
140+
]))
141+
142+
train_sampler = torch.utils.data.distributed.DistributedSampler(
143+
train_dataset,
144+
num_replicas=args.world_size,
145+
rank=rank)
146+
train_loader = torch.utils.data.DataLoader(
147+
train_dataset,
148+
batch_size=args.batch_size,
149+
shuffle=False,
150+
num_workers=0,
151+
pin_memory=True,
152+
sampler=train_sampler)
153+
if rank == 0:
154+
test_loader = torch.utils.data.DataLoader(
155+
datasets.MNIST(data_path, train=False, transform=transforms.Compose([
156+
transforms.ToTensor(),
157+
transforms.Normalize((0.1307,), (0.3081,))
158+
])),
159+
batch_size=args.test_batch_size, shuffle=True)
160+
161+
model = DDP(Net().to(device))
162+
torch.cuda.set_device(local_rank)
163+
model.cuda(local_rank)
164+
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
165+
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
166+
for epoch in range(1, args.epochs + 1):
167+
train(args, model, device, train_loader, optimizer, epoch)
168+
if rank == 0:
169+
test(model, device, test_loader)
170+
scheduler.step()
171+
172+
if args.save_model:
173+
torch.save(model.state_dict(), "mnist_cnn.pt")
174+
175+
176+
if __name__ == '__main__':
177+
main()

0 commit comments

Comments
 (0)