@@ -826,7 +826,7 @@ async def on_connect(self) -> None:
826826 if str_if_bytes (await self .read_response ()) != "OK" :
827827 raise ConnectionError ("Invalid Database" )
828828
829- async def disconnect (self ) -> None :
829+ async def disconnect (self , nowait : bool = False ) -> None :
830830 """Disconnects from the Redis server"""
831831 try :
832832 async with async_timeout .timeout (self .socket_connect_timeout ):
@@ -836,8 +836,9 @@ async def disconnect(self) -> None:
836836 try :
837837 if os .getpid () == self .pid :
838838 self ._writer .close () # type: ignore[union-attr]
839- # py3.6 doesn't have this method
840- if hasattr (self ._writer , "wait_closed" ):
839+ # wait for close to finish, except when handling errors and
840+ # forcecully disconnecting.
841+ if not nowait :
841842 await self ._writer .wait_closed () # type: ignore[union-attr]
842843 except OSError :
843844 pass
@@ -938,10 +939,10 @@ async def read_response(self, disable_decoding: bool = False):
938939 disable_decoding = disable_decoding
939940 )
940941 except asyncio .TimeoutError :
941- await self .disconnect ()
942+ await self .disconnect (nowait = True )
942943 raise TimeoutError (f"Timeout reading from { self .host } :{ self .port } " )
943944 except OSError as e :
944- await self .disconnect ()
945+ await self .disconnect (nowait = True )
945946 raise ConnectionError (
946947 f"Error while reading from { self .host } :{ self .port } : { e .args } "
947948 )
@@ -950,7 +951,7 @@ async def read_response(self, disable_decoding: bool = False):
950951 # is subclass of Exception, not BaseException
951952 raise
952953 except Exception :
953- await self .disconnect ()
954+ await self .disconnect (nowait = True )
954955 raise
955956
956957 if self .health_check_interval :
0 commit comments