11"""Handler for REST API call to provide answer to query."""
22
3- from contextlib import suppress
43from datetime import datetime , UTC
54import json
65import logging
76import os
87from pathlib import Path
98from typing import Annotated , Any
109
11- from llama_stack_client .lib .agents .agent import Agent
1210from llama_stack_client import APIConnectionError
13- from llama_stack_client import LlamaStackClient # type: ignore
11+ from llama_stack_client import AsyncLlamaStackClient # type: ignore
1412from llama_stack_client .types import UserMessage , Shield # type: ignore
1513from llama_stack_client .types .agents .turn_create_params import (
1614 ToolgroupAgentToolGroupWithArgs ,
2018
2119from fastapi import APIRouter , HTTPException , status , Depends
2220
23- from client import LlamaStackClientHolder
21+ from auth import get_auth_dependency
22+ from auth .interface import AuthTuple
23+ from client import AsyncLlamaStackClientHolder
2424from configuration import configuration
2525import metrics
2626from models .responses import QueryResponse , UnauthorizedResponse , ForbiddenResponse
2727from models .requests import QueryRequest , Attachment
2828import constants
29- from auth import get_auth_dependency
30- from auth .interface import AuthTuple
31- from utils .endpoints import check_configuration_loaded , get_system_prompt
29+ from utils .endpoints import check_configuration_loaded , get_agent , get_system_prompt
3230from utils .mcp_headers import mcp_headers_dependency , handle_mcp_headers_with_toolgroups
3331from utils .suid import get_suid
34- from utils .types import GraniteToolParser
3532
3633logger = logging .getLogger ("app.endpoints.handlers" )
3734router = APIRouter (tags = ["query" ])
@@ -68,53 +65,8 @@ def is_transcripts_enabled() -> bool:
6865 return configuration .user_data_collection_configuration .transcripts_enabled
6966
7067
71- def get_agent ( # pylint: disable=too-many-arguments,too-many-positional-arguments
72- client : LlamaStackClient ,
73- model_id : str ,
74- system_prompt : str ,
75- available_input_shields : list [str ],
76- available_output_shields : list [str ],
77- conversation_id : str | None ,
78- no_tools : bool = False ,
79- ) -> tuple [Agent , str , str ]:
80- """Get existing agent or create a new one with session persistence."""
81- existing_agent_id = None
82- if conversation_id :
83- with suppress (ValueError ):
84- existing_agent_id = client .agents .retrieve (
85- agent_id = conversation_id
86- ).agent_id
87-
88- logger .debug ("Creating new agent" )
89- # TODO(lucasagomes): move to ReActAgent
90- agent = Agent (
91- client ,
92- model = model_id ,
93- instructions = system_prompt ,
94- input_shields = available_input_shields if available_input_shields else [],
95- output_shields = available_output_shields if available_output_shields else [],
96- tool_parser = None if no_tools else GraniteToolParser .get_parser (model_id ),
97- enable_session_persistence = True ,
98- )
99-
100- agent .initialize ()
101-
102- if existing_agent_id and conversation_id :
103- orphan_agent_id = agent .agent_id
104- agent .agent_id = conversation_id
105- client .agents .delete (agent_id = orphan_agent_id )
106- sessions_response = client .agents .session .list (agent_id = conversation_id )
107- logger .info ("session response: %s" , sessions_response )
108- session_id = str (sessions_response .data [0 ]["session_id" ])
109- else :
110- conversation_id = agent .agent_id
111- session_id = agent .create_session (get_suid ())
112-
113- return agent , conversation_id , session_id
114-
115-
11668@router .post ("/query" , responses = query_response )
117- def query_endpoint_handler (
69+ async def query_endpoint_handler (
11870 query_request : QueryRequest ,
11971 auth : Annotated [AuthTuple , Depends (auth_dependency )],
12072 mcp_headers : dict [str , dict [str , str ]] = Depends (mcp_headers_dependency ),
@@ -129,11 +81,11 @@ def query_endpoint_handler(
12981
13082 try :
13183 # try to get Llama Stack client
132- client = LlamaStackClientHolder ().get_client ()
84+ client = AsyncLlamaStackClientHolder ().get_client ()
13385 model_id , provider_id = select_model_and_provider_id (
134- client .models .list (), query_request
86+ await client .models .list (), query_request
13587 )
136- response , conversation_id = retrieve_response (
88+ response , conversation_id = await retrieve_response (
13789 client ,
13890 model_id ,
13991 query_request ,
@@ -253,19 +205,21 @@ def is_input_shield(shield: Shield) -> bool:
253205 return _is_inout_shield (shield ) or not is_output_shield (shield )
254206
255207
256- def retrieve_response ( # pylint: disable=too-many-locals
257- client : LlamaStackClient ,
208+ async def retrieve_response ( # pylint: disable=too-many-locals
209+ client : AsyncLlamaStackClient ,
258210 model_id : str ,
259211 query_request : QueryRequest ,
260212 token : str ,
261213 mcp_headers : dict [str , dict [str , str ]] | None = None ,
262214) -> tuple [str , str ]:
263215 """Retrieve response from LLMs and agents."""
264216 available_input_shields = [
265- shield .identifier for shield in filter (is_input_shield , client .shields .list ())
217+ shield .identifier
218+ for shield in filter (is_input_shield , await client .shields .list ())
266219 ]
267220 available_output_shields = [
268- shield .identifier for shield in filter (is_output_shield , client .shields .list ())
221+ shield .identifier
222+ for shield in filter (is_output_shield , await client .shields .list ())
269223 ]
270224 if not available_input_shields and not available_output_shields :
271225 logger .info ("No available shields. Disabling safety" )
@@ -284,7 +238,7 @@ def retrieve_response( # pylint: disable=too-many-locals
284238 if query_request .attachments :
285239 validate_attachments_metadata (query_request .attachments )
286240
287- agent , conversation_id , session_id = get_agent (
241+ agent , conversation_id , session_id = await get_agent (
288242 client ,
289243 model_id ,
290244 system_prompt ,
@@ -294,6 +248,7 @@ def retrieve_response( # pylint: disable=too-many-locals
294248 query_request .no_tools or False ,
295249 )
296250
251+ logger .debug ("Conversation ID: %s, session ID: %s" , conversation_id , session_id )
297252 # bypass tools and MCP servers if no_tools is True
298253 if query_request .no_tools :
299254 mcp_headers = {}
@@ -318,15 +273,17 @@ def retrieve_response( # pylint: disable=too-many-locals
318273 ),
319274 }
320275
321- vector_db_ids = [vector_db .identifier for vector_db in client .vector_dbs .list ()]
276+ vector_db_ids = [
277+ vector_db .identifier for vector_db in await client .vector_dbs .list ()
278+ ]
322279 toolgroups = (get_rag_toolgroups (vector_db_ids ) or []) + [
323280 mcp_server .name for mcp_server in configuration .mcp_servers
324281 ]
325282 # Convert empty list to None for consistency with existing behavior
326283 if not toolgroups :
327284 toolgroups = None
328285
329- response = agent .create_turn (
286+ response = await agent .create_turn (
330287 messages = [UserMessage (role = "user" , content = query_request .query )],
331288 session_id = session_id ,
332289 documents = query_request .get_documents (),
0 commit comments