diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 4b235c596ed6..607d15d6f7d9 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -27,8 +27,7 @@ from vllm.usage.usage_lib import UsageContext from vllm.utils import Device, cdiv from vllm.v1.engine import EngineCoreRequest -from vllm.v1.engine.core_client import (AsyncMPClient, DPAsyncMPClient, - RayDPClient) +from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError from vllm.v1.engine.output_processor import (OutputProcessor, RequestOutputCollector) @@ -120,15 +119,8 @@ def __init__( log_stats=self.log_stats) # EngineCore (starts the engine in background process). - core_client_class: type[AsyncMPClient] - if vllm_config.parallel_config.data_parallel_size == 1: - core_client_class = AsyncMPClient - elif vllm_config.parallel_config.data_parallel_backend == "ray": - core_client_class = RayDPClient - else: - core_client_class = DPAsyncMPClient - - self.engine_core = core_client_class( + + self.engine_core = EngineCoreClient.make_async_mp_client( vllm_config=vllm_config, executor_class=executor_class, log_stats=self.log_stats, diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index fa01998aa9fe..c89ba032a1a2 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -67,18 +67,31 @@ def make_client( "is not currently supported.") if multiprocess_mode and asyncio_mode: - if vllm_config.parallel_config.data_parallel_size > 1: - if vllm_config.parallel_config.data_parallel_backend == "ray": - return RayDPClient(vllm_config, executor_class, log_stats) - return DPAsyncMPClient(vllm_config, executor_class, log_stats) - - return AsyncMPClient(vllm_config, executor_class, log_stats) + return EngineCoreClient.make_async_mp_client( + vllm_config, executor_class, log_stats) if multiprocess_mode and not asyncio_mode: return SyncMPClient(vllm_config, executor_class, log_stats) return InprocClient(vllm_config, executor_class, log_stats) + @staticmethod + def make_async_mp_client( + vllm_config: VllmConfig, + executor_class: type[Executor], + log_stats: bool, + client_addresses: Optional[dict[str, str]] = None, + client_index: int = 0, + ) -> "MPClient": + if vllm_config.parallel_config.data_parallel_size > 1: + if vllm_config.parallel_config.data_parallel_backend == "ray": + return RayDPClient(vllm_config, executor_class, log_stats, + client_addresses, client_index) + return DPAsyncMPClient(vllm_config, executor_class, log_stats, + client_addresses, client_index) + return AsyncMPClient(vllm_config, executor_class, log_stats, + client_addresses, client_index) + @abstractmethod def shutdown(self): ...