Skip to content

Commit 8e2d70b

Browse files
committed
feat: Convert lightspeed-core to async architecture
- Migrate endpoints from sync to async handlers - Remove legacy sync client infrastructure - Update unit tests for async compatibility This resolves blocking behavior in all endpoints except streaming_query which was already async, enabling proper concurrent request handling. Signed-off-by: Eran Cohen <[email protected]>
1 parent a3b530d commit 8e2d70b

21 files changed

+772
-1300
lines changed

scripts/generate_openapi_schema.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@
1010
# it is needed to read proper configuration in order to start the app to generate schema
1111
from configuration import configuration
1212

13-
from client import LlamaStackClientHolder
13+
from client import AsyncLlamaStackClientHolder
1414

1515
cfg_file = "lightspeed-stack.yaml"
1616
configuration.load_configuration(cfg_file)
1717

1818
# Llama Stack client needs to be loaded before REST API is fully initialized
19-
LlamaStackClientHolder().load(configuration.configuration.llama_stack)
19+
import asyncio # noqa: E402
20+
21+
asyncio.run(AsyncLlamaStackClientHolder().load(configuration.configuration.llama_stack))
2022

2123
from app.main import app # noqa: E402 pylint: disable=C0413
2224

src/app/endpoints/authorized.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
"""Handler for REST API call to authorized endpoint."""
22

3-
import asyncio
43
import logging
54
from typing import Any
65

7-
from fastapi import APIRouter, Request
6+
from fastapi import APIRouter, Depends
87

98
from auth import get_auth_dependency
109
from models.responses import AuthorizedResponse, UnauthorizedResponse, ForbiddenResponse
@@ -31,8 +30,10 @@
3130

3231

3332
@router.post("/authorized", responses=authorized_responses)
34-
def authorized_endpoint_handler(_request: Request) -> AuthorizedResponse:
33+
async def authorized_endpoint_handler(
34+
auth: Any = Depends(auth_dependency),
35+
) -> AuthorizedResponse:
3536
"""Handle request to the /authorized endpoint."""
3637
# Ignore the user token, we should not return it in the response
37-
user_id, user_name, _ = asyncio.run(auth_dependency(_request))
38+
user_id, user_name, _ = auth
3839
return AuthorizedResponse(user_id=user_id, username=user_name)

src/app/endpoints/conversations.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from fastapi import APIRouter, HTTPException, status, Depends
99

10-
from client import LlamaStackClientHolder
10+
from client import AsyncLlamaStackClientHolder
1111
from configuration import configuration
1212
from models.responses import ConversationResponse, ConversationDeleteResponse
1313
from auth import get_auth_dependency
@@ -110,7 +110,7 @@ def simplify_session_data(session_data: dict) -> list[dict[str, Any]]:
110110

111111

112112
@router.get("/conversations/{conversation_id}", responses=conversation_responses)
113-
def get_conversation_endpoint_handler(
113+
async def get_conversation_endpoint_handler(
114114
conversation_id: str,
115115
_auth: Any = Depends(auth_dependency),
116116
) -> ConversationResponse:
@@ -132,9 +132,9 @@ def get_conversation_endpoint_handler(
132132
logger.info("Retrieving conversation %s", conversation_id)
133133

134134
try:
135-
client = LlamaStackClientHolder().get_client()
135+
client = AsyncLlamaStackClientHolder().get_client()
136136

137-
session_data = client.agents.session.list(agent_id=agent_id).data[0]
137+
session_data = (await client.agents.session.list(agent_id=agent_id)).data[0]
138138

139139
logger.info("Successfully retrieved conversation %s", conversation_id)
140140

@@ -179,7 +179,7 @@ def get_conversation_endpoint_handler(
179179
@router.delete(
180180
"/conversations/{conversation_id}", responses=conversation_delete_responses
181181
)
182-
def delete_conversation_endpoint_handler(
182+
async def delete_conversation_endpoint_handler(
183183
conversation_id: str,
184184
_auth: Any = Depends(auth_dependency),
185185
) -> ConversationDeleteResponse:
@@ -201,10 +201,12 @@ def delete_conversation_endpoint_handler(
201201

202202
try:
203203
# Get Llama Stack client
204-
client = LlamaStackClientHolder().get_client()
204+
client = AsyncLlamaStackClientHolder().get_client()
205205
# Delete session using the conversation_id as session_id
206206
# In this implementation, conversation_id and session_id are the same
207-
client.agents.session.delete(agent_id=agent_id, session_id=conversation_id)
207+
await client.agents.session.delete(
208+
agent_id=agent_id, session_id=conversation_id
209+
)
208210

209211
logger.info("Successfully deleted conversation %s", conversation_id)
210212

src/app/endpoints/models.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from llama_stack_client import APIConnectionError
77
from fastapi import APIRouter, HTTPException, Request, status
88

9-
from client import LlamaStackClientHolder
9+
from client import AsyncLlamaStackClientHolder
1010
from configuration import configuration
1111
from models.responses import ModelsResponse
1212
from utils.endpoints import check_configuration_loaded
@@ -43,7 +43,7 @@
4343

4444

4545
@router.get("/models", responses=models_responses)
46-
def models_endpoint_handler(_request: Request) -> ModelsResponse:
46+
async def models_endpoint_handler(_request: Request) -> ModelsResponse:
4747
"""Handle requests to the /models endpoint."""
4848
check_configuration_loaded(configuration)
4949

@@ -52,9 +52,9 @@ def models_endpoint_handler(_request: Request) -> ModelsResponse:
5252

5353
try:
5454
# try to get Llama Stack client
55-
client = LlamaStackClientHolder().get_client()
55+
client = AsyncLlamaStackClientHolder().get_client()
5656
# retrieve models
57-
models = client.models.list()
57+
models = await client.models.list()
5858
m = [dict(m) for m in models]
5959
return ModelsResponse(models=m)
6060

src/app/endpoints/query.py

Lines changed: 19 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
11
"""Handler for REST API call to provide answer to query."""
22

3-
from contextlib import suppress
43
from datetime import datetime, UTC
54
import json
65
import logging
76
import os
87
from pathlib import Path
98
from typing import Any
109

11-
from llama_stack_client.lib.agents.agent import Agent
1210
from llama_stack_client import APIConnectionError
13-
from llama_stack_client import LlamaStackClient # type: ignore
11+
from llama_stack_client import AsyncLlamaStackClient # type: ignore
1412
from llama_stack_client.types import UserMessage, Shield # type: ignore
1513
from llama_stack_client.types.agents.turn_create_params import (
1614
ToolgroupAgentToolGroupWithArgs,
@@ -20,18 +18,17 @@
2018

2119
from fastapi import APIRouter, HTTPException, status, Depends
2220

23-
from client import LlamaStackClientHolder
21+
from client import AsyncLlamaStackClientHolder
2422
from configuration import configuration
2523
import metrics
2624
from models.responses import QueryResponse, UnauthorizedResponse, ForbiddenResponse
2725
from models.requests import QueryRequest, Attachment
2826
import constants
2927
from auth import get_auth_dependency
3028
from utils.common import retrieve_user_id
31-
from utils.endpoints import check_configuration_loaded, get_system_prompt
29+
from utils.endpoints import check_configuration_loaded, get_agent, get_system_prompt
3230
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
3331
from utils.suid import get_suid
34-
from utils.types import GraniteToolParser
3532

3633
logger = logging.getLogger("app.endpoints.handlers")
3734
router = APIRouter(tags=["query"])
@@ -68,50 +65,8 @@ def is_transcripts_enabled() -> bool:
6865
return configuration.user_data_collection_configuration.transcripts_enabled
6966

7067

71-
def get_agent( # pylint: disable=too-many-arguments,too-many-positional-arguments
72-
client: LlamaStackClient,
73-
model_id: str,
74-
system_prompt: str,
75-
available_input_shields: list[str],
76-
available_output_shields: list[str],
77-
conversation_id: str | None,
78-
no_tools: bool = False,
79-
) -> tuple[Agent, str, str]:
80-
"""Get existing agent or create a new one with session persistence."""
81-
existing_agent_id = None
82-
if conversation_id:
83-
with suppress(ValueError):
84-
existing_agent_id = client.agents.retrieve(
85-
agent_id=conversation_id
86-
).agent_id
87-
88-
logger.debug("Creating new agent")
89-
# TODO(lucasagomes): move to ReActAgent
90-
agent = Agent(
91-
client,
92-
model=model_id,
93-
instructions=system_prompt,
94-
input_shields=available_input_shields if available_input_shields else [],
95-
output_shields=available_output_shields if available_output_shields else [],
96-
tool_parser=None if no_tools else GraniteToolParser.get_parser(model_id),
97-
enable_session_persistence=True,
98-
)
99-
if existing_agent_id and conversation_id:
100-
orphan_agent_id = agent.agent_id
101-
agent.agent_id = conversation_id
102-
client.agents.delete(agent_id=orphan_agent_id)
103-
sessions_response = client.agents.session.list(agent_id=conversation_id)
104-
logger.info("session response: %s", sessions_response)
105-
session_id = str(sessions_response.data[0]["session_id"])
106-
else:
107-
conversation_id = agent.agent_id
108-
session_id = agent.create_session(get_suid())
109-
110-
return agent, conversation_id, session_id
111-
112-
11368
@router.post("/query", responses=query_response)
114-
def query_endpoint_handler(
69+
async def query_endpoint_handler(
11570
query_request: QueryRequest,
11671
auth: Any = Depends(auth_dependency),
11772
mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency),
@@ -126,11 +81,11 @@ def query_endpoint_handler(
12681

12782
try:
12883
# try to get Llama Stack client
129-
client = LlamaStackClientHolder().get_client()
84+
client = AsyncLlamaStackClientHolder().get_client()
13085
model_id, provider_id = select_model_and_provider_id(
131-
client.models.list(), query_request
86+
await client.models.list(), query_request
13287
)
133-
response, conversation_id = retrieve_response(
88+
response, conversation_id = await retrieve_response(
13489
client,
13590
model_id,
13691
query_request,
@@ -250,19 +205,21 @@ def is_input_shield(shield: Shield) -> bool:
250205
return _is_inout_shield(shield) or not is_output_shield(shield)
251206

252207

253-
def retrieve_response( # pylint: disable=too-many-locals
254-
client: LlamaStackClient,
208+
async def retrieve_response( # pylint: disable=too-many-locals
209+
client: AsyncLlamaStackClient,
255210
model_id: str,
256211
query_request: QueryRequest,
257212
token: str,
258213
mcp_headers: dict[str, dict[str, str]] | None = None,
259214
) -> tuple[str, str]:
260215
"""Retrieve response from LLMs and agents."""
261216
available_input_shields = [
262-
shield.identifier for shield in filter(is_input_shield, client.shields.list())
217+
shield.identifier
218+
for shield in filter(is_input_shield, await client.shields.list())
263219
]
264220
available_output_shields = [
265-
shield.identifier for shield in filter(is_output_shield, client.shields.list())
221+
shield.identifier
222+
for shield in filter(is_output_shield, await client.shields.list())
266223
]
267224
if not available_input_shields and not available_output_shields:
268225
logger.info("No available shields. Disabling safety")
@@ -281,7 +238,7 @@ def retrieve_response( # pylint: disable=too-many-locals
281238
if query_request.attachments:
282239
validate_attachments_metadata(query_request.attachments)
283240

284-
agent, conversation_id, session_id = get_agent(
241+
agent, conversation_id, session_id = await get_agent(
285242
client,
286243
model_id,
287244
system_prompt,
@@ -291,6 +248,7 @@ def retrieve_response( # pylint: disable=too-many-locals
291248
query_request.no_tools or False,
292249
)
293250

251+
logger.debug("Conversation ID: %s, session ID: %s", conversation_id, session_id)
294252
# bypass tools and MCP servers if no_tools is True
295253
if query_request.no_tools:
296254
mcp_headers = {}
@@ -315,15 +273,17 @@ def retrieve_response( # pylint: disable=too-many-locals
315273
),
316274
}
317275

318-
vector_db_ids = [vector_db.identifier for vector_db in client.vector_dbs.list()]
276+
vector_db_ids = [
277+
vector_db.identifier for vector_db in await client.vector_dbs.list()
278+
]
319279
toolgroups = (get_rag_toolgroups(vector_db_ids) or []) + [
320280
mcp_server.name for mcp_server in configuration.mcp_servers
321281
]
322282
# Convert empty list to None for consistency with existing behavior
323283
if not toolgroups:
324284
toolgroups = None
325285

326-
response = agent.create_turn(
286+
response = await agent.create_turn(
327287
messages=[UserMessage(role="user", content=query_request.query)],
328288
session_id=session_id,
329289
documents=query_request.get_documents(),

src/app/endpoints/streaming_query.py

Lines changed: 2 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
"""Handler for REST API call to provide answer to streaming query."""
22

33
import ast
4-
from contextlib import suppress
54
import json
65
import re
76
import logging
87
from typing import Any, AsyncIterator, Iterator
98

109
from llama_stack_client import APIConnectionError
11-
from llama_stack_client.lib.agents.agent import AsyncAgent # type: ignore
1210
from llama_stack_client import AsyncLlamaStackClient # type: ignore
1311
from llama_stack_client.types import UserMessage # type: ignore
1412

@@ -24,11 +22,9 @@
2422
from configuration import configuration
2523
import metrics
2624
from models.requests import QueryRequest
27-
from utils.endpoints import check_configuration_loaded, get_system_prompt
25+
from utils.endpoints import check_configuration_loaded, get_agent, get_system_prompt
2826
from utils.common import retrieve_user_id
2927
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
30-
from utils.suid import get_suid
31-
from utils.types import GraniteToolParser
3228

3329
from app.endpoints.query import (
3430
get_rag_toolgroups,
@@ -45,48 +41,6 @@
4541
auth_dependency = get_auth_dependency()
4642

4743

48-
# # pylint: disable=R0913,R0917
49-
async def get_agent(
50-
client: AsyncLlamaStackClient,
51-
model_id: str,
52-
system_prompt: str,
53-
available_input_shields: list[str],
54-
available_output_shields: list[str],
55-
conversation_id: str | None,
56-
no_tools: bool = False,
57-
) -> tuple[AsyncAgent, str, str]:
58-
"""Get existing agent or create a new one with session persistence."""
59-
existing_agent_id = None
60-
if conversation_id:
61-
with suppress(ValueError):
62-
agent_response = await client.agents.retrieve(agent_id=conversation_id)
63-
existing_agent_id = agent_response.agent_id
64-
65-
logger.debug("Creating new agent")
66-
agent = AsyncAgent(
67-
client, # type: ignore[arg-type]
68-
model=model_id,
69-
instructions=system_prompt,
70-
input_shields=available_input_shields if available_input_shields else [],
71-
output_shields=available_output_shields if available_output_shields else [],
72-
tool_parser=None if no_tools else GraniteToolParser.get_parser(model_id),
73-
enable_session_persistence=True,
74-
)
75-
76-
if existing_agent_id and conversation_id:
77-
orphan_agent_id = agent.agent_id
78-
agent._agent_id = conversation_id # type: ignore[assignment] # pylint: disable=protected-access
79-
await client.agents.delete(agent_id=orphan_agent_id)
80-
sessions_response = await client.agents.session.list(agent_id=conversation_id)
81-
logger.info("session response: %s", sessions_response)
82-
session_id = str(sessions_response.data[0]["session_id"])
83-
else:
84-
conversation_id = agent.agent_id
85-
session_id = await agent.create_session(get_suid())
86-
87-
return agent, conversation_id, session_id
88-
89-
9044
METADATA_PATTERN = re.compile(r"\nMetadata: (\{.+})\n")
9145

9246

@@ -536,6 +490,7 @@ async def retrieve_response(
536490
query_request.no_tools or False,
537491
)
538492

493+
logger.debug("Conversation ID: %s, session ID: %s", conversation_id, session_id)
539494
# bypass tools and MCP servers if no_tools is True
540495
if query_request.no_tools:
541496
mcp_headers = {}
@@ -562,7 +517,6 @@ async def retrieve_response(
562517
),
563518
}
564519

565-
logger.debug("Session ID: %s", conversation_id)
566520
vector_db_ids = [
567521
vector_db.identifier for vector_db in await client.vector_dbs.list()
568522
]
@@ -573,7 +527,6 @@ async def retrieve_response(
573527
if not toolgroups:
574528
toolgroups = None
575529

576-
logger.debug("Session ID: %s", conversation_id)
577530
response = await agent.create_turn(
578531
messages=[UserMessage(role="user", content=query_request.query)],
579532
session_id=session_id,

0 commit comments

Comments
 (0)