11import asyncio
2- from contextlib import contextmanager
2+ from contextlib import contextmanager , suppress
33from typing import Any , AsyncGenerator , Mapping , Optional
44from uuid import uuid4
55
1111 ParallelConfig , SchedulerConfig )
1212# yapf: disable
1313from vllm .entrypoints .openai .rpc import (RPC_REQUEST_TYPE ,
14- VLLM_RPC_HEALTH_TIMEOUT_MS ,
15- VLLM_RPC_SERVER_START_TIMEOUT_MS ,
1614 VLLM_RPC_SOCKET_LIMIT_CUTOFF ,
1715 VLLM_RPC_SUCCESS_STR ,
1816 VLLM_RPC_ZMQ_HWM , RPCAbortRequest ,
1917 RPCGenerateRequest , RPCUtilityRequest )
2018# yapf: enable
19+ from vllm .envs import VLLM_RPC_GET_DATA_TIMEOUT_MS
2120from vllm .inputs import PromptInputs
2221from vllm .logger import init_logger
2322from vllm .lora .request import LoRARequest
3231INPROC_PROXY_PATH = f"inproc://{ uuid4 ()} "
3332
3433
34+ class RPCClientClosedError (Exception ):
35+ """Exception class raised when the client is used post-close.
36+
37+ The client can be closed, which closes the ZMQ context. This normally
38+ happens on server shutdown. In some cases, methods like abort and
39+ do_log_stats will still be called and then try to open a socket, which
40+ causes a ZMQError and creates a huge stack trace.
41+ So, we throw this error such that we can suppress it.
42+ """
43+
44+
3545class AsyncEngineRPCClient :
3646 """
3747 RPCClient that connects to the RPCServer wrapping AsyncLLMEngine.
@@ -85,6 +95,8 @@ class AsyncEngineRPCClient:
8595
8696 def __init__ (self , rpc_path : str ):
8797 self .context = zmq .asyncio .Context ()
98+ self ._data_timeout = VLLM_RPC_GET_DATA_TIMEOUT_MS
99+ self ._errored = False
88100
89101 # Maximum number of sockets that can be opened (typically 65536).
90102 # ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get)
@@ -143,7 +155,6 @@ async def setup(self):
143155
144156 # Wait until server is ready.
145157 await self ._wait_for_server_rpc ()
146- self ._errored = False
147158
148159 # Get the configs.
149160 self .model_config = await self ._get_model_config_rpc ()
@@ -170,6 +181,15 @@ def close(self):
170181 @contextmanager
171182 def to_proxy_socket (self ):
172183 # Connect to the RPCServer via the proxy.
184+
185+ # Raise a sensible error if the client was already closed.
186+ # This can happen if a server shutdown is triggered but some coroutines
187+ # are still running requests.
188+ # There should not be a race condition with this check because we don't
189+ # yield to the event loop between here and opening the socket.
190+ if self .context .closed :
191+ raise RPCClientClosedError ("The ZMQ client has already shut down" )
192+
173193 # Note that we use DEALER to enable asynchronous communication
174194 # to enable streaming.
175195 socket = self .context .socket (zmq .constants .DEALER )
@@ -189,9 +209,18 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,
189209 # Ping RPCServer with a request.
190210 await socket .send_multipart ([cloudpickle .dumps (request )])
191211
212+ # Make sure the server responds
213+ if await socket .poll (timeout = self ._data_timeout ) == 0 :
214+ raise TimeoutError ("Server didn't reply within "
215+ f"{ self ._data_timeout } ms" )
216+
192217 # Await the data from the Server.
193218 data = cloudpickle .loads (await socket .recv ())
194219
220+ if isinstance (data , Exception ):
221+ # Re-raise exceptions returned by the server
222+ raise data
223+
195224 if not isinstance (data , expected_type ):
196225 # LoRAConfig can be None.
197226 if expected_type == LoRAConfig and data is None :
@@ -208,29 +237,28 @@ async def _send_one_way_rpc_request(
208237 self ,
209238 request : RPC_REQUEST_TYPE ,
210239 error_message : str ,
211- timeout : Optional [int ] = None ,
212240 socket : Optional [zmq .asyncio .Socket ] = None ):
213241 """Send one-way RPC request to trigger an action."""
214242
215243 async def do_rpc_call (socket : zmq .asyncio .Socket ,
216- request : RPC_REQUEST_TYPE ,
217- timeout = None ):
244+ request : RPC_REQUEST_TYPE ):
218245
219246 await socket .send_multipart ([cloudpickle .dumps (request )])
220247
221- if timeout is not None and await socket .poll (timeout = timeout ) == 0 :
222- raise TimeoutError (f"Server didn't reply within { timeout } ms" )
248+ if await socket .poll (timeout = self ._data_timeout ) == 0 :
249+ raise TimeoutError ("Server didn't reply within "
250+ f"{ self ._data_timeout } ms" )
223251
224252 return cloudpickle .loads (await socket .recv ())
225253
226254 # Make a new socket connection.
227255 if socket is None :
228256 with self .to_proxy_socket () as socket :
229- response = await do_rpc_call (socket , request , timeout )
257+ response = await do_rpc_call (socket , request )
230258
231259 # Use existing socket connection.
232260 else :
233- response = await do_rpc_call (socket , request , timeout )
261+ response = await do_rpc_call (socket , request )
234262
235263 if not isinstance (response , str ) or response != VLLM_RPC_SUCCESS_STR :
236264 if isinstance (response , Exception ):
@@ -255,8 +283,7 @@ async def _wait_for_server_rpc(self):
255283
256284 await self ._send_one_way_rpc_request (
257285 request = RPCUtilityRequest .IS_SERVER_READY ,
258- error_message = "Unable to start RPC Server" ,
259- timeout = VLLM_RPC_SERVER_START_TIMEOUT_MS )
286+ error_message = "Unable to start RPC Server" )
260287
261288 async def _get_model_config_rpc (self ) -> ModelConfig :
262289 """Get the ModelConfig object from the RPC Server"""
@@ -308,17 +335,17 @@ async def _is_tracing_enabled_rpc(self) -> bool:
308335
309336 async def abort (self , request_id : str ):
310337 """Send an ABORT_REQUEST signal to the RPC Server"""
311-
312- await self ._send_one_way_rpc_request (
313- request = RPCAbortRequest (request_id ),
314- error_message = f"RPCAbortRequest { request_id } failed" )
338+ with suppress ( RPCClientClosedError ):
339+ await self ._send_one_way_rpc_request (
340+ request = RPCAbortRequest (request_id ),
341+ error_message = f"RPCAbortRequest { request_id } failed" )
315342
316343 async def do_log_stats (self ):
317344 """Send a DO_LOG_STATS signal to the RPC Server"""
318-
319- await self ._send_one_way_rpc_request (
320- request = RPCUtilityRequest .DO_LOG_STATS ,
321- error_message = "RPCRequest DO_LOG_STATS failed." )
345+ with suppress ( RPCClientClosedError ):
346+ await self ._send_one_way_rpc_request (
347+ request = RPCUtilityRequest .DO_LOG_STATS ,
348+ error_message = "RPCRequest DO_LOG_STATS failed." )
322349
323350 @property
324351 def is_running (self ) -> bool :
@@ -393,7 +420,6 @@ async def check_health(self,
393420 await self ._send_one_way_rpc_request (
394421 request = RPCUtilityRequest .IS_SERVER_HEALTHY ,
395422 error_message = "Got Unhealthy response from RPC Server" ,
396- timeout = VLLM_RPC_HEALTH_TIMEOUT_MS ,
397423 socket = socket )
398424
399425 async def encode (self , * args ,
0 commit comments