Skip to content

Commit d1d48b4

Browse files
authored
Merge pull request #347 from omertuc/conv
LCORE-494: Add database persistence layer and conversation tracking system
2 parents 879d42a + 3f7ed75 commit d1d48b4

21 files changed

+984
-41
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ dependencies = [
3434
"aiohttp>=3.12.14",
3535
"authlib>=1.6.0",
3636
"openai==1.99.1",
37+
"sqlalchemy>=2.0.42",
3738
]
3839

3940
[tool.pyright]

src/app/database.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
"""Database engine management."""
2+
3+
from pathlib import Path
4+
from typing import Any
5+
6+
from sqlalchemy import create_engine, text
7+
from sqlalchemy.engine.base import Engine
8+
from sqlalchemy.orm import sessionmaker, Session
9+
from log import get_logger, logging
10+
from configuration import configuration
11+
from models.database.base import Base
12+
from models.config import SQLiteDatabaseConfiguration, PostgreSQLDatabaseConfiguration
13+
14+
logger = get_logger(__name__)
15+
16+
engine: Engine | None = None
17+
SessionLocal: sessionmaker | None = None
18+
19+
20+
def get_engine() -> Engine:
21+
"""Get the database engine. Raises an error if not initialized."""
22+
if engine is None:
23+
raise RuntimeError(
24+
"Database engine not initialized. Call initialize_database() first."
25+
)
26+
return engine
27+
28+
29+
def create_tables() -> None:
30+
"""Create tables."""
31+
Base.metadata.create_all(get_engine())
32+
33+
34+
def get_session() -> Session:
35+
"""Get a database session. Raises an error if not initialized."""
36+
if SessionLocal is None:
37+
raise RuntimeError(
38+
"Database session not initialized. Call initialize_database() first."
39+
)
40+
return SessionLocal()
41+
42+
43+
def _create_sqlite_engine(config: SQLiteDatabaseConfiguration, **kwargs: Any) -> Engine:
44+
"""Create SQLite database engine."""
45+
if not Path(config.db_path).parent.exists():
46+
raise FileNotFoundError(
47+
f"SQLite database directory does not exist: {config.db_path}"
48+
)
49+
50+
try:
51+
return create_engine(f"sqlite:///{config.db_path}", **kwargs)
52+
except Exception as e:
53+
logger.exception("Failed to create SQLite engine")
54+
raise RuntimeError(f"SQLite engine creation failed: {e}") from e
55+
56+
57+
def _create_postgres_engine(
58+
config: PostgreSQLDatabaseConfiguration, **kwargs: Any
59+
) -> Engine:
60+
"""Create PostgreSQL database engine."""
61+
postgres_url = (
62+
f"postgresql://{config.user}:{config.password}@"
63+
f"{config.host}:{config.port}/{config.db}"
64+
f"?sslmode={config.ssl_mode}&gssencmode={config.gss_encmode}"
65+
)
66+
67+
is_custom_schema = config.namespace is not None and config.namespace != "public"
68+
69+
connect_args = {}
70+
if is_custom_schema:
71+
connect_args["options"] = f"-csearch_path={config.namespace}"
72+
73+
if config.ca_cert_path is not None:
74+
connect_args["sslrootcert"] = str(config.ca_cert_path)
75+
76+
try:
77+
postgres_engine = create_engine(
78+
postgres_url, connect_args=connect_args, **kwargs
79+
)
80+
except Exception as e:
81+
logger.exception("Failed to create PostgreSQL engine")
82+
raise RuntimeError(f"PostgreSQL engine creation failed: {e}") from e
83+
84+
if is_custom_schema:
85+
try:
86+
with postgres_engine.connect() as connection:
87+
connection.execute(
88+
text(f'CREATE SCHEMA IF NOT EXISTS "{config.namespace}"')
89+
)
90+
connection.commit()
91+
logger.info("Schema '%s' created or already exists", config.namespace)
92+
except Exception as e:
93+
logger.exception("Failed to create schema '%s'", config.namespace)
94+
raise RuntimeError(
95+
f"Schema creation failed for '{config.namespace}': {e}"
96+
) from e
97+
98+
return postgres_engine
99+
100+
101+
def initialize_database() -> None:
102+
"""Initialize the database engine."""
103+
db_config = configuration.database_configuration
104+
105+
global engine, SessionLocal # pylint: disable=global-statement
106+
107+
# Debug print all SQL statements if our logger is at-least DEBUG level
108+
echo = bool(logger.isEnabledFor(logging.DEBUG))
109+
110+
create_engine_kwargs = {
111+
"echo": echo,
112+
}
113+
114+
match db_config.db_type:
115+
case "sqlite":
116+
sqlite_config = db_config.config
117+
assert isinstance(sqlite_config, SQLiteDatabaseConfiguration)
118+
engine = _create_sqlite_engine(sqlite_config, **create_engine_kwargs)
119+
case "postgres":
120+
postgres_config = db_config.config
121+
assert isinstance(postgres_config, PostgreSQLDatabaseConfiguration)
122+
engine = _create_postgres_engine(postgres_config, **create_engine_kwargs)
123+
124+
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

src/app/endpoints/conversations.py

Lines changed: 109 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,16 @@
99

1010
from client import AsyncLlamaStackClientHolder
1111
from configuration import configuration
12-
from models.responses import ConversationResponse, ConversationDeleteResponse
12+
from models.responses import (
13+
ConversationResponse,
14+
ConversationDeleteResponse,
15+
ConversationsListResponse,
16+
ConversationDetails,
17+
)
18+
from models.database.conversations import UserConversation
1319
from auth import get_auth_dependency
14-
from utils.endpoints import check_configuration_loaded
20+
from app.database import get_session
21+
from utils.endpoints import check_configuration_loaded, validate_conversation_ownership
1522
from utils.suid import check_suid
1623

1724
logger = logging.getLogger("app.endpoints.handlers")
@@ -66,6 +73,35 @@
6673
},
6774
}
6875

76+
conversations_list_responses: dict[int | str, dict[str, Any]] = {
77+
200: {
78+
"conversations": [
79+
{
80+
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
81+
"created_at": "2024-01-01T00:00:00Z",
82+
"last_message_at": "2024-01-01T00:05:00Z",
83+
"last_used_model": "gemini/gemini-1.5-flash",
84+
"last_used_provider": "gemini",
85+
"message_count": 5,
86+
},
87+
{
88+
"conversation_id": "456e7890-e12b-34d5-a678-901234567890",
89+
"created_at": "2024-01-01T01:00:00Z",
90+
"last_message_at": "2024-01-01T01:02:00Z",
91+
"last_used_model": "gemini/gemini-2.0-flash",
92+
"last_used_provider": "gemini",
93+
"message_count": 2,
94+
},
95+
]
96+
},
97+
503: {
98+
"detail": {
99+
"response": "Unable to connect to Llama Stack",
100+
"cause": "Connection error.",
101+
}
102+
},
103+
}
104+
69105

70106
def simplify_session_data(session_data: dict) -> list[dict[str, Any]]:
71107
"""Simplify session data to include only essential conversation information.
@@ -109,10 +145,64 @@ def simplify_session_data(session_data: dict) -> list[dict[str, Any]]:
109145
return chat_history
110146

111147

148+
@router.get("/conversations", responses=conversations_list_responses)
149+
def get_conversations_list_endpoint_handler(
150+
auth: Any = Depends(auth_dependency),
151+
) -> ConversationsListResponse:
152+
"""Handle request to retrieve all conversations for the authenticated user."""
153+
check_configuration_loaded(configuration)
154+
155+
user_id, _, _ = auth
156+
157+
logger.info("Retrieving conversations for user %s", user_id)
158+
159+
with get_session() as session:
160+
try:
161+
# Get all conversations for this user
162+
user_conversations = (
163+
session.query(UserConversation).filter_by(user_id=user_id).all()
164+
)
165+
166+
# Return conversation summaries with metadata
167+
conversations = [
168+
ConversationDetails(
169+
conversation_id=conv.id,
170+
created_at=conv.created_at.isoformat() if conv.created_at else None,
171+
last_message_at=(
172+
conv.last_message_at.isoformat()
173+
if conv.last_message_at
174+
else None
175+
),
176+
message_count=conv.message_count,
177+
last_used_model=conv.last_used_model,
178+
last_used_provider=conv.last_used_provider,
179+
)
180+
for conv in user_conversations
181+
]
182+
183+
logger.info(
184+
"Found %d conversations for user %s", len(conversations), user_id
185+
)
186+
187+
return ConversationsListResponse(conversations=conversations)
188+
189+
except Exception as e:
190+
logger.exception(
191+
"Error retrieving conversations for user %s: %s", user_id, e
192+
)
193+
raise HTTPException(
194+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
195+
detail={
196+
"response": "Unknown error",
197+
"cause": f"Unknown error while getting conversations for user {user_id}",
198+
},
199+
) from e
200+
201+
112202
@router.get("/conversations/{conversation_id}", responses=conversation_responses)
113203
async def get_conversation_endpoint_handler(
114204
conversation_id: str,
115-
_auth: Any = Depends(auth_dependency),
205+
auth: Any = Depends(auth_dependency),
116206
) -> ConversationResponse:
117207
"""Handle request to retrieve a conversation by ID."""
118208
check_configuration_loaded(configuration)
@@ -128,6 +218,13 @@ async def get_conversation_endpoint_handler(
128218
},
129219
)
130220

221+
user_id, _, _ = auth
222+
223+
validate_conversation_ownership(
224+
user_id=user_id,
225+
conversation_id=conversation_id,
226+
)
227+
131228
agent_id = conversation_id
132229
logger.info("Retrieving conversation %s", conversation_id)
133230

@@ -187,7 +284,7 @@ async def get_conversation_endpoint_handler(
187284
)
188285
async def delete_conversation_endpoint_handler(
189286
conversation_id: str,
190-
_auth: Any = Depends(auth_dependency),
287+
auth: Any = Depends(auth_dependency),
191288
) -> ConversationDeleteResponse:
192289
"""Handle request to delete a conversation by ID."""
193290
check_configuration_loaded(configuration)
@@ -202,6 +299,14 @@ async def delete_conversation_endpoint_handler(
202299
"cause": f"Conversation ID {conversation_id} is not a valid UUID",
203300
},
204301
)
302+
303+
user_id, _, _ = auth
304+
305+
validate_conversation_ownership(
306+
user_id=user_id,
307+
conversation_id=conversation_id,
308+
)
309+
205310
agent_id = conversation_id
206311
logger.info("Deleting conversation %s", conversation_id)
207312

0 commit comments

Comments
 (0)