diff --git a/tests/async_engine/test_async_llm_engine.py b/tests/async_engine/test_async_llm_engine.py index 48e2e31e5db8..b6f44871497c 100644 --- a/tests/async_engine/test_async_llm_engine.py +++ b/tests/async_engine/test_async_llm_engine.py @@ -41,7 +41,7 @@ def __init__(self): self.abort_request_calls = 0 self.request_id = None # Ugly, remove dependency when possible - self.parallel_config = ParallelConfig(1, 1, False) + self.parallel_config = ParallelConfig() self.model_config = MockModelConfig() async def step_async(self, virtual_engine): diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index fd8d1fd7ff48..452fe1e37e2c 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -18,9 +18,10 @@ from vllm.usage.usage_lib import UsageContext from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.core import EngineCore -from vllm.v1.engine.core_client import (AsyncMPClient, CoreEngine, - EngineCoreClient, SyncMPClient) +from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient, + SyncMPClient) from vllm.v1.executor.abstract import Executor +from vllm.v1.utils import CoreEngineProcManager from ...distributed.conftest import MockSubscriber from ...utils import create_new_process_for_each_test @@ -348,13 +349,13 @@ def test_startup_failure(monkeypatch: pytest.MonkeyPatch): # Monkey-patch to extract core process pid while it's starting. core_proc_pid = [None] - ce_ctor = CoreEngine.__init__ + cepm_ctor = CoreEngineProcManager.__init__ - def patched_ce_ctor(self, *args, **kwargs): - ce_ctor(self, *args, **kwargs) - core_proc_pid[0] = self.proc_handle.proc.pid + def patched_cepm_ctor(self: CoreEngineProcManager, *args, **kwargs): + cepm_ctor(self, *args, **kwargs) + core_proc_pid[0] = self.processes[0].pid - m.setattr(CoreEngine, "__init__", patched_ce_ctor) + m.setattr(CoreEngineProcManager, "__init__", patched_cepm_ctor) t = time.time() engine_args = EngineArgs(model=MODEL_NAME) diff --git a/vllm/config.py b/vllm/config.py index dd0791537b96..d8eabfb2e4f0 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1668,25 +1668,17 @@ class ParallelConfig: data_parallel_size: int = 1 """Number of data parallel groups. MoE layers will be sharded according to the product of the tensor parallel size and data parallel size.""" + data_parallel_size_local: int = 1 + """Number of local data parallel groups.""" data_parallel_rank: int = 0 """Rank of the data parallel group.""" - _data_parallel_rank_local: Optional[int] = field(default=None, init=False) - """Private field to store the local rank of the data parallel group.""" - - @property - def data_parallel_rank_local(self) -> int: - """Local rank of the data parallel group, defaults to global rank.""" - if self._data_parallel_rank_local is None: - return self.data_parallel_rank - return self._data_parallel_rank_local - - @data_parallel_rank_local.setter - def data_parallel_rank_local(self, value: int) -> None: - """Set the local rank of the data parallel group.""" - self._data_parallel_rank_local = value - + data_parallel_rank_local: Optional[int] = None + """Local rank of the data parallel group, + set only in SPMD mode.""" data_parallel_master_ip: str = "127.0.0.1" """IP of the data parallel master.""" + data_parallel_rpc_port: int = 29550 + """Port for data parallel messaging.""" data_parallel_master_port: int = 29500 """Port of the data parallel master.""" enable_expert_parallel: bool = False @@ -1734,13 +1726,16 @@ class is dynamically inherited by the worker class. This is used to inject world_size: int = field(init=False) """world_size is TPxPP, it affects the number of workers we create.""" - world_size_across_dp: int = field(init=False) - """world_size_across_dp is TPxPPxDP, it is the size of the world - including data parallelism.""" rank: int = 0 """Global rank in distributed setup.""" + @property + def world_size_across_dp(self) -> int: + """world_size_across_dp is TPxPPxDP, it is the size of the world + including data parallelism.""" + return self.world_size * self.data_parallel_size + def get_next_dp_init_port(self) -> int: """ We might need to initialize process groups in multiple @@ -1800,10 +1795,14 @@ def __post_init__(self) -> None: self.world_size = self.pipeline_parallel_size * \ self.tensor_parallel_size - if self.data_parallel_size > 1: + if self.data_parallel_size_local > self.data_parallel_size: + raise ValueError( + f"data_parallel_size_local ({self.data_parallel_size_local}) " + f"must be <= data_parallel_size ({self.data_parallel_size})") + + if self.data_parallel_size > 1 or self.data_parallel_size_local == 0: # Data parallel was specified in the engine args. self.data_parallel_master_port = get_open_port() - # TODO multi-node else: # Otherwise fall back to env vars (e.g. for offline SPMD case). self.data_parallel_size = envs.VLLM_DP_SIZE @@ -1812,8 +1811,6 @@ def __post_init__(self) -> None: self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT - self.world_size_across_dp = self.world_size * self.data_parallel_size - if self.distributed_executor_backend == "external_launcher": import os os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index e4d4008cd0a6..a8f292c6e31f 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -22,6 +22,7 @@ import vllm.envs as envs from vllm.logger import init_logger +from vllm.utils import get_tcp_uri logger = init_logger(__name__) @@ -303,7 +304,7 @@ def stateless_init_torch_distributed_process_group( always formed with process 1, 2, ..., 8, and the additional communication channel is formed with process 9 and 10. """ - init_method = f"tcp://{host}:{port}" + init_method = get_tcp_uri(host, port) backend = Backend(backend) # it is basically string timeout = _get_default_timeout(backend) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index bba05c4c3e1b..240142a1c5d1 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -283,6 +283,9 @@ class EngineArgs: pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size tensor_parallel_size: int = ParallelConfig.tensor_parallel_size data_parallel_size: int = ParallelConfig.data_parallel_size + data_parallel_size_local: Optional[int] = None + data_parallel_address: Optional[str] = None + data_parallel_rpc_port: Optional[int] = None enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel max_parallel_loading_workers: Optional[ int] = ParallelConfig.max_parallel_loading_workers @@ -596,6 +599,21 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: **parallel_kwargs["tensor_parallel_size"]) parallel_group.add_argument("--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"]) + parallel_group.add_argument('--data-parallel-size-local', + '-dpl', + type=int, + help='Number of data parallel replicas ' + 'to run on this node.') + parallel_group.add_argument('--data-parallel-address', + '-dpa', + type=str, + help='Address of data parallel cluster ' + 'head-node.') + parallel_group.add_argument('--data-parallel-rpc-port', + '-dpp', + type=int, + help='Port for data parallel RPC ' + 'communication.') parallel_group.add_argument( "--enable-expert-parallel", **parallel_kwargs["enable_expert_parallel"]) @@ -1019,10 +1037,30 @@ def create_engine_config( # but we should not do this here. placement_group = ray.util.get_current_placement_group() + # Local DP size defaults to global DP size if not set. + data_parallel_size_local = self.data_parallel_size if ( + self.data_parallel_size_local + is None) else self.data_parallel_size_local + + # DP address, used in multi-node case for torch distributed group + # and ZMQ sockets. + data_parallel_address = self.data_parallel_address if ( + self.data_parallel_address + is not None) else ParallelConfig.data_parallel_master_ip + + # This port is only used when there are remote data parallel engines, + # otherwise the local IPC transport is used. + data_parallel_rpc_port = self.data_parallel_rpc_port if ( + self.data_parallel_rpc_port + is not None) else ParallelConfig.data_parallel_rpc_port + parallel_config = ParallelConfig( pipeline_parallel_size=self.pipeline_parallel_size, tensor_parallel_size=self.tensor_parallel_size, data_parallel_size=self.data_parallel_size, + data_parallel_size_local=data_parallel_size_local, + data_parallel_master_ip=data_parallel_address, + data_parallel_rpc_port=data_parallel_rpc_port, enable_expert_parallel=self.enable_expert_parallel, max_parallel_loading_workers=self.max_parallel_loading_workers, disable_custom_all_reduce=self.disable_custom_all_reduce, diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index 5c8781b50d2c..04be7c033998 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -1,14 +1,24 @@ # SPDX-License-Identifier: Apache-2.0 import argparse +import signal import uvloop +import vllm.envs as envs +from vllm import AsyncEngineArgs from vllm.entrypoints.cli.types import CLISubcommand from vllm.entrypoints.openai.api_server import run_server from vllm.entrypoints.openai.cli_args import (make_arg_parser, validate_parsed_serve_args) -from vllm.utils import FlexibleArgumentParser +from vllm.logger import init_logger +from vllm.usage.usage_lib import UsageContext +from vllm.utils import FlexibleArgumentParser, get_tcp_uri +from vllm.v1.engine.core import EngineCoreProc +from vllm.v1.engine.core_client import CoreEngineProcManager +from vllm.v1.executor.abstract import Executor + +logger = init_logger(__name__) class ServeSubcommand(CLISubcommand): @@ -24,7 +34,10 @@ def cmd(args: argparse.Namespace) -> None: if hasattr(args, 'model_tag') and args.model_tag is not None: args.model = args.model_tag - uvloop.run(run_server(args)) + if args.headless: + run_headless(args) + else: + uvloop.run(run_server(args)) def validate(self, args: argparse.Namespace) -> None: validate_parsed_serve_args(args) @@ -42,6 +55,18 @@ def subparser_init( nargs='?', help="The model tag to serve " "(optional if specified in config)") + serve_parser.add_argument( + "--headless", + action='store_true', + default=False, + help="Run in headless mode. See multi-node data parallel " + "documentation for more details.") + serve_parser.add_argument( + '--data-parallel-start-rank', + '-dpr', + type=int, + default=0, + help='Starting data parallel rank for secondary nodes.') serve_parser.add_argument( "--config", type=str, @@ -57,3 +82,55 @@ def subparser_init( def cmd_init() -> list[CLISubcommand]: return [ServeSubcommand()] + + +def run_headless(args: argparse.Namespace): + + # Create the EngineConfig. + engine_args = AsyncEngineArgs.from_cli_args(args) + usage_context = UsageContext.OPENAI_API_SERVER + vllm_config = engine_args.create_engine_config(usage_context=usage_context) + + if not envs.VLLM_USE_V1: + raise RuntimeError("Headless mode is only supported for V1") + + parallel_config = vllm_config.parallel_config + local_engine_count = parallel_config.data_parallel_size_local + host = parallel_config.data_parallel_master_ip + port = engine_args.data_parallel_rpc_port # add to config too + input_address = get_tcp_uri(host, port) + + if local_engine_count <= 0: + raise RuntimeError("data_parallel_size_local must be > 0 in " + "headless mode") + + # Catch SIGTERM and SIGINT to allow graceful shutdown. + def signal_handler(signum, frame): + logger.debug("Received %d signal.", signum) + raise SystemExit + + signal.signal(signal.SIGTERM, signal_handler) + signal.signal(signal.SIGINT, signal_handler) + + logger.info( + "Launching %d data parallel engine(s) in headless mode, " + "with head node address %s.", local_engine_count, input_address) + + # Create the engines. + engine_manager = CoreEngineProcManager( + target_fn=EngineCoreProc.run_engine_core, + local_engine_count=local_engine_count, + start_index=args.data_parallel_start_rank, + local_start_index=0, + vllm_config=vllm_config, + on_head_node=False, + input_address=input_address, + executor_class=Executor.get_class(vllm_config), + log_stats=not engine_args.disable_log_stats, + ) + + try: + engine_manager.join_first() + finally: + logger.info("Shutting down.") + engine_manager.close() diff --git a/vllm/utils.py b/vllm/utils.py index 59635a25eb32..9a7da8067ba4 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -613,6 +613,10 @@ def is_valid_ipv6_address(address: str) -> bool: def get_distributed_init_method(ip: str, port: int) -> str: + return get_tcp_uri(ip, port) + + +def get_tcp_uri(ip: str, port: int) -> str: # Brackets are not permitted in ipv4 addresses, # see https://github.com/python/cpython/issues/103848 return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}" diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index bc410befbdad..edc79ae20b9f 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -1,5 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 -import json import os import queue import signal @@ -23,7 +22,7 @@ from vllm.lora.request import LoRARequest from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) -from vllm.utils import resolve_obj_by_qualname, zmq_socket_ctx +from vllm.utils import make_zmq_socket, resolve_obj_by_qualname, zmq_socket_ctx from vllm.v1.core.kv_cache_utils import (get_kv_cache_config, unify_kv_cache_configs) from vllm.v1.core.sched.interface import SchedulerInterface @@ -43,6 +42,7 @@ logger = init_logger(__name__) POLLING_TIMEOUT_S = 2.5 +HANDSHAKE_TIMEOUT_MINS = 5 _R = TypeVar('_R') # Return type for collective_rpc @@ -348,9 +348,9 @@ class EngineCoreProc(EngineCore): def __init__( self, - input_path: str, - output_path: str, vllm_config: VllmConfig, + on_head_node: bool, + input_address: str, executor_class: type[Executor], log_stats: bool, engine_index: int = 0, @@ -360,28 +360,91 @@ def __init__( executor_fail_callback = lambda: input_queue.put_nowait( (EngineCoreRequestType.EXECUTOR_FAILED, b'')) - super().__init__(vllm_config, executor_class, log_stats, - executor_fail_callback) - - self.step_fn = (self.step if self.batch_queue is None else - self.step_with_batch_queue) - self.engines_running = False - - # Background Threads and Queues for IO. These enable us to - # overlap ZMQ socket IO with GPU since they release the GIL, - # and to overlap some serialization/deserialization with the - # model forward pass. - # Threads handle Socket <-> Queues and core_busy_loop uses Queue. - self.input_queue = input_queue - self.output_queue = queue.Queue[Union[EngineCoreOutputs, bytes]]() - threading.Thread(target=self.process_input_socket, - args=(input_path, engine_index), - daemon=True).start() - self.output_thread = threading.Thread( - target=self.process_output_socket, - args=(output_path, engine_index), - daemon=True) - self.output_thread.start() + # Create input socket. + input_ctx = zmq.Context() + identity = engine_index.to_bytes(length=2, byteorder="little") + input_socket = make_zmq_socket(input_ctx, + input_address, + zmq.DEALER, + identity=identity, + bind=False) + try: + # Register engine with front-end. + output_address = self.startup_handshake( + input_socket, on_head_node, vllm_config.parallel_config) + + # Update config which may have changed from the handshake. + vllm_config.__post_init__() + + # Set up data parallel environment. + self._init_data_parallel(vllm_config) + + # Initialize engine core and model. + super().__init__(vllm_config, executor_class, log_stats, + executor_fail_callback) + + self.step_fn = (self.step if self.batch_queue is None else + self.step_with_batch_queue) + self.engines_running = False + + # Send ready message. + num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks + input_socket.send( + msgspec.msgpack.encode({ + "status": "READY", + "local": on_head_node, + "num_gpu_blocks": num_gpu_blocks, + })) + + # Background Threads and Queues for IO. These enable us to + # overlap ZMQ socket IO with GPU since they release the GIL, + # and to overlap some serialization/deserialization with the + # model forward pass. + # Threads handle Socket <-> Queues and core_busy_loop uses Queue. + self.input_queue = input_queue + self.output_queue = queue.Queue[Union[EngineCoreOutputs, bytes]]() + threading.Thread(target=self.process_input_socket, + args=(input_socket, ), + daemon=True).start() + input_socket = None + self.output_thread = threading.Thread( + target=self.process_output_socket, + args=(output_address, engine_index), + daemon=True) + self.output_thread.start() + finally: + if input_socket is not None: + input_socket.close(linger=0) + + @staticmethod + def startup_handshake(input_socket: zmq.Socket, on_head_node: bool, + parallel_config: ParallelConfig) -> str: + + # Send registration message. + input_socket.send( + msgspec.msgpack.encode({ + "status": "HELLO", + "local": on_head_node, + })) + + # Receive initialization message. + logger.info("Waiting for init message from front-end.") + if not input_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60 * 1000): + raise RuntimeError("Did not receive response from front-end " + f"process within {HANDSHAKE_TIMEOUT_MINS} " + f"minutes") + init_bytes = input_socket.recv() + init_message = msgspec.msgpack.decode(init_bytes) + logger.debug("Received init message: %s", init_message) + + output_socket_address = init_message["output_socket_address"] + #TBD(nick) maybe replace IP with configured head node address + + received_parallel_config = init_message["parallel_config"] + for key, value in received_parallel_config.items(): + setattr(parallel_config, key, value) + + return output_socket_address @staticmethod def run_engine_core(*args, @@ -412,7 +475,7 @@ def signal_handler(signum, frame): try: parallel_config: ParallelConfig = kwargs[ "vllm_config"].parallel_config - if parallel_config.data_parallel_size > 1: + if parallel_config.data_parallel_size > 1 or dp_rank > 0: # Set data parallel rank for this engine process. parallel_config.data_parallel_rank = dp_rank parallel_config.data_parallel_rank_local = local_dp_rank @@ -436,6 +499,9 @@ def signal_handler(signum, frame): if engine_core is not None: engine_core.shutdown() + def _init_data_parallel(self, vllm_config: VllmConfig): + pass + def run_busy_loop(self): """Core busy loop of the EngineCore.""" @@ -527,40 +593,25 @@ def _send_engine_dead(self): logger.fatal("vLLM shutdown signal from EngineCore failed " "to send. Please report this issue.") - def process_input_socket(self, input_path: str, engine_index: int): + def process_input_socket(self, input_socket: zmq.Socket): """Input socket IO thread.""" # Msgpack serialization decoding. add_request_decoder = MsgpackDecoder(EngineCoreRequest) generic_decoder = MsgpackDecoder() - identity = engine_index.to_bytes(length=2, byteorder="little") - with zmq_socket_ctx(input_path, - zmq.DEALER, - identity=identity, - bind=False) as socket: - - # Send ready message to front-end once input socket is connected. - message_dict = { - 'type': 'READY', - 'num_gpu_blocks': self.vllm_config.cache_config.num_gpu_blocks, - } - message = json.dumps(message_dict).encode('utf-8') - socket.send(message) - - while True: - # (RequestType, RequestData) - type_frame, *data_frames = socket.recv_multipart(copy=False) - request_type = EngineCoreRequestType(bytes(type_frame.buffer)) + while True: + # (RequestType, RequestData) + type_frame, *data_frames = input_socket.recv_multipart(copy=False) + request_type = EngineCoreRequestType(bytes(type_frame.buffer)) - # Deserialize the request data. - decoder = add_request_decoder if ( - request_type - == EngineCoreRequestType.ADD) else generic_decoder - request = decoder.decode(data_frames) + # Deserialize the request data. + decoder = add_request_decoder if ( + request_type == EngineCoreRequestType.ADD) else generic_decoder + request = decoder.decode(data_frames) - # Push to input queue for core busy loop. - self.input_queue.put_nowait((request_type, request)) + # Push to input queue for core busy loop. + self.input_queue.put_nowait((request_type, request)) def process_output_socket(self, output_path: str, engine_index: int): """Output socket IO thread.""" @@ -609,9 +660,9 @@ class DPEngineCoreProc(EngineCoreProc): def __init__( self, - input_path: str, - output_path: str, vllm_config: VllmConfig, + on_head_node: bool, + input_address: str, executor_class: type[Executor], log_stats: bool, ): @@ -623,8 +674,20 @@ def __init__( _add_prefix(sys.stdout, process_name, pid) _add_prefix(sys.stderr, process_name, pid) - dp_size = vllm_config.parallel_config.data_parallel_size + # Counts forward-passes of the model so that we can synchronize + # finished with DP peers every N steps. + self.counter = 0 + + # Initialize the engine. + dp_rank = vllm_config.parallel_config.data_parallel_rank + super().__init__(vllm_config, on_head_node, input_address, + executor_class, log_stats, dp_rank) + + def _init_data_parallel(self, vllm_config: VllmConfig): + + # Configure GPUs and stateless process group for data parallel. dp_rank = vllm_config.parallel_config.data_parallel_rank + dp_size = vllm_config.parallel_config.data_parallel_size local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local assert dp_size > 1 @@ -632,24 +695,16 @@ def __init__( from vllm.platforms import current_platform device_control_env_var = current_platform.device_control_env_var - tp_size = vllm_config.parallel_config.tensor_parallel_size + world_size = vllm_config.parallel_config.world_size os.environ[device_control_env_var] = ",".join( str(current_platform.device_id_to_physical_device_id(i)) - for i in range(local_dp_rank * tp_size, (local_dp_rank + 1) * - tp_size)) + for i in range(local_dp_rank * world_size, (local_dp_rank + 1) * + world_size)) self.local_dp_rank = local_dp_rank self.dp_group = vllm_config.parallel_config.stateless_init_dp_group() self.current_wave = 0 - # Initialize the engine after setting up environment. - super().__init__(input_path, output_path, vllm_config, executor_class, - log_stats, dp_rank) - - # Counts forward-passes of the model so that we can synchronize - # finished with DP peers every N steps. - self.counter = 0 - def shutdown(self): super().shutdown() if dp_group := getattr(self, "dp_group", None): diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index c33317edcbb0..0d52bc9a6814 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio import contextlib -import json import queue import uuid import weakref @@ -9,25 +8,27 @@ from collections import deque from collections.abc import Awaitable, Sequence from concurrent.futures import Future -from dataclasses import dataclass, field +from dataclasses import dataclass +from enum import Enum, auto from threading import Thread from typing import Any, Callable, Optional, TypeVar, Union +import msgspec import zmq import zmq.asyncio -from vllm.config import VllmConfig +from vllm.config import ParallelConfig, VllmConfig from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.utils import (get_open_zmq_inproc_path, get_open_zmq_ipc_path, - make_zmq_socket) +from vllm.utils import (get_open_port, get_open_zmq_inproc_path, + get_open_zmq_ipc_path, get_tcp_uri, make_zmq_socket) from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType, UtilityOutput) from vllm.v1.engine.core import EngineCore, EngineCoreProc from vllm.v1.engine.exceptions import EngineDeadError from vllm.v1.executor.abstract import Executor from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr -from vllm.v1.utils import BackgroundProcHandle +from vllm.v1.utils import CoreEngineProcManager logger = init_logger(__name__) @@ -264,45 +265,22 @@ def collective_rpc(self, return self.engine_core.collective_rpc(method, timeout, args, kwargs) +class CoreEngineState(Enum): + NEW = auto() + CONNECTED = auto() + READY = auto() + + class CoreEngine: """One per data parallel rank.""" - def __init__( - self, - vllm_config: VllmConfig, - executor_class: type[Executor], - log_stats: bool, - input_path: str, - output_path: str, - index: int = 0, - local_dp_rank: int = 0, - ): + def __init__(self, index: int = 0, local: bool = True): + self.local = local self.index = index self.identity = index.to_bytes(length=2, byteorder="little") - try: - # Start EngineCore in background process. - self.proc_handle = BackgroundProcHandle( - input_path=input_path, - output_path=output_path, - process_name=f"EngineCore_{index}", - target_fn=EngineCoreProc.run_engine_core, - process_kwargs={ - "vllm_config": vllm_config, - "dp_rank": index, - "local_dp_rank": local_dp_rank, - "executor_class": executor_class, - "log_stats": log_stats, - }) - - self.num_reqs_in_flight = 0 - finally: - if not hasattr(self, "num_reqs_in_flight"): - # Ensure socket is closed if process fails to start. - self.close() - def close(self): - if proc_handle := getattr(self, "proc_handle", None): - proc_handle.shutdown() + self.state = CoreEngineState.NEW + self.num_reqs_in_flight = 0 @dataclass @@ -311,7 +289,7 @@ class BackgroundResources: circular reference back to the client object.""" ctx: Union[zmq.Context] - core_engines: list[CoreEngine] = field(default_factory=list) + local_engine_manager: Optional[CoreEngineProcManager] = None output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None output_queue_task: Optional[asyncio.Task] = None @@ -325,8 +303,8 @@ def __call__(self): """Clean up background resources.""" self.engine_dead = True - for core_engine in self.core_engines: - core_engine.close() + if self.local_engine_manager is not None: + self.local_engine_manager.close() if self.output_queue_task is not None: self.output_queue_task.cancel() @@ -388,25 +366,56 @@ def __init__( self._finalizer = weakref.finalize(self, self.resources) success = False try: - # Paths and sockets for IPC. - self.output_path = get_open_zmq_ipc_path() - input_path = get_open_zmq_ipc_path() - self.input_socket = make_zmq_socket(self.ctx, - input_path, - zmq.ROUTER, - bind=True) - self.resources.input_socket = self.input_socket - - new_core_engine = lambda index, local_dp_rank=None: CoreEngine( - vllm_config, executor_class, log_stats, input_path, self. - output_path, index, local_dp_rank) - - # Start engine core process(es). - self._init_core_engines(vllm_config, new_core_engine, - self.resources.core_engines) + parallel_config = vllm_config.parallel_config + local_engine_count = parallel_config.data_parallel_size_local + start_index = parallel_config.data_parallel_rank + local_start_index = parallel_config.data_parallel_rank_local + + # SPMD mode is where there is an LLM instance per DP rank and + # one core engine per LLM, see + # examples/offline_inference/data_parallel.py. + spmd_mode = local_start_index is not None + if spmd_mode: + assert local_engine_count == 1 + self.core_engines = [ + CoreEngine(index=local_start_index, local=True) + ] + else: + assert start_index == 0 + local_start_index = 0 + self.core_engines = [ + CoreEngine(index=i, local=(i < local_engine_count)) + for i in range(parallel_config.data_parallel_size) + ] + + input_address, output_address = self._get_zmq_addresses( + parallel_config, spmd_mode) + + # Create input and output sockets. + self.input_socket = self.resources.input_socket = make_zmq_socket( + self.ctx, input_address, zmq.ROUTER, bind=True) + + self.resources.output_socket = make_zmq_socket( + self.ctx, output_address, zmq.constants.PULL) + # Start local engines. + if local_engine_count: + # In server mode, start_index and local_start_index will + # both be 0. + self.resources.local_engine_manager = CoreEngineProcManager( + EngineCoreProc.run_engine_core, + vllm_config=vllm_config, + executor_class=executor_class, + log_stats=log_stats, + input_address=input_address, + on_head_node=True, + local_engine_count=local_engine_count, + start_index=start_index, + local_start_index=local_start_index) + + self.core_engine = self.core_engines[0] # Wait for engine core process(es) to start. - self._wait_for_engine_startup() + self._wait_for_engine_startup(output_address, parallel_config) self.utility_results: dict[int, AnyFuture] = {} @@ -420,56 +429,116 @@ def __init__( if not success: self._finalizer() - def _wait_for_engine_startup(self): + @staticmethod + def _get_zmq_addresses(parallel_config: ParallelConfig, + spmd_mode: bool) -> tuple[str, str]: + """Returns (input_address, output_address).""" + dp_size = parallel_config.data_parallel_size + local_engine_count = parallel_config.data_parallel_size_local + + if local_engine_count == dp_size or spmd_mode: + input_address = get_open_zmq_ipc_path() + output_address = get_open_zmq_ipc_path() + else: + host = parallel_config.data_parallel_master_ip + input_port = parallel_config.data_parallel_rpc_port + output_port = get_open_port() + input_address = get_tcp_uri(host, input_port) + output_address = get_tcp_uri(host, output_port) + + return input_address, output_address + + def _wait_for_engine_startup(self, output_address: str, + parallel_config: ParallelConfig): # Get a sync handle to the socket which can be sync or async. sync_input_socket = zmq.Socket.shadow(self.input_socket) # Wait for engine core process(es) to send ready messages. - identities = set(eng.index for eng in self.resources.core_engines) + local_count = parallel_config.data_parallel_size_local + remote_count = len(self.core_engines) - local_count + # [local, remote] counts + conn_pending, start_pending = [local_count, remote_count], [0, 0] + poller = zmq.Poller() poller.register(sync_input_socket, zmq.POLLIN) - for eng in self.resources.core_engines: - poller.register(eng.proc_handle, zmq.POLLIN) - while identities: + proc_manager = self.resources.local_engine_manager + if proc_manager is not None: + for sentinel in proc_manager.sentinels(): + poller.register(sentinel, zmq.POLLIN) + while any(conn_pending) or any(start_pending): events = poller.poll(STARTUP_POLL_PERIOD_MS) if not events: - logger.debug("Waiting for %d core engine proc(s) to start: %s", - len(identities), identities) + if any(conn_pending): + logger.debug( + "Waiting for %d local, %d remote core engine proc(s) " + "to connect.", *conn_pending) + if any(start_pending): + logger.debug( + "Waiting for %d local, %d remote core engine proc(s) " + "to start.", *start_pending) continue if len(events) > 1 or events[0][0] != sync_input_socket: - # One of the core processes exited. + # One of the local core processes exited. + finished = proc_manager.finished_procs( + ) if proc_manager else {} raise RuntimeError("Engine core initialization failed. " - "See root cause above.") - - eng_id_bytes, data = sync_input_socket.recv_multipart() - eng_id = int.from_bytes(eng_id_bytes, byteorder="little") - if eng_id not in identities: - raise RuntimeError(f"Unexpected or duplicate engine: {eng_id}") - message_dict = json.loads(data.decode('utf-8')) - if message_dict['type'] != 'READY': - raise RuntimeError(f"Engine {eng_id} failed: {data.decode()}") - logger.info("Core engine process %d ready.", eng_id) - identities.discard(eng_id) - # Setup KV cache config with initialization state from - # engine core process. Sum values from all engines in DP case. - num_gpu_blocks = self.vllm_config.cache_config.num_gpu_blocks or 0 - num_gpu_blocks += message_dict['num_gpu_blocks'] - self.vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks - - def _init_core_engines( - self, - vllm_config: VllmConfig, - new_core_engine: Callable[[int, Optional[int]], CoreEngine], - core_engines: list[CoreEngine], - ) -> None: - - # Default case - single core engine. - core_engine = new_core_engine( - vllm_config.parallel_config.data_parallel_rank, - vllm_config.parallel_config.data_parallel_rank_local, - ) - core_engines.append(core_engine) - self.core_engine = core_engine + "See root cause above. " + f"Failed core proc(s): {finished}") + + # Receive HELLO and READY messages from the input socket. + eng_identity, ready_msg_bytes = sync_input_socket.recv_multipart() + eng_index = int.from_bytes(eng_identity, byteorder="little") + engine = next( + (e for e in self.core_engines if e.identity == eng_identity), + None) + if engine is None: + raise RuntimeError(f"Message from engine with unexpected data " + f"parallel rank: {eng_index}") + msg = msgspec.msgpack.decode(ready_msg_bytes) + status, local = msg["status"], msg["local"] + if local != engine.local: + raise RuntimeError(f"{status} message from " + f"{'local' if local else 'remote'} " + f"engine {eng_index}, expected it to be " + f"{'local' if engine.local else 'remote'}") + + if status == "HELLO" and engine.state == CoreEngineState.NEW: + + # Send init message with DP config info. + init_message = self.encoder.encode({ + "output_socket_address": output_address, + "parallel_config": { + "data_parallel_master_ip": + parallel_config.data_parallel_master_ip, + "data_parallel_master_port": + parallel_config.data_parallel_master_port, + "data_parallel_size": + parallel_config.data_parallel_size, + }, + }) + sync_input_socket.send_multipart((eng_identity, *init_message), + copy=False) + conn_pending[0 if local else 1] -= 1 + start_pending[0 if local else 1] += 1 + engine.state = CoreEngineState.CONNECTED + elif status == "READY" and (engine.state + == CoreEngineState.CONNECTED): + # Setup KV cache config with initialization state from + # engine core process. Sum values from all engines in DP case. + cache_config = self.vllm_config.cache_config + num_gpu_blocks = cache_config.num_gpu_blocks or 0 + num_gpu_blocks += msg['num_gpu_blocks'] + cache_config.num_gpu_blocks = num_gpu_blocks + + start_pending[0 if local else 1] -= 1 + engine.state = CoreEngineState.READY + else: + raise RuntimeError(f"Unexpected {status} message for " + f"{'local' if local else 'remote'} engine " + f"{eng_index} in {engine.state} state.") + + logger.debug("%s from %s core engine process %s.", status, + "local" if local else "remote", eng_index) def shutdown(self): # Terminate background resources. @@ -520,7 +589,8 @@ def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], # Ensure that the outputs socket processing thread does not have # a ref to the client which prevents gc. ctx = self.ctx - output_path = self.output_path + out_socket = self.resources.output_socket + assert out_socket is not None decoder = self.decoder utility_results = self.utility_results outputs_queue = self.outputs_queue @@ -531,7 +601,6 @@ def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], def process_outputs_socket(): shutdown_socket = ctx.socket(zmq.PAIR) - out_socket = make_zmq_socket(ctx, output_path, zmq.constants.PULL) try: shutdown_socket.bind(shutdown_path) poller = zmq.Poller() @@ -566,6 +635,9 @@ def process_outputs_socket(): daemon=True) self.output_queue_thread.start() + # The thread takes on responsibility for closing the socket. + self.resources.output_socket = None + def get_output(self) -> EngineCoreOutputs: # If an exception arises in process_outputs_socket task, # it is forwarded to the outputs_queue so we can raise it @@ -693,10 +765,8 @@ def _ensure_output_queue_task(self): self.__class__, "process_engine_outputs", None) _self_ref = weakref.ref(self) if output_handler else None - output_path = self.output_path - output_socket = make_zmq_socket(self.ctx, output_path, - zmq.constants.PULL) - resources.output_socket = output_socket + output_socket = resources.output_socket + assert output_socket is not None async def process_outputs_socket(): try: @@ -861,21 +931,6 @@ def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], assert len(self.core_engines) > 1 - def _init_core_engines( - self, - vllm_config: VllmConfig, - new_core_engine: Callable[[int, Optional[int]], CoreEngine], - core_engines: list[CoreEngine], - ) -> None: - - # Launch a core engine for each data parallel rank. - dp_size = vllm_config.parallel_config.data_parallel_size - for i in range(dp_size): - # Multi-node not yet supported so local_dp_rank == dp_rank. - core_engines.append(new_core_engine(i, i)) - - self.core_engines = core_engines - async def call_utility_async(self, method: str, *args) -> Any: # Only the result from the first engine is returned. return (await asyncio.gather(*[ diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 9c238c3aad8e..0758747a83cc 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -1,20 +1,23 @@ # SPDX-License-Identifier: Apache-2.0 import os +import time import weakref from collections import defaultdict from collections.abc import Sequence -from multiprocessing import Process -from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar, - Union, overload) +from multiprocessing import Process, connection +from typing import (TYPE_CHECKING, Callable, Generic, Optional, TypeVar, Union, + overload) import torch +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.models.utils import extract_layer_index from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) from vllm.utils import get_mp_context, kill_process_tree +from vllm.v1.executor.abstract import Executor if TYPE_CHECKING: from vllm.attention.layer import Attention @@ -92,7 +95,7 @@ def __repr__(self): return f"ConstantList({self._x})" -class BackgroundProcHandle: +class CoreEngineProcManager: """ Utility class to handle creation, readiness, and shutdown of background processes used by the AsyncLLM and LLMEngine. @@ -100,49 +103,91 @@ class BackgroundProcHandle: def __init__( self, - input_path: str, - output_path: str, - process_name: str, target_fn: Callable, - process_kwargs: dict[Any, Any], + local_engine_count: int, + start_index: int, + local_start_index: int, + vllm_config: VllmConfig, + on_head_node: bool, + input_address: str, + executor_class: type[Executor], + log_stats: bool, ): context = get_mp_context() + common_kwargs = { + "vllm_config": vllm_config, + "on_head_node": on_head_node, + "input_address": input_address, + "executor_class": executor_class, + "log_stats": log_stats, + } + + self.processes: list[Process] = [] + for index in range(local_engine_count): + local_index = local_start_index + index + global_index = start_index + index + # Start EngineCore in background process. + self.processes.append( + context.Process(target=target_fn, + name=f"EngineCore_{global_index}", + kwargs=common_kwargs | { + "dp_rank": global_index, + "local_dp_rank": local_index, + })) + + self._finalizer = weakref.finalize(self, shutdown, self.processes, + input_address) + try: + for proc in self.processes: + proc.start() + finally: + # Kill other procs if not all are running. + if self.finished_procs(): + self.close() + + def close(self): + """Shutdown all procs.""" + self._finalizer() - assert ("input_path" not in process_kwargs - and "output_path" not in process_kwargs) - process_kwargs["input_path"] = input_path - process_kwargs["output_path"] = output_path - - # Run busy loop in background process. - self.proc: Process = context.Process(target=target_fn, - kwargs=process_kwargs, - name=process_name) - self._finalizer = weakref.finalize(self, shutdown, self.proc, - input_path, output_path) - self.proc.start() + def join_first(self): + """Wait for any process to exit.""" + connection.wait(proc.sentinel for proc in self.processes) - def fileno(self): - return self.proc.sentinel + def sentinels(self) -> list: + return [proc.sentinel for proc in self.processes] - def shutdown(self): - self._finalizer() + def finished_procs(self) -> dict[str, int]: + """Returns dict of proc name -> exit code for any finished procs.""" + return { + proc.name: proc.exitcode + for proc in self.processes if proc.exitcode is not None + } # Note(rob): shutdown function cannot be a bound method, -# else the gc cannot collect the object. -def shutdown(proc: Process, input_path: str, output_path: str): +# else the gc cannot collect the objedecoupct. +def shutdown(procs: list[Process], input_address: str): # Shutdown the process. - if proc.is_alive(): - proc.terminate() - proc.join(5) - + for proc in procs: + if proc.is_alive(): + proc.terminate() + + # Allow 5 seconds for remaining procs to terminate. + deadline = time.monotonic() + 5 + for proc in procs: + remaining = deadline - time.monotonic() + if remaining <= 0: + break + if proc.is_alive(): + proc.join(remaining) + + for proc in procs: if proc.is_alive() and (pid := proc.pid) is not None: kill_process_tree(pid) # Remove zmq ipc socket files. - ipc_sockets = [output_path, input_path] - for ipc_socket in ipc_sockets: - socket_file = ipc_socket.replace("ipc://", "") + if input_address.startswith("ipc://"): + socket_file = input_address[len("ipc://"):] if os and os.path.exists(socket_file): os.remove(socket_file)