Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,8 +869,9 @@ def _executor_loop(self):

self._pad_attention_dp_dummy_request()

if self.draft_model_engine is not None or is_ngram:
self._prepare_draft_requests()
if self.draft_model_engine is not None or is_ngram or hasattr(
self, 'drafter') and self.drafter is not None:
self._prepare_draft_requests(self.active_requests)

scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
)
Expand Down Expand Up @@ -966,13 +967,13 @@ def _executor_loop(self):
iter_stats=iter_stats,
iter_start_time=iter_start_time))

def _prepare_draft_requests(self):
def _prepare_draft_requests(self, requests):
try:
# Set draft tokens here to make the KV cache manager
# and scheduler aware of them.
for req in self.active_requests:
# TODO: enable draft tokens in context phase
if req.state != LlmRequestState.GENERATION_IN_PROGRESS:
for req in requests:
if req.state not in (LlmRequestState.GENERATION_IN_PROGRESS,
LlmRequestState.DISAGG_GENERATION_INIT):
continue
req.py_last_draft_tokens = req.py_draft_tokens
max_draft_len = self.model_engine.spec_config.max_draft_tokens
Expand Down Expand Up @@ -1528,9 +1529,16 @@ def _prepare_disagg_gen_init(self, fitting_disagg_gen_init_requests):
disagg_gen_init_to_prepare.generation_requests = []
disagg_gen_init_to_prepare.paused_requests = []

self.resource_manager.resource_managers[
ResourceManagerType.KV_CACHE_MANAGER].prepare_resources(
disagg_gen_init_to_prepare)
for resource_mgr_type in (
ResourceManagerType.KV_CACHE_MANAGER,
ResourceManagerType.SEQ_SLOT_MANAGER,
ResourceManagerType.SPEC_RESOURCE_MANAGER,
ResourceManagerType.DRAFT_KV_CACHE_MANAGER):
if resource_mgr_type in self.resource_manager.resource_managers and self.resource_manager.resource_managers[
resource_mgr_type] is not None:
self.resource_manager.resource_managers[
resource_mgr_type].prepare_resources(
disagg_gen_init_to_prepare)

# Trigger KV cache exchange for new disagg_gen_init_requests
self._recv_disagg_gen_cache(fitting_disagg_gen_init_requests)
Expand Down Expand Up @@ -1790,7 +1798,6 @@ def _prepare_draft_batch(
# This is the first time the draft model is seeing this request.
# Prepare a context request. We discard the first token and take
# the newly decoded one - this is the convention for EAGLE 2 and 3.
assert num_draft_tokens == 0
new_request = LlmRequest(
request_id=request.py_request_id,
max_new_tokens=request.py_max_new_tokens,
Expand Down
46 changes: 23 additions & 23 deletions tensorrt_llm/_torch/pyexecutor/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,30 +307,30 @@ def handle_logits(request: LlmRequest, tokens: list[int], count=1):
if request.state != LlmRequestState.GENERATION_COMPLETE:
new_token = new_tokens_list[token_idx]
num_tokens = request.add_new_token(new_token, beam_idx)
if self._handle_stop_criteria(request, new_token, num_tokens,
beam_idx):
continue

# Accept draft tokens (if we have any) if and only if they match the new
# token exactly.
num_accepted = 0
new_tokens = [new_token]
for draft_token in request.py_draft_tokens:
if draft_token != new_token:
# Reject.
break
num_accepted += 1
new_token = new_tokens_list[token_idx + num_accepted]
num_tokens = request.add_new_token(new_token, beam_idx)
new_tokens.append(num_tokens) # `num_tokens`->`new_token`

if self._handle_stop_criteria(request, new_token,
if not self._handle_stop_criteria(request, new_token,
num_tokens, beam_idx):
break
handle_logits(request, new_tokens, num_accepted)
request.py_decoding_iter += 1
request.py_num_accepted_draft_tokens = num_accepted
request.py_rewind_len = request.py_draft_pages_allocated - num_accepted

# Accept draft tokens (if we have any) if and only if they match the new
# token exactly.
num_accepted = 0
new_tokens = [new_token]
for draft_token in request.py_draft_tokens:
if draft_token != new_token:
# Reject.
break
num_accepted += 1
new_token = new_tokens_list[token_idx + num_accepted]
num_tokens = request.add_new_token(new_token, beam_idx)
new_tokens.append(
num_tokens) # `num_tokens`->`new_token`

if self._handle_stop_criteria(request, new_token,
num_tokens, beam_idx):
break
handle_logits(request, new_tokens, num_accepted)
request.py_decoding_iter += 1
request.py_num_accepted_draft_tokens = num_accepted
request.py_rewind_len = request.py_draft_pages_allocated - num_accepted
advance_idx(len(request.py_draft_tokens) + 1)

for request in generation_requests:
Expand Down
75 changes: 62 additions & 13 deletions tests/integration/defs/accuracy/test_disaggregated_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
trtllm_serve_path, model_name, "--host", "localhost", "--backend",
"pytorch"
]

if tensor_parallel_size > 1:
common_args.append(f"--tp_size={tensor_parallel_size}")

Expand All @@ -104,18 +105,22 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
env_gen["TRTLLM_USE_UCX_KVCACHE"] = "1"
env_gen["CUDA_VISIBLE_DEVICES"] = ",".join(
map(str, range(tensor_parallel_size, 2 * tensor_parallel_size)))

with (MyThreadPoolExecutor(max_workers=16) as thread_pool, temp_dir,
popen(common_args + [
"--port", "8001", "--extra_llm_api_options",
ctx_server_config_path
],
env=env_ctx) as ctx_server,
popen(common_args + [
"--port", "8002", "--extra_llm_api_options",
gen_server_config_path
],
env=env_gen) as gen_server,
ctx_server_args = common_args + [
"--port", "8001", "--extra_llm_api_options", ctx_server_config_path
]
gen_server_args = common_args + [
"--port", "8002", "--extra_llm_api_options", gen_server_config_path
]
if "max_num_tokens" in ctx_server_config:
ctx_server_args.append(
f"--max_num_tokens={ctx_server_config['max_num_tokens']}")
if "max_num_tokens" in gen_server_config:
gen_server_args.append(
f"--max_num_tokens={gen_server_config['max_num_tokens']}")

with (MyThreadPoolExecutor(max_workers=16) as
thread_pool, temp_dir, popen(ctx_server_args, env=env_ctx) as
ctx_server, popen(gen_server_args, env=env_gen) as gen_server,
popen([
trtllm_serve_path, "disaggregated", "-c",
disaggregated_serving_config_path, "--server_start_timeout",
Expand Down Expand Up @@ -209,9 +214,53 @@ def test_auto_dtype(self, disable_overlap_scheduler):
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)

@pytest.mark.parametrize("overlap_scheduler", [False])
def test_eagle3(self, overlap_scheduler):
speculative_decoding_config = {
"decoding_type": "Eagle",
"max_draft_len": 4,
"pytorch_weights_path":
f"{llm_models_root()}/EAGLE3-LLaMA3.1-Instruct-8B",
"eagle3_one_model": False
}
kv_cache_config = {
"free_gpu_memory_fraction": 0.5,
"enable_block_reuse": False
}
ctx_server_config = {
"disable_overlap_scheduler": True,
"speculative_config": speculative_decoding_config,
"kv_cache_config": kv_cache_config,
"max_num_tokens": 13393 * 2
}
gen_server_config = {
"disable_overlap_scheduler": not overlap_scheduler,
"speculative_config": speculative_decoding_config,
"kv_cache_config": kv_cache_config,
"max_num_tokens": 13393 * 2
}
disaggregated_server_config = {
"hostname": "localhost",
"port": 8000,
"backend": "pytorch",
"context_servers": {
"num_instances": 1,
"urls": ["localhost:8001"]
},
"generation_servers": {
"num_instances": 1,
"urls": ["localhost:8002"]
}
}
with launch_disaggregated_llm(disaggregated_server_config,
ctx_server_config, gen_server_config,
self.MODEL_PATH) as llm:
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)


@pytest.mark.timeout(3600)
@pytest.mark.skip_less_device_memory(140000)
@pytest.mark.timeout(3600)
class TestLlama4ScoutInstruct(LlmapiAccuracyTestHarness):
MODEL_NAME = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
MODEL_PATH = f"{llm_models_root()}/llama4-models/Llama-4-Scout-17B-16E-Instruct"
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_lists/test-db/l0_dgx_h100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ l0_dgx_h100:
- disaggregated/test_disaggregated.py::test_disaggregated_overlap[TinyLlama-1.1B-Chat-v1.0]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[False]
- test_e2e.py::test_ptp_quickstart_advanced_bs1
- condition:
ranges:
Expand Down