diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index ad87ba174e1..63e194f0e4d 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -1,7 +1,6 @@ -import math import random from collections.abc import Iterable -from typing import Optional +from typing import Dict, List, Optional import torch @@ -21,7 +20,8 @@ from ..speculative import get_spec_decoder from .config_utils import is_mla, is_nemotron_hybrid from .kv_cache_transceiver import AttentionTypeCpp, create_kv_cache_transceiver -from .model_engine import KV_CACHE_MANAGER_KEY, PyTorchModelEngine +from .model_engine import (DRAFT_KV_CACHE_MANAGER_KEY, KV_CACHE_MANAGER_KEY, + PyTorchModelEngine) from .py_executor import PyExecutor from .resource_manager import (KVCacheManager, MambaHybridCacheManager, PeftCacheManager, ResourceManager) @@ -34,281 +34,331 @@ GB = 1 << 30 -def get_cache_size_per_token(model_config, mapping): - mem_per_token = 2 - quant_config = model_config.quant_config - if quant_config is not None and quant_config.quant_mode.has_fp8_kv_cache(): - mem_per_token = 1 - - config = model_config.pretrained_config - - num_key_value_heads = getattr(config, 'num_key_value_heads', - config.num_attention_heads) - if isinstance(num_key_value_heads, Iterable): - num_key_value_heads = sum(num_key_value_heads) / len( - num_key_value_heads) - - mla = is_mla(config) - tp_size = 1 if mapping.enable_attention_dp else mapping.tp_size - - kv_factor = 2 - if mla: - # MLA has kv_lora_rank and qk_rope_head_dim - head_dim = config.kv_lora_rank + config.qk_rope_head_dim - kv_factor = 1 - else: - head_dim = getattr( - config, - "head_dim", - config.hidden_size // config.num_attention_heads, - ) * num_key_value_heads // tp_size - - # provide at least 1 layer to prevent division by zero cache size - num_hidden_layers = max(len(mapping.pp_layers(config.num_hidden_layers)), 1) - mem_per_token *= num_hidden_layers * head_dim - # K and V - mem_per_token *= kv_factor - return mem_per_token - - -def get_fraction_from_executor_config(executor_config): - fraction = executor_config.kv_cache_config.free_gpu_memory_fraction - if fraction is None: - fraction = 0.9 - return fraction - - -def cal_max_tokens(peak_memory, total_gpu_memory, fraction, model_config, - draft_model_config, mapping: Mapping, - alloc_kv_tokens: int) -> int: - model_kv_size_per_token = get_cache_size_per_token(model_config, mapping) - draft_kv_size_per_token = get_cache_size_per_token( - draft_model_config, mapping) if draft_model_config is not None else 0 - kv_size_per_token = model_kv_size_per_token + draft_kv_size_per_token - - available_kv_mem = (total_gpu_memory - peak_memory + - alloc_kv_tokens * kv_size_per_token) * fraction - logger.info( - f"Peak memory during memory usage profiling (torch + non-torch): {peak_memory / (GB):.2f} GiB, " - f"available KV cache memory when calculating max tokens: {available_kv_mem / (GB):.2f} GiB, " - f"fraction is set {fraction}, kv size is {kv_size_per_token}") - max_tokens = int((available_kv_mem) // kv_size_per_token) - max_tokens = max(max_tokens, 0) - return max_tokens - - -def create_dummy_context_requests(max_num_tokens: int, max_seq_len: int, - vocab_size: int): - # NB: The requests constructed here need to be compatible with - # get_token_num_for_estimation. - requests = [] - max_seq_len = min(max_num_tokens, max_seq_len) - remaining_tokens = max_num_tokens - while remaining_tokens > 0: - input_len = min(max_seq_len, remaining_tokens) - input_tokens = [ - random.randint(0, vocab_size - 1) for _ in range(input_len) - ] - request = trtllm.Request(input_tokens, - max_tokens=1, - streaming=False, - sampling_config=trtllm.SamplingConfig(), - output_config=trtllm.OutputConfig(), - end_id=-1) - requests.append(request) - remaining_tokens -= input_len - return requests - - -def get_token_num_for_estimation(executor_config: ExecutorConfig, - model_config: ModelConfig) -> int: - """Compute KV cache capacity required for estimate_max_kv_cache_tokens to succeed.""" - mapping = executor_config.mapping - if 'cp_type' in mapping.cp_config: - raise ValueError( - "KV cache size estimation not supported with context parallelism.") - # When reusing KV cache blocks, we need to add extra tokens to account for partially filled blocks - # that cannot be reused. Each sequence used during estimation (cf. estimate_max_kv_cache_tokens) - # has at most max_seq_len (or max_seq_len - 1) tokens. In total, the sequences used for estimation - # have max_num_tokens tokens (cf. create_dummy_context_requests). For each sequence, we may need up to one extra - # block (tokens_per_block tokens) if the sequence length is not perfectly divisible by tokens_per_block. - # So we add math.ceil(max_num_tokens/(max_seq_len-1)) * tokens_per_block extra tokens. - return (executor_config.max_num_tokens + - math.ceil(executor_config.max_num_tokens / - (executor_config.max_seq_len - 1)) * - executor_config.tokens_per_block) - - -def estimate_max_kv_cache_tokens(py_executor: PyExecutor, - model_engine: PyTorchModelEngine, - executor_config: ExecutorConfig, - mapping: Mapping, origin_seq_len: int, - ctx_chunk_config, - draft_model_engine: PyTorchModelEngine) -> int: - # TODO: support CP by generating dummy requests for it. - if 'cp_type' in mapping.cp_config: - # This is called from create_py_executor, which ensures that - # executor_config.max_num_tokens is set. - assert executor_config.max_num_tokens is not None - return executor_config.max_num_tokens - - vocab_size = model_engine.model.model_config.pretrained_config.vocab_size - max_num_tokens = executor_config.max_num_tokens - fraction = get_fraction_from_executor_config(executor_config) - kv_cache_max_tokens_in = executor_config.kv_cache_config.max_tokens - - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() - model_bytes = torch.cuda.memory_stats()["allocated_bytes.all.current"] - logger.info( - f"Memory used after loading model weights (inside torch) in memory usage profiling: {model_bytes / (GB):.2f} GiB" - ) - - py_executor.set_gather_responses(True) - origin_iter_stats = py_executor.enable_iter_perf_stats - py_executor.enable_iter_perf_stats = False - req_ids = [] - if py_executor.dist.mapping.rank == 0: - # NOTE: TRTLLMSampler requires origin_seq_len - 1 for requests. - # Spec decoders with overlap require origin_seq_len. - seq_len = origin_seq_len - 1 if type( - py_executor.sampler) == TRTLLMSampler else origin_seq_len - req = create_dummy_context_requests(max_num_tokens, seq_len, vocab_size) - req_ids = py_executor.enqueue_requests(req) - req_ids = py_executor.dist.broadcast(req_ids, root=0) - py_executor.is_warmup = True - py_executor.start_worker() - py_executor.await_responses(req_ids) - - torch_peak_memory = torch.cuda.memory_stats()["allocated_bytes.all.peak"] - - # Clear the caching allocator before measuring the current memory usage - torch.cuda.empty_cache() - end, total_gpu_memory = torch.cuda.mem_get_info() - torch_used_bytes = torch.cuda.memory_stats()["allocated_bytes.all.current"] - total_used_bytes = total_gpu_memory - end - activation_bytes = torch_peak_memory - model_bytes - extra_cost = max(total_used_bytes - torch_used_bytes, 0) - peak_memory = torch_peak_memory + extra_cost - logger.info( - f"Memory dynamically allocated during inference (inside torch) in memory usage profiling: {activation_bytes / (GB):.2f} GiB" - ) - logger.info( - f"Memory used outside torch (e.g., NCCL and CUDA graphs) in memory usage profiling: {extra_cost / (GB):.2f} GiB" - ) - kv_stats = py_executor.resource_manager.resource_managers.get( - "kv_cache_manager").get_kv_cache_stats() - - draft_model_config = draft_model_engine.model.model_config if draft_model_engine is not None else None - kv_cache_max_tokens = cal_max_tokens( - peak_memory, total_gpu_memory, fraction, - model_engine.model.model_config, draft_model_config, mapping, - kv_stats.max_num_blocks * kv_stats.tokens_per_block) - - if kv_cache_max_tokens_in is not None: - kv_cache_max_tokens = min(kv_cache_max_tokens, kv_cache_max_tokens_in) +class KvCacheCreator: + """Groups together logic related to KV cache construction.""" + + def __init__(self, *, executor_config: ExecutorConfig, + model_engine: PyTorchModelEngine, + draft_model_engine: Optional[PyTorchModelEngine], + mapping: Mapping, net_max_seq_len: int): + self._executor_config = executor_config + self._model_engine = model_engine + self._draft_model_engine = draft_model_engine + self._mapping = mapping + self._max_kv_tokens_in = self._executor_config.kv_cache_config.max_tokens + self._dummy_reqs = self._create_dummy_context_requests(net_max_seq_len - + 1) + + @staticmethod + def _get_cache_size_per_token(model_config: ModelConfig, + mapping: Mapping) -> int: + mem_per_token = 2 + quant_config = model_config.quant_config + if quant_config is not None and quant_config.quant_mode.has_fp8_kv_cache( + ): + mem_per_token = 1 + + config = model_config.pretrained_config + + num_key_value_heads = getattr(config, 'num_key_value_heads', + config.num_attention_heads) + if isinstance(num_key_value_heads, Iterable): + num_key_value_heads = sum(num_key_value_heads) / len( + num_key_value_heads) + + mla = is_mla(config) + tp_size = 1 if mapping.enable_attention_dp else mapping.tp_size + + kv_factor = 2 + if mla: + # MLA has kv_lora_rank and qk_rope_head_dim + head_dim = config.kv_lora_rank + config.qk_rope_head_dim + kv_factor = 1 + else: + head_dim = getattr( + config, + "head_dim", + config.hidden_size // config.num_attention_heads, + ) * num_key_value_heads // tp_size + + # provide at least 1 layer to prevent division by zero cache size + num_hidden_layers = max( + len(mapping.pp_layers(config.num_hidden_layers)), 1) + mem_per_token *= num_hidden_layers * head_dim + # K and V + mem_per_token *= kv_factor + return mem_per_token + + def _get_free_gpu_memory_fraction(self) -> float: + fraction = self._executor_config.kv_cache_config.free_gpu_memory_fraction + if fraction is None: + fraction = 0.9 + return fraction + + def _cal_max_tokens(self, peak_memory, total_gpu_memory, fraction, + alloc_kv_tokens: int) -> int: + model_config = self._model_engine.model.model_config + mapping = self._mapping + kv_size_per_token = self._get_cache_size_per_token( + model_config, mapping) + if self._draft_model_engine is not None: + draft_model_config = self._draft_model_engine.model.model_config + kv_size_per_token += self._get_cache_size_per_token( + draft_model_config, mapping) + + available_kv_mem = (total_gpu_memory - peak_memory + + alloc_kv_tokens * kv_size_per_token) * fraction + logger.info( + f"Peak memory during memory usage profiling (torch + non-torch): {peak_memory / (GB):.2f} GiB, " + f"available KV cache memory when calculating max tokens: {available_kv_mem / (GB):.2f} GiB, " + f"fraction is set {fraction}, kv size is {kv_size_per_token}") + max_tokens = int((available_kv_mem) // kv_size_per_token) + max_tokens = max(max_tokens, 0) + return max_tokens + + def _create_dummy_context_requests( + self, input_seq_len: int) -> List[trtllm.Request]: + vocab_size = self._model_engine.model.model_config.pretrained_config.vocab_size + max_num_tokens = self._executor_config.max_num_tokens + + requests = [] + input_seq_len = min(max_num_tokens, input_seq_len) + remaining_tokens = max_num_tokens + while remaining_tokens > 0: + input_seq_len = min(input_seq_len, remaining_tokens) + input_tokens = [ + random.randint(0, vocab_size - 1) for _ in range(input_seq_len) + ] + request = trtllm.Request(input_tokens, + max_tokens=1, + streaming=False, + sampling_config=trtllm.SamplingConfig(), + output_config=trtllm.OutputConfig(), + end_id=-1) + requests.append(request) + remaining_tokens -= input_seq_len + return requests + + def _get_token_num_for_estimation(self) -> int: + """Compute KV cache capacity required for estimate_max_kv_cache_tokens to succeed.""" + executor_config = self._executor_config + if 'cp_type' in self._mapping.cp_config: + raise ValueError( + "KV cache size estimation not supported with context parallelism." + ) + # estimate_max_kv_cache_tokens submits self._dummy_reqs + num_cache_blocks = 0 + num_extra_tokens_per_seq = 1 # account for generated tokens + spec_cfg = executor_config.speculative_config + if spec_cfg is not None: + num_extra_tokens_per_seq += spec_cfg.max_draft_tokens + num_extra_tokens_per_seq += spec_cfg.num_extra_kv_tokens + for req in self._dummy_reqs: + num_req_tokens = len(req.input_token_ids) + num_extra_tokens_per_seq + # Requests cannot share KV cache blocks. Round up to nearest integer multiple of block size. + num_cache_blocks += (num_req_tokens + + executor_config.tokens_per_block - + 1) // executor_config.tokens_per_block + return num_cache_blocks * executor_config.tokens_per_block + + def try_prepare_estimation(self) -> bool: + """Prepare for possible KV cache capacity estimation. + + This updates `kv_cache_config` and returns a boolean indicating whether KV cache + estimation is to be performend. + """ + estimating_kv_cache = False + if 'cp_type' not in self._mapping.cp_config: + estimating_kv_cache = True + self._executor_config.kv_cache_config.max_tokens = self._get_token_num_for_estimation( + ) + return estimating_kv_cache + + def estimate_max_tokens(self, py_executor: PyExecutor) -> None: + """Perform KV cache capacity estimation. + + This updates `kv_cache_config`. + """ + executor_config = self._executor_config + mapping = self._mapping + + # TODO: support CP by generating dummy requests for it. + assert 'cp_type' not in mapping.cp_config + + fraction = self._get_free_gpu_memory_fraction() + + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + model_bytes = torch.cuda.memory_stats()["allocated_bytes.all.current"] + logger.info( + f"Memory used after loading model weights (inside torch) in memory usage profiling: {model_bytes / (GB):.2f} GiB" + ) - logger.info(f"Estimated max tokens in KV cache : {kv_cache_max_tokens}") + py_executor.set_gather_responses(True) + origin_iter_stats = py_executor.enable_iter_perf_stats + py_executor.enable_iter_perf_stats = False + req_ids = [] + if py_executor.dist.mapping.rank == 0: + req_ids = py_executor.enqueue_requests(self._dummy_reqs) + req_ids = py_executor.dist.broadcast(req_ids, root=0) + py_executor.is_warmup = True + py_executor.start_worker() + py_executor.await_responses(req_ids) + + torch_peak_memory = torch.cuda.memory_stats( + )["allocated_bytes.all.peak"] + + # Clear the caching allocator before measuring the current memory usage + torch.cuda.empty_cache() + end, total_gpu_memory = torch.cuda.mem_get_info() + torch_used_bytes = torch.cuda.memory_stats( + )["allocated_bytes.all.current"] + total_used_bytes = total_gpu_memory - end + activation_bytes = torch_peak_memory - model_bytes + extra_cost = max(total_used_bytes - torch_used_bytes, 0) + peak_memory = torch_peak_memory + extra_cost + logger.info( + f"Memory dynamically allocated during inference (inside torch) in memory usage profiling: {activation_bytes / (GB):.2f} GiB" + ) + logger.info( + f"Memory used outside torch (e.g., NCCL and CUDA graphs) in memory usage profiling: {extra_cost / (GB):.2f} GiB" + ) + kv_stats = py_executor.resource_manager.resource_managers.get( + "kv_cache_manager").get_kv_cache_stats() - py_executor.resource_manager.resource_managers.get( - "kv_cache_manager").shutdown() + kv_cache_max_tokens = self._cal_max_tokens( + peak_memory, total_gpu_memory, fraction, + kv_stats.max_num_blocks * kv_stats.tokens_per_block) - py_executor.shutdown() - py_executor.is_warmup = False - py_executor.set_gather_responses(False) - py_executor.enable_iter_perf_stats = origin_iter_stats + if self._max_kv_tokens_in is not None: + kv_cache_max_tokens = min(kv_cache_max_tokens, + self._max_kv_tokens_in) - return kv_cache_max_tokens + logger.info(f"Estimated max tokens in KV cache : {kv_cache_max_tokens}") + py_executor.resource_manager.resource_managers.get( + "kv_cache_manager").shutdown() -def create_kv_cache_manager(model_engine: PyTorchModelEngine, mapping: Mapping, - executor_config: ExecutorConfig) -> KVCacheManager: - assert executor_config.pytorch_backend_config.use_kv_cache, "Only construct KV cache when it is needed." + py_executor.shutdown() + py_executor.is_warmup = False + py_executor.set_gather_responses(False) + py_executor.enable_iter_perf_stats = origin_iter_stats - config = model_engine.model.model_config.pretrained_config - quant_config = model_engine.model.model_config.quant_config - spec_config = executor_config.speculative_config + executor_config.kv_cache_config.max_tokens = kv_cache_max_tokens - hidden_size = config.hidden_size - num_attention_heads = config.num_attention_heads - num_key_value_heads = getattr(config, 'num_key_value_heads', - num_attention_heads) - head_dim = getattr(config, "head_dim", hidden_size // num_attention_heads) + def _create_kv_cache_manager( + self, model_engine: PyTorchModelEngine) -> KVCacheManager: + executor_config = self._executor_config + mapping = self._mapping + assert executor_config.pytorch_backend_config.use_kv_cache, "Only construct KV cache when it is needed." - if quant_config is not None and quant_config.quant_mode.has_fp8_kv_cache(): - kv_cache_dtype = tensorrt_llm.bindings.DataType.FP8 - else: - kv_cache_dtype = str_dtype_to_binding( - torch_dtype_to_str(model_engine.dtype)) - - num_hidden_layers = config.num_hidden_layers - - if is_mla(config): - kv_cache_manager = KVCacheManager( - executor_config.kv_cache_config, - tensorrt_llm.bindings.internal.batch_manager.CacheType.SELFKONLY, - num_layers=num_hidden_layers, - num_kv_heads=1, - head_dim=config.kv_lora_rank + config.qk_rope_head_dim, - tokens_per_block=executor_config.tokens_per_block, - max_seq_len=executor_config.max_seq_len, - max_batch_size=executor_config.max_batch_size, - mapping=mapping, - dtype=kv_cache_dtype, - spec_config=spec_config, - ) - elif is_nemotron_hybrid(config): config = model_engine.model.model_config.pretrained_config - num_layers = config.hybrid_override_pattern.count("*") - layer_mask = [char == "*" for char in config.hybrid_override_pattern] - mamba_num_layers = config.hybrid_override_pattern.count("M") - mamba_layer_mask = [ - char == "M" for char in config.hybrid_override_pattern - ] - kv_cache_manager = MambaHybridCacheManager( - # mamba cache parameters - config.hidden_size, - config.ssm_state_size, - config.conv_kernel, - config.expand, - config.n_groups, - config.mamba_head_dim, - mamba_num_layers, - mamba_layer_mask, - config.torch_dtype, - # kv cache parameters - executor_config.kv_cache_config, - tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF, - num_layers=num_layers, - layer_mask=layer_mask, - num_kv_heads=num_key_value_heads, - head_dim=head_dim, - tokens_per_block=executor_config.tokens_per_block, - max_seq_len=executor_config.max_seq_len, - max_batch_size=executor_config.max_batch_size, - mapping=mapping, - dtype=kv_cache_dtype, - spec_config=spec_config, - ) - else: - kv_cache_manager = KVCacheManager( - executor_config.kv_cache_config, - tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF, - num_layers=num_hidden_layers, - num_kv_heads=num_key_value_heads, - head_dim=head_dim, - tokens_per_block=executor_config.tokens_per_block, - max_seq_len=executor_config.max_seq_len, - max_batch_size=executor_config.max_batch_size, - mapping=mapping, - dtype=kv_cache_dtype, - spec_config=spec_config, - ) - # KVCacheManager (Non-draft) modifies the max_seq_len field, update it to executor_config - if model_engine.kv_cache_manager_key == KV_CACHE_MANAGER_KEY: - executor_config.max_seq_len = kv_cache_manager.max_seq_len - - return kv_cache_manager + quant_config = model_engine.model.model_config.quant_config + spec_config = executor_config.speculative_config + + hidden_size = config.hidden_size + num_attention_heads = config.num_attention_heads + num_key_value_heads = getattr(config, 'num_key_value_heads', + num_attention_heads) + head_dim = getattr(config, "head_dim", + hidden_size // num_attention_heads) + + if quant_config is not None and quant_config.quant_mode.has_fp8_kv_cache( + ): + kv_cache_dtype = tensorrt_llm.bindings.DataType.FP8 + else: + kv_cache_dtype = str_dtype_to_binding( + torch_dtype_to_str(model_engine.dtype)) + + num_hidden_layers = config.num_hidden_layers + + if is_mla(config): + kv_cache_manager = KVCacheManager( + executor_config.kv_cache_config, + tensorrt_llm.bindings.internal.batch_manager.CacheType. + SELFKONLY, + num_layers=num_hidden_layers, + num_kv_heads=1, + head_dim=config.kv_lora_rank + config.qk_rope_head_dim, + tokens_per_block=executor_config.tokens_per_block, + max_seq_len=executor_config.max_seq_len, + max_batch_size=executor_config.max_batch_size, + mapping=mapping, + dtype=kv_cache_dtype, + spec_config=spec_config, + ) + elif is_nemotron_hybrid(config): + config = model_engine.model.model_config.pretrained_config + num_layers = config.hybrid_override_pattern.count("*") + layer_mask = [ + char == "*" for char in config.hybrid_override_pattern + ] + mamba_num_layers = config.hybrid_override_pattern.count("M") + mamba_layer_mask = [ + char == "M" for char in config.hybrid_override_pattern + ] + kv_cache_manager = MambaHybridCacheManager( + # mamba cache parameters + config.hidden_size, + config.ssm_state_size, + config.conv_kernel, + config.expand, + config.n_groups, + config.mamba_head_dim, + mamba_num_layers, + mamba_layer_mask, + config.torch_dtype, + # kv cache parameters + executor_config.kv_cache_config, + tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF, + num_layers=num_layers, + layer_mask=layer_mask, + num_kv_heads=num_key_value_heads, + head_dim=head_dim, + tokens_per_block=executor_config.tokens_per_block, + max_seq_len=executor_config.max_seq_len, + max_batch_size=executor_config.max_batch_size, + mapping=mapping, + dtype=kv_cache_dtype, + spec_config=spec_config, + ) + else: + kv_cache_manager = KVCacheManager( + executor_config.kv_cache_config, + tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF, + num_layers=num_hidden_layers, + num_kv_heads=num_key_value_heads, + head_dim=head_dim, + tokens_per_block=executor_config.tokens_per_block, + max_seq_len=executor_config.max_seq_len, + max_batch_size=executor_config.max_batch_size, + mapping=mapping, + dtype=kv_cache_dtype, + spec_config=spec_config, + ) + # KVCacheManager (Non-draft) modifies the max_seq_len field, update it to executor_config + if model_engine.kv_cache_manager_key == KV_CACHE_MANAGER_KEY: + executor_config.max_seq_len = kv_cache_manager.max_seq_len + + return kv_cache_manager + + def build_managers(self, resources: Dict) -> None: + """Construct KV caches for model and draft model (if applicable).""" + kv_cache_manager = self._create_kv_cache_manager(self._model_engine) + draft_kv_cache_manager = self._create_kv_cache_manager( + self._draft_model_engine + ) if self._draft_model_engine is not None else None + resources[KV_CACHE_MANAGER_KEY] = kv_cache_manager + resources[DRAFT_KV_CACHE_MANAGER_KEY] = draft_kv_cache_manager + + def teardown_managers(self, resources: Dict) -> None: + """Clean up KV caches for model and draft model (if applicable).""" + resources[KV_CACHE_MANAGER_KEY].shutdown() + del resources[KV_CACHE_MANAGER_KEY] + draft_kv_cache_manager = resources[DRAFT_KV_CACHE_MANAGER_KEY] + if draft_kv_cache_manager: + draft_kv_cache_manager.shutdown() + del resources[DRAFT_KV_CACHE_MANAGER_KEY] def create_py_executor_instance( diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index 77c1800b135..e543ad6ae8d 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -19,13 +19,11 @@ from ..attention_backend.interface import AttentionRuntimeFeatures from ..distributed import MPIDist from ..speculative import NGramConfig, get_spec_resource_manager -from ._util import (create_kv_cache_manager, create_py_executor_instance, - estimate_max_kv_cache_tokens, get_token_num_for_estimation, +from ._util import (KvCacheCreator, create_py_executor_instance, instantiate_sampler, is_mla) from .config import PyTorchConfig from .config_utils import is_mla -from .model_engine import (DRAFT_KV_CACHE_MANAGER_KEY, KV_CACHE_MANAGER_KEY, - PyTorchModelEngine) +from .model_engine import DRAFT_KV_CACHE_MANAGER_KEY, PyTorchModelEngine from .py_executor import PyExecutor @@ -138,23 +136,13 @@ def observe_creation_stage(self, current_stage: _ExecutorCreationStage): )) -def create_py_executor(executor_config: ExecutorConfig, - checkpoint_dir: str = None, - engine_dir: str = None, - lora_config: Optional[LoraConfig] = None) -> PyExecutor: +def _mangle_executor_config(executor_config: ExecutorConfig): if executor_config.pytorch_backend_config is None: executor_config.pytorch_backend_config = PyTorchConfig() - pytorch_backend_config = executor_config.pytorch_backend_config - if executor_config.mapping is None: - mapping = Mapping(world_size=tensorrt_llm.mpi_world_size(), - tp_size=tensorrt_llm.mpi_world_size(), - gpus_per_node=tensorrt_llm.default_gpus_per_node(), - rank=tensorrt_llm.mpi_rank()) - else: - mapping = copy.deepcopy(executor_config.mapping) - mapping.rank = tensorrt_llm.mpi_rank() + if executor_config.max_num_tokens is None: + executor_config.max_num_tokens = 8192 if pytorch_backend_config.attn_backend in [ "FLASHINFER", "FLASHINFER_STAR_ATTENTION" @@ -172,8 +160,28 @@ def create_py_executor(executor_config: ExecutorConfig, ) executor_config.enable_chunked_context = False - if executor_config.max_num_tokens is None: - executor_config.max_num_tokens = 8192 + +def _get_mapping(executor_config: ExecutorConfig) -> Mapping: + if executor_config.mapping is None: + mapping = Mapping(world_size=tensorrt_llm.mpi_world_size(), + tp_size=tensorrt_llm.mpi_world_size(), + gpus_per_node=tensorrt_llm.default_gpus_per_node(), + rank=tensorrt_llm.mpi_rank()) + else: + mapping = copy.deepcopy(executor_config.mapping) + mapping.rank = tensorrt_llm.mpi_rank() + return mapping + + +def create_py_executor(executor_config: ExecutorConfig, + checkpoint_dir: str = None, + engine_dir: str = None, + lora_config: Optional[LoraConfig] = None) -> PyExecutor: + _mangle_executor_config(executor_config) + pytorch_backend_config = executor_config.pytorch_backend_config + + mapping = _get_mapping(executor_config) + dist = MPIDist(mapping=mapping) spec_config = executor_config.speculative_config @@ -234,7 +242,7 @@ def create_py_executor(executor_config: ExecutorConfig, # PyTorchModelEngine modifies these fields, update them to executor_config max_seq_len = model_engine.max_seq_len - origin_seq_len = max_seq_len + net_max_seq_len = max_seq_len if not pytorch_backend_config.disable_overlap_scheduler: max_seq_len = model_engine.max_seq_len + 1 if spec_config is not None: @@ -294,26 +302,20 @@ def create_py_executor(executor_config: ExecutorConfig, sampler = instantiate_sampler(model_engine, executor_config, pytorch_backend_config, mapping) - kv_cache_manager = None - draft_kv_cache_manager = None resources = {} - origin_executor_config = copy.deepcopy(executor_config) estimating_kv_cache = False + kv_cache_creator = None if executor_config.pytorch_backend_config.use_kv_cache: - if 'cp_type' not in mapping.cp_config: - estimating_kv_cache = True - executor_config.kv_cache_config.max_tokens = get_token_num_for_estimation( - executor_config, model_engine.model.model_config) + kv_cache_creator = KvCacheCreator(executor_config=executor_config, + model_engine=model_engine, + draft_model_engine=draft_model_engine, + mapping=mapping, + net_max_seq_len=net_max_seq_len) + estimating_kv_cache = kv_cache_creator.try_prepare_estimation() with mem_monitor.observe_creation_stage( _ExecutorCreationStage.INIT_KV_CACHE if estimating_kv_cache else _ExecutorCreationStage.KV_CACHE): - kv_cache_manager = create_kv_cache_manager(model_engine, mapping, - executor_config) - draft_kv_cache_manager = create_kv_cache_manager( - draft_model_engine, mapping, - executor_config) if draft_model_engine is not None else None - resources[KV_CACHE_MANAGER_KEY] = kv_cache_manager - resources[DRAFT_KV_CACHE_MANAGER_KEY] = draft_kv_cache_manager + kv_cache_creator.build_managers(resources) # resource managers for speculative decoding if spec_config is not None: @@ -330,15 +332,11 @@ def create_py_executor(executor_config: ExecutorConfig, ctx_chunk_config, model_engine, draft_model_engine, False, sampler, lora_config) - if executor_config.pytorch_backend_config.use_kv_cache and 'cp_type' not in mapping.cp_config: - kv_cache_max_tokens = estimate_max_kv_cache_tokens( - py_executor, model_engine, origin_executor_config, mapping, - origin_seq_len, ctx_chunk_config, draft_model_engine) + if estimating_kv_cache: + assert kv_cache_creator is not None + kv_cache_creator.estimate_max_tokens(py_executor) + kv_cache_creator.teardown_managers(resources) del py_executor # free before constructing new - del kv_cache_manager # free before constructing new - del resources[KV_CACHE_MANAGER_KEY] - - executor_config.kv_cache_config.max_tokens = kv_cache_max_tokens with mem_monitor.observe_creation_stage( _ExecutorCreationStage.KV_CACHE): @@ -346,27 +344,15 @@ def create_py_executor(executor_config: ExecutorConfig, # create_kv_cache_manager above, which caps executor_config.max_seq_len. Restoring # the original value before creating the final KV cache. executor_config.max_seq_len = max_seq_len - kv_cache_manager = create_kv_cache_manager(model_engine, mapping, - executor_config) - resources[KV_CACHE_MANAGER_KEY] = kv_cache_manager - - if model_engine.attn_metadata is not None: - if pytorch_backend_config.use_cuda_graph: - model_engine._release_cuda_graphs() - del model_engine.attn_metadata - model_engine.attn_metadata = None - - if draft_model_engine is not None: - del draft_kv_cache_manager # free before constructing new - del resources[DRAFT_KV_CACHE_MANAGER_KEY] - draft_kv_cache_manager = create_kv_cache_manager( - draft_model_engine, mapping, executor_config) - resources[DRAFT_KV_CACHE_MANAGER_KEY] = draft_kv_cache_manager - if draft_model_engine.attn_metadata is not None: + kv_cache_creator.build_managers(resources) + + for eng in [model_engine, draft_model_engine]: + if eng is None: + continue + if eng.attn_metadata is not None: if pytorch_backend_config.use_cuda_graph: - draft_model_engine._release_cuda_graphs() - del draft_model_engine.attn_metadata - draft_model_engine.attn_metadata = None + eng._release_cuda_graphs() + eng.attn_metadata = None with mem_monitor.observe_creation_stage( _ExecutorCreationStage.EXTRA_RESOURCES):