33import json
44import logging
55import re
6- from typing import Any , AsyncIterator
6+ from typing import Any , AsyncIterator , Iterator
77
88from cachetools import TTLCache # type: ignore
99
1010from llama_stack_client import APIConnectionError
1111from llama_stack_client .lib .agents .agent import AsyncAgent # type: ignore
1212from llama_stack_client import AsyncLlamaStackClient # type: ignore
13- from llama_stack_client .types .shared .interleaved_content_item import TextContentItem
1413from llama_stack_client .types import UserMessage # type: ignore
1514
15+ from llama_stack_client .lib .agents .event_logger import interleaved_content_as_str
16+ from llama_stack_client .types .shared import ToolCall
17+ from llama_stack_client .types .shared .interleaved_content_item import TextContentItem
18+
1619from fastapi import APIRouter , HTTPException , Request , Depends , status
1720from fastapi .responses import StreamingResponse
1821
@@ -122,7 +125,9 @@ def stream_end_event(metadata_map: dict) -> str:
122125 )
123126
124127
125- def stream_build_event (chunk : Any , chunk_id : int , metadata_map : dict ) -> str | None :
128+ # pylint: disable=R1702
129+ # pylint: disable=R0912
130+ def stream_build_event (chunk : Any , chunk_id : int , metadata_map : dict ) -> Iterator [str ]:
126131 """Build a streaming event from a chunk response.
127132
128133 This function processes chunks from the LLama Stack streaming response and formats
@@ -137,52 +142,227 @@ def stream_build_event(chunk: Any, chunk_id: int, metadata_map: dict) -> str | N
137142 chunk_id: The current chunk ID counter (gets incremented for each token)
138143
139144 Returns:
140- str | None: A formatted SSE data string with event information, or None if
141- the chunk doesn't contain processable event data
145+ Iterator[str]: An iterable list of formatted SSE data strings with event information
142146 """
143- # pylint: disable=R1702
144- if hasattr (chunk .event , "payload" ):
145- if chunk .event .payload .event_type == "step_progress" :
146- if hasattr (chunk .event .payload .delta , "text" ):
147- text = chunk .event .payload .delta .text
148- return format_stream_data (
147+ # -----------------------------------
148+ # Error handling
149+ # -----------------------------------
150+ if hasattr (chunk , "error" ):
151+ yield format_stream_data (
152+ {
153+ "event" : "error" ,
154+ "data" : {
155+ "id" : chunk_id ,
156+ "token" : chunk .error ["message" ],
157+ },
158+ }
159+ )
160+ return
161+
162+ # -----------------------------------
163+ # Turn handling
164+ # -----------------------------------
165+ if chunk .event .payload .event_type in {"turn_start" , "turn_awaiting_input" }:
166+ yield format_stream_data (
167+ {
168+ "event" : "token" ,
169+ "data" : {
170+ "id" : chunk_id ,
171+ "token" : "" ,
172+ },
173+ }
174+ )
175+ return
176+
177+ if chunk .event .payload .event_type == "turn_complete" :
178+ yield format_stream_data (
179+ {
180+ "event" : "turn_complete" ,
181+ "data" : {
182+ "id" : chunk_id ,
183+ "token" : chunk .event .payload .turn .output_message .content ,
184+ },
185+ }
186+ )
187+ return
188+
189+ # -----------------------------------
190+ # Shield handling
191+ # -----------------------------------
192+ if chunk .event .payload .step_type == "shield_call" :
193+ if chunk .event .payload .event_type == "step_complete" :
194+ violation = chunk .event .payload .step_details .violation
195+ if not violation :
196+ yield format_stream_data (
149197 {
150198 "event" : "token" ,
151199 "data" : {
152200 "id" : chunk_id ,
153201 "role" : chunk .event .payload .step_type ,
154- "token" : text ,
202+ "token" : "No Violation" ,
155203 },
156204 }
157205 )
158- if (
159- chunk .event .payload .event_type == "step_complete"
160- and chunk .event .payload .step_details .step_type == "tool_execution"
161- ):
162- for r in chunk .event .payload .step_details .tool_responses :
163- if r .tool_name == "knowledge_search" and r .content :
164- for text_content_item in r .content :
165- if isinstance (text_content_item , TextContentItem ):
166- for match in METADATA_PATTERN .findall (
167- text_content_item .text
168- ):
169- meta = json .loads (match .replace ("'" , '"' ))
170- metadata_map [meta ["document_id" ]] = meta
171- if chunk .event .payload .step_details .tool_calls :
172- tool_name = str (
173- chunk .event .payload .step_details .tool_calls [0 ].tool_name
206+ else :
207+ yield format_stream_data (
208+ {
209+ "event" : "token" ,
210+ "data" : {
211+ "id" : chunk_id ,
212+ "role" : chunk .event .payload .step_type ,
213+ "token" : f"{ violation .metadata } { violation .user_message } " ,
214+ },
215+ }
174216 )
175- return format_stream_data (
217+ return
218+
219+ # -----------------------------------
220+ # Inference handling
221+ # -----------------------------------
222+ if chunk .event .payload .step_type == "inference" :
223+ if chunk .event .payload .event_type == "step_start" :
224+ yield format_stream_data (
225+ {
226+ "event" : "token" ,
227+ "data" : {
228+ "id" : chunk_id ,
229+ "role" : chunk .event .payload .step_type ,
230+ "token" : "" ,
231+ },
232+ }
233+ )
234+
235+ elif chunk .event .payload .event_type == "step_progress" :
236+ if chunk .event .payload .delta .type == "tool_call" :
237+ if isinstance (chunk .event .payload .delta .tool_call , str ):
238+ yield format_stream_data (
239+ {
240+ "event" : "tool_call" ,
241+ "data" : {
242+ "id" : chunk_id ,
243+ "role" : chunk .event .payload .step_type ,
244+ "token" : chunk .event .payload .delta .tool_call ,
245+ },
246+ }
247+ )
248+ elif isinstance (chunk .event .payload .delta .tool_call , ToolCall ):
249+ yield format_stream_data (
250+ {
251+ "event" : "tool_call" ,
252+ "data" : {
253+ "id" : chunk_id ,
254+ "role" : chunk .event .payload .step_type ,
255+ "token" : chunk .event .payload .delta .tool_call .tool_name ,
256+ },
257+ }
258+ )
259+
260+ elif chunk .event .payload .delta .type == "text" :
261+ yield format_stream_data (
176262 {
177263 "event" : "token" ,
178264 "data" : {
179265 "id" : chunk_id ,
180266 "role" : chunk .event .payload .step_type ,
181- "token" : tool_name ,
267+ "token" : chunk . event . payload . delta . text ,
182268 },
183269 }
184270 )
185- return None
271+
272+ return
273+
274+ # -----------------------------------
275+ # Tool Execution handling
276+ # -----------------------------------
277+ if chunk .event .payload .step_type == "tool_execution" :
278+ if chunk .event .payload .event_type == "step_start" :
279+ yield format_stream_data (
280+ {
281+ "event" : "tool_call" ,
282+ "data" : {
283+ "id" : chunk_id ,
284+ "role" : chunk .event .payload .step_type ,
285+ "token" : "" ,
286+ },
287+ }
288+ )
289+
290+ elif chunk .event .payload .event_type == "step_complete" :
291+ for t in chunk .event .payload .step_details .tool_calls :
292+ yield format_stream_data (
293+ {
294+ "event" : "tool_call" ,
295+ "data" : {
296+ "id" : chunk_id ,
297+ "role" : chunk .event .payload .step_type ,
298+ "token" : f"Tool:{ t .tool_name } arguments:{ t .arguments } " ,
299+ },
300+ }
301+ )
302+
303+ for r in chunk .event .payload .step_details .tool_responses :
304+ if r .tool_name == "query_from_memory" :
305+ inserted_context = interleaved_content_as_str (r .content )
306+ yield format_stream_data (
307+ {
308+ "event" : "tool_call" ,
309+ "data" : {
310+ "id" : chunk_id ,
311+ "role" : chunk .event .payload .step_type ,
312+ "token" : f"Fetched { len (inserted_context )} bytes from memory" ,
313+ },
314+ }
315+ )
316+
317+ elif r .tool_name == "knowledge_search" and r .content :
318+ summary = ""
319+ for i , text_content_item in enumerate (r .content ):
320+ if isinstance (text_content_item , TextContentItem ):
321+ if i == 0 :
322+ summary = text_content_item .text
323+ summary = summary [: summary .find ("\n " )]
324+ for match in METADATA_PATTERN .findall (
325+ text_content_item .text
326+ ):
327+ meta = json .loads (match .replace ("'" , '"' ))
328+ metadata_map [meta ["document_id" ]] = meta
329+ yield format_stream_data (
330+ {
331+ "event" : "tool_call" ,
332+ "data" : {
333+ "id" : chunk_id ,
334+ "role" : chunk .event .payload .step_type ,
335+ "token" : f"Tool:{ r .tool_name } summary:{ summary } " ,
336+ },
337+ }
338+ )
339+
340+ else :
341+ yield format_stream_data (
342+ {
343+ "event" : "tool_call" ,
344+ "data" : {
345+ "id" : chunk_id ,
346+ "role" : chunk .event .payload .step_type ,
347+ "token" : f"Tool:{ r .tool_name } response:{ r .content } " ,
348+ },
349+ }
350+ )
351+
352+ return
353+
354+ # -----------------------------------
355+ # Catch-all for everything else
356+ # -----------------------------------
357+ yield format_stream_data (
358+ {
359+ "event" : "heartbeat" ,
360+ "data" : {
361+ "id" : chunk_id ,
362+ "token" : "heartbeat" ,
363+ },
364+ }
365+ )
186366
187367
188368@router .post ("/streaming_query" )
@@ -222,7 +402,7 @@ async def response_generator(turn_response: Any) -> AsyncIterator[str]:
222402 yield stream_start_event (conversation_id )
223403
224404 async for chunk in turn_response :
225- if event := stream_build_event (chunk , chunk_id , metadata_map ):
405+ for event in stream_build_event (chunk , chunk_id , metadata_map ):
226406 complete_response += json .loads (event .replace ("data: " , "" ))[
227407 "data"
228408 ]["token" ]
0 commit comments