Skip to content

Commit f5abc63

Browse files
committed
Minor edit
Signed-off-by: Iman Tabrizian <[email protected]> Fix accuracy for disagg + eagle Signed-off-by: Iman Tabrizian <[email protected]>
1 parent 18335dd commit f5abc63

File tree

4 files changed

+72
-34
lines changed

4 files changed

+72
-34
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -875,13 +875,6 @@ def _executor_loop(self):
875875
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
876876
)
877877

878-
if self.draft_model_engine is not None or is_ngram or self.drafter is not None:
879-
# REVIEW: This might need to be changed. The reason we call prepare_draft_requests
880-
# on all active requests before scheduling is to make the scheduler aware of KV pages used
881-
# by draft tokens.
882-
self._prepare_draft_requests(
883-
fitting_disagg_gen_init_requests)
884-
885878
if self.kv_cache_transceiver:
886879
# For requests that are fitting disagg gen init, also prepare resources for KV cache manager
887880
self._prepare_disagg_gen_init(
@@ -978,7 +971,6 @@ def _prepare_draft_requests(self, requests):
978971
# Set draft tokens here to make the KV cache manager
979972
# and scheduler aware of them.
980973
for req in requests:
981-
# TODO: enable draft tokens in context phase
982974
if req.state not in (LlmRequestState.GENERATION_IN_PROGRESS,
983975
LlmRequestState.DISAGG_GENERATION_INIT):
984976
continue
@@ -1541,7 +1533,8 @@ def _prepare_disagg_gen_init(self, fitting_disagg_gen_init_requests):
15411533
ResourceManagerType.SEQ_SLOT_MANAGER,
15421534
ResourceManagerType.SPEC_RESOURCE_MANAGER,
15431535
ResourceManagerType.DRAFT_KV_CACHE_MANAGER):
1544-
if resource_mgr_type in self.resource_manager.resource_managers:
1536+
if resource_mgr_type in self.resource_manager.resource_managers and self.resource_manager.resource_managers[
1537+
resource_mgr_type] is not None:
15451538
self.resource_manager.resource_managers[
15461539
resource_mgr_type].prepare_resources(
15471540
disagg_gen_init_to_prepare)

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -307,30 +307,30 @@ def handle_logits(request: LlmRequest, tokens: list[int], count=1):
307307
if request.state != LlmRequestState.GENERATION_COMPLETE:
308308
new_token = new_tokens_list[token_idx]
309309
num_tokens = request.add_new_token(new_token, beam_idx)
310-
if self._handle_stop_criteria(request, new_token, num_tokens,
311-
beam_idx):
312-
continue
313-
314-
# Accept draft tokens (if we have any) if and only if they match the new
315-
# token exactly.
316-
num_accepted = 0
317-
new_tokens = [new_token]
318-
for draft_token in request.py_draft_tokens:
319-
if draft_token != new_token:
320-
# Reject.
321-
break
322-
num_accepted += 1
323-
new_token = new_tokens_list[token_idx + num_accepted]
324-
num_tokens = request.add_new_token(new_token, beam_idx)
325-
new_tokens.append(num_tokens) # `num_tokens`->`new_token`
326-
327-
if self._handle_stop_criteria(request, new_token,
310+
if not self._handle_stop_criteria(request, new_token,
328311
num_tokens, beam_idx):
329-
break
330-
handle_logits(request, new_tokens, num_accepted)
331-
request.py_decoding_iter += 1
332-
request.py_num_accepted_draft_tokens = num_accepted
333-
request.py_rewind_len = request.py_draft_pages_allocated - num_accepted
312+
313+
# Accept draft tokens (if we have any) if and only if they match the new
314+
# token exactly.
315+
num_accepted = 0
316+
new_tokens = [new_token]
317+
for draft_token in request.py_draft_tokens:
318+
if draft_token != new_token:
319+
# Reject.
320+
break
321+
num_accepted += 1
322+
new_token = new_tokens_list[token_idx + num_accepted]
323+
num_tokens = request.add_new_token(new_token, beam_idx)
324+
new_tokens.append(
325+
num_tokens) # `num_tokens`->`new_token`
326+
327+
if self._handle_stop_criteria(request, new_token,
328+
num_tokens, beam_idx):
329+
break
330+
handle_logits(request, new_tokens, num_accepted)
331+
request.py_decoding_iter += 1
332+
request.py_num_accepted_draft_tokens = num_accepted
333+
request.py_rewind_len = request.py_draft_pages_allocated - num_accepted
334334
advance_idx(len(request.py_draft_tokens) + 1)
335335

336336
for request in generation_requests:

tests/integration/defs/accuracy/test_disaggregated_serving.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
9090
# Common arguments for both servers
9191
common_args = [
9292
trtllm_serve_path, model_name, "--host", "localhost", "--backend",
93-
"pytorch"
93+
"pytorch", "--max_num_tokens", f"{13393*2}"
9494
]
9595
if tensor_parallel_size > 1:
9696
common_args.append(f"--tp_size={tensor_parallel_size}")
@@ -209,9 +209,53 @@ def test_auto_dtype(self, disable_overlap_scheduler):
209209
task = GSM8K(self.MODEL_NAME)
210210
task.evaluate(llm)
211211

212+
@pytest.mark.parametrize("overlap_scheduler", [False])
213+
def test_eagle3(self, overlap_scheduler):
214+
sepculative_decoding_config = {
215+
"decoding_type": "Eagle",
216+
"max_draft_len": 4,
217+
"pytorch_weights_path":
218+
f"{llm_models_root()}/EAGLE3-LLaMA3.1-Instruct-8B",
219+
"eagle3_one_model": False
220+
}
221+
kv_cache_config = {
222+
"free_gpu_memory_fraction": 0.5,
223+
"enable_block_reuse": False
224+
}
225+
ctx_server_config = {
226+
"disable_overlap_scheduler": True,
227+
"speculative_config": sepculative_decoding_config,
228+
"kv_cache_config": kv_cache_config,
229+
"max_num_tokens": 13393 * 2
230+
}
231+
gen_server_config = {
232+
"disable_overlap_scheduler": not overlap_scheduler,
233+
"speculative_config": sepculative_decoding_config,
234+
"kv_cache_config": kv_cache_config,
235+
"max_num_tokens": 13393 * 2
236+
}
237+
disaggregated_server_config = {
238+
"hostname": "localhost",
239+
"port": 8000,
240+
"backend": "pytorch",
241+
"context_servers": {
242+
"num_instances": 1,
243+
"urls": ["localhost:8001"]
244+
},
245+
"generation_servers": {
246+
"num_instances": 1,
247+
"urls": ["localhost:8002"]
248+
}
249+
}
250+
with launch_disaggregated_llm(disaggregated_server_config,
251+
ctx_server_config, gen_server_config,
252+
self.MODEL_PATH) as llm:
253+
task = GSM8K(self.MODEL_NAME)
254+
task.evaluate(llm)
255+
212256

213-
@pytest.mark.timeout(3600)
214257
@pytest.mark.skip_less_device_memory(140000)
258+
@pytest.mark.timeout(3600)
215259
class TestLlama4ScoutInstruct(LlmapiAccuracyTestHarness):
216260
MODEL_NAME = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
217261
MODEL_PATH = f"{llm_models_root()}/llama4-models/Llama-4-Scout-17B-16E-Instruct"

tests/integration/test_lists/test-db/l0_h100.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ l0_h100:
6767
- disaggregated/test_workers.py::test_workers_kv_cache_aware_router[TinyLlama-1.1B-Chat-v1.0]
6868
- disaggregated/test_workers.py::test_workers_kv_cache_aware_router_eviction[TinyLlama-1.1B-Chat-v1.0]
6969
- disaggregated/test_disaggregated_single_gpu.py::test_disaggregated_llama_context_capacity[False-False-DeepSeek-V3-Lite-fp8/fp8]
70+
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[False]
7071
- test_e2e.py::test_trtllm_bench_iteration_log[PyTorch-streaming-meta-llama/Llama-3.1-8B-llama-3.1-model/Meta-Llama-3.1-8B]
7172
- test_e2e.py::test_trtllm_bench_iteration_log[PyTorch-non-streaming-meta-llama/Llama-3.1-8B-llama-3.1-model/Meta-Llama-3.1-8B]
7273
- condition:

0 commit comments

Comments
 (0)