Skip to content

Commit f3c324e

Browse files
committed
Fixes to address PR reviews
1 parent 7e2c1b7 commit f3c324e

File tree

7 files changed

+316
-362
lines changed

7 files changed

+316
-362
lines changed

src/app/endpoints/streaming_query.py

Lines changed: 46 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from llama_stack_client.types.shared.interleaved_content_item import TextContentItem
2525
from llama_stack_client.types.agents.turn_create_params import Document
2626

27-
from app.database import get_session
2827
from app.endpoints.query import (
2928
get_rag_toolgroups,
3029
is_input_shield,
@@ -45,18 +44,17 @@
4544
from constants import DEFAULT_RAG_TOOL, MEDIA_TYPE_JSON, MEDIA_TYPE_TEXT
4645
import metrics
4746
from metrics.utils import update_llm_token_count_from_turn
48-
from models.cache_entry import CacheEntry
4947
from models.config import Action
48+
from models.context import ResponseGeneratorContext
5049
from models.database.conversations import UserConversation
5150
from models.requests import QueryRequest
5251
from models.responses import ForbiddenResponse, UnauthorizedResponse
5352
from utils.endpoints import (
5453
check_configuration_loaded,
55-
create_referenced_documents_with_metadata,
54+
cleanup_after_streaming,
5655
create_rag_chunks_dict,
5756
get_agent,
5857
get_system_prompt,
59-
store_conversation_into_cache,
6058
validate_model_provider_override,
6159
)
6260
from utils.mcp_headers import handle_mcp_headers_with_toolgroups, mcp_headers_dependency
@@ -696,17 +694,8 @@ def _handle_heartbeat_event(
696694
)
697695

698696

699-
def create_agent_response_generator( # pylint: disable=too-many-arguments,too-many-locals
700-
conversation_id: str,
701-
user_id: str,
702-
model_id: str,
703-
provider_id: str,
704-
query_request: QueryRequest,
705-
metadata_map: dict[str, dict[str, Any]],
706-
client: AsyncLlamaStackClient,
707-
llama_stack_model_id: str,
708-
started_at: str,
709-
_skip_userid_check: bool,
697+
def create_agent_response_generator( # pylint: disable=too-many-locals
698+
context: ResponseGeneratorContext,
710699
) -> Any:
711700
"""
712701
Create a response generator function for Agent API streaming.
@@ -715,16 +704,7 @@ def create_agent_response_generator( # pylint: disable=too-many-arguments,too-m
715704
responses from the Agent API and yields Server-Sent Events (SSE).
716705
717706
Args:
718-
conversation_id: The conversation identifier
719-
user_id: The user identifier
720-
model_id: The model identifier
721-
provider_id: The provider identifier
722-
query_request: The query request object
723-
metadata_map: Dictionary for storing metadata from tool responses
724-
client: The Llama Stack client
725-
llama_stack_model_id: The full llama stack model ID
726-
started_at: Timestamp when the request started
727-
_skip_userid_check: Whether to skip user ID validation
707+
context: Context object containing all necessary parameters for response generation
728708
729709
Returns:
730710
An async generator function that yields SSE-formatted strings
@@ -748,10 +728,10 @@ async def response_generator(
748728
summary = TurnSummary(llm_response="No response from the model", tool_calls=[])
749729

750730
# Determine media type for response formatting
751-
media_type = query_request.media_type or MEDIA_TYPE_JSON
731+
media_type = context.query_request.media_type or MEDIA_TYPE_JSON
752732

753733
# Send start event at the beginning of the stream
754-
yield stream_start_event(conversation_id)
734+
yield stream_start_event(context.conversation_id)
755735

756736
latest_turn: Any | None = None
757737

@@ -764,10 +744,10 @@ async def response_generator(
764744
p.turn.output_message.content
765745
)
766746
latest_turn = p.turn
767-
system_prompt = get_system_prompt(query_request, configuration)
747+
system_prompt = get_system_prompt(context.query_request, configuration)
768748
try:
769749
update_llm_token_count_from_turn(
770-
p.turn, model_id, provider_id, system_prompt
750+
p.turn, context.model_id, context.provider_id, system_prompt
771751
)
772752
except Exception: # pylint: disable=broad-except
773753
logger.exception("Failed to update token usage metrics")
@@ -776,7 +756,11 @@ async def response_generator(
776756
summary.append_tool_calls_from_llama(p.step_details)
777757

778758
for event in stream_build_event(
779-
chunk, chunk_id, metadata_map, media_type, conversation_id
759+
chunk,
760+
chunk_id,
761+
context.metadata_map,
762+
media_type,
763+
context.conversation_id,
780764
):
781765
chunk_id += 1
782766
yield event
@@ -788,76 +772,33 @@ async def response_generator(
788772
else TokenCounter()
789773
)
790774

791-
yield stream_end_event(metadata_map, summary, token_usage, media_type)
792-
793-
if not is_transcripts_enabled():
794-
logger.debug("Transcript collection is disabled in the configuration")
795-
else:
796-
store_transcript(
797-
user_id=user_id,
798-
conversation_id=conversation_id,
799-
model_id=model_id,
800-
provider_id=provider_id,
801-
query_is_valid=True, # TODO(lucasagomes): implement as part of query validation
802-
query=query_request.query,
803-
query_request=query_request,
804-
summary=summary,
805-
rag_chunks=create_rag_chunks_dict(summary),
806-
truncated=False, # TODO(lucasagomes): implement truncation as part
807-
# of quota work
808-
attachments=query_request.attachments or [],
809-
)
810-
811-
# Get the initial topic summary for the conversation
812-
topic_summary = None
813-
with get_session() as session:
814-
existing_conversation = (
815-
session.query(UserConversation).filter_by(id=conversation_id).first()
816-
)
817-
if not existing_conversation:
818-
topic_summary = await get_topic_summary(
819-
query_request.query, client, llama_stack_model_id
820-
)
821-
822-
completed_at = datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ")
823-
824-
referenced_documents = create_referenced_documents_with_metadata(
825-
summary, metadata_map
826-
)
827-
828-
cache_entry = CacheEntry(
829-
query=query_request.query,
830-
response=summary.llm_response,
831-
provider=provider_id,
832-
model=model_id,
833-
started_at=started_at,
834-
completed_at=completed_at,
835-
referenced_documents=(
836-
referenced_documents if referenced_documents else None
837-
),
838-
)
839-
840-
store_conversation_into_cache(
841-
configuration,
842-
user_id,
843-
conversation_id,
844-
cache_entry,
845-
_skip_userid_check,
846-
topic_summary,
847-
)
848-
849-
persist_user_conversation_details(
850-
user_id=user_id,
851-
conversation_id=conversation_id,
852-
model=model_id,
853-
provider_id=provider_id,
854-
topic_summary=topic_summary,
775+
yield stream_end_event(context.metadata_map, summary, token_usage, media_type)
776+
777+
# Perform cleanup tasks (database and cache operations)
778+
await cleanup_after_streaming(
779+
user_id=context.user_id,
780+
conversation_id=context.conversation_id,
781+
model_id=context.model_id,
782+
provider_id=context.provider_id,
783+
llama_stack_model_id=context.llama_stack_model_id,
784+
query_request=context.query_request,
785+
summary=summary,
786+
metadata_map=context.metadata_map,
787+
started_at=context.started_at,
788+
client=context.client,
789+
config=configuration,
790+
skip_userid_check=context.skip_userid_check,
791+
get_topic_summary_func=get_topic_summary,
792+
is_transcripts_enabled_func=is_transcripts_enabled,
793+
store_transcript_func=store_transcript,
794+
persist_user_conversation_details_func=persist_user_conversation_details,
795+
rag_chunks=create_rag_chunks_dict(summary),
855796
)
856797

857798
return response_generator
858799

859800

860-
async def streaming_query_endpoint_handler_base( # pylint: disable=too-many-locals,too-many-statements,too-many-arguments
801+
async def streaming_query_endpoint_handler_base( # pylint: disable=too-many-locals,too-many-statements,too-many-arguments,too-many-positional-arguments
861802
request: Request,
862803
query_request: QueryRequest,
863804
auth: AuthTuple,
@@ -866,7 +807,7 @@ async def streaming_query_endpoint_handler_base( # pylint: disable=too-many-loc
866807
create_response_generator_func: Callable[..., Any],
867808
) -> StreamingResponse:
868809
"""
869-
Base handler for streaming query endpoints.
810+
Handle streaming query endpoints with common logic.
870811
871812
This base handler contains all the common logic for streaming query endpoints
872813
and accepts functions for API-specific behavior (Agent API vs Responses API).
@@ -937,20 +878,23 @@ async def streaming_query_endpoint_handler_base( # pylint: disable=too-many-loc
937878
)
938879
metadata_map: dict[str, dict[str, Any]] = {}
939880

940-
# Create the response generator using the provided factory function
941-
response_generator = create_response_generator_func(
881+
# Create context object for response generator
882+
context = ResponseGeneratorContext(
942883
conversation_id=conversation_id,
943884
user_id=user_id,
885+
skip_userid_check=_skip_userid_check,
944886
model_id=model_id,
945887
provider_id=provider_id,
946-
query_request=query_request,
947-
metadata_map=metadata_map,
948-
client=client,
949888
llama_stack_model_id=llama_stack_model_id,
889+
query_request=query_request,
950890
started_at=started_at,
951-
_skip_userid_check=_skip_userid_check,
891+
client=client,
892+
metadata_map=metadata_map,
952893
)
953894

895+
# Create the response generator using the provided factory function
896+
response_generator = create_response_generator_func(context)
897+
954898
# Update metrics for the LLM call
955899
metrics.llm_calls_total.labels(provider_id, model_id).inc()
956900

0 commit comments

Comments
 (0)