@@ -170,24 +170,26 @@ def test_on_conflict_nothing_foreign_key_by_id():
170170 assert obj1 .data == "some data"
171171
172172
173- def test_on_conflict_nothing_duplicate_rows ():
173+ @pytest .mark .parametrize (
174+ "rows,expected_row_count" ,
175+ [
176+ ([dict (amount = 1 ), dict (amount = 1 )], 1 ),
177+ (iter ([dict (amount = 1 ), dict (amount = 1 )]), 1 ),
178+ ((row for row in [dict (amount = 1 ), dict (amount = 1 )]), 1 ),
179+ ([], 0 ),
180+ (iter ([]), 0 ),
181+ ((row for row in []), 0 ),
182+ ],
183+ )
184+ def test_on_conflict_nothing_duplicate_rows (rows , expected_row_count ):
174185 """Tests whether duplicate rows are filtered out when doing a insert
175186 NOTHING and no error is raised when the list of rows contains
176187 duplicates."""
177188
178189 model = get_fake_model ({"amount" : models .IntegerField (unique = True )})
179190
180- rows = [dict (amount = 1 ), dict (amount = 1 )]
181-
182- inserted_rows = model .objects .on_conflict (
183- ["amount" ], ConflictAction .NOTHING
184- ).bulk_insert (rows )
185-
186- assert len (inserted_rows ) == 1
187-
188- rows = iter ([dict (amount = 2 ), dict (amount = 2 )])
189191 inserted_rows = model .objects .on_conflict (
190192 ["amount" ], ConflictAction .NOTHING
191193 ).bulk_insert (rows )
192194
193- assert len (inserted_rows ) == 1
195+ assert len (inserted_rows ) == expected_row_count
0 commit comments