1212from tensorrt_llm import LLM , DisaggregatedParams , SamplingParams
1313from tensorrt_llm ._utils import set_mpi_comm
1414from tensorrt_llm .llmapi import CudaGraphConfig , KvCacheConfig , MpiCommSession
15+ from tensorrt_llm .llmapi .llm_args import EagleDecodingConfig
1516
1617cloudpickle .register_pickle_by_value (sys .modules [__name__ ])
1718MPI .pickle .__init__ (
@@ -33,6 +34,11 @@ def model_path(model_name):
3334 elif 'TinyLlama-1.1B-Chat-v1.0' in model_name :
3435 return os .path .join (llm_models_root , 'llama-models-v2' ,
3536 'TinyLlama-1.1B-Chat-v1.0' )
37+ elif 'Llama-3.1-8B-Instruct' in model_name :
38+ return os .path .join (llm_models_root , 'llama-3.1-model' ,
39+ 'Llama-3.1-8B-Instruct/' )
40+ elif 'EAGLE3-LLaMA3.1-Instruct-8B' in model_name :
41+ return os .path .join (llm_models_root , 'EAGLE3-LLaMA3.1-Instruct-8B' )
3642 else :
3743 raise ValueError (f"Unknown model: { model_name } " )
3844
@@ -317,5 +323,106 @@ def test_disaggregated_llama_context_capacity(model, enable_cuda_graph,
317323 print ("All workers terminated." )
318324
319325
326+ @pytest .mark .parametrize ("model" , ["Llama-3.1-8B-Instruct" ])
327+ @pytest .mark .parametrize ("spec_dec_model_path" , ["EAGLE3-LLaMA3.1-Instruct-8B" ])
328+ @pytest .mark .parametrize ("generation_overlap" , [False ])
329+ def test_disaggregated_spec_dec_batch_slot_limit (model , spec_dec_model_path ,
330+ generation_overlap ):
331+ # Test whether the batch slots are properly released when using speculative decoding
332+ # with disaggregated serving.
333+ spec_dec_config = EagleDecodingConfig (
334+ speculative_model_dir = model_path (spec_dec_model_path ),
335+ eagle3_one_model = False ,
336+ max_draft_len = 3 )
337+
338+ worker_pytorch_configs = []
339+
340+ # Context worker
341+ worker_pytorch_configs .append (
342+ dict (disable_overlap_scheduler = True ,
343+ kv_cache_dtype = "auto" ,
344+ speculative_config = spec_dec_config ,
345+ max_batch_size = 1 ))
346+
347+ # Generation worker
348+ worker_pytorch_configs .append (
349+ dict (disable_overlap_scheduler = not generation_overlap ,
350+ kv_cache_dtype = "auto" ,
351+ speculative_config = spec_dec_config ,
352+ max_batch_size = 1 ))
353+
354+ kv_cache_configs = [
355+ KvCacheConfig (max_tokens = 128 , enable_block_reuse = False )
356+ for _ in range (2 )
357+ ]
358+ model_names = [model_path (model ) for _ in range (2 )]
359+ ranks = [0 , 1 ]
360+ worker_args = list (
361+ zip (kv_cache_configs , worker_pytorch_configs , model_names , ranks ))
362+
363+ port_name = MPI .Open_port ()
364+ MPI .Publish_name ('my_port' , port_name )
365+
366+ prompt = "What is the capital of Germany?"
367+
368+ with MPIPoolExecutor (max_workers = 2 , env = {"TRTLLM_USE_MPI_KVCACHE" :
369+ "1" }) as executor :
370+ futures = []
371+ try :
372+ for worker_arg in worker_args :
373+ future = executor .submit (worker_entry_point , * worker_arg )
374+ futures .append (future )
375+ except Exception as e :
376+ print (f"Error in worker { worker_arg } : { e } " )
377+ raise e
378+
379+ try :
380+ print ("Launched all the workers." )
381+ intercomm = MPI .COMM_SELF .Accept (port_name )
382+
383+ for _ in range (2 ):
384+ intercomm .recv (tag = MPI_READY )
385+ print ("Received ready signal." )
386+ max_tokens = 25
387+
388+ requests = []
389+ for _ in range (10 ):
390+ requests .append (
391+ (prompt , SamplingParams (max_tokens = 1 , ignore_eos = True ),
392+ DisaggregatedParams (request_type = "context_only" )))
393+
394+ intercomm .send (requests , dest = 0 , tag = MPI_REQUEST )
395+
396+ for _ in range (len (requests )):
397+ output = intercomm .recv (source = 0 , tag = MPI_RESULT )
398+ assert output [0 ].disaggregated_params is not None
399+ assert output [
400+ 0 ].disaggregated_params .request_type == "context_only"
401+ assert len (output [0 ].token_ids ) == 1
402+
403+ generation_request_disagg_params = output [
404+ 0 ].disaggregated_params
405+ generation_request_disagg_params .request_type = "generation_only"
406+ requests = []
407+ requests .append ((prompt ,
408+ SamplingParams (max_tokens = max_tokens ,
409+ ignore_eos = True ),
410+ generation_request_disagg_params ))
411+
412+ intercomm .send (requests , dest = 1 , tag = MPI_REQUEST )
413+ output = intercomm .recv (source = 1 , tag = MPI_RESULT )
414+
415+ finally :
416+ # Send termination requests
417+ intercomm .send (None , dest = 0 , tag = MPI_REQUEST )
418+ intercomm .send (None , dest = 1 , tag = MPI_REQUEST )
419+ print ("Sent termination requests to the workers." )
420+
421+ # Wait for all futures to complete
422+ for future in futures :
423+ future .result ()
424+ print ("All workers terminated." )
425+
426+
320427if __name__ == "__main__" :
321428 pytest .main ()
0 commit comments