Skip to content

Commit 0cd3b1e

Browse files
committed
add test and make error class optional
1 parent bddfaa2 commit 0cd3b1e

File tree

3 files changed

+73
-14
lines changed

3 files changed

+73
-14
lines changed

src/sagemaker_training/process.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
_DEFAULT_BUF_SIZE = 1024 * 64
3737

3838

39-
async def watch(stream, error_classes, proc_per_host):
39+
async def watch(stream, proc_per_host, error_classes=None):
4040
"""Process the stdout and stderr streams on the fly.
4141
Decode the output lines
4242
Remove new line characters (if any)
@@ -45,12 +45,16 @@ async def watch(stream, error_classes, proc_per_host):
4545
4646
Args:
4747
stream: asyncio subprocess PIPE
48-
error_classes (list): List of exception classes to watch and raise
4948
proc_per_host (int): Number of processes per each host
49+
error_classes (list): List of exception classes to watch and raise
5050
5151
Returns:
5252
output: Filtered stderr
5353
"""
54+
if not error_classes:
55+
error_classes = []
56+
if not isinstance(error_classes, list):
57+
error_classes = [error_classes]
5458
output = []
5559
buf_size = _DEFAULT_BUF_SIZE
5660
start = False
@@ -105,17 +109,17 @@ async def watch(stream, error_classes, proc_per_host):
105109
return " ".join(output)
106110

107111

108-
async def run_async(cmd, error_classes, processes_per_host, env, cwd, stderr, **kwargs):
112+
async def run_async(cmd, processes_per_host, env, cwd, stderr, error_classes=None, **kwargs):
109113
"""Method responsible for launching asyncio subprocess shell
110-
Use asyncio gather to collect processed stdout and stderr
114+
Usyncse asyncio gather to collect processed stdout and stderr
111115
112116
Args:
113117
cmd (list): The command to be run
114-
error_classes (list): List of exception classes to watch and raise
115118
processes_per_host (int): Number of processes per host
116119
env: os.environ
117120
cwd (str): The location from which to run the command (default: None).
118121
If None, this defaults to the ``code_dir`` of the environment.
122+
error_classes (list): List of exception classes to watch and raise
119123
**kwargs: Extra arguments that are passed to the asyncio create subprocess constructor.
120124
121125
Returns:
@@ -126,14 +130,18 @@ async def run_async(cmd, error_classes, processes_per_host, env, cwd, stderr, **
126130
Raises:
127131
ExecuteUserScriptError: If there is an exception raised when creating the process.
128132
"""
133+
if not error_classes:
134+
error_classes = []
135+
if not isinstance(error_classes, list):
136+
error_classes = [error_classes]
129137
cmd = " ".join(cmd)
130138
proc = await asyncio.create_subprocess_shell(
131139
cmd, env=env, cwd=cwd, stdout=PIPE, stderr=stderr, **kwargs
132140
)
133141

134142
output = await asyncio.gather(
135-
watch(proc.stdout, error_classes, processes_per_host),
136-
watch(proc.stderr, error_classes, processes_per_host),
143+
watch(proc.stdout, processes_per_host, error_classes=error_classes),
144+
watch(proc.stderr, processes_per_host, error_classes=error_classes),
137145
)
138146
return_code = proc.returncode
139147
return return_code, output, proc
@@ -166,16 +174,20 @@ def create(
166174
Raises:
167175
ExecuteUserScriptError: If there is an exception raised when creating the process.
168176
"""
177+
if not error_classes:
178+
error_classes = []
179+
if not isinstance(error_classes, list):
180+
error_classes = [error_classes]
169181
try:
170182
stderr = PIPE if capture_error else None
171183
rc, output, proc = asyncio.run(
172184
run_async(
173185
cmd,
174-
error_classes,
175186
processes_per_host,
176187
env=env or os.environ,
177188
cwd=cwd or environment.code_dir,
178189
stderr=stderr,
190+
error_classes=error_classes,
179191
**kwargs,
180192
)
181193
)
@@ -206,7 +218,10 @@ def check_error(cmd, error_classes, processes_per_host, cwd=None, capture_error=
206218
Raises:
207219
ExecuteUserScriptError: If there is an exception raised when creating the process.
208220
"""
209-
221+
if not error_classes:
222+
error_classes = []
223+
if not isinstance(error_classes, list):
224+
error_classes = [error_classes]
210225
if capture_error:
211226
return_code, output, process = create(
212227
cmd,
@@ -239,10 +254,11 @@ def check_error(cmd, error_classes, processes_per_host, cwd=None, capture_error=
239254
# default error class will be user script error
240255
error_class = errors.ExecuteUserScriptError
241256
# use first found target error class if available
242-
for error_name in error_classes:
243-
if error_name in stderr:
244-
error_class = type(error_name, (errors.ExecuteUserScriptError,), {})
245-
break
257+
if stderr:
258+
for error_name in error_classes:
259+
if str(error_name) in stderr:
260+
error_class = type(error_name, (errors.ExecuteUserScriptError,), {})
261+
break
246262

247263
raise error_class(
248264
cmd=" ".join(cmd) if isinstance(cmd, list) else cmd,

test/unit/test_environment.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ def test_env_mapping_properties(training_env):
212212
"output_intermediate_dir",
213213
"is_master",
214214
"master_hostname",
215+
"is_modelparallel_enabled",
215216
}
216217

217218

test/unit/test_process.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ async def test_watch(event_loop, capsys):
105105
expected_stream += (
106106
"[1,mpirank:0,algo-1]<stderr>:FileNotFoundError: [Errno 2] No such file or directory\n"
107107
)
108-
expected_errmsg = ":FileNotFoundError: [Errno 2] No such file or directory\n"
108+
expected_errmsg = "FileNotFoundError: [Errno 2] No such file or directory\n"
109109

110110
stream = asyncio.StreamReader()
111111
stream.feed_data(b"[1,10]<stdout>:This is stdout\n")
@@ -119,6 +119,48 @@ async def test_watch(event_loop, capsys):
119119
assert output == expected_errmsg
120120

121121

122+
@pytest.mark.asyncio
123+
async def test_watch_custom_error(event_loop, capsys):
124+
num_processes_per_host = 8
125+
expected_stream = "[1,mpirank:10,algo-2]<stdout>:This is stdout\n"
126+
expected_stream += "[1,mpirank:10,algo-2]<stderr>:This is stderr\n"
127+
expected_stream += "[1,mpirank:0,algo-1]<stderr>:SMDDPNCCLError: unhandled cuda error\n"
128+
expected_errmsg = "SMDDPNCCLError: unhandled cuda error\n"
129+
130+
stream = asyncio.StreamReader()
131+
stream.feed_data(b"[1,10]<stdout>:This is stdout\n")
132+
stream.feed_data(b"[1,10]<stderr>:This is stderr\n")
133+
stream.feed_data(b"[1,0]<stderr>:SMDDPNCCLError: unhandled cuda error")
134+
stream.feed_eof()
135+
136+
error_classes = ["SMDDPNCCLError"]
137+
output = await process.watch(stream, num_processes_per_host, error_classes=error_classes)
138+
captured_stream = capsys.readouterr()
139+
assert captured_stream.out == expected_stream
140+
assert output == expected_errmsg
141+
142+
# test errors piped in stdout
143+
stream = asyncio.StreamReader()
144+
stream.feed_data(b"[1,10]<stdout>:This is stdout\n")
145+
stream.feed_data(b"[1,10]<stderr>:This is stderr\n")
146+
stream.feed_data(b"[1,0]<stdout>:SMDDPNCCLError: unhandled cuda error")
147+
stream.feed_eof()
148+
149+
error_classes = ["SMDDPNCCLError"]
150+
output = await process.watch(stream, num_processes_per_host, error_classes=error_classes)
151+
assert output == expected_errmsg
152+
153+
# test single item
154+
stream = asyncio.StreamReader()
155+
stream.feed_data(b"[1,10]<stdout>:This is stdout\n")
156+
stream.feed_data(b"[1,10]<stderr>:This is stderr\n")
157+
stream.feed_data(b"[1,0]<stdout>:SMDDPNCCLError: unhandled cuda error")
158+
stream.feed_eof()
159+
error_classes = "SMDDPNCCLError"
160+
output = await process.watch(stream, num_processes_per_host, error_classes=error_classes)
161+
assert output == expected_errmsg
162+
163+
122164
@patch("asyncio.run", AsyncMock(side_effect=ValueError("FAIL")))
123165
def test_create_error():
124166
with pytest.raises(errors.ExecuteUserScriptError):

0 commit comments

Comments
 (0)