Skip to content
Closed
1 change: 1 addition & 0 deletions examples/llm-api/llm_logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# This simple callback will output a specific token at each step irrespective of prompt.
# Refer to ../bindings/executor/example_logits_processor.py for a more
# sophisticated callback that generates JSON structured output.
# Please also refer to sampling_params.py for adding subclass to the approved class list for deserialization
class MyLogitsProcessor(LogitsProcessor):

def __init__(self, allowed_token_id: int):
Expand Down
7 changes: 4 additions & 3 deletions tensorrt_llm/auto_parallel/parallelization.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import contextlib
import copy
import itertools
import pickle # nosec B403
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
Expand All @@ -12,6 +11,7 @@
import torch
from filelock import FileLock

import tensorrt_llm.serialization as serialization
from tensorrt_llm._utils import (str_dtype_to_trt, trt_dtype_to_np,
trt_dtype_to_torch)
from tensorrt_llm.functional import AllReduceParams, create_allreduce_plugin
Expand Down Expand Up @@ -55,12 +55,13 @@ class ParallelConfig:

def save(self, filename):
with open(filename, 'wb') as file:
pickle.dump(self, file)
serialization.dump(self, file)

@staticmethod
def from_file(filename) -> "ParallelConfig":
with open(filename, "rb") as file:
return pickle.load(file) # nosec B301
return serialization.load(
file, approved_imports=serialization.BASE_PARALLEL_CLASSES)

def print_graph_strategy(self, file=None):
for index, (node_name,
Expand Down
24 changes: 15 additions & 9 deletions tensorrt_llm/executor/ipc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import hashlib
import hmac
import os
import pickle # nosec B403
import time
import traceback
from queue import Queue
Expand All @@ -10,6 +9,7 @@
import zmq
import zmq.asyncio

import tensorrt_llm.serialization as serialization
from tensorrt_llm.logger import logger

from .._utils import nvtx_mark, nvtx_range_debug
Expand Down Expand Up @@ -116,26 +116,26 @@ def poll(self, timeout: int) -> bool:
def put(self, obj: Any):
self.setup_lazily()
with nvtx_range_debug("send", color="blue", category="IPC"):
data = serialization.dumps(obj)
if self.use_hmac_encryption:
# Send pickled data with HMAC appended
data = pickle.dumps(obj) # nosec B301
signed_data = self._sign_data(data)
self.socket.send(signed_data)
else:
# Send data without HMAC
self.socket.send_pyobj(obj)
self.socket.send(data)

async def put_async(self, obj: Any):
self.setup_lazily()
try:
data = serialization.dumps(obj)
if self.use_hmac_encryption:
# Send pickled data with HMAC appended
data = pickle.dumps(obj) # nosec B301
signed_data = self._sign_data(data)
await self.socket.send(signed_data)
else:
# Send data without HMAC
await self.socket.send_pyobj(obj)
await self.socket.send(data)
except TypeError as e:
logger.error(f"Cannot pickle {obj}")
raise e
Expand All @@ -161,10 +161,13 @@ def get(self) -> Any:
if not self._verify_hmac(data, actual_hmac):
raise RuntimeError("HMAC verification failed")

obj = pickle.loads(data) # nosec B301
obj = serialization.loads(
data, approved_imports=serialization.BASE_ZMQ_CLASSES)
else:
# Receive data without HMAC
obj = self.socket.recv_pyobj()
data = self.socket.recv()
obj = serialization.loads(
data, approved_imports=serialization.BASE_ZMQ_CLASSES)
return obj

async def get_async(self) -> Any:
Expand All @@ -182,10 +185,13 @@ async def get_async(self) -> Any:
if not self._verify_hmac(data, actual_hmac):
raise RuntimeError("HMAC verification failed")

obj = pickle.loads(data) # nosec B301
obj = serialization.loads(
data, approved_imports=serialization.BASE_ZMQ_CLASSES)
else:
# Receive data without HMAC
obj = await self.socket.recv_pyobj()
data = await self.socket.recv()
obj = serialization.loads(
data, approved_imports=serialization.BASE_ZMQ_CLASSES)
return obj

def close(self):
Expand Down
11 changes: 8 additions & 3 deletions tensorrt_llm/executor/postproc_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import zmq
import zmq.asyncio

import tensorrt_llm.serialization as serialization

from .._utils import nvtx_range_debug
from ..bindings import executor as tllm
from ..llmapi.tokenizer import TransformersTokenizer, load_hf_tokenizer
Expand Down Expand Up @@ -77,6 +79,7 @@ def __init__(
tokenizer_dir: str,
record_creator: Callable[
["PostprocWorker.Input", TransformersTokenizer], Any],
BASE_ZMQ_CLASSES: Dict,
):
'''
Args:
Expand All @@ -86,7 +89,7 @@ def __init__(
record_creator (Callable[["ResponsePostprocessWorker.Input"], Any]): A creator for creating a record for a request.
result_handler (Optional[Callable[[GenerationResultBase], Any]]): A callback handles the final result.
'''

serialization.BASE_ZMQ_CLASSES = BASE_ZMQ_CLASSES
self._records: Dict[int, GenerationResult] = {}
self._record_creator = record_creator
self._pull_pipe = ZeroMqQueue(address=pull_pipe_addr,
Expand Down Expand Up @@ -213,9 +216,11 @@ async def main():
@print_traceback_on_error
def postproc_worker_main(feedin_ipc_addr: tuple[str, Optional[bytes]],
feedout_ipc_addr: tuple[str, Optional[bytes]],
tokenizer_dir: str, record_creator: Callable):
tokenizer_dir: str, record_creator: Callable,
BASE_ZMQ_CLASSES: Dict):
worker = PostprocWorker(feedin_ipc_addr,
feedout_ipc_addr,
tokenizer_dir=tokenizer_dir,
record_creator=record_creator)
record_creator=record_creator,
BASE_ZMQ_CLASSES=BASE_ZMQ_CLASSES)
worker.start()
3 changes: 2 additions & 1 deletion tensorrt_llm/executor/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import zmq
import zmq.asyncio

import tensorrt_llm.serialization as serialization
from tensorrt_llm.logger import logger

from .._utils import mpi_rank
Expand Down Expand Up @@ -297,7 +298,7 @@ def mpi_done_callback(future: concurrent.futures.Future):
tracer_init_kwargs=tracer_init_kwargs,
_torch_model_class_mapping=MODEL_CLASS_MAPPING,
ready_signal=ExecutorBindingsProxy.READY_SIGNAL,
)
BASE_ZMQ_CLASSES=serialization.BASE_ZMQ_CLASSES)
for fut in self.mpi_futures:
fut.add_done_callback(mpi_done_callback)

Expand Down
8 changes: 5 additions & 3 deletions tensorrt_llm/executor/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import torch

import tensorrt_llm.serialization as serialization
from tensorrt_llm.logger import logger

from .._utils import (KVCacheEventSerializer, global_mpi_rank, mpi_comm,
Expand Down Expand Up @@ -528,7 +529,9 @@ def worker_main(
is_llm_executor: Optional[
bool] = True, # whether it's the main executor instance
lora_config: Optional[LoraConfig] = None,
BASE_ZMQ_CLASSES: Dict = serialization.BASE_ZMQ_CLASSES,
) -> None:
serialization.BASE_ZMQ_CLASSES = BASE_ZMQ_CLASSES
mpi_comm().barrier()
print_colored_debug(f"Worker {mpi_rank()} entering worker_main...\n",
"green")
Expand Down Expand Up @@ -623,12 +626,11 @@ def notify_proxy_threads_to_quit():
assert isinstance(proxy_result_queue, tuple)
for i in range(postproc_worker_config.num_postprocess_workers):
fut = postproc_worker_pool.submit(
postproc_worker_main,
result_queues[i].address,
postproc_worker_main, result_queues[i].address,
proxy_result_queue,
postproc_worker_config.postprocess_tokenizer_dir,
PostprocWorker.default_record_creator,
)
serialization.BASE_ZMQ_CLASSES)
postprocess_worker_futures.append(fut)

# Error handling in the Worker/MPI process
Expand Down
9 changes: 9 additions & 0 deletions tensorrt_llm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pydantic import BaseModel

from tensorrt_llm.bindings import executor as tllme
from tensorrt_llm.serialization import register_approved_ipc_class


@dataclass(slots=True, kw_only=True)
Expand Down Expand Up @@ -69,6 +70,14 @@ def __call__(self, req_id: int, logits: torch.Tensor,
"""
pass # noqa

def __init_subclass__(cls, **kwargs):
"""
This method is called when a class inherits from LogitsProcessor.
"""
# Register subclass as an approved class for deserialization across IPC boundaries.
super().__init_subclass__(**kwargs)
register_approved_ipc_class(cls)


class BatchedLogitsProcessor(ABC):
"""Base class for batched logits processor.
Expand Down
Loading