Skip to content

Commit 200f2b9

Browse files
seroyPhotonios
authored andcommitted
Fix StopIteration in deduplication rows code when conflict_action == ConflictAction.NOTHING and rows parameter is iterator or generator
1 parent 13b5672 commit 200f2b9

File tree

2 files changed

+23
-8
lines changed

2 files changed

+23
-8
lines changed

psqlextra/query.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,11 +174,19 @@ def bulk_insert(
174174
A list of either the dicts of the rows inserted, including the pk or
175175
the models of the rows inserted with defaults for any fields not specified
176176
"""
177+
if rows is None:
178+
return []
179+
180+
def peek(iterable):
181+
try:
182+
first = next(iterable)
183+
except StopIteration:
184+
return None
185+
return list(chain([first], iterable))
177186

178-
def is_empty(r):
179-
return all([False for _ in r])
187+
rows = peek(iter(rows))
180188

181-
if not rows or is_empty(rows):
189+
if not rows:
182190
return []
183191

184192
if not self.conflict_target and not self.conflict_action:

tests/test_on_conflict_nothing.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,15 @@ def test_on_conflict_nothing_duplicate_rows():
179179

180180
rows = [dict(amount=1), dict(amount=1)]
181181

182-
(
183-
model.objects.on_conflict(
184-
["amount"], ConflictAction.NOTHING
185-
).bulk_insert(rows)
186-
)
182+
inserted_rows = model.objects.on_conflict(
183+
["amount"], ConflictAction.NOTHING
184+
).bulk_insert(rows)
185+
186+
assert len(inserted_rows) == 1
187+
188+
rows = iter([dict(amount=2), dict(amount=2)])
189+
inserted_rows = model.objects.on_conflict(
190+
["amount"], ConflictAction.NOTHING
191+
).bulk_insert(rows)
192+
193+
assert len(inserted_rows) == 1

0 commit comments

Comments
 (0)