Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,15 +848,15 @@ def _executor_loop(self):
finished_requests = []

if scheduled_batch.batch_size > 0:
self.resource_manager.prepare_resources(scheduled_batch)
if self.draft_model_engine is not None:
self._prepare_draft_tokens(scheduled_batch)

if self.kv_cache_transceiver:
# For generation requests which have completed KV cache transfer
self._prepare_disagg_gen_transmission_complete(
scheduled_batch)

self.resource_manager.prepare_resources(scheduled_batch)
if self.draft_model_engine is not None:
self._prepare_draft_tokens(scheduled_batch)

batch_outputs = self._forward_step(scheduled_batch)

sample_state = self._sample_async(scheduled_batch,
Expand Down Expand Up @@ -980,6 +980,11 @@ def _executor_loop_overlap(self):
self._pause_requests(scheduled_batch.paused_requests)

if scheduled_batch.batch_size > 0:
if self.kv_cache_transceiver:
# For generation requests which have completed KV cache transfer
self._prepare_disagg_gen_transmission_complete(
scheduled_batch)

self.resource_manager.prepare_resources(scheduled_batch)

generation_requests = scheduled_batch.generation_requests
Expand All @@ -999,11 +1004,6 @@ def _executor_loop_overlap(self):
new_generation_requests.append(req)
scheduled_batch.generation_requests = new_generation_requests

if self.kv_cache_transceiver:
# For generation requests which have completed KV cache transfer
self._prepare_disagg_gen_transmission_complete(
scheduled_batch)

previous_tensors_device = self.previous_batch and self.previous_batch.sample_state.device

batch_outputs = self._forward_step(scheduled_batch,
Expand Down
6 changes: 2 additions & 4 deletions tensorrt_llm/_torch/speculative/mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,7 @@ def update_requests(self, state: SampleStateMTP) -> None:
if self._draft_meet_max_token_stop_criteria(
request, num_tokens, beam_idx):
should_stop = True
if not should_stop:
request.py_draft_tokens = next_draft_tokens_list[idx]
request.py_draft_tokens = next_draft_tokens_list[idx]
request.py_decoding_iter += 1
idx += 1

Expand All @@ -282,8 +281,7 @@ def update_requests(self, state: SampleStateMTP) -> None:
if self._draft_meet_max_token_stop_criteria(
request, num_tokens, beam_idx):
should_stop = True
if not should_stop:
request.py_draft_tokens = next_draft_tokens_list[idx]
request.py_draft_tokens = next_draft_tokens_list[idx]
request.py_rewind_len = self.draft_len - (num_new_tokens - 1)
request.py_decoding_iter += 1
idx += 1
Expand Down
53 changes: 52 additions & 1 deletion tests/integration/defs/accuracy/test_disaggregated_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from tensorrt_llm.llmapi import CompletionOutput, RequestOutput, SamplingParams
from tensorrt_llm.llmapi.llm_args import LlmArgs

from ..conftest import llm_models_root
from ..conftest import llm_models_root, parametrize_with_ids, skip_pre_hopper
from ..trt_test_alternative import popen
from .accuracy_core import GSM8K, MMLU, LlmapiAccuracyTestHarness

Expand Down Expand Up @@ -252,3 +252,54 @@ def test_auto_dtype(self, overlap_scheduler):
task.evaluate(llm)
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)


class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
MODEL_NAME = "deepseek-ai/DeepSeek-V3-Lite"
MODEL_PATH = f"{llm_models_root()}/DeepSeek-V3-Lite/bf16"

@parametrize_with_ids("overlap_scheduler", [True, False])
@parametrize_with_ids("mtp_nextn",
[0, pytest.param(2, marks=skip_pre_hopper)])
def test_auto_dtype(self, overlap_scheduler, mtp_nextn):
ctx_server_config = {
"pytorch_backend_config": {
"disable_overlap_scheduler": True
}
}
gen_server_config = {
"pytorch_backend_config": {
"disable_overlap_scheduler": not overlap_scheduler
}
}
if mtp_nextn > 0:
ctx_server_config["speculative_config"] = {
"decoding_type": "MTP",
"num_nextn_predict_layers": mtp_nextn
}
gen_server_config["speculative_config"] = {
"decoding_type": "MTP",
"num_nextn_predict_layers": mtp_nextn
}
disaggregated_server_config = {
"hostname": "localhost",
"port": 8000,
"backend": "pytorch",
"context_servers": {
"num_instances": 1,
"urls": ["localhost:8001"]
},
"generation_servers": {
"num_instances": 1,
"urls": ["localhost:8002"]
}
}
with launch_disaggregated_llm(disaggregated_server_config,
ctx_server_config,
gen_server_config,
self.MODEL_PATH,
tensor_parallel_size=4) as llm:
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
4 changes: 4 additions & 0 deletions tests/integration/test_lists/test-db/l0_dgx_h200.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ l0_dgx_h200:
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[latency] # 1h
- accuracy/test_disaggregated_serving.py::TestLlama4ScoutInstruct::test_auto_dtype[True]
- accuracy/test_disaggregated_serving.py::TestLlama4ScoutInstruct::test_auto_dtype[False]
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=0-overlap_scheduler=True]
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=0-overlap_scheduler=False]
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=2-overlap_scheduler=True]
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=2-overlap_scheduler=False]
- unittest/_torch/multi_gpu_modeling/test_llama4.py::test_llama4[pp1-ep1-disable_adp-enable_graph-tp8-trtllm-scout]
- unittest/_torch/multi_gpu_modeling/test_llama4.py::test_llama4[pp1-ep4-enable_adp-enable_graph-tp8-trtllm-scout]
- unittest/llmapi/test_llm_pytorch.py::test_nemotron_nas_lora