Skip to content

Commit 69846c6

Browse files
authored
[https://nvbugs/5427801][fix] Torch compile support for Llama4 and Ea… (#6978)
Signed-off-by: Jin Li <[email protected]>
1 parent df00c81 commit 69846c6

File tree

5 files changed

+74
-63
lines changed

5 files changed

+74
-63
lines changed

tensorrt_llm/_torch/models/modeling_llama.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -165,22 +165,9 @@ def _forward_nope(
165165
q, k, v = self.split_qkv(q, k, v)
166166
q = self._attention_scaling(q, position_ids)
167167

168-
out_scale = None
169-
out_scale_sf = None
170-
if self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4 or self.o_proj.has_fp8_block_scales:
171-
out_scale = self.o_proj.inv_input_scale
172-
if self.o_proj.has_nvfp4 and self.support_nvfp4_output:
173-
out_scale_sf = self.o_proj.input_scale
174-
175168
q, k, v = self.convert_qkv(q, k, v)
176-
attn_output = self.attn.forward(q,
177-
k,
178-
v,
179-
attn_metadata,
180-
out_scale=out_scale,
181-
out_scale_sf=out_scale_sf,
182-
attention_mask=attention_mask,
183-
mrope_config=mrope_config)
169+
attn_output = self.forward_impl(q, k, v, attn_metadata, attention_mask,
170+
None, None, mrope_config)
184171

185172
if isinstance(attn_output, tuple):
186173
attn_output = Fp4QuantizedTensor(attn_output[0], attn_output[1])

tensorrt_llm/_torch/models/modeling_speculative.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,12 @@ def __init__(self, model: TModel, model_config: ModelConfig[TConfig]):
362362
model_config,
363363
model_config.mapping)
364364

365+
if draft_config is not None:
366+
for key, value in draft_config.extra_attrs.items():
367+
assert key in ('attn_layers', 'mla_layers')
368+
assert key in model_config.extra_attrs
369+
model_config.extra_attrs[key].update(value)
370+
365371
def forward(
366372
self,
367373
attn_metadata: AttentionMetadata,

tensorrt_llm/_torch/modules/attention.py

Lines changed: 61 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def attn_custom_op_inplace(
9292
mrope_position_deltas,
9393
attention_window_size,
9494
attention_mask_data,
95-
False,
95+
enable_attn_nvfp4_output=False,
9696
output=output)
9797

9898

@@ -372,6 +372,58 @@ def _attn_impl(
372372
return attn_output[0], attn_output[1]
373373
return attn_output, None
374374

375+
def forward_impl(
376+
self,
377+
q: torch.Tensor,
378+
k: Optional[torch.Tensor],
379+
v: Optional[torch.Tensor],
380+
attn_metadata: AttentionMetadata,
381+
attention_mask: AttentionMask,
382+
attention_window_size: Optional[int],
383+
attention_mask_data: Optional[torch.Tensor],
384+
mrope_config: Optional[dict],
385+
):
386+
mrope_rotary_cos_sin = None
387+
mrope_position_deltas = None
388+
if mrope_config is not None:
389+
if "mrope_rotary_cos_sin" in mrope_config:
390+
mrope_rotary_cos_sin = mrope_config["mrope_rotary_cos_sin"]
391+
if "mrope_position_deltas" in mrope_config:
392+
mrope_position_deltas = mrope_config["mrope_position_deltas"]
393+
394+
# Currently only TRTLLM and FLASHINFER are torch compile compatible backends.
395+
# Only enable custom inplace op when torch compiling.
396+
use_custom_inplace_op = (self.register_to_config
397+
and (self.attn_backend == "TRTLLM"
398+
or self.attn_backend == "FLASHINFER")
399+
and is_torch_compiling())
400+
401+
if use_custom_inplace_op:
402+
output = self.create_output(q)
403+
attn_custom_op_inplace(
404+
q,
405+
k,
406+
v,
407+
attention_mask,
408+
mrope_rotary_cos_sin,
409+
mrope_position_deltas,
410+
attention_window_size,
411+
attention_mask_data,
412+
self.layer_idx_str,
413+
output,
414+
)
415+
else:
416+
output, output_sf = self._attn_impl(q, k, v, attn_metadata,
417+
attention_mask,
418+
mrope_rotary_cos_sin,
419+
mrope_position_deltas,
420+
attention_window_size,
421+
attention_mask_data)
422+
if output_sf is not None:
423+
output = Fp4QuantizedTensor(output, output_sf)
424+
425+
return output
426+
375427
def forward(
376428
self,
377429
position_ids: Optional[torch.IntTensor],
@@ -414,54 +466,18 @@ def forward(
414466
if qkv_lora is not None:
415467
qkv = qkv + qkv_lora
416468

417-
mrope_rotary_cos_sin = None
418-
mrope_position_deltas = None
419-
if mrope_config is not None:
420-
if "mrope_rotary_cos_sin" in mrope_config:
421-
mrope_rotary_cos_sin = mrope_config["mrope_rotary_cos_sin"]
422-
if "mrope_position_deltas" in mrope_config:
423-
mrope_position_deltas = mrope_config["mrope_position_deltas"]
424-
425-
output = None
426-
427469
q, k, v = qkv, None, None
428470
q, k, v = self.apply_rope(q, k, v, position_ids)
429471
q, k, v = self.convert_qkv(q, k, v)
430472

431-
# Currently only TRTLLM and FLASHINFER are torch compile compatible backends.
432-
# Only enable custom inplace op when torch compiling.
433-
use_custom_inplace_op = (self.register_to_config
434-
and (self.attn_backend == "TRTLLM"
435-
or self.attn_backend == "FLASHINFER")
436-
and is_torch_compiling())
437-
if use_custom_inplace_op:
438-
output = self.create_output(q)
439-
attn_custom_op_inplace(
440-
q,
441-
k,
442-
v,
443-
attention_mask,
444-
mrope_rotary_cos_sin,
445-
mrope_position_deltas,
446-
attention_window_size,
447-
attention_mask_data,
448-
self.layer_idx_str,
449-
output=output,
450-
)
451-
else:
452-
output, output_sf = self._attn_impl(
453-
q,
454-
k,
455-
v,
456-
attn_metadata,
457-
attention_mask,
458-
mrope_rotary_cos_sin,
459-
mrope_position_deltas,
460-
attention_window_size,
461-
attention_mask_data,
462-
)
463-
if output_sf is not None:
464-
output = Fp4QuantizedTensor(output, output_sf)
473+
output = self.forward_impl(q,
474+
k,
475+
v,
476+
attn_metadata,
477+
attention_mask,
478+
attention_window_size,
479+
attention_mask_data,
480+
mrope_config=mrope_config)
465481

466482
attn_output = self.o_proj(output,
467483
all_reduce_params=all_reduce_params,

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -550,11 +550,14 @@ def test_fp8_eagle3(self, tp_size, pp_size, ep_size, torch_compile):
550550
speculative_model_dir=eagle_model_dir)
551551
kv_cache_config = KvCacheConfig(enable_block_reuse=False,
552552
free_gpu_memory_fraction=0.75)
553+
torch_compile_config = TorchCompileConfig(
554+
enable_fullgraph=True,
555+
enable_piecewise_cuda_graph=True,
556+
max_num_streams=3) if torch_compile else None
553557
pytorch_config = dict(
554558
cuda_graph_config=CudaGraphConfig(max_batch_size=8),
555559
enable_attention_dp=False,
556-
torch_compile_config=TorchCompileConfig(
557-
enable_fullgraph=torch_compile))
560+
torch_compile_config=torch_compile_config)
558561
with LLM(model_path,
559562
kv_cache_config=kv_cache_config,
560563
tensor_parallel_size=tp_size,

tests/integration/test_lists/waives.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,6 @@ examples/test_bert.py::test_llm_bert_general[compare_hf-enable_remove_input_padd
260260
test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-mixture_text_image-True] SKIP (https://nvbugs/5430124)
261261
examples/test_granite.py::test_granite_bf16_lora[granite-3.0-1b-a400m-instruct] SKIP (https://nvbugs/5431132)
262262
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=1-ctx_pp=4] SKIP (https://nvbugs/5434320)
263-
accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_fp8_eagle3[tp8-torch_compile=True] SKIP (https://nvbugs/5427801)
264263
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=2-ctx_pp=4] SKIP (https://nvbugs/5434320)
265264
accuracy/test_llm_api.py::TestLlama3_2_1B::test_int4_awq_int8_kv_cache SKIP (https://nvbugs/5433541)
266265
accuracy/test_llm_api.py::TestLlama3_2_1B::test_fp8_pp2 SKIP (https://nvbugs/5433541)

0 commit comments

Comments
 (0)