@@ -815,7 +815,7 @@ async def on_connect(self) -> None:
815815 if str_if_bytes (await self .read_response ()) != "OK" :
816816 raise ConnectionError ("Invalid Database" )
817817
818- async def disconnect (self ) -> None :
818+ async def disconnect (self , nowait : bool = False ) -> None :
819819 """Disconnects from the Redis server"""
820820 try :
821821 async with async_timeout .timeout (self .socket_connect_timeout ):
@@ -825,8 +825,9 @@ async def disconnect(self) -> None:
825825 try :
826826 if os .getpid () == self .pid :
827827 self ._writer .close () # type: ignore[union-attr]
828- # py3.6 doesn't have this method
829- if hasattr (self ._writer , "wait_closed" ):
828+ # wait for close to finish, except when handling errors and
829+ # forcecully disconnecting.
830+ if not nowait :
830831 await self ._writer .wait_closed () # type: ignore[union-attr]
831832 except OSError :
832833 pass
@@ -927,10 +928,10 @@ async def read_response(self, disable_decoding: bool = False):
927928 disable_decoding = disable_decoding
928929 )
929930 except asyncio .TimeoutError :
930- await self .disconnect ()
931+ await self .disconnect (nowait = True )
931932 raise TimeoutError (f"Timeout reading from { self .host } :{ self .port } " )
932933 except OSError as e :
933- await self .disconnect ()
934+ await self .disconnect (nowait = True )
934935 raise ConnectionError (
935936 f"Error while reading from { self .host } :{ self .port } : { e .args } "
936937 )
@@ -939,7 +940,7 @@ async def read_response(self, disable_decoding: bool = False):
939940 # is subclass of Exception, not BaseException
940941 raise
941942 except Exception :
942- await self .disconnect ()
943+ await self .disconnect (nowait = True )
943944 raise
944945
945946 if self .health_check_interval :
0 commit comments