44import time
55from concurrent .futures import ProcessPoolExecutor
66
7+ from test_worker_base import TestWorkerBase
8+
79from tensorrt_llm .executor .request import GenerationRequest
810from tensorrt_llm .executor .rpc import RPCClient
911from tensorrt_llm .executor .rpc_proxy import GenerationExecutorRpcProxy
2022
2123class TestRpcWorker :
2224
25+ def __init__ (self ):
26+ self .executor_config = TestWorkerBase .create_fake_executor_config (
27+ model_path )
28+
2329 def create_tp1_worker_process (self ):
2430 addr = GenerationExecutorRpcProxy .gen_uniq_rpc_addr ()
2531 # Use spawn method instead of fork
2632 mp_context = multiprocessing .get_context ('spawn' )
2733 pool = ProcessPoolExecutor (max_workers = 1 , mp_context = mp_context )
28- pool .submit (RpcWorker .main_task , engine = model_path , rpc_addr = addr )
34+ pool .submit (RpcWorker .main_task ,
35+ engine = model_path ,
36+ rpc_addr = addr ,
37+ executor_config = self .executor_config )
2938 return pool , addr
3039
3140 def create_rpc_client (self , addr : str ):
@@ -35,15 +44,53 @@ def create_rpc_client(self, addr: str):
3544 def test_main (self ):
3645 pool , addr = self .create_tp1_worker_process ()
3746 client = self .create_rpc_client (addr )
38- client .setup_engine (engine = model_path )
47+ print ("call setup_engine" )
48+ client .setup_engine (engine = model_path ,
49+ executor_config = self .executor_config ,
50+ __rpc_timeout = 120 )
51+ print ("call submit" )
3952 time .sleep (1 )
40- client .submit (
41- GenerationRequest (prompt_token_ids = [3 , 4 , 5 ],
42- sampling_params = SamplingParams (max_tokens = 10 )))
43- responses = client .fetch_responses ()
44- assert responses
4553
46- client .shutdown ()
54+ def process_request ():
55+ ret = client .submit (GenerationRequest (
56+ prompt_token_ids = [3 , 4 , 5 ],
57+ sampling_params = SamplingParams (max_tokens = 10 )),
58+ __rpc_need_response = False )
59+ assert ret is None
60+
61+ print (f"submit result: { ret } " )
62+ print ("call fetch_responses" )
63+ # NOTE: known issue, the responses should be fetched before shutdown,
64+ # or the shutdown will hang.
65+ results = []
66+ for i in range (3 ):
67+ time .sleep (3 )
68+ results .extend (client .fetch_responses ())
69+ print (f"fetch_responses result: { results } " )
70+ assert len (results ) == 1
71+
72+ def process_request_streaming ():
73+ ret = client .submit (prompt_token_ids = [3 , 4 , 5 ],
74+ sampling_params = SamplingParams (max_tokens = 10 ),
75+ streaming = True ,
76+ __rpc_need_response = False )
77+ assert ret is None
78+
79+ print ("call fetch_responses" )
80+ # NOTE: known issue, the responses should be fetched before shutdown,
81+ # or the shutdown will hang.
82+ results = []
83+ for i in range (3 ):
84+ time .sleep (3 )
85+ results .extend (client .fetch_responses ())
86+ print (f"fetch_responses result: { results } " )
87+ print (f"generate_async result: { results } " )
88+
89+ process_request ()
90+ process_request_streaming ()
91+
92+ print ("call shutdown" )
93+ client .shutdown (__rpc_timeout = 10 )
4794 pool .shutdown ()
4895
4996
0 commit comments