Skip to content

Commit 3544b10

Browse files
lfr-0531k-l-lambda
authored andcommitted
[https://nvbugs/5277592][fix] fix cuda graph padding for spec decoding (only for 0.20) (NVIDIA#5058)
Signed-off-by: Fanrong Li <[email protected]>
1 parent 9935911 commit 3544b10

File tree

5 files changed

+58
-6
lines changed

5 files changed

+58
-6
lines changed

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,13 +1113,18 @@ def _prepare_tp_inputs(
11131113
new_tokens_lens_device = new_tensors_device.new_tokens_lens # [batch]
11141114
next_draft_tokens_device = new_tensors_device.next_draft_tokens # [batch, draft_len]
11151115

1116-
# Requests with draft tokens are treated like extend requests.
1116+
# Requests with draft tokens are treated like extend requests. CUDA graph dummy extend
1117+
# requests should be at the end of extend_requests.
11171118
extend_requests = []
1119+
extend_cuda_graph_dummy_requests = []
11181120
generation_requests = []
11191121
for request in scheduled_requests.generation_requests:
11201122
if len(request.py_draft_tokens
11211123
) > 0 or next_draft_tokens_device is not None:
1122-
extend_requests.append(request)
1124+
if request.is_cuda_graph_dummy:
1125+
extend_cuda_graph_dummy_requests.append(request)
1126+
else:
1127+
extend_requests.append(request)
11231128
else:
11241129
generation_requests.append(request)
11251130

@@ -1129,6 +1134,7 @@ def _prepare_tp_inputs(
11291134
torch.tensor([mrope_position_deltas],
11301135
dtype=torch.int32).to('cuda',
11311136
non_blocking=True))
1137+
extend_requests = extend_cuda_graph_dummy_requests + extend_requests
11321138

11331139
if not self._disable_overlap_scheduler and self.is_spec_decode:
11341140
spec_dec_mode = self.spec_config.spec_dec_mode

tensorrt_llm/_torch/speculative/mtp.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,12 @@ def update_requests(self, state: SampleStateMTP) -> None:
281281
request.py_decoding_iter += 1
282282
idx += 1
283283

284+
# skip the results of cuda graph dummy requests
285+
if idx == 0:
286+
num_cuda_graph_dummy_requests = len(new_tokens_list) - len(
287+
state.scheduled_requests.generation_requests)
288+
idx += num_cuda_graph_dummy_requests
289+
284290
for request in state.scheduled_requests.generation_requests:
285291
assert not request.py_return_context_logits, "return_context_logits not implemented for MTPSampler"
286292
assert not request.py_return_generation_logits, "return_generation_logits not implemented for MTPSampler"

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -576,22 +576,59 @@ def test_fp8_block_scales(self, mtp_nextn, fp8kv, attention_dp, cuda_graph,
576576
task.evaluate(llm)
577577

578578
@pytest.mark.skip_device_not_contain(["H100"])
579-
def test_fp8_block_scales_cuda_graph_padding(self):
579+
@parametrize_with_ids("mtp_nextn", [0, 2])
580+
def test_fp8_block_scales_cuda_graph_padding(self, mtp_nextn):
580581
# OOM on H100 with default free_gpu_memory_fraction=0.9
581582
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8)
583+
mtp_config = None
584+
if mtp_nextn > 0:
585+
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn)
582586
pytorch_config = PyTorchConfig(disable_overlap_scheduler=False,
583587
use_cuda_graph=True,
584588
cuda_graph_max_batch_size=512,
585589
cuda_graph_padding_enabled=True)
586590
llm = LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/fp8",
587591
kv_cache_config=kv_cache_config,
588-
pytorch_backend_config=pytorch_config)
592+
pytorch_backend_config=pytorch_config,
593+
speculative_config=mtp_config)
589594
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8_BLOCK_SCALES
590595
with llm:
591-
task = CnnDailymail(self.MODEL_NAME)
596+
task = MMLU(self.MODEL_NAME)
592597
task.evaluate(llm)
598+
task = GSM8K(self.MODEL_NAME)
599+
task.evaluate(llm)
600+
601+
@pytest.mark.skip_less_device(4)
602+
@pytest.mark.skip_device_not_contain(["H100", "H200"])
603+
@parametrize_with_ids("mtp_nextn", [0, 2])
604+
@parametrize_with_ids("attention_dp", [False, True])
605+
def test_fp8_block_scales_cuda_graph_padding_4gpus(self, mtp_nextn,
606+
attention_dp):
607+
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
608+
mtp_config = None
609+
if mtp_nextn > 0:
610+
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn)
611+
pytorch_config = PyTorchConfig(
612+
disable_overlap_scheduler=False,
613+
use_cuda_graph=True,
614+
cuda_graph_padding_enabled=True,
615+
)
616+
quant_config = QuantConfig()
617+
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
618+
619+
llm = LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/fp8",
620+
tensor_parallel_size=4,
621+
kv_cache_config=kv_cache_config,
622+
pytorch_backend_config=pytorch_config,
623+
quant_config=quant_config,
624+
enable_attention_dp=attention_dp,
625+
speculative_config=mtp_config)
626+
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8_BLOCK_SCALES
627+
with llm:
593628
task = MMLU(self.MODEL_NAME)
594629
task.evaluate(llm)
630+
task = GSM8K(self.MODEL_NAME)
631+
task.evaluate(llm)
595632

596633
@pytest.mark.skip_less_device(4)
597634
@pytest.mark.skip_device_not_contain(["H100", "H200"])

tests/integration/test_lists/test-db/l0_dgx_h100.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ l0_dgx_h100:
101101
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[pp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False]
102102
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[pp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=True]
103103
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[pp4-mtp_nextn=0-fp8kv=False-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
104+
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_cuda_graph_padding_4gpus[attention_dp=True-mtp_nextn=0]
105+
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_cuda_graph_padding_4gpus[attention_dp=True-mtp_nextn=2]
104106
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8[DeepSeek-V3-Lite-fp8]
105107
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_ucx[DeepSeek-V3-Lite-fp8]
106108
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_attention_dp[DeepSeek-V3-Lite-fp8]

tests/integration/test_lists/test-db/l0_h100.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ l0_h100:
5151
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_no_kv_cache_reuse[quant_dtype=fp8-mtp_nextn=2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True]
5252
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency]
5353
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8[latency]
54-
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_cuda_graph_padding
54+
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_cuda_graph_padding[mtp_nextn=0]
55+
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_cuda_graph_padding[mtp_nextn=2]
5556
- test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-False-False]
5657
- test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-instruct-hf-fp8-True-True]
5758
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_tp1_single_gpu[DeepSeek-V3-Lite-fp8]

0 commit comments

Comments
 (0)