88import uuid
99import weakref
1010from abc import ABC , abstractmethod
11- from collections .abc import Awaitable , Sequence
11+ from collections .abc import Awaitable
1212from concurrent .futures import Future
1313from dataclasses import dataclass , field
1414from threading import Thread
@@ -243,15 +243,12 @@ def __init__(
243243 vllm_config : VllmConfig ,
244244 executor_class : type [Executor ],
245245 log_stats : bool ,
246- ctx : Union [ zmq . Context , zmq . asyncio . Context ] ,
246+ input_path : str ,
247247 output_path : str ,
248248 index : int = 0 ,
249249 local_dp_rank : int = 0 ,
250250 ):
251- # Paths and sockets for IPC.
252- input_path = get_open_zmq_ipc_path ()
253- self .input_socket = make_zmq_socket (ctx , input_path ,
254- zmq .constants .PUSH )
251+ self .identity = index .to_bytes (length = 2 , byteorder = "little" )
255252 try :
256253 # Start EngineCore in background process.
257254 self .proc_handle = BackgroundProcHandle (
@@ -273,14 +270,9 @@ def __init__(
273270 # Ensure socket is closed if process fails to start.
274271 self .close ()
275272
276- def send_multipart (self , msg_parts : Sequence ):
277- return self .input_socket .send_multipart (msg_parts , copy = False )
278-
279273 def close (self ):
280274 if proc_handle := getattr (self , "proc_handle" , None ):
281275 proc_handle .shutdown ()
282- if socket := getattr (self , "input_socket" , None ):
283- socket .close (linger = 0 )
284276
285277
286278@dataclass
@@ -291,6 +283,7 @@ class BackgroundResources:
291283 ctx : Union [zmq .Context ]
292284 core_engines : list [CoreEngine ] = field (default_factory = list )
293285 output_socket : Optional [Union [zmq .Socket , zmq .asyncio .Socket ]] = None
286+ input_socket : Optional [Union [zmq .Socket , zmq .asyncio .Socket ]] = None
294287 shutdown_path : Optional [str ] = None
295288
296289 def __call__ (self ):
@@ -303,6 +296,8 @@ def __call__(self):
303296 # aren't explicitly closed first.
304297 if self .output_socket is not None :
305298 self .output_socket .close (linger = 0 )
299+ if self .input_socket is not None :
300+ self .input_socket .close (linger = 0 )
306301 if self .shutdown_path is not None :
307302 # We must ensure that the sync output socket is
308303 # closed cleanly in its own thread.
@@ -369,10 +364,16 @@ def sigusr1_handler(signum, frame):
369364
370365 # Paths and sockets for IPC.
371366 self .output_path = get_open_zmq_ipc_path ()
367+ input_path = get_open_zmq_ipc_path ()
368+ self .input_socket = make_zmq_socket (self .ctx ,
369+ input_path ,
370+ zmq .ROUTER ,
371+ bind = True )
372+ self .resources .input_socket = self .input_socket
372373
373374 new_core_engine = lambda index , local_dp_rank = None : CoreEngine (
374- vllm_config , executor_class , log_stats , self . ctx , self .output_path ,
375- index , local_dp_rank )
375+ vllm_config , executor_class , log_stats , input_path , self .
376+ output_path , index , local_dp_rank )
376377
377378 # Start engine core process(es).
378379 self ._init_core_engines (vllm_config , new_core_engine ,
@@ -476,9 +477,10 @@ def get_output(self) -> EngineCoreOutputs:
476477 return self .outputs_queue .get ()
477478
478479 def _send_input (self , request_type : EngineCoreRequestType , request : Any ):
479- # (RequestType, SerializedRequest)
480- msg = (request_type .value , self .encoder .encode (request ))
481- self .core_engine .send_multipart (msg )
480+ # (Identity, RequestType, SerializedRequest)
481+ msg = (self .core_engine .identity , request_type .value ,
482+ self .encoder .encode (request ))
483+ self .input_socket .send_multipart (msg , copy = False )
482484
483485 def call_utility (self , method : str , * args ) -> Any :
484486 call_id = uuid .uuid1 ().int >> 64
@@ -601,30 +603,34 @@ async def get_output_async(self) -> EngineCoreOutputs:
601603 assert self .outputs_queue is not None
602604 return await self .outputs_queue .get ()
603605
604- async def _send_input (self , request_type : EngineCoreRequestType ,
605- request : Any ) -> None :
606- await self .core_engine .send_multipart (
607- (request_type .value , self .encoder .encode (request )))
606+ def _send_input (self ,
607+ request_type : EngineCoreRequestType ,
608+ request : Any ,
609+ engine : Optional [CoreEngine ] = None ) -> Awaitable [None ]:
610+ if engine is None :
611+ engine = self .core_engine
608612
609- self ._ensure_output_queue_task ()
613+ message = (request_type .value , self .encoder .encode (request ))
614+ return self ._send_input_message (message , engine )
615+
616+ def _send_input_message (self , message : tuple [bytes , bytes ],
617+ engine : CoreEngine ) -> Awaitable [None ]:
618+ message = (engine .identity , ) + message # type: ignore[assignment]
619+ return self .input_socket .send_multipart (message , copy = False )
610620
611621 async def call_utility_async (self , method : str , * args ) -> Any :
612622 return await self ._call_utility_async (method ,
613623 * args ,
614624 engine = self .core_engine )
615625
616- async def _call_utility_async (
617- self ,
618- method : str ,
619- * args ,
620- engine : CoreEngine ,
621- ) -> Any :
626+ async def _call_utility_async (self , method : str , * args ,
627+ engine : CoreEngine ) -> Any :
622628 call_id = uuid .uuid1 ().int >> 64
623629 future = asyncio .get_running_loop ().create_future ()
624630 self .utility_results [call_id ] = future
625631 message = (EngineCoreRequestType .UTILITY .value ,
626632 self .encoder .encode ((call_id , method , args )))
627- await engine . send_multipart (message )
633+ await self . _send_input_message (message , engine )
628634 self ._ensure_output_queue_task ()
629635 return await future
630636
@@ -633,6 +639,7 @@ async def add_request_async(self, request: EngineCoreRequest) -> None:
633639 # tokenized.
634640 request .prompt = None
635641 await self ._send_input (EngineCoreRequestType .ADD , request )
642+ self ._ensure_output_queue_task ()
636643
637644 async def abort_requests_async (self , request_ids : list [str ]) -> None :
638645 if len (request_ids ) > 0 :
@@ -730,15 +737,15 @@ async def add_request_async(self, request: EngineCoreRequest) -> None:
730737 self .reqs_in_flight [request .request_id ] = chosen_engine
731738 chosen_engine .num_reqs_in_flight += 1
732739 if self .num_engines_running >= len (self .core_engines ):
733- await chosen_engine . send_multipart (msg )
740+ await self . _send_input_message (msg , chosen_engine )
734741 else :
735742 # Send request to chosen engine and dp start loop
736743 # control message to all other engines.
737744 self .num_engines_running += len (self .core_engines )
738745 await asyncio .gather (* [
739- engine . send_multipart ( msg if engine is
740- chosen_engine else self .start_dp_msg )
741- for engine in self .core_engines
746+ self . _send_input_message (
747+ msg if engine is chosen_engine else self .start_dp_msg ,
748+ engine ) for engine in self .core_engines
742749 ])
743750
744751 self ._ensure_output_queue_task ()
@@ -763,7 +770,7 @@ async def process_engine_outputs(self: "DPAsyncMPClient",
763770 # sure to start the other engines:
764771 self .num_engines_running = len (self .core_engines )
765772 coros = [
766- engine . send_multipart (self .start_dp_msg )
773+ self . _send_input_message (self .start_dp_msg , engine )
767774 for engine in self .core_engines
768775 if not engine .num_reqs_in_flight
769776 ]
@@ -789,5 +796,5 @@ async def abort_requests_async(self, request_ids: list[str]) -> None:
789796
790797 async def _abort_requests (self , request_ids : list [str ],
791798 engine : CoreEngine ) -> None :
792- await engine . send_multipart (( EngineCoreRequestType .ABORT . value ,
793- self . encoder . encode ( request_ids )) )
799+ await self . _send_input ( EngineCoreRequestType .ABORT , request_ids ,
800+ engine )
0 commit comments