Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/sagemaker_training/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,9 @@ def __init__(self, resource_config=None, input_data_config=None, hyperparameters
self._master_hostname = list(hosts)[0]
self._is_master = current_host == self._master_hostname

mp_parameters = os.environ.get(params.SM_HP_MP_PARAMETERS)
self._is_modelparallel_enabled = mp_parameters and mp_parameters != "{}"

@property
def model_dir(self): # type: () -> str
"""The directory where models should be saved.
Expand Down Expand Up @@ -909,6 +912,15 @@ def framework_module(self): # type: () -> str
"""
return self._framework_module

@property
def is_modelparallel_enabled(self): # type: () -> bool
"""Whether SM ModelParallel is enabled.

Returns:
bool: True if SM ModelParallel is enabled
"""
return self._is_modelparallel_enabled


def write_env_vars(env_vars=None): # type: (dict) -> None
"""Write the dictionary env_vars in the system, as environment variables.
Expand Down
75 changes: 71 additions & 4 deletions src/sagemaker_training/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"""This module contains functionality related to distributed training using
MPI (Message Passing Interface)."""
import argparse
import inspect
from inspect import getfile, isclass
import logging
import os
import subprocess
Expand All @@ -23,11 +23,31 @@
import psutil

import gethostname
from sagemaker_training import logging_config, process, timeout
from sagemaker_training import environment, errors, logging_config, process, timeout

logger = logging_config.get_logger()
logging.getLogger("paramiko").setLevel(logging.INFO)

exception_classes = None
try:
from smdistributed.modelparallel.backend import exceptions

# list of exceptions SMMP wants training toolkit to catch and log
exception_classes = [x for x in dir(exceptions) if isclass(getattr(exceptions, x))]
except ImportError:
logger.info("No exception classes found in smdistributed.modelparallel.backend")

try:
from smdistributed.modelparallel.torch import exceptions as torch_exceptions

# list of torch exceptions SMMP wants training toolkit to catch and log
exception_classes += [x for x in dir(torch_exceptions) if isclass(getattr(torch_exceptions, x))]
except ImportError:
logger.info("No torch exception classes found in smdistributed.modelparallel.torch")

if not exception_classes:
exception_classes = [errors.ExecuteUserScriptError]


class WorkerRunner(process.ProcessRunner):
"""Runner responsible for preparing MPI distributed training and waiting for MPI
Expand Down Expand Up @@ -235,12 +255,16 @@ def _create_command(self):
"-x",
"PATH",
"-x",
"LD_PRELOAD=%s" % inspect.getfile(gethostname),
"LD_PRELOAD=%s" % getfile(gethostname),
]

command.extend(additional_options)

for credential in ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"]:
for credential in [
"AWS_ACCESS_KEY_ID",
"AWS_SECRET_ACCESS_KEY",
"AWS_SESSION_TOKEN",
]:
if credential in os.environ:
command.extend(["-x", credential])

Expand All @@ -256,6 +280,49 @@ def _python_command(self):
"""
return super(MasterRunner, self)._python_command() + ["-m", "mpi4py"]

def run(self, wait=True, capture_error=False):
"""Run the process.

Args:
wait (bool): A boolean indicating whether to wait and check for errors.
Defaults to True.
capture_error (bool): A boolean indicating whether to direct stderr to a stream
that can later be read. Defaults to False.

Returns:
process (subprocess.Popen): The spawned process.
"""
self._setup()

cmd = self._create_command()

logging_config.log_script_invocation(cmd, self._env_vars)

training_env = environment.Environment()
if wait:
process_spawned = process.check_error(
cmd,
exception_classes
if training_env.is_modelparallel_enabled
else errors.ExecuteUserScriptError,
self._processes_per_host,
capture_error=capture_error,
cwd=environment.code_dir,
)
else:
_, _, process_spawned = process.create(
cmd,
exception_classes
if training_env.is_modelparallel_enabled
else errors.ExecuteUserScriptError,
self._processes_per_host,
capture_error=capture_error,
cwd=environment.code_dir,
)

self._tear_down()
return process_spawned


_SSH_DAEMON_NOT_FOUND_ERROR_MESSAGE = """
SSH daemon not found, please install SSH to allow MPI to communicate different nodes in cluster.
Expand Down
1 change: 1 addition & 0 deletions src/sagemaker_training/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,4 @@
SMDATAPARALLEL_CUSTOM_MPI_OPTIONS = (
"sagemaker_distributed_dataparallel_custom_mpi_options"
) # type: str
SM_HP_MP_PARAMETERS = "SM_HP_MP_PARAMETERS"
117 changes: 96 additions & 21 deletions src/sagemaker_training/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import asyncio
from asyncio.subprocess import PIPE
from inspect import isclass
import os
import re
import subprocess
Expand All @@ -38,7 +39,24 @@
_DEFAULT_BUF_SIZE = 1024 * 64


async def watch(stream, proc_per_host):
def process_error_classes(error_classes):
"""Process error classes and return a list of string.
Input could be class, string, or None

Args:
error_classes (list): List of error classes

Returns:
error_classes: processed classes
"""
if not error_classes:
return []
if not isinstance(error_classes, list):
error_classes = [error_classes]
return [error.__name__ if isclass(error) else error for error in error_classes]


async def watch(stream, proc_per_host, error_classes=None):
"""Process the stdout and stderr streams on the fly.
Decode the output lines
Remove new line characters (if any)
Expand All @@ -48,10 +66,12 @@ async def watch(stream, proc_per_host):
Args:
stream: asyncio subprocess PIPE
proc_per_host (int): Number of processes per each host
error_classes (list): List of exception classes to watch and raise

Returns:
output: Filtered stderr
"""
error_classes = process_error_classes(error_classes)
output = []
buf_size = _DEFAULT_BUF_SIZE
start = False
Expand Down Expand Up @@ -82,20 +102,31 @@ async def watch(stream, proc_per_host):
line,
)
print(line)
# log only if necessary
# log only if necessary, remove node and rank id for de-duplication
err_line = re.sub(r"\[(\d),(\d)\]<stderr>", "", err_line)
# in case error piped to stdout
err_line = re.sub(r"\[(\d),(\d)\]<stdout>", "", err_line)

if start:
if line not in output:
output.append(err_line)
if err_line not in output:
output.append(err_line.strip(" :\n") + "\n")
else:
if any(err in err_line for err in _PYTHON_ERRORS_):
if any(
str(err) in err_line
for err in (
_PYTHON_ERRORS_ + error_classes
if isinstance(error_classes, list)
else [error_classes]
)
):
# start logging error message if target exceptions found
start = True
output.append(err_line + "\n")
output.append(err_line.strip(" :\n") + "\n")

return " ".join(output)


async def run_async(cmd, processes_per_host, env, cwd, stderr, **kwargs):
async def run_async(cmd, processes_per_host, env, cwd, stderr, error_classes=None, **kwargs):
"""Method responsible for launching asyncio subprocess shell
Use asyncio gather to collect processed stdout and stderr

Expand All @@ -105,6 +136,7 @@ async def run_async(cmd, processes_per_host, env, cwd, stderr, **kwargs):
env: os.environ
cwd (str): The location from which to run the command (default: None).
If None, this defaults to the ``code_dir`` of the environment.
error_classes (list): List of exception classes to watch and raise
**kwargs: Extra arguments that are passed to the asyncio create subprocess constructor.

Returns:
Expand All @@ -113,28 +145,37 @@ async def run_async(cmd, processes_per_host, env, cwd, stderr, **kwargs):
asyncio.subprocess.Process: The asyncio process for the given command.

Raises:
error_class: If there is an exception raised when creating the process.
ExecuteUserScriptError: If there is an exception raised when creating the process.
"""
cmd = " ".join(cmd)
proc = await asyncio.create_subprocess_shell(
cmd, env=env, cwd=cwd, stdout=PIPE, stderr=stderr, **kwargs
)

output = await asyncio.gather(
watch(proc.stdout, processes_per_host), watch(proc.stderr, processes_per_host)
watch(proc.stdout, processes_per_host, error_classes=error_classes),
watch(proc.stderr, processes_per_host, error_classes=error_classes),
)
logger.info("Waiting for the process to finish and give a return code.")
return_code = await proc.wait()
logger.info(f"Done waiting for a return code. Received {return_code} from exiting process.")
return return_code, output, proc


def create(cmd, error_class, processes_per_host, cwd=None, env=None, capture_error=False, **kwargs):
def create(
cmd,
error_classes,
processes_per_host,
cwd=None,
env=None,
capture_error=False,
**kwargs,
):
"""Spawn a process with asyncio for the given command.

Args:
cmd (list): The command to be run.
error_class (cls): The class to use when raising an exception.
error_classes (list): List of exception classes to watch and raise.
cwd (str): The location from which to run the command (default: None).
If None, this defaults to the ``code_dir`` of the environment.
env: os.environ
Expand All @@ -146,7 +187,7 @@ def create(cmd, error_class, processes_per_host, cwd=None, env=None, capture_err
asyncio.subprocess.Process: The asyncio process for the given command.

Raises:
error_class: If there is an exception raised when creating the process.
ExecuteUserScriptError: If there is an exception raised when creating the process.
"""
try:
stderr = PIPE if capture_error else None
Expand All @@ -157,20 +198,25 @@ def create(cmd, error_class, processes_per_host, cwd=None, env=None, capture_err
env=env or os.environ,
cwd=cwd or environment.code_dir,
stderr=stderr,
error_classes=error_classes,
**kwargs,
)
)
return rc, output, proc
except Exception as e: # pylint: disable=broad-except
six.reraise(error_class, error_class(e), sys.exc_info()[2])
six.reraise(
errors.ExecuteUserScriptError,
errors.ExecuteUserScriptError(e),
sys.exc_info()[2],
)


def check_error(cmd, error_class, processes_per_host, cwd=None, capture_error=True, **kwargs):
def check_error(cmd, error_classes, processes_per_host, cwd=None, capture_error=True, **kwargs):
"""Run a commmand, raising an exception if there is an error.

Args:
cmd ([str]): The command to be run.
error_class (cls): The class to use when raising an exception.
error_classes (list): List of exception classes to watch and raise.
processes_per_host (int): Number of processes per host
capture_error (bool): Whether or not to include stderr in
the exception message (default: True). In either case,
Expand All @@ -181,32 +227,57 @@ def check_error(cmd, error_class, processes_per_host, cwd=None, capture_error=Tr
subprocess.Popen: The process for the given command.

Raises:
error_class: If there is an exception raised when creating the process.
ExecuteUserScriptError: If there is an exception raised when creating the process.
"""

error_classes = process_error_classes(error_classes)
if capture_error:
return_code, output, process = create(
cmd,
error_class,
error_classes,
processes_per_host,
env=os.environ,
cwd=cwd or environment.code_dir,
capture_error=True,
**kwargs,
)
stderr = output[1]
stderr = " ".join(output)
# remove duplicate while preserve order
stderr = "\n".join(list(dict.fromkeys(stderr.split("\n")))).strip()
else:
stderr = None
# remove extra quotes for subprocess.Popen
cmd[-1] = cmd[-1].strip('"')
process = subprocess.Popen(
cmd, env=os.environ, cwd=cwd or environment.code_dir, stderr=stderr, **kwargs
cmd,
env=os.environ,
cwd=cwd or environment.code_dir,
stderr=stderr,
**kwargs,
)
return_code = process.wait()
if return_code:
extra_info = None
if return_code == 137:
extra_info = "OutOfMemory: Process killed by SIGKILL (signal 9)"

# throw internal error classes first
internal_errors = [err for err in dir(errors) if isclass(getattr(errors, err))]
error_class = next(
(name for name in error_classes if name in internal_errors), "ExecuteUserScriptError"
)
error_class = getattr(errors, error_class)

# only replace ExecuteUserScriptError with custom library errors
if stderr and error_class == errors.ExecuteUserScriptError:
# find the first target error in stderr
error_name = next((str(name) for name in error_classes if str(name) in stderr), False)
if error_name:
error_class = type(
error_name,
(errors._CalledProcessError,), # pylint: disable=protected-access
{},
)

raise error_class(
cmd=" ".join(cmd) if isinstance(cmd, list) else cmd,
return_code=return_code,
Expand Down Expand Up @@ -259,7 +330,11 @@ def _create_command(self):
six.moves.shlex_quote(arg) # pylint: disable=too-many-function-args
for arg in self._args
]
return ["/bin/sh", "-c", '"./%s %s"' % (self._user_entry_point, " ".join(args))]
return [
"/bin/sh",
"-c",
'"./%s %s"' % (self._user_entry_point, " ".join(args)),
]

def _python_command(self): # pylint: disable=no-self-use
return [python_executable()]
Expand Down
Loading