2828logger = logging_config .get_logger ()
2929logging .getLogger ("paramiko" ).setLevel (logging .INFO )
3030
31+ MPI_FINISHED_STATUS_FILE = "/tmp/done"
32+
3133
3234def get_modelparallel_exception_classes ():
3335 """Set exception classes"""
@@ -60,7 +62,9 @@ class WorkerRunner(process.ProcessRunner):
6062 master execution to finish.
6163 """
6264
63- def __init__ (self , user_entry_point , args , env_vars , processes_per_host , master_hostname ):
65+ def __init__ (
66+ self , user_entry_point , args , env_vars , processes_per_host , master_hostname , current_host
67+ ):
6468 """Initialize a WorkerRunner, which is responsible for preparing distributed
6569 training with MPI and waiting for MPI master execution to finish.
6670
@@ -69,9 +73,11 @@ def __init__(self, user_entry_point, args, env_vars, processes_per_host, master_
6973 args ([str]): A list of arguments to include when executing the entry point.
7074 env_vars (dict(str,str)): A dictionary of environment variables.
7175 master_hostname (str): The master hostname.
76+ current_hostname (str): Current hostname.
7277 """
7378 super (WorkerRunner , self ).__init__ (user_entry_point , args , env_vars , processes_per_host )
7479 self ._master_hostname = str (master_hostname )
80+ self ._current_host = str (current_host )
7581
7682 def run (
7783 self , wait = True , capture_error = False
@@ -81,6 +87,8 @@ def run(
8187 - wait for the MPI Master to create its SSH daemon
8288 - start its SSH daemon
8389 - monitor the MPI orted process and wait it to finish the MPI execution
90+ - wait for the status file from master
91+ - Exit once orted process is finished and status file is found.
8492 """
8593 logger .info ("Starting MPI run as worker node." )
8694 if wait :
@@ -95,18 +103,41 @@ def run(
95103
96104 if wait :
97105 logger .info ("Waiting for MPI process to finish." )
98- _wait_orted_process_to_finish ()
106+ gone , alive = _wait_orted_process_to_finish ()
107+ logger .info (f"Reporting status for ORTEd process. gone: { gone } alive: { alive } " )
99108 logger .info ("Orted process exited" )
100109 time .sleep (30 )
110+ logger .info (f"Begin looking for status file on { self ._current_host } " )
111+ status_file = MPI_FINISHED_STATUS_FILE + "." + self ._master_hostname
112+ file_found = self ._wait_for_status_file (status_file )
113+ if file_found :
114+ logger .info ("MPI training job status file found. Exit gracefully" )
115+ else :
116+ logger .info ("Status file not found. Exiting..." )
117+ logger .info ("End looking for status file" )
101118 logger .info ("MPI process finished." )
102119
120+ def _wait_for_status_file (self , status_file ):
121+ start_time = time .time ()
122+ file_found = os .path .exists (status_file )
123+ while not file_found :
124+ time .sleep (30 )
125+ curr_time = time .time ()
126+ # Check connectivity with master every 2 minutes
127+ if int (curr_time - start_time ) % 120 == 0 :
128+ logger .info ("status file not found..." )
129+ if not _can_connect (self ._master_hostname ):
130+ return False
131+ file_found = os .path .exists (status_file )
132+ return True
133+
103134 def _wait_master_to_start (self ): # type: () -> None
104135 while not _can_connect (self ._master_hostname ):
105136 time .sleep (1 )
106137
107- def _wait_master_to_finish (self ): # type: () -> None
108- while _can_connect (self ._master_hostname ):
109- time .sleep (30 )
138+ # def _wait_master_to_finish(self): # type: () -> None
139+ # while _can_connect(self._master_hostname):
140+ # time.sleep(30)
110141
111142
112143def _write_env_vars_to_file (): # type: () -> None
@@ -115,11 +146,17 @@ def _write_env_vars_to_file(): # type: () -> None
115146 f .write ("{}={}\n " .format (name , os .environ .get (name )))
116147
117148
149+ def _on_terminate (proc ):
150+ logger .info ("Invoked on_terminate from psutil.wait_for_procs" )
151+ logger .info ("process {} terminated with exit code {}" .format (proc , proc .returncode ))
152+
153+
118154def _wait_orted_process_to_finish (): # type: () -> None
119155 orted = _orted_process ()
120156 logger .info ("Orted process found %s" , orted )
121157 logger .info ("Waiting for orted process %s" , orted )
122- psutil .wait_procs (orted )
158+ gone , alive = psutil .wait_procs (orted , callback = _on_terminate )
159+ return gone , alive
123160
124161
125162def _orted_process (): # pylint: disable=inconsistent-return-statements
@@ -150,6 +187,7 @@ def __init__(
150187 interval = 1 ,
151188 timeout_in_seconds = 60 * 60 ,
152189 num_processes = None ,
190+ instance_type = "ml.p3.16xlarge" ,
153191 ):
154192 """Initialize a MasterRunner, which is responsible for preparing distributed
155193 training with MPI and synchronizing work among the Workers.
@@ -178,6 +216,8 @@ def __init__(
178216 self ._custom_mpi_options = custom_mpi_options
179217 self ._network_interface_name = network_interface_name
180218 self ._interval = interval
219+ self ._env_vars = env_vars
220+ self ._instance_type = instance_type
181221 self .timeout_in_seconds = timeout_in_seconds
182222
183223 def _setup (self ): # type: () -> None
@@ -265,6 +305,12 @@ def _create_command(self):
265305 ]
266306
267307 command .extend (additional_options )
308+ # EFA settings
309+ if self ._instance_type in ["ml.p3dn.24xlarge" , "ml.p4d.24xlarge" ]:
310+ # Use EFA's RDMA functionality for one-sided and two-sided transfer
311+ command .extend (["-x" , "FI_EFA_USE_DEVICE_RDMA=1" ])
312+ # Use simple protocol to handle the out-of-order data delivery from EFA
313+ command .extend (["-x" , "NCCL_PROTO=simple" ])
268314
269315 for credential in [
270316 "AWS_ACCESS_KEY_ID" ,
@@ -280,6 +326,12 @@ def _create_command(self):
280326 command .extend (super (MasterRunner , self )._create_command ())
281327 return command
282328
329+ def _get_instance_type (self ):
330+ """Get instance type"""
331+ instance_type = self ._env_vars .get ("current_instance_type" , None )
332+ logger .info ("instance type: %s" % instance_type )
333+ return instance_type
334+
283335 def _python_command (self ):
284336 """Use mpi4py to force processes to abort if an uncaught exception occurs.
285337 https://docs.chainer.org/en/stable/chainermn/tutorial/tips_faqs.html#mpi-process-hangs-after-an-unhandled-python-exception
@@ -326,7 +378,26 @@ def run(self, wait=True, capture_error=False):
326378 capture_error = capture_error ,
327379 cwd = environment .code_dir ,
328380 )
329-
381+ logger .info ("Begin writing status file from leader node to worker nodes (if any)" )
382+ # Write status file to all nodes
383+ status_file = MPI_FINISHED_STATUS_FILE + "." + self ._master_hostname
384+ for host in self ._hosts :
385+ if host != self ._master_hostname :
386+ status = _write_status_file (host , status_file )
387+ retry_count = 5 if not status else 0
388+ while not status :
389+ if retry_count == 0 :
390+ break
391+ logger .info (f"Retry creating status file onto { host } " )
392+ retry_count -= 1
393+ time .sleep (1 )
394+ status = _write_status_file (host , status_file )
395+
396+ if not status :
397+ logger .info (f"Failed to create status file onto { host } " )
398+
399+ time .sleep (30 )
400+ logger .info ("Finished writing status file from leader node to worker nodes (if any)" )
330401 self ._tear_down ()
331402 return process_spawned
332403
@@ -378,8 +449,28 @@ def _can_connect(host, port=22): # type: (str, int) -> bool
378449 return True
379450 except Exception as e : # pylint: disable=broad-except
380451 logger .info ("Cannot connect to host %s" , host )
452+ logger .info (
453+ "Connection failed with exception: \n %s. \
454+ Can be ignored for worker when master completes and exits." ,
455+ str (e ),
456+ )
457+ return False
381458
382- logger .info ("Connection failed with exception: \n %s" , str (e ))
459+
460+ def _write_status_file (host , status_file ):
461+ try :
462+ logger .info (f"Start writing mpirun finished status to { host } " )
463+ output = subprocess .run (
464+ ["ssh" , str (host ), "touch" , f"{ status_file } " ],
465+ capture_output = True ,
466+ text = True ,
467+ check = True ,
468+ )
469+ logger .info (f"output from subprocess run { output } " )
470+ logger .info ("Finished writing status file" )
471+ return True
472+ except subprocess .CalledProcessError :
473+ logger .info (f"Cannot connect to { host } " )
383474 return False
384475
385476
0 commit comments