Skip to content

Commit 38a059d

Browse files
refactor: optimize PostgreSQL store to use one-time table setup
- Move table/index creation from _setup_collection to _setup (called once) - Remove collection sanitization since collection names are column values - Remove PostgreSQLV1CollectionSanitizationStrategy class and exports - Update tests to verify collection names work without restrictions - Update docstrings to clarify collections are stored as values This simplifies the implementation since all collections share a single table, eliminating unnecessary per-collection setup overhead. Co-authored-by: William Easton <[email protected]>
1 parent 6211f03 commit 38a059d

File tree

3 files changed

+35
-120
lines changed

3 files changed

+35
-120
lines changed
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
"""PostgreSQL store for py-key-value-aio."""
22

33
try:
4-
from key_value.aio.stores.postgresql.store import PostgreSQLStore, PostgreSQLV1CollectionSanitizationStrategy
4+
from key_value.aio.stores.postgresql.store import PostgreSQLStore
55
except ImportError as e:
66
msg = 'PostgreSQLStore requires the "postgresql" extra. Install via: pip install "py-key-value-aio[postgresql]"'
77
raise ImportError(msg) from e
88

9-
__all__ = ["PostgreSQLStore", "PostgreSQLV1CollectionSanitizationStrategy"]
9+
__all__ = ["PostgreSQLStore"]

key-value/key-value-aio/src/key_value/aio/stores/postgresql/store.py

Lines changed: 17 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
from typing import Any, overload
1313

1414
from key_value.shared.utils.managed_entry import ManagedEntry
15-
from key_value.shared.utils.sanitization import HybridSanitizationStrategy, SanitizationStrategy
16-
from key_value.shared.utils.sanitize import ALPHANUMERIC_CHARACTERS
1715
from typing_extensions import Self, override
1816

1917
from key_value.aio.stores.base import BaseContextManagerStore, BaseDestroyCollectionStore, BaseEnumerateCollectionsStore, BaseStore
@@ -34,31 +32,15 @@
3432
PAGE_LIMIT = 10000
3533

3634
# PostgreSQL table name length limit is 63 characters
37-
# Use 200 for consistency with MongoDB
38-
MAX_COLLECTION_LENGTH = 200
3935
POSTGRES_MAX_IDENTIFIER_LEN = 63
40-
COLLECTION_ALLOWED_CHARACTERS = ALPHANUMERIC_CHARACTERS + "_"
41-
42-
43-
class PostgreSQLV1CollectionSanitizationStrategy(HybridSanitizationStrategy):
44-
def __init__(self) -> None:
45-
super().__init__(
46-
replacement_character="_",
47-
max_length=MAX_COLLECTION_LENGTH,
48-
allowed_characters=COLLECTION_ALLOWED_CHARACTERS,
49-
)
5036

5137

5238
class PostgreSQLStore(BaseEnumerateCollectionsStore, BaseDestroyCollectionStore, BaseContextManagerStore, BaseStore):
5339
"""PostgreSQL-based key-value store using asyncpg.
5440
55-
This store uses a single table with columns for collection, key, value (JSONB), and metadata.
56-
Collections are stored as a column value rather than separate tables.
57-
58-
By default, collections are not sanitized. This means that there are character and length restrictions on
59-
collection names that may cause errors when trying to get and put entries.
60-
61-
To avoid issues, you may want to consider leveraging the `PostgreSQLV1CollectionSanitizationStrategy` strategy.
41+
This store uses a single shared table with columns for collection, key, value (JSONB), and metadata.
42+
Collections are stored as values in the collection column, not as separate tables or SQL identifiers,
43+
so there are no character restrictions on collection names.
6244
6345
Example:
6446
Basic usage with default connection:
@@ -99,15 +81,13 @@ def __init__(
9981
pool: asyncpg.Pool, # type: ignore[type-arg]
10082
table_name: str | None = None,
10183
default_collection: str | None = None,
102-
collection_sanitization_strategy: SanitizationStrategy | None = None,
10384
) -> None:
10485
"""Initialize the PostgreSQL store with an existing connection pool.
10586
10687
Args:
10788
pool: An existing asyncpg connection pool to use.
10889
table_name: The name of the table to use for storage (default: kv_store).
10990
default_collection: The default collection to use if no collection is provided.
110-
collection_sanitization_strategy: The sanitization strategy to use for collections.
11191
"""
11292

11393
@overload
@@ -117,15 +97,13 @@ def __init__(
11797
url: str,
11898
table_name: str | None = None,
11999
default_collection: str | None = None,
120-
collection_sanitization_strategy: SanitizationStrategy | None = None,
121100
) -> None:
122101
"""Initialize the PostgreSQL store with a connection URL.
123102
124103
Args:
125104
url: PostgreSQL connection URL (e.g., postgresql://user:pass@localhost/dbname).
126105
table_name: The name of the table to use for storage (default: kv_store).
127106
default_collection: The default collection to use if no collection is provided.
128-
collection_sanitization_strategy: The sanitization strategy to use for collections.
129107
"""
130108

131109
@overload
@@ -139,7 +117,6 @@ def __init__(
139117
password: str | None = None,
140118
table_name: str | None = None,
141119
default_collection: str | None = None,
142-
collection_sanitization_strategy: SanitizationStrategy | None = None,
143120
) -> None:
144121
"""Initialize the PostgreSQL store with connection parameters.
145122
@@ -151,7 +128,6 @@ def __init__(
151128
password: Database password (default: None).
152129
table_name: The name of the table to use for storage (default: kv_store).
153130
default_collection: The default collection to use if no collection is provided.
154-
collection_sanitization_strategy: The sanitization strategy to use for collections.
155131
"""
156132

157133
def __init__(
@@ -166,7 +142,6 @@ def __init__(
166142
password: str | None = None,
167143
table_name: str | None = None,
168144
default_collection: str | None = None,
169-
collection_sanitization_strategy: SanitizationStrategy | None = None,
170145
) -> None:
171146
"""Initialize the PostgreSQL store."""
172147
self._pool = pool
@@ -178,7 +153,7 @@ def __init__(
178153
self._user = user
179154
self._password = password
180155

181-
# Validate and sanitize table name to prevent SQL injection and invalid identifiers
156+
# Validate table name to prevent SQL injection and invalid identifiers
182157
table_name = table_name or DEFAULT_TABLE
183158
if not table_name.replace("_", "").isalnum():
184159
msg = f"Table name must be alphanumeric (with underscores): {table_name}"
@@ -192,10 +167,7 @@ def __init__(
192167
raise ValueError(msg)
193168
self._table_name = table_name
194169

195-
super().__init__(
196-
default_collection=default_collection,
197-
collection_sanitization_strategy=collection_sanitization_strategy,
198-
)
170+
super().__init__(default_collection=default_collection)
199171

200172
def _ensure_pool_initialized(self) -> asyncpg.Pool: # type: ignore[type-arg]
201173
"""Ensure the connection pool is initialized.
@@ -250,14 +222,12 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: #
250222
await self._pool.close()
251223

252224
@override
253-
async def _setup_collection(self, *, collection: str) -> None:
225+
async def _setup(self) -> None:
254226
"""Set up the database table and indexes if they don't exist.
255227
256-
Args:
257-
collection: The collection name (used for validation, but all collections share the same table).
228+
This is called once when the store is first used. Since all collections share the same table,
229+
we only need to set up the schema once.
258230
"""
259-
_ = self._sanitize_collection(collection=collection)
260-
261231
# Create the main table if it doesn't exist
262232
table_sql = (
263233
f"CREATE TABLE IF NOT EXISTS {self._table_name} ("
@@ -295,12 +265,10 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry
295265
Returns:
296266
The managed entry if found and not expired, None otherwise.
297267
"""
298-
sanitized_collection = self._sanitize_collection(collection=collection)
299-
300268
async with self._acquire_connection() as conn:
301269
row = await conn.fetchrow( # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
302270
f"SELECT value, ttl, created_at, expires_at FROM {self._table_name} WHERE collection = $1 AND key = $2",
303-
sanitized_collection,
271+
collection,
304272
key,
305273
)
306274

@@ -318,7 +286,7 @@ async def _get_managed_entry(self, *, key: str, collection: str) -> ManagedEntry
318286
if managed_entry.is_expired:
319287
await conn.execute( # pyright: ignore[reportUnknownMemberType]
320288
f"DELETE FROM {self._table_name} WHERE collection = $1 AND key = $2",
321-
sanitized_collection,
289+
collection,
322290
key,
323291
)
324292
return None
@@ -339,13 +307,11 @@ async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) ->
339307
if not keys:
340308
return []
341309

342-
sanitized_collection = self._sanitize_collection(collection=collection)
343-
344310
async with self._acquire_connection() as conn:
345311
# Use ANY to query for multiple keys
346312
rows = await conn.fetch( # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
347313
f"SELECT key, value, ttl, created_at, expires_at FROM {self._table_name} WHERE collection = $1 AND key = ANY($2::text[])",
348-
sanitized_collection,
314+
collection,
349315
list(keys),
350316
)
351317

@@ -370,7 +336,7 @@ async def _get_managed_entries(self, *, collection: str, keys: Sequence[str]) ->
370336
if expired_keys:
371337
await conn.execute( # pyright: ignore[reportUnknownMemberType]
372338
f"DELETE FROM {self._table_name} WHERE collection = $1 AND key = ANY($2::text[])",
373-
sanitized_collection,
339+
collection,
374340
expired_keys,
375341
)
376342

@@ -391,7 +357,6 @@ async def _put_managed_entry(
391357
collection: The collection to store in.
392358
managed_entry: The managed entry to store.
393359
"""
394-
sanitized_collection = self._sanitize_collection(collection=collection)
395360

396361
async with self._acquire_connection() as conn:
397362
upsert_sql = (
@@ -403,7 +368,7 @@ async def _put_managed_entry(
403368
)
404369
await conn.execute( # pyright: ignore[reportUnknownMemberType]
405370
upsert_sql,
406-
sanitized_collection,
371+
collection,
407372
key,
408373
managed_entry.value,
409374
managed_entry.ttl,
@@ -435,12 +400,8 @@ async def _put_managed_entries(
435400
if not keys:
436401
return
437402

438-
sanitized_collection = self._sanitize_collection(collection=collection)
439-
440403
# Prepare data for batch insert using method-level ttl/created_at/expires_at
441-
values = [
442-
(sanitized_collection, key, entry.value, ttl, created_at, expires_at) for key, entry in zip(keys, managed_entries, strict=True)
443-
]
404+
values = [(collection, key, entry.value, ttl, created_at, expires_at) for key, entry in zip(keys, managed_entries, strict=True)]
444405

445406
async with self._acquire_connection() as conn:
446407
# Use executemany for batch insert
@@ -467,12 +428,10 @@ async def _delete_managed_entry(self, *, key: str, collection: str) -> bool:
467428
Returns:
468429
True if the entry was deleted, False if it didn't exist.
469430
"""
470-
sanitized_collection = self._sanitize_collection(collection=collection)
471-
472431
async with self._acquire_connection() as conn:
473432
result = await conn.execute( # pyright: ignore[reportUnknownMemberType]
474433
f"DELETE FROM {self._table_name} WHERE collection = $1 AND key = $2",
475-
sanitized_collection,
434+
collection,
476435
key,
477436
)
478437
# PostgreSQL execute returns a string like "DELETE N" where N is the number of rows deleted
@@ -492,12 +451,10 @@ async def _delete_managed_entries(self, *, keys: Sequence[str], collection: str)
492451
if not keys:
493452
return 0
494453

495-
sanitized_collection = self._sanitize_collection(collection=collection)
496-
497454
async with self._acquire_connection() as conn:
498455
result = await conn.execute( # pyright: ignore[reportUnknownMemberType]
499456
f"DELETE FROM {self._table_name} WHERE collection = $1 AND key = ANY($2::text[])",
500-
sanitized_collection,
457+
collection,
501458
list(keys),
502459
)
503460
# PostgreSQL execute returns a string like "DELETE N" where N is the number of rows deleted
@@ -535,12 +492,10 @@ async def _delete_collection(self, *, collection: str) -> bool:
535492
Returns:
536493
True if any entries were deleted, False otherwise.
537494
"""
538-
sanitized_collection = self._sanitize_collection(collection=collection)
539-
540495
async with self._acquire_connection() as conn:
541496
result = await conn.execute( # pyright: ignore[reportUnknownMemberType]
542497
f"DELETE FROM {self._table_name} WHERE collection = $1",
543-
sanitized_collection,
498+
collection,
544499
)
545500
# Return True if any rows were deleted
546501
return result.split()[-1] != "0"

key-value/key-value-aio/tests/stores/postgresql/test_postgresql.py

Lines changed: 16 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing_extensions import override
88

99
from key_value.aio.stores.base import BaseStore
10-
from key_value.aio.stores.postgresql import PostgreSQLStore, PostgreSQLV1CollectionSanitizationStrategy
10+
from key_value.aio.stores.postgresql import PostgreSQLStore
1111
from tests.conftest import docker_container, should_skip_docker_tests
1212
from tests.stores.base import BaseStoreTests, ContextManagerStoreTestMixin
1313

@@ -104,65 +104,25 @@ async def store(self, setup_postgresql: None) -> PostgreSQLStore:
104104

105105
return store
106106

107-
@pytest.fixture
108-
async def postgresql_store(self, store: PostgreSQLStore) -> PostgreSQLStore:
109-
"""Provide the PostgreSQL store fixture."""
110-
return store
111-
112-
@pytest.fixture
113-
async def sanitizing_store(self, setup_postgresql: None) -> PostgreSQLStore:
114-
"""Create a PostgreSQL store with collection sanitization enabled."""
115-
store = PostgreSQLStore(
116-
host=POSTGRESQL_HOST,
117-
port=POSTGRESQL_HOST_PORT,
118-
database=POSTGRESQL_TEST_DB,
119-
user=POSTGRESQL_USER,
120-
password=POSTGRESQL_PASSWORD,
121-
table_name="kv_store_sanitizing",
122-
collection_sanitization_strategy=PostgreSQLV1CollectionSanitizationStrategy(),
123-
)
124-
125-
# Clean up the database before each test
126-
async with store:
127-
if store._pool is not None: # pyright: ignore[reportPrivateUsage]
128-
async with store._pool.acquire() as conn: # pyright: ignore[reportPrivateUsage, reportUnknownMemberType, reportUnknownVariableType]
129-
# Drop and recreate the kv_store_sanitizing table
130-
with contextlib.suppress(Exception):
131-
await conn.execute("DROP TABLE IF EXISTS kv_store_sanitizing") # pyright: ignore[reportUnknownMemberType]
132-
133-
return store
134-
135107
@pytest.mark.skip(reason="Distributed Caches are unbounded")
136108
@override
137109
async def test_not_unbounded(self, store: BaseStore): ...
138110

139111
@override
140-
async def test_long_collection_name(self, store: PostgreSQLStore, sanitizing_store: PostgreSQLStore): # pyright: ignore[reportIncompatibleMethodOverride]
141-
"""Test that long collection names fail without sanitization but work with it."""
142-
with pytest.raises(Exception): # noqa: B017, PT011
143-
await store.put(collection="test_collection" * 100, key="test_key", value={"test": "test"})
144-
145-
await sanitizing_store.put(collection="test_collection" * 100, key="test_key", value={"test": "test"})
146-
assert await sanitizing_store.get(collection="test_collection" * 100, key="test_key") == {"test": "test"}
112+
async def test_long_collection_name(self, store: PostgreSQLStore): # pyright: ignore[reportIncompatibleMethodOverride]
113+
"""Test that long collection names work since they're just column values."""
114+
# Long collection names should work fine since they're stored as column values, not SQL identifiers
115+
long_collection = "test_collection" * 100
116+
await store.put(collection=long_collection, key="test_key", value={"test": "test"})
117+
assert await store.get(collection=long_collection, key="test_key") == {"test": "test"}
147118

148119
@override
149-
async def test_special_characters_in_collection_name(self, store: PostgreSQLStore, sanitizing_store: PostgreSQLStore): # pyright: ignore[reportIncompatibleMethodOverride]
150-
"""Test that special characters in collection names fail without sanitization but work with it."""
151-
# Without sanitization, special characters should work (PostgreSQL allows them in column values)
152-
# but may cause issues with certain characters
153-
await store.put(collection="test_collection", key="test_key", value={"test": "test"})
154-
assert await store.get(collection="test_collection", key="test_key") == {"test": "test"}
155-
156-
# With sanitization, special characters should work
157-
await sanitizing_store.put(collection="test_collection!@#$%^&*()", key="test_key", value={"test": "test"})
158-
assert await sanitizing_store.get(collection="test_collection!@#$%^&*()", key="test_key") == {"test": "test"}
159-
160-
async def test_postgresql_collection_name_sanitization(self, sanitizing_store: PostgreSQLStore):
161-
"""Test that the V1 sanitization strategy produces expected collection names."""
162-
await sanitizing_store.put(collection="test_collection!@#$%^&*()", key="test_key", value={"test": "test"})
163-
assert await sanitizing_store.get(collection="test_collection!@#$%^&*()", key="test_key") == {"test": "test"}
164-
165-
collections = await sanitizing_store.collections()
166-
# The sanitized collection name should only contain alphanumeric characters and underscores
167-
assert len(collections) == 1
168-
assert all(c.isalnum() or c in "_-" for c in collections[0])
120+
async def test_special_characters_in_collection_name(self, store: PostgreSQLStore): # pyright: ignore[reportIncompatibleMethodOverride]
121+
"""Test that special characters in collection names work since they're just column values."""
122+
# Special characters should work fine since collection names are stored as column values
123+
await store.put(collection="test_collection!@#$%^&*()", key="test_key", value={"test": "test"})
124+
assert await store.get(collection="test_collection!@#$%^&*()", key="test_key") == {"test": "test"}
125+
126+
# Verify the collection name is stored as-is
127+
collections = await store.collections()
128+
assert "test_collection!@#$%^&*()" in collections

0 commit comments

Comments
 (0)