Skip to content

Commit 24a8048

Browse files
committed
Streaming improvements
1 parent 8c26807 commit 24a8048

File tree

2 files changed

+582
-82
lines changed

2 files changed

+582
-82
lines changed

src/app/endpoints/streaming_query.py

Lines changed: 242 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,19 @@
33
import json
44
import logging
55
import re
6-
from typing import Any, AsyncIterator
6+
from typing import Any, AsyncIterator, Iterator
77

88
from cachetools import TTLCache # type: ignore
99

1010
from llama_stack_client import APIConnectionError
1111
from llama_stack_client.lib.agents.agent import AsyncAgent # type: ignore
1212
from llama_stack_client import AsyncLlamaStackClient # type: ignore
13-
from llama_stack_client.types.shared.interleaved_content_item import TextContentItem
1413
from llama_stack_client.types import UserMessage # type: ignore
1514

15+
from llama_stack_client.lib.agents.event_logger import interleaved_content_as_str
16+
from llama_stack_client.types.shared import ToolCall
17+
from llama_stack_client.types.shared.interleaved_content_item import TextContentItem
18+
1619
from fastapi import APIRouter, HTTPException, Request, Depends, status
1720
from fastapi.responses import StreamingResponse
1821

@@ -122,7 +125,7 @@ def stream_end_event(metadata_map: dict) -> str:
122125
)
123126

124127

125-
def stream_build_event(chunk: Any, chunk_id: int, metadata_map: dict) -> str | None:
128+
def stream_build_event(chunk: Any, chunk_id: int, metadata_map: dict) -> Iterator[str]:
126129
"""Build a streaming event from a chunk response.
127130
128131
This function processes chunks from the LLama Stack streaming response and formats
@@ -137,52 +140,258 @@ def stream_build_event(chunk: Any, chunk_id: int, metadata_map: dict) -> str | N
137140
chunk_id: The current chunk ID counter (gets incremented for each token)
138141
139142
Returns:
140-
str | None: A formatted SSE data string with event information, or None if
141-
the chunk doesn't contain processable event data
143+
Iterator[str]: An iterable list of formatted SSE data strings with event information
142144
"""
143-
# pylint: disable=R1702
144-
if hasattr(chunk.event, "payload"):
145-
if chunk.event.payload.event_type == "step_progress":
146-
if hasattr(chunk.event.payload.delta, "text"):
147-
text = chunk.event.payload.delta.text
148-
return format_stream_data(
145+
if hasattr(chunk, "error"):
146+
yield from _handle_error_event(chunk, chunk_id)
147+
return
148+
149+
event_type = chunk.event.payload.event_type
150+
step_type = getattr(chunk.event.payload, "step_type", None)
151+
152+
if event_type in {"turn_start", "turn_awaiting_input"}:
153+
yield from _handle_turn_start_event(chunk_id)
154+
elif event_type == "turn_complete":
155+
yield from _handle_turn_complete_event(chunk, chunk_id)
156+
elif step_type == "shield_call":
157+
yield from _handle_shield_event(chunk, chunk_id)
158+
elif step_type == "inference":
159+
yield from _handle_inference_event(chunk, chunk_id)
160+
elif step_type == "tool_execution":
161+
yield from _handle_tool_execution_event(chunk, chunk_id, metadata_map)
162+
else:
163+
yield from _handle_heartbeat_event(chunk_id)
164+
165+
166+
# -----------------------------------
167+
# Error handling
168+
# -----------------------------------
169+
def _handle_error_event(chunk: Any, chunk_id: int) -> Iterator[str]:
170+
yield format_stream_data(
171+
{
172+
"event": "error",
173+
"data": {
174+
"id": chunk_id,
175+
"token": chunk.error["message"],
176+
},
177+
}
178+
)
179+
180+
181+
# -----------------------------------
182+
# Turn handling
183+
# -----------------------------------
184+
def _handle_turn_start_event(chunk_id: int) -> Iterator[str]:
185+
yield format_stream_data(
186+
{
187+
"event": "token",
188+
"data": {
189+
"id": chunk_id,
190+
"token": "",
191+
},
192+
}
193+
)
194+
195+
196+
def _handle_turn_complete_event(chunk: Any, chunk_id: int) -> Iterator[str]:
197+
yield format_stream_data(
198+
{
199+
"event": "turn_complete",
200+
"data": {
201+
"id": chunk_id,
202+
"token": chunk.event.payload.turn.output_message.content,
203+
},
204+
}
205+
)
206+
207+
208+
# -----------------------------------
209+
# Shield handling
210+
# -----------------------------------
211+
def _handle_shield_event(chunk: Any, chunk_id: int) -> Iterator[str]:
212+
if chunk.event.payload.event_type == "step_complete":
213+
violation = chunk.event.payload.step_details.violation
214+
if not violation:
215+
yield format_stream_data(
216+
{
217+
"event": "token",
218+
"data": {
219+
"id": chunk_id,
220+
"role": chunk.event.payload.step_type,
221+
"token": "No Violation",
222+
},
223+
}
224+
)
225+
else:
226+
violation = (
227+
f"Violation: {violation.user_message} (Metadata: {violation.metadata})"
228+
)
229+
yield format_stream_data(
230+
{
231+
"event": "token",
232+
"data": {
233+
"id": chunk_id,
234+
"role": chunk.event.payload.step_type,
235+
"token": violation,
236+
},
237+
}
238+
)
239+
240+
241+
# -----------------------------------
242+
# Inference handling
243+
# -----------------------------------
244+
def _handle_inference_event(chunk: Any, chunk_id: int) -> Iterator[str]:
245+
if chunk.event.payload.event_type == "step_start":
246+
yield format_stream_data(
247+
{
248+
"event": "token",
249+
"data": {
250+
"id": chunk_id,
251+
"role": chunk.event.payload.step_type,
252+
"token": "",
253+
},
254+
}
255+
)
256+
257+
elif chunk.event.payload.event_type == "step_progress":
258+
if chunk.event.payload.delta.type == "tool_call":
259+
if isinstance(chunk.event.payload.delta.tool_call, str):
260+
yield format_stream_data(
149261
{
150-
"event": "token",
262+
"event": "tool_call",
151263
"data": {
152264
"id": chunk_id,
153265
"role": chunk.event.payload.step_type,
154-
"token": text,
266+
"token": chunk.event.payload.delta.tool_call,
155267
},
156268
}
157269
)
158-
if (
159-
chunk.event.payload.event_type == "step_complete"
160-
and chunk.event.payload.step_details.step_type == "tool_execution"
161-
):
162-
for r in chunk.event.payload.step_details.tool_responses:
163-
if r.tool_name == "knowledge_search" and r.content:
164-
for text_content_item in r.content:
165-
if isinstance(text_content_item, TextContentItem):
166-
for match in METADATA_PATTERN.findall(
167-
text_content_item.text
168-
):
270+
elif isinstance(chunk.event.payload.delta.tool_call, ToolCall):
271+
yield format_stream_data(
272+
{
273+
"event": "tool_call",
274+
"data": {
275+
"id": chunk_id,
276+
"role": chunk.event.payload.step_type,
277+
"token": chunk.event.payload.delta.tool_call.tool_name,
278+
},
279+
}
280+
)
281+
282+
elif chunk.event.payload.delta.type == "text":
283+
yield format_stream_data(
284+
{
285+
"event": "token",
286+
"data": {
287+
"id": chunk_id,
288+
"role": chunk.event.payload.step_type,
289+
"token": chunk.event.payload.delta.text,
290+
},
291+
}
292+
)
293+
294+
295+
# -----------------------------------
296+
# Tool Execution handling
297+
# -----------------------------------
298+
# pylint: disable=R1702,R0912
299+
def _handle_tool_execution_event(
300+
chunk: Any, chunk_id: int, metadata_map: dict
301+
) -> Iterator[str]:
302+
if chunk.event.payload.event_type == "step_start":
303+
yield format_stream_data(
304+
{
305+
"event": "tool_call",
306+
"data": {
307+
"id": chunk_id,
308+
"role": chunk.event.payload.step_type,
309+
"token": "",
310+
},
311+
}
312+
)
313+
314+
elif chunk.event.payload.event_type == "step_complete":
315+
for t in chunk.event.payload.step_details.tool_calls:
316+
yield format_stream_data(
317+
{
318+
"event": "tool_call",
319+
"data": {
320+
"id": chunk_id,
321+
"role": chunk.event.payload.step_type,
322+
"token": f"Tool:{t.tool_name} arguments:{t.arguments}",
323+
},
324+
}
325+
)
326+
327+
for r in chunk.event.payload.step_details.tool_responses:
328+
if r.tool_name == "query_from_memory":
329+
inserted_context = interleaved_content_as_str(r.content)
330+
yield format_stream_data(
331+
{
332+
"event": "tool_call",
333+
"data": {
334+
"id": chunk_id,
335+
"role": chunk.event.payload.step_type,
336+
"token": f"Fetched {len(inserted_context)} bytes from memory",
337+
},
338+
}
339+
)
340+
341+
elif r.tool_name == "knowledge_search" and r.content:
342+
summary = ""
343+
for i, text_content_item in enumerate(r.content):
344+
if isinstance(text_content_item, TextContentItem):
345+
if i == 0:
346+
summary = text_content_item.text
347+
newline_pos = summary.find("\n")
348+
if newline_pos > 0:
349+
summary = summary[:newline_pos]
350+
for match in METADATA_PATTERN.findall(text_content_item.text):
351+
try:
169352
meta = json.loads(match.replace("'", '"'))
170-
metadata_map[meta["document_id"]] = meta
171-
if chunk.event.payload.step_details.tool_calls:
172-
tool_name = str(
173-
chunk.event.payload.step_details.tool_calls[0].tool_name
353+
if "document_id" in meta:
354+
metadata_map[meta["document_id"]] = meta
355+
except (json.JSONDecodeError, KeyError) as e:
356+
logger.warning("Failed to parse metadata: %s", e)
357+
358+
yield format_stream_data(
359+
{
360+
"event": "tool_call",
361+
"data": {
362+
"id": chunk_id,
363+
"role": chunk.event.payload.step_type,
364+
"token": f"Tool:{r.tool_name} summary:{summary}",
365+
},
366+
}
174367
)
175-
return format_stream_data(
368+
369+
else:
370+
yield format_stream_data(
176371
{
177-
"event": "token",
372+
"event": "tool_call",
178373
"data": {
179374
"id": chunk_id,
180375
"role": chunk.event.payload.step_type,
181-
"token": tool_name,
376+
"token": f"Tool:{r.tool_name} response:{r.content}",
182377
},
183378
}
184379
)
185-
return None
380+
381+
382+
# -----------------------------------
383+
# Catch-all for everything else
384+
# -----------------------------------
385+
def _handle_heartbeat_event(chunk_id: int) -> Iterator[str]:
386+
yield format_stream_data(
387+
{
388+
"event": "heartbeat",
389+
"data": {
390+
"id": chunk_id,
391+
"token": "heartbeat",
392+
},
393+
}
394+
)
186395

187396

188397
@router.post("/streaming_query")
@@ -222,7 +431,7 @@ async def response_generator(turn_response: Any) -> AsyncIterator[str]:
222431
yield stream_start_event(conversation_id)
223432

224433
async for chunk in turn_response:
225-
if event := stream_build_event(chunk, chunk_id, metadata_map):
434+
for event in stream_build_event(chunk, chunk_id, metadata_map):
226435
complete_response += json.loads(event.replace("data: ", ""))[
227436
"data"
228437
]["token"]

0 commit comments

Comments
 (0)