2424from llama_stack_client .types .shared .interleaved_content_item import TextContentItem
2525from llama_stack_client .types .agents .turn_create_params import Document
2626
27- from app .database import get_session
2827from app .endpoints .query import (
2928 get_rag_toolgroups ,
3029 is_input_shield ,
4544from constants import DEFAULT_RAG_TOOL , MEDIA_TYPE_JSON , MEDIA_TYPE_TEXT
4645import metrics
4746from metrics .utils import update_llm_token_count_from_turn
48- from models .cache_entry import CacheEntry
4947from models .config import Action
48+ from models .context import ResponseGeneratorContext
5049from models .database .conversations import UserConversation
5150from models .requests import QueryRequest
5251from models .responses import ForbiddenResponse , UnauthorizedResponse
5352from utils .endpoints import (
5453 check_configuration_loaded ,
55- create_referenced_documents_with_metadata ,
54+ cleanup_after_streaming ,
5655 create_rag_chunks_dict ,
5756 get_agent ,
5857 get_system_prompt ,
59- store_conversation_into_cache ,
6058 validate_model_provider_override ,
6159)
6260from utils .mcp_headers import handle_mcp_headers_with_toolgroups , mcp_headers_dependency
@@ -696,17 +694,8 @@ def _handle_heartbeat_event(
696694 )
697695
698696
699- def create_agent_response_generator ( # pylint: disable=too-many-arguments,too-many-locals
700- conversation_id : str ,
701- user_id : str ,
702- model_id : str ,
703- provider_id : str ,
704- query_request : QueryRequest ,
705- metadata_map : dict [str , dict [str , Any ]],
706- client : AsyncLlamaStackClient ,
707- llama_stack_model_id : str ,
708- started_at : str ,
709- _skip_userid_check : bool ,
697+ def create_agent_response_generator ( # pylint: disable=too-many-locals
698+ context : ResponseGeneratorContext ,
710699) -> Any :
711700 """
712701 Create a response generator function for Agent API streaming.
@@ -715,16 +704,7 @@ def create_agent_response_generator( # pylint: disable=too-many-arguments,too-m
715704 responses from the Agent API and yields Server-Sent Events (SSE).
716705
717706 Args:
718- conversation_id: The conversation identifier
719- user_id: The user identifier
720- model_id: The model identifier
721- provider_id: The provider identifier
722- query_request: The query request object
723- metadata_map: Dictionary for storing metadata from tool responses
724- client: The Llama Stack client
725- llama_stack_model_id: The full llama stack model ID
726- started_at: Timestamp when the request started
727- _skip_userid_check: Whether to skip user ID validation
707+ context: Context object containing all necessary parameters for response generation
728708
729709 Returns:
730710 An async generator function that yields SSE-formatted strings
@@ -748,10 +728,10 @@ async def response_generator(
748728 summary = TurnSummary (llm_response = "No response from the model" , tool_calls = [])
749729
750730 # Determine media type for response formatting
751- media_type = query_request .media_type or MEDIA_TYPE_JSON
731+ media_type = context . query_request .media_type or MEDIA_TYPE_JSON
752732
753733 # Send start event at the beginning of the stream
754- yield stream_start_event (conversation_id )
734+ yield stream_start_event (context . conversation_id )
755735
756736 latest_turn : Any | None = None
757737
@@ -764,10 +744,10 @@ async def response_generator(
764744 p .turn .output_message .content
765745 )
766746 latest_turn = p .turn
767- system_prompt = get_system_prompt (query_request , configuration )
747+ system_prompt = get_system_prompt (context . query_request , configuration )
768748 try :
769749 update_llm_token_count_from_turn (
770- p .turn , model_id , provider_id , system_prompt
750+ p .turn , context . model_id , context . provider_id , system_prompt
771751 )
772752 except Exception : # pylint: disable=broad-except
773753 logger .exception ("Failed to update token usage metrics" )
@@ -776,7 +756,11 @@ async def response_generator(
776756 summary .append_tool_calls_from_llama (p .step_details )
777757
778758 for event in stream_build_event (
779- chunk , chunk_id , metadata_map , media_type , conversation_id
759+ chunk ,
760+ chunk_id ,
761+ context .metadata_map ,
762+ media_type ,
763+ context .conversation_id ,
780764 ):
781765 chunk_id += 1
782766 yield event
@@ -788,76 +772,33 @@ async def response_generator(
788772 else TokenCounter ()
789773 )
790774
791- yield stream_end_event (metadata_map , summary , token_usage , media_type )
792-
793- if not is_transcripts_enabled ():
794- logger .debug ("Transcript collection is disabled in the configuration" )
795- else :
796- store_transcript (
797- user_id = user_id ,
798- conversation_id = conversation_id ,
799- model_id = model_id ,
800- provider_id = provider_id ,
801- query_is_valid = True , # TODO(lucasagomes): implement as part of query validation
802- query = query_request .query ,
803- query_request = query_request ,
804- summary = summary ,
805- rag_chunks = create_rag_chunks_dict (summary ),
806- truncated = False , # TODO(lucasagomes): implement truncation as part
807- # of quota work
808- attachments = query_request .attachments or [],
809- )
810-
811- # Get the initial topic summary for the conversation
812- topic_summary = None
813- with get_session () as session :
814- existing_conversation = (
815- session .query (UserConversation ).filter_by (id = conversation_id ).first ()
816- )
817- if not existing_conversation :
818- topic_summary = await get_topic_summary (
819- query_request .query , client , llama_stack_model_id
820- )
821-
822- completed_at = datetime .now (UTC ).strftime ("%Y-%m-%dT%H:%M:%SZ" )
823-
824- referenced_documents = create_referenced_documents_with_metadata (
825- summary , metadata_map
826- )
827-
828- cache_entry = CacheEntry (
829- query = query_request .query ,
830- response = summary .llm_response ,
831- provider = provider_id ,
832- model = model_id ,
833- started_at = started_at ,
834- completed_at = completed_at ,
835- referenced_documents = (
836- referenced_documents if referenced_documents else None
837- ),
838- )
839-
840- store_conversation_into_cache (
841- configuration ,
842- user_id ,
843- conversation_id ,
844- cache_entry ,
845- _skip_userid_check ,
846- topic_summary ,
847- )
848-
849- persist_user_conversation_details (
850- user_id = user_id ,
851- conversation_id = conversation_id ,
852- model = model_id ,
853- provider_id = provider_id ,
854- topic_summary = topic_summary ,
775+ yield stream_end_event (context .metadata_map , summary , token_usage , media_type )
776+
777+ # Perform cleanup tasks (database and cache operations)
778+ await cleanup_after_streaming (
779+ user_id = context .user_id ,
780+ conversation_id = context .conversation_id ,
781+ model_id = context .model_id ,
782+ provider_id = context .provider_id ,
783+ llama_stack_model_id = context .llama_stack_model_id ,
784+ query_request = context .query_request ,
785+ summary = summary ,
786+ metadata_map = context .metadata_map ,
787+ started_at = context .started_at ,
788+ client = context .client ,
789+ config = configuration ,
790+ skip_userid_check = context .skip_userid_check ,
791+ get_topic_summary_func = get_topic_summary ,
792+ is_transcripts_enabled_func = is_transcripts_enabled ,
793+ store_transcript_func = store_transcript ,
794+ persist_user_conversation_details_func = persist_user_conversation_details ,
795+ rag_chunks = create_rag_chunks_dict (summary ),
855796 )
856797
857798 return response_generator
858799
859800
860- async def streaming_query_endpoint_handler_base ( # pylint: disable=too-many-locals,too-many-statements,too-many-arguments
801+ async def streaming_query_endpoint_handler_base ( # pylint: disable=too-many-locals,too-many-statements,too-many-arguments,too-many-positional-arguments
861802 request : Request ,
862803 query_request : QueryRequest ,
863804 auth : AuthTuple ,
@@ -866,7 +807,7 @@ async def streaming_query_endpoint_handler_base( # pylint: disable=too-many-loc
866807 create_response_generator_func : Callable [..., Any ],
867808) -> StreamingResponse :
868809 """
869- Base handler for streaming query endpoints.
810+ Handle streaming query endpoints with common logic .
870811
871812 This base handler contains all the common logic for streaming query endpoints
872813 and accepts functions for API-specific behavior (Agent API vs Responses API).
@@ -937,20 +878,23 @@ async def streaming_query_endpoint_handler_base( # pylint: disable=too-many-loc
937878 )
938879 metadata_map : dict [str , dict [str , Any ]] = {}
939880
940- # Create the response generator using the provided factory function
941- response_generator = create_response_generator_func (
881+ # Create context object for response generator
882+ context = ResponseGeneratorContext (
942883 conversation_id = conversation_id ,
943884 user_id = user_id ,
885+ skip_userid_check = _skip_userid_check ,
944886 model_id = model_id ,
945887 provider_id = provider_id ,
946- query_request = query_request ,
947- metadata_map = metadata_map ,
948- client = client ,
949888 llama_stack_model_id = llama_stack_model_id ,
889+ query_request = query_request ,
950890 started_at = started_at ,
951- _skip_userid_check = _skip_userid_check ,
891+ client = client ,
892+ metadata_map = metadata_map ,
952893 )
953894
895+ # Create the response generator using the provided factory function
896+ response_generator = create_response_generator_func (context )
897+
954898 # Update metrics for the LLM call
955899 metrics .llm_calls_total .labels (provider_id , model_id ).inc ()
956900
0 commit comments