|
7 | 7 |
|
8 | 8 | from tornado.web import HTTPError |
9 | 9 |
|
10 | | -from ads.aqua.app import logger |
11 | 10 | from ads.aqua.client.client import Client, ExtendedRequestError |
| 11 | +from ads.aqua.client.openai_client import OpenAI |
12 | 12 | from ads.aqua.common.decorator import handle_exceptions |
13 | 13 | from ads.aqua.common.enums import PredictEndpoints |
14 | 14 | from ads.aqua.extension.base_handler import AquaAPIhandler |
@@ -178,6 +178,43 @@ def list_shapes(self): |
178 | 178 |
|
179 | 179 |
|
180 | 180 | class AquaDeploymentStreamingInferenceHandler(AquaAPIhandler): |
| 181 | + def _extract_text_from_choice(self, choice): |
| 182 | + # choice may be a dict or an object |
| 183 | + if isinstance(choice, dict): |
| 184 | + # streaming chunk: {"delta": {"content": "..."}} |
| 185 | + delta = choice.get("delta") |
| 186 | + if isinstance(delta, dict): |
| 187 | + return delta.get("content") or delta.get("text") or None |
| 188 | + # non-streaming: {"message": {"content": "..."}} |
| 189 | + msg = choice.get("message") |
| 190 | + if isinstance(msg, dict): |
| 191 | + return msg.get("content") or msg.get("text") |
| 192 | + # fallback top-level fields |
| 193 | + return choice.get("text") or choice.get("content") |
| 194 | + # object-like choice |
| 195 | + delta = getattr(choice, "delta", None) |
| 196 | + if delta is not None: |
| 197 | + return getattr(delta, "content", None) or getattr(delta, "text", None) |
| 198 | + msg = getattr(choice, "message", None) |
| 199 | + if msg is not None: |
| 200 | + if isinstance(msg, str): |
| 201 | + return msg |
| 202 | + return getattr(msg, "content", None) or getattr(msg, "text", None) |
| 203 | + return getattr(choice, "text", None) or getattr(choice, "content", None) |
| 204 | + |
| 205 | + def _extract_text_from_chunk(self, chunk): |
| 206 | + if isinstance(chunk, dict): |
| 207 | + choices = chunk.get("choices") or [] |
| 208 | + if choices: |
| 209 | + return self._extract_text_from_choice(choices[0]) |
| 210 | + # fallback top-level |
| 211 | + return chunk.get("text") or chunk.get("content") |
| 212 | + # object-like chunk |
| 213 | + choices = getattr(chunk, "choices", None) |
| 214 | + if choices: |
| 215 | + return self._extract_text_from_choice(choices[0]) |
| 216 | + return getattr(chunk, "text", None) or getattr(chunk, "content", None) |
| 217 | + |
181 | 218 | def _get_model_deployment_response( |
182 | 219 | self, |
183 | 220 | model_deployment_id: str, |
@@ -233,45 +270,78 @@ def _get_model_deployment_response( |
233 | 270 | endpoint_type = model_deployment.environment_variables.get( |
234 | 271 | "MODEL_DEPLOY_PREDICT_ENDPOINT", PredictEndpoints.TEXT_COMPLETIONS_ENDPOINT |
235 | 272 | ) |
236 | | - aqua_client = Client(endpoint=endpoint) |
| 273 | + aqua_client = OpenAI(base_url=self.endpoint) |
| 274 | + |
| 275 | + allowed = { |
| 276 | + "max_tokens", |
| 277 | + "temperature", |
| 278 | + "top_p", |
| 279 | + "stop", |
| 280 | + "n", |
| 281 | + "presence_penalty", |
| 282 | + "frequency_penalty", |
| 283 | + "logprobs", |
| 284 | + "user", |
| 285 | + "echo", |
| 286 | + } |
| 287 | + |
| 288 | + # normalize and filter |
| 289 | + if self.params.get("stop") == []: |
| 290 | + self.params["stop"] = None |
| 291 | + |
| 292 | + model = self.params.pop("model") |
| 293 | + filtered = {k: v for k, v in self.params.items() if k in allowed} |
237 | 294 |
|
238 | 295 | if PredictEndpoints.CHAT_COMPLETIONS_ENDPOINT in ( |
239 | 296 | endpoint_type, |
240 | 297 | route_override_header, |
241 | 298 | ): |
242 | 299 | try: |
243 | | - for chunk in aqua_client.chat( |
244 | | - messages=payload.pop("messages"), |
245 | | - payload=payload, |
| 300 | + for chunk in aqua_client.chat.completions.create( |
| 301 | + model=model, |
| 302 | + messages=[{"role": "user", "content": self.prompt}], |
246 | 303 | stream=True, |
| 304 | + **filtered, |
247 | 305 | ): |
248 | | - try: |
249 | | - if "text" in chunk["choices"][0]: |
250 | | - yield chunk["choices"][0]["text"] |
251 | | - elif "content" in chunk["choices"][0]["delta"]: |
252 | | - yield chunk["choices"][0]["delta"]["content"] |
253 | | - except Exception as e: |
254 | | - logger.debug( |
255 | | - f"Exception occurred while parsing streaming response: {e}" |
256 | | - ) |
| 306 | + yield self._extract_text_from_chunk(chunk) |
| 307 | + # try: |
| 308 | + # if "text" in chunk["choices"][0]: |
| 309 | + # yield chunk["choices"][0]["text"] |
| 310 | + # elif "content" in chunk["choices"][0]["delta"]: |
| 311 | + # yield chunk["choices"][0]["delta"]["content"] |
| 312 | + # except Exception as e: |
| 313 | + # logger.debug( |
| 314 | + # f"Exception occurred while parsing streaming response: {e}" |
| 315 | + # ) |
257 | 316 | except ExtendedRequestError as ex: |
258 | 317 | raise HTTPError(400, str(ex)) |
259 | 318 | except Exception as ex: |
260 | 319 | raise HTTPError(500, str(ex)) |
261 | 320 |
|
262 | 321 | elif endpoint_type == PredictEndpoints.TEXT_COMPLETIONS_ENDPOINT: |
263 | 322 | try: |
264 | | - for chunk in aqua_client.generate( |
265 | | - prompt=payload.pop("prompt"), |
266 | | - payload=payload, |
267 | | - stream=True, |
| 323 | + for chunk in aqua_client.self.session.completions.create( |
| 324 | + prompt=self.prompt, stream=True, model=model, **filtered |
268 | 325 | ): |
269 | | - try: |
270 | | - yield chunk["choices"][0]["text"] |
271 | | - except Exception as e: |
272 | | - logger.debug( |
273 | | - f"Exception occurred while parsing streaming response: {e}" |
274 | | - ) |
| 326 | + yield self._extract_text_from_chunk(chunk) |
| 327 | + # try: |
| 328 | + # yield chunk["choices"][0]["text"] |
| 329 | + # except Exception as e: |
| 330 | + # logger.debug( |
| 331 | + # f"Exception occurred while parsing streaming response: {e}" |
| 332 | + # ) |
| 333 | + except ExtendedRequestError as ex: |
| 334 | + raise HTTPError(400, str(ex)) |
| 335 | + except Exception as ex: |
| 336 | + raise HTTPError(500, str(ex)) |
| 337 | + |
| 338 | + elif endpoint_type == PredictEndpoints.RESPONSES: |
| 339 | + response = aqua_client.responses.create( |
| 340 | + prompt=self.prompt, stream=True, model=model, **filtered |
| 341 | + ) |
| 342 | + try: |
| 343 | + for chunk in response: |
| 344 | + yield self._extract_text_from_chunk(chunk) |
275 | 345 | except ExtendedRequestError as ex: |
276 | 346 | raise HTTPError(400, str(ex)) |
277 | 347 | except Exception as ex: |
|
0 commit comments