8
8
import subprocess
9
9
import tempfile
10
10
import time
11
+ from concurrent .futures import ThreadPoolExecutor
11
12
from typing import Any , Dict , List , Optional
12
13
13
14
import openai
20
21
from tensorrt_llm .llmapi import CompletionOutput , RequestOutput , SamplingParams
21
22
22
23
from ..conftest import llm_models_root
23
- from .accuracy_core import MMLU , LlmapiAccuracyTestHarness
24
+ from .accuracy_core import GSM8K , MMLU , LlmapiAccuracyTestHarness
24
25
25
26
26
27
class Result (GenerationResultBase ):
@@ -41,10 +42,15 @@ def result(self):
41
42
42
43
class OpenAIServerClient :
43
44
44
- def __init__ (self , disaggregated_server_config : Dict [str , Any ],
45
+ def __init__ (self ,
46
+ disaggregated_server_config : Dict [str , Any ],
45
47
ctx_server_config : Dict [str , Any ],
46
- gen_server_config : Dict [str , Any ], model_name : str ):
48
+ gen_server_config : Dict [str , Any ],
49
+ model_name : str ,
50
+ tensor_parallel_size : int = 1 ):
51
+ self .thread_pool = ThreadPoolExecutor (max_workers = 16 )
47
52
self .temp_dir = tempfile .mkdtemp ()
53
+ self .futures = []
48
54
self .disaggregated_serving_config_path = os .path .join (
49
55
self .temp_dir , "disaggregated_serving_config.yaml" )
50
56
with open (self .disaggregated_serving_config_path , "w" ) as f :
@@ -58,18 +64,26 @@ def __init__(self, disaggregated_server_config: Dict[str, Any],
58
64
with open (gen_server_config_path , "w" ) as f :
59
65
yaml .dump (gen_server_config , f )
60
66
61
- with LLM (model_name ) as llm :
67
+ with LLM (model_name , tensor_parallel_size = tensor_parallel_size ) as llm :
62
68
self .args = llm .args
63
69
70
+ cuda_device_idx = 0
71
+ cuda_devices = []
72
+ for i in range (tensor_parallel_size ):
73
+ cuda_devices .append (f"{ cuda_device_idx } " )
74
+ cuda_device_idx += 1
75
+
64
76
trtllm_serve_path = "trtllm-serve"
65
77
# Common arguments for both servers
66
78
common_args = [
67
79
trtllm_serve_path , model_name , "--host" , "localhost" , "--backend" ,
68
80
"pytorch"
69
81
]
82
+ if tensor_parallel_size > 1 :
83
+ common_args .append (f"--tp_size={ tensor_parallel_size } " )
70
84
env_ctx = os .environ .copy ()
71
85
env_ctx ["TRTLLM_USE_UCX_KVCACHE" ] = "1"
72
-
86
+ env_ctx [ "CUDA_VISIBLE_DEVICES" ] = "," . join ( cuda_devices )
73
87
# Start the context server
74
88
self ._ctx_server = subprocess .Popen (common_args + [
75
89
"--port" , "8001" , "--extra_llm_api_options" , ctx_server_config_path
@@ -78,6 +92,11 @@ def __init__(self, disaggregated_server_config: Dict[str, Any],
78
92
# Start the generation server
79
93
env_gen = os .environ .copy ()
80
94
env_gen ["TRTLLM_USE_UCX_KVCACHE" ] = "1"
95
+ cuda_devices = []
96
+ for i in range (tensor_parallel_size ):
97
+ cuda_devices .append (f"{ cuda_device_idx } " )
98
+ cuda_device_idx += 1
99
+ env_gen ["CUDA_VISIBLE_DEVICES" ] = "," .join (cuda_devices )
81
100
self ._gen_server = subprocess .Popen (common_args + [
82
101
"--port" , "8002" , "--extra_llm_api_options" , gen_server_config_path
83
102
],
@@ -86,7 +105,8 @@ def __init__(self, disaggregated_server_config: Dict[str, Any],
86
105
# Start the disaggregated server
87
106
self ._disaggregated_server = subprocess .Popen ([
88
107
trtllm_serve_path , "disaggregated" , "-c" ,
89
- self .disaggregated_serving_config_path
108
+ self .disaggregated_serving_config_path , "--server_start_timeout" ,
109
+ "3600"
90
110
])
91
111
self .model_name = model_name
92
112
@@ -103,10 +123,7 @@ def __init__(self, disaggregated_server_config: Dict[str, Any],
103
123
self .client = openai .OpenAI (api_key = "1234567890" ,
104
124
base_url = f"http://localhost:8000/v1" )
105
125
106
- def generate_async (self ,
107
- prompt : str ,
108
- sampling_params : Optional [SamplingParams ] = None ):
109
- # TODO: Make this async
126
+ def send_request (self , prompt : str , sampling_params : SamplingParams ):
110
127
response = self .client .completions .create (
111
128
model = self .model_name ,
112
129
prompt = prompt ,
@@ -127,7 +144,18 @@ def generate_async(self,
127
144
setattr (requested_output , "result" , result .result )
128
145
return requested_output
129
146
130
- def __del__ (self ):
147
+ def generate_async (self ,
148
+ prompt : str ,
149
+ sampling_params : Optional [SamplingParams ] = None ):
150
+ future = self .thread_pool .submit (self .send_request , prompt ,
151
+ sampling_params )
152
+ self .futures .append (future )
153
+ return future
154
+
155
+ def __enter__ (self ):
156
+ return self
157
+
158
+ def __exit__ (self , exc_type , exc_value , traceback ):
131
159
shutil .rmtree (self .temp_dir )
132
160
self ._ctx_server .terminate ()
133
161
self ._gen_server .terminate ()
@@ -137,10 +165,14 @@ def __del__(self):
137
165
self ._gen_server .wait ()
138
166
self ._disaggregated_server .wait ()
139
167
168
+ for future in self .futures :
169
+ future .result ()
170
+ self .thread_pool .shutdown (wait = True )
171
+
140
172
141
- class TestLlama3_1_8B (LlmapiAccuracyTestHarness ):
142
- MODEL_NAME = "meta-llama/Llama-3.1-8B"
143
- MODEL_PATH = f"{ llm_models_root ()} /llama-3.1-model/Meta- Llama-3.1-8B"
173
+ class TestLlama3_1_8BInstruct (LlmapiAccuracyTestHarness ):
174
+ MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct "
175
+ MODEL_PATH = f"{ llm_models_root ()} /llama-3.1-model/Llama-3.1-8B-Instruct "
144
176
145
177
@pytest .mark .skip_less_device_memory (32000 )
146
178
@pytest .mark .skip_device_not_contain (["H100" , "H200" ])
@@ -169,8 +201,49 @@ def test_auto_dtype(self, disable_overlap_scheduler):
169
201
"urls" : ["localhost:8002" ]
170
202
}
171
203
}
172
- client = OpenAIServerClient (disaggregated_server_config ,
173
- ctx_server_config , gen_server_config ,
174
- self .MODEL_PATH )
175
- task = MMLU (self .MODEL_NAME )
176
- task .evaluate (client )
204
+ with OpenAIServerClient (disaggregated_server_config , ctx_server_config ,
205
+ gen_server_config , self .MODEL_PATH ) as client :
206
+ task = MMLU (self .MODEL_NAME )
207
+ task .evaluate (client )
208
+ task = GSM8K (self .MODEL_NAME )
209
+ task .evaluate (client )
210
+
211
+
212
+ class TestLlama4ScoutInstruct (LlmapiAccuracyTestHarness ):
213
+ MODEL_NAME = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
214
+ MODEL_PATH = f"{ llm_models_root ()} /llama4-models/Llama-4-Scout-17B-16E-Instruct"
215
+
216
+ @pytest .mark .parametrize ("overlap_scheduler" , [False , True ])
217
+ def test_auto_dtype (self , overlap_scheduler ):
218
+ ctx_server_config = {
219
+ "pytorch_backend_config" : {
220
+ "disable_overlap_scheduler" : True
221
+ }
222
+ }
223
+ gen_server_config = {
224
+ "pytorch_backend_config" : {
225
+ "disable_overlap_scheduler" : overlap_scheduler
226
+ }
227
+ }
228
+ disaggregated_server_config = {
229
+ "hostname" : "localhost" ,
230
+ "port" : 8000 ,
231
+ "backend" : "pytorch" ,
232
+ "context_servers" : {
233
+ "num_instances" : 1 ,
234
+ "urls" : ["localhost:8001" ]
235
+ },
236
+ "generation_servers" : {
237
+ "num_instances" : 1 ,
238
+ "urls" : ["localhost:8002" ]
239
+ }
240
+ }
241
+ with OpenAIServerClient (disaggregated_server_config ,
242
+ ctx_server_config ,
243
+ gen_server_config ,
244
+ self .MODEL_PATH ,
245
+ tensor_parallel_size = 4 ) as client :
246
+ task = MMLU (self .MODEL_NAME )
247
+ task .evaluate (client )
248
+ task = GSM8K (self .MODEL_NAME )
249
+ task .evaluate (client )
0 commit comments