Skip to content

Commit c083f6f

Browse files
committed
fix nemotron-nas lora test
Signed-off-by: Shahar Mor <[email protected]>
1 parent fa9d22d commit c083f6f

File tree

3 files changed

+11
-59
lines changed

3 files changed

+11
-59
lines changed

tensorrt_llm/_torch/model_config.py

Lines changed: 10 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -298,49 +298,6 @@ def get_bindings_model_config(self,
298298
num_heads = self.pretrained_config.num_attention_heads // (
299299
self.mapping.tp_size * self.mapping.cp_size)
300300

301-
print("SMOR, in get_bindings_model_config")
302-
from IPython import embed
303-
embed()
304-
# Handle both uniform and per-layer KV heads
305-
num_kv_heads_per_layer = getattr(self.pretrained_config,
306-
'num_kv_heads_per_layer', None)
307-
if num_kv_heads_per_layer is not None:
308-
kv_heads_per_layer_raw = num_kv_heads_per_layer
309-
use_per_layer_kv_heads = True
310-
else:
311-
# Check if num_key_value_heads is a list (per-layer) or scalar (uniform)
312-
num_kv_heads_raw = getattr(self.pretrained_config,
313-
'num_key_value_heads', None)
314-
315-
if num_kv_heads_raw is not None and isinstance(
316-
num_kv_heads_raw, list):
317-
kv_heads_per_layer_raw = num_kv_heads_raw
318-
use_per_layer_kv_heads = True
319-
else:
320-
# num_key_value_heads is scalar or None - treat as uniform KV heads
321-
if num_kv_heads_raw is None:
322-
# For uniform models, check: num_key_value_heads (standard) -> num_query_groups (NeMo) -> num_attention_heads
323-
num_kv_heads_raw = getattr(
324-
self.pretrained_config, 'num_query_groups',
325-
self.pretrained_config.num_attention_heads)
326-
327-
num_kv_heads = num_kv_heads_raw // (self.mapping.tp_size *
328-
self.mapping.cp_size)
329-
use_per_layer_kv_heads = False
330-
331-
if use_per_layer_kv_heads:
332-
# TRT-LLM LoRA requires uniform KV heads across layers
333-
if self.lora_config is not None and len(
334-
set(kv_heads_per_layer_raw)) > 1:
335-
raise ValueError(
336-
f"TRT-LLM LoRA requires uniform KV heads across layers, "
337-
f"got: {kv_heads_per_layer_raw}")
338-
# Apply TP/CP scaling to each layer
339-
num_kv_heads_per_layer = [
340-
kv_heads // (self.mapping.tp_size * self.mapping.cp_size)
341-
for kv_heads in kv_heads_per_layer_raw
342-
]
343-
344301
hidden_size = self.pretrained_config.hidden_size // self.mapping.tp_size
345302

346303
model_config_cpp = ModelConfigCpp(
@@ -361,9 +318,18 @@ def get_bindings_model_config(self,
361318
else:
362319
model_config_cpp.tokens_per_block = tokens_per_block
363320

364-
if use_per_layer_kv_heads:
321+
num_key_value_heads = getattr(self.pretrained_config,
322+
"num_key_value_heads", num_heads)
323+
if isinstance(num_key_value_heads, (list, tuple)):
324+
# Per-layer KV heads (e.g., Nemotron-NAS, variable GQA models)
325+
num_kv_heads_per_layer = [
326+
kv_heads // (self.mapping.tp_size * self.mapping.cp_size)
327+
for kv_heads in num_key_value_heads
328+
]
365329
model_config_cpp.num_kv_heads_per_layer = num_kv_heads_per_layer
366330
else:
331+
num_kv_heads = num_key_value_heads // (self.mapping.tp_size *
332+
self.mapping.cp_size)
367333
model_config_cpp.set_num_kv_heads(num_kv_heads)
368334

369335
mlp_hidden_size = None

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -451,25 +451,12 @@ def create_py_executor_instance(
451451

452452
num_experts = _try_infer_num_experts(model_engine.model.model_config)
453453

454-
num_attn_layers = model_binding_config.num_attention_layers()
455-
per_layer_kv_heads = [
456-
model_binding_config.num_kv_heads(i) for i in range(num_attn_layers)
457-
]
458-
num_kv_attention_heads = max(per_layer_kv_heads)
459-
if len(set(per_layer_kv_heads)) > 1:
460-
# NOTE: This code-path is currently untested and not validated. Can fail!
461-
# This support is tracked in TRTLLM-6561
462-
logger.warning(
463-
f"Non-uniform KV heads per layer detected, using max ({num_kv_attention_heads}) for LoRA. "
464-
"This code-path is currently untested and not validated. May fail!"
465-
)
466-
467454
lora_modules = LoraModule.create_lora_modules(
468455
lora_module_names=lora_config.lora_target_modules,
469456
hidden_size=model_binding_config.hidden_size,
470457
mlp_hidden_size=model_binding_config.mlp_hidden_size,
471458
num_attention_heads=model_binding_config.num_heads,
472-
num_kv_attention_heads=num_kv_attention_heads,
459+
num_kv_attention_heads=model_binding_config.num_heads,
473460
attention_head_size=model_binding_config.head_size,
474461
tp_size=mapping.tp_size,
475462
num_experts=num_experts)

tests/unittest/llmapi/test_llm_pytorch.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,6 @@ def _check_contains_expected_message(stdout: str, stderr: str):
290290

291291
# TODO smor: currently Nemotron-Super-49B-v1 with LoRA memory consumption is overly high
292292
# https://jirasw.nvidia.com/browse/TRTLLM-5045
293-
# @pytest.mark.skip(reason="https://nvbugs/5401210")
294293
@skip_gpu_memory_less_than_138gb
295294
def test_nemotron_nas_lora() -> None:
296295
lora_config = LoraConfig(lora_dir=[

0 commit comments

Comments
 (0)