@@ -221,6 +221,7 @@ def list_shapes(self):
221221
222222
223223class AquaDeploymentStreamingInferenceHandler (AquaAPIhandler ):
224+
224225 def _extract_text_from_choice (self , choice ):
225226 # choice may be a dict or an object
226227 if isinstance (choice , dict ):
@@ -246,23 +247,23 @@ def _extract_text_from_choice(self, choice):
246247 return getattr (choice , "text" , None ) or getattr (choice , "content" , None )
247248
248249 def _extract_text_from_chunk (self , chunk ):
249- if isinstance (chunk , dict ):
250- choices = chunk .get ("choices" ) or []
250+ if chunk :
251+ if isinstance (chunk , dict ):
252+ choices = chunk .get ("choices" ) or []
253+ if choices :
254+ return self ._extract_text_from_choice (choices [0 ])
255+ # fallback top-level
256+ return chunk .get ("text" ) or chunk .get ("content" )
257+ # object-like chunk
258+ choices = getattr (chunk , "choices" , None )
251259 if choices :
252260 return self ._extract_text_from_choice (choices [0 ])
253- # fallback top-level
254- return chunk .get ("text" ) or chunk .get ("content" )
255- # object-like chunk
256- choices = getattr (chunk , "choices" , None )
257- if choices :
258- return self ._extract_text_from_choice (choices [0 ])
259- return getattr (chunk , "text" , None ) or getattr (chunk , "content" , None )
261+ return getattr (chunk , "text" , None ) or getattr (chunk , "content" , None )
260262
261263 def _get_model_deployment_response (
262264 self ,
263265 model_deployment_id : str ,
264- payload : dict ,
265- route_override_header : Optional [str ],
266+ payload : dict
266267 ):
267268 """
268269 Returns the model deployment inference response in a streaming fashion.
@@ -309,11 +310,9 @@ def _get_model_deployment_response(
309310 """
310311
311312 model_deployment = AquaDeploymentApp ().get (model_deployment_id )
312- endpoint = model_deployment .endpoint + "/predictWithResponseStream"
313- endpoint_type = model_deployment .environment_variables .get (
314- "MODEL_DEPLOY_PREDICT_ENDPOINT" , PredictEndpoints .TEXT_COMPLETIONS_ENDPOINT
315- )
316- aqua_client = OpenAI (base_url = self .endpoint )
313+ endpoint = model_deployment .endpoint + "/predictWithResponseStream/v1"
314+ endpoint_type = payload ["endpoint_type" ]
315+ aqua_client = OpenAI (base_url = endpoint )
317316
318317 allowed = {
319318 "max_tokens" ,
@@ -327,64 +326,144 @@ def _get_model_deployment_response(
327326 "user" ,
328327 "echo" ,
329328 }
329+ responses_allowed = {
330+ "temperature" , "top_p"
331+ }
330332
331333 # normalize and filter
332- if self . params .get ("stop" ) == []:
333- self . params ["stop" ] = None
334+ if payload .get ("stop" ) == []:
335+ payload ["stop" ] = None
334336
335- model = self .params .pop ("model" )
336- filtered = {k : v for k , v in self .params .items () if k in allowed }
337+ encoded_image = "NA"
338+ if encoded_image in payload :
339+ encoded_image = payload ["encoded_image" ]
337340
338- if PredictEndpoints .CHAT_COMPLETIONS_ENDPOINT in (
339- endpoint_type ,
340- route_override_header ,
341- ):
341+ model = payload .pop ("model" )
342+ filtered = {k : v for k , v in payload .items () if k in allowed }
343+ responses_filtered = {k : v for k , v in payload .items () if k in responses_allowed }
344+
345+ if PredictEndpoints .CHAT_COMPLETIONS_ENDPOINT == endpoint_type and encoded_image == "NA" :
342346 try :
343- for chunk in aqua_client .chat .completions .create (
344- model = model ,
345- messages = [{"role" : "user" , "content" : self .prompt }],
346- stream = True ,
347- ** filtered ,
348- ):
349- yield self ._extract_text_from_chunk (chunk )
350- # try:
351- # if "text" in chunk["choices"][0]:
352- # yield chunk["choices"][0]["text"]
353- # elif "content" in chunk["choices"][0]["delta"]:
354- # yield chunk["choices"][0]["delta"]["content"]
355- # except Exception as e:
356- # logger.debug(
357- # f"Exception occurred while parsing streaming response: {e}"
358- # )
347+ api_kwargs = {
348+ "model" : model ,
349+ "messages" : [{"role" : "user" , "content" : payload ["prompt" ]}],
350+ "stream" : True ,
351+ ** filtered
352+ }
353+ if "chat_template" in payload :
354+ chat_template = payload .pop ("chat_template" )
355+ api_kwargs ["extra_body" ] = {"chat_template" : chat_template }
356+
357+ stream = aqua_client .chat .completions .create (** api_kwargs )
358+
359+ for chunk in stream :
360+ if chunk :
361+ piece = self ._extract_text_from_chunk (chunk )
362+ if piece :
363+ yield piece
359364 except ExtendedRequestError as ex :
360365 raise HTTPError (400 , str (ex ))
361366 except Exception as ex :
362367 raise HTTPError (500 , str (ex ))
363368
369+ elif (
370+ endpoint_type == PredictEndpoints .CHAT_COMPLETIONS_ENDPOINT
371+ and encoded_image != "NA"
372+ ):
373+ file_type = payload .pop ("file_type" )
374+ if file_type .startswith ("image" ):
375+ api_kwargs = {
376+ "model" : model ,
377+ "messages" : [
378+ {
379+ "role" : "user" ,
380+ "content" : [
381+ {"type" : "text" , "text" : payload ["prompt" ]},
382+ {
383+ "type" : "image_url" ,
384+ "image_url" : {"url" : f"{ self .encoded_image } " },
385+ },
386+ ],
387+ }
388+ ],
389+ "stream" : True ,
390+ ** filtered
391+ }
392+
393+ # Add chat_template for image-based chat completions
394+ if "chat_template" in payload :
395+ chat_template = payload .pop ("chat_template" )
396+ api_kwargs ["extra_body" ] = {"chat_template" : chat_template }
397+
398+ response = aqua_client .chat .completions .create (** api_kwargs )
399+
400+ elif self .file_type .startswith ("audio" ):
401+ api_kwargs = {
402+ "model" : model ,
403+ "messages" : [
404+ {
405+ "role" : "user" ,
406+ "content" : [
407+ {"type" : "text" , "text" : payload ["prompt" ]},
408+ {
409+ "type" : "audio_url" ,
410+ "audio_url" : {"url" : f"{ self .encoded_image } " },
411+ },
412+ ],
413+ }
414+ ],
415+ "stream" : True ,
416+ ** filtered
417+ }
418+
419+ # Add chat_template for audio-based chat completions
420+ if "chat_template" in payload :
421+ chat_template = payload .pop ("chat_template" )
422+ api_kwargs ["extra_body" ] = {"chat_template" : chat_template }
423+
424+ response = aqua_client .chat .completions .create (** api_kwargs )
425+ try :
426+ for chunk in response :
427+ piece = self ._extract_text_from_chunk (chunk )
428+ if piece :
429+ print (piece , end = "" , flush = True )
430+ except ExtendedRequestError as ex :
431+ raise HTTPError (400 , str (ex ))
432+ except Exception as ex :
433+ raise HTTPError (500 , str (ex ))
364434 elif endpoint_type == PredictEndpoints .TEXT_COMPLETIONS_ENDPOINT :
365435 try :
366- for chunk in aqua_client .self . session . completions .create (
367- prompt = self . prompt , stream = True , model = model , ** filtered
436+ for chunk in aqua_client .completions .create (
437+ prompt = payload [ " prompt" ] , stream = True , model = model , ** filtered
368438 ):
369- yield self ._extract_text_from_chunk (chunk )
370- # try:
371- # yield chunk["choices"][0]["text"]
372- # except Exception as e:
373- # logger.debug(
374- # f"Exception occurred while parsing streaming response: {e}"
375- # )
439+ if chunk :
440+ piece = self ._extract_text_from_chunk (chunk )
441+ if piece :
442+ yield piece
376443 except ExtendedRequestError as ex :
377444 raise HTTPError (400 , str (ex ))
378445 except Exception as ex :
379446 raise HTTPError (500 , str (ex ))
380447
381448 elif endpoint_type == PredictEndpoints .RESPONSES :
382- response = aqua_client .responses .create (
383- prompt = self .prompt , stream = True , model = model , ** filtered
384- )
449+ api_kwargs = {
450+ "model" : model ,
451+ "input" : payload ["prompt" ],
452+ "stream" : True
453+ }
454+
455+ if "temperature" in responses_filtered :
456+ api_kwargs ["temperature" ] = responses_filtered ["temperature" ]
457+ if "top_p" in responses_filtered :
458+ api_kwargs ["top_p" ] = responses_filtered ["top_p" ]
459+
460+ response = aqua_client .responses .create (** api_kwargs )
385461 try :
386462 for chunk in response :
387- yield self ._extract_text_from_chunk (chunk )
463+ if chunk :
464+ piece = self ._extract_text_from_chunk (chunk )
465+ if piece :
466+ yield piece
388467 except ExtendedRequestError as ex :
389468 raise HTTPError (400 , str (ex ))
390469 except Exception as ex :
@@ -410,19 +489,20 @@ def post(self, model_deployment_id):
410489 prompt = input_data .get ("prompt" )
411490 messages = input_data .get ("messages" )
412491
492+
413493 if not prompt and not messages :
414494 raise HTTPError (
415495 400 , Errors .MISSING_REQUIRED_PARAMETER .format ("prompt/messages" )
416496 )
417497 if not input_data .get ("model" ):
418498 raise HTTPError (400 , Errors .MISSING_REQUIRED_PARAMETER .format ("model" ))
419- route_override_header = self .request .headers .get ("route" , None )
420499 self .set_header ("Content-Type" , "text/event-stream" )
421500 response_gen = self ._get_model_deployment_response (
422- model_deployment_id , input_data , route_override_header
501+ model_deployment_id , input_data
423502 )
424503 try :
425504 for chunk in response_gen :
505+ print (chunk )
426506 self .write (chunk )
427507 self .flush ()
428508 self .finish ()
0 commit comments