Skip to content

Commit 401b0d3

Browse files
committed
Add tests
Signed-off-by: Iman Tabrizian <[email protected]>
1 parent 8fc44aa commit 401b0d3

File tree

2 files changed

+108
-0
lines changed

2 files changed

+108
-0
lines changed

tests/integration/defs/disaggregated/test_disaggregated_single_gpu.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from tensorrt_llm import LLM, DisaggregatedParams, SamplingParams
1313
from tensorrt_llm._utils import set_mpi_comm
1414
from tensorrt_llm.llmapi import CudaGraphConfig, KvCacheConfig, MpiCommSession
15+
from tensorrt_llm.llmapi.llm_args import EagleDecodingConfig
1516

1617
cloudpickle.register_pickle_by_value(sys.modules[__name__])
1718
MPI.pickle.__init__(
@@ -33,6 +34,11 @@ def model_path(model_name):
3334
elif 'TinyLlama-1.1B-Chat-v1.0' in model_name:
3435
return os.path.join(llm_models_root, 'llama-models-v2',
3536
'TinyLlama-1.1B-Chat-v1.0')
37+
elif 'Llama-3.1-8B-Instruct' in model_name:
38+
return os.path.join(llm_models_root, 'llama-3.1-model',
39+
'Llama-3.1-8B-Instruct/')
40+
elif 'EAGLE3-LLaMA3.1-Instruct-8B' in model_name:
41+
return os.path.join(llm_models_root, 'EAGLE3-LLaMA3.1-Instruct-8B')
3642
else:
3743
raise ValueError(f"Unknown model: {model_name}")
3844

@@ -317,5 +323,106 @@ def test_disaggregated_llama_context_capacity(model, enable_cuda_graph,
317323
print("All workers terminated.")
318324

319325

326+
@pytest.mark.parametrize("model", ["Llama-3.1-8B-Instruct"])
327+
@pytest.mark.parametrize("spec_dec_model_path", ["EAGLE3-LLaMA3.1-Instruct-8B"])
328+
@pytest.mark.parametrize("generation_overlap", [False])
329+
def test_disaggregated_spec_dec_batch_slot_limit(model, spec_dec_model_path,
330+
generation_overlap):
331+
# Test whether the batch slots are properly released when using speculative decoding
332+
# with disaggregated serving.
333+
spec_dec_config = EagleDecodingConfig(
334+
speculative_model_dir=model_path(spec_dec_model_path),
335+
eagle3_one_model=False,
336+
max_draft_len=3)
337+
338+
worker_pytorch_configs = []
339+
340+
# Context worker
341+
worker_pytorch_configs.append(
342+
dict(disable_overlap_scheduler=True,
343+
kv_cache_dtype="auto",
344+
speculative_config=spec_dec_config,
345+
max_batch_size=1))
346+
347+
# Generation worker
348+
worker_pytorch_configs.append(
349+
dict(disable_overlap_scheduler=not generation_overlap,
350+
kv_cache_dtype="auto",
351+
speculative_config=spec_dec_config,
352+
max_batch_size=1))
353+
354+
kv_cache_configs = [
355+
KvCacheConfig(max_tokens=128, enable_block_reuse=False)
356+
for _ in range(2)
357+
]
358+
model_names = [model_path(model) for _ in range(2)]
359+
ranks = [0, 1]
360+
worker_args = list(
361+
zip(kv_cache_configs, worker_pytorch_configs, model_names, ranks))
362+
363+
port_name = MPI.Open_port()
364+
MPI.Publish_name('my_port', port_name)
365+
366+
prompt = "What is the capital of Germany?"
367+
368+
with MPIPoolExecutor(max_workers=2, env={"TRTLLM_USE_MPI_KVCACHE":
369+
"1"}) as executor:
370+
futures = []
371+
try:
372+
for worker_arg in worker_args:
373+
future = executor.submit(worker_entry_point, *worker_arg)
374+
futures.append(future)
375+
except Exception as e:
376+
print(f"Error in worker {worker_arg}: {e}")
377+
raise e
378+
379+
try:
380+
print("Launched all the workers.")
381+
intercomm = MPI.COMM_SELF.Accept(port_name)
382+
383+
for _ in range(2):
384+
intercomm.recv(tag=MPI_READY)
385+
print("Received ready signal.")
386+
max_tokens = 25
387+
388+
requests = []
389+
for _ in range(10):
390+
requests.append(
391+
(prompt, SamplingParams(max_tokens=1, ignore_eos=True),
392+
DisaggregatedParams(request_type="context_only")))
393+
394+
intercomm.send(requests, dest=0, tag=MPI_REQUEST)
395+
396+
for _ in range(len(requests)):
397+
output = intercomm.recv(source=0, tag=MPI_RESULT)
398+
assert output[0].disaggregated_params is not None
399+
assert output[
400+
0].disaggregated_params.request_type == "context_only"
401+
assert len(output[0].token_ids) == 1
402+
403+
generation_request_disagg_params = output[
404+
0].disaggregated_params
405+
generation_request_disagg_params.request_type = "generation_only"
406+
requests = []
407+
requests.append((prompt,
408+
SamplingParams(max_tokens=max_tokens,
409+
ignore_eos=True),
410+
generation_request_disagg_params))
411+
412+
intercomm.send(requests, dest=1, tag=MPI_REQUEST)
413+
output = intercomm.recv(source=1, tag=MPI_RESULT)
414+
415+
finally:
416+
# Send termination requests
417+
intercomm.send(None, dest=0, tag=MPI_REQUEST)
418+
intercomm.send(None, dest=1, tag=MPI_REQUEST)
419+
print("Sent termination requests to the workers.")
420+
421+
# Wait for all futures to complete
422+
for future in futures:
423+
future.result()
424+
print("All workers terminated.")
425+
426+
320427
if __name__ == "__main__":
321428
pytest.main()

tests/integration/test_lists/test-db/l0_h100.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ l0_h100:
6666
- disaggregated/test_workers.py::test_workers_kv_cache_aware_router[TinyLlama-1.1B-Chat-v1.0]
6767
- disaggregated/test_workers.py::test_workers_kv_cache_aware_router_eviction[TinyLlama-1.1B-Chat-v1.0]
6868
- disaggregated/test_disaggregated_single_gpu.py::test_disaggregated_llama_context_capacity[False-False-DeepSeek-V3-Lite-fp8/fp8]
69+
- disaggregated/test_disaggregated_single_gpu.py::test_disaggregated_spec_dec_batch_slot_limit[False-EAGLE3-LLaMA3.1-Instruct-8B-Llama-3.1-8B-Instruct]
6970
- test_e2e.py::test_trtllm_bench_iteration_log[PyTorch-streaming-meta-llama/Llama-3.1-8B-llama-3.1-model/Meta-Llama-3.1-8B]
7071
- test_e2e.py::test_trtllm_bench_iteration_log[PyTorch-non-streaming-meta-llama/Llama-3.1-8B-llama-3.1-model/Meta-Llama-3.1-8B]
7172
- test_e2e.py::test_trtllm_bench_request_rate_and_concurrency[enable_concurrency-]

0 commit comments

Comments
 (0)