@@ -92,6 +92,7 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
92
92
trtllm_serve_path , model_name , "--host" , "localhost" , "--backend" ,
93
93
"pytorch"
94
94
]
95
+
95
96
if tensor_parallel_size > 1 :
96
97
common_args .append (f"--tp_size={ tensor_parallel_size } " )
97
98
@@ -104,18 +105,22 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
104
105
env_gen ["TRTLLM_USE_UCX_KVCACHE" ] = "1"
105
106
env_gen ["CUDA_VISIBLE_DEVICES" ] = "," .join (
106
107
map (str , range (tensor_parallel_size , 2 * tensor_parallel_size )))
107
-
108
- with (MyThreadPoolExecutor (max_workers = 16 ) as thread_pool , temp_dir ,
109
- popen (common_args + [
110
- "--port" , "8001" , "--extra_llm_api_options" ,
111
- ctx_server_config_path
112
- ],
113
- env = env_ctx ) as ctx_server ,
114
- popen (common_args + [
115
- "--port" , "8002" , "--extra_llm_api_options" ,
116
- gen_server_config_path
117
- ],
118
- env = env_gen ) as gen_server ,
108
+ ctx_server_args = common_args + [
109
+ "--port" , "8001" , "--extra_llm_api_options" , ctx_server_config_path
110
+ ]
111
+ gen_server_args = common_args + [
112
+ "--port" , "8002" , "--extra_llm_api_options" , gen_server_config_path
113
+ ]
114
+ if "max_num_tokens" in ctx_server_config :
115
+ ctx_server_args .append (
116
+ f"--max_num_tokens={ ctx_server_config ['max_num_tokens' ]} " )
117
+ if "max_num_tokens" in gen_server_config :
118
+ gen_server_args .append (
119
+ f"--max_num_tokens={ gen_server_config ['max_num_tokens' ]} " )
120
+
121
+ with (MyThreadPoolExecutor (max_workers = 16 ) as
122
+ thread_pool , temp_dir , popen (ctx_server_args , env = env_ctx ) as
123
+ ctx_server , popen (gen_server_args , env = env_gen ) as gen_server ,
119
124
popen ([
120
125
trtllm_serve_path , "disaggregated" , "-c" ,
121
126
disaggregated_serving_config_path , "--server_start_timeout" ,
@@ -209,9 +214,53 @@ def test_auto_dtype(self, disable_overlap_scheduler):
209
214
task = GSM8K (self .MODEL_NAME )
210
215
task .evaluate (llm )
211
216
217
+ @pytest .mark .parametrize ("overlap_scheduler" , [False ])
218
+ def test_eagle3 (self , overlap_scheduler ):
219
+ sepculative_decoding_config = {
220
+ "decoding_type" : "Eagle" ,
221
+ "max_draft_len" : 4 ,
222
+ "pytorch_weights_path" :
223
+ f"{ llm_models_root ()} /EAGLE3-LLaMA3.1-Instruct-8B" ,
224
+ "eagle3_one_model" : False
225
+ }
226
+ kv_cache_config = {
227
+ "free_gpu_memory_fraction" : 0.5 ,
228
+ "enable_block_reuse" : False
229
+ }
230
+ ctx_server_config = {
231
+ "disable_overlap_scheduler" : True ,
232
+ "speculative_config" : sepculative_decoding_config ,
233
+ "kv_cache_config" : kv_cache_config ,
234
+ "max_num_tokens" : 13393 * 2
235
+ }
236
+ gen_server_config = {
237
+ "disable_overlap_scheduler" : not overlap_scheduler ,
238
+ "speculative_config" : sepculative_decoding_config ,
239
+ "kv_cache_config" : kv_cache_config ,
240
+ "max_num_tokens" : 13393 * 2
241
+ }
242
+ disaggregated_server_config = {
243
+ "hostname" : "localhost" ,
244
+ "port" : 8000 ,
245
+ "backend" : "pytorch" ,
246
+ "context_servers" : {
247
+ "num_instances" : 1 ,
248
+ "urls" : ["localhost:8001" ]
249
+ },
250
+ "generation_servers" : {
251
+ "num_instances" : 1 ,
252
+ "urls" : ["localhost:8002" ]
253
+ }
254
+ }
255
+ with launch_disaggregated_llm (disaggregated_server_config ,
256
+ ctx_server_config , gen_server_config ,
257
+ self .MODEL_PATH ) as llm :
258
+ task = GSM8K (self .MODEL_NAME )
259
+ task .evaluate (llm )
260
+
212
261
213
- @pytest .mark .timeout (3600 )
214
262
@pytest .mark .skip_less_device_memory (140000 )
263
+ @pytest .mark .timeout (3600 )
215
264
class TestLlama4ScoutInstruct (LlmapiAccuracyTestHarness ):
216
265
MODEL_NAME = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
217
266
MODEL_PATH = f"{ llm_models_root ()} /llama4-models/Llama-4-Scout-17B-16E-Instruct"
0 commit comments