diff --git a/tensorrt_llm/executor/base_worker.py b/tensorrt_llm/executor/base_worker.py index c17401cc9fa..0bdfc1ca52b 100644 --- a/tensorrt_llm/executor/base_worker.py +++ b/tensorrt_llm/executor/base_worker.py @@ -235,6 +235,12 @@ def fetch_stats(self) -> list: else: return self.engine.get_latest_iteration_stats() + def fetch_kv_cache_events(self) -> list: + if isinstance(self.engine, tllm.Executor): + return self.engine.get_latest_kv_cache_events() + else: + return self.engine.get_latest_kv_cache_events() + def set_result_queue(self, queue): """In multi-gpu mode, result_queue will be set here to communicate between the proxy and the worker 0 process.""" assert self.postproc_queues is None @@ -548,6 +554,38 @@ def submit(self, request: GenerationRequest) -> GenerationResult: return result + def shutdown(self): + if self.doing_shutdown: + return + else: + self.doing_shutdown = True + + if self.engine is not None and self.engine.can_enqueue_requests(): + self.engine.shutdown() + self.engine = None + + # Define a Callable to join iteration and request stats + @staticmethod + def _stats_serializer( + stats: Tuple[tllm.IterationStats, tllm.RequestStats]) -> str: + iteration_stats, req_stats = stats + stats_dict = json.loads(iteration_stats.to_json_str()) + + if req_stats is not None and len(req_stats) > 0: + stats_dict["requestStats"] = [] + for req_stat in req_stats: + stats_dict["requestStats"].append( + json.loads(req_stat.to_json_str())) + + # Convert back to JSON string + return json.dumps(stats_dict) + + # Define a Callable to serialize KV cache events + @staticmethod + def _kv_cache_events_serializer(events) -> str: + from .._utils import KVCacheEventSerializer + return json.dumps(KVCacheEventSerializer.serialize(events)) + def _pop_result(self, client_id: int): self._results.pop(client_id, None) self._client_id_to_request_id.pop(client_id, None) diff --git a/tensorrt_llm/executor/executor.py b/tensorrt_llm/executor/executor.py index e8af846c247..c7fc8ea5d0d 100644 --- a/tensorrt_llm/executor/executor.py +++ b/tensorrt_llm/executor/executor.py @@ -410,9 +410,22 @@ def create( mpirun_launch = external_mpi_comm_available(model_world_size) # The case where the Python main process utilizes mpi4py to spawn MPI workers spawn_workers = need_spawn_mpi_workers(model_world_size) + orchestrator_is_rpc = llm_args and llm_args.orchestrator_type == "rpc" + if spawn_workers or (mpirun_launch and reuse_mpi_comm): if reuse_mpi_comm: assert mpi_session is not None, "reuse_mpi_comm requires an external MPI session" + + if orchestrator_is_rpc: + from .rpc_proxy import GenerationExecutorRpcProxy + return GenerationExecutorRpcProxy( + worker_kwargs, + model_world_size=model_world_size, + mpi_session=mpi_session, + postproc_worker_config=postproc_worker_config, + is_llm_executor=is_llm_executor, + kv_connector_config=kv_connector_config) + return GenerationExecutorProxy( worker_kwargs, model_world_size=model_world_size, @@ -429,6 +442,16 @@ def create( logger.warning( "Using single process worker for TP1, this may hurt streaming generation performance." ) + if orchestrator_is_rpc: + from .rpc_proxy import GenerationExecutorRpcProxy + return GenerationExecutorRpcProxy( + worker_kwargs, + model_world_size=model_world_size, + mpi_session=mpi_session, + postproc_worker_config=postproc_worker_config, + is_llm_executor=is_llm_executor, + kv_connector_config=kv_connector_config) + return GenerationExecutorWorker( **worker_kwargs, is_llm_executor=is_llm_executor, @@ -439,6 +462,16 @@ def create( # While this requires uses to protect their entrypoint to # `if __name__ == "__main__":`. if not platform.system() == 'Windows': + if orchestrator_is_rpc: + from .rpc_proxy import GenerationExecutorRpcProxy + return GenerationExecutorRpcProxy( + worker_kwargs, + model_world_size=model_world_size, + mpi_session=mpi_session, + postproc_worker_config=postproc_worker_config, + is_llm_executor=is_llm_executor, + kv_connector_config=kv_connector_config) + return GenerationExecutorProxy( worker_kwargs, model_world_size=model_world_size, @@ -451,6 +484,7 @@ def create( # The ProcessPoolExecutorSession is used to support Windows, as mpi4py cannot. mpi_session = ProcessPoolExecutorSession(n_workers=1, mp_context=ctx) + # TODO: add rpc worker here return GenerationExecutorProxy( worker_kwargs, model_world_size=model_world_size, diff --git a/tensorrt_llm/executor/ipc.py b/tensorrt_llm/executor/ipc.py index 327dbf4f6f5..a2b01ee96fb 100644 --- a/tensorrt_llm/executor/ipc.py +++ b/tensorrt_llm/executor/ipc.py @@ -1,3 +1,4 @@ +import asyncio import hashlib import hmac import os @@ -179,6 +180,20 @@ async def put_async(self, obj: Any): nvtx_mark("ipc.send", color="blue", category="IPC") + async def put_async_noblock(self, obj: Any): + self.setup_lazily() + try: + if self.use_hmac_encryption: + data = pickle.dumps(obj) # nosec B301 + signed_data = self._sign_data(data) + await self.socket.send(signed_data, flags=zmq.NOBLOCK) + else: + await self.socket.send_pyobj(obj, flags=zmq.NOBLOCK) + except Exception as e: + logger.error(f"Error sending object: {e}") + logger.error(traceback.format_exc()) + raise e + def get(self) -> Any: self.setup_lazily() return self._recv_data() @@ -187,6 +202,9 @@ async def get_async(self) -> Any: self.setup_lazily() return await self._recv_data_async() + async def get_async_noblock(self, timeout: float = 0.5) -> Any: + return await asyncio.wait_for(self.get_async(), timeout) + def close(self): if self.socket: self.socket.close() diff --git a/tensorrt_llm/executor/result.py b/tensorrt_llm/executor/result.py index d19a8368297..927d6b0a9d4 100644 --- a/tensorrt_llm/executor/result.py +++ b/tensorrt_llm/executor/result.py @@ -14,7 +14,7 @@ from ..bindings import executor as tllm from ..disaggregated_params import DisaggregatedParams from ..llmapi.tracer import global_tracer -from ..llmapi.utils import AsyncQueue +from ..llmapi.utils import AsyncQueue, print_traceback_on_error from ..metrics import MetricNames, MetricsCollector, RequestEventTiming from ..sampling_params import LogprobParams, SamplingParams from .utils import ErrorResponse, has_event_loop, is_llm_response @@ -315,6 +315,7 @@ def _handle_sequence(self, f"Unknown finish reason: {finish_reasons[src_idx]}") self.record_stats(output, req_perf_metrics_dict) + @print_traceback_on_error @nvtx_range_debug("handle_response", color="red", category="GenerationResultBase") diff --git a/tensorrt_llm/executor/rpc/README.md b/tensorrt_llm/executor/rpc/README.md new file mode 100644 index 00000000000..76d7b846ab3 --- /dev/null +++ b/tensorrt_llm/executor/rpc/README.md @@ -0,0 +1,85 @@ +# A Lightweight RPC +This is a pure-Python lightweight RPC we build to simplify our existing IPC code in the orchestrator part. It provides multiple call modes (sync, async, future, streaming) and supports both IPC and TCP connections. + +## Examples +### Create Server and Client + +```python +from tensorrt_llm.executor.rpc import RPCServer, RPCClient + +# Define your application +class App: + def add(self, a: int, b: int) -> int: + return a + b + + async def async_multiply(self, x: int, y: int) -> int: + return x * y + +# Create and start server +app = App() +with RPCServer(app) as server: + server.bind("ipc:///tmp/my_rpc") # or "tcp://127.0.0.1:5555" + server.start() + + # Create client and make calls + with RPCClient("ipc:///tmp/my_rpc") as client: + result = client.add(5, 3).remote() + print(result) # Output: 8 +``` + +### Different Remote Calls + +#### Synchronous Call +```python +# Blocking call that waits for result +result = client.add(10, 20).remote() +# or with timeout +result = client.add(10, 20).remote(timeout=5.0) +``` + +#### Asynchronous Call +```python +# Async call that returns a coroutine +result = await client.async_multiply(3, 4).remote_async() +``` + +#### Future-based Call +```python +# Returns a concurrent.futures.Future +future = client.add(1, 2).remote_future() +# Get result later +result = future.result() +``` + +#### Fire-and-Forget Call +```python +# Send request without waiting for response +client.submit_task(task_id=123).remote(need_response=False) +``` + +#### Streaming Call +```python +# For async generator methods +async for value in client.stream_data(n=10).remote_streaming(): + print(f"Received: {value}") +``` + +### Error Handling +```python +from tensorrt_llm.executor.rpc import RPCError, RPCTimeout + +try: + result = client.risky_operation().remote(timeout=1.0) +except RPCTimeout: + print("Operation timed out") +except RPCError as e: + print(f"RPC Error: {e}") + print(f"Original cause: {e.cause}") + print(f"Traceback: {e.traceback}") +``` + +### Graceful Shutdown +```python +# Shutdown server from client +client.shutdown_server() +``` diff --git a/tensorrt_llm/executor/rpc/__init__.py b/tensorrt_llm/executor/rpc/__init__.py new file mode 100644 index 00000000000..6f62051bb41 --- /dev/null +++ b/tensorrt_llm/executor/rpc/__init__.py @@ -0,0 +1,10 @@ +from .rpc_client import RPCClient +from .rpc_common import (RPCCancelled, RPCError, RPCParams, RPCRequest, + RPCResponse, RPCStreamingError, RPCTimeout) +from .rpc_server import RPCServer, Server + +__all__ = [ + "RPCClient", "RPCServer", "Server", "RPCError", "RPCTimeout", + "RPCCancelled", "RPCStreamingError", "RPCRequest", "RPCResponse", + "RPCParams" +] diff --git a/tensorrt_llm/executor/rpc/rpc_client.py b/tensorrt_llm/executor/rpc/rpc_client.py new file mode 100644 index 00000000000..03d43ce1d0b --- /dev/null +++ b/tensorrt_llm/executor/rpc/rpc_client.py @@ -0,0 +1,497 @@ +import asyncio +import concurrent.futures +import threading +import uuid +from typing import Any, AsyncIterator, Dict, Optional + +from ...llmapi.utils import AsyncQueue, _SyncQueue, logger_debug +from ...logger import logger +from ..ipc import ZeroMqQueue +from .rpc_common import (RPCCancelled, RPCParams, RPCRequest, RPCResponse, + RPCStreamingError, RPCTimeout) + + +class RemoteCall: + """Helper class to enable chained remote call syntax like client.method().remote()""" + + def __init__(self, client: 'RPCClient', method_name: str, *args, **kwargs): + self.client = client + self.method_name = method_name + self.args = args + self.kwargs = kwargs + + def _prepare_and_call(self, timeout: Optional[float], need_response: bool, + mode: str, call_method: str) -> Any: + """Common method to prepare RPC params and make the call. + + Args: + timeout: Timeout for the RPC call + need_response: Whether a response is expected + mode: The RPC mode ("sync", "async", "future") + call_method: The method name to call on the client + + Returns: + The result of the client method call + """ + rpc_params = RPCParams(timeout=timeout, + need_response=need_response, + mode=mode) + self.kwargs["__rpc_params"] = rpc_params + client_method = getattr(self.client, call_method) + return client_method(self.method_name, *self.args, **self.kwargs) + + def remote(self, + timeout: Optional[float] = None, + need_response: bool = True) -> Any: + """Synchronous remote call with optional RPC parameters.""" + return self._prepare_and_call(timeout, need_response, "sync", + "_call_sync") + + def remote_async(self, + timeout: Optional[float] = None, + need_response: bool = True): + """Asynchronous remote call that returns a coroutine.""" + return self._prepare_and_call(timeout, need_response, "async", + "_call_async") + + def remote_future(self, + timeout: Optional[float] = None, + need_response: bool = True) -> concurrent.futures.Future: + """Remote call that returns a Future object.""" + return self._prepare_and_call(timeout, need_response, "future", + "_call_future") + + def remote_streaming(self, + timeout: Optional[float] = None) -> AsyncIterator[Any]: + """Remote call for streaming results.""" + # Streaming always needs a response + return self._prepare_and_call(timeout, True, "async", "_call_streaming") + + +class RPCClient: + """ + An RPC Client that connects to the RPCServer. + """ + + def __init__(self, + address: str, + hmac_key=None, + timeout: Optional[float] = None, + num_workers: int = 4): + ''' + Args: + address: The ZMQ address to connect to. + hmac_key: The HMAC key for encryption. + timeout: The timeout (seconds) for RPC calls. + num_workers: The number of workers for the RPC client. + ''' + self._address = address + self._timeout = timeout + self._client_socket = ZeroMqQueue(address=(address, hmac_key), + is_server=False, + is_async=True, + use_hmac_encryption=False) + self._pending_futures = {} + # map request_id to the queue for streaming responses + self._streaming_queues: Dict[str, AsyncQueue] = {} + self._reader_task = None + self._executor = concurrent.futures.ThreadPoolExecutor( + max_workers=num_workers, thread_name_prefix="rpc_client_worker") + + self._server_stopped = False + self._closed = False + self._stop_event = None + self._loop = None + self._loop_thread = None + + logger_debug(f"RPC Client initialized. Connected to {self._address}") + + def shutdown_server(self): + """Shutdown the server.""" + if self._server_stopped: + return + + self._rpc_shutdown().remote() + + self._server_stopped = True + + def close(self): + """Gracefully close the client, cleaning up background tasks.""" + + if self._closed: + return + # stop the main loop + self._closed = True + + logger_debug("RPC Client closing") + + if self._stop_event and self._loop: + # Use call_soon_threadsafe since set() is not a coroutine + self._loop.call_soon_threadsafe(self._stop_event.set) + + if self._reader_task: + try: + self._reader_task.result(timeout=1.0) + except concurrent.futures.TimeoutError: + logger.warning( + "Reader task did not exit gracefully, cancelling") + self._reader_task.cancel() + except Exception as e: + # Task might have already finished or been cancelled + logger_debug(f"Reader task cleanup: {e}") + self._reader_task = None + + if self._loop and self._loop.is_running(): + self._loop.call_soon_threadsafe(self._loop.stop) + if self._loop_thread: + self._loop_thread.join() + self._loop_thread = None + if self._executor: + self._executor.shutdown(wait=True) + + if self._client_socket: + self._client_socket.close() + self._client_socket = None + + logger_debug("RPC Client closed") + + async def _response_reader(self): + """Task to read responses from the socket and set results on futures.""" + logger_debug("Response reader started") + + while not self._stop_event.is_set(): + try: + # Use wait_for with a short timeout to periodically check stop event + try: + response: RPCResponse = await asyncio.wait_for( + self._client_socket.get_async(), + timeout=0.1 # Check stop event every 100ms + ) + except asyncio.TimeoutError: + # Timeout is expected - just check stop event and continue + continue + + logger_debug(f"RPC Client received response: {response}") + logger_debug( + f"Response request_id: {response.request_id}, is_streaming: {response.is_streaming}" + ) + logger_debug( + f"Pending futures: {list(self._pending_futures.keys())}") + + # Handle streaming responses + if response.is_streaming: + assert response.stream_status in [ + 'start', 'data', 'end', 'error' + ], f"Invalid stream status: {response.stream_status}" + queue = self._streaming_queues.get(response.request_id) + if queue: + # put to the sync queue, as the current event loop is + # different from the one in call_async or call_streaming + assert isinstance(queue, AsyncQueue) + logger_debug( + f"RPC Client putting response to AsyncQueue: {response}" + ) + queue.sync_q.put(response) + # Clean up if stream ended + if response.stream_status in ['end', 'error']: + self._streaming_queues.pop(response.request_id, + None) + else: + # Handle regular responses + logger_debug( + f"Handling regular response for request_id: {response.request_id}" + ) + if future_info := self._pending_futures.get( + response.request_id): + future, target_loop = future_info + logger_debug( + f"Found future for request_id: {response.request_id}, future done: {future.done()}" + ) + + if not future.done(): + if response.error is None: + logger_debug( + f"Setting result for request_id: {response.request_id}, result: {response.result}" + ) + target_loop.call_soon_threadsafe( + future.set_result, response.result) + else: + # Use the original RPCError from the response + logger_debug( + f"Setting exception for request_id: {response.request_id}, error: {response.error}" + ) + target_loop.call_soon_threadsafe( + future.set_exception, response.error) + else: + logger_debug( + f"No future found for request_id: {response.request_id}" + ) + self._pending_futures.pop(response.request_id, None) + + except asyncio.CancelledError: + # Still handle cancellation for backward compatibility + logger_debug("Response reader cancelled") + break + except Exception as e: + logger.error(f"Exception in RPC response reader: {e}") + # Propagate exception to all pending futures + for (future, target_loop) in self._pending_futures.values(): + + if not future.done(): + target_loop.call_soon_threadsafe( + future.set_exception, e) + # Also signal error to streaming queues + for queue in self._streaming_queues.values(): + await queue.put(RPCResponse("", None, e, False, 0, 'error')) + break + + logger_debug("Response reader exiting gracefully") + self._reader_task = None + + def _start_response_reader_lazily(self): + if self._reader_task is None or self._reader_task.done(): + # Ensure we have a persistent background loop + self._ensure_event_loop() + # Always create the reader task on the persistent loop + future = asyncio.run_coroutine_threadsafe(self._response_reader(), + self._loop) + # Store the concurrent.futures.Future + self._reader_task = future + + async def _call_async(self, method_name, *args, **kwargs): + """Async version of RPC call. + Args: + method_name: Method name to call + *args: Positional arguments + **kwargs: Keyword arguments + __rpc_params: RPCParams object containing RPC parameters. + + Returns: + The result of the remote method call + """ + logger_debug( + f"RPC client calling method: {method_name} with args: {args} and kwargs: {kwargs}" + ) + if self._server_stopped: + raise RPCCancelled("Server is shutting down, request cancelled") + + self._start_response_reader_lazily() + rpc_params = kwargs.pop("__rpc_params", RPCParams()) + need_response = rpc_params.need_response + timeout = rpc_params.timeout if rpc_params.timeout is not None else self._timeout + + request_id = uuid.uuid4().hex + request = RPCRequest(request_id, + method_name, + args, + kwargs, + need_response, + timeout=timeout) + logger_debug(f"RPC client sending request: {request}") + await self._client_socket.put_async(request) + + if not need_response: + return None + + loop = asyncio.get_running_loop() + future = loop.create_future() + logger_debug( + f"RPC Client _call_async: Created future for request_id: {request_id} in loop: {id(loop)}" + ) + self._pending_futures[request_id] = (future, loop) + logger_debug( + f"RPC Client _call_async: Stored future in pending_futures") + + try: + # If timeout, the remote call should return a timeout error timely, + # so we add 1 second to the timeout to ensure the client can get + # that result. + logger_debug( + f"RPC Client _call_async: Awaiting future for request_id: {request_id}" + ) + if timeout is None: + res = await future + else: + # Add 1 second to the timeout to ensure the client can get + res = await asyncio.wait_for(future, timeout) + logger_debug( + f"RPC Client _call_async: Got result for request_id: {request_id}: {res}" + ) + return res + except RPCCancelled: + self._server_stopped = True + raise + except asyncio.TimeoutError: + raise RPCTimeout( + f"Request '{method_name}' timed out after {timeout}s") + except Exception as e: + raise e + finally: + self._pending_futures.pop(request_id, None) + + def _ensure_event_loop(self): + """Ensure we have a running event loop in a background thread.""" + if self._loop is None or not self._loop.is_running(): + self._loop = asyncio.new_event_loop() + + def run_loop(): + asyncio.set_event_loop(self._loop) + self._stop_event = asyncio.Event() + self._loop.run_forever() + + self._loop_thread = threading.Thread(target=run_loop, + daemon=True, + name="rpc_client_loop") + self._loop_thread.start() + + # Give the loop a moment to start + import time + time.sleep(0.1) + + def _call_sync(self, method_name, *args, **kwargs): + """Synchronous version of RPC call.""" + logger_debug( + f"RPC Client calling method: {method_name} with args: {args} and kwargs: {kwargs}" + ) + self._ensure_event_loop() + logger_debug( + f"RPC Client _call_sync: Creating future for {method_name}") + future = asyncio.run_coroutine_threadsafe( + self._call_async(method_name, *args, **kwargs), self._loop) + logger_debug( + f"RPC Client _call_sync: Waiting for result of {method_name}") + result = future.result() + logger_debug( + f"RPC Client _call_sync: Got result for {method_name}: {result}") + return result + + def _call_future(self, name: str, *args, + **kwargs) -> concurrent.futures.Future: + """ + Call a remote method and return a Future. + + Args: + name: Method name to call + *args: Positional arguments + **kwargs: Keyword arguments + + Returns: + A Future object that can be used to retrieve the result + """ + + def _async_to_sync(): + self._ensure_event_loop() + future = asyncio.run_coroutine_threadsafe( + self._call_async(name, *args, **kwargs), self._loop) + return future.result() + + return self._executor.submit(_async_to_sync) + + async def _call_streaming(self, name: str, *args, + **kwargs) -> AsyncIterator[Any]: + """ + Call a remote async generator method and get streaming results. + + Args: + name: Method name to call + *args: Positional arguments + **kwargs: Keyword arguments + + Yields: + Results from the remote async generator + """ + if self._server_stopped: + raise RPCCancelled("Server is shutting down, request cancelled") + + self._start_response_reader_lazily() + rpc_params = kwargs.pop("__rpc_params", RPCParams()) + timeout = rpc_params.timeout if rpc_params.timeout is not None else self._timeout + + request_id = uuid.uuid4().hex + # Use AsyncQueue to ensure proper cross-thread communication + queue = AsyncQueue() + # Recreate sync_q with the current running loop for proper cross-thread communication + # This ensures the background _response_reader thread can properly notify this event loop + queue._sync_q = _SyncQueue(queue, asyncio.get_running_loop()) + self._streaming_queues[request_id] = queue + + try: + # Send streaming request + request = RPCRequest(request_id, + name, + args, + kwargs, + need_response=True, + timeout=timeout, + is_streaming=True) + await self._client_socket.put_async(request) + + # Read streaming responses + while True: + logger_debug(f"RPC Client _call_streaming waiting for response", + color="green") + if timeout is None: + response = await queue.get() + else: + response = await asyncio.wait_for(queue.get(), + timeout=timeout) + + logger_debug( + f"RPC Client _call_streaming received [{response.stream_status}] response: {response}", + color="green") + if response.stream_status == 'start': + # Start of stream + continue + elif response.stream_status == 'data': + logger_debug( + f"RPC Client _call_streaming received data: {response.result}", + color="green") + yield response.result + elif response.stream_status == 'end': + # End of stream + break + elif response.stream_status == 'error': + # Error in stream + if response.error: + raise response.error + else: + raise RPCStreamingError("Unknown streaming error") + + except asyncio.TimeoutError: + raise RPCTimeout( + f"Streaming request '{name}' timed out after {timeout}s") + finally: + # Clean up + self._streaming_queues.pop(request_id, None) + + def get_server_attr(self, name: str): + """ Get the attribute of the RPC server. + This is mainly used for testing. """ + return self._rpc_get_attr(name).remote() + + def __getattr__(self, name): + """ + Magically handles calls to non-existent methods. + Returns a callable that when invoked returns a RemoteCall instance. + + This enables the new syntax: + client.method(args).remote() + await client.method(args).remote_async() + client.method(args).remote_future() + async for x in client.method(args).remote_streaming() + """ + logger_debug(f"RPC Client getting attribute: {name}") + + def method_caller(*args, **kwargs): + return RemoteCall(self, name, *args, **kwargs) + + return method_caller + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + + def __del__(self): + self.close() diff --git a/tensorrt_llm/executor/rpc/rpc_common.py b/tensorrt_llm/executor/rpc/rpc_common.py new file mode 100644 index 00000000000..5ea809efe71 --- /dev/null +++ b/tensorrt_llm/executor/rpc/rpc_common.py @@ -0,0 +1,77 @@ +import time +from dataclasses import dataclass +from typing import Any, Literal, NamedTuple, Optional + + +class RPCParams(NamedTuple): + """ Parameters for RPC calls. """ + + # seconds to wait for the response + timeout: Optional[float] = None + + # whether the client needs the response, if False, it will return immediately + need_response: bool = True + + # mode for RPC calls: "sync", "async", or "future" + mode: str = "sync" + + +# --- Custom Exceptions --- +class RPCError(Exception): + """Custom exception for RPC-related errors raised on the client side. + + Args: + message: The error message. + cause: The original exception that caused this error. + traceback: The traceback of the exception. + """ + + def __init__(self, + message: str, + cause: Optional[Exception] = None, + traceback: Optional[str] = None): + super().__init__(message) + self.cause = cause + self.traceback = traceback + + +class RPCTimeout(RPCError): + """Exception for when a request processing times out.""" + + +class RPCCancelled(RPCError): + """Exception for when a client request is cancelled. + This happens when the server is shutting down and all the pending + requests will be cancelled and return with this error. + """ + + +class RPCStreamingError(RPCError): + """Exception for streaming-related errors.""" + + +@dataclass +class RPCRequest: + request_id: str + method_name: str + args: tuple + kwargs: dict + need_response: bool = True + timeout: float = 0.5 + is_streaming: bool = False + creation_timestamp: Optional[ + float] = None # Unix timestamp when request was created + + def __post_init__(self): + """Initialize creation_timestamp if not provided.""" + if self.creation_timestamp is None: + self.creation_timestamp = time.time() + + +class RPCResponse(NamedTuple): + request_id: str + result: Any + error: Optional[RPCError] = None + is_streaming: bool = False # True if more responses coming + sequence_number: int = 0 # For ordering streaming responses + stream_status: Literal['start', 'data', 'end', 'error'] = 'data' diff --git a/tensorrt_llm/executor/rpc/rpc_server.py b/tensorrt_llm/executor/rpc/rpc_server.py new file mode 100644 index 00000000000..268bb6012f2 --- /dev/null +++ b/tensorrt_llm/executor/rpc/rpc_server.py @@ -0,0 +1,517 @@ +import asyncio +import inspect +import queue +import threading +import time +import traceback +from concurrent.futures import ThreadPoolExecutor +from typing import Optional + +from ...llmapi.utils import ManagedThread, logger_debug +from ...logger import logger +from ..ipc import ZeroMqQueue +from .rpc_common import (RPCError, RPCRequest, RPCResponse, RPCStreamingError, + RPCTimeout) + + +class RPCServer: + """ + An RPC Server that listens for requests and executes them concurrently. + """ + + def __init__(self, + instance, + hmac_key=None, + num_workers: int = 4, + timeout: float = 0.5, + async_run_task: bool = False): + """ + Initializes the server with an instance. + + Args: + instance: The instance whose methods will be exposed via RPC. + hmac_key (bytes, optional): HMAC key for encryption. + num_workers (int): Number of worker threads or worker tasks that help parallelize the task execution. + timeout (int): Timeout for RPC calls. + async_run_task (bool): Whether to run the task asynchronously. + + NOTE: make num_workers larger if there are some streaming tasks runs infinitely. + """ + self._instance = instance + self._hmac_key = hmac_key + self._num_workers = num_workers + self._address = None + self._timeout = timeout + self._client_socket = None + + # set the stop event to True, and all the workers will exit + self._stop_event = threading.Event() + + self._num_pending_requests = 0 + + self._functions = { + "_rpc_shutdown": lambda: self.shutdown(is_remote_call=True), + "_rpc_get_attr": lambda name: self.get_attr(name), + } + self._dispatcher_thread: Optional[ManagedThread] = None + if async_run_task: + self._executor = ThreadPoolExecutor( + max_workers=num_workers, thread_name_prefix="rpc_server_worker") + else: + self._executor = None + + self._queue = None + + # Automatically register the instance + self.register_instance(instance) + + logger_debug(f"RPC Server initialized with {num_workers} workers.", + color="green") + + @property + def address(self) -> str: + assert self._client_socket is not None, "Client socket is not bound" + return self._client_socket.address[0] + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.shutdown() + + def bind(self, address="tcp://*:5555"): + """ + Bind the server to the specified address. + + Args: + address (str): The ZMQ address to bind the client-facing socket. + """ + self._address = address + self._client_socket = ZeroMqQueue(address=(address, self._hmac_key), + is_server=True, + is_async=True, + use_hmac_encryption=False) + logger.info(f"RPC Server bound to {self._address}") + + def shutdown(self, is_remote_call: bool = False): + """Internal method to trigger server shutdown. + + Args: + is_remote_call: Whether the shutdown is called by a remote call. + This should be True when client.server_shutdown() is called. + """ + # NOTE: shutdown is also a remote method, so it could be executed by + # a thread in a worker executor thread + + if self._stop_event.is_set(): + return + + logger_debug( + "RPC Server shutdown signal received. Terminating server...") + + # Set the stop event to True, this will trigger the dispatcher routine and + # the worker routine to prepare for exit, like stopping accepting new requests, + # and continue to process the pending requests. + self._stop_event.set() + + # The worker routine should process the pending requests + logger_debug( + f"RPC Server shutdown: {self._num_pending_requests} pending requests" + ) + + while self._num_pending_requests > 0: + time.sleep(0.01) + logger_debug(f"RPC Server shutdown finished pending requests") + + if not is_remote_call: + # Block the thread until shutdown is finished + + # 1. Wait for the dispatcher thread to exit, so that no new requests are accepted + logger_debug(f"RPC Server dispatcher thread joining") + if self._dispatcher_thread: + self._dispatcher_thread.join() + self._dispatcher_thread = None + logger_debug(f"RPC Server dispatcher thread joined") + + # 2. Wait for the executor to exit, it will wait for the pending requests to be processed + if self._executor: + self._executor.shutdown(wait=True) + self._executor = None + + # 3. (Optionally) Close the client socket, this doesn't affect + # anything since zmq client will not timeout even if the target is not available + if self._client_socket: + self._client_socket.close() + else: + # if the shutdown is called by a remote call, this method itself will + # be executed in a executor thread, so we cannot join the dispatcher thread as + # the dispatcher thread is awaiting for the shutdown result. + logger_debug( + f"RPC Server to shutdown: {self._num_pending_requests} pending requests" + ) + + while self._num_pending_requests > 0: + time.sleep(0.01) + logger_debug(f"RPC Server shutdown finished pending requests") + + def register_function(self, func, name=None): + """Exposes a single function to clients.""" + fname = name or func.__name__ + if fname in self._functions: + logger.warning( + f"Function '{fname}' is already registered. Overwriting.") + self._functions[fname] = func + logger_debug(f"Registered function: {fname}") + + def register_instance(self, instance): + """Exposes all public methods of a class instance.""" + logger_debug( + f"Registering instance of class: {instance.__class__.__name__}") + for name in dir(instance): + if not name.startswith('_'): + attr = getattr(instance, name) + if callable(attr): + self.register_function(attr, name) + + def get_attr(self, name: str): + """ Get the attribute of the RPC server. + This is mainly used for testing. """ + return getattr(self, name) + + async def _dispatcher_routine(self, stop_event: threading.Event): + assert self._client_socket is not None, "Client socket is not bound" + assert self._queue is not None, "RPC queue is not initialized" + + # Once shutdown, the dispatcher will exit first, and the workers will + # continue to process the pending requests. + while not stop_event.is_set(): + try: + req: RPCRequest = await self._client_socket.get_async_noblock( + timeout=0.5) + logger_debug(f"RPC dispatcher got request: {req}") + except asyncio.TimeoutError: + await asyncio.sleep(0) + continue + except Exception as e: + logger.error(f"RPC dispatcher caught an exception: {e}") + logger.error(traceback.format_exc()) + continue + + await self._queue.put(req) # type: ignore + + # shutdown methods depend on _num_pending_requests, so + # they should not be counted + if req.method_name not in ["_rpc_shutdown", "shutdown"]: + self._num_pending_requests += 1 + logger_debug( + f"Dispatcher received request {req}, pending: {self._num_pending_requests}" + ) + + # TODO optimization: resolve the sequential scheduling for the remote calls + # Suppose tons of submit remote call block the FIFO queue, and the later get_stats remote calls may be blocked + # There could be two dispatch modes: + # 1. (current) mix mode, share the same routine/pool + # 2. (promising) stream mode, specific remote_call -> stream -> specific routine/pool + # - get_stats() - 1, remote_call -> dedicated queue -> dedicated routine/pool + # - submit() - 3 -> dedicated queue -> dedicated routine/pool + # TODO potential optimization: for submit(), batch the ad-hoc requests in an interval like 5ms, reduce the IPC count + async def _worker_routine(self, stop_event: threading.Event): + """The routine executed by each worker thread.""" + assert self._client_socket is not None, "Client socket is not bound" + assert self._queue is not None, "RPC queue is not initialized" + + while (not stop_event.is_set()) or self._num_pending_requests > 0: + try: + req: RPCRequest = await asyncio.wait_for( + self._queue.get(), # type: ignore + timeout=self._timeout) + except asyncio.TimeoutError: + await asyncio.sleep(0) + continue + + # check if the method name is in the functions + if req.method_name not in self._functions: + logger.error( + f"Method '{req.method_name}' not found in RPC server.") + self._num_pending_requests -= 1 + + if not req.need_response: + continue + if req.is_streaming: + await self._client_socket.put_async( + RPCResponse( + req.request_id, + None, + RPCStreamingError( + f"Method '{req.method_name}' not found in RPC server.", + traceback=traceback.format_exc()), + stream_status='error')) + else: + response = RPCResponse( + req.request_id, + None, + RPCError( + f"Method '{req.method_name}' not found in RPC server.", + traceback=traceback.format_exc()), + ) + await self._client_socket.put_async(response) + + continue + + func = self._functions[req.method_name] + if req.is_streaming: + if inspect.isasyncgenfunction(func): + await self._process_streaming_request(req) + else: + # Non-streaming function called with streaming flag + response = RPCResponse( + req.request_id, + None, + RPCStreamingError( + f"Method '{req.method_name}' is not a streaming function." + ), + # need to redirect the error to the client's streaming queue + is_streaming=True, + stream_status='error', + ) + await self._client_socket.put_async(response) + else: + # Process regular request + response = await self._process_request(req) + + # Some tasks don't need response, e.g. submit_request or shutdown + if req.need_response and response is not None: + logger_debug( + f"RPC Server sending response for request {req}, pending: {self._num_pending_requests}" + ) + if await self._send_response(req, response): + logger_debug( + f"RPC Server sent response for request {req}") + + # Only decrement if this request was counted in the first place + if req.method_name not in ["_rpc_shutdown", "shutdown"]: + self._num_pending_requests -= 1 + + def _calculate_adjusted_timeout(self, + req: RPCRequest, + is_streaming: bool = False) -> float: + """Calculate adjusted timeout based on pending overhead. + + Args: + req: The RPC request + is_streaming: Whether this is for a streaming request + + Returns: + The adjusted timeout value + """ + adjusted_timeout = req.timeout + if req.creation_timestamp is not None and req.timeout is not None and req.timeout > 0: + pending_time = time.time() - req.creation_timestamp + adjusted_timeout = max(0.1, req.timeout - + pending_time) # Keep at least 0.1s timeout + if pending_time > 0.1: # Only log if significant pending time + method_type = "streaming " if is_streaming else "" + logger_debug( + f"RPC Server adjusted timeout for {method_type}{req.method_name}: " + f"original={req.timeout}s, pending={pending_time:.3f}s, adjusted={adjusted_timeout:.3f}s" + ) + return adjusted_timeout + + async def _process_request(self, req: RPCRequest) -> Optional[RPCResponse]: + """Process a request. Returns None for streaming requests (handled separately).""" + func = self._functions[req.method_name] + + # Calculate adjusted timeout based on pending overhead + adjusted_timeout = self._calculate_adjusted_timeout(req) + + try: + if inspect.iscoroutinefunction(func): + # Execute async function directly in event loop, no need to run in executor due to the GIL + logger_debug( + f"RPC Server running async task {req.method_name} in dispatcher" + ) + result = await asyncio.wait_for(func(*req.args, **req.kwargs), + timeout=adjusted_timeout) + else: + # Execute sync function in thread executor + loop = asyncio.get_running_loop() + + def call_with_kwargs(): + return func(*req.args, **req.kwargs) + + logger_debug( + f"RPC Server running async task {req.method_name} in worker" + ) + # TODO: let num worker control the pool size + result = await asyncio.wait_for(loop.run_in_executor( + self._executor, call_with_kwargs), + timeout=adjusted_timeout) + + logger_debug(f"RPC Server returned result for request {req}") + response = RPCResponse(req.request_id, result) + + except asyncio.TimeoutError: + response = RPCResponse( + req.request_id, None, + RPCTimeout( + f"Method '{req.method_name}' timed out after {req.timeout} seconds", + traceback=traceback.format_exc())) + + except Exception as e: + response = RPCResponse( + req.request_id, None, + RPCError(str(e), cause=e, traceback=traceback.format_exc())) + + return response + + async def _process_streaming_request(self, req: RPCRequest): + """Process a streaming request by sending multiple responses.""" + func = self._functions[req.method_name] + + if not inspect.isasyncgenfunction(func): + await self._client_socket.put_async( + RPCResponse( + req.request_id, + None, + RPCStreamingError( + f"Method '{req.method_name}' is not an async generator.", + traceback=traceback.format_exc()), + # need to redirect the error to the client's streaming queue + stream_status='error')) + return + + sequence_number = 0 + + # Calculate adjusted timeout based on pending overhead + adjusted_timeout = self._calculate_adjusted_timeout(req, + is_streaming=True) + + try: + logger_debug(f"RPC Server running streaming task {req.method_name}") + # Send start signal + await self._client_socket.put_async( + RPCResponse(req.request_id, None, None, True, sequence_number, + 'start')) + sequence_number += 1 + + # Apply timeout to the entire streaming operation if specified + if adjusted_timeout is not None and adjusted_timeout > 0: + # Create a task for the async generator with timeout + async def stream_with_timeout(): + nonlocal sequence_number + async for result in func(*req.args, **req.kwargs): + logger_debug( + f"RPC Server got data and ready to send result {result}" + ) + response = RPCResponse(req.request_id, result, None, + True, sequence_number, 'data') + if not await self._send_response(req, response): + # Stop streaming after a pickle error + return + sequence_number += 1 + + # Use wait_for for timeout handling + await asyncio.wait_for(stream_with_timeout(), + timeout=adjusted_timeout) + else: + # No timeout specified, stream normally + async for result in func(*req.args, **req.kwargs): + logger_debug( + f"RPC Server got data and ready to send result {result}" + ) + response = RPCResponse(req.request_id, result, None, True, + sequence_number, 'data') + if not await self._send_response(req, response): + # Stop streaming after a pickle error + return + sequence_number += 1 + + # Send end signal + await self._client_socket.put_async( + RPCResponse(req.request_id, None, None, True, sequence_number, + 'end')) + + except asyncio.TimeoutError: + await self._client_socket.put_async( + RPCResponse( + req.request_id, None, + RPCTimeout( + f"Streaming method '{req.method_name}' timed out", + traceback=traceback.format_exc()), True, + sequence_number, 'error')) + + except Exception as e: + response = RPCResponse( + req.request_id, None, + RPCStreamingError(str(e), traceback=traceback.format_exc()), + True, sequence_number, 'error') + await self._send_response(req, response) + + async def _send_response(self, req: RPCRequest, + response: RPCResponse) -> bool: + """Safely sends a response, handling pickle errors.""" + try: + await self._client_socket.put_async(response) + return True + except Exception as e: + logger.error( + f"Failed to pickle response for request {req.request_id}: {e}") + error_msg = f"Failed to pickle response: {e}" + if req.is_streaming: + error_cls = RPCStreamingError + # For streaming, we also need sequence number. The original response has it. + sequence_number = response.sequence_number if response else None + error_response = RPCResponse( + req.request_id, + None, + error_cls(error_msg, traceback=traceback.format_exc()), + is_streaming=True, + sequence_number=sequence_number, + stream_status='error') + else: + error_cls = RPCError + error_response = RPCResponse( + req.request_id, None, + error_cls(error_msg, traceback=traceback.format_exc())) + + try: + await self._client_socket.put_async(error_response) + except Exception as e_inner: + logger.error( + f"Failed to send error response for request {req.request_id}: {e_inner}" + ) + return False + + def start(self): + """Binds sockets, starts workers, and begins proxying messages.""" + if self._client_socket is None: + raise RuntimeError( + "Server must be bound to an address before starting. Call bind() first." + ) + + self._client_socket.setup_lazily() + logger.info(f"RPC Server started and listening on {self._address}") + + async def tasks(): + self._queue = asyncio.Queue() + await asyncio.gather( + self._dispatcher_routine(self._stop_event), *[ + self._worker_routine(self._stop_event) + for i in range(self._num_workers) + ]) + + def loop() -> bool: + asyncio.run(tasks()) + return True # ManagedThread + + error_queue = queue.Queue() + self._dispatcher_thread = ManagedThread(task=loop, + stop_event=self._stop_event, + name="rpc_dispatcher_thread", + error_queue=error_queue) + self._dispatcher_thread.start() + + logger.info("RPC Server has started.") + + +Server = RPCServer diff --git a/tensorrt_llm/executor/rpc_proxy.py b/tensorrt_llm/executor/rpc_proxy.py new file mode 100644 index 00000000000..8e1375a1811 --- /dev/null +++ b/tensorrt_llm/executor/rpc_proxy.py @@ -0,0 +1,375 @@ +import asyncio +import atexit +import json +import os +import threading +from typing import Optional + +from ..llmapi.llm_args import KvCacheConnectorConfig +from ..llmapi.mpi_session import MpiPoolSession, MpiSession +from ..llmapi.tracer import global_tracer +from ..llmapi.utils import (AsyncQueue, _SyncQueue, logger_debug, + print_colored_debug) +from ..logger import logger +from .executor import GenerationExecutor +from .postproc_worker import PostprocWorkerConfig +from .request import GenerationRequest +from .result import GenerationResult +from .rpc import RPCClient +from .rpc_worker import RpcWorker +from .utils import (ErrorResponse, create_mpi_comm_session, + get_spawn_proxy_process_env, is_llm_response) + + +class GenerationExecutorRpcProxy(GenerationExecutor): + # NOTE: this is a global counter for the number of instances of this class + INSTANCE_COUNTER = 0 + + def __init__( + self, + worker_kwargs: dict, + model_world_size: int = 1, + mpi_session: Optional[MpiSession] = None, + *, + postproc_worker_config: Optional[PostprocWorkerConfig] = None, + is_llm_executor: Optional[bool] = None, + kv_connector_config: Optional[KvCacheConnectorConfig] = None, + ): + """ + Args: + worker_kwargs: kwargs for the rpc worker + model_world_size: the world size of the model + mpi_session: the mpi session to use + postproc_worker_config: the postproc worker config + is_llm_executor: whether this is an llm executor + kv_connector_config: the kv cache connector config + """ + GenerationExecutorRpcProxy.INSTANCE_COUNTER += 1 + self.rpc_addr = self.gen_uniq_rpc_addr() + self.rpc_client = RPCClient(self.rpc_addr) + + postproc_worker_config = postproc_worker_config or PostprocWorkerConfig( + ) + + super().__init__( + num_postprocess_workers=postproc_worker_config. + num_postprocess_workers, + postprocess_tokenizer_dir=postproc_worker_config. + postprocess_tokenizer_dir, + is_llm_executor=is_llm_executor, + ) + + self._results = {} + + self._create_mpi_session(model_world_size, mpi_session) + + self._shutdown_event = threading.Event() + self.worker_kwargs = worker_kwargs + + self.main_loop_task_obj = None + self.main_loop = None + + self.launch_workers() + + # Invoke model creation on the remote + # TBD: Move model creation to the mpi task, or left in RPC? + self.setup_engine_remote() + + # Setup main loop after engine is ready + self.setup_mainloop() + + def launch_workers(self): + logger.debug(f"Launching workers") + assert self.mpi_session is not None + self.mpi_session.submit(RpcWorker.main_task, + rpc_addr=self.rpc_addr, + **self.worker_kwargs) + + async def _generic_fetch_loop_async(self, fetch_method_name: str, + handler_method, method_name: str): + """Generic method for fetching data in a loop from RPC worker. + + Args: + fetch_method_name: Name of the RPC client method to call + handler_method: The handler method to call with the fetched data + method_name: Name of the method for logging + """ + try: + fetch_method = getattr(self.rpc_client, fetch_method_name) + async for data in fetch_method().remote_streaming(): + if self._shutdown_event.is_set(): + return + handler_method(data) + except asyncio.CancelledError: + logger.debug(f"{method_name} task cancelled") + except Exception as e: + logger.error(f"Error in {method_name}: {e}") + raise + + async def _fetch_responses_loop_async(self): + await self._generic_fetch_loop_async( + fetch_method_name="fetch_responses_loop_async", + handler_method=self.handle_responses, + method_name="_fetch_responses_loop_async") + + async def _fetch_stats_loop_async(self): + await self._generic_fetch_loop_async( + fetch_method_name="fetch_stats_loop_async", + handler_method=self.handle_stats, + method_name="_fetch_stats_loop_async") + + async def _fetch_kv_cache_events_loop_async(self): + await self._generic_fetch_loop_async( + fetch_method_name="fetch_kv_cache_events_loop_async", + handler_method=self.handle_kv_cache_events, + method_name="_fetch_kv_cache_events_loop_async") + + def setup_mainloop(self): + + async def main_loop_task(): + tasks = [ + self._fetch_responses_loop_async(), + self._fetch_stats_loop_async(), + self._fetch_kv_cache_events_loop_async(), + ] + # Only add kv_cache_events loop if it's enabled + if self._iter_kv_events_result: + tasks.append(self._fetch_kv_cache_events_loop_async()) + await asyncio.gather(*tasks) + + def _run_main_loop_task(): + """Local method to run the main loop task.""" + self.main_loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.main_loop) + + self.main_loop_task_obj = self.main_loop.create_task( + main_loop_task()) + try: + self.main_loop.run_until_complete(self.main_loop_task_obj) + except asyncio.CancelledError: + pass # Task cancellation is expected during shutdown + finally: + self.main_loop.close() + + self.main_loop_thread = threading.Thread(target=_run_main_loop_task, + daemon=True) + self.main_loop_thread.start() + atexit.register(self.shutdown) + + def handle_responses(self, responses: list[GenerationResult]) -> bool: + async_queues = [] + event_loop = None + + def process_res(res: list): + for r in res: + client_id = r.client_id + nonlocal event_loop + nonlocal async_queues + + if client_id not in self._results: + logger.warning( + f"Received response for unknown client_id: {client_id}") + continue + + queue = self._results[client_id].queue + if isinstance(queue, _SyncQueue): + queue.put_nowait(r) + async_queues.append(queue) + # all the loops are identical + event_loop = event_loop or queue.loop + else: + queue.put(r) + + if (is_llm_response(r) and r.result.is_final) or isinstance( + r, ErrorResponse): + self._results.pop(client_id) + + # Handle the case where responses might not be a list of lists + if responses and not isinstance(responses[0], list): + # If responses is a flat list, wrap it + responses = [responses] + + for res in responses: + global_tracer().log_instant("RPC.get") + process_res(res) + + if async_queues: + _SyncQueue.notify_many(event_loop, async_queues) + + def _handle_iteration_data(self, data, result_singleton, data_type: str): + """Generic method to handle iteration data received from RPC worker. + + Args: + data: Data from the RPC worker (can be dict, str, or list) + result_singleton: The iteration result singleton to put data into + data_type: Type of data for logging (e.g., "stats", "kv_cache_events") + """ + # Make sure we have initialized the iteration results + self._maybe_initialize_iteration_results() + + if not result_singleton: + logger.debug( + f"Skipping {data_type} handling while result_singleton=None") + return + + # Get the queue from the result singleton + queue = result_singleton.queue + async_queues = [] + + # Clear old data if queue is full (similar to _iteration_result_task) + while queue.full(): + queue.get() + + try: + # Handle different types of data + if isinstance(data, str): + # Already JSON serialized + data_json = data + elif isinstance(data, list): + # Skip empty lists to avoid putting nothing in the queue + if not data: + logger.debug( + f"rpc_proxy.py: Skipping empty {data_type} list") + return + + # Handle list of data (multiple iterations) + for item in data: + if isinstance(item, str): + item_json = item + else: + item_json = json.dumps(item) + + if isinstance(queue, _SyncQueue): + queue.put_nowait(item_json) + async_queues.append(queue) + else: + queue.put(item_json) + + if async_queues: + _SyncQueue.notify_many(queue.loop, async_queues) + return + else: + # Convert dict/other to JSON string as expected by IterationResult + data_json = json.dumps(data) + + if isinstance(queue, _SyncQueue): + queue.put_nowait(data_json) + async_queues.append(queue) + else: + queue.put(data_json) + + if async_queues: + _SyncQueue.notify_many(queue.loop, async_queues) + + except AsyncQueue.EventLoopShutdownError: + # This happens when the event loop is already closed + logger.debug( + f"rpc_proxy.py: EventLoopShutdownError in handle_{data_type}") + except Exception as e: + logger.error(f"rpc_proxy.py: Error in handle_{data_type}: {e}") + raise e + + def handle_stats(self, stats): + """Handle stats received from RPC worker and put them into the stats result queue. + + Args: + stats: Statistics data from the RPC worker (can be dict, str, or list) + """ + self._handle_iteration_data(stats, self._iter_stats_result, "stats") + + def handle_kv_cache_events(self, events): + """Handle KV cache events received from RPC worker and put them into the events result queue. + + Args: + events: KV cache events data from the RPC worker (can be dict, str, or list) + """ + self._handle_iteration_data(events, self._iter_kv_events_result, + "kv_cache_events") + + def submit(self, request: GenerationRequest) -> GenerationResult: + request.set_id(self._get_next_client_id()) + logprob_params = self._get_logprob_params(request) + + # submit is a fire-and-forget operation, don't need to wait for response + self.rpc_client.submit(request).remote(need_response=False) + + result = GenerationResult( + request, + background_error_handler=self._handle_background_error, + executor=self, + disaggregated_params=request.disaggregated_params, + logprob_params=logprob_params) + self._results[request.id] = result + + return result + + def fetch_stats_remote(self): + return self.rpc_client.fetch_stats().remote() + + def setup_engine_remote(self): + return self.rpc_client.setup_engine().remote(need_response=True) + + def shutdown_remote(self): + logger_debug(f"Shutting down rpc remote", color="yellow") + self.rpc_client.shutdown().remote() + + def abort_request(self, request_id: int) -> None: + return self.rpc_client.abort_request(request_id).remote() + + def shutdown(self): + if self._shutdown_event.is_set(): + return + self._shutdown_event.set() + logger_debug(f"Shutting down GenerationExecutorRpcProxy", + color="yellow") + + # 1. shutdown the rpc server (PyExecutor Rank 0 + RPC server) + self.shutdown_remote() + + # 2. stop the main loop, so that no new rpc requests + if self.main_loop and self.main_loop_task_obj: + logger_debug("Cancelling main loop task.", color="yellow") + # The cancel() is thread-safe + try: + self.main_loop.call_soon_threadsafe( + self.main_loop_task_obj.cancel) + except Exception as e: + logger_debug(f"Error cancelling main loop task: {e}", + color="yellow") + + self.main_loop_thread.join() + + # 3. shutdown the mpi session, this should wait until all the PyExecutor + # processes are shutdown + if self.mpi_session is not None: + logger_debug(f"Shutting down mpi session", color="yellow") + self.mpi_session.shutdown() + logger_debug(f"Mpi session shutdown", color="yellow") + self.mpi_session = None + + self.rpc_client.close() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.shutdown() + + def _create_mpi_session(self, model_world_size: int, + mpi_session: Optional[MpiSession]): + mpi_process_pre_spawned: bool = get_spawn_proxy_process_env() + if mpi_session is None: + if mpi_process_pre_spawned: + print_colored_debug('create comm session ...\n', "yellow") + self.mpi_session = create_mpi_comm_session(model_world_size) + else: + print_colored_debug('create pool session ...\n', "yellow") + self.mpi_session = MpiPoolSession(n_workers=model_world_size) + else: + print_colored_debug('using external mpi session ...\n', "yellow") + self.mpi_session = mpi_session + + @staticmethod + def gen_uniq_rpc_addr() -> str: + process_id = os.getpid() + return f"ipc:///tmp/rpc-proxy-{process_id}-{GenerationExecutorRpcProxy.INSTANCE_COUNTER}" diff --git a/tensorrt_llm/executor/rpc_worker.py b/tensorrt_llm/executor/rpc_worker.py new file mode 100644 index 00000000000..a9ef9f435d3 --- /dev/null +++ b/tensorrt_llm/executor/rpc_worker.py @@ -0,0 +1,253 @@ +import asyncio +from pathlib import Path +from queue import Queue +from threading import Event +from typing import AsyncGenerator, Optional, Union + +from tensorrt_llm._utils import mpi_comm +from tensorrt_llm.llmapi.utils import enable_llm_debug, logger_debug + +from .._utils import mpi_rank +from ..bindings import executor as tllm +from ..builder import Engine +from ..llmapi.llm_args import BaseLlmArgs, KvCacheConnectorConfig +from ..llmapi.tokenizer import TokenizerBase +from ..logger import set_level +from ..lora_manager import LoraConfig +from ..sampling_params import BatchedLogitsProcessor +from .base_worker import BaseWorker +from .postproc_worker import PostprocWorkerConfig +from .request import GenerationRequest +from .rpc import RPCServer + + +class RpcWorker(BaseWorker): + """ + A RPC wrapper for the BaseWorker class. + + Actions: + - `setup_engine`: Setup the engine. + - `submit`: Submit a request to the worker. + - `fetch_responses`: Fetch the latest responses from engine. + - `fetch_stats`: Fetch the latest stats from engine. + - `fetch_kv_cache_events`: Fetch the latest kv cache events from engine. + - `shutdown`: Shutdown the worker. + """ + + # Number of RPC server workers + NUM_WORKERS = 6 + + def __init__( + self, + engine: Union[Path, Engine], + executor_config: Optional[tllm.ExecutorConfig] = None, + is_llm_executor: Optional[bool] = None, + lora_config: Optional[LoraConfig] = None, + batched_logits_processor: Optional[BatchedLogitsProcessor] = None, + postproc_worker_config: Optional[PostprocWorkerConfig] = None, + kv_connector_config: Optional[KvCacheConnectorConfig] = None, + hf_model_dir: Optional[Path] = None, + tokenizer: Optional[TokenizerBase] = None, + llm_args: Optional[BaseLlmArgs] = None, + ) -> None: + super().__init__( + engine=engine, + executor_config=executor_config, + is_llm_executor=is_llm_executor, + lora_config=lora_config, + llm_args=llm_args, + batched_logits_processor=batched_logits_processor, + postproc_worker_config=postproc_worker_config, + kv_connector_config=kv_connector_config, + hf_model_dir=hf_model_dir, + tokenizer=tokenizer, + ) + # Extract garbage_collection_gen0_threshold from llm_args if available + self.garbage_collection_gen0_threshold = ( + llm_args.garbage_collection_gen0_threshold if llm_args is not None + and hasattr(llm_args, 'garbage_collection_gen0_threshold') else + None) + self.shutdown_event = Event() + + self._response_queue = Queue() + self.set_result_queue(self._response_queue) + + def submit(self, request: GenerationRequest): + """ Submits a request to the worker. """ + super().submit(request) + + def fetch_responses(self, timeout: Optional[float] = None) -> list: + logger_debug(f"RpcWorker {mpi_rank()} is fetching responses", + color="yellow") + # NOTE: This is a blocking call, it will wait for the responses to be available. + responses = super().await_responses(timeout) + self._await_response_helper.responses_handler(responses) + + qsize = self._response_queue.qsize() + logger_debug(f"RpcWorker returning {qsize} responses", color="yellow") + + all_responses = [] + for _ in range(qsize): + # The queue contains batches of responses, so extend the list + all_responses.extend(self._response_queue.get()) + return all_responses + + async def fetch_responses_async(self, + timeout: Optional[float] = None) -> list: + # A really async version of fetch_responses + logger_debug(f"RpcWorker {mpi_rank()} is fetching responses async", + color="yellow") + + # First, await any pending responses without blocking the event loop + responses = await asyncio.to_thread(self.fetch_responses, + timeout=timeout) + return responses + + async def fetch_stats_async(self, timeout: Optional[float] = None) -> list: + return await asyncio.to_thread(self.fetch_stats) + + async def fetch_kv_cache_events_async(self, + timeout: Optional[float] = None + ) -> list: + return await asyncio.to_thread(self.fetch_kv_cache_events) + + # for streaming performance + async def fetch_responses_loop_async(self) -> AsyncGenerator[list, None]: + while not self.shutdown_event.is_set(): + responses = await self.fetch_responses_async() + if responses: # Only yield if there are actual responses + logger_debug( + f"RpcWorker {mpi_rank()} is yielding responses: {responses}", + color="yellow") + yield responses # batching the responses to opt IPC performance + else: + # Small delay to prevent busy waiting when no responses + await asyncio.sleep(0) + logger_debug( + f"RpcWorker {mpi_rank()} quitting fetch_responses_loop_async", + color="yellow") + + async def _generic_fetch_loop_async( + self, + fetch_method, + serializer, + method_name: str, + timeout: Optional[float] = None) -> AsyncGenerator[list, None]: + """Generic method for fetching data in a loop. + + Args: + fetch_method: The async method to call for fetching data + serializer: The serializer function to apply to each item + method_name: Name of the method for logging + timeout: Optional timeout between fetches + """ + while not self.shutdown_event.is_set(): + timeout = timeout or 0.1 + await asyncio.sleep(timeout) + data = await fetch_method() + # Always yield data, even if empty, to prevent the client looks like hanging + # TODO: Remove the empty data to reduce the IPC overhead + yield [serializer(item) for item in data] + logger_debug(f"RpcWorker {mpi_rank()} quitting {method_name}", + color="yellow") + + async def fetch_stats_loop_async( + self, + timeout: Optional[float] = None) -> AsyncGenerator[list, None]: + async for data in self._generic_fetch_loop_async( + fetch_method=self.fetch_stats_async, + serializer=self._stats_serializer, + method_name="fetch_stats_loop_async", + timeout=timeout): + yield data + + async def fetch_kv_cache_events_loop_async( + self, + timeout: Optional[float] = None) -> AsyncGenerator[list, None]: + async for data in self._generic_fetch_loop_async( + fetch_method=self.fetch_kv_cache_events_async, + serializer=self._kv_cache_events_serializer, + method_name="fetch_kv_cache_events_loop_async", + timeout=timeout): + yield data + + def setup_engine(self): + # Force all the ranks to wait here, and start creating the executor simultaneously. + # Only call barrier if we have multiple ranks to avoid hanging in single-process tests + if mpi_comm().Get_size() > 1: + mpi_comm().barrier() + + super().setup_engine() + + def shutdown(self): + logger_debug(f"RPC worker {mpi_rank()} is shutting down", + color="yellow") + self.shutdown_event.set() + super().shutdown() + logger_debug(f"RPC worker {mpi_rank()} is shutdown", color="yellow") + + def start(self): + pass + + @staticmethod + def main_task( + engine: Union[Path, Engine], + rpc_addr: str, + *, + executor_config: Optional[tllm.ExecutorConfig] = None, + batched_logits_processor: Optional[BatchedLogitsProcessor] = None, + postproc_worker_config: Optional[PostprocWorkerConfig] = None, + is_llm_executor: Optional[bool] = None, + lora_config: Optional[LoraConfig] = None, + llm_args: Optional[BaseLlmArgs] = None, + kv_connector_config: Optional[KvCacheConnectorConfig] = None, + hf_model_dir: Optional[Path] = None, + tokenizer: Optional[TokenizerBase] = None, + **kwargs, + ) -> None: + if enable_llm_debug(): + set_level("debug") + + # Step 1: Create the worker instance + worker = RpcWorker( + engine=engine, + executor_config=executor_config, + is_llm_executor=is_llm_executor, + lora_config=lora_config, + llm_args=llm_args, + batched_logits_processor=batched_logits_processor, + postproc_worker_config=postproc_worker_config, + kv_connector_config=kv_connector_config, + hf_model_dir=hf_model_dir, + tokenizer=tokenizer, + ) + + if mpi_rank() != 0: + # The non-leader worker will setup the engine immediately. + # The leader worker will wait for the RPC call to propagate the + # potential error. + logger_debug(f"Worker {mpi_rank()} is setting up the engine", + color="yellow") + worker.setup_engine() + + else: + logger_debug(f"Worker {mpi_rank()} is creating the RPC service", + color="yellow") + # Step 2: Create the RPC service, it will expose all the APIs of the worker as remote call to the client + # Set num_workers to larger than 1 since there are some streaming tasks runs infinitely, such as await_responses_async. + rpc_server = RPCServer(worker, num_workers=RpcWorker.NUM_WORKERS) + rpc_server.bind(rpc_addr) + rpc_server.start() + + # Step 3: Wait for the worker to shutdown + logger_debug( + f"Worker {mpi_rank()} is waiting for the worker to shutdown") + worker.shutdown_event.wait() + rpc_server.shutdown() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.shutdown() + return True diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index 148cdcf038c..859e5687448 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -1,18 +1,17 @@ import gc -import json import os import time import traceback from concurrent.futures import ProcessPoolExecutor from pathlib import Path from queue import Queue -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, List, Optional, Union import zmq from tensorrt_llm.logger import logger -from .._utils import KVCacheEventSerializer, mpi_comm, mpi_rank +from .._utils import mpi_comm, mpi_rank from ..bindings import executor as tllm from ..builder import Engine from ..llmapi.llm_args import BaseLlmArgs, KvCacheConnectorConfig @@ -144,25 +143,9 @@ def _iteration_result_task(self, it_result_queue: IterationResultQueue, return True # success def dispatch_stats_task(self) -> bool: - - # Define a Callable to join iteration and request stats - def stats_serializer( - stats: Tuple[tllm.IterationStats, tllm.RequestStats]) -> str: - iteration_stats, req_stats = stats - stats_dict = json.loads(iteration_stats.to_json_str()) - - if req_stats is not None and len(req_stats) > 0: - stats_dict["requestStats"] = [] - for req_stat in req_stats: - stats_dict["requestStats"].append( - json.loads(req_stat.to_json_str())) - - # Convert back to JSON string - return json.dumps(stats_dict) - return self._iteration_result_task(self.stats_queues, self.fetch_stats, self._iter_stats_result, - stats_serializer) + self._stats_serializer) def dispatch_kv_cache_events_task(self) -> bool: if isinstance(self.engine, tllm.Executor): @@ -173,14 +156,14 @@ def dispatch_kv_cache_events_task(self) -> bool: events_api = lambda: [None] else: events_api = event_manager.get_latest_events - return self._iteration_result_task( - self.kv_events_queues, events_api, self._iter_kv_events_result, - lambda x: json.dumps(KVCacheEventSerializer.serialize(x))) + return self._iteration_result_task(self.kv_events_queues, + events_api, + self._iter_kv_events_result, + self._kv_cache_events_serializer) else: return self._iteration_result_task( self.kv_events_queues, self.engine.get_latest_kv_cache_events, - self._iter_kv_events_result, - lambda x: json.dumps(KVCacheEventSerializer.serialize(x))) + self._iter_kv_events_result, self._kv_cache_events_serializer) def start(self): # create iteration result queues diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 2ea038935e1..b64c11246dd 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -1550,6 +1550,13 @@ class BaseLlmArgs(StrictBaseModel): description="Return perf metrics.", status="prototype") + orchestrator_type: Optional[Literal["rpc"]] = Field( + default=None, + description= + "The orchestrator type to use. Defaults to None, which uses MPI.", + status="prototype", + ) + _parallel_config: Optional[object] = PrivateAttr(default=None) _model_format: Optional[_ModelFormatKind] = PrivateAttr(default=None) _speculative_model: Optional[str] = PrivateAttr(default=None) diff --git a/tensorrt_llm/llmapi/utils.py b/tensorrt_llm/llmapi/utils.py index 5e9f2e69aa4..4a9c10f0d7a 100644 --- a/tensorrt_llm/llmapi/utils.py +++ b/tensorrt_llm/llmapi/utils.py @@ -1,6 +1,8 @@ import asyncio import collections +import datetime import hashlib +import inspect import io import os import re @@ -64,10 +66,63 @@ def print_colored(message, def print_colored_debug(message, color: Optional[str] = None, writer: io.TextIOWrapper = sys.stderr): - if enable_llm_debug(): + if enable_llmapi_debug(): print_colored(message, color, writer) +def get_current_location(skip_frames: int = 2) -> str: + """ + Get the current execution location in format 'module.class.function'. + + Args: + skip_frames: Number of stack frames to skip (default 2 to skip this function and its caller) + + Returns: + String in format 'module.class.function' or 'module.function' if not in a class + """ + stack = inspect.stack() + if len(stack) <= skip_frames: + return "unknown" + + frame = stack[skip_frames] + module_name = frame.frame.f_globals.get('__name__', 'unknown') + function_name = frame.function + + # Try to determine if we're in a class method + class_name = None + if 'self' in frame.frame.f_locals: + # This is likely an instance method + obj = frame.frame.f_locals['self'] + class_name = obj.__class__.__name__ + elif 'cls' in frame.frame.f_locals: + # This might be a class method + cls = frame.frame.f_locals['cls'] + if inspect.isclass(cls): + class_name = cls.__name__ + + # Build the location string + if class_name: + return f"{module_name}.{class_name}.{function_name}" + else: + return f"{module_name}.{function_name}" + + +def logger_debug(message, + color: Optional[str] = None, + writer: io.TextIOWrapper = sys.stderr): + """ Print the message if the llmapi debug mode is enabled. Fallback to logger.debug if not. """ + if enable_llmapi_debug(): + timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + location = get_current_location() + cur_dualname = "..." + location[-47:] if len( + location) > 50 else location + print_colored(f"{timestamp} [{cur_dualname}]", "bold_green", writer) + print_colored(f" {message}\n", color, writer) + else: + # Fallback to logger.debug + logger.debug(message) + + def file_with_glob_exists(directory, glob) -> bool: path = Path(directory) for file_path in path.glob(glob): @@ -244,14 +299,14 @@ def __init__(self, task: Callable[..., bool], error_queue: Queue, name: Optional[str] = None, + stop_event: Optional[threading.Event] = None, **kwargs): super().__init__(name=name) self.task = task self.error_queue = error_queue self.kwargs = kwargs self.daemon = True - - self.stop_event = threading.Event() + self.stop_event = stop_event or threading.Event() def run(self): @@ -290,6 +345,17 @@ def enable_llm_debug() -> bool: return _enable_llm_debug_ +_enable_llmapi_debug_ = None + + +def enable_llmapi_debug() -> bool: + global _enable_llmapi_debug_ + if _enable_llmapi_debug_ is None: + _enable_llmapi_debug_ = os.environ.get("TLLM_LLMAPI_ENABLE_DEBUG", + "0") == "1" + return _enable_llmapi_debug_ + + @cache def enable_worker_single_process_for_tp1() -> bool: ''' Tell whether to make worker use single process for TP1. diff --git a/tests/integration/test_lists/test-db/l0_a10.yml b/tests/integration/test_lists/test-db/l0_a10.yml index 37aac5a6b87..d45fc865cb4 100644 --- a/tests/integration/test_lists/test-db/l0_a10.yml +++ b/tests/integration/test_lists/test-db/l0_a10.yml @@ -46,6 +46,8 @@ l0_a10: - unittest/llmapi/test_serialization.py - unittest/llmapi/test_utils.py - unittest/llmapi/test_llm_args.py + # executor + - unittest/executor/test_rpc.py - condition: ranges: system_gpu_count: diff --git a/tests/integration/test_lists/test-db/l0_a100.yml b/tests/integration/test_lists/test-db/l0_a100.yml index acc25bf2e48..951231de7bd 100644 --- a/tests/integration/test_lists/test-db/l0_a100.yml +++ b/tests/integration/test_lists/test-db/l0_a100.yml @@ -16,6 +16,10 @@ l0_a100: - unittest/llmapi/test_llm_pytorch.py - unittest/llmapi/test_mpi_session.py # generic tests - unittest/trt/model_api/test_model_quantization.py + # executor + - unittest/executor/test_base_worker.py + - unittest/executor/test_rpc_proxy.py + - unittest/executor/test_rpc_worker.py - condition: ranges: system_gpu_count: diff --git a/tests/unittest/api_stability/references/llm.yaml b/tests/unittest/api_stability/references/llm.yaml index 642729fc406..26f333eaad2 100644 --- a/tests/unittest/api_stability/references/llm.yaml +++ b/tests/unittest/api_stability/references/llm.yaml @@ -179,6 +179,10 @@ methods: annotation: bool default: False status: prototype + orchestrator_type: + annotation: Optional[Literal["rpc"]] + default: null + status: prototype return_annotation: None generate: parameters: diff --git a/tests/unittest/executor/test_base_worker.py b/tests/unittest/executor/test_base_worker.py index 664beefd270..6e661d1e4e5 100644 --- a/tests/unittest/executor/test_base_worker.py +++ b/tests/unittest/executor/test_base_worker.py @@ -12,6 +12,7 @@ # isort: off sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/..") from utils.llm_data import llm_models_root +from utils.util import skip_single_gpu # isort: on from tensorrt_llm._torch.pyexecutor.config import update_executor_config @@ -117,9 +118,12 @@ def test_fetch_responses_timeout(self, timeout: float): def create_fake_executor_config(model_path, tp_size=1): # Use TorchLlmArgs for PyTorch backend tests - llm_args = TorchLlmArgs(model=model_path, - tensor_parallel_size=tp_size, - backend='pytorch') + llm_args = TorchLlmArgs( + model=model_path, + tensor_parallel_size=tp_size, + backend='pytorch', + enable_iter_perf_stats=True, + ) executor_config = tllm.ExecutorConfig(1) executor_config.max_batch_size = 1 @@ -153,6 +157,8 @@ def create_worker_session(self): session = MpiPoolSession(n_workers=2) return session + @pytest.mark.gpu2 + @skip_single_gpu def test_create_executor(self): futures = self.session.submit( TestRpcWorkerBaseTP2.create_executor, diff --git a/tests/unittest/executor/test_rpc.py b/tests/unittest/executor/test_rpc.py new file mode 100644 index 00000000000..8a04d7534c2 --- /dev/null +++ b/tests/unittest/executor/test_rpc.py @@ -0,0 +1,661 @@ +import asyncio +import time + +import pytest + +from tensorrt_llm.executor.rpc import (RPCCancelled, RPCClient, RPCError, + RPCServer, RPCStreamingError, RPCTimeout) + + +class RpcServerWrapper(RPCServer): + + def __init__(self, *args, addr: str, **kwargs): + super().__init__(*args, **kwargs) + self.addr = addr + + def __enter__(self): + self.bind(self.addr) + self.start() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.shutdown() + + +class TestRpcBasics: + + def test_rpc_server_basics(self): + + class App: + + def hello(self): + print("hello") + + with RpcServerWrapper(App(), addr="ipc:///tmp/rpc_test") as server: + pass + + def test_remote_call_without_arg(self): + + class App: + + def hello(self): + print("hello") + return "world" + + with RpcServerWrapper(App(), addr="ipc:///tmp/rpc_test") as server: + with RPCClient("ipc:///tmp/rpc_test") as client: + ret = client.hello().remote() # sync call + assert ret == "world" + + def test_remote_call_with_args(self): + + class App: + + def hello(self, name: str, location: str): + print("hello") + return f"hello {name} from {location}" + + with RpcServerWrapper(App(), addr="ipc:///tmp/rpc_test") as server: + with RPCClient("ipc:///tmp/rpc_test") as client: + ret = client.hello("app", "Marvel").remote() + assert ret == "hello app from Marvel" + + def test_remote_call_with_kwargs(self): + + class App: + + def hello(self, name: str, location: str): + print("hello") + return f"hello {name} from {location}" + + with RpcServerWrapper(App(), addr="ipc:///tmp/rpc_test") as server: + with RPCClient("ipc:///tmp/rpc_test") as client: + ret = client.hello(name="app", location="Marvel").remote() + assert ret == "hello app from Marvel" + + def test_remote_call_with_args_and_kwargs(self): + + class App: + + def hello(self, name: str, location: str): + print("hello") + return f"hello {name} from {location}" + + with RpcServerWrapper(App(), addr="ipc:///tmp/rpc_test") as server: + with RPCClient("ipc:///tmp/rpc_test") as client: + ret = client.hello(name="app", location="Marvel").remote() + assert ret == "hello app from Marvel" + + def test_rpc_server_address(self): + + class App: + pass + + with RpcServerWrapper(App(), addr="ipc:///tmp/rpc_test") as server: + assert server.address == "ipc:///tmp/rpc_test" + + def test_rpc_with_error(self): + + class App: + + def hello(self): + raise ValueError("hello") + + with RpcServerWrapper(App(), + addr="ipc:///tmp/rpc_test_error") as server: + with RPCClient("ipc:///tmp/rpc_test_error") as client: + with pytest.raises(RPCError): + client.hello().remote() + + def test_rpc_without_wait_response(self): + + class App: + + def __init__(self): + self.task_submitted = False + + def send_task(self) -> None: + # Just submit the task and return immediately + # The result is not important + self.task_submitted = True + return None + + def get_task_submitted(self) -> bool: + return self.task_submitted + + with RpcServerWrapper(App(), + addr="ipc:///tmp/rpc_test_no_wait") as server: + with RPCClient("ipc:///tmp/rpc_test_no_wait") as client: + client.send_task().remote(need_response=False) + time.sleep( + 0.1 + ) # wait for some time to make sure the task is submitted + assert client.get_task_submitted().remote() + + +class TestRpcError: + + class CustomError(Exception): + pass + + def test_task_error(self): + """Test that server-side exceptions are properly wrapped in RPCError with details.""" + + class App: + + def hello(self): + raise ValueError("Test error message") + + def divide_by_zero(self): + return 1 / 0 + + def custom_exception(self): + raise TestRpcError.CustomError("Custom error occurred") + + with RPCServer(App()) as server: + server.bind("ipc:///tmp/rpc_test_error") + server.start() + time.sleep(0.1) + with RPCClient("ipc:///tmp/rpc_test_error") as client: + # Test ValueError handling + with pytest.raises(RPCError) as exc_info: + client.hello().remote() + + error = exc_info.value + assert "Test error message" in str(error) + assert error.cause is not None + assert isinstance(error.cause, ValueError) + assert error.traceback is not None + assert "ValueError: Test error message" in error.traceback + + # Test ZeroDivisionError handling + with pytest.raises(RPCError) as exc_info: + client.divide_by_zero().remote() + + error = exc_info.value + assert "division by zero" in str(error) + assert error.cause is not None + assert isinstance(error.cause, ZeroDivisionError) + assert error.traceback is not None + + # Test custom exception handling + with pytest.raises(RPCError) as exc_info: + client.custom_exception().remote() + + error = exc_info.value + assert "Custom error occurred" in str(error) + assert error.cause is not None + assert error.traceback is not None + + def test_shutdown_cancelled_error(self): + """Test that pending requests are cancelled with RPCCancelled when server shuts down.""" + + class App: + + def task(self): + time.sleep(10) + return True + + addr = "ipc:///tmp/rpc_test_cancelled" + + server = RPCServer( + App(), + # only one worker to make it easier to pend requests + num_workers=1) + server.bind(addr) + server.start() + time.sleep(0.1) + + with RPCClient(addr) as client: + client.shutdown_server() + pending_futures = [client.task().remote_future() for _ in range(10)] + + for future in pending_futures: + with pytest.raises(RPCCancelled): + future.result() + + time.sleep(5) + + client.close() + + def test_timeout_error(self): + """Test that requests that exceed timeout are handled with proper error.""" + + class App: + + def slow_method(self): + # Sleep longer than the timeout + time.sleep(2.0) + return "completed" + + with RpcServerWrapper(App(), + addr="ipc:///tmp/rpc_test_timeout") as server: + time.sleep(0.1) + + # Create client with short timeout + with RPCClient("ipc:///tmp/rpc_test_timeout", + timeout=0.5) as client: + with pytest.raises(RPCError) as exc_info: + client.slow_method().remote(timeout=0.5) + + error = exc_info.value + # Should be either a timeout error or RPC error indicating timeout + assert "timed out" in str(error).lower() or "timeout" in str( + error).lower() + + def test_method_not_found_error(self): + """Test that calling non-existent methods returns proper error.""" + + class App: + + def existing_method(self): + return "exists" + + with RpcServerWrapper(App(), + addr="ipc:///tmp/rpc_test_not_found") as server: + time.sleep(0.1) + + with RPCClient("ipc:///tmp/rpc_test_not_found") as client: + with pytest.raises(RPCError) as exc_info: + client.non_existent_method().remote() + + error = exc_info.value + assert "not found" in str(error) + assert error.traceback is not None + + +def test_rpc_shutdown_server(): + + class App: + + def hello(self): + return "world" + + with RPCServer(App()) as server: + server.bind("ipc:///tmp/rpc_test_shutdown") + server.start() + time.sleep(0.1) + with RPCClient("ipc:///tmp/rpc_test_shutdown") as client: + ret = client.hello().remote() + assert ret == "world" + + client.shutdown_server() + + time.sleep(5) # the server dispatcher thread need some time to quit + + +def test_rpc_without_response_performance(): + # At any circumstances, the RPC call without response should be faster than the one with response + class App: + + def __init__(self): + self.task_submitted = False + + def send_task(self) -> None: + # Just submit the task and return immediately + # The result is not important + time.sleep(0.001) + return None + + with RPCServer(App(), num_workers=10) as server: + server.bind("ipc:///tmp/rpc_test_no_wait") + server.start() + time.sleep(0.1) + with RPCClient("ipc:///tmp/rpc_test_no_wait") as client: + time_start = time.time() + for i in range(100): + client.send_task().remote(need_response=False) + time_end = time.time() + + no_wait_time = time_end - time_start + + time_start = time.time() + for i in range(100): + client.send_task().remote(need_response=True) + time_end = time.time() + wait_time = time_end - time_start + + assert no_wait_time < wait_time, f"{no_wait_time} > {wait_time}" + + +@pytest.mark.parametrize("async_run_task", [True, False]) +@pytest.mark.parametrize("use_ipc_addr", [True, False]) +def test_rpc_benchmark(async_run_task: bool, use_ipc_addr: bool): + + class App: + + def cal(self, n: int): + return n * 2 + + with RPCServer(App(), async_run_task=async_run_task) as server: + address = "ipc:///tmp/rpc_test" if use_ipc_addr else "tcp://127.0.0.1:*" + + server.bind(address) + server.start() + time.sleep(0.1) + + with RPCClient(server.address) as client: + + time_start = time.time() + for i in range(100): + ret = client.cal(i).remote(timeout=10) # sync call + assert ret == i * 2, f"{ret} != {i * 2}" + time_end = time.time() + print( + f"Time taken: {time_end - time_start} seconds, {10000 / (time_end - time_start)} calls/second" + ) + + +class TestRpcTimeout: + """Test RPC timeout functionality for both sync and async calls, sharing server/client.""" + + class App: + + def slow_operation(self, delay: float): + """A method that takes a long time to complete.""" + time.sleep(delay) + return "completed" + + def setup_method(self, method): + """Setup RPC server and client for timeout tests.""" + # Use unique address based on the test parameter to avoid socket conflicts + test_name = method.__name__ + self.address = f"ipc:///tmp/rpc_test_timeout_{test_name}_{id(self)}" + self.server = RPCServer(self.App()) + self.server.bind(self.address) + self.server.start() + time.sleep(0.1) + self.client = RPCClient(self.address) + + def teardown_method(self): + """Shutdown server and close client.""" + self.client.close() + self.server.shutdown() + # Add a small delay to ensure the socket is fully released before the next test + time.sleep(0.5) + + def run_sync_timeout_test(self): + with pytest.raises(RPCTimeout) as exc_info: + self.client.slow_operation(2.0).remote(timeout=0.1) + assert "timed out" in str( + exc_info.value), f"Timeout message not found: {exc_info.value}" + + def run_async_timeout_test(self): + import asyncio + + async def async_timeout(): + with pytest.raises(RPCTimeout) as exc_info: + await self.client.slow_operation(2.0).remote_async(timeout=0.1) + assert "timed out" in str( + exc_info.value), f"Timeout message not found: {exc_info.value}" + + asyncio.run(async_timeout()) + + def run_sync_success_test(self): + result = self.client.slow_operation(0.1).remote(timeout=10.0) + assert result == "completed" + print(f"final result: {result}") + + def run_async_success_test(self): + import asyncio + + async def async_success(): + result = await self.client.slow_operation(0.1).remote_async( + timeout=10.0) + assert result == "completed" + print(f"final result: {result}") + return result + + return asyncio.run(async_success()) + + @pytest.mark.parametrize("use_async", [True, False]) + def test_rpc_timeout(self, use_async): + if use_async: + self.run_async_timeout_test() + self.run_async_success_test() + else: + self.run_sync_timeout_test() + self.run_sync_success_test() + + +class TestRpcShutdown: + + def test_duplicate_shutdown(self): + + class App: + + def quick_task(self, task_id: int): + return f"quick_task_{task_id}" + + with RpcServerWrapper(App(), + addr="ipc:///tmp/rpc_test_shutdown") as server: + time.sleep(0.1) + with RPCClient("ipc:///tmp/rpc_test_shutdown") as client: + client.quick_task(1).remote() + + # repeated shutdown should not raise an error + for i in range(10): + server.shutdown() + + def test_submit_request_after_server_shutdown(self): + + class App: + + def foo(self, delay: int): + time.sleep(delay) + return "foo" + + server = RPCServer(App()) + server.bind("ipc:///tmp/rpc_test_shutdown") + server.start() + + time.sleep(0.1) + with RPCClient("ipc:///tmp/rpc_test_shutdown") as client: + # This task should be continued after server shutdown + res = client.foo(10).remote_future(timeout=12) + + # The shutdown will block until all pending requests are finished + server.shutdown() + + assert res.result() == "foo" + + +class TestApp: + """Test application with various method types.""" + + def __init__(self): + self.call_count = 0 + + def sync_add(self, a: int, b: int) -> int: + """Sync method.""" + self.call_count += 1 + return a + b + + async def async_multiply(self, x: int, y: int) -> int: + """Async method.""" + await asyncio.sleep(0.01) + self.call_count += 1 + return x * y + + async def streaming_range(self, n: int): + """Streaming generator.""" + for i in range(n): + await asyncio.sleep(0.01) + yield i + + async def streaming_error(self, n: int): + """Streaming generator that raises error.""" + for i in range(n): + if i == 2: + raise ValueError("Test error at i=2") + yield i + + async def streaming_timeout(self, delay: float): + """Streaming generator with configurable delay.""" + for i in range(10): + await asyncio.sleep(delay) + yield i + + +class TestRpcAsync: + # Use setup_method/teardown_method for pytest class-based setup/teardown + def setup_method(self): + """Setup RPC server and client for tests.""" + self.app = TestApp() + self.server = RPCServer(self.app, num_workers=2, async_run_task=True) + self.server.bind("tcp://127.0.0.1:0") # Use random port + self.server.start() + # Get actual address after binding + address = f"tcp://127.0.0.1:{self.server.address.split(':')[-1]}" + self.client = RPCClient(address) + + def teardown_method(self): + self.server.shutdown() + self.client.close() + + @pytest.mark.asyncio + async def test_sync_method(self): + """Test traditional sync method still works.""" + app, client, server = self.app, self.client, self.server + + # Test sync call + result = client.sync_add(5, 3).remote() + assert result == 8 + assert app.call_count == 1 + + @pytest.mark.asyncio + async def test_async_method(self): + """Test async method execution.""" + app, client, server = self.app, self.client, self.server + + # Test async call + result = await client.async_multiply(4, 7).remote_async() + assert result == 28 + assert app.call_count == 1 + + @pytest.mark.asyncio + async def test_streaming_basic(self): + """Test basic streaming functionality.""" + app, client, server = self.app, self.client, self.server + + results = [] + async for value in client.streaming_range(5).remote_streaming(): + results.append(value) + + assert results == [0, 1, 2, 3, 4] + + @pytest.mark.asyncio + async def test_streaming_concurrent(self): + """Test concurrent streaming calls.""" + app, client, server = self.app, self.client, self.server + + async def collect_stream(n): + results = [] + async for value in client.streaming_range(n).remote_streaming(): + results.append(value) + return results + + # Run 3 concurrent streams + results = await asyncio.gather(collect_stream(3), collect_stream(4), + collect_stream(5)) + + assert results[0] == [0, 1, 2] + assert results[1] == [0, 1, 2, 3] + assert results[2] == [0, 1, 2, 3, 4] + + @pytest.mark.asyncio + async def test_streaming_error_handling(self): + """Test error handling in streaming.""" + app, client, server = self.app, self.client, self.server + + results = [] + with pytest.raises(RPCStreamingError, match="Test error at i=2"): + async for value in client.streaming_error(5).remote_streaming(): + results.append(value) + + # Should have received values before error + assert results == [0, 1] + + @pytest.mark.asyncio + async def test_streaming_timeout(self): + """Test timeout handling in streaming.""" + app, client, server = self.app, self.client, self.server + + # Set short timeout + with pytest.raises(RPCTimeout): + async for value in client.streaming_timeout( + delay=2.0).remote_streaming(timeout=0.5): + pass # Should timeout before first yield + + @pytest.mark.asyncio + async def test_mixed_calls(self): + """Test mixing different call types.""" + app, client, server = self.app, self.client, self.server + + # Run sync, async, and streaming calls together + sync_result = client.sync_add(1, 2).remote() + async_future = client.async_multiply(3, 4).remote_future() + + streaming_results = [] + async for value in client.streaming_range(3).remote_streaming(): + streaming_results.append(value) + + async_result = async_future.result() + + assert sync_result == 3 + assert async_result == 12 + assert streaming_results == [0, 1, 2] + assert app.call_count == 2 # sync + async (streaming doesn't increment) + + @pytest.mark.asyncio + async def test_invalid_streaming_call(self): + """Test calling non-streaming method with streaming.""" + app, client, server = self.app, self.client, self.server + + # This should fail because sync_add is not an async generator + with pytest.raises(RPCStreamingError): + async for value in client.sync_add(1, 2).remote_streaming(): + pass + + +class TestResponsePickleError: + """ The pickle error will break the whole server, test the error handling. """ + + class App: + + def unpickleable_return(self): + # Functions defined locally are not pickleable + def nested_function(): + pass + + return nested_function + + async def unpickleable_streaming_return(self): + # Functions defined locally are not pickleable + def nested_function(): + pass + + yield nested_function + + def test_unpickleable_error(self): + with RpcServerWrapper( + self.App(), addr="ipc:///tmp/rpc_test_pickle_error") as server: + with RPCClient("ipc:///tmp/rpc_test_pickle_error") as client: + with pytest.raises(RPCError) as exc_info: + client.unpickleable_return().remote() + + assert "Failed to pickle response" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_unpickleable_streaming_error(self): + with RpcServerWrapper(self.App(), + addr="ipc:///tmp/rpc_test_pickle_error_streaming", + async_run_task=True) as server: + with RPCClient( + "ipc:///tmp/rpc_test_pickle_error_streaming") as client: + with pytest.raises(RPCStreamingError) as exc_info: + async for _ in client.unpickleable_streaming_return( + ).remote_streaming(): + pass + + assert "Failed to pickle response" in str(exc_info.value) diff --git a/tests/unittest/executor/test_rpc_proxy.py b/tests/unittest/executor/test_rpc_proxy.py new file mode 100644 index 00000000000..17d99fd24d7 --- /dev/null +++ b/tests/unittest/executor/test_rpc_proxy.py @@ -0,0 +1,99 @@ +import os +import sys +import time + +import pytest +from test_base_worker import create_fake_executor_config + +from tensorrt_llm.executor.rpc_proxy import GenerationExecutorRpcProxy +from tensorrt_llm.llmapi.llm_args import KvCacheConfig +from tensorrt_llm.llmapi.mpi_session import MpiPoolSession +from tensorrt_llm.llmapi.tokenizer import TransformersTokenizer +from tensorrt_llm.sampling_params import SamplingParams + +# isort: off +sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/..") +from utils.llm_data import llm_models_root +from utils.util import similar, skip_single_gpu +# isort: on + +model_path = llm_models_root() / "llama-models-v2/TinyLlama-1.1B-Chat-v1.0" + + +class TestRpcProxy: + + def create_proxy(self, tp_size: int): + # Create executor config with the correct tp_size + llm_args, executor_config = create_fake_executor_config(model_path, + tp_size=tp_size) + + # Enable KV cache events + llm_args.kv_cache_config = KvCacheConfig( + event_buffer_max_size=1000, # Enable event buffer + enable_block_reuse=True, # Required for KV cache events + ) + + mpi_session = MpiPoolSession(n_workers=tp_size) + proxy = GenerationExecutorRpcProxy( + worker_kwargs={ + "engine": model_path, + "executor_config": None, + "llm_args": llm_args, + "model_world_size": tp_size, + "hf_model_dir": model_path, + }, + model_world_size=tp_size, + mpi_session=mpi_session, + is_llm_executor=True, # Enable stats collection + ) + + # Add additional wait for PyTorch backend with multi-rank setup + if tp_size > 1: + print(f"[Test] Waiting for {tp_size} ranks to initialize...") + time.sleep( + 5) # Give more time for multi-rank PyTorch initialization + + return proxy + + @pytest.mark.parametrize("num_reqs", [1, 10]) + def test_tp1(self, num_reqs): + tokenizer = TransformersTokenizer.from_pretrained(model_path) + prompt = "A B C D" + prompt_token_ids = tokenizer.encode(prompt) + max_tokens = 8 + + with self.create_proxy(tp_size=1) as proxy: + sampling_params = SamplingParams(max_tokens=max_tokens) + for _ in range(num_reqs): + result = proxy.generate(prompt_token_ids, sampling_params) + print(f"get result: {result}") + assert similar(tokenizer.decode(result.outputs[0].token_ids), + 'E F G H I J K L') + + stats = proxy.get_stats(timeout=2) + assert stats + + kv_cache_events = proxy.get_kv_events(timeout=2) + # KV cache events may be empty if no cache operations occurred + assert isinstance(kv_cache_events, list) + + @pytest.mark.parametrize("num_reqs", [1, 10]) + @skip_single_gpu + @pytest.mark.gpu2 + def test_tp2(self, num_reqs): + tokenizer = TransformersTokenizer.from_pretrained(model_path) + prompt = "A B C D" + prompt_token_ids = tokenizer.encode(prompt) + max_tokens = 8 + + with self.create_proxy(tp_size=2) as proxy: + sampling_params = SamplingParams(max_tokens=max_tokens) + for _ in range(num_reqs): + result = proxy.generate(prompt_token_ids, sampling_params) + print(f"get result: {result}") + assert similar(tokenizer.decode(result.outputs[0].token_ids), + 'E F G H I J K L') + + +if __name__ == "__main__": + TestRpcProxy().test_tp1(1) diff --git a/tests/unittest/executor/test_rpc_worker.py b/tests/unittest/executor/test_rpc_worker.py new file mode 100644 index 00000000000..623c124c92b --- /dev/null +++ b/tests/unittest/executor/test_rpc_worker.py @@ -0,0 +1,251 @@ +import asyncio +import multiprocessing +import os +import sys +import time +from concurrent.futures import ProcessPoolExecutor + +import pytest +from test_base_worker import create_fake_executor_config + +from tensorrt_llm.executor.request import GenerationRequest +from tensorrt_llm.executor.rpc import RPCClient +from tensorrt_llm.executor.rpc_proxy import GenerationExecutorRpcProxy +from tensorrt_llm.executor.rpc_worker import RpcWorker +from tensorrt_llm.llmapi.mpi_session import MpiPoolSession +from tensorrt_llm.sampling_params import SamplingParams + +# isort: off +sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/..") +from utils.llm_data import llm_models_root +from utils.util import skip_single_gpu +# isort: on + +model_path = llm_models_root() / "llama-models-v2/TinyLlama-1.1B-Chat-v1.0" +assert model_path.exists() + + +class TestRpcWorkerTP1: + + def setup_method(self): + self.llm_args, self.executor_config = create_fake_executor_config( + model_path) + self.pool, self.addr = self.create_worker_pool() + self.client = self.create_rpc_client(self.addr) + self.client.setup_engine().remote() + print(f"Worker setup engine done") + time.sleep(10) + + def teardown_method(self): + self.client.shutdown().remote() + self.pool.shutdown() + self.client.close() + + def create_worker_pool(self): + addr = GenerationExecutorRpcProxy.gen_uniq_rpc_addr() + mp_context = multiprocessing.get_context( + 'spawn') # spawn for CUDA context + pool = ProcessPoolExecutor(max_workers=1, mp_context=mp_context) + pool.submit( + RpcWorker.main_task, + engine=model_path, + rpc_addr=addr, + executor_config=self.executor_config, + llm_args=self.llm_args, + hf_model_dir=model_path, + ) + return pool, addr + + def create_rpc_client(self, addr: str): + client = RPCClient(addr) + return client + + def test_create_shutdown(self): + pass + + def test_fetch_responses_sync(self): + # Wait a bit to ensure engine is ready + time.sleep(1) + + print(f"start to submit") + self.client.submit( + GenerationRequest(prompt_token_ids=[3, 4, 5], + sampling_params=SamplingParams( + max_tokens=5)), ).remote(need_response=False) + print(f"submit done") + + time.sleep(3) + + results = [] + # Fetch responses + results.extend(self.client.fetch_responses().remote()) + assert len(results) == 1 + + def test_fetch_responses_streaming_sync(self): + self.client.submit( + GenerationRequest(prompt_token_ids=[3, 4, 5], + sampling_params=SamplingParams(max_tokens=5), + streaming=True), ).remote(need_response=False) + + results = [] + for i in range(10): + res = self.client.fetch_responses().remote(timeout=1.0) + results.extend(res) + print(f"fetch_responses {i} result: {results}") + # If we've received enough results, break early + if len(results) >= 5: + break + assert 0 < len(results) <= 5 + + @pytest.mark.asyncio + @pytest.mark.parametrize("req_count", [10]) + async def test_main_loop_async(self, req_count: int): + await asyncio.sleep(1) + + async def process_request_streaming(): + for i in range(req_count): + ret = self.client.submit( + GenerationRequest( + prompt_token_ids=[3, 4, 5], + sampling_params=SamplingParams(max_tokens=5), + streaming=True), ).remote(need_response=False) + assert ret is None + print("submit result: ", ret) + + # NOTE: known issue, the responses should be fetched before shutdown, + # or the shutdown will hang. + results = [] + responses_per_client = {} + expected_responses_per_client = 5 # max_tokens=5 + + print(f"start to fetch_responses_async") + no = 0 + async for result in self.client.fetch_responses_loop_async( + ).remote_streaming(): + if result: # result is already a list of lists + print( + f"fetch_responses_async batch {no}, received {len(result)} sub-batches" + ) + for batch in result: + if isinstance(batch, list): + print(f" Sub-batch has {len(batch)} responses") + results.extend(batch) + # Track responses per client + for response in batch: + client_id = response.client_id + if client_id not in responses_per_client: + responses_per_client[client_id] = 0 + responses_per_client[client_id] += 1 + else: + # Single response + results.append(batch) + client_id = batch.client_id + if client_id not in responses_per_client: + responses_per_client[client_id] = 0 + responses_per_client[client_id] += 1 + + no += 1 + + # Check if all clients have received their expected responses + completed_clients = sum( + 1 for count in responses_per_client.values() + if count >= expected_responses_per_client) + + print(f"Responses per client: {responses_per_client}") + print(f"Completed clients: {completed_clients}/{req_count}") + + # Break when we've received all expected responses + if completed_clients >= req_count: + print( + f"All {completed_clients} clients completed after {no} batches" + ) + break + + # Safety break to prevent infinite loop + if no >= req_count * 20: # Much higher limit as safety + print(f"Safety break after {no} batches") + break + + print(f"Received {no} batches of streaming responses") + print(f"Total responses received: {len(results)}") + print(f"Final responses per client: {responses_per_client}") + assert results + assert len(responses_per_client) >= req_count + + await process_request_streaming() + + @pytest.mark.asyncio + async def test_fetch_stats_loop_async(self): + await asyncio.sleep(1) + results = [] + max_batches = 5 + + async def consume_stats(): + async for stats in self.client.fetch_stats_loop_async( + ).remote_streaming(): + results.append(stats) + assert not stats # empty stats + if len(results) >= max_batches: + break + + await asyncio.wait_for(consume_stats(), timeout=5) + + assert len(results) == max_batches + assert all(not stats for stats in results) + + +class TestRpcWorkerTP2: + + def setup_method(self): + self.llm_args, self.executor_config = create_fake_executor_config( + model_path, tp_size=2) + self.session, self.addr, self.futures = self.create_worker_session() + self.client = self.create_rpc_client(self.addr) + self.client.setup_engine().remote() + time.sleep(10) + + def teardown_method(self): + self.client.shutdown().remote() + self.session.shutdown() + self.client.close() + + def create_worker_session(self): + session = MpiPoolSession(n_workers=2) + addr = GenerationExecutorRpcProxy.gen_uniq_rpc_addr() + futures = session.submit(RpcWorker.main_task, + engine=model_path, + rpc_addr=addr, + executor_config=self.executor_config, + llm_args=self.llm_args, + hf_model_dir=model_path, + model_world_size=2) + return session, addr, futures + + def create_rpc_client(self, addr: str): + return RPCClient(addr) + + @skip_single_gpu + @pytest.mark.gpu2 + def test_create_shutdown(self): + # Invoke setup_engine in rank 0, and that will unblock all the ranks to + # invoke setup_engine simultaneously. + pass + + @skip_single_gpu + @pytest.mark.gpu2 + def test_fetch_responses_sync(self): + # Wait a bit to ensure engine is ready + time.sleep(1) + + self.client.submit( + GenerationRequest(prompt_token_ids=[3, 4, 5], + sampling_params=SamplingParams( + max_tokens=5)), ).remote(need_response=False) + + # Wait for generation to complete + time.sleep(3) + + results = [] + # Fetch responses with timeout + results.extend(self.client.fetch_responses().remote(timeout=5)) + assert len(results) == 1 diff --git a/tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py b/tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py index 28d6bedf1ba..b5d34dcd735 100644 --- a/tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py +++ b/tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py @@ -1,7 +1,7 @@ import pytest # isort: off -from .test_llm import tinyllama_logits_processor_test_harness +from .test_llm import tinyllama_logits_processor_test_harness, llama_model_path from tensorrt_llm import LLM from tensorrt_llm.llmapi import KvCacheConfig from tensorrt_llm.lora_helper import LoraConfig @@ -9,6 +9,8 @@ from .test_llm_pytorch import llama_7b_lora_from_dir_test_harness from .test_llm import _test_llm_capture_request_error # isort: on +from tensorrt_llm.executor.rpc_proxy import GenerationExecutorRpcProxy +from tensorrt_llm.sampling_params import SamplingParams global_kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4) @@ -57,3 +59,35 @@ def test_llama_7b_multi_lora_tp2(): # Disable CUDA graph # TODO: remove this once we have a proper fix for CUDA graph in LoRA cuda_graph_config=None) + + +@pytest.mark.gpu2 +def test_llm_rpc_tp2(): + with LLM(model=llama_model_path, + kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), + orchestrator_type="rpc", + tensor_parallel_size=2) as llm: + assert isinstance(llm._executor, GenerationExecutorRpcProxy) + + res = llm.generate("Tell me a joke", + sampling_params=SamplingParams(max_tokens=10, + end_id=-1)) + print(f"get result: {res}") + + assert len(res.outputs) == 1 + assert len(res.outputs[0].token_ids) == 10 + + +@pytest.mark.gpu2 +@pytest.mark.asyncio +async def test_llm_rpc_streaming_tp2(): + with LLM(model=llama_model_path, + kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), + orchestrator_type="rpc", + tensor_parallel_size=2) as llm: + assert isinstance(llm._executor, GenerationExecutorRpcProxy) + + async for output in llm.generate_async("Tell me a joke", + sampling_params=SamplingParams( + max_tokens=10, end_id=-1)): + print(f"get result: {output}") diff --git a/tests/unittest/llmapi/test_llm_pytorch.py b/tests/unittest/llmapi/test_llm_pytorch.py index 62253df45a5..0ed1faab2c7 100644 --- a/tests/unittest/llmapi/test_llm_pytorch.py +++ b/tests/unittest/llmapi/test_llm_pytorch.py @@ -6,6 +6,7 @@ from tensorrt_llm import LLM from tensorrt_llm.executor import GenerationExecutorWorker +from tensorrt_llm.executor.rpc_proxy import GenerationExecutorRpcProxy from tensorrt_llm.llmapi import KvCacheConfig from tensorrt_llm.llmapi.llm_args import NGramDecodingConfig, PeftCacheConfig from tensorrt_llm.llmapi.tokenizer import TransformersTokenizer @@ -940,3 +941,37 @@ def test_max_num_token_check(self): match="should not exceed max_num_tokens"): ids = [random.randint(10, 100) for _ in range(101)] llm.generate([ids]) + + +def test_llm_rpc(): + # TODO: remove the with-statement when shutdown hang issue is fixed + with LLM(model=llama_model_path, + kv_cache_config=global_kvcache_config, + orchestrator_type="rpc") as llm: + assert isinstance(llm._executor, GenerationExecutorRpcProxy) + + res = llm.generate("Tell me a joke", + sampling_params=SamplingParams(max_tokens=10, + end_id=-1)) + print(f"get result: {res}") + + assert len(res.outputs) == 1 + assert len(res.outputs[0].token_ids) == 10 + + +@pytest.mark.asyncio +async def test_llm_rpc_streaming(): + # TODO: remove the with-statement when shutdown hang issue is fixed + with LLM(model=llama_model_path, + kv_cache_config=global_kvcache_config, + orchestrator_type="rpc") as llm: + assert isinstance(llm._executor, GenerationExecutorRpcProxy) + + outputs = [] + async for output in llm.generate_async("Tell me a joke", + sampling_params=SamplingParams( + max_tokens=10, end_id=-1), + streaming=True): + outputs.append(output.outputs[0].text) + "".join(outputs) + print(f"get result: {outputs}") diff --git a/tests/unittest/llmapi/test_llm_utils.py b/tests/unittest/llmapi/test_llm_utils.py index 5155e158e62..58608556c9d 100644 --- a/tests/unittest/llmapi/test_llm_utils.py +++ b/tests/unittest/llmapi/test_llm_utils.py @@ -1,9 +1,13 @@ +import asyncio import tempfile +import threading +import time from pathlib import Path import torch from tensorrt_llm.llmapi.llm_utils import * +from tensorrt_llm.llmapi.utils import AsyncQueue # isort: off from .test_llm import llama_model_path @@ -58,3 +62,25 @@ def test_LlmArgs_default_gpus_per_node(): # set explicitly llm_args = TrtLlmArgs(model=llama_model_path, gpus_per_node=6) assert llm_args.gpus_per_node == 6 + + +def test_AsyncQueue(): + queue = AsyncQueue() + + # put data to queue sync in a thread + # async get data from queue in the current event loop + # NOTE: the event loop in the two threads are different + + def put_data_to_queue(): + for i in range(10): + time.sleep(0.1) + queue.put(i) + + async def get_data_from_queue(): + for i in range(10): + print(f"get: {queue.get()}") + + thread = threading.Thread(target=put_data_to_queue) + thread.start() + asyncio.run(get_data_from_queue()) + thread.join() diff --git a/tests/unittest/pytest.ini b/tests/unittest/pytest.ini index 1e93eae77ab..f78281579df 100644 --- a/tests/unittest/pytest.ini +++ b/tests/unittest/pytest.ini @@ -2,7 +2,7 @@ xdist_start_method = spawn asyncio_default_fixture_loop_scope = module threadleak = True -threadleak_exclude = asyncio_\d+ +threadleak_exclude = asyncio_\d+|rpc_client_loop|rpc_client_worker_\d+|rpc_server_worker_\d+ addopts = --durations=0 -W ignore::DeprecationWarning pythonpath = _torch/auto_deploy/_utils_test