Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 10 additions & 43 deletions tensorrt_llm/_torch/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,48 +299,6 @@ def get_bindings_model_config(self,
num_heads = self.pretrained_config.num_attention_heads // (
self.mapping.tp_size * self.mapping.cp_size)

# Handle both uniform and per-layer KV heads
num_kv_heads_per_layer = getattr(self.pretrained_config,
'num_kv_heads_per_layer', None)
if num_kv_heads_per_layer is not None:
# For models with per-layer KV heads, like nemotron-nas
kv_heads_per_layer_raw = num_kv_heads_per_layer
use_per_layer_kv_heads = True
else:
# Check if num_key_value_heads is a list (per-layer) or scalar (uniform)
num_kv_heads_raw = getattr(self.pretrained_config,
'num_key_value_heads', None)

if num_kv_heads_raw is not None and isinstance(
num_kv_heads_raw, list):
# num_key_value_heads is a list - treat as per-layer KV heads
kv_heads_per_layer_raw = num_kv_heads_raw
use_per_layer_kv_heads = True
else:
# num_key_value_heads is scalar or None - treat as uniform KV heads
if num_kv_heads_raw is None:
# For uniform models, check: num_key_value_heads (standard) -> num_query_groups (NeMo) -> num_attention_heads
num_kv_heads_raw = getattr(
self.pretrained_config, 'num_query_groups',
self.pretrained_config.num_attention_heads)

num_kv_heads = num_kv_heads_raw // (self.mapping.tp_size *
self.mapping.cp_size)
use_per_layer_kv_heads = False

if use_per_layer_kv_heads:
# TRT-LLM LoRA requires uniform KV heads across layers
if self.lora_config is not None and len(
set(kv_heads_per_layer_raw)) > 1:
raise ValueError(
f"TRT-LLM LoRA requires uniform KV heads across layers, "
f"got: {kv_heads_per_layer_raw}")
# Apply TP/CP scaling to each layer
num_kv_heads_per_layer = [
kv_heads // (self.mapping.tp_size * self.mapping.cp_size)
for kv_heads in kv_heads_per_layer_raw
]

hidden_size = self.pretrained_config.hidden_size // self.mapping.tp_size

model_config_cpp = ModelConfigCpp(
Expand All @@ -361,9 +319,18 @@ def get_bindings_model_config(self,
else:
model_config_cpp.tokens_per_block = tokens_per_block

if use_per_layer_kv_heads:
num_key_value_heads = getattr(self.pretrained_config,
"num_key_value_heads", num_heads)
if isinstance(num_key_value_heads, (list, tuple)):
# Per-layer KV heads (e.g., Nemotron-NAS, variable GQA models)
num_kv_heads_per_layer = [
kv_heads // (self.mapping.tp_size * self.mapping.cp_size)
for kv_heads in num_key_value_heads
]
model_config_cpp.num_kv_heads_per_layer = num_kv_heads_per_layer
else:
num_kv_heads = num_key_value_heads // (self.mapping.tp_size *
self.mapping.cp_size)
model_config_cpp.set_num_kv_heads(num_kv_heads)

mlp_hidden_size = None
Expand Down
18 changes: 8 additions & 10 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,18 +451,16 @@ def create_py_executor_instance(

num_experts = _try_infer_num_experts(model_engine.model.model_config)

num_attn_layers = model_binding_config.num_attention_layers()
per_layer_kv_heads = [
model_binding_config.num_kv_heads(i) for i in range(num_attn_layers)
]
num_kv_attention_heads = max(per_layer_kv_heads)
if len(set(per_layer_kv_heads)) > 1:
# NOTE: This code-path is currently untested and not validated. Can fail!
# This support is tracked in TRTLLM-6561
num_kv_attention_heads_per_layer = model_binding_config.num_kv_heads_per_layer
if max(num_kv_attention_heads_per_layer) != min(
num_kv_attention_heads_per_layer):
logger.warning(
f"Non-uniform KV heads per layer detected, using max ({num_kv_attention_heads}) for LoRA. "
"This code-path is currently untested and not validated. May fail!"
"Defining LORA with per-layer KV heads is not supported for LORA, using the max number of KV heads per layer"
)
num_kv_attention_heads = max(num_kv_attention_heads_per_layer)
else:
# all layers have the same number of KV heads
num_kv_attention_heads = num_kv_attention_heads_per_layer[0]

lora_modules = LoraModule.create_lora_modules(
lora_module_names=lora_config.lora_target_modules,
Expand Down
1 change: 0 additions & 1 deletion tests/unittest/llmapi/test_llm_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,6 @@ def test_llama_7b_lora_config_overrides_peft_cache_config():

# TODO smor: currently Nemotron-Super-49B-v1 with LoRA memory consumption is overly high
# https://jirasw.nvidia.com/browse/TRTLLM-5045
@pytest.mark.skip(reason="https://nvbugs/5401210")
@skip_gpu_memory_less_than_138gb
def test_nemotron_nas_lora() -> None:
lora_config = LoraConfig(lora_dir=[
Expand Down