diff --git a/src/sagemaker/hyperpod/inference/hp_endpoint.py b/src/sagemaker/hyperpod/inference/hp_endpoint.py index d91424fb..963a456c 100644 --- a/src/sagemaker/hyperpod/inference/hp_endpoint.py +++ b/src/sagemaker/hyperpod/inference/hp_endpoint.py @@ -54,6 +54,7 @@ def create( kind=INFERENCE_ENDPOINT_CONFIG_KIND, namespace=namespace, spec=spec, + debug=debug, ) self.metadata = Metadata( @@ -71,9 +72,10 @@ def create_from_dict( input: Dict, name: str = None, namespace: str = None, + debug=False ) -> None: logger = self.get_logger() - logger = setup_logging(logger) + logger = setup_logging(logger, debug) spec = _HPEndpoint.model_validate(input, by_name=True) @@ -93,6 +95,7 @@ def create_from_dict( kind=INFERENCE_ENDPOINT_CONFIG_KIND, namespace=namespace, spec=spec, + debug=debug, ) self.metadata = Metadata( diff --git a/src/sagemaker/hyperpod/inference/hp_endpoint_base.py b/src/sagemaker/hyperpod/inference/hp_endpoint_base.py index 5c68c367..5407e7be 100644 --- a/src/sagemaker/hyperpod/inference/hp_endpoint_base.py +++ b/src/sagemaker/hyperpod/inference/hp_endpoint_base.py @@ -63,6 +63,7 @@ def call_create_api( kind: str, namespace: str, spec: Union[_HPJumpStartEndpoint, _HPEndpoint], + debug: bool = False, ): """Create an inference endpoint using Kubernetes API. @@ -104,7 +105,7 @@ def call_create_api( cls.verify_kube_config() logger = cls.get_logger() - logger = setup_logging(logger) + logger = setup_logging(logger, debug) custom_api = client.CustomObjectsApi() diff --git a/src/sagemaker/hyperpod/inference/hp_jumpstart_endpoint.py b/src/sagemaker/hyperpod/inference/hp_jumpstart_endpoint.py index 670a6e36..724e67e6 100644 --- a/src/sagemaker/hyperpod/inference/hp_jumpstart_endpoint.py +++ b/src/sagemaker/hyperpod/inference/hp_jumpstart_endpoint.py @@ -59,6 +59,7 @@ def create( kind=JUMPSTART_MODEL_KIND, namespace=namespace, spec=spec, + debug=debug, ) self.metadata = Metadata( @@ -76,9 +77,10 @@ def create_from_dict( input: Dict, name: str = None, namespace: str = None, + debug = False ) -> None: logger = self.get_logger() - logger = setup_logging(logger) + logger = setup_logging(logger, debug) spec = _HPJumpStartEndpoint.model_validate(input, by_name=True) @@ -102,6 +104,7 @@ def create_from_dict( kind=JUMPSTART_MODEL_KIND, namespace=namespace, spec=spec, + debug=debug, ) self.metadata = Metadata( diff --git a/test/unit_tests/inference/test_hp_endpoint.py b/test/unit_tests/inference/test_hp_endpoint.py index 2faaf384..ccface05 100644 --- a/test/unit_tests/inference/test_hp_endpoint.py +++ b/test/unit_tests/inference/test_hp_endpoint.py @@ -104,6 +104,7 @@ def test_create(self, mock_create_api, mock_validate_instance_type): kind=INFERENCE_ENDPOINT_CONFIG_KIND, namespace="test-ns", spec=unittest.mock.ANY, + debug=False, ) self.assertEqual(self.endpoint.metadata.name, "test-name") @@ -115,7 +116,13 @@ def test_create_from_dict(self, mock_create_api, mock_validate_instance_type): self.endpoint.create_from_dict(input_dict, namespace="test-ns") - mock_create_api.assert_called_once() + mock_create_api.assert_called_once_with( + name=unittest.mock.ANY, + kind=INFERENCE_ENDPOINT_CONFIG_KIND, + namespace="test-ns", + spec=unittest.mock.ANY, + debug=False, + ) @patch.object(HPEndpoint, "call_get_api") def test_refresh(self, mock_get_api): diff --git a/test/unit_tests/inference/test_hp_jumpstart_endpoint.py b/test/unit_tests/inference/test_hp_jumpstart_endpoint.py index 6887bcf0..452ad326 100644 --- a/test/unit_tests/inference/test_hp_jumpstart_endpoint.py +++ b/test/unit_tests/inference/test_hp_jumpstart_endpoint.py @@ -44,6 +44,7 @@ def test_create(self, mock_create_api, mock_validate_instance_type): kind=JUMPSTART_MODEL_KIND, namespace="test-ns", spec=unittest.mock.ANY, + debug=False, ) self.assertEqual(self.endpoint.metadata.name, "test-name") @@ -60,7 +61,13 @@ def test_create_from_dict(self, mock_create_api, mock_validate_instance_type): input_dict, name="test-name", namespace="test-ns" ) - mock_create_api.assert_called_once() + mock_create_api.assert_called_once_with( + name="test-name", + kind=JUMPSTART_MODEL_KIND, + namespace="test-ns", + spec=unittest.mock.ANY, + debug=False, + ) @patch.object(HPJumpStartEndpoint, "call_get_api") def test_refresh(self, mock_get_api):