3737)
3838from redis .asyncio .lock import Lock
3939from redis .asyncio .retry import Retry
40+ from redis .cache import (
41+ DEFAULT_BLACKLIST ,
42+ DEFAULT_EVICTION_POLICY ,
43+ DEFAULT_WHITELIST ,
44+ _LocalCache ,
45+ )
4046from redis .client import (
4147 EMPTY_RESPONSE ,
4248 NEVER_DECODE ,
6066 TimeoutError ,
6167 WatchError ,
6268)
63- from redis .typing import ChannelT , EncodableT , KeyT
69+ from redis .typing import ChannelT , EncodableT , KeysT , KeyT , ResponseT
6470from redis .utils import (
6571 HIREDIS_AVAILABLE ,
6672 _set_info_logger ,
@@ -231,6 +237,13 @@ def __init__(
231237 redis_connect_func = None ,
232238 credential_provider : Optional [CredentialProvider ] = None ,
233239 protocol : Optional [int ] = 2 ,
240+ cache_enable : bool = False ,
241+ client_cache : Optional [_LocalCache ] = None ,
242+ cache_max_size : int = 100 ,
243+ cache_ttl : int = 0 ,
244+ cache_eviction_policy : str = DEFAULT_EVICTION_POLICY ,
245+ cache_blacklist : List [str ] = DEFAULT_BLACKLIST ,
246+ cache_whitelist : List [str ] = DEFAULT_WHITELIST ,
234247 ):
235248 """
236249 Initialize a new Redis client.
@@ -336,6 +349,16 @@ def __init__(
336349 # on a set of redis commands
337350 self ._single_conn_lock = asyncio .Lock ()
338351
352+ self .client_cache = client_cache
353+ if cache_enable :
354+ self .client_cache = _LocalCache (
355+ cache_max_size , cache_ttl , cache_eviction_policy
356+ )
357+ if self .client_cache is not None :
358+ self .cache_blacklist = cache_blacklist
359+ self .cache_whitelist = cache_whitelist
360+ self .client_cache_initialized = False
361+
339362 def __repr__ (self ):
340363 return f"{ self .__class__ .__name__ } <{ self .connection_pool !r} >"
341364
@@ -347,6 +370,10 @@ async def initialize(self: _RedisT) -> _RedisT:
347370 async with self ._single_conn_lock :
348371 if self .connection is None :
349372 self .connection = await self .connection_pool .get_connection ("_" )
373+ if self .client_cache is not None :
374+ self .connection ._parser .set_invalidation_push_handler (
375+ self ._cache_invalidation_process
376+ )
350377 return self
351378
352379 def set_response_callback (self , command : str , callback : ResponseCallbackT ):
@@ -565,6 +592,8 @@ async def aclose(self, close_connection_pool: Optional[bool] = None) -> None:
565592 close_connection_pool is None and self .auto_close_connection_pool
566593 ):
567594 await self .connection_pool .disconnect ()
595+ if self .client_cache :
596+ self .client_cache .flush ()
568597
569598 @deprecated_function (version = "5.0.1" , reason = "Use aclose() instead" , name = "close" )
570599 async def close (self , close_connection_pool : Optional [bool ] = None ) -> None :
@@ -593,29 +622,95 @@ async def _disconnect_raise(self, conn: Connection, error: Exception):
593622 ):
594623 raise error
595624
625+ def _cache_invalidation_process (
626+ self , data : List [Union [str , Optional [List [str ]]]]
627+ ) -> None :
628+ """
629+ Invalidate (delete) all redis commands associated with a specific key.
630+ `data` is a list of strings, where the first string is the invalidation message
631+ and the second string is the list of keys to invalidate.
632+ (if the list of keys is None, then all keys are invalidated)
633+ """
634+ if data [1 ] is not None :
635+ for key in data [1 ]:
636+ self .client_cache .invalidate (str_if_bytes (key ))
637+ else :
638+ self .client_cache .flush ()
639+
640+ async def _get_from_local_cache (self , command : str ):
641+ """
642+ If the command is in the local cache, return the response
643+ """
644+ if (
645+ self .client_cache is None
646+ or command [0 ] in self .cache_blacklist
647+ or command [0 ] not in self .cache_whitelist
648+ ):
649+ return None
650+ while not self .connection ._is_socket_empty ():
651+ await self .connection .read_response (push_request = True )
652+ return self .client_cache .get (command )
653+
654+ def _add_to_local_cache (
655+ self , command : Tuple [str ], response : ResponseT , keys : List [KeysT ]
656+ ):
657+ """
658+ Add the command and response to the local cache if the command
659+ is allowed to be cached
660+ """
661+ if (
662+ self .client_cache is not None
663+ and (self .cache_blacklist == [] or command [0 ] not in self .cache_blacklist )
664+ and (self .cache_whitelist == [] or command [0 ] in self .cache_whitelist )
665+ ):
666+ self .client_cache .set (command , response , keys )
667+
668+ def delete_from_local_cache (self , command : str ):
669+ """
670+ Delete the command from the local cache
671+ """
672+ try :
673+ self .client_cache .delete (command )
674+ except AttributeError :
675+ pass
676+
596677 # COMMAND EXECUTION AND PROTOCOL PARSING
597678 async def execute_command (self , * args , ** options ):
598679 """Execute a command and return a parsed response"""
599680 await self .initialize ()
600- options .pop ("keys" , None ) # the keys are used only for client side caching
601- pool = self .connection_pool
602681 command_name = args [0 ]
603- conn = self .connection or await pool .get_connection (command_name , ** options )
682+ keys = options .pop ("keys" , None ) # keys are used only for client side caching
683+ response_from_cache = await self ._get_from_local_cache (args )
684+ if response_from_cache is not None :
685+ return response_from_cache
686+ else :
687+ pool = self .connection_pool
688+ conn = self .connection or await pool .get_connection (command_name , ** options )
604689
605- if self .single_connection_client :
606- await self ._single_conn_lock .acquire ()
607- try :
608- return await conn .retry .call_with_retry (
609- lambda : self ._send_command_parse_response (
610- conn , command_name , * args , ** options
611- ),
612- lambda error : self ._disconnect_raise (conn , error ),
613- )
614- finally :
615690 if self .single_connection_client :
616- self ._single_conn_lock .release ()
617- if not self .connection :
618- await pool .release (conn )
691+ await self ._single_conn_lock .acquire ()
692+ try :
693+ if self .client_cache is not None and not self .client_cache_initialized :
694+ await conn .retry .call_with_retry (
695+ lambda : self ._send_command_parse_response (
696+ conn , "CLIENT" , * ("CLIENT" , "TRACKING" , "ON" )
697+ ),
698+ lambda error : self ._disconnect_raise (conn , error ),
699+ )
700+ self .client_cache_initialized = True
701+ response = await conn .retry .call_with_retry (
702+ lambda : self ._send_command_parse_response (
703+ conn , command_name , * args , ** options
704+ ),
705+ lambda error : self ._disconnect_raise (conn , error ),
706+ )
707+ self ._add_to_local_cache (args , response , keys )
708+ return response
709+ finally :
710+ if self .single_connection_client :
711+ self ._single_conn_lock .release ()
712+ if not self .connection :
713+ await pool .release (conn )
619714
620715 async def parse_response (
621716 self , connection : Connection , command_name : Union [str , bytes ], ** options
@@ -863,7 +958,7 @@ async def connect(self):
863958 else :
864959 await self .connection .connect ()
865960 if self .push_handler_func is not None and not HIREDIS_AVAILABLE :
866- self .connection ._parser .set_push_handler (self .push_handler_func )
961+ self .connection ._parser .set_pubsub_push_handler (self .push_handler_func )
867962
868963 async def _disconnect_raise_connect (self , conn , error ):
869964 """
0 commit comments