Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ dependencies = [
"aiohttp>=3.12.14",
"authlib>=1.6.0",
"openai==1.99.1",
"sqlalchemy>=2.0.42",
]

[tool.pyright]
Expand Down
124 changes: 124 additions & 0 deletions src/app/database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""Database engine management."""

from pathlib import Path
from typing import Any

from sqlalchemy import create_engine, text
from sqlalchemy.engine.base import Engine
from sqlalchemy.orm import sessionmaker, Session
from log import get_logger, logging
from configuration import configuration
from models.database.base import Base
from models.config import SQLiteDatabaseConfiguration, PostgreSQLDatabaseConfiguration

logger = get_logger(__name__)

engine: Engine | None = None
SessionLocal: sessionmaker | None = None


def get_engine() -> Engine:
"""Get the database engine. Raises an error if not initialized."""
if engine is None:
raise RuntimeError(
"Database engine not initialized. Call initialize_database() first."
)
return engine


def create_tables() -> None:
"""Create tables."""
Base.metadata.create_all(get_engine())


def get_session() -> Session:
"""Get a database session. Raises an error if not initialized."""
if SessionLocal is None:
raise RuntimeError(
"Database session not initialized. Call initialize_database() first."
)
return SessionLocal()


def _create_sqlite_engine(config: SQLiteDatabaseConfiguration, **kwargs: Any) -> Engine:
"""Create SQLite database engine."""
if not Path(config.db_path).parent.exists():
raise FileNotFoundError(
f"SQLite database directory does not exist: {config.db_path}"
)

try:
return create_engine(f"sqlite:///{config.db_path}", **kwargs)
except Exception as e:
logger.exception("Failed to create SQLite engine")
raise RuntimeError(f"SQLite engine creation failed: {e}") from e


def _create_postgres_engine(
config: PostgreSQLDatabaseConfiguration, **kwargs: Any
) -> Engine:
"""Create PostgreSQL database engine."""
postgres_url = (
f"postgresql://{config.user}:{config.password}@"
f"{config.host}:{config.port}/{config.db}"
f"?sslmode={config.ssl_mode}&gssencmode={config.gss_encmode}"
)

is_custom_schema = config.namespace is not None and config.namespace != "public"

connect_args = {}
if is_custom_schema:
connect_args["options"] = f"-csearch_path={config.namespace}"

if config.ca_cert_path is not None:
connect_args["sslrootcert"] = str(config.ca_cert_path)

try:
postgres_engine = create_engine(
postgres_url, connect_args=connect_args, **kwargs
)
except Exception as e:
logger.exception("Failed to create PostgreSQL engine")
raise RuntimeError(f"PostgreSQL engine creation failed: {e}") from e

if is_custom_schema:
try:
with postgres_engine.connect() as connection:
connection.execute(
text(f'CREATE SCHEMA IF NOT EXISTS "{config.namespace}"')
)
connection.commit()
logger.info("Schema '%s' created or already exists", config.namespace)
except Exception as e:
logger.exception("Failed to create schema '%s'", config.namespace)
raise RuntimeError(
f"Schema creation failed for '{config.namespace}': {e}"
) from e

return postgres_engine


def initialize_database() -> None:
"""Initialize the database engine."""
db_config = configuration.database_configuration

global engine, SessionLocal # pylint: disable=global-statement

# Debug print all SQL statements if our logger is at-least DEBUG level
echo = bool(logger.isEnabledFor(logging.DEBUG))

create_engine_kwargs = {
"echo": echo,
}

match db_config.db_type:
case "sqlite":
sqlite_config = db_config.config
assert isinstance(sqlite_config, SQLiteDatabaseConfiguration)
engine = _create_sqlite_engine(sqlite_config, **create_engine_kwargs)
case "postgres":
postgres_config = db_config.config
assert isinstance(postgres_config, PostgreSQLDatabaseConfiguration)
engine = _create_postgres_engine(postgres_config, **create_engine_kwargs)

SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
113 changes: 109 additions & 4 deletions src/app/endpoints/conversations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,16 @@

from client import AsyncLlamaStackClientHolder
from configuration import configuration
from models.responses import ConversationResponse, ConversationDeleteResponse
from models.responses import (
ConversationResponse,
ConversationDeleteResponse,
ConversationsListResponse,
ConversationDetails,
)
from models.database.conversations import UserConversation
from auth import get_auth_dependency
from utils.endpoints import check_configuration_loaded
from app.database import get_session
from utils.endpoints import check_configuration_loaded, validate_conversation_ownership
from utils.suid import check_suid

logger = logging.getLogger("app.endpoints.handlers")
Expand Down Expand Up @@ -66,6 +73,35 @@
},
}

conversations_list_responses: dict[int | str, dict[str, Any]] = {
200: {
"conversations": [
{
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
"created_at": "2024-01-01T00:00:00Z",
"last_message_at": "2024-01-01T00:05:00Z",
"last_used_model": "gemini/gemini-1.5-flash",
"last_used_provider": "gemini",
"message_count": 5,
},
{
"conversation_id": "456e7890-e12b-34d5-a678-901234567890",
"created_at": "2024-01-01T01:00:00Z",
"last_message_at": "2024-01-01T01:02:00Z",
"last_used_model": "gemini/gemini-2.0-flash",
"last_used_provider": "gemini",
"message_count": 2,
},
]
},
503: {
"detail": {
"response": "Unable to connect to Llama Stack",
"cause": "Connection error.",
}
},
}


def simplify_session_data(session_data: dict) -> list[dict[str, Any]]:
"""Simplify session data to include only essential conversation information.
Expand Down Expand Up @@ -109,10 +145,64 @@ def simplify_session_data(session_data: dict) -> list[dict[str, Any]]:
return chat_history


@router.get("/conversations", responses=conversations_list_responses)
def get_conversations_list_endpoint_handler(
auth: Any = Depends(auth_dependency),
) -> ConversationsListResponse:
"""Handle request to retrieve all conversations for the authenticated user."""
check_configuration_loaded(configuration)

user_id, _, _ = auth

logger.info("Retrieving conversations for user %s", user_id)

with get_session() as session:
try:
# Get all conversations for this user
user_conversations = (
session.query(UserConversation).filter_by(user_id=user_id).all()
)

# Return conversation summaries with metadata
conversations = [
ConversationDetails(
conversation_id=conv.id,
created_at=conv.created_at.isoformat() if conv.created_at else None,
last_message_at=(
conv.last_message_at.isoformat()
if conv.last_message_at
else None
),
message_count=conv.message_count,
last_used_model=conv.last_used_model,
last_used_provider=conv.last_used_provider,
)
for conv in user_conversations
]

logger.info(
"Found %d conversations for user %s", len(conversations), user_id
)

return ConversationsListResponse(conversations=conversations)

except Exception as e:
logger.exception(
"Error retrieving conversations for user %s: %s", user_id, e
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
"response": "Unknown error",
"cause": f"Unknown error while getting conversations for user {user_id}",
},
) from e

Comment on lines +159 to +200
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Catch exceptions from session acquisition as well.
If get_session().enter raises (e.g., DB connection failure), it happens before the try: block and bypasses your HTTPException mapping. Wrap the with in the try.

-    with get_session() as session:
-        try:
-            # Get all conversations for this user
-            user_conversations = (
-                session.query(UserConversation).filter_by(user_id=user_id).all()
-            )
-            ...
-            return ConversationsListResponse(conversations=conversations)
-
-        except Exception as e:
-            logger.exception(
-                "Error retrieving conversations for user %s: %s", user_id, e
-            )
-            raise HTTPException(
-                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
-                detail={
-                    "response": "Unknown error",
-                    "cause": f"Unknown error while getting conversations for user {user_id}",
-                },
-            ) from e
+    try:
+        with get_session() as session:
+            # Get all conversations for this user
+            user_conversations = (
+                session.query(UserConversation).filter_by(user_id=user_id).all()
+            )
+            ...
+            return ConversationsListResponse(conversations=conversations)
+    except Exception as e:
+        logger.exception(
+            "Error retrieving conversations for user %s: %s", user_id, e
+        )
+        raise HTTPException(
+            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+            detail={
+                "response": "Unknown error",
+                "cause": "Unknown error while getting conversations",
+            },
+        ) from e
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
with get_session() as session:
try:
# Get all conversations for this user
user_conversations = (
session.query(UserConversation).filter_by(user_id=user_id).all()
)
# Return conversation summaries with metadata
conversations = [
ConversationDetails(
conversation_id=conv.id,
created_at=conv.created_at.isoformat() if conv.created_at else None,
last_message_at=(
conv.last_message_at.isoformat()
if conv.last_message_at
else None
),
message_count=conv.message_count,
last_used_model=conv.last_used_model,
last_used_provider=conv.last_used_provider,
)
for conv in user_conversations
]
logger.info(
"Found %d conversations for user %s", len(conversations), user_id
)
return ConversationsListResponse(conversations=conversations)
except Exception as e:
logger.exception(
"Error retrieving conversations for user %s: %s", user_id, e
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
"response": "Unknown error",
"cause": f"Unknown error while getting conversations for user {user_id}",
},
) from e
try:
with get_session() as session:
# Get all conversations for this user
user_conversations = (
session.query(UserConversation).filter_by(user_id=user_id).all()
)
# Return conversation summaries with metadata
conversations = [
ConversationDetails(
conversation_id=conv.id,
created_at=conv.created_at.isoformat() if conv.created_at else None,
last_message_at=(
conv.last_message_at.isoformat()
if conv.last_message_at
else None
),
message_count=conv.message_count,
last_used_model=conv.last_used_model,
last_used_provider=conv.last_used_provider,
)
for conv in user_conversations
]
logger.info(
"Found %d conversations for user %s", len(conversations), user_id
)
return ConversationsListResponse(conversations=conversations)
except Exception as e:
logger.exception(
"Error retrieving conversations for user %s: %s", user_id, e
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
"response": "Unknown error",
"cause": "Unknown error while getting conversations",
},
) from e
🤖 Prompt for AI Agents
In src/app/endpoints/conversations.py around lines 159 to 200, the current try
block does not cover the get_session() context manager acquisition, so
exceptions raised during session creation (like DB connection failures) are not
caught and mapped to HTTPException. To fix this, move the try block to wrap the
entire with get_session() statement, ensuring any exceptions during session
acquisition are caught and handled properly.


@router.get("/conversations/{conversation_id}", responses=conversation_responses)
async def get_conversation_endpoint_handler(
conversation_id: str,
_auth: Any = Depends(auth_dependency),
auth: Any = Depends(auth_dependency),
) -> ConversationResponse:
"""Handle request to retrieve a conversation by ID."""
check_configuration_loaded(configuration)
Expand All @@ -128,6 +218,13 @@ async def get_conversation_endpoint_handler(
},
)

user_id, _, _ = auth

validate_conversation_ownership(
user_id=user_id,
conversation_id=conversation_id,
)

agent_id = conversation_id
logger.info("Retrieving conversation %s", conversation_id)

Expand Down Expand Up @@ -187,7 +284,7 @@ async def get_conversation_endpoint_handler(
)
async def delete_conversation_endpoint_handler(
conversation_id: str,
_auth: Any = Depends(auth_dependency),
auth: Any = Depends(auth_dependency),
) -> ConversationDeleteResponse:
"""Handle request to delete a conversation by ID."""
check_configuration_loaded(configuration)
Expand All @@ -202,6 +299,14 @@ async def delete_conversation_endpoint_handler(
"cause": f"Conversation ID {conversation_id} is not a valid UUID",
},
)

user_id, _, _ = auth

validate_conversation_ownership(
user_id=user_id,
conversation_id=conversation_id,
)

agent_id = conversation_id
logger.info("Deleting conversation %s", conversation_id)

Expand Down
Loading
Loading