Skip to content
38 changes: 38 additions & 0 deletions tensorrt_llm/executor/base_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,12 @@ def fetch_stats(self) -> list:
else:
return self.engine.get_latest_iteration_stats()

def fetch_kv_cache_events(self) -> list:
if isinstance(self.engine, tllm.Executor):
return self.engine.get_latest_kv_cache_events()
else:
return self.engine.get_latest_kv_cache_events()

def set_result_queue(self, queue):
"""In multi-gpu mode, result_queue will be set here to communicate between the proxy and the worker 0 process."""
assert self.postproc_queues is None
Expand Down Expand Up @@ -548,6 +554,38 @@ def submit(self, request: GenerationRequest) -> GenerationResult:

return result

def shutdown(self):
if self.doing_shutdown:
return
else:
self.doing_shutdown = True

if self.engine is not None and self.engine.can_enqueue_requests():
self.engine.shutdown()
self.engine = None

# Define a Callable to join iteration and request stats
@staticmethod
def _stats_serializer(
stats: Tuple[tllm.IterationStats, tllm.RequestStats]) -> str:
iteration_stats, req_stats = stats
stats_dict = json.loads(iteration_stats.to_json_str())

if req_stats is not None and len(req_stats) > 0:
stats_dict["requestStats"] = []
for req_stat in req_stats:
stats_dict["requestStats"].append(
json.loads(req_stat.to_json_str()))

# Convert back to JSON string
return json.dumps(stats_dict)

# Define a Callable to serialize KV cache events
@staticmethod
def _kv_cache_events_serializer(events) -> str:
from .._utils import KVCacheEventSerializer
return json.dumps(KVCacheEventSerializer.serialize(events))

def _pop_result(self, client_id: int):
self._results.pop(client_id, None)
self._client_id_to_request_id.pop(client_id, None)
Expand Down
34 changes: 34 additions & 0 deletions tensorrt_llm/executor/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,9 +410,22 @@ def create(
mpirun_launch = external_mpi_comm_available(model_world_size)
# The case where the Python main process utilizes mpi4py to spawn MPI workers
spawn_workers = need_spawn_mpi_workers(model_world_size)
orchestrator_is_rpc = llm_args and llm_args.orchestrator_type == "rpc"

if spawn_workers or (mpirun_launch and reuse_mpi_comm):
if reuse_mpi_comm:
assert mpi_session is not None, "reuse_mpi_comm requires an external MPI session"

if orchestrator_is_rpc:
from .rpc_proxy import GenerationExecutorRpcProxy
return GenerationExecutorRpcProxy(
worker_kwargs,
model_world_size=model_world_size,
mpi_session=mpi_session,
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor,
kv_connector_config=kv_connector_config)

return GenerationExecutorProxy(
worker_kwargs,
model_world_size=model_world_size,
Expand All @@ -429,6 +442,16 @@ def create(
logger.warning(
"Using single process worker for TP1, this may hurt streaming generation performance."
)
if orchestrator_is_rpc:
from .rpc_proxy import GenerationExecutorRpcProxy
return GenerationExecutorRpcProxy(
worker_kwargs,
model_world_size=model_world_size,
mpi_session=mpi_session,
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor,
kv_connector_config=kv_connector_config)

return GenerationExecutorWorker(
**worker_kwargs,
is_llm_executor=is_llm_executor,
Expand All @@ -439,6 +462,16 @@ def create(
# While this requires uses to protect their entrypoint to
# `if __name__ == "__main__":`.
if not platform.system() == 'Windows':
if orchestrator_is_rpc:
from .rpc_proxy import GenerationExecutorRpcProxy
return GenerationExecutorRpcProxy(
worker_kwargs,
model_world_size=model_world_size,
mpi_session=mpi_session,
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor,
kv_connector_config=kv_connector_config)

return GenerationExecutorProxy(
worker_kwargs,
model_world_size=model_world_size,
Expand All @@ -451,6 +484,7 @@ def create(
# The ProcessPoolExecutorSession is used to support Windows, as mpi4py cannot.
mpi_session = ProcessPoolExecutorSession(n_workers=1,
mp_context=ctx)
# TODO: add rpc worker here
return GenerationExecutorProxy(
worker_kwargs,
model_world_size=model_world_size,
Expand Down
18 changes: 18 additions & 0 deletions tensorrt_llm/executor/ipc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import hashlib
import hmac
import os
Expand Down Expand Up @@ -179,6 +180,20 @@ async def put_async(self, obj: Any):

nvtx_mark("ipc.send", color="blue", category="IPC")

async def put_async_noblock(self, obj: Any):
self.setup_lazily()
try:
if self.use_hmac_encryption:
data = pickle.dumps(obj) # nosec B301
signed_data = self._sign_data(data)
await self.socket.send(signed_data, flags=zmq.NOBLOCK)
else:
await self.socket.send_pyobj(obj, flags=zmq.NOBLOCK)
except Exception as e:
logger.error(f"Error sending object: {e}")
logger.error(traceback.format_exc())
raise e

def get(self) -> Any:
self.setup_lazily()
return self._recv_data()
Expand All @@ -187,6 +202,9 @@ async def get_async(self) -> Any:
self.setup_lazily()
return await self._recv_data_async()

async def get_async_noblock(self, timeout: float = 0.5) -> Any:
return await asyncio.wait_for(self.get_async(), timeout)

def close(self):
if self.socket:
self.socket.close()
Expand Down
3 changes: 2 additions & 1 deletion tensorrt_llm/executor/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ..bindings import executor as tllm
from ..disaggregated_params import DisaggregatedParams
from ..llmapi.tracer import global_tracer
from ..llmapi.utils import AsyncQueue
from ..llmapi.utils import AsyncQueue, print_traceback_on_error
from ..metrics import MetricNames, MetricsCollector, RequestEventTiming
from ..sampling_params import LogprobParams, SamplingParams
from .utils import ErrorResponse, has_event_loop, is_llm_response
Expand Down Expand Up @@ -315,6 +315,7 @@ def _handle_sequence(self,
f"Unknown finish reason: {finish_reasons[src_idx]}")
self.record_stats(output, req_perf_metrics_dict)

@print_traceback_on_error
@nvtx_range_debug("handle_response",
color="red",
category="GenerationResultBase")
Expand Down
85 changes: 85 additions & 0 deletions tensorrt_llm/executor/rpc/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# A Lightweight RPC
This is a pure-Python lightweight RPC we build to simplify our existing IPC code in the orchestrator part. It provides multiple call modes (sync, async, future, streaming) and supports both IPC and TCP connections.

## Examples
### Create Server and Client

```python
from tensorrt_llm.executor.rpc import RPCServer, RPCClient

# Define your application
class App:
def add(self, a: int, b: int) -> int:
return a + b

async def async_multiply(self, x: int, y: int) -> int:
return x * y

# Create and start server
app = App()
with RPCServer(app) as server:
server.bind("ipc:///tmp/my_rpc") # or "tcp://127.0.0.1:5555"
server.start()

# Create client and make calls
with RPCClient("ipc:///tmp/my_rpc") as client:
result = client.add(5, 3).remote()
print(result) # Output: 8
```

### Different Remote Calls

#### Synchronous Call
```python
# Blocking call that waits for result
result = client.add(10, 20).remote()
# or with timeout
result = client.add(10, 20).remote(timeout=5.0)
```

#### Asynchronous Call
```python
# Async call that returns a coroutine
result = await client.async_multiply(3, 4).remote_async()
```

#### Future-based Call
```python
# Returns a concurrent.futures.Future
future = client.add(1, 2).remote_future()
# Get result later
result = future.result()
```

#### Fire-and-Forget Call
```python
# Send request without waiting for response
client.submit_task(task_id=123).remote(need_response=False)
```

#### Streaming Call
```python
# For async generator methods
async for value in client.stream_data(n=10).remote_streaming():
print(f"Received: {value}")
```

### Error Handling
```python
from tensorrt_llm.executor.rpc import RPCError, RPCTimeout

try:
result = client.risky_operation().remote(timeout=1.0)
except RPCTimeout:
print("Operation timed out")
except RPCError as e:
print(f"RPC Error: {e}")
print(f"Original cause: {e.cause}")
print(f"Traceback: {e.traceback}")
```

### Graceful Shutdown
```python
# Shutdown server from client
client.shutdown_server()
```
10 changes: 10 additions & 0 deletions tensorrt_llm/executor/rpc/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from .rpc_client import RPCClient
from .rpc_common import (RPCCancelled, RPCError, RPCParams, RPCRequest,
RPCResponse, RPCStreamingError, RPCTimeout)
from .rpc_server import RPCServer, Server

__all__ = [
"RPCClient", "RPCServer", "Server", "RPCError", "RPCTimeout",
"RPCCancelled", "RPCStreamingError", "RPCRequest", "RPCResponse",
"RPCParams"
]
Loading