@@ -833,6 +833,7 @@ class AbstractRedis:
833833 "QUIT" : bool_ok ,
834834 "STRALGO" : parse_stralgo ,
835835 "PUBSUB NUMSUB" : parse_pubsub_numsub ,
836+ "PUBSUB SHARDNUMSUB" : parse_pubsub_numsub ,
836837 "RANDOMKEY" : lambda r : r and r or None ,
837838 "RESET" : str_if_bytes ,
838839 "SCAN" : parse_scan ,
@@ -1440,8 +1441,8 @@ class PubSub:
14401441 will be returned and it's safe to start listening again.
14411442 """
14421443
1443- PUBLISH_MESSAGE_TYPES = ("message" , "pmessage" )
1444- UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe" , "punsubscribe" )
1444+ PUBLISH_MESSAGE_TYPES = ("message" , "pmessage" , "smessage" )
1445+ UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe" , "punsubscribe" , "sunsubscribe" )
14451446 HEALTH_CHECK_MESSAGE = "redis-py-health-check"
14461447
14471448 def __init__ (
@@ -1493,9 +1494,11 @@ def reset(self):
14931494 self .connection .clear_connect_callbacks ()
14941495 self .connection_pool .release (self .connection )
14951496 self .connection = None
1496- self .channels = {}
14971497 self .health_check_response_counter = 0
1498+ self .channels = {}
14981499 self .pending_unsubscribe_channels = set ()
1500+ self .shard_channels = {}
1501+ self .pending_unsubscribe_shard_channels = set ()
14991502 self .patterns = {}
15001503 self .pending_unsubscribe_patterns = set ()
15011504 self .subscribed_event .clear ()
@@ -1510,16 +1513,23 @@ def on_connect(self, connection):
15101513 # before passing them to [p]subscribe.
15111514 self .pending_unsubscribe_channels .clear ()
15121515 self .pending_unsubscribe_patterns .clear ()
1516+ self .pending_unsubscribe_shard_channels .clear ()
15131517 if self .channels :
1514- channels = {}
1515- for k , v in self .channels .items ():
1516- channels [ self . encoder . decode ( k , force = True )] = v
1518+ channels = {
1519+ self . encoder . decode ( k , force = True ): v for k , v in self .channels .items ()
1520+ }
15171521 self .subscribe (** channels )
15181522 if self .patterns :
1519- patterns = {}
1520- for k , v in self .patterns .items ():
1521- patterns [ self . encoder . decode ( k , force = True )] = v
1523+ patterns = {
1524+ self . encoder . decode ( k , force = True ): v for k , v in self .patterns .items ()
1525+ }
15221526 self .psubscribe (** patterns )
1527+ if self .shard_channels :
1528+ shard_channels = {
1529+ self .encoder .decode (k , force = True ): v
1530+ for k , v in self .shard_channels .items ()
1531+ }
1532+ self .ssubscribe (** shard_channels )
15231533
15241534 @property
15251535 def subscribed (self ):
@@ -1728,6 +1738,45 @@ def unsubscribe(self, *args):
17281738 self .pending_unsubscribe_channels .update (channels )
17291739 return self .execute_command ("UNSUBSCRIBE" , * args )
17301740
1741+ def ssubscribe (self , * args , target_node = None , ** kwargs ):
1742+ """
1743+ Subscribes the client to the specified shard channels.
1744+ Channels supplied as keyword arguments expect a channel name as the key
1745+ and a callable as the value. A channel's callable will be invoked automatically
1746+ when a message is received on that channel rather than producing a message via
1747+ ``listen()`` or ``get_sharded_message()``.
1748+ """
1749+ if args :
1750+ args = list_or_args (args [0 ], args [1 :])
1751+ new_s_channels = dict .fromkeys (args )
1752+ new_s_channels .update (kwargs )
1753+ ret_val = self .execute_command ("SSUBSCRIBE" , * new_s_channels .keys ())
1754+ # update the s_channels dict AFTER we send the command. we don't want to
1755+ # subscribe twice to these channels, once for the command and again
1756+ # for the reconnection.
1757+ new_s_channels = self ._normalize_keys (new_s_channels )
1758+ self .shard_channels .update (new_s_channels )
1759+ if not self .subscribed :
1760+ # Set the subscribed_event flag to True
1761+ self .subscribed_event .set ()
1762+ # Clear the health check counter
1763+ self .health_check_response_counter = 0
1764+ self .pending_unsubscribe_shard_channels .difference_update (new_s_channels )
1765+ return ret_val
1766+
1767+ def sunsubscribe (self , * args , target_node = None ):
1768+ """
1769+ Unsubscribe from the supplied shard_channels. If empty, unsubscribe from
1770+ all shard_channels
1771+ """
1772+ if args :
1773+ args = list_or_args (args [0 ], args [1 :])
1774+ s_channels = self ._normalize_keys (dict .fromkeys (args ))
1775+ else :
1776+ s_channels = self .shard_channels
1777+ self .pending_unsubscribe_shard_channels .update (s_channels )
1778+ return self .execute_command ("SUNSUBSCRIBE" , * args )
1779+
17311780 def listen (self ):
17321781 "Listen for messages on channels this client has been subscribed to"
17331782 while self .subscribed :
@@ -1762,6 +1811,8 @@ def get_message(self, ignore_subscribe_messages=False, timeout=0.0):
17621811 return self .handle_message (response , ignore_subscribe_messages )
17631812 return None
17641813
1814+ get_sharded_message = get_message
1815+
17651816 def ping (self , message = None ):
17661817 """
17671818 Ping the Redis server
@@ -1809,12 +1860,17 @@ def handle_message(self, response, ignore_subscribe_messages=False):
18091860 if pattern in self .pending_unsubscribe_patterns :
18101861 self .pending_unsubscribe_patterns .remove (pattern )
18111862 self .patterns .pop (pattern , None )
1863+ elif message_type == "sunsubscribe" :
1864+ s_channel = response [1 ]
1865+ if s_channel in self .pending_unsubscribe_shard_channels :
1866+ self .pending_unsubscribe_shard_channels .remove (s_channel )
1867+ self .shard_channels .pop (s_channel , None )
18121868 else :
18131869 channel = response [1 ]
18141870 if channel in self .pending_unsubscribe_channels :
18151871 self .pending_unsubscribe_channels .remove (channel )
18161872 self .channels .pop (channel , None )
1817- if not self .channels and not self .patterns :
1873+ if not self .channels and not self .patterns and not self . shard_channels :
18181874 # There are no subscriptions anymore, set subscribed_event flag
18191875 # to false
18201876 self .subscribed_event .clear ()
@@ -1823,6 +1879,8 @@ def handle_message(self, response, ignore_subscribe_messages=False):
18231879 # if there's a message handler, invoke it
18241880 if message_type == "pmessage" :
18251881 handler = self .patterns .get (message ["pattern" ], None )
1882+ elif message_type == "smessage" :
1883+ handler = self .shard_channels .get (message ["channel" ], None )
18261884 else :
18271885 handler = self .channels .get (message ["channel" ], None )
18281886 if handler :
@@ -1843,6 +1901,11 @@ def run_in_thread(self, sleep_time=0, daemon=False, exception_handler=None):
18431901 for pattern , handler in self .patterns .items ():
18441902 if handler is None :
18451903 raise PubSubError (f"Pattern: '{ pattern } ' has no handler registered" )
1904+ for s_channel , handler in self .shard_channels .items ():
1905+ if handler is None :
1906+ raise PubSubError (
1907+ f"Shard Channel: '{ s_channel } ' has no handler registered"
1908+ )
18461909
18471910 thread = PubSubWorkerThread (
18481911 self , sleep_time , daemon = daemon , exception_handler = exception_handler
0 commit comments