2020from tensorrt_llm .llmapi import CompletionOutput , RequestOutput , SamplingParams
2121from tensorrt_llm .llmapi .llm_args import LlmArgs
2222
23- from ..conftest import llm_models_root , parametrize_with_ids , skip_pre_hopper
23+ from ..conftest import (get_device_count , llm_models_root , parametrize_with_ids ,
24+ skip_pre_hopper )
2425from ..trt_test_alternative import popen
25- from .accuracy_core import GSM8K , MMLU , LlmapiAccuracyTestHarness
26+ from .accuracy_core import (GSM8K , MMLU , LlmapiAccuracyTestHarness ,
27+ get_accuracy_task )
2628
2729
2830class Result (GenerationResultBase ):
@@ -71,6 +73,12 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
7173 temp_dir = tempfile .TemporaryDirectory ()
7274 disaggregated_serving_config_path = os .path .join (
7375 temp_dir .name , "disaggregated_serving_config.yaml" )
76+
77+ if tensor_parallel_size > 1 :
78+ print (
79+ f"Using unified tp parameter for testing is not recommended. Please use server configs instead."
80+ )
81+
7482 with open (disaggregated_serving_config_path , "w" ) as f :
7583 yaml .dump (disaggregated_server_config , f )
7684 ctx_server_config_path = os .path .join (temp_dir .name ,
@@ -88,27 +96,40 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
8896 trtllm_serve_path = "trtllm-serve"
8997 # Common arguments for both servers
9098 common_args = [
91- trtllm_serve_path , model_name , "--host" , "localhost" , "--backend" ,
92- "pytorch"
99+ trtllm_serve_path ,
100+ model_name ,
101+ "--host" ,
102+ "localhost" ,
103+ "--backend" ,
104+ "pytorch" ,
93105 ]
94-
95- if tensor_parallel_size > 1 :
96- common_args .append (f"--tp_size={ tensor_parallel_size } " )
106+ gen_tp , gen_pp = gen_server_config .get (
107+ "tensor_parallel_size" ,
108+ tensor_parallel_size ), gen_server_config .get ("pipeline_parallel_size" ,
109+ 1 )
110+ ctx_tp , ctx_pp = ctx_server_config .get (
111+ "tensor_parallel_size" ,
112+ tensor_parallel_size ), ctx_server_config .get ("pipeline_parallel_size" ,
113+ 1 )
114+
115+ ctx_total_gpus = ctx_tp * ctx_pp
116+ gen_total_gpus = gen_tp * gen_pp
97117
98118 env_ctx = os .environ .copy ()
99119 env_ctx ["TRTLLM_USE_UCX_KVCACHE" ] = "1"
100- env_ctx ["CUDA_VISIBLE_DEVICES" ] = "," .join (
101- map (str , range (tensor_parallel_size )))
120+ env_ctx ["CUDA_VISIBLE_DEVICES" ] = "," .join (map (str , range (ctx_total_gpus )))
102121
103122 env_gen = os .environ .copy ()
104123 env_gen ["TRTLLM_USE_UCX_KVCACHE" ] = "1"
105124 env_gen ["CUDA_VISIBLE_DEVICES" ] = "," .join (
106- map (str , range (tensor_parallel_size , 2 * tensor_parallel_size )))
125+ map (str , range (ctx_total_gpus , ctx_total_gpus + gen_total_gpus )))
107126 ctx_server_args = common_args + [
108- "--port" , "8001" , "--extra_llm_api_options" , ctx_server_config_path
127+ "--port" , "8001" , "--extra_llm_api_options" , ctx_server_config_path ,
128+ f"--tp_size={ ctx_tp } " , f"--pp_size={ ctx_pp } "
109129 ]
110130 gen_server_args = common_args + [
111- "--port" , "8002" , "--extra_llm_api_options" , gen_server_config_path
131+ "--port" , "8002" , "--extra_llm_api_options" , gen_server_config_path ,
132+ f"--tp_size={ gen_tp } " , f"--pp_size={ gen_pp } "
112133 ]
113134 if "max_num_tokens" in ctx_server_config :
114135 ctx_server_args .append (
@@ -182,6 +203,56 @@ def generate_async(prompt: str,
182203 disaggregated_server .wait ()
183204
184205
206+ def run_parallel_test (model_name : str , model_path : str , ctx_pp : int ,
207+ ctx_tp : int , gen_pp : int , gen_tp : int ,
208+ test_set : LlmapiAccuracyTestHarness ):
209+ if ctx_tp * ctx_pp + gen_tp * gen_pp > get_device_count ():
210+ pytest .fail (
211+ f"Not enough devices for ctx_pp={ ctx_pp } +ctx_tp={ ctx_tp } and gen_pp={ gen_pp } +gen_tp={ gen_tp } test"
212+ )
213+
214+ kv_cache_config = {
215+ "free_gpu_memory_fraction" : 0.5 ,
216+ "enable_block_reuse" : False
217+ }
218+ ctx_server_config = {
219+ "pipeline_parallel_size" : ctx_pp ,
220+ "tensor_parallel_size" : ctx_tp ,
221+ "disable_overlap_scheduler" : True ,
222+ "kv_cache_config" : kv_cache_config ,
223+ "cache_transceiver_config" : {
224+ "backend" : "default"
225+ }
226+ }
227+ gen_server_config = {
228+ "tensor_parallel_size" : gen_tp ,
229+ "pipeline_parallel_size" : gen_pp ,
230+ "disable_overlap_scheduler" : True ,
231+ "kv_cache_config" : kv_cache_config ,
232+ "cache_transceiver_config" : {
233+ "backend" : "default"
234+ }
235+ }
236+ disaggregated_server_config = {
237+ "hostname" : "localhost" ,
238+ "port" : 8000 ,
239+ "backend" : "pytorch" ,
240+ "context_servers" : {
241+ "num_instances" : 1 ,
242+ "urls" : ["localhost:8001" ]
243+ },
244+ "generation_servers" : {
245+ "num_instances" : 1 ,
246+ "urls" : ["localhost:8002" ]
247+ }
248+ }
249+ with launch_disaggregated_llm (disaggregated_server_config ,
250+ ctx_server_config , gen_server_config ,
251+ model_path ) as llm :
252+ task = test_set (model_name )
253+ task .evaluate (llm )
254+
255+
185256@pytest .mark .timeout (3600 )
186257class TestLlama3_1_8BInstruct (LlmapiAccuracyTestHarness ):
187258 MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
@@ -315,6 +386,20 @@ def test_eagle3(self, overlap_scheduler):
315386 task = GSM8K (self .MODEL_NAME )
316387 task .evaluate (llm )
317388
389+ @pytest .mark .parametrize ("tp,pp" , [(1 , 2 ), (2 , 1 ), (2 , 2 )],
390+ ids = ["tp1pp2" , "tp2pp1" , "tp2pp2" ])
391+ @pytest .mark .parametrize ("testset" , ["GSM8K" , "MMLU" ])
392+ def test_tp_pp_symmetric (self , tp , pp , testset ):
393+ return run_parallel_test (self .MODEL_NAME , self .MODEL_PATH , pp , tp , pp ,
394+ tp , get_accuracy_task (testset ))
395+
396+ @parametrize_with_ids ("ctx_pp" , [2 , 4 ])
397+ @parametrize_with_ids ("gen_tp" , [1 , 2 ])
398+ @pytest .mark .parametrize ("testset" , ["GSM8K" , "MMLU" ])
399+ def test_ctx_pp_gen_tp_asymmetric (self , ctx_pp , gen_tp , testset ):
400+ return run_parallel_test (self .MODEL_NAME , self .MODEL_PATH , ctx_pp , 1 , 1 ,
401+ gen_tp , get_accuracy_task (testset ))
402+
318403
319404@pytest .mark .skip_less_device_memory (140000 )
320405@pytest .mark .timeout (3600 )
0 commit comments