Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ hyperpod connect-cluster --cluster-name <cluster-name> [--region <region>] [--na
This command submits a new training job to the connected SageMaker HyperPod cluster.

```
hyperpod start-job --job-name <job-name> [--namespace <namespace>] [--job-kind <kubeflow/PyTorchJob>] [--image <image>] [--command <command>] [--entry-script <script>] [--script-args <arg1 arg2>] [--environment <key=value>] [--pull-policy <Always|IfNotPresent|Never>] [--instance-type <instance-type>] [--node-count <count>] [--tasks-per-node <count>] [--label-selector <key=value>] [--deep-health-check-passed-nodes-only] [--scheduler-type <Kueue SageMaker None>] [--queue-name <queue-name>] [--priority <priority>] [--auto-resume] [--max-retry <count>] [--restart-policy <Always|OnFailure|Never|ExitCode>] [--volumes <volume1,volume2>] [--persistent-volume-claims <claim1:/mount/path,claim2:/mount/path>] [--results-dir <dir>] [--service-account-name <account>]
hyperpod start-job --job-name <job-name> [--namespace <namespace>] [--job-kind <kubeflow/PyTorchJob>] [--image <image>] [--command <command>] [--entry-script <script>] [--script-args <arg1 arg2>] [--environment <key=value>] [--pull-policy <Always|IfNotPresent|Never>] [--instance-type <instance-type>] [--node-count <count>] [--tasks-per-node <count>] [--label-selector <key=value>] [--deep-health-check-passed-nodes-only] [--scheduler-type <Kueue SageMaker None>] [--queue-name <queue-name>] [--priority <priority>] [--auto-resume] [--max-retry <count>] [--restart-policy <Always|OnFailure|Never|ExitCode>] [--volumes <volume1,volume2>] [--persistent-volume-claims <claim1:/mount/path,claim2:/mount/path>] [--results-dir <dir>] [--service-account-name <account>] [--pre-script <cmd1 cmd2>] [--post-script <cmd1 cmd2>]
```

* `job-name` (string) - Required. The base name of the job. A unique identifier (UUID) will automatically be appended to the name like `<job-name>-<UUID>`.
Expand All @@ -147,6 +147,9 @@ hyperpod start-job --job-name <job-name> [--namespace <namespace>] [--job-kind <
* `script-args` (list[string]) - Optional. The list of arguments for entry scripts.
* `environment` (dict[string, string]) - Optional. The environment variables (key-value pairs) to set in the containers.
* `node-count` (int) - Required. The number of nodes (instances) to launch the jobs on.
* `instance-type` (string) - Required. The instance type to launch the job on. Note that the instance types you can use are the available instances within your SageMaker quotas for instances prefixed with `ml`.
* `pre-script` (list[string]) - Optional. Commands to run before the job starts. Multiple commands should be separated by comma.
* `post-script` (list[string]) - Optional. Commands to run after the job completes. Multiple commands should be separated by comma.
* `instance-type` (string) - Required. The instance type to launch the job on. Note that the instance types you can use are the available instances within your SageMaker quotas for instances prefixed with `ml`. If `node.kubernetes.io/instance-type` is provided via the `label-selector` it will take precedence for node selection.
* `tasks-per-node` (int) - Optional. The number of devices to use per instance.
* `label-selector` (dict[string, list[string]]) - Optional. A dictionary of labels and their values that will override the predefined node selection rules based on the SageMaker HyperPod `node-health-status` label and values. If users provide this field, the CLI will launch the job with this customized label selection.
Expand Down
33 changes: 32 additions & 1 deletion src/hyperpod_cli/commands/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,18 @@ def cancel_job(
help="Optional. Add a temp directory for containers to store data in the hosts."
" <volume_name>:</host/mount/path>:</container/mount/path>,<volume_name>:</host/mount/path1>:</container/mount/path1>",
)
@click.option(
"--pre-script",
type=click.STRING,
required=False,
help="Optional. Commands to run before the job starts. Multiple commands should be separated by semicolons.",
)
@click.option(
"--post-script",
type=click.STRING,
required=False,
help="Optional. Commands to run after the job completes. Multiple commands should be separated by semicolons.",
)
@click.option(
"--recipe",
type=click.STRING,
Expand Down Expand Up @@ -549,6 +561,8 @@ def start_job(
service_account_name: Optional[str],
persistent_volume_claims: Optional[str],
volumes: Optional[str],
pre_script: Optional[str],
post_script: Optional[str],
recipe: Optional[str],
override_parameters: Optional[str],
debug: bool,
Expand Down Expand Up @@ -721,6 +735,23 @@ def start_job(
custom_labels[KUEUE_WORKLOAD_PRIORITY_CLASS_LABEL_KEY] = priority
priority = None

# Handle pre_script
if pre_script:
_override_or_remove(
config["training_cfg"],
"pre_script",
pre_script.split(',')
)

# Handle post_script
if post_script:
_override_or_remove(
config["training_cfg"],
"post_script",
post_script.split(',')
)


_override_or_remove(
config["cluster"]["cluster_config"],
"custom_labels",
Expand Down Expand Up @@ -807,7 +838,7 @@ def start_job(
auto_resume=auto_resume,
label_selector=label_selector,
max_retry=max_retry,
deep_health_check_passed_nodes_only=deep_health_check_passed_nodes_only,
deep_health_check_passed_nodes_only=deep_health_check_passed_nodes_only
)
# TODO: Unblock this after fixing customer using EKS cluster.
console_link = utils.get_cluster_console_url()
Expand Down
47 changes: 47 additions & 0 deletions test/unit_tests/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,6 +886,53 @@ def test_start_job_with_cli_args_label_selection_not_json_str(
)
self.assertEqual(result.exit_code, 1)

@mock.patch("yaml.dump")
@mock.patch("hyperpod_cli.clients.kubernetes_client.KubernetesClient.__new__")
@mock.patch("hyperpod_cli.commands.job.JobValidator")
@mock.patch("boto3.Session")
def test_start_job_with_cli_args_pre_script_and_post_script(
self,
mock_boto3,
mock_validator_cls,
mock_kubernetes_client,
mock_yaml_dump,
):
mock_validator = mock_validator_cls.return_value
mock_validator.validate_aws_credential.return_value = True
mock_kubernetes_client.get_current_context_namespace.return_value = "kubeflow"
mock_yaml_dump.return_value = None
result = self.runner.invoke(
start_job,
[
"--job-name",
"test-job",
"--instance-type",
"ml.c5.xlarge",
"--image",
"pytorch:1.9.0-cuda11.1-cudnn8-runtime",
"--node-count",
"2",
"--label-selector",
"{NonJsonStr",
"--entry-script",
"/opt/train/src/train.py",
"--pre-script",
"echo 'test', echo 'test 1'",
"--post-script",
"echo 'test 1', echo 'test 2'",
"--label-selector",
'{"preferred": {"node.kubernetes.io/instance-type": ["ml.c5.xlarge"]}}'
],
)

# Assert that yaml.dump was called with the correct configuration
mock_yaml_dump.assert_called_once()
call_args = mock_yaml_dump.call_args[0]
self.assertEqual(call_args[0]['training_cfg']['pre_script'], ["echo 'test'", " echo 'test 1'"])
self.assertEqual(call_args[0]['training_cfg']['post_script'], ["echo 'test 1'", " echo 'test 2'"])

self.assertEqual(result.exit_code, 1)

@mock.patch("yaml.dump")
@mock.patch("hyperpod_cli.clients.kubernetes_client.KubernetesClient.__new__")
@mock.patch("hyperpod_cli.commands.job.JobValidator")
Expand Down