|  | 
|  | 1 | +import atexit | 
|  | 2 | +import os | 
|  | 3 | +import threading | 
|  | 4 | +import time | 
|  | 5 | +from typing import Optional | 
|  | 6 | + | 
|  | 7 | +from ..llmapi.mpi_session import MpiPoolSession, MpiSession | 
|  | 8 | +from ..llmapi.tracer import global_tracer | 
|  | 9 | +from ..llmapi.utils import _SyncQueue, print_colored_debug | 
|  | 10 | +from .executor import GenerationExecutor | 
|  | 11 | +from .postproc_worker import PostprocWorkerConfig | 
|  | 12 | +from .request import GenerationRequest | 
|  | 13 | +from .result import GenerationResult | 
|  | 14 | +from .rpc import RPCClient | 
|  | 15 | +from .rpc_worker import rpc_worker_main | 
|  | 16 | +from .utils import (ErrorResponse, create_mpi_comm_session, | 
|  | 17 | +                    get_spawn_proxy_process_env, is_llm_response) | 
|  | 18 | + | 
|  | 19 | + | 
|  | 20 | +class GenerationExecutorRpcProxy(GenerationExecutor): | 
|  | 21 | +    # NOTE: this is a global counter for the number of instances of this class | 
|  | 22 | +    INSTANCE_COUNTER = 0 | 
|  | 23 | + | 
|  | 24 | +    def __init__(self, | 
|  | 25 | +                 worker_kwargs: dict, | 
|  | 26 | +                 model_world_size: int = 1, | 
|  | 27 | +                 mpi_session: Optional[MpiSession] = None, | 
|  | 28 | +                 *, | 
|  | 29 | +                 postproc_worker_config: Optional[PostprocWorkerConfig] = None, | 
|  | 30 | +                 is_llm_executor: Optional[bool] = None, | 
|  | 31 | +                 garbage_collection_gen0_threshold: Optional[int] = None, | 
|  | 32 | +                 clock_unit: int = 1): | 
|  | 33 | +        """ | 
|  | 34 | +        Args: | 
|  | 35 | +            worker_kwargs: kwargs for the rpc worker | 
|  | 36 | +            model_world_size: the world size of the model | 
|  | 37 | +            mpi_session: the mpi session to use | 
|  | 38 | +            postproc_worker_config: the postproc worker config | 
|  | 39 | +            is_llm_executor: whether this is an llm executor | 
|  | 40 | +            garbage_collection_gen0_threshold: the garbage collection gen0 threshold | 
|  | 41 | +            clock_unit: the unit of the clock, 1 means 1 second | 
|  | 42 | +        """ | 
|  | 43 | + | 
|  | 44 | +        GenerationExecutorRpcProxy.INSTANCE_COUNTER += 1 | 
|  | 45 | +        self.rpc_addr = self._gen_rpc_addr() | 
|  | 46 | +        self.rpc_client = RPCClient(self.rpc_addr) | 
|  | 47 | + | 
|  | 48 | +        postproc_worker_config = postproc_worker_config or PostprocWorkerConfig( | 
|  | 49 | +        ) | 
|  | 50 | + | 
|  | 51 | +        super().__init__( | 
|  | 52 | +            num_postprocess_workers=postproc_worker_config. | 
|  | 53 | +            num_postprocess_workers, | 
|  | 54 | +            postprocess_tokenizer_dir=postproc_worker_config. | 
|  | 55 | +            postprocess_tokenizer_dir, | 
|  | 56 | +            is_llm_executor=is_llm_executor, | 
|  | 57 | +        ) | 
|  | 58 | + | 
|  | 59 | +        self.mpi_session = self._create_mpi_session(model_world_size, | 
|  | 60 | +                                                    mpi_session) | 
|  | 61 | + | 
|  | 62 | +        self._shutdown_event = threading.Event() | 
|  | 63 | + | 
|  | 64 | +        self.launch_workers() | 
|  | 65 | +        time.sleep(1)  # wait for the workers to launch | 
|  | 66 | + | 
|  | 67 | +        # Invoke model creation on the remote | 
|  | 68 | +        # TBD: Move model creation to the mpi task, or left in RPC? | 
|  | 69 | +        self.create_engine_remote() | 
|  | 70 | + | 
|  | 71 | +        self.setup_mainloop() | 
|  | 72 | + | 
|  | 73 | +    def launch_workers(self): | 
|  | 74 | +        assert self.mpi_session is not None | 
|  | 75 | +        self.mpi_session.submit(rpc_worker_main, | 
|  | 76 | +                                rpc_addr=self.rpc_addr, | 
|  | 77 | +                                **self.worker_kwargs) | 
|  | 78 | + | 
|  | 79 | +    def main_loop_task(self): | 
|  | 80 | +        """ | 
|  | 81 | +        Main loop of the proxy, it will invoke the actions periodically. | 
|  | 82 | +        """ | 
|  | 83 | +        clock = 0 | 
|  | 84 | +        while not self._shutdown_event.is_set(): | 
|  | 85 | +            if clock % 1 == 0: | 
|  | 86 | +                responses = self.await_responses_remote() | 
|  | 87 | +                self.handle_responses(responses) | 
|  | 88 | +            if clock % 10 == 0: | 
|  | 89 | +                stats = self.get_stats_remote()  # TODO | 
|  | 90 | +                self.handle_stats(stats) | 
|  | 91 | + | 
|  | 92 | +            clock += 1 | 
|  | 93 | +            time.sleep(self.clock_unit) | 
|  | 94 | + | 
|  | 95 | +    def setup_mainloop(self): | 
|  | 96 | +        self.main_loop_thread = threading.Thread(target=self.main_loop_task, | 
|  | 97 | +                                                 daemon=True) | 
|  | 98 | +        self.main_loop_thread.start() | 
|  | 99 | +        atexit.register(self.shutdown) | 
|  | 100 | + | 
|  | 101 | +    def handle_responses(self, responses: list[GenerationResult]) -> bool: | 
|  | 102 | +        async_queues = [] | 
|  | 103 | +        event_loop = None | 
|  | 104 | + | 
|  | 105 | +        def process_res(res): | 
|  | 106 | +            client_id = res.client_id | 
|  | 107 | +            nonlocal event_loop | 
|  | 108 | +            nonlocal async_queues | 
|  | 109 | + | 
|  | 110 | +            queue = self._results[client_id].queue | 
|  | 111 | +            if isinstance(queue, _SyncQueue): | 
|  | 112 | +                queue.put_nowait(res) | 
|  | 113 | +                async_queues.append(queue) | 
|  | 114 | +                # all the loops are identical | 
|  | 115 | +                event_loop = event_loop or queue.loop | 
|  | 116 | +            else: | 
|  | 117 | +                queue.put(res) | 
|  | 118 | + | 
|  | 119 | +            if (is_llm_response(res) and res.result.is_final) or isinstance( | 
|  | 120 | +                    res, ErrorResponse): | 
|  | 121 | +                self._results.pop(client_id) | 
|  | 122 | + | 
|  | 123 | +        for res in responses: | 
|  | 124 | +            global_tracer().log_instant("RPC.get") | 
|  | 125 | +            process_res(res) | 
|  | 126 | + | 
|  | 127 | +        if async_queues: | 
|  | 128 | +            _SyncQueue.notify_many(event_loop, async_queues) | 
|  | 129 | + | 
|  | 130 | +    def handle_stats(self, stats: dict): | 
|  | 131 | +        raise NotImplementedError | 
|  | 132 | + | 
|  | 133 | +    def submit(self, request: GenerationRequest) -> GenerationResult: | 
|  | 134 | +        # submit is a fire-and-forget operation, don't need to wait for response | 
|  | 135 | +        return self.rpc_client.submit(request, need_response=False) | 
|  | 136 | + | 
|  | 137 | +    def await_responses_remote(self): | 
|  | 138 | +        return self.rpc_client.await_responses() | 
|  | 139 | + | 
|  | 140 | +    def create_engine_remote(self): | 
|  | 141 | +        return self.rpc_client.create_engine()  # TODO | 
|  | 142 | + | 
|  | 143 | +    def shutdown_remote(self): | 
|  | 144 | +        self.rpc_client.shutdown() | 
|  | 145 | + | 
|  | 146 | +    def _create_mpi_session(self, model_world_size: int, | 
|  | 147 | +                            mpi_session: Optional[MpiSession]): | 
|  | 148 | +        mpi_process_pre_spawned: bool = get_spawn_proxy_process_env() | 
|  | 149 | +        if mpi_session is None: | 
|  | 150 | +            if mpi_process_pre_spawned: | 
|  | 151 | +                print_colored_debug('create comm session ...\n', "yellow") | 
|  | 152 | +                self.mpi_session = create_mpi_comm_session(model_world_size) | 
|  | 153 | +            else: | 
|  | 154 | +                print_colored_debug('create pool session ...\n', "yellow") | 
|  | 155 | +                self.mpi_session = MpiPoolSession(n_workers=model_world_size) | 
|  | 156 | +        else: | 
|  | 157 | +            print_colored_debug('using external mpi session ...\n', "yellow") | 
|  | 158 | +            self.mpi_session = mpi_session | 
|  | 159 | + | 
|  | 160 | +    def _gen_rpc_addr(self): | 
|  | 161 | +        process_id = os.getpid() | 
|  | 162 | +        return f"ipc:///tmp/rpc-proxy-{process_id}-{GenerationExecutorRpcProxy.INSTANCE_COUNTER}" | 
0 commit comments