Skip to content

Commit c6ee7f3

Browse files
Merge pull request #1 from seanshi-scale/seanshi-scale/rpyc
add rpyc
2 parents bc5a20a + f53e1b9 commit c6ee7f3

File tree

6 files changed

+223
-5
lines changed

6 files changed

+223
-5
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@ xformers >= 0.0.22
1111
fastapi
1212
uvicorn[standard]
1313
pydantic < 2 # Required for OpenAI server.
14+
rpyc >= 5.3.0 # Required if you want to use RPyC. As of 5.3.0, there needs to be a separate change in the source to enable not-terrible performance compared to Ray.

vllm/config.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,13 +237,16 @@ def __init__(
237237
pipeline_parallel_size: int,
238238
tensor_parallel_size: int,
239239
worker_use_ray: bool,
240+
worker_use_rpyc: bool,
240241
) -> None:
241242
self.pipeline_parallel_size = pipeline_parallel_size
242243
self.tensor_parallel_size = tensor_parallel_size
243244
self.worker_use_ray = worker_use_ray
245+
self.worker_use_rpyc = worker_use_rpyc
244246

245247
self.world_size = pipeline_parallel_size * tensor_parallel_size
246-
if self.world_size > 1:
248+
if self.world_size > 1 and not worker_use_rpyc:
249+
# HACK: kinda messy handling of whether we choose to use ray/rpyc/none for the workers
247250
self.worker_use_ray = True
248251
self._verify_args()
249252

vllm/engine/arg_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class EngineArgs:
2020
seed: int = 0
2121
max_model_len: Optional[int] = None
2222
worker_use_ray: bool = False
23+
worker_use_rpyc: bool = False
2324
pipeline_parallel_size: int = 1
2425
tensor_parallel_size: int = 1
2526
block_size: int = 16
@@ -109,6 +110,7 @@ def add_cli_args(
109110
action='store_true',
110111
help='use Ray for distributed serving, will be '
111112
'automatically set when using more than 1 GPU')
113+
parser.add_argument('--worker-use-rpyc', action='store_true', help='use rpyc for distributed serving, todo this is kinda hacked in')
112114
parser.add_argument('--pipeline-parallel-size',
113115
'-pp',
114116
type=int,
@@ -181,7 +183,8 @@ def create_engine_configs(
181183
getattr(model_config.hf_config, 'sliding_window', None))
182184
parallel_config = ParallelConfig(self.pipeline_parallel_size,
183185
self.tensor_parallel_size,
184-
self.worker_use_ray)
186+
self.worker_use_ray,
187+
self.worker_use_rpyc)
185188
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
186189
self.max_num_seqs,
187190
model_config.max_model_len)

vllm/engine/async_llm_engine.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,8 @@ async def _run_workers_async(
210210
for worker in self.workers:
211211
if self.parallel_config.worker_use_ray:
212212
executor = partial(worker.execute_method.remote, method)
213+
elif self.parallel_config.worker_use_rpyc:
214+
executor = partial(worker.aexecute_method, method)
213215
else:
214216
executor = getattr(worker, method)
215217

@@ -218,14 +220,18 @@ async def _run_workers_async(
218220

219221
if self.parallel_config.worker_use_ray:
220222
all_outputs = await asyncio.gather(*all_outputs)
223+
elif self.parallel_config.worker_use_rpyc:
224+
all_outputs = await asyncio.gather(*all_outputs)
221225

222226
if get_all_outputs:
223227
return all_outputs
224228

225229
# Make sure all workers have the same results.
226-
output = all_outputs[0]
227-
for other_output in all_outputs[1:]:
228-
assert output == other_output
230+
output = all_outputs[0] # some "ray objectref" object in ray mode, some list(list(sequence_output)) in one-process mode
231+
if not self.parallel_config.worker_use_rpyc:
232+
# HACK: if we're using rpyc, we are returned coroutines, and we can't assert equality
233+
for other_output in all_outputs[1:]:
234+
assert output == other_output
229235
return output
230236

231237

@@ -257,13 +263,15 @@ class AsyncLLMEngine:
257263

258264
def __init__(self,
259265
worker_use_ray: bool,
266+
worker_use_rpyc: bool,
260267
engine_use_ray: bool,
261268
*args,
262269
log_requests: bool = True,
263270
max_log_len: Optional[int] = None,
264271
start_engine_loop: bool = True,
265272
**kwargs) -> None:
266273
self.worker_use_ray = worker_use_ray
274+
self.worker_use_rpyc = worker_use_rpyc
267275
self.engine_use_ray = engine_use_ray
268276
self.log_requests = log_requests
269277
self.max_log_len = max_log_len
@@ -484,6 +492,7 @@ def from_engine_args(cls,
484492
parallel_config, engine_args.engine_use_ray)
485493
# Create the async LLM engine.
486494
engine = cls(engine_args.worker_use_ray,
495+
engine_args.worker_use_rpyc,
487496
engine_args.engine_use_ray,
488497
*engine_configs,
489498
distributed_init_method,

vllm/engine/llm_engine.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import copy
22
import time
3+
import os
4+
import asyncio as aio
35
from functools import partial
46
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union
57

@@ -104,6 +106,8 @@ def __init__(
104106
# Create the parallel GPU workers.
105107
if self.parallel_config.worker_use_ray:
106108
self._init_workers_ray(placement_group)
109+
elif self.parallel_config.worker_use_rpyc:
110+
self._init_workers_rpyc()
107111
else:
108112
self._init_workers(distributed_init_method)
109113

@@ -181,6 +185,72 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
181185
get_all_outputs=True,
182186
)
183187

188+
def _init_workers_rpyc(self):
189+
190+
from multiprocessing import Process, set_start_method
191+
import rpyc
192+
from vllm.worker.worker import Worker # pylint: disable=import-outside-toplevel
193+
194+
from vllm.engine.rpyc_utils import RPyCWorkerClient, init_rpyc_env, find_free_port # Import here, otherwise we break Ray
195+
196+
self.workers: List[RPyCWorkerClient] = []
197+
ports = []
198+
set_start_method("spawn") # forkserver mode may work too
199+
# HACK: There's some messiness with the order of spawning the process, importing torch, and setting env vars,
200+
# that cause the gpu to either be recognized or not by the worker process, so we set the env var here to make sure
201+
# we've set the gpu correctly.
202+
gpu_ids = list(range(self.parallel_config.world_size))
203+
# Think we just need to set CUDA_VISIBLE_DEVICES?
204+
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join([str(gpu_id) for gpu_id in gpu_ids])
205+
206+
for i in range(self.parallel_config.world_size):
207+
port = find_free_port()
208+
p = Process(target=init_rpyc_env, args=(port,))
209+
p.start()
210+
ports.append(port)
211+
time.sleep(2)
212+
for i in range(self.parallel_config.world_size):
213+
port = ports[i]
214+
for _ in range(20):
215+
try:
216+
conn = rpyc.connect("localhost", port, config={"allow_pickle": True})
217+
self.workers.append(RPyCWorkerClient(conn))
218+
break
219+
except ConnectionRefusedError:
220+
print(f"Conn refused for worker {i}")
221+
time.sleep(2)
222+
continue
223+
else:
224+
raise ConnectionRefusedError("Couldn't connect to workers")
225+
226+
# Initialize torch distributed process group for the workers.
227+
addr, port = self.workers[0].get_addr_and_port()
228+
229+
executors = []
230+
for i, worker_client in enumerate(self.workers):
231+
exec = worker_client.ainit_torch_distributed(
232+
addr,
233+
port,
234+
list(range(self.parallel_config.world_size)),
235+
self.parallel_config.world_size,
236+
i,
237+
)
238+
executors.append(exec)
239+
loop = aio.get_event_loop()
240+
loop.run_until_complete(aio.gather(*executors))
241+
242+
executors = []
243+
for worker_client in self.workers:
244+
exec = worker_client.ainit_worker(
245+
self.model_config, self.parallel_config, self.scheduler_config
246+
)
247+
executors.append(exec)
248+
loop.run_until_complete(aio.gather(*executors))
249+
self._run_workers(
250+
"init_model",
251+
get_all_outputs=True,
252+
)
253+
184254
def _verify_args(self) -> None:
185255
self.model_config.verify_with_parallel_config(self.parallel_config)
186256
self.cache_config.verify_with_parallel_config(self.parallel_config)
@@ -686,6 +756,8 @@ def _run_workers(
686756
for worker in self.workers:
687757
if self.parallel_config.worker_use_ray:
688758
executor = partial(worker.execute_method.remote, method)
759+
elif self.parallel_config.worker_use_rpyc:
760+
executor = partial(worker.aexecute_method, method)
689761
else:
690762
executor = getattr(worker, method)
691763

@@ -694,6 +766,10 @@ def _run_workers(
694766

695767
if self.parallel_config.worker_use_ray:
696768
all_outputs = ray.get(all_outputs)
769+
elif self.parallel_config.worker_use_rpyc:
770+
# There may be a faster way to make all the requests.
771+
loop = aio.get_event_loop()
772+
all_outputs = loop.run_until_complete(aio.gather(*all_outputs))
697773

698774
if get_all_outputs:
699775
return all_outputs

vllm/engine/rpyc_utils.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# TODO ray_utils wraps everything in a try except, we could do the same here.
2+
3+
import os
4+
import asyncio as aio
5+
import rpyc
6+
from rpyc.utils.server import ThreadedServer
7+
from rpyc.utils.classic import obtain
8+
from contextlib import closing
9+
import socket
10+
from datetime import timedelta
11+
import time
12+
13+
14+
def find_free_port():
15+
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
16+
s.bind(("", 0))
17+
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
18+
return s.getsockname()[1]
19+
20+
21+
class RPyCWorkerService(rpyc.Service):
22+
def on_connect(self, conn):
23+
pass
24+
25+
def on_disconnect(self, conn):
26+
pass
27+
28+
def exposed_get_addr_and_port(self):
29+
# equivalent of
30+
# addr = ray.util.get_node_ip_address()
31+
# port = find_free_port()
32+
addr = "127.0.0.1" # we should be local I think
33+
port = find_free_port()
34+
return addr, port
35+
36+
def exposed_init_torch_distributed(self, master_addr, master_port, gpu_ids, world_size, rank):
37+
# https://github.com/ray-project/ray/blob/7a3ae5ba5dbd6704f435bde8dba91a8a8d207ae4/python/ray/air/util/torch_dist.py#L95
38+
# for reference
39+
40+
os.environ["MASTER_ADDR"] = str(master_addr)
41+
os.environ["MASTER_PORT"] = str(master_port)
42+
43+
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1"
44+
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(gpu_id for gpu_id in gpu_ids))
45+
if "NCCL_SOCKET_IFNAME" not in os.environ:
46+
os.environ["NCCL_SOCKET_IFNAME"] = "^lo,docker,veth"
47+
48+
import torch
49+
import torch.distributed as dist
50+
51+
# ray makes a call to init process group here
52+
dist.init_process_group(backend="nccl", init_method="env://", rank=rank, world_size=world_size, timeout=timedelta(seconds=1800))
53+
54+
# running on one node, local_{rank|world_size} is same as {rank|world_size}
55+
os.environ["WORLD_SIZE"] = str(world_size)
56+
os.environ["LOCAL_WORLD_SIZE"] = str(world_size)
57+
os.environ["RANK"] = str(rank)
58+
os.environ["LOCAL_RANK"] = str(rank)
59+
60+
def exposed_init_worker(self, model_config, parallel_config, scheduler_config):
61+
# we import worker explicitly here as opposed to provide some generic init_worker_fn() api
62+
# since the init_worker_fn() can't be pickled and sent over.
63+
# also import inside worker process since if not it'll break the engine process
64+
# probably same reason as why _init_workers_ray imports this so late?
65+
from vllm.worker.worker import Worker
66+
model_config, parallel_config, scheduler_config = obtain(model_config), obtain(parallel_config), obtain(scheduler_config)
67+
self.worker = Worker(
68+
model_config,
69+
parallel_config,
70+
scheduler_config,
71+
None,
72+
None,
73+
)
74+
75+
def exposed_execute_method(self, method: str, *args, **kwargs):
76+
# I believe this obtain() makes a call to the other process, which may be a bottleneck.
77+
# Potentially can try 1. a faster way of serializing the args/kwargs objects + avoiding the call to the other process
78+
# or 2. sticking args/kwargs into shared memory
79+
args, kwargs = obtain(args), obtain(kwargs) # with prints, seems like this takes about 0.0025 seconds with 4 workers, which is pretty significant
80+
executor = getattr(self.worker, method)
81+
retval = executor(*args, **kwargs)
82+
return retval
83+
84+
class RPyCWorkerClient:
85+
def __init__(self, conn):
86+
self.conn = conn
87+
def async_wrap(f):
88+
f = rpyc.async_(f)
89+
async def _func(*args, **kwargs):
90+
ans = f(*args, **kwargs)
91+
await aio.to_thread(ans.wait)
92+
# raise if exception
93+
return ans.value
94+
return _func
95+
self.async_wrap = async_wrap
96+
self._ainit_torch_distributed = self.async_wrap(self.conn.root.init_torch_distributed)
97+
self._ainit_worker = self.async_wrap(self.conn.root.init_worker)
98+
self._aexecute_method = self.async_wrap(self.conn.root.execute_method)
99+
self._get_addr_and_port = self.conn.root.get_addr_and_port
100+
101+
def get_addr_and_port(self):
102+
return self._get_addr_and_port()
103+
104+
async def aexecute_method(self, method, *args, **kwargs):
105+
ans = await self._aexecute_method(method, *args, **kwargs)
106+
new_ans = obtain(ans)
107+
return new_ans
108+
109+
async def ainit_torch_distributed(self, master_addr, master_port, gpu_ids, world_size, rank):
110+
return await self._ainit_torch_distributed(master_addr, master_port, gpu_ids, world_size, rank)
111+
112+
async def ainit_worker(self, model_config, parallel_config, scheduler_config):
113+
return await self._ainit_worker(model_config, parallel_config, scheduler_config)
114+
115+
116+
117+
def init_rpyc_env(port):
118+
# We need to import torch here, otherwise torch won't recognize CUDA devices as available.
119+
# Not sure why unfortunately, but I think it's related to some ordering of imports/environment set up
120+
import torch
121+
# This following print is necessary for the workers to start up, otherwise we get some weird error with torch not recognizing gpus
122+
# We probably just need to run `torch.cuda.is_available()/.device_count()`
123+
print("init_rpyc_env cuda support:", torch.cuda.is_available(),":", torch.cuda.device_count(), "devices")
124+
t = ThreadedServer(RPyCWorkerService(), port=port, protocol_config={"allow_pickle": True})
125+
t.start()
126+
return

0 commit comments

Comments
 (0)