Skip to content

Commit 65d4cba

Browse files
committed
Streaming improvements
1 parent 8c26807 commit 65d4cba

File tree

1 file changed

+243
-34
lines changed

1 file changed

+243
-34
lines changed

src/app/endpoints/streaming_query.py

Lines changed: 243 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,19 @@
33
import json
44
import logging
55
import re
6-
from typing import Any, AsyncIterator
6+
from typing import Any, AsyncIterator, Iterator
77

88
from cachetools import TTLCache # type: ignore
99

1010
from llama_stack_client import APIConnectionError
1111
from llama_stack_client.lib.agents.agent import AsyncAgent # type: ignore
1212
from llama_stack_client import AsyncLlamaStackClient # type: ignore
13-
from llama_stack_client.types.shared.interleaved_content_item import TextContentItem
1413
from llama_stack_client.types import UserMessage # type: ignore
1514

15+
from llama_stack_client.lib.agents.event_logger import interleaved_content_as_str
16+
from llama_stack_client.types.shared import ToolCall
17+
from llama_stack_client.types.shared.interleaved_content_item import TextContentItem
18+
1619
from fastapi import APIRouter, HTTPException, Request, Depends, status
1720
from fastapi.responses import StreamingResponse
1821

@@ -122,7 +125,9 @@ def stream_end_event(metadata_map: dict) -> str:
122125
)
123126

124127

125-
def stream_build_event(chunk: Any, chunk_id: int, metadata_map: dict) -> str | None:
128+
# pylint: disable=R1702
129+
# pylint: disable=R0912
130+
def stream_build_event(chunk: Any, chunk_id: int, metadata_map: dict) -> Iterator[str]:
126131
"""Build a streaming event from a chunk response.
127132
128133
This function processes chunks from the LLama Stack streaming response and formats
@@ -137,52 +142,256 @@ def stream_build_event(chunk: Any, chunk_id: int, metadata_map: dict) -> str | N
137142
chunk_id: The current chunk ID counter (gets incremented for each token)
138143
139144
Returns:
140-
str | None: A formatted SSE data string with event information, or None if
141-
the chunk doesn't contain processable event data
145+
Iterator[str]: An iterable list of formatted SSE data strings with event information
142146
"""
143-
# pylint: disable=R1702
144-
if hasattr(chunk.event, "payload"):
145-
if chunk.event.payload.event_type == "step_progress":
146-
if hasattr(chunk.event.payload.delta, "text"):
147-
text = chunk.event.payload.delta.text
148-
return format_stream_data(
147+
# -----------------------------------
148+
# Error handling
149+
# -----------------------------------
150+
if hasattr(chunk, "error"):
151+
yield format_stream_data(
152+
{
153+
"event": "error",
154+
"data": {
155+
"id": chunk_id,
156+
"token": chunk.error["message"],
157+
},
158+
}
159+
)
160+
return
161+
162+
# -----------------------------------
163+
# Turn handling
164+
# -----------------------------------
165+
if chunk.event.payload.event_type in {"turn_start", "turn_awaiting_input"}:
166+
yield format_stream_data(
167+
{
168+
"event": "token",
169+
"data": {
170+
"id": chunk_id,
171+
"token": "",
172+
},
173+
}
174+
)
175+
return
176+
177+
if chunk.event.payload.event_type == "turn_complete":
178+
yield format_stream_data(
179+
{
180+
"event": "turn_complete",
181+
"data": {
182+
"id": chunk_id,
183+
"token": chunk.event.payload.turn.output_message.content,
184+
},
185+
}
186+
)
187+
return
188+
189+
# -----------------------------------
190+
# Shield handling
191+
# -----------------------------------
192+
if (
193+
chunk.event.payload.event_type == "step_complete"
194+
and chunk.event.payload.step_type == "shield_call"
195+
):
196+
violation = chunk.event.payload.step_details.violation
197+
if not violation:
198+
yield format_stream_data(
199+
{
200+
"event": "token",
201+
"data": {
202+
"id": chunk_id,
203+
"role": chunk.event.payload.step_type,
204+
"token": "No Violation",
205+
},
206+
}
207+
)
208+
else:
209+
yield format_stream_data(
210+
{
211+
"event": "token",
212+
"data": {
213+
"id": chunk_id,
214+
"role": chunk.event.payload.step_type,
215+
"token": f"{violation.metadata} {violation.user_message}",
216+
},
217+
}
218+
)
219+
return
220+
221+
# -----------------------------------
222+
# Inference handling
223+
# -----------------------------------
224+
if chunk.event.payload.step_type == "inference":
225+
if chunk.event.payload.event_type == "step_start":
226+
yield format_stream_data(
227+
{
228+
"event": "token",
229+
"data": {
230+
"id": chunk_id,
231+
"role": chunk.event.payload.step_type,
232+
"token": "",
233+
},
234+
}
235+
)
236+
237+
elif chunk.event.payload.event_type == "step_progress":
238+
if chunk.event.payload.delta.type == "tool_call":
239+
if isinstance(chunk.event.payload.delta.tool_call, str):
240+
yield format_stream_data(
241+
{
242+
"event": "tool_call",
243+
"data": {
244+
"id": chunk_id,
245+
"role": chunk.event.payload.step_type,
246+
"token": chunk.event.payload.delta.tool_call,
247+
},
248+
}
249+
)
250+
elif isinstance(chunk.event.payload.delta.tool_call, ToolCall):
251+
yield format_stream_data(
252+
{
253+
"event": "tool_call",
254+
"data": {
255+
"id": chunk_id,
256+
"role": chunk.event.payload.step_type,
257+
"token": chunk.event.payload.delta.tool_call.tool_name,
258+
},
259+
}
260+
)
261+
262+
elif chunk.event.payload.delta.type == "text":
263+
yield format_stream_data(
149264
{
150265
"event": "token",
151266
"data": {
152267
"id": chunk_id,
153268
"role": chunk.event.payload.step_type,
154-
"token": text,
269+
"token": chunk.event.payload.delta.text,
155270
},
156271
}
157272
)
158-
if (
159-
chunk.event.payload.event_type == "step_complete"
160-
and chunk.event.payload.step_details.step_type == "tool_execution"
161-
):
273+
274+
elif chunk.event.payload.event_type == "step_complete":
275+
yield format_stream_data(
276+
{
277+
"event": "step_complete",
278+
"data": {
279+
"id": chunk_id,
280+
"token": "",
281+
},
282+
}
283+
)
284+
return
285+
286+
# -----------------------------------
287+
# Tool Execution handling
288+
# -----------------------------------
289+
if chunk.event.payload.step_type == "tool_execution":
290+
if chunk.event.payload.event_type == "step_start":
291+
yield format_stream_data(
292+
{
293+
"event": "tool_call",
294+
"data": {
295+
"id": chunk_id,
296+
# PatternFly Chat UI expects 'role=inference' to render correctly
297+
"role": "inference", # chunk.event.payload.step_type,
298+
"token": "",
299+
},
300+
}
301+
)
302+
303+
elif chunk.event.payload.event_type == "step_complete":
304+
for t in chunk.event.payload.step_details.tool_calls:
305+
yield format_stream_data(
306+
{
307+
"event": "tool_call",
308+
"data": {
309+
"id": chunk_id,
310+
# PatternFly Chat UI expects 'role=inference' to render correctly
311+
"role": "inference", # chunk.event.payload.step_type,
312+
"token": f"Tool:{t.tool_name} arguments:{t.arguments}",
313+
},
314+
}
315+
)
316+
162317
for r in chunk.event.payload.step_details.tool_responses:
163-
if r.tool_name == "knowledge_search" and r.content:
164-
for text_content_item in r.content:
318+
if r.tool_name == "query_from_memory":
319+
inserted_context = interleaved_content_as_str(r.content)
320+
yield format_stream_data(
321+
{
322+
"event": "tool_call",
323+
"data": {
324+
"id": chunk_id,
325+
# PatternFly Chat UI expects 'role=inference' to render correctly
326+
"role": "inference", # chunk.event.payload.step_type,
327+
"token": f"Fetched {len(inserted_context)} bytes from memory",
328+
},
329+
}
330+
)
331+
332+
elif r.tool_name == "knowledge_search" and r.content:
333+
summary = ""
334+
for i, text_content_item in enumerate(r.content):
165335
if isinstance(text_content_item, TextContentItem):
336+
if i == 0:
337+
summary = text_content_item.text
338+
summary = summary[: summary.find("\n")]
166339
for match in METADATA_PATTERN.findall(
167340
text_content_item.text
168341
):
169342
meta = json.loads(match.replace("'", '"'))
170343
metadata_map[meta["document_id"]] = meta
171-
if chunk.event.payload.step_details.tool_calls:
172-
tool_name = str(
173-
chunk.event.payload.step_details.tool_calls[0].tool_name
174-
)
175-
return format_stream_data(
176-
{
177-
"event": "token",
178-
"data": {
179-
"id": chunk_id,
180-
"role": chunk.event.payload.step_type,
181-
"token": tool_name,
182-
},
183-
}
184-
)
185-
return None
344+
yield format_stream_data(
345+
{
346+
"event": "tool_call",
347+
"data": {
348+
"id": chunk_id,
349+
# PatternFly Chat UI expects 'role=inference' to render correctly
350+
"role": "inference", # chunk.event.payload.step_type,
351+
"token": f"Tool:{r.tool_name} summary:{summary}\n",
352+
},
353+
}
354+
)
355+
356+
else:
357+
yield format_stream_data(
358+
{
359+
"event": "tool_call",
360+
"data": {
361+
"id": chunk_id,
362+
# PatternFly Chat UI expects 'role=inference' to render correctly
363+
"role": "inference", # chunk.event.payload.step_type,
364+
"token": f"Tool:{r.tool_name} response:{r.content}",
365+
},
366+
}
367+
)
368+
369+
# We swallow the 'step_complete' event and re-emit 'token' events with the tool details
370+
# Ensure we send a 'step_complete' event so the UI knows the 'tool_execution' completed.
371+
yield format_stream_data(
372+
{
373+
"event": "step_complete",
374+
"data": {
375+
"id": chunk_id,
376+
"token": "",
377+
},
378+
}
379+
)
380+
381+
return
382+
383+
# -----------------------------------
384+
# Catch-all for everything else
385+
# -----------------------------------
386+
yield format_stream_data(
387+
{
388+
"event": "error",
389+
"data": {
390+
"id": chunk_id,
391+
"token": "manstis",
392+
},
393+
}
394+
)
186395

187396

188397
@router.post("/streaming_query")
@@ -222,7 +431,7 @@ async def response_generator(turn_response: Any) -> AsyncIterator[str]:
222431
yield stream_start_event(conversation_id)
223432

224433
async for chunk in turn_response:
225-
if event := stream_build_event(chunk, chunk_id, metadata_map):
434+
for event in stream_build_event(chunk, chunk_id, metadata_map):
226435
complete_response += json.loads(event.replace("data: ", ""))[
227436
"data"
228437
]["token"]

0 commit comments

Comments
 (0)