Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,13 +1114,18 @@ def _prepare_tp_inputs(
new_tokens_lens_device = new_tensors_device.new_tokens_lens # [batch]
next_draft_tokens_device = new_tensors_device.next_draft_tokens # [batch, draft_len]

# Requests with draft tokens are treated like extend requests.
# Requests with draft tokens are treated like extend requests. CUDA graph dummy extend
# requests should be at the end of extend_requests.
extend_requests = []
extend_cuda_graph_dummy_requests = []
generation_requests = []
for request in scheduled_requests.generation_requests:
if len(request.py_draft_tokens
) > 0 or next_draft_tokens_device is not None:
extend_requests.append(request)
if request.is_cuda_graph_dummy:
extend_cuda_graph_dummy_requests.append(request)
else:
extend_requests.append(request)
else:
generation_requests.append(request)

Expand All @@ -1130,6 +1135,7 @@ def _prepare_tp_inputs(
torch.tensor([mrope_position_deltas],
dtype=torch.int32).to('cuda',
non_blocking=True))
extend_requests = extend_cuda_graph_dummy_requests + extend_requests

if not self._disable_overlap_scheduler and self.is_spec_decode:
spec_dec_mode = self.spec_config.spec_dec_mode
Expand Down
6 changes: 6 additions & 0 deletions tensorrt_llm/_torch/speculative/mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,12 @@ def update_requests(self, state: SampleStateMTP) -> None:
request.py_decoding_iter += 1
idx += 1

# skip the results of cuda graph dummy requests
if idx == 0:
num_cuda_graph_dummy_requests = len(new_tokens_list) - len(
state.scheduled_requests.generation_requests)
idx += num_cuda_graph_dummy_requests

for request in state.scheduled_requests.generation_requests:
assert not request.py_return_context_logits, "return_context_logits not implemented for MTPSampler"
assert not request.py_return_generation_logits, "return_generation_logits not implemented for MTPSampler"
Expand Down
43 changes: 40 additions & 3 deletions tests/integration/defs/accuracy/test_llm_api_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,22 +568,59 @@ def test_fp8_block_scales(self, mtp_nextn, fp8kv, attention_dp, cuda_graph,
task.evaluate(llm)

@pytest.mark.skip_device_not_contain(["H100"])
def test_fp8_block_scales_cuda_graph_padding(self):
@parametrize_with_ids("mtp_nextn", [0, 2])
def test_fp8_block_scales_cuda_graph_padding(self, mtp_nextn):
# OOM on H100 with default free_gpu_memory_fraction=0.9
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8)
mtp_config = None
if mtp_nextn > 0:
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn)
pytorch_config = PyTorchConfig(disable_overlap_scheduler=False,
use_cuda_graph=True,
cuda_graph_max_batch_size=512,
cuda_graph_padding_enabled=True)
llm = LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/fp8",
kv_cache_config=kv_cache_config,
pytorch_backend_config=pytorch_config)
pytorch_backend_config=pytorch_config,
speculative_config=mtp_config)
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8_BLOCK_SCALES
with llm:
task = CnnDailymail(self.MODEL_NAME)
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)

@pytest.mark.skip_less_device(4)
@pytest.mark.skip_device_not_contain(["H100", "H200"])
@parametrize_with_ids("mtp_nextn", [0, 2])
@parametrize_with_ids("attention_dp", [False, True])
def test_fp8_block_scales_cuda_graph_padding_4gpus(self, mtp_nextn,
attention_dp):
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
mtp_config = None
if mtp_nextn > 0:
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn)
pytorch_config = PyTorchConfig(
disable_overlap_scheduler=False,
use_cuda_graph=True,
cuda_graph_padding_enabled=True,
)
quant_config = QuantConfig()
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES

llm = LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/fp8",
tensor_parallel_size=4,
kv_cache_config=kv_cache_config,
pytorch_backend_config=pytorch_config,
quant_config=quant_config,
enable_attention_dp=attention_dp,
speculative_config=mtp_config)
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8_BLOCK_SCALES
with llm:
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)

@pytest.mark.skip_less_device(4)
@pytest.mark.skip_device_not_contain(["H100", "H200"])
Expand Down
2 changes: 2 additions & 0 deletions tests/integration/test_lists/test-db/l0_dgx_h100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ l0_dgx_h100:
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[pp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[pp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[pp4-mtp_nextn=0-fp8kv=False-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_cuda_graph_padding_4gpus[attention_dp=True-mtp_nextn=0]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_cuda_graph_padding_4gpus[attention_dp=True-mtp_nextn=2]
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8[DeepSeek-V3-Lite-fp8]
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_ucx[DeepSeek-V3-Lite-fp8]
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_attention_dp[DeepSeek-V3-Lite-fp8]
Expand Down
3 changes: 2 additions & 1 deletion tests/integration/test_lists/test-db/l0_h100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ l0_h100:
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_no_kv_cache_reuse[quant_dtype=fp8-mtp_nextn=2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True]
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency]
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8[latency]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_cuda_graph_padding
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_cuda_graph_padding[mtp_nextn=0]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_cuda_graph_padding[mtp_nextn=2]
- test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-False-False]
- test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-instruct-hf-fp8-True-True]
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_tp1_single_gpu[DeepSeek-V3-Lite-fp8]
Expand Down