Skip to content

Commit 781da75

Browse files
committed
refine rpc timeout
TODO: The timeout error may need a specific Error type
1 parent f887cf0 commit 781da75

File tree

2 files changed

+107
-11
lines changed

2 files changed

+107
-11
lines changed

tensorrt_llm/executor/rpc.py

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class RPCRequest(NamedTuple):
2727
args: tuple
2828
kwargs: dict
2929
need_response: bool = True
30+
timeout: float = 0.5
3031

3132

3233
class RPCResponse(NamedTuple):
@@ -176,15 +177,35 @@ async def _worker_routine(self, stop_event: threading.Event):
176177
if req.method_name in self._functions:
177178
try:
178179
if self._executor is not None:
179-
# Dispatch to worker thread and await result
180+
# Dispatch to worker thread and await result with timeout
180181
loop = asyncio.get_running_loop()
181-
result = await loop.run_in_executor(
182-
self._executor, self._functions[req.method_name],
183-
*req.args, **req.kwargs)
182+
183+
# Create a wrapper function to handle keyword arguments
184+
def call_with_kwargs():
185+
return self._functions[req.method_name](
186+
*req.args, **req.kwargs)
187+
188+
result = await asyncio.wait_for(loop.run_in_executor(
189+
self._executor, call_with_kwargs),
190+
timeout=req.timeout)
184191
else:
185-
result = self._functions[req.method_name](*req.args,
186-
**req.kwargs)
192+
# For synchronous execution, we need to run in executor to support timeout
193+
loop = asyncio.get_running_loop()
194+
195+
# Create a wrapper function to handle keyword arguments
196+
def call_with_kwargs():
197+
return self._functions[req.method_name](
198+
*req.args, **req.kwargs)
199+
200+
result = await asyncio.wait_for(loop.run_in_executor(
201+
None, call_with_kwargs),
202+
timeout=req.timeout)
187203
response = RPCResponse(req.request_id, 'OK', result)
204+
except asyncio.TimeoutError:
205+
response = RPCResponse(
206+
req.request_id, 'ERROR',
207+
f"Method '{req.method_name}' timed out after {req.timeout} seconds"
208+
)
188209
except Exception:
189210
tb = traceback.format_exc()
190211
response = RPCResponse(req.request_id, 'ERROR', tb)
@@ -313,13 +334,28 @@ async def _start_reader_if_needed(self):
313334
self._reader_task = loop.create_task(self._response_reader())
314335

315336
async def _call_async(self, name, *args, **kwargs):
316-
"""Async version of RPC call."""
337+
"""Async version of RPC call.
338+
Args:
339+
name: Method name to call
340+
*args: Positional arguments
341+
**kwargs: Keyword arguments
342+
__rpc_timeout: The timeout (seconds) for the RPC call.
343+
344+
Returns:
345+
The result of the remote method call
346+
"""
317347
await self._start_reader_if_needed()
318348
need_response = kwargs.pop("need_response", True)
319349

320350
request_id = uuid.uuid4().hex
321351
logger.debug(f"RPC client sending request: {request_id}")
322-
request = RPCRequest(request_id, name, args, kwargs, need_response)
352+
timeout = kwargs.pop("__rpc_timeout", self._timeout)
353+
request = RPCRequest(request_id,
354+
name,
355+
args,
356+
kwargs,
357+
need_response,
358+
timeout=timeout)
323359
logger.debug(f"RPC client sending request: {request}")
324360
await self._client_socket.put_async(request)
325361

@@ -331,9 +367,12 @@ async def _call_async(self, name, *args, **kwargs):
331367
self._pending_futures[request_id] = future
332368

333369
try:
334-
return await asyncio.wait_for(future, self._timeout)
370+
# If timeout, the remote call should return a timeout error timely,
371+
# so we add 1 second to the timeout to ensure the client can get
372+
# that result.
373+
return await asyncio.wait_for(future, timeout + 1)
335374
except asyncio.TimeoutError:
336-
raise RPCError(f"Request '{name}' timed out after {self._timeout}s")
375+
raise RPCTimeout(f"Request '{name}' timed out after {timeout}s")
337376
finally:
338377
self._pending_futures.pop(request_id, None)
339378

tests/unittest/executor/test_rpc.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,66 @@ def cal(self, n: int):
183183

184184
time_start = time.time()
185185
for i in range(10000):
186-
ret = client.cal(i) # sync call
186+
ret = client.cal(i, __rpc_timeout=10) # sync call
187187
assert ret == i * 2, f"{ret} != {i * 2}"
188188
time_end = time.time()
189189
print(
190190
f"Time taken: {time_end - time_start} seconds, {10000 / (time_end - time_start)} calls/second"
191191
)
192+
193+
194+
@pytest.mark.parametrize("use_async", [True, False])
195+
def test_rpc_timeout(use_async: bool):
196+
"""Test RPC timeout functionality.
197+
198+
Args:
199+
use_async: Whether to test async RPC calls or sync RPC calls
200+
"""
201+
202+
class App:
203+
204+
def slow_operation(self, delay: float):
205+
"""A method that takes a long time to complete."""
206+
time.sleep(delay)
207+
return "completed"
208+
209+
with RPCServer(App()) as server:
210+
server.bind("ipc:///tmp/rpc_test_timeout")
211+
server.start()
212+
time.sleep(0.1)
213+
client = RPCClient("ipc:///tmp/rpc_test_timeout")
214+
215+
# Test that a short timeout causes RPCTimeout exception
216+
with pytest.raises(RPCError) as exc_info:
217+
if use_async:
218+
# Test async call with timeout
219+
import asyncio
220+
221+
async def test_async_timeout():
222+
return await client.call_async('slow_operation',
223+
2.0,
224+
__rpc_timeout=0.1)
225+
226+
asyncio.run(test_async_timeout())
227+
else:
228+
# Test sync call with timeout
229+
client.slow_operation(2.0, __rpc_timeout=0.1)
230+
231+
assert "timed out" in str(
232+
exc_info.value), f"Timeout message not found: {exc_info.value}"
233+
234+
# Test that a long timeout allows the operation to complete
235+
if use_async:
236+
# Test async call with sufficient timeout
237+
import asyncio
238+
239+
async def test_async_success():
240+
return await client.call_async('slow_operation',
241+
0.1,
242+
__rpc_timeout=1.0)
243+
244+
result = asyncio.run(test_async_success())
245+
else:
246+
result = client.slow_operation(0.1, __rpc_timeout=1.0)
247+
248+
assert result == "completed"

0 commit comments

Comments
 (0)