Skip to content

Commit a98652b

Browse files
Aditi2424adishaa
andauthored
Fix training integration tests (#113)
* Fix training integration tests * Update CLI tests * minor changes --------- Co-authored-by: adishaa <[email protected]>
1 parent babb17d commit a98652b

File tree

7 files changed

+426
-18
lines changed

7 files changed

+426
-18
lines changed

test/integration_tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def test_job_name():
2121
@pytest.fixture(scope="class")
2222
def image_uri():
2323
"""Return a standard PyTorch image URI for testing."""
24-
return "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.2.0-cpu-py310-ubuntu20.04-sagemaker"
24+
return "448049793756.dkr.ecr.us-west-2.amazonaws.com/ptjob:mnist"
2525

2626
@pytest.fixture(scope="class")
2727
def cluster_name():
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
import time
2+
import uuid
3+
import pytest
4+
import boto3
5+
from click.testing import CliRunner
6+
from sagemaker.hyperpod.cli.commands.inference import (
7+
custom_create,
8+
custom_invoke,
9+
custom_list,
10+
custom_describe,
11+
custom_delete,
12+
custom_get_operator_logs,
13+
custom_list_pods
14+
)
15+
from sagemaker.hyperpod.inference.hp_endpoint import HPEndpoint
16+
17+
# --------- Test Configuration ---------
18+
NAMESPACE = "integration"
19+
VERSION = "1.0"
20+
REGION = "us-east-2"
21+
TIMEOUT_MINUTES = 15
22+
POLL_INTERVAL_SECONDS = 30
23+
24+
@pytest.fixture(scope="module")
25+
def runner():
26+
return CliRunner()
27+
28+
@pytest.fixture(scope="module")
29+
def custom_endpoint_name():
30+
return f"custom-cli-integration"
31+
32+
@pytest.fixture(scope="module")
33+
def sagemaker_client():
34+
return boto3.client("sagemaker", region_name=REGION)
35+
36+
# --------- Custom Endpoint Tests ---------
37+
38+
def test_custom_create(runner, custom_endpoint_name):
39+
result = runner.invoke(custom_create, [
40+
"--namespace", NAMESPACE,
41+
"--version", VERSION,
42+
"--instance-type", "ml.g5.8xlarge",
43+
"--model-name", "test-model-integration",
44+
"--model-source-type", "s3",
45+
"--model-location", "deepseek15b",
46+
"--s3-bucket-name", "test-model-s3-zhaoqi",
47+
"--s3-region", REGION,
48+
"--image-uri", "763104351884.dkr.ecr.us-east-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.4.0-tgi2.3.1-gpu-py311-cu124-ubuntu22.04-v2.0",
49+
"--container-port", "8080",
50+
"--model-volume-mount-name", "model-weights",
51+
"--endpoint-name", custom_endpoint_name,
52+
"--resources-requests", '{"cpu": "30000m", "nvidia.com/gpu": 1, "memory": "100Gi"}',
53+
"--resources-limits", '{"nvidia.com/gpu": 1}',
54+
"--tls-certificate-output-s3-uri", "s3://tls-bucket-inf1-beta2",
55+
"--metrics-enabled", "true",
56+
"--metric-collection-period", "30",
57+
"--metric-name", "Invocations",
58+
"--metric-stat", "Sum",
59+
"--metric-type", "Average",
60+
"--min-value", "0.0",
61+
"--cloud-watch-trigger-name", "SageMaker-Invocations-new",
62+
"--cloud-watch-trigger-namespace", "AWS/SageMaker",
63+
"--target-value", "10",
64+
"--use-cached-metrics", "true",
65+
"--dimensions", '{"EndpointName": "' + custom_endpoint_name + '", "VariantName": "AllTraffic"}',
66+
"--env", '{ "HF_MODEL_ID": "/opt/ml/model", "SAGEMAKER_PROGRAM": "inference.py", "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code", "MODEL_CACHE_ROOT": "/opt/ml/model", "SAGEMAKER_ENV": "1" }',
67+
68+
])
69+
assert result.exit_code == 0, result.output
70+
71+
72+
def test_custom_list(runner, custom_endpoint_name):
73+
result = runner.invoke(custom_list, ["--namespace", NAMESPACE])
74+
assert result.exit_code == 0
75+
assert custom_endpoint_name in result.output
76+
77+
78+
def test_custom_describe(runner, custom_endpoint_name):
79+
result = runner.invoke(custom_describe, [
80+
"--name", custom_endpoint_name,
81+
"--namespace", NAMESPACE,
82+
"--full"
83+
])
84+
assert result.exit_code == 0
85+
assert custom_endpoint_name in result.output
86+
87+
88+
def test_wait_until_inservice(custom_endpoint_name):
89+
"""Poll SDK until specific JumpStart endpoint reaches DeploymentComplete"""
90+
print(f"[INFO] Waiting for JumpStart endpoint '{custom_endpoint_name}' to be DeploymentComplete...")
91+
deadline = time.time() + (TIMEOUT_MINUTES * 60)
92+
poll_count = 0
93+
94+
while time.time() < deadline:
95+
poll_count += 1
96+
print(f"[DEBUG] Poll #{poll_count}: Checking endpoint status...")
97+
98+
try:
99+
ep = HPEndpoint.get(name=custom_endpoint_name, namespace=NAMESPACE)
100+
state = ep.status.endpoints.sagemaker.state
101+
print(f"[DEBUG] Current state: {state}")
102+
if state == "CreationCompleted":
103+
print("[INFO] Endpoint is in CreationCompleted state.")
104+
return
105+
106+
deployment_state = ep.status.deploymentStatus.deploymentObjectOverallState
107+
if deployment_state == "DeploymentFailed":
108+
pytest.fail("Endpoint deployment failed.")
109+
110+
except Exception as e:
111+
print(f"[ERROR] Exception during polling: {e}")
112+
113+
time.sleep(POLL_INTERVAL_SECONDS)
114+
115+
pytest.fail("[ERROR] Timed out waiting for endpoint to be DeploymentComplete")
116+
117+
118+
def test_custom_invoke(runner, custom_endpoint_name):
119+
result = runner.invoke(custom_invoke, [
120+
"--endpoint-name", custom_endpoint_name,
121+
"--body", '{"inputs": "What is the capital of USA?"}'
122+
])
123+
assert result.exit_code == 0
124+
assert "error" not in result.output.lower()
125+
126+
127+
def test_custom_get_operator_logs(runner):
128+
result = runner.invoke(custom_get_operator_logs, ["--since-hours", "1"])
129+
assert result.exit_code == 0
130+
131+
132+
def test_custom_list_pods(runner):
133+
result = runner.invoke(custom_list_pods, ["--namespace", NAMESPACE])
134+
assert result.exit_code == 0
135+
136+
137+
def test_custom_delete(runner, custom_endpoint_name):
138+
result = runner.invoke(custom_delete, [
139+
"--name", custom_endpoint_name,
140+
"--namespace", NAMESPACE
141+
])
142+
assert result.exit_code == 0
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
import time
2+
import uuid
3+
import json
4+
import pytest
5+
import boto3
6+
7+
from sagemaker.hyperpod.inference.hp_endpoint import HPEndpoint
8+
from sagemaker.hyperpod.inference.config.hp_endpoint_config import (
9+
ModelSourceConfig, S3Storage, TlsConfig, Worker, ModelVolumeMount,
10+
ModelInvocationPort, Resources, EnvironmentVariables, AutoScalingSpec,
11+
CloudWatchTrigger, Dimensions, Metrics
12+
)
13+
import sagemaker_core.main.code_injection.codec as codec
14+
15+
# --------- Test Configuration ---------
16+
NAMESPACE = "integration"
17+
REGION = "us-east-2"
18+
ENDPOINT_NAME = f"custom-sdk-integration"
19+
20+
MODEL_NAME = f"ds-model-integration"
21+
S3_BUCKET = "test-model-s3-zhaoqi"
22+
MODEL_LOCATION = "deepseek15b"
23+
IMAGE_URI = "763104351884.dkr.ecr.us-east-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.4.0-tgi2.3.1-gpu-py311-cu124-ubuntu22.04-v2.0"
24+
TLS_URI = "s3://tls-bucket-inf1-beta2"
25+
26+
TIMEOUT_MINUTES = 15
27+
POLL_INTERVAL_SECONDS = 30
28+
29+
@pytest.fixture(scope="module")
30+
def sagemaker_client():
31+
return boto3.client("sagemaker", region_name=REGION)
32+
33+
@pytest.fixture(scope="module")
34+
def custom_endpoint():
35+
# TLS
36+
tls = TlsConfig(tls_certificate_output_s3_uri=TLS_URI)
37+
38+
# Model Source
39+
model_src = ModelSourceConfig(
40+
model_source_type="s3",
41+
model_location=MODEL_LOCATION,
42+
s3_storage=S3Storage(
43+
bucket_name=S3_BUCKET,
44+
region=REGION
45+
)
46+
)
47+
48+
# Env vars
49+
env_vars = [
50+
EnvironmentVariables(name="HF_MODEL_ID", value="/opt/ml/model"),
51+
EnvironmentVariables(name="SAGEMAKER_PROGRAM", value="inference.py"),
52+
EnvironmentVariables(name="SAGEMAKER_SUBMIT_DIRECTORY", value="/opt/ml/model/code"),
53+
EnvironmentVariables(name="MODEL_CACHE_ROOT", value="/opt/ml/model"),
54+
EnvironmentVariables(name="SAGEMAKER_ENV", value="1"),
55+
]
56+
57+
# Worker
58+
worker = Worker(
59+
image=IMAGE_URI,
60+
model_volume_mount=ModelVolumeMount(name="model-weights"),
61+
model_invocation_port=ModelInvocationPort(container_port=8080),
62+
resources=Resources(
63+
requests={"cpu": "30000m", "nvidia.com/gpu": 1, "memory": "100Gi"},
64+
limits={"nvidia.com/gpu": 1}
65+
),
66+
environment_variables=env_vars
67+
)
68+
69+
# AutoScaling
70+
dimensions = [
71+
Dimensions(name="EndpointName", value=ENDPOINT_NAME),
72+
Dimensions(name="VariantName", value="AllTraffic"),
73+
]
74+
cw_trigger = CloudWatchTrigger(
75+
dimensions=dimensions,
76+
metric_collection_period=30,
77+
metric_name="Invocations",
78+
metric_stat="Sum",
79+
metric_type="Average",
80+
min_value=0.0,
81+
name="SageMaker-Invocations",
82+
namespace="AWS/SageMaker",
83+
target_value=10,
84+
use_cached_metrics=True
85+
)
86+
auto_scaling = AutoScalingSpec(cloud_watch_trigger=cw_trigger)
87+
88+
# Metrics
89+
metrics = Metrics(enabled=True)
90+
91+
return HPEndpoint(
92+
endpoint_name=ENDPOINT_NAME,
93+
instance_type="ml.g5.8xlarge",
94+
model_name=MODEL_NAME,
95+
tls_config=tls,
96+
model_source_config=model_src,
97+
worker=worker,
98+
auto_scaling_spec=auto_scaling,
99+
metrics=metrics
100+
)
101+
102+
def test_create_endpoint(custom_endpoint):
103+
custom_endpoint.create(namespace=NAMESPACE)
104+
assert custom_endpoint.metadata.name == ENDPOINT_NAME
105+
106+
def test_list_endpoint():
107+
endpoints = HPEndpoint.list(namespace=NAMESPACE)
108+
names = [ep.metadata.name for ep in endpoints]
109+
assert ENDPOINT_NAME in names
110+
111+
def test_get_endpoint():
112+
ep = HPEndpoint.get(name=ENDPOINT_NAME, namespace=NAMESPACE)
113+
assert ep.modelName == MODEL_NAME
114+
115+
def test_wait_until_inservice():
116+
"""Poll SDK until specific JumpStart endpoint reaches DeploymentComplete"""
117+
print(f"[INFO] Waiting for JumpStart endpoint '{ENDPOINT_NAME}' to be DeploymentComplete...")
118+
deadline = time.time() + (TIMEOUT_MINUTES * 60)
119+
poll_count = 0
120+
121+
while time.time() < deadline:
122+
poll_count += 1
123+
print(f"[DEBUG] Poll #{poll_count}: Checking endpoint status...")
124+
125+
try:
126+
ep = HPEndpoint.get(name=ENDPOINT_NAME, namespace=NAMESPACE)
127+
state = ep.status.endpoints.sagemaker.state
128+
print(f"[DEBUG] Current state: {state}")
129+
if state == "CreationCompleted":
130+
print("[INFO] Endpoint is in CreationCompleted state.")
131+
return
132+
133+
deployment_state = ep.status.deploymentStatus.deploymentObjectOverallState
134+
if deployment_state == "DeploymentFailed":
135+
pytest.fail("Endpoint deployment failed.")
136+
137+
except Exception as e:
138+
print(f"[ERROR] Exception during polling: {e}")
139+
140+
time.sleep(POLL_INTERVAL_SECONDS)
141+
142+
pytest.fail("[ERROR] Timed out waiting for endpoint to be DeploymentComplete")
143+
144+
def test_invoke_endpoint(monkeypatch):
145+
original_transform = codec.transform # Save original
146+
147+
def mock_transform(data, shape, object_instance=None):
148+
if "Body" in data:
149+
return {"body": data["Body"].read().decode("utf-8")}
150+
return original_transform(data, shape, object_instance) # Call original
151+
152+
monkeypatch.setattr("sagemaker_core.main.resources.transform", mock_transform)
153+
154+
ep = HPEndpoint.get(name=ENDPOINT_NAME, namespace=NAMESPACE)
155+
data = '{"inputs":"What is the capital of USA?"}'
156+
response = ep.invoke(body=data)
157+
158+
assert "error" not in response.body.lower()
159+
160+
161+
def test_get_operator_logs():
162+
ep = HPEndpoint.get(name=ENDPOINT_NAME, namespace=NAMESPACE)
163+
logs = ep.get_operator_logs(since_hours=1)
164+
assert logs
165+
166+
167+
def test_list_pods():
168+
ep = HPEndpoint.get(name=ENDPOINT_NAME, namespace=NAMESPACE)
169+
pods = ep.list_pods(NAMESPACE)
170+
assert pods
171+
172+
173+
def test_delete_endpoint():
174+
ep = HPEndpoint.get(name=ENDPOINT_NAME, namespace=NAMESPACE)
175+
ep.delete()

0 commit comments

Comments
 (0)