11import  time 
22from  typing  import  List 
3+ import  asyncio 
34
45from  .dynamic_batching .infer_batch  import  InferBatch 
56from  .dynamic_batching .io_struct  import  Batch , Req 
89from  .dynamic_batching .stats  import  Stats 
910from  .tensor_parallel  import  TPInferEngine 
1011
12+ from  transformers  import  AutoTokenizer 
13+ _FAST_LLAMA_TOKENIZER  =  "hf-internal-testing/llama-tokenizer" 
1114
1215class  DynamicBatchManager :
1316    def  __init__ (
@@ -54,6 +57,20 @@ def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, reques
5457        self .req_queue .append (req )
5558        return 
5659
60+     def  add_input (self , request_id , sampling_params , input_ids ):
61+         """ 
62+         Encode and Add new input to req queue. support one sequence input for now. 
63+         """ 
64+         prompt_ids  =  self .tokenizer .encode (input_ids )
65+         prompt_len  =  len (prompt_ids )
66+         if  prompt_len  >  self .engine .max_input_len :
67+             raise  ValueError (
68+                 f"the input prompt token len { prompt_len } { self .engine .max_input_len }  
69+             )
70+         sampling_params .stop_sentences_to_token_ids (self .tokenizer )
71+         self .add_req (prompt_ids , sampling_params , request_id )
72+         return 
73+      
5774    def  abort (self , request_id ):
5875        if  self .running_batch  is  not None :
5976            for  req  in  self .running_batch .reqs :
@@ -66,13 +83,15 @@ def abort(self, request_id):
6683                req .aborted  =  True 
6784        return 
6885
69-     def  loop_for_fwd (self ):
86+     async   def  loop_for_fwd (self ):
7087        """ 
7188        The main loop for a dynamic batching process. 
7289        """ 
7390        counter_count  =  0 
74-         while  self .running_batch  is  not None  or  self .req_queue .waiting_req_list :
75-             self ._step ()
91+         #self.running_batch is not None or self.req_queue.waiting_req_list 
92+         while  True :
93+             async  for  item  in  self ._step ():
94+                 yield  item 
7695            counter_count  +=  1 
7796            if  self .running_batch  is  not None :
7897                if  counter_count  %  self .mem_usage_interval  ==  0 :
@@ -87,6 +106,26 @@ def loop_for_fwd(self):
87106            if  self .running_batch  is  None :
88107                time .sleep (0.1 )  # 10ms 
89108
109+     def  _set_tokenizer (self , tokenizer = None , tokenizer_name : str  =  "" , trust_remote_code : bool  =  False , use_fast :bool  =  True ,):
110+         if  tokenizer  is  not None :
111+             self .tokenizer  =  tokenizer  
112+         else :
113+             if  "llama"  in  tokenizer_name .lower () and  use_fast  ==  True :
114+                 print (
115+                 "For some LLaMA-based models, initializing the fast tokenizer may " 
116+                 "take a long time. To eliminate the initialization time, consider " 
117+                 f"using '{ _FAST_LLAMA_TOKENIZER }  
118+                 "tokenizer. This is done automatically in Colossalai." )
119+                 
120+                 tokenizer_name  =  _FAST_LLAMA_TOKENIZER   
121+         
122+             try : 
123+                 self .tokenizer  =  AutoTokenizer .from_pretrained (tokenizer_name , use_fast = use_fast ,trust_remote_code = trust_remote_code )
124+             except  TypeError  as  e :
125+                 use_fast  =  False 
126+                 self .tokenizer  =  AutoTokenizer .from_pretrained (tokenizer_name , use_fast = use_fast ,trust_remote_code = trust_remote_code )
127+ 
128+ 
90129    def  _step (self ):
91130        """ 
92131        Logic for handling requests 
@@ -97,32 +136,33 @@ def _step(self):
97136            if  new_batch  is  not None :
98137                self .stats_tool .count_prompt_tokens (new_batch )
99138                self .running_batch  =  new_batch 
100-                 self ._prefill_batch (self .running_batch )
139+                 yield   from   self ._prefill_batch (self .running_batch )
101140                self ._filter_runing_batch ()
102141                self .has_wait_tokens  =  0 
103142            return 
104143
105144        if  self .has_wait_tokens  <  self .max_wait_tokens :
106145            self .stats_tool .count_output_tokens (self .running_batch )
107-             self ._decode_batch (self .running_batch )
146+             yield   from   self ._decode_batch (self .running_batch )
108147            self ._filter_runing_batch ()
109148            self .has_wait_tokens  +=  1 
110149            return 
111150        else :
112151            new_mini_batch  =  self .req_queue .generate_new_batch (self .running_batch )
113152            if  new_mini_batch  is  not None :
114153                self .stats_tool .count_prompt_tokens (new_mini_batch )
115-                 self ._prefill_batch (new_mini_batch )
154+                 yield   from   self ._prefill_batch (new_mini_batch )
116155                if  not  new_mini_batch .is_clear ():
117156                    self ._merge_batch (self .running_batch , new_mini_batch )
118157                    self .running_batch .merge (new_mini_batch )
119158                self .has_wait_tokens  =  0 
159+                 
120160            else :
121161                self .stats_tool .count_output_tokens (self .running_batch )
122-                 self ._decode_batch (self .running_batch )
162+                 yield   from   self ._decode_batch (self .running_batch )
123163                self ._filter_runing_batch ()
124164                self .has_wait_tokens  +=  1 
125- 
165+           
126166        return 
127167
128168    def  _init_batch (self , batch : Batch , dtype = "fp16" ):
@@ -158,7 +198,8 @@ def _prefill_batch(self, batch):
158198        req_to_out_token_id  =  ans 
159199        self ._add_token_id_to_req (batch , req_to_out_token_id )
160200        has_new_finished_req  =  batch .mark_finished_req (self .eos_id )
161-         self ._handle_finish_req (batch , has_new_finished_req )
201+         yield  from  self ._handle_finish_req (batch , has_new_finished_req )
202+         
162203        # delete finished reqs 
163204
164205    def  _decode_batch (self , batch : Batch ):
@@ -169,7 +210,7 @@ def _decode_batch(self, batch: Batch):
169210        req_to_out_token_id  =  ans 
170211        self ._add_token_id_to_req (batch , req_to_out_token_id )
171212        has_new_finished_req  =  batch .mark_finished_req (self .eos_id )
172-         self ._handle_finish_req (batch , has_new_finished_req )
213+         yield   from   self ._handle_finish_req (batch , has_new_finished_req )
173214
174215    def  _filter_batch (self , batch : Batch ):
175216        batch_id  =  batch .batch_id 
@@ -201,11 +242,13 @@ def _remove_batch(self, batch):
201242
202243    def  _handle_finish_req (self , batch : Batch , has_new_finished_req ):
203244        if  has_new_finished_req :
204-             batch .filter_finished ()
245+             finished_reqs = batch .filter_finished ()
205246            if  batch .is_clear ():
206247                self ._remove_batch (batch )
207248            else :
208249                self ._filter_batch (batch )
250+             yield  from  self ._output_process (finished_reqs )
251+ 
209252
210253    def  _filter_runing_batch (self ):
211254        if  self .running_batch  is  not None  and  self .running_batch .is_clear ():
@@ -218,26 +261,47 @@ def _add_token_id_to_req(self, batch: Batch, req_ans):
218261            req .output_metadata_list .append (new_gen_metadata )
219262        return 
220263
264+     async  def  _output_process (self , finished_reqs : List [Req ]):
265+         """ 
266+         Process the output of a batch. 
267+         """ 
268+         for  req  in  finished_reqs :
269+             output  =  self .tokenizer .decode (req .output_ids )
270+             yield  output , req .request_id , req .output_metadata_list  
271+ 
221272    def  clean_up (self ):
222273        # this logic should be implemented in the future. 
223274        pass 
224275
276+     async  def  generate (self ,request_id ,prompt_id ,sampling_params ):
277+         """ 
278+         Generate the output of a request. 
279+         """ 
280+         self .add_input (request_id ,prompt_id ,sampling_params )
281+     
225282
226283def  start_dynamic_batching (args , tp_engine , waiting_req_list ):
227-     # try: 
228-     batch_manager  =  DynamicBatchManager (
229-         tp_engine = tp_engine ,
230-         max_total_token_num = args .max_total_token_num ,
231-         batch_max_tokens = args .batch_max_tokens ,
232-         eos_id = args .eos_id ,
233-         log_stats = not  args .disable_log_stats ,
234-         log_stats_interval = args .log_stats_interval ,
235-         waiting_req_list = waiting_req_list ,
236-     )
237- 
238-     # except Exception: 
239-     #     batch_manager.clean_up() 
240-     #     raise 
241- 
242-     batch_manager .loop_for_fwd ()
243-     return 
284+     try :
285+         batch_manager  =  DynamicBatchManager (
286+             tp_engine = tp_engine ,
287+             max_total_token_num = args .max_total_token_num ,
288+             batch_max_tokens = args .batch_max_tokens ,
289+             eos_id = args .eos_id ,
290+             log_stats = not  args .disable_log_stats ,
291+             log_stats_interval = args .log_stats_interval ,
292+             waiting_req_list = waiting_req_list ,
293+         )
294+ 
295+     except  Exception :
296+         batch_manager .clean_up ()
297+         raise 
298+     
299+     batch_manager ._set_tokenizer (tokenizer_name  =  tp_engine .model .__class__ .__name__ )
300+     prod_task  =  asyncio .create_task (batch_manager .add_input (4 ,sampling_params = SamplingParams (),input_ids = "hello world" ))
301+ 
302+     asyncio .run (prod_task )
303+     
304+     for  item  in  batch_manager .loop_for_fwd ():
305+         print (item )
306+ 
307+     return  batch_manager 
0 commit comments