@@ -379,6 +379,30 @@ class AbstractRedisCluster:
379379
380380 ERRORS_ALLOW_RETRY = (ConnectionError , TimeoutError , ClusterDownError )
381381
382+ def replace_default_node (self , target_node : "ClusterNode" = None ) -> None :
383+ """Replace the default cluster node.
384+ A random cluster node will be chosen if target_node isn't passed, and primaries
385+ will be prioritized. The default node will not be changed if there are no other
386+ nodes in the cluster.
387+
388+ Args:
389+ target_node (ClusterNode, optional): Target node to replace the default
390+ node. Defaults to None.
391+ """
392+ if target_node :
393+ self .nodes_manager .default_node = target_node
394+ else :
395+ curr_node = self .get_default_node ()
396+ primaries = [node for node in self .get_primaries () if node != curr_node ]
397+ if primaries :
398+ # Choose a primary if the cluster contains different primaries
399+ self .nodes_manager .default_node = random .choice (primaries )
400+ else :
401+ # Otherwise, hoose a primary if the cluster contains different primaries
402+ replicas = [node for node in self .get_replicas () if node != curr_node ]
403+ if replicas :
404+ self .nodes_manager .default_node = random .choice (replicas )
405+
382406
383407class RedisCluster (AbstractRedisCluster , RedisClusterCommands ):
384408 @classmethod
@@ -811,7 +835,14 @@ def set_response_callback(self, command, callback):
811835 """Set a custom Response Callback"""
812836 self .cluster_response_callbacks [command ] = callback
813837
814- def _determine_nodes (self , * args , ** kwargs ):
838+ def _determine_nodes (self , * args , ** kwargs ) -> tuple [list ["ClusterNode" ], bool ]:
839+ """Determine which nodes should be executed the command on
840+
841+ Returns:
842+ tuple[list[Type[ClusterNode]], bool]:
843+ A tuple containing a list of target nodes and a bool indicating
844+ if the return node was chosen because it is the default node
845+ """
815846 command = args [0 ].upper ()
816847 if len (args ) >= 2 and f"{ args [0 ]} { args [1 ]} " .upper () in self .command_flags :
817848 command = f"{ args [0 ]} { args [1 ]} " .upper ()
@@ -825,28 +856,28 @@ def _determine_nodes(self, *args, **kwargs):
825856 command_flag = self .command_flags .get (command )
826857 if command_flag == self .__class__ .RANDOM :
827858 # return a random node
828- return [self .get_random_node ()]
859+ return [self .get_random_node ()], False
829860 elif command_flag == self .__class__ .PRIMARIES :
830861 # return all primaries
831- return self .get_primaries ()
862+ return self .get_primaries (), False
832863 elif command_flag == self .__class__ .REPLICAS :
833864 # return all replicas
834- return self .get_replicas ()
865+ return self .get_replicas (), False
835866 elif command_flag == self .__class__ .ALL_NODES :
836867 # return all nodes
837- return self .get_nodes ()
868+ return self .get_nodes (), False
838869 elif command_flag == self .__class__ .DEFAULT_NODE :
839870 # return the cluster's default node
840- return [self .nodes_manager .default_node ]
871+ return [self .nodes_manager .default_node ], True
841872 elif command in self .__class__ .SEARCH_COMMANDS [0 ]:
842- return [self .nodes_manager .default_node ]
873+ return [self .nodes_manager .default_node ], True
843874 else :
844875 # get the node that holds the key's slot
845876 slot = self .determine_slot (* args )
846877 node = self .nodes_manager .get_node_from_slot (
847878 slot , self .read_from_replicas and command in READ_COMMANDS
848879 )
849- return [node ]
880+ return [node ], False
850881
851882 def _should_reinitialized (self ):
852883 # To reinitialize the cluster on every MOVED error,
@@ -990,6 +1021,7 @@ def execute_command(self, *args, **kwargs):
9901021 dict<Any, ClusterNode>
9911022 """
9921023 target_nodes_specified = False
1024+ is_default_node = False
9931025 target_nodes = None
9941026 passed_targets = kwargs .pop ("target_nodes" , None )
9951027 if passed_targets is not None and not self ._is_nodes_flag (passed_targets ):
@@ -1013,7 +1045,7 @@ def execute_command(self, *args, **kwargs):
10131045 res = {}
10141046 if not target_nodes_specified :
10151047 # Determine the nodes to execute the command on
1016- target_nodes = self ._determine_nodes (
1048+ target_nodes , is_default_node = self ._determine_nodes (
10171049 * args , ** kwargs , nodes_flag = passed_targets
10181050 )
10191051 if not target_nodes :
@@ -1025,6 +1057,9 @@ def execute_command(self, *args, **kwargs):
10251057 # Return the processed result
10261058 return self ._process_result (args [0 ], res , ** kwargs )
10271059 except Exception as e :
1060+ if is_default_node :
1061+ # Replace the default cluster node
1062+ self .replace_default_node ()
10281063 if retry_attempts > 0 and type (e ) in self .__class__ .ERRORS_ALLOW_RETRY :
10291064 # The nodes and slots cache were reinitialized.
10301065 # Try again with the new cluster setup.
@@ -1883,7 +1918,7 @@ def _send_cluster_commands(
18831918 # if we have to run through it again, we only retry
18841919 # the commands that failed.
18851920 attempt = sorted (stack , key = lambda x : x .position )
1886-
1921+ is_default_node = False
18871922 # build a list of node objects based on node names we need to
18881923 nodes = {}
18891924
@@ -1900,7 +1935,7 @@ def _send_cluster_commands(
19001935 if passed_targets and not self ._is_nodes_flag (passed_targets ):
19011936 target_nodes = self ._parse_target_nodes (passed_targets )
19021937 else :
1903- target_nodes = self ._determine_nodes (
1938+ target_nodes , is_default_node = self ._determine_nodes (
19041939 * c .args , node_flag = passed_targets
19051940 )
19061941 if not target_nodes :
@@ -1926,6 +1961,8 @@ def _send_cluster_commands(
19261961 # Connection retries are being handled in the node's
19271962 # Retry object. Reinitialize the node -> slot table.
19281963 self .nodes_manager .initialize ()
1964+ if is_default_node :
1965+ self .replace_default_node ()
19291966 raise
19301967 nodes [node_name ] = NodeCommands (
19311968 redis_node .parse_response ,
@@ -2007,6 +2044,8 @@ def _send_cluster_commands(
20072044 self .reinitialize_counter += 1
20082045 if self ._should_reinitialized ():
20092046 self .nodes_manager .initialize ()
2047+ if is_default_node :
2048+ self .replace_default_node ()
20102049 for c in attempt :
20112050 try :
20122051 # send each command individually like we
0 commit comments