Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions colossalai/inference/dynamic_batching/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,17 +102,21 @@ def mark_finished_req(self, eos_id):
has_new_finish = True
return has_new_finish

def filter_finished(self):
def filter_finished(self)->List[Req]:
"""
Filter finished requests from the batch, the finished ones will be removed from 'reqs'.
"""
# TODO: the logic of return should be defined here.
unfinished_req = []
finished_req = []
for req in self.reqs:
if not req.has_generate_finished:
unfinished_req.append(req)
unfinished_req.append(req)
else:
finished_req.append(req)
self.reqs = unfinished_req
self.id_to_reqs = {req.request_id: req for req in self.reqs}
return finished_req

def is_clear(self):
return len(self.reqs) == 0
Expand Down
120 changes: 92 additions & 28 deletions colossalai/inference/manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import time
from typing import List
import asyncio

from .dynamic_batching.infer_batch import InferBatch
from .dynamic_batching.io_struct import Batch, Req
Expand All @@ -8,6 +9,8 @@
from .dynamic_batching.stats import Stats
from .tensor_parallel import TPInferEngine

from transformers import AutoTokenizer
_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"

class DynamicBatchManager:
def __init__(
Expand Down Expand Up @@ -54,6 +57,20 @@ def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, reques
self.req_queue.append(req)
return

def add_input(self, request_id, sampling_params, input_ids):
"""
Encode and Add new input to req queue. support one sequence input for now.
"""
prompt_ids = self.tokenizer.encode(input_ids)
prompt_len = len(prompt_ids)
if prompt_len > self.engine.max_input_len:
raise ValueError(
f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}"
)
sampling_params.stop_sentences_to_token_ids(self.tokenizer)
self.add_req(prompt_ids, sampling_params, request_id)
return

def abort(self, request_id):
if self.running_batch is not None:
for req in self.running_batch.reqs:
Expand All @@ -66,13 +83,15 @@ def abort(self, request_id):
req.aborted = True
return

def loop_for_fwd(self):
async def loop_for_fwd(self):
"""
The main loop for a dynamic batching process.
"""
counter_count = 0
while self.running_batch is not None or self.req_queue.waiting_req_list:
self._step()
#self.running_batch is not None or self.req_queue.waiting_req_list
while True:
async for item in self._step():
yield item
counter_count += 1
if self.running_batch is not None:
if counter_count % self.mem_usage_interval == 0:
Expand All @@ -87,6 +106,26 @@ def loop_for_fwd(self):
if self.running_batch is None:
time.sleep(0.1) # 10ms

def _set_tokenizer(self, tokenizer=None, tokenizer_name: str = "", trust_remote_code: bool = False, use_fast:bool = True,):
if tokenizer is not None:
self.tokenizer = tokenizer
else:
if "llama" in tokenizer_name.lower() and use_fast == True:
print(
"For some LLaMA-based models, initializing the fast tokenizer may "
"take a long time. To eliminate the initialization time, consider "
f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
"tokenizer. This is done automatically in Colossalai.")

tokenizer_name = _FAST_LLAMA_TOKENIZER

try:
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code)
except TypeError as e:
use_fast = False
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code)


def _step(self):
"""
Logic for handling requests
Expand All @@ -97,32 +136,33 @@ def _step(self):
if new_batch is not None:
self.stats_tool.count_prompt_tokens(new_batch)
self.running_batch = new_batch
self._prefill_batch(self.running_batch)
yield from self._prefill_batch(self.running_batch)
self._filter_runing_batch()
self.has_wait_tokens = 0
return

if self.has_wait_tokens < self.max_wait_tokens:
self.stats_tool.count_output_tokens(self.running_batch)
self._decode_batch(self.running_batch)
yield from self._decode_batch(self.running_batch)
self._filter_runing_batch()
self.has_wait_tokens += 1
return
else:
new_mini_batch = self.req_queue.generate_new_batch(self.running_batch)
if new_mini_batch is not None:
self.stats_tool.count_prompt_tokens(new_mini_batch)
self._prefill_batch(new_mini_batch)
yield from self._prefill_batch(new_mini_batch)
if not new_mini_batch.is_clear():
self._merge_batch(self.running_batch, new_mini_batch)
self.running_batch.merge(new_mini_batch)
self.has_wait_tokens = 0

else:
self.stats_tool.count_output_tokens(self.running_batch)
self._decode_batch(self.running_batch)
yield from self._decode_batch(self.running_batch)
self._filter_runing_batch()
self.has_wait_tokens += 1

return

def _init_batch(self, batch: Batch, dtype="fp16"):
Expand Down Expand Up @@ -158,7 +198,8 @@ def _prefill_batch(self, batch):
req_to_out_token_id = ans
self._add_token_id_to_req(batch, req_to_out_token_id)
has_new_finished_req = batch.mark_finished_req(self.eos_id)
self._handle_finish_req(batch, has_new_finished_req)
yield from self._handle_finish_req(batch, has_new_finished_req)

# delete finished reqs

def _decode_batch(self, batch: Batch):
Expand All @@ -169,7 +210,7 @@ def _decode_batch(self, batch: Batch):
req_to_out_token_id = ans
self._add_token_id_to_req(batch, req_to_out_token_id)
has_new_finished_req = batch.mark_finished_req(self.eos_id)
self._handle_finish_req(batch, has_new_finished_req)
yield from self._handle_finish_req(batch, has_new_finished_req)

def _filter_batch(self, batch: Batch):
batch_id = batch.batch_id
Expand Down Expand Up @@ -201,11 +242,13 @@ def _remove_batch(self, batch):

def _handle_finish_req(self, batch: Batch, has_new_finished_req):
if has_new_finished_req:
batch.filter_finished()
finished_reqs=batch.filter_finished()
if batch.is_clear():
self._remove_batch(batch)
else:
self._filter_batch(batch)
yield from self._output_process(finished_reqs)


def _filter_runing_batch(self):
if self.running_batch is not None and self.running_batch.is_clear():
Expand All @@ -218,26 +261,47 @@ def _add_token_id_to_req(self, batch: Batch, req_ans):
req.output_metadata_list.append(new_gen_metadata)
return

async def _output_process(self, finished_reqs: List[Req]):
"""
Process the output of a batch.
"""
for req in finished_reqs:
output = self.tokenizer.decode(req.output_ids)
yield output, req.request_id, req.output_metadata_list

def clean_up(self):
# this logic should be implemented in the future.
pass

async def generate(self,request_id,prompt_id,sampling_params):
"""
Generate the output of a request.
"""
self.add_input(request_id,prompt_id,sampling_params)


def start_dynamic_batching(args, tp_engine, waiting_req_list):
# try:
batch_manager = DynamicBatchManager(
tp_engine=tp_engine,
max_total_token_num=args.max_total_token_num,
batch_max_tokens=args.batch_max_tokens,
eos_id=args.eos_id,
log_stats=not args.disable_log_stats,
log_stats_interval=args.log_stats_interval,
waiting_req_list=waiting_req_list,
)

# except Exception:
# batch_manager.clean_up()
# raise

batch_manager.loop_for_fwd()
return
try:
batch_manager = DynamicBatchManager(
tp_engine=tp_engine,
max_total_token_num=args.max_total_token_num,
batch_max_tokens=args.batch_max_tokens,
eos_id=args.eos_id,
log_stats=not args.disable_log_stats,
log_stats_interval=args.log_stats_interval,
waiting_req_list=waiting_req_list,
)

except Exception:
batch_manager.clean_up()
raise

batch_manager._set_tokenizer(tokenizer_name = tp_engine.model.__class__.__name__)
prod_task = asyncio.create_task(batch_manager.add_input(4,sampling_params=SamplingParams(),input_ids="hello world"))

asyncio.run(prod_task)

for item in batch_manager.loop_for_fwd():
print(item)

return batch_manager
33 changes: 33 additions & 0 deletions colossalai/inference/test_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import asyncio

shared_list = []

async def producer():
for i in range(5):
await asyncio.sleep(1) # 模拟异步获取数据的操作
shared_list.append(i)
print(f"Produced {i}")

async def consumer():
last_index = 0
while True:
await asyncio.sleep(0.5) # 为了不使循环过于紧凑,增加了小的延迟
if last_index < len(shared_list):
item = shared_list[last_index]
print(f"Consumed {item}")
yield item
last_index += 1

async def main():
# 创建生产者和消费者任务
prod_task = asyncio.create_task(producer())

# 等待生产者任务完成
await prod_task

async for data in consumer():
print(data)
# 为了示例的目的,我们只等待一段时间,然后停止消费者
await asyncio.sleep(5)

asyncio.run(main())
10 changes: 8 additions & 2 deletions tests/test_infer/test_dynamic_batching/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,21 @@ def run():
waiting_list.append(req2)
waiting_list.append(req3)
waiting_list.append(req4)

llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024)
model = LlamaForCausalLM(llama_config)
model = model.half()

shard_config = ShardConfig(enable_tensor_parallelism=True if TP_SIZE > 1 else False, inference_only=True)

infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list)
manager = start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list)
manager._set_tokenizer(tokenizer_name = model.__class__.__name__)
result_generator = manager.loop_for_fwd()
for result in result_generator:
print(result)




def check_dynamic_forward(rank, world_size, port):
Expand Down