11# SPDX-License-Identifier: Apache-2.0
22
33import argparse
4+ import signal
45
56import uvloop
67
8+ import vllm .envs as envs
9+ from vllm import AsyncEngineArgs
710from vllm .entrypoints .cli .types import CLISubcommand
811from vllm .entrypoints .openai .api_server import run_server
912from vllm .entrypoints .openai .cli_args import (make_arg_parser ,
1013 validate_parsed_serve_args )
11- from vllm .utils import FlexibleArgumentParser
14+ from vllm .logger import init_logger
15+ from vllm .usage .usage_lib import UsageContext
16+ from vllm .utils import FlexibleArgumentParser , get_tcp_uri
17+ from vllm .v1 .engine .core import EngineCoreProc
18+ from vllm .v1 .engine .core_client import CoreEngineProcManager
19+ from vllm .v1 .executor .abstract import Executor
20+
21+ logger = init_logger (__name__ )
1222
1323
1424class ServeSubcommand (CLISubcommand ):
@@ -24,7 +34,10 @@ def cmd(args: argparse.Namespace) -> None:
2434 if hasattr (args , 'model_tag' ) and args .model_tag is not None :
2535 args .model = args .model_tag
2636
27- uvloop .run (run_server (args ))
37+ if args .headless :
38+ run_headless (args )
39+ else :
40+ uvloop .run (run_server (args ))
2841
2942 def validate (self , args : argparse .Namespace ) -> None :
3043 validate_parsed_serve_args (args )
@@ -42,6 +55,18 @@ def subparser_init(
4255 nargs = '?' ,
4356 help = "The model tag to serve "
4457 "(optional if specified in config)" )
58+ serve_parser .add_argument (
59+ "--headless" ,
60+ action = 'store_true' ,
61+ default = False ,
62+ help = "Run in headless mode. See multi-node data parallel "
63+ "documentation for more details." )
64+ serve_parser .add_argument (
65+ '--data-parallel-start-rank' ,
66+ '-dpr' ,
67+ type = int ,
68+ default = 0 ,
69+ help = 'Starting data parallel rank for secondary nodes.' )
4570 serve_parser .add_argument (
4671 "--config" ,
4772 type = str ,
@@ -57,3 +82,55 @@ def subparser_init(
5782
5883def cmd_init () -> list [CLISubcommand ]:
5984 return [ServeSubcommand ()]
85+
86+
87+ def run_headless (args : argparse .Namespace ):
88+
89+ # Create the EngineConfig.
90+ engine_args = AsyncEngineArgs .from_cli_args (args )
91+ usage_context = UsageContext .OPENAI_API_SERVER
92+ vllm_config = engine_args .create_engine_config (usage_context = usage_context )
93+
94+ if not envs .VLLM_USE_V1 :
95+ raise RuntimeError ("Headless mode is only supported for V1" )
96+
97+ parallel_config = vllm_config .parallel_config
98+ local_engine_count = parallel_config .data_parallel_size_local
99+ host = parallel_config .data_parallel_master_ip
100+ port = engine_args .data_parallel_rpc_port # add to config too
101+ input_address = get_tcp_uri (host , port )
102+
103+ if local_engine_count <= 0 :
104+ raise RuntimeError ("data_parallel_size_local must be > 0 in "
105+ "headless mode" )
106+
107+ # Catch SIGTERM and SIGINT to allow graceful shutdown.
108+ def signal_handler (signum , frame ):
109+ logger .debug ("Received %d signal." , signum )
110+ raise SystemExit
111+
112+ signal .signal (signal .SIGTERM , signal_handler )
113+ signal .signal (signal .SIGINT , signal_handler )
114+
115+ logger .info (
116+ "Launching %d data parallel engine(s) in headless mode, "
117+ "with head node address %s." , local_engine_count , input_address )
118+
119+ # Create the engines.
120+ engine_manager = CoreEngineProcManager (
121+ target_fn = EngineCoreProc .run_engine_core ,
122+ local_engine_count = local_engine_count ,
123+ start_index = args .data_parallel_start_rank ,
124+ local_start_index = 0 ,
125+ vllm_config = vllm_config ,
126+ on_head_node = False ,
127+ input_address = input_address ,
128+ executor_class = Executor .get_class (vllm_config ),
129+ log_stats = not engine_args .disable_log_stats ,
130+ )
131+
132+ try :
133+ engine_manager .join_first ()
134+ finally :
135+ logger .info ("Shutting down." )
136+ engine_manager .close ()
0 commit comments