diff --git a/src/sagemaker/hyperpod/inference/hp_endpoint.py b/src/sagemaker/hyperpod/inference/hp_endpoint.py index f4bc2b22..bb6c3c74 100644 --- a/src/sagemaker/hyperpod/inference/hp_endpoint.py +++ b/src/sagemaker/hyperpod/inference/hp_endpoint.py @@ -19,6 +19,7 @@ from typing import Dict, List, Optional from sagemaker_core.main.resources import Endpoint from pydantic import Field, ValidationError +from kubernetes import client class HPEndpoint(_HPEndpoint, HPEndpointBase): @@ -211,3 +212,34 @@ def validate_instance_type(self, instance_type: str): raise Exception( f"Current HyperPod cluster does not have instance type {instance_type}. Supported instance types are {cluster_instance_types}" ) + + @classmethod + @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "list_pods_endpoint") + def list_pods(cls, namespace=None): + cls.verify_kube_config() + + if not namespace: + namespace = get_default_namespace() + + v1 = client.CoreV1Api() + list_pods_response = v1.list_namespaced_pod(namespace=namespace) + + list_response = cls.call_list_api( + kind=INFERENCE_ENDPOINT_CONFIG_KIND, + namespace=namespace, + ) + + endpoints = set() + if list_response and list_response["items"]: + for item in list_response["items"]: + endpoints.add(item["metadata"]["name"]) + + pods = [] + for item in list_pods_response.items: + app_name = item.metadata.labels.get("app", None) + if app_name in endpoints: + # list_namespaced_pod will return all pods in the namespace, so we need to filter + # out the pods that are created by custom endpoint + pods.append(item.metadata.name) + + return pods diff --git a/src/sagemaker/hyperpod/inference/hp_endpoint_base.py b/src/sagemaker/hyperpod/inference/hp_endpoint_base.py index 1a5c22c2..c8f2c451 100644 --- a/src/sagemaker/hyperpod/inference/hp_endpoint_base.py +++ b/src/sagemaker/hyperpod/inference/hp_endpoint_base.py @@ -209,23 +209,6 @@ def get_logs( return logs - @classmethod - @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "list_pods_endpoint") - def list_pods(cls, namespace=None): - cls.verify_kube_config() - - if not namespace: - namespace = get_default_namespace() - - v1 = client.CoreV1Api() - response = v1.list_namespaced_pod(namespace=namespace) - - pods = [] - for item in response.items: - pods.append(item.metadata.name) - - return pods - @classmethod @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "list_namespaces") def list_namespaces(cls): diff --git a/src/sagemaker/hyperpod/inference/hp_jumpstart_endpoint.py b/src/sagemaker/hyperpod/inference/hp_jumpstart_endpoint.py index c3a45711..ad872227 100644 --- a/src/sagemaker/hyperpod/inference/hp_jumpstart_endpoint.py +++ b/src/sagemaker/hyperpod/inference/hp_jumpstart_endpoint.py @@ -20,6 +20,7 @@ _hyperpod_telemetry_emitter, ) from sagemaker.hyperpod.common.telemetry.constants import Feature +from kubernetes import client class HPJumpStartEndpoint(_HPJumpStartEndpoint, HPEndpointBase): @@ -240,3 +241,34 @@ def validate_instance_type(self, model_id: str, instance_type: str): raise Exception( f"Current HyperPod cluster does not have instance type {instance_type}. Supported instance types are {cluster_instance_types}" ) + + @classmethod + @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "list_pods_endpoint") + def list_pods(cls, namespace=None): + cls.verify_kube_config() + + if not namespace: + namespace = get_default_namespace() + + v1 = client.CoreV1Api() + list_pods_response = v1.list_namespaced_pod(namespace=namespace) + + list_response = cls.call_list_api( + kind=JUMPSTART_MODEL_KIND, + namespace=namespace, + ) + + endpoints = set() + if list_response and list_response["items"]: + for item in list_response["items"]: + endpoints.add(item["metadata"]["name"]) + + pods = [] + for item in list_pods_response.items: + app_name = item.metadata.labels.get("app", None) + if app_name in endpoints: + # list_namespaced_pod will return all pods in the namespace, so we need to filter + # out the pods that are created by jumpstart endpoint + pods.append(item.metadata.name) + + return pods diff --git a/test/unit_tests/inference/test_hp_endpoint.py b/test/unit_tests/inference/test_hp_endpoint.py index a225e586..c948fd30 100644 --- a/test/unit_tests/inference/test_hp_endpoint.py +++ b/test/unit_tests/inference/test_hp_endpoint.py @@ -194,3 +194,37 @@ def test_invoke(self, mock_endpoint_get, mock_get_cluster_context): body={"input": "test"}, content_type="application/json" ) self.assertEqual(result, "response") + + @patch.object(HPEndpoint, "call_list_api") + @patch("kubernetes.client.CoreV1Api") + @patch.object(HPEndpoint, "verify_kube_config") + def test_list_pods(self, mock_verify_config, mock_core_api, mock_list_api): + mock_pod1 = MagicMock() + mock_pod1.metadata.name = "custom-endpoint-pod1" + mock_pod1.metadata.labels = {"app": "custom-endpoint"} + mock_pod2 = MagicMock() + mock_pod2.metadata.name = "custom-endpoint-pod2" + mock_pod2.metadata.labels = {"app": "custom-endpoint"} + mock_pod3 = MagicMock() + mock_pod3.metadata.name = "not-custom-endpoint-pod" + mock_pod3.metadata.labels = {"app": "not-custom-endpoint"} + mock_core_api.return_value.list_namespaced_pod.return_value.items = [ + mock_pod1, + mock_pod2, + mock_pod3, + ] + + mock_list_api.return_value = { + "items": [ + { + "metadata": {"name": "custom-endpoint"} + } + ] + } + + result = self.endpoint.list_pods(namespace="test-ns") + + self.assertEqual(result, ["custom-endpoint-pod1", "custom-endpoint-pod2"]) + mock_core_api.return_value.list_namespaced_pod.assert_called_once_with( + namespace="test-ns" + ) diff --git a/test/unit_tests/inference/test_hp_endpoint_base.py b/test/unit_tests/inference/test_hp_endpoint_base.py index 4e27d89a..aeca28b4 100644 --- a/test/unit_tests/inference/test_hp_endpoint_base.py +++ b/test/unit_tests/inference/test_hp_endpoint_base.py @@ -109,25 +109,6 @@ def test_get_logs(self, mock_verify_config, mock_core_api): timestamps=True, ) - @patch("kubernetes.client.CoreV1Api") - @patch.object(HPEndpointBase, "verify_kube_config") - def test_list_pods(self, mock_verify_config, mock_core_api): - mock_pod1 = MagicMock() - mock_pod1.metadata.name = "pod1" - mock_pod2 = MagicMock() - mock_pod2.metadata.name = "pod2" - mock_core_api.return_value.list_namespaced_pod.return_value.items = [ - mock_pod1, - mock_pod2, - ] - - result = self.base.list_pods(namespace="test-ns") - - self.assertEqual(result, ["pod1", "pod2"]) - mock_core_api.return_value.list_namespaced_pod.assert_called_once_with( - namespace="test-ns" - ) - @patch("kubernetes.client.CoreV1Api") @patch.object(HPEndpointBase, "verify_kube_config") def test_list_namespaces(self, mock_verify_config, mock_core_api): diff --git a/test/unit_tests/inference/test_hp_jumpstart_endpoint.py b/test/unit_tests/inference/test_hp_jumpstart_endpoint.py index b067836a..b0cdb514 100644 --- a/test/unit_tests/inference/test_hp_jumpstart_endpoint.py +++ b/test/unit_tests/inference/test_hp_jumpstart_endpoint.py @@ -140,3 +140,37 @@ def test_invoke(self, mock_endpoint_get, mock_get_cluster_context): body={"input": "test"}, content_type="application/json" ) self.assertEqual(result, "response") + + @patch.object(HPJumpStartEndpoint, "call_list_api") + @patch("kubernetes.client.CoreV1Api") + @patch.object(HPJumpStartEndpoint, "verify_kube_config") + def test_list_pods(self, mock_verify_config, mock_core_api, mock_list_api): + mock_pod1 = MagicMock() + mock_pod1.metadata.name = "js-endpoint-pod1" + mock_pod1.metadata.labels = {"app": "js-endpoint"} + mock_pod2 = MagicMock() + mock_pod2.metadata.name = "js-endpoint-pod2" + mock_pod2.metadata.labels = {"app": "js-endpoint"} + mock_pod3 = MagicMock() + mock_pod3.metadata.name = "not-js-endpoint-pod" + mock_pod3.metadata.labels = {"app": "not-js-endpoint"} + mock_core_api.return_value.list_namespaced_pod.return_value.items = [ + mock_pod1, + mock_pod2, + mock_pod3, + ] + + mock_list_api.return_value = { + "items": [ + { + "metadata": {"name": "js-endpoint"} + } + ] + } + + result = self.endpoint.list_pods(namespace="test-ns") + + self.assertEqual(result, ["js-endpoint-pod1", "js-endpoint-pod2"]) + mock_core_api.return_value.list_namespaced_pod.assert_called_once_with( + namespace="test-ns" + )