Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,7 +974,8 @@ def _prepare_draft_requests(self):
# and scheduler aware of them.
for req in self.active_requests:
# TODO: enable draft tokens in context phase
if req.state != LlmRequestState.GENERATION_IN_PROGRESS:
if req.state not in (LlmRequestState.GENERATION_IN_PROGRESS,
LlmRequestState.DISAGG_GENERATION_INIT):
continue
req.py_last_draft_tokens = req.py_draft_tokens
max_draft_len = self.model_engine.spec_config.max_draft_tokens
Expand Down Expand Up @@ -1533,12 +1534,17 @@ def _prepare_disagg_gen_init(self, fitting_disagg_gen_init_requests):
disagg_gen_init_to_prepare.generation_requests = []
disagg_gen_init_to_prepare.paused_requests = []

self.resource_manager.resource_managers[
ResourceManagerType.KV_CACHE_MANAGER].prepare_resources(
disagg_gen_init_to_prepare)
self.resource_manager.resource_managers[
ResourceManagerType.SEQ_SLOT_MANAGER].prepare_resources(
disagg_gen_init_to_prepare)
for resource_mgr_type in (
ResourceManagerType.KV_CACHE_MANAGER,
ResourceManagerType.SEQ_SLOT_MANAGER,
ResourceManagerType.SPEC_RESOURCE_MANAGER,
ResourceManagerType.DRAFT_KV_CACHE_MANAGER):
if (resource_mgr_type in self.resource_manager.resource_managers
and self.resource_manager.
resource_managers[resource_mgr_type] is not None):
self.resource_manager.resource_managers[
resource_mgr_type].prepare_resources(
disagg_gen_init_to_prepare)

# Trigger KV cache exchange for new disagg_gen_init_requests
self._recv_disagg_gen_cache(fitting_disagg_gen_init_requests)
Expand Down
15 changes: 10 additions & 5 deletions tensorrt_llm/_torch/speculative/ngram.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def get_draft_tokens(
pattern = tuple(sequence[l:l + size])
new_match = tuple(sequence[l + size:r])
if pattern not in pool or \
(not self.is_keep_all and len(match) > pool[pattern][0]):
(not self.is_keep_all and len(new_match) > len(pool[pattern][0])):
# Replace the match if
# 1. the pattern does not exist in the pool
# 2. only one match is kept, and the new match is longer (MRU)
Expand Down Expand Up @@ -202,10 +202,15 @@ def prepare_draft_tokens(
self,
scheduled_requests: ScheduledRequests,
) -> None:

for request in sorted(scheduled_requests.generation_requests,
key=lambda r: r.py_batch_idx):
# Add new token to a copy of the generated tokens to find new daft tokens
# Sort by request_id when py_batch_idx is None as a fallback.
# This happens in the disagg case: for a set of new requests, we draft
# before forward_step, so py_batch_idx is not assigned.
for request in sorted(
scheduled_requests.generation_requests,
key=lambda r:
(r.py_batch_idx is None, r.py_batch_idx or r.request_id),
):
# Add new token to a copy of the generated tokens to find new draft tokens
prefix = list(request.get_tokens()[0]) # Get a copy

# Generate draft tokens
Expand Down
2 changes: 2 additions & 0 deletions tests/integration/defs/accuracy/references/gsm8k.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
meta-llama/Llama-3.1-8B-Instruct:
- accuracy: 74.20
- spec_dec_algo: NGRAM
accuracy: 74.20
- quant_algo: FP8
accuracy: 74.30
- quant_algo: FP8
Expand Down
41 changes: 41 additions & 0 deletions tests/integration/defs/accuracy/test_disaggregated_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,47 @@ def test_auto_dtype(self, disable_overlap_scheduler):
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)

def test_ngram(self):
speculative_decoding_config = {
"decoding_type": "NGram",
"prompt_lookup_num_tokens": 4,
"max_matching_ngram_size": 4,
"is_keep_all": True,
"is_use_oldest": True,
"is_public_pool": True
}
kv_cache_config = {
"free_gpu_memory_fraction": 0.5,
"enable_block_reuse": False
}
ctx_server_config = {
"disable_overlap_scheduler": True,
"kv_cache_config": kv_cache_config,
}
gen_server_config = {
"disable_overlap_scheduler": True,
"speculative_config": speculative_decoding_config,
"kv_cache_config": kv_cache_config,
}
disaggregated_server_config = {
"hostname": "localhost",
"port": 8000,
"backend": "pytorch",
"context_servers": {
"num_instances": 1,
"urls": ["localhost:8001"]
},
"generation_servers": {
"num_instances": 1,
"urls": ["localhost:8002"]
}
}
with launch_disaggregated_llm(disaggregated_server_config,
ctx_server_config, gen_server_config,
self.MODEL_PATH) as llm:
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)


@pytest.mark.timeout(3600)
@pytest.mark.skip_less_device_memory(140000)
Expand Down
2 changes: 2 additions & 0 deletions tests/integration/defs/accuracy/test_llm_api_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,8 @@ def test_ngram(self):
with llm:
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)

@pytest.mark.parametrize("backend", ["xgrammar", "llguidance"])
def test_guided_decoding(self, backend: str, mocker):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
hostname: localhost
port: 8000
model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
free_gpu_memory_fraction: 0.1
backend: pytorch
disable_overlap_scheduler: True
context_servers:
num_instances: 1
tensor_parallel_size: 1
pipeline_parallel_size: 1
urls:
- "localhost:8001"
generation_servers:
num_instances: 1
tensor_parallel_size: 1
pipeline_parallel_size: 1
urls:
- "localhost:8002"
speculative_config:
decoding_type: NGram
prompt_lookup_num_tokens: 4
max_matching_ngram_size: 4
is_keep_all: True
is_use_oldest: True
is_public_pool: True
19 changes: 19 additions & 0 deletions tests/integration/defs/disaggregated/test_disaggregated.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def get_test_config(test_desc, example_dir, test_root):
(4, f"{test_configs_root}/disagg_config_cache_aware_balance.yaml"),
"conditional": (2,
f"{test_configs_root}/disagg_config_conditional.yaml"),
"ngram": (2, f"{test_configs_root}/disagg_config_ngram.yaml"),
"deepseek_v3_lite_fp8":
(4,
f"{test_configs_root}/disagg_config_ctxtp2_gentp2_deepseek_v3_lite.yaml"
Expand Down Expand Up @@ -502,6 +503,24 @@ def test_disaggregated_conditional(disaggregated_test_root, llm_venv,
cwd=llm_venv.get_working_directory())


@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'],
indirect=True)
def test_disaggregated_ngram(disaggregated_test_root, llm_venv,
disaggregated_example_root, llama_model_root):
src_dst_dict = {
llama_model_root:
f"{llm_venv.get_working_directory()}/TinyLlama/TinyLlama-1.1B-Chat-v1.0",
}
for src, dst in src_dst_dict.items():
if not os.path.islink(dst):
os.makedirs(os.path.dirname(dst), exist_ok=True)
os.symlink(src, dst, target_is_directory=True)
run_disaggregated_test(disaggregated_example_root,
"ngram",
env=llm_venv._new_env,
cwd=llm_venv.get_working_directory())


@skip_no_hopper
@pytest.mark.skip_less_device(4)
@pytest.mark.parametrize("deepseek_v3_model_root", ['DeepSeek-V3-Lite-fp8'],
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_lists/test-db/l0_dgx_h100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ l0_dgx_h100:
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True]
- accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False]
- accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ngram
- test_e2e.py::test_ptp_quickstart_advanced_bs1
- condition:
ranges:
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_lists/test-db/l0_h100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ l0_h100:
- disaggregated/test_disaggregated.py::test_disaggregated_load_balance[TinyLlama-1.1B-Chat-v1.0]
- disaggregated/test_disaggregated.py::test_disaggregated_cache_aware_balance[TinyLlama-1.1B-Chat-v1.0]
- disaggregated/test_disaggregated.py::test_disaggregated_conditional[TinyLlama-1.1B-Chat-v1.0]
- disaggregated/test_disaggregated.py::test_disaggregated_ngram[TinyLlama-1.1B-Chat-v1.0]
- disaggregated/test_workers.py::test_workers_conditional_disaggregation[TinyLlama-1.1B-Chat-v1.0]
- disaggregated/test_workers.py::test_workers_kv_cache_events[TinyLlama-1.1B-Chat-v1.0]
- disaggregated/test_workers.py::test_workers_kv_cache_aware_router[TinyLlama-1.1B-Chat-v1.0]
Expand Down