Skip to content

Commit bac51ee

Browse files
fix: address code review feedback for PostgreSQL store
- Add pool ownership tracking to avoid closing externally-provided pools - Enhance table name validation (63 char limit, no leading digits) - Fix collection name sanitization to allow underscores and validate non-empty - Add index name length handling with hash fallback for long names - Fix upsert operations to preserve created_at timestamps on updates - Fix bulk operations to use method-level ttl/created_at/expires_at values - Add proper limit clamping to prevent negative values - Improve ImportError message with explicit installation command - Merge README changes from main branch (update doc URLs) Co-authored-by: William Easton <[email protected]>
1 parent a10278f commit bac51ee

File tree

3 files changed

+46
-19
lines changed

3 files changed

+46
-19
lines changed

README.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,16 @@ This monorepo contains two libraries:
88

99
## Documentation
1010

11-
- [Full Documentation](https://strawgate.github.io/py-key-value/)
12-
- [Getting Started Guide](https://strawgate.github.io/py-key-value/getting-started/)
13-
- [Wrappers Guide](https://strawgate.github.io/py-key-value/wrappers/)
14-
- [Adapters Guide](https://strawgate.github.io/py-key-value/adapters/)
15-
- [API Reference](https://strawgate.github.io/py-key-value/api/protocols/)
11+
- [Full Documentation](https://strawgate.com/py-key-value/)
12+
- [Getting Started Guide](https://strawgate.com/py-key-value/getting-started/)
13+
- [Wrappers Guide](https://strawgate.com/py-key-value/wrappers/)
14+
- [Adapters Guide](https://strawgate.com/py-key-value/adapters/)
15+
- [API Reference](https://strawgate.com/py-key-value/api/protocols/)
1616

1717
## Why use this library?
1818

19-
- **Multiple backends**: DynamoDB, Elasticsearch, Memcached, MongoDB,
20-
PostgreSQL, Redis, RocksDB, Valkey, and In-memory, Disk, etc
19+
- **Multiple backends**: DynamoDB, Elasticsearch, Memcached, MongoDB, PostgreSQL,
20+
Redis, RocksDB, Valkey, and In-memory, Disk, etc
2121
- **TTL support**: Automatic expiration handling across all store types
2222
- **Type-safe**: Full type hints with Protocol-based interfaces
2323
- **Adapters**: Pydantic model support, raise-on-missing behavior, etc
@@ -131,7 +131,7 @@ pip install py-key-value-aio[memory]
131131
pip install py-key-value-aio[disk]
132132
pip install py-key-value-aio[dynamodb]
133133
pip install py-key-value-aio[elasticsearch]
134-
# or: redis, mongodb, postgresql, memcached, valkey, vault, registry, rocksdb, see below for all options
134+
# or: redis, mongodb, postgresql, memcached, valkey, vault, rocksdb, see below for all options
135135
```
136136

137137
```python

key-value/key-value-aio/src/key_value/aio/stores/postgresql/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
try:
44
from key_value.aio.stores.postgresql.store import PostgreSQLStore
55
except ImportError as e:
6-
msg = "PostgreSQLStore requires py-key-value-aio[postgresql]"
6+
msg = 'PostgreSQLStore requires the "postgresql" extra. Install via: pip install "py-key-value-aio[postgresql]"'
77
raise ImportError(msg) from e
88

99
__all__ = ["PostgreSQLStore"]

key-value/key-value-aio/src/key_value/aio/stores/postgresql/store.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""PostgreSQL-based key-value store using asyncpg.
22
33
Note: SQL queries in this module use f-strings for table names, which triggers S608 warnings.
4-
This is safe because table names are validated in __init__ to be alphanumeric (plus underscores).
4+
This is safe because table names are validated in __init__ to be alphanumeric plus underscores.
55
"""
66

77
from collections.abc import AsyncIterator, Sequence
@@ -33,6 +33,7 @@
3333
# PostgreSQL table name length limit is 63 characters
3434
# Use 200 for consistency with MongoDB
3535
MAX_COLLECTION_LENGTH = 200
36+
POSTGRES_MAX_IDENTIFIER_LEN = 63
3637

3738

3839
class PostgreSQLStore(BaseEnumerateCollectionsStore, BaseDestroyCollectionStore, BaseContextManagerStore, BaseStore):
@@ -64,6 +65,7 @@ class PostgreSQLStore(BaseEnumerateCollectionsStore, BaseDestroyCollectionStore,
6465
"""
6566

6667
_pool: asyncpg.Pool | None # type: ignore[type-arg]
68+
_owns_pool: bool
6769
_url: str | None
6870
_host: str
6971
_port: int
@@ -143,18 +145,26 @@ def __init__(
143145
) -> None:
144146
"""Initialize the PostgreSQL store."""
145147
self._pool = pool
148+
self._owns_pool = pool is None # Only own the pool if we create it
146149
self._url = url
147150
self._host = host
148151
self._port = port
149152
self._database = database
150153
self._user = user
151154
self._password = password
152155

153-
# Validate and sanitize table name to prevent SQL injection
156+
# Validate and sanitize table name to prevent SQL injection and invalid identifiers
154157
table_name = table_name or DEFAULT_TABLE
155158
if not table_name.replace("_", "").isalnum():
156159
msg = f"Table name must be alphanumeric (with underscores): {table_name}"
157160
raise ValueError(msg)
161+
if table_name[0].isdigit():
162+
msg = f"Table name must not start with a digit: {table_name}"
163+
raise ValueError(msg)
164+
# PostgreSQL identifier limit is 63 bytes
165+
if len(table_name) > POSTGRES_MAX_IDENTIFIER_LEN:
166+
msg = f"Table name too long (>{POSTGRES_MAX_IDENTIFIER_LEN}): {table_name}"
167+
raise ValueError(msg)
158168
self._table_name = table_name
159169

160170
super().__init__(default_collection=default_collection)
@@ -200,14 +210,15 @@ async def __aenter__(self) -> Self:
200210
user=self._user,
201211
password=self._password,
202212
)
213+
self._owns_pool = True
203214

204215
await super().__aenter__()
205216
return self
206217

207218
@override
208219
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: # pyright: ignore[reportAny]
209220
await super().__aexit__(exc_type, exc_val, exc_tb)
210-
if self._pool is not None:
221+
if self._pool is not None and self._owns_pool:
211222
await self._pool.close()
212223

213224
def _sanitize_collection_name(self, collection: str) -> str:
@@ -218,8 +229,19 @@ def _sanitize_collection_name(self, collection: str) -> str:
218229
219230
Returns:
220231
A sanitized collection name.
232+
233+
Raises:
234+
ValueError: If the sanitized collection name is empty.
221235
"""
222-
return sanitize_string(value=collection, max_length=MAX_COLLECTION_LENGTH, allowed_characters=ALPHANUMERIC_CHARACTERS)
236+
sanitized = sanitize_string(
237+
value=collection,
238+
max_length=MAX_COLLECTION_LENGTH,
239+
allowed_characters=ALPHANUMERIC_CHARACTERS + "_",
240+
)
241+
if not sanitized:
242+
msg = "Collection name cannot be empty after sanitization"
243+
raise ValueError(msg)
244+
return sanitized
223245

224246
@override
225247
async def _setup_collection(self, *, collection: str) -> None:
@@ -244,8 +266,13 @@ async def _setup_collection(self, *, collection: str) -> None:
244266
"""
245267

246268
# Create index on expires_at for efficient TTL queries
269+
# Ensure index name <= 63 chars (PostgreSQL identifier limit)
270+
index_name = f"idx_{self._table_name}_expires_at"
271+
if len(index_name) > POSTGRES_MAX_IDENTIFIER_LEN:
272+
import hashlib
273+
index_name = "idx_" + hashlib.sha256(self._table_name.encode()).hexdigest()[:16] + "_exp"
247274
create_index_sql = f""" # noqa: S608
248-
CREATE INDEX IF NOT EXISTS idx_{self._table_name}_expires_at
275+
CREATE INDEX IF NOT EXISTS {index_name}
249276
ON {self._table_name}(expires_at)
250277
WHERE expires_at IS NOT NULL
251278
"""
@@ -374,7 +401,6 @@ async def _put_managed_entry(
374401
DO UPDATE SET
375402
value = EXCLUDED.value,
376403
ttl = EXCLUDED.ttl,
377-
created_at = EXCLUDED.created_at,
378404
expires_at = EXCLUDED.expires_at
379405
""", # noqa: S608
380406
sanitized_collection,
@@ -411,9 +437,9 @@ async def _put_managed_entries(
411437

412438
sanitized_collection = self._sanitize_collection_name(collection=collection)
413439

414-
# Prepare data for batch insert
440+
# Prepare data for batch insert using method-level ttl/created_at/expires_at
415441
values = [
416-
(sanitized_collection, key, entry.value, entry.ttl, entry.created_at, entry.expires_at)
442+
(sanitized_collection, key, entry.value, ttl, created_at, expires_at)
417443
for key, entry in zip(keys, managed_entries, strict=True)
418444
]
419445

@@ -427,7 +453,6 @@ async def _put_managed_entries(
427453
DO UPDATE SET
428454
value = EXCLUDED.value,
429455
ttl = EXCLUDED.ttl,
430-
created_at = EXCLUDED.created_at,
431456
expires_at = EXCLUDED.expires_at
432457
""", # noqa: S608
433458
values,
@@ -490,7 +515,9 @@ async def _get_collection_names(self, *, limit: int | None = None) -> list[str]:
490515
Returns:
491516
A list of collection names.
492517
"""
493-
limit = min(limit or DEFAULT_PAGE_SIZE, PAGE_LIMIT)
518+
if limit is None or limit <= 0:
519+
limit = DEFAULT_PAGE_SIZE
520+
limit = min(limit, PAGE_LIMIT)
494521

495522
async with self._acquire_connection() as conn:
496523
rows = await conn.fetch( # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]

0 commit comments

Comments
 (0)