Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from tensorrt_llm._utils import nvtx_range

from ...._utils import mpi_rank, mpi_world_size
from ....bindings.executor import ExecutorConfig
from ....bindings.internal.batch_manager import CacheType
from ....mapping import Mapping
from ...distributed import MPIDist
Expand Down Expand Up @@ -259,7 +258,7 @@ def forward(
return {"logits": logits_flat}


def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir: str = None):
def create_autodeploy_executor(ad_config: LlmArgs):
"""Create an AutoDeploy executor from the given configuration and checkpoint directory.

This is the entrypoint API to the _autodeploy backend.
Expand All @@ -276,8 +275,7 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir:

# some config
msg = "pytorch_backend_config must be an AD LlmArgs object"
assert isinstance(executor_config.pytorch_backend_config, LlmArgs), msg
ad_config: LlmArgs = executor_config.pytorch_backend_config
assert isinstance(ad_config, LlmArgs), msg
assert ad_config.max_beam_width <= 1, "_autodeploy + beam_search is not supported"

max_num_sequences = ad_config.max_batch_size * dist_mapping.pp_size
Expand Down
6 changes: 4 additions & 2 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,8 @@ def create_py_executor_instance(
guided_decoder: Optional[GuidedDecoder] = None,
lora_config: Optional[LoraConfig] = None,
garbage_collection_gen0_threshold: Optional[int] = None,
kv_connector_manager: Optional[KvCacheConnectorManager] = None
kv_connector_manager: Optional[KvCacheConnectorManager] = None,
max_seq_len: Optional[int] = None,
) -> PyExecutor:
kv_cache_manager = resources.get(ResourceManagerType.KV_CACHE_MANAGER, None)

Expand Down Expand Up @@ -659,7 +660,8 @@ def create_py_executor_instance(
guided_decoder=guided_decoder,
start_worker=start_worker,
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold,
kv_connector_manager=kv_connector_manager)
kv_connector_manager=kv_connector_manager,
max_seq_len=max_seq_len)


def create_torch_sampler_args(executor_config: ExecutorConfig, mapping: Mapping,
Expand Down
39 changes: 20 additions & 19 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,25 +139,25 @@ class BatchStatePP(BatchState):

class PyExecutor:

def __init__(
self,
resource_manager,
scheduler: RequestScheduler,
model_engine: ModelEngine,
sampler: Sampler,
dist: Distributed,
max_num_sequences: int,
drafter: Optional[Drafter] = None,
disable_overlap_scheduler: bool = False,
max_input_len: int = 2048,
max_batch_size: int = 8,
max_beam_width: int = 1,
max_draft_len: int = 0,
kv_cache_transceiver: Optional[KvCacheTransceiver] = None,
guided_decoder: Optional[GuidedDecoder] = None,
garbage_collection_gen0_threshold: Optional[int] = None,
start_worker: bool = True,
kv_connector_manager: Optional[KvCacheConnectorManager] = None):
def __init__(self,
resource_manager,
scheduler: RequestScheduler,
model_engine: ModelEngine,
sampler: Sampler,
dist: Distributed,
max_num_sequences: int,
drafter: Optional[Drafter] = None,
disable_overlap_scheduler: bool = False,
max_input_len: int = 2048,
max_batch_size: int = 8,
max_beam_width: int = 1,
max_draft_len: int = 0,
kv_cache_transceiver: Optional[KvCacheTransceiver] = None,
guided_decoder: Optional[GuidedDecoder] = None,
garbage_collection_gen0_threshold: Optional[int] = None,
start_worker: bool = True,
kv_connector_manager: Optional[KvCacheConnectorManager] = None,
max_seq_len: Optional[int] = None):
super(PyExecutor, self).__init__()
self.device_id = torch.cuda.current_device()
self.global_rank = global_mpi_rank()
Expand Down Expand Up @@ -271,6 +271,7 @@ def __init__(
)
self.draft_seq_slot_manager = SeqSlotManager(max_num_sequences)
self.garbage_collection_gen0_threshold = garbage_collection_gen0_threshold
self.max_seq_len = max_seq_len

self.worker_started = False
self.worker_lock = threading.Lock()
Expand Down
28 changes: 21 additions & 7 deletions tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@
from tensorrt_llm._utils import get_sm_version
from tensorrt_llm.bindings.executor import (CapacitySchedulerPolicy,
ContextChunkingPolicy,
ExecutorConfig)
ExecutorConfig,
LogitsPostProcessorConfig,
ParallelConfig)
from tensorrt_llm.bindings.internal.batch_manager import ContextChunkingConfig
from tensorrt_llm.llmapi.llm_args import KvCacheConnectorConfig
from tensorrt_llm.llmapi.llm_args import KvCacheConnectorConfig, TorchLlmArgs
from tensorrt_llm.llmapi.tokenizer import TokenizerBase
from tensorrt_llm.logger import logger
from tensorrt_llm.lora_helper import LoraConfig
from tensorrt_llm.mapping import Mapping
Expand Down Expand Up @@ -209,12 +212,21 @@ def _get_mapping(executor_config: ExecutorConfig) -> Mapping:


def create_py_executor(
executor_config: ExecutorConfig,
checkpoint_dir: str = None,
lora_config: Optional[LoraConfig] = None,
garbage_collection_gen0_threshold: Optional[int] = None,
kv_connector_config: Optional[KvCacheConnectorConfig] = None
llm_args: TorchLlmArgs,
checkpoint_dir: str = None,
tokenizer: Optional[TokenizerBase] = None,
lora_config: Optional[LoraConfig] = None,
kv_connector_config: Optional[KvCacheConnectorConfig] = None,
logits_post_processor_config: Optional[LogitsPostProcessorConfig] = None,
parallel_config: Optional[ParallelConfig] = None,
) -> PyExecutor:

executor_config = llm_args.get_executor_config(checkpoint_dir, tokenizer)
executor_config.logits_post_processor_config = logits_post_processor_config
executor_config.parallel_config = parallel_config

garbage_collection_gen0_threshold = llm_args.garbage_collection_gen0_threshold

_mangle_executor_config(executor_config)
pytorch_backend_config = executor_config.pytorch_backend_config

Expand Down Expand Up @@ -484,6 +496,7 @@ def create_py_executor(
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold,
kv_connector_manager=kv_connector_manager
if not estimating_kv_cache else None,
max_seq_len=executor_config.max_seq_len,
)

if estimating_kv_cache:
Expand Down Expand Up @@ -528,6 +541,7 @@ def create_py_executor(
garbage_collection_gen0_threshold=
garbage_collection_gen0_threshold,
kv_connector_manager=kv_connector_manager,
max_seq_len=executor_config.max_seq_len,
)

_adjust_torch_mem_fraction(executor_config.pytorch_backend_config)
Expand Down
4 changes: 2 additions & 2 deletions tensorrt_llm/executor/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ..bindings import executor as tllm
from ..builder import Engine
from ..disaggregated_params import DisaggregatedParams
from ..llmapi.llm_args import KvCacheConnectorConfig, TorchLlmArgs
from ..llmapi.llm_args import BaseLlmArgs, KvCacheConnectorConfig
from ..llmapi.llm_utils import KvCacheRetentionConfig
from ..llmapi.mpi_session import (MpiSession, external_mpi_comm_available,
need_spawn_mpi_workers)
Expand Down Expand Up @@ -359,7 +359,7 @@ def create(
kv_connector_config: Optional[KvCacheConnectorConfig] = None,
hf_model_dir: Optional[Path] = None,
tokenizer: Optional[TokenizerBase] = None,
llm_args: Optional[TorchLlmArgs] = None,
llm_args: Optional[BaseLlmArgs] = None,
) -> Union["GenerationExecutorProxy", "GenerationExecutorWorker"]:
# local imports to avoid cyclic importing
from .proxy import GenerationExecutorProxy
Expand Down
128 changes: 85 additions & 43 deletions tensorrt_llm/executor/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
mpi_comm, mpi_rank, nvtx_range_debug)
from ..bindings import executor as tllm
from ..builder import ConfigEncoder, Engine, EngineConfig
from ..llmapi.llm_args import KvCacheConnectorConfig, PybindMirror, TorchLlmArgs
from ..llmapi.llm_args import (BaseLlmArgs, KvCacheConnectorConfig,
PybindMirror, TorchLlmArgs)
from ..llmapi.mpi_session import set_mpi_session_cpp
from ..llmapi.tokenizer import TokenizerBase
from ..llmapi.tracer import VizTracer, global_tracer, set_global_tracer
Expand Down Expand Up @@ -64,7 +65,7 @@ def __init__(
kv_connector_config: Optional[KvCacheConnectorConfig] = None,
hf_model_dir: Optional[Path] = None,
tokenizer: Optional[TokenizerBase] = None,
llm_args: Optional[TorchLlmArgs] = None,
llm_args: Optional[BaseLlmArgs] = None,
) -> None:
postproc_config = postproc_worker_config or PostprocWorkerConfig()
super().__init__(
Expand Down Expand Up @@ -107,40 +108,55 @@ def _get_comm_ranks_device_id():
device_ids = mpi_comm().allgather(device_id)
return comm_ranks, device_ids

def _create_py_executor(executor_config):
assert executor_config is None, "expect an empty executor_config is _create_py_executor"
executor_config = llm_args.get_executor_config(
hf_model_dir, tokenizer)
# Persist so downstream code (e.g., default max_tokens deduction) has access
self._executor_config = executor_config
executor_config.logits_post_processor_config = tllm.LogitsPostProcessorConfig(
processor_batched=batched_logits_processor, replicate=False)
comm_ranks, device_ids = _get_comm_ranks_device_id()
executor_config.parallel_config = tllm.ParallelConfig(
participant_ids=comm_ranks, device_ids=device_ids)
args = {
"executor_config": executor_config,
"checkpoint_dir": executor_config.hf_model_dir,
}
def _create_py_executor():
args = {}
assert hasattr(
executor_config, "backend"
), "executor_config should be with backend in _create_py_executor"
if executor_config.backend == "pytorch":
self.llm_args, "backend"
), "llm_args should be with backend in _create_py_executor"
if self.llm_args.backend == "pytorch":
from tensorrt_llm._torch.pyexecutor.py_executor_creator import \
create_py_executor
create_executor = create_py_executor
args["llm_args"] = self.llm_args
args["checkpoint_dir"] = hf_model_dir
args["tokenizer"] = tokenizer
args["lora_config"] = lora_config
args[
"garbage_collection_gen0_threshold"] = llm_args.garbage_collection_gen0_threshold
args["kv_connector_config"] = kv_connector_config
elif executor_config.backend == "_autodeploy":
args[
"logits_post_processor_config"] = tllm.LogitsPostProcessorConfig(
processor_batched=batched_logits_processor,
replicate=False)
comm_ranks, device_ids = _get_comm_ranks_device_id()
args["parallel_config"] = tllm.ParallelConfig(
participant_ids=comm_ranks, device_ids=device_ids)
elif self.llm_args.backend == "_autodeploy":
from tensorrt_llm._torch.auto_deploy.llm_args import \
LlmArgs as ADLlmArgs
from tensorrt_llm._torch.auto_deploy.shim.ad_executor import \
create_autodeploy_executor
create_executor = create_autodeploy_executor
assert isinstance(self.llm_args, ADLlmArgs)
args["ad_config"] = self.llm_args.get_pytorch_backend_config()
else:
raise ValueError(
f"Unsupported backend config: {executor_config.backend}")
return create_executor(**args)
f"Unsupported backend config: {self.llm_args.backend}")

# Define additional attributes that can be used later, such as in _deduce_max_tokens
self.mapping = self.llm_args.parallel_config.to_mapping()
self.checkpoint_loader = None
if self.llm_args.backend == "pytorch":
from tensorrt_llm._torch.pyexecutor.config import \
_construct_checkpoint_loader
self.checkpoint_loader = _construct_checkpoint_loader(
self.llm_args.backend, self.llm_args.checkpoint_loader,
self.llm_args.checkpoint_format)

_executor = create_executor(**args)
self.max_seq_len = self.llm_args.max_seq_len
if _executor.max_seq_len is not None:
# max_seq_len might be updated by model engine as in create_py_executor
self.max_seq_len = _executor.max_seq_len
return _executor

def _create_engine(executor_config):
if executor_config is None:
Expand All @@ -164,8 +180,7 @@ def _create_engine(executor_config):
executor_config)

self.engine = _create_py_executor(
executor_config) if llm_args is not None else _create_engine(
executor_config)
) if self.llm_args is not None else _create_engine(executor_config)

self._lora_manager: Optional[LoraManager] = None
self._prompt_adapter_manager: Optional[PromptAdapterManager] = None
Expand All @@ -188,8 +203,9 @@ def _create_engine(executor_config):
if engine_config.build_config.max_prompt_embedding_table_size > 0:
self._prompt_adapter_manager = PromptAdapterManager()

if getattr(self._executor_config, "backend",
"") == "pytorch" and lora_config is not None:
if self.llm_args and getattr(
self.llm_args, "backend",
"") == "pytorch" and lora_config is not None:
from tensorrt_llm._torch.pyexecutor.resource_manager import \
ResourceManagerType
peft_cache_manager = self.engine.resource_manager.resource_managers.get(
Expand Down Expand Up @@ -471,26 +487,43 @@ def _enqueue_request(self, request: GenerationRequest) -> int:
assert request.id is not None

def _deduce_max_tokens(request: GenerationRequest,
executor_config: tllm.ExecutorConfig) -> int:
executor_config: tllm.ExecutorConfig,
llm_args: Optional[BaseLlmArgs] = None) -> int:
# deduce max_tokens when it's not set by user
max_tokens = request.sampling_params.max_tokens
query_token_len = len(
request.query_token_ids) if request.query_token_ids else 0
cp_size = 1 if (not hasattr(executor_config, "mapping")
or executor_config.mapping.cp_size
is None) else executor_config.mapping.cp_size
if not hasattr(executor_config, "max_seq_len"):

cp_size = 1
max_seq_len = None
if llm_args is not None:
# deduce max_tokens by llm args
assert executor_config is None, "An empty executor_config in _deduce_max_tokens is expected when LLM arguments are defined."
if hasattr(self,
"mapping") and self.mapping.cp_size is not None:
cp_size = self.mapping.cp_size
max_seq_len = getattr(self, "max_seq_len", None)
else:
# deduce max_tokens by executor config
if hasattr(executor_config, "mapping"
) and executor_config.mapping.cp_size is not None:
cp_size = executor_config.mapping.cp_size
max_seq_len = getattr(executor_config, "max_seq_len", None)
if max_seq_len is None:
logger.warning("`default_max_tokens` cannot be deduced")
if max_tokens is None:
raise ValueError(
"`max_tokens` must be set when `default_max_tokens` cannot be deduced"
)
else:
# use max_tokens if can't deduce default_max_tokens
return max_tokens
splited_prompt_len = int(len(prompt_token_ids) / cp_size)
default_max_tokens = executor_config.max_seq_len - splited_prompt_len - query_token_len
default_max_tokens = max_seq_len - splited_prompt_len - query_token_len
if default_max_tokens <= 0:
logger.warning(
f"`default_max_tokens` ({default_max_tokens}) should be greater than 0, "
f"`default_max_tokens` ({default_max_tokens}) = max_seq_len ({executor_config.max_seq_len})"
f"`default_max_tokens` ({default_max_tokens}) = max_seq_len ({max_seq_len})"
f" - `splited_prompt_len` ({splited_prompt_len}) - `query_token_len` ({query_token_len})"
)
if max_tokens is None:
Expand All @@ -512,7 +545,8 @@ def _deduce_max_tokens(request: GenerationRequest,
executor_request = tllm.Request(
client_id=request.id,
input_token_ids=prompt_token_ids,
max_tokens=_deduce_max_tokens(request, self._executor_config),
max_tokens=_deduce_max_tokens(request, self._executor_config,
self.llm_args),
streaming=request.streaming,
sampling_config=request.sampling_params._get_sampling_config(),
end_id=-1 if request.sampling_params.ignore_eos else
Expand Down Expand Up @@ -638,11 +672,19 @@ def shutdown(self):
self.engine.shutdown()
self.engine = None

if hasattr(
self._executor_config, "checkpoint_loader"
) and self._executor_config.checkpoint_loader is not None:
self._executor_config.checkpoint_loader.cleanup()
self._executor_config.checkpoint_loader = None
if self.llm_args is not None:
assert self._executor_config is None, "An empty executor_config is expected in shutdown when LLM arguments are defined."
if (self.llm_args.backend == "pytorch"
and hasattr(self, "checkpoint_loader")
and self.checkpoint_loader is not None):
self.checkpoint_loader.cleanup()
self.checkpoint_loader = None
else:
if hasattr(
self._executor_config, "checkpoint_loader"
) and self._executor_config.checkpoint_loader is not None:
self._executor_config.checkpoint_loader.cleanup()
self._executor_config.checkpoint_loader = None

# Check if there are any errors from the threads before shutdown.
self._handle_background_error()
Expand Down Expand Up @@ -689,7 +731,7 @@ def worker_main(
kv_connector_config: Optional[KvCacheConnectorConfig] = None,
hf_model_dir: Optional[Path] = None,
tokenizer: Optional[TokenizerBase] = None,
llm_args: Optional[TorchLlmArgs] = None,
llm_args: Optional[BaseLlmArgs] = None,
) -> None:
mpi_comm().barrier()
print_colored_debug(f"Worker {mpi_rank()} entering worker_main...\n",
Expand Down