diff --git a/MANIFEST.in b/MANIFEST.in index c9715ab..fe7adb1 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1 +1,3 @@ -include src/web/index.html \ No newline at end of file +include src/web/index.html +include src/promptlab/alembic.ini +recursive-include src/promptlab/migrations * \ No newline at end of file diff --git a/docs/database_management.md b/docs/database_management.md new file mode 100644 index 0000000..2aaf1de --- /dev/null +++ b/docs/database_management.md @@ -0,0 +1,163 @@ +# Database Management in PromptLab + +This document explains the centralized database management system implemented in PromptLab to solve the multiple initialization issue. + +## Problem Solved + +Previously, PromptLab had multiple `init_engine` functions that could initialize the database concurrently, leading to: +- Race conditions during startup +- Redundant database initialization +- Potential inconsistent database state +- Performance issues + +## Solution Overview + +### Centralized Database Manager + +The new system implements a singleton `DatabaseManager` class that ensures: +- **Single Point of Initialization**: Only one place handles database setup +- **Thread Safety**: Uses locking mechanisms to prevent race conditions +- **One-Time Operation**: Database is initialized only once, regardless of how many times it's requested +- **Migration Support**: Integrated Alembic support for schema migrations + +### Key Components + +1. **DatabaseManager** (`src/promptlab/sqlite/database_manager.py`) + - Singleton pattern ensures single instance + - Thread-safe initialization with double-checked locking + - Automatic Alembic migration support + - Logging for debugging and monitoring + +2. **Enhanced Session Management** (`src/promptlab/sqlite/session.py`) + - Thread-safe session initialization + - Utility functions for checking initialization state + - Reset functionality for testing + +3. **CLI Commands** (`src/promptlab/_cli.py`) + - `promptlab db init`: Initialize database + - `promptlab db migrate`: Run migrations + - `promptlab db revision`: Create new migration + +## Usage + +### Starting the Studio + +The studio startup remains the same: +```bash +promptlab studio start -d /path/to/database.db -p 8000 +``` + +The database will be automatically initialized on first startup. + +### Manual Database Operations + +Initialize a database: +```bash +promptlab db init -d /path/to/database.db +``` + +Run migrations: +```bash +promptlab db migrate -d /path/to/database.db +``` + +Create a new migration: +```bash +promptlab db revision -d /path/to/database.db -m "Add new table" +``` + +### Programmatic Usage + +```python +from promptlab.sqlite.database_manager import db_manager +from promptlab.tracer.local_tracer import LocalTracer + +# The database will be automatically initialized +tracer = LocalTracer({"type": "local", "db_file": "/path/to/db.sqlite"}) + +# Or initialize manually +db_manager.initialize_database("/path/to/db.sqlite") +``` + +## Migration System + +### Alembic Integration + +The system now includes full Alembic support for database schema migrations: + +- **Automatic Migration Detection**: On startup, the system checks for pending migrations +- **Safe Migration Execution**: Migrations are applied automatically and safely +- **Version Tracking**: Database schema version is tracked in the `alembic_version` table + +### Creating Migrations + +1. Make changes to your SQLAlchemy models in `src/promptlab/sqlite/models.py` +2. Generate a migration: + ```bash + promptlab db revision -d /path/to/database.db -m "Description of changes" + ``` +3. Review the generated migration file in `migrations/versions/` +4. Apply migrations: + ```bash + promptlab db migrate -d /path/to/database.db + ``` + +## File Structure + +``` +promptlab/ +├── alembic.ini # Alembic configuration +├── migrations/ # Migration files +│ ├── env.py # Alembic environment +│ ├── script.py.mako # Migration template +│ └── versions/ # Version files +├── src/promptlab/sqlite/ +│ ├── database_manager.py # Centralized database manager +│ ├── session.py # Enhanced session management +│ └── models.py # SQLAlchemy models +└── tests/unit/ + └── test_database_initialization.py # Tests for new system +``` + +## Thread Safety + +The new system implements several thread safety mechanisms: + +1. **Double-Checked Locking**: Prevents race conditions during initialization +2. **Global State Protection**: Uses threading locks to protect shared state +3. **Singleton Pattern**: Ensures only one DatabaseManager instance exists + +## Testing + +Run the database initialization tests: +```bash +python -m pytest tests/unit/test_database_initialization.py +``` + +The tests verify: +- Single initialization across multiple calls +- Thread safety with concurrent access +- Proper LocalTracer integration +- Database file creation + +## Benefits + +1. **Reliability**: Eliminates race conditions and initialization conflicts +2. **Performance**: Avoids redundant database operations +3. **Maintainability**: Centralized database logic is easier to maintain +4. **Scalability**: Proper migration system supports schema evolution +5. **Debugging**: Comprehensive logging helps troubleshoot issues + +## Backwards Compatibility + +The changes are backwards compatible: +- Existing `LocalTracer` usage remains unchanged +- CLI commands work as before +- No breaking changes to public APIs + +## Future Enhancements + +- Database connection pooling for improved performance +- Support for multiple database backends +- Enhanced migration rollback capabilities +- Database health monitoring and metrics diff --git a/pyproject.toml b/pyproject.toml index a9d4c6e..03c9c47 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ dependencies = [ "fastapi", "uvicorn[standard]>=0.18.0", "sqlalchemy", + "alembic>=1.8.0", "passlib", "python-jose", "python-multipart", @@ -46,7 +47,7 @@ Homepage = "https://github.com/imum-ai/promptlab" Issues = "https://github.com/imum-ai/promptlab/issues" [tool.setuptools.package-data] -promptlab = ["web/*.html"] +promptlab = ["web/*.html", "alembic.ini", "migrations/*", "migrations/**/*"] [tool.setuptools.packages.find] where = ["src"] diff --git a/run_tests.sh b/run_tests.sh deleted file mode 100755 index bb1d0f3..0000000 --- a/run_tests.sh +++ /dev/null @@ -1,25 +0,0 @@ -#!/bin/bash - -# Create a virtual environment -python3 -m venv venv - -# Activate the virtual environment -source venv/bin/activate - -# Install dependencies -pip install -e . -pip install pytest pytest-asyncio pytest-cov - -# Run unit tests -echo "Running unit tests..." -pytest tests/unit -v --cov=src/promptlab - -# Run integration tests -echo "Running integration tests..." -pytest tests/integration -v - -# Run performance tests -echo "Running performance tests..." -pytest tests/performance -v - - diff --git a/src/promptlab/_cli.py b/src/promptlab/_cli.py index e0ab2a0..4d4b69b 100644 --- a/src/promptlab/_cli.py +++ b/src/promptlab/_cli.py @@ -31,6 +31,5 @@ def start(db, port): click.echo(f"Running on port: {port}") - if __name__ == "__main__": promptlab() diff --git a/src/promptlab/alembic.ini b/src/promptlab/alembic.ini new file mode 100644 index 0000000..e867d79 --- /dev/null +++ b/src/promptlab/alembic.ini @@ -0,0 +1,71 @@ +# Alembic configuration for PromptLab database migrations + +[alembic] +# Path to migration scripts +script_location = migrations + +# Template used to generate migration file names +file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# Max length of characters to apply to the "slug" field +truncate_slug_length = 40 + +# Set to 'true' to run the environment during the 'revision' command +revision_environment = false + +# Set to 'true' to allow .pyc and .pyo files without a source .py file to be detected +# as revisions in the versions/ directory +sourceless = false + +# Version table name +version_table = alembic_version + +# Version path separator (default: os.pathsep) +version_path_separator = : + +# Set to 'true' to search source files recursively in the versions/ directory +recursive_version_locations = false + +# The output encoding used when revision files are written from script.py.mako +output_encoding = utf-8 + +# Database URL placeholder - will be set programmatically +sqlalchemy.url = + +[post_write_hooks] +# Post-write hooks define scripts or Python functions that are run +# on newly-generated revision scripts + +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/src/promptlab/migrations/env.py b/src/promptlab/migrations/env.py new file mode 100644 index 0000000..e875f63 --- /dev/null +++ b/src/promptlab/migrations/env.py @@ -0,0 +1,89 @@ +"""Alembic environment configuration for PromptLab migrations.""" + +from logging.config import fileConfig +from pathlib import Path +import sys + +from sqlalchemy import engine_from_config +from sqlalchemy import pool +from alembic import context + +# Add the src directory to the path so we can import our models +project_root = Path(__file__).parent.parent.parent.parent +src_path = project_root / "src" +sys.path.insert(0, str(src_path)) + +# Import your models here (after path modification) +from promptlab.sqlite.models import Base # noqa: E402 + +# This is the Alembic Config object +config = context.config + +# Interpret the config file for Python logging +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# Set the target metadata for 'autogenerate' support +target_metadata = Base.metadata + + +# Other values from the config +def get_url(): + """Get database URL from environment or config.""" + url = config.get_main_option("sqlalchemy.url") + if url: + return url + + # Fallback to a default SQLite URL for offline mode + return "sqlite:///promptlab.db" + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + """ + url = get_url() + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + """ + configuration = config.get_section(config.config_ini_section) + configuration["sqlalchemy.url"] = get_url() + + connectable = engine_from_config( + configuration, + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure(connection=connection, target_metadata=target_metadata) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/src/promptlab/migrations/script.py.mako b/src/promptlab/migrations/script.py.mako new file mode 100644 index 0000000..cbdd2dd --- /dev/null +++ b/src/promptlab/migrations/script.py.mako @@ -0,0 +1,26 @@ +"""Mako template for Alembic migration scripts.""" + +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision = ${repr(up_revision)} +down_revision = ${repr(down_revision)} +branch_labels = ${repr(branch_labels)} +depends_on = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/src/promptlab/migrations/versions/2025_08_21_1302-9f33dbbebd6c_initial_migration.py b/src/promptlab/migrations/versions/2025_08_21_1302-9f33dbbebd6c_initial_migration.py new file mode 100644 index 0000000..8f89205 --- /dev/null +++ b/src/promptlab/migrations/versions/2025_08_21_1302-9f33dbbebd6c_initial_migration.py @@ -0,0 +1,79 @@ +"""Mako template for Alembic migration scripts.""" + +"""Initial migration + +Revision ID: 9f33dbbebd6c +Revises: +Create Date: 2025-08-21 13:02:37.445724 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '9f33dbbebd6c' +down_revision = None +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('users', + sa.Column('id', sa.Integer(), autoincrement=True, nullable=False), + sa.Column('username', sa.String(), nullable=False), + sa.Column('password_hash', sa.String(), nullable=False), + sa.Column('role', sa.String(), nullable=False), + sa.Column('status', sa.Integer(), nullable=True), + sa.Column('created_at', sa.DateTime(), nullable=True), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('username') + ) + op.create_table('assets', + sa.Column('asset_name', sa.String(), nullable=False), + sa.Column('asset_version', sa.Integer(), nullable=False), + sa.Column('asset_description', sa.Text(), nullable=True), + sa.Column('asset_type', sa.String(), nullable=True), + sa.Column('asset_binary', sa.Text(), nullable=True), + sa.Column('is_deployed', sa.Boolean(), nullable=True), + sa.Column('deployment_time', sa.DateTime(), nullable=True), + sa.Column('status', sa.Integer(), nullable=True), + sa.Column('created_at', sa.DateTime(), nullable=True), + sa.Column('user_id', sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), + sa.PrimaryKeyConstraint('asset_name', 'asset_version') + ) + op.create_table('experiments', + sa.Column('experiment_id', sa.String(), nullable=False), + sa.Column('model', sa.Text(), nullable=True), + sa.Column('asset', sa.Text(), nullable=True), + sa.Column('status', sa.Integer(), nullable=True), + sa.Column('created_at', sa.DateTime(), nullable=True), + sa.Column('user_id', sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), + sa.PrimaryKeyConstraint('experiment_id') + ) + op.create_table('experiment_result', + sa.Column('id', sa.Integer(), autoincrement=True, nullable=False), + sa.Column('experiment_id', sa.String(), nullable=True), + sa.Column('dataset_record_id', sa.String(), nullable=True), + sa.Column('completion', sa.Text(), nullable=True), + sa.Column('prompt_tokens', sa.Integer(), nullable=True), + sa.Column('completion_tokens', sa.Integer(), nullable=True), + sa.Column('latency_ms', sa.Float(), nullable=True), + sa.Column('evaluation', sa.Text(), nullable=True), + sa.Column('created_at', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint(['experiment_id'], ['experiments.experiment_id'], ), + sa.PrimaryKeyConstraint('id') + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('experiment_result') + op.drop_table('experiments') + op.drop_table('assets') + op.drop_table('users') + # ### end Alembic commands ### diff --git a/src/promptlab/sqlite/__init__.py b/src/promptlab/sqlite/__init__.py index e69de29..4f081aa 100644 --- a/src/promptlab/sqlite/__init__.py +++ b/src/promptlab/sqlite/__init__.py @@ -0,0 +1,11 @@ +# """SQLite database module for PromptLab.""" + +# # from .database_manager import db_manager +# from .session import db_manager get_session, init_engine, is_initialized + +# __all__ = [ +# "db_manager", +# "get_session", +# "init_engine", +# "is_initialized", +# ] diff --git a/src/promptlab/sqlite/db_manager.py b/src/promptlab/sqlite/db_manager.py new file mode 100644 index 0000000..36c4d86 --- /dev/null +++ b/src/promptlab/sqlite/db_manager.py @@ -0,0 +1,103 @@ +import threading +import logging +from pathlib import Path +from requests import session +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from .models import Base +from alembic.config import Config +from alembic import command +from .models import User +from passlib.context import CryptContext + +import threading +import logging +from pathlib import Path +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from .models import Base + +logger = logging.getLogger(__name__) + +class DatabaseManager: + def __init__(self): + self._engine = None + self._session = None + self._db_initialized = False + self._init_lock = threading.Lock() + + def initialize_database(self, db_file: str) -> None: + """Initialize the database and run migrations, create default admin user.""" + if self._db_initialized: + logger.info("Database already initialized, skipping initialization") + return + + db_path = Path(db_file) + db_path.parent.mkdir(parents=True, exist_ok=True) + db_url = f"sqlite:///{db_file}" + logger.info(f"Initializing database at: {db_file}") + + try: + self._init_engine(db_url) + self._run_migrations(db_url) + self._create_default_admin_user() + logger.info("Database initialized successfully") + except Exception as e: + logger.error(f"Failed to initialize database: {e}") + raise + + def _init_engine(self, db_url): + """Initialize the database engine and session maker (thread-safe, idempotent).""" + if self._db_initialized: + return + with self._init_lock: + if self._db_initialized: + return + self._engine = create_engine(db_url, connect_args={"check_same_thread": False}) + self._session = sessionmaker(autocommit=False, autoflush=False, bind=self._engine) + self._db_initialized = True + + def get_session(self): + """Return a new SQLAlchemy session.""" + if self._session is None: + raise RuntimeError("Session not initialized. Call init_engine first.") + return self._session() + + def _create_default_admin_user(self): + pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + try: + session = self.get_session() + if not session.query(User).filter_by(username="admin").first(): + admin_user = User( + username="admin", password_hash=pwd_context.hash("admin"), role="admin" + ) + session.add(admin_user) + session.commit() + finally: + session.close() + + def _run_migrations(self, db_url: str) -> None: + try: + package_root = Path(__file__).parent.parent + alembic_cfg_path = package_root / "alembic.ini" + if not alembic_cfg_path.exists(): + logger.warning("Alembic configuration not found, skipping migrations") + return + + alembic_cfg = Config(str(alembic_cfg_path)) + alembic_cfg.set_main_option("sqlalchemy.url", db_url) + script_location = alembic_cfg.get_main_option("script_location") + + if script_location and not Path(script_location).is_absolute(): + abs_script_location = str((alembic_cfg_path.parent / script_location).resolve()) + alembic_cfg.set_main_option("script_location", abs_script_location) + + logger.info("Running Alembic migrations to update database schema...") + command.upgrade(alembic_cfg, "head") + logger.info("Database migrations applied successfully.") + except ImportError: + logger.warning("Alembic not installed, skipping migrations") + except Exception as e: + logger.error(f"Error running migrations: {e}") + +db_manager = DatabaseManager() diff --git a/src/promptlab/sqlite/session.py b/src/promptlab/sqlite/session.py deleted file mode 100644 index 504a253..0000000 --- a/src/promptlab/sqlite/session.py +++ /dev/null @@ -1,40 +0,0 @@ -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker -from .models import Base - -_engine = None -_SessionLocal = None - - -def _create_default_admin_user(): - """Create default admin user if it doesn't exist.""" - from .models import User - from passlib.context import CryptContext - - pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") - session = _SessionLocal() - try: - if not session.query(User).filter_by(username="admin").first(): - admin_user = User( - username="admin", password_hash=pwd_context.hash("admin"), role="admin" - ) - session.add(admin_user) - session.commit() - finally: - session.close() - - -def init_engine(db_url): - global _engine, _SessionLocal - _engine = create_engine(db_url, connect_args={"check_same_thread": False}) - _SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=_engine) - Base.metadata.create_all(bind=_engine) - - # Insert default admin user if not exists - _create_default_admin_user() - - -def get_session(): - if _SessionLocal is None: - raise RuntimeError("Session not initialized. Call init_engine first.") - return _SessionLocal() diff --git a/src/promptlab/tracer/local_tracer.py b/src/promptlab/tracer/local_tracer.py index 79d0b05..77cab8b 100644 --- a/src/promptlab/tracer/local_tracer.py +++ b/src/promptlab/tracer/local_tracer.py @@ -6,7 +6,8 @@ from sqlalchemy.orm import joinedload from promptlab.types import ExperimentConfig, TracerConfig, Dataset, PromptTemplate -from promptlab.sqlite.session import get_session, init_engine +# from promptlab.sqlite.session import get_session +from promptlab.sqlite.db_manager import db_manager from promptlab.enums import AssetType from promptlab.sqlite.sql import SQLQuery from promptlab.tracer.tracer import Tracer @@ -20,11 +21,10 @@ class LocalTracer(Tracer): def __init__(self, tracer_config: TracerConfig): - db_url = f"sqlite:///{tracer_config.db_file}" - init_engine(db_url) + db_manager.initialize_database(tracer_config.db_file) def _create_asset(self, asset: ORMAsset): - session = get_session() + session = db_manager.get_session() try: session.add(asset) session.commit() @@ -95,7 +95,7 @@ def create_prompttemplate(self, template: PromptTemplate): def trace_experiment( self, experiment_config: ExperimentConfig, experiment_summary: List[Dict] ) -> None: - session = get_session() + session = db_manager.get_session() try: experiment_id = experiment_summary[0]["experiment_id"] @@ -144,7 +144,7 @@ def trace_experiment( session.close() def get_asset(self, asset_name: str, asset_version: int) -> ORMAsset: - session = get_session() + session = db_manager.get_session() try: asset = ( session.query(ORMAsset) @@ -163,7 +163,7 @@ def get_asset(self, asset_name: str, asset_version: int) -> ORMAsset: session.close() def get_assets_by_type(self, asset_type: str) -> List[Any]: - session = get_session() + session = db_manager.get_session() try: if asset_type not in AssetType._value2member_map_: raise ValueError(f"Invalid asset type: {asset_type}") @@ -181,7 +181,7 @@ def get_assets_by_type(self, asset_type: str) -> List[Any]: session.close() def get_latest_asset(self, asset_name: str) -> ORMAsset: - session = get_session() + session = db_manager.get_session() try: asset = ( session.query(ORMAsset) @@ -197,7 +197,7 @@ def get_latest_asset(self, asset_name: str) -> ORMAsset: session.close() def get_user_by_username(self, username: str) -> User: - session = get_session() + session = db_manager.get_session() try: user = session.query(User).filter_by(username=username).first() if not user: @@ -210,7 +210,7 @@ def get_user_by_username(self, username: str) -> User: session.close() def get_experiments(self): - session = get_session() + session = db_manager.get_session() try: return ( session.execute(text(SQLQuery.SELECT_EXPERIMENTS_QUERY)) @@ -224,7 +224,7 @@ def get_experiments(self): session.close() def get_users(self): - session = get_session() + session = db_manager.get_session() try: return session.query(User).filter(User.status == 1).all() except Exception: @@ -234,7 +234,7 @@ def get_users(self): session.close() def create_user(self, user: User): - session = get_session() + session = db_manager.get_session() try: session.add(user) session.commit() @@ -245,7 +245,7 @@ def create_user(self, user: User): session.close() def deactivate_user_by_username(self, username: str): - session = get_session() + session = db_manager.get_session() try: user = session.query(User).filter_by(username=username).first() if not user: @@ -262,7 +262,7 @@ def me(self) -> User: _current_username = ( "admin" # This should be replaced with the actual current user logic ) - session = get_session() + session = db_manager.get_session() try: user = ( session.query(User).filter_by(username=_current_username).first()