Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions tensorrt_llm/evaluate/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,15 @@ def evaluate(self,
outputs.append(output)
references.append(reference)
auxiliaries.append(aux)
results = []
for output in tqdm(outputs, desc="Fetching responses"):
output.result()
results.append(output.result())
profiler.stop("trtllm exec")
elapsed_time = profiler.elapsed_time_in_sec("trtllm exec")
logger.info(f"TRTLLM execution time: {elapsed_time:.3f} seconds.")
profiler.reset("trtllm exec")

score = self.compute_score(outputs, references, *zip(*auxiliaries))
score = self.compute_score(results, references, *zip(*auxiliaries))
return score

@staticmethod
Expand Down
10 changes: 6 additions & 4 deletions tensorrt_llm/evaluate/lm_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,20 +96,22 @@ def _get_sampling_params(self, gen_kwargs: dict) -> SamplingParams:

def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
profiler.start("trtllm exec")
outputs = []
results = []
for request in tqdm(requests,
desc="Submitting requests",
disable=disable_tqdm):
prompt, gen_kwargs = request.args
sampling_params = self._get_sampling_params(gen_kwargs)
output = self.llm.generate_async(prompt,
sampling_params=sampling_params)
outputs.append(output)
results.append(output)

for output in tqdm(outputs,
outputs = []
for output in tqdm(results,
desc="Fetching responses",
disable=disable_tqdm):
output.result()
outputs.append(output.result())

profiler.stop("trtllm exec")
elapsed_time = profiler.elapsed_time_in_sec("trtllm exec")
logger.info(f"TRTLLM execution time: {elapsed_time:.3f} seconds.")
Expand Down
111 changes: 92 additions & 19 deletions tests/integration/defs/accuracy/test_disaggregated_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import subprocess
import tempfile
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Optional

import openai
Expand All @@ -20,7 +21,7 @@
from tensorrt_llm.llmapi import CompletionOutput, RequestOutput, SamplingParams

from ..conftest import llm_models_root
from .accuracy_core import MMLU, LlmapiAccuracyTestHarness
from .accuracy_core import GSM8K, MMLU, LlmapiAccuracyTestHarness


class Result(GenerationResultBase):
Expand All @@ -41,10 +42,15 @@ def result(self):

class OpenAIServerClient:

def __init__(self, disaggregated_server_config: Dict[str, Any],
def __init__(self,
disaggregated_server_config: Dict[str, Any],
ctx_server_config: Dict[str, Any],
gen_server_config: Dict[str, Any], model_name: str):
gen_server_config: Dict[str, Any],
model_name: str,
tensor_parallel_size: int = 1):
self.thread_pool = ThreadPoolExecutor(max_workers=16)
self.temp_dir = tempfile.mkdtemp()
self.futures = []
self.disaggregated_serving_config_path = os.path.join(
self.temp_dir, "disaggregated_serving_config.yaml")
with open(self.disaggregated_serving_config_path, "w") as f:
Expand All @@ -58,18 +64,26 @@ def __init__(self, disaggregated_server_config: Dict[str, Any],
with open(gen_server_config_path, "w") as f:
yaml.dump(gen_server_config, f)

with LLM(model_name) as llm:
with LLM(model_name, tensor_parallel_size=tensor_parallel_size) as llm:
self.args = llm.args

cuda_device_idx = 0
cuda_devices = []
for i in range(tensor_parallel_size):
cuda_devices.append(f"{cuda_device_idx}")
cuda_device_idx += 1

trtllm_serve_path = "trtllm-serve"
# Common arguments for both servers
common_args = [
trtllm_serve_path, model_name, "--host", "localhost", "--backend",
"pytorch"
]
if tensor_parallel_size > 1:
common_args.append(f"--tp_size={tensor_parallel_size}")
env_ctx = os.environ.copy()
env_ctx["TRTLLM_USE_UCX_KVCACHE"] = "1"

env_ctx["CUDA_VISIBLE_DEVICES"] = ",".join(cuda_devices)
# Start the context server
self._ctx_server = subprocess.Popen(common_args + [
"--port", "8001", "--extra_llm_api_options", ctx_server_config_path
Expand All @@ -78,6 +92,11 @@ def __init__(self, disaggregated_server_config: Dict[str, Any],
# Start the generation server
env_gen = os.environ.copy()
env_gen["TRTLLM_USE_UCX_KVCACHE"] = "1"
cuda_devices = []
for i in range(tensor_parallel_size):
cuda_devices.append(f"{cuda_device_idx}")
cuda_device_idx += 1
env_gen["CUDA_VISIBLE_DEVICES"] = ",".join(cuda_devices)
self._gen_server = subprocess.Popen(common_args + [
"--port", "8002", "--extra_llm_api_options", gen_server_config_path
],
Expand All @@ -86,7 +105,8 @@ def __init__(self, disaggregated_server_config: Dict[str, Any],
# Start the disaggregated server
self._disaggregated_server = subprocess.Popen([
trtllm_serve_path, "disaggregated", "-c",
self.disaggregated_serving_config_path
self.disaggregated_serving_config_path, "--server_start_timeout",
"3600"
])
self.model_name = model_name

Expand All @@ -103,10 +123,7 @@ def __init__(self, disaggregated_server_config: Dict[str, Any],
self.client = openai.OpenAI(api_key="1234567890",
base_url=f"http://localhost:8000/v1")

def generate_async(self,
prompt: str,
sampling_params: Optional[SamplingParams] = None):
# TODO: Make this async
def send_request(self, prompt: str, sampling_params: SamplingParams):
response = self.client.completions.create(
model=self.model_name,
prompt=prompt,
Expand All @@ -127,7 +144,18 @@ def generate_async(self,
setattr(requested_output, "result", result.result)
return requested_output

def __del__(self):
def generate_async(self,
prompt: str,
sampling_params: Optional[SamplingParams] = None):
future = self.thread_pool.submit(self.send_request, prompt,
sampling_params)
self.futures.append(future)
return future

def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, traceback):
shutil.rmtree(self.temp_dir)
self._ctx_server.terminate()
self._gen_server.terminate()
Expand All @@ -137,10 +165,14 @@ def __del__(self):
self._gen_server.wait()
self._disaggregated_server.wait()

for future in self.futures:
future.result()
self.thread_pool.shutdown(wait=True)


class TestLlama3_1_8B(LlmapiAccuracyTestHarness):
MODEL_NAME = "meta-llama/Llama-3.1-8B"
MODEL_PATH = f"{llm_models_root()}/llama-3.1-model/Meta-Llama-3.1-8B"
class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
MODEL_PATH = f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct"

@pytest.mark.skip_less_device_memory(32000)
@pytest.mark.skip_device_not_contain(["H100", "H200"])
Expand Down Expand Up @@ -169,8 +201,49 @@ def test_auto_dtype(self, disable_overlap_scheduler):
"urls": ["localhost:8002"]
}
}
client = OpenAIServerClient(disaggregated_server_config,
ctx_server_config, gen_server_config,
self.MODEL_PATH)
task = MMLU(self.MODEL_NAME)
task.evaluate(client)
with OpenAIServerClient(disaggregated_server_config, ctx_server_config,
gen_server_config, self.MODEL_PATH) as client:
task = MMLU(self.MODEL_NAME)
task.evaluate(client)
task = GSM8K(self.MODEL_NAME)
task.evaluate(client)


class TestLlama4ScoutInstruct(LlmapiAccuracyTestHarness):
MODEL_NAME = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
MODEL_PATH = f"{llm_models_root()}/llama4-models/Llama-4-Scout-17B-16E-Instruct"

@pytest.mark.parametrize("overlap_scheduler", [False, True])
def test_auto_dtype(self, overlap_scheduler):
ctx_server_config = {
"pytorch_backend_config": {
"disable_overlap_scheduler": True
}
}
gen_server_config = {
"pytorch_backend_config": {
"disable_overlap_scheduler": overlap_scheduler
}
}
disaggregated_server_config = {
"hostname": "localhost",
"port": 8000,
"backend": "pytorch",
"context_servers": {
"num_instances": 1,
"urls": ["localhost:8001"]
},
"generation_servers": {
"num_instances": 1,
"urls": ["localhost:8002"]
}
}
with OpenAIServerClient(disaggregated_server_config,
ctx_server_config,
gen_server_config,
self.MODEL_PATH,
tensor_parallel_size=4) as client:
task = MMLU(self.MODEL_NAME)
task.evaluate(client)
task = GSM8K(self.MODEL_NAME)
task.evaluate(client)
4 changes: 2 additions & 2 deletions tests/integration/test_lists/qa/examples_test_list.txt
Original file line number Diff line number Diff line change
Expand Up @@ -453,8 +453,8 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_8gpus[throughput]
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_8gpus[throughput_tp8]
accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency]
accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8B::test_auto_dtype[False]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8B::test_auto_dtype[True]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True]

test_e2e.py::test_llama_e2e[use_cpp_session-remove_input_padding-]
test_e2e.py::test_llama_e2e[use_py_session-remove_input_padding-]
Expand Down
2 changes: 2 additions & 0 deletions tests/integration/test_lists/test-db/l0_dgx_h100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ l0_dgx_h100:
- disaggregated/test_disaggregated.py::test_disaggregated_cuda_graph[TinyLlama-1.1B-Chat-v1.0]
- disaggregated/test_disaggregated.py::test_disaggregated_mixed[TinyLlama-1.1B-Chat-v1.0]
- disaggregated/test_disaggregated.py::test_disaggregated_overlap[TinyLlama-1.1B-Chat-v1.0]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True]
- condition:
ranges:
system_gpu_count:
Expand Down
2 changes: 2 additions & 0 deletions tests/integration/test_lists/test-db/l0_dgx_h200.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,7 @@ l0_dgx_h200:
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[ep4-mtp_nextn=0-attention_dp=True-cuda_graph=True-overlap_scheduler=True]
# - accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[throughput] # OOM
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[latency] # 1h
- accuracy/test_disaggregated_serving.py::TestLlama4ScoutInstruct::test_auto_dtype[True]
- accuracy/test_disaggregated_serving.py::TestLlama4ScoutInstruct::test_auto_dtype[False]
- unittest/_torch/multi_gpu_modeling/test_llama4.py::test_llama4[pp1-ep1-enable_graph-tp8-trtllm-scout]
- unittest/llmapi/test_llm_pytorch.py::test_nemotron_nas_lora
2 changes: 0 additions & 2 deletions tests/integration/test_lists/test-db/l0_h100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ l0_h100:
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_no_kv_cache_reuse[quant_dtype=fp8-mtp_nextn=2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True]
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency]
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8[latency]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8B::test_auto_dtype[False]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8B::test_auto_dtype[True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_cuda_graph_padding
- test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-False-False]
- test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-instruct-hf-fp8-True-True]
Expand Down
2 changes: 0 additions & 2 deletions tests/integration/test_lists/waives.txt
Original file line number Diff line number Diff line change
Expand Up @@ -443,8 +443,6 @@ examples/test_bert.py::test_llm_bert_general[compare_hf-enable_remove_input_padd
examples/test_bert.py::test_llm_bert_general[compare_hf-enable_remove_input_padding-disable_attention_plugin-disable_context_fmha-tp:2-pp:1-float16-RobertaForSequenceClassification-bert/twitter-roberta-base-emotion] SKIP (https://nvbugs/5234058)
examples/test_bert.py::test_llm_bert_general[compare_hf-enable_remove_input_padding-use_attention_plugin-enable_context_fmha-tp:2-pp:1-float16-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity] SKIP (https://nvbugs/5234058)
examples/test_bert.py::test_llm_bert_general[compare_hf-enable_remove_input_padding-use_attention_plugin-enable_context_fmha-tp:2-pp:1-float16-RobertaForQuestionAnswering-bert/roberta-base-squad2] SKIP (https://nvbugs/5234058)
accuracy/test_disaggregated_serving.py::TestLlama3_1_8B::test_auto_dtype[False] SKIP (https://nvbugs/5266257)
accuracy/test_disaggregated_serving.py::TestLlama3_1_8B::test_auto_dtype[True] SKIP (https://nvbugs/5266257)
disaggregated/test_disaggregated.py::test_disaggregated_cuda_graph[TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/5247271)
disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_tp1_attention_dp_overlap_one_mtp[DeepSeek-V3-Lite-fp8] SKIP (https://nvbugspro.nvidia.com/bug/5273945)
disaggregated/test_workers.py::test_workers_kv_cache_aware_router[TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/5279438)
Expand Down