11import time
22from typing import List
3- import asyncio
43
54from .dynamic_batching .infer_batch import InferBatch
65from .dynamic_batching .io_struct import Batch , Req
98from .dynamic_batching .stats import Stats
109from .tensor_parallel import TPInferEngine
1110
12- from transformers import AutoTokenizer
13- _FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"
1411
1512class DynamicBatchManager :
1613 def __init__ (
@@ -57,20 +54,6 @@ def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, reques
5754 self .req_queue .append (req )
5855 return
5956
60- def add_input (self , request_id , sampling_params , input_ids ):
61- """
62- Encode and Add new input to req queue. support one sequence input for now.
63- """
64- prompt_ids = self .tokenizer .encode (input_ids )
65- prompt_len = len (prompt_ids )
66- if prompt_len > self .engine .max_input_len :
67- raise ValueError (
68- f"the input prompt token len { prompt_len } is too long > { self .engine .max_input_len } "
69- )
70- sampling_params .stop_sentences_to_token_ids (self .tokenizer )
71- self .add_req (prompt_ids , sampling_params , request_id )
72- return
73-
7457 def abort (self , request_id ):
7558 if self .running_batch is not None :
7659 for req in self .running_batch .reqs :
@@ -83,15 +66,13 @@ def abort(self, request_id):
8366 req .aborted = True
8467 return
8568
86- async def loop_for_fwd (self ):
69+ def loop_for_fwd (self ):
8770 """
8871 The main loop for a dynamic batching process.
8972 """
9073 counter_count = 0
91- #self.running_batch is not None or self.req_queue.waiting_req_list
92- while True :
93- async for item in self ._step ():
94- yield item
74+ while self .running_batch is not None or self .req_queue .waiting_req_list :
75+ self ._step ()
9576 counter_count += 1
9677 if self .running_batch is not None :
9778 if counter_count % self .mem_usage_interval == 0 :
@@ -106,26 +87,6 @@ async def loop_for_fwd(self):
10687 if self .running_batch is None :
10788 time .sleep (0.1 ) # 10ms
10889
109- def _set_tokenizer (self , tokenizer = None , tokenizer_name : str = "" , trust_remote_code : bool = False , use_fast :bool = True ,):
110- if tokenizer is not None :
111- self .tokenizer = tokenizer
112- else :
113- if "llama" in tokenizer_name .lower () and use_fast == True :
114- print (
115- "For some LLaMA-based models, initializing the fast tokenizer may "
116- "take a long time. To eliminate the initialization time, consider "
117- f"using '{ _FAST_LLAMA_TOKENIZER } ' instead of the original "
118- "tokenizer. This is done automatically in Colossalai." )
119-
120- tokenizer_name = _FAST_LLAMA_TOKENIZER
121-
122- try :
123- self .tokenizer = AutoTokenizer .from_pretrained (tokenizer_name , use_fast = use_fast ,trust_remote_code = trust_remote_code )
124- except TypeError as e :
125- use_fast = False
126- self .tokenizer = AutoTokenizer .from_pretrained (tokenizer_name , use_fast = use_fast ,trust_remote_code = trust_remote_code )
127-
128-
12990 def _step (self ):
13091 """
13192 Logic for handling requests
@@ -136,33 +97,32 @@ def _step(self):
13697 if new_batch is not None :
13798 self .stats_tool .count_prompt_tokens (new_batch )
13899 self .running_batch = new_batch
139- yield from self ._prefill_batch (self .running_batch )
100+ self ._prefill_batch (self .running_batch )
140101 self ._filter_runing_batch ()
141102 self .has_wait_tokens = 0
142103 return
143104
144105 if self .has_wait_tokens < self .max_wait_tokens :
145106 self .stats_tool .count_output_tokens (self .running_batch )
146- yield from self ._decode_batch (self .running_batch )
107+ self ._decode_batch (self .running_batch )
147108 self ._filter_runing_batch ()
148109 self .has_wait_tokens += 1
149110 return
150111 else :
151112 new_mini_batch = self .req_queue .generate_new_batch (self .running_batch )
152113 if new_mini_batch is not None :
153114 self .stats_tool .count_prompt_tokens (new_mini_batch )
154- yield from self ._prefill_batch (new_mini_batch )
115+ self ._prefill_batch (new_mini_batch )
155116 if not new_mini_batch .is_clear ():
156117 self ._merge_batch (self .running_batch , new_mini_batch )
157118 self .running_batch .merge (new_mini_batch )
158119 self .has_wait_tokens = 0
159-
160120 else :
161121 self .stats_tool .count_output_tokens (self .running_batch )
162- yield from self ._decode_batch (self .running_batch )
122+ self ._decode_batch (self .running_batch )
163123 self ._filter_runing_batch ()
164124 self .has_wait_tokens += 1
165-
125+
166126 return
167127
168128 def _init_batch (self , batch : Batch , dtype = "fp16" ):
@@ -198,8 +158,7 @@ def _prefill_batch(self, batch):
198158 req_to_out_token_id = ans
199159 self ._add_token_id_to_req (batch , req_to_out_token_id )
200160 has_new_finished_req = batch .mark_finished_req (self .eos_id )
201- yield from self ._handle_finish_req (batch , has_new_finished_req )
202-
161+ self ._handle_finish_req (batch , has_new_finished_req )
203162 # delete finished reqs
204163
205164 def _decode_batch (self , batch : Batch ):
@@ -210,7 +169,7 @@ def _decode_batch(self, batch: Batch):
210169 req_to_out_token_id = ans
211170 self ._add_token_id_to_req (batch , req_to_out_token_id )
212171 has_new_finished_req = batch .mark_finished_req (self .eos_id )
213- yield from self ._handle_finish_req (batch , has_new_finished_req )
172+ self ._handle_finish_req (batch , has_new_finished_req )
214173
215174 def _filter_batch (self , batch : Batch ):
216175 batch_id = batch .batch_id
@@ -242,13 +201,11 @@ def _remove_batch(self, batch):
242201
243202 def _handle_finish_req (self , batch : Batch , has_new_finished_req ):
244203 if has_new_finished_req :
245- finished_reqs = batch .filter_finished ()
204+ batch .filter_finished ()
246205 if batch .is_clear ():
247206 self ._remove_batch (batch )
248207 else :
249208 self ._filter_batch (batch )
250- yield from self ._output_process (finished_reqs )
251-
252209
253210 def _filter_runing_batch (self ):
254211 if self .running_batch is not None and self .running_batch .is_clear ():
@@ -261,47 +218,26 @@ def _add_token_id_to_req(self, batch: Batch, req_ans):
261218 req .output_metadata_list .append (new_gen_metadata )
262219 return
263220
264- async def _output_process (self , finished_reqs : List [Req ]):
265- """
266- Process the output of a batch.
267- """
268- for req in finished_reqs :
269- output = self .tokenizer .decode (req .output_ids )
270- yield output , req .request_id , req .output_metadata_list
271-
272221 def clean_up (self ):
273222 # this logic should be implemented in the future.
274223 pass
275224
276- async def generate (self ,request_id ,prompt_id ,sampling_params ):
277- """
278- Generate the output of a request.
279- """
280- self .add_input (request_id ,prompt_id ,sampling_params )
281-
282225
283226def start_dynamic_batching (args , tp_engine , waiting_req_list ):
284- try :
285- batch_manager = DynamicBatchManager (
286- tp_engine = tp_engine ,
287- max_total_token_num = args .max_total_token_num ,
288- batch_max_tokens = args .batch_max_tokens ,
289- eos_id = args .eos_id ,
290- log_stats = not args .disable_log_stats ,
291- log_stats_interval = args .log_stats_interval ,
292- waiting_req_list = waiting_req_list ,
293- )
294-
295- except Exception :
296- batch_manager .clean_up ()
297- raise
298-
299- batch_manager ._set_tokenizer (tokenizer_name = tp_engine .model .__class__ .__name__ )
300- prod_task = asyncio .create_task (batch_manager .add_input (4 ,sampling_params = SamplingParams (),input_ids = "hello world" ))
301-
302- asyncio .run (prod_task )
303-
304- for item in batch_manager .loop_for_fwd ():
305- print (item )
306-
307- return batch_manager
227+ # try:
228+ batch_manager = DynamicBatchManager (
229+ tp_engine = tp_engine ,
230+ max_total_token_num = args .max_total_token_num ,
231+ batch_max_tokens = args .batch_max_tokens ,
232+ eos_id = args .eos_id ,
233+ log_stats = not args .disable_log_stats ,
234+ log_stats_interval = args .log_stats_interval ,
235+ waiting_req_list = waiting_req_list ,
236+ )
237+
238+ # except Exception:
239+ # batch_manager.clean_up()
240+ # raise
241+
242+ batch_manager .loop_for_fwd ()
243+ return
0 commit comments