Skip to content

Commit ad50e5a

Browse files
author
Satish Pasumarthi
committed
fix: provide option to override capture error
1 parent ad1724d commit ad50e5a

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

src/sagemaker_pytorch_container/training.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,17 @@ def train(training_environment):
8181
runner_type = runner.PyTorchXLARunnerType
8282
logger.info('Invoking PT-XLA Runner')
8383
logger.info('Invoking user training script.')
84+
85+
# get capture_error from framework parameters
86+
capture_error = training_environment.additional_framework_parameters.get("sagemaker_toolkit_capture_error", True)
87+
logger.info(f'capture_error is {capture_error}. Default is True')
88+
8489
try:
8590
entry_point.run(uri=training_environment.module_dir,
8691
user_entry_point=training_environment.user_entry_point,
8792
args=training_environment.to_cmd_args(),
8893
env_vars=training_environment.to_env_vars(),
89-
capture_error=True,
94+
capture_error=capture_error,
9095
runner_type=runner_type)
9196
except errors.ExecuteUserScriptError as err:
9297
message = str(err)

test/unit/test_train.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,20 @@ def test_train(run_entry_point, training_env):
7474
runner_type=runner.ProcessRunnerType)
7575

7676

77+
@patch('sagemaker_training.entry_point.run')
78+
@patch('socket.gethostbyname', MagicMock())
79+
def test_train_no_capture_error(run_entry_point, training_env):
80+
training_env.additional_framework_parameters["sagemaker_toolkit_capture_error"] = False
81+
train(training_env)
82+
83+
run_entry_point.assert_called_with(uri=training_env.module_dir,
84+
user_entry_point=training_env.user_entry_point,
85+
args=training_env.to_cmd_args(),
86+
env_vars=training_env.to_env_vars(),
87+
capture_error=False,
88+
runner_type=runner.ProcessRunnerType)
89+
90+
7791
@patch("sagemaker_training.entry_point.run")
7892
@patch('socket.gethostbyname', MagicMock())
7993
def test_train_smdataparallel(run_module, training_env):

0 commit comments

Comments
 (0)