diff --git a/src/app/endpoints/conversations_v2.py b/src/app/endpoints/conversations_v2.py new file mode 100644 index 00000000..5033b5e5 --- /dev/null +++ b/src/app/endpoints/conversations_v2.py @@ -0,0 +1,240 @@ +"""Handler for REST API calls to manage conversation history.""" + +import logging +from typing import Any + +from fastapi import APIRouter, Request, Depends, HTTPException, status + +from configuration import configuration +from authentication import get_auth_dependency +from authorization.middleware import authorize +from models.cache_entry import CacheEntry +from models.config import Action +from models.responses import ( + ConversationsListResponseV2, + ConversationResponse, + ConversationDeleteResponse, + UnauthorizedResponse, +) +from utils.endpoints import check_configuration_loaded +from utils.suid import check_suid + +logger = logging.getLogger("app.endpoints.handlers") +router = APIRouter(tags=["conversations_v2"]) +auth_dependency = get_auth_dependency() + + +conversation_responses: dict[int | str, dict[str, Any]] = { + 200: { + "conversation_id": "123e4567-e89b-12d3-a456-426614174000", + "chat_history": [ + { + "messages": [ + {"content": "Hi", "type": "user"}, + {"content": "Hello!", "type": "assistant"}, + ], + "started_at": "2024-01-01T00:00:00Z", + "completed_at": "2024-01-01T00:00:05Z", + "provider": "provider ID", + "model": "model ID", + } + ], + }, + 400: { + "description": "Missing or invalid credentials provided by client", + "model": UnauthorizedResponse, + }, + 401: { + "description": "Unauthorized: Invalid or missing Bearer token", + "model": UnauthorizedResponse, + }, + 404: { + "detail": { + "response": "Conversation not found", + "cause": "The specified conversation ID does not exist.", + } + }, +} + +conversation_delete_responses: dict[int | str, dict[str, Any]] = { + 200: { + "conversation_id": "123e4567-e89b-12d3-a456-426614174000", + "success": True, + "message": "Conversation deleted successfully", + }, + 400: { + "description": "Missing or invalid credentials provided by client", + "model": UnauthorizedResponse, + }, + 401: { + "description": "Unauthorized: Invalid or missing Bearer token", + "model": UnauthorizedResponse, + }, + 404: { + "detail": { + "response": "Conversation not found", + "cause": "The specified conversation ID does not exist.", + } + }, +} + +conversations_list_responses: dict[int | str, dict[str, Any]] = { + 200: { + "conversations": [ + { + "conversation_id": "123e4567-e89b-12d3-a456-426614174000", + } + ] + } +} + + +@router.get("/conversations", responses=conversations_list_responses) +@authorize(Action.LIST_CONVERSATIONS) +async def get_conversations_list_endpoint_handler( + request: Request, # pylint: disable=unused-argument + auth: Any = Depends(auth_dependency), +) -> ConversationsListResponseV2: + """Handle request to retrieve all conversations for the authenticated user.""" + check_configuration_loaded(configuration) + + user_id = auth[0] + + logger.info("Retrieving conversations for user %s", user_id) + + if configuration.conversation_cache is None: + logger.warning("Converastion cache is not configured") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={ + "response": "Conversation cache is not configured", + "cause": "Conversation cache is not configured", + }, + ) + + conversations = configuration.conversation_cache.list(user_id, False) + logger.info("Conversations for user %s: %s", user_id, len(conversations)) + + return ConversationsListResponseV2(conversations=conversations) + + +@router.get("/conversations/{conversation_id}", responses=conversation_responses) +@authorize(Action.GET_CONVERSATION) +async def get_conversation_endpoint_handler( + request: Request, # pylint: disable=unused-argument + conversation_id: str, + auth: Any = Depends(auth_dependency), +) -> ConversationResponse: + """Handle request to retrieve a conversation by ID.""" + check_configuration_loaded(configuration) + check_valid_conversation_id(conversation_id) + + user_id = auth[0] + logger.info("Retrieving conversation %s for user %s", conversation_id, user_id) + + if configuration.conversation_cache is None: + logger.warning("Converastion cache is not configured") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={ + "response": "Conversation cache is not configured", + "cause": "Conversation cache is not configured", + }, + ) + + check_conversation_existence(user_id, conversation_id) + + conversation = configuration.conversation_cache.get(user_id, conversation_id, False) + chat_history = [transform_chat_message(entry) for entry in conversation] + + return ConversationResponse( + conversation_id=conversation_id, chat_history=chat_history + ) + + +@router.delete( + "/conversations/{conversation_id}", responses=conversation_delete_responses +) +@authorize(Action.DELETE_CONVERSATION) +async def delete_conversation_endpoint_handler( + request: Request, # pylint: disable=unused-argument + conversation_id: str, + auth: Any = Depends(auth_dependency), +) -> ConversationDeleteResponse: + """Handle request to delete a conversation by ID.""" + check_configuration_loaded(configuration) + check_valid_conversation_id(conversation_id) + + user_id = auth[0] + logger.info("Deleting conversation %s for user %s", conversation_id, user_id) + + if configuration.conversation_cache is None: + logger.warning("Converastion cache is not configured") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={ + "response": "Conversation cache is not configured", + "cause": "Conversation cache is not configured", + }, + ) + + check_conversation_existence(user_id, conversation_id) + + logger.info("Deleting conversation %s for user %s", conversation_id, user_id) + deleted = configuration.conversation_cache.delete(user_id, conversation_id, False) + + if deleted: + return ConversationDeleteResponse( + conversation_id=conversation_id, + success=True, + response="Conversation deleted successfully", + ) + return ConversationDeleteResponse( + conversation_id=conversation_id, + success=True, + response="Conversation can not be deleted", + ) + + +def check_valid_conversation_id(conversation_id: str) -> None: + """Check validity of conversation ID format.""" + if not check_suid(conversation_id): + logger.error("Invalid conversation ID format: %s", conversation_id) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={ + "response": "Invalid conversation ID format", + "cause": f"Conversation ID {conversation_id} is not a valid UUID", + }, + ) + + +def check_conversation_existence(user_id: str, conversation_id: str) -> None: + """Check if conversation exists.""" + # checked already, but we need to make pyright happy + if configuration.conversation_cache is None: + return + conversations = configuration.conversation_cache.list(user_id, False) + if conversation_id not in conversations: + logger.error("No conversation found for conversation ID %s", conversation_id) + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={ + "response": "Conversation not found", + "cause": f"Conversation {conversation_id} could not be retrieved.", + }, + ) + + +def transform_chat_message(entry: CacheEntry) -> dict[str, Any]: + """Transform the message read from cache into format used by response payload.""" + return { + "provider": entry.provider, + "model": entry.model, + "query": entry.query, + "response": entry.response, + "messages": [ + {"content": entry.query, "type": "user"}, + {"content": entry.response, "type": "assistant"}, + ], + } diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index ecc7d2ed..15a21544 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -47,6 +47,7 @@ check_configuration_loaded, get_agent, get_system_prompt, + store_conversation_into_cache, validate_conversation_ownership, validate_model_provider_override, ) @@ -279,6 +280,16 @@ async def query_endpoint_handler( provider_id=provider_id, ) + store_conversation_into_cache( + configuration, + user_id, + conversation_id, + provider_id, + model_id, + query_request.query, + summary.llm_response, + ) + # Convert tool calls to response format logger.info("Processing tool calls...") tool_calls = [ diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 14172eeb..f48646ac 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -47,6 +47,7 @@ check_configuration_loaded, get_agent, get_system_prompt, + store_conversation_into_cache, validate_model_provider_override, ) from utils.mcp_headers import handle_mcp_headers_with_toolgroups, mcp_headers_dependency @@ -704,6 +705,16 @@ async def response_generator( attachments=query_request.attachments or [], ) + store_conversation_into_cache( + configuration, + user_id, + conversation_id, + provider_id, + model_id, + query_request.query, + summary.llm_response, + ) + persist_user_conversation_details( user_id=user_id, conversation_id=conversation_id, diff --git a/src/app/routers.py b/src/app/routers.py index 9131076a..bd4de2e5 100644 --- a/src/app/routers.py +++ b/src/app/routers.py @@ -13,6 +13,7 @@ streaming_query, authorized, conversations, + conversations_v2, metrics, ) @@ -31,6 +32,7 @@ def include_routers(app: FastAPI) -> None: app.include_router(config.router, prefix="/v1") app.include_router(feedback.router, prefix="/v1") app.include_router(conversations.router, prefix="/v1") + app.include_router(conversations_v2.router, prefix="/v2") # road-core does not version these endpoints app.include_router(health.router) diff --git a/src/models/responses.py b/src/models/responses.py index 458c486b..cab7c3cc 100644 --- a/src/models/responses.py +++ b/src/models/responses.py @@ -667,6 +667,16 @@ class ConversationsListResponse(BaseModel): } +class ConversationsListResponseV2(BaseModel): + """Model representing a response for listing conversations of a user. + + Attributes: + conversations: List of conversation IDs associated with the user. + """ + + conversations: list[str] + + class ErrorResponse(BaseModel): """Model representing error response for query endpoint.""" diff --git a/src/utils/endpoints.py b/src/utils/endpoints.py index 27687af2..818aac64 100644 --- a/src/utils/endpoints.py +++ b/src/utils/endpoints.py @@ -6,6 +6,7 @@ from llama_stack_client.lib.agents.agent import AsyncAgent import constants +from models.cache_entry import CacheEntry from models.requests import QueryRequest from models.database.conversations import UserConversation from models.config import Action @@ -135,6 +136,31 @@ def validate_model_provider_override( ) +# # pylint: disable=R0913,R0917 +def store_conversation_into_cache( + config: AppConfig, + user_id: str, + conversation_id: str, + provider_id: str, + model_id: str, + query: str, + response: str, +) -> None: + """Store one part of conversation into conversation history cache.""" + if config.conversation_cache_configuration.type is not None: + cache = config.conversation_cache + if cache is None: + logger.warning("Conversation cache configured but not initialized") + return + cache_entry = CacheEntry( + query=query, + response=response, + provider=provider_id, + model=model_id, + ) + cache.insert_or_append(user_id, conversation_id, cache_entry, False) + + # # pylint: disable=R0913,R0917 async def get_agent( client: AsyncLlamaStackClient, diff --git a/tests/unit/app/test_routers.py b/tests/unit/app/test_routers.py index 42f08586..aee36b7d 100644 --- a/tests/unit/app/test_routers.py +++ b/tests/unit/app/test_routers.py @@ -8,6 +8,7 @@ from app.endpoints import ( conversations, + conversations_v2, root, info, models, @@ -60,7 +61,7 @@ def test_include_routers() -> None: include_routers(app) # are all routers added? - assert len(app.routers) == 11 + assert len(app.routers) == 12 assert root.router in app.get_routers() assert info.router in app.get_routers() assert models.router in app.get_routers() @@ -80,7 +81,7 @@ def test_check_prefixes() -> None: include_routers(app) # are all routers added? - assert len(app.routers) == 11 + assert len(app.routers) == 12 assert app.get_router_prefix(root.router) == "" assert app.get_router_prefix(info.router) == "/v1" assert app.get_router_prefix(models.router) == "/v1" @@ -92,3 +93,4 @@ def test_check_prefixes() -> None: assert app.get_router_prefix(authorized.router) == "" assert app.get_router_prefix(conversations.router) == "/v1" assert app.get_router_prefix(metrics.router) == "" + assert app.get_router_prefix(conversations_v2.router) == "/v2"