Skip to content

Commit 23a87a9

Browse files
committed
deployment inference using openAI client
1 parent 38992ff commit 23a87a9

File tree

1 file changed

+94
-24
lines changed

1 file changed

+94
-24
lines changed

ads/aqua/extension/deployment_handler.py

Lines changed: 94 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77

88
from tornado.web import HTTPError
99

10-
from ads.aqua.app import logger
1110
from ads.aqua.client.client import Client, ExtendedRequestError
11+
from ads.aqua.client.openai_client import OpenAI
1212
from ads.aqua.common.decorator import handle_exceptions
1313
from ads.aqua.common.enums import PredictEndpoints
1414
from ads.aqua.extension.base_handler import AquaAPIhandler
@@ -178,6 +178,43 @@ def list_shapes(self):
178178

179179

180180
class AquaDeploymentStreamingInferenceHandler(AquaAPIhandler):
181+
def _extract_text_from_choice(self, choice):
182+
# choice may be a dict or an object
183+
if isinstance(choice, dict):
184+
# streaming chunk: {"delta": {"content": "..."}}
185+
delta = choice.get("delta")
186+
if isinstance(delta, dict):
187+
return delta.get("content") or delta.get("text") or None
188+
# non-streaming: {"message": {"content": "..."}}
189+
msg = choice.get("message")
190+
if isinstance(msg, dict):
191+
return msg.get("content") or msg.get("text")
192+
# fallback top-level fields
193+
return choice.get("text") or choice.get("content")
194+
# object-like choice
195+
delta = getattr(choice, "delta", None)
196+
if delta is not None:
197+
return getattr(delta, "content", None) or getattr(delta, "text", None)
198+
msg = getattr(choice, "message", None)
199+
if msg is not None:
200+
if isinstance(msg, str):
201+
return msg
202+
return getattr(msg, "content", None) or getattr(msg, "text", None)
203+
return getattr(choice, "text", None) or getattr(choice, "content", None)
204+
205+
def _extract_text_from_chunk(self, chunk):
206+
if isinstance(chunk, dict):
207+
choices = chunk.get("choices") or []
208+
if choices:
209+
return self._extract_text_from_choice(choices[0])
210+
# fallback top-level
211+
return chunk.get("text") or chunk.get("content")
212+
# object-like chunk
213+
choices = getattr(chunk, "choices", None)
214+
if choices:
215+
return self._extract_text_from_choice(choices[0])
216+
return getattr(chunk, "text", None) or getattr(chunk, "content", None)
217+
181218
def _get_model_deployment_response(
182219
self,
183220
model_deployment_id: str,
@@ -233,45 +270,78 @@ def _get_model_deployment_response(
233270
endpoint_type = model_deployment.environment_variables.get(
234271
"MODEL_DEPLOY_PREDICT_ENDPOINT", PredictEndpoints.TEXT_COMPLETIONS_ENDPOINT
235272
)
236-
aqua_client = Client(endpoint=endpoint)
273+
aqua_client = OpenAI(base_url=self.endpoint)
274+
275+
allowed = {
276+
"max_tokens",
277+
"temperature",
278+
"top_p",
279+
"stop",
280+
"n",
281+
"presence_penalty",
282+
"frequency_penalty",
283+
"logprobs",
284+
"user",
285+
"echo",
286+
}
287+
288+
# normalize and filter
289+
if self.params.get("stop") == []:
290+
self.params["stop"] = None
291+
292+
model = self.params.pop("model")
293+
filtered = {k: v for k, v in self.params.items() if k in allowed}
237294

238295
if PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT in (
239296
endpoint_type,
240297
route_override_header,
241298
):
242299
try:
243-
for chunk in aqua_client.chat(
244-
messages=payload.pop("messages"),
245-
payload=payload,
300+
for chunk in aqua_client.chat.completions.create(
301+
model=model,
302+
messages=[{"role": "user", "content": self.prompt}],
246303
stream=True,
304+
**filtered,
247305
):
248-
try:
249-
if "text" in chunk["choices"][0]:
250-
yield chunk["choices"][0]["text"]
251-
elif "content" in chunk["choices"][0]["delta"]:
252-
yield chunk["choices"][0]["delta"]["content"]
253-
except Exception as e:
254-
logger.debug(
255-
f"Exception occurred while parsing streaming response: {e}"
256-
)
306+
yield self._extract_text_from_chunk(chunk)
307+
# try:
308+
# if "text" in chunk["choices"][0]:
309+
# yield chunk["choices"][0]["text"]
310+
# elif "content" in chunk["choices"][0]["delta"]:
311+
# yield chunk["choices"][0]["delta"]["content"]
312+
# except Exception as e:
313+
# logger.debug(
314+
# f"Exception occurred while parsing streaming response: {e}"
315+
# )
257316
except ExtendedRequestError as ex:
258317
raise HTTPError(400, str(ex))
259318
except Exception as ex:
260319
raise HTTPError(500, str(ex))
261320

262321
elif endpoint_type == PredictEndpoints.TEXT_COMPLETIONS_ENDPOINT:
263322
try:
264-
for chunk in aqua_client.generate(
265-
prompt=payload.pop("prompt"),
266-
payload=payload,
267-
stream=True,
323+
for chunk in aqua_client.self.session.completions.create(
324+
prompt=self.prompt, stream=True, model=model, **filtered
268325
):
269-
try:
270-
yield chunk["choices"][0]["text"]
271-
except Exception as e:
272-
logger.debug(
273-
f"Exception occurred while parsing streaming response: {e}"
274-
)
326+
yield self._extract_text_from_chunk(chunk)
327+
# try:
328+
# yield chunk["choices"][0]["text"]
329+
# except Exception as e:
330+
# logger.debug(
331+
# f"Exception occurred while parsing streaming response: {e}"
332+
# )
333+
except ExtendedRequestError as ex:
334+
raise HTTPError(400, str(ex))
335+
except Exception as ex:
336+
raise HTTPError(500, str(ex))
337+
338+
elif endpoint_type == PredictEndpoints.RESPONSES:
339+
response = aqua_client.responses.create(
340+
prompt=self.prompt, stream=True, model=model, **filtered
341+
)
342+
try:
343+
for chunk in response:
344+
yield self._extract_text_from_chunk(chunk)
275345
except ExtendedRequestError as ex:
276346
raise HTTPError(400, str(ex))
277347
except Exception as ex:

0 commit comments

Comments
 (0)