Skip to content

Commit c51986e

Browse files
mollyheamazonnargokul
authored andcommitted
Fix merge conflict issues, update cluster template to add default in model.py (#186)
* Fix merge conflict issues, update cluster template to add default in model.py * Update model.py to remove default for network related params
1 parent f747fbe commit c51986e

File tree

4 files changed

+196
-252
lines changed

4 files changed

+196
-252
lines changed

hyperpod-cluster-stack-template/hyperpod_cluster_stack_template/v1_0/model.py

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,51 +2,51 @@
22
from typing import Optional, Literal, List, Any
33

44
class ClusterStackBase(BaseModel):
5-
stage: Optional[str] = Field(None, description="Deployment stage (gamma, prod)")
5+
stage: Optional[str] = Field("gamma", description="Deployment stage (gamma, prod)")
66
enable_hp_inference_feature: Optional[str] = Field(None, description="Feature flag for enabling HP inference")
77
custom_bucket_name: Optional[str] = Field(None, description="Custom S3 bucket name for templates")
8-
resource_name_prefix: Optional[str] = Field(None, description="Prefix to be used for all resources")
8+
resource_name_prefix: Optional[str] = Field("sagemaker-hyperpod-eks", description="Prefix to be used for all resources")
99
vpc_cidr: Optional[str] = Field(None, description="The IP range (CIDR notation) for the VPC")
1010
availability_zone_ids: Optional[str] = Field(None, description="List of AZs to deploy subnets in")
1111
vpc_id: Optional[str] = Field(None, description="The ID of the VPC")
1212
nat_gateway_ids: Optional[str] = Field(None, description="Comma-separated list of NAT Gateway IDs")
13-
security_group_id: Optional[str] = Field(None, description="The ID of the security group")
14-
kubernetes_version: Optional[str] = Field(None, description="The Kubernetes version")
13+
security_group_id: Optional[str] = Field("", description="The ID of the security group")
14+
kubernetes_version: Optional[str] = Field("1.31", description="The Kubernetes version")
1515
node_provisioning_mode: Optional[str] = Field(None, description="The node provisioning mode")
16-
eks_cluster_name: Optional[str] = Field(None, description="The name of the EKS cluster")
16+
eks_cluster_name: Optional[str] = Field("eks", description="The name of the EKS cluster")
1717
eks_private_subnet_ids: Optional[str] = Field(None, description="Comma-delimited list of private subnet IDs")
1818
security_group_ids: Optional[str] = Field(None, description="The Id of your cluster security group")
1919
private_route_table_ids: Optional[str] = Field(None, description="Comma-separated list of private route table IDs")
2020
s3_bucket_name: Optional[str] = Field(None, description="The name of the S3 bucket")
21-
github_raw_url: Optional[str] = Field(None, description="The raw GitHub URL for the lifecycle script")
22-
helm_repo_url: Optional[str] = Field(None, description="The URL of the Helm repo")
23-
helm_repo_path: Optional[str] = Field(None, description="The path to the HyperPod Helm chart")
24-
helm_operators: Optional[str] = Field(None, description="The configuration of HyperPod Helm chart")
25-
namespace: Optional[str] = Field(None, description="The namespace to deploy the HyperPod Helm chart")
26-
helm_release: Optional[str] = Field(None, description="The name of the Helm release")
27-
hyperpod_cluster_name: Optional[str] = Field(None, description="Name of SageMaker HyperPod Cluster")
28-
node_recovery: Optional[str] = Field(None, description="Instance recovery setting")
29-
sagemaker_iam_role_name: Optional[str] = Field(None, description="The name of the IAM role")
21+
github_raw_url: Optional[str] = Field("https://raw.githubusercontent.com/aws-samples/awsome-distributed-training/refs/heads/main/1.architectures/7.sagemaker-hyperpod-eks/LifecycleScripts/base-config/on_create.sh", description="The raw GitHub URL for the lifecycle script")
22+
helm_repo_url: Optional[str] = Field("https://github.com/aws/sagemaker-hyperpod-cli.git", description="The URL of the Helm repo")
23+
helm_repo_path: Optional[str] = Field("helm_chart/HyperPodHelmChart", description="The path to the HyperPod Helm chart")
24+
helm_operators: Optional[str] = Field("", description="The configuration of HyperPod Helm chart")
25+
namespace: Optional[str] = Field("kube-system", description="The namespace to deploy the HyperPod Helm chart")
26+
helm_release: Optional[str] = Field("hyperpod-dependencies", description="The name of the Helm release")
27+
hyperpod_cluster_name: Optional[str] = Field("hp-cluster", description="Name of SageMaker HyperPod Cluster")
28+
node_recovery: Optional[str] = Field("Automatic", description="Instance recovery setting")
29+
sagemaker_iam_role_name: Optional[str] = Field("iam-role", description="The name of the IAM role")
3030
private_subnet_ids: Optional[str] = Field(None, description="Comma-separated list of private subnet IDs")
31-
on_create_path: Optional[str] = Field(None, description="The file name of lifecycle script")
32-
instance_group_settings: Optional[str] = Field(None, description="JSON array string containing instance group configurations")
33-
rig_settings: Optional[str] = Field(None, description="JSON array string containing restricted instance group configurations")
31+
on_create_path: Optional[str] = Field("sagemaker-hyperpod-eks-bucket", description="The file name of lifecycle script")
32+
instance_group_settings: Optional[str] = Field('[{"InstanceCount":1,"InstanceGroupName":"ig-1","InstanceStorageConfigs":[],"InstanceType":"ml.t3.medium","ThreadsPerCore":1},{"InstanceCount":1,"InstanceGroupName":"ig-2","InstanceStorageConfigs":[],"InstanceType":"ml.t3.medium","ThreadsPerCore":1}]', description="JSON array string containing instance group configurations")
33+
rig_settings: Optional[str] = Field("", description="JSON array string containing restricted instance group configurations")
3434
rig_s3_bucket_name: Optional[str] = Field(None, description="The name of the S3 bucket for RIG resources")
35-
tags: Optional[str] = Field(None, description="Custom tags for managing the SageMaker HyperPod cluster")
36-
fsx_subnet_id: Optional[str] = Field(None, description="The subnet id for FSx")
35+
tags: Optional[str] = Field("", description="Custom tags for managing the SageMaker HyperPod cluster")
36+
fsx_subnet_id: Optional[str] = Field("", description="The subnet id for FSx")
3737
fsx_availability_zone_id: Optional[str] = Field(None, description="The availability zone for FSx")
38-
per_unit_storage_throughput: Optional[int] = Field(None, description="Per unit storage throughput")
38+
per_unit_storage_throughput: Optional[int] = Field(250, description="Per unit storage throughput")
3939
data_compression_type: Optional[str] = Field(None, description="Data compression type")
40-
file_system_type_version: Optional[float] = Field(None, description="File system type version")
41-
storage_capacity: Optional[int] = Field(None, description="Storage capacity in GiB")
42-
fsx_file_system_id: Optional[str] = Field(None, description="Existing FSx file system ID")
43-
create_vpc_stack: Optional[bool] = Field(None, description="Boolean to Create VPC Stack")
44-
create_security_group_stack: Optional[bool] = Field(None, description="Boolean to Create Security Group Stack")
45-
create_eks_cluster_stack: Optional[bool] = Field(None, description="Boolean to Create EKS Cluster Stack")
46-
create_s3_bucket_stack: Optional[bool] = Field(None, description="Boolean to Create S3 Bucket Stack")
47-
create_s3_endpoint_stack: Optional[bool] = Field(None, description="Boolean to Create S3 Endpoint Stack")
48-
create_life_cycle_script_stack: Optional[bool] = Field(None, description="Boolean to Create Life Cycle Script Stack")
49-
create_sagemaker_iam_role_stack: Optional[bool] = Field(None, description="Boolean to Create SageMaker IAM Role Stack")
50-
create_helm_chart_stack: Optional[bool] = Field(None, description="Boolean to Create Helm Chart Stack")
51-
create_hyperpod_cluster_stack: Optional[bool] = Field(None, description="Boolean to Create HyperPod Cluster Stack")
52-
create_fsx_stack: Optional[bool] = Field(None, description="Boolean to Create FSx Stack")
40+
file_system_type_version: Optional[float] = Field(2.15, description="File system type version")
41+
storage_capacity: Optional[int] = Field(1200, description="Storage capacity in GiB")
42+
fsx_file_system_id: Optional[str] = Field("", description="Existing FSx file system ID")
43+
create_vpc_stack: Optional[bool] = Field(True, description="Boolean to Create VPC Stack")
44+
create_security_group_stack: Optional[bool] = Field(True, description="Boolean to Create Security Group Stack")
45+
create_eks_cluster_stack: Optional[bool] = Field(True, description="Boolean to Create EKS Cluster Stack")
46+
create_s3_bucket_stack: Optional[bool] = Field(True, description="Boolean to Create S3 Bucket Stack")
47+
create_s3_endpoint_stack: Optional[bool] = Field(True, description="Boolean to Create S3 Endpoint Stack")
48+
create_life_cycle_script_stack: Optional[bool] = Field(True, description="Boolean to Create Life Cycle Script Stack")
49+
create_sagemaker_iam_role_stack: Optional[bool] = Field(True, description="Boolean to Create SageMaker IAM Role Stack")
50+
create_helm_chart_stack: Optional[bool] = Field(True, description="Boolean to Create Helm Chart Stack")
51+
create_hyperpod_cluster_stack: Optional[bool] = Field(True, description="Boolean to Create HyperPod Cluster Stack")
52+
create_fsx_stack: Optional[bool] = Field(True, description="Boolean to Create FSx Stack")

src/sagemaker/hyperpod/cli/commands/init.py

Lines changed: 4 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
generate_click_command,
2222
save_config_yaml,
2323
TEMPLATES,
24+
load_config,
2425
load_config_and_validate,
2526
validate_config_against_model,
2627
filter_validation_errors_for_user_input,
@@ -152,7 +153,7 @@ def reset():
152153
dir_path = Path(".").resolve()
153154

154155
# 1) Load and validate config
155-
data, template, version = load_config_and_validate(dir_path)
156+
data, template, version = load_config(dir_path)
156157

157158
# 2) Build config with default values from schema
158159
full_cfg, comment_map = build_config_from_schema(template, version)
@@ -199,7 +200,7 @@ def configure(ctx, model_config):
199200
"""
200201
# 1) Load existing config without validation
201202
dir_path = Path(".").resolve()
202-
data, template, version = load_config_and_validate(dir_path)
203+
data, template, version = load_config(dir_path)
203204

204205
# 2) Determine which fields the user actually provided
205206
# Use Click's parameter source tracking to identify command-line provided parameters
@@ -259,55 +260,6 @@ def validate():
259260
Validate this directory's config.yaml against the appropriate schema.
260261
"""
261262
dir_path = Path(".").resolve()
262-
data, template, version = load_config_and_validate(dir_path)
263-
264-
info = TEMPLATES[template]
265-
266-
if info["schema_type"] == CFN:
267-
# CFN validation using HpClusterStack
268-
payload = {}
269-
for k, v in data.items():
270-
if k not in ("template", "namespace") and v is not None:
271-
# Convert lists to JSON strings for CFN parameters
272-
if isinstance(v, list):
273-
payload[k] = json.dumps(v)
274-
else:
275-
payload[k] = str(v)
276-
277-
try:
278-
HpClusterStack(**payload)
279-
click.secho("✔️ config.yaml is valid!", fg="green")
280-
except ValidationError as e:
281-
click.secho("❌ Validation errors:", fg="red")
282-
for err in e.errors():
283-
loc = ".".join(str(x) for x in err["loc"])
284-
msg = err["msg"]
285-
click.echo(f" – {loc}: {msg}")
286-
sys.exit(1)
287-
else:
288-
# CRD validation using schema registry
289-
registry = info["registry"]
290-
model = registry.get(version)
291-
if model is None:
292-
click.secho(f"❌ Unsupported schema version: {version}", fg="red")
293-
sys.exit(1)
294-
295-
payload = {
296-
k: v
297-
for k, v in data.items()
298-
if k not in ("template", "namespace")
299-
}
300-
301-
try:
302-
model(**payload)
303-
click.secho("✔️ config.yaml is valid!", fg="green")
304-
except ValidationError as e:
305-
click.secho("❌ Validation errors:", fg="red")
306-
for err in e.errors():
307-
loc = ".".join(str(x) for x in err["loc"])
308-
msg = err["msg"]
309-
click.echo(f" – {loc}: {msg}")
310-
sys.exit(1)
311263
load_config_and_validate(dir_path)
312264

313265

@@ -360,45 +312,6 @@ def submit(region):
360312
click.secho(f"❌ Missing config.yaml or {jinja_file.name}. Run `hyp init` first.", fg="red")
361313
sys.exit(1)
362314

363-
# 4) Validate config based on schema type
364-
if schema_type == CFN:
365-
# For CFN templates, use HpClusterStack validation
366-
from sagemaker.hyperpod.cluster_management.hp_cluster_stack import HpClusterStack
367-
import json
368-
payload = {}
369-
for k, v in data.items():
370-
if k not in ('template', 'namespace') and v is not None:
371-
# Convert lists to JSON strings, everything else to string
372-
if isinstance(v, list):
373-
payload[k] = json.dumps(v)
374-
else:
375-
payload[k] = str(v)
376-
try:
377-
HpClusterStack(**payload)
378-
except ValidationError as e:
379-
click.secho("❌ HpClusterStack Validation errors:", fg="red")
380-
for err in e.errors():
381-
loc = '.'.join(str(x) for x in err['loc'])
382-
msg = err['msg']
383-
click.echo(f" – {loc}: {msg}")
384-
sys.exit(1)
385-
else:
386-
# For CRD templates, use registry validation
387-
registry = info["registry"]
388-
model = registry.get(version)
389-
if model is None:
390-
click.secho(f"❌ Unsupported schema version: {version}", fg="red")
391-
sys.exit(1)
392-
payload = {k: v for k, v in data.items() if k not in ('template', 'namespace')}
393-
try:
394-
model(**payload)
395-
except ValidationError as e:
396-
click.secho("❌ Validation errors:", fg="red")
397-
for err in e.errors():
398-
loc = '.'.join(str(x) for x in err['loc'])
399-
msg = err['msg']
400-
click.echo(f" – {loc}: {msg}")
401-
sys.exit(1)
402315
# 4) Validate config using consolidated function
403316
validation_errors = validate_config_against_model(data, template, version)
404317
is_valid = display_validation_results(
@@ -463,7 +376,7 @@ def submit(region):
463376
region=region)
464377
else:
465378
dir_path = Path(".").resolve()
466-
data, template, version = load_config_and_validate(dir_path)
379+
data, template, version = load_config(dir_path)
467380
namespace = data.get("namespace", "default")
468381
registry = TEMPLATES[template]["registry"]
469382
model = registry.get(version)

0 commit comments

Comments
 (0)