11"""PostgreSQL-based key-value store using asyncpg.
22
33Note: SQL queries in this module use f-strings for table names, which triggers S608 warnings.
4- This is safe because table names are validated in __init__ to be alphanumeric ( plus underscores) .
4+ This is safe because table names are validated in __init__ to be alphanumeric plus underscores.
55"""
66
77from collections .abc import AsyncIterator , Sequence
3333# PostgreSQL table name length limit is 63 characters
3434# Use 200 for consistency with MongoDB
3535MAX_COLLECTION_LENGTH = 200
36+ POSTGRES_MAX_IDENTIFIER_LEN = 63
3637
3738
3839class PostgreSQLStore (BaseEnumerateCollectionsStore , BaseDestroyCollectionStore , BaseContextManagerStore , BaseStore ):
@@ -64,6 +65,7 @@ class PostgreSQLStore(BaseEnumerateCollectionsStore, BaseDestroyCollectionStore,
6465 """
6566
6667 _pool : asyncpg .Pool | None # type: ignore[type-arg]
68+ _owns_pool : bool
6769 _url : str | None
6870 _host : str
6971 _port : int
@@ -143,18 +145,26 @@ def __init__(
143145 ) -> None :
144146 """Initialize the PostgreSQL store."""
145147 self ._pool = pool
148+ self ._owns_pool = pool is None # Only own the pool if we create it
146149 self ._url = url
147150 self ._host = host
148151 self ._port = port
149152 self ._database = database
150153 self ._user = user
151154 self ._password = password
152155
153- # Validate and sanitize table name to prevent SQL injection
156+ # Validate and sanitize table name to prevent SQL injection and invalid identifiers
154157 table_name = table_name or DEFAULT_TABLE
155158 if not table_name .replace ("_" , "" ).isalnum ():
156159 msg = f"Table name must be alphanumeric (with underscores): { table_name } "
157160 raise ValueError (msg )
161+ if table_name [0 ].isdigit ():
162+ msg = f"Table name must not start with a digit: { table_name } "
163+ raise ValueError (msg )
164+ # PostgreSQL identifier limit is 63 bytes
165+ if len (table_name ) > POSTGRES_MAX_IDENTIFIER_LEN :
166+ msg = f"Table name too long (>{ POSTGRES_MAX_IDENTIFIER_LEN } ): { table_name } "
167+ raise ValueError (msg )
158168 self ._table_name = table_name
159169
160170 super ().__init__ (default_collection = default_collection )
@@ -200,14 +210,15 @@ async def __aenter__(self) -> Self:
200210 user = self ._user ,
201211 password = self ._password ,
202212 )
213+ self ._owns_pool = True
203214
204215 await super ().__aenter__ ()
205216 return self
206217
207218 @override
208219 async def __aexit__ (self , exc_type : Any , exc_val : Any , exc_tb : Any ) -> None : # pyright: ignore[reportAny]
209220 await super ().__aexit__ (exc_type , exc_val , exc_tb )
210- if self ._pool is not None :
221+ if self ._pool is not None and self . _owns_pool :
211222 await self ._pool .close ()
212223
213224 def _sanitize_collection_name (self , collection : str ) -> str :
@@ -218,8 +229,19 @@ def _sanitize_collection_name(self, collection: str) -> str:
218229
219230 Returns:
220231 A sanitized collection name.
232+
233+ Raises:
234+ ValueError: If the sanitized collection name is empty.
221235 """
222- return sanitize_string (value = collection , max_length = MAX_COLLECTION_LENGTH , allowed_characters = ALPHANUMERIC_CHARACTERS )
236+ sanitized = sanitize_string (
237+ value = collection ,
238+ max_length = MAX_COLLECTION_LENGTH ,
239+ allowed_characters = ALPHANUMERIC_CHARACTERS + "_" ,
240+ )
241+ if not sanitized :
242+ msg = "Collection name cannot be empty after sanitization"
243+ raise ValueError (msg )
244+ return sanitized
223245
224246 @override
225247 async def _setup_collection (self , * , collection : str ) -> None :
@@ -244,8 +266,13 @@ async def _setup_collection(self, *, collection: str) -> None:
244266 """
245267
246268 # Create index on expires_at for efficient TTL queries
269+ # Ensure index name <= 63 chars (PostgreSQL identifier limit)
270+ index_name = f"idx_{ self ._table_name } _expires_at"
271+ if len (index_name ) > POSTGRES_MAX_IDENTIFIER_LEN :
272+ import hashlib
273+ index_name = "idx_" + hashlib .sha256 (self ._table_name .encode ()).hexdigest ()[:16 ] + "_exp"
247274 create_index_sql = f""" # noqa: S608
248- CREATE INDEX IF NOT EXISTS idx_ { self . _table_name } _expires_at
275+ CREATE INDEX IF NOT EXISTS { index_name }
249276 ON { self ._table_name } (expires_at)
250277 WHERE expires_at IS NOT NULL
251278 """
@@ -374,7 +401,6 @@ async def _put_managed_entry(
374401 DO UPDATE SET
375402 value = EXCLUDED.value,
376403 ttl = EXCLUDED.ttl,
377- created_at = EXCLUDED.created_at,
378404 expires_at = EXCLUDED.expires_at
379405 """ , # noqa: S608
380406 sanitized_collection ,
@@ -411,9 +437,9 @@ async def _put_managed_entries(
411437
412438 sanitized_collection = self ._sanitize_collection_name (collection = collection )
413439
414- # Prepare data for batch insert
440+ # Prepare data for batch insert using method-level ttl/created_at/expires_at
415441 values = [
416- (sanitized_collection , key , entry .value , entry . ttl , entry . created_at , entry . expires_at )
442+ (sanitized_collection , key , entry .value , ttl , created_at , expires_at )
417443 for key , entry in zip (keys , managed_entries , strict = True )
418444 ]
419445
@@ -427,7 +453,6 @@ async def _put_managed_entries(
427453 DO UPDATE SET
428454 value = EXCLUDED.value,
429455 ttl = EXCLUDED.ttl,
430- created_at = EXCLUDED.created_at,
431456 expires_at = EXCLUDED.expires_at
432457 """ , # noqa: S608
433458 values ,
@@ -490,7 +515,9 @@ async def _get_collection_names(self, *, limit: int | None = None) -> list[str]:
490515 Returns:
491516 A list of collection names.
492517 """
493- limit = min (limit or DEFAULT_PAGE_SIZE , PAGE_LIMIT )
518+ if limit is None or limit <= 0 :
519+ limit = DEFAULT_PAGE_SIZE
520+ limit = min (limit , PAGE_LIMIT )
494521
495522 async with self ._acquire_connection () as conn :
496523 rows = await conn .fetch ( # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
0 commit comments