Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 240 additions & 0 deletions src/app/endpoints/conversations_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
"""Handler for REST API calls to manage conversation history."""

import logging
from typing import Any

from fastapi import APIRouter, Request, Depends, HTTPException, status

from configuration import configuration
from authentication import get_auth_dependency
from authorization.middleware import authorize
from models.cache_entry import CacheEntry
from models.config import Action
from models.responses import (
ConversationsListResponseV2,
ConversationResponse,
ConversationDeleteResponse,
UnauthorizedResponse,
)
from utils.endpoints import check_configuration_loaded
from utils.suid import check_suid

logger = logging.getLogger("app.endpoints.handlers")
router = APIRouter(tags=["conversations_v2"])
auth_dependency = get_auth_dependency()


conversation_responses: dict[int | str, dict[str, Any]] = {
200: {
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
"chat_history": [
{
"messages": [
{"content": "Hi", "type": "user"},
{"content": "Hello!", "type": "assistant"},
],
"started_at": "2024-01-01T00:00:00Z",
"completed_at": "2024-01-01T00:00:05Z",
"provider": "provider ID",
"model": "model ID",
}
],
},
400: {
"description": "Missing or invalid credentials provided by client",
"model": UnauthorizedResponse,
},
401: {
"description": "Unauthorized: Invalid or missing Bearer token",
"model": UnauthorizedResponse,
},
404: {
"detail": {
"response": "Conversation not found",
"cause": "The specified conversation ID does not exist.",
}
},
}

conversation_delete_responses: dict[int | str, dict[str, Any]] = {
200: {
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
"success": True,
"message": "Conversation deleted successfully",
},
400: {
"description": "Missing or invalid credentials provided by client",
"model": UnauthorizedResponse,
},
401: {
"description": "Unauthorized: Invalid or missing Bearer token",
"model": UnauthorizedResponse,
},
404: {
"detail": {
"response": "Conversation not found",
"cause": "The specified conversation ID does not exist.",
}
},
}

conversations_list_responses: dict[int | str, dict[str, Any]] = {
200: {
"conversations": [
{
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
}
]
}
}


@router.get("/conversations", responses=conversations_list_responses)
@authorize(Action.LIST_CONVERSATIONS)
async def get_conversations_list_endpoint_handler(
request: Request, # pylint: disable=unused-argument
auth: Any = Depends(auth_dependency),
) -> ConversationsListResponseV2:
"""Handle request to retrieve all conversations for the authenticated user."""
check_configuration_loaded(configuration)

user_id = auth[0]

logger.info("Retrieving conversations for user %s", user_id)

if configuration.conversation_cache is None:
logger.warning("Converastion cache is not configured")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={
"response": "Conversation cache is not configured",
"cause": "Conversation cache is not configured",
},
)

conversations = configuration.conversation_cache.list(user_id, False)
logger.info("Conversations for user %s: %s", user_id, len(conversations))

return ConversationsListResponseV2(conversations=conversations)


@router.get("/conversations/{conversation_id}", responses=conversation_responses)
@authorize(Action.GET_CONVERSATION)
async def get_conversation_endpoint_handler(
request: Request, # pylint: disable=unused-argument
conversation_id: str,
auth: Any = Depends(auth_dependency),
) -> ConversationResponse:
"""Handle request to retrieve a conversation by ID."""
check_configuration_loaded(configuration)
check_valid_conversation_id(conversation_id)

user_id = auth[0]
logger.info("Retrieving conversation %s for user %s", conversation_id, user_id)

if configuration.conversation_cache is None:
logger.warning("Converastion cache is not configured")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={
"response": "Conversation cache is not configured",
"cause": "Conversation cache is not configured",
},
)

check_conversation_existence(user_id, conversation_id)

conversation = configuration.conversation_cache.get(user_id, conversation_id, False)
chat_history = [transform_chat_message(entry) for entry in conversation]

return ConversationResponse(
conversation_id=conversation_id, chat_history=chat_history
)


@router.delete(
"/conversations/{conversation_id}", responses=conversation_delete_responses
)
@authorize(Action.DELETE_CONVERSATION)
async def delete_conversation_endpoint_handler(
request: Request, # pylint: disable=unused-argument
conversation_id: str,
auth: Any = Depends(auth_dependency),
) -> ConversationDeleteResponse:
"""Handle request to delete a conversation by ID."""
check_configuration_loaded(configuration)
check_valid_conversation_id(conversation_id)

user_id = auth[0]
logger.info("Deleting conversation %s for user %s", conversation_id, user_id)

if configuration.conversation_cache is None:
logger.warning("Converastion cache is not configured")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={
"response": "Conversation cache is not configured",
"cause": "Conversation cache is not configured",
},
)

check_conversation_existence(user_id, conversation_id)

logger.info("Deleting conversation %s for user %s", conversation_id, user_id)
deleted = configuration.conversation_cache.delete(user_id, conversation_id, False)

if deleted:
return ConversationDeleteResponse(
conversation_id=conversation_id,
success=True,
response="Conversation deleted successfully",
)
return ConversationDeleteResponse(
conversation_id=conversation_id,
success=True,
response="Conversation can not be deleted",
)


def check_valid_conversation_id(conversation_id: str) -> None:
"""Check validity of conversation ID format."""
if not check_suid(conversation_id):
logger.error("Invalid conversation ID format: %s", conversation_id)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"response": "Invalid conversation ID format",
"cause": f"Conversation ID {conversation_id} is not a valid UUID",
},
)


def check_conversation_existence(user_id: str, conversation_id: str) -> None:
"""Check if conversation exists."""
# checked already, but we need to make pyright happy
if configuration.conversation_cache is None:
return
conversations = configuration.conversation_cache.list(user_id, False)
if conversation_id not in conversations:
logger.error("No conversation found for conversation ID %s", conversation_id)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={
"response": "Conversation not found",
"cause": f"Conversation {conversation_id} could not be retrieved.",
},
)


def transform_chat_message(entry: CacheEntry) -> dict[str, Any]:
"""Transform the message read from cache into format used by response payload."""
return {
"provider": entry.provider,
"model": entry.model,
"query": entry.query,
"response": entry.response,
"messages": [
{"content": entry.query, "type": "user"},
{"content": entry.response, "type": "assistant"},
],
}
11 changes: 11 additions & 0 deletions src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
check_configuration_loaded,
get_agent,
get_system_prompt,
store_conversation_into_cache,
validate_conversation_ownership,
validate_model_provider_override,
)
Expand Down Expand Up @@ -279,6 +280,16 @@ async def query_endpoint_handler(
provider_id=provider_id,
)

store_conversation_into_cache(
configuration,
user_id,
conversation_id,
provider_id,
model_id,
query_request.query,
summary.llm_response,
)

# Convert tool calls to response format
logger.info("Processing tool calls...")
tool_calls = [
Expand Down
11 changes: 11 additions & 0 deletions src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
check_configuration_loaded,
get_agent,
get_system_prompt,
store_conversation_into_cache,
validate_model_provider_override,
)
from utils.mcp_headers import handle_mcp_headers_with_toolgroups, mcp_headers_dependency
Expand Down Expand Up @@ -704,6 +705,16 @@ async def response_generator(
attachments=query_request.attachments or [],
)

store_conversation_into_cache(
configuration,
user_id,
conversation_id,
provider_id,
model_id,
query_request.query,
summary.llm_response,
)

persist_user_conversation_details(
user_id=user_id,
conversation_id=conversation_id,
Expand Down
2 changes: 2 additions & 0 deletions src/app/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
streaming_query,
authorized,
conversations,
conversations_v2,
metrics,
)

Expand All @@ -31,6 +32,7 @@ def include_routers(app: FastAPI) -> None:
app.include_router(config.router, prefix="/v1")
app.include_router(feedback.router, prefix="/v1")
app.include_router(conversations.router, prefix="/v1")
app.include_router(conversations_v2.router, prefix="/v2")

# road-core does not version these endpoints
app.include_router(health.router)
Expand Down
10 changes: 10 additions & 0 deletions src/models/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,16 @@ class ConversationsListResponse(BaseModel):
}


class ConversationsListResponseV2(BaseModel):
"""Model representing a response for listing conversations of a user.

Attributes:
conversations: List of conversation IDs associated with the user.
"""

conversations: list[str]


class ErrorResponse(BaseModel):
"""Model representing error response for query endpoint."""

Expand Down
26 changes: 26 additions & 0 deletions src/utils/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from llama_stack_client.lib.agents.agent import AsyncAgent

import constants
from models.cache_entry import CacheEntry
from models.requests import QueryRequest
from models.database.conversations import UserConversation
from models.config import Action
Expand Down Expand Up @@ -135,6 +136,31 @@ def validate_model_provider_override(
)


# # pylint: disable=R0913,R0917
def store_conversation_into_cache(
config: AppConfig,
user_id: str,
conversation_id: str,
provider_id: str,
model_id: str,
query: str,
response: str,
) -> None:
"""Store one part of conversation into conversation history cache."""
if config.conversation_cache_configuration.type is not None:
cache = config.conversation_cache
if cache is None:
logger.warning("Conversation cache configured but not initialized")
return
cache_entry = CacheEntry(
query=query,
response=response,
provider=provider_id,
model=model_id,
)
cache.insert_or_append(user_id, conversation_id, cache_entry, False)


# # pylint: disable=R0913,R0917
async def get_agent(
client: AsyncLlamaStackClient,
Expand Down
6 changes: 4 additions & 2 deletions tests/unit/app/test_routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from app.endpoints import (
conversations,
conversations_v2,
root,
info,
models,
Expand Down Expand Up @@ -60,7 +61,7 @@ def test_include_routers() -> None:
include_routers(app)

# are all routers added?
assert len(app.routers) == 11
assert len(app.routers) == 12
assert root.router in app.get_routers()
assert info.router in app.get_routers()
assert models.router in app.get_routers()
Expand All @@ -80,7 +81,7 @@ def test_check_prefixes() -> None:
include_routers(app)

# are all routers added?
assert len(app.routers) == 11
assert len(app.routers) == 12
assert app.get_router_prefix(root.router) == ""
assert app.get_router_prefix(info.router) == "/v1"
assert app.get_router_prefix(models.router) == "/v1"
Expand All @@ -92,3 +93,4 @@ def test_check_prefixes() -> None:
assert app.get_router_prefix(authorized.router) == ""
assert app.get_router_prefix(conversations.router) == "/v1"
assert app.get_router_prefix(metrics.router) == ""
assert app.get_router_prefix(conversations_v2.router) == "/v2"
Loading