Skip to content

Commit 55aa7af

Browse files
authored
[V1] DP scale-out (2/N): Decouple engine process management and comms (#15977)
Signed-off-by: Nick Hill <[email protected]>
1 parent 0b217da commit 55aa7af

File tree

10 files changed

+525
-252
lines changed

10 files changed

+525
-252
lines changed

tests/async_engine/test_async_llm_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __init__(self):
4141
self.abort_request_calls = 0
4242
self.request_id = None
4343
# Ugly, remove dependency when possible
44-
self.parallel_config = ParallelConfig(1, 1, False)
44+
self.parallel_config = ParallelConfig()
4545
self.model_config = MockModelConfig()
4646

4747
async def step_async(self, virtual_engine):

tests/v1/engine/test_engine_core_client.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@
1818
from vllm.usage.usage_lib import UsageContext
1919
from vllm.v1.engine import EngineCoreRequest
2020
from vllm.v1.engine.core import EngineCore
21-
from vllm.v1.engine.core_client import (AsyncMPClient, CoreEngine,
22-
EngineCoreClient, SyncMPClient)
21+
from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient,
22+
SyncMPClient)
2323
from vllm.v1.executor.abstract import Executor
24+
from vllm.v1.utils import CoreEngineProcManager
2425

2526
from ...distributed.conftest import MockSubscriber
2627
from ...utils import create_new_process_for_each_test
@@ -348,13 +349,13 @@ def test_startup_failure(monkeypatch: pytest.MonkeyPatch):
348349

349350
# Monkey-patch to extract core process pid while it's starting.
350351
core_proc_pid = [None]
351-
ce_ctor = CoreEngine.__init__
352+
cepm_ctor = CoreEngineProcManager.__init__
352353

353-
def patched_ce_ctor(self, *args, **kwargs):
354-
ce_ctor(self, *args, **kwargs)
355-
core_proc_pid[0] = self.proc_handle.proc.pid
354+
def patched_cepm_ctor(self: CoreEngineProcManager, *args, **kwargs):
355+
cepm_ctor(self, *args, **kwargs)
356+
core_proc_pid[0] = self.processes[0].pid
356357

357-
m.setattr(CoreEngine, "__init__", patched_ce_ctor)
358+
m.setattr(CoreEngineProcManager, "__init__", patched_cepm_ctor)
358359

359360
t = time.time()
360361
engine_args = EngineArgs(model=MODEL_NAME)

vllm/config.py

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1668,25 +1668,17 @@ class ParallelConfig:
16681668
data_parallel_size: int = 1
16691669
"""Number of data parallel groups. MoE layers will be sharded according to
16701670
the product of the tensor parallel size and data parallel size."""
1671+
data_parallel_size_local: int = 1
1672+
"""Number of local data parallel groups."""
16711673
data_parallel_rank: int = 0
16721674
"""Rank of the data parallel group."""
1673-
_data_parallel_rank_local: Optional[int] = field(default=None, init=False)
1674-
"""Private field to store the local rank of the data parallel group."""
1675-
1676-
@property
1677-
def data_parallel_rank_local(self) -> int:
1678-
"""Local rank of the data parallel group, defaults to global rank."""
1679-
if self._data_parallel_rank_local is None:
1680-
return self.data_parallel_rank
1681-
return self._data_parallel_rank_local
1682-
1683-
@data_parallel_rank_local.setter
1684-
def data_parallel_rank_local(self, value: int) -> None:
1685-
"""Set the local rank of the data parallel group."""
1686-
self._data_parallel_rank_local = value
1687-
1675+
data_parallel_rank_local: Optional[int] = None
1676+
"""Local rank of the data parallel group,
1677+
set only in SPMD mode."""
16881678
data_parallel_master_ip: str = "127.0.0.1"
16891679
"""IP of the data parallel master."""
1680+
data_parallel_rpc_port: int = 29550
1681+
"""Port for data parallel messaging."""
16901682
data_parallel_master_port: int = 29500
16911683
"""Port of the data parallel master."""
16921684
enable_expert_parallel: bool = False
@@ -1734,13 +1726,16 @@ class is dynamically inherited by the worker class. This is used to inject
17341726

17351727
world_size: int = field(init=False)
17361728
"""world_size is TPxPP, it affects the number of workers we create."""
1737-
world_size_across_dp: int = field(init=False)
1738-
"""world_size_across_dp is TPxPPxDP, it is the size of the world
1739-
including data parallelism."""
17401729

17411730
rank: int = 0
17421731
"""Global rank in distributed setup."""
17431732

1733+
@property
1734+
def world_size_across_dp(self) -> int:
1735+
"""world_size_across_dp is TPxPPxDP, it is the size of the world
1736+
including data parallelism."""
1737+
return self.world_size * self.data_parallel_size
1738+
17441739
def get_next_dp_init_port(self) -> int:
17451740
"""
17461741
We might need to initialize process groups in multiple
@@ -1800,10 +1795,14 @@ def __post_init__(self) -> None:
18001795
self.world_size = self.pipeline_parallel_size * \
18011796
self.tensor_parallel_size
18021797

1803-
if self.data_parallel_size > 1:
1798+
if self.data_parallel_size_local > self.data_parallel_size:
1799+
raise ValueError(
1800+
f"data_parallel_size_local ({self.data_parallel_size_local}) "
1801+
f"must be <= data_parallel_size ({self.data_parallel_size})")
1802+
1803+
if self.data_parallel_size > 1 or self.data_parallel_size_local == 0:
18041804
# Data parallel was specified in the engine args.
18051805
self.data_parallel_master_port = get_open_port()
1806-
# TODO multi-node
18071806
else:
18081807
# Otherwise fall back to env vars (e.g. for offline SPMD case).
18091808
self.data_parallel_size = envs.VLLM_DP_SIZE
@@ -1812,8 +1811,6 @@ def __post_init__(self) -> None:
18121811
self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP
18131812
self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT
18141813

1815-
self.world_size_across_dp = self.world_size * self.data_parallel_size
1816-
18171814
if self.distributed_executor_backend == "external_launcher":
18181815
import os
18191816
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"

vllm/distributed/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import vllm.envs as envs
2424
from vllm.logger import init_logger
25+
from vllm.utils import get_tcp_uri
2526

2627
logger = init_logger(__name__)
2728

@@ -303,7 +304,7 @@ def stateless_init_torch_distributed_process_group(
303304
always formed with process 1, 2, ..., 8, and the additional communication
304305
channel is formed with process 9 and 10.
305306
"""
306-
init_method = f"tcp://{host}:{port}"
307+
init_method = get_tcp_uri(host, port)
307308
backend = Backend(backend) # it is basically string
308309
timeout = _get_default_timeout(backend)
309310

vllm/engine/arg_utils.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,9 @@ class EngineArgs:
283283
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
284284
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
285285
data_parallel_size: int = ParallelConfig.data_parallel_size
286+
data_parallel_size_local: Optional[int] = None
287+
data_parallel_address: Optional[str] = None
288+
data_parallel_rpc_port: Optional[int] = None
286289
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
287290
max_parallel_loading_workers: Optional[
288291
int] = ParallelConfig.max_parallel_loading_workers
@@ -596,6 +599,21 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
596599
**parallel_kwargs["tensor_parallel_size"])
597600
parallel_group.add_argument("--data-parallel-size", "-dp",
598601
**parallel_kwargs["data_parallel_size"])
602+
parallel_group.add_argument('--data-parallel-size-local',
603+
'-dpl',
604+
type=int,
605+
help='Number of data parallel replicas '
606+
'to run on this node.')
607+
parallel_group.add_argument('--data-parallel-address',
608+
'-dpa',
609+
type=str,
610+
help='Address of data parallel cluster '
611+
'head-node.')
612+
parallel_group.add_argument('--data-parallel-rpc-port',
613+
'-dpp',
614+
type=int,
615+
help='Port for data parallel RPC '
616+
'communication.')
599617
parallel_group.add_argument(
600618
"--enable-expert-parallel",
601619
**parallel_kwargs["enable_expert_parallel"])
@@ -1019,10 +1037,30 @@ def create_engine_config(
10191037
# but we should not do this here.
10201038
placement_group = ray.util.get_current_placement_group()
10211039

1040+
# Local DP size defaults to global DP size if not set.
1041+
data_parallel_size_local = self.data_parallel_size if (
1042+
self.data_parallel_size_local
1043+
is None) else self.data_parallel_size_local
1044+
1045+
# DP address, used in multi-node case for torch distributed group
1046+
# and ZMQ sockets.
1047+
data_parallel_address = self.data_parallel_address if (
1048+
self.data_parallel_address
1049+
is not None) else ParallelConfig.data_parallel_master_ip
1050+
1051+
# This port is only used when there are remote data parallel engines,
1052+
# otherwise the local IPC transport is used.
1053+
data_parallel_rpc_port = self.data_parallel_rpc_port if (
1054+
self.data_parallel_rpc_port
1055+
is not None) else ParallelConfig.data_parallel_rpc_port
1056+
10221057
parallel_config = ParallelConfig(
10231058
pipeline_parallel_size=self.pipeline_parallel_size,
10241059
tensor_parallel_size=self.tensor_parallel_size,
10251060
data_parallel_size=self.data_parallel_size,
1061+
data_parallel_size_local=data_parallel_size_local,
1062+
data_parallel_master_ip=data_parallel_address,
1063+
data_parallel_rpc_port=data_parallel_rpc_port,
10261064
enable_expert_parallel=self.enable_expert_parallel,
10271065
max_parallel_loading_workers=self.max_parallel_loading_workers,
10281066
disable_custom_all_reduce=self.disable_custom_all_reduce,

vllm/entrypoints/cli/serve.py

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,24 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import argparse
4+
import signal
45

56
import uvloop
67

8+
import vllm.envs as envs
9+
from vllm import AsyncEngineArgs
710
from vllm.entrypoints.cli.types import CLISubcommand
811
from vllm.entrypoints.openai.api_server import run_server
912
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
1013
validate_parsed_serve_args)
11-
from vllm.utils import FlexibleArgumentParser
14+
from vllm.logger import init_logger
15+
from vllm.usage.usage_lib import UsageContext
16+
from vllm.utils import FlexibleArgumentParser, get_tcp_uri
17+
from vllm.v1.engine.core import EngineCoreProc
18+
from vllm.v1.engine.core_client import CoreEngineProcManager
19+
from vllm.v1.executor.abstract import Executor
20+
21+
logger = init_logger(__name__)
1222

1323

1424
class ServeSubcommand(CLISubcommand):
@@ -24,7 +34,10 @@ def cmd(args: argparse.Namespace) -> None:
2434
if hasattr(args, 'model_tag') and args.model_tag is not None:
2535
args.model = args.model_tag
2636

27-
uvloop.run(run_server(args))
37+
if args.headless:
38+
run_headless(args)
39+
else:
40+
uvloop.run(run_server(args))
2841

2942
def validate(self, args: argparse.Namespace) -> None:
3043
validate_parsed_serve_args(args)
@@ -42,6 +55,18 @@ def subparser_init(
4255
nargs='?',
4356
help="The model tag to serve "
4457
"(optional if specified in config)")
58+
serve_parser.add_argument(
59+
"--headless",
60+
action='store_true',
61+
default=False,
62+
help="Run in headless mode. See multi-node data parallel "
63+
"documentation for more details.")
64+
serve_parser.add_argument(
65+
'--data-parallel-start-rank',
66+
'-dpr',
67+
type=int,
68+
default=0,
69+
help='Starting data parallel rank for secondary nodes.')
4570
serve_parser.add_argument(
4671
"--config",
4772
type=str,
@@ -57,3 +82,55 @@ def subparser_init(
5782

5883
def cmd_init() -> list[CLISubcommand]:
5984
return [ServeSubcommand()]
85+
86+
87+
def run_headless(args: argparse.Namespace):
88+
89+
# Create the EngineConfig.
90+
engine_args = AsyncEngineArgs.from_cli_args(args)
91+
usage_context = UsageContext.OPENAI_API_SERVER
92+
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
93+
94+
if not envs.VLLM_USE_V1:
95+
raise RuntimeError("Headless mode is only supported for V1")
96+
97+
parallel_config = vllm_config.parallel_config
98+
local_engine_count = parallel_config.data_parallel_size_local
99+
host = parallel_config.data_parallel_master_ip
100+
port = engine_args.data_parallel_rpc_port # add to config too
101+
input_address = get_tcp_uri(host, port)
102+
103+
if local_engine_count <= 0:
104+
raise RuntimeError("data_parallel_size_local must be > 0 in "
105+
"headless mode")
106+
107+
# Catch SIGTERM and SIGINT to allow graceful shutdown.
108+
def signal_handler(signum, frame):
109+
logger.debug("Received %d signal.", signum)
110+
raise SystemExit
111+
112+
signal.signal(signal.SIGTERM, signal_handler)
113+
signal.signal(signal.SIGINT, signal_handler)
114+
115+
logger.info(
116+
"Launching %d data parallel engine(s) in headless mode, "
117+
"with head node address %s.", local_engine_count, input_address)
118+
119+
# Create the engines.
120+
engine_manager = CoreEngineProcManager(
121+
target_fn=EngineCoreProc.run_engine_core,
122+
local_engine_count=local_engine_count,
123+
start_index=args.data_parallel_start_rank,
124+
local_start_index=0,
125+
vllm_config=vllm_config,
126+
on_head_node=False,
127+
input_address=input_address,
128+
executor_class=Executor.get_class(vllm_config),
129+
log_stats=not engine_args.disable_log_stats,
130+
)
131+
132+
try:
133+
engine_manager.join_first()
134+
finally:
135+
logger.info("Shutting down.")
136+
engine_manager.close()

vllm/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,10 @@ def is_valid_ipv6_address(address: str) -> bool:
613613

614614

615615
def get_distributed_init_method(ip: str, port: int) -> str:
616+
return get_tcp_uri(ip, port)
617+
618+
619+
def get_tcp_uri(ip: str, port: int) -> str:
616620
# Brackets are not permitted in ipv4 addresses,
617621
# see https://github.com/python/cpython/issues/103848
618622
return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}"

0 commit comments

Comments
 (0)