Skip to content

Commit 14f2894

Browse files
author
Satish Pasumarthi
committed
fix: improve worker node wait logic
1 parent bb20f65 commit 14f2894

File tree

7 files changed

+318
-18
lines changed

7 files changed

+318
-18
lines changed

src/sagemaker_training/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,5 @@
6969
"OutOfRangeError",
7070
"InvalidArgumentError",
7171
]
72+
73+
_MPI_ERRORS_ = ["mpirun.real", "ORTE"]

src/sagemaker_training/mpi.py

Lines changed: 99 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
logger = logging_config.get_logger()
2929
logging.getLogger("paramiko").setLevel(logging.INFO)
3030

31+
MPI_FINISHED_STATUS_FILE = "/tmp/done"
32+
3133

3234
def get_modelparallel_exception_classes():
3335
"""Set exception classes"""
@@ -60,7 +62,9 @@ class WorkerRunner(process.ProcessRunner):
6062
master execution to finish.
6163
"""
6264

63-
def __init__(self, user_entry_point, args, env_vars, processes_per_host, master_hostname):
65+
def __init__(
66+
self, user_entry_point, args, env_vars, processes_per_host, master_hostname, current_host
67+
):
6468
"""Initialize a WorkerRunner, which is responsible for preparing distributed
6569
training with MPI and waiting for MPI master execution to finish.
6670
@@ -69,9 +73,11 @@ def __init__(self, user_entry_point, args, env_vars, processes_per_host, master_
6973
args ([str]): A list of arguments to include when executing the entry point.
7074
env_vars (dict(str,str)): A dictionary of environment variables.
7175
master_hostname (str): The master hostname.
76+
current_hostname (str): Current hostname.
7277
"""
7378
super(WorkerRunner, self).__init__(user_entry_point, args, env_vars, processes_per_host)
7479
self._master_hostname = str(master_hostname)
80+
self._current_host = str(current_host)
7581

7682
def run(
7783
self, wait=True, capture_error=False
@@ -81,6 +87,8 @@ def run(
8187
- wait for the MPI Master to create its SSH daemon
8288
- start its SSH daemon
8389
- monitor the MPI orted process and wait it to finish the MPI execution
90+
- wait for the status file from master
91+
- Exit once orted process is finished and status file is found.
8492
"""
8593
logger.info("Starting MPI run as worker node.")
8694
if wait:
@@ -95,18 +103,41 @@ def run(
95103

96104
if wait:
97105
logger.info("Waiting for MPI process to finish.")
98-
_wait_orted_process_to_finish()
106+
gone, alive = _wait_orted_process_to_finish()
107+
logger.info(f"Reporting status for ORTEd process. gone: {gone} alive: {alive}")
99108
logger.info("Orted process exited")
100109
time.sleep(30)
110+
logger.info(f"Begin looking for status file on {self._current_host}")
111+
status_file = MPI_FINISHED_STATUS_FILE + "." + self._master_hostname
112+
file_found = self._wait_for_status_file(status_file)
113+
if file_found:
114+
logger.info("MPI training job status file found. Exit gracefully")
115+
else:
116+
logger.info("Status file not found. Exiting...")
117+
logger.info("End looking for status file")
101118
logger.info("MPI process finished.")
102119

120+
def _wait_for_status_file(self, status_file):
121+
start_time = time.time()
122+
file_found = os.path.exists(status_file)
123+
while not file_found:
124+
time.sleep(30)
125+
curr_time = time.time()
126+
# Check connectivity with master every 2 minutes
127+
if int(curr_time - start_time) % 120 == 0:
128+
logger.info("status file not found...")
129+
if not _can_connect(self._master_hostname):
130+
return False
131+
file_found = os.path.exists(status_file)
132+
return True
133+
103134
def _wait_master_to_start(self): # type: () -> None
104135
while not _can_connect(self._master_hostname):
105136
time.sleep(1)
106137

107-
def _wait_master_to_finish(self): # type: () -> None
108-
while _can_connect(self._master_hostname):
109-
time.sleep(30)
138+
# def _wait_master_to_finish(self): # type: () -> None
139+
# while _can_connect(self._master_hostname):
140+
# time.sleep(30)
110141

111142

112143
def _write_env_vars_to_file(): # type: () -> None
@@ -115,11 +146,17 @@ def _write_env_vars_to_file(): # type: () -> None
115146
f.write("{}={}\n".format(name, os.environ.get(name)))
116147

117148

149+
def _on_terminate(proc):
150+
logger.info("Invoked on_terminate from psutil.wait_for_procs")
151+
logger.info("process {} terminated with exit code {}".format(proc, proc.returncode))
152+
153+
118154
def _wait_orted_process_to_finish(): # type: () -> None
119155
orted = _orted_process()
120156
logger.info("Orted process found %s", orted)
121157
logger.info("Waiting for orted process %s", orted)
122-
psutil.wait_procs(orted)
158+
gone, alive = psutil.wait_procs(orted, callback=_on_terminate)
159+
return gone, alive
123160

124161

125162
def _orted_process(): # pylint: disable=inconsistent-return-statements
@@ -150,6 +187,7 @@ def __init__(
150187
interval=1,
151188
timeout_in_seconds=60 * 60,
152189
num_processes=None,
190+
instance_type="ml.p3.16xlarge",
153191
):
154192
"""Initialize a MasterRunner, which is responsible for preparing distributed
155193
training with MPI and synchronizing work among the Workers.
@@ -178,6 +216,8 @@ def __init__(
178216
self._custom_mpi_options = custom_mpi_options
179217
self._network_interface_name = network_interface_name
180218
self._interval = interval
219+
self._env_vars = env_vars
220+
self._instance_type = instance_type
181221
self.timeout_in_seconds = timeout_in_seconds
182222

183223
def _setup(self): # type: () -> None
@@ -265,6 +305,12 @@ def _create_command(self):
265305
]
266306

267307
command.extend(additional_options)
308+
# EFA settings
309+
if self._instance_type in ["ml.p3dn.24xlarge", "ml.p4d.24xlarge"]:
310+
# Use EFA's RDMA functionality for one-sided and two-sided transfer
311+
command.extend(["-x", "FI_EFA_USE_DEVICE_RDMA=1"])
312+
# Use simple protocol to handle the out-of-order data delivery from EFA
313+
command.extend(["-x", "NCCL_PROTO=simple"])
268314

269315
for credential in [
270316
"AWS_ACCESS_KEY_ID",
@@ -280,6 +326,12 @@ def _create_command(self):
280326
command.extend(super(MasterRunner, self)._create_command())
281327
return command
282328

329+
def _get_instance_type(self):
330+
"""Get instance type"""
331+
instance_type = self._env_vars.get("current_instance_type", None)
332+
logger.info("instance type: %s" % instance_type)
333+
return instance_type
334+
283335
def _python_command(self):
284336
"""Use mpi4py to force processes to abort if an uncaught exception occurs.
285337
https://docs.chainer.org/en/stable/chainermn/tutorial/tips_faqs.html#mpi-process-hangs-after-an-unhandled-python-exception
@@ -326,7 +378,26 @@ def run(self, wait=True, capture_error=False):
326378
capture_error=capture_error,
327379
cwd=environment.code_dir,
328380
)
329-
381+
logger.info("Begin writing status file from leader node to worker nodes (if any)")
382+
# Write status file to all nodes
383+
status_file = MPI_FINISHED_STATUS_FILE + "." + self._master_hostname
384+
for host in self._hosts:
385+
if host != self._master_hostname:
386+
status = _write_status_file(host, status_file)
387+
retry_count = 5 if not status else 0
388+
while not status:
389+
if retry_count == 0:
390+
break
391+
logger.info(f"Retry creating status file onto {host}")
392+
retry_count -= 1
393+
time.sleep(1)
394+
status = _write_status_file(host, status_file)
395+
396+
if not status:
397+
logger.info(f"Failed to create status file onto {host}")
398+
399+
time.sleep(30)
400+
logger.info("Finished writing status file from leader node to worker nodes (if any)")
330401
self._tear_down()
331402
return process_spawned
332403

@@ -378,8 +449,28 @@ def _can_connect(host, port=22): # type: (str, int) -> bool
378449
return True
379450
except Exception as e: # pylint: disable=broad-except
380451
logger.info("Cannot connect to host %s", host)
452+
logger.info(
453+
"Connection failed with exception: \n %s. \
454+
Can be ignored for worker when master completes and exits.",
455+
str(e),
456+
)
457+
return False
381458

382-
logger.info("Connection failed with exception: \n %s", str(e))
459+
460+
def _write_status_file(host, status_file):
461+
try:
462+
logger.info(f"Start writing mpirun finished status to {host}")
463+
output = subprocess.run(
464+
["ssh", str(host), "touch", f"{status_file}"],
465+
capture_output=True,
466+
text=True,
467+
check=True,
468+
)
469+
logger.info(f"output from subprocess run {output}")
470+
logger.info("Finished writing status file")
471+
return True
472+
except subprocess.CalledProcessError:
473+
logger.info(f"Cannot connect to {host}")
383474
return False
384475

385476

src/sagemaker_training/process.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
from sagemaker_training import (
2929
_entry_point_type,
30+
_MPI_ERRORS_,
3031
_PYTHON_ERRORS_,
3132
environment,
3233
errors,
@@ -115,7 +116,7 @@ async def watch(stream, proc_per_host, error_classes=None):
115116
if any(
116117
str(err) in err_line
117118
for err in (
118-
_PYTHON_ERRORS_ + error_classes
119+
_PYTHON_ERRORS_ + _MPI_ERRORS_ + error_classes
119120
if isinstance(error_classes, list)
120121
else [error_classes]
121122
)

src/sagemaker_training/runner.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def _get_by_runner_type(
9090
elif identifier is RunnerType.MPI and env.is_master:
9191
num_processes = _mpi_param_value(mpi_args, env, params.MPI_NUM_PROCESSES)
9292
custom_mpi_options = _mpi_param_value(mpi_args, env, params.MPI_CUSTOM_OPTIONS, "")
93+
current_instance_type = env.current_instance_type
9394
return mpi.MasterRunner(
9495
user_entry_point,
9596
args,
@@ -100,10 +101,16 @@ def _get_by_runner_type(
100101
custom_mpi_options,
101102
env.network_interface_name,
102103
num_processes=num_processes,
104+
instance_type=current_instance_type,
103105
)
104106
elif identifier is RunnerType.MPI:
105107
return mpi.WorkerRunner(
106-
user_entry_point, args, env_vars, processes_per_host, env.master_hostname
108+
user_entry_point,
109+
args,
110+
env_vars,
111+
processes_per_host,
112+
env.master_hostname,
113+
env.current_host,
107114
)
108115
elif identifier is RunnerType.PyTorchXLA:
109116
return pytorch_xla.PyTorchXLARunner(

src/sagemaker_training/smdataparallel.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
)
4343
exception_classes = [errors.ExecuteUserScriptError]
4444

45+
MPI_FINISHED_STATUS_FILE = "/tmp/done"
46+
4547

4648
class SMDataParallelRunner(process.ProcessRunner):
4749
"""Prepare SMDataParallel-based distributed training.
@@ -185,9 +187,12 @@ def _get_mpirun_command(
185187
mpirun_command.extend(additional_options)
186188

187189
instance_type = self._get_instance_type()
188-
# Use EFA's RDMA functionality for one-sided and two-sided transfer
190+
# EFA settings
189191
if instance_type in ["ml.p3dn.24xlarge", "ml.p4d.24xlarge"]:
192+
# Use EFA's RDMA functionality for one-sided and two-sided transfer
190193
mpirun_command.extend(["-x", "FI_EFA_USE_DEVICE_RDMA=1"])
194+
# Use simple protocol to handle the out-of-order data delivery from EFA
195+
mpirun_command.extend(["-x", "NCCL_PROTO=simple"])
191196

192197
if smdataparallel_server_addr and smdataparallel_server_port:
193198
# in case of multi-node [distributed] training, smdataparallel_server_addr,
@@ -300,6 +305,28 @@ def run(self, wait=True, capture_error=False):
300305
capture_error=capture_error,
301306
cwd=environment.code_dir,
302307
)
308+
309+
logger.info("Begin writing status file from leader node to worker nodes")
310+
# Write status file to all nodes
311+
status_file = MPI_FINISHED_STATUS_FILE + "." + self._master_hostname
312+
for host in self._hosts:
313+
if host != self._master_hostname:
314+
status = _write_status_file(host, status_file)
315+
retry_count = 5 if not status else 0
316+
while not status:
317+
if retry_count == 0:
318+
break
319+
logger.info(f"Retry creating status file onto {host}")
320+
retry_count -= 1
321+
time.sleep(1)
322+
status = _write_status_file(host, status_file)
323+
324+
if not status:
325+
logger.info(f"Failed to create status file onto {host}")
326+
327+
time.sleep(30)
328+
logger.info("Finished writing status file from leader node to worker nodes")
329+
303330
self._tear_down()
304331
return process_spawned
305332

@@ -357,6 +384,23 @@ def _can_connect(host, port=22):
357384
logger.info("Connection closed")
358385

359386

387+
def _write_status_file(host, status_file):
388+
try:
389+
logger.info(f"Start writing mpirun finished status to {host}")
390+
output = subprocess.run(
391+
["ssh", str(host), "touch", f"{status_file}"],
392+
capture_output=True,
393+
text=True,
394+
check=True,
395+
)
396+
logger.info(f"output from subprocess run {output}")
397+
logger.info("Finished writing status file")
398+
return True
399+
except subprocess.CalledProcessError:
400+
logger.info(f"Cannot connect to {host}")
401+
return False
402+
403+
360404
def _parse_custom_mpi_options(custom_mpi_options):
361405
"""Parse custom MPI options provided by user. Known options default value will be overridden
362406
and unknown options will be identified separately."""

0 commit comments

Comments
 (0)