Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 30 additions & 14 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1323,7 +1323,6 @@ def previous_seq_slots_device():

num_tokens = len(input_ids)
num_draft_tokens = len(draft_tokens)
num_requests = len(request_ids)
total_num_tokens = len(position_ids)
assert total_num_tokens <= self.max_num_tokens, (
"total_num_tokens should be less than or equal to max_num_tokens")
Expand All @@ -1340,6 +1339,10 @@ def previous_seq_slots_device():
self.draft_tokens_cuda[:len(draft_tokens)].copy_(draft_tokens,
non_blocking=True)
if next_draft_tokens_device is not None:
# Initialize these two values to zeros
self.previous_pos_id_offsets_cuda *= 0
self.previous_kv_lens_offsets_cuda *= 0

if previous_batch_len > 0:
previous_slots = previous_seq_slots_device()
# previous input ids
Expand All @@ -1364,24 +1367,37 @@ def previous_seq_slots_device():
pin_memory=True)
self.previous_pos_indices_cuda[0:previous_batch_tokens].copy_(
previous_pos_indices_host, non_blocking=True)

# The order of requests in a batch: [context requests, generation requests]
# generation requests: ['requests that do not have previous batch', 'requests that already have previous batch', 'dummy requests']
# 1) 'requests that do not have previous batch': disable overlap scheduler or the first step in the generation server of disaggregated serving.
# 2) 'requests that already have previous batch': previous iteration's requests.
# 3) 'dummy requests': pad dummy requests for CUDA graph or attention dp.
# Therefore, both of self.previous_pos_id_offsets_cuda and self.previous_kv_lens_offsets_cuda are also 3 segments.
# For 1) 'requests that do not have previous batch': disable overlap scheduler or the first step in the generation server of disaggregated serving.
# Set these requests' previous_pos_id_offsets and previous_kv_lens_offsets to '0' to skip the value changes in _preprocess_inputs.
# Already set to '0' during initialization.
# For 2) 'requests that already have previous batch': enable overlap scheduler.
# Set their previous_pos_id_offsets and previous_kv_lens_offsets according to new_tokens_lens_device and kv_len_offsets_device.
# For 3) 'dummy requests': pad dummy requests for CUDA graph or attention dp.
# Already set to '0' during initialization.

num_extend_reqeust_wo_dummy = len(extend_requests) - len(
extend_dummy_requests)
self.previous_pos_id_offsets_cuda[
0:previous_batch_tokens].copy_(
(num_extend_reqeust_wo_dummy - previous_batch_len) *
(1 + self.max_draft_len):num_extend_reqeust_wo_dummy *
(1 + self.max_draft_len)].copy_(
new_tokens_lens_device[self.previous_pos_indices_cuda[
0:previous_batch_tokens]],
non_blocking=True)
self.previous_kv_lens_offsets_cuda[0:previous_batch_len].copy_(
kv_len_offsets_device[previous_slots], non_blocking=True)
# for the requests that do not have previous batch, set the previous_pos_id_offsets and
# previous_kv_lens_offsets to zeros to skip the value changes in _preprocess_inputs
self.previous_pos_id_offsets_cuda[
previous_batch_tokens:num_requests *
(1 + self.max_draft_len)] *= 0

self.previous_kv_lens_offsets_cuda[
previous_batch_len:num_requests] *= 0
else:
# change the data to zeros to skip the value changes in _preprocess_inputs
self.previous_pos_id_offsets_cuda *= 0
self.previous_kv_lens_offsets_cuda *= 0
num_extend_reqeust_wo_dummy -
previous_batch_len:num_extend_reqeust_wo_dummy].copy_(
kv_len_offsets_device[previous_slots],
non_blocking=True)

elif new_tokens_device is not None:
seq_slots_device = previous_seq_slots_device()
max_draft_len = max(draft_lens)
Expand Down
4 changes: 0 additions & 4 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1022,10 +1022,6 @@ def _executor_loop_overlap(self):
)

if self.kv_cache_transceiver:
# For generation requests which have completed KV cache transfer
self._prepare_disagg_gen_transmission_complete(
scheduled_batch)

# Return the first token to the client
self._handle_first_token_response(scheduled_batch)

Expand Down
2 changes: 0 additions & 2 deletions tests/integration/test_lists/waives.txt
Original file line number Diff line number Diff line change
Expand Up @@ -371,8 +371,6 @@ perf/test_perf.py::test_perf[bert_large-bench-float16-maxbs:32-input_len:128+512
perf/test_perf.py::test_perf[roberta_base-bench-float16-maxbs:32-input_len:128+512] SKIP (https://nvbugspro.nvidia.com/bug/5295411)
disaggregated/test_disaggregated.py::test_disaggregated_single_gpu_with_mpirun[TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/5328160)
stress_test/stress_test.py::test_run_stress_test[llama-v3-8b-instruct-hf_tp1-stress_time_300s_timeout_450s-MAX_UTILIZATION-pytorch-stress-test] SKIP (https://nvbugs/5328495)
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=0-overlap_scheduler=True] SKIP (https://nvbugs/5322354)
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=2-overlap_scheduler=True] SKIP (https://nvbugs/5322354)
full:B200/examples/test_gemma.py::test_llm_gemma_1gpu_summary_vswa[gemma-3-1b-it-other-bfloat16-8] SKIP (https://nvbugs/5292737)
full:B200/accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype SKIP (https://nvbugs/5295470)
examples/test_mistral.py::test_llm_mistral_v1_1gpu[mistral-7b-v0.1-float16-max_attention_window_size_4096-summarization_long] SKIP (https://nvbugs/5324976)
Expand Down