@@ -27,6 +27,7 @@ class RPCRequest(NamedTuple):
2727 args : tuple
2828 kwargs : dict
2929 need_response : bool = True
30+ timeout : float = 0.5
3031
3132
3233class RPCResponse (NamedTuple ):
@@ -176,15 +177,35 @@ async def _worker_routine(self, stop_event: threading.Event):
176177 if req .method_name in self ._functions :
177178 try :
178179 if self ._executor is not None :
179- # Dispatch to worker thread and await result
180+ # Dispatch to worker thread and await result with timeout
180181 loop = asyncio .get_running_loop ()
181- result = await loop .run_in_executor (
182- self ._executor , self ._functions [req .method_name ],
183- * req .args , ** req .kwargs )
182+
183+ # Create a wrapper function to handle keyword arguments
184+ def call_with_kwargs ():
185+ return self ._functions [req .method_name ](
186+ * req .args , ** req .kwargs )
187+
188+ result = await asyncio .wait_for (loop .run_in_executor (
189+ self ._executor , call_with_kwargs ),
190+ timeout = req .timeout )
184191 else :
185- result = self ._functions [req .method_name ](* req .args ,
186- ** req .kwargs )
192+ # For synchronous execution, we need to run in executor to support timeout
193+ loop = asyncio .get_running_loop ()
194+
195+ # Create a wrapper function to handle keyword arguments
196+ def call_with_kwargs ():
197+ return self ._functions [req .method_name ](
198+ * req .args , ** req .kwargs )
199+
200+ result = await asyncio .wait_for (loop .run_in_executor (
201+ None , call_with_kwargs ),
202+ timeout = req .timeout )
187203 response = RPCResponse (req .request_id , 'OK' , result )
204+ except asyncio .TimeoutError :
205+ response = RPCResponse (
206+ req .request_id , 'ERROR' ,
207+ f"Method '{ req .method_name } ' timed out after { req .timeout } seconds"
208+ )
188209 except Exception :
189210 tb = traceback .format_exc ()
190211 response = RPCResponse (req .request_id , 'ERROR' , tb )
@@ -313,13 +334,28 @@ async def _start_reader_if_needed(self):
313334 self ._reader_task = loop .create_task (self ._response_reader ())
314335
315336 async def _call_async (self , name , * args , ** kwargs ):
316- """Async version of RPC call."""
337+ """Async version of RPC call.
338+ Args:
339+ name: Method name to call
340+ *args: Positional arguments
341+ **kwargs: Keyword arguments
342+ __rpc_timeout: The timeout (seconds) for the RPC call.
343+
344+ Returns:
345+ The result of the remote method call
346+ """
317347 await self ._start_reader_if_needed ()
318348 need_response = kwargs .pop ("need_response" , True )
319349
320350 request_id = uuid .uuid4 ().hex
321351 logger .debug (f"RPC client sending request: { request_id } " )
322- request = RPCRequest (request_id , name , args , kwargs , need_response )
352+ timeout = kwargs .pop ("__rpc_timeout" , self ._timeout )
353+ request = RPCRequest (request_id ,
354+ name ,
355+ args ,
356+ kwargs ,
357+ need_response ,
358+ timeout = timeout )
323359 logger .debug (f"RPC client sending request: { request } " )
324360 await self ._client_socket .put_async (request )
325361
@@ -331,9 +367,12 @@ async def _call_async(self, name, *args, **kwargs):
331367 self ._pending_futures [request_id ] = future
332368
333369 try :
334- return await asyncio .wait_for (future , self ._timeout )
370+ # If timeout, the remote call should return a timeout error timely,
371+ # so we add 1 second to the timeout to ensure the client can get
372+ # that result.
373+ return await asyncio .wait_for (future , timeout + 1 )
335374 except asyncio .TimeoutError :
336- raise RPCError (f"Request '{ name } ' timed out after { self . _timeout } s" )
375+ raise RPCTimeout (f"Request '{ name } ' timed out after { timeout } s" )
337376 finally :
338377 self ._pending_futures .pop (request_id , None )
339378
0 commit comments