Skip to content

Commit a534b62

Browse files
committed
stream inference endpoint
1 parent ab7e984 commit a534b62

File tree

2 files changed

+136
-55
lines changed

2 files changed

+136
-55
lines changed

ads/aqua/common/enums.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class PredictEndpoints(ExtendedEnum):
2424
CHAT_COMPLETIONS_ENDPOINT = "/v1/chat/completions"
2525
TEXT_COMPLETIONS_ENDPOINT = "/v1/completions"
2626
EMBEDDING_ENDPOINT = "/v1/embedding"
27+
RESPONSES = "/v1/responses"
2728

2829

2930
class Tags(ExtendedEnum):

ads/aqua/extension/deployment_handler.py

Lines changed: 135 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ def list_shapes(self):
221221

222222

223223
class AquaDeploymentStreamingInferenceHandler(AquaAPIhandler):
224+
224225
def _extract_text_from_choice(self, choice):
225226
# choice may be a dict or an object
226227
if isinstance(choice, dict):
@@ -246,23 +247,23 @@ def _extract_text_from_choice(self, choice):
246247
return getattr(choice, "text", None) or getattr(choice, "content", None)
247248

248249
def _extract_text_from_chunk(self, chunk):
249-
if isinstance(chunk, dict):
250-
choices = chunk.get("choices") or []
250+
if chunk :
251+
if isinstance(chunk, dict):
252+
choices = chunk.get("choices") or []
253+
if choices:
254+
return self._extract_text_from_choice(choices[0])
255+
# fallback top-level
256+
return chunk.get("text") or chunk.get("content")
257+
# object-like chunk
258+
choices = getattr(chunk, "choices", None)
251259
if choices:
252260
return self._extract_text_from_choice(choices[0])
253-
# fallback top-level
254-
return chunk.get("text") or chunk.get("content")
255-
# object-like chunk
256-
choices = getattr(chunk, "choices", None)
257-
if choices:
258-
return self._extract_text_from_choice(choices[0])
259-
return getattr(chunk, "text", None) or getattr(chunk, "content", None)
261+
return getattr(chunk, "text", None) or getattr(chunk, "content", None)
260262

261263
def _get_model_deployment_response(
262264
self,
263265
model_deployment_id: str,
264-
payload: dict,
265-
route_override_header: Optional[str],
266+
payload: dict
266267
):
267268
"""
268269
Returns the model deployment inference response in a streaming fashion.
@@ -309,11 +310,9 @@ def _get_model_deployment_response(
309310
"""
310311

311312
model_deployment = AquaDeploymentApp().get(model_deployment_id)
312-
endpoint = model_deployment.endpoint + "/predictWithResponseStream"
313-
endpoint_type = model_deployment.environment_variables.get(
314-
"MODEL_DEPLOY_PREDICT_ENDPOINT", PredictEndpoints.TEXT_COMPLETIONS_ENDPOINT
315-
)
316-
aqua_client = OpenAI(base_url=self.endpoint)
313+
endpoint = model_deployment.endpoint + "/predictWithResponseStream/v1"
314+
endpoint_type = payload["endpoint_type"]
315+
aqua_client = OpenAI(base_url=endpoint)
317316

318317
allowed = {
319318
"max_tokens",
@@ -327,64 +326,144 @@ def _get_model_deployment_response(
327326
"user",
328327
"echo",
329328
}
329+
responses_allowed = {
330+
"temperature", "top_p"
331+
}
330332

331333
# normalize and filter
332-
if self.params.get("stop") == []:
333-
self.params["stop"] = None
334+
if payload.get("stop") == []:
335+
payload["stop"] = None
334336

335-
model = self.params.pop("model")
336-
filtered = {k: v for k, v in self.params.items() if k in allowed}
337+
encoded_image = "NA"
338+
if encoded_image in payload :
339+
encoded_image = payload["encoded_image"]
337340

338-
if PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT in (
339-
endpoint_type,
340-
route_override_header,
341-
):
341+
model = payload.pop("model")
342+
filtered = {k: v for k, v in payload.items() if k in allowed}
343+
responses_filtered = {k: v for k, v in payload.items() if k in responses_allowed}
344+
345+
if PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT == endpoint_type and encoded_image == "NA":
342346
try:
343-
for chunk in aqua_client.chat.completions.create(
344-
model=model,
345-
messages=[{"role": "user", "content": self.prompt}],
346-
stream=True,
347-
**filtered,
348-
):
349-
yield self._extract_text_from_chunk(chunk)
350-
# try:
351-
# if "text" in chunk["choices"][0]:
352-
# yield chunk["choices"][0]["text"]
353-
# elif "content" in chunk["choices"][0]["delta"]:
354-
# yield chunk["choices"][0]["delta"]["content"]
355-
# except Exception as e:
356-
# logger.debug(
357-
# f"Exception occurred while parsing streaming response: {e}"
358-
# )
347+
api_kwargs = {
348+
"model": model,
349+
"messages": [{"role": "user", "content": payload["prompt"]}],
350+
"stream": True,
351+
**filtered
352+
}
353+
if "chat_template" in payload:
354+
chat_template = payload.pop("chat_template")
355+
api_kwargs["extra_body"] = {"chat_template": chat_template}
356+
357+
stream = aqua_client.chat.completions.create(**api_kwargs)
358+
359+
for chunk in stream:
360+
if chunk :
361+
piece = self._extract_text_from_chunk(chunk)
362+
if piece :
363+
yield piece
359364
except ExtendedRequestError as ex:
360365
raise HTTPError(400, str(ex))
361366
except Exception as ex:
362367
raise HTTPError(500, str(ex))
363368

369+
elif (
370+
endpoint_type == PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT
371+
and encoded_image != "NA"
372+
):
373+
file_type = payload.pop("file_type")
374+
if file_type.startswith("image"):
375+
api_kwargs = {
376+
"model": model,
377+
"messages": [
378+
{
379+
"role": "user",
380+
"content": [
381+
{"type": "text", "text": payload["prompt"]},
382+
{
383+
"type": "image_url",
384+
"image_url": {"url": f"{self.encoded_image}"},
385+
},
386+
],
387+
}
388+
],
389+
"stream": True,
390+
**filtered
391+
}
392+
393+
# Add chat_template for image-based chat completions
394+
if "chat_template" in payload:
395+
chat_template = payload.pop("chat_template")
396+
api_kwargs["extra_body"] = {"chat_template": chat_template}
397+
398+
response = aqua_client.chat.completions.create(**api_kwargs)
399+
400+
elif self.file_type.startswith("audio"):
401+
api_kwargs = {
402+
"model": model,
403+
"messages": [
404+
{
405+
"role": "user",
406+
"content": [
407+
{"type": "text", "text": payload["prompt"]},
408+
{
409+
"type": "audio_url",
410+
"audio_url": {"url": f"{self.encoded_image}"},
411+
},
412+
],
413+
}
414+
],
415+
"stream": True,
416+
**filtered
417+
}
418+
419+
# Add chat_template for audio-based chat completions
420+
if "chat_template" in payload:
421+
chat_template = payload.pop("chat_template")
422+
api_kwargs["extra_body"] = {"chat_template": chat_template}
423+
424+
response = aqua_client.chat.completions.create(**api_kwargs)
425+
try:
426+
for chunk in response:
427+
piece = self._extract_text_from_chunk(chunk)
428+
if piece:
429+
print(piece, end="", flush=True)
430+
except ExtendedRequestError as ex:
431+
raise HTTPError(400, str(ex))
432+
except Exception as ex:
433+
raise HTTPError(500, str(ex))
364434
elif endpoint_type == PredictEndpoints.TEXT_COMPLETIONS_ENDPOINT:
365435
try:
366-
for chunk in aqua_client.self.session.completions.create(
367-
prompt=self.prompt, stream=True, model=model, **filtered
436+
for chunk in aqua_client.completions.create(
437+
prompt=payload["prompt"], stream=True, model=model, **filtered
368438
):
369-
yield self._extract_text_from_chunk(chunk)
370-
# try:
371-
# yield chunk["choices"][0]["text"]
372-
# except Exception as e:
373-
# logger.debug(
374-
# f"Exception occurred while parsing streaming response: {e}"
375-
# )
439+
if chunk :
440+
piece = self._extract_text_from_chunk(chunk)
441+
if piece :
442+
yield piece
376443
except ExtendedRequestError as ex:
377444
raise HTTPError(400, str(ex))
378445
except Exception as ex:
379446
raise HTTPError(500, str(ex))
380447

381448
elif endpoint_type == PredictEndpoints.RESPONSES:
382-
response = aqua_client.responses.create(
383-
prompt=self.prompt, stream=True, model=model, **filtered
384-
)
449+
api_kwargs = {
450+
"model": model,
451+
"input": payload["prompt"],
452+
"stream": True
453+
}
454+
455+
if "temperature" in responses_filtered:
456+
api_kwargs["temperature"] = responses_filtered["temperature"]
457+
if "top_p" in responses_filtered:
458+
api_kwargs["top_p"] = responses_filtered["top_p"]
459+
460+
response = aqua_client.responses.create(**api_kwargs)
385461
try:
386462
for chunk in response:
387-
yield self._extract_text_from_chunk(chunk)
463+
if chunk :
464+
piece = self._extract_text_from_chunk(chunk)
465+
if piece :
466+
yield piece
388467
except ExtendedRequestError as ex:
389468
raise HTTPError(400, str(ex))
390469
except Exception as ex:
@@ -410,19 +489,20 @@ def post(self, model_deployment_id):
410489
prompt = input_data.get("prompt")
411490
messages = input_data.get("messages")
412491

492+
413493
if not prompt and not messages:
414494
raise HTTPError(
415495
400, Errors.MISSING_REQUIRED_PARAMETER.format("prompt/messages")
416496
)
417497
if not input_data.get("model"):
418498
raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("model"))
419-
route_override_header = self.request.headers.get("route", None)
420499
self.set_header("Content-Type", "text/event-stream")
421500
response_gen = self._get_model_deployment_response(
422-
model_deployment_id, input_data, route_override_header
501+
model_deployment_id, input_data
423502
)
424503
try:
425504
for chunk in response_gen:
505+
print(chunk)
426506
self.write(chunk)
427507
self.flush()
428508
self.finish()

0 commit comments

Comments
 (0)