From 994f72cf356bd6041c6de353ac242d264d259937 Mon Sep 17 00:00:00 2001 From: Yan Chunwei <328693+Superjomn@users.noreply.github.com> Date: Thu, 18 Sep 2025 15:38:55 +0800 Subject: [PATCH 01/13] init Signed-off-by: Yan Chunwei <328693+Superjomn@users.noreply.github.com> Signed-off-by: chunweiy --- tensorrt_llm/executor/worker_base.py | 834 ++++++++++++++++++++ tests/unittest/executor/test_worker_base.py | 165 ++++ 2 files changed, 999 insertions(+) create mode 100644 tensorrt_llm/executor/worker_base.py create mode 100644 tests/unittest/executor/test_worker_base.py diff --git a/tensorrt_llm/executor/worker_base.py b/tensorrt_llm/executor/worker_base.py new file mode 100644 index 00000000000..24a43e261d5 --- /dev/null +++ b/tensorrt_llm/executor/worker_base.py @@ -0,0 +1,834 @@ +import copy +import datetime +import enum +import json +import weakref +from pathlib import Path +from queue import Queue +from typing import Dict, List, Optional, Tuple, Union + +import torch + +from tensorrt_llm.logger import logger + +from .._utils import (global_mpi_rank, global_mpi_size, mpi_comm, mpi_rank, + nvtx_range_debug) +from ..bindings import executor as tllm +from ..builder import ConfigEncoder, Engine, EngineConfig +from ..llmapi.llm_args import BaseLlmArgs, KvCacheConnectorConfig, PybindMirror +from ..llmapi.tokenizer import TokenizerBase +from ..llmapi.tracer import global_tracer +from ..llmapi.utils import _SyncQueue, print_colored_debug +from ..lora_helper import LoraConfig +from ..lora_manager import LoraManager +from ..metrics import RequestEventTiming +from ..prompt_adapter_manager import PromptAdapterManager +from ..runtime import ModelConfig +from ..runtime.model_runner import _engine_config_to_model_config +from ..sampling_params import BatchedLogitsProcessor, SamplingParams +from .executor import GenerationExecutor, IterationResultQueue +from .ipc import FusedIpcQueue, IpcQueue +from .postproc_worker import (PostprocParams, PostprocWorker, + PostprocWorkerConfig) +from .request import GenerationRequest, LoRARequest, PromptAdapterRequest +from .result import (GenerationResult, LogProbsResult, ResponseWrapper, + compute_logprobs) +from .utils import (ErrorResponse, IntraProcessQueue, RequestError, + is_llm_response) + +__all__ = [ + "WorkerBase", +] + + +class WorkerBase(GenerationExecutor): + + class WorkerExit(GeneratorExit): + pass + + def __init__( + self, + engine: Union[Path, Engine], + executor_config: Optional[tllm.ExecutorConfig] = None, + batched_logits_processor: Optional[BatchedLogitsProcessor] = None, + postproc_worker_config: Optional[PostprocWorkerConfig] = None, + is_llm_executor: Optional[bool] = None, + lora_config: Optional[LoraConfig] = None, + kv_connector_config: Optional[KvCacheConnectorConfig] = None, + hf_model_dir: Optional[Path] = None, + tokenizer: Optional[TokenizerBase] = None, + llm_args: Optional[BaseLlmArgs] = None, + ) -> None: + postproc_config = postproc_worker_config or PostprocWorkerConfig() + super().__init__( + num_postprocess_workers=postproc_config.num_postprocess_workers, + postprocess_tokenizer_dir=postproc_config.postprocess_tokenizer_dir, + is_llm_executor=is_llm_executor, + ) + + # inputs + self._engine = engine + self._executor_config = executor_config + self._batched_logits_processor = batched_logits_processor + self._postproc_worker_config = postproc_worker_config + self._is_llm_executor = is_llm_executor + self._lora_config = lora_config + self._kv_connector_config = kv_connector_config + self._hf_model_dir = hf_model_dir + self._tokenizer = tokenizer + self.llm_args = llm_args + + self.engine = None + self.result_queue: Optional[IpcQueue] = None + self.postproc_queues: Optional[List[IpcQueue]] = None + self.rank = mpi_rank() + self.global_rank = global_mpi_rank() + # mapping: client_id -> GenerationResult + self._results: Dict[int, GenerationResult] = {} + # mapping: client_id from Proxy -> request_id returned from runtime backend + self._client_id_to_request_id: Dict[int, int] = {} + self._await_response_helper = AwaitResponseHelper(weakref.proxy(self)) + self._is_pytorch_backend = llm_args is not None and llm_args.backend in [ + "pytorch", "_autodeploy" + ] + + if not self._is_pytorch_backend and kv_connector_config is not None: + raise ValueError( + "KV connector config is only supported for PyTorch backend") + + if global_mpi_size() > 1: + logger.set_rank(self.global_rank) + + def setup_engine(self): + """ + Setup the engine for the worker. + """ + + if isinstance(self._engine, list): + self._engine[self.rank] + + def _get_comm_ranks_device_id(): + device_id = self.global_rank % torch.cuda.device_count() + torch.cuda.set_device(device_id) + # Make sure C++ executor would use same devices/ranks as py_executor + global_rank = global_mpi_rank() + comm_ranks = mpi_comm().allgather(global_rank) + device_ids = mpi_comm().allgather(device_id) + return comm_ranks, device_ids + + def _create_py_executor(): + args = {} + assert hasattr( + self.llm_args, "backend" + ), "llm_args should be with backend in _create_py_executor" + _ = _get_comm_ranks_device_id() + if self.llm_args.backend == "pytorch": + from tensorrt_llm._torch.pyexecutor.py_executor_creator import \ + create_py_executor + create_executor = create_py_executor + args["llm_args"] = self.llm_args + args["checkpoint_dir"] = self._hf_model_dir + args["tokenizer"] = self._tokenizer + args["lora_config"] = self._lora_config + args["kv_connector_config"] = self._kv_connector_config + elif self.llm_args.backend == "_autodeploy": + from tensorrt_llm._torch.auto_deploy.llm_args import \ + LlmArgs as ADLlmArgs + from tensorrt_llm._torch.auto_deploy.shim.ad_executor import \ + create_autodeploy_executor + create_executor = create_autodeploy_executor + assert isinstance(self.llm_args, ADLlmArgs) + args["ad_config"] = self.llm_args.get_pytorch_backend_config() + else: + raise ValueError( + f"Unsupported backend config: {self.llm_args.backend}") + + # Define additional attributes that can be used later, such as in _deduce_max_tokens + self.mapping = self.llm_args.parallel_config.to_mapping() + self.checkpoint_loader = None + if self.llm_args.backend == "pytorch": + from tensorrt_llm._torch.pyexecutor.config import \ + _construct_checkpoint_loader + self.checkpoint_loader = _construct_checkpoint_loader( + self.llm_args.backend, self.llm_args.checkpoint_loader, + self.llm_args.checkpoint_format) + + _executor = create_executor(**args) + self.max_seq_len = self.llm_args.max_seq_len + if _executor.max_seq_len is not None: + # max_seq_len might be updated by model engine as in create_py_executor + self.max_seq_len = _executor.max_seq_len + return _executor + + def _create_engine(executor_config): + engine = self._engine + if executor_config is None: + executor_config = tllm.ExecutorConfig(1) + executor_config.logits_post_processor_config = tllm.LogitsPostProcessorConfig( + processor_batched=self._batched_logits_processor, + replicate=False) + comm_ranks, device_ids = _get_comm_ranks_device_id() + executor_config.parallel_config = tllm.ParallelConfig( + participant_ids=comm_ranks, device_ids=device_ids) + + if isinstance(engine, Engine): + return tllm.Executor(engine.engine, + json.dumps(engine.config.to_dict(), + cls=ConfigEncoder), + tllm.ModelType.DECODER_ONLY, + executor_config=executor_config, + managed_weights=engine.managed_weights) + + assert not hasattr(executor_config, "backend") + return tllm.Executor(engine, tllm.ModelType.DECODER_ONLY, + executor_config) + + self.engine = _create_py_executor( + ) if self.llm_args is not None else _create_engine( + self._executor_config) + + self._lora_manager: Optional[LoraManager] = None + self._prompt_adapter_manager: Optional[PromptAdapterManager] = None + self._runtime_model_config: Optional[ModelConfig] = None + if self.rank == 0 and isinstance(self.engine, tllm.Executor): + if isinstance(self.engine, Engine): + engine_config = self.engine.config + else: + engine_config = EngineConfig.from_json_file( + f"{self._engine}/config.json") + self._runtime_model_config = _engine_config_to_model_config( + engine_config) + if engine_config.build_config.plugin_config.lora_plugin: + # TODO(azuker): Passing peft cache manager to LoraManager is used for LoRA optimization + # (see LoraManager constructor docstring). Getting the peft cache manager from this + # point in the TRT flow is currently not supported (it's at the CPP + # Executor->ExecutorImpl->TrtGptModel->mPeftCacheManager) therefore for now this LoRA + # optimization is not available in TRT-python flow. + self._lora_manager = LoraManager(cpp_peft_cache_manager=None) + if engine_config.build_config.max_prompt_embedding_table_size > 0: + self._prompt_adapter_manager = PromptAdapterManager() + + if self.llm_args and getattr( + self.llm_args, "backend", + "") == "pytorch" and self._lora_config is not None: + from tensorrt_llm._torch.pyexecutor.resource_manager import \ + ResourceManagerType + peft_cache_manager = self.engine.resource_manager.resource_managers.get( + ResourceManagerType.PEFT_CACHE_MANAGER) + self._lora_manager = LoraManager( + cpp_peft_cache_manager=peft_cache_manager.impl) + lora_model_config = self.engine.model_engine.lora_model_config + assert lora_model_config is not None + self._lora_model_config = lora_model_config + + def await_responses(self, timeout: Optional[float] = None) -> list: + return self.engine.await_responses(timeout=datetime.timedelta( + seconds=timeout) if timeout is not None else None) + + def fetch_stats(self) -> list: + if isinstance(self.engine, tllm.Executor): + iter_stats = self.engine.get_latest_iteration_stats() + #TODO: Support req stats with TRT engine + # This would require ensuring iter and req stats have same size + return [(iter_stat, None) for iter_stat in iter_stats] + else: + return self.engine.get_latest_iteration_stats() + + def set_result_queue(self, queue): + """In multi-gpu mode, result_queue will be set here to communicate between the proxy and the worker 0 process.""" + assert self.postproc_queues is None + self.result_queue = queue + + def set_postproc_queues(self, queues: List["IpcQueue"]): + """ Set the IPC queues for feeding post-processing processes. """ + assert self.result_queue is None + self.postproc_queues = queues + + def _set_iteration_result_queue(self, it_result_queue: IterationResultQueue, + queue: Union[Queue, FusedIpcQueue, + IntraProcessQueue]): + assert not it_result_queue.is_initialized, "Iteration result queue should not already be initialized." + it_result_queue.is_initialized = True + it_result_queue.queue = queue + it_result_queue.aqueue = None + + def return_queue(self, client_id: int): + """ If a centralized result queue is registered (used for communication with the proxy) + send the message there. + Otherwise, push the result directly in the GenerationResult queue. + """ + if self.result_queue is not None: + return self.result_queue + return self._results[client_id].queue + + def abort_request(self, client_id: int) -> None: + # NOTE: the request_id is the request_id generated by cpp runtime, not the client_id + if self.engine.can_enqueue_requests(): + request_id = self._client_id_to_request_id.get(client_id, None) + if request_id is None: + logger.warning( + f"Request of client_id {client_id} is finished, cannot abort it." + ) + return + self.engine.cancel_request(request_id) + + def _engine_response_callback(self, response: tllm.Response): + return response + + def await_response_task(self) -> bool: + return self._await_response_helper() + + def _has_background_error(self) -> bool: + return not self._error_queue.empty() + + def _create_error_response(self, response: tllm.Response) -> ErrorResponse: + bck_error = self._error_queue.get_nowait() + assert isinstance(bck_error, Exception) + return ErrorResponse(response.client_id, str(bck_error), + response.request_id) + + def start(self): + raise NotImplementedError( + "start method is not implemented in WorkerBase") + + def _load_lora_adapter(self, lora_request: LoRARequest) -> bool: + """Returns True if the adapter was loaded by this call, False if it was already loaded""" + adapter_id = str(lora_request.adapter_id) + newly_loaded_uids = self._lora_manager.load_from_ckpt( + [lora_request.path], + model_config=self._runtime_model_config if + self._runtime_model_config is not None else self._lora_model_config, + runtime_mapping=None, + uids=[adapter_id], + ckpt_source=lora_request.ckpt_source) + return adapter_id in newly_loaded_uids + + def _load_prompt_adapter(self, + prompt_adapter_request: PromptAdapterRequest): + self._prompt_adapter_manager.load_from_ckpt( + [prompt_adapter_request.local_path], + model_config=self._runtime_model_config, + uids=[str(prompt_adapter_request.adapter_id)]) + + def _enqueue_request(self, request: GenerationRequest) -> int: + assert request.id is not None + py_lora_path = None + if self._lora_manager is not None and request.lora_request is not None: + adapter_in_cache = self._lora_manager.is_adapter_in_cpu_cache( + request.lora_request.adapter_id) + self._load_lora_adapter(request.lora_request) + uid = str(request.lora_request.adapter_id) + lora_config = tllm.LoraConfig( + task_id=request.lora_request.adapter_id, + weights=self._lora_manager.cpp_lora_weights[uid] + if not adapter_in_cache else None, + config=self._lora_manager.cpp_lora_config[uid]) + py_lora_path = request.lora_request.lora_path + else: + lora_config = None + + prompt_token_ids = copy.deepcopy(request.prompt_token_ids) + prompt_tuning_config = None + if request.prompt_adapter_request is not None: + self._load_prompt_adapter(request.prompt_adapter_request) + uid = str(request.prompt_adapter_request.adapter_id) + prompt_tuning_config = tllm.PromptTuningConfig( + self._prompt_adapter_manager.uid_to_weights[uid]) + vocab_size = self._runtime_model_config.vocab_size + pa_length = prompt_tuning_config.embedding_table.size(0) + prompt_token_ids = list(range( + vocab_size, vocab_size + pa_length)) + prompt_token_ids + + # MULTIMODAL + # NOTE: Since, we only support PyTorch backend for multimodal, we will send multimodal_data through the 'py_multimodal_data' field + # except `multimodal_input` as it needs to go through the C++ runtime. + multimodal_input = None + if request.multimodal_params is not None and request.multimodal_params.has_content( + ): + if request.multimodal_params.multimodal_input is not None: + multimodal_input = tllm.MultimodalInput( + multimodal_hashes=request.multimodal_params. + multimodal_input.multimodal_hashes, + multimodal_positions=request.multimodal_params. + multimodal_input.multimodal_positions, + multimodal_lengths=request.multimodal_params. + multimodal_input.multimodal_lengths) + # NOTE: Setting to None here to avoid sending multimodal_input again through the 'py_multimodal_data' field + request.multimodal_params.multimodal_input = None + + context_phase_params = None + request_type = tllm.RequestType.REQUEST_TYPE_CONTEXT_AND_GENERATION + if request.disaggregated_params is not None: + assert ( + not self._is_pytorch_backend + or self.engine.kv_cache_transceiver is not None + ), "kv_cache_transceiver is disabled, please set 'cache_transceiver_config: backend:` in config file for disaggregated serving" + request_type = request.disaggregated_params.get_request_type() + if request_type == tllm.RequestType.REQUEST_TYPE_GENERATION_ONLY: + context_phase_params = request.disaggregated_params.get_context_phase_params( + ) + + if self._is_pytorch_backend: + if not self.llm_args.disable_overlap_scheduler: + is_disaggregated = self.engine.kv_cache_transceiver is not None + if is_disaggregated and ( + request_type + == tllm.RequestType.REQUEST_TYPE_CONTEXT_ONLY): + raise ValueError( + "Context only requests are not supported in pytorch backend when overlap is enabled." + ) + + assert request.id is not None + + def _deduce_max_tokens(request: GenerationRequest, + executor_config: tllm.ExecutorConfig, + llm_args: Optional[BaseLlmArgs] = None) -> int: + # deduce max_tokens when it's not set by user + max_tokens = request.sampling_params.max_tokens + query_token_len = len( + request.query_token_ids) if request.query_token_ids else 0 + + cp_size = 1 + max_seq_len = None + if llm_args is not None: + # deduce max_tokens by llm args + assert executor_config is None, "An empty executor_config in _deduce_max_tokens is expected when LLM arguments are defined." + if hasattr(self, + "mapping") and self.mapping.cp_size is not None: + cp_size = self.mapping.cp_size + max_seq_len = getattr(self, "max_seq_len", None) + else: + # deduce max_tokens by executor config + if hasattr(executor_config, "mapping" + ) and executor_config.mapping.cp_size is not None: + cp_size = executor_config.mapping.cp_size + max_seq_len = getattr(executor_config, "max_seq_len", None) + if max_seq_len is None: + logger.warning("`default_max_tokens` cannot be deduced") + if max_tokens is None: + raise ValueError( + "`max_tokens` must be set when `default_max_tokens` cannot be deduced" + ) + else: + # use max_tokens if can't deduce default_max_tokens + return max_tokens + if executor_config is not None: + assert ( + len(prompt_token_ids) <= executor_config.max_seq_len + ), f"`prompt_token_ids` length ({len(prompt_token_ids)}) is greater than `max_seq_len` ({executor_config.max_seq_len})" + splited_prompt_len = int(len(prompt_token_ids) / cp_size) + default_max_tokens = max_seq_len - splited_prompt_len - query_token_len + if default_max_tokens <= 0: + logger.warning( + f"`default_max_tokens` ({default_max_tokens}) should be greater than 0, " + f"`default_max_tokens` ({default_max_tokens}) = max_seq_len ({max_seq_len})" + f" - `splited_prompt_len` ({splited_prompt_len}) - `query_token_len` ({query_token_len})" + ) + if max_tokens is None: + raise ValueError( + "`max_tokens` must be set when `default_max_tokens` is illegal" + ) + # default_max_tokens is the biggest available value + if max_tokens is None: + return default_max_tokens + elif max_tokens > default_max_tokens: + logger.warning( + f"User-specified `max_tokens` ({max_tokens}) is greater than deduced " + f"`default_max_tokens` ({default_max_tokens}), using default_max_tokens instead." + ) + return default_max_tokens + return max_tokens + + try: + executor_request = tllm.Request( + client_id=request.id, + input_token_ids=prompt_token_ids, + max_tokens=_deduce_max_tokens(request, self._executor_config, + self.llm_args), + streaming=request.streaming, + sampling_config=request.sampling_params._get_sampling_config(), + end_id=-1 if request.sampling_params.ignore_eos else + request.sampling_params.end_id, + pad_id=request.sampling_params.pad_id, + output_config=request.sampling_params._get_output_config( + is_pytorch_backend=self._is_pytorch_backend), + # Beam search enforces return_all_generated_tokens=True regardless of the passed value + return_all_generated_tokens=False, + # convert python config into pybind config + lookahead_config=PybindMirror.maybe_to_pybind( + request.sampling_params.lookahead_config), + guided_decoding_params=request.sampling_params. + _get_guided_decoding_params(), + bad_words=request.sampling_params._get_bad_words(), + stop_words=request.sampling_params._get_stop_words(), + embedding_bias=request.sampling_params.embedding_bias, + lora_config=lora_config, + prompt_tuning_config=prompt_tuning_config, + multimodal_input=multimodal_input, + # NOTE: `multimodal_embedding` and `mrope_config` will be in MultimodalParams.multimodal_data. And this will be handled below by `py_multimodal_data`. + multimodal_embedding=None, + mrope_config=None, + logits_post_processor_name=( + tllm.Request.BATCHED_POST_PROCESSOR_NAME + if request.sampling_params.apply_batched_logits_processor + else None), + logits_post_processor=None if self._is_pytorch_backend else + request.sampling_params.logits_processor, + kv_cache_retention_config=request.kv_cache_retention_config, + context_phase_params=context_phase_params, + type=request_type, + cache_salt_id=request.cache_salt_id) + executor_request.py_lora_path = py_lora_path + + if self._is_pytorch_backend and request.multimodal_params is not None: + if request.multimodal_params.multimodal_data is not None: + # NOTE: Deserialize SharedTensor handle to actual tensor + request.multimodal_params.to_tensor("multimodal_data") + executor_request.py_multimodal_data = request.multimodal_params.multimodal_data + + if self._is_pytorch_backend and request.sampling_params.logits_processor: + # For PyTorch backend, we attach logits processors as a dynamic Python attribute + # instead of using the C++ binding, since the latter will cause PyCapsule pickling issues. + lp = request.sampling_params.logits_processor + executor_request.py_logits_post_processors = lp if isinstance( + lp, list) else [lp] + + executor_request.py_scheduling_params = None + if self._is_pytorch_backend and request.scheduling_params is not None: + executor_request.py_scheduling_params = request.scheduling_params + + if request.arrival_time is not None: + executor_request.py_arrival_time = request.arrival_time + + if request.query_token_ids is not None: + # pytorch star attention workflow + # a workaround to avoid public interface update + req_id = self.engine.enqueue_request(executor_request, + request.query_token_ids) + else: + req_id = self.engine.enqueue_request(executor_request) + return req_id + except Exception as e: + raise RequestError(str(e)) from e + + def submit(self, request: GenerationRequest) -> GenerationResult: + """ Low-level API to the executor. Return a "future" GenerationResult which can be waited. """ + self.start() + + if self.rank != 0: + raise RuntimeError( + "Only rank 0 can submit requests.\n" + "To fix this, ensure that the llm.generate(...) method is " + "guarded with the `if __name__ == '__main__':` block.") + + client_id = request.id if request.id is not None else self._get_next_client_id( + ) + if request.id is None: + request.set_id(client_id) + + logprob_params = self._get_logprob_params(request) + + result = GenerationResult( + request, + background_error_handler=self._handle_background_error, + executor=self, + disaggregated_params=request.disaggregated_params, + logprob_params=logprob_params) + + self._results[client_id] = result + + request_id = self._enqueue_request(request) + # request_id returned from backend is necessary for the abort_request method. + self._client_id_to_request_id[client_id] = request_id + + self._handle_background_error() + + return result + + def _pop_result(self, client_id: int): + self._results.pop(client_id, None) + self._client_id_to_request_id.pop(client_id, None) + + def block_subordinates(self): + if self.rank != 0: + if isinstance(self.engine, tllm.Executor): + self.shutdown() + raise self.WorkerExit( + "block_subordinates() should be used in a `with GenerationExecutorWorker() as ...:` block" + ) + from tensorrt_llm._torch.pyexecutor.py_executor import PyExecutor + if isinstance(self.engine, PyExecutor): + self.engine.wait_shutdown() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback) -> bool: + self.shutdown() + return exc_type is None or exc_type == WorkerBase.WorkerExit + + def __del__(self): + self.shutdown() + + +class AwaitResponseHelper: + ''' Multiple-implementations for await_response for performance. ''' + + class HandlerKind(enum.Enum): + unknown = 0 + single_process_worker = 1 + ipc_batched = 2 + + def __init__(self, worker: "WorkerBase"): + # TODO: make worker weakref + self.worker = worker + self.handler_kind: AwaitResponseHelper.HandlerKind = AwaitResponseHelper.HandlerKind.unknown + self.enable_postprocprocess_parallel = self.worker.enable_postprocess_parallel + # The error responses when submit request failed will be put here + self.temp_error_responses = Queue() + + def responses_handler(self, responses: List[tllm.Response]): + HandlerKind = AwaitResponseHelper.HandlerKind + + if self.handler_kind is HandlerKind.unknown: + if not (self.worker.result_queue is not None + or self.worker.postproc_queues is not None): + print_colored_debug( + f"creating await_response helper for Worker\n", + color="yellow") + # When ExecutorBindingWorker is used in the main process + # aka the single process mode + self.handler_kind = HandlerKind.single_process_worker + elif self.worker.result_queue is not None or self.worker.postproc_queues is not None: + # The ExecutorBindingProxy is used + print_colored_debug(f"creating await_response helper for IPC\n", + color="yellow") + self.handler_kind = HandlerKind.ipc_batched + else: + raise NotImplementedError + + match self.handler_kind: + case HandlerKind.single_process_worker: + return self.handle_for_worker(responses) + case HandlerKind.ipc_batched: + return self.handle_for_ipc_batched(responses) + case _: + raise NotImplementedError + + def __call__(self, timeout: Optional[float] = None) -> bool: + ''' This method should be called by a ManagedThread. ''' + timeout = timeout or 0.1 + responses = self.worker.engine.await_responses( + timeout=datetime.timedelta(seconds=timeout)) + # filter since The _engine_response_callback may return None + responses = list( + filter( + lambda _: _, + [self.worker._engine_response_callback(r) for r in responses])) + + # append the error responses to the temp_error_responses + while not self.temp_error_responses.empty(): + responses.append(self.temp_error_responses.get()) + + with nvtx_range_debug(f"await_response-{len(responses)}", + color="red", + category="Worker"): + self.responses_handler(responses) + return True + + def handle_for_worker(self, responses: List[tllm.Response]) -> None: + ''' Return the responses to asyncio.event_loop. ''' + event_loop = None + async_queues = [] + for response in responses: + assert response is not None + queue = self.worker.return_queue(response.client_id) + + response = _maybe_wrap_response(self.worker, response, + self.worker._is_pytorch_backend) + + # For AsyncQueue.sync_q, we will batch the events to avoid too many + # event notifications, thus put without wait here. + if isinstance(queue, _SyncQueue): + global_tracer().log_instant("worker-rsp.put") + queue.put_nowait(response) + async_queues.append(queue) + # all the loops are identical + event_loop = event_loop or queue.loop + else: + queue.put(response) + + if response.has_error() or response.result.is_final: + self.worker._pop_result(response.client_id) + + # Notify the events in bulk for performance. + if async_queues: + _SyncQueue.notify_many(event_loop, async_queues) + + def handle_for_ipc_batched(self, responses: List[tllm.Response]) -> None: + ''' Perform the IPC in batch explicitly. ''' + postproc_batches = [ + [] + for _ in range(self.worker.postproc_config.num_postprocess_workers) + ] if self.enable_postprocprocess_parallel else None + rsp_batch = [] if not self.enable_postprocprocess_parallel else None + + for response in responses: + + if isinstance(response, ErrorResponse): + pass # send ErrorResponse directly + elif self.worker._has_background_error(): + response = self.worker._create_error_response(response) + elif response.has_error(): + # Convert to ErrorResponse, because tllm.Response cannot be + # serialized when it has error. + response = ErrorResponse(response.client_id, response.error_msg, + response.request_id) + else: + response = _maybe_wrap_response(self.worker, response, + self.worker._is_pytorch_backend) + + _send_rsp(self.worker, + response, + postproc_batches=postproc_batches, + rsp_batch=rsp_batch) + + if postproc_batches: + for wid, batch in enumerate(postproc_batches): + if batch: + self.worker.postproc_queues[wid].put(batch) + + if rsp_batch: + self.worker.result_queue.put(rsp_batch) + + +def _get_params_for_first_rsp( + worker, + client_id) -> Tuple[Optional[SamplingParams], Optional[PostprocParams]]: + res = worker._results.get(client_id, None) + assert res is not None + if not res._params_transmitted: + res._params_transmitted = True + return res.sampling_params, res.postproc_params + return None, None + + +def _get_logprobs(worker, + response: tllm.Response, + is_pytorch_backend=False) -> Optional[LogProbsResult]: + """Compute logprob and prompt logprob and clear out logits if applicable. + """ + if is_pytorch_backend: + # _get_logprobs() is a WAR for the TRT backend, where top-k logprobs are computed post runtime. + # In the PyTorch backend, logprobs are already computed during runtime if requested. + return None + + logprobs_result = None + generation_result = worker._results.get(response.client_id, None) + + if not generation_result: + return + + logprob_params = getattr(generation_result, "_logprob_params", None) + if logprob_params: + logprobs_result = compute_logprobs(logprob_params.prompt_logprobs, + logprob_params.logprobs, + response.result.context_logits, + response.result.generation_logits, + response.result.output_token_ids[0]) + + if logprob_params.drop_context_logits: + response.clear_context_logits() + + if logprob_params.drop_generation_logits: + response.clear_generation_logits() + + if response.result.is_final: + generation_result.clear_logprob_params() + + return logprobs_result + + +def _send_rsp( + worker, + response: Union[tllm.Response, ResponseWrapper, ErrorResponse], + postproc_batches: Optional[List[List["PostprocWorker.Input"]]] = None, + rsp_batch: Optional[List[tllm.Response]] = None): + # if postproc_batches is set, append to batch instead of putting to IpcQueue + + if worker.result_queue is not None: + if rsp_batch is not None: + rsp_batch.append(response) + else: + worker.result_queue.put(response) + else: + sampling_params, postproc_params = _get_params_for_first_rsp( + worker, response.client_id) + inp = PostprocWorker.Input( + response, + # sampling_params is necessary for creating fake GenerationResult + # instances in the postproc processes. They are for incremental + # detokenize. They should be transmitted only once for each + # Request. + sampling_params=sampling_params, + postproc_params=postproc_params, + streaming=worker._results.get(response.client_id, None)._streaming) + + pid = response.client_id % worker.postproc_config.num_postprocess_workers + + if not postproc_batches: + # Group the responses into buckets for the postprocessing steps. + # Bucketing is used instead of random dispatching because the + # incremental detokenization during postprocessing relies on the + # prior CompletionOutput of a given request. + worker.postproc_queues[pid].put(inp) + else: + postproc_batches[pid].append(inp) + + # Eliminate the finished GenerationRequest instances timely, which may + # take considerable memory. + if is_llm_response(response): + if response.has_error() or response.result.is_final: + worker._pop_result(response.client_id) + elif isinstance(response, ErrorResponse): + worker._pop_result(response.client_id) + else: + raise ValueError(f"Unknown response type: {response}") + + +def _get_metrics_dict( + response: tllm.Response) -> dict[RequestEventTiming, float]: + req_perf_metrics, metrics_dict = None, {} + res = response.result + if res: + if hasattr(res, '_result'): + if result := res.get_result(): + req_perf_metrics = result.request_perf_metrics + else: + req_perf_metrics = res.request_perf_metrics + if req_perf_metrics and req_perf_metrics.timing_metrics: + metrics_dict = { + RequestEventTiming.ARRIVAL_TIME: + req_perf_metrics.timing_metrics.arrival_time.total_seconds(), + RequestEventTiming.FIRST_TOKEN_TIME: + req_perf_metrics.timing_metrics.first_token_time.total_seconds( + ), + RequestEventTiming.FIRST_SCHEDULED_TIME: + req_perf_metrics.timing_metrics.first_scheduled_time. + total_seconds(), + RequestEventTiming.LAST_TOKEN_TIME: + req_perf_metrics.timing_metrics.last_token_time.total_seconds() + } + return metrics_dict + + +def _maybe_wrap_response( + worker, + response: tllm.Response, + is_pytorch_backend=False) -> Union[tllm.Response, ResponseWrapper]: + + logprobs_result = _get_logprobs(worker, response, is_pytorch_backend) + req_perf_metrics = _get_metrics_dict(response) + if logprobs_result or req_perf_metrics: + response = ResponseWrapper(response, logprobs_result, req_perf_metrics) + return response diff --git a/tests/unittest/executor/test_worker_base.py b/tests/unittest/executor/test_worker_base.py new file mode 100644 index 00000000000..698c43cb1ad --- /dev/null +++ b/tests/unittest/executor/test_worker_base.py @@ -0,0 +1,165 @@ +import os +import sys +import time + +import pytest +import torch + +from tensorrt_llm._utils import mpi_comm, mpi_rank, mpi_world_size +from tensorrt_llm.bindings import executor as tllm +from tensorrt_llm.llmapi.mpi_session import MpiPoolSession + +# isort: off +sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/..") +from utils.llm_data import llm_models_root +# isort: on + +from tensorrt_llm._torch.pyexecutor.config import update_executor_config +from tensorrt_llm.executor.request import GenerationRequest +from tensorrt_llm.executor.worker_base import WorkerBase +from tensorrt_llm.llmapi.llm_args import LlmArgs +from tensorrt_llm.sampling_params import SamplingParams + +default_model_name = "llama-models-v2/TinyLlama-1.1B-Chat-v1.0" +model_path = llm_models_root() / default_model_name + + +class FakeWorker(WorkerBase): + + def __init__(self, engine: str, tp_size: int = 1): + llm_args, executor_config = create_fake_executor_config(engine, tp_size) + super().__init__( + engine=engine, + llm_args=llm_args, + hf_model_dir=engine, + ) + # Pass config in constructor and finalize with parameterless setup + self._executor_config = executor_config + self.llm_args = llm_args + self.setup_engine() + + def start(self): + pass + + def shutdown(self): + if self.engine is not None: + self.engine.shutdown() + self.engine = None + + +class TestWorkerBase: + + def test_create_engine(self): + with self.FakeWorker(engine=model_path) as worker: + print(f"Created engine: {worker.engine}") + + def test_submit_request(self): + sampling_params = SamplingParams(max_tokens=10) + request = GenerationRequest(prompt_token_ids=[3, 4, 5], + sampling_params=sampling_params) + with self.FakeWorker(engine=model_path) as worker: + print(f"Created engine: {worker.engine}") + worker.submit(request) + for i in range(10): + time.sleep(0.5) + worker.await_responses() + print(f"Submitted request: {request}") + time.sleep(6) + + def test_fetch_stats(self): + request = GenerationRequest( + prompt_token_ids=[3, 4, 5], + sampling_params=SamplingParams(max_tokens=10)) + with self.FakeWorker(engine=model_path) as worker: + worker.submit(request) + time.sleep(1) + worker.await_responses() + stats = worker.fetch_stats() + print(stats) + + @pytest.mark.parametrize("timeout", [0.1, 0.2, 1]) + def test_fetch_responses_timeout(self, timeout: float): + with self.FakeWorker(engine=model_path) as worker: + # Not submit any request, and let the await_responses timeout. + start_time = time.time() + results = worker.await_responses(timeout=timeout) + elapsed = time.time() - start_time + print(f"await_responses latency: {elapsed:.3f} seconds") + assert timeout / 2 <= elapsed <= timeout * 2, f"Latency out of expected range: {elapsed}" + + +def create_fake_executor_config(model_path, tp_size=1): + llm_args = LlmArgs(model=model_path, + cuda_graph_config=None, + tensor_parallel_size=tp_size) + + executor_config = tllm.ExecutorConfig(1) + executor_config.max_batch_size = 1 + executor_config.model_world_size = tp_size + + update_executor_config( + executor_config, + pytorch_backend_config=llm_args.get_pytorch_backend_config(), + mapping=llm_args.parallel_config.to_mapping(), + speculative_config=llm_args.speculative_config, + hf_model_dir=model_path, + max_input_len=20, + max_seq_len=40, + checkpoint_format=llm_args.checkpoint_format, + checkpoint_loader=llm_args.checkpoint_loader, + ) + + return llm_args, executor_config + + +class TestRpcWorkerBaseTP2: + + def setup_method(self): + self.llm_args = LlmArgs(model=model_path, tensor_parallel_size=2) + self.session = self.create_worker_session() + + def create_worker_session(self): + session = MpiPoolSession(n_workers=2) + return session + + def test_create_executor(self): + futures = self.session.submit( + TestRpcWorkerBaseTP2.create_executor, + engine=model_path, + llm_args=self.llm_args, + ) + # Wait for completion + for future in futures: + future.result() + + self.session.shutdown() + + @staticmethod + def create_executor(engine, llm_args): + rank = mpi_rank() + world_size = mpi_world_size() + device_id = rank % torch.cuda.device_count() + torch.cuda.set_device(device_id) + + print(f"[Test] Rank {rank}/{world_size} using device {device_id}") + + # Synchronize all workers before creating executor + mpi_comm().barrier() + + print(f"[Test] Rank {rank} creating WorkerBase...") + executor = FakeWorker(engine=engine, tp_size=2) + + # For PyTorch backend, all ranks need to participate in setup + print(f"[Test] Rank {rank} calling setup_engine...") + + # Setup the engine which contains another barrier + executor.setup_engine() + + print(f"[Test] Rank {rank} setup_engine completed successfully") + + executor.shutdown() + + +if __name__ == "__main__": + test_worker_base = TestWorkerBase() + test_worker_base.test_fetch_stats() From 3cd210335e1142feb4693f972d25827014caab7b Mon Sep 17 00:00:00 2001 From: Yan Chunwei <328693+Superjomn@users.noreply.github.com> Date: Fri, 19 Sep 2025 11:09:51 +0800 Subject: [PATCH 02/13] fix test Signed-off-by: Yan Chunwei <328693+Superjomn@users.noreply.github.com> Signed-off-by: chunweiy init Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> Signed-off-by: chunweiy Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> --- tensorrt_llm/executor/ipc.py | 4 + tensorrt_llm/executor/rpc.py | 431 ++++++++++ tensorrt_llm/executor/worker_base.py | 834 -------------------- tensorrt_llm/llmapi/utils.py | 4 +- tests/unittest/executor/test_rpc.py | 191 +++++ tests/unittest/executor/test_worker_base.py | 165 ---- 6 files changed, 628 insertions(+), 1001 deletions(-) create mode 100644 tensorrt_llm/executor/rpc.py delete mode 100644 tensorrt_llm/executor/worker_base.py create mode 100644 tests/unittest/executor/test_rpc.py delete mode 100644 tests/unittest/executor/test_worker_base.py diff --git a/tensorrt_llm/executor/ipc.py b/tensorrt_llm/executor/ipc.py index 327dbf4f6f5..d4318eb379e 100644 --- a/tensorrt_llm/executor/ipc.py +++ b/tensorrt_llm/executor/ipc.py @@ -1,3 +1,4 @@ +import asyncio import hashlib import hmac import os @@ -187,6 +188,9 @@ async def get_async(self) -> Any: self.setup_lazily() return await self._recv_data_async() + async def get_async_noblock(self, timeout: float = 0.5) -> Any: + return await asyncio.wait_for(self.get_async(), timeout) + def close(self): if self.socket: self.socket.close() diff --git a/tensorrt_llm/executor/rpc.py b/tensorrt_llm/executor/rpc.py new file mode 100644 index 00000000000..a863bc93942 --- /dev/null +++ b/tensorrt_llm/executor/rpc.py @@ -0,0 +1,431 @@ +import asyncio +import concurrent.futures +import queue +import threading +import traceback +import uuid +from concurrent.futures import ThreadPoolExecutor +from typing import Any, NamedTuple, Optional + +from ..llmapi.utils import ManagedThread +from ..logger import logger +from .ipc import ZeroMqQueue + + +# --- Custom Exceptions --- +class RPCError(Exception): + """Custom exception for RPC-related errors raised on the client side.""" + + +class RPCTimeout(RPCError): + """Custom exception for when a client request times out.""" + + +class RPCRequest(NamedTuple): + request_id: str + method_name: str + args: tuple + kwargs: dict + need_response: bool = True + + +class RPCResponse(NamedTuple): + request_id: str + status: str + result: Any + + +class RPCServer: + """ + An RPC Server that listens for requests and executes them concurrently. + """ + + def __init__(self, + instance, + hmac_key=None, + num_workers: int = 1, + timeout: float = 0.5, + async_run_task: bool = False): + """ + Initializes the server with an instance. + + Args: + instance: The instance whose methods will be exposed via RPC. + hmac_key (bytes, optional): HMAC key for encryption. + num_workers (int): Number of worker threads. + timeout (int): Timeout for RPC calls. + async_run_task (bool): Whether to run the task asynchronously. + """ + self._instance = instance + self._hmac_key = hmac_key + self._num_workers = num_workers + self._address = None + self._timeout = timeout + self._client_socket = None + + # set the stop event to True, and all the workers will exit + self._stop_event = threading.Event() + + self._functions = {"shutdown": self.shutdown} + self._dispatcher_thread: Optional[ManagedThread] = None + if async_run_task: + self._executor = ThreadPoolExecutor(max_workers=num_workers, + thread_name_prefix="rpc_worker") + else: + self._executor = None + + self._queue = None + + # Automatically register the instance + self.register_instance(instance) + + logger.debug(f"RPC Server initialized with {num_workers} workers.") + + @property + def address(self) -> str: + assert self._client_socket is not None, "Client socket is not bound" + return self._client_socket.address[0] + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.shutdown() + + def bind(self, address="tcp://*:5555"): + """ + Bind the server to the specified address. + + Args: + address (str): The ZMQ address to bind the client-facing socket. + """ + self._address = address + self._client_socket = ZeroMqQueue(address=(address, self._hmac_key), + is_server=True, + is_async=True, + use_hmac_encryption=False) + logger.info(f"RPC Server bound to {self._address}") + + def shutdown(self): + """Internal method to trigger server shutdown.""" + logger.debug( + "RPC Server shutdown signal received. Terminating server...") + + if self._dispatcher_thread and self._dispatcher_thread.is_alive(): + self._stop_event.set() + self._dispatcher_thread.join() + self._dispatcher_thread = None + + if self._executor: + self._executor.shutdown(wait=False) + + if self._client_socket: + self._client_socket.close() + + self._client_socket = None + self._queue = None + + def register_function(self, func, name=None): + """Exposes a single function to clients.""" + fname = name or func.__name__ + if fname in self._functions: + logger.warning( + f"Function '{fname}' is already registered. Overwriting.") + self._functions[fname] = func + logger.debug(f"Registered function: {fname}") + + def register_instance(self, instance): + """Exposes all public methods of a class instance.""" + logger.debug( + f"Registering instance of class: {instance.__class__.__name__}") + for name in dir(instance): + if not name.startswith('_'): + attr = getattr(instance, name) + if callable(attr): + self.register_function(attr, name) + + async def _dispatcher_routine(self, stop_event: threading.Event): + assert self._client_socket is not None, "Client socket is not bound" + assert self._queue is not None, "RPC queue is not initialized" + + while not stop_event.is_set(): + try: + req: RPCRequest = await self._client_socket.get_async_noblock( + timeout=0.5) + logger.debug(f"RPC dispatcher got request: {req}") + except asyncio.TimeoutError: + logger.debug("RPC dispatcher get request timeout") + continue + + await self._queue.put(req) # type: ignore + + async def _worker_routine(self, stop_event: threading.Event): + """The routine executed by each worker thread.""" + assert self._client_socket is not None, "Client socket is not bound" + assert self._queue is not None, "RPC queue is not initialized" + + while not stop_event.is_set(): + try: + req: RPCRequest = await asyncio.wait_for( + self._queue.get(), # type: ignore + timeout=self._timeout) + except asyncio.TimeoutError: + logger.debug("RPC worker get request timeout") + continue + + if req.method_name in self._functions: + try: + if self._executor is not None: + # Dispatch to worker thread and await result + loop = asyncio.get_running_loop() + result = await loop.run_in_executor( + self._executor, self._functions[req.method_name], + *req.args, **req.kwargs) + else: + result = self._functions[req.method_name](*req.args, + **req.kwargs) + response = RPCResponse(req.request_id, 'OK', result) + except Exception: + tb = traceback.format_exc() + response = RPCResponse(req.request_id, 'ERROR', tb) + else: + response = RPCResponse( + req.request_id, 'ERROR', + f"Method '{req.method_name}' not found.") + + # Some tasks don't need response, e.g. submit_request or shutdown + if req.need_response: + await self._client_socket.put_async(response) + + def start(self): + """Binds sockets, starts workers, and begins proxying messages.""" + if self._client_socket is None: + raise RuntimeError( + "Server must be bound to an address before starting. Call bind() first." + ) + + self._client_socket.setup_lazily() + logger.info(f"RPC Server started and listening on {self._address}") + + async def tasks(): + self._queue = asyncio.Queue() + await asyncio.gather( + self._dispatcher_routine(self._stop_event), *[ + self._worker_routine(self._stop_event) + for i in range(self._num_workers) + ]) + + def loop() -> bool: + asyncio.run(tasks()) + return True # ManagedThread + + error_queue = queue.Queue() + self._dispatcher_thread = ManagedThread(task=loop, + stop_event=self._stop_event, + name="rpc_dispatcher_thread", + error_queue=error_queue) + self._dispatcher_thread.start() + + logger.info("RPC Server has started.") + + +Server = RPCServer + + +class RPCClient: + """ + An RPC Client that connects to the RPCServer. + """ + + def __init__(self, + address: str, + hmac_key=None, + timeout: float = 10, + num_workers: int = 4): + ''' + Args: + address: The ZMQ address to connect to. + hmac_key: The HMAC key for encryption. + timeout: The timeout (seconds) for RPC calls. + ''' + self._address = address + self._timeout = timeout + self._client_socket = ZeroMqQueue(address=(address, hmac_key), + is_server=False, + is_async=True, + use_hmac_encryption=False) + self._pending_futures = {} + self._reader_task = None + self._executor = concurrent.futures.ThreadPoolExecutor( + max_workers=num_workers, thread_name_prefix="rpc_client") + logger.info(f"RPC Client initialized. Connected to {self._address}") + + def __del__(self): + """Cleanup executor when client is destroyed.""" + self.close() + + def close(self): + """Gracefully close the client, cleaning up background tasks.""" + if self._reader_task: + self._reader_task.cancel() + self._reader_task = None + if self._executor: + self._executor.shutdown(wait=False) + + async def _response_reader(self): + """Task to read responses from the socket and set results on futures.""" + + while True: + try: + response: RPCResponse = await self._client_socket.get_async() + future = self._pending_futures.get(response.request_id) + if future and not future.done(): + if response.status == 'OK': + future.set_result(response.result) + elif response.status == 'ERROR': + # TODO: Maybe keep the original Error type? + future.set_exception( + RPCError( + f"Server-side exception:\n{response.result}")) + else: + future.set_exception( + RPCError( + f"Unknown response status: {response.status}")) + self._pending_futures.pop(response.request_id, None) + + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Exception in RPC response reader: {e}") + # Propagate exception to all pending futures + for future in self._pending_futures.values(): + if not future.done(): + future.set_exception(e) + break + + await asyncio.sleep(0) + + self._reader_task = None + + async def _start_reader_if_needed(self): + if self._reader_task is None or self._reader_task.done(): + loop = asyncio.get_running_loop() + self._reader_task = loop.create_task(self._response_reader()) + + async def _call_async(self, name, *args, **kwargs): + """Async version of RPC call.""" + await self._start_reader_if_needed() + need_response = kwargs.pop("need_response", True) + + request_id = uuid.uuid4().hex + logger.debug(f"RPC client sending request: {request_id}") + request = RPCRequest(request_id, name, args, kwargs, need_response) + logger.debug(f"RPC client sending request: {request}") + await self._client_socket.put_async(request) + + if not need_response: + return None + + loop = asyncio.get_running_loop() + future = loop.create_future() + self._pending_futures[request_id] = future + + try: + return await asyncio.wait_for(future, self._timeout) + except asyncio.TimeoutError: + raise RPCError(f"Request '{name}' timed out after {self._timeout}s") + finally: + self._pending_futures.pop(request_id, None) + + def _call_sync(self, name, *args, **kwargs): + """Synchronous version of RPC call.""" + return asyncio.run(self._call_async(name, *args, **kwargs)) + + def call_async(self, name: str, *args, **kwargs): + """ + Call a remote method asynchronously. + + Args: + name: Method name to call + *args: Positional arguments + **kwargs: Keyword arguments + + Returns: + Coroutine that can be awaited + + Example: + result = await client.call_async('remote_method', arg1, arg2, key=value) + """ + return self._call_async(name, *args, **kwargs, need_response=True) + + def call_future(self, name: str, *args, + **kwargs) -> concurrent.futures.Future: + """ + Call a remote method and return a Future. + + Args: + name: Method name to call + *args: Positional arguments + **kwargs: Keyword arguments + + Returns: + A Future object that can be used to retrieve the result + + Example: + future = client.call_future('remote_method', arg1, arg2, key=value) + result = future.result() # blocks until complete + # or + future.add_done_callback(lambda f: print(f.result())) + """ + + def _async_to_sync(): + return asyncio.run(self._call_async(name, *args, **kwargs)) + + return self._executor.submit(_async_to_sync) + + def call_sync(self, name: str, *args, **kwargs): + """ + Call a remote method synchronously (blocking). + + Args: + name: Method name to call + *args: Positional arguments + **kwargs: Keyword arguments + + Returns: + The result of the remote method call + + Example: + result = client.call_sync('remote_method', arg1, arg2, key=value) + """ + return self._call_sync(name, *args, **kwargs) + + def __getattr__(self, name): + """ + Magically handles calls to non-existent methods. + Returns a proxy object that supports multiple calling patterns. + """ + + class MethodProxy: + + def __init__(self, client, method_name): + self.client = client + self.method_name = method_name + + def __call__(self, *args, **kwargs): + """Default synchronous call""" + return self.client._call_sync(self.method_name, *args, **kwargs) + + def call_async(self, *args, **kwargs): + """Async call - returns coroutine""" + return self.client._call_async(self.method_name, + *args, + need_response=True, + **kwargs) + + def call_future(self, *args, **kwargs) -> concurrent.futures.Future: + """Future call - returns Future object""" + return self.client.call_future(self.method_name, *args, + **kwargs) + + return MethodProxy(self, name) diff --git a/tensorrt_llm/executor/worker_base.py b/tensorrt_llm/executor/worker_base.py deleted file mode 100644 index 24a43e261d5..00000000000 --- a/tensorrt_llm/executor/worker_base.py +++ /dev/null @@ -1,834 +0,0 @@ -import copy -import datetime -import enum -import json -import weakref -from pathlib import Path -from queue import Queue -from typing import Dict, List, Optional, Tuple, Union - -import torch - -from tensorrt_llm.logger import logger - -from .._utils import (global_mpi_rank, global_mpi_size, mpi_comm, mpi_rank, - nvtx_range_debug) -from ..bindings import executor as tllm -from ..builder import ConfigEncoder, Engine, EngineConfig -from ..llmapi.llm_args import BaseLlmArgs, KvCacheConnectorConfig, PybindMirror -from ..llmapi.tokenizer import TokenizerBase -from ..llmapi.tracer import global_tracer -from ..llmapi.utils import _SyncQueue, print_colored_debug -from ..lora_helper import LoraConfig -from ..lora_manager import LoraManager -from ..metrics import RequestEventTiming -from ..prompt_adapter_manager import PromptAdapterManager -from ..runtime import ModelConfig -from ..runtime.model_runner import _engine_config_to_model_config -from ..sampling_params import BatchedLogitsProcessor, SamplingParams -from .executor import GenerationExecutor, IterationResultQueue -from .ipc import FusedIpcQueue, IpcQueue -from .postproc_worker import (PostprocParams, PostprocWorker, - PostprocWorkerConfig) -from .request import GenerationRequest, LoRARequest, PromptAdapterRequest -from .result import (GenerationResult, LogProbsResult, ResponseWrapper, - compute_logprobs) -from .utils import (ErrorResponse, IntraProcessQueue, RequestError, - is_llm_response) - -__all__ = [ - "WorkerBase", -] - - -class WorkerBase(GenerationExecutor): - - class WorkerExit(GeneratorExit): - pass - - def __init__( - self, - engine: Union[Path, Engine], - executor_config: Optional[tllm.ExecutorConfig] = None, - batched_logits_processor: Optional[BatchedLogitsProcessor] = None, - postproc_worker_config: Optional[PostprocWorkerConfig] = None, - is_llm_executor: Optional[bool] = None, - lora_config: Optional[LoraConfig] = None, - kv_connector_config: Optional[KvCacheConnectorConfig] = None, - hf_model_dir: Optional[Path] = None, - tokenizer: Optional[TokenizerBase] = None, - llm_args: Optional[BaseLlmArgs] = None, - ) -> None: - postproc_config = postproc_worker_config or PostprocWorkerConfig() - super().__init__( - num_postprocess_workers=postproc_config.num_postprocess_workers, - postprocess_tokenizer_dir=postproc_config.postprocess_tokenizer_dir, - is_llm_executor=is_llm_executor, - ) - - # inputs - self._engine = engine - self._executor_config = executor_config - self._batched_logits_processor = batched_logits_processor - self._postproc_worker_config = postproc_worker_config - self._is_llm_executor = is_llm_executor - self._lora_config = lora_config - self._kv_connector_config = kv_connector_config - self._hf_model_dir = hf_model_dir - self._tokenizer = tokenizer - self.llm_args = llm_args - - self.engine = None - self.result_queue: Optional[IpcQueue] = None - self.postproc_queues: Optional[List[IpcQueue]] = None - self.rank = mpi_rank() - self.global_rank = global_mpi_rank() - # mapping: client_id -> GenerationResult - self._results: Dict[int, GenerationResult] = {} - # mapping: client_id from Proxy -> request_id returned from runtime backend - self._client_id_to_request_id: Dict[int, int] = {} - self._await_response_helper = AwaitResponseHelper(weakref.proxy(self)) - self._is_pytorch_backend = llm_args is not None and llm_args.backend in [ - "pytorch", "_autodeploy" - ] - - if not self._is_pytorch_backend and kv_connector_config is not None: - raise ValueError( - "KV connector config is only supported for PyTorch backend") - - if global_mpi_size() > 1: - logger.set_rank(self.global_rank) - - def setup_engine(self): - """ - Setup the engine for the worker. - """ - - if isinstance(self._engine, list): - self._engine[self.rank] - - def _get_comm_ranks_device_id(): - device_id = self.global_rank % torch.cuda.device_count() - torch.cuda.set_device(device_id) - # Make sure C++ executor would use same devices/ranks as py_executor - global_rank = global_mpi_rank() - comm_ranks = mpi_comm().allgather(global_rank) - device_ids = mpi_comm().allgather(device_id) - return comm_ranks, device_ids - - def _create_py_executor(): - args = {} - assert hasattr( - self.llm_args, "backend" - ), "llm_args should be with backend in _create_py_executor" - _ = _get_comm_ranks_device_id() - if self.llm_args.backend == "pytorch": - from tensorrt_llm._torch.pyexecutor.py_executor_creator import \ - create_py_executor - create_executor = create_py_executor - args["llm_args"] = self.llm_args - args["checkpoint_dir"] = self._hf_model_dir - args["tokenizer"] = self._tokenizer - args["lora_config"] = self._lora_config - args["kv_connector_config"] = self._kv_connector_config - elif self.llm_args.backend == "_autodeploy": - from tensorrt_llm._torch.auto_deploy.llm_args import \ - LlmArgs as ADLlmArgs - from tensorrt_llm._torch.auto_deploy.shim.ad_executor import \ - create_autodeploy_executor - create_executor = create_autodeploy_executor - assert isinstance(self.llm_args, ADLlmArgs) - args["ad_config"] = self.llm_args.get_pytorch_backend_config() - else: - raise ValueError( - f"Unsupported backend config: {self.llm_args.backend}") - - # Define additional attributes that can be used later, such as in _deduce_max_tokens - self.mapping = self.llm_args.parallel_config.to_mapping() - self.checkpoint_loader = None - if self.llm_args.backend == "pytorch": - from tensorrt_llm._torch.pyexecutor.config import \ - _construct_checkpoint_loader - self.checkpoint_loader = _construct_checkpoint_loader( - self.llm_args.backend, self.llm_args.checkpoint_loader, - self.llm_args.checkpoint_format) - - _executor = create_executor(**args) - self.max_seq_len = self.llm_args.max_seq_len - if _executor.max_seq_len is not None: - # max_seq_len might be updated by model engine as in create_py_executor - self.max_seq_len = _executor.max_seq_len - return _executor - - def _create_engine(executor_config): - engine = self._engine - if executor_config is None: - executor_config = tllm.ExecutorConfig(1) - executor_config.logits_post_processor_config = tllm.LogitsPostProcessorConfig( - processor_batched=self._batched_logits_processor, - replicate=False) - comm_ranks, device_ids = _get_comm_ranks_device_id() - executor_config.parallel_config = tllm.ParallelConfig( - participant_ids=comm_ranks, device_ids=device_ids) - - if isinstance(engine, Engine): - return tllm.Executor(engine.engine, - json.dumps(engine.config.to_dict(), - cls=ConfigEncoder), - tllm.ModelType.DECODER_ONLY, - executor_config=executor_config, - managed_weights=engine.managed_weights) - - assert not hasattr(executor_config, "backend") - return tllm.Executor(engine, tllm.ModelType.DECODER_ONLY, - executor_config) - - self.engine = _create_py_executor( - ) if self.llm_args is not None else _create_engine( - self._executor_config) - - self._lora_manager: Optional[LoraManager] = None - self._prompt_adapter_manager: Optional[PromptAdapterManager] = None - self._runtime_model_config: Optional[ModelConfig] = None - if self.rank == 0 and isinstance(self.engine, tllm.Executor): - if isinstance(self.engine, Engine): - engine_config = self.engine.config - else: - engine_config = EngineConfig.from_json_file( - f"{self._engine}/config.json") - self._runtime_model_config = _engine_config_to_model_config( - engine_config) - if engine_config.build_config.plugin_config.lora_plugin: - # TODO(azuker): Passing peft cache manager to LoraManager is used for LoRA optimization - # (see LoraManager constructor docstring). Getting the peft cache manager from this - # point in the TRT flow is currently not supported (it's at the CPP - # Executor->ExecutorImpl->TrtGptModel->mPeftCacheManager) therefore for now this LoRA - # optimization is not available in TRT-python flow. - self._lora_manager = LoraManager(cpp_peft_cache_manager=None) - if engine_config.build_config.max_prompt_embedding_table_size > 0: - self._prompt_adapter_manager = PromptAdapterManager() - - if self.llm_args and getattr( - self.llm_args, "backend", - "") == "pytorch" and self._lora_config is not None: - from tensorrt_llm._torch.pyexecutor.resource_manager import \ - ResourceManagerType - peft_cache_manager = self.engine.resource_manager.resource_managers.get( - ResourceManagerType.PEFT_CACHE_MANAGER) - self._lora_manager = LoraManager( - cpp_peft_cache_manager=peft_cache_manager.impl) - lora_model_config = self.engine.model_engine.lora_model_config - assert lora_model_config is not None - self._lora_model_config = lora_model_config - - def await_responses(self, timeout: Optional[float] = None) -> list: - return self.engine.await_responses(timeout=datetime.timedelta( - seconds=timeout) if timeout is not None else None) - - def fetch_stats(self) -> list: - if isinstance(self.engine, tllm.Executor): - iter_stats = self.engine.get_latest_iteration_stats() - #TODO: Support req stats with TRT engine - # This would require ensuring iter and req stats have same size - return [(iter_stat, None) for iter_stat in iter_stats] - else: - return self.engine.get_latest_iteration_stats() - - def set_result_queue(self, queue): - """In multi-gpu mode, result_queue will be set here to communicate between the proxy and the worker 0 process.""" - assert self.postproc_queues is None - self.result_queue = queue - - def set_postproc_queues(self, queues: List["IpcQueue"]): - """ Set the IPC queues for feeding post-processing processes. """ - assert self.result_queue is None - self.postproc_queues = queues - - def _set_iteration_result_queue(self, it_result_queue: IterationResultQueue, - queue: Union[Queue, FusedIpcQueue, - IntraProcessQueue]): - assert not it_result_queue.is_initialized, "Iteration result queue should not already be initialized." - it_result_queue.is_initialized = True - it_result_queue.queue = queue - it_result_queue.aqueue = None - - def return_queue(self, client_id: int): - """ If a centralized result queue is registered (used for communication with the proxy) - send the message there. - Otherwise, push the result directly in the GenerationResult queue. - """ - if self.result_queue is not None: - return self.result_queue - return self._results[client_id].queue - - def abort_request(self, client_id: int) -> None: - # NOTE: the request_id is the request_id generated by cpp runtime, not the client_id - if self.engine.can_enqueue_requests(): - request_id = self._client_id_to_request_id.get(client_id, None) - if request_id is None: - logger.warning( - f"Request of client_id {client_id} is finished, cannot abort it." - ) - return - self.engine.cancel_request(request_id) - - def _engine_response_callback(self, response: tllm.Response): - return response - - def await_response_task(self) -> bool: - return self._await_response_helper() - - def _has_background_error(self) -> bool: - return not self._error_queue.empty() - - def _create_error_response(self, response: tllm.Response) -> ErrorResponse: - bck_error = self._error_queue.get_nowait() - assert isinstance(bck_error, Exception) - return ErrorResponse(response.client_id, str(bck_error), - response.request_id) - - def start(self): - raise NotImplementedError( - "start method is not implemented in WorkerBase") - - def _load_lora_adapter(self, lora_request: LoRARequest) -> bool: - """Returns True if the adapter was loaded by this call, False if it was already loaded""" - adapter_id = str(lora_request.adapter_id) - newly_loaded_uids = self._lora_manager.load_from_ckpt( - [lora_request.path], - model_config=self._runtime_model_config if - self._runtime_model_config is not None else self._lora_model_config, - runtime_mapping=None, - uids=[adapter_id], - ckpt_source=lora_request.ckpt_source) - return adapter_id in newly_loaded_uids - - def _load_prompt_adapter(self, - prompt_adapter_request: PromptAdapterRequest): - self._prompt_adapter_manager.load_from_ckpt( - [prompt_adapter_request.local_path], - model_config=self._runtime_model_config, - uids=[str(prompt_adapter_request.adapter_id)]) - - def _enqueue_request(self, request: GenerationRequest) -> int: - assert request.id is not None - py_lora_path = None - if self._lora_manager is not None and request.lora_request is not None: - adapter_in_cache = self._lora_manager.is_adapter_in_cpu_cache( - request.lora_request.adapter_id) - self._load_lora_adapter(request.lora_request) - uid = str(request.lora_request.adapter_id) - lora_config = tllm.LoraConfig( - task_id=request.lora_request.adapter_id, - weights=self._lora_manager.cpp_lora_weights[uid] - if not adapter_in_cache else None, - config=self._lora_manager.cpp_lora_config[uid]) - py_lora_path = request.lora_request.lora_path - else: - lora_config = None - - prompt_token_ids = copy.deepcopy(request.prompt_token_ids) - prompt_tuning_config = None - if request.prompt_adapter_request is not None: - self._load_prompt_adapter(request.prompt_adapter_request) - uid = str(request.prompt_adapter_request.adapter_id) - prompt_tuning_config = tllm.PromptTuningConfig( - self._prompt_adapter_manager.uid_to_weights[uid]) - vocab_size = self._runtime_model_config.vocab_size - pa_length = prompt_tuning_config.embedding_table.size(0) - prompt_token_ids = list(range( - vocab_size, vocab_size + pa_length)) + prompt_token_ids - - # MULTIMODAL - # NOTE: Since, we only support PyTorch backend for multimodal, we will send multimodal_data through the 'py_multimodal_data' field - # except `multimodal_input` as it needs to go through the C++ runtime. - multimodal_input = None - if request.multimodal_params is not None and request.multimodal_params.has_content( - ): - if request.multimodal_params.multimodal_input is not None: - multimodal_input = tllm.MultimodalInput( - multimodal_hashes=request.multimodal_params. - multimodal_input.multimodal_hashes, - multimodal_positions=request.multimodal_params. - multimodal_input.multimodal_positions, - multimodal_lengths=request.multimodal_params. - multimodal_input.multimodal_lengths) - # NOTE: Setting to None here to avoid sending multimodal_input again through the 'py_multimodal_data' field - request.multimodal_params.multimodal_input = None - - context_phase_params = None - request_type = tllm.RequestType.REQUEST_TYPE_CONTEXT_AND_GENERATION - if request.disaggregated_params is not None: - assert ( - not self._is_pytorch_backend - or self.engine.kv_cache_transceiver is not None - ), "kv_cache_transceiver is disabled, please set 'cache_transceiver_config: backend:` in config file for disaggregated serving" - request_type = request.disaggregated_params.get_request_type() - if request_type == tllm.RequestType.REQUEST_TYPE_GENERATION_ONLY: - context_phase_params = request.disaggregated_params.get_context_phase_params( - ) - - if self._is_pytorch_backend: - if not self.llm_args.disable_overlap_scheduler: - is_disaggregated = self.engine.kv_cache_transceiver is not None - if is_disaggregated and ( - request_type - == tllm.RequestType.REQUEST_TYPE_CONTEXT_ONLY): - raise ValueError( - "Context only requests are not supported in pytorch backend when overlap is enabled." - ) - - assert request.id is not None - - def _deduce_max_tokens(request: GenerationRequest, - executor_config: tllm.ExecutorConfig, - llm_args: Optional[BaseLlmArgs] = None) -> int: - # deduce max_tokens when it's not set by user - max_tokens = request.sampling_params.max_tokens - query_token_len = len( - request.query_token_ids) if request.query_token_ids else 0 - - cp_size = 1 - max_seq_len = None - if llm_args is not None: - # deduce max_tokens by llm args - assert executor_config is None, "An empty executor_config in _deduce_max_tokens is expected when LLM arguments are defined." - if hasattr(self, - "mapping") and self.mapping.cp_size is not None: - cp_size = self.mapping.cp_size - max_seq_len = getattr(self, "max_seq_len", None) - else: - # deduce max_tokens by executor config - if hasattr(executor_config, "mapping" - ) and executor_config.mapping.cp_size is not None: - cp_size = executor_config.mapping.cp_size - max_seq_len = getattr(executor_config, "max_seq_len", None) - if max_seq_len is None: - logger.warning("`default_max_tokens` cannot be deduced") - if max_tokens is None: - raise ValueError( - "`max_tokens` must be set when `default_max_tokens` cannot be deduced" - ) - else: - # use max_tokens if can't deduce default_max_tokens - return max_tokens - if executor_config is not None: - assert ( - len(prompt_token_ids) <= executor_config.max_seq_len - ), f"`prompt_token_ids` length ({len(prompt_token_ids)}) is greater than `max_seq_len` ({executor_config.max_seq_len})" - splited_prompt_len = int(len(prompt_token_ids) / cp_size) - default_max_tokens = max_seq_len - splited_prompt_len - query_token_len - if default_max_tokens <= 0: - logger.warning( - f"`default_max_tokens` ({default_max_tokens}) should be greater than 0, " - f"`default_max_tokens` ({default_max_tokens}) = max_seq_len ({max_seq_len})" - f" - `splited_prompt_len` ({splited_prompt_len}) - `query_token_len` ({query_token_len})" - ) - if max_tokens is None: - raise ValueError( - "`max_tokens` must be set when `default_max_tokens` is illegal" - ) - # default_max_tokens is the biggest available value - if max_tokens is None: - return default_max_tokens - elif max_tokens > default_max_tokens: - logger.warning( - f"User-specified `max_tokens` ({max_tokens}) is greater than deduced " - f"`default_max_tokens` ({default_max_tokens}), using default_max_tokens instead." - ) - return default_max_tokens - return max_tokens - - try: - executor_request = tllm.Request( - client_id=request.id, - input_token_ids=prompt_token_ids, - max_tokens=_deduce_max_tokens(request, self._executor_config, - self.llm_args), - streaming=request.streaming, - sampling_config=request.sampling_params._get_sampling_config(), - end_id=-1 if request.sampling_params.ignore_eos else - request.sampling_params.end_id, - pad_id=request.sampling_params.pad_id, - output_config=request.sampling_params._get_output_config( - is_pytorch_backend=self._is_pytorch_backend), - # Beam search enforces return_all_generated_tokens=True regardless of the passed value - return_all_generated_tokens=False, - # convert python config into pybind config - lookahead_config=PybindMirror.maybe_to_pybind( - request.sampling_params.lookahead_config), - guided_decoding_params=request.sampling_params. - _get_guided_decoding_params(), - bad_words=request.sampling_params._get_bad_words(), - stop_words=request.sampling_params._get_stop_words(), - embedding_bias=request.sampling_params.embedding_bias, - lora_config=lora_config, - prompt_tuning_config=prompt_tuning_config, - multimodal_input=multimodal_input, - # NOTE: `multimodal_embedding` and `mrope_config` will be in MultimodalParams.multimodal_data. And this will be handled below by `py_multimodal_data`. - multimodal_embedding=None, - mrope_config=None, - logits_post_processor_name=( - tllm.Request.BATCHED_POST_PROCESSOR_NAME - if request.sampling_params.apply_batched_logits_processor - else None), - logits_post_processor=None if self._is_pytorch_backend else - request.sampling_params.logits_processor, - kv_cache_retention_config=request.kv_cache_retention_config, - context_phase_params=context_phase_params, - type=request_type, - cache_salt_id=request.cache_salt_id) - executor_request.py_lora_path = py_lora_path - - if self._is_pytorch_backend and request.multimodal_params is not None: - if request.multimodal_params.multimodal_data is not None: - # NOTE: Deserialize SharedTensor handle to actual tensor - request.multimodal_params.to_tensor("multimodal_data") - executor_request.py_multimodal_data = request.multimodal_params.multimodal_data - - if self._is_pytorch_backend and request.sampling_params.logits_processor: - # For PyTorch backend, we attach logits processors as a dynamic Python attribute - # instead of using the C++ binding, since the latter will cause PyCapsule pickling issues. - lp = request.sampling_params.logits_processor - executor_request.py_logits_post_processors = lp if isinstance( - lp, list) else [lp] - - executor_request.py_scheduling_params = None - if self._is_pytorch_backend and request.scheduling_params is not None: - executor_request.py_scheduling_params = request.scheduling_params - - if request.arrival_time is not None: - executor_request.py_arrival_time = request.arrival_time - - if request.query_token_ids is not None: - # pytorch star attention workflow - # a workaround to avoid public interface update - req_id = self.engine.enqueue_request(executor_request, - request.query_token_ids) - else: - req_id = self.engine.enqueue_request(executor_request) - return req_id - except Exception as e: - raise RequestError(str(e)) from e - - def submit(self, request: GenerationRequest) -> GenerationResult: - """ Low-level API to the executor. Return a "future" GenerationResult which can be waited. """ - self.start() - - if self.rank != 0: - raise RuntimeError( - "Only rank 0 can submit requests.\n" - "To fix this, ensure that the llm.generate(...) method is " - "guarded with the `if __name__ == '__main__':` block.") - - client_id = request.id if request.id is not None else self._get_next_client_id( - ) - if request.id is None: - request.set_id(client_id) - - logprob_params = self._get_logprob_params(request) - - result = GenerationResult( - request, - background_error_handler=self._handle_background_error, - executor=self, - disaggregated_params=request.disaggregated_params, - logprob_params=logprob_params) - - self._results[client_id] = result - - request_id = self._enqueue_request(request) - # request_id returned from backend is necessary for the abort_request method. - self._client_id_to_request_id[client_id] = request_id - - self._handle_background_error() - - return result - - def _pop_result(self, client_id: int): - self._results.pop(client_id, None) - self._client_id_to_request_id.pop(client_id, None) - - def block_subordinates(self): - if self.rank != 0: - if isinstance(self.engine, tllm.Executor): - self.shutdown() - raise self.WorkerExit( - "block_subordinates() should be used in a `with GenerationExecutorWorker() as ...:` block" - ) - from tensorrt_llm._torch.pyexecutor.py_executor import PyExecutor - if isinstance(self.engine, PyExecutor): - self.engine.wait_shutdown() - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback) -> bool: - self.shutdown() - return exc_type is None or exc_type == WorkerBase.WorkerExit - - def __del__(self): - self.shutdown() - - -class AwaitResponseHelper: - ''' Multiple-implementations for await_response for performance. ''' - - class HandlerKind(enum.Enum): - unknown = 0 - single_process_worker = 1 - ipc_batched = 2 - - def __init__(self, worker: "WorkerBase"): - # TODO: make worker weakref - self.worker = worker - self.handler_kind: AwaitResponseHelper.HandlerKind = AwaitResponseHelper.HandlerKind.unknown - self.enable_postprocprocess_parallel = self.worker.enable_postprocess_parallel - # The error responses when submit request failed will be put here - self.temp_error_responses = Queue() - - def responses_handler(self, responses: List[tllm.Response]): - HandlerKind = AwaitResponseHelper.HandlerKind - - if self.handler_kind is HandlerKind.unknown: - if not (self.worker.result_queue is not None - or self.worker.postproc_queues is not None): - print_colored_debug( - f"creating await_response helper for Worker\n", - color="yellow") - # When ExecutorBindingWorker is used in the main process - # aka the single process mode - self.handler_kind = HandlerKind.single_process_worker - elif self.worker.result_queue is not None or self.worker.postproc_queues is not None: - # The ExecutorBindingProxy is used - print_colored_debug(f"creating await_response helper for IPC\n", - color="yellow") - self.handler_kind = HandlerKind.ipc_batched - else: - raise NotImplementedError - - match self.handler_kind: - case HandlerKind.single_process_worker: - return self.handle_for_worker(responses) - case HandlerKind.ipc_batched: - return self.handle_for_ipc_batched(responses) - case _: - raise NotImplementedError - - def __call__(self, timeout: Optional[float] = None) -> bool: - ''' This method should be called by a ManagedThread. ''' - timeout = timeout or 0.1 - responses = self.worker.engine.await_responses( - timeout=datetime.timedelta(seconds=timeout)) - # filter since The _engine_response_callback may return None - responses = list( - filter( - lambda _: _, - [self.worker._engine_response_callback(r) for r in responses])) - - # append the error responses to the temp_error_responses - while not self.temp_error_responses.empty(): - responses.append(self.temp_error_responses.get()) - - with nvtx_range_debug(f"await_response-{len(responses)}", - color="red", - category="Worker"): - self.responses_handler(responses) - return True - - def handle_for_worker(self, responses: List[tllm.Response]) -> None: - ''' Return the responses to asyncio.event_loop. ''' - event_loop = None - async_queues = [] - for response in responses: - assert response is not None - queue = self.worker.return_queue(response.client_id) - - response = _maybe_wrap_response(self.worker, response, - self.worker._is_pytorch_backend) - - # For AsyncQueue.sync_q, we will batch the events to avoid too many - # event notifications, thus put without wait here. - if isinstance(queue, _SyncQueue): - global_tracer().log_instant("worker-rsp.put") - queue.put_nowait(response) - async_queues.append(queue) - # all the loops are identical - event_loop = event_loop or queue.loop - else: - queue.put(response) - - if response.has_error() or response.result.is_final: - self.worker._pop_result(response.client_id) - - # Notify the events in bulk for performance. - if async_queues: - _SyncQueue.notify_many(event_loop, async_queues) - - def handle_for_ipc_batched(self, responses: List[tllm.Response]) -> None: - ''' Perform the IPC in batch explicitly. ''' - postproc_batches = [ - [] - for _ in range(self.worker.postproc_config.num_postprocess_workers) - ] if self.enable_postprocprocess_parallel else None - rsp_batch = [] if not self.enable_postprocprocess_parallel else None - - for response in responses: - - if isinstance(response, ErrorResponse): - pass # send ErrorResponse directly - elif self.worker._has_background_error(): - response = self.worker._create_error_response(response) - elif response.has_error(): - # Convert to ErrorResponse, because tllm.Response cannot be - # serialized when it has error. - response = ErrorResponse(response.client_id, response.error_msg, - response.request_id) - else: - response = _maybe_wrap_response(self.worker, response, - self.worker._is_pytorch_backend) - - _send_rsp(self.worker, - response, - postproc_batches=postproc_batches, - rsp_batch=rsp_batch) - - if postproc_batches: - for wid, batch in enumerate(postproc_batches): - if batch: - self.worker.postproc_queues[wid].put(batch) - - if rsp_batch: - self.worker.result_queue.put(rsp_batch) - - -def _get_params_for_first_rsp( - worker, - client_id) -> Tuple[Optional[SamplingParams], Optional[PostprocParams]]: - res = worker._results.get(client_id, None) - assert res is not None - if not res._params_transmitted: - res._params_transmitted = True - return res.sampling_params, res.postproc_params - return None, None - - -def _get_logprobs(worker, - response: tllm.Response, - is_pytorch_backend=False) -> Optional[LogProbsResult]: - """Compute logprob and prompt logprob and clear out logits if applicable. - """ - if is_pytorch_backend: - # _get_logprobs() is a WAR for the TRT backend, where top-k logprobs are computed post runtime. - # In the PyTorch backend, logprobs are already computed during runtime if requested. - return None - - logprobs_result = None - generation_result = worker._results.get(response.client_id, None) - - if not generation_result: - return - - logprob_params = getattr(generation_result, "_logprob_params", None) - if logprob_params: - logprobs_result = compute_logprobs(logprob_params.prompt_logprobs, - logprob_params.logprobs, - response.result.context_logits, - response.result.generation_logits, - response.result.output_token_ids[0]) - - if logprob_params.drop_context_logits: - response.clear_context_logits() - - if logprob_params.drop_generation_logits: - response.clear_generation_logits() - - if response.result.is_final: - generation_result.clear_logprob_params() - - return logprobs_result - - -def _send_rsp( - worker, - response: Union[tllm.Response, ResponseWrapper, ErrorResponse], - postproc_batches: Optional[List[List["PostprocWorker.Input"]]] = None, - rsp_batch: Optional[List[tllm.Response]] = None): - # if postproc_batches is set, append to batch instead of putting to IpcQueue - - if worker.result_queue is not None: - if rsp_batch is not None: - rsp_batch.append(response) - else: - worker.result_queue.put(response) - else: - sampling_params, postproc_params = _get_params_for_first_rsp( - worker, response.client_id) - inp = PostprocWorker.Input( - response, - # sampling_params is necessary for creating fake GenerationResult - # instances in the postproc processes. They are for incremental - # detokenize. They should be transmitted only once for each - # Request. - sampling_params=sampling_params, - postproc_params=postproc_params, - streaming=worker._results.get(response.client_id, None)._streaming) - - pid = response.client_id % worker.postproc_config.num_postprocess_workers - - if not postproc_batches: - # Group the responses into buckets for the postprocessing steps. - # Bucketing is used instead of random dispatching because the - # incremental detokenization during postprocessing relies on the - # prior CompletionOutput of a given request. - worker.postproc_queues[pid].put(inp) - else: - postproc_batches[pid].append(inp) - - # Eliminate the finished GenerationRequest instances timely, which may - # take considerable memory. - if is_llm_response(response): - if response.has_error() or response.result.is_final: - worker._pop_result(response.client_id) - elif isinstance(response, ErrorResponse): - worker._pop_result(response.client_id) - else: - raise ValueError(f"Unknown response type: {response}") - - -def _get_metrics_dict( - response: tllm.Response) -> dict[RequestEventTiming, float]: - req_perf_metrics, metrics_dict = None, {} - res = response.result - if res: - if hasattr(res, '_result'): - if result := res.get_result(): - req_perf_metrics = result.request_perf_metrics - else: - req_perf_metrics = res.request_perf_metrics - if req_perf_metrics and req_perf_metrics.timing_metrics: - metrics_dict = { - RequestEventTiming.ARRIVAL_TIME: - req_perf_metrics.timing_metrics.arrival_time.total_seconds(), - RequestEventTiming.FIRST_TOKEN_TIME: - req_perf_metrics.timing_metrics.first_token_time.total_seconds( - ), - RequestEventTiming.FIRST_SCHEDULED_TIME: - req_perf_metrics.timing_metrics.first_scheduled_time. - total_seconds(), - RequestEventTiming.LAST_TOKEN_TIME: - req_perf_metrics.timing_metrics.last_token_time.total_seconds() - } - return metrics_dict - - -def _maybe_wrap_response( - worker, - response: tllm.Response, - is_pytorch_backend=False) -> Union[tllm.Response, ResponseWrapper]: - - logprobs_result = _get_logprobs(worker, response, is_pytorch_backend) - req_perf_metrics = _get_metrics_dict(response) - if logprobs_result or req_perf_metrics: - response = ResponseWrapper(response, logprobs_result, req_perf_metrics) - return response diff --git a/tensorrt_llm/llmapi/utils.py b/tensorrt_llm/llmapi/utils.py index 5e9f2e69aa4..a08ad9fd03f 100644 --- a/tensorrt_llm/llmapi/utils.py +++ b/tensorrt_llm/llmapi/utils.py @@ -244,14 +244,14 @@ def __init__(self, task: Callable[..., bool], error_queue: Queue, name: Optional[str] = None, + stop_event: Optional[threading.Event] = None, **kwargs): super().__init__(name=name) self.task = task self.error_queue = error_queue self.kwargs = kwargs self.daemon = True - - self.stop_event = threading.Event() + self.stop_event = stop_event or threading.Event() def run(self): diff --git a/tests/unittest/executor/test_rpc.py b/tests/unittest/executor/test_rpc.py new file mode 100644 index 00000000000..69546cd5d0d --- /dev/null +++ b/tests/unittest/executor/test_rpc.py @@ -0,0 +1,191 @@ +import time + +import pytest + +from tensorrt_llm.executor.rpc import RPCClient, RPCError, RPCServer + + +def test_rpc_server_basics(): + + class App: + + def hello(self): + print("hello") + + server = RPCServer(App()) + print("bind") + server.bind("ipc:///tmp/rpc_test") + print("start") + server.start() + print("sleep") + + time.sleep(1) + print("shutdown") + server.shutdown() + + +def test_rpc_client_context_manager(): + + class App: + + def hello(self): + print("hello") + + with RPCServer(App()) as server: + server.bind("ipc:///tmp/rpc_test") + server.start() + time.sleep(1) + + +def test_rpc_hello_without_arg(): + + class App: + + def hello(self): + print("hello") + return "world" + + with RPCServer(App()) as server: + server.bind("ipc:///tmp/rpc_test") + server.start() + time.sleep(0.1) + client = RPCClient("ipc:///tmp/rpc_test") + ret = client.hello() # sync call + assert ret == "world" + + +def test_rpc_hello_with_arg(): + + class App: + + def hello(self, name: str, location: str): + print("hello") + return f"hello {name} from {location}" + + with RPCServer(App()) as server: + server.bind("ipc:///tmp/rpc_test") + server.start() + time.sleep(0.1) + client = RPCClient("ipc:///tmp/rpc_test") + ret = client.hello("app", location="Marvel") # sync call + assert ret == "hello app from Marvel" + + +def test_rpc_server_address(): + + class App: + + def hello(self): + print("hello") + return "world" + + with RPCServer(App()) as server: + server.bind("ipc:///tmp/rpc_test") + server.start() + time.sleep(0.1) + assert server.address == "ipc:///tmp/rpc_test" + + +def test_rpc_with_error(): + + class App: + + def hello(self): + raise ValueError("hello") + + with RPCServer(App()) as server: + server.bind("ipc:///tmp/rpc_test_error") + server.start() + time.sleep(0.1) + client = RPCClient("ipc:///tmp/rpc_test_error") + with pytest.raises(RPCError): + client.hello() + + +def test_rpc_without_wait_response(): + + class App: + + def __init__(self): + self.task_submitted = False + + def send_task(self) -> None: + # Just submit the task and return immediately + # The result is not important + self.task_submitted = True + return None + + def get_task_submitted(self) -> bool: + return self.task_submitted + + with RPCServer(App()) as server: + server.bind("ipc:///tmp/rpc_test_no_wait") + server.start() + time.sleep(0.1) + client = RPCClient("ipc:///tmp/rpc_test_no_wait") + client.send_task(need_response=False) + time.sleep(0.1) # wait for some time to make sure the task is submitted + assert client.get_task_submitted() + + +def test_rpc_without_response_performance(): + # At any circumstances, the RPC call without response should be faster than the one with response + class App: + + def __init__(self): + self.task_submitted = False + + def send_task(self) -> None: + # Just submit the task and return immediately + # The result is not important + time.sleep(0.001) + return None + + with RPCServer(App(), num_workers=10) as server: + server.bind("ipc:///tmp/rpc_test_no_wait") + server.start() + time.sleep(0.1) + client = RPCClient("ipc:///tmp/rpc_test_no_wait") + + time_start = time.time() + for i in range(100): + client.send_task(need_response=False) + time_end = time.time() + + no_wait_time = time_end - time_start + + time_start = time.time() + for i in range(100): + client.send_task(need_response=True) + time_end = time.time() + wait_time = time_end - time_start + + assert no_wait_time < wait_time, f"{no_wait_time} > {wait_time}" + + +@pytest.mark.parametrize("async_run_task", [True, False]) +@pytest.mark.parametrize("use_ipc_addr", [True, False]) +def test_rpc_benchmark(async_run_task: bool, use_ipc_addr: bool): + + class App: + + def cal(self, n: int): + return n * 2 + + with RPCServer(App(), async_run_task=async_run_task) as server: + address = "ipc:///tmp/rpc_test" if use_ipc_addr else "tcp://127.0.0.1:*" + + server.bind(address) + server.start() + time.sleep(0.1) + + client = RPCClient(server.address) + + time_start = time.time() + for i in range(10000): + ret = client.cal(i) # sync call + assert ret == i * 2, f"{ret} != {i * 2}" + time_end = time.time() + print( + f"Time taken: {time_end - time_start} seconds, {10000 / (time_end - time_start)} calls/second" + ) diff --git a/tests/unittest/executor/test_worker_base.py b/tests/unittest/executor/test_worker_base.py deleted file mode 100644 index 698c43cb1ad..00000000000 --- a/tests/unittest/executor/test_worker_base.py +++ /dev/null @@ -1,165 +0,0 @@ -import os -import sys -import time - -import pytest -import torch - -from tensorrt_llm._utils import mpi_comm, mpi_rank, mpi_world_size -from tensorrt_llm.bindings import executor as tllm -from tensorrt_llm.llmapi.mpi_session import MpiPoolSession - -# isort: off -sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/..") -from utils.llm_data import llm_models_root -# isort: on - -from tensorrt_llm._torch.pyexecutor.config import update_executor_config -from tensorrt_llm.executor.request import GenerationRequest -from tensorrt_llm.executor.worker_base import WorkerBase -from tensorrt_llm.llmapi.llm_args import LlmArgs -from tensorrt_llm.sampling_params import SamplingParams - -default_model_name = "llama-models-v2/TinyLlama-1.1B-Chat-v1.0" -model_path = llm_models_root() / default_model_name - - -class FakeWorker(WorkerBase): - - def __init__(self, engine: str, tp_size: int = 1): - llm_args, executor_config = create_fake_executor_config(engine, tp_size) - super().__init__( - engine=engine, - llm_args=llm_args, - hf_model_dir=engine, - ) - # Pass config in constructor and finalize with parameterless setup - self._executor_config = executor_config - self.llm_args = llm_args - self.setup_engine() - - def start(self): - pass - - def shutdown(self): - if self.engine is not None: - self.engine.shutdown() - self.engine = None - - -class TestWorkerBase: - - def test_create_engine(self): - with self.FakeWorker(engine=model_path) as worker: - print(f"Created engine: {worker.engine}") - - def test_submit_request(self): - sampling_params = SamplingParams(max_tokens=10) - request = GenerationRequest(prompt_token_ids=[3, 4, 5], - sampling_params=sampling_params) - with self.FakeWorker(engine=model_path) as worker: - print(f"Created engine: {worker.engine}") - worker.submit(request) - for i in range(10): - time.sleep(0.5) - worker.await_responses() - print(f"Submitted request: {request}") - time.sleep(6) - - def test_fetch_stats(self): - request = GenerationRequest( - prompt_token_ids=[3, 4, 5], - sampling_params=SamplingParams(max_tokens=10)) - with self.FakeWorker(engine=model_path) as worker: - worker.submit(request) - time.sleep(1) - worker.await_responses() - stats = worker.fetch_stats() - print(stats) - - @pytest.mark.parametrize("timeout", [0.1, 0.2, 1]) - def test_fetch_responses_timeout(self, timeout: float): - with self.FakeWorker(engine=model_path) as worker: - # Not submit any request, and let the await_responses timeout. - start_time = time.time() - results = worker.await_responses(timeout=timeout) - elapsed = time.time() - start_time - print(f"await_responses latency: {elapsed:.3f} seconds") - assert timeout / 2 <= elapsed <= timeout * 2, f"Latency out of expected range: {elapsed}" - - -def create_fake_executor_config(model_path, tp_size=1): - llm_args = LlmArgs(model=model_path, - cuda_graph_config=None, - tensor_parallel_size=tp_size) - - executor_config = tllm.ExecutorConfig(1) - executor_config.max_batch_size = 1 - executor_config.model_world_size = tp_size - - update_executor_config( - executor_config, - pytorch_backend_config=llm_args.get_pytorch_backend_config(), - mapping=llm_args.parallel_config.to_mapping(), - speculative_config=llm_args.speculative_config, - hf_model_dir=model_path, - max_input_len=20, - max_seq_len=40, - checkpoint_format=llm_args.checkpoint_format, - checkpoint_loader=llm_args.checkpoint_loader, - ) - - return llm_args, executor_config - - -class TestRpcWorkerBaseTP2: - - def setup_method(self): - self.llm_args = LlmArgs(model=model_path, tensor_parallel_size=2) - self.session = self.create_worker_session() - - def create_worker_session(self): - session = MpiPoolSession(n_workers=2) - return session - - def test_create_executor(self): - futures = self.session.submit( - TestRpcWorkerBaseTP2.create_executor, - engine=model_path, - llm_args=self.llm_args, - ) - # Wait for completion - for future in futures: - future.result() - - self.session.shutdown() - - @staticmethod - def create_executor(engine, llm_args): - rank = mpi_rank() - world_size = mpi_world_size() - device_id = rank % torch.cuda.device_count() - torch.cuda.set_device(device_id) - - print(f"[Test] Rank {rank}/{world_size} using device {device_id}") - - # Synchronize all workers before creating executor - mpi_comm().barrier() - - print(f"[Test] Rank {rank} creating WorkerBase...") - executor = FakeWorker(engine=engine, tp_size=2) - - # For PyTorch backend, all ranks need to participate in setup - print(f"[Test] Rank {rank} calling setup_engine...") - - # Setup the engine which contains another barrier - executor.setup_engine() - - print(f"[Test] Rank {rank} setup_engine completed successfully") - - executor.shutdown() - - -if __name__ == "__main__": - test_worker_base = TestWorkerBase() - test_worker_base.test_fetch_stats() From df1ee9eb2f556cebbafa71262da7d35ea672ad55 Mon Sep 17 00:00:00 2001 From: Superjomn <328693+Superjomn@users.noreply.github.com> Date: Tue, 1 Jul 2025 11:00:24 +0000 Subject: [PATCH 03/13] refactor worker Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> refine WorkerBase interface Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> add test for BaseWorker Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> add prototype for rpc worker and proxy Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> add fetch_stats and kvcache events to WorkerBase Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> amend test_worker_base Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> refine rpc timeout TODO: The timeout error may need a specific Error type enhance rpc_worker and test Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> Signed-off-by: chunweiy Signed-off-by: chunweiy <328693+Superjomn@users.noreply.github.com> Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> --- tensorrt_llm/executor/rpc.py | 72 +- tensorrt_llm/executor/rpc_proxy.py | 162 +++++ tensorrt_llm/executor/rpc_worker.py | 92 +++ tensorrt_llm/executor/worker.py | 36 +- tensorrt_llm/executor/worker_base.py | 746 ++++++++++++++++++++ tests/unittest/executor/test_rpc.py | 65 +- tests/unittest/executor/test_rpc_worker.py | 99 +++ tests/unittest/executor/test_worker_base.py | 95 +++ 8 files changed, 1317 insertions(+), 50 deletions(-) create mode 100644 tensorrt_llm/executor/rpc_proxy.py create mode 100644 tensorrt_llm/executor/rpc_worker.py create mode 100644 tensorrt_llm/executor/worker_base.py create mode 100644 tests/unittest/executor/test_rpc_worker.py create mode 100644 tests/unittest/executor/test_worker_base.py diff --git a/tensorrt_llm/executor/rpc.py b/tensorrt_llm/executor/rpc.py index a863bc93942..7c6bc0fcbde 100644 --- a/tensorrt_llm/executor/rpc.py +++ b/tensorrt_llm/executor/rpc.py @@ -27,6 +27,7 @@ class RPCRequest(NamedTuple): args: tuple kwargs: dict need_response: bool = True + timeout: float = 0.5 class RPCResponse(NamedTuple): @@ -176,15 +177,35 @@ async def _worker_routine(self, stop_event: threading.Event): if req.method_name in self._functions: try: if self._executor is not None: - # Dispatch to worker thread and await result + # Dispatch to worker thread and await result with timeout loop = asyncio.get_running_loop() - result = await loop.run_in_executor( - self._executor, self._functions[req.method_name], - *req.args, **req.kwargs) + + # Create a wrapper function to handle keyword arguments + def call_with_kwargs(): + return self._functions[req.method_name]( + *req.args, **req.kwargs) + + result = await asyncio.wait_for(loop.run_in_executor( + self._executor, call_with_kwargs), + timeout=req.timeout) else: - result = self._functions[req.method_name](*req.args, - **req.kwargs) + # For synchronous execution, we need to run in executor to support timeout + loop = asyncio.get_running_loop() + + # Create a wrapper function to handle keyword arguments + def call_with_kwargs(): + return self._functions[req.method_name]( + *req.args, **req.kwargs) + + result = await asyncio.wait_for(loop.run_in_executor( + None, call_with_kwargs), + timeout=req.timeout) response = RPCResponse(req.request_id, 'OK', result) + except asyncio.TimeoutError: + response = RPCResponse( + req.request_id, 'ERROR', + f"Method '{req.method_name}' timed out after {req.timeout} seconds" + ) except Exception: tb = traceback.format_exc() response = RPCResponse(req.request_id, 'ERROR', tb) @@ -313,13 +334,33 @@ async def _start_reader_if_needed(self): self._reader_task = loop.create_task(self._response_reader()) async def _call_async(self, name, *args, **kwargs): - """Async version of RPC call.""" + """Async version of RPC call. + Args: + name: Method name to call + *args: Positional arguments + **kwargs: Keyword arguments + __rpc_timeout: The timeout (seconds) for the RPC call. + __rpc_need_response: Whether the RPC call needs a response. + If set to False, the remote call will return immediately. + + Returns: + The result of the remote method call + """ + logger.debug( + f"RPC client calling method: {name} with args: {args} and kwargs: {kwargs}" + ) await self._start_reader_if_needed() - need_response = kwargs.pop("need_response", True) + need_response = kwargs.pop("__rpc_need_response", True) + timeout = kwargs.pop("__rpc_timeout", self._timeout) request_id = uuid.uuid4().hex logger.debug(f"RPC client sending request: {request_id}") - request = RPCRequest(request_id, name, args, kwargs, need_response) + request = RPCRequest(request_id, + name, + args, + kwargs, + need_response, + timeout=timeout) logger.debug(f"RPC client sending request: {request}") await self._client_socket.put_async(request) @@ -331,9 +372,12 @@ async def _call_async(self, name, *args, **kwargs): self._pending_futures[request_id] = future try: - return await asyncio.wait_for(future, self._timeout) + # If timeout, the remote call should return a timeout error timely, + # so we add 1 second to the timeout to ensure the client can get + # that result. + return await asyncio.wait_for(future, timeout + 1) except asyncio.TimeoutError: - raise RPCError(f"Request '{name}' timed out after {self._timeout}s") + raise RPCTimeout(f"Request '{name}' timed out after {timeout}s") finally: self._pending_futures.pop(request_id, None) @@ -356,7 +400,7 @@ def call_async(self, name: str, *args, **kwargs): Example: result = await client.call_async('remote_method', arg1, arg2, key=value) """ - return self._call_async(name, *args, **kwargs, need_response=True) + return self._call_async(name, *args, **kwargs, __rpc_need_response=True) def call_future(self, name: str, *args, **kwargs) -> concurrent.futures.Future: @@ -418,9 +462,7 @@ def __call__(self, *args, **kwargs): def call_async(self, *args, **kwargs): """Async call - returns coroutine""" - return self.client._call_async(self.method_name, - *args, - need_response=True, + return self.client._call_async(self.method_name, *args, **kwargs) def call_future(self, *args, **kwargs) -> concurrent.futures.Future: diff --git a/tensorrt_llm/executor/rpc_proxy.py b/tensorrt_llm/executor/rpc_proxy.py new file mode 100644 index 00000000000..6f120a9a3fd --- /dev/null +++ b/tensorrt_llm/executor/rpc_proxy.py @@ -0,0 +1,162 @@ +import atexit +import os +import threading +import time +from typing import Optional + +from ..llmapi.mpi_session import MpiPoolSession, MpiSession +from ..llmapi.tracer import global_tracer +from ..llmapi.utils import _SyncQueue, print_colored_debug +from .executor import GenerationExecutor +from .postproc_worker import PostprocWorkerConfig +from .request import GenerationRequest +from .result import GenerationResult +from .rpc import RPCClient +from .utils import (ErrorResponse, create_mpi_comm_session, + get_spawn_proxy_process_env, is_llm_response) + + +class GenerationExecutorRpcProxy(GenerationExecutor): + # NOTE: this is a global counter for the number of instances of this class + INSTANCE_COUNTER = 0 + + def __init__(self, + worker_kwargs: dict, + model_world_size: int = 1, + mpi_session: Optional[MpiSession] = None, + *, + postproc_worker_config: Optional[PostprocWorkerConfig] = None, + is_llm_executor: Optional[bool] = None, + garbage_collection_gen0_threshold: Optional[int] = None, + clock_unit: int = 1): + """ + Args: + worker_kwargs: kwargs for the rpc worker + model_world_size: the world size of the model + mpi_session: the mpi session to use + postproc_worker_config: the postproc worker config + is_llm_executor: whether this is an llm executor + garbage_collection_gen0_threshold: the garbage collection gen0 threshold + clock_unit: the unit of the clock, 1 means 1 second + """ + + GenerationExecutorRpcProxy.INSTANCE_COUNTER += 1 + self.rpc_addr = self.gen_uniq_rpc_addr() + self.rpc_client = RPCClient(self.rpc_addr) + + postproc_worker_config = postproc_worker_config or PostprocWorkerConfig( + ) + + super().__init__( + num_postprocess_workers=postproc_worker_config. + num_postprocess_workers, + postprocess_tokenizer_dir=postproc_worker_config. + postprocess_tokenizer_dir, + is_llm_executor=is_llm_executor, + ) + + self.mpi_session = self._create_mpi_session(model_world_size, + mpi_session) + + self._shutdown_event = threading.Event() + + self.launch_workers() + time.sleep(1) # wait for the workers to launch + + # Invoke model creation on the remote + # TBD: Move model creation to the mpi task, or left in RPC? + self.create_engine_remote() + + self.setup_mainloop() + + def launch_workers(self): + assert self.mpi_session is not None + self.mpi_session.submit(rpc_worker_main, + rpc_addr=self.rpc_addr, + **self.worker_kwargs) + + def main_loop_task(self): + """ + Main loop of the proxy, it will invoke the actions periodically. + """ + clock = 0 + while not self._shutdown_event.is_set(): + if clock % 1 == 0: + responses = self.await_responses_remote() + self.handle_responses(responses) + if clock % 10 == 0: + stats = self.get_stats_remote() # TODO + self.handle_stats(stats) + + clock += 1 + time.sleep(self.clock_unit) + + def setup_mainloop(self): + self.main_loop_thread = threading.Thread(target=self.main_loop_task, + daemon=True) + self.main_loop_thread.start() + atexit.register(self.shutdown) + + def handle_responses(self, responses: list[GenerationResult]) -> bool: + async_queues = [] + event_loop = None + + def process_res(res): + client_id = res.client_id + nonlocal event_loop + nonlocal async_queues + + queue = self._results[client_id].queue + if isinstance(queue, _SyncQueue): + queue.put_nowait(res) + async_queues.append(queue) + # all the loops are identical + event_loop = event_loop or queue.loop + else: + queue.put(res) + + if (is_llm_response(res) and res.result.is_final) or isinstance( + res, ErrorResponse): + self._results.pop(client_id) + + for res in responses: + global_tracer().log_instant("RPC.get") + process_res(res) + + if async_queues: + _SyncQueue.notify_many(event_loop, async_queues) + + def handle_stats(self, stats: dict): + raise NotImplementedError + + def submit(self, request: GenerationRequest) -> GenerationResult: + # submit is a fire-and-forget operation, don't need to wait for response + return self.rpc_client.submit(request, need_response=False) + + def await_responses_remote(self): + return self.rpc_client.await_responses() + + def create_engine_remote(self): + return self.rpc_client.create_engine() # TODO + + def shutdown_remote(self): + self.rpc_client.shutdown() + + def _create_mpi_session(self, model_world_size: int, + mpi_session: Optional[MpiSession]): + mpi_process_pre_spawned: bool = get_spawn_proxy_process_env() + if mpi_session is None: + if mpi_process_pre_spawned: + print_colored_debug('create comm session ...\n', "yellow") + self.mpi_session = create_mpi_comm_session(model_world_size) + else: + print_colored_debug('create pool session ...\n', "yellow") + self.mpi_session = MpiPoolSession(n_workers=model_world_size) + else: + print_colored_debug('using external mpi session ...\n', "yellow") + self.mpi_session = mpi_session + + @staticmethod + def gen_uniq_rpc_addr() -> str: + process_id = os.getpid() + return f"ipc:///tmp/rpc-proxy-{process_id}-{GenerationExecutorRpcProxy.INSTANCE_COUNTER}" diff --git a/tensorrt_llm/executor/rpc_worker.py b/tensorrt_llm/executor/rpc_worker.py new file mode 100644 index 00000000000..6c748428229 --- /dev/null +++ b/tensorrt_llm/executor/rpc_worker.py @@ -0,0 +1,92 @@ +from pathlib import Path +from queue import Queue +from threading import Event +from typing import Optional, Union + +from .._utils import mpi_rank +from ..bindings import executor as tllm +from ..builder import Engine +from ..logger import logger +from ..lora_manager import LoraConfig +from ..sampling_params import BatchedLogitsProcessor +from .postproc_worker import PostprocWorkerConfig +from .rpc import RPCServer +from .worker_base import WorkerBase + + +class RpcWorker(WorkerBase): + """ + A RPC wrapper for the WorkerBase class. + + Actions: + - `setup_engine`: Setup the engine. + - `fetch_responses`: Fetch the latest responses from engine. + - `fetch_stats`: Fetch the latest stats from engine. + - `fetch_kv_cache_events`: Fetch the latest kv cache events from engine. + - `shutdown`: Shutdown the worker. + """ + + def __init__( + self, + engine: Union[Path, Engine], + executor_config: Optional[tllm.ExecutorConfig] = None, + is_llm_executor: Optional[bool] = None, + ) -> None: + super().__init__(engine=engine, + executor_config=executor_config, + is_llm_executor=is_llm_executor) + self.shutdown_event = Event() + + self._response_queue = Queue() + self.set_result_queue(self._response_queue) + + def fetch_responses(self) -> list: + logger.debug(f"RPC worker {mpi_rank()} is fetching responses") + super().await_responses() + qsize = self._response_queue.qsize() + return [self._response_queue.get() for _ in range(qsize)] + + def shutdown(self): + logger.debug(f"RPC worker {mpi_rank()} is shutting down") + self.shutdown_event.set() + super().shutdown() + + @staticmethod + def main_task( + engine: Union[Path, Engine], + rpc_addr: str, + *, + executor_config: Optional[tllm.ExecutorConfig] = None, + batched_logits_processor: Optional[BatchedLogitsProcessor] = None, + postproc_worker_config: Optional[PostprocWorkerConfig] = None, + is_llm_executor: Optional[bool] = None, + lora_config: Optional[LoraConfig] = None, + garbage_collection_gen0_threshold: Optional[int] = None, + ) -> None: + # Step 1: Create the worker instance + worker = RpcWorker(engine=engine, executor_config=executor_config) + + if mpi_rank() != 0: + logger.debug(f"Worker {mpi_rank()} is setting up the engine") + # The non-leader worker will setup the engine immediately. + # The leader worker will wait for the RPC call to propagate the + # potential error. + worker.setup_engine( + engine=engine, + executor_config=executor_config, + batched_logits_processor=batched_logits_processor, + postproc_worker_config=postproc_worker_config, + is_llm_executor=is_llm_executor, + lora_config=lora_config, + garbage_collection_gen0_threshold= + garbage_collection_gen0_threshold) + + if mpi_rank() == 0: + # Step 2: Create the RPC service, it will expose all the APIs of the worker as remote call to the client + rpc_server = RPCServer(worker) + rpc_server.bind(rpc_addr) + rpc_server.start() + + # Step 3: Wait for the worker to shutdown + worker.shutdown_event.wait() + rpc_server.shutdown() diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index 148cdcf038c..941348c122c 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -15,9 +15,7 @@ from .._utils import KVCacheEventSerializer, mpi_comm, mpi_rank from ..bindings import executor as tllm from ..builder import Engine -from ..llmapi.llm_args import BaseLlmArgs, KvCacheConnectorConfig from ..llmapi.mpi_session import set_mpi_session_cpp -from ..llmapi.tokenizer import TokenizerBase from ..llmapi.tracer import VizTracer, set_global_tracer from ..llmapi.utils import (AsyncQueue, ManagedThread, _SyncQueue, clear_sched_affinity, print_colored_debug, @@ -52,10 +50,7 @@ def __init__( postproc_worker_config: Optional[PostprocWorkerConfig] = None, is_llm_executor: Optional[bool] = None, lora_config: Optional[LoraConfig] = None, - kv_connector_config: Optional[KvCacheConnectorConfig] = None, - hf_model_dir: Optional[Path] = None, - tokenizer: Optional[TokenizerBase] = None, - llm_args: Optional[BaseLlmArgs] = None, + garbage_collection_gen0_threshold: Optional[int] = None, ) -> None: super().__init__( engine=engine, @@ -162,7 +157,7 @@ def stats_serializer( return self._iteration_result_task(self.stats_queues, self.fetch_stats, self._iter_stats_result, - stats_serializer) + self._stats_serializer) def dispatch_kv_cache_events_task(self) -> bool: if isinstance(self.engine, tllm.Executor): @@ -215,22 +210,7 @@ def shutdown(self): self.dispatch_kv_cache_events_thread.stop() self.dispatch_kv_cache_events_thread.join() - self.engine.shutdown() - self.engine = None - - if self.llm_args is not None: - assert self._executor_config is None, "An empty executor_config is expected in shutdown when LLM arguments are defined." - if (self.llm_args.backend == "pytorch" - and hasattr(self, "checkpoint_loader") - and self.checkpoint_loader is not None): - self.checkpoint_loader.cleanup() - self.checkpoint_loader = None - else: - if hasattr( - self._executor_config, "checkpoint_loader" - ) and self._executor_config.checkpoint_loader is not None: - self._executor_config.checkpoint_loader.cleanup() - self._executor_config.checkpoint_loader = None + super().shutdown() # Check if there are any errors from the threads before shutdown. self._handle_background_error() @@ -268,10 +248,7 @@ def worker_main( is_llm_executor: Optional[ bool] = True, # whether it's the main executor instance lora_config: Optional[LoraConfig] = None, - kv_connector_config: Optional[KvCacheConnectorConfig] = None, - hf_model_dir: Optional[Path] = None, - tokenizer: Optional[TokenizerBase] = None, - llm_args: Optional[BaseLlmArgs] = None, + garbage_collection_gen0_threshold: Optional[int] = None, ) -> None: mpi_comm().barrier() print_colored_debug(f"Worker {mpi_rank()} entering worker_main...\n", @@ -399,10 +376,7 @@ def notify_proxy_threads_to_quit(): postproc_worker_config=postproc_worker_config, is_llm_executor=is_llm_executor, lora_config=lora_config, - kv_connector_config=kv_connector_config, - hf_model_dir=hf_model_dir, - tokenizer=tokenizer, - llm_args=llm_args) + garbage_collection_gen0_threshold=garbage_collection_gen0_threshold) except Exception as e: logger.error(f"Failed to initialize executor on rank {mpi_rank()}: {e}") logger.error(traceback.format_exc()) diff --git a/tensorrt_llm/executor/worker_base.py b/tensorrt_llm/executor/worker_base.py new file mode 100644 index 00000000000..ef3494816b7 --- /dev/null +++ b/tensorrt_llm/executor/worker_base.py @@ -0,0 +1,746 @@ +import copy +import datetime +import enum +import json +import weakref +from pathlib import Path +from queue import Queue +from typing import Dict, List, Optional, Tuple, Union + +import torch + +from tensorrt_llm.logger import logger + +from .._utils import (global_mpi_rank, global_mpi_size, mpi_comm, mpi_rank, + nvtx_range_debug) +from ..bindings import executor as tllm +from ..builder import ConfigEncoder, Engine, EngineConfig +from ..llmapi.llm_args import PybindMirror +from ..llmapi.tracer import global_tracer +from ..llmapi.utils import _SyncQueue, print_colored_debug +from ..logger import logger +from ..lora_manager import LoraConfig, LoraManager +from ..metrics import RequestEventTiming +from ..prompt_adapter_manager import PromptAdapterManager +from ..runtime import ModelConfig +from ..runtime.model_runner import _engine_config_to_model_config +from ..sampling_params import SamplingParams +from .executor import GenerationExecutor +from .ipc import IpcQueue +from .postproc_worker import PostprocParams, PostprocWorker +from .request import GenerationRequest, LoRARequest, PromptAdapterRequest +from .result import (GenerationResult, LogProbsResult, ResponseWrapper, + compute_logprobs) +from .utils import (ErrorResponse, RequestError, enable_llm_debug, + is_llm_response) + +if enable_llm_debug(): + logger.set_level("debug") + +__all__ = [ + "WorkerBase", +] + + +class WorkerBase(GenerationExecutor): + """ + Base class for all workers. + + It contains all the core logic for the worker, without any specific logic for + cross-process communication such as IPC or RPC. + """ + + def __init__( + self, + engine: Union[Path, Engine], + executor_config: Optional[tllm.ExecutorConfig] = None, + is_llm_executor: Optional[bool] = None, + ) -> None: + super().__init__(is_llm_executor=is_llm_executor) + + self.engine = None + self.rank = mpi_rank() + self.global_rank = global_mpi_rank() + # mapping: client_id -> GenerationResult + self._results: Dict[int, GenerationResult] = {} + # mapping: client_id from Proxy -> request_id returned from runtime backend + self._client_id_to_request_id: Dict[int, int] = {} + self._executor_config = executor_config + self._is_pytorch_backend = getattr(self._executor_config, "backend", + None) == "pytorch" + + if global_mpi_size() > 1: + logger.set_rank(self.global_rank) + + if isinstance(engine, list): + self.engine = engine[self.rank] + + self._await_response_helper = AwaitResponseHelper(weakref.proxy(self)) + + self.postproc_queues = None + self.result_queue = None + + self._lora_manager: Optional[LoraManager] = None + self._prompt_adapter_manager: Optional[PromptAdapterManager] = None + self._runtime_model_config: Optional[ModelConfig] = None + + def setup_engine( + self, + engine: Union[Path, Engine], + executor_config: Optional[tllm.ExecutorConfig] = None, + lora_config: Optional[LoraConfig] = None, + garbage_collection_gen0_threshold: Optional[int] = None) -> None: + + device_id = self.global_rank % torch.cuda.device_count() + torch.cuda.set_device(device_id) + + # Make sure C++ executor would use same devices/ranks as py_executor + global_rank = global_mpi_rank() + comm_ranks = mpi_comm().allgather(global_rank) + device_ids = mpi_comm().allgather(device_id) + + if executor_config is None: + executor_config = tllm.ExecutorConfig(1) + + executor_config.parallel_config = tllm.ParallelConfig( + participant_ids=comm_ranks, device_ids=device_ids) + + if isinstance(engine, list): + engine = engine[self.rank] + + if isinstance(engine, Engine): + self.engine = tllm.Executor(engine.engine, + json.dumps(engine.config.to_dict(), + cls=ConfigEncoder), + tllm.ModelType.DECODER_ONLY, + executor_config=executor_config, + managed_weights=engine.managed_weights) + elif not hasattr(executor_config, "backend"): + self.engine = tllm.Executor(engine, tllm.ModelType.DECODER_ONLY, + executor_config) + else: + args = { + "executor_config": executor_config, + "checkpoint_dir": executor_config.hf_model_dir, + } + if executor_config.backend == "pytorch": + from tensorrt_llm._torch.pyexecutor.py_executor_creator import \ + create_py_executor + create_executor = create_py_executor + args["lora_config"] = lora_config + args[ + "garbage_collection_gen0_threshold"] = garbage_collection_gen0_threshold + elif executor_config.backend == "_autodeploy": + from tensorrt_llm._torch.auto_deploy.shim.ad_executor import \ + create_autodeploy_executor + create_executor = create_autodeploy_executor + else: + raise ValueError( + f"Unsupported backend config: {executor_config.backend}") + self.engine = create_executor(**args) + + self._setup_lora(engine, executor_config, lora_config) + + def _setup_lora(self, engine: Union[Path, Engine], + executor_config: tllm.ExecutorConfig, + lora_config: Optional[LoraConfig]) -> None: + """Setup LoRA and prompt adapter managers.""" + # LoRA setup + if self.rank == 0 and isinstance(self.engine, tllm.Executor): + if isinstance(engine, Engine): + engine_config = engine.config + else: + engine_config = EngineConfig.from_json_file( + f"{engine}/config.json") + self._runtime_model_config = _engine_config_to_model_config( + engine_config) + if engine_config.build_config.plugin_config.lora_plugin: + # TODO(azuker): Passing peft cache manager to LoraManager is used for LoRA optimization + # (see LoraManager constructor docstring). Getting the peft cache manager from this + # point in the TRT flow is currently not supported (it's at the CPP + # Executor->ExecutorImpl->TrtGptModel->mPeftCacheManager) therefore for now this LoRA + # optimization is not available in TRT-python flow. + self._lora_manager = LoraManager(cpp_peft_cache_manager=None) + if engine_config.build_config.max_prompt_embedding_table_size > 0: + self._prompt_adapter_manager = PromptAdapterManager() + + if getattr(executor_config, "backend", + "") == "pytorch" and lora_config is not None: + from tensorrt_llm._torch.pyexecutor.resource_manager import \ + ResourceManagerType + peft_cache_manager = self.engine.resource_manager.resource_managers.get( + ResourceManagerType.PEFT_CACHE_MANAGER) + self._lora_manager = LoraManager( + cpp_peft_cache_manager=peft_cache_manager.impl) + lora_model_config = self.engine.model_engine.lora_model_config + assert lora_model_config is not None + self._lora_model_config = lora_model_config + + def abort_request(self, client_id: int) -> None: + # NOTE: the request_id is the request_id generated by cpp runtime, not the client_id + if self.engine.can_enqueue_requests(): + request_id = self._client_id_to_request_id.get(client_id, None) + if request_id is None: + logger.warning( + f"Request of client_id {client_id} is finished, cannot abort it." + ) + return + self.engine.cancel_request(request_id) + + def _engine_response_callback(self, response: tllm.Response): + return response + + def _load_lora_adapter(self, lora_request: LoRARequest) -> bool: + """Returns True if the adapter was loaded by this call, False if it was already loaded""" + adapter_id = str(lora_request.adapter_id) + newly_loaded_uids = self._lora_manager.load_from_ckpt( + [lora_request.path], + model_config=self._runtime_model_config if + self._runtime_model_config is not None else self._lora_model_config, + runtime_mapping=None, + uids=[adapter_id], + ckpt_source=lora_request.ckpt_source) + return adapter_id in newly_loaded_uids + + def _load_prompt_adapter(self, + prompt_adapter_request: PromptAdapterRequest): + self._prompt_adapter_manager.load_from_ckpt( + [prompt_adapter_request.local_path], + model_config=self._runtime_model_config, + uids=[str(prompt_adapter_request.adapter_id)]) + + def _enqueue_request(self, request: GenerationRequest) -> int: + assert request.id is not None + py_lora_path = None + if self._lora_manager is not None and request.lora_request is not None: + adapter_in_cache = self._lora_manager.is_adapter_in_cpu_cache( + request.lora_request.adapter_id) + self._load_lora_adapter(request.lora_request) + uid = str(request.lora_request.adapter_id) + lora_config = tllm.LoraConfig( + task_id=request.lora_request.adapter_id, + weights=self._lora_manager.cpp_lora_weights[uid] + if not adapter_in_cache else None, + config=self._lora_manager.cpp_lora_config[uid]) + py_lora_path = request.lora_request.lora_path + else: + lora_config = None + + prompt_token_ids = copy.deepcopy(request.prompt_token_ids) + prompt_tuning_config = None + if request.prompt_adapter_request is not None: + self._load_prompt_adapter(request.prompt_adapter_request) + uid = str(request.prompt_adapter_request.adapter_id) + prompt_tuning_config = tllm.PromptTuningConfig( + self._prompt_adapter_manager.uid_to_weights[uid]) + vocab_size = self._runtime_model_config.vocab_size + pa_length = prompt_tuning_config.embedding_table.size(0) + prompt_token_ids = list(range( + vocab_size, vocab_size + pa_length)) + prompt_token_ids + + # MULTIMODAL + # NOTE: Since, we only support PyTorch backend for multimodal, we will send multimodal_data through the 'py_multimodal_data' field + # except `multimodal_input` as it needs to go through the C++ runtime. + multimodal_input = None + if request.multimodal_params is not None and request.multimodal_params.has_content( + ): + if request.multimodal_params.multimodal_input is not None: + multimodal_input = tllm.MultimodalInput( + multimodal_hashes=request.multimodal_params. + multimodal_input.multimodal_hashes, + multimodal_positions=request.multimodal_params. + multimodal_input.multimodal_positions, + multimodal_lengths=request.multimodal_params. + multimodal_input.multimodal_lengths) + # NOTE: Setting to None here to avoid sending multimodal_input again through the 'py_multimodal_data' field + request.multimodal_params.multimodal_input = None + + context_phase_params = None + request_type = tllm.RequestType.REQUEST_TYPE_CONTEXT_AND_GENERATION + if request.disaggregated_params is not None: + assert ( + not self._is_pytorch_backend + or self.engine.kv_cache_transceiver is not None + ), "kv_cache_transceiver is disabled, please set 'cache_transceiver_config: backend:` in config file for disaggregated serving" + request_type = request.disaggregated_params.get_request_type() + if request_type == tllm.RequestType.REQUEST_TYPE_GENERATION_ONLY: + context_phase_params = request.disaggregated_params.get_context_phase_params( + ) + + is_overlap_enabled = self._is_pytorch_backend and not self._executor_config.pytorch_backend_config.disable_overlap_scheduler + if is_overlap_enabled: + is_disaggregated = self.engine.kv_cache_transceiver is not None + if is_disaggregated and ( + request_type == tllm.RequestType.REQUEST_TYPE_CONTEXT_ONLY): + raise ValueError( + "Context only requests are not supported in pytorch backend when overlap is enabled." + ) + + assert request.id is not None + + def _deduce_max_tokens(request: GenerationRequest, + executor_config: tllm.ExecutorConfig) -> int: + if request.sampling_params.max_tokens: + return request.sampling_params.max_tokens + # deduce max_tokens when it's not set by user + query_token_len = len( + request.query_token_ids) if request.query_token_ids else 0 + cp_size = 1 if (not hasattr(executor_config, "mapping") + or executor_config.mapping.cp_size + is None) else executor_config.mapping.cp_size + if not hasattr(executor_config, "max_seq_len"): + raise RuntimeError( + "max_tokens for sampling is not set and cannot be deduced") + splited_prompt_len = int(len(prompt_token_ids) / cp_size) + default_max_tokens = executor_config.max_seq_len - splited_prompt_len - query_token_len + if default_max_tokens < 0: + raise ValueError( + f"Deduced max_tokens {default_max_tokens} is less than 0, because" + f"prompt length {splited_prompt_len} plus query length {query_token_len} " + f"is larger than max_seq_len {executor_config.max_seq_len}") + return default_max_tokens + + try: + executor_request = tllm.Request( + client_id=request.id, + input_token_ids=prompt_token_ids, + max_tokens=_deduce_max_tokens(request, self._executor_config), + streaming=request.streaming, + sampling_config=request.sampling_params._get_sampling_config(), + end_id=-1 if request.sampling_params.ignore_eos else + request.sampling_params.end_id, + pad_id=request.sampling_params.pad_id, + output_config=request.sampling_params._get_output_config( + is_pytorch_backend=self._is_pytorch_backend), + # Beam search enforces return_all_generated_tokens=True regardless of the passed value + return_all_generated_tokens=False, + # convert python config into pybind config + lookahead_config=PybindMirror.maybe_to_pybind( + request.sampling_params.lookahead_config), + guided_decoding_params=request.sampling_params. + _get_guided_decoding_params(), + bad_words=request.sampling_params._get_bad_words(), + stop_words=request.sampling_params._get_stop_words(), + embedding_bias=request.sampling_params.embedding_bias, + lora_config=lora_config, + prompt_tuning_config=prompt_tuning_config, + multimodal_input=multimodal_input, + # NOTE: `multimodal_embedding` and `mrope_config` will be in MultimodalParams.multimodal_data. And this will be handled below by `py_multimodal_data`. + multimodal_embedding=None, + mrope_config=None, + logits_post_processor_name=( + tllm.Request.BATCHED_POST_PROCESSOR_NAME + if request.sampling_params.apply_batched_logits_processor + else None), + logits_post_processor=None if self._is_pytorch_backend else + request.sampling_params.logits_processor, + kv_cache_retention_config=request.kv_cache_retention_config, + context_phase_params=context_phase_params, + type=request_type) + executor_request.py_lora_path = py_lora_path + + if self._is_pytorch_backend and request.multimodal_params is not None: + if request.multimodal_params.multimodal_data is not None: + # NOTE: Deserialize SharedTensor handle to actual tensor + request.multimodal_params.to_tensor("multimodal_data") + executor_request.py_multimodal_data = request.multimodal_params.multimodal_data + + if self._is_pytorch_backend and request.sampling_params.logits_processor: + # For PyTorch backend, we attach logits processors as a dynamic Python attribute + # instead of using the C++ binding, since the latter will cause PyCapsule pickling issues. + lp = request.sampling_params.logits_processor + executor_request.py_logits_post_processors = lp if isinstance( + lp, list) else [lp] + + executor_request.py_scheduling_params = None + if self._is_pytorch_backend and request.scheduling_params is not None: + executor_request.py_scheduling_params = request.scheduling_params + + if request.query_token_ids is not None: + # pytorch star attention workflow + # a workaround to avoid public interface update + req_id = self.engine.enqueue_request(executor_request, + request.query_token_ids) + else: + req_id = self.engine.enqueue_request(executor_request) + return req_id + except Exception as e: + raise RequestError(str(e)) from e + + def submit(self, request: GenerationRequest) -> GenerationResult: + """ Low-level API to the executor. Return a "future" GenerationResult which can be waited. """ + if self.rank != 0: + raise RuntimeError( + "Only rank 0 can submit requests.\n" + "To fix this, ensure that the llm.generate(...) method is " + "guarded with the `if __name__ == '__main__':` block.") + + client_id = request.id if request.id is not None else self._get_next_client_id( + ) + if request.id is None: + request.set_id(client_id) + + logprob_params = self._get_logprob_params(request) + + result = GenerationResult( + request, + background_error_handler=self._handle_background_error, + executor=self, + disaggregated_params=request.disaggregated_params, + logprob_params=logprob_params) + + self._results[client_id] = result + + request_id = self._enqueue_request(request) + # request_id returned from backend is necessary for the abort_request method. + self._client_id_to_request_id[client_id] = request_id + + self._handle_background_error() + + return result + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.shutdown() + return True + + def await_responses(self) -> None: + self._await_response_helper() + logger.debug(f"worker done await_responses") + + def fetch_kv_cache_events(self) -> list: + if isinstance(self.engine, tllm.Executor): + # Check if the engine has a kv cache event manager + # If not, return an empty list for the events which will cause the thread to exit early. + event_manager = self.engine.get_kv_cache_event_manager() + if event_manager is None: + return [] + else: + return event_manager.get_latest_events() + else: + return self.engine.get_latest_kv_cache_events() + + def fetch_stats( + self) -> List[Tuple[tllm.IterationStats, tllm.RequestStats]]: + if isinstance(self.engine, tllm.Executor): + iter_stats = self.engine.get_latest_iteration_stats() + #TODO: Support req stats with TRT engine + # This would require ensuring iter and req stats have same size + return [(iter_stat, None) for iter_stat in iter_stats] + else: + return self.engine.get_latest_iteration_stats() + + # Define a Callable to join iteration and request stats + @staticmethod + def _stats_serializer( + stats: Tuple[tllm.IterationStats, tllm.RequestStats]) -> str: + iteration_stats, req_stats = stats + stats_dict = json.loads(iteration_stats.to_json_str()) + + if req_stats is not None and len(req_stats) > 0: + stats_dict["requestStats"] = [] + for req_stat in req_stats: + stats_dict["requestStats"].append( + json.loads(req_stat.to_json_str())) + + # Convert back to JSON string + return json.dumps(stats_dict) + + def set_result_queue(self, queue: Queue | IpcQueue): + """In multi-gpu mode, result_queue will be set here to communicate between the proxy and the worker 0 process.""" + assert self.postproc_queues is None + self.result_queue = queue + + def set_postproc_queues(self, queues: list[Queue | IpcQueue]): + """ Set the IPC queues for feeding post-processing processes. """ + assert self.result_queue is None + self.postproc_queues = queues + + def _pop_result(self, client_id: int): + self._results.pop(client_id, None) + self._client_id_to_request_id.pop(client_id, None) + + def shutdown(self): + if self.engine is not None: + if self.engine.can_enqueue_requests(): + self.engine.shutdown() + self.engine = None + + if hasattr( + self._executor_config, "checkpoint_loader" + ) and self._executor_config.checkpoint_loader is not None: + self._executor_config.checkpoint_loader.cleanup() + self._executor_config.checkpoint_loader = None + # Check if there are any errors from the threads before shutdown. + self._handle_background_error() + + def _has_background_error(self) -> bool: + # TODO[Superjomn]: The worker background error should be deprecated once + # RPC approach is supported. + return not self._error_queue.empty() + + +class AwaitResponseHelper: + ''' Multiple-implementations for await_response for performance. ''' + + class HandlerKind(enum.Enum): + unknown = 0 + single_process_worker = 1 + ipc_batched = 2 + + def __init__(self, worker: "WorkerBase"): + self.worker = worker + self.handler_kind: AwaitResponseHelper.HandlerKind = AwaitResponseHelper.HandlerKind.unknown + self.enable_postprocprocess_parallel = self.worker.enable_postprocess_parallel + # The error responses when submit request failed will be put here + self.temp_error_responses = Queue() + + def responses_handler(self, responses: List[tllm.Response]): + HandlerKind = AwaitResponseHelper.HandlerKind + + if self.handler_kind is HandlerKind.unknown: + if not (self.worker.result_queue is not None + or self.worker.postproc_queues is not None): + print_colored_debug( + f"creating await_response helper for Worker\n", + color="yellow") + # When ExecutorBindingWorker is used in the main process + # aka the single process mode + self.handler_kind = HandlerKind.single_process_worker + elif self.worker.result_queue is not None or self.worker.postproc_queues is not None: + # The ExecutorBindingProxy is used + print_colored_debug(f"creating await_response helper for IPC\n", + color="yellow") + self.handler_kind = HandlerKind.ipc_batched + else: + raise NotImplementedError + + match self.handler_kind: + case HandlerKind.single_process_worker: + return self.handle_for_worker(responses) + case HandlerKind.ipc_batched: + return self.handle_for_ipc_batched(responses) + case _: + raise NotImplementedError + + def __call__(self) -> bool: + ''' This method should be called by a ManagedThread. ''' + logger.debug(f"await_response: {self.worker.engine}") + responses = self.worker.engine.await_responses( + timeout=datetime.timedelta(milliseconds=100)) + logger.debug(f"PyExecutor returned {len(responses)} responses") + + # filter since The _engine_response_callback may return None + responses = list( + filter( + lambda _: _, + [self.worker._engine_response_callback(r) for r in responses])) + + # append the error responses to the temp_error_responses + while not self.temp_error_responses.empty(): + responses.append(self.temp_error_responses.get()) + + with nvtx_range_debug(f"await_response-{len(responses)}", + color="red", + category="Worker"): + self.responses_handler(responses) + return True + + def handle_for_worker(self, responses: List[tllm.Response]) -> None: + ''' Return the responses to asyncio.event_loop. ''' + event_loop = None + async_queues = [] + for response in responses: + assert response is not None + queue = self.worker.return_queue(response.client_id) + + response = _maybe_wrap_response(self.worker, response, + self.worker._is_pytorch_backend) + + # For AsyncQueue.sync_q, we will batch the events to avoid too many + # event notifications, thus put without wait here. + if isinstance(queue, _SyncQueue): + global_tracer().log_instant("worker-rsp.put") + queue.put_nowait(response) + async_queues.append(queue) + # all the loops are identical + event_loop = event_loop or queue.loop + else: + queue.put(response) + + if response.has_error() or response.result.is_final: + self.worker._pop_result(response.client_id) + + # Notify the events in bulk for performance. + if async_queues: + _SyncQueue.notify_many(event_loop, async_queues) + + def handle_for_ipc_batched(self, responses: List[tllm.Response]) -> None: + ''' Perform the IPC in batch explicitly. ''' + postproc_batches = [ + [] + for _ in range(self.worker.postproc_config.num_postprocess_workers) + ] if self.enable_postprocprocess_parallel else None + rsp_batch = [] if not self.enable_postprocprocess_parallel else None + + for response in responses: + + if isinstance(response, ErrorResponse): + pass # send ErrorResponse directly + elif self.worker._has_background_error(): + response = self.worker._create_error_response(response) + elif response.has_error(): + # Convert to ErrorResponse, because tllm.Response cannot be + # serialized when it has error. + response = ErrorResponse(response.client_id, response.error_msg, + response.request_id) + else: + response = _maybe_wrap_response(self.worker, response, + self.worker._is_pytorch_backend) + + _send_rsp(self.worker, + response, + postproc_batches=postproc_batches, + rsp_batch=rsp_batch) + + if postproc_batches: + for wid, batch in enumerate(postproc_batches): + if batch: + self.worker.postproc_queues[wid].put(batch) + + if rsp_batch: + self.worker.result_queue.put(rsp_batch) + + +def _get_params_for_first_rsp( + worker, + client_id) -> Tuple[Optional[SamplingParams], Optional[PostprocParams]]: + res = worker._results.get(client_id, None) + assert res is not None + if not res._params_transmitted: + res._params_transmitted = True + return res.sampling_params, res.postproc_params + return None, None + + +def _get_logprobs(worker, + response: tllm.Response, + is_pytorch_backend=False) -> Optional[LogProbsResult]: + """Compute logprob and prompt logprob and clear out logits if applicable. + """ + if is_pytorch_backend: + # _get_logprobs() is a WAR for the TRT backend, where top-k logprobs are computed post runtime. + # In the PyTorch backend, logprobs are already computed during runtime if requested. + return None + + logprobs_result = None + generation_result = worker._results.get(response.client_id, None) + + if not generation_result: + return + + logprob_params = getattr(generation_result, "_logprob_params", None) + if logprob_params: + logprobs_result = compute_logprobs(logprob_params.prompt_logprobs, + logprob_params.logprobs, + response.result.context_logits, + response.result.generation_logits, + response.result.output_token_ids[0]) + + if logprob_params.drop_context_logits: + response.clear_context_logits() + + if logprob_params.drop_generation_logits: + response.clear_generation_logits() + + if response.result.is_final: + generation_result.clear_logprob_params() + + return logprobs_result + + +def _send_rsp( + worker, + response: Union[tllm.Response, ResponseWrapper, ErrorResponse], + postproc_batches: Optional[List[List["PostprocWorker.Input"]]] = None, + rsp_batch: Optional[List[tllm.Response]] = None): + # if postproc_batches is set, append to batch instead of putting to IpcQueue + + if worker.result_queue is not None: + if rsp_batch is not None: + rsp_batch.append(response) + else: + worker.result_queue.put(response) + else: + sampling_params, postproc_params = _get_params_for_first_rsp( + worker, response.client_id) + inp = PostprocWorker.Input( + response, + # sampling_params is necessary for creating fake GenerationResult + # instances in the postproc processes. They are for incremental + # detokenize. They should be transmitted only once for each + # Request. + sampling_params=sampling_params, + postproc_params=postproc_params, + streaming=worker._results.get(response.client_id, None)._streaming) + + pid = response.client_id % worker.postproc_config.num_postprocess_workers + + if not postproc_batches: + # Group the responses into buckets for the postprocessing steps. + # Bucketing is used instead of random dispatching because the + # incremental detokenization during postprocessing relies on the + # prior CompletionOutput of a given request. + worker.postproc_queues[pid].put(inp) + else: + postproc_batches[pid].append(inp) + + # Eliminate the finished GenerationRequest instances timely, which may + # take considerable memory. + if is_llm_response(response): + if response.has_error() or response.result.is_final: + worker._pop_result(response.client_id) + elif isinstance(response, ErrorResponse): + worker._pop_result(response.client_id) + else: + raise ValueError(f"Unknown response type: {response}") + + +def _get_metrics_dict( + response: tllm.Response) -> dict[RequestEventTiming, float]: + req_perf_metrics, metrics_dict = None, {} + res = response.result + if res: + if hasattr(res, '_result'): + if result := res.get_result(): + req_perf_metrics = result.request_perf_metrics + else: + req_perf_metrics = res.request_perf_metrics + if req_perf_metrics and req_perf_metrics.timing_metrics: + metrics_dict = { + RequestEventTiming.ARRIVAL_TIME: + req_perf_metrics.timing_metrics.arrival_time.total_seconds(), + RequestEventTiming.FIRST_TOKEN_TIME: + req_perf_metrics.timing_metrics.first_token_time.total_seconds( + ), + RequestEventTiming.FIRST_SCHEDULED_TIME: + req_perf_metrics.timing_metrics.first_scheduled_time. + total_seconds(), + RequestEventTiming.LAST_TOKEN_TIME: + req_perf_metrics.timing_metrics.last_token_time.total_seconds() + } + return metrics_dict + + +def _maybe_wrap_response( + worker, + response: tllm.Response, + is_pytorch_backend=False) -> Union[tllm.Response, ResponseWrapper]: + + logprobs_result = _get_logprobs(worker, response, is_pytorch_backend) + req_perf_metrics = _get_metrics_dict(response) + if logprobs_result or req_perf_metrics: + response = ResponseWrapper(response, logprobs_result, req_perf_metrics) + return response diff --git a/tests/unittest/executor/test_rpc.py b/tests/unittest/executor/test_rpc.py index 69546cd5d0d..d00a4cc673b 100644 --- a/tests/unittest/executor/test_rpc.py +++ b/tests/unittest/executor/test_rpc.py @@ -123,7 +123,7 @@ def get_task_submitted(self) -> bool: server.start() time.sleep(0.1) client = RPCClient("ipc:///tmp/rpc_test_no_wait") - client.send_task(need_response=False) + client.send_task(__rpc_need_response=False) time.sleep(0.1) # wait for some time to make sure the task is submitted assert client.get_task_submitted() @@ -149,14 +149,14 @@ def send_task(self) -> None: time_start = time.time() for i in range(100): - client.send_task(need_response=False) + client.send_task(__rpc_need_response=False) time_end = time.time() no_wait_time = time_end - time_start time_start = time.time() for i in range(100): - client.send_task(need_response=True) + client.send_task(__rpc_need_response=True) time_end = time.time() wait_time = time_end - time_start @@ -183,9 +183,66 @@ def cal(self, n: int): time_start = time.time() for i in range(10000): - ret = client.cal(i) # sync call + ret = client.cal(i, __rpc_timeout=10) # sync call assert ret == i * 2, f"{ret} != {i * 2}" time_end = time.time() print( f"Time taken: {time_end - time_start} seconds, {10000 / (time_end - time_start)} calls/second" ) + + +@pytest.mark.parametrize("use_async", [True, False]) +def test_rpc_timeout(use_async: bool): + """Test RPC timeout functionality. + + Args: + use_async: Whether to test async RPC calls or sync RPC calls + """ + + class App: + + def slow_operation(self, delay: float): + """A method that takes a long time to complete.""" + time.sleep(delay) + return "completed" + + with RPCServer(App()) as server: + server.bind("ipc:///tmp/rpc_test_timeout") + server.start() + time.sleep(0.1) + client = RPCClient("ipc:///tmp/rpc_test_timeout") + + # Test that a short timeout causes RPCTimeout exception + with pytest.raises(RPCError) as exc_info: + if use_async: + # Test async call with timeout + import asyncio + + async def test_async_timeout(): + return await client.call_async('slow_operation', + 2.0, + __rpc_timeout=0.1) + + asyncio.run(test_async_timeout()) + else: + # Test sync call with timeout + client.slow_operation(2.0, __rpc_timeout=0.1) + + assert "timed out" in str( + exc_info.value), f"Timeout message not found: {exc_info.value}" + + # Test that a long timeout allows the operation to complete + if use_async: + # Test async call with sufficient timeout + import asyncio + + async def test_async_success(): + return await client.call_async('slow_operation', + 0.1, + __rpc_timeout=1.0) + + result = asyncio.run(test_async_success()) + else: + result = client.slow_operation(0.1, __rpc_timeout=1.0) + + assert result == "completed" diff --git a/tests/unittest/executor/test_rpc_worker.py b/tests/unittest/executor/test_rpc_worker.py new file mode 100644 index 00000000000..6d0be4bf531 --- /dev/null +++ b/tests/unittest/executor/test_rpc_worker.py @@ -0,0 +1,99 @@ +import multiprocessing +import os +import sys +import time +from concurrent.futures import ProcessPoolExecutor + +from test_worker_base import TestWorkerBase + +from tensorrt_llm.executor.request import GenerationRequest +from tensorrt_llm.executor.rpc import RPCClient +from tensorrt_llm.executor.rpc_proxy import GenerationExecutorRpcProxy +from tensorrt_llm.executor.rpc_worker import RpcWorker +from tensorrt_llm.sampling_params import SamplingParams + +# isort: off +sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/..") +from utils.llm_data import llm_models_root +# isort: on + +model_path = llm_models_root() / "llama-models-v2/TinyLlama-1.1B-Chat-v1.0" + + +class TestRpcWorker: + + def __init__(self): + self.executor_config = TestWorkerBase.create_fake_executor_config( + model_path) + + def create_tp1_worker_process(self): + addr = GenerationExecutorRpcProxy.gen_uniq_rpc_addr() + # Use spawn method instead of fork + mp_context = multiprocessing.get_context('spawn') + pool = ProcessPoolExecutor(max_workers=1, mp_context=mp_context) + pool.submit(RpcWorker.main_task, + engine=model_path, + rpc_addr=addr, + executor_config=self.executor_config) + return pool, addr + + def create_rpc_client(self, addr: str): + client = RPCClient(addr) + return client + + def test_main(self): + pool, addr = self.create_tp1_worker_process() + client = self.create_rpc_client(addr) + print("call setup_engine") + client.setup_engine(engine=model_path, + executor_config=self.executor_config, + __rpc_timeout=120) + print("call submit") + time.sleep(1) + + def process_request(): + ret = client.submit(GenerationRequest( + prompt_token_ids=[3, 4, 5], + sampling_params=SamplingParams(max_tokens=10)), + __rpc_need_response=False) + assert ret is None + + print(f"submit result: {ret}") + print("call fetch_responses") + # NOTE: known issue, the responses should be fetched before shutdown, + # or the shutdown will hang. + results = [] + for i in range(3): + time.sleep(3) + results.extend(client.fetch_responses()) + print(f"fetch_responses result: {results}") + assert len(results) == 1 + + def process_request_streaming(): + ret = client.submit(prompt_token_ids=[3, 4, 5], + sampling_params=SamplingParams(max_tokens=10), + streaming=True, + __rpc_need_response=False) + assert ret is None + + print("call fetch_responses") + # NOTE: known issue, the responses should be fetched before shutdown, + # or the shutdown will hang. + results = [] + for i in range(3): + time.sleep(3) + results.extend(client.fetch_responses()) + print(f"fetch_responses result: {results}") + print(f"generate_async result: {results}") + + process_request() + process_request_streaming() + + print("call shutdown") + client.shutdown(__rpc_timeout=10) + pool.shutdown() + + +if __name__ == '__main__': + worker = TestRpcWorker() + worker.test_main() diff --git a/tests/unittest/executor/test_worker_base.py b/tests/unittest/executor/test_worker_base.py new file mode 100644 index 00000000000..d825d077d18 --- /dev/null +++ b/tests/unittest/executor/test_worker_base.py @@ -0,0 +1,95 @@ +import os +import sys +import time + +# isort: off +sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/..") +from utils.llm_data import llm_models_root +from tensorrt_llm.bindings import executor as tllm +# isort: on + +from tensorrt_llm._torch.pyexecutor.config import update_executor_config +from tensorrt_llm.executor.request import GenerationRequest +from tensorrt_llm.executor.worker_base import WorkerBase +from tensorrt_llm.llmapi.llm_args import LlmArgs +from tensorrt_llm.sampling_params import SamplingParams + +default_model_name = "llama-models-v2/TinyLlama-1.1B-Chat-v1.0" +model_path = llm_models_root() / default_model_name + + +class TestWorkerBase: + + class FakeWorker(WorkerBase): + + def __init__(self, engine: str): + super().__init__(engine=engine) + executor_config = TestWorkerBase.create_fake_executor_config(engine) + self.setup_engine(engine=engine, executor_config=executor_config) + + def test_create_engine(self): + with self.FakeWorker(engine=model_path) as worker: + print(f"Created engine: {worker.engine}") + + def test_submit_request(self): + sampling_params = SamplingParams(max_tokens=10) + request = GenerationRequest(prompt_token_ids=[3, 4, 5], + sampling_params=sampling_params) + with self.FakeWorker(engine=model_path) as worker: + print(f"Created engine: {worker.engine}") + worker.submit(request) + for i in range(10): + time.sleep(0.5) + worker.await_responses() + print(f"Submitted request: {request}") + time.sleep(6) + + def test_fetch_stats(self): + request = GenerationRequest( + prompt_token_ids=[3, 4, 5], + sampling_params=SamplingParams(max_tokens=10)) + with self.FakeWorker(engine=model_path) as worker: + worker.submit(request) + time.sleep(1) + worker.await_responses() + stats = worker.fetch_stats() + assert len(stats) > 0 + + def test_dispatch_stats_task(self): + request = GenerationRequest( + prompt_token_ids=[3, 4, 5], + sampling_params=SamplingParams(max_tokens=10)) + with self.FakeWorker(engine=model_path) as worker: + worker.submit(request) + worker.await_responses() + worker.dispatch_stats_task() + time.sleep(10) + stats = worker.fetch_stats() + assert len(stats) == 1 + + @staticmethod + def create_fake_executor_config(model_path): + llm_args = LlmArgs(model=model_path, cuda_graph_config=None) + + executor_config = tllm.ExecutorConfig(1) + executor_config.max_batch_size = 1 + + update_executor_config( + executor_config, + backend="pytorch", + pytorch_backend_config=llm_args.get_pytorch_backend_config(), + mapping=llm_args.parallel_config.to_mapping(), + speculative_config=llm_args.speculative_config, + hf_model_dir=model_path, + max_input_len=20, + max_seq_len=40, + checkpoint_format=llm_args.checkpoint_format, + checkpoint_loader=llm_args.checkpoint_loader, + ) + + return executor_config + + +if __name__ == "__main__": + test_worker_base = TestWorkerBase() + test_worker_base.test_fetch_stats() From b4ab04148e388b6312f2b8191b2ba4f51fa19f91 Mon Sep 17 00:00:00 2001 From: Superjomn <328693+Superjomn@users.noreply.github.com> Date: Sat, 2 Aug 2025 10:25:09 +0800 Subject: [PATCH 04/13] support client shutdown_server and cancel requests client.shutdown_server Once the server shutdown, all the pending and new requests will return RPCCancelled error. enhance rpc test_rpc.py pass add rpc proxy and worker enhancing rpc, some issue with shutdown mechanism partition rpc into multiple files Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> . address multiple asyncio.run timeout issue The RPCClient cannot recv the message since the event loop doesn't run Created a persist loop instead. Signed-off-by: chunweiy Signed-off-by: chunweiy <328693+Superjomn@users.noreply.github.com> Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> --- tensorrt_llm/executor/ipc.py | 46 ++ tensorrt_llm/executor/rpc.py | 473 -------------------- tensorrt_llm/executor/rpc/__init__.py | 9 + tensorrt_llm/executor/rpc/rpc_client.py | 300 +++++++++++++ tensorrt_llm/executor/rpc/rpc_common.py | 46 ++ tensorrt_llm/executor/rpc/rpc_server.py | 286 ++++++++++++ tensorrt_llm/executor/rpc_proxy.py | 115 +++-- tensorrt_llm/executor/rpc_worker.py | 54 ++- tensorrt_llm/executor/worker_base.py | 33 +- tests/unittest/executor/test_rpc.py | 472 +++++++++++++------ tests/unittest/executor/test_rpc_proxy.py | 52 +++ tests/unittest/executor/test_rpc_worker.py | 55 +-- tests/unittest/executor/test_worker_base.py | 62 ++- 13 files changed, 1289 insertions(+), 714 deletions(-) delete mode 100644 tensorrt_llm/executor/rpc.py create mode 100644 tensorrt_llm/executor/rpc/__init__.py create mode 100644 tensorrt_llm/executor/rpc/rpc_client.py create mode 100644 tensorrt_llm/executor/rpc/rpc_common.py create mode 100644 tensorrt_llm/executor/rpc/rpc_server.py create mode 100644 tests/unittest/executor/test_rpc_proxy.py diff --git a/tensorrt_llm/executor/ipc.py b/tensorrt_llm/executor/ipc.py index d4318eb379e..00e9b4d336b 100644 --- a/tensorrt_llm/executor/ipc.py +++ b/tensorrt_llm/executor/ipc.py @@ -180,6 +180,52 @@ async def put_async(self, obj: Any): nvtx_mark("ipc.send", color="blue", category="IPC") + async def put_async_noblock(self, obj: Any): + self.setup_lazily() + try: + if self.use_hmac_encryption: + data = pickle.dumps(obj) # nosec B301 + signed_data = self._sign_data(data) + await self.socket.send(signed_data, flags=zmq.NOBLOCK) + else: + await self.socket.send_pyobj(obj, flags=zmq.NOBLOCK) + except Exception as e: + logger.error(f"Error sending object: {e}") + logger.error(traceback.format_exc()) + raise e + + async def put_async_with_timeout(self, obj: Any, timeout: float = 5.0): + """ + Send an object with timeout to detect connection failures. + + Args: + obj: The object to send + timeout: Timeout in seconds for the send operation + + Raises: + zmq.Again: If send operation times out (peer may be disconnected) + Exception: Other send errors + """ + self.setup_lazily() + try: + if self.use_hmac_encryption: + data = pickle.dumps(obj) # nosec B301 + signed_data = self._sign_data(data) + # Use asyncio.wait_for to implement timeout instead of zmq.NOBLOCK + await asyncio.wait_for(self.socket.send(signed_data), + timeout=timeout) + else: + await asyncio.wait_for(self.socket.send_pyobj(obj), + timeout=timeout) + except asyncio.TimeoutError: + # Convert timeout to zmq.Again to maintain compatibility with existing error handling + raise zmq.Again( + "Send operation timed out - peer may be disconnected") + except Exception as e: + logger.error(f"Error sending object: {e}") + logger.error(traceback.format_exc()) + raise e + def get(self) -> Any: self.setup_lazily() return self._recv_data() diff --git a/tensorrt_llm/executor/rpc.py b/tensorrt_llm/executor/rpc.py deleted file mode 100644 index 7c6bc0fcbde..00000000000 --- a/tensorrt_llm/executor/rpc.py +++ /dev/null @@ -1,473 +0,0 @@ -import asyncio -import concurrent.futures -import queue -import threading -import traceback -import uuid -from concurrent.futures import ThreadPoolExecutor -from typing import Any, NamedTuple, Optional - -from ..llmapi.utils import ManagedThread -from ..logger import logger -from .ipc import ZeroMqQueue - - -# --- Custom Exceptions --- -class RPCError(Exception): - """Custom exception for RPC-related errors raised on the client side.""" - - -class RPCTimeout(RPCError): - """Custom exception for when a client request times out.""" - - -class RPCRequest(NamedTuple): - request_id: str - method_name: str - args: tuple - kwargs: dict - need_response: bool = True - timeout: float = 0.5 - - -class RPCResponse(NamedTuple): - request_id: str - status: str - result: Any - - -class RPCServer: - """ - An RPC Server that listens for requests and executes them concurrently. - """ - - def __init__(self, - instance, - hmac_key=None, - num_workers: int = 1, - timeout: float = 0.5, - async_run_task: bool = False): - """ - Initializes the server with an instance. - - Args: - instance: The instance whose methods will be exposed via RPC. - hmac_key (bytes, optional): HMAC key for encryption. - num_workers (int): Number of worker threads. - timeout (int): Timeout for RPC calls. - async_run_task (bool): Whether to run the task asynchronously. - """ - self._instance = instance - self._hmac_key = hmac_key - self._num_workers = num_workers - self._address = None - self._timeout = timeout - self._client_socket = None - - # set the stop event to True, and all the workers will exit - self._stop_event = threading.Event() - - self._functions = {"shutdown": self.shutdown} - self._dispatcher_thread: Optional[ManagedThread] = None - if async_run_task: - self._executor = ThreadPoolExecutor(max_workers=num_workers, - thread_name_prefix="rpc_worker") - else: - self._executor = None - - self._queue = None - - # Automatically register the instance - self.register_instance(instance) - - logger.debug(f"RPC Server initialized with {num_workers} workers.") - - @property - def address(self) -> str: - assert self._client_socket is not None, "Client socket is not bound" - return self._client_socket.address[0] - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - self.shutdown() - - def bind(self, address="tcp://*:5555"): - """ - Bind the server to the specified address. - - Args: - address (str): The ZMQ address to bind the client-facing socket. - """ - self._address = address - self._client_socket = ZeroMqQueue(address=(address, self._hmac_key), - is_server=True, - is_async=True, - use_hmac_encryption=False) - logger.info(f"RPC Server bound to {self._address}") - - def shutdown(self): - """Internal method to trigger server shutdown.""" - logger.debug( - "RPC Server shutdown signal received. Terminating server...") - - if self._dispatcher_thread and self._dispatcher_thread.is_alive(): - self._stop_event.set() - self._dispatcher_thread.join() - self._dispatcher_thread = None - - if self._executor: - self._executor.shutdown(wait=False) - - if self._client_socket: - self._client_socket.close() - - self._client_socket = None - self._queue = None - - def register_function(self, func, name=None): - """Exposes a single function to clients.""" - fname = name or func.__name__ - if fname in self._functions: - logger.warning( - f"Function '{fname}' is already registered. Overwriting.") - self._functions[fname] = func - logger.debug(f"Registered function: {fname}") - - def register_instance(self, instance): - """Exposes all public methods of a class instance.""" - logger.debug( - f"Registering instance of class: {instance.__class__.__name__}") - for name in dir(instance): - if not name.startswith('_'): - attr = getattr(instance, name) - if callable(attr): - self.register_function(attr, name) - - async def _dispatcher_routine(self, stop_event: threading.Event): - assert self._client_socket is not None, "Client socket is not bound" - assert self._queue is not None, "RPC queue is not initialized" - - while not stop_event.is_set(): - try: - req: RPCRequest = await self._client_socket.get_async_noblock( - timeout=0.5) - logger.debug(f"RPC dispatcher got request: {req}") - except asyncio.TimeoutError: - logger.debug("RPC dispatcher get request timeout") - continue - - await self._queue.put(req) # type: ignore - - async def _worker_routine(self, stop_event: threading.Event): - """The routine executed by each worker thread.""" - assert self._client_socket is not None, "Client socket is not bound" - assert self._queue is not None, "RPC queue is not initialized" - - while not stop_event.is_set(): - try: - req: RPCRequest = await asyncio.wait_for( - self._queue.get(), # type: ignore - timeout=self._timeout) - except asyncio.TimeoutError: - logger.debug("RPC worker get request timeout") - continue - - if req.method_name in self._functions: - try: - if self._executor is not None: - # Dispatch to worker thread and await result with timeout - loop = asyncio.get_running_loop() - - # Create a wrapper function to handle keyword arguments - def call_with_kwargs(): - return self._functions[req.method_name]( - *req.args, **req.kwargs) - - result = await asyncio.wait_for(loop.run_in_executor( - self._executor, call_with_kwargs), - timeout=req.timeout) - else: - # For synchronous execution, we need to run in executor to support timeout - loop = asyncio.get_running_loop() - - # Create a wrapper function to handle keyword arguments - def call_with_kwargs(): - return self._functions[req.method_name]( - *req.args, **req.kwargs) - - result = await asyncio.wait_for(loop.run_in_executor( - None, call_with_kwargs), - timeout=req.timeout) - response = RPCResponse(req.request_id, 'OK', result) - except asyncio.TimeoutError: - response = RPCResponse( - req.request_id, 'ERROR', - f"Method '{req.method_name}' timed out after {req.timeout} seconds" - ) - except Exception: - tb = traceback.format_exc() - response = RPCResponse(req.request_id, 'ERROR', tb) - else: - response = RPCResponse( - req.request_id, 'ERROR', - f"Method '{req.method_name}' not found.") - - # Some tasks don't need response, e.g. submit_request or shutdown - if req.need_response: - await self._client_socket.put_async(response) - - def start(self): - """Binds sockets, starts workers, and begins proxying messages.""" - if self._client_socket is None: - raise RuntimeError( - "Server must be bound to an address before starting. Call bind() first." - ) - - self._client_socket.setup_lazily() - logger.info(f"RPC Server started and listening on {self._address}") - - async def tasks(): - self._queue = asyncio.Queue() - await asyncio.gather( - self._dispatcher_routine(self._stop_event), *[ - self._worker_routine(self._stop_event) - for i in range(self._num_workers) - ]) - - def loop() -> bool: - asyncio.run(tasks()) - return True # ManagedThread - - error_queue = queue.Queue() - self._dispatcher_thread = ManagedThread(task=loop, - stop_event=self._stop_event, - name="rpc_dispatcher_thread", - error_queue=error_queue) - self._dispatcher_thread.start() - - logger.info("RPC Server has started.") - - -Server = RPCServer - - -class RPCClient: - """ - An RPC Client that connects to the RPCServer. - """ - - def __init__(self, - address: str, - hmac_key=None, - timeout: float = 10, - num_workers: int = 4): - ''' - Args: - address: The ZMQ address to connect to. - hmac_key: The HMAC key for encryption. - timeout: The timeout (seconds) for RPC calls. - ''' - self._address = address - self._timeout = timeout - self._client_socket = ZeroMqQueue(address=(address, hmac_key), - is_server=False, - is_async=True, - use_hmac_encryption=False) - self._pending_futures = {} - self._reader_task = None - self._executor = concurrent.futures.ThreadPoolExecutor( - max_workers=num_workers, thread_name_prefix="rpc_client") - logger.info(f"RPC Client initialized. Connected to {self._address}") - - def __del__(self): - """Cleanup executor when client is destroyed.""" - self.close() - - def close(self): - """Gracefully close the client, cleaning up background tasks.""" - if self._reader_task: - self._reader_task.cancel() - self._reader_task = None - if self._executor: - self._executor.shutdown(wait=False) - - async def _response_reader(self): - """Task to read responses from the socket and set results on futures.""" - - while True: - try: - response: RPCResponse = await self._client_socket.get_async() - future = self._pending_futures.get(response.request_id) - if future and not future.done(): - if response.status == 'OK': - future.set_result(response.result) - elif response.status == 'ERROR': - # TODO: Maybe keep the original Error type? - future.set_exception( - RPCError( - f"Server-side exception:\n{response.result}")) - else: - future.set_exception( - RPCError( - f"Unknown response status: {response.status}")) - self._pending_futures.pop(response.request_id, None) - - except asyncio.CancelledError: - break - except Exception as e: - logger.error(f"Exception in RPC response reader: {e}") - # Propagate exception to all pending futures - for future in self._pending_futures.values(): - if not future.done(): - future.set_exception(e) - break - - await asyncio.sleep(0) - - self._reader_task = None - - async def _start_reader_if_needed(self): - if self._reader_task is None or self._reader_task.done(): - loop = asyncio.get_running_loop() - self._reader_task = loop.create_task(self._response_reader()) - - async def _call_async(self, name, *args, **kwargs): - """Async version of RPC call. - Args: - name: Method name to call - *args: Positional arguments - **kwargs: Keyword arguments - __rpc_timeout: The timeout (seconds) for the RPC call. - __rpc_need_response: Whether the RPC call needs a response. - If set to False, the remote call will return immediately. - - Returns: - The result of the remote method call - """ - logger.debug( - f"RPC client calling method: {name} with args: {args} and kwargs: {kwargs}" - ) - await self._start_reader_if_needed() - need_response = kwargs.pop("__rpc_need_response", True) - timeout = kwargs.pop("__rpc_timeout", self._timeout) - - request_id = uuid.uuid4().hex - logger.debug(f"RPC client sending request: {request_id}") - request = RPCRequest(request_id, - name, - args, - kwargs, - need_response, - timeout=timeout) - logger.debug(f"RPC client sending request: {request}") - await self._client_socket.put_async(request) - - if not need_response: - return None - - loop = asyncio.get_running_loop() - future = loop.create_future() - self._pending_futures[request_id] = future - - try: - # If timeout, the remote call should return a timeout error timely, - # so we add 1 second to the timeout to ensure the client can get - # that result. - return await asyncio.wait_for(future, timeout + 1) - except asyncio.TimeoutError: - raise RPCTimeout(f"Request '{name}' timed out after {timeout}s") - finally: - self._pending_futures.pop(request_id, None) - - def _call_sync(self, name, *args, **kwargs): - """Synchronous version of RPC call.""" - return asyncio.run(self._call_async(name, *args, **kwargs)) - - def call_async(self, name: str, *args, **kwargs): - """ - Call a remote method asynchronously. - - Args: - name: Method name to call - *args: Positional arguments - **kwargs: Keyword arguments - - Returns: - Coroutine that can be awaited - - Example: - result = await client.call_async('remote_method', arg1, arg2, key=value) - """ - return self._call_async(name, *args, **kwargs, __rpc_need_response=True) - - def call_future(self, name: str, *args, - **kwargs) -> concurrent.futures.Future: - """ - Call a remote method and return a Future. - - Args: - name: Method name to call - *args: Positional arguments - **kwargs: Keyword arguments - - Returns: - A Future object that can be used to retrieve the result - - Example: - future = client.call_future('remote_method', arg1, arg2, key=value) - result = future.result() # blocks until complete - # or - future.add_done_callback(lambda f: print(f.result())) - """ - - def _async_to_sync(): - return asyncio.run(self._call_async(name, *args, **kwargs)) - - return self._executor.submit(_async_to_sync) - - def call_sync(self, name: str, *args, **kwargs): - """ - Call a remote method synchronously (blocking). - - Args: - name: Method name to call - *args: Positional arguments - **kwargs: Keyword arguments - - Returns: - The result of the remote method call - - Example: - result = client.call_sync('remote_method', arg1, arg2, key=value) - """ - return self._call_sync(name, *args, **kwargs) - - def __getattr__(self, name): - """ - Magically handles calls to non-existent methods. - Returns a proxy object that supports multiple calling patterns. - """ - - class MethodProxy: - - def __init__(self, client, method_name): - self.client = client - self.method_name = method_name - - def __call__(self, *args, **kwargs): - """Default synchronous call""" - return self.client._call_sync(self.method_name, *args, **kwargs) - - def call_async(self, *args, **kwargs): - """Async call - returns coroutine""" - return self.client._call_async(self.method_name, *args, - **kwargs) - - def call_future(self, *args, **kwargs) -> concurrent.futures.Future: - """Future call - returns Future object""" - return self.client.call_future(self.method_name, *args, - **kwargs) - - return MethodProxy(self, name) diff --git a/tensorrt_llm/executor/rpc/__init__.py b/tensorrt_llm/executor/rpc/__init__.py new file mode 100644 index 00000000000..38c4924e1ab --- /dev/null +++ b/tensorrt_llm/executor/rpc/__init__.py @@ -0,0 +1,9 @@ +from .rpc_client import RPCClient +from .rpc_common import (RPCCancelled, RPCError, RPCRequest, RPCResponse, + RPCTimeout) +from .rpc_server import RPCServer, Server + +__all__ = [ + "RPCClient", "RPCServer", "Server", "RPCError", "RPCTimeout", + "RPCCancelled", "RPCRequest", "RPCResponse" +] diff --git a/tensorrt_llm/executor/rpc/rpc_client.py b/tensorrt_llm/executor/rpc/rpc_client.py new file mode 100644 index 00000000000..35ac2d42366 --- /dev/null +++ b/tensorrt_llm/executor/rpc/rpc_client.py @@ -0,0 +1,300 @@ +import asyncio +import concurrent.futures +import threading +import uuid + +from ...logger import logger +from ..ipc import ZeroMqQueue +from .rpc_common import RPCCancelled, RPCRequest, RPCResponse, RPCTimeout + + +class RPCClient: + """ + An RPC Client that connects to the RPCServer. + """ + + def __init__(self, + address: str, + hmac_key=None, + timeout: float = 10, + num_workers: int = 4): + ''' + Args: + address: The ZMQ address to connect to. + hmac_key: The HMAC key for encryption. + timeout: The timeout (seconds) for RPC calls. + ''' + self._address = address + self._timeout = timeout + self._client_socket = ZeroMqQueue(address=(address, hmac_key), + is_server=False, + is_async=True, + use_hmac_encryption=False) + self._pending_futures = {} + self._reader_task = None + self._executor = concurrent.futures.ThreadPoolExecutor( + max_workers=num_workers, thread_name_prefix="rpc_client") + + self._server_stopped = False + self._loop = None + self._loop_thread = None + + logger.debug(f"RPC Client initialized. Connected to {self._address}") + + def shutdown_server(self): + """Shutdown the server.""" + if self._server_stopped: + return + + self.call_sync("__rpc_shutdown") + + self._server_stopped = True + + def close(self): + """Gracefully close the client, cleaning up background tasks.""" + if self._reader_task: + self._reader_task.cancel() + self._reader_task = None + if self._loop and self._loop.is_running(): + self._loop.call_soon_threadsafe(self._loop.stop) + if self._loop_thread: + self._loop_thread.join() + self._loop_thread = None + if self._executor: + self._executor.shutdown(wait=True) + + async def _response_reader(self): + """Task to read responses from the socket and set results on futures.""" + + while True: + try: + response: RPCResponse = await self._client_socket.get_async() + logger.debug(f"RPC Client received response: {response}") + future = self._pending_futures.get(response.request_id) + if future and not future.done(): + if response.error is None: + future.set_result(response.result) + else: + # Use the original RPCError from the response + future.set_exception(response.error) + self._pending_futures.pop(response.request_id, None) + + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Exception in RPC response reader: {e}") + # Propagate exception to all pending futures + for future in self._pending_futures.values(): + if not future.done(): + future.set_exception(e) + break + + await asyncio.sleep(0) + + self._reader_task = None + + def _start_response_reader_lazily(self): + if self._reader_task is None or self._reader_task.done(): + # Ensure we have a persistent background loop + self._ensure_event_loop() + # Always create the reader task on the persistent loop + future = asyncio.run_coroutine_threadsafe(self._response_reader(), + self._loop) + # Store the concurrent.futures.Future + self._reader_task = future + + async def _call_async(self, __rpc_method_name, *args, **kwargs): + """Async version of RPC call. + Args: + __rpc_method_name: Method name to call + *args: Positional arguments + **kwargs: Keyword arguments + __rpc_timeout: The timeout (seconds) for the RPC call. + __rpc_need_response: Whether the RPC call needs a response. + If set to False, the remote call will return immediately. + + Returns: + The result of the remote method call + """ + logger.debug( + f"RPC client calling method: {__rpc_method_name} with args: {args} and kwargs: {kwargs}" + ) + if self._server_stopped: + raise RPCCancelled("Server is shutting down, request cancelled") + + self._start_response_reader_lazily() + need_response = kwargs.pop("__rpc_need_response", True) + timeout = kwargs.pop("__rpc_timeout", self._timeout) + + request_id = uuid.uuid4().hex + logger.debug(f"RPC client sending request: {request_id}") + request = RPCRequest(request_id, + __rpc_method_name, + args, + kwargs, + need_response, + timeout=timeout) + logger.debug(f"RPC client sending request: {request}") + await self._client_socket.put_async(request) + + if not need_response: + return None + + loop = asyncio.get_running_loop() + future = loop.create_future() + self._pending_futures[request_id] = future + + try: + # If timeout, the remote call should return a timeout error timely, + # so we add 1 second to the timeout to ensure the client can get + # that result. + res = await asyncio.wait_for(future, timeout + 1) + return res + except RPCCancelled: + self._server_stopped = True + raise + except asyncio.TimeoutError: + raise RPCTimeout( + f"Request '{__rpc_method_name}' timed out after {timeout}s") + except Exception as e: + raise e + finally: + self._pending_futures.pop(request_id, None) + + def _ensure_event_loop(self): + """Ensure we have a running event loop in a background thread.""" + if self._loop is None or not self._loop.is_running(): + self._loop = asyncio.new_event_loop() + + def run_loop(): + asyncio.set_event_loop(self._loop) + self._loop.run_forever() + + self._loop_thread = threading.Thread(target=run_loop, daemon=True) + self._loop_thread.start() + + # Give the loop a moment to start + import time + time.sleep(0.1) + + def _call_sync(self, __rpc_method_name, *args, **kwargs): + """Synchronous version of RPC call.""" + self._ensure_event_loop() + future = asyncio.run_coroutine_threadsafe( + self._call_async(__rpc_method_name, *args, **kwargs), self._loop) + return future.result() + + def call_async(self, name: str, *args, **kwargs): + """ + Call a remote method asynchronously. + + Args: + name: Method name to call + *args: Positional arguments + **kwargs: Keyword arguments + + Returns: + Coroutine that can be awaited + + Example: + result = await client.call_async('remote_method', arg1, arg2, key=value) + """ + return self._call_async(name, *args, **kwargs, __rpc_need_response=True) + + def call_future(self, name: str, *args, + **kwargs) -> concurrent.futures.Future: + """ + Call a remote method and return a Future. + + Args: + name: Method name to call + *args: Positional arguments + **kwargs: Keyword arguments + + Returns: + A Future object that can be used to retrieve the result + + Example: + future = client.call_future('remote_method', arg1, arg2, key=value) + result = future.result() # blocks until complete + # or + future.add_done_callback(lambda f: print(f.result())) + """ + + def _async_to_sync(): + self._ensure_event_loop() + future = asyncio.run_coroutine_threadsafe( + self._call_async(name, *args, **kwargs), self._loop) + return future.result() + + return self._executor.submit(_async_to_sync) + + def call_sync(self, name: str, *args, **kwargs): + """ + Call a remote method synchronously (blocking). + + Args: + name: Method name to call + *args: Positional arguments + **kwargs: Keyword arguments + + Returns: + The result of the remote method call + + Example: + result = client.call_sync('remote_method', arg1, arg2, key=value) + """ + return self._call_sync(name, *args, **kwargs) + + def get_server_attr(self, name: str): + """ Get the attribute of the RPC server. + This is mainly used for testing. """ + return self._call_sync("__rpc_get_attr", name, __rpc_timeout=10) + + def __getattr__(self, name): + """ + Magically handles calls to non-existent methods. + Returns a proxy object that supports multiple calling patterns. + """ + + class MethodProxy: + + def __init__(self, client, method_name): + self.client = client + self.method_name = method_name + + def __call__(self, *args, **kwargs): + """Default synchronous call""" + mode = kwargs.pop("__rpc_mode", "sync") + if mode == "sync": + return self.client._call_sync(self.method_name, *args, + **kwargs) + elif mode == "async": + return self.client._call_async(self.method_name, *args, + **kwargs) + elif mode == "future": + return self.client.call_future(self.method_name, *args, + **kwargs) + else: + raise ValueError(f"Invalid RPC mode: {mode}") + + def call_async(self, *args, **kwargs): + """Async call - returns coroutine""" + return self.client._call_async(self.method_name, *args, + **kwargs) + + def call_future(self, *args, **kwargs) -> concurrent.futures.Future: + """Future call - returns Future object""" + return self.client.call_future(self.method_name, *args, + **kwargs) + + return MethodProxy(self, name) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + + def __del__(self): + self.close() diff --git a/tensorrt_llm/executor/rpc/rpc_common.py b/tensorrt_llm/executor/rpc/rpc_common.py new file mode 100644 index 00000000000..22b85097555 --- /dev/null +++ b/tensorrt_llm/executor/rpc/rpc_common.py @@ -0,0 +1,46 @@ +from typing import Any, NamedTuple, Optional + + +# --- Custom Exceptions --- +class RPCError(Exception): + """Custom exception for RPC-related errors raised on the client side. + + Args: + message: The error message. + cause: The original exception that caused this error. + traceback: The traceback of the exception. + """ + + def __init__(self, + message: str, + cause: Optional[Exception] = None, + traceback: Optional[str] = None): + super().__init__(message) + self.cause = cause + self.traceback = traceback + + +class RPCTimeout(RPCError): + """Exception for when a request processing times out.""" + + +class RPCCancelled(RPCError): + """Exception for when a client request is cancelled. + This happens when the server is shutting down and all the pending + requests will be cancelled and return with this error. + """ + + +class RPCRequest(NamedTuple): + request_id: str + method_name: str + args: tuple + kwargs: dict + need_response: bool = True + timeout: float = 0.5 + + +class RPCResponse(NamedTuple): + request_id: str + result: Any + error: Optional[RPCError] = None diff --git a/tensorrt_llm/executor/rpc/rpc_server.py b/tensorrt_llm/executor/rpc/rpc_server.py new file mode 100644 index 00000000000..b0d1377569e --- /dev/null +++ b/tensorrt_llm/executor/rpc/rpc_server.py @@ -0,0 +1,286 @@ +import asyncio +import queue +import threading +import time +import traceback +from concurrent.futures import ThreadPoolExecutor +from typing import Optional + +from ...llmapi.utils import ManagedThread +from ...logger import logger +from ..ipc import ZeroMqQueue +from .rpc_common import RPCError, RPCRequest, RPCResponse, RPCTimeout + + +class RPCServer: + """ + An RPC Server that listens for requests and executes them concurrently. + """ + + def __init__(self, + instance, + hmac_key=None, + num_workers: int = 1, + timeout: float = 0.5, + async_run_task: bool = False): + """ + Initializes the server with an instance. + + Args: + instance: The instance whose methods will be exposed via RPC. + hmac_key (bytes, optional): HMAC key for encryption. + num_workers (int): Number of worker threads. + timeout (int): Timeout for RPC calls. + async_run_task (bool): Whether to run the task asynchronously. + """ + self._instance = instance + self._hmac_key = hmac_key + self._num_workers = num_workers + self._address = None + self._timeout = timeout + self._client_socket = None + + # set the stop event to True, and all the workers will exit + self._stop_event = threading.Event() + + self._num_pending_requests = 0 + + self._functions = { + "__rpc_shutdown": lambda: self.shutdown(is_remote_call=True), + "__rpc_get_attr": lambda name: self.get_attr(name), + } + self._dispatcher_thread: Optional[ManagedThread] = None + if async_run_task: + self._executor = ThreadPoolExecutor(max_workers=num_workers, + thread_name_prefix="rpc_worker") + else: + self._executor = None + + self._queue = None + + # Automatically register the instance + self.register_instance(instance) + + logger.debug(f"RPC Server initialized with {num_workers} workers.") + + @property + def address(self) -> str: + assert self._client_socket is not None, "Client socket is not bound" + return self._client_socket.address[0] + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.shutdown() + + def bind(self, address="tcp://*:5555"): + """ + Bind the server to the specified address. + + Args: + address (str): The ZMQ address to bind the client-facing socket. + """ + self._address = address + self._client_socket = ZeroMqQueue(address=(address, self._hmac_key), + is_server=True, + is_async=True, + use_hmac_encryption=False) + logger.info(f"RPC Server bound to {self._address}") + + def shutdown(self, is_remote_call: bool = False): + """Internal method to trigger server shutdown. + + Args: + is_remote_call: Whether the shutdown is called by a remote call. + This should be True when client.server_shutdown() is called. + """ + # NOTE: shutdown is also a remote method, so it could be executed by + # a thread in a worker executor thread + + if self._stop_event.is_set(): + return + + logger.debug( + "RPC Server shutdown signal received. Terminating server...") + + # Set the stop event to True, this will trigger the dispatcher routine and + # the worker routine to prepare for exit, like stopping accepting new requests, + # and continue to process the pending requests. + self._stop_event.set() + + # The worker routine should process the pending requests + logger.debug( + f"RPC Server shutdown: {self._num_pending_requests} pending requests" + ) + while self._num_pending_requests > 0: + time.sleep(0.01) + logger.debug(f"RPC Server shutdown finished pending requests") + + if not is_remote_call: + # Block the thread until shutdown is finished + + # 1. Wait for the dispatcher thread to exit, so that no new requests are accepted + logger.debug(f"RPC Server dispatcher thread joining") + if self._dispatcher_thread: + self._dispatcher_thread.join() + self._dispatcher_thread = None + logger.debug(f"RPC Server dispatcher thread joined") + + # 2. Wait for the executor to exit, it will wait for the pending requests to be processed + if self._executor: + self._executor.shutdown(wait=True) + self._executor = None + + # 3. (Optionally) Close the client socket, this doesn't affect + # anything since zmq client will not timeout even if the target is not available + if self._client_socket: + self._client_socket.close() + else: + # if the shutdown is called by a remote call, this method itself will + # be executed in a executor thread, so we cannot join the dispatcher thread as + # the dispatcher thread is awaiting for the shutdown result. + logger.debug( + f"RPC Server to shutdown: {self._num_pending_requests} pending requests" + ) + + while self._num_pending_requests > 0: + time.sleep(0.01) + logger.debug(f"RPC Server shutdown finished pending requests") + + def register_function(self, func, name=None): + """Exposes a single function to clients.""" + fname = name or func.__name__ + if fname in self._functions: + logger.warning( + f"Function '{fname}' is already registered. Overwriting.") + self._functions[fname] = func + logger.debug(f"Registered function: {fname}") + + def register_instance(self, instance): + """Exposes all public methods of a class instance.""" + logger.debug( + f"Registering instance of class: {instance.__class__.__name__}") + for name in dir(instance): + if not name.startswith('_'): + attr = getattr(instance, name) + if callable(attr): + self.register_function(attr, name) + + def get_attr(self, name: str): + """ Get the attribute of the RPC server. + This is mainly used for testing. """ + return getattr(self, name) + + async def _dispatcher_routine(self, stop_event: threading.Event): + assert self._client_socket is not None, "Client socket is not bound" + assert self._queue is not None, "RPC queue is not initialized" + + # Once shutdown, the dispatcher will exit first, and the workers will + # continue to process the pending requests. + while not stop_event.is_set(): + try: + req: RPCRequest = await self._client_socket.get_async_noblock( + timeout=0.5) + logger.debug(f"RPC dispatcher got request: {req}") + except asyncio.TimeoutError: + await asyncio.sleep(0) + continue + + await self._queue.put(req) # type: ignore + + # shutdown is a builtin method depends on _num_pending_requests, so + # it should not be counted + if req.method_name != "__rpc_shutdown": + self._num_pending_requests += 1 + + async def _worker_routine(self, stop_event: threading.Event): + """The routine executed by each worker thread.""" + assert self._client_socket is not None, "Client socket is not bound" + assert self._queue is not None, "RPC queue is not initialized" + + while (not stop_event.is_set()) or self._num_pending_requests > 0: + try: + req: RPCRequest = await asyncio.wait_for( + self._queue.get(), # type: ignore + timeout=self._timeout) + except asyncio.TimeoutError: + await asyncio.sleep(0) + continue + + response = await self._process_request(req) + + # Some tasks don't need response, e.g. submit_request or shutdown + if req.need_response: + logger.debug(f"RPC Server sending response for request {req}") + await self._client_socket.put_async(response) + logger.debug(f"RPC Server sent response for request {req}") + + self._num_pending_requests -= 1 + + async def _process_request(self, req: RPCRequest) -> RPCResponse: + if req.method_name not in self._functions: + return RPCResponse( + req.request_id, None, + RPCError(f"Method '{req.method_name}' not found in RPC server.", + traceback=traceback.format_exc())) + + try: + loop = asyncio.get_running_loop() + + def call_with_kwargs(): + return self._functions[req.method_name](*req.args, **req.kwargs) + + result = await asyncio.wait_for(loop.run_in_executor( + self._executor, call_with_kwargs), + timeout=req.timeout) + logger.debug(f"RPC Server returned result for request {req}") + response = RPCResponse(req.request_id, result) + + except asyncio.TimeoutError: + response = RPCResponse( + req.request_id, None, + RPCTimeout( + f"Method '{req.method_name}' timed out after {req.timeout} seconds", + traceback=traceback.format_exc())) + + except Exception as e: + response = RPCResponse( + req.request_id, None, + RPCError(str(e), cause=e, traceback=traceback.format_exc())) + + return response + + def start(self): + """Binds sockets, starts workers, and begins proxying messages.""" + if self._client_socket is None: + raise RuntimeError( + "Server must be bound to an address before starting. Call bind() first." + ) + + self._client_socket.setup_lazily() + logger.info(f"RPC Server started and listening on {self._address}") + + async def tasks(): + self._queue = asyncio.Queue() + await asyncio.gather( + self._dispatcher_routine(self._stop_event), *[ + self._worker_routine(self._stop_event) + for i in range(self._num_workers) + ]) + + def loop() -> bool: + asyncio.run(tasks()) + return True # ManagedThread + + error_queue = queue.Queue() + self._dispatcher_thread = ManagedThread(task=loop, + stop_event=self._stop_event, + name="rpc_dispatcher_thread", + error_queue=error_queue) + self._dispatcher_thread.start() + + logger.info("RPC Server has started.") + + +Server = RPCServer diff --git a/tensorrt_llm/executor/rpc_proxy.py b/tensorrt_llm/executor/rpc_proxy.py index 6f120a9a3fd..ea637e5d6c7 100644 --- a/tensorrt_llm/executor/rpc_proxy.py +++ b/tensorrt_llm/executor/rpc_proxy.py @@ -6,12 +6,15 @@ from ..llmapi.mpi_session import MpiPoolSession, MpiSession from ..llmapi.tracer import global_tracer -from ..llmapi.utils import _SyncQueue, print_colored_debug +from ..llmapi.utils import (_SyncQueue, print_colored_debug, + print_traceback_on_error) +from ..logger import logger from .executor import GenerationExecutor from .postproc_worker import PostprocWorkerConfig from .request import GenerationRequest from .result import GenerationResult from .rpc import RPCClient +from .rpc_worker import RpcWorker from .utils import (ErrorResponse, create_mpi_comm_session, get_spawn_proxy_process_env, is_llm_response) @@ -39,6 +42,7 @@ def __init__(self, garbage_collection_gen0_threshold: the garbage collection gen0 threshold clock_unit: the unit of the clock, 1 means 1 second """ + self.clock_unit = clock_unit GenerationExecutorRpcProxy.INSTANCE_COUNTER += 1 self.rpc_addr = self.gen_uniq_rpc_addr() @@ -55,26 +59,30 @@ def __init__(self, is_llm_executor=is_llm_executor, ) - self.mpi_session = self._create_mpi_session(model_world_size, - mpi_session) + self._results = {} + + self._create_mpi_session(model_world_size, mpi_session) self._shutdown_event = threading.Event() + self.worker_kwargs = worker_kwargs self.launch_workers() time.sleep(1) # wait for the workers to launch # Invoke model creation on the remote # TBD: Move model creation to the mpi task, or left in RPC? - self.create_engine_remote() + self.setup_engine_remote() self.setup_mainloop() def launch_workers(self): + logger.debug(f"Launching workers") assert self.mpi_session is not None - self.mpi_session.submit(rpc_worker_main, + self.mpi_session.submit(RpcWorker.main_task, rpc_addr=self.rpc_addr, **self.worker_kwargs) + @print_traceback_on_error def main_loop_task(self): """ Main loop of the proxy, it will invoke the actions periodically. @@ -82,10 +90,10 @@ def main_loop_task(self): clock = 0 while not self._shutdown_event.is_set(): if clock % 1 == 0: - responses = self.await_responses_remote() + responses = self.fetch_responses_remote() self.handle_responses(responses) if clock % 10 == 0: - stats = self.get_stats_remote() # TODO + stats = self.fetch_stats_remote() # TODO self.handle_stats(stats) clock += 1 @@ -101,23 +109,24 @@ def handle_responses(self, responses: list[GenerationResult]) -> bool: async_queues = [] event_loop = None - def process_res(res): - client_id = res.client_id - nonlocal event_loop - nonlocal async_queues - - queue = self._results[client_id].queue - if isinstance(queue, _SyncQueue): - queue.put_nowait(res) - async_queues.append(queue) - # all the loops are identical - event_loop = event_loop or queue.loop - else: - queue.put(res) - - if (is_llm_response(res) and res.result.is_final) or isinstance( - res, ErrorResponse): - self._results.pop(client_id) + def process_res(res: list): + for r in res: + client_id = r.client_id + nonlocal event_loop + nonlocal async_queues + + queue = self._results[client_id].queue + if isinstance(queue, _SyncQueue): + queue.put_nowait(r) + async_queues.append(queue) + # all the loops are identical + event_loop = event_loop or queue.loop + else: + queue.put(r) + + if (is_llm_response(r) and r.result.is_final) or isinstance( + r, ErrorResponse): + self._results.pop(client_id) for res in responses: global_tracer().log_instant("RPC.get") @@ -127,20 +136,64 @@ def process_res(res): _SyncQueue.notify_many(event_loop, async_queues) def handle_stats(self, stats: dict): - raise NotImplementedError + # raise NotImplementedError + pass def submit(self, request: GenerationRequest) -> GenerationResult: + request.set_id(self._get_next_client_id()) + logprob_params = self._get_logprob_params(request) + # submit is a fire-and-forget operation, don't need to wait for response - return self.rpc_client.submit(request, need_response=False) + self.rpc_client.submit(request, __rpc_need_response=False) - def await_responses_remote(self): - return self.rpc_client.await_responses() + result = GenerationResult( + request, + background_error_handler=self._handle_background_error, + executor=self, + disaggregated_params=request.disaggregated_params, + logprob_params=logprob_params) + self._results[request.id] = result - def create_engine_remote(self): - return self.rpc_client.create_engine() # TODO + return result + + def fetch_responses_remote(self): + return self.rpc_client.fetch_responses(__rpc_timeout=20) + + def fetch_stats_remote(self): + return self.rpc_client.fetch_stats() + + def setup_engine_remote(self): + return self.rpc_client.setup_engine(__rpc_timeout=60 * 20) # 20 min def shutdown_remote(self): - self.rpc_client.shutdown() + self.rpc_client.shutdown(__rpc_timeout=60 * 20) # 20 min + + def abort_request(self, request_id: int) -> None: + return self.rpc_client.abort_request(request_id) + + def shutdown(self): + if self._shutdown_event.is_set(): + return + + # 1. stop the main loop, so that no new rpc requests + self._shutdown_event.set() + self.main_loop_thread.join() + + # 2. shutdown the rpc server (PyExecutor Rank 0 + RPC server) + self.shutdown_remote() + + # 3. shutdown the mpi session, this should wait until all the PyExecutor + # processes are shutdown + if self.mpi_session is not None: + self.mpi_session.shutdown() + + self.rpc_client.close() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.shutdown() def _create_mpi_session(self, model_world_size: int, mpi_session: Optional[MpiSession]): diff --git a/tensorrt_llm/executor/rpc_worker.py b/tensorrt_llm/executor/rpc_worker.py index 6c748428229..8ff41a7fc34 100644 --- a/tensorrt_llm/executor/rpc_worker.py +++ b/tensorrt_llm/executor/rpc_worker.py @@ -3,10 +3,12 @@ from threading import Event from typing import Optional, Union +from tensorrt_llm.llmapi.utils import enable_llm_debug + from .._utils import mpi_rank from ..bindings import executor as tllm from ..builder import Engine -from ..logger import logger +from ..logger import logger, set_level from ..lora_manager import LoraConfig from ..sampling_params import BatchedLogitsProcessor from .postproc_worker import PostprocWorkerConfig @@ -31,18 +33,28 @@ def __init__( engine: Union[Path, Engine], executor_config: Optional[tllm.ExecutorConfig] = None, is_llm_executor: Optional[bool] = None, + lora_config: Optional[LoraConfig] = None, + garbage_collection_gen0_threshold: Optional[int] = None, ) -> None: - super().__init__(engine=engine, - executor_config=executor_config, - is_llm_executor=is_llm_executor) + super().__init__( + engine=engine, + executor_config=executor_config, + is_llm_executor=is_llm_executor, + lora_config=lora_config, + garbage_collection_gen0_threshold=garbage_collection_gen0_threshold) self.shutdown_event = Event() self._response_queue = Queue() self.set_result_queue(self._response_queue) + def fetch_stats(self) -> list: + return super().fetch_stats() + def fetch_responses(self) -> list: - logger.debug(f"RPC worker {mpi_rank()} is fetching responses") + logger.debug(f"RpcWorker {mpi_rank()} is fetching responses") + # NOTE: This is a blocking call, it will wait for the responses to be available. super().await_responses() + logger.debug(f"RpcWorker returning responses") qsize = self._response_queue.qsize() return [self._response_queue.get() for _ in range(qsize)] @@ -62,31 +74,43 @@ def main_task( is_llm_executor: Optional[bool] = None, lora_config: Optional[LoraConfig] = None, garbage_collection_gen0_threshold: Optional[int] = None, + **kwargs, ) -> None: + if enable_llm_debug(): + set_level("debug") + # Step 1: Create the worker instance - worker = RpcWorker(engine=engine, executor_config=executor_config) + worker = RpcWorker( + engine=engine, + executor_config=executor_config, + is_llm_executor=is_llm_executor, + lora_config=lora_config, + garbage_collection_gen0_threshold=garbage_collection_gen0_threshold) if mpi_rank() != 0: logger.debug(f"Worker {mpi_rank()} is setting up the engine") # The non-leader worker will setup the engine immediately. # The leader worker will wait for the RPC call to propagate the # potential error. - worker.setup_engine( - engine=engine, - executor_config=executor_config, - batched_logits_processor=batched_logits_processor, - postproc_worker_config=postproc_worker_config, - is_llm_executor=is_llm_executor, - lora_config=lora_config, - garbage_collection_gen0_threshold= - garbage_collection_gen0_threshold) + logger.debug(f"Worker {mpi_rank()} is setting up the engine") + worker.setup_engine() if mpi_rank() == 0: + logger.debug(f"Worker {mpi_rank()} is creating the RPC service") # Step 2: Create the RPC service, it will expose all the APIs of the worker as remote call to the client rpc_server = RPCServer(worker) rpc_server.bind(rpc_addr) rpc_server.start() # Step 3: Wait for the worker to shutdown + logger.debug( + f"Worker {mpi_rank()} is waiting for the worker to shutdown") worker.shutdown_event.wait() rpc_server.shutdown() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.shutdown() + return True diff --git a/tensorrt_llm/executor/worker_base.py b/tensorrt_llm/executor/worker_base.py index ef3494816b7..8353129d2d9 100644 --- a/tensorrt_llm/executor/worker_base.py +++ b/tensorrt_llm/executor/worker_base.py @@ -55,9 +55,16 @@ def __init__( engine: Union[Path, Engine], executor_config: Optional[tllm.ExecutorConfig] = None, is_llm_executor: Optional[bool] = None, + lora_config: Optional[LoraConfig] = None, + garbage_collection_gen0_threshold: Optional[int] = None, ) -> None: super().__init__(is_llm_executor=is_llm_executor) + # Persist constructor arguments for deferred setup + self._engine_input = engine + self._lora_config = lora_config + self._garbage_collection_gen0_threshold = garbage_collection_gen0_threshold + self.engine = None self.rank = mpi_rank() self.global_rank = global_mpi_rank() @@ -84,12 +91,7 @@ def __init__( self._prompt_adapter_manager: Optional[PromptAdapterManager] = None self._runtime_model_config: Optional[ModelConfig] = None - def setup_engine( - self, - engine: Union[Path, Engine], - executor_config: Optional[tllm.ExecutorConfig] = None, - lora_config: Optional[LoraConfig] = None, - garbage_collection_gen0_threshold: Optional[int] = None) -> None: + def setup_engine(self) -> None: device_id = self.global_rank % torch.cuda.device_count() torch.cuda.set_device(device_id) @@ -99,12 +101,15 @@ def setup_engine( comm_ranks = mpi_comm().allgather(global_rank) device_ids = mpi_comm().allgather(device_id) + executor_config = self._executor_config if executor_config is None: executor_config = tllm.ExecutorConfig(1) + self._executor_config = executor_config executor_config.parallel_config = tllm.ParallelConfig( participant_ids=comm_ranks, device_ids=device_ids) + engine = self._engine_input if isinstance(engine, list): engine = engine[self.rank] @@ -127,9 +132,9 @@ def setup_engine( from tensorrt_llm._torch.pyexecutor.py_executor_creator import \ create_py_executor create_executor = create_py_executor - args["lora_config"] = lora_config + args["lora_config"] = self._lora_config args[ - "garbage_collection_gen0_threshold"] = garbage_collection_gen0_threshold + "garbage_collection_gen0_threshold"] = self._garbage_collection_gen0_threshold elif executor_config.backend == "_autodeploy": from tensorrt_llm._torch.auto_deploy.shim.ad_executor import \ create_autodeploy_executor @@ -139,7 +144,7 @@ def setup_engine( f"Unsupported backend config: {executor_config.backend}") self.engine = create_executor(**args) - self._setup_lora(engine, executor_config, lora_config) + self._setup_lora(engine, executor_config, self._lora_config) def _setup_lora(self, engine: Union[Path, Engine], executor_config: tllm.ExecutorConfig, @@ -406,8 +411,8 @@ def __exit__(self, exc_type, exc_value, traceback): self.shutdown() return True - def await_responses(self) -> None: - self._await_response_helper() + def await_responses(self, timeout: Optional[float] = None) -> None: + self._await_response_helper(timeout) logger.debug(f"worker done await_responses") def fetch_kv_cache_events(self) -> list: @@ -525,11 +530,11 @@ def responses_handler(self, responses: List[tllm.Response]): case _: raise NotImplementedError - def __call__(self) -> bool: + def __call__(self, timeout: Optional[float] = None) -> bool: ''' This method should be called by a ManagedThread. ''' logger.debug(f"await_response: {self.worker.engine}") - responses = self.worker.engine.await_responses( - timeout=datetime.timedelta(milliseconds=100)) + timeout = datetime.timedelta(seconds=timeout or 0.1) + responses = self.worker.engine.await_responses(timeout=timeout) logger.debug(f"PyExecutor returned {len(responses)} responses") # filter since The _engine_response_callback may return None diff --git a/tests/unittest/executor/test_rpc.py b/tests/unittest/executor/test_rpc.py index d00a4cc673b..b421a1d99a7 100644 --- a/tests/unittest/executor/test_rpc.py +++ b/tests/unittest/executor/test_rpc.py @@ -2,130 +2,287 @@ import pytest -from tensorrt_llm.executor.rpc import RPCClient, RPCError, RPCServer +from tensorrt_llm.executor.rpc import (RPCCancelled, RPCClient, RPCError, + RPCServer, RPCTimeout) -def test_rpc_server_basics(): +class RpcServerWrapper(RPCServer): - class App: + def __init__(self, *args, addr: str, **kwargs): + super().__init__(*args, **kwargs) + self.addr = addr - def hello(self): - print("hello") + def __enter__(self): + self.bind(self.addr) + self.start() + return self - server = RPCServer(App()) - print("bind") - server.bind("ipc:///tmp/rpc_test") - print("start") - server.start() - print("sleep") + def __exit__(self, exc_type, exc_value, traceback): + self.shutdown() - time.sleep(1) - print("shutdown") - server.shutdown() +class TestRpcBasics: -def test_rpc_client_context_manager(): + def test_rpc_server_basics(self): - class App: + class App: - def hello(self): - print("hello") + def hello(self): + print("hello") - with RPCServer(App()) as server: - server.bind("ipc:///tmp/rpc_test") - server.start() - time.sleep(1) + with RpcServerWrapper(App(), addr="ipc:///tmp/rpc_test") as server: + pass + def test_remote_call_without_arg(self): -def test_rpc_hello_without_arg(): + class App: - class App: + def hello(self): + print("hello") + return "world" - def hello(self): - print("hello") - return "world" + with RpcServerWrapper(App(), addr="ipc:///tmp/rpc_test") as server: + with RPCClient("ipc:///tmp/rpc_test") as client: + ret = client.hello() # sync call + assert ret == "world" - with RPCServer(App()) as server: - server.bind("ipc:///tmp/rpc_test") - server.start() - time.sleep(0.1) - client = RPCClient("ipc:///tmp/rpc_test") - ret = client.hello() # sync call - assert ret == "world" + def test_remote_call_with_args(self): + class App: -def test_rpc_hello_with_arg(): + def hello(self, name: str, location: str): + print("hello") + return f"hello {name} from {location}" - class App: + with RpcServerWrapper(App(), addr="ipc:///tmp/rpc_test") as server: + with RPCClient("ipc:///tmp/rpc_test") as client: + ret = client.hello("app", "Marvel") + assert ret == "hello app from Marvel" - def hello(self, name: str, location: str): - print("hello") - return f"hello {name} from {location}" + def test_remote_call_with_kwargs(self): - with RPCServer(App()) as server: - server.bind("ipc:///tmp/rpc_test") - server.start() - time.sleep(0.1) - client = RPCClient("ipc:///tmp/rpc_test") - ret = client.hello("app", location="Marvel") # sync call - assert ret == "hello app from Marvel" + class App: + def hello(self, name: str, location: str): + print("hello") + return f"hello {name} from {location}" -def test_rpc_server_address(): + with RpcServerWrapper(App(), addr="ipc:///tmp/rpc_test") as server: + with RPCClient("ipc:///tmp/rpc_test") as client: + ret = client.hello(name="app", location="Marvel") + assert ret == "hello app from Marvel" - class App: + def test_remote_call_with_args_and_kwargs(self): - def hello(self): - print("hello") - return "world" + class App: - with RPCServer(App()) as server: - server.bind("ipc:///tmp/rpc_test") - server.start() - time.sleep(0.1) - assert server.address == "ipc:///tmp/rpc_test" + def hello(self, name: str, location: str): + print("hello") + return f"hello {name} from {location}" + with RpcServerWrapper(App(), addr="ipc:///tmp/rpc_test") as server: + with RPCClient("ipc:///tmp/rpc_test") as client: + ret = client.hello(name="app", location="Marvel") + assert ret == "hello app from Marvel" -def test_rpc_with_error(): + def test_rpc_server_address(self): - class App: + class App: + pass - def hello(self): - raise ValueError("hello") + with RpcServerWrapper(App(), addr="ipc:///tmp/rpc_test") as server: + assert server.address == "ipc:///tmp/rpc_test" - with RPCServer(App()) as server: - server.bind("ipc:///tmp/rpc_test_error") + def test_rpc_with_error(self): + + class App: + + def hello(self): + raise ValueError("hello") + + with RpcServerWrapper(App(), + addr="ipc:///tmp/rpc_test_error") as server: + with RPCClient("ipc:///tmp/rpc_test_error") as client: + with pytest.raises(RPCError): + client.hello() + + def test_rpc_without_wait_response(self): + + class App: + + def __init__(self): + self.task_submitted = False + + def send_task(self) -> None: + # Just submit the task and return immediately + # The result is not important + self.task_submitted = True + return None + + def get_task_submitted(self) -> bool: + return self.task_submitted + + with RpcServerWrapper(App(), + addr="ipc:///tmp/rpc_test_no_wait") as server: + with RPCClient("ipc:///tmp/rpc_test_no_wait") as client: + client.send_task(__rpc_need_response=False) + time.sleep( + 0.1 + ) # wait for some time to make sure the task is submitted + assert client.get_task_submitted() + + +class TestRpcError: + + class CustomError(Exception): + pass + + def test_task_error(self): + """Test that server-side exceptions are properly wrapped in RPCError with details.""" + + class App: + + def hello(self): + raise ValueError("Test error message") + + def divide_by_zero(self): + return 1 / 0 + + def custom_exception(self): + raise TestRpcError.CustomError("Custom error occurred") + + with RPCServer(App()) as server: + server.bind("ipc:///tmp/rpc_test_error") + server.start() + time.sleep(0.1) + with RPCClient("ipc:///tmp/rpc_test_error") as client: + # Test ValueError handling + with pytest.raises(RPCError) as exc_info: + client.hello() + + error = exc_info.value + assert "Test error message" in str(error) + assert error.cause is not None + assert isinstance(error.cause, ValueError) + assert error.traceback is not None + assert "ValueError: Test error message" in error.traceback + + # Test ZeroDivisionError handling + with pytest.raises(RPCError) as exc_info: + client.divide_by_zero() + + error = exc_info.value + assert "division by zero" in str(error) + assert error.cause is not None + assert isinstance(error.cause, ZeroDivisionError) + assert error.traceback is not None + + # Test custom exception handling + with pytest.raises(RPCError) as exc_info: + client.custom_exception() + + error = exc_info.value + assert "Custom error occurred" in str(error) + assert error.cause is not None + assert error.traceback is not None + + def test_shutdown_cancelled_error(self): + """Test that pending requests are cancelled with RPCCancelled when server shuts down.""" + + class App: + + def task(self): + time.sleep(10) + return True + + addr = "ipc:///tmp/rpc_test_cancelled" + + server = RPCServer( + App(), + # only one worker to make it easier to pend requests + num_workers=1) + server.bind(addr) server.start() time.sleep(0.1) - client = RPCClient("ipc:///tmp/rpc_test_error") - with pytest.raises(RPCError): - client.hello() + with RPCClient(addr) as client: + client.shutdown_server() + pending_futures = [ + client.task(__rpc_mode="future") for _ in range(10) + ] -def test_rpc_without_wait_response(): + for future in pending_futures: + with pytest.raises(RPCCancelled): + future.result() - class App: + time.sleep(5) - def __init__(self): - self.task_submitted = False + client.close() - def send_task(self) -> None: - # Just submit the task and return immediately - # The result is not important - self.task_submitted = True - return None + def test_timeout_error(self): + """Test that requests that exceed timeout are handled with proper error.""" + + class App: + + def slow_method(self): + # Sleep longer than the timeout + time.sleep(2.0) + return "completed" + + with RpcServerWrapper(App(), + addr="ipc:///tmp/rpc_test_timeout") as server: + time.sleep(0.1) + + # Create client with short timeout + with RPCClient("ipc:///tmp/rpc_test_timeout", + timeout=0.5) as client: + with pytest.raises(RPCError) as exc_info: + client.slow_method(__rpc_timeout=0.5) + + error = exc_info.value + # Should be either a timeout error or RPC error indicating timeout + assert "timed out" in str( + error).lower() or "timeout" in str(error).lower() + + def test_method_not_found_error(self): + """Test that calling non-existent methods returns proper error.""" + + class App: - def get_task_submitted(self) -> bool: - return self.task_submitted + def existing_method(self): + return "exists" + + with RpcServerWrapper(App(), + addr="ipc:///tmp/rpc_test_not_found") as server: + time.sleep(0.1) + + with RPCClient("ipc:///tmp/rpc_test_not_found") as client: + with pytest.raises(RPCError) as exc_info: + client.non_existent_method() + + error = exc_info.value + assert "not found" in str(error) + assert error.traceback is not None + + +def test_rpc_shutdown_server(): + + class App: + + def hello(self): + return "world" with RPCServer(App()) as server: - server.bind("ipc:///tmp/rpc_test_no_wait") + server.bind("ipc:///tmp/rpc_test_shutdown") server.start() time.sleep(0.1) - client = RPCClient("ipc:///tmp/rpc_test_no_wait") - client.send_task(__rpc_need_response=False) - time.sleep(0.1) # wait for some time to make sure the task is submitted - assert client.get_task_submitted() + with RPCClient("ipc:///tmp/rpc_test_shutdown") as client: + ret = client.hello() + assert ret == "world" + + client.shutdown_server() + + time.sleep(5) # the server dispatcher thread need some time to quit def test_rpc_without_response_performance(): @@ -145,22 +302,21 @@ def send_task(self) -> None: server.bind("ipc:///tmp/rpc_test_no_wait") server.start() time.sleep(0.1) - client = RPCClient("ipc:///tmp/rpc_test_no_wait") + with RPCClient("ipc:///tmp/rpc_test_no_wait") as client: + time_start = time.time() + for i in range(100): + client.send_task(__rpc_need_response=False) + time_end = time.time() - time_start = time.time() - for i in range(100): - client.send_task(__rpc_need_response=False) - time_end = time.time() + no_wait_time = time_end - time_start - no_wait_time = time_end - time_start + time_start = time.time() + for i in range(100): + client.send_task(__rpc_need_response=True) + time_end = time.time() + wait_time = time_end - time_start - time_start = time.time() - for i in range(100): - client.send_task(__rpc_need_response=True) - time_end = time.time() - wait_time = time_end - time_start - - assert no_wait_time < wait_time, f"{no_wait_time} > {wait_time}" + assert no_wait_time < wait_time, f"{no_wait_time} > {wait_time}" @pytest.mark.parametrize("async_run_task", [True, False]) @@ -179,16 +335,16 @@ def cal(self, n: int): server.start() time.sleep(0.1) - client = RPCClient(server.address) + with RPCClient(server.address) as client: - time_start = time.time() - for i in range(10000): - ret = client.cal(i, __rpc_timeout=10) # sync call - assert ret == i * 2, f"{ret} != {i * 2}" - time_end = time.time() - print( - f"Time taken: {time_end - time_start} seconds, {10000 / (time_end - time_start)} calls/second" - ) + time_start = time.time() + for i in range(100): + ret = client.cal(i, __rpc_timeout=10) # sync call + assert ret == i * 2, f"{ret} != {i * 2}" + time_end = time.time() + print( + f"Time taken: {time_end - time_start} seconds, {10000 / (time_end - time_start)} calls/second" + ) @pytest.mark.parametrize("use_async", [True, False]) @@ -206,43 +362,99 @@ def slow_operation(self, delay: float): time.sleep(delay) return "completed" - with RPCServer(App()) as server: - server.bind("ipc:///tmp/rpc_test_timeout") - server.start() + # Use manual server lifecycle management to ensure server stays alive + server = RPCServer(App()) + server.bind("ipc:///tmp/rpc_test_timeout") + server.start() + + try: time.sleep(0.1) - client = RPCClient("ipc:///tmp/rpc_test_timeout") + with RPCClient("ipc:///tmp/rpc_test_timeout") as client: + + # Test that a short timeout causes RPCTimeout exception + with pytest.raises(RPCTimeout) as exc_info: + import asyncio + if use_async: + + async def test_async_timeout(): + return await client.call_async('slow_operation', + 2.0, + __rpc_timeout=0.1) + + asyncio.run(test_async_timeout()) + else: + assert client.slow_operation( + 2.0, __rpc_timeout=0.1) # small timeout + + assert "timed out" in str( + exc_info.value + ), f"Timeout message not found: {exc_info.value}" - # Test that a short timeout causes RPCTimeout exception - with pytest.raises(RPCError) as exc_info: + # Test that a long timeout allows the operation to complete if use_async: - # Test async call with timeout import asyncio - async def test_async_timeout(): + async def test_async_success(): return await client.call_async('slow_operation', - 2.0, - __rpc_timeout=0.1) + 0.1, + __rpc_timeout=10.0) - asyncio.run(test_async_timeout()) + result = asyncio.run(test_async_success()) else: - # Test sync call with timeout - client.slow_operation(2.0, __rpc_timeout=0.1) + result = client.slow_operation(0.1, __rpc_timeout=10.0) + + assert result == "completed" + + print(f"final result: {result}") + + finally: + server.shutdown() - assert "timed out" in str( - exc_info.value), f"Timeout message not found: {exc_info.value}" - # Test that a long timeout allows the operation to complete - if use_async: - # Test async call with sufficient timeout - import asyncio +class TestRpcShutdown: + + def test_duplicate_shutdown(self): + + class App: + + def quick_task(self, task_id: int): + return f"quick_task_{task_id}" + + with RpcServerWrapper(App(), + addr="ipc:///tmp/rpc_test_shutdown") as server: + time.sleep(0.1) + with RPCClient("ipc:///tmp/rpc_test_shutdown") as client: + client.quick_task(1) + + # repeated shutdown should not raise an error + for i in range(10): + server.shutdown() + + def test_submit_request_after_server_shutdown(self): + + class App: + + def foo(self, delay: int): + time.sleep(delay) + return "foo" + + server = RPCServer(App()) + server.bind("ipc:///tmp/rpc_test_shutdown") + server.start() + + time.sleep(0.1) + with RPCClient("ipc:///tmp/rpc_test_shutdown") as client: + # This task should be continued after server shutdown + res = client.foo(10, __rpc_timeout=12, __rpc_mode="future") + + # The shutdown will block until all pending requests are finished + server.shutdown() - async def test_async_success(): - return await client.call_async('slow_operation', - 0.1, - __rpc_timeout=1.0) + assert res.result() == "foo" - result = asyncio.run(test_async_success()) - else: - result = client.slow_operation(0.1, __rpc_timeout=1.0) - assert result == "completed" +if __name__ == "__main__": + #TestRpcError().test_shutdown_cancelled_error() + #test_rpc_shutdown_server() + #TestRpcShutdown().test_submit_request_after_server_shutdown() + test_rpc_timeout(True) diff --git a/tests/unittest/executor/test_rpc_proxy.py b/tests/unittest/executor/test_rpc_proxy.py new file mode 100644 index 00000000000..5251d66d7b9 --- /dev/null +++ b/tests/unittest/executor/test_rpc_proxy.py @@ -0,0 +1,52 @@ +import os +import sys + +from test_worker_base import create_fake_executor_config + +from tensorrt_llm.executor.rpc_proxy import GenerationExecutorRpcProxy +from tensorrt_llm.llmapi.mpi_session import MpiPoolSession +from tensorrt_llm.llmapi.tokenizer import TransformersTokenizer +from tensorrt_llm.sampling_params import SamplingParams + +# isort: off +sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/..") +from utils.llm_data import llm_models_root +from utils.util import similar +# isort: on + +model_path = llm_models_root() / "llama-models-v2/TinyLlama-1.1B-Chat-v1.0" + + +class TestRpcProxyTp1: + + def setup_method(self): + self.executor_config = create_fake_executor_config(model_path) + + def create_proxy(self, tp_size: int): + mpi_session = MpiPoolSession(n_workers=tp_size) + proxy = GenerationExecutorRpcProxy( + worker_kwargs={ + "engine": model_path, + "executor_config": self.executor_config, + "model_world_size": tp_size, + }, + mpi_session=mpi_session, + ) + return proxy + + def test_tp1(self): + tokenizer = TransformersTokenizer.from_pretrained(model_path) + prompt = "A B C D" + prompt_token_ids = tokenizer.encode(prompt) + max_tokens = 8 + + with self.create_proxy(tp_size=1) as proxy: + sampling_params = SamplingParams(max_tokens=max_tokens) + result = proxy.generate(prompt_token_ids, sampling_params) + print(f"get result: {result}") + assert similar(tokenizer.decode(result.outputs[0].token_ids), + 'E F G H I J K L') + + +if __name__ == "__main__": + TestRpcProxyTp1().test_tp1() diff --git a/tests/unittest/executor/test_rpc_worker.py b/tests/unittest/executor/test_rpc_worker.py index 6d0be4bf531..97ff0c9b997 100644 --- a/tests/unittest/executor/test_rpc_worker.py +++ b/tests/unittest/executor/test_rpc_worker.py @@ -4,7 +4,7 @@ import time from concurrent.futures import ProcessPoolExecutor -from test_worker_base import TestWorkerBase +from test_worker_base import create_fake_executor_config from tensorrt_llm.executor.request import GenerationRequest from tensorrt_llm.executor.rpc import RPCClient @@ -22,9 +22,8 @@ class TestRpcWorker: - def __init__(self): - self.executor_config = TestWorkerBase.create_fake_executor_config( - model_path) + def setup_method(self): + self.executor_config = create_fake_executor_config(model_path) def create_tp1_worker_process(self): addr = GenerationExecutorRpcProxy.gen_uniq_rpc_addr() @@ -41,14 +40,10 @@ def create_rpc_client(self, addr: str): client = RPCClient(addr) return client - def test_main(self): + def test_main_loop(self): pool, addr = self.create_tp1_worker_process() client = self.create_rpc_client(addr) - print("call setup_engine") - client.setup_engine(engine=model_path, - executor_config=self.executor_config, - __rpc_timeout=120) - print("call submit") + client.setup_engine(__rpc_timeout=120) time.sleep(1) def process_request(): @@ -56,44 +51,50 @@ def process_request(): prompt_token_ids=[3, 4, 5], sampling_params=SamplingParams(max_tokens=10)), __rpc_need_response=False) - assert ret is None + assert ret is None # need_response = False print(f"submit result: {ret}") print("call fetch_responses") # NOTE: known issue, the responses should be fetched before shutdown, # or the shutdown will hang. results = [] - for i in range(3): - time.sleep(3) - results.extend(client.fetch_responses()) - print(f"fetch_responses result: {results}") - assert len(results) == 1 + time.sleep(8) # wait for PyExecutor to finish the generation + results.extend( + client.fetch_responses()) # fetch_responses will block + print(f"fetch_responses result: {results}") + assert len(results) == 1 # one request, one response def process_request_streaming(): - ret = client.submit(prompt_token_ids=[3, 4, 5], - sampling_params=SamplingParams(max_tokens=10), - streaming=True, + ret = client.submit(GenerationRequest( + prompt_token_ids=[3, 4, 5], + sampling_params=SamplingParams(max_tokens=10), + streaming=True), __rpc_need_response=False) assert ret is None + print("submit result: ", ret) - print("call fetch_responses") # NOTE: known issue, the responses should be fetched before shutdown, # or the shutdown will hang. results = [] - for i in range(3): - time.sleep(3) - results.extend(client.fetch_responses()) - print(f"fetch_responses result: {results}") - print(f"generate_async result: {results}") + time.sleep(8) + + while not results: + time.sleep(1) + results.extend(client.fetch_responses(__rpc_timeout=10)) + print(f"try fetch_responses result: {results}") + print(f"fetch_responses result: {results}") + assert results - process_request() + for i in range(5): + process_request() process_request_streaming() print("call shutdown") client.shutdown(__rpc_timeout=10) pool.shutdown() + client.close() if __name__ == '__main__': worker = TestRpcWorker() - worker.test_main() + worker.test_main_loop() diff --git a/tests/unittest/executor/test_worker_base.py b/tests/unittest/executor/test_worker_base.py index d825d077d18..d40efe756b5 100644 --- a/tests/unittest/executor/test_worker_base.py +++ b/tests/unittest/executor/test_worker_base.py @@ -2,6 +2,8 @@ import sys import time +import pytest + # isort: off sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/..") from utils.llm_data import llm_models_root @@ -24,8 +26,10 @@ class FakeWorker(WorkerBase): def __init__(self, engine: str): super().__init__(engine=engine) - executor_config = TestWorkerBase.create_fake_executor_config(engine) - self.setup_engine(engine=engine, executor_config=executor_config) + executor_config = create_fake_executor_config(engine) + # Pass config in constructor and finalize with parameterless setup + self._executor_config = executor_config + self.setup_engine() def test_create_engine(self): with self.FakeWorker(engine=model_path) as worker: @@ -62,32 +66,42 @@ def test_dispatch_stats_task(self): with self.FakeWorker(engine=model_path) as worker: worker.submit(request) worker.await_responses() - worker.dispatch_stats_task() time.sleep(10) stats = worker.fetch_stats() assert len(stats) == 1 - @staticmethod - def create_fake_executor_config(model_path): - llm_args = LlmArgs(model=model_path, cuda_graph_config=None) - - executor_config = tllm.ExecutorConfig(1) - executor_config.max_batch_size = 1 - - update_executor_config( - executor_config, - backend="pytorch", - pytorch_backend_config=llm_args.get_pytorch_backend_config(), - mapping=llm_args.parallel_config.to_mapping(), - speculative_config=llm_args.speculative_config, - hf_model_dir=model_path, - max_input_len=20, - max_seq_len=40, - checkpoint_format=llm_args.checkpoint_format, - checkpoint_loader=llm_args.checkpoint_loader, - ) - - return executor_config + @pytest.mark.parametrize("timeout", [0.1, 0.2, 1]) + def test_fetch_responses_timeout(self, timeout: float): + with self.FakeWorker(engine=model_path) as worker: + # Not submit any request, and let the await_responses timeout. + start_time = time.time() + results = worker.await_responses(timeout=timeout) + elapsed = time.time() - start_time + print(f"await_responses latency: {elapsed:.3f} seconds") + assert timeout / 2 <= elapsed <= timeout * 2, f"Latency out of expected range: {elapsed}" + assert results is None + + +def create_fake_executor_config(model_path): + llm_args = LlmArgs(model=model_path, cuda_graph_config=None) + + executor_config = tllm.ExecutorConfig(1) + executor_config.max_batch_size = 1 + + update_executor_config( + executor_config, + backend="pytorch", + pytorch_backend_config=llm_args.get_pytorch_backend_config(), + mapping=llm_args.parallel_config.to_mapping(), + speculative_config=llm_args.speculative_config, + hf_model_dir=model_path, + max_input_len=20, + max_seq_len=40, + checkpoint_format=llm_args.checkpoint_format, + checkpoint_loader=llm_args.checkpoint_loader, + ) + + return executor_config if __name__ == "__main__": From d34ce1bf4e887d39e0a418905237f5eeb2ce7558 Mon Sep 17 00:00:00 2001 From: Superjomn <328693+Superjomn@users.noreply.github.com> Date: Fri, 22 Aug 2025 18:36:34 +0800 Subject: [PATCH 05/13] support rpc streaming call remove default timeout fix streaming task differ event_loop bug fix pending count shutdown bug fix test_rpc_worker.py Signed-off-by: Yan Chunwei <328693+Superjomn@users.noreply.github.com> add tp2 creation test for TP2 fix rpc_worker TP2 creation hang fix proxy Signed-off-by: chunweiy Signed-off-by: chunweiy <328693+Superjomn@users.noreply.github.com> Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> --- tensorrt_llm/executor/rpc/__init__.py | 7 +- tensorrt_llm/executor/rpc/rpc_client.py | 253 +++++++++++++++--- tensorrt_llm/executor/rpc/rpc_common.py | 23 +- tensorrt_llm/executor/rpc/rpc_server.py | 211 ++++++++++++--- tensorrt_llm/executor/rpc_proxy.py | 51 ++-- tensorrt_llm/executor/rpc_worker.py | 58 ++++- tensorrt_llm/executor/worker_base.py | 41 ++- tensorrt_llm/llmapi/utils.py | 68 ++++- tests/unittest/executor/test_rpc.py | 273 ++++++++++++++++---- tests/unittest/executor/test_rpc_proxy.py | 42 ++- tests/unittest/executor/test_rpc_worker.py | 175 +++++++++++-- tests/unittest/executor/test_worker_base.py | 77 +++++- tests/unittest/llmapi/test_llm_utils.py | 26 ++ tests/unittest/pytest.ini | 2 +- 14 files changed, 1087 insertions(+), 220 deletions(-) diff --git a/tensorrt_llm/executor/rpc/__init__.py b/tensorrt_llm/executor/rpc/__init__.py index 38c4924e1ab..6f62051bb41 100644 --- a/tensorrt_llm/executor/rpc/__init__.py +++ b/tensorrt_llm/executor/rpc/__init__.py @@ -1,9 +1,10 @@ from .rpc_client import RPCClient -from .rpc_common import (RPCCancelled, RPCError, RPCRequest, RPCResponse, - RPCTimeout) +from .rpc_common import (RPCCancelled, RPCError, RPCParams, RPCRequest, + RPCResponse, RPCStreamingError, RPCTimeout) from .rpc_server import RPCServer, Server __all__ = [ "RPCClient", "RPCServer", "Server", "RPCError", "RPCTimeout", - "RPCCancelled", "RPCRequest", "RPCResponse" + "RPCCancelled", "RPCStreamingError", "RPCRequest", "RPCResponse", + "RPCParams" ] diff --git a/tensorrt_llm/executor/rpc/rpc_client.py b/tensorrt_llm/executor/rpc/rpc_client.py index 35ac2d42366..723dfeff944 100644 --- a/tensorrt_llm/executor/rpc/rpc_client.py +++ b/tensorrt_llm/executor/rpc/rpc_client.py @@ -2,10 +2,13 @@ import concurrent.futures import threading import uuid +from typing import Any, AsyncIterator, Dict, Optional +from ...llmapi.utils import AsyncQueue, _SyncQueue, logger_debug from ...logger import logger from ..ipc import ZeroMqQueue -from .rpc_common import RPCCancelled, RPCRequest, RPCResponse, RPCTimeout +from .rpc_common import (RPCCancelled, RPCParams, RPCRequest, RPCResponse, + RPCStreamingError, RPCTimeout) class RPCClient: @@ -16,7 +19,7 @@ class RPCClient: def __init__(self, address: str, hmac_key=None, - timeout: float = 10, + timeout: Optional[float] = None, num_workers: int = 4): ''' Args: @@ -31,15 +34,19 @@ def __init__(self, is_async=True, use_hmac_encryption=False) self._pending_futures = {} + # map request_id to the queue for streaming responses + self._streaming_queues: Dict[str, AsyncQueue] = {} self._reader_task = None self._executor = concurrent.futures.ThreadPoolExecutor( - max_workers=num_workers, thread_name_prefix="rpc_client") + max_workers=num_workers, thread_name_prefix="rpc_client_worker") self._server_stopped = False + self._closed = False + self._stop_event = None self._loop = None self._loop_thread = None - logger.debug(f"RPC Client initialized. Connected to {self._address}") + logger_debug(f"RPC Client initialized. Connected to {self._address}") def shutdown_server(self): """Shutdown the server.""" @@ -52,9 +59,30 @@ def shutdown_server(self): def close(self): """Gracefully close the client, cleaning up background tasks.""" + + if self._closed: + return + # stop the main loop + self._closed = True + + logger_debug("RPC Client closing") + + if self._stop_event and self._loop: + # Use call_soon_threadsafe since set() is not a coroutine + self._loop.call_soon_threadsafe(self._stop_event.set) + if self._reader_task: - self._reader_task.cancel() + try: + self._reader_task.result(timeout=1.0) + except concurrent.futures.TimeoutError: + logger.warning( + "Reader task did not exit gracefully, cancelling") + self._reader_task.cancel() + except Exception as e: + # Task might have already finished or been cancelled + logger_debug(f"Reader task cleanup: {e}") self._reader_task = None + if self._loop and self._loop.is_running(): self._loop.call_soon_threadsafe(self._loop.stop) if self._loop_thread: @@ -63,34 +91,81 @@ def close(self): if self._executor: self._executor.shutdown(wait=True) + if self._client_socket: + self._client_socket.close() + self._client_socket = None + + logger_debug("RPC Client closed") + async def _response_reader(self): """Task to read responses from the socket and set results on futures.""" - while True: + while not self._stop_event.is_set(): try: - response: RPCResponse = await self._client_socket.get_async() - logger.debug(f"RPC Client received response: {response}") - future = self._pending_futures.get(response.request_id) - if future and not future.done(): - if response.error is None: - future.set_result(response.result) - else: - # Use the original RPCError from the response - future.set_exception(response.error) - self._pending_futures.pop(response.request_id, None) + # Use wait_for with a short timeout to periodically check stop event + try: + response: RPCResponse = await asyncio.wait_for( + self._client_socket.get_async(), + timeout=0.1 # Check stop event every 100ms + ) + except asyncio.TimeoutError: + # Timeout is expected - just check stop event and continue + continue + + logger_debug(f"RPC Client received response: {response}") + + # Handle streaming responses + if response.is_streaming: + assert response.stream_status in [ + 'start', 'data', 'end', 'error' + ], f"Invalid stream status: {response.stream_status}" + queue = self._streaming_queues.get(response.request_id) + if queue: + # put to the sync queue, as the current event loop is + # different from the one in call_async or call_streaming + assert isinstance(queue, AsyncQueue) + logger_debug( + f"RPC Client putting response to AsyncQueue: {response}" + ) + queue.sync_q.put(response) + # Clean up if stream ended + if response.stream_status in ['end', 'error']: + self._streaming_queues.pop(response.request_id, + None) + else: + # Handle regular responses + if future_info := self._pending_futures.get( + response.request_id): + future, target_loop = future_info + + if not future.done(): + if response.error is None: + target_loop.call_soon_threadsafe( + future.set_result, response.result) + else: + # Use the original RPCError from the response + target_loop.call_soon_threadsafe( + future.set_exception, response.error) + self._pending_futures.pop(response.request_id, None) except asyncio.CancelledError: + # Still handle cancellation for backward compatibility + logger_debug("Response reader cancelled") break except Exception as e: logger.error(f"Exception in RPC response reader: {e}") # Propagate exception to all pending futures - for future in self._pending_futures.values(): + for (future, target_loop) in self._pending_futures.values(): + if not future.done(): - future.set_exception(e) + target_loop.call_soon_threadsafe( + future.set_exception, e) + # Also signal error to streaming queues + for queue in self._streaming_queues.values(): + await queue.put(RPCResponse("", None, e, False, 0, 'error')) break - await asyncio.sleep(0) - + logger_debug("Response reader exiting gracefully") self._reader_task = None def _start_response_reader_lazily(self): @@ -103,38 +178,36 @@ def _start_response_reader_lazily(self): # Store the concurrent.futures.Future self._reader_task = future - async def _call_async(self, __rpc_method_name, *args, **kwargs): + async def _call_async(self, method_name, *args, **kwargs): """Async version of RPC call. Args: - __rpc_method_name: Method name to call + method_name: Method name to call *args: Positional arguments **kwargs: Keyword arguments - __rpc_timeout: The timeout (seconds) for the RPC call. - __rpc_need_response: Whether the RPC call needs a response. - If set to False, the remote call will return immediately. + __rpc_params: RPCParams object containing RPC parameters. Returns: The result of the remote method call """ - logger.debug( - f"RPC client calling method: {__rpc_method_name} with args: {args} and kwargs: {kwargs}" + logger_debug( + f"RPC client calling method: {method_name} with args: {args} and kwargs: {kwargs}" ) if self._server_stopped: raise RPCCancelled("Server is shutting down, request cancelled") self._start_response_reader_lazily() - need_response = kwargs.pop("__rpc_need_response", True) - timeout = kwargs.pop("__rpc_timeout", self._timeout) + rpc_params = kwargs.pop("__rpc_params", RPCParams()) + need_response = rpc_params.need_response + timeout = rpc_params.timeout if rpc_params.timeout is not None else self._timeout request_id = uuid.uuid4().hex - logger.debug(f"RPC client sending request: {request_id}") request = RPCRequest(request_id, - __rpc_method_name, + method_name, args, kwargs, need_response, timeout=timeout) - logger.debug(f"RPC client sending request: {request}") + logger_debug(f"RPC client sending request: {request}") await self._client_socket.put_async(request) if not need_response: @@ -142,20 +215,23 @@ async def _call_async(self, __rpc_method_name, *args, **kwargs): loop = asyncio.get_running_loop() future = loop.create_future() - self._pending_futures[request_id] = future + self._pending_futures[request_id] = (future, loop) try: # If timeout, the remote call should return a timeout error timely, # so we add 1 second to the timeout to ensure the client can get # that result. - res = await asyncio.wait_for(future, timeout + 1) + if timeout is None: + res = await future + else: + res = await asyncio.wait_for(future, timeout + 1) return res except RPCCancelled: self._server_stopped = True raise except asyncio.TimeoutError: raise RPCTimeout( - f"Request '{__rpc_method_name}' timed out after {timeout}s") + f"Request '{method_name}' timed out after {timeout}s") except Exception as e: raise e finally: @@ -168,20 +244,23 @@ def _ensure_event_loop(self): def run_loop(): asyncio.set_event_loop(self._loop) + self._stop_event = asyncio.Event() self._loop.run_forever() - self._loop_thread = threading.Thread(target=run_loop, daemon=True) + self._loop_thread = threading.Thread(target=run_loop, + daemon=True, + name="rpc_client_loop") self._loop_thread.start() # Give the loop a moment to start import time time.sleep(0.1) - def _call_sync(self, __rpc_method_name, *args, **kwargs): + def _call_sync(self, method_name, *args, **kwargs): """Synchronous version of RPC call.""" self._ensure_event_loop() future = asyncio.run_coroutine_threadsafe( - self._call_async(__rpc_method_name, *args, **kwargs), self._loop) + self._call_async(method_name, *args, **kwargs), self._loop) return future.result() def call_async(self, name: str, *args, **kwargs): @@ -199,7 +278,9 @@ def call_async(self, name: str, *args, **kwargs): Example: result = await client.call_async('remote_method', arg1, arg2, key=value) """ - return self._call_async(name, *args, **kwargs, __rpc_need_response=True) + if "__rpc_params" not in kwargs: + kwargs["__rpc_params"] = RPCParams(need_response=True) + return self._call_async(name, *args, **kwargs) def call_future(self, name: str, *args, **kwargs) -> concurrent.futures.Future: @@ -246,10 +327,94 @@ def call_sync(self, name: str, *args, **kwargs): """ return self._call_sync(name, *args, **kwargs) + async def call_streaming(self, name: str, *args, + **kwargs) -> AsyncIterator[Any]: + """ + Call a remote async generator method and get streaming results. + + Args: + name: Method name to call + *args: Positional arguments + **kwargs: Keyword arguments + + Yields: + Results from the remote async generator + + Example: + async for result in client.call_streaming('streaming_task'): + print(result) + """ + if self._server_stopped: + raise RPCCancelled("Server is shutting down, request cancelled") + + self._start_response_reader_lazily() + rpc_params = kwargs.pop("__rpc_params", RPCParams()) + timeout = rpc_params.timeout if rpc_params.timeout is not None else self._timeout + + request_id = uuid.uuid4().hex + # Use AsyncQueue to ensure proper cross-thread communication + queue = AsyncQueue() + # Recreate sync_q with the current running loop for proper cross-thread communication + # This ensures the background _response_reader thread can properly notify this event loop + queue._sync_q = _SyncQueue(queue, asyncio.get_running_loop()) + self._streaming_queues[request_id] = queue + + try: + # Send streaming request + request = RPCRequest(request_id, + name, + args, + kwargs, + need_response=True, + timeout=timeout, + is_streaming=True) + await self._client_socket.put_async(request) + + # Read streaming responses + while True: + logger_debug(f"RPC Client call_streaming waiting for response", + color="green") + if timeout is None: + response = await queue.get() + else: + response = await asyncio.wait_for(queue.get(), + timeout=timeout + 1) + + logger_debug( + f"RPC Client call_streaming received [{response.stream_status}] response: {response}", + color="green") + if response.stream_status == 'start': + # Start of stream + continue + elif response.stream_status == 'data': + logger_debug( + f"RPC Client call_streaming received data: {response.result}", + color="green") + # Yield data + yield response.result + elif response.stream_status == 'end': + # End of stream + break + elif response.stream_status == 'error': + # Error in stream + if response.error: + raise response.error + else: + raise RPCStreamingError("Unknown streaming error") + + except asyncio.TimeoutError: + raise RPCTimeout( + f"Streaming request '{name}' timed out after {timeout}s") + finally: + # Clean up + self._streaming_queues.pop(request_id, None) + def get_server_attr(self, name: str): """ Get the attribute of the RPC server. This is mainly used for testing. """ - return self._call_sync("__rpc_get_attr", name, __rpc_timeout=10) + return self._call_sync("__rpc_get_attr", + name, + __rpc_params=RPCParams(timeout=10)) def __getattr__(self, name): """ @@ -265,7 +430,8 @@ def __init__(self, client, method_name): def __call__(self, *args, **kwargs): """Default synchronous call""" - mode = kwargs.pop("__rpc_mode", "sync") + rpc_params = kwargs.get("__rpc_params", RPCParams()) + mode = rpc_params.mode if mode == "sync": return self.client._call_sync(self.method_name, *args, **kwargs) @@ -288,6 +454,11 @@ def call_future(self, *args, **kwargs) -> concurrent.futures.Future: return self.client.call_future(self.method_name, *args, **kwargs) + def call_streaming(self, *args, **kwargs) -> AsyncIterator[Any]: + """Streaming call - returns async iterator""" + return self.client.call_streaming(self.method_name, *args, + **kwargs) + return MethodProxy(self, name) def __enter__(self): diff --git a/tensorrt_llm/executor/rpc/rpc_common.py b/tensorrt_llm/executor/rpc/rpc_common.py index 22b85097555..4c81911cf89 100644 --- a/tensorrt_llm/executor/rpc/rpc_common.py +++ b/tensorrt_llm/executor/rpc/rpc_common.py @@ -1,4 +1,17 @@ -from typing import Any, NamedTuple, Optional +from typing import Any, Literal, NamedTuple, Optional + + +class RPCParams(NamedTuple): + """ Parameters for RPC calls. """ + + # seconds to wait for the response + timeout: Optional[float] = None + + # whether the client needs the response, if False, it will return immediately + need_response: bool = True + + # mode for RPC calls: "sync", "async", or "future" + mode: str = "sync" # --- Custom Exceptions --- @@ -31,6 +44,10 @@ class RPCCancelled(RPCError): """ +class RPCStreamingError(RPCError): + """Exception for streaming-related errors.""" + + class RPCRequest(NamedTuple): request_id: str method_name: str @@ -38,9 +55,13 @@ class RPCRequest(NamedTuple): kwargs: dict need_response: bool = True timeout: float = 0.5 + is_streaming: bool = False class RPCResponse(NamedTuple): request_id: str result: Any error: Optional[RPCError] = None + is_streaming: bool = False # True if more responses coming + sequence_number: int = 0 # For ordering streaming responses + stream_status: Literal['start', 'data', 'end', 'error'] = 'data' diff --git a/tensorrt_llm/executor/rpc/rpc_server.py b/tensorrt_llm/executor/rpc/rpc_server.py index b0d1377569e..cf6e05486d3 100644 --- a/tensorrt_llm/executor/rpc/rpc_server.py +++ b/tensorrt_llm/executor/rpc/rpc_server.py @@ -1,4 +1,5 @@ import asyncio +import inspect import queue import threading import time @@ -6,10 +7,11 @@ from concurrent.futures import ThreadPoolExecutor from typing import Optional -from ...llmapi.utils import ManagedThread +from ...llmapi.utils import ManagedThread, logger_debug from ...logger import logger from ..ipc import ZeroMqQueue -from .rpc_common import RPCError, RPCRequest, RPCResponse, RPCTimeout +from .rpc_common import (RPCError, RPCRequest, RPCResponse, RPCStreamingError, + RPCTimeout) class RPCServer: @@ -20,7 +22,7 @@ class RPCServer: def __init__(self, instance, hmac_key=None, - num_workers: int = 1, + num_workers: int = 4, timeout: float = 0.5, async_run_task: bool = False): """ @@ -32,6 +34,8 @@ def __init__(self, num_workers (int): Number of worker threads. timeout (int): Timeout for RPC calls. async_run_task (bool): Whether to run the task asynchronously. + + NOTE: make num_workers larger if there are some streaming tasks runs infinitely. """ self._instance = instance self._hmac_key = hmac_key @@ -51,8 +55,8 @@ def __init__(self, } self._dispatcher_thread: Optional[ManagedThread] = None if async_run_task: - self._executor = ThreadPoolExecutor(max_workers=num_workers, - thread_name_prefix="rpc_worker") + self._executor = ThreadPoolExecutor( + max_workers=num_workers, thread_name_prefix="rpc_server_worker") else: self._executor = None @@ -61,7 +65,8 @@ def __init__(self, # Automatically register the instance self.register_instance(instance) - logger.debug(f"RPC Server initialized with {num_workers} workers.") + logger_debug(f"RPC Server initialized with {num_workers} workers.", + color="green") @property def address(self) -> str: @@ -101,7 +106,7 @@ def shutdown(self, is_remote_call: bool = False): if self._stop_event.is_set(): return - logger.debug( + logger_debug( "RPC Server shutdown signal received. Terminating server...") # Set the stop event to True, this will trigger the dispatcher routine and @@ -110,22 +115,23 @@ def shutdown(self, is_remote_call: bool = False): self._stop_event.set() # The worker routine should process the pending requests - logger.debug( + logger_debug( f"RPC Server shutdown: {self._num_pending_requests} pending requests" ) + while self._num_pending_requests > 0: time.sleep(0.01) - logger.debug(f"RPC Server shutdown finished pending requests") + logger_debug(f"RPC Server shutdown finished pending requests") if not is_remote_call: # Block the thread until shutdown is finished # 1. Wait for the dispatcher thread to exit, so that no new requests are accepted - logger.debug(f"RPC Server dispatcher thread joining") + logger_debug(f"RPC Server dispatcher thread joining") if self._dispatcher_thread: self._dispatcher_thread.join() self._dispatcher_thread = None - logger.debug(f"RPC Server dispatcher thread joined") + logger_debug(f"RPC Server dispatcher thread joined") # 2. Wait for the executor to exit, it will wait for the pending requests to be processed if self._executor: @@ -140,13 +146,13 @@ def shutdown(self, is_remote_call: bool = False): # if the shutdown is called by a remote call, this method itself will # be executed in a executor thread, so we cannot join the dispatcher thread as # the dispatcher thread is awaiting for the shutdown result. - logger.debug( + logger_debug( f"RPC Server to shutdown: {self._num_pending_requests} pending requests" ) while self._num_pending_requests > 0: time.sleep(0.01) - logger.debug(f"RPC Server shutdown finished pending requests") + logger_debug(f"RPC Server shutdown finished pending requests") def register_function(self, func, name=None): """Exposes a single function to clients.""" @@ -155,11 +161,11 @@ def register_function(self, func, name=None): logger.warning( f"Function '{fname}' is already registered. Overwriting.") self._functions[fname] = func - logger.debug(f"Registered function: {fname}") + logger_debug(f"Registered function: {fname}") def register_instance(self, instance): """Exposes all public methods of a class instance.""" - logger.debug( + logger_debug( f"Registering instance of class: {instance.__class__.__name__}") for name in dir(instance): if not name.startswith('_'): @@ -182,17 +188,20 @@ async def _dispatcher_routine(self, stop_event: threading.Event): try: req: RPCRequest = await self._client_socket.get_async_noblock( timeout=0.5) - logger.debug(f"RPC dispatcher got request: {req}") + logger_debug(f"RPC dispatcher got request: {req}") except asyncio.TimeoutError: await asyncio.sleep(0) continue await self._queue.put(req) # type: ignore - # shutdown is a builtin method depends on _num_pending_requests, so - # it should not be counted - if req.method_name != "__rpc_shutdown": + # shutdown methods depend on _num_pending_requests, so + # they should not be counted + if req.method_name not in ["__rpc_shutdown", "shutdown"]: self._num_pending_requests += 1 + logger_debug( + f"Dispatcher received request {req}, pending: {self._num_pending_requests}" + ) async def _worker_routine(self, stop_event: threading.Event): """The routine executed by each worker thread.""" @@ -208,33 +217,93 @@ async def _worker_routine(self, stop_event: threading.Event): await asyncio.sleep(0) continue - response = await self._process_request(req) + # check if the method name is in the functions + if req.method_name not in self._functions: + logger.error( + f"Method '{req.method_name}' not found in RPC server.") + if not req.need_response: + continue + if req.is_streaming: + await self._client_socket.put_async( + RPCResponse( + req.request_id, + None, + RPCStreamingError( + f"Method '{req.method_name}' not found in RPC server.", + traceback=traceback.format_exc()), + stream_status='error')) + else: + response = RPCResponse( + req.request_id, + None, + RPCError( + f"Method '{req.method_name}' not found in RPC server.", + traceback=traceback.format_exc()), + ) + await self._client_socket.put_async(response) - # Some tasks don't need response, e.g. submit_request or shutdown - if req.need_response: - logger.debug(f"RPC Server sending response for request {req}") - await self._client_socket.put_async(response) - logger.debug(f"RPC Server sent response for request {req}") - - self._num_pending_requests -= 1 + continue - async def _process_request(self, req: RPCRequest) -> RPCResponse: - if req.method_name not in self._functions: - return RPCResponse( - req.request_id, None, - RPCError(f"Method '{req.method_name}' not found in RPC server.", - traceback=traceback.format_exc())) + func = self._functions[req.method_name] + if req.is_streaming: + if inspect.isasyncgenfunction(func): + await self._process_streaming_request(req) + else: + # Non-streaming function called with streaming flag + response = RPCResponse( + req.request_id, + None, + RPCStreamingError( + f"Method '{req.method_name}' is not a streaming function." + ), + # need to redirect the error to the client's streaming queue + is_streaming=True, + stream_status='error', + ) + await self._client_socket.put_async(response) + else: + # Process regular request + response = await self._process_request(req) + + # Some tasks don't need response, e.g. submit_request or shutdown + if req.need_response and response is not None: + logger_debug( + f"RPC Server sending response for request {req}, pending: {self._num_pending_requests}" + ) + await self._client_socket.put_async(response) + logger_debug(f"RPC Server sent response for request {req}") + + # Only decrement if this request was counted in the first place + if req.method_name not in ["__rpc_shutdown", "shutdown"]: + self._num_pending_requests -= 1 + + async def _process_request(self, req: RPCRequest) -> Optional[RPCResponse]: + """Process a request. Returns None for streaming requests (handled separately).""" + func = self._functions[req.method_name] try: - loop = asyncio.get_running_loop() - - def call_with_kwargs(): - return self._functions[req.method_name](*req.args, **req.kwargs) - - result = await asyncio.wait_for(loop.run_in_executor( - self._executor, call_with_kwargs), - timeout=req.timeout) - logger.debug(f"RPC Server returned result for request {req}") + if inspect.iscoroutinefunction(func): + # Execute async function directly in event loop, no need to run in executor due to the GIL + logger_debug( + f"RPC Server running async task {req.method_name} in dispatcher" + ) + result = await asyncio.wait_for(func(*req.args, **req.kwargs), + timeout=req.timeout) + else: + # Execute sync function in thread executor + loop = asyncio.get_running_loop() + + def call_with_kwargs(): + return func(*req.args, **req.kwargs) + + logger_debug( + f"RPC Server running async task {req.method_name} in worker" + ) + result = await asyncio.wait_for(loop.run_in_executor( + self._executor, call_with_kwargs), + timeout=req.timeout) + + logger_debug(f"RPC Server returned result for request {req}") response = RPCResponse(req.request_id, result) except asyncio.TimeoutError: @@ -251,6 +320,64 @@ def call_with_kwargs(): return response + async def _process_streaming_request(self, req: RPCRequest): + """Process a streaming request by sending multiple responses.""" + func = self._functions[req.method_name] + + if not inspect.isasyncgenfunction(func): + await self._client_socket.put_async( + RPCResponse( + req.request_id, + None, + RPCStreamingError( + f"Method '{req.method_name}' is not an async generator.", + traceback=traceback.format_exc()), + # need to redirect the error to the client's streaming queue + stream_status='error')) + return + + sequence_number = 0 + + try: + logger_debug(f"RPC Server running streaming task {req.method_name}") + # Send start signal + await self._client_socket.put_async( + RPCResponse(req.request_id, None, None, True, sequence_number, + 'start')) + sequence_number += 1 + + # Stream the results + async for result in func(*req.args, **req.kwargs): + logger_debug( + f"RPC Server got data and ready to send result {result}") + await self._client_socket.put_async( + RPCResponse(req.request_id, result, None, True, + sequence_number, 'data')) + sequence_number += 1 + + # Send end signal + await self._client_socket.put_async( + RPCResponse(req.request_id, None, None, True, sequence_number, + 'end')) + + except asyncio.TimeoutError: + await self._client_socket.put_async( + RPCResponse( + req.request_id, None, + RPCTimeout( + f"Streaming method '{req.method_name}' timed out", + traceback=traceback.format_exc()), True, + sequence_number, 'error')) + + except Exception as e: + await self._client_socket.put_async( + RPCResponse( + req.request_id, None, + RPCStreamingError(str(e), + cause=e, + traceback=traceback.format_exc()), True, + sequence_number, 'error')) + def start(self): """Binds sockets, starts workers, and begins proxying messages.""" if self._client_socket is None: diff --git a/tensorrt_llm/executor/rpc_proxy.py b/tensorrt_llm/executor/rpc_proxy.py index ea637e5d6c7..63783033ef8 100644 --- a/tensorrt_llm/executor/rpc_proxy.py +++ b/tensorrt_llm/executor/rpc_proxy.py @@ -1,3 +1,4 @@ +import asyncio import atexit import os import threading @@ -6,7 +7,7 @@ from ..llmapi.mpi_session import MpiPoolSession, MpiSession from ..llmapi.tracer import global_tracer -from ..llmapi.utils import (_SyncQueue, print_colored_debug, +from ..llmapi.utils import (_SyncQueue, logger_debug, print_colored_debug, print_traceback_on_error) from ..logger import logger from .executor import GenerationExecutor @@ -14,6 +15,7 @@ from .request import GenerationRequest from .result import GenerationResult from .rpc import RPCClient +from .rpc.rpc_common import RPCParams from .rpc_worker import RpcWorker from .utils import (ErrorResponse, create_mpi_comm_session, get_spawn_proxy_process_env, is_llm_response) @@ -83,24 +85,23 @@ def launch_workers(self): **self.worker_kwargs) @print_traceback_on_error - def main_loop_task(self): + async def main_loop_task(self): """ Main loop of the proxy, it will invoke the actions periodically. """ - clock = 0 - while not self._shutdown_event.is_set(): - if clock % 1 == 0: - responses = self.fetch_responses_remote() - self.handle_responses(responses) - if clock % 10 == 0: - stats = self.fetch_stats_remote() # TODO - self.handle_stats(stats) - - clock += 1 - time.sleep(self.clock_unit) + async for responses in self.rpc_client.fetch_responses_loop_async.call_streaming( + ): + if self._shutdown_event.is_set(): + return + self.handle_responses(responses) def setup_mainloop(self): - self.main_loop_thread = threading.Thread(target=self.main_loop_task, + + def _run_main_loop_task(): + """Local method to run the main loop task.""" + asyncio.run(self.main_loop_task()) + + self.main_loop_thread = threading.Thread(target=_run_main_loop_task, daemon=True) self.main_loop_thread.start() atexit.register(self.shutdown) @@ -144,7 +145,8 @@ def submit(self, request: GenerationRequest) -> GenerationResult: logprob_params = self._get_logprob_params(request) # submit is a fire-and-forget operation, don't need to wait for response - self.rpc_client.submit(request, __rpc_need_response=False) + self.rpc_client.submit(request, + __rpc_params=RPCParams(need_response=False)) result = GenerationResult( request, @@ -157,16 +159,19 @@ def submit(self, request: GenerationRequest) -> GenerationResult: return result def fetch_responses_remote(self): - return self.rpc_client.fetch_responses(__rpc_timeout=20) + return self.rpc_client.fetch_responses(__rpc_params=RPCParams( + timeout=20)) def fetch_stats_remote(self): return self.rpc_client.fetch_stats() def setup_engine_remote(self): - return self.rpc_client.setup_engine(__rpc_timeout=60 * 20) # 20 min + return self.rpc_client.setup_engine(__rpc_params=RPCParams( + need_response=True)) def shutdown_remote(self): - self.rpc_client.shutdown(__rpc_timeout=60 * 20) # 20 min + logger_debug(f"Shutting down rpc remote", color="yellow") + self.rpc_client.shutdown() def abort_request(self, request_id: int) -> None: return self.rpc_client.abort_request(request_id) @@ -174,14 +179,16 @@ def abort_request(self, request_id: int) -> None: def shutdown(self): if self._shutdown_event.is_set(): return + logger_debug(f"Shutting down GenerationExecutorRpcProxy", + color="yellow") + + # 1. shutdown the rpc server (PyExecutor Rank 0 + RPC server) + self.shutdown_remote() - # 1. stop the main loop, so that no new rpc requests + # 2. stop the main loop, so that no new rpc requests self._shutdown_event.set() self.main_loop_thread.join() - # 2. shutdown the rpc server (PyExecutor Rank 0 + RPC server) - self.shutdown_remote() - # 3. shutdown the mpi session, this should wait until all the PyExecutor # processes are shutdown if self.mpi_session is not None: diff --git a/tensorrt_llm/executor/rpc_worker.py b/tensorrt_llm/executor/rpc_worker.py index 8ff41a7fc34..4ffe866ae9b 100644 --- a/tensorrt_llm/executor/rpc_worker.py +++ b/tensorrt_llm/executor/rpc_worker.py @@ -1,14 +1,16 @@ +import asyncio from pathlib import Path from queue import Queue from threading import Event -from typing import Optional, Union +from typing import AsyncGenerator, Optional, Union -from tensorrt_llm.llmapi.utils import enable_llm_debug +from tensorrt_llm._utils import mpi_comm +from tensorrt_llm.llmapi.utils import enable_llm_debug, logger_debug from .._utils import mpi_rank from ..bindings import executor as tllm from ..builder import Engine -from ..logger import logger, set_level +from ..logger import set_level from ..lora_manager import LoraConfig from ..sampling_params import BatchedLogitsProcessor from .postproc_worker import PostprocWorkerConfig @@ -50,16 +52,44 @@ def __init__( def fetch_stats(self) -> list: return super().fetch_stats() - def fetch_responses(self) -> list: - logger.debug(f"RpcWorker {mpi_rank()} is fetching responses") + def fetch_responses(self, timeout: Optional[float] = None) -> list: + logger_debug(f"RpcWorker {mpi_rank()} is fetching responses", + color="yellow") # NOTE: This is a blocking call, it will wait for the responses to be available. - super().await_responses() - logger.debug(f"RpcWorker returning responses") + super().await_responses(timeout) + logger_debug(f"RpcWorker returning responses", color="yellow") qsize = self._response_queue.qsize() return [self._response_queue.get() for _ in range(qsize)] + async def fetch_responses_async(self) -> list: + return await asyncio.to_thread(self.fetch_responses) + + # for streaming performance + async def fetch_responses_loop_async(self) -> AsyncGenerator[list, None]: + while not self.shutdown_event.is_set(): + responses = await asyncio.to_thread(self.fetch_responses + ) # run blocking call in thread + if responses: # Only yield if there are actual responses + logger_debug( + f"RpcWorker {mpi_rank()} is yielding responses: {responses}", + color="yellow") + yield responses # batching the responses to opt IPC performance + else: + # Small delay to prevent busy waiting when no responses + await asyncio.sleep(0) + logger_debug( + f"RpcWorker {mpi_rank()} quitting fetch_responses_loop_async", + color="yellow") + + def setup_engine(self): + # Force all the ranks to wait here, and start creating the executor simultaneously. + mpi_comm().barrier() + + super().setup_engine() + def shutdown(self): - logger.debug(f"RPC worker {mpi_rank()} is shutting down") + logger_debug(f"RPC worker {mpi_rank()} is shutting down", + color="yellow") self.shutdown_event.set() super().shutdown() @@ -88,22 +118,24 @@ def main_task( garbage_collection_gen0_threshold=garbage_collection_gen0_threshold) if mpi_rank() != 0: - logger.debug(f"Worker {mpi_rank()} is setting up the engine") # The non-leader worker will setup the engine immediately. # The leader worker will wait for the RPC call to propagate the # potential error. - logger.debug(f"Worker {mpi_rank()} is setting up the engine") + logger_debug(f"Worker {mpi_rank()} is setting up the engine", + color="yellow") worker.setup_engine() if mpi_rank() == 0: - logger.debug(f"Worker {mpi_rank()} is creating the RPC service") + logger_debug(f"Worker {mpi_rank()} is creating the RPC service", + color="yellow") # Step 2: Create the RPC service, it will expose all the APIs of the worker as remote call to the client - rpc_server = RPCServer(worker) + # Set num_workers to larger than 1 since there are some streaming tasks runs infinitely, such as await_responses_async. + rpc_server = RPCServer(worker, num_workers=6) rpc_server.bind(rpc_addr) rpc_server.start() # Step 3: Wait for the worker to shutdown - logger.debug( + logger_debug( f"Worker {mpi_rank()} is waiting for the worker to shutdown") worker.shutdown_event.wait() rpc_server.shutdown() diff --git a/tensorrt_llm/executor/worker_base.py b/tensorrt_llm/executor/worker_base.py index 8353129d2d9..73d0ae2a519 100644 --- a/tensorrt_llm/executor/worker_base.py +++ b/tensorrt_llm/executor/worker_base.py @@ -17,7 +17,7 @@ from ..builder import ConfigEncoder, Engine, EngineConfig from ..llmapi.llm_args import PybindMirror from ..llmapi.tracer import global_tracer -from ..llmapi.utils import _SyncQueue, print_colored_debug +from ..llmapi.utils import _SyncQueue, logger_debug, print_colored_debug from ..logger import logger from ..lora_manager import LoraConfig, LoraManager from ..metrics import RequestEventTiming @@ -92,6 +92,11 @@ def __init__( self._runtime_model_config: Optional[ModelConfig] = None def setup_engine(self) -> None: + logger_debug(f"WorkerBase {self.rank} is setting up the engine", + color="yellow") + + # Force all the ranks to wait here, and start creating the executor simultaneously. + mpi_comm().barrier() device_id = self.global_rank % torch.cuda.device_count() torch.cuda.set_device(device_id) @@ -142,7 +147,14 @@ def setup_engine(self) -> None: else: raise ValueError( f"Unsupported backend config: {executor_config.backend}") + + logger_debug(f"WorkerBase {self.rank} creating py_executor", + color="yellow") self.engine = create_executor(**args) + logger_debug(f"WorkerBase {self.rank} created py_executor", + color="yellow") + logger_debug(f"WorkerBase {self.rank} setup engine done", + color="yellow") self._setup_lora(engine, executor_config, self._lora_config) @@ -468,16 +480,23 @@ def _pop_result(self, client_id: int): self._client_id_to_request_id.pop(client_id, None) def shutdown(self): - if self.engine is not None: - if self.engine.can_enqueue_requests(): - self.engine.shutdown() - self.engine = None - - if hasattr( - self._executor_config, "checkpoint_loader" - ) and self._executor_config.checkpoint_loader is not None: - self._executor_config.checkpoint_loader.cleanup() - self._executor_config.checkpoint_loader = None + if self.engine is None: + return + + logger_debug(f"WorkerBase {self.rank} is shutting down", color="yellow") + if self.engine.can_enqueue_requests(): + self.engine.shutdown() + self.engine = None + + if hasattr(self._executor_config, "checkpoint_loader" + ) and self._executor_config.checkpoint_loader is not None: + self._executor_config.checkpoint_loader.cleanup() + self._executor_config.checkpoint_loader = None + + self.engine = None + + logger_debug(f"WorkerBase {self.rank} shutdown done", color="yellow") + # Check if there are any errors from the threads before shutdown. self._handle_background_error() diff --git a/tensorrt_llm/llmapi/utils.py b/tensorrt_llm/llmapi/utils.py index a08ad9fd03f..4a9c10f0d7a 100644 --- a/tensorrt_llm/llmapi/utils.py +++ b/tensorrt_llm/llmapi/utils.py @@ -1,6 +1,8 @@ import asyncio import collections +import datetime import hashlib +import inspect import io import os import re @@ -64,10 +66,63 @@ def print_colored(message, def print_colored_debug(message, color: Optional[str] = None, writer: io.TextIOWrapper = sys.stderr): - if enable_llm_debug(): + if enable_llmapi_debug(): print_colored(message, color, writer) +def get_current_location(skip_frames: int = 2) -> str: + """ + Get the current execution location in format 'module.class.function'. + + Args: + skip_frames: Number of stack frames to skip (default 2 to skip this function and its caller) + + Returns: + String in format 'module.class.function' or 'module.function' if not in a class + """ + stack = inspect.stack() + if len(stack) <= skip_frames: + return "unknown" + + frame = stack[skip_frames] + module_name = frame.frame.f_globals.get('__name__', 'unknown') + function_name = frame.function + + # Try to determine if we're in a class method + class_name = None + if 'self' in frame.frame.f_locals: + # This is likely an instance method + obj = frame.frame.f_locals['self'] + class_name = obj.__class__.__name__ + elif 'cls' in frame.frame.f_locals: + # This might be a class method + cls = frame.frame.f_locals['cls'] + if inspect.isclass(cls): + class_name = cls.__name__ + + # Build the location string + if class_name: + return f"{module_name}.{class_name}.{function_name}" + else: + return f"{module_name}.{function_name}" + + +def logger_debug(message, + color: Optional[str] = None, + writer: io.TextIOWrapper = sys.stderr): + """ Print the message if the llmapi debug mode is enabled. Fallback to logger.debug if not. """ + if enable_llmapi_debug(): + timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + location = get_current_location() + cur_dualname = "..." + location[-47:] if len( + location) > 50 else location + print_colored(f"{timestamp} [{cur_dualname}]", "bold_green", writer) + print_colored(f" {message}\n", color, writer) + else: + # Fallback to logger.debug + logger.debug(message) + + def file_with_glob_exists(directory, glob) -> bool: path = Path(directory) for file_path in path.glob(glob): @@ -290,6 +345,17 @@ def enable_llm_debug() -> bool: return _enable_llm_debug_ +_enable_llmapi_debug_ = None + + +def enable_llmapi_debug() -> bool: + global _enable_llmapi_debug_ + if _enable_llmapi_debug_ is None: + _enable_llmapi_debug_ = os.environ.get("TLLM_LLMAPI_ENABLE_DEBUG", + "0") == "1" + return _enable_llmapi_debug_ + + @cache def enable_worker_single_process_for_tp1() -> bool: ''' Tell whether to make worker use single process for TP1. diff --git a/tests/unittest/executor/test_rpc.py b/tests/unittest/executor/test_rpc.py index b421a1d99a7..daf15026322 100644 --- a/tests/unittest/executor/test_rpc.py +++ b/tests/unittest/executor/test_rpc.py @@ -1,9 +1,11 @@ +import asyncio import time import pytest from tensorrt_llm.executor.rpc import (RPCCancelled, RPCClient, RPCError, - RPCServer, RPCTimeout) + RPCParams, RPCServer, RPCStreamingError, + RPCTimeout) class RpcServerWrapper(RPCServer): @@ -125,7 +127,7 @@ def get_task_submitted(self) -> bool: with RpcServerWrapper(App(), addr="ipc:///tmp/rpc_test_no_wait") as server: with RPCClient("ipc:///tmp/rpc_test_no_wait") as client: - client.send_task(__rpc_need_response=False) + client.send_task(__rpc_params=RPCParams(need_response=False)) time.sleep( 0.1 ) # wait for some time to make sure the task is submitted @@ -208,7 +210,8 @@ def task(self): with RPCClient(addr) as client: client.shutdown_server() pending_futures = [ - client.task(__rpc_mode="future") for _ in range(10) + client.task(__rpc_params=RPCParams(mode="future")) + for _ in range(10) ] for future in pending_futures: @@ -237,7 +240,7 @@ def slow_method(self): with RPCClient("ipc:///tmp/rpc_test_timeout", timeout=0.5) as client: with pytest.raises(RPCError) as exc_info: - client.slow_method(__rpc_timeout=0.5) + client.slow_method(__rpc_params=RPCParams(timeout=0.5)) error = exc_info.value # Should be either a timeout error or RPC error indicating timeout @@ -305,14 +308,14 @@ def send_task(self) -> None: with RPCClient("ipc:///tmp/rpc_test_no_wait") as client: time_start = time.time() for i in range(100): - client.send_task(__rpc_need_response=False) + client.send_task(__rpc_params=RPCParams(need_response=False)) time_end = time.time() no_wait_time = time_end - time_start time_start = time.time() for i in range(100): - client.send_task(__rpc_need_response=True) + client.send_task(__rpc_params=RPCParams(need_response=True)) time_end = time.time() wait_time = time_end - time_start @@ -339,7 +342,8 @@ def cal(self, n: int): time_start = time.time() for i in range(100): - ret = client.cal(i, __rpc_timeout=10) # sync call + ret = client.cal( + i, __rpc_params=RPCParams(timeout=10)) # sync call assert ret == i * 2, f"{ret} != {i * 2}" time_end = time.time() print( @@ -347,13 +351,8 @@ def cal(self, n: int): ) -@pytest.mark.parametrize("use_async", [True, False]) -def test_rpc_timeout(use_async: bool): - """Test RPC timeout functionality. - - Args: - use_async: Whether to test async RPC calls or sync RPC calls - """ +class TestRpcTimeout: + """Test RPC timeout functionality for both sync and async calls, sharing server/client.""" class App: @@ -362,53 +361,68 @@ def slow_operation(self, delay: float): time.sleep(delay) return "completed" - # Use manual server lifecycle management to ensure server stays alive - server = RPCServer(App()) - server.bind("ipc:///tmp/rpc_test_timeout") - server.start() - - try: + def setup_method(self, method): + """Setup RPC server and client for timeout tests.""" + # Use unique address based on the test parameter to avoid socket conflicts + test_name = method.__name__ + self.address = f"ipc:///tmp/rpc_test_timeout_{test_name}_{id(self)}" + self.server = RPCServer(self.App()) + self.server.bind(self.address) + self.server.start() time.sleep(0.1) - with RPCClient("ipc:///tmp/rpc_test_timeout") as client: + self.client = RPCClient(self.address) - # Test that a short timeout causes RPCTimeout exception - with pytest.raises(RPCTimeout) as exc_info: - import asyncio - if use_async: + def teardown_method(self): + """Shutdown server and close client.""" + self.client.close() + self.server.shutdown() + # Add a small delay to ensure the socket is fully released before the next test + time.sleep(0.5) - async def test_async_timeout(): - return await client.call_async('slow_operation', - 2.0, - __rpc_timeout=0.1) + def run_sync_timeout_test(self): + with pytest.raises(RPCTimeout) as exc_info: + self.client.slow_operation(2.0, __rpc_params=RPCParams(timeout=0.1)) + assert "timed out" in str( + exc_info.value), f"Timeout message not found: {exc_info.value}" - asyncio.run(test_async_timeout()) - else: - assert client.slow_operation( - 2.0, __rpc_timeout=0.1) # small timeout + def run_async_timeout_test(self): + import asyncio - assert "timed out" in str( - exc_info.value - ), f"Timeout message not found: {exc_info.value}" + async def async_timeout(): + with pytest.raises(RPCTimeout) as exc_info: + await self.client.call_async( + 'slow_operation', 2.0, __rpc_params=RPCParams(timeout=0.1)) + assert "timed out" in str( + exc_info.value), f"Timeout message not found: {exc_info.value}" - # Test that a long timeout allows the operation to complete - if use_async: - import asyncio + asyncio.run(async_timeout()) - async def test_async_success(): - return await client.call_async('slow_operation', - 0.1, - __rpc_timeout=10.0) + def run_sync_success_test(self): + result = self.client.slow_operation( + 0.1, __rpc_params=RPCParams(timeout=10.0)) + assert result == "completed" + print(f"final result: {result}") - result = asyncio.run(test_async_success()) - else: - result = client.slow_operation(0.1, __rpc_timeout=10.0) + def run_async_success_test(self): + import asyncio + async def async_success(): + result = await self.client.call_async( + 'slow_operation', 0.1, __rpc_params=RPCParams(timeout=10.0)) assert result == "completed" - print(f"final result: {result}") + return result + + return asyncio.run(async_success()) - finally: - server.shutdown() + @pytest.mark.parametrize("use_async", [True, False]) + def test_rpc_timeout(self, use_async): + if use_async: + self.run_async_timeout_test() + self.run_async_success_test() + else: + self.run_sync_timeout_test() + self.run_sync_success_test() class TestRpcShutdown: @@ -445,7 +459,8 @@ def foo(self, delay: int): time.sleep(0.1) with RPCClient("ipc:///tmp/rpc_test_shutdown") as client: # This task should be continued after server shutdown - res = client.foo(10, __rpc_timeout=12, __rpc_mode="future") + res = client.foo(10, + __rpc_params=RPCParams(timeout=12, mode="future")) # The shutdown will block until all pending requests are finished server.shutdown() @@ -453,6 +468,164 @@ def foo(self, delay: int): assert res.result() == "foo" +class TestApp: + """Test application with various method types.""" + + def __init__(self): + self.call_count = 0 + + def sync_add(self, a: int, b: int) -> int: + """Sync method.""" + self.call_count += 1 + return a + b + + async def async_multiply(self, x: int, y: int) -> int: + """Async method.""" + await asyncio.sleep(0.01) + self.call_count += 1 + return x * y + + async def streaming_range(self, n: int): + """Streaming generator.""" + for i in range(n): + await asyncio.sleep(0.01) + yield i + + async def streaming_error(self, n: int): + """Streaming generator that raises error.""" + for i in range(n): + if i == 2: + raise ValueError("Test error at i=2") + yield i + + async def streaming_timeout(self, delay: float): + """Streaming generator with configurable delay.""" + for i in range(10): + await asyncio.sleep(delay) + yield i + + +class TestRpcAsync: + # Use setup_method/teardown_method for pytest class-based setup/teardown + def setup_method(self): + """Setup RPC server and client for tests.""" + self.app = TestApp() + self.server = RPCServer(self.app, num_workers=2, async_run_task=True) + self.server.bind("tcp://127.0.0.1:0") # Use random port + self.server.start() + # Get actual address after binding + address = f"tcp://127.0.0.1:{self.server.address.split(':')[-1]}" + self.client = RPCClient(address) + + def teardown_method(self): + self.server.shutdown() + self.client.close() + + @pytest.mark.asyncio + async def test_sync_method(self): + """Test traditional sync method still works.""" + app, client, server = self.app, self.client, self.server + + # Test sync call + result = client.sync_add(5, 3) + assert result == 8 + assert app.call_count == 1 + + @pytest.mark.asyncio + async def test_async_method(self): + """Test async method execution.""" + app, client, server = self.app, self.client, self.server + + # Test async call + result = await client.async_multiply.call_async(4, 7) + assert result == 28 + assert app.call_count == 1 + + @pytest.mark.asyncio + async def test_streaming_basic(self): + """Test basic streaming functionality.""" + app, client, server = self.app, self.client, self.server + + results = [] + async for value in client.streaming_range.call_streaming(5): + results.append(value) + + assert results == [0, 1, 2, 3, 4] + + @pytest.mark.asyncio + async def test_streaming_concurrent(self): + """Test concurrent streaming calls.""" + app, client, server = self.app, self.client, self.server + + async def collect_stream(n): + results = [] + async for value in client.streaming_range.call_streaming(n): + results.append(value) + return results + + # Run 3 concurrent streams + results = await asyncio.gather(collect_stream(3), collect_stream(4), + collect_stream(5)) + + assert results[0] == [0, 1, 2] + assert results[1] == [0, 1, 2, 3] + assert results[2] == [0, 1, 2, 3, 4] + + @pytest.mark.asyncio + async def test_streaming_error_handling(self): + """Test error handling in streaming.""" + app, client, server = self.app, self.client, self.server + + results = [] + with pytest.raises(RPCStreamingError, match="Test error at i=2"): + async for value in client.streaming_error.call_streaming(5): + results.append(value) + + # Should have received values before error + assert results == [0, 1] + + @pytest.mark.asyncio + async def test_streaming_timeout(self): + """Test timeout handling in streaming.""" + app, client, server = self.app, self.client, self.server + + # Set short timeout + with pytest.raises(RPCTimeout): + async for value in client.streaming_timeout.call_streaming( + delay=2.0, __rpc_params=RPCParams(timeout=0.5)): + pass # Should timeout before first yield + + @pytest.mark.asyncio + async def test_mixed_calls(self): + """Test mixing different call types.""" + app, client, server = self.app, self.client, self.server + + # Run sync, async, and streaming calls together + sync_result = client.sync_add(1, 2) + async_future = client.async_multiply.call_future(3, 4) + + streaming_results = [] + async for value in client.streaming_range.call_streaming(3): + streaming_results.append(value) + + async_result = async_future.result() + + assert sync_result == 3 + assert async_result == 12 + assert streaming_results == [0, 1, 2] + assert app.call_count == 2 # sync + async (streaming doesn't increment) + + @pytest.mark.asyncio + async def test_invalid_streaming_call(self): + """Test calling non-streaming method with streaming.""" + app, client, server = self.app, self.client, self.server + + # This should fail because sync_add is not an async generator + with pytest.raises(RPCStreamingError): + async for value in client.call_streaming('sync_add', 1, 2): + pass + + if __name__ == "__main__": #TestRpcError().test_shutdown_cancelled_error() #test_rpc_shutdown_server() diff --git a/tests/unittest/executor/test_rpc_proxy.py b/tests/unittest/executor/test_rpc_proxy.py index 5251d66d7b9..90e1036c5f7 100644 --- a/tests/unittest/executor/test_rpc_proxy.py +++ b/tests/unittest/executor/test_rpc_proxy.py @@ -1,6 +1,8 @@ import os import sys +import time +import pytest from test_worker_base import create_fake_executor_config from tensorrt_llm.executor.rpc_proxy import GenerationExecutorRpcProxy @@ -17,24 +19,34 @@ model_path = llm_models_root() / "llama-models-v2/TinyLlama-1.1B-Chat-v1.0" -class TestRpcProxyTp1: - - def setup_method(self): - self.executor_config = create_fake_executor_config(model_path) +class TestRpcProxy: def create_proxy(self, tp_size: int): + # Create executor config with the correct tp_size + executor_config = create_fake_executor_config(model_path, + tp_size=tp_size) + mpi_session = MpiPoolSession(n_workers=tp_size) proxy = GenerationExecutorRpcProxy( worker_kwargs={ "engine": model_path, - "executor_config": self.executor_config, + "executor_config": executor_config, "model_world_size": tp_size, }, + model_world_size=tp_size, mpi_session=mpi_session, ) + + # Add additional wait for PyTorch backend with multi-rank setup + if tp_size > 1: + print(f"[Test] Waiting for {tp_size} ranks to initialize...") + time.sleep( + 5) # Give more time for multi-rank PyTorch initialization + return proxy - def test_tp1(self): + @pytest.mark.parametrize("num_reqs", [1, 10]) + def test_tp1(self, num_reqs): tokenizer = TransformersTokenizer.from_pretrained(model_path) prompt = "A B C D" prompt_token_ids = tokenizer.encode(prompt) @@ -42,7 +54,23 @@ def test_tp1(self): with self.create_proxy(tp_size=1) as proxy: sampling_params = SamplingParams(max_tokens=max_tokens) - result = proxy.generate(prompt_token_ids, sampling_params) + for _ in range(num_reqs): + result = proxy.generate(prompt_token_ids, sampling_params) + print(f"get result: {result}") + assert similar(tokenizer.decode(result.outputs[0].token_ids), + 'E F G H I J K L') + + @pytest.mark.parametrize("num_reqs", [1, 10]) + def test_tp2(self, num_reqs): + tokenizer = TransformersTokenizer.from_pretrained(model_path) + prompt = "A B C D" + prompt_token_ids = tokenizer.encode(prompt) + max_tokens = 8 + + with self.create_proxy(tp_size=2) as proxy: + sampling_params = SamplingParams(max_tokens=max_tokens) + for _ in range(num_reqs): + result = proxy.generate(prompt_token_ids, sampling_params) print(f"get result: {result}") assert similar(tokenizer.decode(result.outputs[0].token_ids), 'E F G H I J K L') diff --git a/tests/unittest/executor/test_rpc_worker.py b/tests/unittest/executor/test_rpc_worker.py index 97ff0c9b997..7c08a63db34 100644 --- a/tests/unittest/executor/test_rpc_worker.py +++ b/tests/unittest/executor/test_rpc_worker.py @@ -1,15 +1,18 @@ +import asyncio import multiprocessing import os import sys import time from concurrent.futures import ProcessPoolExecutor +import pytest from test_worker_base import create_fake_executor_config from tensorrt_llm.executor.request import GenerationRequest -from tensorrt_llm.executor.rpc import RPCClient +from tensorrt_llm.executor.rpc import RPCClient, RPCParams from tensorrt_llm.executor.rpc_proxy import GenerationExecutorRpcProxy from tensorrt_llm.executor.rpc_worker import RpcWorker +from tensorrt_llm.llmapi.mpi_session import MpiPoolSession from tensorrt_llm.sampling_params import SamplingParams # isort: off @@ -20,15 +23,24 @@ model_path = llm_models_root() / "llama-models-v2/TinyLlama-1.1B-Chat-v1.0" -class TestRpcWorker: +class TestRpcWorkerTP1: def setup_method(self): self.executor_config = create_fake_executor_config(model_path) + self.pool, self.addr = self.create_worker_pool() + self.client = self.create_rpc_client(self.addr) + self.client.setup_engine() + time.sleep(10) - def create_tp1_worker_process(self): + def teardown_method(self): + self.client.shutdown() + self.pool.shutdown() + self.client.close() + + def create_worker_pool(self): addr = GenerationExecutorRpcProxy.gen_uniq_rpc_addr() - # Use spawn method instead of fork - mp_context = multiprocessing.get_context('spawn') + mp_context = multiprocessing.get_context( + 'spawn') # spawn for CUDA context pool = ProcessPoolExecutor(max_workers=1, mp_context=mp_context) pool.submit(RpcWorker.main_task, engine=model_path, @@ -40,17 +52,94 @@ def create_rpc_client(self, addr: str): client = RPCClient(addr) return client + def test_create_shutdown(self): + pass + + def test_fetch_responses_sync(self): + self.client.submit(GenerationRequest( + prompt_token_ids=[3, 4, 5], + sampling_params=SamplingParams(max_tokens=5)), + __rpc_params=RPCParams(need_response=False)) + results = self.client.fetch_responses() + assert len(results) == 1 + + def test_fetch_responses_streaming_sync(self): + self.client.submit(GenerationRequest( + prompt_token_ids=[3, 4, 5], + sampling_params=SamplingParams(max_tokens=5), + streaming=True), + __rpc_params=RPCParams(need_response=False)) + + results = [] + for i in range(10): + res = self.client.fetch_responses() + results.extend(res) + print(f"fetch_responses {i} result: {results}") + assert 0 < len(results) <= 5 + + time.sleep(5) + + @pytest.mark.asyncio + async def test_fetch_responses_streaming_async(self): + self.client.submit(GenerationRequest( + prompt_token_ids=[3, 4, 5], + sampling_params=SamplingParams(max_tokens=5), + streaming=True), + __rpc_params=RPCParams(need_response=False)) + + results = [] + # Must fetch all the responses, or the PyExecutor will hang + for i in range(10): + res = await self.client.fetch_responses_async.call_async() + results.extend(res) + print(f"fetch_responses_async {i} result: {results}") + assert 0 < len(results) <= 5 + + @pytest.mark.asyncio + @pytest.mark.parametrize("req_count", [10]) + async def test_main_loop_async(self, req_count: int): + await asyncio.sleep(1) + + async def process_request_streaming(): + for i in range(req_count): + ret = self.client.submit( + GenerationRequest( + prompt_token_ids=[3, 4, 5], + sampling_params=SamplingParams(max_tokens=5), + streaming=True), + __rpc_params=RPCParams(need_response=False)) + assert ret is None + print("submit result: ", ret) + + # NOTE: known issue, the responses should be fetched before shutdown, + # or the shutdown will hang. + results = [] + + print(f"start to fetch_responses_async") + no = 0 + async for result in self.client.fetch_responses_loop_async.call_streaming( + ): + print(f"fetch_responses_async {no} result: {result}") + results.extend(result) # result is a list of responses + no += 1 + if no >= req_count * 5: # Break after receiving 5 batches + print(f"break after receiving {no} batches") + break + print(f"Received {no} batches of streaming responses") + print(f"fetch_responses result: {results}") + assert results + + await process_request_streaming() + def test_main_loop(self): - pool, addr = self.create_tp1_worker_process() - client = self.create_rpc_client(addr) - client.setup_engine(__rpc_timeout=120) time.sleep(1) def process_request(): - ret = client.submit(GenerationRequest( - prompt_token_ids=[3, 4, 5], - sampling_params=SamplingParams(max_tokens=10)), - __rpc_need_response=False) + ret = self.client.submit( + GenerationRequest( + prompt_token_ids=[3, 4, 5], + sampling_params=SamplingParams(max_tokens=10)), + __rpc_params=RPCParams(need_response=False)) assert ret is None # need_response = False print(f"submit result: {ret}") @@ -60,16 +149,16 @@ def process_request(): results = [] time.sleep(8) # wait for PyExecutor to finish the generation results.extend( - client.fetch_responses()) # fetch_responses will block + self.client.fetch_responses()) # fetch_responses will block print(f"fetch_responses result: {results}") assert len(results) == 1 # one request, one response def process_request_streaming(): - ret = client.submit(GenerationRequest( - prompt_token_ids=[3, 4, 5], - sampling_params=SamplingParams(max_tokens=10), - streaming=True), - __rpc_need_response=False) + ret = self.client.submit( + GenerationRequest(prompt_token_ids=[3, 4, 5], + sampling_params=SamplingParams(max_tokens=10), + streaming=True), + __rpc_params=RPCParams(need_response=False)) assert ret is None print("submit result: ", ret) @@ -80,7 +169,9 @@ def process_request_streaming(): while not results: time.sleep(1) - results.extend(client.fetch_responses(__rpc_timeout=10)) + results.extend( + self.client.fetch_responses(__rpc_params=RPCParams( + timeout=10))) print(f"try fetch_responses result: {results}") print(f"fetch_responses result: {results}") assert results @@ -89,12 +180,44 @@ def process_request_streaming(): process_request() process_request_streaming() - print("call shutdown") - client.shutdown(__rpc_timeout=10) - pool.shutdown() - client.close() +class TestRpcWorkerTP2: -if __name__ == '__main__': - worker = TestRpcWorker() - worker.test_main_loop() + def setup_method(self): + self.executor_config = create_fake_executor_config(model_path, + tp_size=2) + self.session, self.addr, self.futures = self.create_worker_session() + self.client = self.create_rpc_client(self.addr) + self.client.setup_engine() + time.sleep(10) + + def teardown_method(self): + self.client.shutdown() + self.session.shutdown() + self.client.close() + + def create_worker_session(self): + session = MpiPoolSession(n_workers=2) + addr = GenerationExecutorRpcProxy.gen_uniq_rpc_addr() + futures = session.submit(RpcWorker.main_task, + engine=model_path, + rpc_addr=addr, + executor_config=self.executor_config, + model_world_size=2) + return session, addr, futures + + def create_rpc_client(self, addr: str): + return RPCClient(addr) + + def test_create_shutdown(self): + # Invoke setup_engine in rank 0, and that will unblock all the ranks to + # invoke setup_engine simultaneously. + pass + + def test_fetch_responses_sync(self): + self.client.submit(GenerationRequest( + prompt_token_ids=[3, 4, 5], + sampling_params=SamplingParams(max_tokens=5)), + __rpc_params=RPCParams(need_response=False)) + results = self.client.fetch_responses() + assert len(results) == 1 diff --git a/tests/unittest/executor/test_worker_base.py b/tests/unittest/executor/test_worker_base.py index d40efe756b5..a5fbdb7ae75 100644 --- a/tests/unittest/executor/test_worker_base.py +++ b/tests/unittest/executor/test_worker_base.py @@ -3,11 +3,15 @@ import time import pytest +import torch + +from tensorrt_llm._utils import mpi_comm, mpi_rank, mpi_world_size +from tensorrt_llm.bindings import executor as tllm +from tensorrt_llm.llmapi.mpi_session import MpiPoolSession, set_mpi_session_cpp # isort: off sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/..") from utils.llm_data import llm_models_root -from tensorrt_llm.bindings import executor as tllm # isort: on from tensorrt_llm._torch.pyexecutor.config import update_executor_config @@ -82,11 +86,16 @@ def test_fetch_responses_timeout(self, timeout: float): assert results is None -def create_fake_executor_config(model_path): +def create_fake_executor_config(model_path, tp_size=1): llm_args = LlmArgs(model=model_path, cuda_graph_config=None) executor_config = tllm.ExecutorConfig(1) executor_config.max_batch_size = 1 + executor_config.model_world_size = tp_size + + # For PyTorch backend with TP > 1, we need proper parallel config + if tp_size > 1: + llm_args.parallel_config.tp_size = tp_size update_executor_config( executor_config, @@ -104,6 +113,70 @@ def create_fake_executor_config(model_path): return executor_config +class TestRpcWorkerBaseTP2: + + def setup_method(self): + self.executor_config = create_fake_executor_config(model_path, + tp_size=2) + self.session = self.create_worker_session() + # No need to sleep here - the session is ready immediately + + def create_worker_session(self): + session = MpiPoolSession(n_workers=2) + return session + + def test_create_executor(self): + futures = self.session.submit(TestRpcWorkerBaseTP2.create_executor, + engine=model_path, + executor_config=self.executor_config) + # Wait for completion + for future in futures: + future.result() + + self.session.shutdown() + + @staticmethod + def create_executor(engine, executor_config): + # Set MPI session for C++ backend + set_mpi_session_cpp(mpi_comm()) + + # Set CUDA device for this rank + rank = mpi_rank() + world_size = mpi_world_size() + device_id = rank % torch.cuda.device_count() + torch.cuda.set_device(device_id) + + # Don't set CUDA_VISIBLE_DEVICES as it interferes with MPI multi-GPU setup + + print(f"[Test] Rank {rank}/{world_size} using device {device_id}") + + # Synchronize all workers before creating executor + mpi_comm().barrier() + + try: + print(f"[Test] Rank {rank} creating WorkerBase...") + executor = WorkerBase(engine=engine, + executor_config=executor_config) + + # For PyTorch backend, all ranks need to participate in setup + print(f"[Test] Rank {rank} calling setup_engine...") + + # Setup the engine which contains another barrier + executor.setup_engine() + + print(f"[Test] Rank {rank} setup_engine completed successfully") + + executor.shutdown() + + except Exception as e: + print(f"[Test] Rank {rank} failed with error: {e}") + import traceback + traceback.print_exc() + raise + + return None # executor cannot be picked and returned + + if __name__ == "__main__": test_worker_base = TestWorkerBase() test_worker_base.test_fetch_stats() diff --git a/tests/unittest/llmapi/test_llm_utils.py b/tests/unittest/llmapi/test_llm_utils.py index 5155e158e62..58608556c9d 100644 --- a/tests/unittest/llmapi/test_llm_utils.py +++ b/tests/unittest/llmapi/test_llm_utils.py @@ -1,9 +1,13 @@ +import asyncio import tempfile +import threading +import time from pathlib import Path import torch from tensorrt_llm.llmapi.llm_utils import * +from tensorrt_llm.llmapi.utils import AsyncQueue # isort: off from .test_llm import llama_model_path @@ -58,3 +62,25 @@ def test_LlmArgs_default_gpus_per_node(): # set explicitly llm_args = TrtLlmArgs(model=llama_model_path, gpus_per_node=6) assert llm_args.gpus_per_node == 6 + + +def test_AsyncQueue(): + queue = AsyncQueue() + + # put data to queue sync in a thread + # async get data from queue in the current event loop + # NOTE: the event loop in the two threads are different + + def put_data_to_queue(): + for i in range(10): + time.sleep(0.1) + queue.put(i) + + async def get_data_from_queue(): + for i in range(10): + print(f"get: {queue.get()}") + + thread = threading.Thread(target=put_data_to_queue) + thread.start() + asyncio.run(get_data_from_queue()) + thread.join() diff --git a/tests/unittest/pytest.ini b/tests/unittest/pytest.ini index 1e93eae77ab..f78281579df 100644 --- a/tests/unittest/pytest.ini +++ b/tests/unittest/pytest.ini @@ -2,7 +2,7 @@ xdist_start_method = spawn asyncio_default_fixture_loop_scope = module threadleak = True -threadleak_exclude = asyncio_\d+ +threadleak_exclude = asyncio_\d+|rpc_client_loop|rpc_client_worker_\d+|rpc_server_worker_\d+ addopts = --durations=0 -W ignore::DeprecationWarning pythonpath = _torch/auto_deploy/_utils_test From dc9f7453dcbf7ceb776e81b0024e7552f55605d0 Mon Sep 17 00:00:00 2001 From: Yan Chunwei <328693+Superjomn@users.noreply.github.com> Date: Mon, 15 Sep 2025 13:35:43 +0800 Subject: [PATCH 06/13] refactor RPC client interface Change to something similar to Ray but with more configuration: client.xxx_api(args).remote(rpc_params) Signed-off-by: Yan Chunwei <328693+Superjomn@users.noreply.github.com> Signed-off-by: chunweiy Signed-off-by: chunweiy <328693+Superjomn@users.noreply.github.com> Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> --- tensorrt_llm/executor/rpc/rpc_client.py | 111 +++++++++++++-------- tensorrt_llm/executor/rpc/rpc_server.py | 10 +- tensorrt_llm/executor/rpc_proxy.py | 20 ++-- tensorrt_llm/executor/rpc_worker.py | 2 +- tests/unittest/executor/test_rpc.py | 80 +++++++-------- tests/unittest/executor/test_rpc_worker.py | 79 ++++++++------- 6 files changed, 159 insertions(+), 143 deletions(-) diff --git a/tensorrt_llm/executor/rpc/rpc_client.py b/tensorrt_llm/executor/rpc/rpc_client.py index 723dfeff944..baa8305f64d 100644 --- a/tensorrt_llm/executor/rpc/rpc_client.py +++ b/tensorrt_llm/executor/rpc/rpc_client.py @@ -11,6 +11,59 @@ RPCStreamingError, RPCTimeout) +class RemoteCall: + """Helper class to enable chained remote call syntax like client.method().remote()""" + + def __init__(self, client: 'RPCClient', method_name: str, *args, **kwargs): + self.client = client + self.method_name = method_name + self.args = args + self.kwargs = kwargs + + def remote(self, + timeout: Optional[float] = None, + need_response: bool = True) -> Any: + """Synchronous remote call with optional RPC parameters.""" + rpc_params = RPCParams(timeout=timeout, + need_response=need_response, + mode="sync") + self.kwargs["__rpc_params"] = rpc_params + return self.client._call_sync(self.method_name, *self.args, + **self.kwargs) + + def remote_async(self, + timeout: Optional[float] = None, + need_response: bool = True): + """Asynchronous remote call that returns a coroutine.""" + rpc_params = RPCParams(timeout=timeout, + need_response=need_response, + mode="async") + self.kwargs["__rpc_params"] = rpc_params + return self.client._call_async(self.method_name, *self.args, + **self.kwargs) + + def remote_future(self, + timeout: Optional[float] = None, + need_response: bool = True) -> concurrent.futures.Future: + """Remote call that returns a Future object.""" + rpc_params = RPCParams(timeout=timeout, + need_response=need_response, + mode="future") + self.kwargs["__rpc_params"] = rpc_params + return self.client.call_future(self.method_name, *self.args, + **self.kwargs) + + def remote_streaming(self, + timeout: Optional[float] = None) -> AsyncIterator[Any]: + """Remote call for streaming results.""" + rpc_params = RPCParams(timeout=timeout, + need_response=True, + mode="async") + self.kwargs["__rpc_params"] = rpc_params + return self.client.call_streaming(self.method_name, *self.args, + **self.kwargs) + + class RPCClient: """ An RPC Client that connects to the RPCServer. @@ -53,7 +106,7 @@ def shutdown_server(self): if self._server_stopped: return - self.call_sync("__rpc_shutdown") + self._rpc_shutdown().remote() self._server_stopped = True @@ -258,6 +311,9 @@ def run_loop(): def _call_sync(self, method_name, *args, **kwargs): """Synchronous version of RPC call.""" + logger_debug( + f"RPC Client calling method: {method_name} with args: {args} and kwargs: {kwargs}" + ) self._ensure_event_loop() future = asyncio.run_coroutine_threadsafe( self._call_async(method_name, *args, **kwargs), self._loop) @@ -412,54 +468,25 @@ async def call_streaming(self, name: str, *args, def get_server_attr(self, name: str): """ Get the attribute of the RPC server. This is mainly used for testing. """ - return self._call_sync("__rpc_get_attr", - name, - __rpc_params=RPCParams(timeout=10)) + return self._rpc_get_attr(name).remote() def __getattr__(self, name): """ Magically handles calls to non-existent methods. - Returns a proxy object that supports multiple calling patterns. - """ + Returns a callable that when invoked returns a RemoteCall instance. - class MethodProxy: - - def __init__(self, client, method_name): - self.client = client - self.method_name = method_name - - def __call__(self, *args, **kwargs): - """Default synchronous call""" - rpc_params = kwargs.get("__rpc_params", RPCParams()) - mode = rpc_params.mode - if mode == "sync": - return self.client._call_sync(self.method_name, *args, - **kwargs) - elif mode == "async": - return self.client._call_async(self.method_name, *args, - **kwargs) - elif mode == "future": - return self.client.call_future(self.method_name, *args, - **kwargs) - else: - raise ValueError(f"Invalid RPC mode: {mode}") - - def call_async(self, *args, **kwargs): - """Async call - returns coroutine""" - return self.client._call_async(self.method_name, *args, - **kwargs) - - def call_future(self, *args, **kwargs) -> concurrent.futures.Future: - """Future call - returns Future object""" - return self.client.call_future(self.method_name, *args, - **kwargs) + This enables the new syntax: + client.method(args).remote() + await client.method(args).remote_async() + client.method(args).remote_future() + async for x in client.method(args).remote_streaming() + """ + logger_debug(f"RPC Client getting attribute: {name}") - def call_streaming(self, *args, **kwargs) -> AsyncIterator[Any]: - """Streaming call - returns async iterator""" - return self.client.call_streaming(self.method_name, *args, - **kwargs) + def method_caller(*args, **kwargs): + return RemoteCall(self, name, *args, **kwargs) - return MethodProxy(self, name) + return method_caller def __enter__(self): return self diff --git a/tensorrt_llm/executor/rpc/rpc_server.py b/tensorrt_llm/executor/rpc/rpc_server.py index cf6e05486d3..dd7246b4ed3 100644 --- a/tensorrt_llm/executor/rpc/rpc_server.py +++ b/tensorrt_llm/executor/rpc/rpc_server.py @@ -50,8 +50,8 @@ def __init__(self, self._num_pending_requests = 0 self._functions = { - "__rpc_shutdown": lambda: self.shutdown(is_remote_call=True), - "__rpc_get_attr": lambda name: self.get_attr(name), + "_rpc_shutdown": lambda: self.shutdown(is_remote_call=True), + "_rpc_get_attr": lambda name: self.get_attr(name), } self._dispatcher_thread: Optional[ManagedThread] = None if async_run_task: @@ -197,7 +197,7 @@ async def _dispatcher_routine(self, stop_event: threading.Event): # shutdown methods depend on _num_pending_requests, so # they should not be counted - if req.method_name not in ["__rpc_shutdown", "shutdown"]: + if req.method_name not in ["_rpc_shutdown", "shutdown"]: self._num_pending_requests += 1 logger_debug( f"Dispatcher received request {req}, pending: {self._num_pending_requests}" @@ -221,6 +221,8 @@ async def _worker_routine(self, stop_event: threading.Event): if req.method_name not in self._functions: logger.error( f"Method '{req.method_name}' not found in RPC server.") + self._num_pending_requests -= 1 + if not req.need_response: continue if req.is_streaming: @@ -274,7 +276,7 @@ async def _worker_routine(self, stop_event: threading.Event): logger_debug(f"RPC Server sent response for request {req}") # Only decrement if this request was counted in the first place - if req.method_name not in ["__rpc_shutdown", "shutdown"]: + if req.method_name not in ["_rpc_shutdown", "shutdown"]: self._num_pending_requests -= 1 async def _process_request(self, req: RPCRequest) -> Optional[RPCResponse]: diff --git a/tensorrt_llm/executor/rpc_proxy.py b/tensorrt_llm/executor/rpc_proxy.py index 63783033ef8..ef08e73e1d0 100644 --- a/tensorrt_llm/executor/rpc_proxy.py +++ b/tensorrt_llm/executor/rpc_proxy.py @@ -15,7 +15,6 @@ from .request import GenerationRequest from .result import GenerationResult from .rpc import RPCClient -from .rpc.rpc_common import RPCParams from .rpc_worker import RpcWorker from .utils import (ErrorResponse, create_mpi_comm_session, get_spawn_proxy_process_env, is_llm_response) @@ -89,8 +88,8 @@ async def main_loop_task(self): """ Main loop of the proxy, it will invoke the actions periodically. """ - async for responses in self.rpc_client.fetch_responses_loop_async.call_streaming( - ): + async for responses in self.rpc_client.fetch_responses_loop_async( + ).remote_streaming(): if self._shutdown_event.is_set(): return self.handle_responses(responses) @@ -145,8 +144,7 @@ def submit(self, request: GenerationRequest) -> GenerationResult: logprob_params = self._get_logprob_params(request) # submit is a fire-and-forget operation, don't need to wait for response - self.rpc_client.submit(request, - __rpc_params=RPCParams(need_response=False)) + self.rpc_client.submit(request).remote(need_response=False) result = GenerationResult( request, @@ -159,22 +157,20 @@ def submit(self, request: GenerationRequest) -> GenerationResult: return result def fetch_responses_remote(self): - return self.rpc_client.fetch_responses(__rpc_params=RPCParams( - timeout=20)) + return self.rpc_client.fetch_responses().remote(timeout=20) def fetch_stats_remote(self): - return self.rpc_client.fetch_stats() + return self.rpc_client.fetch_stats().remote() def setup_engine_remote(self): - return self.rpc_client.setup_engine(__rpc_params=RPCParams( - need_response=True)) + return self.rpc_client.setup_engine().remote(need_response=True) def shutdown_remote(self): logger_debug(f"Shutting down rpc remote", color="yellow") - self.rpc_client.shutdown() + self.rpc_client.shutdown().remote() def abort_request(self, request_id: int) -> None: - return self.rpc_client.abort_request(request_id) + return self.rpc_client.abort_request(request_id).remote() def shutdown(self): if self._shutdown_event.is_set(): diff --git a/tensorrt_llm/executor/rpc_worker.py b/tensorrt_llm/executor/rpc_worker.py index 4ffe866ae9b..85a799dc1f9 100644 --- a/tensorrt_llm/executor/rpc_worker.py +++ b/tensorrt_llm/executor/rpc_worker.py @@ -57,8 +57,8 @@ def fetch_responses(self, timeout: Optional[float] = None) -> list: color="yellow") # NOTE: This is a blocking call, it will wait for the responses to be available. super().await_responses(timeout) - logger_debug(f"RpcWorker returning responses", color="yellow") qsize = self._response_queue.qsize() + logger_debug(f"RpcWorker returning {qsize} responses", color="yellow") return [self._response_queue.get() for _ in range(qsize)] async def fetch_responses_async(self) -> list: diff --git a/tests/unittest/executor/test_rpc.py b/tests/unittest/executor/test_rpc.py index daf15026322..e84be8fad34 100644 --- a/tests/unittest/executor/test_rpc.py +++ b/tests/unittest/executor/test_rpc.py @@ -4,8 +4,7 @@ import pytest from tensorrt_llm.executor.rpc import (RPCCancelled, RPCClient, RPCError, - RPCParams, RPCServer, RPCStreamingError, - RPCTimeout) + RPCServer, RPCStreamingError, RPCTimeout) class RpcServerWrapper(RPCServer): @@ -45,7 +44,7 @@ def hello(self): with RpcServerWrapper(App(), addr="ipc:///tmp/rpc_test") as server: with RPCClient("ipc:///tmp/rpc_test") as client: - ret = client.hello() # sync call + ret = client.hello().remote() # sync call assert ret == "world" def test_remote_call_with_args(self): @@ -58,7 +57,7 @@ def hello(self, name: str, location: str): with RpcServerWrapper(App(), addr="ipc:///tmp/rpc_test") as server: with RPCClient("ipc:///tmp/rpc_test") as client: - ret = client.hello("app", "Marvel") + ret = client.hello("app", "Marvel").remote() assert ret == "hello app from Marvel" def test_remote_call_with_kwargs(self): @@ -71,7 +70,7 @@ def hello(self, name: str, location: str): with RpcServerWrapper(App(), addr="ipc:///tmp/rpc_test") as server: with RPCClient("ipc:///tmp/rpc_test") as client: - ret = client.hello(name="app", location="Marvel") + ret = client.hello(name="app", location="Marvel").remote() assert ret == "hello app from Marvel" def test_remote_call_with_args_and_kwargs(self): @@ -84,7 +83,7 @@ def hello(self, name: str, location: str): with RpcServerWrapper(App(), addr="ipc:///tmp/rpc_test") as server: with RPCClient("ipc:///tmp/rpc_test") as client: - ret = client.hello(name="app", location="Marvel") + ret = client.hello(name="app", location="Marvel").remote() assert ret == "hello app from Marvel" def test_rpc_server_address(self): @@ -106,7 +105,7 @@ def hello(self): addr="ipc:///tmp/rpc_test_error") as server: with RPCClient("ipc:///tmp/rpc_test_error") as client: with pytest.raises(RPCError): - client.hello() + client.hello().remote() def test_rpc_without_wait_response(self): @@ -127,11 +126,11 @@ def get_task_submitted(self) -> bool: with RpcServerWrapper(App(), addr="ipc:///tmp/rpc_test_no_wait") as server: with RPCClient("ipc:///tmp/rpc_test_no_wait") as client: - client.send_task(__rpc_params=RPCParams(need_response=False)) + client.send_task().remote(need_response=False) time.sleep( 0.1 ) # wait for some time to make sure the task is submitted - assert client.get_task_submitted() + assert client.get_task_submitted().remote() class TestRpcError: @@ -160,7 +159,7 @@ def custom_exception(self): with RPCClient("ipc:///tmp/rpc_test_error") as client: # Test ValueError handling with pytest.raises(RPCError) as exc_info: - client.hello() + client.hello().remote() error = exc_info.value assert "Test error message" in str(error) @@ -171,7 +170,7 @@ def custom_exception(self): # Test ZeroDivisionError handling with pytest.raises(RPCError) as exc_info: - client.divide_by_zero() + client.divide_by_zero().remote() error = exc_info.value assert "division by zero" in str(error) @@ -181,7 +180,7 @@ def custom_exception(self): # Test custom exception handling with pytest.raises(RPCError) as exc_info: - client.custom_exception() + client.custom_exception().remote() error = exc_info.value assert "Custom error occurred" in str(error) @@ -209,10 +208,7 @@ def task(self): with RPCClient(addr) as client: client.shutdown_server() - pending_futures = [ - client.task(__rpc_params=RPCParams(mode="future")) - for _ in range(10) - ] + pending_futures = [client.task().remote_future() for _ in range(10)] for future in pending_futures: with pytest.raises(RPCCancelled): @@ -240,7 +236,7 @@ def slow_method(self): with RPCClient("ipc:///tmp/rpc_test_timeout", timeout=0.5) as client: with pytest.raises(RPCError) as exc_info: - client.slow_method(__rpc_params=RPCParams(timeout=0.5)) + client.slow_method().remote(timeout=0.5) error = exc_info.value # Should be either a timeout error or RPC error indicating timeout @@ -261,7 +257,7 @@ def existing_method(self): with RPCClient("ipc:///tmp/rpc_test_not_found") as client: with pytest.raises(RPCError) as exc_info: - client.non_existent_method() + client.non_existent_method().remote() error = exc_info.value assert "not found" in str(error) @@ -280,7 +276,7 @@ def hello(self): server.start() time.sleep(0.1) with RPCClient("ipc:///tmp/rpc_test_shutdown") as client: - ret = client.hello() + ret = client.hello().remote() assert ret == "world" client.shutdown_server() @@ -308,14 +304,14 @@ def send_task(self) -> None: with RPCClient("ipc:///tmp/rpc_test_no_wait") as client: time_start = time.time() for i in range(100): - client.send_task(__rpc_params=RPCParams(need_response=False)) + client.send_task().remote(need_response=False) time_end = time.time() no_wait_time = time_end - time_start time_start = time.time() for i in range(100): - client.send_task(__rpc_params=RPCParams(need_response=True)) + client.send_task().remote(need_response=True) time_end = time.time() wait_time = time_end - time_start @@ -342,8 +338,7 @@ def cal(self, n: int): time_start = time.time() for i in range(100): - ret = client.cal( - i, __rpc_params=RPCParams(timeout=10)) # sync call + ret = client.cal(i).remote(timeout=10) # sync call assert ret == i * 2, f"{ret} != {i * 2}" time_end = time.time() print( @@ -381,7 +376,7 @@ def teardown_method(self): def run_sync_timeout_test(self): with pytest.raises(RPCTimeout) as exc_info: - self.client.slow_operation(2.0, __rpc_params=RPCParams(timeout=0.1)) + self.client.slow_operation(2.0).remote(timeout=0.1) assert "timed out" in str( exc_info.value), f"Timeout message not found: {exc_info.value}" @@ -390,16 +385,14 @@ def run_async_timeout_test(self): async def async_timeout(): with pytest.raises(RPCTimeout) as exc_info: - await self.client.call_async( - 'slow_operation', 2.0, __rpc_params=RPCParams(timeout=0.1)) + await self.client.slow_operation(2.0).remote_async(timeout=0.1) assert "timed out" in str( exc_info.value), f"Timeout message not found: {exc_info.value}" asyncio.run(async_timeout()) def run_sync_success_test(self): - result = self.client.slow_operation( - 0.1, __rpc_params=RPCParams(timeout=10.0)) + result = self.client.slow_operation(0.1).remote(timeout=10.0) assert result == "completed" print(f"final result: {result}") @@ -407,8 +400,8 @@ def run_async_success_test(self): import asyncio async def async_success(): - result = await self.client.call_async( - 'slow_operation', 0.1, __rpc_params=RPCParams(timeout=10.0)) + result = await self.client.slow_operation(0.1).remote_async( + timeout=10.0) assert result == "completed" print(f"final result: {result}") return result @@ -438,7 +431,7 @@ def quick_task(self, task_id: int): addr="ipc:///tmp/rpc_test_shutdown") as server: time.sleep(0.1) with RPCClient("ipc:///tmp/rpc_test_shutdown") as client: - client.quick_task(1) + client.quick_task(1).remote() # repeated shutdown should not raise an error for i in range(10): @@ -459,8 +452,7 @@ def foo(self, delay: int): time.sleep(0.1) with RPCClient("ipc:///tmp/rpc_test_shutdown") as client: # This task should be continued after server shutdown - res = client.foo(10, - __rpc_params=RPCParams(timeout=12, mode="future")) + res = client.foo(10).remote_future(timeout=12) # The shutdown will block until all pending requests are finished server.shutdown() @@ -527,7 +519,7 @@ async def test_sync_method(self): app, client, server = self.app, self.client, self.server # Test sync call - result = client.sync_add(5, 3) + result = client.sync_add(5, 3).remote() assert result == 8 assert app.call_count == 1 @@ -537,7 +529,7 @@ async def test_async_method(self): app, client, server = self.app, self.client, self.server # Test async call - result = await client.async_multiply.call_async(4, 7) + result = await client.async_multiply(4, 7).remote_async() assert result == 28 assert app.call_count == 1 @@ -547,7 +539,7 @@ async def test_streaming_basic(self): app, client, server = self.app, self.client, self.server results = [] - async for value in client.streaming_range.call_streaming(5): + async for value in client.streaming_range(5).remote_streaming(): results.append(value) assert results == [0, 1, 2, 3, 4] @@ -559,7 +551,7 @@ async def test_streaming_concurrent(self): async def collect_stream(n): results = [] - async for value in client.streaming_range.call_streaming(n): + async for value in client.streaming_range(n).remote_streaming(): results.append(value) return results @@ -578,7 +570,7 @@ async def test_streaming_error_handling(self): results = [] with pytest.raises(RPCStreamingError, match="Test error at i=2"): - async for value in client.streaming_error.call_streaming(5): + async for value in client.streaming_error(5).remote_streaming(): results.append(value) # Should have received values before error @@ -591,8 +583,8 @@ async def test_streaming_timeout(self): # Set short timeout with pytest.raises(RPCTimeout): - async for value in client.streaming_timeout.call_streaming( - delay=2.0, __rpc_params=RPCParams(timeout=0.5)): + async for value in client.streaming_timeout( + delay=2.0).remote_streaming(timeout=0.5): pass # Should timeout before first yield @pytest.mark.asyncio @@ -601,11 +593,11 @@ async def test_mixed_calls(self): app, client, server = self.app, self.client, self.server # Run sync, async, and streaming calls together - sync_result = client.sync_add(1, 2) - async_future = client.async_multiply.call_future(3, 4) + sync_result = client.sync_add(1, 2).remote() + async_future = client.async_multiply(3, 4).remote_future() streaming_results = [] - async for value in client.streaming_range.call_streaming(3): + async for value in client.streaming_range(3).remote_streaming(): streaming_results.append(value) async_result = async_future.result() @@ -622,7 +614,7 @@ async def test_invalid_streaming_call(self): # This should fail because sync_add is not an async generator with pytest.raises(RPCStreamingError): - async for value in client.call_streaming('sync_add', 1, 2): + async for value in client.sync_add(1, 2).remote_streaming(): pass diff --git a/tests/unittest/executor/test_rpc_worker.py b/tests/unittest/executor/test_rpc_worker.py index 7c08a63db34..11e5304ab5b 100644 --- a/tests/unittest/executor/test_rpc_worker.py +++ b/tests/unittest/executor/test_rpc_worker.py @@ -9,7 +9,7 @@ from test_worker_base import create_fake_executor_config from tensorrt_llm.executor.request import GenerationRequest -from tensorrt_llm.executor.rpc import RPCClient, RPCParams +from tensorrt_llm.executor.rpc import RPCClient from tensorrt_llm.executor.rpc_proxy import GenerationExecutorRpcProxy from tensorrt_llm.executor.rpc_worker import RpcWorker from tensorrt_llm.llmapi.mpi_session import MpiPoolSession @@ -29,11 +29,11 @@ def setup_method(self): self.executor_config = create_fake_executor_config(model_path) self.pool, self.addr = self.create_worker_pool() self.client = self.create_rpc_client(self.addr) - self.client.setup_engine() + self.client.setup_engine().remote() time.sleep(10) def teardown_method(self): - self.client.shutdown() + self.client.shutdown().remote() self.pool.shutdown() self.client.close() @@ -56,23 +56,24 @@ def test_create_shutdown(self): pass def test_fetch_responses_sync(self): - self.client.submit(GenerationRequest( - prompt_token_ids=[3, 4, 5], - sampling_params=SamplingParams(max_tokens=5)), - __rpc_params=RPCParams(need_response=False)) - results = self.client.fetch_responses() + self.client.submit( + GenerationRequest(prompt_token_ids=[3, 4, 5], + sampling_params=SamplingParams( + max_tokens=5)), ).remote(need_response=False) + results = [] + while not results: + results.extend(self.client.fetch_responses().remote()) assert len(results) == 1 def test_fetch_responses_streaming_sync(self): - self.client.submit(GenerationRequest( - prompt_token_ids=[3, 4, 5], - sampling_params=SamplingParams(max_tokens=5), - streaming=True), - __rpc_params=RPCParams(need_response=False)) + self.client.submit( + GenerationRequest(prompt_token_ids=[3, 4, 5], + sampling_params=SamplingParams(max_tokens=5), + streaming=True), ).remote(need_response=False) results = [] for i in range(10): - res = self.client.fetch_responses() + res = self.client.fetch_responses().remote() results.extend(res) print(f"fetch_responses {i} result: {results}") assert 0 < len(results) <= 5 @@ -81,16 +82,15 @@ def test_fetch_responses_streaming_sync(self): @pytest.mark.asyncio async def test_fetch_responses_streaming_async(self): - self.client.submit(GenerationRequest( - prompt_token_ids=[3, 4, 5], - sampling_params=SamplingParams(max_tokens=5), - streaming=True), - __rpc_params=RPCParams(need_response=False)) + self.client.submit( + GenerationRequest(prompt_token_ids=[3, 4, 5], + sampling_params=SamplingParams(max_tokens=5), + streaming=True), ).remote(need_response=False) results = [] # Must fetch all the responses, or the PyExecutor will hang for i in range(10): - res = await self.client.fetch_responses_async.call_async() + res = await self.client.fetch_responses_async().remote_async() results.extend(res) print(f"fetch_responses_async {i} result: {results}") assert 0 < len(results) <= 5 @@ -106,8 +106,7 @@ async def process_request_streaming(): GenerationRequest( prompt_token_ids=[3, 4, 5], sampling_params=SamplingParams(max_tokens=5), - streaming=True), - __rpc_params=RPCParams(need_response=False)) + streaming=True), ).remote(need_response=False) assert ret is None print("submit result: ", ret) @@ -117,8 +116,8 @@ async def process_request_streaming(): print(f"start to fetch_responses_async") no = 0 - async for result in self.client.fetch_responses_loop_async.call_streaming( - ): + async for result in self.client.fetch_responses_loop_async( + ).remote_streaming(): print(f"fetch_responses_async {no} result: {result}") results.extend(result) # result is a list of responses no += 1 @@ -138,8 +137,8 @@ def process_request(): ret = self.client.submit( GenerationRequest( prompt_token_ids=[3, 4, 5], - sampling_params=SamplingParams(max_tokens=10)), - __rpc_params=RPCParams(need_response=False)) + sampling_params=SamplingParams(max_tokens=10)), ).remote( + need_response=False) assert ret is None # need_response = False print(f"submit result: {ret}") @@ -148,8 +147,8 @@ def process_request(): # or the shutdown will hang. results = [] time.sleep(8) # wait for PyExecutor to finish the generation - results.extend( - self.client.fetch_responses()) # fetch_responses will block + results.extend(self.client.fetch_responses().remote() + ) # fetch_responses will block print(f"fetch_responses result: {results}") assert len(results) == 1 # one request, one response @@ -157,8 +156,7 @@ def process_request_streaming(): ret = self.client.submit( GenerationRequest(prompt_token_ids=[3, 4, 5], sampling_params=SamplingParams(max_tokens=10), - streaming=True), - __rpc_params=RPCParams(need_response=False)) + streaming=True), ).remote(need_response=False) assert ret is None print("submit result: ", ret) @@ -169,9 +167,7 @@ def process_request_streaming(): while not results: time.sleep(1) - results.extend( - self.client.fetch_responses(__rpc_params=RPCParams( - timeout=10))) + results.extend(self.client.fetch_responses().remote(timeout=10)) print(f"try fetch_responses result: {results}") print(f"fetch_responses result: {results}") assert results @@ -188,11 +184,11 @@ def setup_method(self): tp_size=2) self.session, self.addr, self.futures = self.create_worker_session() self.client = self.create_rpc_client(self.addr) - self.client.setup_engine() + self.client.setup_engine().remote() time.sleep(10) def teardown_method(self): - self.client.shutdown() + self.client.shutdown().remote() self.session.shutdown() self.client.close() @@ -215,9 +211,12 @@ def test_create_shutdown(self): pass def test_fetch_responses_sync(self): - self.client.submit(GenerationRequest( - prompt_token_ids=[3, 4, 5], - sampling_params=SamplingParams(max_tokens=5)), - __rpc_params=RPCParams(need_response=False)) - results = self.client.fetch_responses() + self.client.submit( + GenerationRequest(prompt_token_ids=[3, 4, 5], + sampling_params=SamplingParams( + max_tokens=5)), )\ + .remote(need_response=False) + results = [] + while not results: + results.extend(self.client.fetch_responses().remote()) assert len(results) == 1 From 1d1fd7c1278d48d366eed58aa4047f00a7eaa59a Mon Sep 17 00:00:00 2001 From: Yan Chunwei <328693+Superjomn@users.noreply.github.com> Date: Mon, 15 Sep 2025 21:56:00 +0800 Subject: [PATCH 07/13] fix WorkerBase and test Signed-off-by: chunweiy Signed-off-by: chunweiy <328693+Superjomn@users.noreply.github.com> Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> --- tensorrt_llm/executor/result.py | 3 +- tensorrt_llm/executor/rpc/rpc_client.py | 1 - tensorrt_llm/executor/rpc_worker.py | 12 +- tensorrt_llm/executor/worker.py | 19 +- tensorrt_llm/executor/worker_base.py | 288 +++++++++++++------- tests/unittest/executor/test_rpc_proxy.py | 7 +- tests/unittest/executor/test_rpc_worker.py | 63 ++++- tests/unittest/executor/test_worker_base.py | 25 +- 8 files changed, 283 insertions(+), 135 deletions(-) diff --git a/tensorrt_llm/executor/result.py b/tensorrt_llm/executor/result.py index d19a8368297..927d6b0a9d4 100644 --- a/tensorrt_llm/executor/result.py +++ b/tensorrt_llm/executor/result.py @@ -14,7 +14,7 @@ from ..bindings import executor as tllm from ..disaggregated_params import DisaggregatedParams from ..llmapi.tracer import global_tracer -from ..llmapi.utils import AsyncQueue +from ..llmapi.utils import AsyncQueue, print_traceback_on_error from ..metrics import MetricNames, MetricsCollector, RequestEventTiming from ..sampling_params import LogprobParams, SamplingParams from .utils import ErrorResponse, has_event_loop, is_llm_response @@ -315,6 +315,7 @@ def _handle_sequence(self, f"Unknown finish reason: {finish_reasons[src_idx]}") self.record_stats(output, req_perf_metrics_dict) + @print_traceback_on_error @nvtx_range_debug("handle_response", color="red", category="GenerationResultBase") diff --git a/tensorrt_llm/executor/rpc/rpc_client.py b/tensorrt_llm/executor/rpc/rpc_client.py index baa8305f64d..86bb8b74584 100644 --- a/tensorrt_llm/executor/rpc/rpc_client.py +++ b/tensorrt_llm/executor/rpc/rpc_client.py @@ -446,7 +446,6 @@ async def call_streaming(self, name: str, *args, logger_debug( f"RPC Client call_streaming received data: {response.result}", color="green") - # Yield data yield response.result elif response.stream_status == 'end': # End of stream diff --git a/tensorrt_llm/executor/rpc_worker.py b/tensorrt_llm/executor/rpc_worker.py index 85a799dc1f9..145407e6b18 100644 --- a/tensorrt_llm/executor/rpc_worker.py +++ b/tensorrt_llm/executor/rpc_worker.py @@ -10,6 +10,7 @@ from .._utils import mpi_rank from ..bindings import executor as tllm from ..builder import Engine +from ..llmapi.llm_args import BaseLlmArgs from ..logger import set_level from ..lora_manager import LoraConfig from ..sampling_params import BatchedLogitsProcessor @@ -37,13 +38,17 @@ def __init__( is_llm_executor: Optional[bool] = None, lora_config: Optional[LoraConfig] = None, garbage_collection_gen0_threshold: Optional[int] = None, + llm_args: Optional[BaseLlmArgs] = None, + batched_logits_processor: Optional[BatchedLogitsProcessor] = None, ) -> None: super().__init__( engine=engine, executor_config=executor_config, is_llm_executor=is_llm_executor, lora_config=lora_config, - garbage_collection_gen0_threshold=garbage_collection_gen0_threshold) + garbage_collection_gen0_threshold=garbage_collection_gen0_threshold, + llm_args=llm_args, + batched_logits_processor=batched_logits_processor) self.shutdown_event = Event() self._response_queue = Queue() @@ -104,6 +109,7 @@ def main_task( is_llm_executor: Optional[bool] = None, lora_config: Optional[LoraConfig] = None, garbage_collection_gen0_threshold: Optional[int] = None, + llm_args: Optional[BaseLlmArgs] = None, **kwargs, ) -> None: if enable_llm_debug(): @@ -115,7 +121,9 @@ def main_task( executor_config=executor_config, is_llm_executor=is_llm_executor, lora_config=lora_config, - garbage_collection_gen0_threshold=garbage_collection_gen0_threshold) + garbage_collection_gen0_threshold=garbage_collection_gen0_threshold, + llm_args=llm_args, + batched_logits_processor=batched_logits_processor) if mpi_rank() != 0: # The non-leader worker will setup the engine immediately. diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index 941348c122c..91cb400ef95 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -15,7 +15,9 @@ from .._utils import KVCacheEventSerializer, mpi_comm, mpi_rank from ..bindings import executor as tllm from ..builder import Engine +from ..llmapi.llm_args import BaseLlmArgs, KvCacheConnectorConfig from ..llmapi.mpi_session import set_mpi_session_cpp +from ..llmapi.tokenizer import TokenizerBase from ..llmapi.tracer import VizTracer, set_global_tracer from ..llmapi.utils import (AsyncQueue, ManagedThread, _SyncQueue, clear_sched_affinity, print_colored_debug, @@ -50,7 +52,10 @@ def __init__( postproc_worker_config: Optional[PostprocWorkerConfig] = None, is_llm_executor: Optional[bool] = None, lora_config: Optional[LoraConfig] = None, - garbage_collection_gen0_threshold: Optional[int] = None, + kv_connector_config: Optional[KvCacheConnectorConfig] = None, + hf_model_dir: Optional[Path] = None, + tokenizer: Optional[TokenizerBase] = None, + llm_args: Optional[BaseLlmArgs] = None, ) -> None: super().__init__( engine=engine, @@ -82,6 +87,8 @@ def __init__( error_queue=self._error_queue, name="dispatch_kv_cache_events_thread") + self.setup_engine() + def _create_iteration_result_queue(self, it_result_queue: IterationResultQueue): if not it_result_queue.is_initialized: @@ -248,7 +255,10 @@ def worker_main( is_llm_executor: Optional[ bool] = True, # whether it's the main executor instance lora_config: Optional[LoraConfig] = None, - garbage_collection_gen0_threshold: Optional[int] = None, + kv_connector_config: Optional[KvCacheConnectorConfig] = None, + hf_model_dir: Optional[Path] = None, + tokenizer: Optional[TokenizerBase] = None, + llm_args: Optional[BaseLlmArgs] = None, ) -> None: mpi_comm().barrier() print_colored_debug(f"Worker {mpi_rank()} entering worker_main...\n", @@ -376,7 +386,10 @@ def notify_proxy_threads_to_quit(): postproc_worker_config=postproc_worker_config, is_llm_executor=is_llm_executor, lora_config=lora_config, - garbage_collection_gen0_threshold=garbage_collection_gen0_threshold) + kv_connector_config=kv_connector_config, + hf_model_dir=hf_model_dir, + tokenizer=tokenizer, + llm_args=llm_args) except Exception as e: logger.error(f"Failed to initialize executor on rank {mpi_rank()}: {e}") logger.error(traceback.format_exc()) diff --git a/tensorrt_llm/executor/worker_base.py b/tensorrt_llm/executor/worker_base.py index 73d0ae2a519..939efbc0543 100644 --- a/tensorrt_llm/executor/worker_base.py +++ b/tensorrt_llm/executor/worker_base.py @@ -2,7 +2,6 @@ import datetime import enum import json -import weakref from pathlib import Path from queue import Queue from typing import Dict, List, Optional, Tuple, Union @@ -15,7 +14,8 @@ nvtx_range_debug) from ..bindings import executor as tllm from ..builder import ConfigEncoder, Engine, EngineConfig -from ..llmapi.llm_args import PybindMirror +from ..llmapi.llm_args import BaseLlmArgs, KvCacheConnectorConfig, PybindMirror +from ..llmapi.tokenizer import TokenizerBase from ..llmapi.tracer import global_tracer from ..llmapi.utils import _SyncQueue, logger_debug, print_colored_debug from ..logger import logger @@ -24,10 +24,11 @@ from ..prompt_adapter_manager import PromptAdapterManager from ..runtime import ModelConfig from ..runtime.model_runner import _engine_config_to_model_config -from ..sampling_params import SamplingParams +from ..sampling_params import BatchedLogitsProcessor, SamplingParams from .executor import GenerationExecutor from .ipc import IpcQueue -from .postproc_worker import PostprocParams, PostprocWorker +from .postproc_worker import (PostprocParams, PostprocWorker, + PostprocWorkerConfig) from .request import GenerationRequest, LoRARequest, PromptAdapterRequest from .result import (GenerationResult, LogProbsResult, ResponseWrapper, compute_logprobs) @@ -54,115 +55,152 @@ def __init__( self, engine: Union[Path, Engine], executor_config: Optional[tllm.ExecutorConfig] = None, + batched_logits_processor: Optional[BatchedLogitsProcessor] = None, + postproc_worker_config: Optional[PostprocWorkerConfig] = None, is_llm_executor: Optional[bool] = None, lora_config: Optional[LoraConfig] = None, - garbage_collection_gen0_threshold: Optional[int] = None, + kv_connector_config: Optional[KvCacheConnectorConfig] = None, + hf_model_dir: Optional[Path] = None, + tokenizer: Optional[TokenizerBase] = None, + llm_args: Optional[BaseLlmArgs] = None, ) -> None: - super().__init__(is_llm_executor=is_llm_executor) + postproc_config = postproc_worker_config or PostprocWorkerConfig() + super().__init__( + num_postprocess_workers=postproc_config.num_postprocess_workers, + postprocess_tokenizer_dir=postproc_config.postprocess_tokenizer_dir, + is_llm_executor=is_llm_executor, + ) - # Persist constructor arguments for deferred setup - self._engine_input = engine + # inputs + self._engine = engine + self._executor_config = executor_config + self._batched_logits_processor = batched_logits_processor + self._hf_model_dir = hf_model_dir + self._tokenizer = tokenizer + self._kv_connector_config = kv_connector_config self._lora_config = lora_config - self._garbage_collection_gen0_threshold = garbage_collection_gen0_threshold + self.llm_args = llm_args - self.engine = None + self.result_queue: Optional[IpcQueue] = None + self.postproc_queues: Optional[List[IpcQueue]] = None self.rank = mpi_rank() self.global_rank = global_mpi_rank() # mapping: client_id -> GenerationResult self._results: Dict[int, GenerationResult] = {} # mapping: client_id from Proxy -> request_id returned from runtime backend self._client_id_to_request_id: Dict[int, int] = {} - self._executor_config = executor_config - self._is_pytorch_backend = getattr(self._executor_config, "backend", - None) == "pytorch" - - if global_mpi_size() > 1: - logger.set_rank(self.global_rank) - - if isinstance(engine, list): - self.engine = engine[self.rank] - - self._await_response_helper = AwaitResponseHelper(weakref.proxy(self)) + self._await_response_helper = AwaitResponseHelper( + self) # TODO: make it weakref + self._is_pytorch_backend = llm_args is not None and llm_args.backend in [ + "pytorch", "_autodeploy" + ] - self.postproc_queues = None - self.result_queue = None - - self._lora_manager: Optional[LoraManager] = None - self._prompt_adapter_manager: Optional[PromptAdapterManager] = None - self._runtime_model_config: Optional[ModelConfig] = None + if not self._is_pytorch_backend and kv_connector_config is not None: + raise ValueError( + "KV connector config is only supported for PyTorch backend") def setup_engine(self) -> None: logger_debug(f"WorkerBase {self.rank} is setting up the engine", color="yellow") - # Force all the ranks to wait here, and start creating the executor simultaneously. mpi_comm().barrier() - device_id = self.global_rank % torch.cuda.device_count() - torch.cuda.set_device(device_id) - - # Make sure C++ executor would use same devices/ranks as py_executor - global_rank = global_mpi_rank() - comm_ranks = mpi_comm().allgather(global_rank) - device_ids = mpi_comm().allgather(device_id) + if global_mpi_size() > 1: + logger.set_rank(self.global_rank) + engine = self._engine + self.llm_args + batched_logits_processor = self._batched_logits_processor + hf_model_dir = self._hf_model_dir + tokenizer = self._tokenizer + lora_config = self._lora_config + kv_connector_config = self._kv_connector_config executor_config = self._executor_config - if executor_config is None: - executor_config = tllm.ExecutorConfig(1) - self._executor_config = executor_config - - executor_config.parallel_config = tllm.ParallelConfig( - participant_ids=comm_ranks, device_ids=device_ids) + assert hf_model_dir is not None - engine = self._engine_input if isinstance(engine, list): engine = engine[self.rank] - if isinstance(engine, Engine): - self.engine = tllm.Executor(engine.engine, - json.dumps(engine.config.to_dict(), - cls=ConfigEncoder), - tllm.ModelType.DECODER_ONLY, - executor_config=executor_config, - managed_weights=engine.managed_weights) - elif not hasattr(executor_config, "backend"): - self.engine = tllm.Executor(engine, tllm.ModelType.DECODER_ONLY, - executor_config) - else: - args = { - "executor_config": executor_config, - "checkpoint_dir": executor_config.hf_model_dir, - } - if executor_config.backend == "pytorch": + def _get_comm_ranks_device_id(): + device_id = self.global_rank % torch.cuda.device_count() + torch.cuda.set_device(device_id) + # Make sure C++ executor would use same devices/ranks as py_executor + global_rank = global_mpi_rank() + comm_ranks = mpi_comm().allgather(global_rank) + device_ids = mpi_comm().allgather(device_id) + return comm_ranks, device_ids + + def _create_py_executor(): + args = {} + assert hasattr( + self.llm_args, "backend" + ), "llm_args should be with backend in _create_py_executor" + _ = _get_comm_ranks_device_id() + if self.llm_args.backend == "pytorch": from tensorrt_llm._torch.pyexecutor.py_executor_creator import \ create_py_executor create_executor = create_py_executor - args["lora_config"] = self._lora_config - args[ - "garbage_collection_gen0_threshold"] = self._garbage_collection_gen0_threshold - elif executor_config.backend == "_autodeploy": + args["llm_args"] = self.llm_args + args["checkpoint_dir"] = hf_model_dir + args["tokenizer"] = tokenizer + args["lora_config"] = lora_config + args["kv_connector_config"] = kv_connector_config + elif self.llm_args.backend == "_autodeploy": + from tensorrt_llm._torch.auto_deploy.llm_args import \ + LlmArgs as ADLlmArgs from tensorrt_llm._torch.auto_deploy.shim.ad_executor import \ create_autodeploy_executor create_executor = create_autodeploy_executor + assert isinstance(self.llm_args, ADLlmArgs) + args["ad_config"] = self.llm_args.get_pytorch_backend_config() else: raise ValueError( - f"Unsupported backend config: {executor_config.backend}") - - logger_debug(f"WorkerBase {self.rank} creating py_executor", - color="yellow") - self.engine = create_executor(**args) - logger_debug(f"WorkerBase {self.rank} created py_executor", - color="yellow") - logger_debug(f"WorkerBase {self.rank} setup engine done", - color="yellow") + f"Unsupported backend config: {self.llm_args.backend}") + + # Define additional attributes that can be used later, such as in _deduce_max_tokens + self.mapping = self.llm_args.parallel_config.to_mapping() + self.checkpoint_loader = None + if self.llm_args.backend == "pytorch": + from tensorrt_llm._torch.pyexecutor.config import \ + _construct_checkpoint_loader + self.checkpoint_loader = _construct_checkpoint_loader( + self.llm_args.backend, self.llm_args.checkpoint_loader, + self.llm_args.checkpoint_format) + + _executor = create_executor(**args) + self.max_seq_len = self.llm_args.max_seq_len + if _executor.max_seq_len is not None: + # max_seq_len might be updated by model engine as in create_py_executor + self.max_seq_len = _executor.max_seq_len + return _executor + + def _create_engine(executor_config): + if executor_config is None: + executor_config = tllm.ExecutorConfig(1) + executor_config.logits_post_processor_config = tllm.LogitsPostProcessorConfig( + processor_batched=batched_logits_processor, replicate=False) + comm_ranks, device_ids = _get_comm_ranks_device_id() + executor_config.parallel_config = tllm.ParallelConfig( + participant_ids=comm_ranks, device_ids=device_ids) + + if isinstance(engine, Engine): + return tllm.Executor(engine.engine, + json.dumps(engine.config.to_dict(), + cls=ConfigEncoder), + tllm.ModelType.DECODER_ONLY, + executor_config=executor_config, + managed_weights=engine.managed_weights) - self._setup_lora(engine, executor_config, self._lora_config) + assert not hasattr(executor_config, "backend") + return tllm.Executor(engine, tllm.ModelType.DECODER_ONLY, + executor_config) - def _setup_lora(self, engine: Union[Path, Engine], - executor_config: tllm.ExecutorConfig, - lora_config: Optional[LoraConfig]) -> None: - """Setup LoRA and prompt adapter managers.""" - # LoRA setup + self.engine = _create_py_executor( + ) if self.llm_args is not None else _create_engine(executor_config) + + self._lora_manager: Optional[LoraManager] = None + self._prompt_adapter_manager: Optional[PromptAdapterManager] = None + self._runtime_model_config: Optional[ModelConfig] = None if self.rank == 0 and isinstance(self.engine, tllm.Executor): if isinstance(engine, Engine): engine_config = engine.config @@ -181,8 +219,9 @@ def _setup_lora(self, engine: Union[Path, Engine], if engine_config.build_config.max_prompt_embedding_table_size > 0: self._prompt_adapter_manager = PromptAdapterManager() - if getattr(executor_config, "backend", - "") == "pytorch" and lora_config is not None: + if self.llm_args and getattr( + self.llm_args, "backend", + "") == "pytorch" and lora_config is not None: from tensorrt_llm._torch.pyexecutor.resource_manager import \ ResourceManagerType peft_cache_manager = self.engine.resource_manager.resource_managers.get( @@ -284,44 +323,79 @@ def _enqueue_request(self, request: GenerationRequest) -> int: context_phase_params = request.disaggregated_params.get_context_phase_params( ) - is_overlap_enabled = self._is_pytorch_backend and not self._executor_config.pytorch_backend_config.disable_overlap_scheduler - if is_overlap_enabled: - is_disaggregated = self.engine.kv_cache_transceiver is not None - if is_disaggregated and ( - request_type == tllm.RequestType.REQUEST_TYPE_CONTEXT_ONLY): - raise ValueError( - "Context only requests are not supported in pytorch backend when overlap is enabled." - ) + if self._is_pytorch_backend: + if not self.llm_args.disable_overlap_scheduler: + is_disaggregated = self.engine.kv_cache_transceiver is not None + if is_disaggregated and ( + request_type + == tllm.RequestType.REQUEST_TYPE_CONTEXT_ONLY): + raise ValueError( + "Context only requests are not supported in pytorch backend when overlap is enabled." + ) assert request.id is not None def _deduce_max_tokens(request: GenerationRequest, - executor_config: tllm.ExecutorConfig) -> int: - if request.sampling_params.max_tokens: - return request.sampling_params.max_tokens + executor_config: tllm.ExecutorConfig, + llm_args: Optional[BaseLlmArgs] = None) -> int: # deduce max_tokens when it's not set by user + max_tokens = request.sampling_params.max_tokens query_token_len = len( request.query_token_ids) if request.query_token_ids else 0 - cp_size = 1 if (not hasattr(executor_config, "mapping") - or executor_config.mapping.cp_size - is None) else executor_config.mapping.cp_size - if not hasattr(executor_config, "max_seq_len"): - raise RuntimeError( - "max_tokens for sampling is not set and cannot be deduced") + + cp_size = 1 + max_seq_len = None + if llm_args is not None: + # deduce max_tokens by llm args + assert executor_config is None, "An empty executor_config in _deduce_max_tokens is expected when LLM arguments are defined." + if hasattr(self, + "mapping") and self.mapping.cp_size is not None: + cp_size = self.mapping.cp_size + max_seq_len = getattr(self, "max_seq_len", None) + else: + # deduce max_tokens by executor config + if hasattr(executor_config, "mapping" + ) and executor_config.mapping.cp_size is not None: + cp_size = executor_config.mapping.cp_size + max_seq_len = getattr(executor_config, "max_seq_len", None) + if max_seq_len is None: + logger.warning("`default_max_tokens` cannot be deduced") + if max_tokens is None: + raise ValueError( + "`max_tokens` must be set when `default_max_tokens` cannot be deduced" + ) + else: + # use max_tokens if can't deduce default_max_tokens + return max_tokens splited_prompt_len = int(len(prompt_token_ids) / cp_size) - default_max_tokens = executor_config.max_seq_len - splited_prompt_len - query_token_len - if default_max_tokens < 0: - raise ValueError( - f"Deduced max_tokens {default_max_tokens} is less than 0, because" - f"prompt length {splited_prompt_len} plus query length {query_token_len} " - f"is larger than max_seq_len {executor_config.max_seq_len}") - return default_max_tokens + default_max_tokens = max_seq_len - splited_prompt_len - query_token_len + if default_max_tokens <= 0: + logger.warning( + f"`default_max_tokens` ({default_max_tokens}) should be greater than 0, " + f"`default_max_tokens` ({default_max_tokens}) = max_seq_len ({max_seq_len})" + f" - `splited_prompt_len` ({splited_prompt_len}) - `query_token_len` ({query_token_len})" + ) + if max_tokens is None: + raise ValueError( + "`max_tokens` must be set when `default_max_tokens` is illegal" + ) + # default_max_tokens is the biggest available value + if max_tokens is None: + return default_max_tokens + elif max_tokens > default_max_tokens: + logger.warning( + f"User-specified `max_tokens` ({max_tokens}) is greater than deduced " + f"`default_max_tokens` ({default_max_tokens}), using default_max_tokens instead." + ) + return default_max_tokens + return max_tokens try: executor_request = tllm.Request( client_id=request.id, input_token_ids=prompt_token_ids, - max_tokens=_deduce_max_tokens(request, self._executor_config), + max_tokens=_deduce_max_tokens(request, self._executor_config, + self.llm_args), streaming=request.streaming, sampling_config=request.sampling_params._get_sampling_config(), end_id=-1 if request.sampling_params.ignore_eos else @@ -353,7 +427,8 @@ def _deduce_max_tokens(request: GenerationRequest, request.sampling_params.logits_processor, kv_cache_retention_config=request.kv_cache_retention_config, context_phase_params=context_phase_params, - type=request_type) + type=request_type, + cache_salt_id=request.cache_salt_id) executor_request.py_lora_path = py_lora_path if self._is_pytorch_backend and request.multimodal_params is not None: @@ -373,6 +448,9 @@ def _deduce_max_tokens(request: GenerationRequest, if self._is_pytorch_backend and request.scheduling_params is not None: executor_request.py_scheduling_params = request.scheduling_params + if request.arrival_time is not None: + executor_request.py_arrival_time = request.arrival_time + if request.query_token_ids is not None: # pytorch star attention workflow # a workaround to avoid public interface update diff --git a/tests/unittest/executor/test_rpc_proxy.py b/tests/unittest/executor/test_rpc_proxy.py index 90e1036c5f7..be165b41a6f 100644 --- a/tests/unittest/executor/test_rpc_proxy.py +++ b/tests/unittest/executor/test_rpc_proxy.py @@ -23,14 +23,15 @@ class TestRpcProxy: def create_proxy(self, tp_size: int): # Create executor config with the correct tp_size - executor_config = create_fake_executor_config(model_path, - tp_size=tp_size) + llm_args, executor_config = create_fake_executor_config(model_path, + tp_size=tp_size) mpi_session = MpiPoolSession(n_workers=tp_size) proxy = GenerationExecutorRpcProxy( worker_kwargs={ "engine": model_path, - "executor_config": executor_config, + "executor_config": None, + "llm_args": llm_args, "model_world_size": tp_size, }, model_world_size=tp_size, diff --git a/tests/unittest/executor/test_rpc_worker.py b/tests/unittest/executor/test_rpc_worker.py index 11e5304ab5b..8c7affb228a 100644 --- a/tests/unittest/executor/test_rpc_worker.py +++ b/tests/unittest/executor/test_rpc_worker.py @@ -26,7 +26,8 @@ class TestRpcWorkerTP1: def setup_method(self): - self.executor_config = create_fake_executor_config(model_path) + self.llm_args, self.executor_config = create_fake_executor_config( + model_path) self.pool, self.addr = self.create_worker_pool() self.client = self.create_rpc_client(self.addr) self.client.setup_engine().remote() @@ -45,7 +46,8 @@ def create_worker_pool(self): pool.submit(RpcWorker.main_task, engine=model_path, rpc_addr=addr, - executor_config=self.executor_config) + executor_config=self.executor_config, + llm_args=self.llm_args) return pool, addr def create_rpc_client(self, addr: str): @@ -113,20 +115,62 @@ async def process_request_streaming(): # NOTE: known issue, the responses should be fetched before shutdown, # or the shutdown will hang. results = [] + responses_per_client = {} + expected_responses_per_client = 5 # max_tokens=5 print(f"start to fetch_responses_async") no = 0 async for result in self.client.fetch_responses_loop_async( ).remote_streaming(): - print(f"fetch_responses_async {no} result: {result}") - results.extend(result) # result is a list of responses + if result: # result is already a list of lists + print( + f"fetch_responses_async batch {no}, received {len(result)} sub-batches" + ) + for batch in result: + if isinstance(batch, list): + print(f" Sub-batch has {len(batch)} responses") + results.extend(batch) + # Track responses per client + for response in batch: + client_id = response.client_id + if client_id not in responses_per_client: + responses_per_client[client_id] = 0 + responses_per_client[client_id] += 1 + else: + # Single response + results.append(batch) + client_id = batch.client_id + if client_id not in responses_per_client: + responses_per_client[client_id] = 0 + responses_per_client[client_id] += 1 + no += 1 - if no >= req_count * 5: # Break after receiving 5 batches - print(f"break after receiving {no} batches") + + # Check if all clients have received their expected responses + completed_clients = sum( + 1 for count in responses_per_client.values() + if count >= expected_responses_per_client) + + print(f"Responses per client: {responses_per_client}") + print(f"Completed clients: {completed_clients}/{req_count}") + + # Break when we've received all expected responses + if completed_clients >= req_count: + print( + f"All {completed_clients} clients completed after {no} batches" + ) break + + # Safety break to prevent infinite loop + if no >= req_count * 20: # Much higher limit as safety + print(f"Safety break after {no} batches") + break + print(f"Received {no} batches of streaming responses") - print(f"fetch_responses result: {results}") + print(f"Total responses received: {len(results)}") + print(f"Final responses per client: {responses_per_client}") assert results + assert len(responses_per_client) >= req_count await process_request_streaming() @@ -180,8 +224,8 @@ def process_request_streaming(): class TestRpcWorkerTP2: def setup_method(self): - self.executor_config = create_fake_executor_config(model_path, - tp_size=2) + self.llm_args, self.executor_config = create_fake_executor_config( + model_path, tp_size=2) self.session, self.addr, self.futures = self.create_worker_session() self.client = self.create_rpc_client(self.addr) self.client.setup_engine().remote() @@ -199,6 +243,7 @@ def create_worker_session(self): engine=model_path, rpc_addr=addr, executor_config=self.executor_config, + llm_args=self.llm_args, model_world_size=2) return session, addr, futures diff --git a/tests/unittest/executor/test_worker_base.py b/tests/unittest/executor/test_worker_base.py index a5fbdb7ae75..919d34c6a93 100644 --- a/tests/unittest/executor/test_worker_base.py +++ b/tests/unittest/executor/test_worker_base.py @@ -29,10 +29,11 @@ class TestWorkerBase: class FakeWorker(WorkerBase): def __init__(self, engine: str): - super().__init__(engine=engine) - executor_config = create_fake_executor_config(engine) + super().__init__(engine=engine, hf_model_dir=engine) + llm_args, executor_config = create_fake_executor_config(engine) # Pass config in constructor and finalize with parameterless setup self._executor_config = executor_config + self.llm_args = llm_args self.setup_engine() def test_create_engine(self): @@ -99,7 +100,6 @@ def create_fake_executor_config(model_path, tp_size=1): update_executor_config( executor_config, - backend="pytorch", pytorch_backend_config=llm_args.get_pytorch_backend_config(), mapping=llm_args.parallel_config.to_mapping(), speculative_config=llm_args.speculative_config, @@ -110,14 +110,14 @@ def create_fake_executor_config(model_path, tp_size=1): checkpoint_loader=llm_args.checkpoint_loader, ) - return executor_config + return llm_args, executor_config class TestRpcWorkerBaseTP2: def setup_method(self): - self.executor_config = create_fake_executor_config(model_path, - tp_size=2) + self.llm_args, self.executor_config = create_fake_executor_config( + model_path, tp_size=2) self.session = self.create_worker_session() # No need to sleep here - the session is ready immediately @@ -126,9 +126,11 @@ def create_worker_session(self): return session def test_create_executor(self): - futures = self.session.submit(TestRpcWorkerBaseTP2.create_executor, - engine=model_path, - executor_config=self.executor_config) + futures = self.session.submit( + TestRpcWorkerBaseTP2.create_executor, + engine=model_path, + llm_args=self.llm_args, + ) # Wait for completion for future in futures: future.result() @@ -136,7 +138,7 @@ def test_create_executor(self): self.session.shutdown() @staticmethod - def create_executor(engine, executor_config): + def create_executor(engine, llm_args): # Set MPI session for C++ backend set_mpi_session_cpp(mpi_comm()) @@ -156,7 +158,8 @@ def create_executor(engine, executor_config): try: print(f"[Test] Rank {rank} creating WorkerBase...") executor = WorkerBase(engine=engine, - executor_config=executor_config) + llm_args=llm_args, + hf_model_dir=engine) # For PyTorch backend, all ranks need to participate in setup print(f"[Test] Rank {rank} calling setup_engine...") From 0f646e4378973e126098267c02c3e4ec493613b1 Mon Sep 17 00:00:00 2001 From: chunweiy Date: Sat, 20 Sep 2025 00:05:37 +0000 Subject: [PATCH 08/13] fix rpc server pickle error Signed-off-by: chunweiy Signed-off-by: chunweiy <328693+Superjomn@users.noreply.github.com> Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> --- tensorrt_llm/executor/base_worker.py | 26 + tensorrt_llm/executor/rpc/rpc_client.py | 42 +- tensorrt_llm/executor/rpc/rpc_server.py | 62 +- tensorrt_llm/executor/rpc_proxy.py | 25 +- tensorrt_llm/executor/rpc_worker.py | 83 +- tensorrt_llm/executor/worker.py | 19 +- tensorrt_llm/executor/worker_base.py | 848 -------------------- tests/unittest/executor/test_rpc.py | 46 +- tests/unittest/executor/test_rpc_proxy.py | 5 +- tests/unittest/executor/test_rpc_worker.py | 95 +-- tests/unittest/executor/test_worker_base.py | 185 ----- 11 files changed, 305 insertions(+), 1131 deletions(-) delete mode 100644 tensorrt_llm/executor/worker_base.py delete mode 100644 tests/unittest/executor/test_worker_base.py diff --git a/tensorrt_llm/executor/base_worker.py b/tensorrt_llm/executor/base_worker.py index c17401cc9fa..1dccf5dcf66 100644 --- a/tensorrt_llm/executor/base_worker.py +++ b/tensorrt_llm/executor/base_worker.py @@ -548,6 +548,32 @@ def submit(self, request: GenerationRequest) -> GenerationResult: return result + def shutdown(self): + if self.doing_shutdown: + return + else: + self.doing_shutdown = True + + if self.engine is not None and self.engine.can_enqueue_requests(): + self.engine.shutdown() + self.engine = None + + # Define a Callable to join iteration and request stats + @staticmethod + def _stats_serializer( + stats: Tuple[tllm.IterationStats, tllm.RequestStats]) -> str: + iteration_stats, req_stats = stats + stats_dict = json.loads(iteration_stats.to_json_str()) + + if req_stats is not None and len(req_stats) > 0: + stats_dict["requestStats"] = [] + for req_stat in req_stats: + stats_dict["requestStats"].append( + json.loads(req_stat.to_json_str())) + + # Convert back to JSON string + return json.dumps(stats_dict) + def _pop_result(self, client_id: int): self._results.pop(client_id, None) self._client_id_to_request_id.pop(client_id, None) diff --git a/tensorrt_llm/executor/rpc/rpc_client.py b/tensorrt_llm/executor/rpc/rpc_client.py index 86bb8b74584..6767291523c 100644 --- a/tensorrt_llm/executor/rpc/rpc_client.py +++ b/tensorrt_llm/executor/rpc/rpc_client.py @@ -152,6 +152,7 @@ def close(self): async def _response_reader(self): """Task to read responses from the socket and set results on futures.""" + logger_debug("Response reader started") while not self._stop_event.is_set(): try: @@ -166,6 +167,11 @@ async def _response_reader(self): continue logger_debug(f"RPC Client received response: {response}") + logger_debug( + f"Response request_id: {response.request_id}, is_streaming: {response.is_streaming}" + ) + logger_debug( + f"Pending futures: {list(self._pending_futures.keys())}") # Handle streaming responses if response.is_streaming: @@ -187,18 +193,34 @@ async def _response_reader(self): None) else: # Handle regular responses + logger_debug( + f"Handling regular response for request_id: {response.request_id}" + ) if future_info := self._pending_futures.get( response.request_id): future, target_loop = future_info + logger_debug( + f"Found future for request_id: {response.request_id}, future done: {future.done()}" + ) if not future.done(): if response.error is None: + logger_debug( + f"Setting result for request_id: {response.request_id}, result: {response.result}" + ) target_loop.call_soon_threadsafe( future.set_result, response.result) else: # Use the original RPCError from the response + logger_debug( + f"Setting exception for request_id: {response.request_id}, error: {response.error}" + ) target_loop.call_soon_threadsafe( future.set_exception, response.error) + else: + logger_debug( + f"No future found for request_id: {response.request_id}" + ) self._pending_futures.pop(response.request_id, None) except asyncio.CancelledError: @@ -268,16 +290,27 @@ async def _call_async(self, method_name, *args, **kwargs): loop = asyncio.get_running_loop() future = loop.create_future() + logger_debug( + f"RPC Client _call_async: Created future for request_id: {request_id} in loop: {id(loop)}" + ) self._pending_futures[request_id] = (future, loop) + logger_debug( + f"RPC Client _call_async: Stored future in pending_futures") try: # If timeout, the remote call should return a timeout error timely, # so we add 1 second to the timeout to ensure the client can get # that result. + logger_debug( + f"RPC Client _call_async: Awaiting future for request_id: {request_id}" + ) if timeout is None: res = await future else: res = await asyncio.wait_for(future, timeout + 1) + logger_debug( + f"RPC Client _call_async: Got result for request_id: {request_id}: {res}" + ) return res except RPCCancelled: self._server_stopped = True @@ -315,9 +348,16 @@ def _call_sync(self, method_name, *args, **kwargs): f"RPC Client calling method: {method_name} with args: {args} and kwargs: {kwargs}" ) self._ensure_event_loop() + logger_debug( + f"RPC Client _call_sync: Creating future for {method_name}") future = asyncio.run_coroutine_threadsafe( self._call_async(method_name, *args, **kwargs), self._loop) - return future.result() + logger_debug( + f"RPC Client _call_sync: Waiting for result of {method_name}") + result = future.result() + logger_debug( + f"RPC Client _call_sync: Got result for {method_name}: {result}") + return result def call_async(self, name: str, *args, **kwargs): """ diff --git a/tensorrt_llm/executor/rpc/rpc_server.py b/tensorrt_llm/executor/rpc/rpc_server.py index dd7246b4ed3..eed47f273b3 100644 --- a/tensorrt_llm/executor/rpc/rpc_server.py +++ b/tensorrt_llm/executor/rpc/rpc_server.py @@ -192,6 +192,10 @@ async def _dispatcher_routine(self, stop_event: threading.Event): except asyncio.TimeoutError: await asyncio.sleep(0) continue + except Exception as e: + logger.error(f"RPC dispatcher caught an exception: {e}") + logger.error(traceback.format_exc()) + continue await self._queue.put(req) # type: ignore @@ -272,8 +276,9 @@ async def _worker_routine(self, stop_event: threading.Event): logger_debug( f"RPC Server sending response for request {req}, pending: {self._num_pending_requests}" ) - await self._client_socket.put_async(response) - logger_debug(f"RPC Server sent response for request {req}") + if await self._send_response(req, response): + logger_debug( + f"RPC Server sent response for request {req}") # Only decrement if this request was counted in the first place if req.method_name not in ["_rpc_shutdown", "shutdown"]: @@ -352,9 +357,11 @@ async def _process_streaming_request(self, req: RPCRequest): async for result in func(*req.args, **req.kwargs): logger_debug( f"RPC Server got data and ready to send result {result}") - await self._client_socket.put_async( - RPCResponse(req.request_id, result, None, True, - sequence_number, 'data')) + response = RPCResponse(req.request_id, result, None, True, + sequence_number, 'data') + if not await self._send_response(req, response): + # Stop streaming after a pickle error + return sequence_number += 1 # Send end signal @@ -372,13 +379,46 @@ async def _process_streaming_request(self, req: RPCRequest): sequence_number, 'error')) except Exception as e: - await self._client_socket.put_async( - RPCResponse( + response = RPCResponse( + req.request_id, None, + RPCStreamingError(str(e), traceback=traceback.format_exc()), + True, sequence_number, 'error') + await self._send_response(req, response) + + async def _send_response(self, req: RPCRequest, + response: RPCResponse) -> bool: + """Safely sends a response, handling pickle errors.""" + try: + await self._client_socket.put_async(response) + return True + except Exception as e: + logger.error( + f"Failed to pickle response for request {req.request_id}: {e}") + error_msg = f"Failed to pickle response: {e}" + if req.is_streaming: + error_cls = RPCStreamingError + # For streaming, we also need sequence number. The original response has it. + sequence_number = response.sequence_number if response else None + error_response = RPCResponse( + req.request_id, + None, + error_cls(error_msg, traceback=traceback.format_exc()), + is_streaming=True, + sequence_number=sequence_number, + stream_status='error') + else: + error_cls = RPCError + error_response = RPCResponse( req.request_id, None, - RPCStreamingError(str(e), - cause=e, - traceback=traceback.format_exc()), True, - sequence_number, 'error')) + error_cls(error_msg, traceback=traceback.format_exc())) + + try: + await self._client_socket.put_async(error_response) + except Exception as e_inner: + logger.error( + f"Failed to send error response for request {req.request_id}: {e_inner}" + ) + return False def start(self): """Binds sockets, starts workers, and begins proxying messages.""" diff --git a/tensorrt_llm/executor/rpc_proxy.py b/tensorrt_llm/executor/rpc_proxy.py index ef08e73e1d0..1d26c4e70b0 100644 --- a/tensorrt_llm/executor/rpc_proxy.py +++ b/tensorrt_llm/executor/rpc_proxy.py @@ -74,6 +74,7 @@ def __init__(self, # TBD: Move model creation to the mpi task, or left in RPC? self.setup_engine_remote() + # Setup main loop after engine is ready self.setup_mainloop() def launch_workers(self): @@ -88,11 +89,15 @@ async def main_loop_task(self): """ Main loop of the proxy, it will invoke the actions periodically. """ - async for responses in self.rpc_client.fetch_responses_loop_async( - ).remote_streaming(): - if self._shutdown_event.is_set(): - return - self.handle_responses(responses) + try: + async for responses in self.rpc_client.fetch_responses_loop_async( + ).remote_streaming(): + if self._shutdown_event.is_set(): + return + self.handle_responses(responses) + except Exception as e: + logger.error(f"Error in main_loop_task: {e}") + raise def setup_mainloop(self): @@ -115,6 +120,11 @@ def process_res(res: list): nonlocal event_loop nonlocal async_queues + if client_id not in self._results: + logger.warning( + f"Received response for unknown client_id: {client_id}") + continue + queue = self._results[client_id].queue if isinstance(queue, _SyncQueue): queue.put_nowait(r) @@ -128,6 +138,11 @@ def process_res(res: list): r, ErrorResponse): self._results.pop(client_id) + # Handle the case where responses might not be a list of lists + if responses and not isinstance(responses[0], list): + # If responses is a flat list, wrap it + responses = [responses] + for res in responses: global_tracer().log_instant("RPC.get") process_res(res) diff --git a/tensorrt_llm/executor/rpc_worker.py b/tensorrt_llm/executor/rpc_worker.py index 145407e6b18..33543e4760e 100644 --- a/tensorrt_llm/executor/rpc_worker.py +++ b/tensorrt_llm/executor/rpc_worker.py @@ -10,21 +10,24 @@ from .._utils import mpi_rank from ..bindings import executor as tllm from ..builder import Engine -from ..llmapi.llm_args import BaseLlmArgs +from ..llmapi.llm_args import BaseLlmArgs, KvCacheConnectorConfig +from ..llmapi.tokenizer import TokenizerBase from ..logger import set_level from ..lora_manager import LoraConfig from ..sampling_params import BatchedLogitsProcessor +from .base_worker import BaseWorker from .postproc_worker import PostprocWorkerConfig +from .request import GenerationRequest from .rpc import RPCServer -from .worker_base import WorkerBase -class RpcWorker(WorkerBase): +class RpcWorker(BaseWorker): """ - A RPC wrapper for the WorkerBase class. + A RPC wrapper for the BaseWorker class. Actions: - `setup_engine`: Setup the engine. + - `submit`: Submit a request to the worker. - `fetch_responses`: Fetch the latest responses from engine. - `fetch_stats`: Fetch the latest stats from engine. - `fetch_kv_cache_events`: Fetch the latest kv cache events from engine. @@ -38,22 +41,36 @@ def __init__( is_llm_executor: Optional[bool] = None, lora_config: Optional[LoraConfig] = None, garbage_collection_gen0_threshold: Optional[int] = None, - llm_args: Optional[BaseLlmArgs] = None, batched_logits_processor: Optional[BatchedLogitsProcessor] = None, + postproc_worker_config: Optional[PostprocWorkerConfig] = None, + kv_connector_config: Optional[KvCacheConnectorConfig] = None, + hf_model_dir: Optional[Path] = None, + tokenizer: Optional[TokenizerBase] = None, + llm_args: Optional[BaseLlmArgs] = None, ) -> None: super().__init__( engine=engine, executor_config=executor_config, is_llm_executor=is_llm_executor, lora_config=lora_config, - garbage_collection_gen0_threshold=garbage_collection_gen0_threshold, llm_args=llm_args, - batched_logits_processor=batched_logits_processor) + batched_logits_processor=batched_logits_processor, + postproc_worker_config=postproc_worker_config, + kv_connector_config=kv_connector_config, + hf_model_dir=hf_model_dir, + tokenizer=tokenizer, + ) + # Store garbage_collection_gen0_threshold if needed in the future + self.garbage_collection_gen0_threshold = garbage_collection_gen0_threshold self.shutdown_event = Event() self._response_queue = Queue() self.set_result_queue(self._response_queue) + def submit(self, request: GenerationRequest): + """ Submits a request to the worker. """ + super().submit(request) + def fetch_stats(self) -> list: return super().fetch_stats() @@ -61,13 +78,42 @@ def fetch_responses(self, timeout: Optional[float] = None) -> list: logger_debug(f"RpcWorker {mpi_rank()} is fetching responses", color="yellow") # NOTE: This is a blocking call, it will wait for the responses to be available. - super().await_responses(timeout) + responses = super().await_responses(timeout) + self._await_response_helper.responses_handler(responses) + qsize = self._response_queue.qsize() logger_debug(f"RpcWorker returning {qsize} responses", color="yellow") - return [self._response_queue.get() for _ in range(qsize)] + + if qsize == 0: + return [] + + all_responses = [] + for _ in range(qsize): + # The queue contains batches of responses, so extend the list + all_responses.extend(self._response_queue.get()) + return all_responses async def fetch_responses_async(self) -> list: - return await asyncio.to_thread(self.fetch_responses) + # A really async version of fetch_responses + logger_debug(f"RpcWorker {mpi_rank()} is fetching responses async", + color="yellow") + + # First, await any pending responses without blocking the event loop + responses = await asyncio.to_thread(self.await_responses, 0.001) + # Handle the responses that are ready + self._await_response_helper.responses_handler(responses) + + qsize = self._response_queue.qsize() + logger_debug(f"RpcWorker returning {qsize} async responses", + color="yellow") + + if qsize == 0: + return [] + + all_responses = [] + for _ in range(qsize): + all_responses.extend(self._response_queue.get()) + return all_responses # for streaming performance async def fetch_responses_loop_async(self) -> AsyncGenerator[list, None]: @@ -88,7 +134,9 @@ async def fetch_responses_loop_async(self) -> AsyncGenerator[list, None]: def setup_engine(self): # Force all the ranks to wait here, and start creating the executor simultaneously. - mpi_comm().barrier() + # Only call barrier if we have multiple ranks to avoid hanging in single-process tests + if mpi_comm().Get_size() > 1: + mpi_comm().barrier() super().setup_engine() @@ -98,6 +146,9 @@ def shutdown(self): self.shutdown_event.set() super().shutdown() + def start(self): + pass + @staticmethod def main_task( engine: Union[Path, Engine], @@ -110,6 +161,9 @@ def main_task( lora_config: Optional[LoraConfig] = None, garbage_collection_gen0_threshold: Optional[int] = None, llm_args: Optional[BaseLlmArgs] = None, + kv_connector_config: Optional[KvCacheConnectorConfig] = None, + hf_model_dir: Optional[Path] = None, + tokenizer: Optional[TokenizerBase] = None, **kwargs, ) -> None: if enable_llm_debug(): @@ -123,7 +177,12 @@ def main_task( lora_config=lora_config, garbage_collection_gen0_threshold=garbage_collection_gen0_threshold, llm_args=llm_args, - batched_logits_processor=batched_logits_processor) + batched_logits_processor=batched_logits_processor, + postproc_worker_config=postproc_worker_config, + kv_connector_config=kv_connector_config, + hf_model_dir=hf_model_dir, + tokenizer=tokenizer, + ) if mpi_rank() != 0: # The non-leader worker will setup the engine immediately. diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index 91cb400ef95..b744913ace6 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -87,8 +87,6 @@ def __init__( error_queue=self._error_queue, name="dispatch_kv_cache_events_thread") - self.setup_engine() - def _create_iteration_result_queue(self, it_result_queue: IterationResultQueue): if not it_result_queue.is_initialized: @@ -217,7 +215,22 @@ def shutdown(self): self.dispatch_kv_cache_events_thread.stop() self.dispatch_kv_cache_events_thread.join() - super().shutdown() + self.engine.shutdown() + self.engine = None + + if self.llm_args is not None: + assert self._executor_config is None, "An empty executor_config is expected in shutdown when LLM arguments are defined." + if (self.llm_args.backend == "pytorch" + and hasattr(self, "checkpoint_loader") + and self.checkpoint_loader is not None): + self.checkpoint_loader.cleanup() + self.checkpoint_loader = None + else: + if hasattr( + self._executor_config, "checkpoint_loader" + ) and self._executor_config.checkpoint_loader is not None: + self._executor_config.checkpoint_loader.cleanup() + self._executor_config.checkpoint_loader = None # Check if there are any errors from the threads before shutdown. self._handle_background_error() diff --git a/tensorrt_llm/executor/worker_base.py b/tensorrt_llm/executor/worker_base.py deleted file mode 100644 index 939efbc0543..00000000000 --- a/tensorrt_llm/executor/worker_base.py +++ /dev/null @@ -1,848 +0,0 @@ -import copy -import datetime -import enum -import json -from pathlib import Path -from queue import Queue -from typing import Dict, List, Optional, Tuple, Union - -import torch - -from tensorrt_llm.logger import logger - -from .._utils import (global_mpi_rank, global_mpi_size, mpi_comm, mpi_rank, - nvtx_range_debug) -from ..bindings import executor as tllm -from ..builder import ConfigEncoder, Engine, EngineConfig -from ..llmapi.llm_args import BaseLlmArgs, KvCacheConnectorConfig, PybindMirror -from ..llmapi.tokenizer import TokenizerBase -from ..llmapi.tracer import global_tracer -from ..llmapi.utils import _SyncQueue, logger_debug, print_colored_debug -from ..logger import logger -from ..lora_manager import LoraConfig, LoraManager -from ..metrics import RequestEventTiming -from ..prompt_adapter_manager import PromptAdapterManager -from ..runtime import ModelConfig -from ..runtime.model_runner import _engine_config_to_model_config -from ..sampling_params import BatchedLogitsProcessor, SamplingParams -from .executor import GenerationExecutor -from .ipc import IpcQueue -from .postproc_worker import (PostprocParams, PostprocWorker, - PostprocWorkerConfig) -from .request import GenerationRequest, LoRARequest, PromptAdapterRequest -from .result import (GenerationResult, LogProbsResult, ResponseWrapper, - compute_logprobs) -from .utils import (ErrorResponse, RequestError, enable_llm_debug, - is_llm_response) - -if enable_llm_debug(): - logger.set_level("debug") - -__all__ = [ - "WorkerBase", -] - - -class WorkerBase(GenerationExecutor): - """ - Base class for all workers. - - It contains all the core logic for the worker, without any specific logic for - cross-process communication such as IPC or RPC. - """ - - def __init__( - self, - engine: Union[Path, Engine], - executor_config: Optional[tllm.ExecutorConfig] = None, - batched_logits_processor: Optional[BatchedLogitsProcessor] = None, - postproc_worker_config: Optional[PostprocWorkerConfig] = None, - is_llm_executor: Optional[bool] = None, - lora_config: Optional[LoraConfig] = None, - kv_connector_config: Optional[KvCacheConnectorConfig] = None, - hf_model_dir: Optional[Path] = None, - tokenizer: Optional[TokenizerBase] = None, - llm_args: Optional[BaseLlmArgs] = None, - ) -> None: - postproc_config = postproc_worker_config or PostprocWorkerConfig() - super().__init__( - num_postprocess_workers=postproc_config.num_postprocess_workers, - postprocess_tokenizer_dir=postproc_config.postprocess_tokenizer_dir, - is_llm_executor=is_llm_executor, - ) - - # inputs - self._engine = engine - self._executor_config = executor_config - self._batched_logits_processor = batched_logits_processor - self._hf_model_dir = hf_model_dir - self._tokenizer = tokenizer - self._kv_connector_config = kv_connector_config - self._lora_config = lora_config - self.llm_args = llm_args - - self.result_queue: Optional[IpcQueue] = None - self.postproc_queues: Optional[List[IpcQueue]] = None - self.rank = mpi_rank() - self.global_rank = global_mpi_rank() - # mapping: client_id -> GenerationResult - self._results: Dict[int, GenerationResult] = {} - # mapping: client_id from Proxy -> request_id returned from runtime backend - self._client_id_to_request_id: Dict[int, int] = {} - self._await_response_helper = AwaitResponseHelper( - self) # TODO: make it weakref - self._is_pytorch_backend = llm_args is not None and llm_args.backend in [ - "pytorch", "_autodeploy" - ] - - if not self._is_pytorch_backend and kv_connector_config is not None: - raise ValueError( - "KV connector config is only supported for PyTorch backend") - - def setup_engine(self) -> None: - logger_debug(f"WorkerBase {self.rank} is setting up the engine", - color="yellow") - # Force all the ranks to wait here, and start creating the executor simultaneously. - mpi_comm().barrier() - - if global_mpi_size() > 1: - logger.set_rank(self.global_rank) - - engine = self._engine - self.llm_args - batched_logits_processor = self._batched_logits_processor - hf_model_dir = self._hf_model_dir - tokenizer = self._tokenizer - lora_config = self._lora_config - kv_connector_config = self._kv_connector_config - executor_config = self._executor_config - assert hf_model_dir is not None - - if isinstance(engine, list): - engine = engine[self.rank] - - def _get_comm_ranks_device_id(): - device_id = self.global_rank % torch.cuda.device_count() - torch.cuda.set_device(device_id) - # Make sure C++ executor would use same devices/ranks as py_executor - global_rank = global_mpi_rank() - comm_ranks = mpi_comm().allgather(global_rank) - device_ids = mpi_comm().allgather(device_id) - return comm_ranks, device_ids - - def _create_py_executor(): - args = {} - assert hasattr( - self.llm_args, "backend" - ), "llm_args should be with backend in _create_py_executor" - _ = _get_comm_ranks_device_id() - if self.llm_args.backend == "pytorch": - from tensorrt_llm._torch.pyexecutor.py_executor_creator import \ - create_py_executor - create_executor = create_py_executor - args["llm_args"] = self.llm_args - args["checkpoint_dir"] = hf_model_dir - args["tokenizer"] = tokenizer - args["lora_config"] = lora_config - args["kv_connector_config"] = kv_connector_config - elif self.llm_args.backend == "_autodeploy": - from tensorrt_llm._torch.auto_deploy.llm_args import \ - LlmArgs as ADLlmArgs - from tensorrt_llm._torch.auto_deploy.shim.ad_executor import \ - create_autodeploy_executor - create_executor = create_autodeploy_executor - assert isinstance(self.llm_args, ADLlmArgs) - args["ad_config"] = self.llm_args.get_pytorch_backend_config() - else: - raise ValueError( - f"Unsupported backend config: {self.llm_args.backend}") - - # Define additional attributes that can be used later, such as in _deduce_max_tokens - self.mapping = self.llm_args.parallel_config.to_mapping() - self.checkpoint_loader = None - if self.llm_args.backend == "pytorch": - from tensorrt_llm._torch.pyexecutor.config import \ - _construct_checkpoint_loader - self.checkpoint_loader = _construct_checkpoint_loader( - self.llm_args.backend, self.llm_args.checkpoint_loader, - self.llm_args.checkpoint_format) - - _executor = create_executor(**args) - self.max_seq_len = self.llm_args.max_seq_len - if _executor.max_seq_len is not None: - # max_seq_len might be updated by model engine as in create_py_executor - self.max_seq_len = _executor.max_seq_len - return _executor - - def _create_engine(executor_config): - if executor_config is None: - executor_config = tllm.ExecutorConfig(1) - executor_config.logits_post_processor_config = tllm.LogitsPostProcessorConfig( - processor_batched=batched_logits_processor, replicate=False) - comm_ranks, device_ids = _get_comm_ranks_device_id() - executor_config.parallel_config = tllm.ParallelConfig( - participant_ids=comm_ranks, device_ids=device_ids) - - if isinstance(engine, Engine): - return tllm.Executor(engine.engine, - json.dumps(engine.config.to_dict(), - cls=ConfigEncoder), - tllm.ModelType.DECODER_ONLY, - executor_config=executor_config, - managed_weights=engine.managed_weights) - - assert not hasattr(executor_config, "backend") - return tllm.Executor(engine, tllm.ModelType.DECODER_ONLY, - executor_config) - - self.engine = _create_py_executor( - ) if self.llm_args is not None else _create_engine(executor_config) - - self._lora_manager: Optional[LoraManager] = None - self._prompt_adapter_manager: Optional[PromptAdapterManager] = None - self._runtime_model_config: Optional[ModelConfig] = None - if self.rank == 0 and isinstance(self.engine, tllm.Executor): - if isinstance(engine, Engine): - engine_config = engine.config - else: - engine_config = EngineConfig.from_json_file( - f"{engine}/config.json") - self._runtime_model_config = _engine_config_to_model_config( - engine_config) - if engine_config.build_config.plugin_config.lora_plugin: - # TODO(azuker): Passing peft cache manager to LoraManager is used for LoRA optimization - # (see LoraManager constructor docstring). Getting the peft cache manager from this - # point in the TRT flow is currently not supported (it's at the CPP - # Executor->ExecutorImpl->TrtGptModel->mPeftCacheManager) therefore for now this LoRA - # optimization is not available in TRT-python flow. - self._lora_manager = LoraManager(cpp_peft_cache_manager=None) - if engine_config.build_config.max_prompt_embedding_table_size > 0: - self._prompt_adapter_manager = PromptAdapterManager() - - if self.llm_args and getattr( - self.llm_args, "backend", - "") == "pytorch" and lora_config is not None: - from tensorrt_llm._torch.pyexecutor.resource_manager import \ - ResourceManagerType - peft_cache_manager = self.engine.resource_manager.resource_managers.get( - ResourceManagerType.PEFT_CACHE_MANAGER) - self._lora_manager = LoraManager( - cpp_peft_cache_manager=peft_cache_manager.impl) - lora_model_config = self.engine.model_engine.lora_model_config - assert lora_model_config is not None - self._lora_model_config = lora_model_config - - def abort_request(self, client_id: int) -> None: - # NOTE: the request_id is the request_id generated by cpp runtime, not the client_id - if self.engine.can_enqueue_requests(): - request_id = self._client_id_to_request_id.get(client_id, None) - if request_id is None: - logger.warning( - f"Request of client_id {client_id} is finished, cannot abort it." - ) - return - self.engine.cancel_request(request_id) - - def _engine_response_callback(self, response: tllm.Response): - return response - - def _load_lora_adapter(self, lora_request: LoRARequest) -> bool: - """Returns True if the adapter was loaded by this call, False if it was already loaded""" - adapter_id = str(lora_request.adapter_id) - newly_loaded_uids = self._lora_manager.load_from_ckpt( - [lora_request.path], - model_config=self._runtime_model_config if - self._runtime_model_config is not None else self._lora_model_config, - runtime_mapping=None, - uids=[adapter_id], - ckpt_source=lora_request.ckpt_source) - return adapter_id in newly_loaded_uids - - def _load_prompt_adapter(self, - prompt_adapter_request: PromptAdapterRequest): - self._prompt_adapter_manager.load_from_ckpt( - [prompt_adapter_request.local_path], - model_config=self._runtime_model_config, - uids=[str(prompt_adapter_request.adapter_id)]) - - def _enqueue_request(self, request: GenerationRequest) -> int: - assert request.id is not None - py_lora_path = None - if self._lora_manager is not None and request.lora_request is not None: - adapter_in_cache = self._lora_manager.is_adapter_in_cpu_cache( - request.lora_request.adapter_id) - self._load_lora_adapter(request.lora_request) - uid = str(request.lora_request.adapter_id) - lora_config = tllm.LoraConfig( - task_id=request.lora_request.adapter_id, - weights=self._lora_manager.cpp_lora_weights[uid] - if not adapter_in_cache else None, - config=self._lora_manager.cpp_lora_config[uid]) - py_lora_path = request.lora_request.lora_path - else: - lora_config = None - - prompt_token_ids = copy.deepcopy(request.prompt_token_ids) - prompt_tuning_config = None - if request.prompt_adapter_request is not None: - self._load_prompt_adapter(request.prompt_adapter_request) - uid = str(request.prompt_adapter_request.adapter_id) - prompt_tuning_config = tllm.PromptTuningConfig( - self._prompt_adapter_manager.uid_to_weights[uid]) - vocab_size = self._runtime_model_config.vocab_size - pa_length = prompt_tuning_config.embedding_table.size(0) - prompt_token_ids = list(range( - vocab_size, vocab_size + pa_length)) + prompt_token_ids - - # MULTIMODAL - # NOTE: Since, we only support PyTorch backend for multimodal, we will send multimodal_data through the 'py_multimodal_data' field - # except `multimodal_input` as it needs to go through the C++ runtime. - multimodal_input = None - if request.multimodal_params is not None and request.multimodal_params.has_content( - ): - if request.multimodal_params.multimodal_input is not None: - multimodal_input = tllm.MultimodalInput( - multimodal_hashes=request.multimodal_params. - multimodal_input.multimodal_hashes, - multimodal_positions=request.multimodal_params. - multimodal_input.multimodal_positions, - multimodal_lengths=request.multimodal_params. - multimodal_input.multimodal_lengths) - # NOTE: Setting to None here to avoid sending multimodal_input again through the 'py_multimodal_data' field - request.multimodal_params.multimodal_input = None - - context_phase_params = None - request_type = tllm.RequestType.REQUEST_TYPE_CONTEXT_AND_GENERATION - if request.disaggregated_params is not None: - assert ( - not self._is_pytorch_backend - or self.engine.kv_cache_transceiver is not None - ), "kv_cache_transceiver is disabled, please set 'cache_transceiver_config: backend:` in config file for disaggregated serving" - request_type = request.disaggregated_params.get_request_type() - if request_type == tllm.RequestType.REQUEST_TYPE_GENERATION_ONLY: - context_phase_params = request.disaggregated_params.get_context_phase_params( - ) - - if self._is_pytorch_backend: - if not self.llm_args.disable_overlap_scheduler: - is_disaggregated = self.engine.kv_cache_transceiver is not None - if is_disaggregated and ( - request_type - == tllm.RequestType.REQUEST_TYPE_CONTEXT_ONLY): - raise ValueError( - "Context only requests are not supported in pytorch backend when overlap is enabled." - ) - - assert request.id is not None - - def _deduce_max_tokens(request: GenerationRequest, - executor_config: tllm.ExecutorConfig, - llm_args: Optional[BaseLlmArgs] = None) -> int: - # deduce max_tokens when it's not set by user - max_tokens = request.sampling_params.max_tokens - query_token_len = len( - request.query_token_ids) if request.query_token_ids else 0 - - cp_size = 1 - max_seq_len = None - if llm_args is not None: - # deduce max_tokens by llm args - assert executor_config is None, "An empty executor_config in _deduce_max_tokens is expected when LLM arguments are defined." - if hasattr(self, - "mapping") and self.mapping.cp_size is not None: - cp_size = self.mapping.cp_size - max_seq_len = getattr(self, "max_seq_len", None) - else: - # deduce max_tokens by executor config - if hasattr(executor_config, "mapping" - ) and executor_config.mapping.cp_size is not None: - cp_size = executor_config.mapping.cp_size - max_seq_len = getattr(executor_config, "max_seq_len", None) - if max_seq_len is None: - logger.warning("`default_max_tokens` cannot be deduced") - if max_tokens is None: - raise ValueError( - "`max_tokens` must be set when `default_max_tokens` cannot be deduced" - ) - else: - # use max_tokens if can't deduce default_max_tokens - return max_tokens - splited_prompt_len = int(len(prompt_token_ids) / cp_size) - default_max_tokens = max_seq_len - splited_prompt_len - query_token_len - if default_max_tokens <= 0: - logger.warning( - f"`default_max_tokens` ({default_max_tokens}) should be greater than 0, " - f"`default_max_tokens` ({default_max_tokens}) = max_seq_len ({max_seq_len})" - f" - `splited_prompt_len` ({splited_prompt_len}) - `query_token_len` ({query_token_len})" - ) - if max_tokens is None: - raise ValueError( - "`max_tokens` must be set when `default_max_tokens` is illegal" - ) - # default_max_tokens is the biggest available value - if max_tokens is None: - return default_max_tokens - elif max_tokens > default_max_tokens: - logger.warning( - f"User-specified `max_tokens` ({max_tokens}) is greater than deduced " - f"`default_max_tokens` ({default_max_tokens}), using default_max_tokens instead." - ) - return default_max_tokens - return max_tokens - - try: - executor_request = tllm.Request( - client_id=request.id, - input_token_ids=prompt_token_ids, - max_tokens=_deduce_max_tokens(request, self._executor_config, - self.llm_args), - streaming=request.streaming, - sampling_config=request.sampling_params._get_sampling_config(), - end_id=-1 if request.sampling_params.ignore_eos else - request.sampling_params.end_id, - pad_id=request.sampling_params.pad_id, - output_config=request.sampling_params._get_output_config( - is_pytorch_backend=self._is_pytorch_backend), - # Beam search enforces return_all_generated_tokens=True regardless of the passed value - return_all_generated_tokens=False, - # convert python config into pybind config - lookahead_config=PybindMirror.maybe_to_pybind( - request.sampling_params.lookahead_config), - guided_decoding_params=request.sampling_params. - _get_guided_decoding_params(), - bad_words=request.sampling_params._get_bad_words(), - stop_words=request.sampling_params._get_stop_words(), - embedding_bias=request.sampling_params.embedding_bias, - lora_config=lora_config, - prompt_tuning_config=prompt_tuning_config, - multimodal_input=multimodal_input, - # NOTE: `multimodal_embedding` and `mrope_config` will be in MultimodalParams.multimodal_data. And this will be handled below by `py_multimodal_data`. - multimodal_embedding=None, - mrope_config=None, - logits_post_processor_name=( - tllm.Request.BATCHED_POST_PROCESSOR_NAME - if request.sampling_params.apply_batched_logits_processor - else None), - logits_post_processor=None if self._is_pytorch_backend else - request.sampling_params.logits_processor, - kv_cache_retention_config=request.kv_cache_retention_config, - context_phase_params=context_phase_params, - type=request_type, - cache_salt_id=request.cache_salt_id) - executor_request.py_lora_path = py_lora_path - - if self._is_pytorch_backend and request.multimodal_params is not None: - if request.multimodal_params.multimodal_data is not None: - # NOTE: Deserialize SharedTensor handle to actual tensor - request.multimodal_params.to_tensor("multimodal_data") - executor_request.py_multimodal_data = request.multimodal_params.multimodal_data - - if self._is_pytorch_backend and request.sampling_params.logits_processor: - # For PyTorch backend, we attach logits processors as a dynamic Python attribute - # instead of using the C++ binding, since the latter will cause PyCapsule pickling issues. - lp = request.sampling_params.logits_processor - executor_request.py_logits_post_processors = lp if isinstance( - lp, list) else [lp] - - executor_request.py_scheduling_params = None - if self._is_pytorch_backend and request.scheduling_params is not None: - executor_request.py_scheduling_params = request.scheduling_params - - if request.arrival_time is not None: - executor_request.py_arrival_time = request.arrival_time - - if request.query_token_ids is not None: - # pytorch star attention workflow - # a workaround to avoid public interface update - req_id = self.engine.enqueue_request(executor_request, - request.query_token_ids) - else: - req_id = self.engine.enqueue_request(executor_request) - return req_id - except Exception as e: - raise RequestError(str(e)) from e - - def submit(self, request: GenerationRequest) -> GenerationResult: - """ Low-level API to the executor. Return a "future" GenerationResult which can be waited. """ - if self.rank != 0: - raise RuntimeError( - "Only rank 0 can submit requests.\n" - "To fix this, ensure that the llm.generate(...) method is " - "guarded with the `if __name__ == '__main__':` block.") - - client_id = request.id if request.id is not None else self._get_next_client_id( - ) - if request.id is None: - request.set_id(client_id) - - logprob_params = self._get_logprob_params(request) - - result = GenerationResult( - request, - background_error_handler=self._handle_background_error, - executor=self, - disaggregated_params=request.disaggregated_params, - logprob_params=logprob_params) - - self._results[client_id] = result - - request_id = self._enqueue_request(request) - # request_id returned from backend is necessary for the abort_request method. - self._client_id_to_request_id[client_id] = request_id - - self._handle_background_error() - - return result - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - self.shutdown() - return True - - def await_responses(self, timeout: Optional[float] = None) -> None: - self._await_response_helper(timeout) - logger.debug(f"worker done await_responses") - - def fetch_kv_cache_events(self) -> list: - if isinstance(self.engine, tllm.Executor): - # Check if the engine has a kv cache event manager - # If not, return an empty list for the events which will cause the thread to exit early. - event_manager = self.engine.get_kv_cache_event_manager() - if event_manager is None: - return [] - else: - return event_manager.get_latest_events() - else: - return self.engine.get_latest_kv_cache_events() - - def fetch_stats( - self) -> List[Tuple[tllm.IterationStats, tllm.RequestStats]]: - if isinstance(self.engine, tllm.Executor): - iter_stats = self.engine.get_latest_iteration_stats() - #TODO: Support req stats with TRT engine - # This would require ensuring iter and req stats have same size - return [(iter_stat, None) for iter_stat in iter_stats] - else: - return self.engine.get_latest_iteration_stats() - - # Define a Callable to join iteration and request stats - @staticmethod - def _stats_serializer( - stats: Tuple[tllm.IterationStats, tllm.RequestStats]) -> str: - iteration_stats, req_stats = stats - stats_dict = json.loads(iteration_stats.to_json_str()) - - if req_stats is not None and len(req_stats) > 0: - stats_dict["requestStats"] = [] - for req_stat in req_stats: - stats_dict["requestStats"].append( - json.loads(req_stat.to_json_str())) - - # Convert back to JSON string - return json.dumps(stats_dict) - - def set_result_queue(self, queue: Queue | IpcQueue): - """In multi-gpu mode, result_queue will be set here to communicate between the proxy and the worker 0 process.""" - assert self.postproc_queues is None - self.result_queue = queue - - def set_postproc_queues(self, queues: list[Queue | IpcQueue]): - """ Set the IPC queues for feeding post-processing processes. """ - assert self.result_queue is None - self.postproc_queues = queues - - def _pop_result(self, client_id: int): - self._results.pop(client_id, None) - self._client_id_to_request_id.pop(client_id, None) - - def shutdown(self): - if self.engine is None: - return - - logger_debug(f"WorkerBase {self.rank} is shutting down", color="yellow") - if self.engine.can_enqueue_requests(): - self.engine.shutdown() - self.engine = None - - if hasattr(self._executor_config, "checkpoint_loader" - ) and self._executor_config.checkpoint_loader is not None: - self._executor_config.checkpoint_loader.cleanup() - self._executor_config.checkpoint_loader = None - - self.engine = None - - logger_debug(f"WorkerBase {self.rank} shutdown done", color="yellow") - - # Check if there are any errors from the threads before shutdown. - self._handle_background_error() - - def _has_background_error(self) -> bool: - # TODO[Superjomn]: The worker background error should be deprecated once - # RPC approach is supported. - return not self._error_queue.empty() - - -class AwaitResponseHelper: - ''' Multiple-implementations for await_response for performance. ''' - - class HandlerKind(enum.Enum): - unknown = 0 - single_process_worker = 1 - ipc_batched = 2 - - def __init__(self, worker: "WorkerBase"): - self.worker = worker - self.handler_kind: AwaitResponseHelper.HandlerKind = AwaitResponseHelper.HandlerKind.unknown - self.enable_postprocprocess_parallel = self.worker.enable_postprocess_parallel - # The error responses when submit request failed will be put here - self.temp_error_responses = Queue() - - def responses_handler(self, responses: List[tllm.Response]): - HandlerKind = AwaitResponseHelper.HandlerKind - - if self.handler_kind is HandlerKind.unknown: - if not (self.worker.result_queue is not None - or self.worker.postproc_queues is not None): - print_colored_debug( - f"creating await_response helper for Worker\n", - color="yellow") - # When ExecutorBindingWorker is used in the main process - # aka the single process mode - self.handler_kind = HandlerKind.single_process_worker - elif self.worker.result_queue is not None or self.worker.postproc_queues is not None: - # The ExecutorBindingProxy is used - print_colored_debug(f"creating await_response helper for IPC\n", - color="yellow") - self.handler_kind = HandlerKind.ipc_batched - else: - raise NotImplementedError - - match self.handler_kind: - case HandlerKind.single_process_worker: - return self.handle_for_worker(responses) - case HandlerKind.ipc_batched: - return self.handle_for_ipc_batched(responses) - case _: - raise NotImplementedError - - def __call__(self, timeout: Optional[float] = None) -> bool: - ''' This method should be called by a ManagedThread. ''' - logger.debug(f"await_response: {self.worker.engine}") - timeout = datetime.timedelta(seconds=timeout or 0.1) - responses = self.worker.engine.await_responses(timeout=timeout) - logger.debug(f"PyExecutor returned {len(responses)} responses") - - # filter since The _engine_response_callback may return None - responses = list( - filter( - lambda _: _, - [self.worker._engine_response_callback(r) for r in responses])) - - # append the error responses to the temp_error_responses - while not self.temp_error_responses.empty(): - responses.append(self.temp_error_responses.get()) - - with nvtx_range_debug(f"await_response-{len(responses)}", - color="red", - category="Worker"): - self.responses_handler(responses) - return True - - def handle_for_worker(self, responses: List[tllm.Response]) -> None: - ''' Return the responses to asyncio.event_loop. ''' - event_loop = None - async_queues = [] - for response in responses: - assert response is not None - queue = self.worker.return_queue(response.client_id) - - response = _maybe_wrap_response(self.worker, response, - self.worker._is_pytorch_backend) - - # For AsyncQueue.sync_q, we will batch the events to avoid too many - # event notifications, thus put without wait here. - if isinstance(queue, _SyncQueue): - global_tracer().log_instant("worker-rsp.put") - queue.put_nowait(response) - async_queues.append(queue) - # all the loops are identical - event_loop = event_loop or queue.loop - else: - queue.put(response) - - if response.has_error() or response.result.is_final: - self.worker._pop_result(response.client_id) - - # Notify the events in bulk for performance. - if async_queues: - _SyncQueue.notify_many(event_loop, async_queues) - - def handle_for_ipc_batched(self, responses: List[tllm.Response]) -> None: - ''' Perform the IPC in batch explicitly. ''' - postproc_batches = [ - [] - for _ in range(self.worker.postproc_config.num_postprocess_workers) - ] if self.enable_postprocprocess_parallel else None - rsp_batch = [] if not self.enable_postprocprocess_parallel else None - - for response in responses: - - if isinstance(response, ErrorResponse): - pass # send ErrorResponse directly - elif self.worker._has_background_error(): - response = self.worker._create_error_response(response) - elif response.has_error(): - # Convert to ErrorResponse, because tllm.Response cannot be - # serialized when it has error. - response = ErrorResponse(response.client_id, response.error_msg, - response.request_id) - else: - response = _maybe_wrap_response(self.worker, response, - self.worker._is_pytorch_backend) - - _send_rsp(self.worker, - response, - postproc_batches=postproc_batches, - rsp_batch=rsp_batch) - - if postproc_batches: - for wid, batch in enumerate(postproc_batches): - if batch: - self.worker.postproc_queues[wid].put(batch) - - if rsp_batch: - self.worker.result_queue.put(rsp_batch) - - -def _get_params_for_first_rsp( - worker, - client_id) -> Tuple[Optional[SamplingParams], Optional[PostprocParams]]: - res = worker._results.get(client_id, None) - assert res is not None - if not res._params_transmitted: - res._params_transmitted = True - return res.sampling_params, res.postproc_params - return None, None - - -def _get_logprobs(worker, - response: tllm.Response, - is_pytorch_backend=False) -> Optional[LogProbsResult]: - """Compute logprob and prompt logprob and clear out logits if applicable. - """ - if is_pytorch_backend: - # _get_logprobs() is a WAR for the TRT backend, where top-k logprobs are computed post runtime. - # In the PyTorch backend, logprobs are already computed during runtime if requested. - return None - - logprobs_result = None - generation_result = worker._results.get(response.client_id, None) - - if not generation_result: - return - - logprob_params = getattr(generation_result, "_logprob_params", None) - if logprob_params: - logprobs_result = compute_logprobs(logprob_params.prompt_logprobs, - logprob_params.logprobs, - response.result.context_logits, - response.result.generation_logits, - response.result.output_token_ids[0]) - - if logprob_params.drop_context_logits: - response.clear_context_logits() - - if logprob_params.drop_generation_logits: - response.clear_generation_logits() - - if response.result.is_final: - generation_result.clear_logprob_params() - - return logprobs_result - - -def _send_rsp( - worker, - response: Union[tllm.Response, ResponseWrapper, ErrorResponse], - postproc_batches: Optional[List[List["PostprocWorker.Input"]]] = None, - rsp_batch: Optional[List[tllm.Response]] = None): - # if postproc_batches is set, append to batch instead of putting to IpcQueue - - if worker.result_queue is not None: - if rsp_batch is not None: - rsp_batch.append(response) - else: - worker.result_queue.put(response) - else: - sampling_params, postproc_params = _get_params_for_first_rsp( - worker, response.client_id) - inp = PostprocWorker.Input( - response, - # sampling_params is necessary for creating fake GenerationResult - # instances in the postproc processes. They are for incremental - # detokenize. They should be transmitted only once for each - # Request. - sampling_params=sampling_params, - postproc_params=postproc_params, - streaming=worker._results.get(response.client_id, None)._streaming) - - pid = response.client_id % worker.postproc_config.num_postprocess_workers - - if not postproc_batches: - # Group the responses into buckets for the postprocessing steps. - # Bucketing is used instead of random dispatching because the - # incremental detokenization during postprocessing relies on the - # prior CompletionOutput of a given request. - worker.postproc_queues[pid].put(inp) - else: - postproc_batches[pid].append(inp) - - # Eliminate the finished GenerationRequest instances timely, which may - # take considerable memory. - if is_llm_response(response): - if response.has_error() or response.result.is_final: - worker._pop_result(response.client_id) - elif isinstance(response, ErrorResponse): - worker._pop_result(response.client_id) - else: - raise ValueError(f"Unknown response type: {response}") - - -def _get_metrics_dict( - response: tllm.Response) -> dict[RequestEventTiming, float]: - req_perf_metrics, metrics_dict = None, {} - res = response.result - if res: - if hasattr(res, '_result'): - if result := res.get_result(): - req_perf_metrics = result.request_perf_metrics - else: - req_perf_metrics = res.request_perf_metrics - if req_perf_metrics and req_perf_metrics.timing_metrics: - metrics_dict = { - RequestEventTiming.ARRIVAL_TIME: - req_perf_metrics.timing_metrics.arrival_time.total_seconds(), - RequestEventTiming.FIRST_TOKEN_TIME: - req_perf_metrics.timing_metrics.first_token_time.total_seconds( - ), - RequestEventTiming.FIRST_SCHEDULED_TIME: - req_perf_metrics.timing_metrics.first_scheduled_time. - total_seconds(), - RequestEventTiming.LAST_TOKEN_TIME: - req_perf_metrics.timing_metrics.last_token_time.total_seconds() - } - return metrics_dict - - -def _maybe_wrap_response( - worker, - response: tllm.Response, - is_pytorch_backend=False) -> Union[tllm.Response, ResponseWrapper]: - - logprobs_result = _get_logprobs(worker, response, is_pytorch_backend) - req_perf_metrics = _get_metrics_dict(response) - if logprobs_result or req_perf_metrics: - response = ResponseWrapper(response, logprobs_result, req_perf_metrics) - return response diff --git a/tests/unittest/executor/test_rpc.py b/tests/unittest/executor/test_rpc.py index e84be8fad34..f628bc8b2b3 100644 --- a/tests/unittest/executor/test_rpc.py +++ b/tests/unittest/executor/test_rpc.py @@ -618,8 +618,44 @@ async def test_invalid_streaming_call(self): pass -if __name__ == "__main__": - #TestRpcError().test_shutdown_cancelled_error() - #test_rpc_shutdown_server() - #TestRpcShutdown().test_submit_request_after_server_shutdown() - test_rpc_timeout(True) +class TestResponsePickleError: + """ The pickle error will break the whole server, test the error handling. """ + + class App: + + def unpickleable_return(self): + # Functions defined locally are not pickleable + def nested_function(): + pass + + return nested_function + + async def unpickleable_streaming_return(self): + # Functions defined locally are not pickleable + def nested_function(): + pass + + yield nested_function + + def test_unpickleable_error(self): + with RpcServerWrapper( + self.App(), addr="ipc:///tmp/rpc_test_pickle_error") as server: + with RPCClient("ipc:///tmp/rpc_test_pickle_error") as client: + with pytest.raises(RPCError) as exc_info: + client.unpickleable_return().remote() + + assert "Failed to pickle response" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_unpickleable_streaming_error(self): + with RpcServerWrapper(self.App(), + addr="ipc:///tmp/rpc_test_pickle_error_streaming", + async_run_task=True) as server: + with RPCClient( + "ipc:///tmp/rpc_test_pickle_error_streaming") as client: + with pytest.raises(RPCStreamingError) as exc_info: + async for _ in client.unpickleable_streaming_return( + ).remote_streaming(): + pass + + assert "Failed to pickle response" in str(exc_info.value) diff --git a/tests/unittest/executor/test_rpc_proxy.py b/tests/unittest/executor/test_rpc_proxy.py index be165b41a6f..22e5cbb7eee 100644 --- a/tests/unittest/executor/test_rpc_proxy.py +++ b/tests/unittest/executor/test_rpc_proxy.py @@ -3,7 +3,7 @@ import time import pytest -from test_worker_base import create_fake_executor_config +from test_base_worker import create_fake_executor_config from tensorrt_llm.executor.rpc_proxy import GenerationExecutorRpcProxy from tensorrt_llm.llmapi.mpi_session import MpiPoolSession @@ -33,6 +33,7 @@ def create_proxy(self, tp_size: int): "executor_config": None, "llm_args": llm_args, "model_world_size": tp_size, + "hf_model_dir": model_path, }, model_world_size=tp_size, mpi_session=mpi_session, @@ -78,4 +79,4 @@ def test_tp2(self, num_reqs): if __name__ == "__main__": - TestRpcProxyTp1().test_tp1() + TestRpcProxy().test_tp1(1) diff --git a/tests/unittest/executor/test_rpc_worker.py b/tests/unittest/executor/test_rpc_worker.py index 8c7affb228a..4b814d345d9 100644 --- a/tests/unittest/executor/test_rpc_worker.py +++ b/tests/unittest/executor/test_rpc_worker.py @@ -6,7 +6,7 @@ from concurrent.futures import ProcessPoolExecutor import pytest -from test_worker_base import create_fake_executor_config +from test_base_worker import create_fake_executor_config from tensorrt_llm.executor.request import GenerationRequest from tensorrt_llm.executor.rpc import RPCClient @@ -21,6 +21,7 @@ # isort: on model_path = llm_models_root() / "llama-models-v2/TinyLlama-1.1B-Chat-v1.0" +assert model_path.exists() class TestRpcWorkerTP1: @@ -31,6 +32,7 @@ def setup_method(self): self.pool, self.addr = self.create_worker_pool() self.client = self.create_rpc_client(self.addr) self.client.setup_engine().remote() + print(f"Worker setup engine done") time.sleep(10) def teardown_method(self): @@ -43,11 +45,14 @@ def create_worker_pool(self): mp_context = multiprocessing.get_context( 'spawn') # spawn for CUDA context pool = ProcessPoolExecutor(max_workers=1, mp_context=mp_context) - pool.submit(RpcWorker.main_task, - engine=model_path, - rpc_addr=addr, - executor_config=self.executor_config, - llm_args=self.llm_args) + pool.submit( + RpcWorker.main_task, + engine=model_path, + rpc_addr=addr, + executor_config=self.executor_config, + llm_args=self.llm_args, + hf_model_dir=model_path, + ) return pool, addr def create_rpc_client(self, addr: str): @@ -58,13 +63,21 @@ def test_create_shutdown(self): pass def test_fetch_responses_sync(self): + # Wait a bit to ensure engine is ready + time.sleep(1) + + print(f"start to submit") self.client.submit( GenerationRequest(prompt_token_ids=[3, 4, 5], sampling_params=SamplingParams( max_tokens=5)), ).remote(need_response=False) + print(f"submit done") + + time.sleep(3) + results = [] - while not results: - results.extend(self.client.fetch_responses().remote()) + # Fetch responses + results.extend(self.client.fetch_responses().remote()) assert len(results) == 1 def test_fetch_responses_streaming_sync(self): @@ -75,9 +88,12 @@ def test_fetch_responses_streaming_sync(self): results = [] for i in range(10): - res = self.client.fetch_responses().remote() + res = self.client.fetch_responses().remote(timeout=1.0) results.extend(res) print(f"fetch_responses {i} result: {results}") + # If we've received enough results, break early + if len(results) >= 5: + break assert 0 < len(results) <= 5 time.sleep(5) @@ -174,52 +190,6 @@ async def process_request_streaming(): await process_request_streaming() - def test_main_loop(self): - time.sleep(1) - - def process_request(): - ret = self.client.submit( - GenerationRequest( - prompt_token_ids=[3, 4, 5], - sampling_params=SamplingParams(max_tokens=10)), ).remote( - need_response=False) - assert ret is None # need_response = False - - print(f"submit result: {ret}") - print("call fetch_responses") - # NOTE: known issue, the responses should be fetched before shutdown, - # or the shutdown will hang. - results = [] - time.sleep(8) # wait for PyExecutor to finish the generation - results.extend(self.client.fetch_responses().remote() - ) # fetch_responses will block - print(f"fetch_responses result: {results}") - assert len(results) == 1 # one request, one response - - def process_request_streaming(): - ret = self.client.submit( - GenerationRequest(prompt_token_ids=[3, 4, 5], - sampling_params=SamplingParams(max_tokens=10), - streaming=True), ).remote(need_response=False) - assert ret is None - print("submit result: ", ret) - - # NOTE: known issue, the responses should be fetched before shutdown, - # or the shutdown will hang. - results = [] - time.sleep(8) - - while not results: - time.sleep(1) - results.extend(self.client.fetch_responses().remote(timeout=10)) - print(f"try fetch_responses result: {results}") - print(f"fetch_responses result: {results}") - assert results - - for i in range(5): - process_request() - process_request_streaming() - class TestRpcWorkerTP2: @@ -244,6 +214,7 @@ def create_worker_session(self): rpc_addr=addr, executor_config=self.executor_config, llm_args=self.llm_args, + hf_model_dir=model_path, model_world_size=2) return session, addr, futures @@ -256,12 +227,18 @@ def test_create_shutdown(self): pass def test_fetch_responses_sync(self): + # Wait a bit to ensure engine is ready + time.sleep(1) + self.client.submit( GenerationRequest(prompt_token_ids=[3, 4, 5], sampling_params=SamplingParams( - max_tokens=5)), )\ - .remote(need_response=False) + max_tokens=5)), ).remote(need_response=False) + + # Wait for generation to complete + time.sleep(3) + results = [] - while not results: - results.extend(self.client.fetch_responses().remote()) + # Fetch responses with timeout + results.extend(self.client.fetch_responses().remote(timeout=5)) assert len(results) == 1 diff --git a/tests/unittest/executor/test_worker_base.py b/tests/unittest/executor/test_worker_base.py deleted file mode 100644 index 919d34c6a93..00000000000 --- a/tests/unittest/executor/test_worker_base.py +++ /dev/null @@ -1,185 +0,0 @@ -import os -import sys -import time - -import pytest -import torch - -from tensorrt_llm._utils import mpi_comm, mpi_rank, mpi_world_size -from tensorrt_llm.bindings import executor as tllm -from tensorrt_llm.llmapi.mpi_session import MpiPoolSession, set_mpi_session_cpp - -# isort: off -sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/..") -from utils.llm_data import llm_models_root -# isort: on - -from tensorrt_llm._torch.pyexecutor.config import update_executor_config -from tensorrt_llm.executor.request import GenerationRequest -from tensorrt_llm.executor.worker_base import WorkerBase -from tensorrt_llm.llmapi.llm_args import LlmArgs -from tensorrt_llm.sampling_params import SamplingParams - -default_model_name = "llama-models-v2/TinyLlama-1.1B-Chat-v1.0" -model_path = llm_models_root() / default_model_name - - -class TestWorkerBase: - - class FakeWorker(WorkerBase): - - def __init__(self, engine: str): - super().__init__(engine=engine, hf_model_dir=engine) - llm_args, executor_config = create_fake_executor_config(engine) - # Pass config in constructor and finalize with parameterless setup - self._executor_config = executor_config - self.llm_args = llm_args - self.setup_engine() - - def test_create_engine(self): - with self.FakeWorker(engine=model_path) as worker: - print(f"Created engine: {worker.engine}") - - def test_submit_request(self): - sampling_params = SamplingParams(max_tokens=10) - request = GenerationRequest(prompt_token_ids=[3, 4, 5], - sampling_params=sampling_params) - with self.FakeWorker(engine=model_path) as worker: - print(f"Created engine: {worker.engine}") - worker.submit(request) - for i in range(10): - time.sleep(0.5) - worker.await_responses() - print(f"Submitted request: {request}") - time.sleep(6) - - def test_fetch_stats(self): - request = GenerationRequest( - prompt_token_ids=[3, 4, 5], - sampling_params=SamplingParams(max_tokens=10)) - with self.FakeWorker(engine=model_path) as worker: - worker.submit(request) - time.sleep(1) - worker.await_responses() - stats = worker.fetch_stats() - assert len(stats) > 0 - - def test_dispatch_stats_task(self): - request = GenerationRequest( - prompt_token_ids=[3, 4, 5], - sampling_params=SamplingParams(max_tokens=10)) - with self.FakeWorker(engine=model_path) as worker: - worker.submit(request) - worker.await_responses() - time.sleep(10) - stats = worker.fetch_stats() - assert len(stats) == 1 - - @pytest.mark.parametrize("timeout", [0.1, 0.2, 1]) - def test_fetch_responses_timeout(self, timeout: float): - with self.FakeWorker(engine=model_path) as worker: - # Not submit any request, and let the await_responses timeout. - start_time = time.time() - results = worker.await_responses(timeout=timeout) - elapsed = time.time() - start_time - print(f"await_responses latency: {elapsed:.3f} seconds") - assert timeout / 2 <= elapsed <= timeout * 2, f"Latency out of expected range: {elapsed}" - assert results is None - - -def create_fake_executor_config(model_path, tp_size=1): - llm_args = LlmArgs(model=model_path, cuda_graph_config=None) - - executor_config = tllm.ExecutorConfig(1) - executor_config.max_batch_size = 1 - executor_config.model_world_size = tp_size - - # For PyTorch backend with TP > 1, we need proper parallel config - if tp_size > 1: - llm_args.parallel_config.tp_size = tp_size - - update_executor_config( - executor_config, - pytorch_backend_config=llm_args.get_pytorch_backend_config(), - mapping=llm_args.parallel_config.to_mapping(), - speculative_config=llm_args.speculative_config, - hf_model_dir=model_path, - max_input_len=20, - max_seq_len=40, - checkpoint_format=llm_args.checkpoint_format, - checkpoint_loader=llm_args.checkpoint_loader, - ) - - return llm_args, executor_config - - -class TestRpcWorkerBaseTP2: - - def setup_method(self): - self.llm_args, self.executor_config = create_fake_executor_config( - model_path, tp_size=2) - self.session = self.create_worker_session() - # No need to sleep here - the session is ready immediately - - def create_worker_session(self): - session = MpiPoolSession(n_workers=2) - return session - - def test_create_executor(self): - futures = self.session.submit( - TestRpcWorkerBaseTP2.create_executor, - engine=model_path, - llm_args=self.llm_args, - ) - # Wait for completion - for future in futures: - future.result() - - self.session.shutdown() - - @staticmethod - def create_executor(engine, llm_args): - # Set MPI session for C++ backend - set_mpi_session_cpp(mpi_comm()) - - # Set CUDA device for this rank - rank = mpi_rank() - world_size = mpi_world_size() - device_id = rank % torch.cuda.device_count() - torch.cuda.set_device(device_id) - - # Don't set CUDA_VISIBLE_DEVICES as it interferes with MPI multi-GPU setup - - print(f"[Test] Rank {rank}/{world_size} using device {device_id}") - - # Synchronize all workers before creating executor - mpi_comm().barrier() - - try: - print(f"[Test] Rank {rank} creating WorkerBase...") - executor = WorkerBase(engine=engine, - llm_args=llm_args, - hf_model_dir=engine) - - # For PyTorch backend, all ranks need to participate in setup - print(f"[Test] Rank {rank} calling setup_engine...") - - # Setup the engine which contains another barrier - executor.setup_engine() - - print(f"[Test] Rank {rank} setup_engine completed successfully") - - executor.shutdown() - - except Exception as e: - print(f"[Test] Rank {rank} failed with error: {e}") - import traceback - traceback.print_exc() - raise - - return None # executor cannot be picked and returned - - -if __name__ == "__main__": - test_worker_base = TestWorkerBase() - test_worker_base.test_fetch_stats() From 6269ae456eecd1ed2c2374538754612e28bc17a9 Mon Sep 17 00:00:00 2001 From: chunweiy Date: Mon, 22 Sep 2025 08:33:02 +0000 Subject: [PATCH 09/13] add llm tests Signed-off-by: chunweiy Signed-off-by: chunweiy <328693+Superjomn@users.noreply.github.com> Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> --- tensorrt_llm/executor/executor.py | 33 +++++++++++ tensorrt_llm/executor/rpc_proxy.py | 57 ++++++++++++++----- tensorrt_llm/executor/rpc_worker.py | 25 ++------ tensorrt_llm/llmapi/llm_args.py | 7 +++ tests/unittest/executor/test_rpc_worker.py | 17 ------ .../llmapi/test_llm_multi_gpu_pytorch.py | 35 +++++++++++- tests/unittest/llmapi/test_llm_pytorch.py | 16 ++++++ 7 files changed, 138 insertions(+), 52 deletions(-) diff --git a/tensorrt_llm/executor/executor.py b/tensorrt_llm/executor/executor.py index e8af846c247..1a2e552f355 100644 --- a/tensorrt_llm/executor/executor.py +++ b/tensorrt_llm/executor/executor.py @@ -410,9 +410,22 @@ def create( mpirun_launch = external_mpi_comm_available(model_world_size) # The case where the Python main process utilizes mpi4py to spawn MPI workers spawn_workers = need_spawn_mpi_workers(model_world_size) + orchestrator_is_rpc = llm_args and llm_args.orchestrator_type == "rpc" + if spawn_workers or (mpirun_launch and reuse_mpi_comm): if reuse_mpi_comm: assert mpi_session is not None, "reuse_mpi_comm requires an external MPI session" + + if orchestrator_is_rpc: + from .rpc_proxy import GenerationExecutorRpcProxy + return GenerationExecutorRpcProxy( + worker_kwargs, + model_world_size=model_world_size, + mpi_session=mpi_session, + postproc_worker_config=postproc_worker_config, + is_llm_executor=is_llm_executor, + kv_connector_config=kv_connector_config) + return GenerationExecutorProxy( worker_kwargs, model_world_size=model_world_size, @@ -429,6 +442,15 @@ def create( logger.warning( "Using single process worker for TP1, this may hurt streaming generation performance." ) + if orchestrator_is_rpc: + from .rpc_proxy import GenerationExecutorRpcProxy + return GenerationExecutorRpcProxy( + worker_kwargs, + model_world_size=model_world_size, + mpi_session=mpi_session, + postproc_worker_config=postproc_worker_config, + is_llm_executor=is_llm_executor, + kv_connector_config=kv_connector_config) return GenerationExecutorWorker( **worker_kwargs, is_llm_executor=is_llm_executor, @@ -439,6 +461,16 @@ def create( # While this requires uses to protect their entrypoint to # `if __name__ == "__main__":`. if not platform.system() == 'Windows': + if orchestrator_is_rpc: + from .rpc_proxy import GenerationExecutorRpcProxy + return GenerationExecutorRpcProxy( + worker_kwargs, + model_world_size=model_world_size, + mpi_session=mpi_session, + postproc_worker_config=postproc_worker_config, + is_llm_executor=is_llm_executor, + kv_connector_config=kv_connector_config) + return GenerationExecutorProxy( worker_kwargs, model_world_size=model_world_size, @@ -451,6 +483,7 @@ def create( # The ProcessPoolExecutorSession is used to support Windows, as mpi4py cannot. mpi_session = ProcessPoolExecutorSession(n_workers=1, mp_context=ctx) + # TODO: add rpc worker here return GenerationExecutorProxy( worker_kwargs, model_world_size=model_world_size, diff --git a/tensorrt_llm/executor/rpc_proxy.py b/tensorrt_llm/executor/rpc_proxy.py index 1d26c4e70b0..b43b25a9499 100644 --- a/tensorrt_llm/executor/rpc_proxy.py +++ b/tensorrt_llm/executor/rpc_proxy.py @@ -5,6 +5,7 @@ import time from typing import Optional +from ..llmapi.llm_args import KvCacheConnectorConfig from ..llmapi.mpi_session import MpiPoolSession, MpiSession from ..llmapi.tracer import global_tracer from ..llmapi.utils import (_SyncQueue, logger_debug, print_colored_debug, @@ -24,15 +25,16 @@ class GenerationExecutorRpcProxy(GenerationExecutor): # NOTE: this is a global counter for the number of instances of this class INSTANCE_COUNTER = 0 - def __init__(self, - worker_kwargs: dict, - model_world_size: int = 1, - mpi_session: Optional[MpiSession] = None, - *, - postproc_worker_config: Optional[PostprocWorkerConfig] = None, - is_llm_executor: Optional[bool] = None, - garbage_collection_gen0_threshold: Optional[int] = None, - clock_unit: int = 1): + def __init__( + self, + worker_kwargs: dict, + model_world_size: int = 1, + mpi_session: Optional[MpiSession] = None, + *, + postproc_worker_config: Optional[PostprocWorkerConfig] = None, + is_llm_executor: Optional[bool] = None, + kv_connector_config: Optional[KvCacheConnectorConfig] = None, + ): """ Args: worker_kwargs: kwargs for the rpc worker @@ -40,11 +42,8 @@ def __init__(self, mpi_session: the mpi session to use postproc_worker_config: the postproc worker config is_llm_executor: whether this is an llm executor - garbage_collection_gen0_threshold: the garbage collection gen0 threshold - clock_unit: the unit of the clock, 1 means 1 second + kv_connector_config: the kv cache connector config """ - self.clock_unit = clock_unit - GenerationExecutorRpcProxy.INSTANCE_COUNTER += 1 self.rpc_addr = self.gen_uniq_rpc_addr() self.rpc_client = RPCClient(self.rpc_addr) @@ -67,6 +66,9 @@ def __init__(self, self._shutdown_event = threading.Event() self.worker_kwargs = worker_kwargs + self.main_loop_task_obj = None + self.main_loop = None + self.launch_workers() time.sleep(1) # wait for the workers to launch @@ -95,6 +97,8 @@ async def main_loop_task(self): if self._shutdown_event.is_set(): return self.handle_responses(responses) + except asyncio.CancelledError: + logger.debug("Main loop task cancelled") except Exception as e: logger.error(f"Error in main_loop_task: {e}") raise @@ -103,7 +107,17 @@ def setup_mainloop(self): def _run_main_loop_task(): """Local method to run the main loop task.""" - asyncio.run(self.main_loop_task()) + self.main_loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.main_loop) + + self.main_loop_task_obj = self.main_loop.create_task( + self.main_loop_task()) + try: + self.main_loop.run_until_complete(self.main_loop_task_obj) + except asyncio.CancelledError: + pass # Task cancellation is expected during shutdown + finally: + self.main_loop.close() self.main_loop_thread = threading.Thread(target=_run_main_loop_task, daemon=True) @@ -190,6 +204,7 @@ def abort_request(self, request_id: int) -> None: def shutdown(self): if self._shutdown_event.is_set(): return + self._shutdown_event.set() logger_debug(f"Shutting down GenerationExecutorRpcProxy", color="yellow") @@ -197,13 +212,25 @@ def shutdown(self): self.shutdown_remote() # 2. stop the main loop, so that no new rpc requests - self._shutdown_event.set() + if self.main_loop and self.main_loop_task_obj: + logger_debug("Cancelling main loop task.", color="yellow") + # The cancel() is thread-safe + try: + self.main_loop.call_soon_threadsafe( + self.main_loop_task_obj.cancel) + except Exception as e: + logger_debug(f"Error cancelling main loop task: {e}", + color="yellow") + self.main_loop_thread.join() # 3. shutdown the mpi session, this should wait until all the PyExecutor # processes are shutdown if self.mpi_session is not None: + logger_debug(f"Shutting down mpi session", color="yellow") self.mpi_session.shutdown() + logger_debug(f"Mpi session shutdown", color="yellow") + self.mpi_session = None self.rpc_client.close() diff --git a/tensorrt_llm/executor/rpc_worker.py b/tensorrt_llm/executor/rpc_worker.py index 33543e4760e..17ae96ae02b 100644 --- a/tensorrt_llm/executor/rpc_worker.py +++ b/tensorrt_llm/executor/rpc_worker.py @@ -84,36 +84,22 @@ def fetch_responses(self, timeout: Optional[float] = None) -> list: qsize = self._response_queue.qsize() logger_debug(f"RpcWorker returning {qsize} responses", color="yellow") - if qsize == 0: - return [] - all_responses = [] for _ in range(qsize): # The queue contains batches of responses, so extend the list all_responses.extend(self._response_queue.get()) return all_responses - async def fetch_responses_async(self) -> list: + async def fetch_responses_async(self, + timeout: Optional[float] = None) -> list: # A really async version of fetch_responses logger_debug(f"RpcWorker {mpi_rank()} is fetching responses async", color="yellow") # First, await any pending responses without blocking the event loop - responses = await asyncio.to_thread(self.await_responses, 0.001) - # Handle the responses that are ready - self._await_response_helper.responses_handler(responses) - - qsize = self._response_queue.qsize() - logger_debug(f"RpcWorker returning {qsize} async responses", - color="yellow") - - if qsize == 0: - return [] - - all_responses = [] - for _ in range(qsize): - all_responses.extend(self._response_queue.get()) - return all_responses + responses = await asyncio.to_thread(self.fetch_responses, + timeout=timeout) + return responses # for streaming performance async def fetch_responses_loop_async(self) -> AsyncGenerator[list, None]: @@ -145,6 +131,7 @@ def shutdown(self): color="yellow") self.shutdown_event.set() super().shutdown() + logger_debug(f"RPC worker {mpi_rank()} is shutdown", color="yellow") def start(self): pass diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 2ea038935e1..7efec37b401 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -2437,6 +2437,13 @@ class TorchLlmArgs(BaseLlmArgs): status="prototype", ) + orchestrator_type: Optional[Literal["ray", "rpc"]] = Field( + default=None, + description= + "The orchestrator type to use. Defaults to None, which uses MPI.", + status="prototype", + ) + # PrivateVars _quant_config: Optional[QuantConfig] = PrivateAttr(default=None) diff --git a/tests/unittest/executor/test_rpc_worker.py b/tests/unittest/executor/test_rpc_worker.py index 4b814d345d9..c27347cb4df 100644 --- a/tests/unittest/executor/test_rpc_worker.py +++ b/tests/unittest/executor/test_rpc_worker.py @@ -96,23 +96,6 @@ def test_fetch_responses_streaming_sync(self): break assert 0 < len(results) <= 5 - time.sleep(5) - - @pytest.mark.asyncio - async def test_fetch_responses_streaming_async(self): - self.client.submit( - GenerationRequest(prompt_token_ids=[3, 4, 5], - sampling_params=SamplingParams(max_tokens=5), - streaming=True), ).remote(need_response=False) - - results = [] - # Must fetch all the responses, or the PyExecutor will hang - for i in range(10): - res = await self.client.fetch_responses_async().remote_async() - results.extend(res) - print(f"fetch_responses_async {i} result: {results}") - assert 0 < len(results) <= 5 - @pytest.mark.asyncio @pytest.mark.parametrize("req_count", [10]) async def test_main_loop_async(self, req_count: int): diff --git a/tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py b/tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py index 28d6bedf1ba..df36ef3076b 100644 --- a/tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py +++ b/tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py @@ -1,7 +1,7 @@ import pytest # isort: off -from .test_llm import tinyllama_logits_processor_test_harness +from .test_llm import tinyllama_logits_processor_test_harness, llama_model_path from tensorrt_llm import LLM from tensorrt_llm.llmapi import KvCacheConfig from tensorrt_llm.lora_helper import LoraConfig @@ -9,6 +9,8 @@ from .test_llm_pytorch import llama_7b_lora_from_dir_test_harness from .test_llm import _test_llm_capture_request_error # isort: on +from tensorrt_llm.executor.rpc_proxy import GenerationExecutorRpcProxy +from tensorrt_llm.sampling_params import SamplingParams global_kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4) @@ -57,3 +59,34 @@ def test_llama_7b_multi_lora_tp2(): # Disable CUDA graph # TODO: remove this once we have a proper fix for CUDA graph in LoRA cuda_graph_config=None) + + +@pytest.mark.gpu2 +def test_llm_rpc_tp2(): + llm = LLM(model=llama_model_path, + kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), + orchestrator_type="rpc", + tensor_parallel_size=2) + assert isinstance(llm._executor, GenerationExecutorRpcProxy) + + res = llm.generate("Tell me a joke", + sampling_params=SamplingParams(max_tokens=10, end_id=-1)) + print(f"get result: {res}") + + assert len(res.outputs) == 1 + assert len(res.outputs[0].token_ids) == 10 + + +@pytest.mark.gpu2 +@pytest.mark.asyncio +async def test_llm_rpc_streaming_tp2(): + llm = LLM(model=llama_model_path, + kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), + orchestrator_type="rpc", + tensor_parallel_size=2) + assert isinstance(llm._executor, GenerationExecutorRpcProxy) + + async for output in llm.generate_async("Tell me a joke", + sampling_params=SamplingParams( + max_tokens=10, end_id=-1)): + print(f"get result: {output}") diff --git a/tests/unittest/llmapi/test_llm_pytorch.py b/tests/unittest/llmapi/test_llm_pytorch.py index 62253df45a5..ac710c53f15 100644 --- a/tests/unittest/llmapi/test_llm_pytorch.py +++ b/tests/unittest/llmapi/test_llm_pytorch.py @@ -6,6 +6,7 @@ from tensorrt_llm import LLM from tensorrt_llm.executor import GenerationExecutorWorker +from tensorrt_llm.executor.rpc_proxy import GenerationExecutorRpcProxy from tensorrt_llm.llmapi import KvCacheConfig from tensorrt_llm.llmapi.llm_args import NGramDecodingConfig, PeftCacheConfig from tensorrt_llm.llmapi.tokenizer import TransformersTokenizer @@ -940,3 +941,18 @@ def test_max_num_token_check(self): match="should not exceed max_num_tokens"): ids = [random.randint(10, 100) for _ in range(101)] llm.generate([ids]) + + +def test_llm_rpc(): + with LLM(model=llama_model_path, + kv_cache_config=global_kvcache_config, + orchestrator_type="rpc") as llm: + assert isinstance(llm._executor, GenerationExecutorRpcProxy) + + res = llm.generate("Tell me a joke", + sampling_params=SamplingParams(max_tokens=10, + end_id=-1)) + print(f"get result: {res}") + + assert len(res.outputs) == 1 + assert len(res.outputs[0].token_ids) == 10 From 78b686919d01ba4106ae12ded91cf1cfd6a16c55 Mon Sep 17 00:00:00 2001 From: chunweiy Date: Tue, 23 Sep 2025 14:59:46 +0000 Subject: [PATCH 10/13] fix orchestrator_type Signed-off-by: chunweiy fix comment Signed-off-by: chunweiy Signed-off-by: chunweiy <328693+Superjomn@users.noreply.github.com> Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> --- tensorrt_llm/executor/rpc/rpc_server.py | 9 +++++ tensorrt_llm/executor/rpc_proxy.py | 3 -- tensorrt_llm/executor/rpc_worker.py | 35 ++++++++++++---- tensorrt_llm/llmapi/llm_args.py | 14 +++---- .../api_stability/references/llm.yaml | 4 ++ tests/unittest/executor/test_rpc_worker.py | 12 ++++++ tests/unittest/llmapi/test_llm_pytorch.py | 40 +++++++++++++------ 7 files changed, 87 insertions(+), 30 deletions(-) diff --git a/tensorrt_llm/executor/rpc/rpc_server.py b/tensorrt_llm/executor/rpc/rpc_server.py index eed47f273b3..eb2fde70f62 100644 --- a/tensorrt_llm/executor/rpc/rpc_server.py +++ b/tensorrt_llm/executor/rpc/rpc_server.py @@ -207,6 +207,14 @@ async def _dispatcher_routine(self, stop_event: threading.Event): f"Dispatcher received request {req}, pending: {self._num_pending_requests}" ) + # TODO optimization: resolve the sequential scheduling for the remote calls + # Suppose tons of submit remote call block the FIFO queue, and the later get_stats remote calls may be blocked + # There could be two dispatch modes: + # 1. (current) mix mode, share the same routine/pool + # 2. (promising) stream mode, specific remote_call -> stream -> specific routine/pool + # - get_stats() - 1, remote_call -> dedicated queue -> dedicated routine/pool + # - submit() - 3 -> dedicated queue -> dedicated routine/pool + # TODO potential optimization: for submit(), batch the ad-hoc requests in an interval like 5ms, reduce the IPC count async def _worker_routine(self, stop_event: threading.Event): """The routine executed by each worker thread.""" assert self._client_socket is not None, "Client socket is not bound" @@ -306,6 +314,7 @@ def call_with_kwargs(): logger_debug( f"RPC Server running async task {req.method_name} in worker" ) + # TODO: let num worker control the pool size result = await asyncio.wait_for(loop.run_in_executor( self._executor, call_with_kwargs), timeout=req.timeout) diff --git a/tensorrt_llm/executor/rpc_proxy.py b/tensorrt_llm/executor/rpc_proxy.py index b43b25a9499..02d3a65c29a 100644 --- a/tensorrt_llm/executor/rpc_proxy.py +++ b/tensorrt_llm/executor/rpc_proxy.py @@ -185,9 +185,6 @@ def submit(self, request: GenerationRequest) -> GenerationResult: return result - def fetch_responses_remote(self): - return self.rpc_client.fetch_responses().remote(timeout=20) - def fetch_stats_remote(self): return self.rpc_client.fetch_stats().remote() diff --git a/tensorrt_llm/executor/rpc_worker.py b/tensorrt_llm/executor/rpc_worker.py index 17ae96ae02b..16ceeb488af 100644 --- a/tensorrt_llm/executor/rpc_worker.py +++ b/tensorrt_llm/executor/rpc_worker.py @@ -34,13 +34,15 @@ class RpcWorker(BaseWorker): - `shutdown`: Shutdown the worker. """ + # Number of RPC server workers + NUM_WORKERS = 6 + def __init__( self, engine: Union[Path, Engine], executor_config: Optional[tllm.ExecutorConfig] = None, is_llm_executor: Optional[bool] = None, lora_config: Optional[LoraConfig] = None, - garbage_collection_gen0_threshold: Optional[int] = None, batched_logits_processor: Optional[BatchedLogitsProcessor] = None, postproc_worker_config: Optional[PostprocWorkerConfig] = None, kv_connector_config: Optional[KvCacheConnectorConfig] = None, @@ -60,8 +62,11 @@ def __init__( hf_model_dir=hf_model_dir, tokenizer=tokenizer, ) - # Store garbage_collection_gen0_threshold if needed in the future - self.garbage_collection_gen0_threshold = garbage_collection_gen0_threshold + # Extract garbage_collection_gen0_threshold from llm_args if available + self.garbage_collection_gen0_threshold = ( + llm_args.garbage_collection_gen0_threshold if llm_args is not None + and hasattr(llm_args, 'garbage_collection_gen0_threshold') else + None) self.shutdown_event = Event() self._response_queue = Queue() @@ -101,11 +106,13 @@ async def fetch_responses_async(self, timeout=timeout) return responses + async def fetch_stats_async(self, timeout: Optional[float] = None) -> list: + return await asyncio.to_thread(self.fetch_stats) + # for streaming performance async def fetch_responses_loop_async(self) -> AsyncGenerator[list, None]: while not self.shutdown_event.is_set(): - responses = await asyncio.to_thread(self.fetch_responses - ) # run blocking call in thread + responses = await self.fetch_responses_async() if responses: # Only yield if there are actual responses logger_debug( f"RpcWorker {mpi_rank()} is yielding responses: {responses}", @@ -118,6 +125,20 @@ async def fetch_responses_loop_async(self) -> AsyncGenerator[list, None]: f"RpcWorker {mpi_rank()} quitting fetch_responses_loop_async", color="yellow") + async def fetch_stats_loop_async( + self, + timeout: Optional[float] = None) -> AsyncGenerator[list, None]: + while not self.shutdown_event.is_set(): + logger_debug(f"RpcWorker {mpi_rank()} is fetching stats async") + timeout = timeout or 0.1 + await asyncio.sleep(timeout) + stats = await self.fetch_stats_async() + # Always yield stats, even if empty, to prevent the client looks like hanging + # TODO: Remove the empty stats to reduce the IPC overhead + yield stats + logger_debug(f"RpcWorker {mpi_rank()} quitting fetch_stats_loop_async", + color="yellow") + def setup_engine(self): # Force all the ranks to wait here, and start creating the executor simultaneously. # Only call barrier if we have multiple ranks to avoid hanging in single-process tests @@ -146,7 +167,6 @@ def main_task( postproc_worker_config: Optional[PostprocWorkerConfig] = None, is_llm_executor: Optional[bool] = None, lora_config: Optional[LoraConfig] = None, - garbage_collection_gen0_threshold: Optional[int] = None, llm_args: Optional[BaseLlmArgs] = None, kv_connector_config: Optional[KvCacheConnectorConfig] = None, hf_model_dir: Optional[Path] = None, @@ -162,7 +182,6 @@ def main_task( executor_config=executor_config, is_llm_executor=is_llm_executor, lora_config=lora_config, - garbage_collection_gen0_threshold=garbage_collection_gen0_threshold, llm_args=llm_args, batched_logits_processor=batched_logits_processor, postproc_worker_config=postproc_worker_config, @@ -184,7 +203,7 @@ def main_task( color="yellow") # Step 2: Create the RPC service, it will expose all the APIs of the worker as remote call to the client # Set num_workers to larger than 1 since there are some streaming tasks runs infinitely, such as await_responses_async. - rpc_server = RPCServer(worker, num_workers=6) + rpc_server = RPCServer(worker, num_workers=RpcWorker.NUM_WORKERS) rpc_server.bind(rpc_addr) rpc_server.start() diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 7efec37b401..b64c11246dd 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -1550,6 +1550,13 @@ class BaseLlmArgs(StrictBaseModel): description="Return perf metrics.", status="prototype") + orchestrator_type: Optional[Literal["rpc"]] = Field( + default=None, + description= + "The orchestrator type to use. Defaults to None, which uses MPI.", + status="prototype", + ) + _parallel_config: Optional[object] = PrivateAttr(default=None) _model_format: Optional[_ModelFormatKind] = PrivateAttr(default=None) _speculative_model: Optional[str] = PrivateAttr(default=None) @@ -2437,13 +2444,6 @@ class TorchLlmArgs(BaseLlmArgs): status="prototype", ) - orchestrator_type: Optional[Literal["ray", "rpc"]] = Field( - default=None, - description= - "The orchestrator type to use. Defaults to None, which uses MPI.", - status="prototype", - ) - # PrivateVars _quant_config: Optional[QuantConfig] = PrivateAttr(default=None) diff --git a/tests/unittest/api_stability/references/llm.yaml b/tests/unittest/api_stability/references/llm.yaml index 642729fc406..26f333eaad2 100644 --- a/tests/unittest/api_stability/references/llm.yaml +++ b/tests/unittest/api_stability/references/llm.yaml @@ -179,6 +179,10 @@ methods: annotation: bool default: False status: prototype + orchestrator_type: + annotation: Optional[Literal["rpc"]] + default: null + status: prototype return_annotation: None generate: parameters: diff --git a/tests/unittest/executor/test_rpc_worker.py b/tests/unittest/executor/test_rpc_worker.py index c27347cb4df..cacaf34e5d6 100644 --- a/tests/unittest/executor/test_rpc_worker.py +++ b/tests/unittest/executor/test_rpc_worker.py @@ -173,6 +173,18 @@ async def process_request_streaming(): await process_request_streaming() + @pytest.mark.asyncio + async def test_fetch_stats_loop_async(self): + await asyncio.sleep(1) + results = [] + async for stats in self.client.fetch_stats_loop_async( + ).remote_streaming(): + results.append(stats) # empty stats + print(f"fetch_stats_async batch, received {len(stats)} stats") + if len(results) >= 10: + break + assert len(results) >= 10 + class TestRpcWorkerTP2: diff --git a/tests/unittest/llmapi/test_llm_pytorch.py b/tests/unittest/llmapi/test_llm_pytorch.py index ac710c53f15..f7175c311d1 100644 --- a/tests/unittest/llmapi/test_llm_pytorch.py +++ b/tests/unittest/llmapi/test_llm_pytorch.py @@ -944,15 +944,31 @@ def test_max_num_token_check(self): def test_llm_rpc(): - with LLM(model=llama_model_path, - kv_cache_config=global_kvcache_config, - orchestrator_type="rpc") as llm: - assert isinstance(llm._executor, GenerationExecutorRpcProxy) - - res = llm.generate("Tell me a joke", - sampling_params=SamplingParams(max_tokens=10, - end_id=-1)) - print(f"get result: {res}") - - assert len(res.outputs) == 1 - assert len(res.outputs[0].token_ids) == 10 + llm = LLM(model=llama_model_path, + kv_cache_config=global_kvcache_config, + orchestrator_type="rpc") + assert isinstance(llm._executor, GenerationExecutorRpcProxy) + + res = llm.generate("Tell me a joke", + sampling_params=SamplingParams(max_tokens=10, end_id=-1)) + print(f"get result: {res}") + + assert len(res.outputs) == 1 + assert len(res.outputs[0].token_ids) == 10 + + +@pytest.mark.asyncio +async def test_llm_rpc_streaming(): + llm = LLM(model=llama_model_path, + kv_cache_config=global_kvcache_config, + orchestrator_type="rpc") + assert isinstance(llm._executor, GenerationExecutorRpcProxy) + + outputs = [] + async for output in llm.generate_async("Tell me a joke", + sampling_params=SamplingParams( + max_tokens=10, end_id=-1), + streaming=True): + outputs.append(output.outputs[0].text) + "".join(outputs) + print(f"get result: {outputs}") From 4a903ce1f8c0ae0cd1ac3224596685c6a6769bfd Mon Sep 17 00:00:00 2001 From: chunweiy Date: Fri, 26 Sep 2025 10:31:42 +0000 Subject: [PATCH 11/13] add get_stats and kv_cache_event Signed-off-by: chunweiy Signed-off-by: chunweiy <328693+Superjomn@users.noreply.github.com> Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> --- tensorrt_llm/executor/base_worker.py | 12 ++ tensorrt_llm/executor/executor.py | 1 + tensorrt_llm/executor/rpc/rpc_client.py | 2 + tensorrt_llm/executor/rpc/rpc_server.py | 2 +- tensorrt_llm/executor/rpc_proxy.py | 150 ++++++++++++++++++-- tensorrt_llm/executor/rpc_worker.py | 52 +++++-- tensorrt_llm/executor/worker.py | 31 +--- tests/unittest/executor/test_base_worker.py | 9 +- tests/unittest/executor/test_rpc_proxy.py | 21 ++- tests/unittest/executor/test_rpc_worker.py | 9 +- 10 files changed, 228 insertions(+), 61 deletions(-) diff --git a/tensorrt_llm/executor/base_worker.py b/tensorrt_llm/executor/base_worker.py index 1dccf5dcf66..0bdfc1ca52b 100644 --- a/tensorrt_llm/executor/base_worker.py +++ b/tensorrt_llm/executor/base_worker.py @@ -235,6 +235,12 @@ def fetch_stats(self) -> list: else: return self.engine.get_latest_iteration_stats() + def fetch_kv_cache_events(self) -> list: + if isinstance(self.engine, tllm.Executor): + return self.engine.get_latest_kv_cache_events() + else: + return self.engine.get_latest_kv_cache_events() + def set_result_queue(self, queue): """In multi-gpu mode, result_queue will be set here to communicate between the proxy and the worker 0 process.""" assert self.postproc_queues is None @@ -574,6 +580,12 @@ def _stats_serializer( # Convert back to JSON string return json.dumps(stats_dict) + # Define a Callable to serialize KV cache events + @staticmethod + def _kv_cache_events_serializer(events) -> str: + from .._utils import KVCacheEventSerializer + return json.dumps(KVCacheEventSerializer.serialize(events)) + def _pop_result(self, client_id: int): self._results.pop(client_id, None) self._client_id_to_request_id.pop(client_id, None) diff --git a/tensorrt_llm/executor/executor.py b/tensorrt_llm/executor/executor.py index 1a2e552f355..c7fc8ea5d0d 100644 --- a/tensorrt_llm/executor/executor.py +++ b/tensorrt_llm/executor/executor.py @@ -451,6 +451,7 @@ def create( postproc_worker_config=postproc_worker_config, is_llm_executor=is_llm_executor, kv_connector_config=kv_connector_config) + return GenerationExecutorWorker( **worker_kwargs, is_llm_executor=is_llm_executor, diff --git a/tensorrt_llm/executor/rpc/rpc_client.py b/tensorrt_llm/executor/rpc/rpc_client.py index 6767291523c..f3f8ecfa01d 100644 --- a/tensorrt_llm/executor/rpc/rpc_client.py +++ b/tensorrt_llm/executor/rpc/rpc_client.py @@ -79,6 +79,7 @@ def __init__(self, address: The ZMQ address to connect to. hmac_key: The HMAC key for encryption. timeout: The timeout (seconds) for RPC calls. + num_workers: The number of workers for the RPC client. ''' self._address = address self._timeout = timeout @@ -307,6 +308,7 @@ async def _call_async(self, method_name, *args, **kwargs): if timeout is None: res = await future else: + # Add 1 second to the timeout to ensure the client can get res = await asyncio.wait_for(future, timeout + 1) logger_debug( f"RPC Client _call_async: Got result for request_id: {request_id}: {res}" diff --git a/tensorrt_llm/executor/rpc/rpc_server.py b/tensorrt_llm/executor/rpc/rpc_server.py index eb2fde70f62..b816a500774 100644 --- a/tensorrt_llm/executor/rpc/rpc_server.py +++ b/tensorrt_llm/executor/rpc/rpc_server.py @@ -31,7 +31,7 @@ def __init__(self, Args: instance: The instance whose methods will be exposed via RPC. hmac_key (bytes, optional): HMAC key for encryption. - num_workers (int): Number of worker threads. + num_workers (int): Number of worker threads or worker tasks that help parallelize the task execution. timeout (int): Timeout for RPC calls. async_run_task (bool): Whether to run the task asynchronously. diff --git a/tensorrt_llm/executor/rpc_proxy.py b/tensorrt_llm/executor/rpc_proxy.py index 02d3a65c29a..83ca5af0665 100644 --- a/tensorrt_llm/executor/rpc_proxy.py +++ b/tensorrt_llm/executor/rpc_proxy.py @@ -1,5 +1,6 @@ import asyncio import atexit +import json import os import threading import time @@ -8,8 +9,8 @@ from ..llmapi.llm_args import KvCacheConnectorConfig from ..llmapi.mpi_session import MpiPoolSession, MpiSession from ..llmapi.tracer import global_tracer -from ..llmapi.utils import (_SyncQueue, logger_debug, print_colored_debug, - print_traceback_on_error) +from ..llmapi.utils import (AsyncQueue, _SyncQueue, logger_debug, + print_colored_debug) from ..logger import logger from .executor import GenerationExecutor from .postproc_worker import PostprocWorkerConfig @@ -86,32 +87,65 @@ def launch_workers(self): rpc_addr=self.rpc_addr, **self.worker_kwargs) - @print_traceback_on_error - async def main_loop_task(self): - """ - Main loop of the proxy, it will invoke the actions periodically. + async def _generic_fetch_loop_async(self, fetch_method_name: str, + handler_method, method_name: str): + """Generic method for fetching data in a loop from RPC worker. + + Args: + fetch_method_name: Name of the RPC client method to call + handler_method: The handler method to call with the fetched data + method_name: Name of the method for logging """ try: - async for responses in self.rpc_client.fetch_responses_loop_async( - ).remote_streaming(): + fetch_method = getattr(self.rpc_client, fetch_method_name) + async for data in fetch_method().remote_streaming(): if self._shutdown_event.is_set(): return - self.handle_responses(responses) + handler_method(data) except asyncio.CancelledError: - logger.debug("Main loop task cancelled") + logger.debug(f"{method_name} task cancelled") except Exception as e: - logger.error(f"Error in main_loop_task: {e}") + logger.error(f"Error in {method_name}: {e}") raise + async def _fetch_responses_loop_async(self): + await self._generic_fetch_loop_async( + fetch_method_name="fetch_responses_loop_async", + handler_method=self.handle_responses, + method_name="_fetch_responses_loop_async") + + async def _fetch_stats_loop_async(self): + await self._generic_fetch_loop_async( + fetch_method_name="fetch_stats_loop_async", + handler_method=self.handle_stats, + method_name="_fetch_stats_loop_async") + + async def _fetch_kv_cache_events_loop_async(self): + await self._generic_fetch_loop_async( + fetch_method_name="fetch_kv_cache_events_loop_async", + handler_method=self.handle_kv_cache_events, + method_name="_fetch_kv_cache_events_loop_async") + def setup_mainloop(self): + async def main_loop_task(): + tasks = [ + self._fetch_responses_loop_async(), + self._fetch_stats_loop_async(), + self._fetch_kv_cache_events_loop_async(), + ] + # Only add kv_cache_events loop if it's enabled + if self._iter_kv_events_result: + tasks.append(self._fetch_kv_cache_events_loop_async()) + await asyncio.gather(*tasks) + def _run_main_loop_task(): """Local method to run the main loop task.""" self.main_loop = asyncio.new_event_loop() asyncio.set_event_loop(self.main_loop) self.main_loop_task_obj = self.main_loop.create_task( - self.main_loop_task()) + main_loop_task()) try: self.main_loop.run_until_complete(self.main_loop_task_obj) except asyncio.CancelledError: @@ -164,9 +198,95 @@ def process_res(res: list): if async_queues: _SyncQueue.notify_many(event_loop, async_queues) - def handle_stats(self, stats: dict): - # raise NotImplementedError - pass + def _handle_iteration_data(self, data, result_singleton, data_type: str): + """Generic method to handle iteration data received from RPC worker. + + Args: + data: Data from the RPC worker (can be dict, str, or list) + result_singleton: The iteration result singleton to put data into + data_type: Type of data for logging (e.g., "stats", "kv_cache_events") + """ + # Make sure we have initialized the iteration results + self._maybe_initialize_iteration_results() + + if not result_singleton: + logger.debug( + f"Skipping {data_type} handling while result_singleton=None") + return + + # Get the queue from the result singleton + queue = result_singleton.queue + async_queues = [] + + # Clear old data if queue is full (similar to _iteration_result_task) + while queue.full(): + queue.get() + + try: + # Handle different types of data + if isinstance(data, str): + # Already JSON serialized + data_json = data + elif isinstance(data, list): + # Skip empty lists to avoid putting nothing in the queue + if not data: + logger.debug( + f"rpc_proxy.py: Skipping empty {data_type} list") + return + + # Handle list of data (multiple iterations) + for item in data: + if isinstance(item, str): + item_json = item + else: + item_json = json.dumps(item) + + if isinstance(queue, _SyncQueue): + queue.put_nowait(item_json) + async_queues.append(queue) + else: + queue.put(item_json) + + if async_queues: + _SyncQueue.notify_many(queue.loop, async_queues) + return + else: + # Convert dict/other to JSON string as expected by IterationResult + data_json = json.dumps(data) + + if isinstance(queue, _SyncQueue): + queue.put_nowait(data_json) + async_queues.append(queue) + else: + queue.put(data_json) + + if async_queues: + _SyncQueue.notify_many(queue.loop, async_queues) + + except AsyncQueue.EventLoopShutdownError: + # This happens when the event loop is already closed + logger.debug( + f"rpc_proxy.py: EventLoopShutdownError in handle_{data_type}") + except Exception as e: + logger.error(f"rpc_proxy.py: Error in handle_{data_type}: {e}") + raise e + + def handle_stats(self, stats): + """Handle stats received from RPC worker and put them into the stats result queue. + + Args: + stats: Statistics data from the RPC worker (can be dict, str, or list) + """ + self._handle_iteration_data(stats, self._iter_stats_result, "stats") + + def handle_kv_cache_events(self, events): + """Handle KV cache events received from RPC worker and put them into the events result queue. + + Args: + events: KV cache events data from the RPC worker (can be dict, str, or list) + """ + self._handle_iteration_data(events, self._iter_kv_events_result, + "kv_cache_events") def submit(self, request: GenerationRequest) -> GenerationResult: request.set_id(self._get_next_client_id()) diff --git a/tensorrt_llm/executor/rpc_worker.py b/tensorrt_llm/executor/rpc_worker.py index 16ceeb488af..f41c2c3257f 100644 --- a/tensorrt_llm/executor/rpc_worker.py +++ b/tensorrt_llm/executor/rpc_worker.py @@ -76,9 +76,6 @@ def submit(self, request: GenerationRequest): """ Submits a request to the worker. """ super().submit(request) - def fetch_stats(self) -> list: - return super().fetch_stats() - def fetch_responses(self, timeout: Optional[float] = None) -> list: logger_debug(f"RpcWorker {mpi_rank()} is fetching responses", color="yellow") @@ -109,6 +106,11 @@ async def fetch_responses_async(self, async def fetch_stats_async(self, timeout: Optional[float] = None) -> list: return await asyncio.to_thread(self.fetch_stats) + async def fetch_kv_cache_events_async(self, + timeout: Optional[float] = None + ) -> list: + return await asyncio.to_thread(self.fetch_kv_cache_events) + # for streaming performance async def fetch_responses_loop_async(self) -> AsyncGenerator[list, None]: while not self.shutdown_event.is_set(): @@ -125,20 +127,50 @@ async def fetch_responses_loop_async(self) -> AsyncGenerator[list, None]: f"RpcWorker {mpi_rank()} quitting fetch_responses_loop_async", color="yellow") - async def fetch_stats_loop_async( + async def _generic_fetch_loop_async( self, + fetch_method, + serializer, + method_name: str, timeout: Optional[float] = None) -> AsyncGenerator[list, None]: + """Generic method for fetching data in a loop. + + Args: + fetch_method: The async method to call for fetching data + serializer: The serializer function to apply to each item + method_name: Name of the method for logging + timeout: Optional timeout between fetches + """ while not self.shutdown_event.is_set(): - logger_debug(f"RpcWorker {mpi_rank()} is fetching stats async") timeout = timeout or 0.1 await asyncio.sleep(timeout) - stats = await self.fetch_stats_async() - # Always yield stats, even if empty, to prevent the client looks like hanging - # TODO: Remove the empty stats to reduce the IPC overhead - yield stats - logger_debug(f"RpcWorker {mpi_rank()} quitting fetch_stats_loop_async", + data = await fetch_method() + # Always yield data, even if empty, to prevent the client looks like hanging + # TODO: Remove the empty data to reduce the IPC overhead + yield [serializer(item) for item in data] + logger_debug(f"RpcWorker {mpi_rank()} quitting {method_name}", color="yellow") + async def fetch_stats_loop_async( + self, + timeout: Optional[float] = None) -> AsyncGenerator[list, None]: + async for data in self._generic_fetch_loop_async( + fetch_method=self.fetch_stats_async, + serializer=self._stats_serializer, + method_name="fetch_stats_loop_async", + timeout=timeout): + yield data + + async def fetch_kv_cache_events_loop_async( + self, + timeout: Optional[float] = None) -> AsyncGenerator[list, None]: + async for data in self._generic_fetch_loop_async( + fetch_method=self.fetch_kv_cache_events_async, + serializer=self._kv_cache_events_serializer, + method_name="fetch_kv_cache_events_loop_async", + timeout=timeout): + yield data + def setup_engine(self): # Force all the ranks to wait here, and start creating the executor simultaneously. # Only call barrier if we have multiple ranks to avoid hanging in single-process tests diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index b744913ace6..859e5687448 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -1,18 +1,17 @@ import gc -import json import os import time import traceback from concurrent.futures import ProcessPoolExecutor from pathlib import Path from queue import Queue -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, List, Optional, Union import zmq from tensorrt_llm.logger import logger -from .._utils import KVCacheEventSerializer, mpi_comm, mpi_rank +from .._utils import mpi_comm, mpi_rank from ..bindings import executor as tllm from ..builder import Engine from ..llmapi.llm_args import BaseLlmArgs, KvCacheConnectorConfig @@ -144,22 +143,6 @@ def _iteration_result_task(self, it_result_queue: IterationResultQueue, return True # success def dispatch_stats_task(self) -> bool: - - # Define a Callable to join iteration and request stats - def stats_serializer( - stats: Tuple[tllm.IterationStats, tllm.RequestStats]) -> str: - iteration_stats, req_stats = stats - stats_dict = json.loads(iteration_stats.to_json_str()) - - if req_stats is not None and len(req_stats) > 0: - stats_dict["requestStats"] = [] - for req_stat in req_stats: - stats_dict["requestStats"].append( - json.loads(req_stat.to_json_str())) - - # Convert back to JSON string - return json.dumps(stats_dict) - return self._iteration_result_task(self.stats_queues, self.fetch_stats, self._iter_stats_result, self._stats_serializer) @@ -173,14 +156,14 @@ def dispatch_kv_cache_events_task(self) -> bool: events_api = lambda: [None] else: events_api = event_manager.get_latest_events - return self._iteration_result_task( - self.kv_events_queues, events_api, self._iter_kv_events_result, - lambda x: json.dumps(KVCacheEventSerializer.serialize(x))) + return self._iteration_result_task(self.kv_events_queues, + events_api, + self._iter_kv_events_result, + self._kv_cache_events_serializer) else: return self._iteration_result_task( self.kv_events_queues, self.engine.get_latest_kv_cache_events, - self._iter_kv_events_result, - lambda x: json.dumps(KVCacheEventSerializer.serialize(x))) + self._iter_kv_events_result, self._kv_cache_events_serializer) def start(self): # create iteration result queues diff --git a/tests/unittest/executor/test_base_worker.py b/tests/unittest/executor/test_base_worker.py index 664beefd270..c8e4fe2691e 100644 --- a/tests/unittest/executor/test_base_worker.py +++ b/tests/unittest/executor/test_base_worker.py @@ -117,9 +117,12 @@ def test_fetch_responses_timeout(self, timeout: float): def create_fake_executor_config(model_path, tp_size=1): # Use TorchLlmArgs for PyTorch backend tests - llm_args = TorchLlmArgs(model=model_path, - tensor_parallel_size=tp_size, - backend='pytorch') + llm_args = TorchLlmArgs( + model=model_path, + tensor_parallel_size=tp_size, + backend='pytorch', + enable_iter_perf_stats=True, + ) executor_config = tllm.ExecutorConfig(1) executor_config.max_batch_size = 1 diff --git a/tests/unittest/executor/test_rpc_proxy.py b/tests/unittest/executor/test_rpc_proxy.py index 22e5cbb7eee..13d281125df 100644 --- a/tests/unittest/executor/test_rpc_proxy.py +++ b/tests/unittest/executor/test_rpc_proxy.py @@ -6,6 +6,7 @@ from test_base_worker import create_fake_executor_config from tensorrt_llm.executor.rpc_proxy import GenerationExecutorRpcProxy +from tensorrt_llm.llmapi.llm_args import KvCacheConfig from tensorrt_llm.llmapi.mpi_session import MpiPoolSession from tensorrt_llm.llmapi.tokenizer import TransformersTokenizer from tensorrt_llm.sampling_params import SamplingParams @@ -26,6 +27,12 @@ def create_proxy(self, tp_size: int): llm_args, executor_config = create_fake_executor_config(model_path, tp_size=tp_size) + # Enable KV cache events + llm_args.kv_cache_config = KvCacheConfig( + event_buffer_max_size=1000, # Enable event buffer + enable_block_reuse=True, # Required for KV cache events + ) + mpi_session = MpiPoolSession(n_workers=tp_size) proxy = GenerationExecutorRpcProxy( worker_kwargs={ @@ -37,6 +44,7 @@ def create_proxy(self, tp_size: int): }, model_world_size=tp_size, mpi_session=mpi_session, + is_llm_executor=True, # Enable stats collection ) # Add additional wait for PyTorch backend with multi-rank setup @@ -62,6 +70,13 @@ def test_tp1(self, num_reqs): assert similar(tokenizer.decode(result.outputs[0].token_ids), 'E F G H I J K L') + stats = proxy.get_stats(timeout=2) + assert stats + + kv_cache_events = proxy.get_kv_events(timeout=2) + # KV cache events may be empty if no cache operations occurred + assert isinstance(kv_cache_events, list) + @pytest.mark.parametrize("num_reqs", [1, 10]) def test_tp2(self, num_reqs): tokenizer = TransformersTokenizer.from_pretrained(model_path) @@ -73,9 +88,9 @@ def test_tp2(self, num_reqs): sampling_params = SamplingParams(max_tokens=max_tokens) for _ in range(num_reqs): result = proxy.generate(prompt_token_ids, sampling_params) - print(f"get result: {result}") - assert similar(tokenizer.decode(result.outputs[0].token_ids), - 'E F G H I J K L') + print(f"get result: {result}") + assert similar(tokenizer.decode(result.outputs[0].token_ids), + 'E F G H I J K L') if __name__ == "__main__": diff --git a/tests/unittest/executor/test_rpc_worker.py b/tests/unittest/executor/test_rpc_worker.py index cacaf34e5d6..f6b759f9a07 100644 --- a/tests/unittest/executor/test_rpc_worker.py +++ b/tests/unittest/executor/test_rpc_worker.py @@ -179,11 +179,10 @@ async def test_fetch_stats_loop_async(self): results = [] async for stats in self.client.fetch_stats_loop_async( ).remote_streaming(): - results.append(stats) # empty stats - print(f"fetch_stats_async batch, received {len(stats)} stats") - if len(results) >= 10: - break - assert len(results) >= 10 + results.append(stats) + assert not stats # empty stats + + assert len(results) == 0 class TestRpcWorkerTP2: From a86ec76a65c057a1199364ba5d73f0e4010b9bc5 Mon Sep 17 00:00:00 2001 From: chunweiy Date: Fri, 26 Sep 2025 11:31:01 +0000 Subject: [PATCH 12/13] better reuse code between remote methods Signed-off-by: chunweiy Signed-off-by: chunweiy <328693+Superjomn@users.noreply.github.com> Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> --- tensorrt_llm/executor/rpc/rpc_client.py | 60 ++++++++++--------- tensorrt_llm/executor/rpc/rpc_common.py | 12 +++- tensorrt_llm/executor/rpc/rpc_server.py | 77 +++++++++++++++++++++---- 3 files changed, 108 insertions(+), 41 deletions(-) diff --git a/tensorrt_llm/executor/rpc/rpc_client.py b/tensorrt_llm/executor/rpc/rpc_client.py index f3f8ecfa01d..111b51eca1d 100644 --- a/tensorrt_llm/executor/rpc/rpc_client.py +++ b/tensorrt_llm/executor/rpc/rpc_client.py @@ -20,48 +20,52 @@ def __init__(self, client: 'RPCClient', method_name: str, *args, **kwargs): self.args = args self.kwargs = kwargs + def _prepare_and_call(self, timeout: Optional[float], need_response: bool, + mode: str, call_method: str) -> Any: + """Common method to prepare RPC params and make the call. + + Args: + timeout: Timeout for the RPC call + need_response: Whether a response is expected + mode: The RPC mode ("sync", "async", "future") + call_method: The method name to call on the client + + Returns: + The result of the client method call + """ + rpc_params = RPCParams(timeout=timeout, + need_response=need_response, + mode=mode) + self.kwargs["__rpc_params"] = rpc_params + client_method = getattr(self.client, call_method) + return client_method(self.method_name, *self.args, **self.kwargs) + def remote(self, timeout: Optional[float] = None, need_response: bool = True) -> Any: """Synchronous remote call with optional RPC parameters.""" - rpc_params = RPCParams(timeout=timeout, - need_response=need_response, - mode="sync") - self.kwargs["__rpc_params"] = rpc_params - return self.client._call_sync(self.method_name, *self.args, - **self.kwargs) + return self._prepare_and_call(timeout, need_response, "sync", + "_call_sync") def remote_async(self, timeout: Optional[float] = None, need_response: bool = True): """Asynchronous remote call that returns a coroutine.""" - rpc_params = RPCParams(timeout=timeout, - need_response=need_response, - mode="async") - self.kwargs["__rpc_params"] = rpc_params - return self.client._call_async(self.method_name, *self.args, - **self.kwargs) + return self._prepare_and_call(timeout, need_response, "async", + "_call_async") def remote_future(self, timeout: Optional[float] = None, need_response: bool = True) -> concurrent.futures.Future: """Remote call that returns a Future object.""" - rpc_params = RPCParams(timeout=timeout, - need_response=need_response, - mode="future") - self.kwargs["__rpc_params"] = rpc_params - return self.client.call_future(self.method_name, *self.args, - **self.kwargs) + return self._prepare_and_call(timeout, need_response, "future", + "call_future") def remote_streaming(self, timeout: Optional[float] = None) -> AsyncIterator[Any]: """Remote call for streaming results.""" - rpc_params = RPCParams(timeout=timeout, - need_response=True, - mode="async") - self.kwargs["__rpc_params"] = rpc_params - return self.client.call_streaming(self.method_name, *self.args, - **self.kwargs) + # Streaming always needs a response + return self._prepare_and_call(timeout, True, "async", "call_streaming") class RPCClient: @@ -309,7 +313,7 @@ async def _call_async(self, method_name, *args, **kwargs): res = await future else: # Add 1 second to the timeout to ensure the client can get - res = await asyncio.wait_for(future, timeout + 1) + res = await asyncio.wait_for(future, timeout) logger_debug( f"RPC Client _call_async: Got result for request_id: {request_id}: {res}" ) @@ -361,7 +365,7 @@ def _call_sync(self, method_name, *args, **kwargs): f"RPC Client _call_sync: Got result for {method_name}: {result}") return result - def call_async(self, name: str, *args, **kwargs): + def call_async(self, name: str, *args, **kwargs) -> Any: """ Call a remote method asynchronously. @@ -408,7 +412,7 @@ def _async_to_sync(): return self._executor.submit(_async_to_sync) - def call_sync(self, name: str, *args, **kwargs): + def call_sync(self, name: str, *args, **kwargs) -> Any: """ Call a remote method synchronously (blocking). @@ -476,7 +480,7 @@ async def call_streaming(self, name: str, *args, response = await queue.get() else: response = await asyncio.wait_for(queue.get(), - timeout=timeout + 1) + timeout=timeout) logger_debug( f"RPC Client call_streaming received [{response.stream_status}] response: {response}", diff --git a/tensorrt_llm/executor/rpc/rpc_common.py b/tensorrt_llm/executor/rpc/rpc_common.py index 4c81911cf89..5ea809efe71 100644 --- a/tensorrt_llm/executor/rpc/rpc_common.py +++ b/tensorrt_llm/executor/rpc/rpc_common.py @@ -1,3 +1,5 @@ +import time +from dataclasses import dataclass from typing import Any, Literal, NamedTuple, Optional @@ -48,7 +50,8 @@ class RPCStreamingError(RPCError): """Exception for streaming-related errors.""" -class RPCRequest(NamedTuple): +@dataclass +class RPCRequest: request_id: str method_name: str args: tuple @@ -56,6 +59,13 @@ class RPCRequest(NamedTuple): need_response: bool = True timeout: float = 0.5 is_streaming: bool = False + creation_timestamp: Optional[ + float] = None # Unix timestamp when request was created + + def __post_init__(self): + """Initialize creation_timestamp if not provided.""" + if self.creation_timestamp is None: + self.creation_timestamp = time.time() class RPCResponse(NamedTuple): diff --git a/tensorrt_llm/executor/rpc/rpc_server.py b/tensorrt_llm/executor/rpc/rpc_server.py index b816a500774..268bb6012f2 100644 --- a/tensorrt_llm/executor/rpc/rpc_server.py +++ b/tensorrt_llm/executor/rpc/rpc_server.py @@ -292,10 +292,38 @@ async def _worker_routine(self, stop_event: threading.Event): if req.method_name not in ["_rpc_shutdown", "shutdown"]: self._num_pending_requests -= 1 + def _calculate_adjusted_timeout(self, + req: RPCRequest, + is_streaming: bool = False) -> float: + """Calculate adjusted timeout based on pending overhead. + + Args: + req: The RPC request + is_streaming: Whether this is for a streaming request + + Returns: + The adjusted timeout value + """ + adjusted_timeout = req.timeout + if req.creation_timestamp is not None and req.timeout is not None and req.timeout > 0: + pending_time = time.time() - req.creation_timestamp + adjusted_timeout = max(0.1, req.timeout - + pending_time) # Keep at least 0.1s timeout + if pending_time > 0.1: # Only log if significant pending time + method_type = "streaming " if is_streaming else "" + logger_debug( + f"RPC Server adjusted timeout for {method_type}{req.method_name}: " + f"original={req.timeout}s, pending={pending_time:.3f}s, adjusted={adjusted_timeout:.3f}s" + ) + return adjusted_timeout + async def _process_request(self, req: RPCRequest) -> Optional[RPCResponse]: """Process a request. Returns None for streaming requests (handled separately).""" func = self._functions[req.method_name] + # Calculate adjusted timeout based on pending overhead + adjusted_timeout = self._calculate_adjusted_timeout(req) + try: if inspect.iscoroutinefunction(func): # Execute async function directly in event loop, no need to run in executor due to the GIL @@ -303,7 +331,7 @@ async def _process_request(self, req: RPCRequest) -> Optional[RPCResponse]: f"RPC Server running async task {req.method_name} in dispatcher" ) result = await asyncio.wait_for(func(*req.args, **req.kwargs), - timeout=req.timeout) + timeout=adjusted_timeout) else: # Execute sync function in thread executor loop = asyncio.get_running_loop() @@ -317,7 +345,7 @@ def call_with_kwargs(): # TODO: let num worker control the pool size result = await asyncio.wait_for(loop.run_in_executor( self._executor, call_with_kwargs), - timeout=req.timeout) + timeout=adjusted_timeout) logger_debug(f"RPC Server returned result for request {req}") response = RPCResponse(req.request_id, result) @@ -354,6 +382,10 @@ async def _process_streaming_request(self, req: RPCRequest): sequence_number = 0 + # Calculate adjusted timeout based on pending overhead + adjusted_timeout = self._calculate_adjusted_timeout(req, + is_streaming=True) + try: logger_debug(f"RPC Server running streaming task {req.method_name}") # Send start signal @@ -362,16 +394,37 @@ async def _process_streaming_request(self, req: RPCRequest): 'start')) sequence_number += 1 - # Stream the results - async for result in func(*req.args, **req.kwargs): - logger_debug( - f"RPC Server got data and ready to send result {result}") - response = RPCResponse(req.request_id, result, None, True, - sequence_number, 'data') - if not await self._send_response(req, response): - # Stop streaming after a pickle error - return - sequence_number += 1 + # Apply timeout to the entire streaming operation if specified + if adjusted_timeout is not None and adjusted_timeout > 0: + # Create a task for the async generator with timeout + async def stream_with_timeout(): + nonlocal sequence_number + async for result in func(*req.args, **req.kwargs): + logger_debug( + f"RPC Server got data and ready to send result {result}" + ) + response = RPCResponse(req.request_id, result, None, + True, sequence_number, 'data') + if not await self._send_response(req, response): + # Stop streaming after a pickle error + return + sequence_number += 1 + + # Use wait_for for timeout handling + await asyncio.wait_for(stream_with_timeout(), + timeout=adjusted_timeout) + else: + # No timeout specified, stream normally + async for result in func(*req.args, **req.kwargs): + logger_debug( + f"RPC Server got data and ready to send result {result}" + ) + response = RPCResponse(req.request_id, result, None, True, + sequence_number, 'data') + if not await self._send_response(req, response): + # Stop streaming after a pickle error + return + sequence_number += 1 # Send end signal await self._client_socket.put_async( From f13a34e4aaf036525a5515c78eee475dcf53e701 Mon Sep 17 00:00:00 2001 From: chunweiy Date: Mon, 29 Sep 2025 11:08:46 +0000 Subject: [PATCH 13/13] add rpc test list Signed-off-by: chunweiy Signed-off-by: Superjomn <328693+Superjomn@users.noreply.github.com> --- tensorrt_llm/executor/ipc.py | 32 ------- tensorrt_llm/executor/rpc/README.md | 85 +++++++++++++++++++ tensorrt_llm/executor/rpc/rpc_client.py | 64 ++------------ tensorrt_llm/executor/rpc_proxy.py | 2 - tensorrt_llm/executor/rpc_worker.py | 2 +- .../integration/test_lists/test-db/l0_a10.yml | 2 + .../test_lists/test-db/l0_a100.yml | 4 + tests/unittest/executor/test_base_worker.py | 3 + tests/unittest/executor/test_rpc.py | 8 +- tests/unittest/executor/test_rpc_proxy.py | 4 +- tests/unittest/executor/test_rpc_worker.py | 23 +++-- .../llmapi/test_llm_multi_gpu_pytorch.py | 41 ++++----- tests/unittest/llmapi/test_llm_pytorch.py | 47 +++++----- 13 files changed, 175 insertions(+), 142 deletions(-) create mode 100644 tensorrt_llm/executor/rpc/README.md diff --git a/tensorrt_llm/executor/ipc.py b/tensorrt_llm/executor/ipc.py index 00e9b4d336b..a2b01ee96fb 100644 --- a/tensorrt_llm/executor/ipc.py +++ b/tensorrt_llm/executor/ipc.py @@ -194,38 +194,6 @@ async def put_async_noblock(self, obj: Any): logger.error(traceback.format_exc()) raise e - async def put_async_with_timeout(self, obj: Any, timeout: float = 5.0): - """ - Send an object with timeout to detect connection failures. - - Args: - obj: The object to send - timeout: Timeout in seconds for the send operation - - Raises: - zmq.Again: If send operation times out (peer may be disconnected) - Exception: Other send errors - """ - self.setup_lazily() - try: - if self.use_hmac_encryption: - data = pickle.dumps(obj) # nosec B301 - signed_data = self._sign_data(data) - # Use asyncio.wait_for to implement timeout instead of zmq.NOBLOCK - await asyncio.wait_for(self.socket.send(signed_data), - timeout=timeout) - else: - await asyncio.wait_for(self.socket.send_pyobj(obj), - timeout=timeout) - except asyncio.TimeoutError: - # Convert timeout to zmq.Again to maintain compatibility with existing error handling - raise zmq.Again( - "Send operation timed out - peer may be disconnected") - except Exception as e: - logger.error(f"Error sending object: {e}") - logger.error(traceback.format_exc()) - raise e - def get(self) -> Any: self.setup_lazily() return self._recv_data() diff --git a/tensorrt_llm/executor/rpc/README.md b/tensorrt_llm/executor/rpc/README.md new file mode 100644 index 00000000000..76d7b846ab3 --- /dev/null +++ b/tensorrt_llm/executor/rpc/README.md @@ -0,0 +1,85 @@ +# A Lightweight RPC +This is a pure-Python lightweight RPC we build to simplify our existing IPC code in the orchestrator part. It provides multiple call modes (sync, async, future, streaming) and supports both IPC and TCP connections. + +## Examples +### Create Server and Client + +```python +from tensorrt_llm.executor.rpc import RPCServer, RPCClient + +# Define your application +class App: + def add(self, a: int, b: int) -> int: + return a + b + + async def async_multiply(self, x: int, y: int) -> int: + return x * y + +# Create and start server +app = App() +with RPCServer(app) as server: + server.bind("ipc:///tmp/my_rpc") # or "tcp://127.0.0.1:5555" + server.start() + + # Create client and make calls + with RPCClient("ipc:///tmp/my_rpc") as client: + result = client.add(5, 3).remote() + print(result) # Output: 8 +``` + +### Different Remote Calls + +#### Synchronous Call +```python +# Blocking call that waits for result +result = client.add(10, 20).remote() +# or with timeout +result = client.add(10, 20).remote(timeout=5.0) +``` + +#### Asynchronous Call +```python +# Async call that returns a coroutine +result = await client.async_multiply(3, 4).remote_async() +``` + +#### Future-based Call +```python +# Returns a concurrent.futures.Future +future = client.add(1, 2).remote_future() +# Get result later +result = future.result() +``` + +#### Fire-and-Forget Call +```python +# Send request without waiting for response +client.submit_task(task_id=123).remote(need_response=False) +``` + +#### Streaming Call +```python +# For async generator methods +async for value in client.stream_data(n=10).remote_streaming(): + print(f"Received: {value}") +``` + +### Error Handling +```python +from tensorrt_llm.executor.rpc import RPCError, RPCTimeout + +try: + result = client.risky_operation().remote(timeout=1.0) +except RPCTimeout: + print("Operation timed out") +except RPCError as e: + print(f"RPC Error: {e}") + print(f"Original cause: {e.cause}") + print(f"Traceback: {e.traceback}") +``` + +### Graceful Shutdown +```python +# Shutdown server from client +client.shutdown_server() +``` diff --git a/tensorrt_llm/executor/rpc/rpc_client.py b/tensorrt_llm/executor/rpc/rpc_client.py index 111b51eca1d..03d43ce1d0b 100644 --- a/tensorrt_llm/executor/rpc/rpc_client.py +++ b/tensorrt_llm/executor/rpc/rpc_client.py @@ -59,13 +59,13 @@ def remote_future(self, need_response: bool = True) -> concurrent.futures.Future: """Remote call that returns a Future object.""" return self._prepare_and_call(timeout, need_response, "future", - "call_future") + "_call_future") def remote_streaming(self, timeout: Optional[float] = None) -> AsyncIterator[Any]: """Remote call for streaming results.""" # Streaming always needs a response - return self._prepare_and_call(timeout, True, "async", "call_streaming") + return self._prepare_and_call(timeout, True, "async", "_call_streaming") class RPCClient: @@ -365,27 +365,8 @@ def _call_sync(self, method_name, *args, **kwargs): f"RPC Client _call_sync: Got result for {method_name}: {result}") return result - def call_async(self, name: str, *args, **kwargs) -> Any: - """ - Call a remote method asynchronously. - - Args: - name: Method name to call - *args: Positional arguments - **kwargs: Keyword arguments - - Returns: - Coroutine that can be awaited - - Example: - result = await client.call_async('remote_method', arg1, arg2, key=value) - """ - if "__rpc_params" not in kwargs: - kwargs["__rpc_params"] = RPCParams(need_response=True) - return self._call_async(name, *args, **kwargs) - - def call_future(self, name: str, *args, - **kwargs) -> concurrent.futures.Future: + def _call_future(self, name: str, *args, + **kwargs) -> concurrent.futures.Future: """ Call a remote method and return a Future. @@ -396,12 +377,6 @@ def call_future(self, name: str, *args, Returns: A Future object that can be used to retrieve the result - - Example: - future = client.call_future('remote_method', arg1, arg2, key=value) - result = future.result() # blocks until complete - # or - future.add_done_callback(lambda f: print(f.result())) """ def _async_to_sync(): @@ -412,25 +387,8 @@ def _async_to_sync(): return self._executor.submit(_async_to_sync) - def call_sync(self, name: str, *args, **kwargs) -> Any: - """ - Call a remote method synchronously (blocking). - - Args: - name: Method name to call - *args: Positional arguments - **kwargs: Keyword arguments - - Returns: - The result of the remote method call - - Example: - result = client.call_sync('remote_method', arg1, arg2, key=value) - """ - return self._call_sync(name, *args, **kwargs) - - async def call_streaming(self, name: str, *args, - **kwargs) -> AsyncIterator[Any]: + async def _call_streaming(self, name: str, *args, + **kwargs) -> AsyncIterator[Any]: """ Call a remote async generator method and get streaming results. @@ -441,10 +399,6 @@ async def call_streaming(self, name: str, *args, Yields: Results from the remote async generator - - Example: - async for result in client.call_streaming('streaming_task'): - print(result) """ if self._server_stopped: raise RPCCancelled("Server is shutting down, request cancelled") @@ -474,7 +428,7 @@ async def call_streaming(self, name: str, *args, # Read streaming responses while True: - logger_debug(f"RPC Client call_streaming waiting for response", + logger_debug(f"RPC Client _call_streaming waiting for response", color="green") if timeout is None: response = await queue.get() @@ -483,14 +437,14 @@ async def call_streaming(self, name: str, *args, timeout=timeout) logger_debug( - f"RPC Client call_streaming received [{response.stream_status}] response: {response}", + f"RPC Client _call_streaming received [{response.stream_status}] response: {response}", color="green") if response.stream_status == 'start': # Start of stream continue elif response.stream_status == 'data': logger_debug( - f"RPC Client call_streaming received data: {response.result}", + f"RPC Client _call_streaming received data: {response.result}", color="green") yield response.result elif response.stream_status == 'end': diff --git a/tensorrt_llm/executor/rpc_proxy.py b/tensorrt_llm/executor/rpc_proxy.py index 83ca5af0665..8e1375a1811 100644 --- a/tensorrt_llm/executor/rpc_proxy.py +++ b/tensorrt_llm/executor/rpc_proxy.py @@ -3,7 +3,6 @@ import json import os import threading -import time from typing import Optional from ..llmapi.llm_args import KvCacheConnectorConfig @@ -71,7 +70,6 @@ def __init__( self.main_loop = None self.launch_workers() - time.sleep(1) # wait for the workers to launch # Invoke model creation on the remote # TBD: Move model creation to the mpi task, or left in RPC? diff --git a/tensorrt_llm/executor/rpc_worker.py b/tensorrt_llm/executor/rpc_worker.py index f41c2c3257f..a9ef9f435d3 100644 --- a/tensorrt_llm/executor/rpc_worker.py +++ b/tensorrt_llm/executor/rpc_worker.py @@ -230,7 +230,7 @@ def main_task( color="yellow") worker.setup_engine() - if mpi_rank() == 0: + else: logger_debug(f"Worker {mpi_rank()} is creating the RPC service", color="yellow") # Step 2: Create the RPC service, it will expose all the APIs of the worker as remote call to the client diff --git a/tests/integration/test_lists/test-db/l0_a10.yml b/tests/integration/test_lists/test-db/l0_a10.yml index 37aac5a6b87..d45fc865cb4 100644 --- a/tests/integration/test_lists/test-db/l0_a10.yml +++ b/tests/integration/test_lists/test-db/l0_a10.yml @@ -46,6 +46,8 @@ l0_a10: - unittest/llmapi/test_serialization.py - unittest/llmapi/test_utils.py - unittest/llmapi/test_llm_args.py + # executor + - unittest/executor/test_rpc.py - condition: ranges: system_gpu_count: diff --git a/tests/integration/test_lists/test-db/l0_a100.yml b/tests/integration/test_lists/test-db/l0_a100.yml index acc25bf2e48..951231de7bd 100644 --- a/tests/integration/test_lists/test-db/l0_a100.yml +++ b/tests/integration/test_lists/test-db/l0_a100.yml @@ -16,6 +16,10 @@ l0_a100: - unittest/llmapi/test_llm_pytorch.py - unittest/llmapi/test_mpi_session.py # generic tests - unittest/trt/model_api/test_model_quantization.py + # executor + - unittest/executor/test_base_worker.py + - unittest/executor/test_rpc_proxy.py + - unittest/executor/test_rpc_worker.py - condition: ranges: system_gpu_count: diff --git a/tests/unittest/executor/test_base_worker.py b/tests/unittest/executor/test_base_worker.py index c8e4fe2691e..6e661d1e4e5 100644 --- a/tests/unittest/executor/test_base_worker.py +++ b/tests/unittest/executor/test_base_worker.py @@ -12,6 +12,7 @@ # isort: off sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/..") from utils.llm_data import llm_models_root +from utils.util import skip_single_gpu # isort: on from tensorrt_llm._torch.pyexecutor.config import update_executor_config @@ -156,6 +157,8 @@ def create_worker_session(self): session = MpiPoolSession(n_workers=2) return session + @pytest.mark.gpu2 + @skip_single_gpu def test_create_executor(self): futures = self.session.submit( TestRpcWorkerBaseTP2.create_executor, diff --git a/tests/unittest/executor/test_rpc.py b/tests/unittest/executor/test_rpc.py index f628bc8b2b3..8a04d7534c2 100644 --- a/tests/unittest/executor/test_rpc.py +++ b/tests/unittest/executor/test_rpc.py @@ -238,10 +238,10 @@ def slow_method(self): with pytest.raises(RPCError) as exc_info: client.slow_method().remote(timeout=0.5) - error = exc_info.value - # Should be either a timeout error or RPC error indicating timeout - assert "timed out" in str( - error).lower() or "timeout" in str(error).lower() + error = exc_info.value + # Should be either a timeout error or RPC error indicating timeout + assert "timed out" in str(error).lower() or "timeout" in str( + error).lower() def test_method_not_found_error(self): """Test that calling non-existent methods returns proper error.""" diff --git a/tests/unittest/executor/test_rpc_proxy.py b/tests/unittest/executor/test_rpc_proxy.py index 13d281125df..17d99fd24d7 100644 --- a/tests/unittest/executor/test_rpc_proxy.py +++ b/tests/unittest/executor/test_rpc_proxy.py @@ -14,7 +14,7 @@ # isort: off sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/..") from utils.llm_data import llm_models_root -from utils.util import similar +from utils.util import similar, skip_single_gpu # isort: on model_path = llm_models_root() / "llama-models-v2/TinyLlama-1.1B-Chat-v1.0" @@ -78,6 +78,8 @@ def test_tp1(self, num_reqs): assert isinstance(kv_cache_events, list) @pytest.mark.parametrize("num_reqs", [1, 10]) + @skip_single_gpu + @pytest.mark.gpu2 def test_tp2(self, num_reqs): tokenizer = TransformersTokenizer.from_pretrained(model_path) prompt = "A B C D" diff --git a/tests/unittest/executor/test_rpc_worker.py b/tests/unittest/executor/test_rpc_worker.py index f6b759f9a07..623c124c92b 100644 --- a/tests/unittest/executor/test_rpc_worker.py +++ b/tests/unittest/executor/test_rpc_worker.py @@ -18,6 +18,7 @@ # isort: off sys.path.append(os.path.dirname(os.path.abspath(__file__)) + "/..") from utils.llm_data import llm_models_root +from utils.util import skip_single_gpu # isort: on model_path = llm_models_root() / "llama-models-v2/TinyLlama-1.1B-Chat-v1.0" @@ -177,12 +178,20 @@ async def process_request_streaming(): async def test_fetch_stats_loop_async(self): await asyncio.sleep(1) results = [] - async for stats in self.client.fetch_stats_loop_async( - ).remote_streaming(): - results.append(stats) - assert not stats # empty stats + max_batches = 5 - assert len(results) == 0 + async def consume_stats(): + async for stats in self.client.fetch_stats_loop_async( + ).remote_streaming(): + results.append(stats) + assert not stats # empty stats + if len(results) >= max_batches: + break + + await asyncio.wait_for(consume_stats(), timeout=5) + + assert len(results) == max_batches + assert all(not stats for stats in results) class TestRpcWorkerTP2: @@ -215,11 +224,15 @@ def create_worker_session(self): def create_rpc_client(self, addr: str): return RPCClient(addr) + @skip_single_gpu + @pytest.mark.gpu2 def test_create_shutdown(self): # Invoke setup_engine in rank 0, and that will unblock all the ranks to # invoke setup_engine simultaneously. pass + @skip_single_gpu + @pytest.mark.gpu2 def test_fetch_responses_sync(self): # Wait a bit to ensure engine is ready time.sleep(1) diff --git a/tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py b/tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py index df36ef3076b..b5d34dcd735 100644 --- a/tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py +++ b/tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py @@ -63,30 +63,31 @@ def test_llama_7b_multi_lora_tp2(): @pytest.mark.gpu2 def test_llm_rpc_tp2(): - llm = LLM(model=llama_model_path, - kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), - orchestrator_type="rpc", - tensor_parallel_size=2) - assert isinstance(llm._executor, GenerationExecutorRpcProxy) + with LLM(model=llama_model_path, + kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), + orchestrator_type="rpc", + tensor_parallel_size=2) as llm: + assert isinstance(llm._executor, GenerationExecutorRpcProxy) - res = llm.generate("Tell me a joke", - sampling_params=SamplingParams(max_tokens=10, end_id=-1)) - print(f"get result: {res}") + res = llm.generate("Tell me a joke", + sampling_params=SamplingParams(max_tokens=10, + end_id=-1)) + print(f"get result: {res}") - assert len(res.outputs) == 1 - assert len(res.outputs[0].token_ids) == 10 + assert len(res.outputs) == 1 + assert len(res.outputs[0].token_ids) == 10 @pytest.mark.gpu2 @pytest.mark.asyncio async def test_llm_rpc_streaming_tp2(): - llm = LLM(model=llama_model_path, - kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), - orchestrator_type="rpc", - tensor_parallel_size=2) - assert isinstance(llm._executor, GenerationExecutorRpcProxy) - - async for output in llm.generate_async("Tell me a joke", - sampling_params=SamplingParams( - max_tokens=10, end_id=-1)): - print(f"get result: {output}") + with LLM(model=llama_model_path, + kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), + orchestrator_type="rpc", + tensor_parallel_size=2) as llm: + assert isinstance(llm._executor, GenerationExecutorRpcProxy) + + async for output in llm.generate_async("Tell me a joke", + sampling_params=SamplingParams( + max_tokens=10, end_id=-1)): + print(f"get result: {output}") diff --git a/tests/unittest/llmapi/test_llm_pytorch.py b/tests/unittest/llmapi/test_llm_pytorch.py index f7175c311d1..0ed1faab2c7 100644 --- a/tests/unittest/llmapi/test_llm_pytorch.py +++ b/tests/unittest/llmapi/test_llm_pytorch.py @@ -944,31 +944,34 @@ def test_max_num_token_check(self): def test_llm_rpc(): - llm = LLM(model=llama_model_path, - kv_cache_config=global_kvcache_config, - orchestrator_type="rpc") - assert isinstance(llm._executor, GenerationExecutorRpcProxy) + # TODO: remove the with-statement when shutdown hang issue is fixed + with LLM(model=llama_model_path, + kv_cache_config=global_kvcache_config, + orchestrator_type="rpc") as llm: + assert isinstance(llm._executor, GenerationExecutorRpcProxy) - res = llm.generate("Tell me a joke", - sampling_params=SamplingParams(max_tokens=10, end_id=-1)) - print(f"get result: {res}") + res = llm.generate("Tell me a joke", + sampling_params=SamplingParams(max_tokens=10, + end_id=-1)) + print(f"get result: {res}") - assert len(res.outputs) == 1 - assert len(res.outputs[0].token_ids) == 10 + assert len(res.outputs) == 1 + assert len(res.outputs[0].token_ids) == 10 @pytest.mark.asyncio async def test_llm_rpc_streaming(): - llm = LLM(model=llama_model_path, - kv_cache_config=global_kvcache_config, - orchestrator_type="rpc") - assert isinstance(llm._executor, GenerationExecutorRpcProxy) - - outputs = [] - async for output in llm.generate_async("Tell me a joke", - sampling_params=SamplingParams( - max_tokens=10, end_id=-1), - streaming=True): - outputs.append(output.outputs[0].text) - "".join(outputs) - print(f"get result: {outputs}") + # TODO: remove the with-statement when shutdown hang issue is fixed + with LLM(model=llama_model_path, + kv_cache_config=global_kvcache_config, + orchestrator_type="rpc") as llm: + assert isinstance(llm._executor, GenerationExecutorRpcProxy) + + outputs = [] + async for output in llm.generate_async("Tell me a joke", + sampling_params=SamplingParams( + max_tokens=10, end_id=-1), + streaming=True): + outputs.append(output.outputs[0].text) + "".join(outputs) + print(f"get result: {outputs}")