Skip to content

Commit 281e6d6

Browse files
Aditi2424adishaa
andauthored
Add integ test for training CLI and SDK (#100)
* Add integ test for training cli * Add integ test for training sdk * relax pydantic version * fix pydantic version * return latest cluster and fix set cluster context test --------- Co-authored-by: adishaa <[email protected]>
1 parent f27a3cb commit 281e6d6

File tree

7 files changed

+386
-2
lines changed

7 files changed

+386
-2
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,6 @@ __pycache__/
2424

2525
# Ignore all contents of result and results directories
2626
/result/
27-
/results/
27+
/results/
28+
29+
.idea/

__init__.py

Whitespace-only changes.

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@
8585
"ruff==0.6.2",
8686
"hera-workflows==5.16.3",
8787
"sagemaker-core<2.0.0",
88-
"pydantic==2.11.7"
88+
"pydantic>=2.10.6,<3.0.0"
8989
],
9090
entry_points={
9191
"console_scripts": [
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
14+
import pytest
15+
import time
16+
17+
from sagemaker.hyperpod.cli.utils import setup_logger
18+
from test.integration_tests.utils import execute_command
19+
from test.integration_tests.abstract_integration_tests import AbstractIntegrationTests
20+
21+
logger = setup_logger(__name__)
22+
23+
24+
class TestHypCLICommands(AbstractIntegrationTests):
25+
"""Integration tests for HyperPod CLI using hyp commands."""
26+
27+
def test_list_clusters(self, cluster_name):
28+
"""Test listing clusters """
29+
assert cluster_name
30+
31+
def test_set_cluster_context(self, cluster_name):
32+
"""Test setting cluster context."""
33+
result = execute_command([
34+
"hyp", "set-cluster-context",
35+
"--cluster-name", cluster_name
36+
])
37+
assert result.returncode == 0
38+
context_line = result.stdout.strip().splitlines()[-1]
39+
assert any(text in context_line for text in ["Updated context", "Added new context"])
40+
41+
def test_get_cluster_context(self, cluster_name):
42+
"""Test getting current cluster context."""
43+
result = execute_command(["hyp", "get-cluster-context"])
44+
assert result.returncode == 0
45+
46+
context_output = result.stdout.strip()
47+
assert "Cluster context:" in context_output
48+
# Just verify we got a valid ARN without checking the specific name
49+
current_arn = context_output.split("Cluster context:")[-1].strip()
50+
assert "arn:aws:eks:" in current_arn
51+
52+
def test_create_job(self, test_job_name, image_uri):
53+
"""Test creating a PyTorch job using the correct CLI parameters."""
54+
result = execute_command([
55+
"hyp", "create", "hp-pytorch-job",
56+
"--version", "1.0",
57+
"--job-name", test_job_name,
58+
"--image", image_uri,
59+
"--pull-policy", "Always",
60+
"--tasks-per-node", "1",
61+
"--max-retry", "1"
62+
])
63+
assert result.returncode == 0
64+
logger.info(f"Created job: {test_job_name}")
65+
66+
# Wait a moment for the job to be created
67+
time.sleep(5)
68+
69+
def test_list_jobs(self, test_job_name):
70+
"""Test listing jobs and verifying the created job is present."""
71+
list_result = execute_command(["hyp", "list", "hp-pytorch-job"])
72+
assert list_result.returncode == 0
73+
74+
# Check if either the job name is in the output or at least the header is present
75+
assert test_job_name in list_result.stdout
76+
logger.info("Successfully listed jobs")
77+
78+
def test_list_pods(self, test_job_name):
79+
"""Test listing pods for a specific job."""
80+
# Wait a moment to ensure pods are created
81+
time.sleep(10)
82+
83+
list_pods_result = execute_command([
84+
"hyp", "list-pods", "hp-pytorch-job",
85+
"--job-name", test_job_name
86+
])
87+
assert list_pods_result.returncode == 0
88+
89+
# Verify the output contains expected headers and job name
90+
output = list_pods_result.stdout.strip()
91+
assert f"Pods for job: {test_job_name}" in output
92+
assert "POD NAME" in output
93+
assert "NAMESPACE" in output
94+
95+
# Verify at least one pod is listed (should contain the job name in the pod name)
96+
assert f"{test_job_name}-pod-" in output
97+
98+
logger.info(f"Successfully listed pods for job: {test_job_name}")
99+
100+
# @pytest.mark.skip(reason="Skipping since there is ")
101+
def test_get_logs(self, test_job_name):
102+
"""Test getting logs for a specific pod in a job."""
103+
# First, get the pod name from list-pods command
104+
list_pods_result = execute_command([
105+
"hyp", "list-pods", "hp-pytorch-job",
106+
"--job-name", test_job_name
107+
])
108+
assert list_pods_result.returncode == 0
109+
110+
# Extract the pod name from the output
111+
output_lines = list_pods_result.stdout.strip().split('\n')
112+
pod_name = None
113+
for line in output_lines:
114+
if f"{test_job_name}-pod-" in line:
115+
# Extract the pod name from the line
116+
pod_name = line.split()[0].strip()
117+
break
118+
119+
assert pod_name is not None, f"Could not find pod for job {test_job_name}"
120+
logger.info(f"Found pod: {pod_name}")
121+
122+
# Now get logs for this pod
123+
get_logs_result = execute_command([
124+
"hyp", "get-logs", "hp-pytorch-job",
125+
"--job-name", test_job_name,
126+
"--pod-name", pod_name
127+
])
128+
assert get_logs_result.returncode == 0
129+
130+
# Verify the output contains the expected header
131+
logs_output = get_logs_result.stdout.strip()
132+
assert f"Listing logs for pod: {pod_name}" in logs_output
133+
134+
logger.info(f"Successfully retrieved logs for pod: {pod_name}")
135+
136+
def test_describe_job(self, test_job_name):
137+
"""Test describing a specific job and verifying the output."""
138+
describe_result = execute_command(["hyp", "describe", "hp-pytorch-job", "--job-name", test_job_name])
139+
assert describe_result.returncode == 0
140+
141+
# Check if either the job name is in the output or metadata is present
142+
assert test_job_name in describe_result.stdout
143+
logger.info(f"Successfully described job: {test_job_name}")
144+
145+
@pytest.mark.run(order=99)
146+
def test_delete_job(self, test_job_name):
147+
"""Test deleting a job and verifying deletion."""
148+
delete_result = execute_command(["hyp", "delete", "hp-pytorch-job", "--job-name", test_job_name])
149+
assert delete_result.returncode == 0
150+
logger.info(f"Successfully deleted job: {test_job_name}")
151+
152+
# Wait a moment for the job to be deleted
153+
time.sleep(5)
154+
155+
# Verify the job is no longer listed
156+
list_result = execute_command(["hyp", "list", "hp-pytorch-job"])
157+
assert list_result.returncode == 0
158+
159+
# The job name should no longer be in the output
160+
assert test_job_name not in list_result.stdout
161+
162+

test/integration_tests/conftest.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import uuid
2+
import pytest
3+
import json
4+
from test.integration_tests.utils import execute_command
5+
from sagemaker.hyperpod.training import (
6+
HyperPodPytorchJob,
7+
Container,
8+
ReplicaSpec,
9+
Resources,
10+
RunPolicy,
11+
Spec,
12+
Template,
13+
)
14+
from sagemaker.hyperpod.common.config import Metadata
15+
16+
@pytest.fixture(scope="class")
17+
def test_job_name():
18+
"""Generate a unique job name for testing."""
19+
return f"test-pytorch-job-{str(uuid.uuid4())[:8]}"
20+
21+
@pytest.fixture(scope="class")
22+
def image_uri():
23+
"""Return a standard PyTorch image URI for testing."""
24+
return "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.2.0-cpu-py310-ubuntu20.04-sagemaker"
25+
26+
@pytest.fixture(scope="class")
27+
def cluster_name():
28+
"""Fixture to list clusters once and return the first cluster name."""
29+
result = execute_command(["hyp", "list-cluster"])
30+
assert result.returncode == 0
31+
32+
try:
33+
json_start = result.stdout.index('[')
34+
json_text = result.stdout[json_start:]
35+
clusters = json.loads(json_text)
36+
except Exception as e:
37+
raise AssertionError(f"Failed to parse cluster list JSON: {e}\nRaw Output:\n{result.stdout}")
38+
39+
assert clusters, "No clusters found in list-cluster output"
40+
return clusters[-1]["Cluster"]
41+
42+
@pytest.fixture(scope="class")
43+
def pytorch_job(test_job_name, image_uri):
44+
"""Create a HyperPodPytorchJob instance for testing."""
45+
nproc_per_node="1"
46+
replica_specs=[
47+
ReplicaSpec(
48+
name="pod",
49+
template=Template(
50+
spec=Spec(
51+
containers=[
52+
Container(
53+
name="container-name",
54+
image=image_uri,
55+
image_pull_policy="Always",
56+
resources=Resources(
57+
requests={"nvidia.com/gpu": "0"},
58+
limits={"nvidia.com/gpu": "0"},
59+
),
60+
# command=[]
61+
)
62+
]
63+
)
64+
),
65+
)
66+
]
67+
run_policy=RunPolicy(clean_pod_policy="None")
68+
69+
pytorch_job = HyperPodPytorchJob(
70+
metadata=Metadata(name=test_job_name),
71+
nproc_per_node=nproc_per_node,
72+
replica_specs=replica_specs,
73+
run_policy=run_policy,
74+
)
75+
76+
return pytorch_job
77+
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
14+
import pytest
15+
import time
16+
import yaml
17+
18+
from sagemaker.hyperpod.training import (
19+
HyperPodPytorchJob,
20+
Container,
21+
ReplicaSpec,
22+
Resources,
23+
RunPolicy,
24+
Spec,
25+
Template,
26+
)
27+
from sagemaker.hyperpod.common.config import Metadata
28+
from sagemaker.hyperpod.cli.utils import setup_logger
29+
from test.integration_tests.abstract_integration_tests import AbstractIntegrationTests
30+
31+
logger = setup_logger(__name__)
32+
33+
34+
class TestHyperPodTrainingSDK(AbstractIntegrationTests):
35+
"""Integration tests for HyperPod Training SDK."""
36+
37+
def test_create_job(self, pytorch_job):
38+
"""Test creating a PyTorch job using the SDK."""
39+
try:
40+
# The create() method doesn't return anything
41+
pytorch_job.create()
42+
logger.info(f"Job creation initiated: {pytorch_job.metadata.name}")
43+
44+
# Wait for the job to be created and status to be available
45+
# We'll try a few times with increasing delays
46+
max_attempts = 5
47+
for attempt in range(1, max_attempts + 1):
48+
try:
49+
logger.info(f"Waiting for job status to be available (attempt {attempt}/{max_attempts})...")
50+
# Wait with increasing delay
51+
time.sleep(attempt * 5) # 5, 10, 15, 20, 25 seconds
52+
53+
# Get the job directly instead of using refresh
54+
HyperPodPytorchJob.get(pytorch_job.metadata.name, pytorch_job.metadata.namespace)
55+
56+
# If we got here without exception, the job exists
57+
logger.info(f"Job successfully created: {pytorch_job.metadata.name}")
58+
return
59+
except Exception as e:
60+
if "status" in str(e) and attempt < max_attempts:
61+
logger.info(f"Status not available yet, retrying... ({e})")
62+
continue
63+
else:
64+
raise
65+
66+
# If we get here, we've exhausted our attempts
67+
pytest.fail(f"Job was created but status never became available after {max_attempts} attempts")
68+
except Exception as e:
69+
logger.error(f"Error creating job: {e}")
70+
pytest.fail(f"Failed to create job: {e}")
71+
72+
def test_list_jobs(self, pytorch_job):
73+
"""Test listing jobs and verifying the created job is present."""
74+
jobs = HyperPodPytorchJob.list()
75+
assert jobs is not None
76+
77+
# Check if the job name is in the list
78+
job_names = [job.metadata.name for job in jobs]
79+
assert pytorch_job.metadata.name in job_names
80+
81+
#
82+
def test_refresh_job(self, pytorch_job):
83+
pytorch_job.refresh()
84+
time.sleep(15)
85+
assert pytorch_job.status is not None, "Job status should not be None"
86+
logger.info(f"Refreshed job status:\n{yaml.dump(pytorch_job.status)}")
87+
88+
def test_list_pods(self, pytorch_job):
89+
"""Test listing pods for a specific job."""
90+
pods = pytorch_job.list_pods()
91+
assert pods is not None
92+
93+
# Check if at least one pod is listed
94+
assert len(pods) > 0
95+
96+
# Store the first pod name for later use
97+
pytest.pod_name = pods[0]
98+
99+
logger.info(f"Successfully listed pods: {pods}")
100+
101+
def test_get_logs(self, pytorch_job):
102+
"""Test getting logs for a specific pod in a job."""
103+
pod_name = getattr(pytest, "pod_name", None)
104+
if not pod_name:
105+
pytest.skip("No pod name available from previous test")
106+
107+
logs = pytorch_job.get_logs_from_pod(pod_name)
108+
assert logs is not None
109+
110+
logger.info(f"Successfully retrieved logs for pod: {pod_name}")
111+
112+
def test_delete_job(self, pytorch_job):
113+
"""Test deleting a job."""
114+
pytorch_job.delete()
115+
logger.info(f"Successfully deleted job: {pytorch_job.metadata.name}")
116+
117+
# Wait a moment for the job to be deleted
118+
time.sleep(5)
119+
120+
# Verify the job is no longer listed
121+
jobs = HyperPodPytorchJob.list()
122+
job_names = [job.metadata.name for job in jobs]
123+
assert pytorch_job.metadata.name not in job_names

0 commit comments

Comments
 (0)