Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions src/sagemaker_training/smdataparallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,17 @@ def _wait_for_workers(self): # type: () -> None
logger.info("Waiting for MPI workers to establish their SSH connections")

workers = [host for host in self._hosts if host != self._master_hostname]
with timeout.timeout(seconds=self.timeout_in_seconds):
for host in workers:
while not _can_connect(host):
time.sleep(self._interval)
logger.info("Worker %s available for communication", host)
try:
with timeout.timeout(seconds=self.timeout_in_seconds):
for host in workers:
while not _can_connect(host):
time.sleep(self._interval)
logger.info("Worker %s available for communication", host)
except timeout.TimeoutError:
logger.exception(
"Connection between the hosts couldn't established. Aborting the training."
)
raise

def _get_mpirun_command(
self,
Expand Down Expand Up @@ -321,8 +327,7 @@ def _can_connect(host, port=22):
logger.info("Can connect to host %s at port %s", host, port)
return True
except Exception: # pylint: disable=broad-except
logger.info("Cannot connect to host %s at port %s", host, port)
logger.exception("Connection failed")
logger.info("Cannot connect to host %s at port %s. Retrying...", host, port)
return False
finally:
client.close()
Expand Down