Skip to content

Commit c6074c4

Browse files
authored
Add llama4 disagg accuracy tests (#4336)
* Add llama4 disagg accuracy tests Signed-off-by: Iman Tabrizian <[email protected]> * Make it async and add GSM8K benchmark Signed-off-by: Iman Tabrizian <[email protected]> --------- Signed-off-by: Iman Tabrizian <[email protected]>
1 parent 001704c commit c6074c4

File tree

8 files changed

+107
-31
lines changed

8 files changed

+107
-31
lines changed

tensorrt_llm/evaluate/interface.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,14 +72,15 @@ def evaluate(self,
7272
outputs.append(output)
7373
references.append(reference)
7474
auxiliaries.append(aux)
75+
results = []
7576
for output in tqdm(outputs, desc="Fetching responses"):
76-
output.result()
77+
results.append(output.result())
7778
profiler.stop("trtllm exec")
7879
elapsed_time = profiler.elapsed_time_in_sec("trtllm exec")
7980
logger.info(f"TRTLLM execution time: {elapsed_time:.3f} seconds.")
8081
profiler.reset("trtllm exec")
8182

82-
score = self.compute_score(outputs, references, *zip(*auxiliaries))
83+
score = self.compute_score(results, references, *zip(*auxiliaries))
8384
return score
8485

8586
@staticmethod

tensorrt_llm/evaluate/lm_eval.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,20 +96,22 @@ def _get_sampling_params(self, gen_kwargs: dict) -> SamplingParams:
9696

9797
def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
9898
profiler.start("trtllm exec")
99-
outputs = []
99+
results = []
100100
for request in tqdm(requests,
101101
desc="Submitting requests",
102102
disable=disable_tqdm):
103103
prompt, gen_kwargs = request.args
104104
sampling_params = self._get_sampling_params(gen_kwargs)
105105
output = self.llm.generate_async(prompt,
106106
sampling_params=sampling_params)
107-
outputs.append(output)
107+
results.append(output)
108108

109-
for output in tqdm(outputs,
109+
outputs = []
110+
for output in tqdm(results,
110111
desc="Fetching responses",
111112
disable=disable_tqdm):
112-
output.result()
113+
outputs.append(output.result())
114+
113115
profiler.stop("trtllm exec")
114116
elapsed_time = profiler.elapsed_time_in_sec("trtllm exec")
115117
logger.info(f"TRTLLM execution time: {elapsed_time:.3f} seconds.")

tests/integration/defs/accuracy/test_disaggregated_serving.py

Lines changed: 92 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import subprocess
99
import tempfile
1010
import time
11+
from concurrent.futures import ThreadPoolExecutor
1112
from typing import Any, Dict, List, Optional
1213

1314
import openai
@@ -20,7 +21,7 @@
2021
from tensorrt_llm.llmapi import CompletionOutput, RequestOutput, SamplingParams
2122

2223
from ..conftest import llm_models_root
23-
from .accuracy_core import MMLU, LlmapiAccuracyTestHarness
24+
from .accuracy_core import GSM8K, MMLU, LlmapiAccuracyTestHarness
2425

2526

2627
class Result(GenerationResultBase):
@@ -41,10 +42,15 @@ def result(self):
4142

4243
class OpenAIServerClient:
4344

44-
def __init__(self, disaggregated_server_config: Dict[str, Any],
45+
def __init__(self,
46+
disaggregated_server_config: Dict[str, Any],
4547
ctx_server_config: Dict[str, Any],
46-
gen_server_config: Dict[str, Any], model_name: str):
48+
gen_server_config: Dict[str, Any],
49+
model_name: str,
50+
tensor_parallel_size: int = 1):
51+
self.thread_pool = ThreadPoolExecutor(max_workers=16)
4752
self.temp_dir = tempfile.mkdtemp()
53+
self.futures = []
4854
self.disaggregated_serving_config_path = os.path.join(
4955
self.temp_dir, "disaggregated_serving_config.yaml")
5056
with open(self.disaggregated_serving_config_path, "w") as f:
@@ -58,18 +64,26 @@ def __init__(self, disaggregated_server_config: Dict[str, Any],
5864
with open(gen_server_config_path, "w") as f:
5965
yaml.dump(gen_server_config, f)
6066

61-
with LLM(model_name) as llm:
67+
with LLM(model_name, tensor_parallel_size=tensor_parallel_size) as llm:
6268
self.args = llm.args
6369

70+
cuda_device_idx = 0
71+
cuda_devices = []
72+
for i in range(tensor_parallel_size):
73+
cuda_devices.append(f"{cuda_device_idx}")
74+
cuda_device_idx += 1
75+
6476
trtllm_serve_path = "trtllm-serve"
6577
# Common arguments for both servers
6678
common_args = [
6779
trtllm_serve_path, model_name, "--host", "localhost", "--backend",
6880
"pytorch"
6981
]
82+
if tensor_parallel_size > 1:
83+
common_args.append(f"--tp_size={tensor_parallel_size}")
7084
env_ctx = os.environ.copy()
7185
env_ctx["TRTLLM_USE_UCX_KVCACHE"] = "1"
72-
86+
env_ctx["CUDA_VISIBLE_DEVICES"] = ",".join(cuda_devices)
7387
# Start the context server
7488
self._ctx_server = subprocess.Popen(common_args + [
7589
"--port", "8001", "--extra_llm_api_options", ctx_server_config_path
@@ -78,6 +92,11 @@ def __init__(self, disaggregated_server_config: Dict[str, Any],
7892
# Start the generation server
7993
env_gen = os.environ.copy()
8094
env_gen["TRTLLM_USE_UCX_KVCACHE"] = "1"
95+
cuda_devices = []
96+
for i in range(tensor_parallel_size):
97+
cuda_devices.append(f"{cuda_device_idx}")
98+
cuda_device_idx += 1
99+
env_gen["CUDA_VISIBLE_DEVICES"] = ",".join(cuda_devices)
81100
self._gen_server = subprocess.Popen(common_args + [
82101
"--port", "8002", "--extra_llm_api_options", gen_server_config_path
83102
],
@@ -86,7 +105,8 @@ def __init__(self, disaggregated_server_config: Dict[str, Any],
86105
# Start the disaggregated server
87106
self._disaggregated_server = subprocess.Popen([
88107
trtllm_serve_path, "disaggregated", "-c",
89-
self.disaggregated_serving_config_path
108+
self.disaggregated_serving_config_path, "--server_start_timeout",
109+
"3600"
90110
])
91111
self.model_name = model_name
92112

@@ -103,10 +123,7 @@ def __init__(self, disaggregated_server_config: Dict[str, Any],
103123
self.client = openai.OpenAI(api_key="1234567890",
104124
base_url=f"http://localhost:8000/v1")
105125

106-
def generate_async(self,
107-
prompt: str,
108-
sampling_params: Optional[SamplingParams] = None):
109-
# TODO: Make this async
126+
def send_request(self, prompt: str, sampling_params: SamplingParams):
110127
response = self.client.completions.create(
111128
model=self.model_name,
112129
prompt=prompt,
@@ -127,7 +144,18 @@ def generate_async(self,
127144
setattr(requested_output, "result", result.result)
128145
return requested_output
129146

130-
def __del__(self):
147+
def generate_async(self,
148+
prompt: str,
149+
sampling_params: Optional[SamplingParams] = None):
150+
future = self.thread_pool.submit(self.send_request, prompt,
151+
sampling_params)
152+
self.futures.append(future)
153+
return future
154+
155+
def __enter__(self):
156+
return self
157+
158+
def __exit__(self, exc_type, exc_value, traceback):
131159
shutil.rmtree(self.temp_dir)
132160
self._ctx_server.terminate()
133161
self._gen_server.terminate()
@@ -137,10 +165,14 @@ def __del__(self):
137165
self._gen_server.wait()
138166
self._disaggregated_server.wait()
139167

168+
for future in self.futures:
169+
future.result()
170+
self.thread_pool.shutdown(wait=True)
171+
140172

141-
class TestLlama3_1_8B(LlmapiAccuracyTestHarness):
142-
MODEL_NAME = "meta-llama/Llama-3.1-8B"
143-
MODEL_PATH = f"{llm_models_root()}/llama-3.1-model/Meta-Llama-3.1-8B"
173+
class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
174+
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
175+
MODEL_PATH = f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct"
144176

145177
@pytest.mark.skip_less_device_memory(32000)
146178
@pytest.mark.skip_device_not_contain(["H100", "H200"])
@@ -169,8 +201,49 @@ def test_auto_dtype(self, disable_overlap_scheduler):
169201
"urls": ["localhost:8002"]
170202
}
171203
}
172-
client = OpenAIServerClient(disaggregated_server_config,
173-
ctx_server_config, gen_server_config,
174-
self.MODEL_PATH)
175-
task = MMLU(self.MODEL_NAME)
176-
task.evaluate(client)
204+
with OpenAIServerClient(disaggregated_server_config, ctx_server_config,
205+
gen_server_config, self.MODEL_PATH) as client:
206+
task = MMLU(self.MODEL_NAME)
207+
task.evaluate(client)
208+
task = GSM8K(self.MODEL_NAME)
209+
task.evaluate(client)
210+
211+
212+
class TestLlama4ScoutInstruct(LlmapiAccuracyTestHarness):
213+
MODEL_NAME = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
214+
MODEL_PATH = f"{llm_models_root()}/llama4-models/Llama-4-Scout-17B-16E-Instruct"
215+
216+
@pytest.mark.parametrize("overlap_scheduler", [False, True])
217+
def test_auto_dtype(self, overlap_scheduler):
218+
ctx_server_config = {
219+
"pytorch_backend_config": {
220+
"disable_overlap_scheduler": True
221+
}
222+
}
223+
gen_server_config = {
224+
"pytorch_backend_config": {
225+
"disable_overlap_scheduler": overlap_scheduler
226+
}
227+
}
228+
disaggregated_server_config = {
229+
"hostname": "localhost",
230+
"port": 8000,
231+
"backend": "pytorch",
232+
"context_servers": {
233+
"num_instances": 1,
234+
"urls": ["localhost:8001"]
235+
},
236+
"generation_servers": {
237+
"num_instances": 1,
238+
"urls": ["localhost:8002"]
239+
}
240+
}
241+
with OpenAIServerClient(disaggregated_server_config,
242+
ctx_server_config,
243+
gen_server_config,
244+
self.MODEL_PATH,
245+
tensor_parallel_size=4) as client:
246+
task = MMLU(self.MODEL_NAME)
247+
task.evaluate(client)
248+
task = GSM8K(self.MODEL_NAME)
249+
task.evaluate(client)

tests/integration/test_lists/qa/examples_test_list.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -453,8 +453,8 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_8gpus[throughput]
453453
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_8gpus[throughput_tp8]
454454
accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency]
455455
accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency]
456-
accuracy/test_disaggregated_serving.py::TestLlama3_1_8B::test_auto_dtype[False]
457-
accuracy/test_disaggregated_serving.py::TestLlama3_1_8B::test_auto_dtype[True]
456+
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False]
457+
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True]
458458

459459
test_e2e.py::test_llama_e2e[use_cpp_session-remove_input_padding-]
460460
test_e2e.py::test_llama_e2e[use_py_session-remove_input_padding-]

tests/integration/test_lists/test-db/l0_dgx_h100.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ l0_dgx_h100:
3737
- disaggregated/test_disaggregated.py::test_disaggregated_cuda_graph[TinyLlama-1.1B-Chat-v1.0]
3838
- disaggregated/test_disaggregated.py::test_disaggregated_mixed[TinyLlama-1.1B-Chat-v1.0]
3939
- disaggregated/test_disaggregated.py::test_disaggregated_overlap[TinyLlama-1.1B-Chat-v1.0]
40+
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False]
41+
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True]
4042
- condition:
4143
ranges:
4244
system_gpu_count:

tests/integration/test_lists/test-db/l0_dgx_h200.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,7 @@ l0_dgx_h200:
1717
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[ep4-mtp_nextn=0-attention_dp=True-cuda_graph=True-overlap_scheduler=True]
1818
# - accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[throughput] # OOM
1919
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[latency] # 1h
20+
- accuracy/test_disaggregated_serving.py::TestLlama4ScoutInstruct::test_auto_dtype[True]
21+
- accuracy/test_disaggregated_serving.py::TestLlama4ScoutInstruct::test_auto_dtype[False]
2022
- unittest/_torch/multi_gpu_modeling/test_llama4.py::test_llama4[pp1-ep1-enable_graph-tp8-trtllm-scout]
2123
- unittest/llmapi/test_llm_pytorch.py::test_nemotron_nas_lora

tests/integration/test_lists/test-db/l0_h100.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,6 @@ l0_h100:
4747
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_no_kv_cache_reuse[quant_dtype=fp8-mtp_nextn=2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True]
4848
- accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency]
4949
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8[latency]
50-
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8B::test_auto_dtype[False]
51-
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8B::test_auto_dtype[True]
5250
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_cuda_graph_padding
5351
- test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-False-False]
5452
- test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-instruct-hf-fp8-True-True]

tests/integration/test_lists/waives.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -443,8 +443,6 @@ examples/test_bert.py::test_llm_bert_general[compare_hf-enable_remove_input_padd
443443
examples/test_bert.py::test_llm_bert_general[compare_hf-enable_remove_input_padding-disable_attention_plugin-disable_context_fmha-tp:2-pp:1-float16-RobertaForSequenceClassification-bert/twitter-roberta-base-emotion] SKIP (https://nvbugs/5234058)
444444
examples/test_bert.py::test_llm_bert_general[compare_hf-enable_remove_input_padding-use_attention_plugin-enable_context_fmha-tp:2-pp:1-float16-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity] SKIP (https://nvbugs/5234058)
445445
examples/test_bert.py::test_llm_bert_general[compare_hf-enable_remove_input_padding-use_attention_plugin-enable_context_fmha-tp:2-pp:1-float16-RobertaForQuestionAnswering-bert/roberta-base-squad2] SKIP (https://nvbugs/5234058)
446-
accuracy/test_disaggregated_serving.py::TestLlama3_1_8B::test_auto_dtype[False] SKIP (https://nvbugs/5266257)
447-
accuracy/test_disaggregated_serving.py::TestLlama3_1_8B::test_auto_dtype[True] SKIP (https://nvbugs/5266257)
448446
disaggregated/test_disaggregated.py::test_disaggregated_cuda_graph[TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/5247271)
449447
disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_tp1_attention_dp_overlap_one_mtp[DeepSeek-V3-Lite-fp8] SKIP (https://nvbugspro.nvidia.com/bug/5273945)
450448
disaggregated/test_workers.py::test_workers_kv_cache_aware_router[TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/5279438)

0 commit comments

Comments
 (0)