Skip to content

Commit f27a3cb

Browse files
authored
Make metadata name same as endpoint name; Updated instance type validation (#110)
Unit test passes and verified in jupyter notebook
1 parent 5eeed51 commit f27a3cb

File tree

4 files changed

+105
-25
lines changed

4 files changed

+105
-25
lines changed

src/sagemaker/hyperpod/inference/hp_endpoint.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
from sagemaker.hyperpod.common.config.metadata import Metadata
22
from sagemaker.hyperpod.inference.config.constants import *
3-
from sagemaker.hyperpod.common.utils import append_uuid, get_default_namespace
3+
from sagemaker.hyperpod.common.utils import (
4+
append_uuid,
5+
get_default_namespace,
6+
get_cluster_instance_types,
7+
)
48
from sagemaker.hyperpod.inference.config.hp_endpoint_config import (
59
InferenceEndpointConfigStatus,
610
_HPEndpoint,
@@ -31,15 +35,17 @@ def create(
3135

3236
spec = _HPEndpoint(**self.model_dump(by_alias=True, exclude_none=True))
3337

34-
if not name:
35-
name = append_uuid(spec.modelName)
36-
3738
if not namespace:
3839
namespace = get_default_namespace()
3940

4041
if spec.endpointName:
4142
spec.endpointName = append_uuid(spec.endpointName)
4243

44+
if not name:
45+
name = spec.endpointName
46+
47+
self.validate_instance_type(spec.instanceType)
48+
4349
self.call_create_api(
4450
name=name, # use model name as metadata name
4551
kind=INFERENCE_ENDPOINT_CONFIG_KIND,
@@ -53,7 +59,7 @@ def create(
5359
)
5460

5561
self.get_logger().info(
56-
f"Creating sagemaker model and endpoint. Metadata name: {name}. Endpoint name: {spec.endpointName}.\n The process may take a few minutes..."
62+
f"Creating sagemaker model and endpoint. Endpoint name: {spec.endpointName}.\n The process may take a few minutes..."
5763
)
5864

5965
def create_from_dict(
@@ -64,15 +70,17 @@ def create_from_dict(
6470
) -> None:
6571
spec = _HPEndpoint.model_validate(input, by_name=True)
6672

67-
if not name:
68-
name = append_uuid(spec.modelName)
69-
7073
if not namespace:
7174
namespace = get_default_namespace()
7275

7376
if spec.endpointName:
7477
spec.endpointName = append_uuid(spec.endpointName)
7578

79+
if not name:
80+
name = spec.endpointName
81+
82+
self.validate_instance_type(spec.instanceType)
83+
7684
self.call_create_api(
7785
name=name, # use model name as metadata name
7886
kind=INFERENCE_ENDPOINT_CONFIG_KIND,
@@ -86,7 +94,7 @@ def create_from_dict(
8694
)
8795

8896
self.get_logger().info(
89-
f"Creating sagemaker model and endpoint. Metadata name: {name}. Endpoint name: {spec.endpointName}.\n The process may take a few minutes..."
97+
f"Creating sagemaker model and endpoint. Endpoint name: {spec.endpointName}.\n The process may take a few minutes..."
9098
)
9199

92100
def refresh(self):
@@ -172,3 +180,22 @@ def invoke(self, body, content_type="application/json"):
172180
)
173181

174182
return endpoint.invoke(body=body, content_type=content_type)
183+
184+
def validate_instance_type(self, instance_type: str):
185+
cluster_instance_types = None
186+
187+
# verify supported instance types from HyperPod cluster
188+
try:
189+
cluster_instance_types = get_cluster_instance_types(
190+
cluster=HyperPodManager.get_current_cluster(),
191+
region=HyperPodManager.get_current_region(),
192+
)
193+
except Exception as e:
194+
self.get_logger().warning(
195+
f"Failed to get instance types from HyperPod cluster: {e}"
196+
)
197+
198+
if cluster_instance_types and (instance_type not in cluster_instance_types):
199+
raise Exception(
200+
f"Current HyperPod cluster does not have instance type {instance_type}. Supported instance types are {cluster_instance_types}"
201+
)

src/sagemaker/hyperpod/inference/hp_jumpstart_endpoint.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,17 @@ def create(
3737

3838
spec = _HPJumpStartEndpoint(**self.model_dump(by_alias=True, exclude_none=True))
3939

40-
if not name:
41-
name = append_uuid(spec.model.modelId)
42-
4340
endpoint_name = ""
4441
if spec.sageMakerEndpoint and spec.sageMakerEndpoint.name:
4542
spec.sageMakerEndpoint.name = append_uuid(spec.sageMakerEndpoint.name)
4643
endpoint_name = spec.sageMakerEndpoint.name
4744

45+
if not endpoint_name and not name:
46+
raise Exception('Input "name" is required if endpoint name is not provided')
47+
48+
if not name:
49+
name = endpoint_name
50+
4851
if not namespace:
4952
namespace = get_default_namespace()
5053

@@ -63,7 +66,7 @@ def create(
6366
)
6467

6568
self.get_logger().info(
66-
f"Creating JumpStart model and sagemaker endpoint. Metadata name: {name}. Endpoint name: {endpoint_name}.\n The process may take a few minutes..."
69+
f"Creating JumpStart model and sagemaker endpoint. Endpoint name: {endpoint_name}.\n The process may take a few minutes..."
6770
)
6871

6972
def create_from_dict(
@@ -74,17 +77,22 @@ def create_from_dict(
7477
) -> None:
7578
spec = _HPJumpStartEndpoint.model_validate(input, by_name=True)
7679

77-
if not name:
78-
name = append_uuid(spec.model.modelId)
79-
8080
endpoint_name = ""
8181
if spec.sageMakerEndpoint and spec.sageMakerEndpoint.name:
8282
spec.sageMakerEndpoint.name = append_uuid(spec.sageMakerEndpoint.name)
8383
endpoint_name = spec.sageMakerEndpoint.name
8484

85+
if not endpoint_name and not name:
86+
raise Exception('Input "name" is required if endpoint name is not provided')
87+
88+
if not name:
89+
name = endpoint_name
90+
8591
if not namespace:
8692
namespace = get_default_namespace()
8793

94+
self.validate_instance_type(spec.model.modelId, spec.server.instanceType)
95+
8896
self.call_create_api(
8997
name=name, # use model name as metadata name
9098
kind=JUMPSTART_MODEL_KIND,
@@ -98,7 +106,7 @@ def create_from_dict(
98106
)
99107

100108
self.get_logger().info(
101-
f"Creating JumpStart model and sagemaker endpoint. Metadata name: {name}. Endpoint name: {endpoint_name}.\n The process may take a few minutes..."
109+
f"Creating JumpStart model and sagemaker endpoint. Endpoint name: {endpoint_name}.\n The process may take a few minutes..."
102110
)
103111

104112
def refresh(self):

test/unit_tests/inference/test_hp_endpoint.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def setUp(self):
9494

9595
self.endpoint = HPEndpoint(
9696
endpoint_name="s3-test-endpoint-name",
97-
instance_type="ml.g5.8xlarge",
97+
instance_type="ml.g5.xlarge",
9898
model_name="deepseek15b-test-model-name",
9999
tls_config=tls_config,
100100
model_source_config=model_source_config,
@@ -103,9 +103,16 @@ def setUp(self):
103103
metrics=metrics,
104104
)
105105

106+
@patch("sagemaker.hyperpod.hyperpod_manager.HyperPodManager.get_current_cluster")
107+
@patch("sagemaker.hyperpod.hyperpod_manager.HyperPodManager.get_current_region")
108+
@patch("sagemaker.hyperpod.common.utils.get_cluster_instance_types")
106109
@patch.object(HPEndpoint, "call_create_api")
107-
def test_create(self, mock_create_api):
108-
self.endpoint.modelName = "test-model"
110+
def test_create(
111+
self, mock_create_api, mock_get_cluster_types, mock_get_region, mock_get_cluster
112+
):
113+
mock_get_cluster_types.return_value = ["ml.g5.xlarge"]
114+
mock_get_region.return_value = "us-west-2"
115+
mock_get_cluster.return_value = "test-cluster"
109116

110117
self.endpoint.create(name="test-name", namespace="test-ns")
111118

@@ -117,8 +124,17 @@ def test_create(self, mock_create_api):
117124
)
118125
self.assertEqual(self.endpoint.metadata.name, "test-name")
119126

127+
@patch("sagemaker.hyperpod.hyperpod_manager.HyperPodManager.get_current_cluster")
128+
@patch("sagemaker.hyperpod.hyperpod_manager.HyperPodManager.get_current_region")
129+
@patch("sagemaker.hyperpod.common.utils.get_cluster_instance_types")
120130
@patch.object(HPEndpoint, "call_create_api")
121-
def test_create_from_dict(self, mock_create_api):
131+
def test_create_from_dict(
132+
self, mock_create_api, mock_get_cluster_types, mock_get_region, mock_get_cluster
133+
):
134+
mock_get_cluster_types.return_value = ["ml.g5.xlarge"]
135+
mock_get_region.return_value = "us-west-2"
136+
mock_get_cluster.return_value = "test-cluster"
137+
122138
input_dict = self.endpoint.model_dump(exclude_none=True)
123139

124140
self.endpoint.create_from_dict(input_dict, namespace="test-ns")

test/unit_tests/inference/test_hp_jumpstart_endpoint.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,23 @@ def setUp(self):
3434
tls_config=tls_config,
3535
)
3636

37+
@patch("sagemaker.hyperpod.hyperpod_manager.HyperPodManager.get_current_cluster")
38+
@patch("sagemaker.hyperpod.hyperpod_manager.HyperPodManager.get_current_region")
39+
@patch("sagemaker.hyperpod.common.utils.get_cluster_instance_types")
40+
@patch("sagemaker.hyperpod.common.utils.get_jumpstart_model_instance_types")
3741
@patch.object(HPJumpStartEndpoint, "call_create_api")
38-
def test_create(self, mock_create_api):
39-
self.endpoint.model = MagicMock()
40-
self.endpoint.model.modelId = "test-model-id"
42+
def test_create(
43+
self,
44+
mock_create_api,
45+
mock_get_model_types,
46+
mock_get_cluster_types,
47+
mock_get_region,
48+
mock_get_cluster,
49+
):
50+
mock_get_model_types.return_value = ["ml.c5.2xlarge"]
51+
mock_get_cluster_types.return_value = ["ml.c5.2xlarge"]
52+
mock_get_region.return_value = "us-west-2"
53+
mock_get_cluster.return_value = "test-cluster"
4154

4255
self.endpoint.create(name="test-name", namespace="test-ns")
4356

@@ -49,8 +62,24 @@ def test_create(self, mock_create_api):
4962
)
5063
self.assertEqual(self.endpoint.metadata.name, "test-name")
5164

65+
@patch("sagemaker.hyperpod.hyperpod_manager.HyperPodManager.get_current_cluster")
66+
@patch("sagemaker.hyperpod.hyperpod_manager.HyperPodManager.get_current_region")
67+
@patch("sagemaker.hyperpod.common.utils.get_cluster_instance_types")
68+
@patch("sagemaker.hyperpod.common.utils.get_jumpstart_model_instance_types")
5269
@patch.object(HPJumpStartEndpoint, "call_create_api")
53-
def test_create_from_dict(self, mock_create_api):
70+
def test_create_from_dict(
71+
self,
72+
mock_create_api,
73+
mock_get_model_types,
74+
mock_get_cluster_types,
75+
mock_get_region,
76+
mock_get_cluster,
77+
):
78+
mock_get_model_types.return_value = ["ml.c5.2xlarge"]
79+
mock_get_cluster_types.return_value = ["ml.c5.2xlarge"]
80+
mock_get_region.return_value = "us-west-2"
81+
mock_get_cluster.return_value = "test-cluster"
82+
5483
input_dict = {
5584
"model": {"modelId": "test-model"},
5685
"server": {"instance_type": "ml.c5.2xlarge"},

0 commit comments

Comments
 (0)