11from collections import OrderedDict
22from itertools import chain
3- from typing import Dict , Iterable , List , Optional , Tuple , Union
3+ from typing import Any , Dict , Iterable , List , Optional , Tuple , Union
44
55from django .core .exceptions import SuspiciousOperation
66from django .db import connections , models , router
77from django .db .models import Expression , Q
88from django .db .models .fields import NOT_PROVIDED
99
10+ from .expressions import ExcludedCol
1011from .sql import PostgresInsertQuery , PostgresQuery
1112from .types import ConflictAction
1213
@@ -27,6 +28,7 @@ def __init__(self, model=None, query=None, using=None, hints=None):
2728 self .conflict_action = None
2829 self .conflict_update_condition = None
2930 self .index_predicate = None
31+ self .update_values = None
3032
3133 def annotate (self , ** annotations ):
3234 """Custom version of the standard annotate function that allows using
@@ -84,6 +86,7 @@ def on_conflict(
8486 action : ConflictAction ,
8587 index_predicate : Optional [Union [Expression , Q , str ]] = None ,
8688 update_condition : Optional [Union [Expression , Q , str ]] = None ,
89+ update_values : Optional [Dict [str , Union [Any , Expression ]]] = None ,
8790 ):
8891 """Sets the action to take when conflicts arise when attempting to
8992 insert/create a new row.
@@ -101,12 +104,18 @@ def on_conflict(
101104
102105 update_condition:
103106 Only update if this SQL expression evaluates to true.
107+
108+ update_values:
109+ Optionally, values/expressions to use when rows
110+ conflict. If not specified, all columns specified
111+ in the rows are updated with the values you specified.
104112 """
105113
106114 self .conflict_target = fields
107115 self .conflict_action = action
108116 self .conflict_update_condition = update_condition
109117 self .index_predicate = index_predicate
118+ self .update_values = update_values
110119
111120 return self
112121
@@ -260,6 +269,7 @@ def upsert(
260269 index_predicate : Optional [Union [Expression , Q , str ]] = None ,
261270 using : Optional [str ] = None ,
262271 update_condition : Optional [Union [Expression , Q , str ]] = None ,
272+ update_values : Optional [Dict [str , Union [Any , Expression ]]] = None ,
263273 ) -> int :
264274 """Creates a new record or updates the existing one with the specified
265275 data.
@@ -282,6 +292,11 @@ def upsert(
282292 update_condition:
283293 Only update if this SQL expression evaluates to true.
284294
295+ update_values:
296+ Optionally, values/expressions to use when rows
297+ conflict. If not specified, all columns specified
298+ in the rows are updated with the values you specified.
299+
285300 Returns:
286301 The primary key of the row that was created/updated.
287302 """
@@ -291,6 +306,7 @@ def upsert(
291306 ConflictAction .UPDATE ,
292307 index_predicate = index_predicate ,
293308 update_condition = update_condition ,
309+ update_values = update_values ,
294310 )
295311 return self .insert (** fields , using = using )
296312
@@ -301,6 +317,7 @@ def upsert_and_get(
301317 index_predicate : Optional [Union [Expression , Q , str ]] = None ,
302318 using : Optional [str ] = None ,
303319 update_condition : Optional [Union [Expression , Q , str ]] = None ,
320+ update_values : Optional [Dict [str , Union [Any , Expression ]]] = None ,
304321 ):
305322 """Creates a new record or updates the existing one with the specified
306323 data and then gets the row.
@@ -323,6 +340,11 @@ def upsert_and_get(
323340 update_condition:
324341 Only update if this SQL expression evaluates to true.
325342
343+ update_values:
344+ Optionally, values/expressions to use when rows
345+ conflict. If not specified, all columns specified
346+ in the rows are updated with the values you specified.
347+
326348 Returns:
327349 The model instance representing the row
328350 that was created/updated.
@@ -333,6 +355,7 @@ def upsert_and_get(
333355 ConflictAction .UPDATE ,
334356 index_predicate = index_predicate ,
335357 update_condition = update_condition ,
358+ update_values = update_values ,
336359 )
337360 return self .insert_and_get (** fields , using = using )
338361
@@ -344,6 +367,7 @@ def bulk_upsert(
344367 return_model : bool = False ,
345368 using : Optional [str ] = None ,
346369 update_condition : Optional [Union [Expression , Q , str ]] = None ,
370+ update_values : Optional [Dict [str , Union [Any , Expression ]]] = None ,
347371 ):
348372 """Creates a set of new records or updates the existing ones with the
349373 specified data.
@@ -370,6 +394,11 @@ def bulk_upsert(
370394 update_condition:
371395 Only update if this SQL expression evaluates to true.
372396
397+ update_values:
398+ Optionally, values/expressions to use when rows
399+ conflict. If not specified, all columns specified
400+ in the rows are updated with the values you specified.
401+
373402 Returns:
374403 A list of either the dicts of the rows upserted, including the pk or
375404 the models of the rows upserted
@@ -386,7 +415,9 @@ def is_empty(r):
386415 ConflictAction .UPDATE ,
387416 index_predicate = index_predicate ,
388417 update_condition = update_condition ,
418+ update_values = update_values ,
389419 )
420+
390421 return self .bulk_insert (rows , return_model , using = using )
391422
392423 def _create_model_instance (
@@ -474,15 +505,19 @@ def _build_insert_compiler(
474505 )
475506
476507 # get the fields to be used during update/insert
477- insert_fields , update_fields = self ._get_upsert_fields (first_row )
508+ insert_fields , update_values = self ._get_upsert_fields (first_row )
509+
510+ # allow the user to override what should happen on update
511+ if self .update_values is not None :
512+ update_values = self .update_values
478513
479514 # build a normal insert query
480515 query = PostgresInsertQuery (self .model )
481516 query .conflict_action = self .conflict_action
482517 query .conflict_target = self .conflict_target
483518 query .conflict_update_condition = self .conflict_update_condition
484519 query .index_predicate = self .index_predicate
485- query .values (objs , insert_fields , update_fields )
520+ query .values (objs , insert_fields , update_values )
486521
487522 compiler = query .get_compiler (using )
488523 return compiler
@@ -547,13 +582,13 @@ def _get_upsert_fields(self, kwargs):
547582
548583 model_instance = self .model (** kwargs )
549584 insert_fields = []
550- update_fields = []
585+ update_values = {}
551586
552587 for field in model_instance ._meta .local_concrete_fields :
553588 has_default = field .default != NOT_PROVIDED
554589 if field .name in kwargs or field .column in kwargs :
555590 insert_fields .append (field )
556- update_fields . append (field )
591+ update_values [ field . name ] = ExcludedCol (field . column )
557592 continue
558593 elif has_default :
559594 insert_fields .append (field )
@@ -564,13 +599,13 @@ def _get_upsert_fields(self, kwargs):
564599 # instead of a concrete field, we have to handle that
565600 if field .primary_key is True and "pk" in kwargs :
566601 insert_fields .append (field )
567- update_fields . append (field )
602+ update_values [ field . name ] = ExcludedCol (field . column )
568603 continue
569604
570605 if self ._is_magical_field (model_instance , field , is_insert = True ):
571606 insert_fields .append (field )
572607
573608 if self ._is_magical_field (model_instance , field , is_insert = False ):
574- update_fields . append (field )
609+ update_values [ field . name ] = ExcludedCol (field . column )
575610
576- return insert_fields , update_fields
611+ return insert_fields , update_values
0 commit comments