Skip to content

Commit f887cf0

Browse files
committed
amend test_worker_base
Signed-off-by: Superjomn <[email protected]>
1 parent 4e56982 commit f887cf0

File tree

6 files changed

+137
-55
lines changed

6 files changed

+137
-55
lines changed

tensorrt_llm/executor/rpc_proxy.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from .request import GenerationRequest
1313
from .result import GenerationResult
1414
from .rpc import RPCClient
15-
from .rpc_worker import rpc_worker_main
1615
from .utils import (ErrorResponse, create_mpi_comm_session,
1716
get_spawn_proxy_process_env, is_llm_response)
1817

@@ -42,7 +41,7 @@ def __init__(self,
4241
"""
4342

4443
GenerationExecutorRpcProxy.INSTANCE_COUNTER += 1
45-
self.rpc_addr = self._gen_rpc_addr()
44+
self.rpc_addr = self.gen_uniq_rpc_addr()
4645
self.rpc_client = RPCClient(self.rpc_addr)
4746

4847
postproc_worker_config = postproc_worker_config or PostprocWorkerConfig(
@@ -157,6 +156,7 @@ def _create_mpi_session(self, model_world_size: int,
157156
print_colored_debug('using external mpi session ...\n', "yellow")
158157
self.mpi_session = mpi_session
159158

160-
def _gen_rpc_addr(self):
159+
@staticmethod
160+
def gen_uniq_rpc_addr() -> str:
161161
process_id = os.getpid()
162162
return f"ipc:///tmp/rpc-proxy-{process_id}-{GenerationExecutorRpcProxy.INSTANCE_COUNTER}"
Lines changed: 60 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,30 @@
11
from pathlib import Path
2+
from queue import Queue
23
from threading import Event
34
from typing import Optional, Union
45

6+
from .._utils import mpi_rank
57
from ..bindings import executor as tllm
68
from ..builder import Engine
9+
from ..logger import logger
710
from ..lora_manager import LoraConfig
811
from ..sampling_params import BatchedLogitsProcessor
912
from .postproc_worker import PostprocWorkerConfig
10-
from .rpc import RpcService
13+
from .rpc import RPCServer
1114
from .worker_base import WorkerBase
1215

1316

1417
class RpcWorker(WorkerBase):
18+
"""
19+
A RPC wrapper for the WorkerBase class.
20+
21+
Actions:
22+
- `setup_engine`: Setup the engine.
23+
- `fetch_responses`: Fetch the latest responses from engine.
24+
- `fetch_stats`: Fetch the latest stats from engine.
25+
- `fetch_kv_cache_events`: Fetch the latest kv cache events from engine.
26+
- `shutdown`: Shutdown the worker.
27+
"""
1528

1629
def __init__(
1730
self,
@@ -24,37 +37,54 @@ def __init__(
2437
is_llm_executor=is_llm_executor)
2538
self.shutdown_event = Event()
2639

40+
self._response_queue = Queue()
41+
self.set_result_queue(self._response_queue)
42+
43+
def fetch_responses(self) -> list:
44+
super().await_responses()
45+
qsize = self._response_queue.qsize()
46+
return [self._response_queue.get() for _ in range(qsize)]
47+
2748
def shutdown(self):
2849
self.shutdown_event.set()
2950
super().shutdown()
3051

52+
@staticmethod
53+
def main_task(
54+
engine: Union[Path, Engine],
55+
rpc_addr: str,
56+
*,
57+
executor_config: Optional[tllm.ExecutorConfig] = None,
58+
batched_logits_processor: Optional[BatchedLogitsProcessor] = None,
59+
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
60+
is_llm_executor: Optional[bool] = None,
61+
lora_config: Optional[LoraConfig] = None,
62+
garbage_collection_gen0_threshold: Optional[int] = None,
63+
) -> None:
64+
# Step 1: Create the worker instance
65+
worker = RpcWorker(engine=engine, executor_config=executor_config)
66+
67+
if mpi_rank() != 0:
68+
logger.debug(f"Worker {mpi_rank()} is setting up the engine")
69+
# The non-leader worker will setup the engine immediately.
70+
# The leader worker will wait for the RPC call to propagate the
71+
# potential error.
72+
worker.setup_engine(
73+
engine=engine,
74+
executor_config=executor_config,
75+
batched_logits_processor=batched_logits_processor,
76+
postproc_worker_config=postproc_worker_config,
77+
is_llm_executor=is_llm_executor,
78+
lora_config=lora_config,
79+
garbage_collection_gen0_threshold=
80+
garbage_collection_gen0_threshold)
81+
82+
if mpi_rank() == 0:
83+
# Step 2: Create the RPC service, it will expose all the APIs of the worker as remote call to the client
84+
rpc_server = RPCServer(worker)
85+
rpc_server.bind(rpc_addr)
86+
rpc_server.start()
3187

32-
def rpc_worker_main(
33-
engine: Union[Path, Engine],
34-
rpc_addr: str,
35-
executor_config: Optional[tllm.ExecutorConfig] = None,
36-
batched_logits_processor: Optional[BatchedLogitsProcessor] = None,
37-
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
38-
is_llm_executor: Optional[bool] = None,
39-
lora_config: Optional[LoraConfig] = None,
40-
garbage_collection_gen0_threshold: Optional[int] = None,
41-
) -> None:
42-
# Step 1: Create the worker instance
43-
worker = RpcWorker(engine=engine, executor_config=executor_config)
44-
worker.create_engine(
45-
engine=engine,
46-
executor_config=executor_config,
47-
batched_logits_processor=batched_logits_processor,
48-
postproc_worker_config=postproc_worker_config,
49-
is_llm_executor=is_llm_executor,
50-
lora_config=lora_config,
51-
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold)
52-
53-
# Step 2: Create the RPC service, it will expose all the APIs of the worker as remote call to the client
54-
rpc_service = RpcService(worker)
55-
rpc_service.bind(rpc_addr)
56-
rpc_service.start()
57-
58-
# Step 3: Wait for the worker to shutdown
59-
worker.shutdown_event.wait()
60-
rpc_service.shutdown()
88+
# Step 3: Wait for the worker to shutdown
89+
worker.shutdown_event.wait()
90+
rpc_server.shutdown()

tensorrt_llm/executor/worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def __init__(
6666
executor_config.logits_post_processor_config = tllm.LogitsPostProcessorConfig(
6767
processor_batched=batched_logits_processor, replicate=False)
6868

69-
self.create_engine(
69+
self.setup_engine(
7070
engine=engine,
7171
executor_config=executor_config,
7272
lora_config=lora_config,

tensorrt_llm/executor/worker_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def __init__(
7878
self._prompt_adapter_manager: Optional[PromptAdapterManager] = None
7979
self._runtime_model_config: Optional[ModelConfig] = None
8080

81-
def create_engine(
81+
def setup_engine(
8282
self,
8383
engine: Union[Path, Engine],
8484
executor_config: Optional[tllm.ExecutorConfig] = None,
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import multiprocessing
2+
import os
3+
import sys
4+
import time
5+
from concurrent.futures import ProcessPoolExecutor
6+
7+
from tensorrt_llm.executor.request import GenerationRequest
8+
from tensorrt_llm.executor.rpc import RPCClient
9+
from tensorrt_llm.executor.rpc_proxy import GenerationExecutorRpcProxy
10+
from tensorrt_llm.executor.rpc_worker import RpcWorker
11+
from tensorrt_llm.sampling_params import SamplingParams
12+
13+
# isort: off
14+
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/..")
15+
from utils.llm_data import llm_models_root
16+
# isort: on
17+
18+
model_path = llm_models_root() / "llama-models-v2/TinyLlama-1.1B-Chat-v1.0"
19+
20+
21+
class TestRpcWorker:
22+
23+
def create_tp1_worker_process(self):
24+
addr = GenerationExecutorRpcProxy.gen_uniq_rpc_addr()
25+
# Use spawn method instead of fork
26+
mp_context = multiprocessing.get_context('spawn')
27+
pool = ProcessPoolExecutor(max_workers=1, mp_context=mp_context)
28+
pool.submit(RpcWorker.main_task, engine=model_path, rpc_addr=addr)
29+
return pool, addr
30+
31+
def create_rpc_client(self, addr: str):
32+
client = RPCClient(addr)
33+
return client
34+
35+
def test_main(self):
36+
pool, addr = self.create_tp1_worker_process()
37+
client = self.create_rpc_client(addr)
38+
client.setup_engine(engine=model_path)
39+
time.sleep(1)
40+
client.submit(
41+
GenerationRequest(prompt_token_ids=[3, 4, 5],
42+
sampling_params=SamplingParams(max_tokens=10)))
43+
responses = client.fetch_responses()
44+
assert responses
45+
46+
client.shutdown()
47+
pool.shutdown()
48+
49+
50+
if __name__ == '__main__':
51+
worker = TestRpcWorker()
52+
worker.test_main()

tests/unittest/executor/test_worker_base.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import os
22
import sys
33
import time
4-
from queue import Queue
54

65
# isort: off
76
sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/..")
@@ -21,37 +20,37 @@
2120

2221
class TestWorkerBase:
2322

23+
class FakeWorker(WorkerBase):
24+
25+
def __init__(self, engine: str):
26+
super().__init__(engine=engine)
27+
executor_config = TestWorkerBase.create_fake_executor_config(engine)
28+
self.setup_engine(engine=engine, executor_config=executor_config)
29+
2430
def test_create_engine(self):
25-
with WorkerBase(engine=model_path) as worker:
26-
pass
31+
with self.FakeWorker(engine=model_path) as worker:
32+
print(f"Created engine: {worker.engine}")
2733

2834
def test_submit_request(self):
2935
sampling_params = SamplingParams(max_tokens=10)
3036
request = GenerationRequest(prompt_token_ids=[3, 4, 5],
3137
sampling_params=sampling_params)
32-
with WorkerBase(engine=model_path) as worker:
33-
worker.submit(request)
34-
35-
def test_await_responses(self):
36-
sampling_params = SamplingParams(max_tokens=10)
37-
request = GenerationRequest(prompt_token_ids=[3, 4, 5],
38-
sampling_params=sampling_params)
39-
with WorkerBase(engine=model_path) as worker:
40-
result_queue = Queue()
41-
worker.set_result_queue(result_queue)
42-
38+
with self.FakeWorker(engine=model_path) as worker:
39+
print(f"Created engine: {worker.engine}")
4340
worker.submit(request)
4441
for i in range(10):
42+
time.sleep(0.5)
4543
worker.await_responses()
46-
47-
assert result_queue.qsize() > 0
44+
print(f"Submitted request: {request}")
45+
time.sleep(6)
4846

4947
def test_fetch_stats(self):
5048
request = GenerationRequest(
5149
prompt_token_ids=[3, 4, 5],
5250
sampling_params=SamplingParams(max_tokens=10))
53-
with WorkerBase(engine=model_path) as worker:
51+
with self.FakeWorker(engine=model_path) as worker:
5452
worker.submit(request)
53+
time.sleep(1)
5554
worker.await_responses()
5655
stats = worker.fetch_stats()
5756
assert len(stats) > 0
@@ -60,15 +59,16 @@ def test_dispatch_stats_task(self):
6059
request = GenerationRequest(
6160
prompt_token_ids=[3, 4, 5],
6261
sampling_params=SamplingParams(max_tokens=10))
63-
with WorkerBase(engine=model_path) as worker:
62+
with self.FakeWorker(engine=model_path) as worker:
6463
worker.submit(request)
6564
worker.await_responses()
6665
worker.dispatch_stats_task()
6766
time.sleep(10)
6867
stats = worker.fetch_stats()
6968
assert len(stats) == 1
7069

71-
def _create_executor_config(self):
70+
@staticmethod
71+
def create_fake_executor_config(model_path):
7272
llm_args = LlmArgs(model=model_path, cuda_graph_config=None)
7373

7474
executor_config = tllm.ExecutorConfig(1)
@@ -92,4 +92,4 @@ def _create_executor_config(self):
9292

9393
if __name__ == "__main__":
9494
test_worker_base = TestWorkerBase()
95-
test_worker_base.test_create_engine()
95+
test_worker_base.test_fetch_stats()

0 commit comments

Comments
 (0)