Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
8b2d721
Create a new distribution mechanism for PT-XLA
Lokiiiiii Aug 4, 2022
45cf028
Adding new unit tests targetting PT-XLA distributed training
Lokiiiiii Aug 9, 2022
cd2a397
Reformatting according to guidelines
Lokiiiiii Aug 9, 2022
e42af3c
Linting changes
Lokiiiiii Aug 9, 2022
90bbc75
Linting changes
Lokiiiiii Aug 10, 2022
6deb2be
Linting changes
Lokiiiiii Aug 10, 2022
4a8be0c
Test Mock syntax fix
Lokiiiiii Aug 10, 2022
37015da
Test Mock syntax fix
Lokiiiiii Aug 10, 2022
4f909c0
Fixing syntax error
Lokiiiiii Aug 10, 2022
5ef16e3
Fixing syntax error
Lokiiiiii Aug 10, 2022
c40e721
Revert "Fixing syntax error"
Lokiiiiii Aug 10, 2022
c02c497
Fixing syntax error
Lokiiiiii Aug 10, 2022
e13bf28
Fixing syntax error
Lokiiiiii Aug 10, 2022
a726717
+ new test to target the PT-XLA Distributed runner
Lokiiiiii Aug 10, 2022
29d7dd0
+ new test to target the PT-XLA Distributed runner
Lokiiiiii Aug 10, 2022
cf36079
+ new test to target the PT-XLA Distributed runner
Lokiiiiii Aug 10, 2022
54bf7ed
+ new test to target the PT-XLA Distributed runner
Lokiiiiii Aug 10, 2022
430597d
+ new test to target the PT-XLA Distributed runner
Lokiiiiii Aug 10, 2022
8dd8fa5
+ new test to target the PT-XLA Distributed runner
Lokiiiiii Aug 10, 2022
39e833e
+ new test to target the PT-XLA Distributed runner
Lokiiiiii Aug 10, 2022
5901c8c
Add verbose reporting for tox tests
Lokiiiiii Aug 10, 2022
c95960f
Fixing syntax errors
Lokiiiiii Aug 10, 2022
55f1aa6
Fixing syntax errors
Lokiiiiii Aug 10, 2022
0597fcc
Fixing syntax errors
Lokiiiiii Aug 10, 2022
e26aa4d
Fixing syntax errors
Lokiiiiii Aug 10, 2022
d98c445
Adding more tests targeting PT-XLA DT mechanism
Lokiiiiii Aug 10, 2022
7750a8b
edits for flake8
Lokiiiiii Aug 10, 2022
e385d56
edits for black
Lokiiiiii Aug 10, 2022
b84e3bd
fixing test errors
Lokiiiiii Aug 10, 2022
6c844c4
fixing test errors
Lokiiiiii Aug 10, 2022
a3bbb96
fixing test errors
Lokiiiiii Aug 10, 2022
291fbed
fixing test errors
Lokiiiiii Aug 10, 2022
083bbd0
fixing test errors
Lokiiiiii Aug 10, 2022
28e4c76
fixing container build for unit testing
Lokiiiiii Aug 10, 2022
a26e701
fixing container build for unit testing
Lokiiiiii Aug 10, 2022
700dc87
retry tests
Lokiiiiii Aug 10, 2022
1578fd6
fixing container build for unit testing
Lokiiiiii Aug 10, 2022
de559b1
fixing container build for unit testing
Lokiiiiii Aug 10, 2022
95b8b71
fixing container execution for unit testing
Lokiiiiii Aug 10, 2022
b7c9260
fixing container execution for unit testing
Lokiiiiii Aug 10, 2022
0466f80
Refactoring some tests as integration tests
Lokiiiiii Aug 11, 2022
c6816f3
Refactoring some tests as integration tests
Lokiiiiii Aug 11, 2022
6e019e2
Refactoring some tests as integration tests
Lokiiiiii Aug 11, 2022
0db2adf
Refactoring some tests as integration tests
Lokiiiiii Aug 11, 2022
55cb859
Refactoring some tests as integration tests
Lokiiiiii Aug 11, 2022
cccaf14
Removing stale files
Lokiiiiii Aug 11, 2022
9d2b7ee
Removing stale test container
Lokiiiiii Aug 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions buildspec.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@ phases:
# run unit tests
- AWS_ACCESS_KEY_ID= AWS_SECRET_ACCESS_KEY= AWS_SESSION_TOKEN=
AWS_CONTAINER_CREDENTIALS_RELATIVE_URI= AWS_DEFAULT_REGION=
tox -e py37,py38 --parallel all -- test/unit
tox -v -e py37,py38 --parallel all -- test/unit

# build toolkit
- python setup.py sdist

# run functional tests
- $(aws ecr get-login --no-include-email --region us-west-2)
- IGNORE_COVERAGE=- tox -e py37,py38 -- test/functional
- IGNORE_COVERAGE=- tox -v -e py37,py38 -- test/functional

# build dummy container
- python setup.py sdist
- cp dist/sagemaker_training-*.tar.gz test/container/dummy/sagemaker_training.tar.gz
- cd test/container
- docker build -t sagemaker-training-toolkit-test:dummy -f dummy/Dockerfile .
Expand All @@ -33,4 +35,4 @@ phases:
- cd ../..

# run local integration tests
- IGNORE_COVERAGE=- tox -e py37,py38 -- test/integration/local
- IGNORE_COVERAGE=- tox -v -e py37,py38 -- test/integration/local
1 change: 1 addition & 0 deletions src/sagemaker_training/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
MULTI_WORKER_MIRRORED_STRATEGY_ENABLED = (
"sagemaker_multi_worker_mirrored_strategy_enabled"
) # type: str
PYTORCH_XLA_MULTI_WORKER_ENABLED = "sagemaker_pytorch_xla_multi_worker_enabled" # type: str
REGION_NAME_PARAM = "sagemaker_region" # type: str
REGION_NAME_ENV = REGION_NAME_PARAM.upper() # type: str
DEFAULT_INVOCATIONS_ACCEPT_ENV = "SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT" # type: str
Expand Down
150 changes: 150 additions & 0 deletions src/sagemaker_training/pytorch_xla.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Copyright 2018-2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License'). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the 'license' file accompanying this file. This file is
# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module contains functionality related to distributed training using
PT-XLA (PyTorch - Accelerated Linear Algebra)."""
from __future__ import absolute_import

import os

from sagemaker_training import (
_entry_point_type,
environment,
errors,
logging_config,
process,
)


logger = logging_config.get_logger()


class PyTorchXLARunner(process.ProcessRunner):
"""Responsible for PT-XLA distributed training."""

MESH_SERVICE_PORT = 53957
WORKER_PORT = 43857

def __init__(
self,
user_entry_point,
args,
env_vars,
processes_per_host,
master_hostname,
current_host,
hosts,
num_gpus,
):
"""Initialize a PyTorchXLARunner, which is responsible for distributed
training with PT-XLA.

Args:
user_entry_point (str): The name of the user entry point.
args ([str]): A list of arguments to include when executing the entry point.
env_vars (dict(str,str)): A dictionary of environment variables.
master_hostname (str): The master hostname.
current_host (str): The current hostname.
hosts ([str]): A list of hosts.
num_gpus (int): The number of GPUs available per host.
"""

super(PyTorchXLARunner, self).__init__(user_entry_point, args, env_vars, processes_per_host)

self._master_hostname = master_hostname
self._current_host = current_host
self._hosts = hosts
self._num_gpus = num_gpus

self._num_hosts = len(self._hosts)
self._rank = self._hosts.index(self._current_host)

def _setup(self): # type: () -> None
logger.info("Starting distributed training through PT-XLA Runtime.")
self._check_compatibility()

os.environ["XRT_HOST_ORDINAL"] = str(self._rank)
os.environ["XRT_SHARD_WORLD_SIZE"] = str(self._num_hosts)
address = "localservice:{};{}:" + str(self.WORKER_PORT)
os.environ["XRT_WORKERS"] = "|".join(
[address.format(i, host) for i, host in enumerate(self._hosts)]
)
os.environ["GPU_NUM_DEVICES"] = str(self._num_gpus)
Copy link

@harryzorus harryzorus Aug 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to set some extra environment variables to ensure EFA works correctly & optimally with NCCL. Please add the following:

# Add configuration for NCCL & EFA NCCL Plugin

# Set NCCL logging to info to debug
# customers issues easier
os.environ["NCCL_DEBUG"]="info"

# Use `simple` protocol to handle
# the out of order data delivery from EFA
os.environ["NCCL_PROTO"]="simple"

# use GPU RDMA when available
# applicable only to p4d.24xlarge
os.environ["FI_EFA_USE_DEVICE_RDMA"]="1"

# Use multiple connections per GPU to
# better saturate the EFA bandwidth
os.environ["OFI_NCCL_NIC_DUP_CONNS"]=str(args.num_gpus)

if self._num_hosts > 1:
os.environ[
"XRT_MESH_SERVICE_ADDRESS"
] = f"{self._master_hostname}:{self.MESH_SERVICE_PORT}"

logger.info("Completed environment setup for distributed training through PT-XLA Runtime.")

def _create_command(self):
entrypoint_type = _entry_point_type.get(environment.code_dir, self._user_entry_point)

if entrypoint_type is _entry_point_type.PYTHON_PACKAGE:
raise errors.ClientError(
"Distributed Training through PT-XLA is not supported for Python packages. "
"Please use a python script as the entry-point"
)
if entrypoint_type is _entry_point_type.PYTHON_PROGRAM:
return self._pytorch_xla_command() + [self._user_entry_point] + self._args
else:
raise errors.ClientError(
"Distributed Training through PT-XLA is only supported for Python scripts. "
"Please use a python script as the entry-point"
)

def _pytorch_xla_command(self):
return self._python_command() + [
"-m",
"torch_xla.distributed.xla_spawn",
"--num_gpus",
str(self._num_gpus),
]

def _check_compatibility(self):
self._check_processor_compatibility()
self._check_for_torch_xla()
self._check_for_sagemaker_integration()

def _check_for_sagemaker_integration(self):
# pylint: disable=no-self-use
try:
import torch_xla.distributed.xla_spawn # pylint: disable=unused-import # noqa: F401
except ModuleNotFoundError as exception:
raise ModuleNotFoundError(
"Unable to find SageMaker integration code in PT-XLA. "
"AWS SageMaker adds custom code on top of open source "
"PT-XLA to provide platform specific "
"optimizations. These SageMaker specific binaries are"
" shipped as part of our Deep Learning Containers."
" Please refer to "
"https://github.com/aws/deep-learning-containers"
"/blob/master/available_images.md"
) from exception

def _check_for_torch_xla(self):
# pylint: disable=no-self-use
try:
import torch_xla # pylint: disable=unused-import # noqa: F401
except ModuleNotFoundError as exception:
raise ModuleNotFoundError(
"Unable to find PT-XLA in the execution environment. "
"This distribution mechanism requires PT-XLA to be available"
" in the execution environment. "
"SageMaker Training Compiler provides ready-to-use containers with PT-XLA. "
"Please refer to https://github.com/aws/deep-learning-containers"
"/blob/master/available_images.md "
) from exception

def _check_processor_compatibility(self):
if not self._num_gpus > 0:
raise ValueError("Distributed training through PT-XLA is only supported for GPUs.")
15 changes: 14 additions & 1 deletion src/sagemaker_training/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import enum

from sagemaker_training import environment, mpi, params, process, smdataparallel
from sagemaker_training import environment, mpi, params, process, pytorch_xla, smdataparallel


class RunnerType(enum.Enum):
Expand All @@ -26,11 +26,13 @@ class RunnerType(enum.Enum):
MPI = "MPI"
Process = "Process"
SMDataParallel = "SMDataParallel"
PyTorchXLA = "PyTorchXLA"


ProcessRunnerType = RunnerType.Process
MPIRunnerType = RunnerType.MPI
SMDataParallelRunnerType = RunnerType.SMDataParallel
PyTorchXLARunnerType = RunnerType.PyTorchXLA


def get(identifier, user_entry_point=None, args=None, env_vars=None, extra_opts=None):
Expand Down Expand Up @@ -103,6 +105,17 @@ def _get_by_runner_type(
return mpi.WorkerRunner(
user_entry_point, args, env_vars, processes_per_host, env.master_hostname
)
elif identifier is RunnerType.PyTorchXLA:
return pytorch_xla.PyTorchXLARunner(
user_entry_point,
args,
env_vars,
processes_per_host,
env.master_hostname,
env.current_host,
env.distribution_hosts,
env.num_gpus,
)
elif identifier is RunnerType.Process:
return process.ProcessRunner(user_entry_point, args, env_vars, processes_per_host)
else:
Expand Down
Loading