Skip to content

Commit fced140

Browse files
authored
[inference] Async dynamic batching (hpcaitech#4894)
* finish input and output logic * add generate * test forward * 1
1 parent e0757c3 commit fced140

File tree

4 files changed

+139
-32
lines changed

4 files changed

+139
-32
lines changed

colossalai/inference/dynamic_batching/io_struct.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,17 +102,21 @@ def mark_finished_req(self, eos_id):
102102
has_new_finish = True
103103
return has_new_finish
104104

105-
def filter_finished(self):
105+
def filter_finished(self)->List[Req]:
106106
"""
107107
Filter finished requests from the batch, the finished ones will be removed from 'reqs'.
108108
"""
109109
# TODO: the logic of return should be defined here.
110110
unfinished_req = []
111+
finished_req = []
111112
for req in self.reqs:
112113
if not req.has_generate_finished:
113-
unfinished_req.append(req)
114+
unfinished_req.append(req)
115+
else:
116+
finished_req.append(req)
114117
self.reqs = unfinished_req
115118
self.id_to_reqs = {req.request_id: req for req in self.reqs}
119+
return finished_req
116120

117121
def is_clear(self):
118122
return len(self.reqs) == 0

colossalai/inference/manager.py

Lines changed: 92 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import time
22
from typing import List
3+
import asyncio
34

45
from .dynamic_batching.infer_batch import InferBatch
56
from .dynamic_batching.io_struct import Batch, Req
@@ -8,6 +9,8 @@
89
from .dynamic_batching.stats import Stats
910
from .tensor_parallel import TPInferEngine
1011

12+
from transformers import AutoTokenizer
13+
_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"
1114

1215
class DynamicBatchManager:
1316
def __init__(
@@ -54,6 +57,20 @@ def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, reques
5457
self.req_queue.append(req)
5558
return
5659

60+
def add_input(self, request_id, sampling_params, input_ids):
61+
"""
62+
Encode and Add new input to req queue. support one sequence input for now.
63+
"""
64+
prompt_ids = self.tokenizer.encode(input_ids)
65+
prompt_len = len(prompt_ids)
66+
if prompt_len > self.engine.max_input_len:
67+
raise ValueError(
68+
f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}"
69+
)
70+
sampling_params.stop_sentences_to_token_ids(self.tokenizer)
71+
self.add_req(prompt_ids, sampling_params, request_id)
72+
return
73+
5774
def abort(self, request_id):
5875
if self.running_batch is not None:
5976
for req in self.running_batch.reqs:
@@ -66,13 +83,15 @@ def abort(self, request_id):
6683
req.aborted = True
6784
return
6885

69-
def loop_for_fwd(self):
86+
async def loop_for_fwd(self):
7087
"""
7188
The main loop for a dynamic batching process.
7289
"""
7390
counter_count = 0
74-
while self.running_batch is not None or self.req_queue.waiting_req_list:
75-
self._step()
91+
#self.running_batch is not None or self.req_queue.waiting_req_list
92+
while True:
93+
async for item in self._step():
94+
yield item
7695
counter_count += 1
7796
if self.running_batch is not None:
7897
if counter_count % self.mem_usage_interval == 0:
@@ -87,6 +106,26 @@ def loop_for_fwd(self):
87106
if self.running_batch is None:
88107
time.sleep(0.1) # 10ms
89108

109+
def _set_tokenizer(self, tokenizer=None, tokenizer_name: str = "", trust_remote_code: bool = False, use_fast:bool = True,):
110+
if tokenizer is not None:
111+
self.tokenizer = tokenizer
112+
else:
113+
if "llama" in tokenizer_name.lower() and use_fast == True:
114+
print(
115+
"For some LLaMA-based models, initializing the fast tokenizer may "
116+
"take a long time. To eliminate the initialization time, consider "
117+
f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
118+
"tokenizer. This is done automatically in Colossalai.")
119+
120+
tokenizer_name = _FAST_LLAMA_TOKENIZER
121+
122+
try:
123+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code)
124+
except TypeError as e:
125+
use_fast = False
126+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code)
127+
128+
90129
def _step(self):
91130
"""
92131
Logic for handling requests
@@ -97,32 +136,33 @@ def _step(self):
97136
if new_batch is not None:
98137
self.stats_tool.count_prompt_tokens(new_batch)
99138
self.running_batch = new_batch
100-
self._prefill_batch(self.running_batch)
139+
yield from self._prefill_batch(self.running_batch)
101140
self._filter_runing_batch()
102141
self.has_wait_tokens = 0
103142
return
104143

105144
if self.has_wait_tokens < self.max_wait_tokens:
106145
self.stats_tool.count_output_tokens(self.running_batch)
107-
self._decode_batch(self.running_batch)
146+
yield from self._decode_batch(self.running_batch)
108147
self._filter_runing_batch()
109148
self.has_wait_tokens += 1
110149
return
111150
else:
112151
new_mini_batch = self.req_queue.generate_new_batch(self.running_batch)
113152
if new_mini_batch is not None:
114153
self.stats_tool.count_prompt_tokens(new_mini_batch)
115-
self._prefill_batch(new_mini_batch)
154+
yield from self._prefill_batch(new_mini_batch)
116155
if not new_mini_batch.is_clear():
117156
self._merge_batch(self.running_batch, new_mini_batch)
118157
self.running_batch.merge(new_mini_batch)
119158
self.has_wait_tokens = 0
159+
120160
else:
121161
self.stats_tool.count_output_tokens(self.running_batch)
122-
self._decode_batch(self.running_batch)
162+
yield from self._decode_batch(self.running_batch)
123163
self._filter_runing_batch()
124164
self.has_wait_tokens += 1
125-
165+
126166
return
127167

128168
def _init_batch(self, batch: Batch, dtype="fp16"):
@@ -158,7 +198,8 @@ def _prefill_batch(self, batch):
158198
req_to_out_token_id = ans
159199
self._add_token_id_to_req(batch, req_to_out_token_id)
160200
has_new_finished_req = batch.mark_finished_req(self.eos_id)
161-
self._handle_finish_req(batch, has_new_finished_req)
201+
yield from self._handle_finish_req(batch, has_new_finished_req)
202+
162203
# delete finished reqs
163204

164205
def _decode_batch(self, batch: Batch):
@@ -169,7 +210,7 @@ def _decode_batch(self, batch: Batch):
169210
req_to_out_token_id = ans
170211
self._add_token_id_to_req(batch, req_to_out_token_id)
171212
has_new_finished_req = batch.mark_finished_req(self.eos_id)
172-
self._handle_finish_req(batch, has_new_finished_req)
213+
yield from self._handle_finish_req(batch, has_new_finished_req)
173214

174215
def _filter_batch(self, batch: Batch):
175216
batch_id = batch.batch_id
@@ -201,11 +242,13 @@ def _remove_batch(self, batch):
201242

202243
def _handle_finish_req(self, batch: Batch, has_new_finished_req):
203244
if has_new_finished_req:
204-
batch.filter_finished()
245+
finished_reqs=batch.filter_finished()
205246
if batch.is_clear():
206247
self._remove_batch(batch)
207248
else:
208249
self._filter_batch(batch)
250+
yield from self._output_process(finished_reqs)
251+
209252

210253
def _filter_runing_batch(self):
211254
if self.running_batch is not None and self.running_batch.is_clear():
@@ -218,26 +261,47 @@ def _add_token_id_to_req(self, batch: Batch, req_ans):
218261
req.output_metadata_list.append(new_gen_metadata)
219262
return
220263

264+
async def _output_process(self, finished_reqs: List[Req]):
265+
"""
266+
Process the output of a batch.
267+
"""
268+
for req in finished_reqs:
269+
output = self.tokenizer.decode(req.output_ids)
270+
yield output, req.request_id, req.output_metadata_list
271+
221272
def clean_up(self):
222273
# this logic should be implemented in the future.
223274
pass
224275

276+
async def generate(self,request_id,prompt_id,sampling_params):
277+
"""
278+
Generate the output of a request.
279+
"""
280+
self.add_input(request_id,prompt_id,sampling_params)
281+
225282

226283
def start_dynamic_batching(args, tp_engine, waiting_req_list):
227-
# try:
228-
batch_manager = DynamicBatchManager(
229-
tp_engine=tp_engine,
230-
max_total_token_num=args.max_total_token_num,
231-
batch_max_tokens=args.batch_max_tokens,
232-
eos_id=args.eos_id,
233-
log_stats=not args.disable_log_stats,
234-
log_stats_interval=args.log_stats_interval,
235-
waiting_req_list=waiting_req_list,
236-
)
237-
238-
# except Exception:
239-
# batch_manager.clean_up()
240-
# raise
241-
242-
batch_manager.loop_for_fwd()
243-
return
284+
try:
285+
batch_manager = DynamicBatchManager(
286+
tp_engine=tp_engine,
287+
max_total_token_num=args.max_total_token_num,
288+
batch_max_tokens=args.batch_max_tokens,
289+
eos_id=args.eos_id,
290+
log_stats=not args.disable_log_stats,
291+
log_stats_interval=args.log_stats_interval,
292+
waiting_req_list=waiting_req_list,
293+
)
294+
295+
except Exception:
296+
batch_manager.clean_up()
297+
raise
298+
299+
batch_manager._set_tokenizer(tokenizer_name = tp_engine.model.__class__.__name__)
300+
prod_task = asyncio.create_task(batch_manager.add_input(4,sampling_params=SamplingParams(),input_ids="hello world"))
301+
302+
asyncio.run(prod_task)
303+
304+
for item in batch_manager.loop_for_fwd():
305+
print(item)
306+
307+
return batch_manager

colossalai/inference/test_async.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import asyncio
2+
3+
shared_list = []
4+
5+
async def producer():
6+
for i in range(5):
7+
await asyncio.sleep(1) # 模拟异步获取数据的操作
8+
shared_list.append(i)
9+
print(f"Produced {i}")
10+
11+
async def consumer():
12+
last_index = 0
13+
while True:
14+
await asyncio.sleep(0.5) # 为了不使循环过于紧凑,增加了小的延迟
15+
if last_index < len(shared_list):
16+
item = shared_list[last_index]
17+
print(f"Consumed {item}")
18+
yield item
19+
last_index += 1
20+
21+
async def main():
22+
# 创建生产者和消费者任务
23+
prod_task = asyncio.create_task(producer())
24+
25+
# 等待生产者任务完成
26+
await prod_task
27+
28+
async for data in consumer():
29+
print(data)
30+
# 为了示例的目的,我们只等待一段时间,然后停止消费者
31+
await asyncio.sleep(5)
32+
33+
asyncio.run(main())

tests/test_infer/test_dynamic_batching/test_forward.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,21 @@ def run():
4242
waiting_list.append(req2)
4343
waiting_list.append(req3)
4444
waiting_list.append(req4)
45-
45+
4646
llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024)
4747
model = LlamaForCausalLM(llama_config)
4848
model = model.half()
4949

5050
shard_config = ShardConfig(enable_tensor_parallelism=True if TP_SIZE > 1 else False, inference_only=True)
5151

5252
infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
53-
start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list)
53+
manager = start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list)
54+
manager._set_tokenizer(tokenizer_name = model.__class__.__name__)
55+
result_generator = manager.loop_for_fwd()
56+
for result in result_generator:
57+
print(result)
58+
59+
5460

5561

5662
def check_dynamic_forward(rank, world_size, port):

0 commit comments

Comments
 (0)