Skip to content

Commit 78cd937

Browse files
authored
Revert "[inference] Async dynamic batching (#4894)" (#4909)
This reverts commit fced140.
1 parent d509e79 commit 78cd937

File tree

4 files changed

+32
-139
lines changed

4 files changed

+32
-139
lines changed

colossalai/inference/dynamic_batching/io_struct.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,21 +102,17 @@ def mark_finished_req(self, eos_id):
102102
has_new_finish = True
103103
return has_new_finish
104104

105-
def filter_finished(self)->List[Req]:
105+
def filter_finished(self):
106106
"""
107107
Filter finished requests from the batch, the finished ones will be removed from 'reqs'.
108108
"""
109109
# TODO: the logic of return should be defined here.
110110
unfinished_req = []
111-
finished_req = []
112111
for req in self.reqs:
113112
if not req.has_generate_finished:
114-
unfinished_req.append(req)
115-
else:
116-
finished_req.append(req)
113+
unfinished_req.append(req)
117114
self.reqs = unfinished_req
118115
self.id_to_reqs = {req.request_id: req for req in self.reqs}
119-
return finished_req
120116

121117
def is_clear(self):
122118
return len(self.reqs) == 0

colossalai/inference/manager.py

Lines changed: 28 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import time
22
from typing import List
3-
import asyncio
43

54
from .dynamic_batching.infer_batch import InferBatch
65
from .dynamic_batching.io_struct import Batch, Req
@@ -9,8 +8,6 @@
98
from .dynamic_batching.stats import Stats
109
from .tensor_parallel import TPInferEngine
1110

12-
from transformers import AutoTokenizer
13-
_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"
1411

1512
class DynamicBatchManager:
1613
def __init__(
@@ -57,20 +54,6 @@ def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, reques
5754
self.req_queue.append(req)
5855
return
5956

60-
def add_input(self, request_id, sampling_params, input_ids):
61-
"""
62-
Encode and Add new input to req queue. support one sequence input for now.
63-
"""
64-
prompt_ids = self.tokenizer.encode(input_ids)
65-
prompt_len = len(prompt_ids)
66-
if prompt_len > self.engine.max_input_len:
67-
raise ValueError(
68-
f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}"
69-
)
70-
sampling_params.stop_sentences_to_token_ids(self.tokenizer)
71-
self.add_req(prompt_ids, sampling_params, request_id)
72-
return
73-
7457
def abort(self, request_id):
7558
if self.running_batch is not None:
7659
for req in self.running_batch.reqs:
@@ -83,15 +66,13 @@ def abort(self, request_id):
8366
req.aborted = True
8467
return
8568

86-
async def loop_for_fwd(self):
69+
def loop_for_fwd(self):
8770
"""
8871
The main loop for a dynamic batching process.
8972
"""
9073
counter_count = 0
91-
#self.running_batch is not None or self.req_queue.waiting_req_list
92-
while True:
93-
async for item in self._step():
94-
yield item
74+
while self.running_batch is not None or self.req_queue.waiting_req_list:
75+
self._step()
9576
counter_count += 1
9677
if self.running_batch is not None:
9778
if counter_count % self.mem_usage_interval == 0:
@@ -106,26 +87,6 @@ async def loop_for_fwd(self):
10687
if self.running_batch is None:
10788
time.sleep(0.1) # 10ms
10889

109-
def _set_tokenizer(self, tokenizer=None, tokenizer_name: str = "", trust_remote_code: bool = False, use_fast:bool = True,):
110-
if tokenizer is not None:
111-
self.tokenizer = tokenizer
112-
else:
113-
if "llama" in tokenizer_name.lower() and use_fast == True:
114-
print(
115-
"For some LLaMA-based models, initializing the fast tokenizer may "
116-
"take a long time. To eliminate the initialization time, consider "
117-
f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
118-
"tokenizer. This is done automatically in Colossalai.")
119-
120-
tokenizer_name = _FAST_LLAMA_TOKENIZER
121-
122-
try:
123-
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code)
124-
except TypeError as e:
125-
use_fast = False
126-
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code)
127-
128-
12990
def _step(self):
13091
"""
13192
Logic for handling requests
@@ -136,33 +97,32 @@ def _step(self):
13697
if new_batch is not None:
13798
self.stats_tool.count_prompt_tokens(new_batch)
13899
self.running_batch = new_batch
139-
yield from self._prefill_batch(self.running_batch)
100+
self._prefill_batch(self.running_batch)
140101
self._filter_runing_batch()
141102
self.has_wait_tokens = 0
142103
return
143104

144105
if self.has_wait_tokens < self.max_wait_tokens:
145106
self.stats_tool.count_output_tokens(self.running_batch)
146-
yield from self._decode_batch(self.running_batch)
107+
self._decode_batch(self.running_batch)
147108
self._filter_runing_batch()
148109
self.has_wait_tokens += 1
149110
return
150111
else:
151112
new_mini_batch = self.req_queue.generate_new_batch(self.running_batch)
152113
if new_mini_batch is not None:
153114
self.stats_tool.count_prompt_tokens(new_mini_batch)
154-
yield from self._prefill_batch(new_mini_batch)
115+
self._prefill_batch(new_mini_batch)
155116
if not new_mini_batch.is_clear():
156117
self._merge_batch(self.running_batch, new_mini_batch)
157118
self.running_batch.merge(new_mini_batch)
158119
self.has_wait_tokens = 0
159-
160120
else:
161121
self.stats_tool.count_output_tokens(self.running_batch)
162-
yield from self._decode_batch(self.running_batch)
122+
self._decode_batch(self.running_batch)
163123
self._filter_runing_batch()
164124
self.has_wait_tokens += 1
165-
125+
166126
return
167127

168128
def _init_batch(self, batch: Batch, dtype="fp16"):
@@ -198,8 +158,7 @@ def _prefill_batch(self, batch):
198158
req_to_out_token_id = ans
199159
self._add_token_id_to_req(batch, req_to_out_token_id)
200160
has_new_finished_req = batch.mark_finished_req(self.eos_id)
201-
yield from self._handle_finish_req(batch, has_new_finished_req)
202-
161+
self._handle_finish_req(batch, has_new_finished_req)
203162
# delete finished reqs
204163

205164
def _decode_batch(self, batch: Batch):
@@ -210,7 +169,7 @@ def _decode_batch(self, batch: Batch):
210169
req_to_out_token_id = ans
211170
self._add_token_id_to_req(batch, req_to_out_token_id)
212171
has_new_finished_req = batch.mark_finished_req(self.eos_id)
213-
yield from self._handle_finish_req(batch, has_new_finished_req)
172+
self._handle_finish_req(batch, has_new_finished_req)
214173

215174
def _filter_batch(self, batch: Batch):
216175
batch_id = batch.batch_id
@@ -242,13 +201,11 @@ def _remove_batch(self, batch):
242201

243202
def _handle_finish_req(self, batch: Batch, has_new_finished_req):
244203
if has_new_finished_req:
245-
finished_reqs=batch.filter_finished()
204+
batch.filter_finished()
246205
if batch.is_clear():
247206
self._remove_batch(batch)
248207
else:
249208
self._filter_batch(batch)
250-
yield from self._output_process(finished_reqs)
251-
252209

253210
def _filter_runing_batch(self):
254211
if self.running_batch is not None and self.running_batch.is_clear():
@@ -261,47 +218,26 @@ def _add_token_id_to_req(self, batch: Batch, req_ans):
261218
req.output_metadata_list.append(new_gen_metadata)
262219
return
263220

264-
async def _output_process(self, finished_reqs: List[Req]):
265-
"""
266-
Process the output of a batch.
267-
"""
268-
for req in finished_reqs:
269-
output = self.tokenizer.decode(req.output_ids)
270-
yield output, req.request_id, req.output_metadata_list
271-
272221
def clean_up(self):
273222
# this logic should be implemented in the future.
274223
pass
275224

276-
async def generate(self,request_id,prompt_id,sampling_params):
277-
"""
278-
Generate the output of a request.
279-
"""
280-
self.add_input(request_id,prompt_id,sampling_params)
281-
282225

283226
def start_dynamic_batching(args, tp_engine, waiting_req_list):
284-
try:
285-
batch_manager = DynamicBatchManager(
286-
tp_engine=tp_engine,
287-
max_total_token_num=args.max_total_token_num,
288-
batch_max_tokens=args.batch_max_tokens,
289-
eos_id=args.eos_id,
290-
log_stats=not args.disable_log_stats,
291-
log_stats_interval=args.log_stats_interval,
292-
waiting_req_list=waiting_req_list,
293-
)
294-
295-
except Exception:
296-
batch_manager.clean_up()
297-
raise
298-
299-
batch_manager._set_tokenizer(tokenizer_name = tp_engine.model.__class__.__name__)
300-
prod_task = asyncio.create_task(batch_manager.add_input(4,sampling_params=SamplingParams(),input_ids="hello world"))
301-
302-
asyncio.run(prod_task)
303-
304-
for item in batch_manager.loop_for_fwd():
305-
print(item)
306-
307-
return batch_manager
227+
# try:
228+
batch_manager = DynamicBatchManager(
229+
tp_engine=tp_engine,
230+
max_total_token_num=args.max_total_token_num,
231+
batch_max_tokens=args.batch_max_tokens,
232+
eos_id=args.eos_id,
233+
log_stats=not args.disable_log_stats,
234+
log_stats_interval=args.log_stats_interval,
235+
waiting_req_list=waiting_req_list,
236+
)
237+
238+
# except Exception:
239+
# batch_manager.clean_up()
240+
# raise
241+
242+
batch_manager.loop_for_fwd()
243+
return

colossalai/inference/test_async.py

Lines changed: 0 additions & 33 deletions
This file was deleted.

tests/test_infer/test_dynamic_batching/test_forward.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,21 +42,15 @@ def run():
4242
waiting_list.append(req2)
4343
waiting_list.append(req3)
4444
waiting_list.append(req4)
45-
45+
4646
llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024)
4747
model = LlamaForCausalLM(llama_config)
4848
model = model.half()
4949

5050
shard_config = ShardConfig(enable_tensor_parallelism=True if TP_SIZE > 1 else False, inference_only=True)
5151

5252
infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
53-
manager = start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list)
54-
manager._set_tokenizer(tokenizer_name = model.__class__.__name__)
55-
result_generator = manager.loop_for_fwd()
56-
for result in result_generator:
57-
print(result)
58-
59-
53+
start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list)
6054

6155

6256
def check_dynamic_forward(rank, world_size, port):

0 commit comments

Comments
 (0)