@@ -43,11 +43,11 @@ def __init__(self, *args, **kw):
4343
4444
4545@patch ("sagemaker_training.mpi._write_env_vars_to_file" )
46+ @patch ("sagemaker_training.mpi.logger" )
4647@patch ("os.path.exists" )
4748@patch ("time.sleep" )
4849@patch ("paramiko.SSHClient" , new_callable = MockSSHClient )
4950@patch ("sagemaker_training.mpi._on_terminate" )
50- @patch ("sagemaker_training.mpi.WorkerRunner._wait_for_status_file" )
5151@patch ("psutil.wait_procs" )
5252@patch ("psutil.process_iter" )
5353@patch ("paramiko.AutoAddPolicy" )
@@ -57,18 +57,18 @@ def test_mpi_worker_run(
5757 policy ,
5858 process_iter ,
5959 wait_procs ,
60- wait_for_status_file ,
6160 on_terminate ,
6261 ssh_client ,
6362 sleep ,
6463 path_exists ,
64+ logger ,
6565 write_env_vars ,
6666):
6767
6868 process = MagicMock (info = {"name" : "orted" })
6969 process_iter .side_effect = lambda attrs : [process ]
7070 wait_procs .return_value = (process , None )
71- wait_for_status_file . return_value = True
71+ path_exists . side_effect = [ True , False , True ]
7272 worker = mpi .WorkerRunner (
7373 user_entry_point = "train.sh" ,
7474 args = ["-v" , "--lr" , "35" ],
@@ -79,7 +79,6 @@ def test_mpi_worker_run(
7979 )
8080
8181 worker .run ()
82-
8382 write_env_vars .assert_called_once ()
8483
8584 ssh_client ().load_system_host_keys .assert_called ()
@@ -90,8 +89,8 @@ def test_mpi_worker_run(
9089
9190 popen .assert_called_with (["/usr/sbin/sshd" , "-D" ])
9291 path_exists .call_count == 2
93- wait_for_status_file . assert_called_once ()
94- wait_for_status_file . assert_called_with ("/tmp/done.algo-1 " )
92+ path_exists . return_value = True
93+ logger . info . assert_called_with ("MPI process finished. " )
9594
9695
9796@patch ("sagemaker_training.mpi._write_env_vars_to_file" )
0 commit comments