Skip to content

Commit a1255b9

Browse files
committed
fix functional test
1 parent f178241 commit a1255b9

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

src/sagemaker_training/process.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -256,23 +256,23 @@ def check_error(cmd, error_classes, processes_per_host, cwd=None, capture_error=
256256
if return_code == 137:
257257
extra_info = "OutOfMemory: Process killed by SIGKILL (signal 9)"
258258

259-
# default to user script error
260-
error_class = errors.ExecuteUserScriptError
261-
# use first found target error class if available
262-
if stderr:
259+
# throw internal error classes first
260+
internal_errors = [err for err in dir(errors) if isclass(getattr(errors, err))]
261+
error_class = next(
262+
(name for name in error_classes if name in internal_errors), "ExecuteUserScriptError"
263+
)
264+
error_class = getattr(errors, error_class)
265+
266+
# only replace ExecuteUserScriptError with custom library errors
267+
if stderr and error_class == errors.ExecuteUserScriptError:
263268
# find the first target error in stderr
264269
error_name = next((str(name) for name in error_classes if str(name) in stderr), False)
265270
if error_name:
266-
# if error name is one of toolkit errors
267-
if str(error_name) in [x for x in dir(errors) if isclass(getattr(errors, x))]:
268-
error_class = getattr(errors, error_name)
269-
else:
270-
# if error is one of custom errors
271-
error_class = type(
272-
error_name,
273-
(errors._CalledProcessError,), # pylint: disable=protected-access
274-
{},
275-
)
271+
error_class = type(
272+
error_name,
273+
(errors._CalledProcessError,), # pylint: disable=protected-access
274+
{},
275+
)
276276

277277
raise error_class(
278278
cmd=" ".join(cmd) if isinstance(cmd, list) else cmd,

0 commit comments

Comments
 (0)