diff --git a/src/sagemaker_xgboost_container/algorithm_mode/train.py b/src/sagemaker_xgboost_container/algorithm_mode/train.py index 24641614..37b34365 100644 --- a/src/sagemaker_xgboost_container/algorithm_mode/train.py +++ b/src/sagemaker_xgboost_container/algorithm_mode/train.py @@ -198,6 +198,9 @@ def sagemaker_train( train_dmatrix, val_dmatrix, train_val_dmatrix = get_validated_dmatrices( train_path, val_path, file_type, csv_weights, is_pipe, combine_train_val ) + + missing_validation_data = validation_channel and not val_dmatrix + train_args = dict( train_cfg=validated_train_config, train_dmatrix=train_dmatrix, @@ -210,22 +213,37 @@ def sagemaker_train( # Wait for hosts to find each other logging.info(f"Distributed node training with {num_hosts} hosts: {sm_hosts}") distributed.wait_hostname_resolution(sm_hosts) + include_in_training = True if not train_dmatrix: logging.warning( - "Host {} does not have data. Will broadcast to cluster and will not be used in distributed" - " training.".format(sm_current_host) + f"Host {sm_current_host} does not have training data. Will broadcast to " + f"cluster and this host {sm_current_host} will not be used in distributed training. " + f"Please divide the training data across instances properly. See https://docs.aws.amazon.com/" + f"sagemaker/latest/dg/xgboost.html#Instance-XGBoost-distributed-training-divide-data. " ) + include_in_training = False + if missing_validation_data: + logging.warning( + f"Host {sm_current_host} does not have validation data " + f"in the validation channel : {validation_channel}. " + f"Will broadcast to cluster and this host {sm_current_host} will not be used " + f"in distributed training. Please divide the validation data across instances properly. " + f"See https://docs.aws.amazon.com/sagemaker/latest/dg/xgboost.html" + f"#Instance-XGBoost-distributed-training-divide-data. " + ) + include_in_training = False + distributed.rabit_run( exec_fun=train_job, args=train_args, - include_in_training=(train_dmatrix is not None), + include_in_training=include_in_training, hosts=sm_hosts, current_host=sm_current_host, update_rabit_args=True, ) elif num_hosts == 1: if train_dmatrix: - if validation_channel and not val_dmatrix: + if missing_validation_data: raise exc.UserError(f"No data in validation channel path {val_path}") logging.info("Single node training.") train_args.update({"is_master": True})