Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 29 additions & 30 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1144,13 +1144,18 @@ def _prepare_tp_inputs(
new_tokens_lens_device = new_tensors_device.new_tokens_lens # [batch]
next_draft_tokens_device = new_tensors_device.next_draft_tokens # [batch, draft_len]

# Requests with draft tokens are treated like extend requests.
# Requests with draft tokens are treated like extend requests. Dummy extend requests should be
# at the end of extend_requests.
extend_requests = []
extend_dummy_requests = []
generation_requests = []
for request in scheduled_requests.generation_requests:
if len(request.py_draft_tokens
) > 0 or next_draft_tokens_device is not None:
extend_requests.append(request)
if request.is_dummy:
extend_dummy_requests.append(request)
else:
extend_requests.append(request)
else:
generation_requests.append(request)

Expand All @@ -1160,6 +1165,7 @@ def _prepare_tp_inputs(
torch.tensor([mrope_position_deltas],
dtype=torch.int32).to('cuda',
non_blocking=True))
extend_requests += extend_dummy_requests

if not self._disable_overlap_scheduler and self.is_spec_decode:
spec_dec_mode = self.spec_config.spec_dec_mode
Expand All @@ -1169,18 +1175,18 @@ def _prepare_tp_inputs(
# will contain previous batch incices of generation requests
previous_batch_indices = []
previous_pos_indices = []
request_ids_with_previous_batch = []
num_extend_reqs_wo_previous_batch = 0
for request in extend_requests:
# the request has no previous tensor:
# (1) next_draft_tokens_device is None, which means overlap scheduler is disabled; or
# (2) a dummy request; or
# (3) the first step in the generation server of disaggregated serving
if next_draft_tokens_device is None or request.is_dummy or request.py_batch_idx is None:
# get token ids, including input token ids and draft token ids
input_ids.append(request.get_last_tokens(0))
input_ids.extend(request.py_draft_tokens)
draft_tokens.extend(request.py_draft_tokens)
# get token ids, including input token ids and draft token ids. For these dummy requests,
# no need to copy the token ids.
if not request.is_dummy:
input_ids.append(request.get_last_tokens(0))
input_ids.extend(request.py_draft_tokens)
draft_tokens.extend(request.py_draft_tokens)
# get other ids and lengths
num_draft_tokens = len(request.py_draft_tokens)
past_seen_token_num = request.max_beam_num_tokens - 1
Expand All @@ -1200,7 +1206,6 @@ def _prepare_tp_inputs(
# update batch index
request.py_batch_idx = batch_idx
batch_idx += 1
num_extend_reqs_wo_previous_batch += 1
else:
# update batch index
previous_batch_idx = request.py_batch_idx
Expand All @@ -1227,10 +1232,7 @@ def _prepare_tp_inputs(
num_cached_tokens_per_seq.append(past_seen_token_num +
self.max_draft_len + 1)
prompt_lengths.append(request.py_prompt_len)
request_ids_with_previous_batch.append(request.py_request_id)

# move requests with previous batch to the end of the list
request_ids.extend(request_ids_with_previous_batch)
request_ids.append(request.py_request_id)

sequence_lengths.extend([1] * len(generation_requests))
gather_ids.extend(
Expand Down Expand Up @@ -1265,6 +1267,7 @@ def _prepare_tp_inputs(
num_tokens = len(input_ids)
num_draft_tokens = len(draft_tokens)
previous_batchs = len(previous_batch_indices)
num_requests = len(request_ids)
# if exist requests that do not have previous batch, copy input_ids and draft_tokens
if num_tokens > 0:
input_ids = torch.tensor(input_ids,
Expand Down Expand Up @@ -1303,31 +1306,27 @@ def _prepare_tp_inputs(
non_blocking=True)
# prepare data for the preprocess inputs
kv_len_offsets_device = new_tokens_lens_device - self.max_draft_len - 1
pre_tokens_start_idx = num_extend_reqs_wo_previous_batch * (
1 + self.max_draft_len)
pre_tokens_end_idx = pre_tokens_start_idx + previous_batch_tokens
pre_batch_start_idx = num_extend_reqs_wo_previous_batch
pre_batch_end_idx = pre_batch_start_idx + previous_batchs
previous_pos_indices = torch.tensor(previous_pos_indices,
dtype=torch.int,
pin_memory=True)
self.previous_pos_indices_cuda[
pre_tokens_start_idx:pre_tokens_end_idx].copy_(
previous_pos_indices, non_blocking=True)
self.previous_pos_indices_cuda[0:previous_batch_tokens].copy_(
previous_pos_indices, non_blocking=True)
self.previous_pos_id_offsets_cuda[
pre_tokens_start_idx:pre_tokens_end_idx].copy_(
0:previous_batch_tokens].copy_(
new_tokens_lens_device[self.previous_pos_indices_cuda[
pre_tokens_start_idx:pre_tokens_end_idx]],
non_blocking=True)
self.previous_kv_lens_offsets_cuda[
pre_batch_start_idx:pre_batch_end_idx].copy_(
kv_len_offsets_device[
self.previous_batch_indices_cuda[:previous_batchs]],
0:previous_batch_tokens]],
non_blocking=True)
self.previous_kv_lens_offsets_cuda[0:previous_batchs].copy_(
kv_len_offsets_device[
self.previous_batch_indices_cuda[:previous_batchs]],
non_blocking=True)
# for the requests that do not have previous batch, set the previous_pos_id_offsets and
# previous_kv_lens_offsets to zeros to skip the value changes in _preprocess_inputs
self.previous_pos_id_offsets_cuda[:pre_tokens_start_idx] *= 0
self.previous_kv_lens_offsets_cuda[:pre_batch_start_idx] *= 0
self.previous_pos_id_offsets_cuda[
previous_batch_tokens:num_requests *
(1 + self.max_draft_len)] *= 0
self.previous_kv_lens_offsets_cuda[
previous_batchs:num_requests] *= 0
else:
# change the data to zeros to skip the value changes in _preprocess_inputs
self.previous_pos_id_offsets_cuda *= 0
Expand Down
41 changes: 39 additions & 2 deletions tests/integration/defs/accuracy/test_llm_api_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,8 +588,12 @@ def test_fp8_block_scales(self, mtp_nextn, fp8kv, attention_dp, cuda_graph,
task.evaluate(llm)

@pytest.mark.skip_device_not_contain(["H100"])
def test_fp8_block_scales_cuda_graph_padding(self):
@parametrize_with_ids("mtp_nextn", [0, 2])
def test_fp8_block_scales_cuda_graph_padding(self, mtp_nextn):
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
mtp_config = None
if mtp_nextn > 0:
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn)
pytorch_config = dict(
disable_overlap_scheduler=False,
use_cuda_graph=True,
Expand All @@ -598,7 +602,40 @@ def test_fp8_block_scales_cuda_graph_padding(self):
)
llm = LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/fp8",
kv_cache_config=kv_cache_config,
**pytorch_config)
**pytorch_config,
speculative_config=mtp_config)
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8_BLOCK_SCALES
with llm:
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)

@pytest.mark.skip_less_device(4)
@pytest.mark.skip_device_not_contain(["H100", "H200"])
@parametrize_with_ids("mtp_nextn", [0, 2])
@parametrize_with_ids("attention_dp", [False, True])
def test_fp8_block_scales_cuda_graph_padding_4gpus(self, mtp_nextn,
attention_dp):
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
mtp_config = None
if mtp_nextn > 0:
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn)
pytorch_config = dict(
disable_overlap_scheduler=False,
use_cuda_graph=True,
cuda_graph_padding_enabled=True,
)
quant_config = QuantConfig()
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES

llm = LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/fp8",
tensor_parallel_size=4,
kv_cache_config=kv_cache_config,
**pytorch_config,
quant_config=quant_config,
enable_attention_dp=attention_dp,
speculative_config=mtp_config)
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8_BLOCK_SCALES
with llm:
task = MMLU(self.MODEL_NAME)
Expand Down
2 changes: 2 additions & 0 deletions tests/integration/test_lists/test-db/l0_dgx_h100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ l0_dgx_h100:
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[pp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[pp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[pp4-mtp_nextn=0-fp8kv=False-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_cuda_graph_padding_4gpus[attention_dp=True-mtp_nextn=0]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_cuda_graph_padding_4gpus[attention_dp=True-mtp_nextn=2]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus_static_eplb
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8[DeepSeek-V3-Lite-fp8]
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_ucx[DeepSeek-V3-Lite-fp8]
Expand Down
3 changes: 2 additions & 1 deletion tests/integration/test_lists/test-db/l0_h100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ l0_h100:
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_no_kv_cache_reuse[quant_dtype=fp8-mtp_nextn=2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True]
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency]
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8[latency]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_cuda_graph_padding
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_cuda_graph_padding[mtp_nextn=0]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_cuda_graph_padding[mtp_nextn=2]
- test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-False-False]
- test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-instruct-hf-fp8-True-True]
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_tp1_single_gpu[DeepSeek-V3-Lite-fp8]
Expand Down