11"""Handler for REST API call to provide answer to query."""
22
3- from contextlib import suppress
43from datetime import datetime , UTC
54import json
65import logging
76import os
87from pathlib import Path
98from typing import Any
109
11- from llama_stack_client .lib .agents .agent import Agent
1210from llama_stack_client import APIConnectionError
13- from llama_stack_client import LlamaStackClient # type: ignore
11+ from llama_stack_client import AsyncLlamaStackClient # type: ignore
1412from llama_stack_client .types import UserMessage , Shield # type: ignore
1513from llama_stack_client .types .agents .turn_create_params import (
1614 ToolgroupAgentToolGroupWithArgs ,
2018
2119from fastapi import APIRouter , HTTPException , status , Depends
2220
23- from client import LlamaStackClientHolder
21+ from client import AsyncLlamaStackClientHolder
2422from configuration import configuration
2523import metrics
2624from models .responses import QueryResponse , UnauthorizedResponse , ForbiddenResponse
2725from models .requests import QueryRequest , Attachment
2826import constants
2927from auth import get_auth_dependency
3028from utils .common import retrieve_user_id
31- from utils .endpoints import check_configuration_loaded , get_system_prompt
29+ from utils .endpoints import check_configuration_loaded , get_agent , get_system_prompt
3230from utils .mcp_headers import mcp_headers_dependency , handle_mcp_headers_with_toolgroups
3331from utils .suid import get_suid
34- from utils .types import GraniteToolParser
3532
3633logger = logging .getLogger ("app.endpoints.handlers" )
3734router = APIRouter (tags = ["query" ])
@@ -68,50 +65,8 @@ def is_transcripts_enabled() -> bool:
6865 return configuration .user_data_collection_configuration .transcripts_enabled
6966
7067
71- def get_agent ( # pylint: disable=too-many-arguments,too-many-positional-arguments
72- client : LlamaStackClient ,
73- model_id : str ,
74- system_prompt : str ,
75- available_input_shields : list [str ],
76- available_output_shields : list [str ],
77- conversation_id : str | None ,
78- no_tools : bool = False ,
79- ) -> tuple [Agent , str , str ]:
80- """Get existing agent or create a new one with session persistence."""
81- existing_agent_id = None
82- if conversation_id :
83- with suppress (ValueError ):
84- existing_agent_id = client .agents .retrieve (
85- agent_id = conversation_id
86- ).agent_id
87-
88- logger .debug ("Creating new agent" )
89- # TODO(lucasagomes): move to ReActAgent
90- agent = Agent (
91- client ,
92- model = model_id ,
93- instructions = system_prompt ,
94- input_shields = available_input_shields if available_input_shields else [],
95- output_shields = available_output_shields if available_output_shields else [],
96- tool_parser = None if no_tools else GraniteToolParser .get_parser (model_id ),
97- enable_session_persistence = True ,
98- )
99- if existing_agent_id and conversation_id :
100- orphan_agent_id = agent .agent_id
101- agent .agent_id = conversation_id
102- client .agents .delete (agent_id = orphan_agent_id )
103- sessions_response = client .agents .session .list (agent_id = conversation_id )
104- logger .info ("session response: %s" , sessions_response )
105- session_id = str (sessions_response .data [0 ]["session_id" ])
106- else :
107- conversation_id = agent .agent_id
108- session_id = agent .create_session (get_suid ())
109-
110- return agent , conversation_id , session_id
111-
112-
11368@router .post ("/query" , responses = query_response )
114- def query_endpoint_handler (
69+ async def query_endpoint_handler (
11570 query_request : QueryRequest ,
11671 auth : Any = Depends (auth_dependency ),
11772 mcp_headers : dict [str , dict [str , str ]] = Depends (mcp_headers_dependency ),
@@ -126,11 +81,11 @@ def query_endpoint_handler(
12681
12782 try :
12883 # try to get Llama Stack client
129- client = LlamaStackClientHolder ().get_client ()
84+ client = AsyncLlamaStackClientHolder ().get_client ()
13085 model_id , provider_id = select_model_and_provider_id (
131- client .models .list (), query_request
86+ await client .models .list (), query_request
13287 )
133- response , conversation_id = retrieve_response (
88+ response , conversation_id = await retrieve_response (
13489 client ,
13590 model_id ,
13691 query_request ,
@@ -250,19 +205,21 @@ def is_input_shield(shield: Shield) -> bool:
250205 return _is_inout_shield (shield ) or not is_output_shield (shield )
251206
252207
253- def retrieve_response ( # pylint: disable=too-many-locals
254- client : LlamaStackClient ,
208+ async def retrieve_response ( # pylint: disable=too-many-locals
209+ client : AsyncLlamaStackClient ,
255210 model_id : str ,
256211 query_request : QueryRequest ,
257212 token : str ,
258213 mcp_headers : dict [str , dict [str , str ]] | None = None ,
259214) -> tuple [str , str ]:
260215 """Retrieve response from LLMs and agents."""
261216 available_input_shields = [
262- shield .identifier for shield in filter (is_input_shield , client .shields .list ())
217+ shield .identifier
218+ for shield in filter (is_input_shield , await client .shields .list ())
263219 ]
264220 available_output_shields = [
265- shield .identifier for shield in filter (is_output_shield , client .shields .list ())
221+ shield .identifier
222+ for shield in filter (is_output_shield , await client .shields .list ())
266223 ]
267224 if not available_input_shields and not available_output_shields :
268225 logger .info ("No available shields. Disabling safety" )
@@ -281,7 +238,7 @@ def retrieve_response( # pylint: disable=too-many-locals
281238 if query_request .attachments :
282239 validate_attachments_metadata (query_request .attachments )
283240
284- agent , conversation_id , session_id = get_agent (
241+ agent , conversation_id , session_id = await get_agent (
285242 client ,
286243 model_id ,
287244 system_prompt ,
@@ -291,6 +248,7 @@ def retrieve_response( # pylint: disable=too-many-locals
291248 query_request .no_tools or False ,
292249 )
293250
251+ logger .debug ("Conversation ID: %s, session ID: %s" , conversation_id , session_id )
294252 # bypass tools and MCP servers if no_tools is True
295253 if query_request .no_tools :
296254 mcp_headers = {}
@@ -315,15 +273,17 @@ def retrieve_response( # pylint: disable=too-many-locals
315273 ),
316274 }
317275
318- vector_db_ids = [vector_db .identifier for vector_db in client .vector_dbs .list ()]
276+ vector_db_ids = [
277+ vector_db .identifier for vector_db in await client .vector_dbs .list ()
278+ ]
319279 toolgroups = (get_rag_toolgroups (vector_db_ids ) or []) + [
320280 mcp_server .name for mcp_server in configuration .mcp_servers
321281 ]
322282 # Convert empty list to None for consistency with existing behavior
323283 if not toolgroups :
324284 toolgroups = None
325285
326- response = agent .create_turn (
286+ response = await agent .create_turn (
327287 messages = [UserMessage (role = "user" , content = query_request .query )],
328288 session_id = session_id ,
329289 documents = query_request .get_documents (),
0 commit comments