diff --git a/README.md b/README.md index f2303b9..e5e30dd 100644 --- a/README.md +++ b/README.md @@ -199,6 +199,59 @@ The Postgres MCP Pro Docker image will automatically remap the hostname `localho Replace `postgresql://...` with your [Postgres database connection URI](https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING-URIS). +##### Multiple Database Connections + +Postgres MCP Pro supports connecting to multiple databases simultaneously. This is useful when you need to work across different databases (e.g., application database, ETL database, analytics database). + +To configure multiple connections, define additional environment variables with the pattern `DATABASE_URI_`: + +```json +{ + "mcpServers": { + "postgres": { + "command": "docker", + "args": [ + "run", + "-i", + "--rm", + "-e", "DATABASE_URI_APP", + "-e", "DATABASE_URI_ETL", + "-e", "DATABASE_URI_ANALYTICS", + "-e", "DATABASE_DESC_APP", + "-e", "DATABASE_DESC_ETL", + "-e", "DATABASE_DESC_ANALYTICS", + "crystaldba/postgres-mcp", + "--access-mode=unrestricted" + ], + "env": { + "DATABASE_URI_APP": "postgresql://user:pass@localhost:5432/app_db", + "DATABASE_URI_ETL": "postgresql://user:pass@localhost:5432/etl_db", + "DATABASE_URI_ANALYTICS": "postgresql://user:pass@localhost:5432/analytics_db", + "DATABASE_DESC_APP": "Main application database with user data and transactions", + "DATABASE_DESC_ETL": "ETL staging database for data processing pipelines", + "DATABASE_DESC_ANALYTICS": "Read-only analytics database with aggregated metrics" + } + } + } +} +``` + +Each connection is identified by its name (the part after `DATABASE_URI_`, converted to lowercase): +- `DATABASE_URI_APP` → connection name: `"app"` +- `DATABASE_URI_ETL` → connection name: `"etl"` +- `DATABASE_URI_ANALYTICS` → connection name: `"analytics"` + +**Connection Descriptions**: You can optionally provide descriptions for each connection using `DATABASE_DESC_` environment variables. These descriptions help the AI assistant understand which database to use for different tasks. The descriptions are: +- Automatically displayed in the server context (visible to the AI without requiring a tool call) +- Useful for guiding the AI to select the appropriate database + +When using tools, the LLM will specify which connection to use via the `conn_name` parameter: +- `list_schemas(conn_name="app")` - Lists schemas in the app database +- `explain_query(conn_name="etl", sql="SELECT ...")` - Explains query in the ETL database + +For backward compatibility, `DATABASE_URI` (without a suffix) maps to the connection name `"default"`. + + ##### Access Mode Postgres MCP Pro supports multiple *access modes* to give you control over the operations that the AI agent can perform on the database: diff --git a/src/postgres_mcp/env_utils.py b/src/postgres_mcp/env_utils.py new file mode 100644 index 0000000..88e73c3 --- /dev/null +++ b/src/postgres_mcp/env_utils.py @@ -0,0 +1,51 @@ +"""Utility functions for environment variable handling.""" + +import os + + +def discover_database_connections() -> dict[str, str]: + """ + Discover all DATABASE_URI_* environment variables. + + Returns: + Dict mapping connection names to connection URLs + - DATABASE_URI -> "default" + - DATABASE_URI_APP -> "app" + - DATABASE_URI_ETL -> "etl" + """ + discovered = {} + + for env_var, url in os.environ.items(): + if env_var == "DATABASE_URI": + discovered["default"] = url + elif env_var.startswith("DATABASE_URI_"): + # Extract postfix and lowercase it + postfix = env_var[len("DATABASE_URI_") :] + conn_name = postfix.lower() + discovered[conn_name] = url + + return discovered + + +def discover_database_descriptions() -> dict[str, str]: + """ + Discover all DATABASE_DESC_* environment variables. + + Returns: + Dict mapping connection names to descriptions + - DATABASE_DESC -> "default" + - DATABASE_DESC_APP -> "app" + - DATABASE_DESC_ETL -> "etl" + """ + descriptions = {} + + for env_var, desc in os.environ.items(): + if env_var == "DATABASE_DESC": + descriptions["default"] = desc + elif env_var.startswith("DATABASE_DESC_"): + # Extract postfix and lowercase it + postfix = env_var[len("DATABASE_DESC_") :] + conn_name = postfix.lower() + descriptions[conn_name] = desc + + return descriptions diff --git a/src/postgres_mcp/server.py b/src/postgres_mcp/server.py index af5669a..27fa572 100644 --- a/src/postgres_mcp/server.py +++ b/src/postgres_mcp/server.py @@ -3,6 +3,7 @@ import asyncio import logging import os +import re import signal import sys from enum import Enum @@ -22,19 +23,51 @@ from .artifacts import ExplainPlanArtifact from .database_health import DatabaseHealthTool from .database_health import HealthType +from .env_utils import discover_database_connections +from .env_utils import discover_database_descriptions from .explain import ExplainPlanTool from .index.index_opt_base import MAX_NUM_INDEX_TUNING_QUERIES from .index.llm_opt import LLMOptimizerTool from .index.presentation import TextPresentation -from .sql import DbConnPool +from .sql import ConnectionRegistry from .sql import SafeSqlDriver from .sql import SqlDriver from .sql import check_hypopg_installation_status from .sql import obfuscate_password from .top_queries import TopQueriesCalc -# Initialize FastMCP with default settings -mcp = FastMCP("postgres-mcp") +INSTRUCTIONS_TEMPLATE = """\ +This PostgreSQL MCP Lite server gives (un)restricted DB access via one or more connection strings. + +Available database connections: +{conn_list} +""" + + +def build_instructions() -> str: + """Build server instructions including available connections.""" + # Discover connections from environment variables + conn_urls = discover_database_connections() + conn_descs = discover_database_descriptions() + + # Build connection list + if not conn_urls: + conn_list = "- No connections configured (set DATABASE_URI environment variable)" + else: + conn_items = [] + for name in sorted(conn_urls.keys()): + desc = conn_descs.get(name, "") + if desc: + conn_items.append(f"- '{name}': {desc}") + else: + conn_items.append(f"- '{name}'") + conn_list = "\n".join(conn_items) + + instructions = INSTRUCTIONS_TEMPLATE.format(conn_list=conn_list) + return instructions + + +mcp = FastMCP("postgres-mcp", instructions=build_instructions()) # Constants PG_STAT_STATEMENTS = "pg_stat_statements" @@ -53,20 +86,32 @@ class AccessMode(str, Enum): # Global variables -db_connection = DbConnPool() +connection_registry = ConnectionRegistry() current_access_mode = AccessMode.UNRESTRICTED shutdown_in_progress = False -async def get_sql_driver() -> Union[SqlDriver, SafeSqlDriver]: - """Get the appropriate SQL driver based on the current access mode.""" +async def get_sql_driver(conn_name: str) -> Union[SqlDriver, SafeSqlDriver]: + """ + Get the appropriate SQL driver based on the current access mode. + + Args: + conn_name: Connection name (e.g., "default", "app", "etl") + + Returns: + SqlDriver or SafeSqlDriver instance + + Raises: + ValueError: If connection name doesn't exist + """ + db_connection = connection_registry.get_connection(conn_name) base_driver = SqlDriver(conn=db_connection) if current_access_mode == AccessMode.RESTRICTED: - logger.debug("Using SafeSqlDriver with restrictions (RESTRICTED mode)") + logger.debug(f"Using SafeSqlDriver with restrictions for '{conn_name}' (RESTRICTED mode)") return SafeSqlDriver(sql_driver=base_driver, timeout=30) # 30 second timeout else: - logger.debug("Using unrestricted SqlDriver (UNRESTRICTED mode)") + logger.debug(f"Using unrestricted SqlDriver for '{conn_name}' (UNRESTRICTED mode)") return base_driver @@ -81,10 +126,12 @@ def format_error_response(error: str) -> ResponseType: @mcp.tool(description="List all schemas in the database") -async def list_schemas() -> ResponseType: +async def list_schemas( + conn_name: str = Field(description="Connection name (e.g., 'default', 'app', 'etl')"), +) -> ResponseType: """List all schemas in the database.""" try: - sql_driver = await get_sql_driver() + sql_driver = await get_sql_driver(conn_name) rows = await sql_driver.execute_query( """ SELECT @@ -108,12 +155,13 @@ async def list_schemas() -> ResponseType: @mcp.tool(description="List objects in a schema") async def list_objects( + conn_name: str = Field(description="Connection name (e.g., 'default', 'app', 'etl')"), schema_name: str = Field(description="Schema name"), object_type: str = Field(description="Object type: 'table', 'view', 'sequence', or 'extension'", default="table"), ) -> ResponseType: """List objects of a given type in a schema.""" try: - sql_driver = await get_sql_driver() + sql_driver = await get_sql_driver(conn_name) if object_type in ("table", "view"): table_type = "BASE TABLE" if object_type == "table" else "VIEW" @@ -176,13 +224,14 @@ async def list_objects( @mcp.tool(description="Show detailed information about a database object") async def get_object_details( + conn_name: str = Field(description="Connection name (e.g., 'default', 'app', 'etl')"), schema_name: str = Field(description="Schema name"), object_name: str = Field(description="Object name"), object_type: str = Field(description="Object type: 'table', 'view', 'sequence', or 'extension'", default="table"), ) -> ResponseType: """Get detailed information about a database object.""" try: - sql_driver = await get_sql_driver() + sql_driver = await get_sql_driver(conn_name) if object_type in ("table", "view"): # Get columns @@ -309,6 +358,7 @@ async def get_object_details( @mcp.tool(description="Explains the execution plan for a SQL query, showing how the database will execute it and provides detailed cost estimates.") async def explain_query( + conn_name: str = Field(description="Connection name (e.g., 'default', 'app', 'etl')"), sql: str = Field(description="SQL query to explain"), analyze: bool = Field( description="When True, actually runs the query to show real execution statistics instead of estimates. " @@ -333,12 +383,13 @@ async def explain_query( Explains the execution plan for a SQL query. Args: + conn_name: Connection name to use sql: The SQL query to explain analyze: When True, actually runs the query for real statistics hypothetical_indexes: Optional list of indexes to simulate """ try: - sql_driver = await get_sql_driver() + sql_driver = await get_sql_driver(conn_name) explain_tool = ExplainPlanTool(sql_driver=sql_driver) result: ExplainPlanArtifact | ErrorResult | None = None @@ -388,11 +439,12 @@ async def explain_query( # Query function declaration without the decorator - we'll add it dynamically based on access mode async def execute_sql( + conn_name: str = Field(description="Connection name (e.g., 'default', 'app', 'etl')"), sql: str = Field(description="SQL to run", default="all"), ) -> ResponseType: """Executes a SQL query against the database.""" try: - sql_driver = await get_sql_driver() + sql_driver = await get_sql_driver(conn_name) rows = await sql_driver.execute_query(sql) # type: ignore if rows is None: return format_text_response("No results") @@ -405,12 +457,13 @@ async def execute_sql( @mcp.tool(description="Analyze frequently executed queries in the database and recommend optimal indexes") @validate_call async def analyze_workload_indexes( + conn_name: str = Field(description="Connection name (e.g., 'default', 'app', 'etl')"), max_index_size_mb: int = Field(description="Max index size in MB", default=10000), method: Literal["dta", "llm"] = Field(description="Method to use for analysis", default="dta"), ) -> ResponseType: """Analyze frequently executed queries in the database and recommend optimal indexes.""" try: - sql_driver = await get_sql_driver() + sql_driver = await get_sql_driver(conn_name) if method == "dta": index_tuning = DatabaseTuningAdvisor(sql_driver) else: @@ -426,6 +479,7 @@ async def analyze_workload_indexes( @mcp.tool(description="Analyze a list of (up to 10) SQL queries and recommend optimal indexes") @validate_call async def analyze_query_indexes( + conn_name: str = Field(description="Connection name (e.g., 'default', 'app', 'etl')"), queries: list[str] = Field(description="List of Query strings to analyze"), max_index_size_mb: int = Field(description="Max index size in MB", default=10000), method: Literal["dta", "llm"] = Field(description="Method to use for analysis", default="dta"), @@ -437,7 +491,7 @@ async def analyze_query_indexes( return format_error_response(f"Please provide a list of up to {MAX_NUM_INDEX_TUNING_QUERIES} queries to analyze.") try: - sql_driver = await get_sql_driver() + sql_driver = await get_sql_driver(conn_name) if method == "dta": index_tuning = DatabaseTuningAdvisor(sql_driver) else: @@ -463,6 +517,7 @@ async def analyze_query_indexes( "You can optionally specify a single health check or a comma-separated list of health checks. The default is 'all' checks." ) async def analyze_db_health( + conn_name: str = Field(description="Connection name (e.g., 'default', 'app', 'etl')"), health_type: str = Field( description=f"Optional. Valid values are: {', '.join(sorted([t.value for t in HealthType]))}.", default="all", @@ -471,10 +526,11 @@ async def analyze_db_health( """Analyze database health for specified components. Args: + conn_name: Connection name to use health_type: Comma-separated list of health check types to perform. Valid values: index, connection, vacuum, sequence, replication, buffer, constraint, all """ - health_tool = DatabaseHealthTool(await get_sql_driver()) + health_tool = DatabaseHealthTool(await get_sql_driver(conn_name)) result = await health_tool.health(health_type=health_type) return format_text_response(result) @@ -484,6 +540,7 @@ async def analyze_db_health( description=f"Reports the slowest or most resource-intensive queries using data from the '{PG_STAT_STATEMENTS}' extension.", ) async def get_top_queries( + conn_name: str = Field(description="Connection name (e.g., 'default', 'app', 'etl')"), sort_by: str = Field( description="Ranking criteria: 'total_time' for total execution time or 'mean_time' for mean execution time per call, or 'resources' " "for resource-intensive queries", @@ -492,7 +549,7 @@ async def get_top_queries( limit: int = Field(description="Number of queries to return when ranking based on mean_time or total_time", default=10), ) -> ResponseType: try: - sql_driver = await get_sql_driver() + sql_driver = await get_sql_driver(conn_name) top_queries_tool = TopQueriesCalc(sql_driver=sql_driver) if sort_by == "resources": @@ -539,6 +596,12 @@ async def main(): default=8000, help="Port for SSE server (default: 8000)", ) + parser.add_argument( + "--db", + action="append", + metavar="NAME=URL", + help="Database connection (can be repeated): --db prod=postgresql://... --db staging=postgresql://...", + ) args = parser.parse_args() @@ -554,24 +617,45 @@ async def main(): logger.info(f"Starting PostgreSQL MCP Server in {current_access_mode.upper()} mode") - # Get database URL from environment variable or command line - database_url = os.environ.get("DATABASE_URI", args.database_url) - - if not database_url: - raise ValueError( - "Error: No database URL provided. Please specify via 'DATABASE_URI' environment variable or command-line argument.", - ) + # Initialize database connection registry + # For backwards compatibility, support command-line database_url argument + if args.database_url and "DATABASE_URI" not in os.environ: + os.environ["DATABASE_URI"] = args.database_url + logger.info("Set default database connection from positional argument") + + # Process --db arguments with validation + if args.db: + for db_spec in args.db: + if "=" not in db_spec: + logger.error(f"Invalid --db format: '{db_spec}'. Expected NAME=URL") + sys.exit(1) + + name, url = db_spec.split("=", 1) + name = name.strip().upper() + + # Validate name contains only alphanumeric and underscore + if not re.match(r"^[A-Z0-9_]+$", name): + logger.error(f"Invalid connection name '{name}'. Only alphanumeric characters and underscores allowed.") + sys.exit(1) + + # Check if already set in environment (env vars take precedence) + env_var = f"DATABASE_URI_{name}" if name != "DEFAULT" else "DATABASE_URI" + if env_var in os.environ: + logger.info(f"Skipping --db {name}=... (already set via {env_var} environment variable)") + else: + os.environ[env_var] = url + logger.info(f"Set database connection '{name.lower()}' from command-line argument") - # Initialize database connection pool try: - await db_connection.pool_connect(database_url) - logger.info("Successfully connected to database and initialized connection pool") + await connection_registry.discover_and_connect() + conn_names = connection_registry.get_connection_names() + logger.info(f"Successfully initialized {len(conn_names)} connection(s): {', '.join(conn_names)}") except Exception as e: logger.warning( - f"Could not connect to database: {obfuscate_password(str(e))}", + f"Could not initialize database connections: {obfuscate_password(str(e))}", ) logger.warning( - "The MCP server will start but database operations will fail until a valid connection is established.", + "The MCP server will start but database operations will fail until valid connections are established.", ) # Set up proper shutdown handling @@ -609,10 +693,10 @@ async def shutdown(sig=None): if sig: logger.info(f"Received exit signal {sig.name}") - # Close database connections + # Close all database connections try: - await db_connection.close() - logger.info("Closed database connections") + await connection_registry.close_all() + logger.info("Closed all database connections") except Exception as e: logger.error(f"Error closing database connections: {e}") diff --git a/src/postgres_mcp/sql/__init__.py b/src/postgres_mcp/sql/__init__.py index 1fded3b..921d0c3 100644 --- a/src/postgres_mcp/sql/__init__.py +++ b/src/postgres_mcp/sql/__init__.py @@ -10,12 +10,14 @@ from .extension_utils import reset_postgres_version_cache from .index import IndexDefinition from .safe_sql import SafeSqlDriver +from .sql_driver import ConnectionRegistry from .sql_driver import DbConnPool from .sql_driver import SqlDriver from .sql_driver import obfuscate_password __all__ = [ "ColumnCollector", + "ConnectionRegistry", "DbConnPool", "IndexDefinition", "SafeSqlDriver", diff --git a/src/postgres_mcp/sql/sql_driver.py b/src/postgres_mcp/sql/sql_driver.py index 5beacb0..2d0683a 100644 --- a/src/postgres_mcp/sql/sql_driver.py +++ b/src/postgres_mcp/sql/sql_driver.py @@ -1,5 +1,6 @@ """SQL driver adapter for PostgreSQL connections.""" +import asyncio import logging import re from dataclasses import dataclass @@ -14,6 +15,9 @@ from psycopg_pool import AsyncConnectionPool from typing_extensions import LiteralString +from ..env_utils import discover_database_connections +from ..env_utils import discover_database_descriptions + logger = logging.getLogger(__name__) @@ -136,6 +140,139 @@ def last_error(self) -> Optional[str]: return self._last_error +class ConnectionRegistry: + """Registry for managing multiple database connections.""" + + def __init__(self): + self.connections: Dict[str, DbConnPool] = {} + self._connection_urls: Dict[str, str] = {} + self._connection_descriptions: Dict[str, str] = {} + + def discover_connections(self) -> Dict[str, str]: + """ + Discover all DATABASE_URI_* environment variables. + + Returns: + Dict mapping connection names to connection URLs + - DATABASE_URI -> "default" + - DATABASE_URI_APP -> "app" + - DATABASE_URI_ETL -> "etl" + """ + return discover_database_connections() + + def discover_descriptions(self) -> Dict[str, str]: + """ + Discover all DATABASE_DESC_* environment variables. + + Returns: + Dict mapping connection names to descriptions + - DATABASE_DESC -> "default" + - DATABASE_DESC_APP -> "app" + - DATABASE_DESC_ETL -> "etl" + """ + return discover_database_descriptions() + + async def discover_and_connect(self) -> None: + """ + Discover all DATABASE_URI_* environment variables and connect to them. + Connections are initialized in parallel for efficiency. + """ + discovered = self.discover_connections() + + if not discovered: + raise ValueError("No database connections found. Please set DATABASE_URI or DATABASE_URI_* environment variables.") + + logger.info(f"Discovered {len(discovered)} database connection(s): {', '.join(discovered.keys())}") + + # Store URLs and descriptions for reference + self._connection_urls = discovered.copy() + self._connection_descriptions = self.discover_descriptions() + + # Create connection pools + for conn_name, url in discovered.items(): + self.connections[conn_name] = DbConnPool(url) + + # Connect to all databases in parallel + async def connect_single(conn_name: str, pool: DbConnPool) -> tuple[str, bool, Optional[str]]: + """Connect to a single database and return status.""" + try: + await pool.pool_connect() + return (conn_name, True, None) + except Exception as e: + error_msg = obfuscate_password(str(e)) + logger.warning(f"Failed to connect to '{conn_name}': {error_msg}") + return (conn_name, False, error_msg) + + # Execute all connections in parallel + results = await asyncio.gather(*[connect_single(name, pool) for name, pool in self.connections.items()], return_exceptions=False) + + # Log results + for conn_name, success, error in results: + if success: + logger.info(f"Successfully connected to '{conn_name}'") + else: + logger.warning(f"Connection '{conn_name}' failed: {error}") + + def get_connection(self, conn_name: str) -> DbConnPool: + """ + Get a connection pool by name. + + Args: + conn_name: Connection name (e.g., "default", "app", "etl") + + Returns: + DbConnPool instance + + Raises: + ValueError: If connection name doesn't exist + """ + if conn_name not in self.connections: + available = ", ".join(f"'{name}'" for name in sorted(self.connections.keys())) + raise ValueError(f"Connection '{conn_name}' not found. Available connections: {available}") + + pool = self.connections[conn_name] + + # Check if connection is valid + if not pool.is_valid: + error_msg = pool.last_error or "Unknown error" + raise ValueError(f"Connection '{conn_name}' is not available: {obfuscate_password(error_msg)}") + + return pool + + async def close_all(self) -> None: + """Close all database connections.""" + close_tasks = [] + for conn_name, pool in self.connections.items(): + logger.info(f"Closing connection '{conn_name}'...") + close_tasks.append(pool.close()) + + # Close all connections in parallel + await asyncio.gather(*close_tasks, return_exceptions=True) + + self.connections.clear() + self._connection_urls.clear() + self._connection_descriptions.clear() + + def get_connection_names(self) -> List[str]: + """Get list of all connection names.""" + return list(self.connections.keys()) + + def get_connection_info(self) -> List[Dict[str, str]]: + """ + Get information about all configured connections. + + Returns: + List of dicts with 'name' and optional 'description' for each connection + """ + info = [] + for conn_name in sorted(self.connections.keys()): + conn_info = {"name": conn_name} + if conn_name in self._connection_descriptions: + conn_info["description"] = self._connection_descriptions[conn_name] + info.append(conn_info) + return info + + class SqlDriver: """Adapter class that wraps a PostgreSQL connection with the interface expected by DTA.""" @@ -196,30 +333,20 @@ async def execute_query( Returns: List of RowResult objects or None on error """ - try: + if self.conn is None: + self.connect() if self.conn is None: - self.connect() - if self.conn is None: - raise ValueError("Connection not established") - - # Handle connection pool vs direct connection - if self.is_pool: - # For pools, get a connection from the pool - pool = await self.conn.pool_connect() - async with pool.connection() as connection: - return await self._execute_with_connection(connection, query, params, force_readonly=force_readonly) - else: - # Direct connection approach - return await self._execute_with_connection(self.conn, query, params, force_readonly=force_readonly) - except Exception as e: - # Mark pool as invalid if there was a connection issue - if self.conn and self.is_pool: - self.conn._is_valid = False # type: ignore - self.conn._last_error = str(e) # type: ignore - elif self.conn and not self.is_pool: - self.conn = None - - raise e + raise ValueError("Connection not established") + + # Handle connection pool vs direct connection + if self.is_pool: + # For pools, get a connection from the pool + pool = await self.conn.pool_connect() + async with pool.connection() as connection: + return await self._execute_with_connection(connection, query, params, force_readonly=force_readonly) + else: + # Direct connection approach + return await self._execute_with_connection(self.conn, query, params, force_readonly=force_readonly) async def _execute_with_connection(self, connection, query, params, force_readonly) -> Optional[List[RowResult]]: """Execute query with the given connection.""" diff --git a/tests/unit/explain/test_server.py b/tests/unit/explain/test_server.py index 0fea0e0..5b4602f 100644 --- a/tests/unit/explain/test_server.py +++ b/tests/unit/explain/test_server.py @@ -46,7 +46,7 @@ async def test_explain_query_basic(): # Use patch to replace the actual explain_query function with our own mock with patch.object(server, "explain_query", return_value=[mock_response]): # Call the patched function - result = await server.explain_query("SELECT * FROM users") + result = await server.explain_query(conn_name="default", sql="SELECT * FROM users") # Verify we get the expected result assert isinstance(result, list) @@ -74,7 +74,7 @@ async def test_explain_query_analyze(): # Use patch to replace the actual explain_query function with our own mock with patch.object(server, "explain_query", return_value=[mock_response]): # Call the patched function with analyze=True - result = await server.explain_query("SELECT * FROM users", analyze=True) + result = await server.explain_query(conn_name="default", sql="SELECT * FROM users", analyze=True) # Verify we get the expected result assert isinstance(result, list) @@ -104,7 +104,7 @@ async def test_explain_query_hypothetical_indexes(): # Use patch to replace the actual explain_query function with our own mock with patch.object(server, "explain_query", return_value=[mock_response]): # Call the patched function with hypothetical_indexes - result = await server.explain_query(test_sql, hypothetical_indexes=test_indexes) + result = await server.explain_query(conn_name="default", sql=test_sql, hypothetical_indexes=test_indexes) # Verify we get the expected result assert isinstance(result, list) @@ -123,7 +123,7 @@ async def test_explain_query_error_handling(): # Use patch to replace the actual function with our mock that returns an error with patch.object(server, "explain_query", return_value=[mock_response]): # Call the patched function - result = await server.explain_query("INVALID SQL") + result = await server.explain_query(conn_name="default", sql="INVALID SQL") # Verify error is formatted correctly assert isinstance(result, list) diff --git a/tests/unit/explain/test_server_integration.py b/tests/unit/explain/test_server_integration.py index aa8d704..43c39ea 100644 --- a/tests/unit/explain/test_server_integration.py +++ b/tests/unit/explain/test_server_integration.py @@ -45,7 +45,7 @@ async def test_explain_query_integration(): with patch("postgres_mcp.server.get_sql_driver"): # Patch the ExplainPlanTool with patch("postgres_mcp.server.ExplainPlanTool"): - result = await explain_query("SELECT * FROM users", hypothetical_indexes=None) + result = await explain_query(conn_name="default", sql="SELECT * FROM users", hypothetical_indexes=None) # Verify result matches our expected plan data assert isinstance(result, list) @@ -67,7 +67,7 @@ async def test_explain_query_with_analyze_integration(): with patch("postgres_mcp.server.get_sql_driver"): # Patch the ExplainPlanTool with patch("postgres_mcp.server.ExplainPlanTool"): - result = await explain_query("SELECT * FROM users", analyze=True, hypothetical_indexes=None) + result = await explain_query(conn_name="default", sql="SELECT * FROM users", analyze=True, hypothetical_indexes=None) # Verify result matches our expected plan data assert isinstance(result, list) @@ -98,7 +98,7 @@ async def test_explain_query_with_hypothetical_indexes_integration(): with patch("postgres_mcp.server.get_sql_driver", return_value=mock_safe_driver): # Patch the ExplainPlanTool with patch("postgres_mcp.server.ExplainPlanTool"): - result = await explain_query(test_sql, hypothetical_indexes=test_indexes) + result = await explain_query(conn_name="default", sql=test_sql, hypothetical_indexes=test_indexes) # Verify result matches our expected plan data assert isinstance(result, list) @@ -129,7 +129,7 @@ async def test_explain_query_missing_hypopg_integration(): with patch("postgres_mcp.server.get_sql_driver", return_value=mock_safe_driver): # Patch the ExplainPlanTool with patch("postgres_mcp.server.ExplainPlanTool"): - result = await explain_query(test_sql, hypothetical_indexes=test_indexes) + result = await explain_query(conn_name="default", sql=test_sql, hypothetical_indexes=test_indexes) # Verify result assert isinstance(result, list) @@ -152,7 +152,7 @@ async def test_explain_query_error_handling_integration(): "postgres_mcp.server.get_sql_driver", side_effect=Exception(error_message), ): - result = await explain_query("INVALID SQL") + result = await explain_query(conn_name="default", sql="INVALID SQL") # Verify error is correctly formatted assert isinstance(result, list) diff --git a/tests/unit/sql/test_readonly_enforcement.py b/tests/unit/sql/test_readonly_enforcement.py index e079c02..4f2ba97 100644 --- a/tests/unit/sql/test_readonly_enforcement.py +++ b/tests/unit/sql/test_readonly_enforcement.py @@ -27,9 +27,9 @@ async def test_force_readonly_enforcement(): # Test UNRESTRICTED mode with patch("postgres_mcp.server.current_access_mode", AccessMode.UNRESTRICTED), patch( - "postgres_mcp.server.db_connection", mock_conn_pool + "postgres_mcp.server.connection_registry.get_connection", return_value=mock_conn_pool ), patch.object(SqlDriver, "_execute_with_connection", mock_execute): - driver = await get_sql_driver() + driver = await get_sql_driver(conn_name="default") assert isinstance(driver, SqlDriver) assert not isinstance(driver, SafeSqlDriver) @@ -56,9 +56,9 @@ async def test_force_readonly_enforcement(): # Test RESTRICTED mode with patch("postgres_mcp.server.current_access_mode", AccessMode.RESTRICTED), patch( - "postgres_mcp.server.db_connection", mock_conn_pool + "postgres_mcp.server.connection_registry.get_connection", return_value=mock_conn_pool ), patch.object(SqlDriver, "_execute_with_connection", mock_execute): - driver = await get_sql_driver() + driver = await get_sql_driver(conn_name="default") assert isinstance(driver, SafeSqlDriver) # Test default behavior diff --git a/tests/unit/test_access_mode.py b/tests/unit/test_access_mode.py index f7d3b80..b772e1d 100644 --- a/tests/unit/test_access_mode.py +++ b/tests/unit/test_access_mode.py @@ -32,9 +32,9 @@ async def test_get_sql_driver_returns_correct_driver(access_mode, expected_drive """Test that get_sql_driver returns the correct driver type based on access mode.""" with ( patch("postgres_mcp.server.current_access_mode", access_mode), - patch("postgres_mcp.server.db_connection", mock_db_connection), + patch("postgres_mcp.server.connection_registry.get_connection", return_value=mock_db_connection), ): - driver = await get_sql_driver() + driver = await get_sql_driver(conn_name="default") assert isinstance(driver, expected_driver_type) # When in RESTRICTED mode, verify timeout is set @@ -48,9 +48,9 @@ async def test_get_sql_driver_sets_timeout_in_restricted_mode(mock_db_connection """Test that get_sql_driver sets the timeout in restricted mode.""" with ( patch("postgres_mcp.server.current_access_mode", AccessMode.RESTRICTED), - patch("postgres_mcp.server.db_connection", mock_db_connection), + patch("postgres_mcp.server.connection_registry.get_connection", return_value=mock_db_connection), ): - driver = await get_sql_driver() + driver = await get_sql_driver(conn_name="default") assert isinstance(driver, SafeSqlDriver) assert driver.timeout == 30 assert hasattr(driver, "sql_driver") @@ -61,9 +61,9 @@ async def test_get_sql_driver_in_unrestricted_mode_no_timeout(mock_db_connection """Test that get_sql_driver in unrestricted mode is a regular SqlDriver.""" with ( patch("postgres_mcp.server.current_access_mode", AccessMode.UNRESTRICTED), - patch("postgres_mcp.server.db_connection", mock_db_connection), + patch("postgres_mcp.server.connection_registry.get_connection", return_value=mock_db_connection), ): - driver = await get_sql_driver() + driver = await get_sql_driver(conn_name="default") assert isinstance(driver, SqlDriver) assert not hasattr(driver, "timeout") @@ -90,7 +90,7 @@ async def test_command_line_parsing(): with ( patch("postgres_mcp.server.current_access_mode", AccessMode.UNRESTRICTED), - patch("postgres_mcp.server.db_connection.pool_connect", AsyncMock()), + patch("postgres_mcp.server.connection_registry.discover_and_connect", AsyncMock()), patch("postgres_mcp.server.mcp.run_stdio_async", AsyncMock()), patch("postgres_mcp.server.shutdown", AsyncMock()), ):