diff --git a/hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_0/model.py b/hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_0/model.py index 9415968b..d81a664e 100644 --- a/hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_0/model.py +++ b/hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_0/model.py @@ -1,5 +1,5 @@ -from pydantic import BaseModel, ConfigDict, Field -from typing import Optional, List, Dict, Union +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +from typing import Optional, List, Dict, Union, Literal from sagemaker.hyperpod.training.config.hyperpod_pytorch_job_unified_config import ( Containers, ReplicaSpec, @@ -8,9 +8,42 @@ Spec, Template, Metadata, + Volumes, + HostPath, + PersistentVolumeClaim ) +class VolumeConfig(BaseModel): + name: str = Field(..., description="Volume name") + type: Literal['hostPath', 'pvc'] = Field(..., description="Volume type") + mount_path: str = Field(..., description="Mount path in container") + path: Optional[str] = Field(None, description="Host path (required for hostPath volumes)") + claim_name: Optional[str] = Field(None, description="PVC claim name (required for pvc volumes)") + read_only: Optional[Literal['true', 'false']] = Field(None, description="Read-only flag for pvc volumes") + + @field_validator('mount_path', 'path') + @classmethod + def paths_must_be_absolute(cls, v): + """Validate that paths are absolute (start with /).""" + if v and not v.startswith('/'): + raise ValueError('Path must be absolute (start with /)') + return v + + @model_validator(mode='after') + def validate_type_specific_fields(self): + """Validate that required fields are present based on volume type.""" + + if self.type == 'hostPath': + if not self.path: + raise ValueError('hostPath volumes require path field') + elif self.type == 'pvc': + if not self.claim_name: + raise ValueError('PVC volumes require claim_name field') + + return self + + class PyTorchJobConfig(BaseModel): model_config = ConfigDict(extra="forbid") @@ -60,22 +93,41 @@ class PyTorchJobConfig(BaseModel): max_retry: Optional[int] = Field( default=None, alias="max_retry", description="Maximum number of job retries" ) - volumes: Optional[List[str]] = Field( - default=None, description="List of volumes to mount" - ) - persistent_volume_claims: Optional[List[str]] = Field( - default=None, - alias="persistent_volume_claims", - description="List of persistent volume claims", + volume: Optional[List[VolumeConfig]] = Field( + default=None, description="List of volume configurations. \ + Command structure: --volume name=,type=,mount_path=, \ + For hostPath: --volume name=model-data,type=hostPath,mount_path=/data,path=/data \ + For persistentVolumeClaim: --volume name=training-output,type=pvc,mount_path=/mnt/output,claim_name=training-output-pvc,read_only=false \ + If multiple --volume flag if multiple volumes are needed \ + " ) service_account_name: Optional[str] = Field( default=None, alias="service_account_name", description="Service account name" ) + @field_validator('volume') + def validate_no_duplicates(cls, v): + """Validate no duplicate volume names or mount paths.""" + if not v: + return v + + # Check for duplicate volume names + names = [vol.name for vol in v] + if len(names) != len(set(names)): + raise ValueError("Duplicate volume names found") + + # Check for duplicate mount paths + mount_paths = [vol.mount_path for vol in v] + if len(mount_paths) != len(set(mount_paths)): + raise ValueError("Duplicate mount paths found") + + return v + def to_domain(self) -> Dict: """ Convert flat config to domain model (HyperPodPytorchJobSpec) """ + # Create container with required fields container_kwargs = { "name": "container-name", @@ -97,17 +149,42 @@ def to_domain(self) -> Dict: container_kwargs["env"] = [ {"name": k, "value": v} for k, v in self.environment.items() ] - if self.volumes is not None: - container_kwargs["volume_mounts"] = [ - {"name": v, "mount_path": f"/mnt/{v}"} for v in self.volumes - ] + + if self.volume is not None: + volume_mounts = [] + for i, vol in enumerate(self.volume): + volume_mount = {"name": vol.name, "mount_path": vol.mount_path} + volume_mounts.append(volume_mount) + + container_kwargs["volume_mounts"] = volume_mounts + # Create container object - container = Containers(**container_kwargs) + try: + container = Containers(**container_kwargs) + except Exception as e: + raise # Create pod spec kwargs spec_kwargs = {"containers": list([container])} + # Add volumes to pod spec if present + if self.volume is not None: + volumes = [] + for i, vol in enumerate(self.volume): + if vol.type == "hostPath": + host_path = HostPath(path=vol.path) + volume_obj = Volumes(name=vol.name, host_path=host_path) + elif vol.type == "pvc": + pvc_config = PersistentVolumeClaim( + claim_name=vol.claim_name, + read_only=vol.read_only == "true" if vol.read_only else False + ) + volume_obj = Volumes(name=vol.name, persistent_volume_claim=pvc_config) + volumes.append(volume_obj) + + spec_kwargs["volumes"] = volumes + # Add node selector if any selector fields are present node_selector = {} if self.instance_type is not None: @@ -175,5 +252,4 @@ def to_domain(self) -> Dict: "namespace": self.namespace, "spec": job_kwargs, } - return result diff --git a/hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_0/schema.json b/hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_0/schema.json index 809a95c6..0c6c58a8 100644 --- a/hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_0/schema.json +++ b/hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_0/schema.json @@ -1,83 +1,319 @@ { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "title": "HyperPod PyTorch Job Parameters", - "type": "object", - "properties": { - "job-name": {"type": "string", "description": "Job name", "minLength": 1}, - "namespace": {"type": "string", "description": "Kubernetes namespace"}, - "image": {"type": "string", "description": "Docker image for training"}, - "command": { - "type": "array", - "items": {"type": "string"}, - "description": "Command to run in the container" - }, - "args": { - "type": "array", - "items": {"type": "string"}, - "description": "Arguments for the entry script" - }, - "environment": { - "type": "object", - "additionalProperties": {"type": "string"}, - "description": "Environment variables as key-value pairs" - }, - "pull-policy": { - "type": "string", - "enum": ["Always", "Never", "IfNotPresent"], - "description": "Image pull policy" - }, - "instance-type": { - "type": "string", - "description": "Instance type for training" - }, - "node-count": { - "type": "integer", - "minimum": 1, - "description": "Number of nodes" - }, - "tasks-per-node": { - "type": "integer", - "minimum": 1, - "description": "Number of tasks per node" - }, - "label-selector": { - "type": "object", - "additionalProperties": {"type": "string"}, - "description": "Node label selector as key-value pairs" - }, - "deep-health-check-passed-nodes-only": { - "type": "boolean", - "description": "Schedule pods only on nodes that passed deep health check" - }, - "scheduler-type": {"type": "string", "description": "Scheduler type"}, - "queue-name": { - "type": "string", - "description": "Queue name for job scheduling" - }, - "priority": { - "type": "string", - "description": "Priority class for job scheduling" - }, - "max-retry": { - "type": "integer", - "minimum": 0, - "description": "Maximum number of job retries" - }, - "volumes": { - "type": "array", - "items": {"type": "string"}, - "description": "List of volumes to mount" - }, - "persistent-volume-claims": { - "type": "array", - "items": {"type": "string"}, - "description": "List of persistent volume claims" - }, - "service-account-name": { - "type": "string", - "description": "Service account name" - } - }, - "required": ["job-name", "image"], - "additionalProperties": false -} + "$defs": { + "VolumeConfig": { + "properties": { + "name": { + "description": "Volume name", + "title": "Name", + "type": "string" + }, + "type": { + "description": "Volume type", + "enum": [ + "hostPath", + "pvc" + ], + "title": "Type", + "type": "string" + }, + "mount_path": { + "description": "Mount path in container", + "title": "Mount Path", + "type": "string" + }, + "path": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Host path (required for hostPath volumes)", + "title": "Path" + }, + "claim_name": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "PVC claim name (required for pvc volumes)", + "title": "Claim Name" + }, + "read_only": { + "anyOf": [ + { + "enum": [ + "true", + "false" + ], + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Read-only flag for pvc volumes", + "title": "Read Only" + } + }, + "required": [ + "name", + "type", + "mount_path" + ], + "title": "VolumeConfig", + "type": "object" + } + }, + "additionalProperties": false, + "properties": { + "job_name": { + "description": "Job name", + "title": "Job Name", + "type": "string" + }, + "image": { + "description": "Docker image for training", + "title": "Image", + "type": "string" + }, + "namespace": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Kubernetes namespace", + "title": "Namespace" + }, + "command": { + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Command to run in the container", + "title": "Command" + }, + "args": { + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Arguments for the entry script", + "title": "Args" + }, + "environment": { + "anyOf": [ + { + "additionalProperties": { + "type": "string" + }, + "type": "object" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Environment variables as key_value pairs", + "title": "Environment" + }, + "pull_policy": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Image pull policy", + "title": "Pull Policy" + }, + "instance_type": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Instance type for training", + "title": "Instance Type" + }, + "node_count": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Number of nodes", + "title": "Node Count" + }, + "tasks_per_node": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Number of tasks per node", + "title": "Tasks Per Node" + }, + "label_selector": { + "anyOf": [ + { + "additionalProperties": { + "type": "string" + }, + "type": "object" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Node label selector as key_value pairs", + "title": "Label Selector" + }, + "deep_health_check_passed_nodes_only": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "default": false, + "description": "Schedule pods only on nodes that passed deep health check", + "title": "Deep Health Check Passed Nodes Only" + }, + "scheduler_type": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Scheduler type", + "title": "Scheduler Type" + }, + "queue_name": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Queue name for job scheduling", + "title": "Queue Name" + }, + "priority": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Priority class for job scheduling", + "title": "Priority" + }, + "max_retry": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Maximum number of job retries", + "title": "Max Retry" + }, + "volume": { + "anyOf": [ + { + "items": { + "$ref": "#/$defs/VolumeConfig" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "default": null, + "description": "List of volume configurations. Command structure: --volume name=,type=,mount_path=, For hostPath: --volume name=model-data,type=hostPath,mount_path=/data,path=/data For persistentVolumeClaim: --volume name=training-output,type=pvc,mount_path=/mnt/output,claim_name=training-output-pvc,read_only=false If multiple --volume flag if multiple volumes are needed ", + "title": "Volume" + }, + "service_account_name": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Service account name", + "title": "Service Account Name" + } + }, + "required": [ + "job_name", + "image" + ], + "title": "PyTorchJobConfig", + "type": "object" +} \ No newline at end of file diff --git a/src/sagemaker/hyperpod/cli/training_utils.py b/src/sagemaker/hyperpod/cli/training_utils.py index eeecb022..a08bb735 100644 --- a/src/sagemaker/hyperpod/cli/training_utils.py +++ b/src/sagemaker/hyperpod/cli/training_utils.py @@ -1,7 +1,8 @@ import json import pkgutil import click -from typing import Callable, Optional, Mapping, Type +from typing import Callable, Optional, Mapping, Type, Dict, Any +from pydantic import ValidationError def load_schema_for_version( @@ -24,7 +25,7 @@ def load_schema_for_version( def generate_click_command( *, version_key: Optional[str] = None, - schema_pkg: str = "hyperpod_jumpstart_inference_template", + schema_pkg: str, registry: Mapping[str, Type] = None, ) -> Callable: """ @@ -57,6 +58,26 @@ def _parse_list_flag(ctx, param, value): value = value.strip("[]") return [item.strip() for item in value.split(",") if item.strip()] + def _parse_volume_param(ctx, param, value): + """Parse volume parameters from command line format to dictionary format.""" + volumes = [] + for i, v in enumerate(value): + try: + # Split by comma and then by equals, with validation + parts = {} + for item in v.split(','): + if '=' not in item: + raise click.UsageError(f"Invalid volume format in volume {i+1}: '{item}' should be key=value") + key, val = item.split('=', 1) # Split only on first '=' to handle values with '=' + parts[key.strip()] = val.strip() + + volumes.append(parts) + except Exception as e: + raise click.UsageError(f"Error parsing volume {i+1}: {str(e)}") + + # Note: Detailed validation will be handled by schema validation + return volumes + # 1) the wrapper click will call def wrapped_func(*args, **kwargs): # extract version @@ -68,93 +89,81 @@ def wrapped_func(*args, **kwargs): if Model is None: raise click.ClickException(f"Unsupported schema version: {version}") - # validate & to_domain - flat = Model(**kwargs) - domain_config = flat.to_domain() + try: + flat = Model(**kwargs) + domain_config = flat.to_domain() + except ValidationError as e: + error_messages = [] + for err in e.errors(): + loc = ".".join(str(x) for x in err["loc"]) + msg = err["msg"] + error_messages.append(f" – {loc}: {msg}") + + raise click.UsageError( + f"❌ Configuration validation errors:\n" + "\n".join(error_messages) + ) # call your handler return func(version, debug, domain_config) # 2) inject click options from JSON Schema excluded_props = set(["version"]) - if schema_pkg == "hyperpod_jumpstart_inference_template": + + wrapped_func = click.option( + "--environment", + callback=_parse_json_flag, + type=str, + default=None, + help=( + "JSON object of environment variables, e.g. " + '\'{"VAR1":"foo","VAR2":"bar"}\'' + ), + metavar="JSON", + )(wrapped_func) + wrapped_func = click.option( + "--label_selector", + callback=_parse_json_flag, + help='JSON object of resource limits, e.g. \'{"cpu":"2","memory":"4Gi"}\'', + metavar="JSON", + )(wrapped_func) + + wrapped_func = click.option( + "--volume", + multiple=True, + callback=_parse_volume_param, + help="List of volume configurations. \ + Command structure: --volume name=,type=,mount_path=, \ + For hostPath: --volume name=model-data,type=hostPath,mount_path=/data,path=/data \ + For persistentVolumeClaim: --volume name=training-output,type=pvc,mount_path=/mnt/output,claim_name=training-output-pvc,read_only=false \ + If multiple --volume flag if multiple volumes are needed.", + )(wrapped_func) + + # Add list options + list_params = { + "command": "List of command arguments", + "args": "List of script arguments, e.g. '[--batch-size, 32, --learning-rate, 0.001]'", + } + + for param_name, help_text in list_params.items(): wrapped_func = click.option( - "--env", - callback=_parse_json_flag, + f"--{param_name}", + callback=_parse_list_flag, type=str, default=None, - help=( - "JSON object of environment variables, e.g. " - '\'{"VAR1":"foo","VAR2":"bar"}\'' - ), - metavar="JSON", - )(wrapped_func) - wrapped_func = click.option( - "--resources-limits", - callback=_parse_json_flag, - help='JSON object of resource limits, e.g. \'{"cpu":"2","memory":"4Gi"}\'', - metavar="JSON", - )(wrapped_func) - - wrapped_func = click.option( - "--resources-requests", - callback=_parse_json_flag, - help='JSON object of resource requests, e.g. \'{"cpu":"1","memory":"2Gi"}\'', - metavar="JSON", + help=help_text, + metavar="LIST", )(wrapped_func) - excluded_props = set( - ["version", "env", "resources_limits", "resources_requests"] - ) - - elif schema_pkg == "hyperpod_pytorch_job_template": - wrapped_func = click.option( - "--environment", - callback=_parse_json_flag, - type=str, - default=None, - help=( - "JSON object of environment variables, e.g. " - '\'{"VAR1":"foo","VAR2":"bar"}\'' - ), - metavar="JSON", - )(wrapped_func) - wrapped_func = click.option( - "--label_selector", - callback=_parse_json_flag, - help='JSON object of resource limits, e.g. \'{"cpu":"2","memory":"4Gi"}\'', - metavar="JSON", - )(wrapped_func) - - # Add list options - list_params = { - "command": "List of command arguments", - "args": "List of script arguments, e.g. '[--batch-size, 32, --learning-rate, 0.001]'", - "volumes": "List of volumes, e.g. '[vol1, vol2, vol3]'", - "persistent_volume_claims": "List of persistent volume claims, e.g. '[pvc1, pvc2]'", - } - - for param_name, help_text in list_params.items(): - wrapped_func = click.option( - f"--{param_name}", - callback=_parse_list_flag, - type=str, - default=None, - help=help_text, - metavar="LIST", - )(wrapped_func) - - excluded_props = set( - [ - "version", - "environment", - "label_selector", - "command", - "args", - "volumes", - "persistent_volume_claims", - ] - ) + excluded_props = set( + [ + "version", + "environment", + "label_selector", + "command", + "args", + "volume", + ] + ) schema = load_schema_for_version(version_key or "1.0", schema_pkg) props = schema.get("properties", {}) diff --git a/test/unit_tests/cli/test_training_utils.py b/test/unit_tests/cli/test_training_utils.py index af7c65e5..683280b4 100644 --- a/test/unit_tests/cli/test_training_utils.py +++ b/test/unit_tests/cli/test_training_utils.py @@ -186,7 +186,7 @@ def to_domain(self): registry = {'1.0': DummyModel} @click.command() - @generate_click_command(registry=registry) + @generate_click_command(registry=registry, schema_pkg="hyperpod-pytorch-job") def cmd(version, debug, config): click.echo(json.dumps({ 'node_count': config.node_count, @@ -211,3 +211,271 @@ def cmd(version, debug, config): result = self.runner.invoke(cmd, ['--node-count', 'not-a-number']) assert result.exit_code == 2 assert "Invalid value" in result.output + + + @patch('sagemaker.hyperpod.cli.training_utils.pkgutil.get_data') + def test_volume_flag_parsing(self, mock_get_data): + """Test volume flag parsing functionality""" + schema = { + 'properties': { + 'volume': { + 'type': 'array', + 'items': { + 'type': 'object', + 'properties': { + 'name': {'type': 'string'}, + 'type': {'type': 'string'}, + 'mount_path': {'type': 'string'}, + 'path': {'type': 'string'}, + 'claim_name': {'type': 'string'}, + 'read_only': {'type': 'string'} + } + } + } + } + } + mock_get_data.return_value = json.dumps(schema).encode() + + class DummyModel: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + def to_domain(self): + return self + + registry = {'1.0': DummyModel} + + @click.command() + @generate_click_command( + schema_pkg="hyperpod_pytorch_job_template", + registry=registry + ) + def cmd(version, debug, config): + click.echo(json.dumps({ + 'volume': config.volume if hasattr(config, 'volume') else None + })) + + # Test single hostPath volume + result = self.runner.invoke(cmd, [ + '--volume', 'name=model-data,type=hostPath,mount_path=/data,path=/host/data' + ]) + assert result.exit_code == 0 + output = json.loads(result.output) + expected_volume = [{ + 'name': 'model-data', + 'type': 'hostPath', + 'mount_path': '/data', + 'path': '/host/data' + }] + assert output['volume'] == expected_volume + + # Test single PVC volume + result = self.runner.invoke(cmd, [ + '--volume', 'name=training-output,type=pvc,mount_path=/output,claim_name=my-pvc,read_only=false' + ]) + assert result.exit_code == 0 + output = json.loads(result.output) + expected_volume = [{ + 'name': 'training-output', + 'type': 'pvc', + 'mount_path': '/output', + 'claim_name': 'my-pvc', + 'read_only': 'false' + }] + assert output['volume'] == expected_volume + + # Test multiple volumes + result = self.runner.invoke(cmd, [ + '--volume', 'name=model-data,type=hostPath,mount_path=/data,path=/host/data', + '--volume', 'name=training-output,type=pvc,mount_path=/output,claim_name=my-pvc,read_only=true' + ]) + assert result.exit_code == 0 + output = json.loads(result.output) + expected_volumes = [ + { + 'name': 'model-data', + 'type': 'hostPath', + 'mount_path': '/data', + 'path': '/host/data' + }, + { + 'name': 'training-output', + 'type': 'pvc', + 'mount_path': '/output', + 'claim_name': 'my-pvc', + 'read_only': 'true' + } + ] + assert output['volume'] == expected_volumes + + + @patch('sagemaker.hyperpod.cli.training_utils.pkgutil.get_data') + def test_volume_domain_conversion(self, mock_get_data): + """Test volume domain conversion functionality""" + schema = { + 'properties': { + 'job_name': {'type': 'string'}, + 'image': {'type': 'string'}, + 'volume': { + 'type': 'array', + 'items': {'type': 'object'} + } + }, + 'required': ['job_name', 'image'] + } + mock_get_data.return_value = json.dumps(schema).encode() + + class MockVolumeModel: + def __init__(self, **kwargs): + self.job_name = kwargs.get('job_name') + self.image = kwargs.get('image') + self.volume = kwargs.get('volume') + + def to_domain(self): + domain_volumes = [] + if self.volume: + for vol in self.volume: + if vol.get('type') == 'hostPath': + domain_volumes.append({ + 'name': vol.get('name'), + 'type': 'hostPath', + 'mount_path': vol.get('mount_path'), + 'host_path': {'path': vol.get('path')} + }) + elif vol.get('type') == 'pvc': + domain_volumes.append({ + 'name': vol.get('name'), + 'type': 'pvc', + 'mount_path': vol.get('mount_path'), + 'persistent_volume_claim': { + 'claim_name': vol.get('claim_name'), + 'read_only': vol.get('read_only') == 'true' + } + }) + + return { + 'name': self.job_name, + 'image': self.image, + 'volumes': domain_volumes + } + + registry = {'1.0': MockVolumeModel} + + @click.command() + @generate_click_command( + schema_pkg="hyperpod_pytorch_job_template", + registry=registry + ) + def cmd(version, debug, config): + click.echo(json.dumps(config)) + + # Test hostPath volume domain conversion + result = self.runner.invoke(cmd, [ + '--job-name', 'test-job', + '--image', 'test-image', + '--volume', 'name=model-data,type=hostPath,mount_path=/data,path=/host/data' + ]) + assert result.exit_code == 0 + output = json.loads(result.output) + assert output['volumes'][0]['type'] == 'hostPath' + assert output['volumes'][0]['host_path']['path'] == '/host/data' + + # Test PVC volume domain conversion + result = self.runner.invoke(cmd, [ + '--job-name', 'test-job', + '--image', 'test-image', + '--volume', 'name=training-output,type=pvc,mount_path=/output,claim_name=my-pvc,read_only=true' + ]) + assert result.exit_code == 0 + output = json.loads(result.output) + assert output['volumes'][0]['type'] == 'pvc' + assert output['volumes'][0]['persistent_volume_claim']['claim_name'] == 'my-pvc' + assert output['volumes'][0]['persistent_volume_claim']['read_only'] is True + + + @patch('sagemaker.hyperpod.cli.training_utils.pkgutil.get_data') + def test_volume_flag_parsing_errors(self, mock_get_data): + """Test volume flag parsing error handling""" + schema = { + 'properties': { + 'volume': { + 'type': 'array', + 'items': {'type': 'object'} + } + } + } + mock_get_data.return_value = json.dumps(schema).encode() + + class DummyModel: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + def to_domain(self): + return self + + registry = {'1.0': DummyModel} + + @click.command() + @generate_click_command( + schema_pkg="hyperpod_pytorch_job_template", + registry=registry + ) + def cmd(version, debug, config): + click.echo("success") + + # Test invalid format (missing equals sign) + result = self.runner.invoke(cmd, [ + '--volume', 'name=model-data,type=hostPath,mount_path,path=/host/data' + ]) + assert result.exit_code == 2 + assert "should be key=value" in result.output + + # Test empty volume parameter + result = self.runner.invoke(cmd, [ + '--volume', '' + ]) + assert result.exit_code == 2 + assert "Error parsing volume" in result.output + + @patch('sagemaker.hyperpod.cli.training_utils.pkgutil.get_data') + def test_volume_flag_with_equals_in_value(self, mock_get_data): + """Test volume flag parsing with equals signs in values""" + schema = { + 'properties': { + 'volume': { + 'type': 'array', + 'items': {'type': 'object'} + } + } + } + mock_get_data.return_value = json.dumps(schema).encode() + + class DummyModel: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + def to_domain(self): + return self + + registry = {'1.0': DummyModel} + + @click.command() + @generate_click_command( + schema_pkg="hyperpod_pytorch_job_template", + registry=registry + ) + def cmd(version, debug, config): + click.echo(json.dumps({ + 'volume': config.volume if hasattr(config, 'volume') else None + })) + + # Test volume with equals sign in path value + result = self.runner.invoke(cmd, [ + '--volume', 'name=model-data,type=hostPath,mount_path=/data,path=/host/data=special' + ]) + assert result.exit_code == 0 + output = json.loads(result.output) + expected_volume = [{ + 'name': 'model-data', + 'type': 'hostPath', + 'mount_path': '/data', + 'path': '/host/data=special' + }] + assert output['volume'] == expected_volume \ No newline at end of file