@@ -19,6 +19,7 @@ def __init__(
1919        max_total_token_num ,
2020        batch_max_tokens ,
2121        eos_id ,
22+         model ,
2223        log_stats = True ,
2324        log_stats_interval = 10 ,
2425        running_batch : Batch  =  None ,
@@ -30,6 +31,7 @@ def __init__(
3031                batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests 
3132                running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine 
3233                eos_id : The end token of a seq 
34+                 model: the model weight dir path, the app will load config, weights and tokenizer from this dir 
3335                log_stats : whether to log stats 
3436                log_stats_interval : log stats interval 
3537                running_batch : running batch 
@@ -45,31 +47,32 @@ def __init__(
4547        self .eos_id  =  eos_id 
4648        self .has_wait_tokens  =  0 
4749        self .max_wait_tokens  =  10 
50+         self .model  =  model 
4851
4952        self .stats_tool  =  Stats (log_stats , log_stats_interval )
5053        self .mem_usage_interval  =  log_stats_interval  *  2 
54+         self ._set_tokenizer (tokenizer_name = self .model )
5155
52-     def  add_req (self , prompt_ids : List [int ], sampling_params : SamplingParams , request_id : str ):
56+     def  add_req (self , prompt_ids : List [int ], sampling_params : SamplingParams , request_id : str ,  prompts :  str ):
5357        """ 
5458        Add new request to req queue, during initialization all requests are held in waiting list. 
5559        """ 
56-         req  =  Req (request_id , prompt_ids , sampling_params )
60+         req  =  Req (request_id , prompt_ids , sampling_params ,  prompts )
5761        self .req_queue .append (req )
5862        return 
5963
60-     def  add_input (self , request_id , sampling_params , input_ids ):
64+     def  add_input (self , request_id , sampling_params , prompts ):
6165        """ 
6266        Encode and Add new input to req queue. support one sequence input for now. 
6367        """ 
64-         prompt_ids  =  self .tokenizer .encode (input_ids )
68+         prompt_ids  =  self .tokenizer .encode (prompts )
6569        prompt_len  =  len (prompt_ids )
66-         print (prompt_ids )
6770        if  prompt_len  >  self .engine .max_input_len :
6871            raise  ValueError (
6972                f"the input prompt token len { prompt_len } { self .engine .max_input_len }  
7073            )
7174        sampling_params .stop_sentences_to_token_ids (self .tokenizer )
72-         self .add_req (prompt_ids , sampling_params , request_id )
75+         self .add_req (prompt_ids , sampling_params , request_id ,  prompts )
7376        return 
7477
7578    def  abort (self , request_id ):
@@ -90,7 +93,7 @@ def loop_for_fwd(self):
9093        """ 
9194        counter_count  =  0 
9295        #self.running_batch is not None or self.req_queue.waiting_req_list 
93-         while  True :
96+         while  self . running_batch   is   not   None   or   self . req_queue . waiting_req_list :
9497            yield  from  self ._step ()
9598            counter_count  +=  1 
9699            if  self .running_batch  is  not None :
@@ -267,17 +270,17 @@ def _output_process(self, finished_reqs: List[Req]):
267270        """ 
268271        for  req  in  finished_reqs :
269272            output  =  self .tokenizer .decode (req .output_ids )
270-             yield  output ,  req .request_id ,  req . output_metadata_list   
273+             yield  req .prompts +   output 
271274
272275    def  clean_up (self ):
273276        # this logic should be implemented in the future. 
274277        pass 
275278
276-     def  generate (self ,request_id , prompt_id , sampling_params ):
279+     def  generate (self ,prompts , sampling_params , request_id ):
277280        """ 
278281        Generate the output of a request. 
279282        """ 
280-         self .add_input (request_id ,prompt_id , sampling_params )
283+         self .add_input (request_id ,sampling_params , prompts )
281284        return  self .loop_for_fwd ()
282285
283286def  start_dynamic_batching (args , tp_engine , waiting_req_list ):
@@ -287,6 +290,7 @@ def start_dynamic_batching(args, tp_engine, waiting_req_list):
287290            max_total_token_num = args .max_total_token_num ,
288291            batch_max_tokens = args .batch_max_tokens ,
289292            eos_id = args .eos_id ,
293+             model = args .model ,
290294            log_stats = not  args .disable_log_stats ,
291295            log_stats_interval = args .log_stats_interval ,
292296            waiting_req_list = waiting_req_list ,
0 commit comments