44import logging
55import re
66from json import JSONDecodeError
7- from typing import Any , AsyncIterator
7+ from typing import Any , AsyncIterator , Iterator
88
99from cachetools import TTLCache # type: ignore
1010
1111from llama_stack_client import APIConnectionError
1212from llama_stack_client .lib .agents .agent import AsyncAgent # type: ignore
1313from llama_stack_client import AsyncLlamaStackClient # type: ignore
14- from llama_stack_client .types .shared .interleaved_content_item import TextContentItem
1514from llama_stack_client .types import UserMessage # type: ignore
1615
16+ from llama_stack_client .lib .agents .event_logger import interleaved_content_as_str
17+ from llama_stack_client .types .shared import ToolCall
18+ from llama_stack_client .types .shared .interleaved_content_item import TextContentItem
19+
1720from fastapi import APIRouter , HTTPException , Request , Depends , status
1821from fastapi .responses import StreamingResponse
1922
4649_agent_cache : TTLCache [str , AsyncAgent ] = TTLCache (maxsize = 1000 , ttl = 3600 )
4750
4851
49- async def get_agent ( # pylint: disable=too-many-arguments,too-many-positional-arguments
52+ # # pylint: disable=R0913,R0917
53+ async def get_agent (
5054 client : AsyncLlamaStackClient ,
5155 model_id : str ,
5256 system_prompt : str ,
@@ -127,7 +131,7 @@ def stream_end_event(metadata_map: dict) -> str:
127131 )
128132
129133
130- def stream_build_event (chunk : Any , chunk_id : int , metadata_map : dict ) -> str | None :
134+ def stream_build_event (chunk : Any , chunk_id : int , metadata_map : dict ) -> Iterator [ str ] :
131135 """Build a streaming event from a chunk response.
132136
133137 This function processes chunks from the LLama Stack streaming response and formats
@@ -142,58 +146,261 @@ def stream_build_event(chunk: Any, chunk_id: int, metadata_map: dict) -> str | N
142146 chunk_id: The current chunk ID counter (gets incremented for each token)
143147
144148 Returns:
145- str | None: A formatted SSE data string with event information, or None if
146- the chunk doesn't contain processable event data
149+ Iterator[str]: An iterable list of formatted SSE data strings with event information
147150 """
148- # pylint: disable=R1702
149- if hasattr (chunk .event , "payload" ):
150- if chunk .event .payload .event_type == "step_progress" :
151- if hasattr (chunk .event .payload .delta , "text" ):
152- text = chunk .event .payload .delta .text
153- return format_stream_data (
151+ if hasattr (chunk , "error" ):
152+ yield from _handle_error_event (chunk , chunk_id )
153+ return
154+
155+ event_type = chunk .event .payload .event_type
156+ step_type = getattr (chunk .event .payload , "step_type" , None )
157+
158+ if event_type in {"turn_start" , "turn_awaiting_input" }:
159+ yield from _handle_turn_start_event (chunk_id )
160+ elif event_type == "turn_complete" :
161+ yield from _handle_turn_complete_event (chunk , chunk_id )
162+ elif step_type == "shield_call" :
163+ yield from _handle_shield_event (chunk , chunk_id )
164+ elif step_type == "inference" :
165+ yield from _handle_inference_event (chunk , chunk_id )
166+ elif step_type == "tool_execution" :
167+ yield from _handle_tool_execution_event (chunk , chunk_id , metadata_map )
168+ else :
169+ yield from _handle_heartbeat_event (chunk_id )
170+
171+
172+ # -----------------------------------
173+ # Error handling
174+ # -----------------------------------
175+ def _handle_error_event (chunk : Any , chunk_id : int ) -> Iterator [str ]:
176+ yield format_stream_data (
177+ {
178+ "event" : "error" ,
179+ "data" : {
180+ "id" : chunk_id ,
181+ "token" : chunk .error ["message" ],
182+ },
183+ }
184+ )
185+
186+
187+ # -----------------------------------
188+ # Turn handling
189+ # -----------------------------------
190+ def _handle_turn_start_event (chunk_id : int ) -> Iterator [str ]:
191+ yield format_stream_data (
192+ {
193+ "event" : "token" ,
194+ "data" : {
195+ "id" : chunk_id ,
196+ "token" : "" ,
197+ },
198+ }
199+ )
200+
201+
202+ def _handle_turn_complete_event (chunk : Any , chunk_id : int ) -> Iterator [str ]:
203+ yield format_stream_data (
204+ {
205+ "event" : "turn_complete" ,
206+ "data" : {
207+ "id" : chunk_id ,
208+ "token" : chunk .event .payload .turn .output_message .content ,
209+ },
210+ }
211+ )
212+
213+
214+ # -----------------------------------
215+ # Shield handling
216+ # -----------------------------------
217+ def _handle_shield_event (chunk : Any , chunk_id : int ) -> Iterator [str ]:
218+ if chunk .event .payload .event_type == "step_complete" :
219+ violation = chunk .event .payload .step_details .violation
220+ if not violation :
221+ yield format_stream_data (
222+ {
223+ "event" : "token" ,
224+ "data" : {
225+ "id" : chunk_id ,
226+ "role" : chunk .event .payload .step_type ,
227+ "token" : "No Violation" ,
228+ },
229+ }
230+ )
231+ else :
232+ violation = (
233+ f"Violation: { violation .user_message } (Metadata: { violation .metadata } )"
234+ )
235+ yield format_stream_data (
236+ {
237+ "event" : "token" ,
238+ "data" : {
239+ "id" : chunk_id ,
240+ "role" : chunk .event .payload .step_type ,
241+ "token" : violation ,
242+ },
243+ }
244+ )
245+
246+
247+ # -----------------------------------
248+ # Inference handling
249+ # -----------------------------------
250+ def _handle_inference_event (chunk : Any , chunk_id : int ) -> Iterator [str ]:
251+ if chunk .event .payload .event_type == "step_start" :
252+ yield format_stream_data (
253+ {
254+ "event" : "token" ,
255+ "data" : {
256+ "id" : chunk_id ,
257+ "role" : chunk .event .payload .step_type ,
258+ "token" : "" ,
259+ },
260+ }
261+ )
262+
263+ elif chunk .event .payload .event_type == "step_progress" :
264+ if chunk .event .payload .delta .type == "tool_call" :
265+ if isinstance (chunk .event .payload .delta .tool_call , str ):
266+ yield format_stream_data (
154267 {
155- "event" : "token " ,
268+ "event" : "tool_call " ,
156269 "data" : {
157270 "id" : chunk_id ,
158271 "role" : chunk .event .payload .step_type ,
159- "token" : text ,
272+ "token" : chunk . event . payload . delta . tool_call ,
160273 },
161274 }
162275 )
163- if (
164- chunk .event .payload .event_type == "step_complete"
165- and chunk .event .payload .step_details .step_type == "tool_execution"
166- ):
167- for r in chunk .event .payload .step_details .tool_responses :
168- if r .tool_name == "knowledge_search" and r .content :
169- for text_content_item in r .content :
170- if isinstance (text_content_item , TextContentItem ):
171- for match in METADATA_PATTERN .findall (
172- text_content_item .text
173- ):
174- try :
175- meta = json .loads (match .replace ("'" , '"' ))
276+ elif isinstance (chunk .event .payload .delta .tool_call , ToolCall ):
277+ yield format_stream_data (
278+ {
279+ "event" : "tool_call" ,
280+ "data" : {
281+ "id" : chunk_id ,
282+ "role" : chunk .event .payload .step_type ,
283+ "token" : chunk .event .payload .delta .tool_call .tool_name ,
284+ },
285+ }
286+ )
287+
288+ elif chunk .event .payload .delta .type == "text" :
289+ yield format_stream_data (
290+ {
291+ "event" : "token" ,
292+ "data" : {
293+ "id" : chunk_id ,
294+ "role" : chunk .event .payload .step_type ,
295+ "token" : chunk .event .payload .delta .text ,
296+ },
297+ }
298+ )
299+
300+
301+ # -----------------------------------
302+ # Tool Execution handling
303+ # -----------------------------------
304+ # pylint: disable=R1702,R0912
305+ def _handle_tool_execution_event (
306+ chunk : Any , chunk_id : int , metadata_map : dict
307+ ) -> Iterator [str ]:
308+ if chunk .event .payload .event_type == "step_start" :
309+ yield format_stream_data (
310+ {
311+ "event" : "tool_call" ,
312+ "data" : {
313+ "id" : chunk_id ,
314+ "role" : chunk .event .payload .step_type ,
315+ "token" : "" ,
316+ },
317+ }
318+ )
319+
320+ elif chunk .event .payload .event_type == "step_complete" :
321+ for t in chunk .event .payload .step_details .tool_calls :
322+ yield format_stream_data (
323+ {
324+ "event" : "tool_call" ,
325+ "data" : {
326+ "id" : chunk_id ,
327+ "role" : chunk .event .payload .step_type ,
328+ "token" : f"Tool:{ t .tool_name } arguments:{ t .arguments } " ,
329+ },
330+ }
331+ )
332+
333+ for r in chunk .event .payload .step_details .tool_responses :
334+ if r .tool_name == "query_from_memory" :
335+ inserted_context = interleaved_content_as_str (r .content )
336+ yield format_stream_data (
337+ {
338+ "event" : "tool_call" ,
339+ "data" : {
340+ "id" : chunk_id ,
341+ "role" : chunk .event .payload .step_type ,
342+ "token" : f"Fetched { len (inserted_context )} bytes from memory" ,
343+ },
344+ }
345+ )
346+
347+ elif r .tool_name == "knowledge_search" and r .content :
348+ summary = ""
349+ for i , text_content_item in enumerate (r .content ):
350+ if isinstance (text_content_item , TextContentItem ):
351+ if i == 0 :
352+ summary = text_content_item .text
353+ newline_pos = summary .find ("\n " )
354+ if newline_pos > 0 :
355+ summary = summary [:newline_pos ]
356+ for match in METADATA_PATTERN .findall (text_content_item .text ):
357+ try :
358+ meta = json .loads (match .replace ("'" , '"' ))
359+ if "document_id" in meta :
176360 metadata_map [meta ["document_id" ]] = meta
177- except JSONDecodeError :
178- logger .debug (
179- "JSONDecodeError was thrown in processing %s" ,
180- match ,
181- )
182- if chunk .event .payload .step_details .tool_calls :
183- tool_name = str (
184- chunk .event .payload .step_details .tool_calls [0 ].tool_name
361+ except JSONDecodeError :
362+ logger .debug (
363+ "JSONDecodeError was thrown in processing %s" ,
364+ match ,
365+ )
366+
367+ yield format_stream_data (
368+ {
369+ "event" : "tool_call" ,
370+ "data" : {
371+ "id" : chunk_id ,
372+ "role" : chunk .event .payload .step_type ,
373+ "token" : f"Tool:{ r .tool_name } summary:{ summary } " ,
374+ },
375+ }
185376 )
186- return format_stream_data (
377+
378+ else :
379+ yield format_stream_data (
187380 {
188- "event" : "token " ,
381+ "event" : "tool_call " ,
189382 "data" : {
190383 "id" : chunk_id ,
191384 "role" : chunk .event .payload .step_type ,
192- "token" : tool_name ,
385+ "token" : f"Tool: { r . tool_name } response: { r . content } " ,
193386 },
194387 }
195388 )
196- return None
389+
390+
391+ # -----------------------------------
392+ # Catch-all for everything else
393+ # -----------------------------------
394+ def _handle_heartbeat_event (chunk_id : int ) -> Iterator [str ]:
395+ yield format_stream_data (
396+ {
397+ "event" : "heartbeat" ,
398+ "data" : {
399+ "id" : chunk_id ,
400+ "token" : "heartbeat" ,
401+ },
402+ }
403+ )
197404
198405
199406@router .post ("/streaming_query" )
@@ -233,7 +440,7 @@ async def response_generator(turn_response: Any) -> AsyncIterator[str]:
233440 yield stream_start_event (conversation_id )
234441
235442 async for chunk in turn_response :
236- if event := stream_build_event (chunk , chunk_id , metadata_map ):
443+ for event in stream_build_event (chunk , chunk_id , metadata_map ):
237444 complete_response += json .loads (event .replace ("data: " , "" ))[
238445 "data"
239446 ]["token" ]
0 commit comments