Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pydantic import BaseModel, ConfigDict, Field
from typing import Optional, List, Dict, Union
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from typing import Optional, List, Dict, Union, Literal
from sagemaker.hyperpod.training.config.hyperpod_pytorch_job_unified_config import (
Containers,
ReplicaSpec,
Expand All @@ -8,9 +8,42 @@
Spec,
Template,
Metadata,
Volumes,
HostPath,
PersistentVolumeClaim
)


class VolumeConfig(BaseModel):
name: str = Field(..., description="Volume name")
type: Literal['hostPath', 'pvc'] = Field(..., description="Volume type")
mount_path: str = Field(..., description="Mount path in container")
path: Optional[str] = Field(None, description="Host path (required for hostPath volumes)")
claim_name: Optional[str] = Field(None, description="PVC claim name (required for pvc volumes)")
read_only: Optional[Literal['true', 'false']] = Field(None, description="Read-only flag for pvc volumes")

@field_validator('mount_path', 'path')
@classmethod
def paths_must_be_absolute(cls, v):
"""Validate that paths are absolute (start with /)."""
if v and not v.startswith('/'):
raise ValueError('Path must be absolute (start with /)')
return v

@model_validator(mode='after')
def validate_type_specific_fields(self):
"""Validate that required fields are present based on volume type."""

if self.type == 'hostPath':
if not self.path:
raise ValueError('hostPath volumes require path field')
elif self.type == 'pvc':
if not self.claim_name:
raise ValueError('PVC volumes require claim_name field')

return self


class PyTorchJobConfig(BaseModel):
model_config = ConfigDict(extra="forbid")

Expand Down Expand Up @@ -60,22 +93,41 @@ class PyTorchJobConfig(BaseModel):
max_retry: Optional[int] = Field(
default=None, alias="max_retry", description="Maximum number of job retries"
)
volumes: Optional[List[str]] = Field(
default=None, description="List of volumes to mount"
)
persistent_volume_claims: Optional[List[str]] = Field(
default=None,
alias="persistent_volume_claims",
description="List of persistent volume claims",
volume: Optional[List[VolumeConfig]] = Field(
default=None, description="List of volume configurations. \
Command structure: --volume name=<volume_name>,type=<volume_type>,mount_path=<mount_path>,<type-specific options> \
For hostPath: --volume name=model-data,type=hostPath,mount_path=/data,path=/data \
For persistentVolumeClaim: --volume name=training-output,type=pvc,mount_path=/mnt/output,claim_name=training-output-pvc,read_only=false \
If multiple --volume flag if multiple volumes are needed \
"
)
service_account_name: Optional[str] = Field(
default=None, alias="service_account_name", description="Service account name"
)

@field_validator('volume')
def validate_no_duplicates(cls, v):
"""Validate no duplicate volume names or mount paths."""
if not v:
return v

# Check for duplicate volume names
names = [vol.name for vol in v]
if len(names) != len(set(names)):
raise ValueError("Duplicate volume names found")

# Check for duplicate mount paths
mount_paths = [vol.mount_path for vol in v]
if len(mount_paths) != len(set(mount_paths)):
raise ValueError("Duplicate mount paths found")

return v

def to_domain(self) -> Dict:
"""
Convert flat config to domain model (HyperPodPytorchJobSpec)
"""

# Create container with required fields
container_kwargs = {
"name": "container-name",
Expand All @@ -97,17 +149,42 @@ def to_domain(self) -> Dict:
container_kwargs["env"] = [
{"name": k, "value": v} for k, v in self.environment.items()
]
if self.volumes is not None:
container_kwargs["volume_mounts"] = [
{"name": v, "mount_path": f"/mnt/{v}"} for v in self.volumes
]

if self.volume is not None:
volume_mounts = []
for i, vol in enumerate(self.volume):
volume_mount = {"name": vol.name, "mount_path": vol.mount_path}
volume_mounts.append(volume_mount)

container_kwargs["volume_mounts"] = volume_mounts


# Create container object
container = Containers(**container_kwargs)
try:
container = Containers(**container_kwargs)
except Exception as e:
raise

# Create pod spec kwargs
spec_kwargs = {"containers": list([container])}

# Add volumes to pod spec if present
if self.volume is not None:
volumes = []
for i, vol in enumerate(self.volume):
if vol.type == "hostPath":
host_path = HostPath(path=vol.path)
volume_obj = Volumes(name=vol.name, host_path=host_path)
elif vol.type == "pvc":
pvc_config = PersistentVolumeClaim(
claim_name=vol.claim_name,
read_only=vol.read_only == "true" if vol.read_only else False
)
volume_obj = Volumes(name=vol.name, persistent_volume_claim=pvc_config)
volumes.append(volume_obj)

spec_kwargs["volumes"] = volumes

# Add node selector if any selector fields are present
node_selector = {}
if self.instance_type is not None:
Expand Down Expand Up @@ -175,5 +252,4 @@ def to_domain(self) -> Dict:
"namespace": self.namespace,
"spec": job_kwargs,
}

return result
Loading
Loading