Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/sagemaker/hyperpod/cli/commands/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from kubernetes import client
from ratelimit import limits, sleep_and_retry
from tabulate import tabulate
from sagemaker.hyperpod.cli.parsers import parse_list_parameter

from sagemaker.hyperpod.cli.clients.kubernetes_client import (
KubernetesClient,
Expand Down Expand Up @@ -97,9 +98,10 @@
)
@click.option(
"--clusters",
callback=parse_list_parameter,
type=click.STRING,
required=False,
help="Optional. A list of HyperPod cluster names that users want to check the capacity for. This is useful for users who know some of their most commonly used clusters and want to check the capacity status of the clusters in the AWS account.",
help="Optional. List of HyperPod cluster names to check capacity for. Supports JSON format: '[\"cluster1\", \"cluster2\"]' or simple format: '[cluster1, cluster2]'",
)
@click.option(
"--debug",
Expand Down Expand Up @@ -183,7 +185,7 @@ def list_cluster(
sys.exit(1)

if clusters:
cluster_names = clusters.split(",")
cluster_names = clusters
else:
try:
cluster_names = _get_hyperpod_clusters(sm_client)
Expand Down
24 changes: 4 additions & 20 deletions src/sagemaker/hyperpod/cli/commands/cluster_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import json
import os
from typing import Optional
from sagemaker.hyperpod.cli.parsers import parse_list_parameter

from sagemaker_core.main.resources import Cluster
from sagemaker_core.main.shapes import ClusterInstanceGroupSpecification
Expand All @@ -22,23 +23,7 @@
logger = logging.getLogger(__name__)


def parse_status_list(ctx, param, value):
"""Parse status list from string format like "['CREATE_COMPLETE', 'UPDATE_COMPLETE']" """
if not value:
return None

try:
# Handle both string representation and direct list
if isinstance(value, str):
# Parse string like "['item1', 'item2']"
parsed = ast.literal_eval(value)
if isinstance(parsed, list):
return parsed
else:
raise click.BadParameter(f"Expected list format, got: {type(parsed).__name__}")
return value
except (ValueError, SyntaxError) as e:
raise click.BadParameter(f"Invalid list format. Use: \"['STATUS1', 'STATUS2']\". Error: {e}")
# Use unified parser for consistent behavior - no need for custom function


@click.command("cluster-stack")
Expand Down Expand Up @@ -223,8 +208,8 @@ def describe_cluster_stack(stack_name: str, debug: bool, region: str) -> None:
@click.option("--region", help="AWS region")
@click.option("--debug", is_flag=True, help="Enable debug logging")
@click.option("--status",
callback=parse_status_list,
help="Filter by stack status. Format: \"['CREATE_COMPLETE', 'UPDATE_COMPLETE']\"")
callback=parse_list_parameter,
help="Filter by stack status. Supports JSON format: '[\"CREATE_COMPLETE\", \"UPDATE_COMPLETE\"]' or simple format: '[CREATE_COMPLETE, UPDATE_COMPLETE]'")
@_hyperpod_telemetry_emitter(Feature.HYPERPOD_CLI, "list_cluster_stack_cli")
def list_cluster_stacks(region, debug, status):
"""List all HyperPod cluster stacks.
Expand Down Expand Up @@ -376,4 +361,3 @@ def update_cluster(

logger.info("Cluster has been updated")
click.secho(f"Cluster {cluster_name} has been updated")

24 changes: 9 additions & 15 deletions src/sagemaker/hyperpod/cli/inference_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Callable, Optional, Mapping, Type
import sys
from sagemaker.hyperpod.cli.common_utils import extract_version_from_args, get_latest_version, load_schema_for_version
from sagemaker.hyperpod.cli.parsers import parse_dict_parameter


def generate_click_command(
Expand All @@ -18,14 +19,7 @@ def generate_click_command(
version = extract_version_from_args(registry, schema_pkg, default_version)

def decorator(func: Callable) -> Callable:
# Parser for the single JSON‐dict env var flag
def _parse_json_flag(ctx, param, value):
if value is None:
return None
try:
return json.loads(value)
except json.JSONDecodeError as e:
raise click.BadParameter(f"{param.name!r} must be valid JSON: {e}")
# Use unified parser for consistent behavior

# 1) the wrapper click actually invokes
def wrapped_func(*args, **kwargs):
Expand All @@ -46,21 +40,21 @@ def wrapped_func(*args, **kwargs):
props = schema.get("properties", {})

json_flags = {
"env": ("JSON object of environment variables, e.g. " '\'{"VAR1":"foo","VAR2":"bar"}\''),
"dimensions": ("JSON object of dimensions, e.g. " '\'{"VAR1":"foo","VAR2":"bar"}\''),
"resources_limits": ('JSON object of resource limits, e.g. \'{"cpu":"2","memory":"4Gi"}\''),
"resources_requests": ('JSON object of resource requests, e.g. \'{"cpu":"1","memory":"2Gi"}\''),
"env": "Environment variables. Supports JSON format: '{\"VAR1\":\"foo\",\"VAR2\":\"bar\"}' or simple format: '{VAR1: foo, VAR2: bar}'",
"dimensions": "Dimensions. Supports JSON format: '{\"VAR1\":\"foo\",\"VAR2\":\"bar\"}' or simple format: '{VAR1: foo, VAR2: bar}'",
"resources_limits": "Resource limits. Supports JSON format: '{\"cpu\":\"2\",\"memory\":\"4Gi\"}' or simple format: '{cpu: 2, memory: 4Gi}'",
"resources_requests": "Resource requests. Supports JSON format: '{\"cpu\":\"1\",\"memory\":\"2Gi\"}' or simple format: '{cpu: 1, memory: 2Gi}'",
}

for flag_name, help_text in json_flags.items():
if flag_name in props:
wrapped_func = click.option(
f"--{flag_name.replace('_', '-')}",
callback=_parse_json_flag,
callback=parse_dict_parameter,
type=str,
default=None,
help=help_text,
metavar="JSON",
metavar="JSON|SIMPLE",
)(wrapped_func)

# 3) auto-inject all schema.json fields
Expand Down Expand Up @@ -99,4 +93,4 @@ def wrapped_func(*args, **kwargs):

return wrapped_func

return decorator
return decorator
72 changes: 11 additions & 61 deletions src/sagemaker/hyperpod/cli/init_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
CFN
)
from sagemaker.hyperpod.cluster_management.hp_cluster_stack import HpClusterStack
from sagemaker.hyperpod.cli.parsers import parse_dict_parameter, parse_complex_object_parameter

log = logging.getLogger()

Expand Down Expand Up @@ -207,58 +208,7 @@ def decorator(func: Callable) -> Callable:
# If template can't be fetched, use empty dict
pass

# JSON flag parser
def _parse_json_flag(ctx, param, value):
if value is None:
return None
try:
return json.loads(value)
except json.JSONDecodeError:
# Try to fix unquoted list items: [python, train.py] -> ["python", "train.py"]
if value.strip().startswith('[') and value.strip().endswith(']'):
try:
# Remove brackets and split by comma
inner = value.strip()[1:-1]
items = [item.strip().strip('"').strip("'") for item in inner.split(',')]
return items
except:
pass
raise click.BadParameter(f"{param.name!r} must be valid JSON or a list like [item1, item2]")


# Volume flag parser
def _parse_volume_flag(ctx, param, value):
if not value:
return None

# Handle multiple volume flags
if not isinstance(value, (list, tuple)):
value = [value]

from hyperpod_pytorch_job_template.v1_0.model import VolumeConfig
volumes = []

for vol_str in value:
# Parse volume string: name=model-data,type=hostPath,mount_path=/data,path=/data
vol_dict = {}
for pair in vol_str.split(','):
if '=' in pair:
key, val = pair.split('=', 1)
key = key.strip()
val = val.strip()

# Convert read_only to boolean
if key == 'read_only':
vol_dict[key] = val.lower() in ('true', '1', 'yes', 'on')
else:
vol_dict[key] = val

try:
volumes.append(VolumeConfig(**vol_dict))
except Exception as e:
raise click.BadParameter(f"Invalid volume configuration '{vol_str}': {e}")

return volumes
# Use unified parsers for consistent behavior across all parameter types

@functools.wraps(func)
def wrapped(*args, **kwargs):
Expand Down Expand Up @@ -318,9 +268,9 @@ def wrapped(*args, **kwargs):
if flag_name in union_props:
wrapped = click.option(
f"--{flag}",
callback=_parse_json_flag,
metavar="JSON",
help=f"JSON object for {flag.replace('-', ' ')}",
callback=parse_dict_parameter,
metavar="JSON|SIMPLE",
help=f"{flag.replace('-', ' ').title()}. Supports JSON format or simple format",
)(wrapped)


Expand Down Expand Up @@ -366,8 +316,8 @@ def wrapped(*args, **kwargs):
wrapped = click.option(
f"--{name.replace('_','-')}",
multiple=True,
callback=_parse_volume_flag,
help=help_text,
callback=lambda ctx, param, value: parse_complex_object_parameter(ctx, param, value, allow_multiple=True),
help="Volume configurations. Supports JSON format or key-value format with multiple --volume flags",
)(wrapped)
else:
wrapped = click.option(
Expand All @@ -393,9 +343,9 @@ def wrapped(*args, **kwargs):
if cfn_param_name == 'Tags':
wrapped = click.option(
f"--{pascal_to_kebab(cfn_param_name)}",
callback=_parse_json_flag,
metavar="JSON",
help=cfn_param_details.get('Description', ''),
callback=parse_dict_parameter,
metavar="JSON|SIMPLE",
help=cfn_param_details.get('Description', '') + " Supports JSON format or simple format",
)(wrapped)
else:
cfn_default = cfn_param_details.get('Default')
Expand Down Expand Up @@ -946,4 +896,4 @@ def pascal_to_kebab(pascal_str):
if char.isupper() and i > 0:
result.append('-')
result.append(char.lower())
return ''.join(result)
return ''.join(result)
Loading
Loading