Skip to content

Commit 75bdb6c

Browse files
authored
Merge branch 'main' into add_decoding_case
2 parents 219ddd2 + 8cf3faa commit 75bdb6c

File tree

25 files changed

+406
-183
lines changed

25 files changed

+406
-183
lines changed

cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,7 @@ void initConfigBindings(nb::module_& m)
477477
c.getExtendedRuntimePerfKnobConfig(), c.getDebugConfig(), c.getRecvPollPeriodMs(),
478478
c.getMaxSeqIdleMicroseconds(), c.getSpecDecConfig(), c.getGuidedDecodingConfig(),
479479
c.getAdditionalModelOutputs(), c.getCacheTransceiverConfig(), c.getGatherGenerationLogits(),
480-
c.getPromptTableOffloading(), c.getEnableTrtOverlap());
480+
c.getPromptTableOffloading(), c.getEnableTrtOverlap(), c.getFailFastOnAttentionWindowTooLarge());
481481
auto pickle_tuple = nb::make_tuple(cpp_states, nb::getattr(self, "__dict__"));
482482
return pickle_tuple;
483483
};
@@ -490,7 +490,7 @@ void initConfigBindings(nb::module_& m)
490490
}
491491

492492
auto cpp_states = nb::cast<nb::tuple>(state[0]);
493-
if (cpp_states.size() != 28)
493+
if (cpp_states.size() != 29)
494494
{
495495
throw std::runtime_error("Invalid cpp_states!");
496496
}
@@ -525,7 +525,8 @@ void initConfigBindings(nb::module_& m)
525525
nb::cast<std::optional<tle::CacheTransceiverConfig>>(cpp_states[24]), // CacheTransceiverConfig
526526
nb::cast<bool>(cpp_states[25]), // GatherGenerationLogits
527527
nb::cast<bool>(cpp_states[26]), // PromptTableOffloading
528-
nb::cast<bool>(cpp_states[27]) // EnableTrtOverlap
528+
nb::cast<bool>(cpp_states[27]), // EnableTrtOverlap
529+
nb::cast<bool>(cpp_states[28]) // FailFastOnAttentionWindowTooLarge
529530
);
530531

531532
// Restore Python data
@@ -564,7 +565,8 @@ void initConfigBindings(nb::module_& m)
564565
std::optional<tle::CacheTransceiverConfig>, // CacheTransceiverConfig
565566
bool, // GatherGenerationLogits
566567
bool, // PromptTableOffloading
567-
bool // EnableTrtOverlap
568+
bool, // EnableTrtOverlap
569+
bool // FailFastOnAttentionWindowTooLarge
568570
>(),
569571
nb::arg("max_beam_width") = 1, nb::arg("scheduler_config") = tle::SchedulerConfig(),
570572
nb::arg("kv_cache_config") = tle::KvCacheConfig(), nb::arg("enable_chunked_context") = false,
@@ -582,7 +584,7 @@ void initConfigBindings(nb::module_& m)
582584
nb::arg("spec_dec_config") = nb::none(), nb::arg("guided_decoding_config") = nb::none(),
583585
nb::arg("additional_model_outputs") = nb::none(), nb::arg("cache_transceiver_config") = nb::none(),
584586
nb::arg("gather_generation_logits") = false, nb::arg("mm_embedding_offloading") = false,
585-
nb::arg("enable_trt_overlap") = false)
587+
nb::arg("enable_trt_overlap") = false, nb::arg("fail_fast_on_attention_window_too_large") = false)
586588
.def_prop_rw("max_beam_width", &tle::ExecutorConfig::getMaxBeamWidth, &tle::ExecutorConfig::setMaxBeamWidth)
587589
.def_prop_rw("max_batch_size", &tle::ExecutorConfig::getMaxBatchSize, &tle::ExecutorConfig::setMaxBatchSize)
588590
.def_prop_rw("max_num_tokens", &tle::ExecutorConfig::getMaxNumTokens, &tle::ExecutorConfig::setMaxNumTokens)
@@ -632,6 +634,9 @@ void initConfigBindings(nb::module_& m)
632634
&tle::ExecutorConfig::setPromptTableOffloading)
633635
.def_prop_rw(
634636
"enable_trt_overlap", &tle::ExecutorConfig::getEnableTrtOverlap, &tle::ExecutorConfig::setEnableTrtOverlap)
637+
.def_prop_rw("fail_fast_on_attention_window_too_large",
638+
&tle::ExecutorConfig::getFailFastOnAttentionWindowTooLarge,
639+
&tle::ExecutorConfig::setFailFastOnAttentionWindowTooLarge)
635640
.def("__getstate__", executorConfigGetState)
636641
.def("__setstate__", executorConfigSetState);
637642
}

examples/llm-api/quickstart_advanced.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import argparse
22

33
from tensorrt_llm import LLM, SamplingParams
4-
from tensorrt_llm.llmapi import (CudaGraphConfig, DraftTargetDecodingConfig,
5-
EagleDecodingConfig, KvCacheConfig, MoeConfig,
6-
MTPDecodingConfig, NGramDecodingConfig,
7-
TorchCompileConfig)
4+
from tensorrt_llm.llmapi import (AutoDecodingConfig, CudaGraphConfig,
5+
DraftTargetDecodingConfig, EagleDecodingConfig,
6+
KvCacheConfig, MoeConfig, MTPDecodingConfig,
7+
NGramDecodingConfig, TorchCompileConfig)
88

99
example_prompts = [
1010
"Hello, my name is",
@@ -181,6 +181,8 @@ def setup_llm(args, **kwargs):
181181
is_use_oldest=True,
182182
is_public_pool=True,
183183
)
184+
elif spec_decode_algo == "AUTO":
185+
spec_config = AutoDecodingConfig()
184186
else:
185187
spec_config = None
186188

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,4 @@ etcd3
6161
blake3
6262
llguidance==0.7.29
6363
soundfile
64+
triton==3.3.1

tensorrt_llm/_torch/model_config.py

Lines changed: 10 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -299,48 +299,6 @@ def get_bindings_model_config(self,
299299
num_heads = self.pretrained_config.num_attention_heads // (
300300
self.mapping.tp_size * self.mapping.cp_size)
301301

302-
# Handle both uniform and per-layer KV heads
303-
num_kv_heads_per_layer = getattr(self.pretrained_config,
304-
'num_kv_heads_per_layer', None)
305-
if num_kv_heads_per_layer is not None:
306-
# For models with per-layer KV heads, like nemotron-nas
307-
kv_heads_per_layer_raw = num_kv_heads_per_layer
308-
use_per_layer_kv_heads = True
309-
else:
310-
# Check if num_key_value_heads is a list (per-layer) or scalar (uniform)
311-
num_kv_heads_raw = getattr(self.pretrained_config,
312-
'num_key_value_heads', None)
313-
314-
if num_kv_heads_raw is not None and isinstance(
315-
num_kv_heads_raw, list):
316-
# num_key_value_heads is a list - treat as per-layer KV heads
317-
kv_heads_per_layer_raw = num_kv_heads_raw
318-
use_per_layer_kv_heads = True
319-
else:
320-
# num_key_value_heads is scalar or None - treat as uniform KV heads
321-
if num_kv_heads_raw is None:
322-
# For uniform models, check: num_key_value_heads (standard) -> num_query_groups (NeMo) -> num_attention_heads
323-
num_kv_heads_raw = getattr(
324-
self.pretrained_config, 'num_query_groups',
325-
self.pretrained_config.num_attention_heads)
326-
327-
num_kv_heads = num_kv_heads_raw // (self.mapping.tp_size *
328-
self.mapping.cp_size)
329-
use_per_layer_kv_heads = False
330-
331-
if use_per_layer_kv_heads:
332-
# TRT-LLM LoRA requires uniform KV heads across layers
333-
if self.lora_config is not None and len(
334-
set(kv_heads_per_layer_raw)) > 1:
335-
raise ValueError(
336-
f"TRT-LLM LoRA requires uniform KV heads across layers, "
337-
f"got: {kv_heads_per_layer_raw}")
338-
# Apply TP/CP scaling to each layer
339-
num_kv_heads_per_layer = [
340-
kv_heads // (self.mapping.tp_size * self.mapping.cp_size)
341-
for kv_heads in kv_heads_per_layer_raw
342-
]
343-
344302
hidden_size = self.pretrained_config.hidden_size // self.mapping.tp_size
345303

346304
model_config_cpp = ModelConfigCpp(
@@ -361,9 +319,18 @@ def get_bindings_model_config(self,
361319
else:
362320
model_config_cpp.tokens_per_block = tokens_per_block
363321

364-
if use_per_layer_kv_heads:
322+
num_key_value_heads = getattr(self.pretrained_config,
323+
"num_key_value_heads", num_heads)
324+
if isinstance(num_key_value_heads, (list, tuple)):
325+
# Per-layer KV heads (e.g., Nemotron-NAS, variable GQA models)
326+
num_kv_heads_per_layer = [
327+
kv_heads // (self.mapping.tp_size * self.mapping.cp_size)
328+
for kv_heads in num_key_value_heads
329+
]
365330
model_config_cpp.num_kv_heads_per_layer = num_kv_heads_per_layer
366331
else:
332+
num_kv_heads = num_key_value_heads // (self.mapping.tp_size *
333+
self.mapping.cp_size)
367334
model_config_cpp.set_num_kv_heads(num_kv_heads)
368335

369336
mlp_hidden_size = None

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -451,18 +451,16 @@ def create_py_executor_instance(
451451

452452
num_experts = _try_infer_num_experts(model_engine.model.model_config)
453453

454-
num_attn_layers = model_binding_config.num_attention_layers()
455-
per_layer_kv_heads = [
456-
model_binding_config.num_kv_heads(i) for i in range(num_attn_layers)
457-
]
458-
num_kv_attention_heads = max(per_layer_kv_heads)
459-
if len(set(per_layer_kv_heads)) > 1:
460-
# NOTE: This code-path is currently untested and not validated. Can fail!
461-
# This support is tracked in TRTLLM-6561
454+
num_kv_attention_heads_per_layer = model_binding_config.num_kv_heads_per_layer
455+
if max(num_kv_attention_heads_per_layer) != min(
456+
num_kv_attention_heads_per_layer):
462457
logger.warning(
463-
f"Non-uniform KV heads per layer detected, using max ({num_kv_attention_heads}) for LoRA. "
464-
"This code-path is currently untested and not validated. May fail!"
458+
"Defining LORA with per-layer KV heads is not supported for LORA, using the max number of KV heads per layer"
465459
)
460+
num_kv_attention_heads = max(num_kv_attention_heads_per_layer)
461+
else:
462+
# all layers have the same number of KV heads
463+
num_kv_attention_heads = num_kv_attention_heads_per_layer[0]
466464

467465
lora_modules = LoraModule.create_lora_modules(
468466
lora_module_names=lora_config.lora_target_modules,

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,3 +477,17 @@ def executor_request_to_llm_request(
477477
py_multimodal_data=getattr(executor_request, "py_multimodal_data",
478478
None))
479479
return llm_request
480+
481+
482+
def get_draft_token_length(request: LlmRequest) -> int:
483+
"""Get the length of draft tokens for a given request.
484+
485+
Args:
486+
request: The LlmRequest to get draft token length for
487+
488+
Returns:
489+
The number of draft tokens, or 0 if no draft tokens exist
490+
"""
491+
if request.py_draft_tokens is not None:
492+
return len(request.py_draft_tokens)
493+
return 0

0 commit comments

Comments
 (0)