Skip to content
6 changes: 2 additions & 4 deletions tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from tensorrt_llm._utils import nvtx_range

from ...._utils import mpi_rank, mpi_world_size
from ....bindings.executor import ExecutorConfig
from ....bindings.internal.batch_manager import CacheType
from ....mapping import Mapping
from ...distributed import MPIDist
Expand Down Expand Up @@ -259,7 +258,7 @@ def forward(
return {"logits": logits_flat}


def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir: str = None):
def create_autodeploy_executor(ad_config: LlmArgs):
"""Create an AutoDeploy executor from the given configuration and checkpoint directory.

This is the entrypoint API to the _autodeploy backend.
Expand All @@ -276,8 +275,7 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir:

# some config
msg = "pytorch_backend_config must be an AD LlmArgs object"
assert isinstance(executor_config.pytorch_backend_config, LlmArgs), msg
ad_config: LlmArgs = executor_config.pytorch_backend_config
assert isinstance(ad_config, LlmArgs), msg
assert ad_config.max_beam_width <= 1, "_autodeploy + beam_search is not supported"

max_num_sequences = ad_config.max_batch_size * dist_mapping.pp_size
Expand Down
31 changes: 26 additions & 5 deletions tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,13 @@
import tensorrt_llm
from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType
from tensorrt_llm._utils import get_sm_version
from tensorrt_llm.bindings.executor import ContextChunkingPolicy, ExecutorConfig
from tensorrt_llm.bindings.executor import (ContextChunkingPolicy,
ExecutorConfig,
LogitsPostProcessorConfig,
ParallelConfig)
from tensorrt_llm.bindings.internal.batch_manager import ContextChunkingConfig
from tensorrt_llm.llmapi.llm_args import TorchLlmArgs
from tensorrt_llm.llmapi.tokenizer import TokenizerBase
from tensorrt_llm.logger import logger
from tensorrt_llm.lora_helper import LoraConfig
from tensorrt_llm.mapping import Mapping
Expand Down Expand Up @@ -203,10 +208,21 @@ def _get_mapping(executor_config: ExecutorConfig) -> Mapping:


def create_py_executor(
executor_config: ExecutorConfig,
checkpoint_dir: str = None,
lora_config: Optional[LoraConfig] = None,
garbage_collection_gen0_threshold: Optional[int] = None) -> PyExecutor:
llm_args: TorchLlmArgs,
checkpoint_dir: str = None,
tokenizer: Optional[TokenizerBase] = None,
lora_config: Optional[LoraConfig] = None,
logits_post_processor_config: Optional[LogitsPostProcessorConfig] = None,
parallel_config: Optional[ParallelConfig] = None,
kwargs_py_executor: Optional[dict] = None,
) -> PyExecutor:

executor_config = llm_args.get_executor_config(checkpoint_dir, tokenizer)
executor_config.logits_post_processor_config = logits_post_processor_config
executor_config.parallel_config = parallel_config

garbage_collection_gen0_threshold = llm_args.garbage_collection_gen0_threshold

_mangle_executor_config(executor_config)
pytorch_backend_config = executor_config.pytorch_backend_config

Expand Down Expand Up @@ -294,6 +310,8 @@ def create_py_executor(
max_seq_len += spec_config.max_draft_len

executor_config.max_seq_len = max_seq_len
if kwargs_py_executor and "max_seq_len" in kwargs_py_executor:
kwargs_py_executor["max_seq_len"] = max_seq_len
executor_config.max_num_tokens = model_engine.max_num_tokens

config = model_engine.model.model_config.pretrained_config
Expand Down Expand Up @@ -441,6 +459,9 @@ def create_py_executor(
# create_kv_cache_manager above, which caps executor_config.max_seq_len. Restoring
# the original value before creating the final KV cache.
executor_config.max_seq_len = max_seq_len
if kwargs_py_executor and "max_seq_len" in kwargs_py_executor:
kwargs_py_executor["max_seq_len"] = max_seq_len

kv_cache_creator.build_managers(resources)

for eng in [model_engine, draft_model_engine]:
Expand Down
25 changes: 12 additions & 13 deletions tensorrt_llm/executor/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
from ..bindings import executor as tllm
from ..builder import Engine
from ..disaggregated_params import DisaggregatedParams
from ..llmapi.llm_args import BaseLlmArgs
from ..llmapi.llm_utils import KvCacheRetentionConfig
from ..llmapi.mpi_session import (MpiSession, external_mpi_comm_available,
need_spawn_mpi_workers)
from ..llmapi.tokenizer import TokenizerBase
from ..llmapi.utils import (AsyncQueue, enable_llm_debug,
enable_worker_single_process_for_tp1, print_colored,
print_colored_debug)
Expand Down Expand Up @@ -354,7 +356,9 @@ def create(
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
is_llm_executor: Optional[bool] = None,
lora_config: Optional[LoraConfig] = None,
garbage_collection_gen0_threshold: Optional[int] = None,
hf_model_dir: Optional[Path] = None,
tokenizer: Optional[TokenizerBase] = None,
llm_args: Optional[BaseLlmArgs] = None,
) -> Union["GenerationExecutorProxy", "GenerationExecutorWorker"]:
# local imports to avoid cyclic importing
from .proxy import GenerationExecutorProxy
Expand All @@ -381,6 +385,9 @@ def create(
"engine": engine,
"executor_config": executor_config,
"batched_logits_processor": batched_logits_processor,
"hf_model_dir": hf_model_dir,
"tokenizer": tokenizer,
"llm_args": llm_args,
}

if lora_config:
Expand All @@ -398,9 +405,7 @@ def create(
model_world_size=model_world_size,
mpi_session=mpi_session,
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor,
garbage_collection_gen0_threshold=
garbage_collection_gen0_threshold)
is_llm_executor=is_llm_executor)

# WAR: For the performance of gathering logits, we use single process worker
# for TP1 to avoid the large overhead of IPC.
Expand All @@ -411,9 +416,7 @@ def create(
"Using single process worker for TP1, this may hurt streaming generation performance."
)
return GenerationExecutorWorker(**worker_kwargs,
is_llm_executor=is_llm_executor,
garbage_collection_gen0_threshold=
garbage_collection_gen0_threshold)
is_llm_executor=is_llm_executor)

# For single-gpu case:
# Partition the workload to multiple process for streaming performance.
Expand All @@ -425,9 +428,7 @@ def create(
model_world_size=model_world_size,
mpi_session=None, # use mpi4py
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor,
garbage_collection_gen0_threshold=
garbage_collection_gen0_threshold)
is_llm_executor=is_llm_executor)
else:
ctx = multiprocessing.get_context("spawn")
# The ProcessPoolExecutorSession is used to support Windows, as mpi4py cannot.
Expand All @@ -438,9 +439,7 @@ def create(
model_world_size=model_world_size,
mpi_session=mpi_session,
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor,
garbage_collection_gen0_threshold=
garbage_collection_gen0_threshold)
is_llm_executor=is_llm_executor)

def wait_first_completed(
self, futures: List[GenerationResult]
Expand Down
9 changes: 4 additions & 5 deletions tensorrt_llm/executor/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def __init__(
worker_cls: type = GenerationExecutorWorker,
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
is_llm_executor: Optional[bool] = None,
garbage_collection_gen0_threshold: Optional[int] = None,
) -> None:
postproc_worker_config = postproc_worker_config or PostprocWorkerConfig(
)
Expand Down Expand Up @@ -87,14 +86,14 @@ def __init__(

self.model_world_size = model_world_size

self.garbage_collection_gen0_threshold = garbage_collection_gen0_threshold
self.garbage_collection_gen0_threshold = worker_kwargs[
"llm_args"].garbage_collection_gen0_threshold if worker_kwargs.get(
"llm_args", None) is not None else None

worker_kwargs = dict(**worker_kwargs,
worker_queues=self._setup_queues(),
postproc_worker_config=postproc_worker_config,
is_llm_executor=False,
garbage_collection_gen0_threshold=self.
garbage_collection_gen0_threshold)
is_llm_executor=False)

if "log_level" not in worker_kwargs:
worker_kwargs["log_level"] = logger.level
Expand Down
Loading