diff --git a/src/sagemaker/hyperpod/cli/commands/cluster.py b/src/sagemaker/hyperpod/cli/commands/cluster.py index cb19f24c..c84b1376 100644 --- a/src/sagemaker/hyperpod/cli/commands/cluster.py +++ b/src/sagemaker/hyperpod/cli/commands/cluster.py @@ -26,6 +26,7 @@ from kubernetes import client from ratelimit import limits, sleep_and_retry from tabulate import tabulate +from sagemaker.hyperpod.cli.parsers import parse_list_parameter from sagemaker.hyperpod.cli.clients.kubernetes_client import ( KubernetesClient, @@ -97,9 +98,10 @@ ) @click.option( "--clusters", + callback=parse_list_parameter, type=click.STRING, required=False, - help="Optional. A list of HyperPod cluster names that users want to check the capacity for. This is useful for users who know some of their most commonly used clusters and want to check the capacity status of the clusters in the AWS account.", + help="Optional. List of HyperPod cluster names to check capacity for. Supports JSON format: '[\"cluster1\", \"cluster2\"]' or simple format: '[cluster1, cluster2]'", ) @click.option( "--debug", @@ -183,7 +185,7 @@ def list_cluster( sys.exit(1) if clusters: - cluster_names = clusters.split(",") + cluster_names = clusters else: try: cluster_names = _get_hyperpod_clusters(sm_client) diff --git a/src/sagemaker/hyperpod/cli/commands/cluster_stack.py b/src/sagemaker/hyperpod/cli/commands/cluster_stack.py index 285ba1f7..e3335221 100644 --- a/src/sagemaker/hyperpod/cli/commands/cluster_stack.py +++ b/src/sagemaker/hyperpod/cli/commands/cluster_stack.py @@ -8,6 +8,7 @@ import json import os from typing import Optional +from sagemaker.hyperpod.cli.parsers import parse_list_parameter from sagemaker_core.main.resources import Cluster from sagemaker_core.main.shapes import ClusterInstanceGroupSpecification @@ -22,23 +23,7 @@ logger = logging.getLogger(__name__) -def parse_status_list(ctx, param, value): - """Parse status list from string format like "['CREATE_COMPLETE', 'UPDATE_COMPLETE']" """ - if not value: - return None - - try: - # Handle both string representation and direct list - if isinstance(value, str): - # Parse string like "['item1', 'item2']" - parsed = ast.literal_eval(value) - if isinstance(parsed, list): - return parsed - else: - raise click.BadParameter(f"Expected list format, got: {type(parsed).__name__}") - return value - except (ValueError, SyntaxError) as e: - raise click.BadParameter(f"Invalid list format. Use: \"['STATUS1', 'STATUS2']\". Error: {e}") +# Use unified parser for consistent behavior - no need for custom function @click.command("cluster-stack") @@ -223,8 +208,8 @@ def describe_cluster_stack(stack_name: str, debug: bool, region: str) -> None: @click.option("--region", help="AWS region") @click.option("--debug", is_flag=True, help="Enable debug logging") @click.option("--status", - callback=parse_status_list, - help="Filter by stack status. Format: \"['CREATE_COMPLETE', 'UPDATE_COMPLETE']\"") + callback=parse_list_parameter, + help="Filter by stack status. Supports JSON format: '[\"CREATE_COMPLETE\", \"UPDATE_COMPLETE\"]' or simple format: '[CREATE_COMPLETE, UPDATE_COMPLETE]'") @_hyperpod_telemetry_emitter(Feature.HYPERPOD_CLI, "list_cluster_stack_cli") def list_cluster_stacks(region, debug, status): """List all HyperPod cluster stacks. @@ -376,4 +361,3 @@ def update_cluster( logger.info("Cluster has been updated") click.secho(f"Cluster {cluster_name} has been updated") - diff --git a/src/sagemaker/hyperpod/cli/inference_utils.py b/src/sagemaker/hyperpod/cli/inference_utils.py index 5ecf2395..7439997a 100644 --- a/src/sagemaker/hyperpod/cli/inference_utils.py +++ b/src/sagemaker/hyperpod/cli/inference_utils.py @@ -4,6 +4,7 @@ from typing import Callable, Optional, Mapping, Type import sys from sagemaker.hyperpod.cli.common_utils import extract_version_from_args, get_latest_version, load_schema_for_version +from sagemaker.hyperpod.cli.parsers import parse_dict_parameter def generate_click_command( @@ -18,14 +19,7 @@ def generate_click_command( version = extract_version_from_args(registry, schema_pkg, default_version) def decorator(func: Callable) -> Callable: - # Parser for the single JSON‐dict env var flag - def _parse_json_flag(ctx, param, value): - if value is None: - return None - try: - return json.loads(value) - except json.JSONDecodeError as e: - raise click.BadParameter(f"{param.name!r} must be valid JSON: {e}") + # Use unified parser for consistent behavior # 1) the wrapper click actually invokes def wrapped_func(*args, **kwargs): @@ -46,21 +40,21 @@ def wrapped_func(*args, **kwargs): props = schema.get("properties", {}) json_flags = { - "env": ("JSON object of environment variables, e.g. " '\'{"VAR1":"foo","VAR2":"bar"}\''), - "dimensions": ("JSON object of dimensions, e.g. " '\'{"VAR1":"foo","VAR2":"bar"}\''), - "resources_limits": ('JSON object of resource limits, e.g. \'{"cpu":"2","memory":"4Gi"}\''), - "resources_requests": ('JSON object of resource requests, e.g. \'{"cpu":"1","memory":"2Gi"}\''), + "env": "Environment variables. Supports JSON format: '{\"VAR1\":\"foo\",\"VAR2\":\"bar\"}' or simple format: '{VAR1: foo, VAR2: bar}'", + "dimensions": "Dimensions. Supports JSON format: '{\"VAR1\":\"foo\",\"VAR2\":\"bar\"}' or simple format: '{VAR1: foo, VAR2: bar}'", + "resources_limits": "Resource limits. Supports JSON format: '{\"cpu\":\"2\",\"memory\":\"4Gi\"}' or simple format: '{cpu: 2, memory: 4Gi}'", + "resources_requests": "Resource requests. Supports JSON format: '{\"cpu\":\"1\",\"memory\":\"2Gi\"}' or simple format: '{cpu: 1, memory: 2Gi}'", } for flag_name, help_text in json_flags.items(): if flag_name in props: wrapped_func = click.option( f"--{flag_name.replace('_', '-')}", - callback=_parse_json_flag, + callback=parse_dict_parameter, type=str, default=None, help=help_text, - metavar="JSON", + metavar="JSON|SIMPLE", )(wrapped_func) # 3) auto-inject all schema.json fields @@ -99,4 +93,4 @@ def wrapped_func(*args, **kwargs): return wrapped_func - return decorator \ No newline at end of file + return decorator diff --git a/src/sagemaker/hyperpod/cli/init_utils.py b/src/sagemaker/hyperpod/cli/init_utils.py index 63624718..34ea037e 100644 --- a/src/sagemaker/hyperpod/cli/init_utils.py +++ b/src/sagemaker/hyperpod/cli/init_utils.py @@ -20,6 +20,7 @@ CFN ) from sagemaker.hyperpod.cluster_management.hp_cluster_stack import HpClusterStack +from sagemaker.hyperpod.cli.parsers import parse_dict_parameter, parse_complex_object_parameter log = logging.getLogger() @@ -207,58 +208,7 @@ def decorator(func: Callable) -> Callable: # If template can't be fetched, use empty dict pass - # JSON flag parser - def _parse_json_flag(ctx, param, value): - if value is None: - return None - try: - return json.loads(value) - except json.JSONDecodeError: - # Try to fix unquoted list items: [python, train.py] -> ["python", "train.py"] - if value.strip().startswith('[') and value.strip().endswith(']'): - try: - # Remove brackets and split by comma - inner = value.strip()[1:-1] - items = [item.strip().strip('"').strip("'") for item in inner.split(',')] - return items - except: - pass - raise click.BadParameter(f"{param.name!r} must be valid JSON or a list like [item1, item2]") - - - # Volume flag parser - def _parse_volume_flag(ctx, param, value): - if not value: - return None - - # Handle multiple volume flags - if not isinstance(value, (list, tuple)): - value = [value] - - from hyperpod_pytorch_job_template.v1_0.model import VolumeConfig - volumes = [] - - for vol_str in value: - # Parse volume string: name=model-data,type=hostPath,mount_path=/data,path=/data - vol_dict = {} - for pair in vol_str.split(','): - if '=' in pair: - key, val = pair.split('=', 1) - key = key.strip() - val = val.strip() - - # Convert read_only to boolean - if key == 'read_only': - vol_dict[key] = val.lower() in ('true', '1', 'yes', 'on') - else: - vol_dict[key] = val - - try: - volumes.append(VolumeConfig(**vol_dict)) - except Exception as e: - raise click.BadParameter(f"Invalid volume configuration '{vol_str}': {e}") - - return volumes + # Use unified parsers for consistent behavior across all parameter types @functools.wraps(func) def wrapped(*args, **kwargs): @@ -318,9 +268,9 @@ def wrapped(*args, **kwargs): if flag_name in union_props: wrapped = click.option( f"--{flag}", - callback=_parse_json_flag, - metavar="JSON", - help=f"JSON object for {flag.replace('-', ' ')}", + callback=parse_dict_parameter, + metavar="JSON|SIMPLE", + help=f"{flag.replace('-', ' ').title()}. Supports JSON format or simple format", )(wrapped) @@ -366,8 +316,8 @@ def wrapped(*args, **kwargs): wrapped = click.option( f"--{name.replace('_','-')}", multiple=True, - callback=_parse_volume_flag, - help=help_text, + callback=lambda ctx, param, value: parse_complex_object_parameter(ctx, param, value, allow_multiple=True), + help="Volume configurations. Supports JSON format or key-value format with multiple --volume flags", )(wrapped) else: wrapped = click.option( @@ -393,9 +343,9 @@ def wrapped(*args, **kwargs): if cfn_param_name == 'Tags': wrapped = click.option( f"--{pascal_to_kebab(cfn_param_name)}", - callback=_parse_json_flag, - metavar="JSON", - help=cfn_param_details.get('Description', ''), + callback=parse_dict_parameter, + metavar="JSON|SIMPLE", + help=cfn_param_details.get('Description', '') + " Supports JSON format or simple format", )(wrapped) else: cfn_default = cfn_param_details.get('Default') @@ -946,4 +896,4 @@ def pascal_to_kebab(pascal_str): if char.isupper() and i > 0: result.append('-') result.append(char.lower()) - return ''.join(result) \ No newline at end of file + return ''.join(result) diff --git a/src/sagemaker/hyperpod/cli/parsers.py b/src/sagemaker/hyperpod/cli/parsers.py new file mode 100644 index 00000000..f0d118a5 --- /dev/null +++ b/src/sagemaker/hyperpod/cli/parsers.py @@ -0,0 +1,406 @@ +""" +Unified parameter parsing module for SageMaker HyperPod CLI. + +This module provides consistent parsing for different parameter types across all CLI commands: +- List parameters: JSON format ['item1', 'item2'] or simple format [item1, item2] +- Dictionary parameters: JSON format {"key": "value"} or simple format {key: value} +- Complex object parameters: JSON format {"key": "value"} or key=value format +""" + +import json +import re +import click +from typing import Any, Dict, List, Union + + +class ParameterParsingError(click.BadParameter): + """Custom exception for parameter parsing errors with helpful messages.""" + pass + + +def parse_list_parameter(ctx, param, value: str) -> List[Any]: + """ + Parse list parameters supporting multiple formats. + + Supported formats: + - JSON: '["item1", "item2", "item3"]' + - Simple: '[item1, item2, item3]' (with or without spaces) + + Args: + ctx: Click context + param: Click parameter object + value: Input string value + + Returns: + List of parsed items + + Raises: + ParameterParsingError: If parsing fails for both JSON and simple formats + """ + if value is None or value == "": + return None + + param_name = param.name if param else "parameter" + + # Try JSON parsing first + try: + parsed = json.loads(value) + if isinstance(parsed, list): + return parsed + else: + # JSON parsed but not a list - provide specific error + raise ParameterParsingError( + f"Expected a list for --{param_name}, got {type(parsed).__name__}" + ) + except (json.JSONDecodeError, ValueError): + # JSON parsing failed, continue to simple parsing + pass + + # Try simple list parsing: [item1, item2] or [item1,item2] + try: + return _parse_simple_list(value) + except Exception: + pass + + # Both formats failed - provide helpful error message + raise ParameterParsingError( + f"Invalid format for --{param_name}. Supported formats:\n" + f" JSON: '[\"item1\", \"item2\", \"item3\"]'\n" + f" Simple: '[item1, item2, item3]'" + ) + + +def parse_dict_parameter(ctx, param, value: str) -> Dict[str, Any]: + """ + Parse dictionary parameters supporting multiple formats. + + Supported formats: + - JSON: '{"key": "value", "key2": "value2"}' + - Simple: '{key: value, key2: "value with spaces"}' + + Args: + ctx: Click context + param: Click parameter object + value: Input string value + + Returns: + Dictionary of parsed key-value pairs + + Raises: + ParameterParsingError: If parsing fails for both JSON and simple formats + """ + if value is None: + return None + + param_name = param.name if param else "parameter" + + # Try JSON parsing first + try: + parsed = json.loads(value) + if isinstance(parsed, dict): + return parsed + else: + # JSON parsed but not a dict - let it fall through to simple parsing + pass + except (json.JSONDecodeError, ValueError): + # JSON parsing failed, continue to simple parsing + pass + + # Try simple dict parsing: {key: value, key2: value2} + try: + return _parse_simple_dict(value) + except Exception: + pass + + # Both formats failed - provide helpful error message + raise ParameterParsingError( + f"Invalid format for --{param_name}. Supported formats:\n" + f" JSON: '{{\"key\": \"value\", \"key2\": \"value2\"}}'\n" + f" Simple: '{{key: value, key2: \"value with spaces\"}}'" + ) + + +def parse_complex_object_parameter(ctx, param, value: Union[str, List[str]], allow_multiple: bool = True) -> List[Dict[str, Any]]: + """ + Parse complex object parameters supporting multiple formats and multiple flag usage. + + Supported formats: + - JSON single object: '{"key": "value", "key2": "value2"}' + - JSON array (if allow_multiple=True): '[{"key": "value"}, {"key2": "value2"}]' + - Key-value single: 'key=value,key2=value2' + - Multiple flags (if allow_multiple=True): --param obj1 --param obj2 + + Args: + ctx: Click context + param: Click parameter object + value: Input string value or list of string values (for multiple flags) + allow_multiple: Whether to support multiple objects via multiple flags or JSON arrays + + Returns: + List of dictionaries representing complex objects (single item list if allow_multiple=False) + + Raises: + ParameterParsingError: If parsing fails for both JSON and key-value formats + """ + if not value: + return None + + param_name = param.name if param else "parameter" + + # Handle multiple flag usage: --volume config1 --volume config2 + if not isinstance(value, (list, tuple)): + value = [value] + + # Check allow_multiple constraint + if not allow_multiple and len(value) > 1: + raise ParameterParsingError( + f"--{param_name} does not support multiple values. " + f"Received {len(value)} values: {value}" + ) + + results = [] + for i, item in enumerate(value): + try: + # Try JSON parsing first + try: + parsed = json.loads(item) + if isinstance(parsed, dict): + results.append(parsed) + continue + elif isinstance(parsed, list) and allow_multiple: + # JSON array format: '[{"key": "value"}, {"key2": "value2"}]' + for j, array_item in enumerate(parsed): + if not isinstance(array_item, dict): + raise ParameterParsingError( + f"--{param_name} JSON array item {j+1} must be an object, got {type(array_item).__name__}" + ) + results.append(array_item) + continue + elif isinstance(parsed, list) and not allow_multiple: + raise ParameterParsingError( + f"--{param_name} does not support JSON arrays. Use single object format." + ) + except (json.JSONDecodeError, ValueError): + pass + + # Try key-value parsing: key=value,key2=value2 + parsed_dict = _parse_key_value_pairs(item) + results.append(parsed_dict) + + except Exception as e: + if isinstance(e, ParameterParsingError): + raise e + raise ParameterParsingError( + f"Invalid format for --{param_name} item {i+1}: '{item}'. Supported formats:\n" + f" JSON object: '{{\"key\": \"value\", \"key2\": \"value2\"}}'\n" + + (f" JSON array: '[{{\"key\": \"value\"}}, {{\"key2\": \"value2\"}}]'\n" if allow_multiple else "") + + f" Key-value: 'key=value,key2=value2'" + ) + + # For single-object mode, return single item list or enforce single result + if not allow_multiple and len(results) > 1: + raise ParameterParsingError( + f"--{param_name} produced multiple objects but only single object is allowed" + ) + + return results + + +def _parse_simple_list(value: str) -> List[str]: + """ + Parse simple list format: [item1, item2] or [item1,item2] + + Args: + value: String in format [item1, item2] + + Returns: + List of string items + + Raises: + ValueError: If format is invalid + """ + value = value.strip() + + if not (value.startswith('[') and value.endswith(']')): + raise ValueError("List must be enclosed in brackets") + + # Remove brackets and get inner content + inner = value[1:-1].strip() + + if not inner: + return [] + + # For simple format, check for common malformed JSON patterns + if inner.endswith(','): # trailing comma + raise ValueError("Invalid list format - trailing comma detected") + + # Split by comma and clean up items + items = [] + for item in inner.split(','): + item = item.strip() + if not item: # Empty item (e.g., trailing comma) + continue + # Remove surrounding quotes if present + if ((item.startswith('"') and item.endswith('"')) or + (item.startswith("'") and item.endswith("'"))): + item = item[1:-1] + items.append(item) + + return items + + +def _parse_simple_dict(value: str) -> Dict[str, str]: + """ + Parse simple dictionary format: {key: value, key2: "value with spaces"} + + Args: + value: String in format {key: value, key2: value2} + + Returns: + Dictionary of string key-value pairs + + Raises: + ValueError: If format is invalid + """ + value = value.strip() + + if not (value.startswith('{') and value.endswith('}')): + raise ValueError("Dictionary must be enclosed in braces") + + # Remove braces and get inner content + inner = value[1:-1].strip() + + if not inner: + return {} + + # Parse key-value pairs using regex to handle quoted values + result = {} + + # Split by comma, but respect quotes + pairs = _split_respecting_quotes(inner, ',') + + for pair in pairs: + pair = pair.strip() + if ':' not in pair: + raise ValueError(f"Invalid key-value pair: '{pair}'. Expected format: 'key: value'") + + key_part, value_part = pair.split(':', 1) + key = key_part.strip() + value_str = value_part.strip() + + # Remove surrounding quotes from key if present + if ((key.startswith('"') and key.endswith('"')) or + (key.startswith("'") and key.endswith("'"))): + key = key[1:-1] + + # Remove surrounding quotes from value and unescape inner quotes + if ((value_str.startswith('"') and value_str.endswith('"')) or + (value_str.startswith("'") and value_str.endswith("'"))): + quote_char = value_str[0] + value_str = value_str[1:-1] + # Unescape inner quotes + if quote_char == '"': + value_str = value_str.replace('\\"', '"') + else: + value_str = value_str.replace("\\'", "'") + + result[key] = value_str + + return result + + +def _parse_key_value_pairs(value: str) -> Dict[str, str]: + """ + Parse key-value pairs format: key=value,key2=value2 + + Args: + value: String in format key=value,key2=value2 + + Returns: + Dictionary of string key-value pairs + + Raises: + ValueError: If format is invalid + """ + result = {} + + # Split by comma and parse each key=value pair + for pair in value.split(','): + if '=' not in pair: + raise ValueError(f"Invalid key-value pair: '{pair}'. Expected format: 'key=value'") + + key, val = pair.split('=', 1) # Split only on first '=' to handle values with '=' + key = key.strip() + val = val.strip() + + if not key: + raise ValueError(f"Empty key in pair: '{pair}'") + + result[key] = val + + return result + + +def _split_respecting_quotes(text: str, delimiter: str) -> List[str]: + """ + Split text by delimiter while respecting quoted sections. + + Args: + text: Text to split + delimiter: Delimiter to split on + + Returns: + List of split parts + """ + parts = [] + current = [] + in_quotes = False + quote_char = None + + i = 0 + while i < len(text): + char = text[i] + + if char in ('"', "'") and (i == 0 or text[i-1] != '\\'): + if not in_quotes: + in_quotes = True + quote_char = char + elif char == quote_char: + in_quotes = False + quote_char = None + + if char == delimiter and not in_quotes: + parts.append(''.join(current)) + current = [] + else: + current.append(char) + + i += 1 + + if current: + parts.append(''.join(current)) + + return parts + + +def parse_comma_separated_list(ctx, param, value: str) -> List[str]: + """ + Parse comma-separated list format: item1,item2,item3 + + This is a legacy format parser for backward compatibility. + Use parse_list_parameter for new implementations. + + Args: + ctx: Click context + param: Click parameter object + value: Input string value + + Returns: + List of string items + """ + if value is None: + return None + + # Split by comma and clean up items + items = [item.strip() for item in value.split(',') if item.strip()] + return items diff --git a/src/sagemaker/hyperpod/cli/training_utils.py b/src/sagemaker/hyperpod/cli/training_utils.py index 1a3d057a..8dddc8f9 100644 --- a/src/sagemaker/hyperpod/cli/training_utils.py +++ b/src/sagemaker/hyperpod/cli/training_utils.py @@ -5,6 +5,7 @@ from pydantic import ValidationError import sys from sagemaker.hyperpod.cli.common_utils import extract_version_from_args, get_latest_version, load_schema_for_version +from sagemaker.hyperpod.cli.parsers import parse_dict_parameter, parse_list_parameter, parse_complex_object_parameter def generate_click_command( @@ -27,45 +28,7 @@ def generate_click_command( version = extract_version_from_args(registry, schema_pkg, default_version) def decorator(func: Callable) -> Callable: - # Parser for the single JSON‐dict env var flag - def _parse_json_flag(ctx, param, value): - if value is None: - return None - try: - return json.loads(value) - except json.JSONDecodeError as e: - raise click.BadParameter(f"{param.name!r} must be valid JSON: {e}") - - # Parser for list flags - def _parse_list_flag(ctx, param, value): - if value is None: - return None - # Remove brackets and split by comma - value = value.strip("[]") - return [item.strip() for item in value.split(",") if item.strip()] - - def _parse_volume_param(ctx, param, value): - """Parse volume parameters from command line format to dictionary format.""" - if not value: - return None - - volumes = [] - for i, v in enumerate(value): - try: - # Split by comma and then by equals, with validation - parts = {} - for item in v.split(','): - if '=' not in item: - raise click.UsageError(f"Invalid volume format in volume {i+1}: '{item}' should be key=value") - key, val = item.split('=', 1) # Split only on first '=' to handle values with '=' - parts[key.strip()] = val.strip() - - volumes.append(parts) - except Exception as e: - raise click.UsageError(f"Error parsing volume {i+1}: {str(e)}") - - # Note: Detailed validation will be handled by schema validation - return volumes + # Use unified parsers for consistent behavior across all parameter types # 1) the wrapper click will call def wrapped_func(*args, **kwargs): @@ -103,47 +66,47 @@ def wrapped_func(*args, **kwargs): wrapped_func = click.option( "--environment", - callback=_parse_json_flag, + callback=parse_dict_parameter, type=str, default=None, help=( - "JSON object of environment variables, e.g. " - '\'{"VAR1":"foo","VAR2":"bar"}\'' + "Environment variables. Supports JSON format: " + '\'{"VAR1":"foo","VAR2":"bar"}\' or simple format: ' + '\'{VAR1: foo, VAR2: bar}\'' ), - metavar="JSON", + metavar="JSON|SIMPLE", )(wrapped_func) wrapped_func = click.option( "--label-selector", - callback=_parse_json_flag, - help='JSON object of resource limits, e.g. \'{"cpu":"2","memory":"4Gi"}\'', - metavar="JSON", + callback=parse_dict_parameter, + help='Node label selector. Supports JSON format: \'{"cpu":"2","memory":"4Gi"}\' or simple format: \'{cpu: 2, memory: 4Gi}\'', + metavar="JSON|SIMPLE", )(wrapped_func) wrapped_func = click.option( "--volume", multiple=True, - callback=_parse_volume_param, - help="List of volume configurations. \ - Command structure: --volume name=,type=,mount_path=, \ - For hostPath: --volume name=model-data,type=hostPath,mount_path=/data,path=/data \ - For persistentVolumeClaim: --volume name=training-output,type=pvc,mount_path=/mnt/output,claim_name=training-output-pvc,read_only=false \ - If multiple --volume flag if multiple volumes are needed.", + callback=lambda ctx, param, value: parse_complex_object_parameter(ctx, param, value, allow_multiple=True), + help="Volume configurations. Supports JSON format: " + '\'{"name":"vol1","type":"hostPath","mount_path":"/data","path":"/data"}\' ' + "or key-value format: 'name=vol1,type=hostPath,mount_path=/data,path=/data'. " + "Use multiple --volume flags for multiple volumes.", )(wrapped_func) # Add list options list_params = { - "command": "List of command arguments", - "args": "List of script arguments, e.g. '[--batch-size, 32, --learning-rate, 0.001]'", + "command": "Command arguments. Supports JSON format: '[\"python\", \"train.py\"]' or simple format: '[python, train.py]'", + "args": "Script arguments. Supports JSON format: '[\"--batch-size\", \"32\", \"--learning-rate\", \"0.001\"]' or simple format: '[--batch-size, 32, --learning-rate, 0.001]'", } for param_name, help_text in list_params.items(): wrapped_func = click.option( f"--{param_name}", - callback=_parse_list_flag, + callback=parse_list_parameter, type=str, default=None, help=help_text, - metavar="LIST", + metavar="JSON|SIMPLE", )(wrapped_func) excluded_props = set( diff --git a/test/unit_tests/cli/test_cluster_stack.py b/test/unit_tests/cli/test_cluster_stack.py index ddff5b63..b0e8af09 100644 --- a/test/unit_tests/cli/test_cluster_stack.py +++ b/test/unit_tests/cli/test_cluster_stack.py @@ -4,7 +4,8 @@ from click.testing import CliRunner from datetime import datetime import click -from sagemaker.hyperpod.cli.commands.cluster_stack import update_cluster, list_cluster_stacks, parse_status_list +from sagemaker.hyperpod.cli.commands.cluster_stack import update_cluster, list_cluster_stacks +from sagemaker.hyperpod.cli.parsers import parse_list_parameter class TestUpdateCluster: @@ -227,7 +228,7 @@ def test_list_cluster_stacks_invalid_status_format(self, mock_setup_logging, moc # Assert assert result.exit_code != 0 - assert 'Invalid list format' in result.output + assert 'Invalid format for --status' in result.output mock_hp_cluster_list.assert_not_called() @patch('sagemaker.hyperpod.cli.commands.cluster_stack.HpClusterStack.list') @@ -262,38 +263,45 @@ def test_list_cluster_stacks_single_status(self, mock_setup_logging, mock_hp_clu mock_hp_cluster_list.assert_called_once_with(region=None, stack_status_filter=['CREATE_IN_PROGRESS']) -class TestParseStatusList: - """Test cases for parse_status_list function""" +class TestParseListParameter: + """Test cases for parse_list_parameter function used by status parsing""" - def test_parse_status_list_valid_format(self): - """Test parsing valid list format.""" - result = parse_status_list(None, None, "['CREATE_COMPLETE', 'UPDATE_COMPLETE']") + def test_parse_list_parameter_valid_json_format(self): + """Test parsing valid JSON list format.""" + result = parse_list_parameter(None, None, '["CREATE_COMPLETE", "UPDATE_COMPLETE"]') assert result == ['CREATE_COMPLETE', 'UPDATE_COMPLETE'] - def test_parse_status_list_single_item(self): + def test_parse_list_parameter_simple_format(self): + """Test parsing simple list format.""" + result = parse_list_parameter(None, None, '[CREATE_COMPLETE, UPDATE_COMPLETE]') + assert result == ['CREATE_COMPLETE', 'UPDATE_COMPLETE'] + + def test_parse_list_parameter_single_item(self): """Test parsing single item list.""" - result = parse_status_list(None, None, "['CREATE_COMPLETE']") + result = parse_list_parameter(None, None, '["CREATE_COMPLETE"]') assert result == ['CREATE_COMPLETE'] - def test_parse_status_list_empty_input(self): + def test_parse_list_parameter_empty_input(self): """Test parsing empty/None input.""" - result = parse_status_list(None, None, None) + result = parse_list_parameter(None, None, None) assert result is None - result = parse_status_list(None, None, "") + result = parse_list_parameter(None, None, "") assert result is None - def test_parse_status_list_invalid_format(self): - """Test parsing invalid format raises BadParameter.""" - with pytest.raises(click.BadParameter) as exc_info: - parse_status_list(None, None, "invalid-format") - assert "Invalid list format" in str(exc_info.value) + def test_parse_list_parameter_invalid_format(self): + """Test parsing invalid format raises ParameterParsingError.""" + from sagemaker.hyperpod.cli.parsers import ParameterParsingError + with pytest.raises(ParameterParsingError) as exc_info: + parse_list_parameter(None, None, "invalid-format") + assert "Invalid format" in str(exc_info.value) - def test_parse_status_list_non_list_format(self): - """Test parsing valid syntax but non-list raises BadParameter.""" - with pytest.raises(click.BadParameter) as exc_info: - parse_status_list(None, None, "'not-a-list'") - assert "Expected list format" in str(exc_info.value) + def test_parse_list_parameter_non_list_format(self): + """Test parsing valid syntax but non-list raises ParameterParsingError.""" + from sagemaker.hyperpod.cli.parsers import ParameterParsingError + with pytest.raises(ParameterParsingError) as exc_info: + parse_list_parameter(None, None, '"not-a-list"') + assert "Expected a list for --parameter" in str(exc_info.value) @patch('sagemaker.hyperpod.cluster_management.hp_cluster_stack.importlib.resources.read_text') @@ -511,4 +519,4 @@ def test_create_cluster_stack_helper_handles_empty_resource_name_prefix(self, mo # Verify empty prefix is not modified call_args = mock_cluster_stack.call_args[1] - assert call_args['resource_name_prefix'] == '' \ No newline at end of file + assert call_args['resource_name_prefix'] == '' diff --git a/test/unit_tests/cli/test_inference_utils.py b/test/unit_tests/cli/test_inference_utils.py index 1eee54f8..57d658a8 100644 --- a/test/unit_tests/cli/test_inference_utils.py +++ b/test/unit_tests/cli/test_inference_utils.py @@ -73,7 +73,7 @@ def cmd(name, namespace, version, domain): # invalid JSON produces click error res_err = self.runner.invoke(cmd, ['--env', 'notjson']) assert res_err.exit_code == 2 - assert 'must be valid JSON' in res_err.output + assert 'Invalid format for --env' in res_err.output @patch('sagemaker.hyperpod.cli.inference_utils.load_schema_for_version') def test_type_mapping_and_defaults(self, mock_load_schema): diff --git a/test/unit_tests/cli/test_training_utils.py b/test/unit_tests/cli/test_training_utils.py index 4253f41a..8a0a189a 100644 --- a/test/unit_tests/cli/test_training_utils.py +++ b/test/unit_tests/cli/test_training_utils.py @@ -92,7 +92,7 @@ def cmd(version, debug, config): # Test invalid JSON input result = self.runner.invoke(cmd, ['--environment', 'invalid']) assert result.exit_code == 2 - assert 'must be valid JSON' in result.output + assert 'Invalid format for --environment' in result.output @patch('sagemaker.hyperpod.cli.training_utils.pkgutil.get_data') def test_list_parameters(self, mock_get_data): @@ -404,14 +404,14 @@ def cmd(version, debug, config): '--volume', 'name=model-data,type=hostPath,mount_path,path=/host/data' ]) assert result.exit_code == 2 - assert "should be key=value" in result.output + assert "Invalid format for --volume" in result.output # Test empty volume parameter result = self.runner.invoke(cmd, [ '--volume', '' ]) assert result.exit_code == 2 - assert "Error parsing volume" in result.output + assert "Invalid format for --volume" in result.output @patch('sagemaker.hyperpod.cli.training_utils.pkgutil.get_data') def test_volume_flag_with_equals_in_value(self, mock_get_data): @@ -584,13 +584,13 @@ def cmd(version, debug, domain): { 'args': ['--job-name', 'test-job', '--environment', 'invalid-json'], 'expected_error': True, - 'error_message': "must be valid JSON" + 'error_message': "Invalid format for --environment" }, # Invalid volume format { 'args': ['--job-name', 'test-job', '--volume', 'invalid-volume-format'], 'expected_error': True, - 'error_message': "Invalid volume format" + 'error_message': "Invalid format for --volume" }, # Multiple valid volumes { diff --git a/test/unit_tests/cli/test_unified_parsers.py b/test/unit_tests/cli/test_unified_parsers.py new file mode 100644 index 00000000..525064b5 --- /dev/null +++ b/test/unit_tests/cli/test_unified_parsers.py @@ -0,0 +1,531 @@ +#!/usr/bin/env python3 +""" +Comprehensive test suite for unified parameter parsers. + +Tests all parameter types, supported formats, edge cases, and error handling +to ensure the unified parsing system works correctly across all scenarios. +""" + +import pytest +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../../src')) + +from sagemaker.hyperpod.cli.parsers import ( + parse_list_parameter, + parse_dict_parameter, + parse_complex_object_parameter, + parse_comma_separated_list, + ParameterParsingError +) + + +class MockParam: + """Mock click parameter for testing.""" + def __init__(self, name): + self.name = name + + +class TestListParameters: + """Test suite for list parameter parsing.""" + + def setup_method(self): + """Set up test fixtures.""" + self.param = MockParam("command") + self.ctx = None + + def test_json_format_basic(self): + """Test basic JSON list format.""" + result = parse_list_parameter(self.ctx, self.param, '["python", "train.py"]') + assert result == ["python", "train.py"] + + def test_json_format_with_spaces(self): + """Test JSON format with spaces in values.""" + result = parse_list_parameter(self.ctx, self.param, '["python", "my script.py", "--arg", "value with spaces"]') + assert result == ["python", "my script.py", "--arg", "value with spaces"] + + def test_json_format_empty_list(self): + """Test empty JSON list.""" + result = parse_list_parameter(self.ctx, self.param, '[]') + assert result == [] + + def test_json_format_single_item(self): + """Test JSON list with single item.""" + result = parse_list_parameter(self.ctx, self.param, '["single"]') + assert result == ["single"] + + def test_simple_format_basic(self): + """Test basic simple list format.""" + result = parse_list_parameter(self.ctx, self.param, '[python, train.py]') + assert result == ["python", "train.py"] + + def test_simple_format_no_spaces(self): + """Test simple format without spaces.""" + result = parse_list_parameter(self.ctx, self.param, '[python,train.py,--epochs,10]') + assert result == ["python", "train.py", "--epochs", "10"] + + def test_simple_format_with_quotes(self): + """Test simple format with quoted values for spaces.""" + result = parse_list_parameter(self.ctx, self.param, '[python, "my script.py", --arg, "value with spaces"]') + assert result == ["python", "my script.py", "--arg", "value with spaces"] + + def test_simple_format_empty_list(self): + """Test empty simple list.""" + result = parse_list_parameter(self.ctx, self.param, '[]') + assert result == [] + + def test_simple_format_mixed_quotes(self): + """Test simple format with mixed quote types.""" + result = parse_list_parameter(self.ctx, self.param, "[python, 'single quote', \"double quote\"]") + assert result == ["python", "single quote", "double quote"] + + def test_none_input(self): + """Test None input returns None.""" + result = parse_list_parameter(self.ctx, self.param, None) + assert result is None + + def test_invalid_format_no_brackets(self): + """Test error for invalid format without brackets.""" + with pytest.raises(ParameterParsingError) as exc_info: + parse_list_parameter(self.ctx, self.param, 'python, train.py') + + error_msg = str(exc_info.value) + assert "Invalid format for --command" in error_msg + assert "JSON:" in error_msg + assert "Simple:" in error_msg + + def test_invalid_json_format(self): + """Test error for invalid JSON.""" + with pytest.raises(ParameterParsingError) as exc_info: + parse_list_parameter(self.ctx, self.param, '["invalid", json,]') + + error_msg = str(exc_info.value) + assert "Invalid format for --command" in error_msg + + def test_non_list_json(self): + """Test error when JSON parses but isn't a list.""" + with pytest.raises(ParameterParsingError) as exc_info: + parse_list_parameter(self.ctx, self.param, '{"not": "a list"}') + + error_msg = str(exc_info.value) + assert "Expected a list for --command, got dict" in error_msg + + +class TestDictParameters: + """Test suite for dictionary parameter parsing.""" + + def setup_method(self): + """Set up test fixtures.""" + self.param = MockParam("environment") + self.ctx = None + + def test_json_format_basic(self): + """Test basic JSON dictionary format.""" + result = parse_dict_parameter(self.ctx, self.param, '{"VAR1": "foo", "VAR2": "bar"}') + assert result == {"VAR1": "foo", "VAR2": "bar"} + + def test_json_format_with_spaces(self): + """Test JSON format with spaces in values.""" + result = parse_dict_parameter(self.ctx, self.param, '{"VAR1": "value with spaces", "VAR2": "another value"}') + assert result == {"VAR1": "value with spaces", "VAR2": "another value"} + + def test_json_format_empty_dict(self): + """Test empty JSON dictionary.""" + result = parse_dict_parameter(self.ctx, self.param, '{}') + assert result == {} + + def test_json_format_nested_quotes(self): + """Test JSON format with nested quotes.""" + result = parse_dict_parameter(self.ctx, self.param, '{"VAR1": "He said \\"hello\\"", "VAR2": "bar"}') + assert result == {"VAR1": "He said \"hello\"", "VAR2": "bar"} + + def test_simple_format_basic(self): + """Test basic simple dictionary format.""" + result = parse_dict_parameter(self.ctx, self.param, '{VAR1: foo, VAR2: bar}') + assert result == {"VAR1": "foo", "VAR2": "bar"} + + def test_simple_format_no_spaces(self): + """Test simple format without spaces.""" + result = parse_dict_parameter(self.ctx, self.param, '{VAR1:foo,VAR2:bar}') + assert result == {"VAR1": "foo", "VAR2": "bar"} + + def test_simple_format_with_quotes(self): + """Test simple format with quoted values.""" + result = parse_dict_parameter(self.ctx, self.param, '{VAR1: "value with spaces", VAR2: bar}') + assert result == {"VAR1": "value with spaces", "VAR2": "bar"} + + def test_simple_format_quoted_keys(self): + """Test simple format with quoted keys.""" + result = parse_dict_parameter(self.ctx, self.param, '{"VAR1": foo, "VAR2": "bar"}') + assert result == {"VAR1": "foo", "VAR2": "bar"} + + def test_simple_format_mixed_quotes(self): + """Test simple format with mixed quote types.""" + result = parse_dict_parameter(self.ctx, self.param, "{VAR1: 'single', VAR2: \"double\"}") + assert result == {"VAR1": "single", "VAR2": "double"} + + def test_simple_format_empty_dict(self): + """Test empty simple dictionary.""" + result = parse_dict_parameter(self.ctx, self.param, '{}') + assert result == {} + + def test_none_input(self): + """Test None input returns None.""" + result = parse_dict_parameter(self.ctx, self.param, None) + assert result is None + + def test_invalid_format_no_braces(self): + """Test error for invalid format without braces.""" + with pytest.raises(ParameterParsingError) as exc_info: + parse_dict_parameter(self.ctx, self.param, 'VAR1: foo, VAR2: bar') + + error_msg = str(exc_info.value) + assert "Invalid format for --environment" in error_msg + assert "JSON:" in error_msg + assert "Simple:" in error_msg + + def test_invalid_json_format(self): + """Test error for invalid JSON.""" + with pytest.raises(ParameterParsingError) as exc_info: + parse_dict_parameter(self.ctx, self.param, '{"VAR1": "foo", invalid}') + + error_msg = str(exc_info.value) + assert "Invalid format for --environment" in error_msg + + def test_simple_format_missing_colon(self): + """Test error for simple format missing colon.""" + with pytest.raises(ParameterParsingError) as exc_info: + parse_dict_parameter(self.ctx, self.param, '{VAR1 foo, VAR2: bar}') + + error_msg = str(exc_info.value) + assert "Invalid format for --environment" in error_msg + + def test_non_dict_json(self): + """Test error when JSON parses but isn't a dictionary.""" + with pytest.raises(ParameterParsingError) as exc_info: + parse_dict_parameter(self.ctx, self.param, '["not", "a", "dict"]') + + error_msg = str(exc_info.value) + assert "Invalid format for --environment" in error_msg + + +class TestComplexObjectParameters: + """Test suite for complex object parameter parsing.""" + + def setup_method(self): + """Set up test fixtures.""" + self.param = MockParam("volume") + self.ctx = None + + def test_json_format_single_object(self): + """Test JSON format with single object.""" + result = parse_complex_object_parameter(self.ctx, self.param, '{"name": "vol1", "type": "hostPath", "mount_path": "/data"}') + assert len(result) == 1 + assert result[0] == {"name": "vol1", "type": "hostPath", "mount_path": "/data"} + + def test_json_format_array_allow_multiple(self): + """Test JSON array format when allow_multiple=True.""" + json_array = '[{"name": "vol1", "type": "hostPath"}, {"name": "vol2", "type": "pvc"}]' + result = parse_complex_object_parameter(self.ctx, self.param, json_array, allow_multiple=True) + assert len(result) == 2 + assert result[0] == {"name": "vol1", "type": "hostPath"} + assert result[1] == {"name": "vol2", "type": "pvc"} + + def test_json_format_array_disallow_multiple(self): + """Test JSON array format error when allow_multiple=False.""" + json_array = '[{"name": "vol1", "type": "hostPath"}, {"name": "vol2", "type": "pvc"}]' + with pytest.raises(ParameterParsingError) as exc_info: + parse_complex_object_parameter(self.ctx, self.param, json_array, allow_multiple=False) + + error_msg = str(exc_info.value) + assert "does not support JSON arrays" in error_msg + + def test_key_value_format_basic(self): + """Test basic key-value format.""" + result = parse_complex_object_parameter(self.ctx, self.param, 'name=vol1,type=hostPath,mount_path=/data') + assert len(result) == 1 + assert result[0] == {"name": "vol1", "type": "hostPath", "mount_path": "/data"} + + def test_key_value_format_with_equals_in_value(self): + """Test key-value format with equals sign in value.""" + result = parse_complex_object_parameter(self.ctx, self.param, 'name=vol1,command=echo "x=y",type=hostPath') + assert len(result) == 1 + assert result[0] == {"name": "vol1", "command": "echo \"x=y\"", "type": "hostPath"} + + def test_multiple_flags_allow_multiple(self): + """Test multiple flag usage when allow_multiple=True.""" + values = [ + 'name=vol1,type=hostPath,mount_path=/data1', + 'name=vol2,type=pvc,mount_path=/data2,claim_name=my-pvc' + ] + result = parse_complex_object_parameter(self.ctx, self.param, values, allow_multiple=True) + assert len(result) == 2 + assert result[0] == {"name": "vol1", "type": "hostPath", "mount_path": "/data1"} + assert result[1] == {"name": "vol2", "type": "pvc", "mount_path": "/data2", "claim_name": "my-pvc"} + + def test_multiple_flags_disallow_multiple(self): + """Test multiple flag usage error when allow_multiple=False.""" + values = ['name=vol1,type=hostPath', 'name=vol2,type=pvc'] + with pytest.raises(ParameterParsingError) as exc_info: + parse_complex_object_parameter(self.ctx, self.param, values, allow_multiple=False) + + error_msg = str(exc_info.value) + assert "does not support multiple values" in error_msg + + def test_mixed_json_and_key_value(self): + """Test mixing JSON and key-value formats in multiple flags.""" + values = [ + '{"name": "vol1", "type": "hostPath"}', + 'name=vol2,type=pvc,claim_name=my-pvc' + ] + result = parse_complex_object_parameter(self.ctx, self.param, values, allow_multiple=True) + assert len(result) == 2 + assert result[0] == {"name": "vol1", "type": "hostPath"} + assert result[1] == {"name": "vol2", "type": "pvc", "claim_name": "my-pvc"} + + def test_none_input(self): + """Test None input returns None.""" + result = parse_complex_object_parameter(self.ctx, self.param, None) + assert result is None + + def test_empty_list_input(self): + """Test empty list input returns None.""" + result = parse_complex_object_parameter(self.ctx, self.param, []) + assert result is None + + def test_invalid_key_value_format(self): + """Test error for invalid key-value format.""" + with pytest.raises(ParameterParsingError) as exc_info: + parse_complex_object_parameter(self.ctx, self.param, 'name=vol1,invalid_pair,type=hostPath') + + error_msg = str(exc_info.value) + assert "Invalid format for --volume" in error_msg + assert "JSON object:" in error_msg + assert "Key-value:" in error_msg + + def test_json_array_invalid_item_type(self): + """Test error for JSON array with invalid item type.""" + with pytest.raises(ParameterParsingError) as exc_info: + parse_complex_object_parameter(self.ctx, self.param, '[{"name": "vol1"}, "invalid"]', allow_multiple=True) + + error_msg = str(exc_info.value) + assert "JSON array item 2 must be an object" in error_msg + + def test_empty_key_in_key_value(self): + """Test error for empty key in key-value format.""" + with pytest.raises(ParameterParsingError) as exc_info: + parse_complex_object_parameter(self.ctx, self.param, '=value,name=vol1') + + error_msg = str(exc_info.value) + assert "Invalid format for --volume" in error_msg + + +class TestCommaSeparatedList: + """Test suite for legacy comma-separated list parsing.""" + + def setup_method(self): + """Set up test fixtures.""" + self.param = MockParam("clusters") + self.ctx = None + + def test_basic_comma_separated(self): + """Test basic comma-separated format.""" + result = parse_comma_separated_list(self.ctx, self.param, 'cluster1,cluster2,cluster3') + assert result == ['cluster1', 'cluster2', 'cluster3'] + + def test_comma_separated_with_spaces(self): + """Test comma-separated format with spaces.""" + result = parse_comma_separated_list(self.ctx, self.param, 'cluster1, cluster2, cluster3') + assert result == ['cluster1', 'cluster2', 'cluster3'] + + def test_single_item(self): + """Test single item.""" + result = parse_comma_separated_list(self.ctx, self.param, 'single-cluster') + assert result == ['single-cluster'] + + def test_empty_items_filtered(self): + """Test that empty items are filtered out.""" + result = parse_comma_separated_list(self.ctx, self.param, 'cluster1,,cluster2, ,cluster3') + assert result == ['cluster1', 'cluster2', 'cluster3'] + + def test_none_input(self): + """Test None input returns None.""" + result = parse_comma_separated_list(self.ctx, self.param, None) + assert result is None + + +class TestEdgeCases: + """Test suite for edge cases and special scenarios.""" + + def setup_method(self): + """Set up test fixtures.""" + self.list_param = MockParam("args") + self.dict_param = MockParam("environment") + self.complex_param = MockParam("volume") + self.ctx = None + + def test_unicode_characters(self): + """Test handling of unicode characters.""" + # List with unicode + result = parse_list_parameter(self.ctx, self.list_param, '["café", "naïve", "résumé"]') + assert result == ["café", "naïve", "résumé"] + + # Dict with unicode + result = parse_dict_parameter(self.ctx, self.dict_param, '{"café": "naïve", "résumé": "value"}') + assert result == {"café": "naïve", "résumé": "value"} + + def test_special_characters_in_values(self): + """Test handling of special characters.""" + # JSON with special chars + result = parse_dict_parameter(self.ctx, self.dict_param, '{"KEY": "value@#$%^&*()"}') + assert result == {"KEY": "value@#$%^&*()"} + + # Simple format with special chars (quoted) + result = parse_dict_parameter(self.ctx, self.dict_param, '{KEY: "value@#$%^&*()"}') + assert result == {"KEY": "value@#$%^&*()"} + + def test_numbers_and_booleans(self): + """Test handling of numbers and booleans in simple format.""" + result = parse_dict_parameter(self.ctx, self.dict_param, '{PORT: 8080, DEBUG: true, TIMEOUT: 30.5}') + assert result == {"PORT": "8080", "DEBUG": "true", "TIMEOUT": "30.5"} + + def test_nested_quotes(self): + """Test handling of nested quotes in values.""" + result = parse_dict_parameter(self.ctx, self.dict_param, '{CMD: "echo \\"hello world\\""}') + assert result == {"CMD": "echo \"hello world\""} + + def test_path_values(self): + """Test handling of file paths.""" + result = parse_complex_object_parameter(self.ctx, self.complex_param, + 'name=vol1,type=hostPath,mount_path=/opt/ml/model,path=/home/user/data') + assert result[0]["mount_path"] == "/opt/ml/model" + assert result[0]["path"] == "/home/user/data" + + def test_empty_string_values(self): + """Test handling of empty string values.""" + result = parse_dict_parameter(self.ctx, self.dict_param, '{KEY1: "", KEY2: value}') + assert result == {"KEY1": "", "KEY2": "value"} + + def test_whitespace_handling(self): + """Test handling of various whitespace scenarios.""" + # Extra whitespace in simple list + result = parse_list_parameter(self.ctx, self.list_param, '[ python , train.py ]') + assert result == ["python", "train.py"] + + # Extra whitespace in simple dict + result = parse_dict_parameter(self.ctx, self.dict_param, '{ VAR1 : foo , VAR2 : bar }') + assert result == {"VAR1": "foo", "VAR2": "bar"} + + +class TestErrorMessageQuality: + """Test suite for error message quality and helpfulness.""" + + def setup_method(self): + """Set up test fixtures.""" + self.param = MockParam("test_param") + self.ctx = None + + def test_list_error_message_content(self): + """Test that list error messages contain helpful information.""" + with pytest.raises(ParameterParsingError) as exc_info: + parse_list_parameter(self.ctx, self.param, 'invalid format') + + error_msg = str(exc_info.value) + assert "--test_param" in error_msg + assert "JSON:" in error_msg + assert "Simple:" in error_msg + assert "[\"item1\", \"item2\", \"item3\"]" in error_msg + assert "[item1, item2, item3]" in error_msg + + def test_dict_error_message_content(self): + """Test that dict error messages contain helpful information.""" + with pytest.raises(ParameterParsingError) as exc_info: + parse_dict_parameter(self.ctx, self.param, 'invalid format') + + error_msg = str(exc_info.value) + assert "--test_param" in error_msg + assert "JSON:" in error_msg + assert "Simple:" in error_msg + assert "{\"key\": \"value\", \"key2\": \"value2\"}" in error_msg + assert "{key: value, key2: \"value with spaces\"}" in error_msg + + def test_complex_object_error_message_content(self): + """Test that complex object error messages contain helpful information.""" + with pytest.raises(ParameterParsingError) as exc_info: + parse_complex_object_parameter(self.ctx, self.param, 'invalid format', allow_multiple=True) + + error_msg = str(exc_info.value) + assert "--test_param" in error_msg + assert "JSON object:" in error_msg + assert "JSON array:" in error_msg # Should show array format when allow_multiple=True + assert "Key-value:" in error_msg + + def test_complex_object_single_mode_error(self): + """Test error message when allow_multiple=False.""" + with pytest.raises(ParameterParsingError) as exc_info: + parse_complex_object_parameter(self.ctx, self.param, 'invalid format', allow_multiple=False) + + error_msg = str(exc_info.value) + assert "JSON array:" not in error_msg # Should not show array format when allow_multiple=False + + def test_parameter_name_in_error_messages(self): + """Test that parameter names are correctly included in error messages.""" + test_cases = [ + (MockParam("environment"), parse_dict_parameter, "invalid"), + (MockParam("command"), parse_list_parameter, "invalid"), + (MockParam("volume"), parse_complex_object_parameter, "invalid"), + ] + + for param, parser_func, invalid_input in test_cases: + with pytest.raises(ParameterParsingError) as exc_info: + if parser_func == parse_complex_object_parameter: + parser_func(self.ctx, param, invalid_input) + else: + parser_func(self.ctx, param, invalid_input) + + error_msg = str(exc_info.value) + assert f"--{param.name}" in error_msg + + +class TestBackwardCompatibility: + """Test suite for backward compatibility with existing formats.""" + + def setup_method(self): + """Set up test fixtures.""" + self.ctx = None + + def test_existing_json_formats_still_work(self): + """Test that all existing JSON formats continue to work.""" + param = MockParam("test") + + # Lists + result = parse_list_parameter(self.ctx, param, '["python", "train.py"]') + assert result == ["python", "train.py"] + + # Dicts + result = parse_dict_parameter(self.ctx, param, '{"VAR1":"foo","VAR2":"bar"}') + assert result == {"VAR1": "foo", "VAR2": "bar"} + + # Complex objects + result = parse_complex_object_parameter(self.ctx, param, '{"name":"vol1","type":"hostPath"}') + assert result == [{"name": "vol1", "type": "hostPath"}] + + def test_existing_key_value_formats_still_work(self): + """Test that existing key-value formats continue to work.""" + param = MockParam("volume") + + result = parse_complex_object_parameter(self.ctx, param, 'name=model-data,type=hostPath,mount_path=/data,path=/data') + expected = {"name": "model-data", "type": "hostPath", "mount_path": "/data", "path": "/data"} + assert result == [expected] + + def test_existing_comma_separated_still_works(self): + """Test that comma-separated format still works.""" + param = MockParam("clusters") + + result = parse_comma_separated_list(self.ctx, param, "cluster1,cluster2,cluster3") + assert result == ["cluster1", "cluster2", "cluster3"] + + +if __name__ == "__main__": + # Run the tests if executed directly + pytest.main([__file__, "-v"]) diff --git a/test/unit_tests/test_cluster.py b/test/unit_tests/test_cluster.py index 37d52ce3..66b2f0b5 100644 --- a/test/unit_tests/test_cluster.py +++ b/test/unit_tests/test_cluster.py @@ -793,7 +793,7 @@ def test_list_clusters_with_clusters_list( result = self.runner.invoke( list_cluster, - ["--clusters", "cluster-3"], + ["--clusters", "[cluster-3]"], ) self.assertEqual(result.exit_code, 0) self.assertNotIn("cluster-1", result.output)