-
Notifications
You must be signed in to change notification settings - Fork 137
Feature: Create a new distribution mechanism for PT-XLA #137
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
47 commits
Select commit
Hold shift + click to select a range
8b2d721
Create a new distribution mechanism for PT-XLA
Lokiiiiii 45cf028
Adding new unit tests targetting PT-XLA distributed training
Lokiiiiii cd2a397
Reformatting according to guidelines
Lokiiiiii e42af3c
Linting changes
Lokiiiiii 90bbc75
Linting changes
Lokiiiiii 6deb2be
Linting changes
Lokiiiiii 4a8be0c
Test Mock syntax fix
Lokiiiiii 37015da
Test Mock syntax fix
Lokiiiiii 4f909c0
Fixing syntax error
Lokiiiiii 5ef16e3
Fixing syntax error
Lokiiiiii c40e721
Revert "Fixing syntax error"
Lokiiiiii c02c497
Fixing syntax error
Lokiiiiii e13bf28
Fixing syntax error
Lokiiiiii a726717
+ new test to target the PT-XLA Distributed runner
Lokiiiiii 29d7dd0
+ new test to target the PT-XLA Distributed runner
Lokiiiiii cf36079
+ new test to target the PT-XLA Distributed runner
Lokiiiiii 54bf7ed
+ new test to target the PT-XLA Distributed runner
Lokiiiiii 430597d
+ new test to target the PT-XLA Distributed runner
Lokiiiiii 8dd8fa5
+ new test to target the PT-XLA Distributed runner
Lokiiiiii 39e833e
+ new test to target the PT-XLA Distributed runner
Lokiiiiii 5901c8c
Add verbose reporting for tox tests
Lokiiiiii c95960f
Fixing syntax errors
Lokiiiiii 55f1aa6
Fixing syntax errors
Lokiiiiii 0597fcc
Fixing syntax errors
Lokiiiiii e26aa4d
Fixing syntax errors
Lokiiiiii d98c445
Adding more tests targeting PT-XLA DT mechanism
Lokiiiiii 7750a8b
edits for flake8
Lokiiiiii e385d56
edits for black
Lokiiiiii b84e3bd
fixing test errors
Lokiiiiii 6c844c4
fixing test errors
Lokiiiiii a3bbb96
fixing test errors
Lokiiiiii 291fbed
fixing test errors
Lokiiiiii 083bbd0
fixing test errors
Lokiiiiii 28e4c76
fixing container build for unit testing
Lokiiiiii a26e701
fixing container build for unit testing
Lokiiiiii 700dc87
retry tests
Lokiiiiii 1578fd6
fixing container build for unit testing
Lokiiiiii de559b1
fixing container build for unit testing
Lokiiiiii 95b8b71
fixing container execution for unit testing
Lokiiiiii b7c9260
fixing container execution for unit testing
Lokiiiiii 0466f80
Refactoring some tests as integration tests
Lokiiiiii c6816f3
Refactoring some tests as integration tests
Lokiiiiii 6e019e2
Refactoring some tests as integration tests
Lokiiiiii 0db2adf
Refactoring some tests as integration tests
Lokiiiiii 55cb859
Refactoring some tests as integration tests
Lokiiiiii cccaf14
Removing stale files
Lokiiiiii 9d2b7ee
Removing stale test container
Lokiiiiii File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,150 @@ | ||
| # Copyright 2018-2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the 'License'). You | ||
| # may not use this file except in compliance with the License. A copy of | ||
| # the License is located at | ||
| # | ||
| # http://aws.amazon.com/apache2.0/ | ||
| # | ||
| # or in the 'license' file accompanying this file. This file is | ||
| # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
| # ANY KIND, either express or implied. See the License for the specific | ||
| # language governing permissions and limitations under the License. | ||
| """This module contains functionality related to distributed training using | ||
| PT-XLA (PyTorch - Accelerated Linear Algebra).""" | ||
| from __future__ import absolute_import | ||
|
|
||
| import os | ||
|
|
||
| from sagemaker_training import ( | ||
| _entry_point_type, | ||
| environment, | ||
| errors, | ||
| logging_config, | ||
| process, | ||
| ) | ||
|
|
||
|
|
||
| logger = logging_config.get_logger() | ||
|
|
||
|
|
||
| class PyTorchXLARunner(process.ProcessRunner): | ||
| """Responsible for PT-XLA distributed training.""" | ||
|
|
||
| MESH_SERVICE_PORT = 53957 | ||
| WORKER_PORT = 43857 | ||
|
|
||
| def __init__( | ||
| self, | ||
| user_entry_point, | ||
| args, | ||
| env_vars, | ||
| processes_per_host, | ||
| master_hostname, | ||
| current_host, | ||
| hosts, | ||
| num_gpus, | ||
| ): | ||
| """Initialize a PyTorchXLARunner, which is responsible for distributed | ||
| training with PT-XLA. | ||
|
|
||
| Args: | ||
| user_entry_point (str): The name of the user entry point. | ||
| args ([str]): A list of arguments to include when executing the entry point. | ||
| env_vars (dict(str,str)): A dictionary of environment variables. | ||
| master_hostname (str): The master hostname. | ||
| current_host (str): The current hostname. | ||
| hosts ([str]): A list of hosts. | ||
| num_gpus (int): The number of GPUs available per host. | ||
| """ | ||
|
|
||
| super(PyTorchXLARunner, self).__init__(user_entry_point, args, env_vars, processes_per_host) | ||
|
|
||
| self._master_hostname = master_hostname | ||
| self._current_host = current_host | ||
| self._hosts = hosts | ||
| self._num_gpus = num_gpus | ||
|
|
||
| self._num_hosts = len(self._hosts) | ||
| self._rank = self._hosts.index(self._current_host) | ||
|
|
||
| def _setup(self): # type: () -> None | ||
| logger.info("Starting distributed training through PT-XLA Runtime.") | ||
| self._check_compatibility() | ||
|
|
||
| os.environ["XRT_HOST_ORDINAL"] = str(self._rank) | ||
| os.environ["XRT_SHARD_WORLD_SIZE"] = str(self._num_hosts) | ||
| address = "localservice:{};{}:" + str(self.WORKER_PORT) | ||
| os.environ["XRT_WORKERS"] = "|".join( | ||
| [address.format(i, host) for i, host in enumerate(self._hosts)] | ||
| ) | ||
| os.environ["GPU_NUM_DEVICES"] = str(self._num_gpus) | ||
| if self._num_hosts > 1: | ||
| os.environ[ | ||
| "XRT_MESH_SERVICE_ADDRESS" | ||
| ] = f"{self._master_hostname}:{self.MESH_SERVICE_PORT}" | ||
|
|
||
| logger.info("Completed environment setup for distributed training through PT-XLA Runtime.") | ||
|
|
||
| def _create_command(self): | ||
| entrypoint_type = _entry_point_type.get(environment.code_dir, self._user_entry_point) | ||
|
|
||
| if entrypoint_type is _entry_point_type.PYTHON_PACKAGE: | ||
| raise errors.ClientError( | ||
| "Distributed Training through PT-XLA is not supported for Python packages. " | ||
| "Please use a python script as the entry-point" | ||
| ) | ||
| if entrypoint_type is _entry_point_type.PYTHON_PROGRAM: | ||
| return self._pytorch_xla_command() + [self._user_entry_point] + self._args | ||
| else: | ||
| raise errors.ClientError( | ||
| "Distributed Training through PT-XLA is only supported for Python scripts. " | ||
| "Please use a python script as the entry-point" | ||
| ) | ||
|
|
||
| def _pytorch_xla_command(self): | ||
| return self._python_command() + [ | ||
| "-m", | ||
| "torch_xla.distributed.xla_spawn", | ||
| "--num_gpus", | ||
| str(self._num_gpus), | ||
| ] | ||
|
|
||
| def _check_compatibility(self): | ||
| self._check_processor_compatibility() | ||
| self._check_for_torch_xla() | ||
| self._check_for_sagemaker_integration() | ||
|
|
||
| def _check_for_sagemaker_integration(self): | ||
| # pylint: disable=no-self-use | ||
| try: | ||
| import torch_xla.distributed.xla_spawn # pylint: disable=unused-import # noqa: F401 | ||
| except ModuleNotFoundError as exception: | ||
| raise ModuleNotFoundError( | ||
| "Unable to find SageMaker integration code in PT-XLA. " | ||
| "AWS SageMaker adds custom code on top of open source " | ||
| "PT-XLA to provide platform specific " | ||
| "optimizations. These SageMaker specific binaries are" | ||
| " shipped as part of our Deep Learning Containers." | ||
| " Please refer to " | ||
| "https://github.com/aws/deep-learning-containers" | ||
| "/blob/master/available_images.md" | ||
| ) from exception | ||
|
|
||
| def _check_for_torch_xla(self): | ||
| # pylint: disable=no-self-use | ||
| try: | ||
| import torch_xla # pylint: disable=unused-import # noqa: F401 | ||
| except ModuleNotFoundError as exception: | ||
| raise ModuleNotFoundError( | ||
| "Unable to find PT-XLA in the execution environment. " | ||
| "This distribution mechanism requires PT-XLA to be available" | ||
| " in the execution environment. " | ||
| "SageMaker Training Compiler provides ready-to-use containers with PT-XLA. " | ||
| "Please refer to https://github.com/aws/deep-learning-containers" | ||
| "/blob/master/available_images.md " | ||
| ) from exception | ||
|
|
||
| def _check_processor_compatibility(self): | ||
| if not self._num_gpus > 0: | ||
| raise ValueError("Distributed training through PT-XLA is only supported for GPUs.") | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to set some extra environment variables to ensure EFA works correctly & optimally with NCCL. Please add the following: