Skip to content

Commit 33c3dbb

Browse files
committed
add prototype for rpc worker and proxy
Signed-off-by: Superjomn <[email protected]>
1 parent a61f1ca commit 33c3dbb

File tree

2 files changed

+222
-0
lines changed

2 files changed

+222
-0
lines changed

tensorrt_llm/executor/rpc_proxy.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
import atexit
2+
import os
3+
import threading
4+
import time
5+
from typing import Optional
6+
7+
from ..llmapi.mpi_session import MpiPoolSession, MpiSession
8+
from ..llmapi.tracer import global_tracer
9+
from ..llmapi.utils import _SyncQueue, print_colored_debug
10+
from .executor import GenerationExecutor
11+
from .postproc_worker import PostprocWorkerConfig
12+
from .request import GenerationRequest
13+
from .result import GenerationResult
14+
from .rpc import RPCClient
15+
from .rpc_worker import rpc_worker_main
16+
from .utils import (ErrorResponse, create_mpi_comm_session,
17+
get_spawn_proxy_process_env, is_llm_response)
18+
19+
20+
class GenerationExecutorRpcProxy(GenerationExecutor):
21+
# NOTE: this is a global counter for the number of instances of this class
22+
INSTANCE_COUNTER = 0
23+
24+
def __init__(self,
25+
worker_kwargs: dict,
26+
model_world_size: int = 1,
27+
mpi_session: Optional[MpiSession] = None,
28+
*,
29+
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
30+
is_llm_executor: Optional[bool] = None,
31+
garbage_collection_gen0_threshold: Optional[int] = None,
32+
clock_unit: int = 1):
33+
"""
34+
Args:
35+
worker_kwargs: kwargs for the rpc worker
36+
model_world_size: the world size of the model
37+
mpi_session: the mpi session to use
38+
postproc_worker_config: the postproc worker config
39+
is_llm_executor: whether this is an llm executor
40+
garbage_collection_gen0_threshold: the garbage collection gen0 threshold
41+
clock_unit: the unit of the clock, 1 means 1 second
42+
"""
43+
44+
GenerationExecutorRpcProxy.INSTANCE_COUNTER += 1
45+
self.rpc_addr = self._gen_rpc_addr()
46+
self.rpc_client = RPCClient(self.rpc_addr)
47+
48+
postproc_worker_config = postproc_worker_config or PostprocWorkerConfig(
49+
)
50+
51+
super().__init__(
52+
num_postprocess_workers=postproc_worker_config.
53+
num_postprocess_workers,
54+
postprocess_tokenizer_dir=postproc_worker_config.
55+
postprocess_tokenizer_dir,
56+
is_llm_executor=is_llm_executor,
57+
)
58+
59+
self.mpi_session = self._create_mpi_session(model_world_size,
60+
mpi_session)
61+
62+
self._shutdown_event = threading.Event()
63+
64+
self.launch_workers()
65+
time.sleep(1) # wait for the workers to launch
66+
67+
# Invoke model creation on the remote
68+
# TBD: Move model creation to the mpi task, or left in RPC?
69+
self.create_engine_remote()
70+
71+
self.setup_mainloop()
72+
73+
def launch_workers(self):
74+
assert self.mpi_session is not None
75+
self.mpi_session.submit(rpc_worker_main,
76+
rpc_addr=self.rpc_addr,
77+
**self.worker_kwargs)
78+
79+
def main_loop_task(self):
80+
"""
81+
Main loop of the proxy, it will invoke the actions periodically.
82+
"""
83+
clock = 0
84+
while not self._shutdown_event.is_set():
85+
if clock % 1 == 0:
86+
responses = self.await_responses_remote()
87+
self.handle_responses(responses)
88+
if clock % 10 == 0:
89+
stats = self.get_stats_remote() # TODO
90+
self.handle_stats(stats)
91+
92+
clock += 1
93+
time.sleep(self.clock_unit)
94+
95+
def setup_mainloop(self):
96+
self.main_loop_thread = threading.Thread(target=self.main_loop_task,
97+
daemon=True)
98+
self.main_loop_thread.start()
99+
atexit.register(self.shutdown)
100+
101+
def handle_responses(self, responses: list[GenerationResult]) -> bool:
102+
async_queues = []
103+
event_loop = None
104+
105+
def process_res(res):
106+
client_id = res.client_id
107+
nonlocal event_loop
108+
nonlocal async_queues
109+
110+
queue = self._results[client_id].queue
111+
if isinstance(queue, _SyncQueue):
112+
queue.put_nowait(res)
113+
async_queues.append(queue)
114+
# all the loops are identical
115+
event_loop = event_loop or queue.loop
116+
else:
117+
queue.put(res)
118+
119+
if (is_llm_response(res) and res.result.is_final) or isinstance(
120+
res, ErrorResponse):
121+
self._results.pop(client_id)
122+
123+
for res in responses:
124+
global_tracer().log_instant("RPC.get")
125+
process_res(res)
126+
127+
if async_queues:
128+
_SyncQueue.notify_many(event_loop, async_queues)
129+
130+
def handle_stats(self, stats: dict):
131+
raise NotImplementedError
132+
133+
def submit(self, request: GenerationRequest) -> GenerationResult:
134+
# submit is a fire-and-forget operation, don't need to wait for response
135+
return self.rpc_client.submit(request, need_response=False)
136+
137+
def await_responses_remote(self):
138+
return self.rpc_client.await_responses()
139+
140+
def create_engine_remote(self):
141+
return self.rpc_client.create_engine() # TODO
142+
143+
def shutdown_remote(self):
144+
self.rpc_client.shutdown()
145+
146+
def _create_mpi_session(self, model_world_size: int,
147+
mpi_session: Optional[MpiSession]):
148+
mpi_process_pre_spawned: bool = get_spawn_proxy_process_env()
149+
if mpi_session is None:
150+
if mpi_process_pre_spawned:
151+
print_colored_debug('create comm session ...\n', "yellow")
152+
self.mpi_session = create_mpi_comm_session(model_world_size)
153+
else:
154+
print_colored_debug('create pool session ...\n', "yellow")
155+
self.mpi_session = MpiPoolSession(n_workers=model_world_size)
156+
else:
157+
print_colored_debug('using external mpi session ...\n', "yellow")
158+
self.mpi_session = mpi_session
159+
160+
def _gen_rpc_addr(self):
161+
process_id = os.getpid()
162+
return f"ipc:///tmp/rpc-proxy-{process_id}-{GenerationExecutorRpcProxy.INSTANCE_COUNTER}"
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from pathlib import Path
2+
from threading import Event
3+
from typing import Optional, Union
4+
5+
from ..bindings import executor as tllm
6+
from ..builder import Engine
7+
from ..lora_manager import LoraConfig
8+
from ..sampling_params import BatchedLogitsProcessor
9+
from .postproc_worker import PostprocWorkerConfig
10+
from .rpc import RpcService
11+
from .worker_base import WorkerBase
12+
13+
14+
class RpcWorker(WorkerBase):
15+
16+
def __init__(
17+
self,
18+
engine: Union[Path, Engine],
19+
executor_config: Optional[tllm.ExecutorConfig] = None,
20+
is_llm_executor: Optional[bool] = None,
21+
) -> None:
22+
super().__init__(engine=engine,
23+
executor_config=executor_config,
24+
is_llm_executor=is_llm_executor)
25+
self.shutdown_event = Event()
26+
27+
def shutdown(self):
28+
self.shutdown_event.set()
29+
super().shutdown()
30+
31+
32+
def rpc_worker_main(
33+
engine: Union[Path, Engine],
34+
rpc_addr: str,
35+
executor_config: Optional[tllm.ExecutorConfig] = None,
36+
batched_logits_processor: Optional[BatchedLogitsProcessor] = None,
37+
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
38+
is_llm_executor: Optional[bool] = None,
39+
lora_config: Optional[LoraConfig] = None,
40+
garbage_collection_gen0_threshold: Optional[int] = None,
41+
) -> None:
42+
# Step 1: Create the worker instance
43+
worker = RpcWorker(engine=engine, executor_config=executor_config)
44+
worker.create_engine(
45+
engine=engine,
46+
executor_config=executor_config,
47+
batched_logits_processor=batched_logits_processor,
48+
postproc_worker_config=postproc_worker_config,
49+
is_llm_executor=is_llm_executor,
50+
lora_config=lora_config,
51+
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold)
52+
53+
# Step 2: Create the RPC service, it will expose all the APIs of the worker as remote call to the client
54+
rpc_service = RpcService(worker)
55+
rpc_service.bind(rpc_addr)
56+
rpc_service.start()
57+
58+
# Step 3: Wait for the worker to shutdown
59+
worker.shutdown_event.wait()
60+
rpc_service.shutdown()

0 commit comments

Comments
 (0)