Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 103 additions & 109 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
load_torch_hf_lora)
from tensorrt_llm.mapping import Mapping

from ..model_config import ModelConfig
from ..speculative import get_spec_decoder
from .config_utils import is_mla, is_nemotron_hybrid
from .kv_cache_transceiver import AttentionTypeCpp, create_kv_cache_transceiver
Expand Down Expand Up @@ -98,6 +99,8 @@ def cal_max_tokens(peak_memory, total_gpu_memory, fraction, model_config,

def create_dummy_context_requests(max_num_tokens: int, max_seq_len: int,
vocab_size: int):
# NB: The requests constructed here need to be compatible with
# get_token_num_for_estimation.
requests = []
max_seq_len = min(max_num_tokens, max_seq_len)
remaining_tokens = max_num_tokens
Expand All @@ -117,26 +120,23 @@ def create_dummy_context_requests(max_num_tokens: int, max_seq_len: int,
return requests


def get_token_num_for_estimation(executor_config, model_config):
def get_token_num_for_estimation(executor_config: ExecutorConfig,
model_config: ModelConfig) -> int:
"""Compute KV cache capacity required for estimate_max_kv_cache_tokens to succeed."""
mapping = executor_config.mapping
if 'cp_type' not in mapping.cp_config:
end, _ = torch.cuda.mem_get_info()
fraction = get_fraction_from_executor_config(executor_config)
kv_size_per_token = get_cache_size_per_token(model_config, mapping)
max_tokens_limit = int(end * fraction // kv_size_per_token)
# When reusing KV cache blocks, we need to add extra tokens to account for partially filled blocks
# that cannot be reused. For each sequence of max_num_tokens length, we may need up to one extra
# block (tokens_per_block tokens) if the sequence length is not perfectly divisible by tokens_per_block.
# So we add math.ceil(max_num_tokens/max_seq_len) * tokens_per_block extra tokens.
return min(
max(
executor_config.max_batch_size, executor_config.max_num_tokens +
math.ceil(executor_config.max_num_tokens /
executor_config.max_seq_len) *
executor_config.tokens_per_block, executor_config.max_seq_len),
max_tokens_limit)
else:
return None
if 'cp_type' in mapping.cp_config:
raise ValueError(
"KV cache size estimation not supported with context parallelism.")
# When reusing KV cache blocks, we need to add extra tokens to account for partially filled blocks
# that cannot be reused. Each sequence used during estimation (cf. estimate_max_kv_cache_tokens)
# has at most max_seq_len (or max_seq_len - 1) tokens. In total, the sequences used for estimation
# have max_num_tokens tokens (cf. create_dummy_context_requests). For each sequence, we may need up to one extra
# block (tokens_per_block tokens) if the sequence length is not perfectly divisible by tokens_per_block.
# So we add math.ceil(max_num_tokens/(max_seq_len-1)) * tokens_per_block extra tokens.
return (executor_config.max_num_tokens +
math.ceil(executor_config.max_num_tokens /
(executor_config.max_seq_len - 1)) *
executor_config.tokens_per_block)


def estimate_max_kv_cache_tokens(py_executor: PyExecutor,
Expand Down Expand Up @@ -223,97 +223,91 @@ def estimate_max_kv_cache_tokens(py_executor: PyExecutor,

def create_kv_cache_manager(model_engine: PyTorchModelEngine, mapping: Mapping,
executor_config: ExecutorConfig) -> KVCacheManager:
kv_cache_manager = None
if executor_config.pytorch_backend_config.use_kv_cache:
assert executor_config.pytorch_backend_config.use_kv_cache, "Only construct KV cache when it is needed."

config = model_engine.model.model_config.pretrained_config
quant_config = model_engine.model.model_config.quant_config
spec_config = executor_config.speculative_config

hidden_size = config.hidden_size
num_attention_heads = config.num_attention_heads
num_key_value_heads = getattr(config, 'num_key_value_heads',
num_attention_heads)
head_dim = getattr(config, "head_dim", hidden_size // num_attention_heads)

if quant_config is not None and quant_config.quant_mode.has_fp8_kv_cache():
kv_cache_dtype = tensorrt_llm.bindings.DataType.FP8
else:
kv_cache_dtype = str_dtype_to_binding(
torch_dtype_to_str(model_engine.dtype))

num_hidden_layers = config.num_hidden_layers

if is_mla(config):
kv_cache_manager = KVCacheManager(
executor_config.kv_cache_config,
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELFKONLY,
num_layers=num_hidden_layers,
num_kv_heads=1,
head_dim=config.kv_lora_rank + config.qk_rope_head_dim,
tokens_per_block=executor_config.tokens_per_block,
max_seq_len=executor_config.max_seq_len,
max_batch_size=executor_config.max_batch_size,
mapping=mapping,
dtype=kv_cache_dtype,
spec_config=spec_config,
)
elif is_nemotron_hybrid(config):
config = model_engine.model.model_config.pretrained_config
quant_config = model_engine.model.model_config.quant_config
spec_config = executor_config.speculative_config

hidden_size = config.hidden_size
num_attention_heads = config.num_attention_heads
num_key_value_heads = getattr(config, 'num_key_value_heads',
num_attention_heads)
head_dim = getattr(config, "head_dim",
hidden_size // num_attention_heads)

if quant_config is not None and quant_config.quant_mode.has_fp8_kv_cache(
):
kv_cache_dtype = tensorrt_llm.bindings.DataType.FP8
else:
kv_cache_dtype = str_dtype_to_binding(
torch_dtype_to_str(model_engine.dtype))

num_hidden_layers = config.num_hidden_layers

if is_mla(config):
kv_cache_manager = KVCacheManager(
executor_config.kv_cache_config,
tensorrt_llm.bindings.internal.batch_manager.CacheType.
SELFKONLY,
num_layers=num_hidden_layers,
num_kv_heads=1,
head_dim=config.kv_lora_rank + config.qk_rope_head_dim,
tokens_per_block=executor_config.tokens_per_block,
max_seq_len=executor_config.max_seq_len,
max_batch_size=executor_config.max_batch_size,
mapping=mapping,
dtype=kv_cache_dtype,
spec_config=spec_config,
)
elif is_nemotron_hybrid(config):
config = model_engine.model.model_config.pretrained_config
num_layers = config.hybrid_override_pattern.count("*")
layer_mask = [
char == "*" for char in config.hybrid_override_pattern
]
mamba_num_layers = config.hybrid_override_pattern.count("M")
mamba_layer_mask = [
char == "M" for char in config.hybrid_override_pattern
]
kv_cache_manager = MambaHybridCacheManager(
# mamba cache parameters
config.hidden_size,
config.ssm_state_size,
config.conv_kernel,
config.expand,
config.n_groups,
config.mamba_head_dim,
mamba_num_layers,
mamba_layer_mask,
config.torch_dtype,
# kv cache parameters
executor_config.kv_cache_config,
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
num_layers=num_layers,
layer_mask=layer_mask,
num_kv_heads=num_key_value_heads,
head_dim=head_dim,
tokens_per_block=executor_config.tokens_per_block,
max_seq_len=executor_config.max_seq_len,
max_batch_size=executor_config.max_batch_size,
mapping=mapping,
dtype=kv_cache_dtype,
spec_config=spec_config,
)
else:
kv_cache_manager = KVCacheManager(
executor_config.kv_cache_config,
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
num_layers=num_hidden_layers,
num_kv_heads=num_key_value_heads,
head_dim=head_dim,
tokens_per_block=executor_config.tokens_per_block,
max_seq_len=executor_config.max_seq_len,
max_batch_size=executor_config.max_batch_size,
mapping=mapping,
dtype=kv_cache_dtype,
spec_config=spec_config,
)
# KVCacheManager (Non-draft) modifies the max_seq_len field, update it to executor_config
if model_engine.kv_cache_manager_key == KV_CACHE_MANAGER_KEY:
executor_config.max_seq_len = kv_cache_manager.max_seq_len

assert kv_cache_manager is not None
num_layers = config.hybrid_override_pattern.count("*")
layer_mask = [char == "*" for char in config.hybrid_override_pattern]
mamba_num_layers = config.hybrid_override_pattern.count("M")
mamba_layer_mask = [
char == "M" for char in config.hybrid_override_pattern
]
kv_cache_manager = MambaHybridCacheManager(
# mamba cache parameters
config.hidden_size,
config.ssm_state_size,
config.conv_kernel,
config.expand,
config.n_groups,
config.mamba_head_dim,
mamba_num_layers,
mamba_layer_mask,
config.torch_dtype,
# kv cache parameters
executor_config.kv_cache_config,
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
num_layers=num_layers,
layer_mask=layer_mask,
num_kv_heads=num_key_value_heads,
head_dim=head_dim,
tokens_per_block=executor_config.tokens_per_block,
max_seq_len=executor_config.max_seq_len,
max_batch_size=executor_config.max_batch_size,
mapping=mapping,
dtype=kv_cache_dtype,
spec_config=spec_config,
)
else:
kv_cache_manager = KVCacheManager(
executor_config.kv_cache_config,
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
num_layers=num_hidden_layers,
num_kv_heads=num_key_value_heads,
head_dim=head_dim,
tokens_per_block=executor_config.tokens_per_block,
max_seq_len=executor_config.max_seq_len,
max_batch_size=executor_config.max_batch_size,
mapping=mapping,
dtype=kv_cache_dtype,
spec_config=spec_config,
)
# KVCacheManager (Non-draft) modifies the max_seq_len field, update it to executor_config
if model_engine.kv_cache_manager_key == KV_CACHE_MANAGER_KEY:
executor_config.max_seq_len = kv_cache_manager.max_seq_len

return kv_cache_manager


Expand Down
6 changes: 6 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,11 +333,16 @@ def create_py_executor(executor_config: ExecutorConfig,
origin_seq_len, ctx_chunk_config, draft_model_engine)
del py_executor # free before constructing new
del kv_cache_manager # free before constructing new
del resources[KV_CACHE_MANAGER_KEY]

executor_config.kv_cache_config.max_tokens = kv_cache_max_tokens

with mem_monitor.observe_creation_stage(
_ExecutorCreationStage.KV_CACHE):
# Before estimating KV cache size, a minimal KV cache has been allocated using
# create_kv_cache_manager above, which caps executor_config.max_seq_len. Restoring
# the original value before creating the final KV cache.
executor_config.max_seq_len = max_seq_len
kv_cache_manager = create_kv_cache_manager(model_engine, mapping,
executor_config)
resources[KV_CACHE_MANAGER_KEY] = kv_cache_manager
Expand All @@ -350,6 +355,7 @@ def create_py_executor(executor_config: ExecutorConfig,

if draft_model_engine is not None:
del draft_kv_cache_manager # free before constructing new
del resources[DRAFT_KV_CACHE_MANAGER_KEY]
draft_kv_cache_manager = create_kv_cache_manager(
draft_model_engine, mapping, executor_config)
resources[DRAFT_KV_CACHE_MANAGER_KEY] = draft_kv_cache_manager
Expand Down