|
| 1 | +# TODO ray_utils wraps everything in a try except, we could do the same here. |
| 2 | + |
| 3 | +import os |
| 4 | +import asyncio as aio |
| 5 | +import rpyc |
| 6 | +from rpyc.utils.server import ThreadedServer |
| 7 | +from rpyc.utils.classic import obtain |
| 8 | +from contextlib import closing |
| 9 | +import socket |
| 10 | +from datetime import timedelta |
| 11 | +import time |
| 12 | + |
| 13 | + |
| 14 | +def find_free_port(): |
| 15 | + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: |
| 16 | + s.bind(("", 0)) |
| 17 | + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) |
| 18 | + return s.getsockname()[1] |
| 19 | + |
| 20 | + |
| 21 | +class RPyCWorkerService(rpyc.Service): |
| 22 | + def on_connect(self, conn): |
| 23 | + pass |
| 24 | + |
| 25 | + def on_disconnect(self, conn): |
| 26 | + pass |
| 27 | + |
| 28 | + def exposed_get_addr_and_port(self): |
| 29 | + # equivalent of |
| 30 | + # addr = ray.util.get_node_ip_address() |
| 31 | + # port = find_free_port() |
| 32 | + addr = "127.0.0.1" # we should be local I think |
| 33 | + port = find_free_port() |
| 34 | + return addr, port |
| 35 | + |
| 36 | + def exposed_init_torch_distributed(self, master_addr, master_port, gpu_ids, world_size, rank): |
| 37 | + # https://github.com/ray-project/ray/blob/7a3ae5ba5dbd6704f435bde8dba91a8a8d207ae4/python/ray/air/util/torch_dist.py#L95 |
| 38 | + # for reference |
| 39 | + |
| 40 | + os.environ["MASTER_ADDR"] = str(master_addr) |
| 41 | + os.environ["MASTER_PORT"] = str(master_port) |
| 42 | + |
| 43 | + os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1" |
| 44 | + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(gpu_id for gpu_id in gpu_ids)) |
| 45 | + if "NCCL_SOCKET_IFNAME" not in os.environ: |
| 46 | + os.environ["NCCL_SOCKET_IFNAME"] = "^lo,docker,veth" |
| 47 | + |
| 48 | + import torch |
| 49 | + import torch.distributed as dist |
| 50 | + |
| 51 | + # ray makes a call to init process group here |
| 52 | + dist.init_process_group(backend="nccl", init_method="env://", rank=rank, world_size=world_size, timeout=timedelta(seconds=1800)) |
| 53 | + |
| 54 | + # running on one node, local_{rank|world_size} is same as {rank|world_size} |
| 55 | + os.environ["WORLD_SIZE"] = str(world_size) |
| 56 | + os.environ["LOCAL_WORLD_SIZE"] = str(world_size) |
| 57 | + os.environ["RANK"] = str(rank) |
| 58 | + os.environ["LOCAL_RANK"] = str(rank) |
| 59 | + |
| 60 | + def exposed_init_worker(self, model_config, parallel_config, scheduler_config): |
| 61 | + # we import worker explicitly here as opposed to provide some generic init_worker_fn() api |
| 62 | + # since the init_worker_fn() can't be pickled and sent over. |
| 63 | + # also import inside worker process since if not it'll break the engine process |
| 64 | + # probably same reason as why _init_workers_ray imports this so late? |
| 65 | + from vllm.worker.worker import Worker |
| 66 | + model_config, parallel_config, scheduler_config = obtain(model_config), obtain(parallel_config), obtain(scheduler_config) |
| 67 | + self.worker = Worker( |
| 68 | + model_config, |
| 69 | + parallel_config, |
| 70 | + scheduler_config, |
| 71 | + None, |
| 72 | + None, |
| 73 | + ) |
| 74 | + |
| 75 | + def exposed_execute_method(self, method: str, *args, **kwargs): |
| 76 | + # I believe this obtain() makes a call to the other process, which may be a bottleneck. |
| 77 | + # Potentially can try 1. a faster way of serializing the args/kwargs objects + avoiding the call to the other process |
| 78 | + # or 2. sticking args/kwargs into shared memory |
| 79 | + args, kwargs = obtain(args), obtain(kwargs) # with prints, seems like this takes about 0.0025 seconds with 4 workers, which is pretty significant |
| 80 | + executor = getattr(self.worker, method) |
| 81 | + retval = executor(*args, **kwargs) |
| 82 | + return retval |
| 83 | + |
| 84 | +class RPyCWorkerClient: |
| 85 | + def __init__(self, conn): |
| 86 | + self.conn = conn |
| 87 | + def async_wrap(f): |
| 88 | + f = rpyc.async_(f) |
| 89 | + async def _func(*args, **kwargs): |
| 90 | + ans = f(*args, **kwargs) |
| 91 | + await aio.to_thread(ans.wait) |
| 92 | + # raise if exception |
| 93 | + return ans.value |
| 94 | + return _func |
| 95 | + self.async_wrap = async_wrap |
| 96 | + self._ainit_torch_distributed = self.async_wrap(self.conn.root.init_torch_distributed) |
| 97 | + self._ainit_worker = self.async_wrap(self.conn.root.init_worker) |
| 98 | + self._aexecute_method = self.async_wrap(self.conn.root.execute_method) |
| 99 | + self._get_addr_and_port = self.conn.root.get_addr_and_port |
| 100 | + |
| 101 | + def get_addr_and_port(self): |
| 102 | + return self._get_addr_and_port() |
| 103 | + |
| 104 | + async def aexecute_method(self, method, *args, **kwargs): |
| 105 | + ans = await self._aexecute_method(method, *args, **kwargs) |
| 106 | + new_ans = obtain(ans) |
| 107 | + return new_ans |
| 108 | + |
| 109 | + async def ainit_torch_distributed(self, master_addr, master_port, gpu_ids, world_size, rank): |
| 110 | + return await self._ainit_torch_distributed(master_addr, master_port, gpu_ids, world_size, rank) |
| 111 | + |
| 112 | + async def ainit_worker(self, model_config, parallel_config, scheduler_config): |
| 113 | + return await self._ainit_worker(model_config, parallel_config, scheduler_config) |
| 114 | + |
| 115 | + |
| 116 | + |
| 117 | +def init_rpyc_env(port): |
| 118 | + # We need to import torch here, otherwise torch won't recognize CUDA devices as available. |
| 119 | + # Not sure why unfortunately, but I think it's related to some ordering of imports/environment set up |
| 120 | + import torch |
| 121 | + # This following print is necessary for the workers to start up, otherwise we get some weird error with torch not recognizing gpus |
| 122 | + # We probably just need to run `torch.cuda.is_available()/.device_count()` |
| 123 | + print("init_rpyc_env cuda support:", torch.cuda.is_available(),":", torch.cuda.device_count(), "devices") |
| 124 | + t = ThreadedServer(RPyCWorkerService(), port=port, protocol_config={"allow_pickle": True}) |
| 125 | + t.start() |
| 126 | + return |
0 commit comments