Skip to content

Commit 87640f1

Browse files
committed
Streaming improvements
1 parent 8c26807 commit 87640f1

File tree

2 files changed

+549
-81
lines changed

2 files changed

+549
-81
lines changed

src/app/endpoints/streaming_query.py

Lines changed: 212 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,19 @@
33
import json
44
import logging
55
import re
6-
from typing import Any, AsyncIterator
6+
from typing import Any, AsyncIterator, Iterator
77

88
from cachetools import TTLCache # type: ignore
99

1010
from llama_stack_client import APIConnectionError
1111
from llama_stack_client.lib.agents.agent import AsyncAgent # type: ignore
1212
from llama_stack_client import AsyncLlamaStackClient # type: ignore
13-
from llama_stack_client.types.shared.interleaved_content_item import TextContentItem
1413
from llama_stack_client.types import UserMessage # type: ignore
1514

15+
from llama_stack_client.lib.agents.event_logger import interleaved_content_as_str
16+
from llama_stack_client.types.shared import ToolCall
17+
from llama_stack_client.types.shared.interleaved_content_item import TextContentItem
18+
1619
from fastapi import APIRouter, HTTPException, Request, Depends, status
1720
from fastapi.responses import StreamingResponse
1821

@@ -122,7 +125,9 @@ def stream_end_event(metadata_map: dict) -> str:
122125
)
123126

124127

125-
def stream_build_event(chunk: Any, chunk_id: int, metadata_map: dict) -> str | None:
128+
# pylint: disable=R1702
129+
# pylint: disable=R0912
130+
def stream_build_event(chunk: Any, chunk_id: int, metadata_map: dict) -> Iterator[str]:
126131
"""Build a streaming event from a chunk response.
127132
128133
This function processes chunks from the LLama Stack streaming response and formats
@@ -137,52 +142,227 @@ def stream_build_event(chunk: Any, chunk_id: int, metadata_map: dict) -> str | N
137142
chunk_id: The current chunk ID counter (gets incremented for each token)
138143
139144
Returns:
140-
str | None: A formatted SSE data string with event information, or None if
141-
the chunk doesn't contain processable event data
145+
Iterator[str]: An iterable list of formatted SSE data strings with event information
142146
"""
143-
# pylint: disable=R1702
144-
if hasattr(chunk.event, "payload"):
145-
if chunk.event.payload.event_type == "step_progress":
146-
if hasattr(chunk.event.payload.delta, "text"):
147-
text = chunk.event.payload.delta.text
148-
return format_stream_data(
147+
# -----------------------------------
148+
# Error handling
149+
# -----------------------------------
150+
if hasattr(chunk, "error"):
151+
yield format_stream_data(
152+
{
153+
"event": "error",
154+
"data": {
155+
"id": chunk_id,
156+
"token": chunk.error["message"],
157+
},
158+
}
159+
)
160+
return
161+
162+
# -----------------------------------
163+
# Turn handling
164+
# -----------------------------------
165+
if chunk.event.payload.event_type in {"turn_start", "turn_awaiting_input"}:
166+
yield format_stream_data(
167+
{
168+
"event": "token",
169+
"data": {
170+
"id": chunk_id,
171+
"token": "",
172+
},
173+
}
174+
)
175+
return
176+
177+
if chunk.event.payload.event_type == "turn_complete":
178+
yield format_stream_data(
179+
{
180+
"event": "turn_complete",
181+
"data": {
182+
"id": chunk_id,
183+
"token": chunk.event.payload.turn.output_message.content,
184+
},
185+
}
186+
)
187+
return
188+
189+
# -----------------------------------
190+
# Shield handling
191+
# -----------------------------------
192+
if chunk.event.payload.step_type == "shield_call":
193+
if chunk.event.payload.event_type == "step_complete":
194+
violation = chunk.event.payload.step_details.violation
195+
if not violation:
196+
yield format_stream_data(
149197
{
150198
"event": "token",
151199
"data": {
152200
"id": chunk_id,
153201
"role": chunk.event.payload.step_type,
154-
"token": text,
202+
"token": "No Violation",
155203
},
156204
}
157205
)
158-
if (
159-
chunk.event.payload.event_type == "step_complete"
160-
and chunk.event.payload.step_details.step_type == "tool_execution"
161-
):
162-
for r in chunk.event.payload.step_details.tool_responses:
163-
if r.tool_name == "knowledge_search" and r.content:
164-
for text_content_item in r.content:
165-
if isinstance(text_content_item, TextContentItem):
166-
for match in METADATA_PATTERN.findall(
167-
text_content_item.text
168-
):
169-
meta = json.loads(match.replace("'", '"'))
170-
metadata_map[meta["document_id"]] = meta
171-
if chunk.event.payload.step_details.tool_calls:
172-
tool_name = str(
173-
chunk.event.payload.step_details.tool_calls[0].tool_name
206+
else:
207+
yield format_stream_data(
208+
{
209+
"event": "token",
210+
"data": {
211+
"id": chunk_id,
212+
"role": chunk.event.payload.step_type,
213+
"token": f"{violation.metadata} {violation.user_message}",
214+
},
215+
}
174216
)
175-
return format_stream_data(
217+
return
218+
219+
# -----------------------------------
220+
# Inference handling
221+
# -----------------------------------
222+
if chunk.event.payload.step_type == "inference":
223+
if chunk.event.payload.event_type == "step_start":
224+
yield format_stream_data(
225+
{
226+
"event": "token",
227+
"data": {
228+
"id": chunk_id,
229+
"role": chunk.event.payload.step_type,
230+
"token": "",
231+
},
232+
}
233+
)
234+
235+
elif chunk.event.payload.event_type == "step_progress":
236+
if chunk.event.payload.delta.type == "tool_call":
237+
if isinstance(chunk.event.payload.delta.tool_call, str):
238+
yield format_stream_data(
239+
{
240+
"event": "tool_call",
241+
"data": {
242+
"id": chunk_id,
243+
"role": chunk.event.payload.step_type,
244+
"token": chunk.event.payload.delta.tool_call,
245+
},
246+
}
247+
)
248+
elif isinstance(chunk.event.payload.delta.tool_call, ToolCall):
249+
yield format_stream_data(
250+
{
251+
"event": "tool_call",
252+
"data": {
253+
"id": chunk_id,
254+
"role": chunk.event.payload.step_type,
255+
"token": chunk.event.payload.delta.tool_call.tool_name,
256+
},
257+
}
258+
)
259+
260+
elif chunk.event.payload.delta.type == "text":
261+
yield format_stream_data(
176262
{
177263
"event": "token",
178264
"data": {
179265
"id": chunk_id,
180266
"role": chunk.event.payload.step_type,
181-
"token": tool_name,
267+
"token": chunk.event.payload.delta.text,
182268
},
183269
}
184270
)
185-
return None
271+
272+
return
273+
274+
# -----------------------------------
275+
# Tool Execution handling
276+
# -----------------------------------
277+
if chunk.event.payload.step_type == "tool_execution":
278+
if chunk.event.payload.event_type == "step_start":
279+
yield format_stream_data(
280+
{
281+
"event": "tool_call",
282+
"data": {
283+
"id": chunk_id,
284+
"role": chunk.event.payload.step_type,
285+
"token": "",
286+
},
287+
}
288+
)
289+
290+
elif chunk.event.payload.event_type == "step_complete":
291+
for t in chunk.event.payload.step_details.tool_calls:
292+
yield format_stream_data(
293+
{
294+
"event": "tool_call",
295+
"data": {
296+
"id": chunk_id,
297+
"role": chunk.event.payload.step_type,
298+
"token": f"Tool:{t.tool_name} arguments:{t.arguments}",
299+
},
300+
}
301+
)
302+
303+
for r in chunk.event.payload.step_details.tool_responses:
304+
if r.tool_name == "query_from_memory":
305+
inserted_context = interleaved_content_as_str(r.content)
306+
yield format_stream_data(
307+
{
308+
"event": "tool_call",
309+
"data": {
310+
"id": chunk_id,
311+
"role": chunk.event.payload.step_type,
312+
"token": f"Fetched {len(inserted_context)} bytes from memory",
313+
},
314+
}
315+
)
316+
317+
elif r.tool_name == "knowledge_search" and r.content:
318+
summary = ""
319+
for i, text_content_item in enumerate(r.content):
320+
if isinstance(text_content_item, TextContentItem):
321+
if i == 0:
322+
summary = text_content_item.text
323+
summary = summary[: summary.find("\n")]
324+
for match in METADATA_PATTERN.findall(
325+
text_content_item.text
326+
):
327+
meta = json.loads(match.replace("'", '"'))
328+
metadata_map[meta["document_id"]] = meta
329+
yield format_stream_data(
330+
{
331+
"event": "tool_call",
332+
"data": {
333+
"id": chunk_id,
334+
"role": chunk.event.payload.step_type,
335+
"token": f"Tool:{r.tool_name} summary:{summary}",
336+
},
337+
}
338+
)
339+
340+
else:
341+
yield format_stream_data(
342+
{
343+
"event": "tool_call",
344+
"data": {
345+
"id": chunk_id,
346+
"role": chunk.event.payload.step_type,
347+
"token": f"Tool:{r.tool_name} response:{r.content}",
348+
},
349+
}
350+
)
351+
352+
return
353+
354+
# -----------------------------------
355+
# Catch-all for everything else
356+
# -----------------------------------
357+
yield format_stream_data(
358+
{
359+
"event": "heartbeat",
360+
"data": {
361+
"id": chunk_id,
362+
"token": "heartbeat",
363+
},
364+
}
365+
)
186366

187367

188368
@router.post("/streaming_query")
@@ -222,7 +402,7 @@ async def response_generator(turn_response: Any) -> AsyncIterator[str]:
222402
yield stream_start_event(conversation_id)
223403

224404
async for chunk in turn_response:
225-
if event := stream_build_event(chunk, chunk_id, metadata_map):
405+
for event in stream_build_event(chunk, chunk_id, metadata_map):
226406
complete_response += json.loads(event.replace("data: ", ""))[
227407
"data"
228408
]["token"]

0 commit comments

Comments
 (0)