Skip to content

Commit 739b786

Browse files
authored
Merge pull request #348 from eranco74/MGMT-21377
Convert lightspeed-core to async architecture
2 parents 7d77c74 + 83b4fac commit 739b786

22 files changed

+843
-1318
lines changed

scripts/generate_openapi_schema.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@
1010
# it is needed to read proper configuration in order to start the app to generate schema
1111
from configuration import configuration
1212

13-
from client import LlamaStackClientHolder
13+
from client import AsyncLlamaStackClientHolder
1414

1515
cfg_file = "lightspeed-stack.yaml"
1616
configuration.load_configuration(cfg_file)
1717

1818
# Llama Stack client needs to be loaded before REST API is fully initialized
19-
LlamaStackClientHolder().load(configuration.configuration.llama_stack)
19+
import asyncio # noqa: E402
20+
21+
asyncio.run(AsyncLlamaStackClientHolder().load(configuration.configuration.llama_stack))
2022

2123
from app.main import app # noqa: E402 pylint: disable=C0413
2224

src/app/endpoints/authorized.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
"""Handler for REST API call to authorized endpoint."""
22

3-
import asyncio
43
import logging
54
from typing import Any
65

7-
from fastapi import APIRouter, Request
6+
from fastapi import APIRouter, Depends
87

98
from auth import get_auth_dependency
109
from models.responses import AuthorizedResponse, UnauthorizedResponse, ForbiddenResponse
@@ -31,8 +30,10 @@
3130

3231

3332
@router.post("/authorized", responses=authorized_responses)
34-
def authorized_endpoint_handler(_request: Request) -> AuthorizedResponse:
33+
async def authorized_endpoint_handler(
34+
auth: Any = Depends(auth_dependency),
35+
) -> AuthorizedResponse:
3536
"""Handle request to the /authorized endpoint."""
3637
# Ignore the user token, we should not return it in the response
37-
user_id, user_name, _ = asyncio.run(auth_dependency(_request))
38+
user_id, user_name, _ = auth
3839
return AuthorizedResponse(user_id=user_id, username=user_name)

src/app/endpoints/conversations.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from fastapi import APIRouter, HTTPException, status, Depends
99

10-
from client import LlamaStackClientHolder
10+
from client import AsyncLlamaStackClientHolder
1111
from configuration import configuration
1212
from models.responses import ConversationResponse, ConversationDeleteResponse
1313
from auth import get_auth_dependency
@@ -110,7 +110,7 @@ def simplify_session_data(session_data: dict) -> list[dict[str, Any]]:
110110

111111

112112
@router.get("/conversations/{conversation_id}", responses=conversation_responses)
113-
def get_conversation_endpoint_handler(
113+
async def get_conversation_endpoint_handler(
114114
conversation_id: str,
115115
_auth: Any = Depends(auth_dependency),
116116
) -> ConversationResponse:
@@ -132,9 +132,9 @@ def get_conversation_endpoint_handler(
132132
logger.info("Retrieving conversation %s", conversation_id)
133133

134134
try:
135-
client = LlamaStackClientHolder().get_client()
135+
client = AsyncLlamaStackClientHolder().get_client()
136136

137-
session_data = client.agents.session.list(agent_id=agent_id).data[0]
137+
session_data = (await client.agents.session.list(agent_id=agent_id)).data[0]
138138

139139
logger.info("Successfully retrieved conversation %s", conversation_id)
140140

@@ -179,7 +179,7 @@ def get_conversation_endpoint_handler(
179179
@router.delete(
180180
"/conversations/{conversation_id}", responses=conversation_delete_responses
181181
)
182-
def delete_conversation_endpoint_handler(
182+
async def delete_conversation_endpoint_handler(
183183
conversation_id: str,
184184
_auth: Any = Depends(auth_dependency),
185185
) -> ConversationDeleteResponse:
@@ -201,10 +201,12 @@ def delete_conversation_endpoint_handler(
201201

202202
try:
203203
# Get Llama Stack client
204-
client = LlamaStackClientHolder().get_client()
204+
client = AsyncLlamaStackClientHolder().get_client()
205205
# Delete session using the conversation_id as session_id
206206
# In this implementation, conversation_id and session_id are the same
207-
client.agents.session.delete(agent_id=agent_id, session_id=conversation_id)
207+
await client.agents.session.delete(
208+
agent_id=agent_id, session_id=conversation_id
209+
)
208210

209211
logger.info("Successfully deleted conversation %s", conversation_id)
210212

src/app/endpoints/models.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from llama_stack_client import APIConnectionError
77
from fastapi import APIRouter, HTTPException, Request, status
88

9-
from client import LlamaStackClientHolder
9+
from client import AsyncLlamaStackClientHolder
1010
from configuration import configuration
1111
from models.responses import ModelsResponse
1212
from utils.endpoints import check_configuration_loaded
@@ -43,7 +43,7 @@
4343

4444

4545
@router.get("/models", responses=models_responses)
46-
def models_endpoint_handler(_request: Request) -> ModelsResponse:
46+
async def models_endpoint_handler(_request: Request) -> ModelsResponse:
4747
"""Handle requests to the /models endpoint."""
4848
check_configuration_loaded(configuration)
4949

@@ -52,9 +52,9 @@ def models_endpoint_handler(_request: Request) -> ModelsResponse:
5252

5353
try:
5454
# try to get Llama Stack client
55-
client = LlamaStackClientHolder().get_client()
55+
client = AsyncLlamaStackClientHolder().get_client()
5656
# retrieve models
57-
models = client.models.list()
57+
models = await client.models.list()
5858
m = [dict(m) for m in models]
5959
return ModelsResponse(models=m)
6060

src/app/endpoints/query.py

Lines changed: 21 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
11
"""Handler for REST API call to provide answer to query."""
22

3-
from contextlib import suppress
43
from datetime import datetime, UTC
54
import json
65
import logging
76
import os
87
from pathlib import Path
98
from typing import Annotated, Any
109

11-
from llama_stack_client.lib.agents.agent import Agent
1210
from llama_stack_client import APIConnectionError
13-
from llama_stack_client import LlamaStackClient # type: ignore
11+
from llama_stack_client import AsyncLlamaStackClient # type: ignore
1412
from llama_stack_client.types import UserMessage, Shield # type: ignore
1513
from llama_stack_client.types.agents.turn_create_params import (
1614
ToolgroupAgentToolGroupWithArgs,
@@ -20,18 +18,17 @@
2018

2119
from fastapi import APIRouter, HTTPException, status, Depends
2220

23-
from client import LlamaStackClientHolder
21+
from auth import get_auth_dependency
22+
from auth.interface import AuthTuple
23+
from client import AsyncLlamaStackClientHolder
2424
from configuration import configuration
2525
import metrics
2626
from models.responses import QueryResponse, UnauthorizedResponse, ForbiddenResponse
2727
from models.requests import QueryRequest, Attachment
2828
import constants
29-
from auth import get_auth_dependency
30-
from auth.interface import AuthTuple
31-
from utils.endpoints import check_configuration_loaded, get_system_prompt
29+
from utils.endpoints import check_configuration_loaded, get_agent, get_system_prompt
3230
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
3331
from utils.suid import get_suid
34-
from utils.types import GraniteToolParser
3532

3633
logger = logging.getLogger("app.endpoints.handlers")
3734
router = APIRouter(tags=["query"])
@@ -68,53 +65,8 @@ def is_transcripts_enabled() -> bool:
6865
return configuration.user_data_collection_configuration.transcripts_enabled
6966

7067

71-
def get_agent( # pylint: disable=too-many-arguments,too-many-positional-arguments
72-
client: LlamaStackClient,
73-
model_id: str,
74-
system_prompt: str,
75-
available_input_shields: list[str],
76-
available_output_shields: list[str],
77-
conversation_id: str | None,
78-
no_tools: bool = False,
79-
) -> tuple[Agent, str, str]:
80-
"""Get existing agent or create a new one with session persistence."""
81-
existing_agent_id = None
82-
if conversation_id:
83-
with suppress(ValueError):
84-
existing_agent_id = client.agents.retrieve(
85-
agent_id=conversation_id
86-
).agent_id
87-
88-
logger.debug("Creating new agent")
89-
# TODO(lucasagomes): move to ReActAgent
90-
agent = Agent(
91-
client,
92-
model=model_id,
93-
instructions=system_prompt,
94-
input_shields=available_input_shields if available_input_shields else [],
95-
output_shields=available_output_shields if available_output_shields else [],
96-
tool_parser=None if no_tools else GraniteToolParser.get_parser(model_id),
97-
enable_session_persistence=True,
98-
)
99-
100-
agent.initialize()
101-
102-
if existing_agent_id and conversation_id:
103-
orphan_agent_id = agent.agent_id
104-
agent.agent_id = conversation_id
105-
client.agents.delete(agent_id=orphan_agent_id)
106-
sessions_response = client.agents.session.list(agent_id=conversation_id)
107-
logger.info("session response: %s", sessions_response)
108-
session_id = str(sessions_response.data[0]["session_id"])
109-
else:
110-
conversation_id = agent.agent_id
111-
session_id = agent.create_session(get_suid())
112-
113-
return agent, conversation_id, session_id
114-
115-
11668
@router.post("/query", responses=query_response)
117-
def query_endpoint_handler(
69+
async def query_endpoint_handler(
11870
query_request: QueryRequest,
11971
auth: Annotated[AuthTuple, Depends(auth_dependency)],
12072
mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency),
@@ -129,11 +81,11 @@ def query_endpoint_handler(
12981

13082
try:
13183
# try to get Llama Stack client
132-
client = LlamaStackClientHolder().get_client()
84+
client = AsyncLlamaStackClientHolder().get_client()
13385
model_id, provider_id = select_model_and_provider_id(
134-
client.models.list(), query_request
86+
await client.models.list(), query_request
13587
)
136-
response, conversation_id = retrieve_response(
88+
response, conversation_id = await retrieve_response(
13789
client,
13890
model_id,
13991
query_request,
@@ -253,19 +205,21 @@ def is_input_shield(shield: Shield) -> bool:
253205
return _is_inout_shield(shield) or not is_output_shield(shield)
254206

255207

256-
def retrieve_response( # pylint: disable=too-many-locals
257-
client: LlamaStackClient,
208+
async def retrieve_response( # pylint: disable=too-many-locals
209+
client: AsyncLlamaStackClient,
258210
model_id: str,
259211
query_request: QueryRequest,
260212
token: str,
261213
mcp_headers: dict[str, dict[str, str]] | None = None,
262214
) -> tuple[str, str]:
263215
"""Retrieve response from LLMs and agents."""
264216
available_input_shields = [
265-
shield.identifier for shield in filter(is_input_shield, client.shields.list())
217+
shield.identifier
218+
for shield in filter(is_input_shield, await client.shields.list())
266219
]
267220
available_output_shields = [
268-
shield.identifier for shield in filter(is_output_shield, client.shields.list())
221+
shield.identifier
222+
for shield in filter(is_output_shield, await client.shields.list())
269223
]
270224
if not available_input_shields and not available_output_shields:
271225
logger.info("No available shields. Disabling safety")
@@ -284,7 +238,7 @@ def retrieve_response( # pylint: disable=too-many-locals
284238
if query_request.attachments:
285239
validate_attachments_metadata(query_request.attachments)
286240

287-
agent, conversation_id, session_id = get_agent(
241+
agent, conversation_id, session_id = await get_agent(
288242
client,
289243
model_id,
290244
system_prompt,
@@ -294,6 +248,7 @@ def retrieve_response( # pylint: disable=too-many-locals
294248
query_request.no_tools or False,
295249
)
296250

251+
logger.debug("Conversation ID: %s, session ID: %s", conversation_id, session_id)
297252
# bypass tools and MCP servers if no_tools is True
298253
if query_request.no_tools:
299254
mcp_headers = {}
@@ -318,15 +273,17 @@ def retrieve_response( # pylint: disable=too-many-locals
318273
),
319274
}
320275

321-
vector_db_ids = [vector_db.identifier for vector_db in client.vector_dbs.list()]
276+
vector_db_ids = [
277+
vector_db.identifier for vector_db in await client.vector_dbs.list()
278+
]
322279
toolgroups = (get_rag_toolgroups(vector_db_ids) or []) + [
323280
mcp_server.name for mcp_server in configuration.mcp_servers
324281
]
325282
# Convert empty list to None for consistency with existing behavior
326283
if not toolgroups:
327284
toolgroups = None
328285

329-
response = agent.create_turn(
286+
response = await agent.create_turn(
330287
messages=[UserMessage(role="user", content=query_request.query)],
331288
session_id=session_id,
332289
documents=query_request.get_documents(),

src/app/endpoints/streaming_query.py

Lines changed: 2 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
"""Handler for REST API call to provide answer to streaming query."""
22

33
import ast
4-
from contextlib import suppress
54
import json
65
import re
76
import logging
87
from typing import Annotated, Any, AsyncIterator, Iterator
98

109
from llama_stack_client import APIConnectionError
11-
from llama_stack_client.lib.agents.agent import AsyncAgent # type: ignore
1210
from llama_stack_client import AsyncLlamaStackClient # type: ignore
1311
from llama_stack_client.types import UserMessage # type: ignore
1412

@@ -25,10 +23,8 @@
2523
from configuration import configuration
2624
import metrics
2725
from models.requests import QueryRequest
28-
from utils.endpoints import check_configuration_loaded, get_system_prompt
26+
from utils.endpoints import check_configuration_loaded, get_agent, get_system_prompt
2927
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
30-
from utils.suid import get_suid
31-
from utils.types import GraniteToolParser
3228

3329
from app.endpoints.query import (
3430
get_rag_toolgroups,
@@ -45,50 +41,6 @@
4541
auth_dependency = get_auth_dependency()
4642

4743

48-
# # pylint: disable=R0913,R0917
49-
async def get_agent(
50-
client: AsyncLlamaStackClient,
51-
model_id: str,
52-
system_prompt: str,
53-
available_input_shields: list[str],
54-
available_output_shields: list[str],
55-
conversation_id: str | None,
56-
no_tools: bool = False,
57-
) -> tuple[AsyncAgent, str, str]:
58-
"""Get existing agent or create a new one with session persistence."""
59-
existing_agent_id = None
60-
if conversation_id:
61-
with suppress(ValueError):
62-
agent_response = await client.agents.retrieve(agent_id=conversation_id)
63-
existing_agent_id = agent_response.agent_id
64-
65-
logger.debug("Creating new agent")
66-
agent = AsyncAgent(
67-
client, # type: ignore[arg-type]
68-
model=model_id,
69-
instructions=system_prompt,
70-
input_shields=available_input_shields if available_input_shields else [],
71-
output_shields=available_output_shields if available_output_shields else [],
72-
tool_parser=None if no_tools else GraniteToolParser.get_parser(model_id),
73-
enable_session_persistence=True,
74-
)
75-
76-
await agent.initialize()
77-
78-
if existing_agent_id and conversation_id:
79-
orphan_agent_id = agent.agent_id
80-
agent._agent_id = conversation_id # type: ignore[assignment] # pylint: disable=protected-access
81-
await client.agents.delete(agent_id=orphan_agent_id)
82-
sessions_response = await client.agents.session.list(agent_id=conversation_id)
83-
logger.info("session response: %s", sessions_response)
84-
session_id = str(sessions_response.data[0]["session_id"])
85-
else:
86-
conversation_id = agent.agent_id
87-
session_id = await agent.create_session(get_suid())
88-
89-
return agent, conversation_id, session_id
90-
91-
9244
METADATA_PATTERN = re.compile(r"\nMetadata: (\{.+})\n")
9345

9446

@@ -556,8 +508,7 @@ async def retrieve_response(
556508
query_request.no_tools or False,
557509
)
558510

559-
logger.debug("Session ID: %s", conversation_id)
560-
511+
logger.debug("Conversation ID: %s, session ID: %s", conversation_id, session_id)
561512
# bypass tools and MCP servers if no_tools is True
562513
if query_request.no_tools:
563514
mcp_headers = {}

0 commit comments

Comments
 (0)