33import json
44import logging
55import re
6- from typing import Any , AsyncIterator
6+ from typing import Any , AsyncIterator , Iterator
77
88from cachetools import TTLCache # type: ignore
99
1010from llama_stack_client import APIConnectionError
1111from llama_stack_client .lib .agents .agent import AsyncAgent # type: ignore
1212from llama_stack_client import AsyncLlamaStackClient # type: ignore
13- from llama_stack_client .types .shared .interleaved_content_item import TextContentItem
1413from llama_stack_client .types import UserMessage # type: ignore
1514
15+ from llama_stack_client .lib .agents .event_logger import interleaved_content_as_str
16+ from llama_stack_client .types .shared import ToolCall
17+ from llama_stack_client .types .shared .interleaved_content_item import TextContentItem
18+
1619from fastapi import APIRouter , HTTPException , Request , Depends , status
1720from fastapi .responses import StreamingResponse
1821
@@ -122,7 +125,7 @@ def stream_end_event(metadata_map: dict) -> str:
122125 )
123126
124127
125- def stream_build_event (chunk : Any , chunk_id : int , metadata_map : dict ) -> str | None :
128+ def stream_build_event (chunk : Any , chunk_id : int , metadata_map : dict ) -> Iterator [ str ] :
126129 """Build a streaming event from a chunk response.
127130
128131 This function processes chunks from the LLama Stack streaming response and formats
@@ -137,52 +140,258 @@ def stream_build_event(chunk: Any, chunk_id: int, metadata_map: dict) -> str | N
137140 chunk_id: The current chunk ID counter (gets incremented for each token)
138141
139142 Returns:
140- str | None: A formatted SSE data string with event information, or None if
141- the chunk doesn't contain processable event data
143+ Iterator[str]: An iterable list of formatted SSE data strings with event information
142144 """
143- # pylint: disable=R1702
144- if hasattr (chunk .event , "payload" ):
145- if chunk .event .payload .event_type == "step_progress" :
146- if hasattr (chunk .event .payload .delta , "text" ):
147- text = chunk .event .payload .delta .text
148- return format_stream_data (
145+ if hasattr (chunk , "error" ):
146+ yield from _handle_error_event (chunk , chunk_id )
147+ return
148+
149+ event_type = chunk .event .payload .event_type
150+ step_type = getattr (chunk .event .payload , "step_type" , None )
151+
152+ if event_type in {"turn_start" , "turn_awaiting_input" }:
153+ yield from _handle_turn_start_event (chunk_id )
154+ elif event_type == "turn_complete" :
155+ yield from _handle_turn_complete_event (chunk , chunk_id )
156+ elif step_type == "shield_call" :
157+ yield from _handle_shield_event (chunk , chunk_id )
158+ elif step_type == "inference" :
159+ yield from _handle_inference_event (chunk , chunk_id )
160+ elif step_type == "tool_execution" :
161+ yield from _handle_tool_execution_event (chunk , chunk_id , metadata_map )
162+ else :
163+ yield from _handle_heartbeat_event (chunk_id )
164+
165+
166+ # -----------------------------------
167+ # Error handling
168+ # -----------------------------------
169+ def _handle_error_event (chunk : Any , chunk_id : int ) -> Iterator [str ]:
170+ yield format_stream_data (
171+ {
172+ "event" : "error" ,
173+ "data" : {
174+ "id" : chunk_id ,
175+ "token" : chunk .error ["message" ],
176+ },
177+ }
178+ )
179+
180+
181+ # -----------------------------------
182+ # Turn handling
183+ # -----------------------------------
184+ def _handle_turn_start_event (chunk_id : int ) -> Iterator [str ]:
185+ yield format_stream_data (
186+ {
187+ "event" : "token" ,
188+ "data" : {
189+ "id" : chunk_id ,
190+ "token" : "" ,
191+ },
192+ }
193+ )
194+
195+
196+ def _handle_turn_complete_event (chunk : Any , chunk_id : int ) -> Iterator [str ]:
197+ yield format_stream_data (
198+ {
199+ "event" : "turn_complete" ,
200+ "data" : {
201+ "id" : chunk_id ,
202+ "token" : chunk .event .payload .turn .output_message .content ,
203+ },
204+ }
205+ )
206+
207+
208+ # -----------------------------------
209+ # Shield handling
210+ # -----------------------------------
211+ def _handle_shield_event (chunk : Any , chunk_id : int ) -> Iterator [str ]:
212+ if chunk .event .payload .event_type == "step_complete" :
213+ violation = chunk .event .payload .step_details .violation
214+ if not violation :
215+ yield format_stream_data (
216+ {
217+ "event" : "token" ,
218+ "data" : {
219+ "id" : chunk_id ,
220+ "role" : chunk .event .payload .step_type ,
221+ "token" : "No Violation" ,
222+ },
223+ }
224+ )
225+ else :
226+ violation = (
227+ f"Violation: { violation .user_message } (Metadata: { violation .metadata } )"
228+ )
229+ yield format_stream_data (
230+ {
231+ "event" : "token" ,
232+ "data" : {
233+ "id" : chunk_id ,
234+ "role" : chunk .event .payload .step_type ,
235+ "token" : violation ,
236+ },
237+ }
238+ )
239+
240+
241+ # -----------------------------------
242+ # Inference handling
243+ # -----------------------------------
244+ def _handle_inference_event (chunk : Any , chunk_id : int ) -> Iterator [str ]:
245+ if chunk .event .payload .event_type == "step_start" :
246+ yield format_stream_data (
247+ {
248+ "event" : "token" ,
249+ "data" : {
250+ "id" : chunk_id ,
251+ "role" : chunk .event .payload .step_type ,
252+ "token" : "" ,
253+ },
254+ }
255+ )
256+
257+ elif chunk .event .payload .event_type == "step_progress" :
258+ if chunk .event .payload .delta .type == "tool_call" :
259+ if isinstance (chunk .event .payload .delta .tool_call , str ):
260+ yield format_stream_data (
149261 {
150- "event" : "token " ,
262+ "event" : "tool_call " ,
151263 "data" : {
152264 "id" : chunk_id ,
153265 "role" : chunk .event .payload .step_type ,
154- "token" : text ,
266+ "token" : chunk . event . payload . delta . tool_call ,
155267 },
156268 }
157269 )
158- if (
159- chunk .event .payload .event_type == "step_complete"
160- and chunk .event .payload .step_details .step_type == "tool_execution"
161- ):
162- for r in chunk .event .payload .step_details .tool_responses :
163- if r .tool_name == "knowledge_search" and r .content :
164- for text_content_item in r .content :
165- if isinstance (text_content_item , TextContentItem ):
166- for match in METADATA_PATTERN .findall (
167- text_content_item .text
168- ):
270+ elif isinstance (chunk .event .payload .delta .tool_call , ToolCall ):
271+ yield format_stream_data (
272+ {
273+ "event" : "tool_call" ,
274+ "data" : {
275+ "id" : chunk_id ,
276+ "role" : chunk .event .payload .step_type ,
277+ "token" : chunk .event .payload .delta .tool_call .tool_name ,
278+ },
279+ }
280+ )
281+
282+ elif chunk .event .payload .delta .type == "text" :
283+ yield format_stream_data (
284+ {
285+ "event" : "token" ,
286+ "data" : {
287+ "id" : chunk_id ,
288+ "role" : chunk .event .payload .step_type ,
289+ "token" : chunk .event .payload .delta .text ,
290+ },
291+ }
292+ )
293+
294+
295+ # -----------------------------------
296+ # Tool Execution handling
297+ # -----------------------------------
298+ # pylint: disable=R1702,R0912
299+ def _handle_tool_execution_event (
300+ chunk : Any , chunk_id : int , metadata_map : dict
301+ ) -> Iterator [str ]:
302+ if chunk .event .payload .event_type == "step_start" :
303+ yield format_stream_data (
304+ {
305+ "event" : "tool_call" ,
306+ "data" : {
307+ "id" : chunk_id ,
308+ "role" : chunk .event .payload .step_type ,
309+ "token" : "" ,
310+ },
311+ }
312+ )
313+
314+ elif chunk .event .payload .event_type == "step_complete" :
315+ for t in chunk .event .payload .step_details .tool_calls :
316+ yield format_stream_data (
317+ {
318+ "event" : "tool_call" ,
319+ "data" : {
320+ "id" : chunk_id ,
321+ "role" : chunk .event .payload .step_type ,
322+ "token" : f"Tool:{ t .tool_name } arguments:{ t .arguments } " ,
323+ },
324+ }
325+ )
326+
327+ for r in chunk .event .payload .step_details .tool_responses :
328+ if r .tool_name == "query_from_memory" :
329+ inserted_context = interleaved_content_as_str (r .content )
330+ yield format_stream_data (
331+ {
332+ "event" : "tool_call" ,
333+ "data" : {
334+ "id" : chunk_id ,
335+ "role" : chunk .event .payload .step_type ,
336+ "token" : f"Fetched { len (inserted_context )} bytes from memory" ,
337+ },
338+ }
339+ )
340+
341+ elif r .tool_name == "knowledge_search" and r .content :
342+ summary = ""
343+ for i , text_content_item in enumerate (r .content ):
344+ if isinstance (text_content_item , TextContentItem ):
345+ if i == 0 :
346+ summary = text_content_item .text
347+ newline_pos = summary .find ("\n " )
348+ if newline_pos > 0 :
349+ summary = summary [:newline_pos ]
350+ for match in METADATA_PATTERN .findall (text_content_item .text ):
351+ try :
169352 meta = json .loads (match .replace ("'" , '"' ))
170- metadata_map [meta ["document_id" ]] = meta
171- if chunk .event .payload .step_details .tool_calls :
172- tool_name = str (
173- chunk .event .payload .step_details .tool_calls [0 ].tool_name
353+ if "document_id" in meta :
354+ metadata_map [meta ["document_id" ]] = meta
355+ except (json .JSONDecodeError , KeyError ) as e :
356+ logger .warning ("Failed to parse metadata: %s" , e )
357+
358+ yield format_stream_data (
359+ {
360+ "event" : "tool_call" ,
361+ "data" : {
362+ "id" : chunk_id ,
363+ "role" : chunk .event .payload .step_type ,
364+ "token" : f"Tool:{ r .tool_name } summary:{ summary } " ,
365+ },
366+ }
174367 )
175- return format_stream_data (
368+
369+ else :
370+ yield format_stream_data (
176371 {
177- "event" : "token " ,
372+ "event" : "tool_call " ,
178373 "data" : {
179374 "id" : chunk_id ,
180375 "role" : chunk .event .payload .step_type ,
181- "token" : tool_name ,
376+ "token" : f"Tool: { r . tool_name } response: { r . content } " ,
182377 },
183378 }
184379 )
185- return None
380+
381+
382+ # -----------------------------------
383+ # Catch-all for everything else
384+ # -----------------------------------
385+ def _handle_heartbeat_event (chunk_id : int ) -> Iterator [str ]:
386+ yield format_stream_data (
387+ {
388+ "event" : "heartbeat" ,
389+ "data" : {
390+ "id" : chunk_id ,
391+ "token" : "heartbeat" ,
392+ },
393+ }
394+ )
186395
187396
188397@router .post ("/streaming_query" )
@@ -222,7 +431,7 @@ async def response_generator(turn_response: Any) -> AsyncIterator[str]:
222431 yield stream_start_event (conversation_id )
223432
224433 async for chunk in turn_response :
225- if event := stream_build_event (chunk , chunk_id , metadata_map ):
434+ for event in stream_build_event (chunk , chunk_id , metadata_map ):
226435 complete_response += json .loads (event .replace ("data: " , "" ))[
227436 "data"
228437 ]["token" ]
0 commit comments