Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions src/sagemaker_xgboost_container/algorithm_mode/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,9 @@ def sagemaker_train(
train_dmatrix, val_dmatrix, train_val_dmatrix = get_validated_dmatrices(
train_path, val_path, file_type, csv_weights, is_pipe, combine_train_val
)

missing_validation_data = validation_channel and not val_dmatrix

train_args = dict(
train_cfg=validated_train_config,
train_dmatrix=train_dmatrix,
Expand All @@ -210,22 +213,37 @@ def sagemaker_train(
# Wait for hosts to find each other
logging.info(f"Distributed node training with {num_hosts} hosts: {sm_hosts}")
distributed.wait_hostname_resolution(sm_hosts)
include_in_training = True
if not train_dmatrix:
logging.warning(
"Host {} does not have data. Will broadcast to cluster and will not be used in distributed"
" training.".format(sm_current_host)
f"Host {sm_current_host} does not have training data. Will broadcast to "
f"cluster and this host {sm_current_host} will not be used in distributed training. "
f"Please divide the training data across instances properly. See https://docs.aws.amazon.com/"
f"sagemaker/latest/dg/xgboost.html#Instance-XGBoost-distributed-training-divide-data. "
)
include_in_training = False
if missing_validation_data:
logging.warning(
f"Host {sm_current_host} does not have validation data "
f"in the validation channel : {validation_channel}. "
f"Will broadcast to cluster and this host {sm_current_host} will not be used "
f"in distributed training. Please divide the validation data across instances properly. "
f"See https://docs.aws.amazon.com/sagemaker/latest/dg/xgboost.html"
f"#Instance-XGBoost-distributed-training-divide-data. "
)
include_in_training = False

distributed.rabit_run(
exec_fun=train_job,
args=train_args,
include_in_training=(train_dmatrix is not None),
include_in_training=include_in_training,
hosts=sm_hosts,
current_host=sm_current_host,
update_rabit_args=True,
)
elif num_hosts == 1:
if train_dmatrix:
if validation_channel and not val_dmatrix:
if missing_validation_data:
raise exc.UserError(f"No data in validation channel path {val_path}")
logging.info("Single node training.")
train_args.update({"is_master": True})
Expand Down