Skip to content

Commit 476d12d

Browse files
committed
Streaming improvements
1 parent c698c95 commit 476d12d

File tree

2 files changed

+587
-89
lines changed

2 files changed

+587
-89
lines changed

src/app/endpoints/streaming_query.py

Lines changed: 247 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,19 @@
44
import logging
55
import re
66
from json import JSONDecodeError
7-
from typing import Any, AsyncIterator
7+
from typing import Any, AsyncIterator, Iterator
88

99
from cachetools import TTLCache # type: ignore
1010

1111
from llama_stack_client import APIConnectionError
1212
from llama_stack_client.lib.agents.agent import AsyncAgent # type: ignore
1313
from llama_stack_client import AsyncLlamaStackClient # type: ignore
14-
from llama_stack_client.types.shared.interleaved_content_item import TextContentItem
1514
from llama_stack_client.types import UserMessage # type: ignore
1615

16+
from llama_stack_client.lib.agents.event_logger import interleaved_content_as_str
17+
from llama_stack_client.types.shared import ToolCall
18+
from llama_stack_client.types.shared.interleaved_content_item import TextContentItem
19+
1720
from fastapi import APIRouter, HTTPException, Request, Depends, status
1821
from fastapi.responses import StreamingResponse
1922

@@ -46,7 +49,8 @@
4649
_agent_cache: TTLCache[str, AsyncAgent] = TTLCache(maxsize=1000, ttl=3600)
4750

4851

49-
async def get_agent( # pylint: disable=too-many-arguments,too-many-positional-arguments
52+
# # pylint: disable=R0913,R0917
53+
async def get_agent(
5054
client: AsyncLlamaStackClient,
5155
model_id: str,
5256
system_prompt: str,
@@ -127,7 +131,7 @@ def stream_end_event(metadata_map: dict) -> str:
127131
)
128132

129133

130-
def stream_build_event(chunk: Any, chunk_id: int, metadata_map: dict) -> str | None:
134+
def stream_build_event(chunk: Any, chunk_id: int, metadata_map: dict) -> Iterator[str]:
131135
"""Build a streaming event from a chunk response.
132136
133137
This function processes chunks from the LLama Stack streaming response and formats
@@ -142,58 +146,261 @@ def stream_build_event(chunk: Any, chunk_id: int, metadata_map: dict) -> str | N
142146
chunk_id: The current chunk ID counter (gets incremented for each token)
143147
144148
Returns:
145-
str | None: A formatted SSE data string with event information, or None if
146-
the chunk doesn't contain processable event data
149+
Iterator[str]: An iterable list of formatted SSE data strings with event information
147150
"""
148-
# pylint: disable=R1702
149-
if hasattr(chunk.event, "payload"):
150-
if chunk.event.payload.event_type == "step_progress":
151-
if hasattr(chunk.event.payload.delta, "text"):
152-
text = chunk.event.payload.delta.text
153-
return format_stream_data(
151+
if hasattr(chunk, "error"):
152+
yield from _handle_error_event(chunk, chunk_id)
153+
return
154+
155+
event_type = chunk.event.payload.event_type
156+
step_type = getattr(chunk.event.payload, "step_type", None)
157+
158+
if event_type in {"turn_start", "turn_awaiting_input"}:
159+
yield from _handle_turn_start_event(chunk_id)
160+
elif event_type == "turn_complete":
161+
yield from _handle_turn_complete_event(chunk, chunk_id)
162+
elif step_type == "shield_call":
163+
yield from _handle_shield_event(chunk, chunk_id)
164+
elif step_type == "inference":
165+
yield from _handle_inference_event(chunk, chunk_id)
166+
elif step_type == "tool_execution":
167+
yield from _handle_tool_execution_event(chunk, chunk_id, metadata_map)
168+
else:
169+
yield from _handle_heartbeat_event(chunk_id)
170+
171+
172+
# -----------------------------------
173+
# Error handling
174+
# -----------------------------------
175+
def _handle_error_event(chunk: Any, chunk_id: int) -> Iterator[str]:
176+
yield format_stream_data(
177+
{
178+
"event": "error",
179+
"data": {
180+
"id": chunk_id,
181+
"token": chunk.error["message"],
182+
},
183+
}
184+
)
185+
186+
187+
# -----------------------------------
188+
# Turn handling
189+
# -----------------------------------
190+
def _handle_turn_start_event(chunk_id: int) -> Iterator[str]:
191+
yield format_stream_data(
192+
{
193+
"event": "token",
194+
"data": {
195+
"id": chunk_id,
196+
"token": "",
197+
},
198+
}
199+
)
200+
201+
202+
def _handle_turn_complete_event(chunk: Any, chunk_id: int) -> Iterator[str]:
203+
yield format_stream_data(
204+
{
205+
"event": "turn_complete",
206+
"data": {
207+
"id": chunk_id,
208+
"token": chunk.event.payload.turn.output_message.content,
209+
},
210+
}
211+
)
212+
213+
214+
# -----------------------------------
215+
# Shield handling
216+
# -----------------------------------
217+
def _handle_shield_event(chunk: Any, chunk_id: int) -> Iterator[str]:
218+
if chunk.event.payload.event_type == "step_complete":
219+
violation = chunk.event.payload.step_details.violation
220+
if not violation:
221+
yield format_stream_data(
222+
{
223+
"event": "token",
224+
"data": {
225+
"id": chunk_id,
226+
"role": chunk.event.payload.step_type,
227+
"token": "No Violation",
228+
},
229+
}
230+
)
231+
else:
232+
violation = (
233+
f"Violation: {violation.user_message} (Metadata: {violation.metadata})"
234+
)
235+
yield format_stream_data(
236+
{
237+
"event": "token",
238+
"data": {
239+
"id": chunk_id,
240+
"role": chunk.event.payload.step_type,
241+
"token": violation,
242+
},
243+
}
244+
)
245+
246+
247+
# -----------------------------------
248+
# Inference handling
249+
# -----------------------------------
250+
def _handle_inference_event(chunk: Any, chunk_id: int) -> Iterator[str]:
251+
if chunk.event.payload.event_type == "step_start":
252+
yield format_stream_data(
253+
{
254+
"event": "token",
255+
"data": {
256+
"id": chunk_id,
257+
"role": chunk.event.payload.step_type,
258+
"token": "",
259+
},
260+
}
261+
)
262+
263+
elif chunk.event.payload.event_type == "step_progress":
264+
if chunk.event.payload.delta.type == "tool_call":
265+
if isinstance(chunk.event.payload.delta.tool_call, str):
266+
yield format_stream_data(
154267
{
155-
"event": "token",
268+
"event": "tool_call",
156269
"data": {
157270
"id": chunk_id,
158271
"role": chunk.event.payload.step_type,
159-
"token": text,
272+
"token": chunk.event.payload.delta.tool_call,
160273
},
161274
}
162275
)
163-
if (
164-
chunk.event.payload.event_type == "step_complete"
165-
and chunk.event.payload.step_details.step_type == "tool_execution"
166-
):
167-
for r in chunk.event.payload.step_details.tool_responses:
168-
if r.tool_name == "knowledge_search" and r.content:
169-
for text_content_item in r.content:
170-
if isinstance(text_content_item, TextContentItem):
171-
for match in METADATA_PATTERN.findall(
172-
text_content_item.text
173-
):
174-
try:
175-
meta = json.loads(match.replace("'", '"'))
276+
elif isinstance(chunk.event.payload.delta.tool_call, ToolCall):
277+
yield format_stream_data(
278+
{
279+
"event": "tool_call",
280+
"data": {
281+
"id": chunk_id,
282+
"role": chunk.event.payload.step_type,
283+
"token": chunk.event.payload.delta.tool_call.tool_name,
284+
},
285+
}
286+
)
287+
288+
elif chunk.event.payload.delta.type == "text":
289+
yield format_stream_data(
290+
{
291+
"event": "token",
292+
"data": {
293+
"id": chunk_id,
294+
"role": chunk.event.payload.step_type,
295+
"token": chunk.event.payload.delta.text,
296+
},
297+
}
298+
)
299+
300+
301+
# -----------------------------------
302+
# Tool Execution handling
303+
# -----------------------------------
304+
# pylint: disable=R1702,R0912
305+
def _handle_tool_execution_event(
306+
chunk: Any, chunk_id: int, metadata_map: dict
307+
) -> Iterator[str]:
308+
if chunk.event.payload.event_type == "step_start":
309+
yield format_stream_data(
310+
{
311+
"event": "tool_call",
312+
"data": {
313+
"id": chunk_id,
314+
"role": chunk.event.payload.step_type,
315+
"token": "",
316+
},
317+
}
318+
)
319+
320+
elif chunk.event.payload.event_type == "step_complete":
321+
for t in chunk.event.payload.step_details.tool_calls:
322+
yield format_stream_data(
323+
{
324+
"event": "tool_call",
325+
"data": {
326+
"id": chunk_id,
327+
"role": chunk.event.payload.step_type,
328+
"token": f"Tool:{t.tool_name} arguments:{t.arguments}",
329+
},
330+
}
331+
)
332+
333+
for r in chunk.event.payload.step_details.tool_responses:
334+
if r.tool_name == "query_from_memory":
335+
inserted_context = interleaved_content_as_str(r.content)
336+
yield format_stream_data(
337+
{
338+
"event": "tool_call",
339+
"data": {
340+
"id": chunk_id,
341+
"role": chunk.event.payload.step_type,
342+
"token": f"Fetched {len(inserted_context)} bytes from memory",
343+
},
344+
}
345+
)
346+
347+
elif r.tool_name == "knowledge_search" and r.content:
348+
summary = ""
349+
for i, text_content_item in enumerate(r.content):
350+
if isinstance(text_content_item, TextContentItem):
351+
if i == 0:
352+
summary = text_content_item.text
353+
newline_pos = summary.find("\n")
354+
if newline_pos > 0:
355+
summary = summary[:newline_pos]
356+
for match in METADATA_PATTERN.findall(text_content_item.text):
357+
try:
358+
meta = json.loads(match.replace("'", '"'))
359+
if "document_id" in meta:
176360
metadata_map[meta["document_id"]] = meta
177-
except JSONDecodeError:
178-
logger.debug(
179-
"JSONDecodeError was thrown in processing %s",
180-
match,
181-
)
182-
if chunk.event.payload.step_details.tool_calls:
183-
tool_name = str(
184-
chunk.event.payload.step_details.tool_calls[0].tool_name
361+
except JSONDecodeError:
362+
logger.debug(
363+
"JSONDecodeError was thrown in processing %s",
364+
match,
365+
)
366+
367+
yield format_stream_data(
368+
{
369+
"event": "tool_call",
370+
"data": {
371+
"id": chunk_id,
372+
"role": chunk.event.payload.step_type,
373+
"token": f"Tool:{r.tool_name} summary:{summary}",
374+
},
375+
}
185376
)
186-
return format_stream_data(
377+
378+
else:
379+
yield format_stream_data(
187380
{
188-
"event": "token",
381+
"event": "tool_call",
189382
"data": {
190383
"id": chunk_id,
191384
"role": chunk.event.payload.step_type,
192-
"token": tool_name,
385+
"token": f"Tool:{r.tool_name} response:{r.content}",
193386
},
194387
}
195388
)
196-
return None
389+
390+
391+
# -----------------------------------
392+
# Catch-all for everything else
393+
# -----------------------------------
394+
def _handle_heartbeat_event(chunk_id: int) -> Iterator[str]:
395+
yield format_stream_data(
396+
{
397+
"event": "heartbeat",
398+
"data": {
399+
"id": chunk_id,
400+
"token": "heartbeat",
401+
},
402+
}
403+
)
197404

198405

199406
@router.post("/streaming_query")
@@ -233,7 +440,7 @@ async def response_generator(turn_response: Any) -> AsyncIterator[str]:
233440
yield stream_start_event(conversation_id)
234441

235442
async for chunk in turn_response:
236-
if event := stream_build_event(chunk, chunk_id, metadata_map):
443+
for event in stream_build_event(chunk, chunk_id, metadata_map):
237444
complete_response += json.loads(event.replace("data: ", ""))[
238445
"data"
239446
]["token"]

0 commit comments

Comments
 (0)