3737)
3838from redis .asyncio .lock import Lock
3939from redis .asyncio .retry import Retry
40+ from redis .cache import (
41+ DEFAULT_BLACKLIST ,
42+ DEFAULT_EVICTION_POLICY ,
43+ DEFAULT_WHITELIST ,
44+ _LocalCache ,
45+ )
4046from redis .client import (
4147 EMPTY_RESPONSE ,
4248 NEVER_DECODE ,
6066 TimeoutError ,
6167 WatchError ,
6268)
63- from redis .typing import ChannelT , EncodableT , KeyT
69+ from redis .typing import ChannelT , EncodableT , KeysT , KeyT , ResponseT
6470from redis .utils import (
6571 HIREDIS_AVAILABLE ,
6672 _set_info_logger ,
@@ -231,6 +237,13 @@ def __init__(
231237 redis_connect_func = None ,
232238 credential_provider : Optional [CredentialProvider ] = None ,
233239 protocol : Optional [int ] = 2 ,
240+ cache_enable : bool = False ,
241+ client_cache : Optional [_LocalCache ] = None ,
242+ cache_max_size : int = 100 ,
243+ cache_ttl : int = 0 ,
244+ cache_eviction_policy : str = DEFAULT_EVICTION_POLICY ,
245+ cache_blacklist : List [str ] = DEFAULT_BLACKLIST ,
246+ cache_whitelist : List [str ] = DEFAULT_WHITELIST ,
234247 ):
235248 """
236249 Initialize a new Redis client.
@@ -336,6 +349,16 @@ def __init__(
336349 # on a set of redis commands
337350 self ._single_conn_lock = asyncio .Lock ()
338351
352+ self .client_cache = client_cache
353+ if cache_enable :
354+ self .client_cache = _LocalCache (
355+ cache_max_size , cache_ttl , cache_eviction_policy
356+ )
357+ if self .client_cache is not None :
358+ self .cache_blacklist = cache_blacklist
359+ self .cache_whitelist = cache_whitelist
360+ self .client_cache_initialized = False
361+
339362 def __repr__ (self ):
340363 return (
341364 f"<{ self .__class__ .__module__ } .{ self .__class__ .__name__ } "
@@ -350,6 +373,10 @@ async def initialize(self: _RedisT) -> _RedisT:
350373 async with self ._single_conn_lock :
351374 if self .connection is None :
352375 self .connection = await self .connection_pool .get_connection ("_" )
376+ if self .client_cache is not None :
377+ self .connection ._parser .set_invalidation_push_handler (
378+ self ._cache_invalidation_process
379+ )
353380 return self
354381
355382 def set_response_callback (self , command : str , callback : ResponseCallbackT ):
@@ -568,6 +595,8 @@ async def aclose(self, close_connection_pool: Optional[bool] = None) -> None:
568595 close_connection_pool is None and self .auto_close_connection_pool
569596 ):
570597 await self .connection_pool .disconnect ()
598+ if self .client_cache :
599+ self .client_cache .flush ()
571600
572601 @deprecated_function (version = "5.0.1" , reason = "Use aclose() instead" , name = "close" )
573602 async def close (self , close_connection_pool : Optional [bool ] = None ) -> None :
@@ -596,29 +625,95 @@ async def _disconnect_raise(self, conn: Connection, error: Exception):
596625 ):
597626 raise error
598627
628+ def _cache_invalidation_process (
629+ self , data : List [Union [str , Optional [List [str ]]]]
630+ ) -> None :
631+ """
632+ Invalidate (delete) all redis commands associated with a specific key.
633+ `data` is a list of strings, where the first string is the invalidation message
634+ and the second string is the list of keys to invalidate.
635+ (if the list of keys is None, then all keys are invalidated)
636+ """
637+ if data [1 ] is not None :
638+ for key in data [1 ]:
639+ self .client_cache .invalidate (str_if_bytes (key ))
640+ else :
641+ self .client_cache .flush ()
642+
643+ async def _get_from_local_cache (self , command : str ):
644+ """
645+ If the command is in the local cache, return the response
646+ """
647+ if (
648+ self .client_cache is None
649+ or command [0 ] in self .cache_blacklist
650+ or command [0 ] not in self .cache_whitelist
651+ ):
652+ return None
653+ while not self .connection ._is_socket_empty ():
654+ await self .connection .read_response (push_request = True )
655+ return self .client_cache .get (command )
656+
657+ def _add_to_local_cache (
658+ self , command : Tuple [str ], response : ResponseT , keys : List [KeysT ]
659+ ):
660+ """
661+ Add the command and response to the local cache if the command
662+ is allowed to be cached
663+ """
664+ if (
665+ self .client_cache is not None
666+ and (self .cache_blacklist == [] or command [0 ] not in self .cache_blacklist )
667+ and (self .cache_whitelist == [] or command [0 ] in self .cache_whitelist )
668+ ):
669+ self .client_cache .set (command , response , keys )
670+
671+ def delete_from_local_cache (self , command : str ):
672+ """
673+ Delete the command from the local cache
674+ """
675+ try :
676+ self .client_cache .delete (command )
677+ except AttributeError :
678+ pass
679+
599680 # COMMAND EXECUTION AND PROTOCOL PARSING
600681 async def execute_command (self , * args , ** options ):
601682 """Execute a command and return a parsed response"""
602683 await self .initialize ()
603- options .pop ("keys" , None ) # the keys are used only for client side caching
604- pool = self .connection_pool
605684 command_name = args [0 ]
606- conn = self .connection or await pool .get_connection (command_name , ** options )
685+ keys = options .pop ("keys" , None ) # keys are used only for client side caching
686+ response_from_cache = await self ._get_from_local_cache (args )
687+ if response_from_cache is not None :
688+ return response_from_cache
689+ else :
690+ pool = self .connection_pool
691+ conn = self .connection or await pool .get_connection (command_name , ** options )
607692
608- if self .single_connection_client :
609- await self ._single_conn_lock .acquire ()
610- try :
611- return await conn .retry .call_with_retry (
612- lambda : self ._send_command_parse_response (
613- conn , command_name , * args , ** options
614- ),
615- lambda error : self ._disconnect_raise (conn , error ),
616- )
617- finally :
618693 if self .single_connection_client :
619- self ._single_conn_lock .release ()
620- if not self .connection :
621- await pool .release (conn )
694+ await self ._single_conn_lock .acquire ()
695+ try :
696+ if self .client_cache is not None and not self .client_cache_initialized :
697+ await conn .retry .call_with_retry (
698+ lambda : self ._send_command_parse_response (
699+ conn , "CLIENT" , * ("CLIENT" , "TRACKING" , "ON" )
700+ ),
701+ lambda error : self ._disconnect_raise (conn , error ),
702+ )
703+ self .client_cache_initialized = True
704+ response = await conn .retry .call_with_retry (
705+ lambda : self ._send_command_parse_response (
706+ conn , command_name , * args , ** options
707+ ),
708+ lambda error : self ._disconnect_raise (conn , error ),
709+ )
710+ self ._add_to_local_cache (args , response , keys )
711+ return response
712+ finally :
713+ if self .single_connection_client :
714+ self ._single_conn_lock .release ()
715+ if not self .connection :
716+ await pool .release (conn )
622717
623718 async def parse_response (
624719 self , connection : Connection , command_name : Union [str , bytes ], ** options
@@ -866,7 +961,7 @@ async def connect(self):
866961 else :
867962 await self .connection .connect ()
868963 if self .push_handler_func is not None and not HIREDIS_AVAILABLE :
869- self .connection ._parser .set_push_handler (self .push_handler_func )
964+ self .connection ._parser .set_pubsub_push_handler (self .push_handler_func )
870965
871966 async def _disconnect_raise_connect (self , conn , error ):
872967 """
0 commit comments