Skip to content

Commit da59cf7

Browse files
committed
Minor edit
Signed-off-by: Iman Tabrizian <[email protected]> Fix accuracy for disagg + eagle Signed-off-by: Iman Tabrizian <[email protected]>
1 parent 18335dd commit da59cf7

File tree

4 files changed

+90
-46
lines changed

4 files changed

+90
-46
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -869,19 +869,13 @@ def _executor_loop(self):
869869

870870
self._pad_attention_dp_dummy_request()
871871

872-
if self.draft_model_engine is not None or is_ngram or self.drafter is not None:
872+
if self.draft_model_engine is not None or is_ngram or hasattr(
873+
self, 'drafter') and self.drafter is not None:
873874
self._prepare_draft_requests(self.active_requests)
874875

875876
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
876877
)
877878

878-
if self.draft_model_engine is not None or is_ngram or self.drafter is not None:
879-
# REVIEW: This might need to be changed. The reason we call prepare_draft_requests
880-
# on all active requests before scheduling is to make the scheduler aware of KV pages used
881-
# by draft tokens.
882-
self._prepare_draft_requests(
883-
fitting_disagg_gen_init_requests)
884-
885879
if self.kv_cache_transceiver:
886880
# For requests that are fitting disagg gen init, also prepare resources for KV cache manager
887881
self._prepare_disagg_gen_init(
@@ -978,7 +972,6 @@ def _prepare_draft_requests(self, requests):
978972
# Set draft tokens here to make the KV cache manager
979973
# and scheduler aware of them.
980974
for req in requests:
981-
# TODO: enable draft tokens in context phase
982975
if req.state not in (LlmRequestState.GENERATION_IN_PROGRESS,
983976
LlmRequestState.DISAGG_GENERATION_INIT):
984977
continue
@@ -1541,7 +1534,8 @@ def _prepare_disagg_gen_init(self, fitting_disagg_gen_init_requests):
15411534
ResourceManagerType.SEQ_SLOT_MANAGER,
15421535
ResourceManagerType.SPEC_RESOURCE_MANAGER,
15431536
ResourceManagerType.DRAFT_KV_CACHE_MANAGER):
1544-
if resource_mgr_type in self.resource_manager.resource_managers:
1537+
if resource_mgr_type in self.resource_manager.resource_managers and self.resource_manager.resource_managers[
1538+
resource_mgr_type] is not None:
15451539
self.resource_manager.resource_managers[
15461540
resource_mgr_type].prepare_resources(
15471541
disagg_gen_init_to_prepare)

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -307,30 +307,30 @@ def handle_logits(request: LlmRequest, tokens: list[int], count=1):
307307
if request.state != LlmRequestState.GENERATION_COMPLETE:
308308
new_token = new_tokens_list[token_idx]
309309
num_tokens = request.add_new_token(new_token, beam_idx)
310-
if self._handle_stop_criteria(request, new_token, num_tokens,
311-
beam_idx):
312-
continue
313-
314-
# Accept draft tokens (if we have any) if and only if they match the new
315-
# token exactly.
316-
num_accepted = 0
317-
new_tokens = [new_token]
318-
for draft_token in request.py_draft_tokens:
319-
if draft_token != new_token:
320-
# Reject.
321-
break
322-
num_accepted += 1
323-
new_token = new_tokens_list[token_idx + num_accepted]
324-
num_tokens = request.add_new_token(new_token, beam_idx)
325-
new_tokens.append(num_tokens) # `num_tokens`->`new_token`
326-
327-
if self._handle_stop_criteria(request, new_token,
310+
if not self._handle_stop_criteria(request, new_token,
328311
num_tokens, beam_idx):
329-
break
330-
handle_logits(request, new_tokens, num_accepted)
331-
request.py_decoding_iter += 1
332-
request.py_num_accepted_draft_tokens = num_accepted
333-
request.py_rewind_len = request.py_draft_pages_allocated - num_accepted
312+
313+
# Accept draft tokens (if we have any) if and only if they match the new
314+
# token exactly.
315+
num_accepted = 0
316+
new_tokens = [new_token]
317+
for draft_token in request.py_draft_tokens:
318+
if draft_token != new_token:
319+
# Reject.
320+
break
321+
num_accepted += 1
322+
new_token = new_tokens_list[token_idx + num_accepted]
323+
num_tokens = request.add_new_token(new_token, beam_idx)
324+
new_tokens.append(
325+
num_tokens) # `num_tokens`->`new_token`
326+
327+
if self._handle_stop_criteria(request, new_token,
328+
num_tokens, beam_idx):
329+
break
330+
handle_logits(request, new_tokens, num_accepted)
331+
request.py_decoding_iter += 1
332+
request.py_num_accepted_draft_tokens = num_accepted
333+
request.py_rewind_len = request.py_draft_pages_allocated - num_accepted
334334
advance_idx(len(request.py_draft_tokens) + 1)
335335

336336
for request in generation_requests:

tests/integration/defs/accuracy/test_disaggregated_serving.py

Lines changed: 62 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
9292
trtllm_serve_path, model_name, "--host", "localhost", "--backend",
9393
"pytorch"
9494
]
95+
9596
if tensor_parallel_size > 1:
9697
common_args.append(f"--tp_size={tensor_parallel_size}")
9798

@@ -104,18 +105,22 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
104105
env_gen["TRTLLM_USE_UCX_KVCACHE"] = "1"
105106
env_gen["CUDA_VISIBLE_DEVICES"] = ",".join(
106107
map(str, range(tensor_parallel_size, 2 * tensor_parallel_size)))
107-
108-
with (MyThreadPoolExecutor(max_workers=16) as thread_pool, temp_dir,
109-
popen(common_args + [
110-
"--port", "8001", "--extra_llm_api_options",
111-
ctx_server_config_path
112-
],
113-
env=env_ctx) as ctx_server,
114-
popen(common_args + [
115-
"--port", "8002", "--extra_llm_api_options",
116-
gen_server_config_path
117-
],
118-
env=env_gen) as gen_server,
108+
ctx_server_args = common_args + [
109+
"--port", "8001", "--extra_llm_api_options", ctx_server_config_path
110+
]
111+
gen_server_args = common_args + [
112+
"--port", "8002", "--extra_llm_api_options", gen_server_config_path
113+
]
114+
if "max_num_tokens" in ctx_server_config:
115+
ctx_server_args.append(
116+
f"--max_num_tokens={ctx_server_config['max_num_tokens']}")
117+
if "max_num_tokens" in gen_server_config:
118+
gen_server_args.append(
119+
f"--max_num_tokens={gen_server_config['max_num_tokens']}")
120+
121+
with (MyThreadPoolExecutor(max_workers=16) as
122+
thread_pool, temp_dir, popen(ctx_server_args, env=env_ctx) as
123+
ctx_server, popen(gen_server_args, env=env_gen) as gen_server,
119124
popen([
120125
trtllm_serve_path, "disaggregated", "-c",
121126
disaggregated_serving_config_path, "--server_start_timeout",
@@ -209,9 +214,53 @@ def test_auto_dtype(self, disable_overlap_scheduler):
209214
task = GSM8K(self.MODEL_NAME)
210215
task.evaluate(llm)
211216

217+
@pytest.mark.parametrize("overlap_scheduler", [False])
218+
def test_eagle3(self, overlap_scheduler):
219+
sepculative_decoding_config = {
220+
"decoding_type": "Eagle",
221+
"max_draft_len": 4,
222+
"pytorch_weights_path":
223+
f"{llm_models_root()}/EAGLE3-LLaMA3.1-Instruct-8B",
224+
"eagle3_one_model": False
225+
}
226+
kv_cache_config = {
227+
"free_gpu_memory_fraction": 0.5,
228+
"enable_block_reuse": False
229+
}
230+
ctx_server_config = {
231+
"disable_overlap_scheduler": True,
232+
"speculative_config": sepculative_decoding_config,
233+
"kv_cache_config": kv_cache_config,
234+
"max_num_tokens": 13393 * 2
235+
}
236+
gen_server_config = {
237+
"disable_overlap_scheduler": not overlap_scheduler,
238+
"speculative_config": sepculative_decoding_config,
239+
"kv_cache_config": kv_cache_config,
240+
"max_num_tokens": 13393 * 2
241+
}
242+
disaggregated_server_config = {
243+
"hostname": "localhost",
244+
"port": 8000,
245+
"backend": "pytorch",
246+
"context_servers": {
247+
"num_instances": 1,
248+
"urls": ["localhost:8001"]
249+
},
250+
"generation_servers": {
251+
"num_instances": 1,
252+
"urls": ["localhost:8002"]
253+
}
254+
}
255+
with launch_disaggregated_llm(disaggregated_server_config,
256+
ctx_server_config, gen_server_config,
257+
self.MODEL_PATH) as llm:
258+
task = GSM8K(self.MODEL_NAME)
259+
task.evaluate(llm)
260+
212261

213-
@pytest.mark.timeout(3600)
214262
@pytest.mark.skip_less_device_memory(140000)
263+
@pytest.mark.timeout(3600)
215264
class TestLlama4ScoutInstruct(LlmapiAccuracyTestHarness):
216265
MODEL_NAME = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
217266
MODEL_PATH = f"{llm_models_root()}/llama4-models/Llama-4-Scout-17B-16E-Instruct"

tests/integration/test_lists/test-db/l0_h100.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ l0_h100:
6767
- disaggregated/test_workers.py::test_workers_kv_cache_aware_router[TinyLlama-1.1B-Chat-v1.0]
6868
- disaggregated/test_workers.py::test_workers_kv_cache_aware_router_eviction[TinyLlama-1.1B-Chat-v1.0]
6969
- disaggregated/test_disaggregated_single_gpu.py::test_disaggregated_llama_context_capacity[False-False-DeepSeek-V3-Lite-fp8/fp8]
70+
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[False]
7071
- test_e2e.py::test_trtllm_bench_iteration_log[PyTorch-streaming-meta-llama/Llama-3.1-8B-llama-3.1-model/Meta-Llama-3.1-8B]
7172
- test_e2e.py::test_trtllm_bench_iteration_log[PyTorch-non-streaming-meta-llama/Llama-3.1-8B-llama-3.1-model/Meta-Llama-3.1-8B]
7273
- condition:

0 commit comments

Comments
 (0)