Skip to content

Commit ddb12ed

Browse files
committed
enhance rpc_worker and test
Signed-off-by: Superjomn <[email protected]>
1 parent 781da75 commit ddb12ed

File tree

5 files changed

+84
-18
lines changed

5 files changed

+84
-18
lines changed

tensorrt_llm/executor/rpc.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -340,16 +340,21 @@ async def _call_async(self, name, *args, **kwargs):
340340
*args: Positional arguments
341341
**kwargs: Keyword arguments
342342
__rpc_timeout: The timeout (seconds) for the RPC call.
343+
__rpc_need_response: Whether the RPC call needs a response.
344+
If set to False, the remote call will return immediately.
343345
344346
Returns:
345347
The result of the remote method call
346348
"""
349+
logger.debug(
350+
f"RPC client calling method: {name} with args: {args} and kwargs: {kwargs}"
351+
)
347352
await self._start_reader_if_needed()
348-
need_response = kwargs.pop("need_response", True)
353+
need_response = kwargs.pop("__rpc_need_response", True)
354+
timeout = kwargs.pop("__rpc_timeout", self._timeout)
349355

350356
request_id = uuid.uuid4().hex
351357
logger.debug(f"RPC client sending request: {request_id}")
352-
timeout = kwargs.pop("__rpc_timeout", self._timeout)
353358
request = RPCRequest(request_id,
354359
name,
355360
args,
@@ -395,7 +400,7 @@ def call_async(self, name: str, *args, **kwargs):
395400
Example:
396401
result = await client.call_async('remote_method', arg1, arg2, key=value)
397402
"""
398-
return self._call_async(name, *args, **kwargs, need_response=True)
403+
return self._call_async(name, *args, **kwargs, __rpc_need_response=True)
399404

400405
def call_future(self, name: str, *args,
401406
**kwargs) -> concurrent.futures.Future:
@@ -457,9 +462,7 @@ def __call__(self, *args, **kwargs):
457462

458463
def call_async(self, *args, **kwargs):
459464
"""Async call - returns coroutine"""
460-
return self.client._call_async(self.method_name,
461-
*args,
462-
need_response=True,
465+
return self.client._call_async(self.method_name, *args,
463466
**kwargs)
464467

465468
def call_future(self, *args, **kwargs) -> concurrent.futures.Future:

tensorrt_llm/executor/rpc_worker.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,13 @@ def __init__(
4141
self.set_result_queue(self._response_queue)
4242

4343
def fetch_responses(self) -> list:
44+
logger.debug(f"RPC worker {mpi_rank()} is fetching responses")
4445
super().await_responses()
4546
qsize = self._response_queue.qsize()
4647
return [self._response_queue.get() for _ in range(qsize)]
4748

4849
def shutdown(self):
50+
logger.debug(f"RPC worker {mpi_rank()} is shutting down")
4951
self.shutdown_event.set()
5052
super().shutdown()
5153

tensorrt_llm/executor/worker_base.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from ..llmapi.llm_args import PybindMirror
1919
from ..llmapi.tracer import global_tracer
2020
from ..llmapi.utils import _SyncQueue, print_colored_debug
21+
from ..logger import logger
2122
from ..lora_manager import LoraConfig, LoraManager
2223
from ..prompt_adapter_manager import PromptAdapterManager
2324
from ..runtime import ModelConfig
@@ -29,7 +30,11 @@
2930
from .request import GenerationRequest, LoRARequest, PromptAdapterRequest
3031
from .result import (GenerationResult, LogProbsResult, ResponseWrapper,
3132
compute_logprobs)
32-
from .utils import ErrorResponse, RequestError, is_llm_response
33+
from .utils import (ErrorResponse, RequestError, enable_llm_debug,
34+
is_llm_response)
35+
36+
if enable_llm_debug():
37+
logger.set_level("debug")
3338

3439
__all__ = [
3540
"WorkerBase",
@@ -405,6 +410,7 @@ def __exit__(self, exc_type, exc_value, traceback):
405410

406411
def await_responses(self) -> None:
407412
self._await_response_helper()
413+
logger.debug(f"worker done await_responses")
408414

409415
def fetch_kv_cache_events(self) -> list:
410416
if isinstance(self.engine, tllm.Executor):
@@ -472,6 +478,11 @@ def shutdown(self):
472478
# Check if there are any errors from the threads before shutdown.
473479
self._handle_background_error()
474480

481+
def _has_background_error(self) -> bool:
482+
# TODO[Superjomn]: The worker background error should be deprecated once
483+
# RPC approach is supported.
484+
return not self._error_queue.empty()
485+
475486

476487
class AwaitResponseHelper:
477488
''' Multiple-implementations for await_response for performance. '''
@@ -518,8 +529,11 @@ def responses_handler(self, responses: List[tllm.Response]):
518529

519530
def __call__(self) -> bool:
520531
''' This method should be called by a ManagedThread. '''
532+
logger.debug(f"await_response: {self.worker.engine}")
521533
responses = self.worker.engine.await_responses(
522534
timeout=datetime.timedelta(milliseconds=100))
535+
logger.debug(f"PyExecutor returned {len(responses)} responses")
536+
523537
# filter since The _engine_response_callback may return None
524538
responses = list(
525539
filter(

tests/unittest/executor/test_rpc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def get_task_submitted(self) -> bool:
123123
server.start()
124124
time.sleep(0.1)
125125
client = RPCClient("ipc:///tmp/rpc_test_no_wait")
126-
client.send_task(need_response=False)
126+
client.send_task(__rpc_need_response=False)
127127
time.sleep(0.1) # wait for some time to make sure the task is submitted
128128
assert client.get_task_submitted()
129129

@@ -149,14 +149,14 @@ def send_task(self) -> None:
149149

150150
time_start = time.time()
151151
for i in range(100):
152-
client.send_task(need_response=False)
152+
client.send_task(__rpc_need_response=False)
153153
time_end = time.time()
154154

155155
no_wait_time = time_end - time_start
156156

157157
time_start = time.time()
158158
for i in range(100):
159-
client.send_task(need_response=True)
159+
client.send_task(__rpc_need_response=True)
160160
time_end = time.time()
161161
wait_time = time_end - time_start
162162

tests/unittest/executor/test_rpc_worker.py

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import time
55
from concurrent.futures import ProcessPoolExecutor
66

7+
from test_worker_base import TestWorkerBase
8+
79
from tensorrt_llm.executor.request import GenerationRequest
810
from tensorrt_llm.executor.rpc import RPCClient
911
from tensorrt_llm.executor.rpc_proxy import GenerationExecutorRpcProxy
@@ -20,12 +22,19 @@
2022

2123
class TestRpcWorker:
2224

25+
def __init__(self):
26+
self.executor_config = TestWorkerBase.create_fake_executor_config(
27+
model_path)
28+
2329
def create_tp1_worker_process(self):
2430
addr = GenerationExecutorRpcProxy.gen_uniq_rpc_addr()
2531
# Use spawn method instead of fork
2632
mp_context = multiprocessing.get_context('spawn')
2733
pool = ProcessPoolExecutor(max_workers=1, mp_context=mp_context)
28-
pool.submit(RpcWorker.main_task, engine=model_path, rpc_addr=addr)
34+
pool.submit(RpcWorker.main_task,
35+
engine=model_path,
36+
rpc_addr=addr,
37+
executor_config=self.executor_config)
2938
return pool, addr
3039

3140
def create_rpc_client(self, addr: str):
@@ -35,15 +44,53 @@ def create_rpc_client(self, addr: str):
3544
def test_main(self):
3645
pool, addr = self.create_tp1_worker_process()
3746
client = self.create_rpc_client(addr)
38-
client.setup_engine(engine=model_path)
47+
print("call setup_engine")
48+
client.setup_engine(engine=model_path,
49+
executor_config=self.executor_config,
50+
__rpc_timeout=120)
51+
print("call submit")
3952
time.sleep(1)
40-
client.submit(
41-
GenerationRequest(prompt_token_ids=[3, 4, 5],
42-
sampling_params=SamplingParams(max_tokens=10)))
43-
responses = client.fetch_responses()
44-
assert responses
4553

46-
client.shutdown()
54+
def process_request():
55+
ret = client.submit(GenerationRequest(
56+
prompt_token_ids=[3, 4, 5],
57+
sampling_params=SamplingParams(max_tokens=10)),
58+
__rpc_need_response=False)
59+
assert ret is None
60+
61+
print(f"submit result: {ret}")
62+
print("call fetch_responses")
63+
# NOTE: known issue, the responses should be fetched before shutdown,
64+
# or the shutdown will hang.
65+
results = []
66+
for i in range(3):
67+
time.sleep(3)
68+
results.extend(client.fetch_responses())
69+
print(f"fetch_responses result: {results}")
70+
assert len(results) == 1
71+
72+
def process_request_streaming():
73+
ret = client.submit(prompt_token_ids=[3, 4, 5],
74+
sampling_params=SamplingParams(max_tokens=10),
75+
streaming=True,
76+
__rpc_need_response=False)
77+
assert ret is None
78+
79+
print("call fetch_responses")
80+
# NOTE: known issue, the responses should be fetched before shutdown,
81+
# or the shutdown will hang.
82+
results = []
83+
for i in range(3):
84+
time.sleep(3)
85+
results.extend(client.fetch_responses())
86+
print(f"fetch_responses result: {results}")
87+
print(f"generate_async result: {results}")
88+
89+
process_request()
90+
process_request_streaming()
91+
92+
print("call shutdown")
93+
client.shutdown(__rpc_timeout=10)
4794
pool.shutdown()
4895

4996

0 commit comments

Comments
 (0)