From d26100bd363dd12424cadfa35ed71cf5ee7a0555 Mon Sep 17 00:00:00 2001 From: leslie-fang25 Date: Wed, 20 Aug 2025 19:06:18 -0700 Subject: [PATCH 1/8] [None][chore] Part 1: Create PyExecutor from TorchLlmArgs Signed-off-by: leslie-fang25 --- tensorrt_llm/executor/executor.py | 22 +++---- tensorrt_llm/executor/proxy.py | 9 ++- tensorrt_llm/executor/worker.py | 98 ++++++++++++++++++++----------- tensorrt_llm/llmapi/llm.py | 88 ++------------------------- tensorrt_llm/llmapi/llm_args.py | 77 +++++++++++++++++++++++- 5 files changed, 157 insertions(+), 137 deletions(-) diff --git a/tensorrt_llm/executor/executor.py b/tensorrt_llm/executor/executor.py index 14c8eeb3894..1b8bea357eb 100644 --- a/tensorrt_llm/executor/executor.py +++ b/tensorrt_llm/executor/executor.py @@ -21,6 +21,7 @@ from ..bindings import executor as tllm from ..builder import Engine from ..disaggregated_params import DisaggregatedParams +from ..llmapi.llm_args import TorchLlmArgs from ..llmapi.llm_utils import KvCacheRetentionConfig from ..llmapi.mpi_session import (MpiSession, external_mpi_comm_available, need_spawn_mpi_workers) @@ -354,7 +355,8 @@ def create( postproc_worker_config: Optional[PostprocWorkerConfig] = None, is_llm_executor: Optional[bool] = None, lora_config: Optional[LoraConfig] = None, - garbage_collection_gen0_threshold: Optional[int] = None, + hf_model_dir: Optional[Path] = None, + llm_args: Optional[TorchLlmArgs] = None, ) -> Union["GenerationExecutorProxy", "GenerationExecutorWorker"]: # local imports to avoid cyclic importing from .proxy import GenerationExecutorProxy @@ -381,6 +383,8 @@ def create( "engine": engine, "executor_config": executor_config, "batched_logits_processor": batched_logits_processor, + "hf_model_dir": hf_model_dir, + "llm_args": llm_args, } if lora_config: @@ -398,9 +402,7 @@ def create( model_world_size=model_world_size, mpi_session=mpi_session, postproc_worker_config=postproc_worker_config, - is_llm_executor=is_llm_executor, - garbage_collection_gen0_threshold= - garbage_collection_gen0_threshold) + is_llm_executor=is_llm_executor) # WAR: For the performance of gathering logits, we use single process worker # for TP1 to avoid the large overhead of IPC. @@ -411,9 +413,7 @@ def create( "Using single process worker for TP1, this may hurt streaming generation performance." ) return GenerationExecutorWorker(**worker_kwargs, - is_llm_executor=is_llm_executor, - garbage_collection_gen0_threshold= - garbage_collection_gen0_threshold) + is_llm_executor=is_llm_executor) # For single-gpu case: # Partition the workload to multiple process for streaming performance. @@ -425,9 +425,7 @@ def create( model_world_size=model_world_size, mpi_session=None, # use mpi4py postproc_worker_config=postproc_worker_config, - is_llm_executor=is_llm_executor, - garbage_collection_gen0_threshold= - garbage_collection_gen0_threshold) + is_llm_executor=is_llm_executor) else: ctx = multiprocessing.get_context("spawn") # The ProcessPoolExecutorSession is used to support Windows, as mpi4py cannot. @@ -438,9 +436,7 @@ def create( model_world_size=model_world_size, mpi_session=mpi_session, postproc_worker_config=postproc_worker_config, - is_llm_executor=is_llm_executor, - garbage_collection_gen0_threshold= - garbage_collection_gen0_threshold) + is_llm_executor=is_llm_executor) def wait_first_completed( self, futures: List[GenerationResult] diff --git a/tensorrt_llm/executor/proxy.py b/tensorrt_llm/executor/proxy.py index 78a0d076200..4026697e072 100644 --- a/tensorrt_llm/executor/proxy.py +++ b/tensorrt_llm/executor/proxy.py @@ -45,7 +45,6 @@ def __init__( worker_cls: type = GenerationExecutorWorker, postproc_worker_config: Optional[PostprocWorkerConfig] = None, is_llm_executor: Optional[bool] = None, - garbage_collection_gen0_threshold: Optional[int] = None, ) -> None: postproc_worker_config = postproc_worker_config or PostprocWorkerConfig( ) @@ -87,14 +86,14 @@ def __init__( self.model_world_size = model_world_size - self.garbage_collection_gen0_threshold = garbage_collection_gen0_threshold + self.garbage_collection_gen0_threshold = worker_kwargs[ + "llm_args"].garbage_collection_gen0_threshold if worker_kwargs.get( + "llm_args", None) is not None else None worker_kwargs = dict(**worker_kwargs, worker_queues=self._setup_queues(), postproc_worker_config=postproc_worker_config, - is_llm_executor=False, - garbage_collection_gen0_threshold=self. - garbage_collection_gen0_threshold) + is_llm_executor=False) if "log_level" not in worker_kwargs: worker_kwargs["log_level"] = logger.level diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index 8a1dab6a237..3fe655d438a 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -18,7 +18,7 @@ mpi_comm, mpi_rank, nvtx_range_debug) from ..bindings import executor as tllm from ..builder import ConfigEncoder, Engine, EngineConfig -from ..llmapi.llm_args import PybindMirror +from ..llmapi.llm_args import PybindMirror, TorchLlmArgs from ..llmapi.mpi_session import set_mpi_session_cpp from ..llmapi.tracer import VizTracer, global_tracer, set_global_tracer from ..llmapi.utils import (AsyncQueue, ManagedThread, _SyncQueue, @@ -60,7 +60,8 @@ def __init__( postproc_worker_config: Optional[PostprocWorkerConfig] = None, is_llm_executor: Optional[bool] = None, lora_config: Optional[LoraConfig] = None, - garbage_collection_gen0_threshold: Optional[int] = None, + hf_model_dir: Optional[Path] = None, + llm_args: Optional[TorchLlmArgs] = None, ) -> None: postproc_config = postproc_worker_config or PostprocWorkerConfig() super().__init__( @@ -81,8 +82,8 @@ def __init__( self._await_response_helper = AwaitResponseHelper( self) # TODO: make it weakref self._executor_config = executor_config - self._is_pytorch_backend = getattr(self._executor_config, "backend", - None) == "pytorch" + self._is_pytorch_backend = llm_args is not None and llm_args.backend == "pytorch" + self.llm_args = llm_args if global_mpi_size() > 1: logger.set_rank(self.global_rank) @@ -90,20 +91,42 @@ def __init__( if isinstance(engine, list): engine = engine[self.rank] - if executor_config is None: - executor_config = tllm.ExecutorConfig(1) + def _create_py_executor(comm_ranks, device_ids): - executor_config.logits_post_processor_config = tllm.LogitsPostProcessorConfig( - processor_batched=batched_logits_processor, replicate=False) + executor_config = llm_args.get_executor_config(hf_model_dir) + # Persist so downstream code (e.g., default max_tokens deduction) has access + self._executor_config = executor_config - def _create_engine(): - device_id = self.global_rank % torch.cuda.device_count() - torch.cuda.set_device(device_id) + executor_config.logits_post_processor_config = tllm.LogitsPostProcessorConfig( + processor_batched=batched_logits_processor, replicate=False) + executor_config.parallel_config = tllm.ParallelConfig( + participant_ids=comm_ranks, device_ids=device_ids) + args = { + "executor_config": executor_config, + "checkpoint_dir": executor_config.hf_model_dir, + } + assert hasattr( + executor_config, "backend" + ), "executor_config should be with backend in _create_py_executor" + if executor_config.backend == "pytorch": + from tensorrt_llm._torch.pyexecutor.py_executor_creator import \ + create_py_executor + create_executor = create_py_executor + args["lora_config"] = lora_config + args[ + "garbage_collection_gen0_threshold"] = llm_args.garbage_collection_gen0_threshold + else: + raise ValueError( + f"Unsupported backend config: {executor_config.backend}") + return create_executor(**args) + + def _create_engine(comm_ranks, device_ids): + if executor_config is None: + executor_config = tllm.ExecutorConfig(1) + + executor_config.logits_post_processor_config = tllm.LogitsPostProcessorConfig( + processor_batched=batched_logits_processor, replicate=False) - # Make sure C++ executor would use same devices/ranks as py_executor - global_rank = global_mpi_rank() - comm_ranks = mpi_comm().allgather(global_rank) - device_ids = mpi_comm().allgather(device_id) executor_config.parallel_config = tllm.ParallelConfig( participant_ids=comm_ranks, device_ids=device_ids) @@ -122,14 +145,7 @@ def _create_engine(): "executor_config": executor_config, "checkpoint_dir": executor_config.hf_model_dir, } - if executor_config.backend == "pytorch": - from tensorrt_llm._torch.pyexecutor.py_executor_creator import \ - create_py_executor - create_executor = create_py_executor - args["lora_config"] = lora_config - args[ - "garbage_collection_gen0_threshold"] = garbage_collection_gen0_threshold - elif executor_config.backend == "_autodeploy": + if executor_config.backend == "_autodeploy": from tensorrt_llm._torch.auto_deploy.shim.ad_executor import \ create_autodeploy_executor create_executor = create_autodeploy_executor @@ -138,7 +154,17 @@ def _create_engine(): f"Unsupported backend config: {executor_config.backend}") return create_executor(**args) - self.engine = _create_engine() + device_id = self.global_rank % torch.cuda.device_count() + torch.cuda.set_device(device_id) + + # Make sure C++ executor would use same devices/ranks as py_executor + global_rank = global_mpi_rank() + comm_ranks = mpi_comm().allgather(global_rank) + device_ids = mpi_comm().allgather(device_id) + + self.engine = _create_py_executor( + comm_ranks, device_ids) if llm_args is not None else _create_engine( + comm_ranks, device_ids) self._lora_manager: Optional[LoraManager] = None self._prompt_adapter_manager: Optional[PromptAdapterManager] = None @@ -430,14 +456,16 @@ def _enqueue_request(self, request: GenerationRequest) -> int: context_phase_params = request.disaggregated_params.get_context_phase_params( ) - is_overlap_enabled = self._is_pytorch_backend and not self._executor_config.pytorch_backend_config.disable_overlap_scheduler - if is_overlap_enabled: - is_disaggregated = self.engine.kv_cache_transceiver is not None - if is_disaggregated and ( - request_type == tllm.RequestType.REQUEST_TYPE_CONTEXT_ONLY): - raise ValueError( - "Context only requests are not supported in pytorch backend when overlap is enabled." - ) + if self._is_pytorch_backend: + assert isinstance(self.llm_args, TorchLlmArgs) + if not self.llm_args.disable_overlap_scheduler: + is_disaggregated = self.engine.kv_cache_transceiver is not None + if is_disaggregated and ( + request_type + == tllm.RequestType.REQUEST_TYPE_CONTEXT_ONLY): + raise ValueError( + "Context only requests are not supported in pytorch backend when overlap is enabled." + ) assert request.id is not None @@ -641,7 +669,8 @@ def worker_main( is_llm_executor: Optional[ bool] = True, # whether it's the main executor instance lora_config: Optional[LoraConfig] = None, - garbage_collection_gen0_threshold: Optional[int] = None, + hf_model_dir: Optional[Path] = None, + llm_args: Optional[TorchLlmArgs] = None, ) -> None: mpi_comm().barrier() print_colored_debug(f"Worker {mpi_rank()} entering worker_main...\n", @@ -768,7 +797,8 @@ def notify_proxy_threads_to_quit(): postproc_worker_config=postproc_worker_config, is_llm_executor=is_llm_executor, lora_config=lora_config, - garbage_collection_gen0_threshold=garbage_collection_gen0_threshold) + hf_model_dir=hf_model_dir, + llm_args=llm_args) except Exception as e: logger.error(f"Failed to initialize executor on rank {mpi_rank()}: {e}") logger.error(traceback.format_exc()) diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 43edb6b62cb..0173bbeb0b0 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -37,8 +37,7 @@ from .llm_utils import (CachedModelLoader, KvCacheRetentionConfig, LlmBuildStats, ModelLoader, _ModelRuntimeContext) from .mpi_session import MpiPoolSession, external_mpi_comm_available -from .tokenizer import (TokenizerBase, _llguidance_tokenizer_info, - _xgrammar_tokenizer_info) +from .tokenizer import TokenizerBase, _xgrammar_tokenizer_info # TODO[chunweiy]: move the following symbols back to utils scope, and remove the following import from .utils import (append_docstring, exception_handler, get_device_count, print_colored_debug, set_api_status) @@ -967,90 +966,13 @@ def _build_model(self): self.tokenizer) self._tokenizer = self.input_processor.tokenizer - max_batch_size = self.args.max_batch_size - max_num_tokens = self.args.max_num_tokens - max_seq_len = self.args.max_seq_len - - kwargs = {} - if self._on_trt_backend: - kwargs[ - "batching_type"] = self.args.batching_type or tllm.BatchingType.INFLIGHT - - self._executor_config = tllm.ExecutorConfig( - max_beam_width=self.args.max_beam_width, - scheduler_config=PybindMirror.maybe_to_pybind( - self.args.scheduler_config), - max_batch_size=max_batch_size, - max_num_tokens=max_num_tokens, - gather_generation_logits=self.args.gather_generation_logits, - fail_fast_on_attention_window_too_large=getattr( - self.args, 'fail_fast_on_attention_window_too_large', False), - **kwargs) - - if self.args.kv_cache_config is not None: - self._executor_config.kv_cache_config = PybindMirror.maybe_to_pybind( - self.args.kv_cache_config) - if os.getenv("FORCE_DETERMINISTIC", "0") == "1": - # Disable KV cache reuse for deterministic mode - self._executor_config.kv_cache_config.enable_block_reuse = False - self._executor_config.kv_cache_config.enable_partial_reuse = False - if self.args.peft_cache_config is not None: - self._executor_config.peft_cache_config = PybindMirror.maybe_to_pybind( - self.args.peft_cache_config) - if self.args.decoding_config is not None: - self._executor_config.decoding_config = self.args.decoding_config - if self.args.guided_decoding_backend == 'xgrammar': - self._executor_config.guided_decoding_config = tllm.GuidedDecodingConfig( - backend=tllm.GuidedDecodingConfig.GuidedDecodingBackend. - XGRAMMAR, - **_xgrammar_tokenizer_info(self.tokenizer)) - elif self.args.guided_decoding_backend == 'llguidance': - self._executor_config.guided_decoding_config = tllm.GuidedDecodingConfig( - backend=tllm.GuidedDecodingConfig.GuidedDecodingBackend. - LLGUIDANCE, - **_llguidance_tokenizer_info(self.tokenizer)) - elif self.args.guided_decoding_backend is not None: - raise ValueError( - f"Unsupported guided decoding backend {self.args.guided_decoding_backend}" - ) - - if self._on_trt_backend: - self._executor_config.normalize_log_probs = self.args.normalize_log_probs - self._executor_config.enable_chunked_context = self.args.enable_chunked_prefill - self._executor_config.max_beam_width = self.args.max_beam_width - if self.args.cache_transceiver_config is not None: - self._executor_config.cache_transceiver_config = PybindMirror.maybe_to_pybind( - self.args.cache_transceiver_config) - from tensorrt_llm._torch.pyexecutor.config import update_executor_config - - spec_config = self.args.speculative_config - max_batch_size = self._executor_config.max_batch_size - - if spec_config is not None and spec_config.decoding_type == "AUTO": - from tensorrt_llm._torch.speculative import suggest_spec_config - spec_config = suggest_spec_config(max_batch_size) - - update_executor_config( - self._executor_config, - backend=self.args.backend, - pytorch_backend_config=self.args.get_pytorch_backend_config() - if self.args.backend in ["pytorch", "_autodeploy"] else None, - mapping=self.args.parallel_config.to_mapping(), - speculative_config=spec_config, - hf_model_dir=self._hf_model_dir, - max_input_len=self.args.max_input_len, - max_seq_len=max_seq_len, - checkpoint_format=None if self.args.backend == "_autodeploy" else - self.args.checkpoint_format, - checkpoint_loader=None if self.args.backend == "_autodeploy" else - self.args.checkpoint_loader) + assert isinstance(self.args, TorchLlmArgs) # TODO: revisit gather_context_logits return_logits = self.args.gather_generation_logits - self._executor = self._executor_cls.create( self._engine_dir, - executor_config=self._executor_config, + executor_config=None, batched_logits_processor=self.args.batched_logits_processor, model_world_size=self.args.parallel_config.world_size, mpi_session=self.mpi_session, @@ -1063,8 +985,8 @@ def _build_model(self): ), is_llm_executor=True, lora_config=self.args.lora_config, - garbage_collection_gen0_threshold=self.args. - garbage_collection_gen0_threshold) + hf_model_dir=self._hf_model_dir, + llm_args=self.args) def _validate_args_for_torch_backend(self, kwargs: dict) -> None: """Validate that users don't pass TrtLlmArgs-specific arguments when using PyTorch backend. diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 6ed4dea76c7..595af14afea 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -44,7 +44,8 @@ KvCacheConfig as _KvCacheConfig, LookaheadDecodingConfig as _LookaheadDecodingConfig, PeftCacheConfig as _PeftCacheConfig, - SchedulerConfig as _SchedulerConfig) # isort: skip + SchedulerConfig as _SchedulerConfig, + GuidedDecodingConfig as _GuidedDecodingConfig) # isort: skip # isort: on # yapf: enable @@ -56,7 +57,8 @@ SpeculativeDecodingMode) from ..sampling_params import BatchedLogitsProcessor from .build_cache import BuildCacheConfig -from .tokenizer import TokenizerBase, tokenizer_factory +from .tokenizer import (TokenizerBase, _llguidance_tokenizer_info, + _xgrammar_tokenizer_info, tokenizer_factory) from .utils import generate_api_docs_as_docstring, get_type_repr # TODO[chunweiy]: move the following symbols back to utils scope, and remove the following import @@ -2374,6 +2376,77 @@ def validate_batch_wait_timeout_ms(self) -> 'TorchLlmArgs': raise ValueError("batch_wait_timeout_ms must be greater than 0") return self + def get_executor_config(self, + _hf_model_dir: Optional[Path] = None + ) -> _ExecutorConfig: + executor_config = _ExecutorConfig( + max_beam_width=self.max_beam_width, + scheduler_config=PybindMirror.maybe_to_pybind( + self.scheduler_config), + max_batch_size=self.max_batch_size, + max_num_tokens=self.max_num_tokens, + gather_generation_logits=self.gather_generation_logits, + fail_fast_on_attention_window_too_large=getattr( + self, 'fail_fast_on_attention_window_too_large', False), + ) + + if self.kv_cache_config is not None: + executor_config.kv_cache_config = PybindMirror.maybe_to_pybind( + self.kv_cache_config) + if os.getenv("FORCE_DETERMINISTIC", "0") == "1": + # Disable KV cache reuse for deterministic mode + executor_config.kv_cache_config.enable_block_reuse = False + executor_config.kv_cache_config.enable_partial_reuse = False + if self.peft_cache_config is not None: + executor_config.peft_cache_config = PybindMirror.maybe_to_pybind( + self.peft_cache_config) + if self.decoding_config is not None: + executor_config.decoding_config = self.decoding_config + if self.guided_decoding_backend == 'xgrammar': + executor_config.guided_decoding_config = _GuidedDecodingConfig( + backend=_GuidedDecodingConfig.GuidedDecodingBackend.XGRAMMAR, + **_xgrammar_tokenizer_info(self.tokenizer)) + elif self.guided_decoding_backend == 'llguidance': + executor_config.guided_decoding_config = _GuidedDecodingConfig( + backend=_GuidedDecodingConfig.GuidedDecodingBackend.LLGUIDANCE, + **_llguidance_tokenizer_info(self.tokenizer)) + elif self.guided_decoding_backend is not None: + raise ValueError( + f"Unsupported guided decoding backend {self.guided_decoding_backend}" + ) + + executor_config.enable_chunked_context = self.enable_chunked_prefill + executor_config.max_beam_width = self.max_beam_width + if self.cache_transceiver_config is not None: + executor_config.cache_transceiver_config = PybindMirror.maybe_to_pybind( + self.cache_transceiver_config) + + from tensorrt_llm._torch.pyexecutor.config import update_executor_config + + spec_config = self.speculative_config + max_batch_size = executor_config.max_batch_size + + if spec_config is not None and spec_config.decoding_type == "AUTO": + from tensorrt_llm._torch.speculative import suggest_spec_config + spec_config = suggest_spec_config(max_batch_size) + + update_executor_config( + executor_config, + backend=self.backend, + pytorch_backend_config=self.get_pytorch_backend_config() + if self.backend in ["pytorch", "_autodeploy"] else None, + mapping=self.parallel_config.to_mapping(), + speculative_config=spec_config, + hf_model_dir=_hf_model_dir, + max_input_len=self.max_input_len, + max_seq_len=self.max_seq_len, + checkpoint_format=None + if self.backend == "_autodeploy" else self.checkpoint_format, + checkpoint_loader=None + if self.backend == "_autodeploy" else self.checkpoint_loader) + + return executor_config + # TODO: Remove this after the PyTorch backend is fully migrated to TorchLlmArgs from ExecutorConfig def get_pytorch_backend_config(self) -> "PyTorchConfig": from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig From 537ab71758d16e6df2d1403297e459caca39c2a4 Mon Sep 17 00:00:00 2001 From: leslie-fang25 Date: Wed, 20 Aug 2025 22:54:21 -0700 Subject: [PATCH 2/8] fix ci failure Signed-off-by: leslie-fang25 --- tensorrt_llm/executor/worker.py | 32 +++++++++++----------- tensorrt_llm/llmapi/llm.py | 2 ++ tensorrt_llm/llmapi/llm_args.py | 16 ++++++++++- tensorrt_llm/llmapi/mm_encoder.py | 44 ++++++------------------------- 4 files changed, 41 insertions(+), 53 deletions(-) diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index 3fe655d438a..7a91aca32d2 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -91,14 +91,23 @@ def __init__( if isinstance(engine, list): engine = engine[self.rank] - def _create_py_executor(comm_ranks, device_ids): - + def _get_comm_ranks_device_id(): + device_id = self.global_rank % torch.cuda.device_count() + torch.cuda.set_device(device_id) + # Make sure C++ executor would use same devices/ranks as py_executor + global_rank = global_mpi_rank() + comm_ranks = mpi_comm().allgather(global_rank) + device_ids = mpi_comm().allgather(device_id) + return comm_ranks, device_ids + + def _create_py_executor(executor_config): + assert executor_config is None, "expect an empty executor_config is _create_py_executor" executor_config = llm_args.get_executor_config(hf_model_dir) # Persist so downstream code (e.g., default max_tokens deduction) has access self._executor_config = executor_config - executor_config.logits_post_processor_config = tllm.LogitsPostProcessorConfig( processor_batched=batched_logits_processor, replicate=False) + comm_ranks, device_ids = _get_comm_ranks_device_id() executor_config.parallel_config = tllm.ParallelConfig( participant_ids=comm_ranks, device_ids=device_ids) args = { @@ -120,13 +129,12 @@ def _create_py_executor(comm_ranks, device_ids): f"Unsupported backend config: {executor_config.backend}") return create_executor(**args) - def _create_engine(comm_ranks, device_ids): + def _create_engine(executor_config): if executor_config is None: executor_config = tllm.ExecutorConfig(1) - executor_config.logits_post_processor_config = tllm.LogitsPostProcessorConfig( processor_batched=batched_logits_processor, replicate=False) - + comm_ranks, device_ids = _get_comm_ranks_device_id() executor_config.parallel_config = tllm.ParallelConfig( participant_ids=comm_ranks, device_ids=device_ids) @@ -154,17 +162,9 @@ def _create_engine(comm_ranks, device_ids): f"Unsupported backend config: {executor_config.backend}") return create_executor(**args) - device_id = self.global_rank % torch.cuda.device_count() - torch.cuda.set_device(device_id) - - # Make sure C++ executor would use same devices/ranks as py_executor - global_rank = global_mpi_rank() - comm_ranks = mpi_comm().allgather(global_rank) - device_ids = mpi_comm().allgather(device_id) - self.engine = _create_py_executor( - comm_ranks, device_ids) if llm_args is not None else _create_engine( - comm_ranks, device_ids) + executor_config) if llm_args is not None else _create_engine( + executor_config) self._lora_manager: Optional[LoraManager] = None self._prompt_adapter_manager: Optional[PromptAdapterManager] = None diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 0173bbeb0b0..2aa047b40f5 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -967,6 +967,8 @@ def _build_model(self): self._tokenizer = self.input_processor.tokenizer assert isinstance(self.args, TorchLlmArgs) + # Update the tokenizer in TorchLlmArgs, so it can be used in GenerationExecutorWorker to init executor_config + self.args.set_tokenizer(self.tokenizer) # TODO: revisit gather_context_logits return_logits = self.args.gather_generation_logits diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 595af14afea..bc02400c0d9 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -1837,6 +1837,9 @@ def _load_config_from_ckpt(self, ckpt_dir: Path): moe_tp_size=moe_tp_size, moe_ep_size=moe_ep_size) + def set_tokenizer(self, tokenizer): + self.tokenizer = tokenizer + class TrtLlmArgs(BaseLlmArgs): @@ -2185,6 +2188,13 @@ class TorchLlmArgs(BaseLlmArgs): status="prototype", ) + mm_encoder_only: Optional[bool] = Field( + default=False, + description= + "Only load/execute the vision encoder part of the full model.", + status="prototype", + ) + # PrivateVars _quant_config: Optional[QuantConfig] = PrivateAttr(default=None) @@ -2376,6 +2386,9 @@ def validate_batch_wait_timeout_ms(self) -> 'TorchLlmArgs': raise ValueError("batch_wait_timeout_ms must be greater than 0") return self + def set_mm_encoder_only(self, mm_encoder_only): + self.mm_encoder_only = mm_encoder_only + def get_executor_config(self, _hf_model_dir: Optional[Path] = None ) -> _ExecutorConfig: @@ -2443,7 +2456,8 @@ def get_executor_config(self, checkpoint_format=None if self.backend == "_autodeploy" else self.checkpoint_format, checkpoint_loader=None - if self.backend == "_autodeploy" else self.checkpoint_loader) + if self.backend == "_autodeploy" else self.checkpoint_loader, + mm_encoder_only=self.mm_encoder_only) return executor_config diff --git a/tensorrt_llm/llmapi/mm_encoder.py b/tensorrt_llm/llmapi/mm_encoder.py index 541068a9a6d..b3333924967 100644 --- a/tensorrt_llm/llmapi/mm_encoder.py +++ b/tensorrt_llm/llmapi/mm_encoder.py @@ -4,13 +4,12 @@ from tqdm import tqdm from tensorrt_llm._utils import nvtx_range_debug -from tensorrt_llm.bindings import executor as tllm from tensorrt_llm.inputs import create_input_processor, prompt_inputs from tensorrt_llm.inputs.data import PromptInputs from tensorrt_llm.sampling_params import SamplingParams from .llm import BaseLLM, RequestOutput, _TorchLLM -from .llm_args import PybindMirror +from .llm_args import TorchLlmArgs from .mpi_session import external_mpi_comm_available @@ -56,48 +55,21 @@ def _build_model(self): self.tokenizer) self._tokenizer = self.input_processor.tokenizer - max_batch_size = self.args.max_batch_size - max_num_tokens = self.args.max_num_tokens - max_seq_len = self.args.max_seq_len - - kwargs = {} - if self._on_trt_backend: - kwargs[ - "batching_type"] = self.args.batching_type or tllm.BatchingType.INFLIGHT - - self._executor_config = tllm.ExecutorConfig( - scheduler_config=PybindMirror.maybe_to_pybind( - self.args.scheduler_config), - max_batch_size=max_batch_size, - max_num_tokens=max_num_tokens, - **kwargs) - from tensorrt_llm._torch.pyexecutor.config import update_executor_config - max_batch_size = self._executor_config.max_batch_size - update_executor_config( - self._executor_config, - backend=self.args.backend, - pytorch_backend_config=self.args.get_pytorch_backend_config() - if self.args.backend in ["pytorch", "_autodeploy"] else None, - mapping=self.args.parallel_config.to_mapping(), - hf_model_dir=self._hf_model_dir, - max_input_len=self.args.max_input_len, - max_seq_len=max_seq_len, - checkpoint_format=None if self.args.backend == "_autodeploy" else - self.args.checkpoint_format, - checkpoint_loader=None if self.args.backend == "_autodeploy" else - self.args.checkpoint_loader, - mm_encoder_only=True) + assert isinstance(self.args, TorchLlmArgs) + # Update the tokenizer in TorchLlmArgs, so it can be used in GenerationExecutorWorker to init executor_config + self.args.set_tokenizer(self.tokenizer) + self.args.set_mm_encoder_only(True) self._executor = self._executor_cls.create( self._engine_dir, - executor_config=self._executor_config, + executor_config=None, model_world_size=self.args.parallel_config.world_size, mpi_session=self.mpi_session, reuse_mpi_comm=external_mpi_comm_available( self.args.parallel_config.world_size), is_llm_executor=True, # TODO: check if this is correct or needed - garbage_collection_gen0_threshold=self.args. - garbage_collection_gen0_threshold) + hf_model_dir=self._hf_model_dir, + llm_args=self.args) def _validate_mm_args_for_torch_backend(self, kwargs: dict) -> None: """Validate that users don't pass LLM-specific arguments when using MultimodalEncoder (PyTorch). From 41ac17a8388cd5a567f7fdebf2a30a8796a4b314 Mon Sep 17 00:00:00 2001 From: leslie-fang25 Date: Thu, 21 Aug 2025 06:06:59 -0700 Subject: [PATCH 3/8] fix autodeploy ci failure Signed-off-by: leslie-fang25 --- tensorrt_llm/executor/worker.py | 22 ++--- tensorrt_llm/llmapi/llm.py | 1 - tensorrt_llm/llmapi/llm_args.py | 140 +++++++++++++++++--------------- 3 files changed, 80 insertions(+), 83 deletions(-) diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index 7a91aca32d2..60df2c0e1ae 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -124,6 +124,10 @@ def _create_py_executor(executor_config): args["lora_config"] = lora_config args[ "garbage_collection_gen0_threshold"] = llm_args.garbage_collection_gen0_threshold + elif executor_config.backend == "_autodeploy": + from tensorrt_llm._torch.auto_deploy.shim.ad_executor import \ + create_autodeploy_executor + create_executor = create_autodeploy_executor else: raise ValueError( f"Unsupported backend config: {executor_config.backend}") @@ -146,21 +150,9 @@ def _create_engine(executor_config): executor_config=executor_config, managed_weights=engine.managed_weights) - if not hasattr(executor_config, "backend"): - return tllm.Executor(engine, tllm.ModelType.DECODER_ONLY, - executor_config) - args = { - "executor_config": executor_config, - "checkpoint_dir": executor_config.hf_model_dir, - } - if executor_config.backend == "_autodeploy": - from tensorrt_llm._torch.auto_deploy.shim.ad_executor import \ - create_autodeploy_executor - create_executor = create_autodeploy_executor - else: - raise ValueError( - f"Unsupported backend config: {executor_config.backend}") - return create_executor(**args) + assert not hasattr(executor_config, "backend") + return tllm.Executor(engine, tllm.ModelType.DECODER_ONLY, + executor_config) self.engine = _create_py_executor( executor_config) if llm_args is not None else _create_engine( diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 2aa047b40f5..17ab42e5428 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -966,7 +966,6 @@ def _build_model(self): self.tokenizer) self._tokenizer = self.input_processor.tokenizer - assert isinstance(self.args, TorchLlmArgs) # Update the tokenizer in TorchLlmArgs, so it can be used in GenerationExecutorWorker to init executor_config self.args.set_tokenizer(self.tokenizer) diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index bc02400c0d9..8e446389e6e 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -1840,6 +1840,77 @@ def _load_config_from_ckpt(self, ckpt_dir: Path): def set_tokenizer(self, tokenizer): self.tokenizer = tokenizer + def get_executor_config(self, + _hf_model_dir: Optional[Path] = None + ) -> _ExecutorConfig: + executor_config = _ExecutorConfig( + max_beam_width=self.max_beam_width, + scheduler_config=PybindMirror.maybe_to_pybind( + self.scheduler_config), + max_batch_size=self.max_batch_size, + max_num_tokens=self.max_num_tokens, + gather_generation_logits=self.gather_generation_logits, + fail_fast_on_attention_window_too_large=getattr( + self, 'fail_fast_on_attention_window_too_large', False), + ) + + if self.kv_cache_config is not None: + executor_config.kv_cache_config = PybindMirror.maybe_to_pybind( + self.kv_cache_config) + if os.getenv("FORCE_DETERMINISTIC", "0") == "1": + # Disable KV cache reuse for deterministic mode + executor_config.kv_cache_config.enable_block_reuse = False + executor_config.kv_cache_config.enable_partial_reuse = False + if self.peft_cache_config is not None: + executor_config.peft_cache_config = PybindMirror.maybe_to_pybind( + self.peft_cache_config) + if self.decoding_config is not None: + executor_config.decoding_config = self.decoding_config + if self.guided_decoding_backend == 'xgrammar': + executor_config.guided_decoding_config = _GuidedDecodingConfig( + backend=_GuidedDecodingConfig.GuidedDecodingBackend.XGRAMMAR, + **_xgrammar_tokenizer_info(self.tokenizer)) + elif self.guided_decoding_backend == 'llguidance': + executor_config.guided_decoding_config = _GuidedDecodingConfig( + backend=_GuidedDecodingConfig.GuidedDecodingBackend.LLGUIDANCE, + **_llguidance_tokenizer_info(self.tokenizer)) + elif self.guided_decoding_backend is not None: + raise ValueError( + f"Unsupported guided decoding backend {self.guided_decoding_backend}" + ) + + executor_config.enable_chunked_context = self.enable_chunked_prefill + executor_config.max_beam_width = self.max_beam_width + if self.cache_transceiver_config is not None: + executor_config.cache_transceiver_config = PybindMirror.maybe_to_pybind( + self.cache_transceiver_config) + + from tensorrt_llm._torch.pyexecutor.config import update_executor_config + + spec_config = self.speculative_config + max_batch_size = executor_config.max_batch_size + + if spec_config is not None and spec_config.decoding_type == "AUTO": + from tensorrt_llm._torch.speculative import suggest_spec_config + spec_config = suggest_spec_config(max_batch_size) + + update_executor_config( + executor_config, + backend=self.backend, + pytorch_backend_config=self.get_pytorch_backend_config() + if self.backend in ["pytorch", "_autodeploy"] else None, + mapping=self.parallel_config.to_mapping(), + speculative_config=spec_config, + hf_model_dir=_hf_model_dir, + max_input_len=self.max_input_len, + max_seq_len=self.max_seq_len, + checkpoint_format=None + if self.backend == "_autodeploy" else self.checkpoint_format, + checkpoint_loader=None + if self.backend == "_autodeploy" else self.checkpoint_loader) + + return executor_config + class TrtLlmArgs(BaseLlmArgs): @@ -2392,73 +2463,8 @@ def set_mm_encoder_only(self, mm_encoder_only): def get_executor_config(self, _hf_model_dir: Optional[Path] = None ) -> _ExecutorConfig: - executor_config = _ExecutorConfig( - max_beam_width=self.max_beam_width, - scheduler_config=PybindMirror.maybe_to_pybind( - self.scheduler_config), - max_batch_size=self.max_batch_size, - max_num_tokens=self.max_num_tokens, - gather_generation_logits=self.gather_generation_logits, - fail_fast_on_attention_window_too_large=getattr( - self, 'fail_fast_on_attention_window_too_large', False), - ) - - if self.kv_cache_config is not None: - executor_config.kv_cache_config = PybindMirror.maybe_to_pybind( - self.kv_cache_config) - if os.getenv("FORCE_DETERMINISTIC", "0") == "1": - # Disable KV cache reuse for deterministic mode - executor_config.kv_cache_config.enable_block_reuse = False - executor_config.kv_cache_config.enable_partial_reuse = False - if self.peft_cache_config is not None: - executor_config.peft_cache_config = PybindMirror.maybe_to_pybind( - self.peft_cache_config) - if self.decoding_config is not None: - executor_config.decoding_config = self.decoding_config - if self.guided_decoding_backend == 'xgrammar': - executor_config.guided_decoding_config = _GuidedDecodingConfig( - backend=_GuidedDecodingConfig.GuidedDecodingBackend.XGRAMMAR, - **_xgrammar_tokenizer_info(self.tokenizer)) - elif self.guided_decoding_backend == 'llguidance': - executor_config.guided_decoding_config = _GuidedDecodingConfig( - backend=_GuidedDecodingConfig.GuidedDecodingBackend.LLGUIDANCE, - **_llguidance_tokenizer_info(self.tokenizer)) - elif self.guided_decoding_backend is not None: - raise ValueError( - f"Unsupported guided decoding backend {self.guided_decoding_backend}" - ) - - executor_config.enable_chunked_context = self.enable_chunked_prefill - executor_config.max_beam_width = self.max_beam_width - if self.cache_transceiver_config is not None: - executor_config.cache_transceiver_config = PybindMirror.maybe_to_pybind( - self.cache_transceiver_config) - - from tensorrt_llm._torch.pyexecutor.config import update_executor_config - - spec_config = self.speculative_config - max_batch_size = executor_config.max_batch_size - - if spec_config is not None and spec_config.decoding_type == "AUTO": - from tensorrt_llm._torch.speculative import suggest_spec_config - spec_config = suggest_spec_config(max_batch_size) - - update_executor_config( - executor_config, - backend=self.backend, - pytorch_backend_config=self.get_pytorch_backend_config() - if self.backend in ["pytorch", "_autodeploy"] else None, - mapping=self.parallel_config.to_mapping(), - speculative_config=spec_config, - hf_model_dir=_hf_model_dir, - max_input_len=self.max_input_len, - max_seq_len=self.max_seq_len, - checkpoint_format=None - if self.backend == "_autodeploy" else self.checkpoint_format, - checkpoint_loader=None - if self.backend == "_autodeploy" else self.checkpoint_loader, - mm_encoder_only=self.mm_encoder_only) - + executor_config = super().get_executor_config(_hf_model_dir) + executor_config.mm_encoder_only = self.mm_encoder_only return executor_config # TODO: Remove this after the PyTorch backend is fully migrated to TorchLlmArgs from ExecutorConfig From aaacfad55628c267ebe7a23a3ec269a1a3ccf384 Mon Sep 17 00:00:00 2001 From: leslie-fang25 Date: Thu, 21 Aug 2025 18:42:23 -0700 Subject: [PATCH 4/8] fix the lora_config Signed-off-by: leslie-fang25 --- tensorrt_llm/executor/worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index 60df2c0e1ae..58dcb9a9956 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -179,7 +179,7 @@ def _create_engine(executor_config): if engine_config.build_config.max_prompt_embedding_table_size > 0: self._prompt_adapter_manager = PromptAdapterManager() - if getattr(executor_config, "backend", + if getattr(self._executor_config, "backend", "") == "pytorch" and lora_config is not None: from tensorrt_llm._torch.pyexecutor.resource_manager import \ ResourceManagerType From bb27058ffdec9ba27060fd786333e66e75b60109 Mon Sep 17 00:00:00 2001 From: leslie-fang25 Date: Thu, 21 Aug 2025 20:38:11 -0700 Subject: [PATCH 5/8] fix test_llm_args test_runtime_sizes Signed-off-by: leslie-fang25 --- tests/unittest/llmapi/test_llm_args.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/unittest/llmapi/test_llm_args.py b/tests/unittest/llmapi/test_llm_args.py index 64282a00229..807ad03d676 100644 --- a/tests/unittest/llmapi/test_llm_args.py +++ b/tests/unittest/llmapi/test_llm_args.py @@ -438,11 +438,6 @@ def test_runtime_sizes(self): assert llm.args.max_seq_len == 128 assert llm.args.max_batch_size == 8 - assert llm._executor_config.max_beam_width == 1 - assert llm._executor_config.max_num_tokens == 256 - assert llm._executor_config.max_seq_len == 128 - assert llm._executor_config.max_batch_size == 8 - def test_dynamic_setattr(self): with pytest.raises(pydantic_core._pydantic_core.ValidationError): args = TorchLlmArgs(model=llama_model_path, invalid_arg=1) From 2bc8794b02961c18d18d6cadfe43a9fbceade2ce Mon Sep 17 00:00:00 2001 From: leslie-fang25 Date: Thu, 21 Aug 2025 23:16:54 -0700 Subject: [PATCH 6/8] fix api stability Signed-off-by: leslie-fang25 --- tensorrt_llm/llmapi/llm_args.py | 4 ++-- tests/unittest/api_stability/references/llm.yaml | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 8e446389e6e..58abafef362 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -2259,10 +2259,10 @@ class TorchLlmArgs(BaseLlmArgs): status="prototype", ) - mm_encoder_only: Optional[bool] = Field( + mm_encoder_only: bool = Field( default=False, description= - "Only load/execute the vision encoder part of the full model.", + "Only load/execute the vision encoder part of the full model. Defaults to False.", status="prototype", ) diff --git a/tests/unittest/api_stability/references/llm.yaml b/tests/unittest/api_stability/references/llm.yaml index d9dcd0f83d2..91511263db1 100644 --- a/tests/unittest/api_stability/references/llm.yaml +++ b/tests/unittest/api_stability/references/llm.yaml @@ -95,6 +95,10 @@ methods: annotation: Optional[str] default: null status: prototype + mm_encoder_only: + annotation: bool + default: False + status: prototype disable_overlap_scheduler: annotation: bool default: False From 82e9ca8c8402bad242f6b136f690b628f59bd412 Mon Sep 17 00:00:00 2001 From: leslie-fang25 Date: Sat, 23 Aug 2025 20:10:14 -0700 Subject: [PATCH 7/8] fix autodeploy tokenizer Signed-off-by: leslie-fang25 --- tensorrt_llm/executor/executor.py | 3 +++ tensorrt_llm/executor/worker.py | 7 ++++++- tensorrt_llm/llmapi/llm.py | 4 +--- tensorrt_llm/llmapi/llm_args.py | 27 +++++++++++++++------------ tensorrt_llm/llmapi/mm_encoder.py | 3 +-- 5 files changed, 26 insertions(+), 18 deletions(-) diff --git a/tensorrt_llm/executor/executor.py b/tensorrt_llm/executor/executor.py index 1b8bea357eb..d85f94c3426 100644 --- a/tensorrt_llm/executor/executor.py +++ b/tensorrt_llm/executor/executor.py @@ -25,6 +25,7 @@ from ..llmapi.llm_utils import KvCacheRetentionConfig from ..llmapi.mpi_session import (MpiSession, external_mpi_comm_available, need_spawn_mpi_workers) +from ..llmapi.tokenizer import TokenizerBase from ..llmapi.utils import (AsyncQueue, enable_llm_debug, enable_worker_single_process_for_tp1, print_colored, print_colored_debug) @@ -356,6 +357,7 @@ def create( is_llm_executor: Optional[bool] = None, lora_config: Optional[LoraConfig] = None, hf_model_dir: Optional[Path] = None, + tokenizer: Optional[TokenizerBase] = None, llm_args: Optional[TorchLlmArgs] = None, ) -> Union["GenerationExecutorProxy", "GenerationExecutorWorker"]: # local imports to avoid cyclic importing @@ -384,6 +386,7 @@ def create( "executor_config": executor_config, "batched_logits_processor": batched_logits_processor, "hf_model_dir": hf_model_dir, + "tokenizer": tokenizer, "llm_args": llm_args, } diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index 58dcb9a9956..493d055370d 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -20,6 +20,7 @@ from ..builder import ConfigEncoder, Engine, EngineConfig from ..llmapi.llm_args import PybindMirror, TorchLlmArgs from ..llmapi.mpi_session import set_mpi_session_cpp +from ..llmapi.tokenizer import TokenizerBase from ..llmapi.tracer import VizTracer, global_tracer, set_global_tracer from ..llmapi.utils import (AsyncQueue, ManagedThread, _SyncQueue, clear_sched_affinity, print_colored_debug, @@ -61,6 +62,7 @@ def __init__( is_llm_executor: Optional[bool] = None, lora_config: Optional[LoraConfig] = None, hf_model_dir: Optional[Path] = None, + tokenizer: Optional[TokenizerBase] = None, llm_args: Optional[TorchLlmArgs] = None, ) -> None: postproc_config = postproc_worker_config or PostprocWorkerConfig() @@ -102,7 +104,8 @@ def _get_comm_ranks_device_id(): def _create_py_executor(executor_config): assert executor_config is None, "expect an empty executor_config is _create_py_executor" - executor_config = llm_args.get_executor_config(hf_model_dir) + executor_config = llm_args.get_executor_config( + hf_model_dir, tokenizer) # Persist so downstream code (e.g., default max_tokens deduction) has access self._executor_config = executor_config executor_config.logits_post_processor_config = tllm.LogitsPostProcessorConfig( @@ -662,6 +665,7 @@ def worker_main( bool] = True, # whether it's the main executor instance lora_config: Optional[LoraConfig] = None, hf_model_dir: Optional[Path] = None, + tokenizer: Optional[TokenizerBase] = None, llm_args: Optional[TorchLlmArgs] = None, ) -> None: mpi_comm().barrier() @@ -790,6 +794,7 @@ def notify_proxy_threads_to_quit(): is_llm_executor=is_llm_executor, lora_config=lora_config, hf_model_dir=hf_model_dir, + tokenizer=tokenizer, llm_args=llm_args) except Exception as e: logger.error(f"Failed to initialize executor on rank {mpi_rank()}: {e}") diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 17ab42e5428..b95e41f57a2 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -966,9 +966,6 @@ def _build_model(self): self.tokenizer) self._tokenizer = self.input_processor.tokenizer - # Update the tokenizer in TorchLlmArgs, so it can be used in GenerationExecutorWorker to init executor_config - self.args.set_tokenizer(self.tokenizer) - # TODO: revisit gather_context_logits return_logits = self.args.gather_generation_logits self._executor = self._executor_cls.create( @@ -987,6 +984,7 @@ def _build_model(self): is_llm_executor=True, lora_config=self.args.lora_config, hf_model_dir=self._hf_model_dir, + tokenizer=self.tokenizer, llm_args=self.args) def _validate_args_for_torch_backend(self, kwargs: dict) -> None: diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 58abafef362..3ca3ede2427 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -1837,12 +1837,11 @@ def _load_config_from_ckpt(self, ckpt_dir: Path): moe_tp_size=moe_tp_size, moe_ep_size=moe_ep_size) - def set_tokenizer(self, tokenizer): - self.tokenizer = tokenizer - - def get_executor_config(self, - _hf_model_dir: Optional[Path] = None - ) -> _ExecutorConfig: + def get_executor_config( + self, + _hf_model_dir: Optional[Path] = None, + tokenizer: Optional[TokenizerBase] = None, + ) -> _ExecutorConfig: executor_config = _ExecutorConfig( max_beam_width=self.max_beam_width, scheduler_config=PybindMirror.maybe_to_pybind( @@ -1867,13 +1866,15 @@ def get_executor_config(self, if self.decoding_config is not None: executor_config.decoding_config = self.decoding_config if self.guided_decoding_backend == 'xgrammar': + assert tokenizer is not None executor_config.guided_decoding_config = _GuidedDecodingConfig( backend=_GuidedDecodingConfig.GuidedDecodingBackend.XGRAMMAR, - **_xgrammar_tokenizer_info(self.tokenizer)) + **_xgrammar_tokenizer_info(tokenizer)) elif self.guided_decoding_backend == 'llguidance': + assert tokenizer is not None executor_config.guided_decoding_config = _GuidedDecodingConfig( backend=_GuidedDecodingConfig.GuidedDecodingBackend.LLGUIDANCE, - **_llguidance_tokenizer_info(self.tokenizer)) + **_llguidance_tokenizer_info(tokenizer)) elif self.guided_decoding_backend is not None: raise ValueError( f"Unsupported guided decoding backend {self.guided_decoding_backend}" @@ -2460,10 +2461,12 @@ def validate_batch_wait_timeout_ms(self) -> 'TorchLlmArgs': def set_mm_encoder_only(self, mm_encoder_only): self.mm_encoder_only = mm_encoder_only - def get_executor_config(self, - _hf_model_dir: Optional[Path] = None - ) -> _ExecutorConfig: - executor_config = super().get_executor_config(_hf_model_dir) + def get_executor_config( + self, + _hf_model_dir: Optional[Path] = None, + tokenizer: Optional[TokenizerBase] = None, + ) -> _ExecutorConfig: + executor_config = super().get_executor_config(_hf_model_dir, tokenizer) executor_config.mm_encoder_only = self.mm_encoder_only return executor_config diff --git a/tensorrt_llm/llmapi/mm_encoder.py b/tensorrt_llm/llmapi/mm_encoder.py index b3333924967..b30dca16ca6 100644 --- a/tensorrt_llm/llmapi/mm_encoder.py +++ b/tensorrt_llm/llmapi/mm_encoder.py @@ -56,8 +56,6 @@ def _build_model(self): self._tokenizer = self.input_processor.tokenizer assert isinstance(self.args, TorchLlmArgs) - # Update the tokenizer in TorchLlmArgs, so it can be used in GenerationExecutorWorker to init executor_config - self.args.set_tokenizer(self.tokenizer) self.args.set_mm_encoder_only(True) self._executor = self._executor_cls.create( @@ -69,6 +67,7 @@ def _build_model(self): self.args.parallel_config.world_size), is_llm_executor=True, # TODO: check if this is correct or needed hf_model_dir=self._hf_model_dir, + tokenizer=self.tokenizer, llm_args=self.args) def _validate_mm_args_for_torch_backend(self, kwargs: dict) -> None: From e9d9d9921dfc19664ade739e71da3c63a5d32f9f Mon Sep 17 00:00:00 2001 From: leslie-fang25 Date: Sun, 24 Aug 2025 18:19:37 -0700 Subject: [PATCH 8/8] address comment Signed-off-by: leslie-fang25 --- tensorrt_llm/llmapi/llm_args.py | 3 --- tensorrt_llm/llmapi/mm_encoder.py | 2 +- tests/unittest/llmapi/test_llm_args.py | 7 +++++++ 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 3ca3ede2427..ff9c8c40af1 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -2458,9 +2458,6 @@ def validate_batch_wait_timeout_ms(self) -> 'TorchLlmArgs': raise ValueError("batch_wait_timeout_ms must be greater than 0") return self - def set_mm_encoder_only(self, mm_encoder_only): - self.mm_encoder_only = mm_encoder_only - def get_executor_config( self, _hf_model_dir: Optional[Path] = None, diff --git a/tensorrt_llm/llmapi/mm_encoder.py b/tensorrt_llm/llmapi/mm_encoder.py index b30dca16ca6..af0f031fc02 100644 --- a/tensorrt_llm/llmapi/mm_encoder.py +++ b/tensorrt_llm/llmapi/mm_encoder.py @@ -56,7 +56,7 @@ def _build_model(self): self._tokenizer = self.input_processor.tokenizer assert isinstance(self.args, TorchLlmArgs) - self.args.set_mm_encoder_only(True) + self.args.mm_encoder_only = True self._executor = self._executor_cls.create( self._engine_dir, diff --git a/tests/unittest/llmapi/test_llm_args.py b/tests/unittest/llmapi/test_llm_args.py index 807ad03d676..a01d7f591f3 100644 --- a/tests/unittest/llmapi/test_llm_args.py +++ b/tests/unittest/llmapi/test_llm_args.py @@ -438,6 +438,13 @@ def test_runtime_sizes(self): assert llm.args.max_seq_len == 128 assert llm.args.max_batch_size == 8 + executor_config = llm.args.get_executor_config( + llm._hf_model_dir, llm.tokenizer) + assert executor_config.max_beam_width == 1 + assert executor_config.max_num_tokens == 256 + assert executor_config.max_seq_len == 128 + assert executor_config.max_batch_size == 8 + def test_dynamic_setattr(self): with pytest.raises(pydantic_core._pydantic_core.ValidationError): args = TorchLlmArgs(model=llama_model_path, invalid_arg=1)