Skip to content

Commit 422e91f

Browse files
committed
Add additional tests for ON CONFLICT DO NOTHING duplicate rows filtering
1 parent 200f2b9 commit 422e91f

File tree

2 files changed

+22
-19
lines changed

2 files changed

+22
-19
lines changed

psqlextra/query.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,14 @@
4141
QuerySetBase = QuerySet
4242

4343

44+
def peek_iterator(iterable):
45+
try:
46+
first = next(iterable)
47+
except StopIteration:
48+
return None
49+
return list(chain([first], iterable))
50+
51+
4452
class PostgresQuerySet(QuerySetBase, Generic[TModel]):
4553
"""Adds support for PostgreSQL specifics."""
4654

@@ -177,14 +185,7 @@ def bulk_insert(
177185
if rows is None:
178186
return []
179187

180-
def peek(iterable):
181-
try:
182-
first = next(iterable)
183-
except StopIteration:
184-
return None
185-
return list(chain([first], iterable))
186-
187-
rows = peek(iter(rows))
188+
rows = peek_iterator(iter(rows))
188189

189190
if not rows:
190191
return []

tests/test_on_conflict_nothing.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -170,24 +170,26 @@ def test_on_conflict_nothing_foreign_key_by_id():
170170
assert obj1.data == "some data"
171171

172172

173-
def test_on_conflict_nothing_duplicate_rows():
173+
@pytest.mark.parametrize(
174+
"rows,expected_row_count",
175+
[
176+
([dict(amount=1), dict(amount=1)], 1),
177+
(iter([dict(amount=1), dict(amount=1)]), 1),
178+
((row for row in [dict(amount=1), dict(amount=1)]), 1),
179+
([], 0),
180+
(iter([]), 0),
181+
((row for row in []), 0),
182+
],
183+
)
184+
def test_on_conflict_nothing_duplicate_rows(rows, expected_row_count):
174185
"""Tests whether duplicate rows are filtered out when doing a insert
175186
NOTHING and no error is raised when the list of rows contains
176187
duplicates."""
177188

178189
model = get_fake_model({"amount": models.IntegerField(unique=True)})
179190

180-
rows = [dict(amount=1), dict(amount=1)]
181-
182-
inserted_rows = model.objects.on_conflict(
183-
["amount"], ConflictAction.NOTHING
184-
).bulk_insert(rows)
185-
186-
assert len(inserted_rows) == 1
187-
188-
rows = iter([dict(amount=2), dict(amount=2)])
189191
inserted_rows = model.objects.on_conflict(
190192
["amount"], ConflictAction.NOTHING
191193
).bulk_insert(rows)
192194

193-
assert len(inserted_rows) == 1
195+
assert len(inserted_rows) == expected_row_count

0 commit comments

Comments
 (0)