Skip to content

Commit c3a8a16

Browse files
authored
Merge pull request #40 from zhaoqizqwang/add-inference-classes
Make function classmethod and update unit tests
2 parents 5430f3f + b22e4ad commit c3a8a16

File tree

10 files changed

+288
-171
lines changed

10 files changed

+288
-171
lines changed

sagemaker-hyperpod/setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@
1616
"pyyaml==6.0.2",
1717
"ratelimit==2.2.1",
1818
"tabulate==0.9.0",
19+
"pydantic==2.11.7",
1920
"pytest==8.3.2",
2021
"pytest-cov==5.0.0",
2122
"pytest-order==1.3.0",
2223
"tox==4.18.0",
2324
"ruff==0.6.2",
2425
"hera-workflows==5.16.3",
2526
],
26-
)
27+
)

sagemaker-hyperpod/src/sagemaker/hyperpod/hyperpod_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def list_clusters(
125125

126126
print(tabulate(table_data, headers=headers))
127127

128-
def set_context_cluster(
128+
def set_context(
129129
self,
130130
cluster_name: str,
131131
region: Optional[str] = None,

sagemaker-hyperpod/src/sagemaker/hyperpod/inference/config/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
INFERENCE_ENDPOINT_CONFIG_KIND = "InferenceEndpointConfig"
1111
INFERENCE_ENDPOINT_CONFIG_PLURAL = "inferenceendpointconfigs"
1212
MODEL_HUB_NAME = "SageMakerPublicHub"
13+
DEFAULT_MOUNT_PATH = "/opt/ml/model"
1314

1415
KIND_PLURAL_MAP = {
1516
JUMPSTART_MODEL_KIND: JUMPSTART_MODEL_PLURAL,

sagemaker-hyperpod/src/sagemaker/hyperpod/inference/hp_endpoint.py

Lines changed: 52 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
)
1212
from sagemaker.hyperpod.inference.hp_endpoint_base import HPEndpointBase
1313
from datetime import datetime
14-
from typing import Union, Dict, Literal
14+
from typing import Dict, Literal
15+
import boto3
1516

1617

1718
class HPEndpoint(HPEndpointBase):
@@ -93,9 +94,11 @@ def _validate_inputs(
9394
raise TypeError(
9495
f"fsx_mount_name must be of type str, got {type(fsx_mount_name)}"
9596
)
96-
97+
9798
# Validate model_volume_mount_name if provided
98-
if model_volume_mount_name is not None and not isinstance(model_volume_mount_name, str):
99+
if model_volume_mount_name is not None and not isinstance(
100+
model_volume_mount_name, str
101+
):
99102
raise TypeError(
100103
f"model_volume_mount_name must be of type str, got {type(model_volume_mount_name)}"
101104
)
@@ -114,9 +117,10 @@ def _get_default_endpoint_name(self, model_name):
114117

115118
return model_name + "-" + time_str
116119

120+
@classmethod
117121
def create(
118-
self,
119-
namespace: str,
122+
cls,
123+
namespace: str = None,
120124
model_name: str = None,
121125
model_version: str = None,
122126
instance_type: str = None,
@@ -132,7 +136,9 @@ def create(
132136
model_volume_mount_name: str = None,
133137
model_volume_mount_path: str = None,
134138
):
135-
self._validate_inputs(
139+
instance = cls()
140+
141+
instance._validate_inputs(
136142
model_name,
137143
instance_type,
138144
image,
@@ -143,16 +149,14 @@ def create(
143149
fsx_dns_name,
144150
fsx_file_system_id,
145151
fsx_mount_name,
146-
endpoint_name,
147152
model_volume_mount_name,
148-
model_volume_mount_path,
149153
)
150154

151155
if not endpoint_name:
152-
endpoint_name = self._get_default_endpoint_name(model_name)
156+
endpoint_name = instance._get_default_endpoint_name(model_name)
153157

154158
if not model_volume_mount_path:
155-
model_volume_mount_path = "/opt/ml/model"
159+
model_volume_mount_path = DEFAULT_MOUNT_PATH
156160

157161
if model_source_type == "s3":
158162
model_source_config = ModelSourceConfig(
@@ -181,7 +185,7 @@ def create(
181185
model_invocation_port=ModelInvocationPort(container_port=container_port),
182186
resources=Resources(),
183187
)
184-
188+
185189
# create spec config
186190
spec = InferenceEndpointConfigSpec(
187191
instance_type=instance_type,
@@ -192,56 +196,80 @@ def create(
192196
endpoint_name=endpoint_name,
193197
)
194198

195-
self.call_create_api(
199+
instance.call_create_api(
196200
name=spec.modelName, # use model name as metadata name
197201
kind=INFERENCE_ENDPOINT_CONFIG_KIND,
198202
namespace=namespace,
199203
spec=spec,
200204
)
201205

206+
return super().get_endpoint(endpoint_name)
207+
208+
@classmethod
202209
def create_from_spec(
203-
self,
210+
cls,
204211
spec: InferenceEndpointConfigSpec,
205212
namespace: str = None,
206213
):
207-
self.call_create_api(
214+
cls().call_create_api(
208215
name=spec.modelName, # use model name as metadata name
209216
kind=INFERENCE_ENDPOINT_CONFIG_KIND,
210217
namespace=namespace,
211218
spec=spec,
212219
)
213220

214-
def create_from_dict(self, input: Dict, namespace: str):
221+
region = boto3.session.Session().region_name
222+
223+
return super().get_endpoint(endpoint_name=spec.endpointName, region=region)
224+
225+
@classmethod
226+
def create_from_dict(
227+
cls,
228+
input: Dict,
229+
namespace: str = None,
230+
):
215231
spec = InferenceEndpointConfigSpec.model_validate(input, by_name=True)
216232

217-
self.call_create_api(
233+
cls().call_create_api(
234+
name=spec.modelName, # use model name as metadata name
235+
kind=INFERENCE_ENDPOINT_CONFIG_KIND,
218236
namespace=namespace,
219237
spec=spec,
220238
)
221239

222-
def list_endpoints(self, namespace: str):
223-
return self.call_list_api(
240+
region = boto3.session.Session().region_name
241+
242+
return super().get_endpoint(endpoint_name=spec.endpointName, region=region)
243+
244+
@classmethod
245+
def list_endpoints(
246+
cls,
247+
namespace: str = None,
248+
):
249+
return cls().call_list_api(
224250
kind=INFERENCE_ENDPOINT_CONFIG_KIND,
225251
namespace=namespace,
226252
)
227253

254+
@classmethod
228255
def describe_endpoint(
229-
self,
256+
cls,
230257
name: str,
231-
namespace: str,
258+
namespace: str = None,
232259
):
233-
return self.call_get_api(
260+
return cls().call_get_api(
234261
name=name,
235262
kind=INFERENCE_ENDPOINT_CONFIG_KIND,
236263
namespace=namespace,
237264
)
238265

266+
@classmethod
239267
def delete_endpoint(
240-
self,
268+
cls,
241269
name: str,
242-
namespace: str,
270+
namespace: str = None,
243271
):
244-
return self.call_delete_api(
272+
cls().call_delete_api(
245273
name=name, # use model id as metadata name
246274
kind=INFERENCE_ENDPOINT_CONFIG_KIND,
247275
namespace=namespace,

sagemaker-hyperpod/src/sagemaker/hyperpod/inference/hp_endpoint_base.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,35 @@
99
InferenceEndpointConfigSpec,
1010
)
1111
from types import SimpleNamespace
12+
import boto3
13+
from sagemaker_core.main.resources import Endpoint
14+
15+
16+
def get_current_region():
17+
session = boto3.session.Session()
18+
return session.region_name
1219

1320

1421
class HPEndpointBase:
22+
_endpoint = None
23+
24+
def get_name(self):
25+
if not self._endpoint or not hasattr(self._endpoint, "endpoint_name"):
26+
print(f"Endpoint is not set!")
27+
28+
print(f"Endpoint name is: {self._endpoint.endpoint_name}")
29+
30+
@classmethod
31+
def get_endpoint(
32+
cls,
33+
endpoint_name: str,
34+
region: str = None,
35+
):
36+
if not region:
37+
region = get_current_region()
38+
39+
return Endpoint.get(endpoint_name, region=region)
40+
1541
def _validate_connection(self):
1642
try:
1743
k8s_config.load_kube_config()
@@ -51,6 +77,8 @@ def call_create_api(
5177
plural=KIND_PLURAL_MAP[kind],
5278
body=body,
5379
)
80+
81+
self.set_endpoint(spec.sageMakerEndpoint.name)
5482
print("\nSuccessful deployed model and its endpoint!")
5583
except Exception as e:
5684
print(f"\nFailed to deploy model and its endpoint: {e}")
@@ -125,3 +153,12 @@ def call_list_api(
125153
)
126154
except Exception as e:
127155
print(f"\nFailed to list endpoint: {e}")
156+
157+
def invoke(self, body, **kwargs):
158+
if self._endpoint is None:
159+
raise Exception("Endpoint not initialized. Please set endpoint first.")
160+
161+
try:
162+
self._endpoint.invoke(body, **kwargs)
163+
except Exception as e:
164+
print(f"\nFailed to invoke endpoint: {e}")

sagemaker-hyperpod/src/sagemaker/hyperpod/inference/hp_jumpstart_endpoint.py

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ def _validate_inputs(
1818
model_id: str,
1919
instance_type: str,
2020
):
21-
2221
# Validate required parameters when spec is None
2322
if model_id is None or instance_type is None:
2423
raise ValueError("Must provide both model_id and instance_type.")
@@ -45,19 +44,21 @@ def _get_default_endpoint_name(self, model_id: str):
4544

4645
return model_id + "-" + time_str
4746

47+
@classmethod
4848
def create(
49-
self,
50-
namespace: str,
49+
cls,
50+
namespace: str = None,
5151
model_id: str = None,
5252
model_version: str = None,
5353
instance_type: str = None,
5454
sagemaker_endpoint: str = None,
5555
accept_eula: bool = False,
5656
):
57-
self._validate_inputs(model_id, instance_type)
57+
instance = cls()
58+
instance._validate_inputs(model_id, instance_type)
5859

5960
if not sagemaker_endpoint:
60-
sagemaker_endpoint = self._get_default_endpoint_name(model_id)
61+
sagemaker_endpoint = instance._get_default_endpoint_name(model_id)
6162

6263
spec = JumpStartModelSpec(
6364
model=Model(
@@ -69,35 +70,53 @@ def create(
6970
sage_maker_endpoint=SageMakerEndpoint(name=sagemaker_endpoint),
7071
)
7172

72-
self.call_create_api(
73+
instance.call_create_api(
7374
name=spec.model.modelId, # use model id as metadata name
7475
kind=JUMPSTART_MODEL_KIND,
7576
namespace=namespace,
7677
spec=spec,
7778
)
7879

80+
return cls.get_endpoint(endpoint_name=sagemaker_endpoint)
81+
82+
@classmethod
7983
def create_from_spec(
80-
self,
84+
cls,
8185
spec: JumpStartModelSpec,
8286
namespace: str = None,
8387
):
84-
self.call_create_api(
88+
cls().call_create_api(
8589
name=spec.model.modelId, # use model id as metadata name
8690
kind=JUMPSTART_MODEL_KIND,
8791
namespace=namespace,
8892
spec=spec,
8993
)
9094

91-
def create_from_dict(self, input: Dict, namespace: str):
95+
return super().get_endpoint(endpoint_name=spec.sageMakerEndpoint.name)
96+
97+
@classmethod
98+
def create_from_dict(
99+
cls,
100+
input: Dict,
101+
namespace: str = None,
102+
):
92103
spec = JumpStartModelSpec.model_validate(input, by_name=True)
93104

94-
self.call_create_api(
105+
cls().call_create_api(
106+
name=spec.model.modelId, # use model id as metadata name
107+
kind=JUMPSTART_MODEL_KIND,
95108
namespace=namespace,
96109
spec=spec,
97110
)
98111

99-
def list_endpoints(self, namespace: str):
100-
response = self.call_list_api(
112+
return super().get_endpoint(endpoint_name=spec.sageMakerEndpoint.name)
113+
114+
@classmethod
115+
def list_endpoints(
116+
cls,
117+
namespace: str = None,
118+
):
119+
response = cls().call_list_api(
101120
kind=JUMPSTART_MODEL_KIND,
102121
namespace=namespace,
103122
)
@@ -109,27 +128,31 @@ def list_endpoints(self, namespace: str):
109128
headers = ["METADATA NAME", "CREATE TIME"]
110129

111130
print(tabulate(output_data, headers=headers))
131+
return response
112132

133+
@classmethod
113134
def describe_endpoint(
114-
self,
135+
cls,
115136
name: str,
116-
namespace: str,
137+
namespace: str = None,
117138
):
118-
response = self.call_get_api(
139+
response = cls().call_get_api(
119140
name=name,
120141
kind=JUMPSTART_MODEL_KIND,
121142
namespace=namespace,
122143
)
123144

124145
response["metadata"].pop("managedFields")
125146
print(yaml.dump(response))
147+
return response
126148

149+
@classmethod
127150
def delete_endpoint(
128-
self,
151+
cls,
129152
name: str,
130-
namespace: str,
153+
namespace: str = None,
131154
):
132-
return self.call_delete_api(
155+
return cls().call_delete_api(
133156
name=name, # use model id as metadata name
134157
kind=JUMPSTART_MODEL_KIND,
135158
namespace=namespace,

0 commit comments

Comments
 (0)