diff --git a/ads/aqua/common/enums.py b/ads/aqua/common/enums.py index 21a71606c..3897eb049 100644 --- a/ads/aqua/common/enums.py +++ b/ads/aqua/common/enums.py @@ -24,6 +24,7 @@ class PredictEndpoints(ExtendedEnum): CHAT_COMPLETIONS_ENDPOINT = "/v1/chat/completions" TEXT_COMPLETIONS_ENDPOINT = "/v1/completions" EMBEDDING_ENDPOINT = "/v1/embedding" + RESPONSES = "/v1/responses" class Tags(ExtendedEnum): diff --git a/ads/aqua/extension/deployment_handler.py b/ads/aqua/extension/deployment_handler.py index 3d23a1052..4c5d264cf 100644 --- a/ads/aqua/extension/deployment_handler.py +++ b/ads/aqua/extension/deployment_handler.py @@ -7,8 +7,8 @@ from tornado.web import HTTPError -from ads.aqua.app import logger from ads.aqua.client.client import Client, ExtendedRequestError +from ads.aqua.client.openai_client import OpenAI from ads.aqua.common.decorator import handle_exceptions from ads.aqua.common.enums import PredictEndpoints from ads.aqua.extension.base_handler import AquaAPIhandler @@ -221,11 +221,49 @@ def list_shapes(self): class AquaDeploymentStreamingInferenceHandler(AquaAPIhandler): + + def _extract_text_from_choice(self, choice): + # choice may be a dict or an object + if isinstance(choice, dict): + # streaming chunk: {"delta": {"content": "..."}} + delta = choice.get("delta") + if isinstance(delta, dict): + return delta.get("content") or delta.get("text") or None + # non-streaming: {"message": {"content": "..."}} + msg = choice.get("message") + if isinstance(msg, dict): + return msg.get("content") or msg.get("text") + # fallback top-level fields + return choice.get("text") or choice.get("content") + # object-like choice + delta = getattr(choice, "delta", None) + if delta is not None: + return getattr(delta, "content", None) or getattr(delta, "text", None) + msg = getattr(choice, "message", None) + if msg is not None: + if isinstance(msg, str): + return msg + return getattr(msg, "content", None) or getattr(msg, "text", None) + return getattr(choice, "text", None) or getattr(choice, "content", None) + + def _extract_text_from_chunk(self, chunk): + if chunk : + if isinstance(chunk, dict): + choices = chunk.get("choices") or [] + if choices: + return self._extract_text_from_choice(choices[0]) + # fallback top-level + return chunk.get("text") or chunk.get("content") + # object-like chunk + choices = getattr(chunk, "choices", None) + if choices: + return self._extract_text_from_choice(choices[0]) + return getattr(chunk, "text", None) or getattr(chunk, "content", None) + def _get_model_deployment_response( self, model_deployment_id: str, - payload: dict, - route_override_header: Optional[str], + payload: dict ): """ Returns the model deployment inference response in a streaming fashion. @@ -272,49 +310,160 @@ def _get_model_deployment_response( """ model_deployment = AquaDeploymentApp().get(model_deployment_id) - endpoint = model_deployment.endpoint + "/predictWithResponseStream" - endpoint_type = model_deployment.environment_variables.get( - "MODEL_DEPLOY_PREDICT_ENDPOINT", PredictEndpoints.TEXT_COMPLETIONS_ENDPOINT - ) - aqua_client = Client(endpoint=endpoint) - - if PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT in ( - endpoint_type, - route_override_header, - ): + endpoint = model_deployment.endpoint + "/predictWithResponseStream/v1" + endpoint_type = payload["endpoint_type"] + aqua_client = OpenAI(base_url=endpoint) + + allowed = { + "max_tokens", + "temperature", + "top_p", + "stop", + "n", + "presence_penalty", + "frequency_penalty", + "logprobs", + "user", + "echo", + } + responses_allowed = { + "temperature", "top_p" + } + + # normalize and filter + if payload.get("stop") == []: + payload["stop"] = None + + encoded_image = "NA" + if encoded_image in payload : + encoded_image = payload["encoded_image"] + + model = payload.pop("model") + filtered = {k: v for k, v in payload.items() if k in allowed} + responses_filtered = {k: v for k, v in payload.items() if k in responses_allowed} + + if PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT == endpoint_type and encoded_image == "NA": try: - for chunk in aqua_client.chat( - messages=payload.pop("messages"), - payload=payload, - stream=True, - ): - try: - if "text" in chunk["choices"][0]: - yield chunk["choices"][0]["text"] - elif "content" in chunk["choices"][0]["delta"]: - yield chunk["choices"][0]["delta"]["content"] - except Exception as e: - logger.debug( - f"Exception occurred while parsing streaming response: {e}" - ) + api_kwargs = { + "model": model, + "messages": [{"role": "user", "content": payload["prompt"]}], + "stream": True, + **filtered + } + if "chat_template" in payload: + chat_template = payload.pop("chat_template") + api_kwargs["extra_body"] = {"chat_template": chat_template} + + stream = aqua_client.chat.completions.create(**api_kwargs) + + for chunk in stream: + if chunk : + piece = self._extract_text_from_chunk(chunk) + if piece : + yield piece except ExtendedRequestError as ex: raise HTTPError(400, str(ex)) except Exception as ex: raise HTTPError(500, str(ex)) + elif ( + endpoint_type == PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT + and encoded_image != "NA" + ): + file_type = payload.pop("file_type") + if file_type.startswith("image"): + api_kwargs = { + "model": model, + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": payload["prompt"]}, + { + "type": "image_url", + "image_url": {"url": f"{self.encoded_image}"}, + }, + ], + } + ], + "stream": True, + **filtered + } + + # Add chat_template for image-based chat completions + if "chat_template" in payload: + chat_template = payload.pop("chat_template") + api_kwargs["extra_body"] = {"chat_template": chat_template} + + response = aqua_client.chat.completions.create(**api_kwargs) + + elif self.file_type.startswith("audio"): + api_kwargs = { + "model": model, + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": payload["prompt"]}, + { + "type": "audio_url", + "audio_url": {"url": f"{self.encoded_image}"}, + }, + ], + } + ], + "stream": True, + **filtered + } + + # Add chat_template for audio-based chat completions + if "chat_template" in payload: + chat_template = payload.pop("chat_template") + api_kwargs["extra_body"] = {"chat_template": chat_template} + + response = aqua_client.chat.completions.create(**api_kwargs) + try: + for chunk in response: + piece = self._extract_text_from_chunk(chunk) + if piece: + print(piece, end="", flush=True) + except ExtendedRequestError as ex: + raise HTTPError(400, str(ex)) + except Exception as ex: + raise HTTPError(500, str(ex)) elif endpoint_type == PredictEndpoints.TEXT_COMPLETIONS_ENDPOINT: try: - for chunk in aqua_client.generate( - prompt=payload.pop("prompt"), - payload=payload, - stream=True, + for chunk in aqua_client.completions.create( + prompt=payload["prompt"], stream=True, model=model, **filtered ): - try: - yield chunk["choices"][0]["text"] - except Exception as e: - logger.debug( - f"Exception occurred while parsing streaming response: {e}" - ) + if chunk : + piece = self._extract_text_from_chunk(chunk) + if piece : + yield piece + except ExtendedRequestError as ex: + raise HTTPError(400, str(ex)) + except Exception as ex: + raise HTTPError(500, str(ex)) + + elif endpoint_type == PredictEndpoints.RESPONSES: + api_kwargs = { + "model": model, + "input": payload["prompt"], + "stream": True + } + + if "temperature" in responses_filtered: + api_kwargs["temperature"] = responses_filtered["temperature"] + if "top_p" in responses_filtered: + api_kwargs["top_p"] = responses_filtered["top_p"] + + response = aqua_client.responses.create(**api_kwargs) + try: + for chunk in response: + if chunk : + piece = self._extract_text_from_chunk(chunk) + if piece : + yield piece except ExtendedRequestError as ex: raise HTTPError(400, str(ex)) except Exception as ex: @@ -340,19 +489,20 @@ def post(self, model_deployment_id): prompt = input_data.get("prompt") messages = input_data.get("messages") + if not prompt and not messages: raise HTTPError( 400, Errors.MISSING_REQUIRED_PARAMETER.format("prompt/messages") ) if not input_data.get("model"): raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("model")) - route_override_header = self.request.headers.get("route", None) self.set_header("Content-Type", "text/event-stream") response_gen = self._get_model_deployment_response( - model_deployment_id, input_data, route_override_header + model_deployment_id, input_data ) try: for chunk in response_gen: + print(chunk) self.write(chunk) self.flush() self.finish() diff --git a/tests/unitary/with_extras/aqua/test_deployment_handler.py b/tests/unitary/with_extras/aqua/test_deployment_handler.py index f6ca6d271..b869cccdf 100644 --- a/tests/unitary/with_extras/aqua/test_deployment_handler.py +++ b/tests/unitary/with_extras/aqua/test_deployment_handler.py @@ -274,8 +274,7 @@ def test_post(self, mock_get_model_deployment_response): mock_get_model_deployment_response.assert_called_with( "mock-deployment-id", - {"prompt": "Hello", "model": "some-model"}, - "test-route", + {"prompt": "Hello", "model": "some-model"} ) self.handler.write.assert_any_call("chunk1") self.handler.write.assert_any_call("chunk2")