Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
e0757c3
[inference] Dynamic Batching for Single and Multiple GPUs (#4831)
CjhHa1 Oct 11, 2023
fced140
[inference] Async dynamic batching (#4894)
CjhHa1 Oct 12, 2023
fbf3c09
[inference]Re push async dynamic batching (#4901)
CjhHa1 Oct 13, 2023
d509e79
Revert "[inference]Re push async dynamic batching (#4901)" (#4905)
CjhHa1 Oct 13, 2023
ec004fe
Revert "[inference] Async dynamic batching (#4894)"
isky-cd Oct 14, 2023
78cd937
Revert "[inference] Async dynamic batching (#4894)" (#4909)
tiandiao123 Oct 14, 2023
d97290a
Add Ray Distributed Environment Init Scripts
isky-cd Oct 14, 2023
8483393
fix conflict
isky-cd Oct 14, 2023
f589e97
support DynamicBatchManager base function
isky-cd Oct 14, 2023
c070050
revert _set_tokenizer version
isky-cd Oct 16, 2023
5deb95c
add driver async generate
isky-cd Oct 16, 2023
306ef77
add async test
isky-cd Oct 16, 2023
632f0e1
fix bugs in test_ray_dist.py
isky-cd Oct 16, 2023
0b2fe51
add get_tokenizer.py
isky-cd Oct 16, 2023
cd843ac
fix code style
isky-cd Oct 16, 2023
8c9ad51
fix bugs about No module named 'pydantic' in ci test
isky-cd Oct 16, 2023
8d0cc6b
fix bugs in ci test
isky-cd Oct 16, 2023
acdd751
fix bugs in ci test
isky-cd Oct 16, 2023
8a761bd
fix bugs in ci test
isky-cd Oct 16, 2023
56f75c4
[infer]Add Ray Distributed Environment Init Scripts (#4911)
isky-cd Oct 16, 2023
c76fd68
support dynamic batch for bloom model and is_running function
isky-cd Oct 17, 2023
f41ccdd
fix conflict
isky-cd Oct 17, 2023
fca12b8
Merge pull request #4933 from yuehuayingxueluo/ray_dist_init_branch
isky-cd Oct 17, 2023
4ea9fbe
[Inference]Test for new Async engine (#4935)
CjhHa1 Oct 19, 2023
3f6af12
add assertion for config (#4947)
CjhHa1 Oct 19, 2023
4867561
[Inference] Finish dynamic batching offline test (#4948)
CjhHa1 Oct 19, 2023
285fc30
fix bugs
CjhHa1 Oct 20, 2023
d5d2c94
fix quant
CjhHa1 Oct 20, 2023
ed86584
add default
CjhHa1 Oct 20, 2023
4bffb8b
fix
CjhHa1 Oct 20, 2023
77adc2e
fix some bugs
CjhHa1 Oct 20, 2023
dcb51b4
fix some bugs
CjhHa1 Oct 20, 2023
afae53b
fix
Oct 23, 2023
c477266
fix bug
Oct 23, 2023
f99eba2
fix bugs
Oct 23, 2023
4c3ea40
reset param
Oct 27, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions colossalai/inference/async_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import asyncio

from colossalai.inference.dynamic_batching.ray_dist_init import Driver

from .dynamic_batching.io_struct import RequestOutput
from .dynamic_batching.sampling_params import SamplingParams


class RequestTracker:
"""
A class for trace down all the requests, abstraction for async
"""

def __init__(self) -> None:
self._requests: asyncio.Queue[str] = asyncio.Queue()
self._finished_requests: asyncio.Queue[RequestOutput] = asyncio.Queue()
self.new_requests_event = None

def __contains__(self, item):
return item in self._requests

def init_event(self):
self.new_requests_event = asyncio.Event()

def add_request(self, request_id: str):
"""Add a request to be sent to the engine on the next background
loop iteration."""
self._requests.put_nowait(request_id)
self.new_requests_event.set() # NOTE: we may find a better way to clear this event

def add_stop(self):
"""
Add a StopIteration flag to stop async generator.
"""
self._finished_requests.put_nowait(StopIteration)
self.new_requests_event.clear()

def process_request_output(self, request_output: RequestOutput) -> None:
"""Process a request output from the engine."""
self._finished_requests.put_nowait(request_output)

async def wait_for_new_requests(self):
await self.new_requests_event.wait()

def __aiter__(self):
return self

async def __anext__(self) -> RequestOutput:
result = await self._finished_requests.get()
# print("result of ", result)
if result is StopIteration:
raise StopAsyncIteration
return result


class Async_Engine:

"""
Use an engine to launch RAY Driver --> RAY Worker --> Async_Manager
Background loop: inference reqs in waiting list (Listen)
Request Tracker: manage incoming requests and restore finished ones
Generate: exposed func for add new input and return finished ones
"""

def __init__(
self,
router_config,
engine_config,
start_engine_loop: bool = True,
) -> None:
self.driver = Driver(router_config=router_config, engine_config=engine_config)
self.background_loop = None
self.start_engine_loop = start_engine_loop
self._request_tracker = RequestTracker()

def _step(self):
"""
Logic for handling requests
"""
request_outputs = self.driver.step()
if request_outputs is not None:
for request_output in request_outputs:
self._request_tracker.process_request_output(request_output)
self._request_tracker.add_stop()

def abort_request(self, request_id: str):
self.driver.abort(request_id)

def _has_requests_in_progress(self):
return self.driver.is_running()

async def run_loop_fwd(self):
has_requests_in_progress = self._has_requests_in_progress()
while True:
if not has_requests_in_progress:
await self._request_tracker.wait_for_new_requests()
self._step()
await asyncio.sleep(0)

@property
def is_running(self):
return self.background_loop is not None and not self.background_loop.done()

def start_background_loop(self):
if self.is_running:
raise RuntimeError("Background loop is already running.")

self._request_tracker.init_event()

self.background_loop_unshielded = asyncio.get_event_loop().create_task(self.run_loop_fwd())
self.background_loop = asyncio.shield(self.background_loop_unshielded)

async def add_request(self, request_id: str, prompt: str, sampling_params: SamplingParams):
self.driver.add_input(request_id, prompt, sampling_params)
self._request_tracker.add_request(request_id)

async def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams):
"""
The only exposed func, adding new request and return a async generator that yields the existing results.
"""
try:
if not self.is_running:
self.start_background_loop()

await self.add_request(request_id, prompt, sampling_params)

async for request_output in self._request_tracker:
yield request_output

except (Exception, asyncio.CancelledError) as e:
# If there is an exception or coroutine is cancelled, abort the request.
self.abort_request(request_id)
raise e
151 changes: 151 additions & 0 deletions colossalai/inference/async_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from typing import List

from .dynamic_batching.io_struct import Batch, Req, RequestOutput
from .manager import DynamicBatchManager
from .tensor_parallel import TPInferEngine


class Async_DynamicBatchManager(DynamicBatchManager):
def __init__(
self,
tp_engine: TPInferEngine,
max_total_token_num: int,
batch_max_tokens: int,
model: str,
tokenizer=None,
eos_id=None,
log_stats=True,
log_stats_interval=10,
running_batch: Batch = None,
waiting_req_list: List = [],
):
"""
Args: tp_engine : The tp engine that dynamic batch manager hold, defined before dynamic batch manager
max_total_token_num : max_total_token_num for memory manager, default to: max batch size * (max input len + max output len)
batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests
running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine
eos_id : The end token of a seq
model: the model weight dir path, the app will load config, weights and tokenizer from this dir
log_stats : whether to log stats
log_stats_interval : log stats interval
running_batch : running batch
waiting_req_list : list of waiting requests, initialized before dynamic batch manager
"""
super().__init__(
tp_engine,
max_total_token_num,
batch_max_tokens,
model,
tokenizer,
eos_id,
log_stats,
log_stats_interval,
running_batch,
waiting_req_list,
)

def _step(self):
"""
Logic for handling requests
"""
has_new_finished = False
if self.running_batch is None:
new_batch = self.req_queue.generate_new_batch(self.running_batch)
if new_batch is not None:
self.stats_tool.count_prompt_tokens(new_batch)
self.running_batch = new_batch
has_new_finished, outputs = self._prefill_batch(self.running_batch)
self._filter_runing_batch()
self.has_wait_tokens = 0

else:
if self.has_wait_tokens < self.max_wait_tokens:
self.stats_tool.count_output_tokens(self.running_batch)
has_new_finished, outputs = self._decode_batch(self.running_batch)
self._filter_runing_batch()
self.has_wait_tokens += 1

else:
new_mini_batch = self.req_queue.generate_new_batch(self.running_batch)
if new_mini_batch is not None:
self.stats_tool.count_prompt_tokens(new_mini_batch)
has_new_finished, outputs = self._prefill_batch(new_mini_batch)
if not new_mini_batch.is_clear():
self._merge_batch(self.running_batch, new_mini_batch)
self.running_batch.merge(new_mini_batch)
self.has_wait_tokens = 0

else:
self.stats_tool.count_output_tokens(self.running_batch)
has_new_finished, outputs = self._decode_batch(self.running_batch)
self._filter_runing_batch()
self.has_wait_tokens += 1

if has_new_finished:
return outputs
return None

def _prefill_batch(self, batch):
"""
For all batches, no matter it is a new batch or a mini batch, we need to do prefill first.
"""
self._init_batch(batch)

# TODO: figure out if cache and batch id is needed
ans = self.engine._prefill_batch(batch.batch_id)
req_to_out_token_id = ans
self._add_token_id_to_req(batch, req_to_out_token_id)
has_new_finished_req = batch.mark_finished_req(self.eos_id, self.engine.max_output_len)
outputs = self._handle_finish_req(batch, has_new_finished_req)
return has_new_finished_req, outputs
# delete finished reqs

def _decode_batch(self, batch: Batch):
"""
Decoding process
"""
ans = self.engine._decode_batch(batch.batch_id)
req_to_out_token_id = ans
self._add_token_id_to_req(batch, req_to_out_token_id)
has_new_finished_req = batch.mark_finished_req(self.eos_id, self.engine.max_output_len)
outputs = self._handle_finish_req(batch, has_new_finished_req)
return has_new_finished_req, outputs

def _handle_finish_req(self, batch: Batch, has_new_finished_req):
if has_new_finished_req:
finished_reqs = batch.filter_finished()
if batch.is_clear():
self._remove_batch(batch)
else:
self._filter_batch(batch)
return self._output_process(finished_reqs)
return None

def _output_process(self, finished_reqs: List[Req]):
"""
Process the output of a batch.
"""
outputs = []
for req in finished_reqs:
output = self.tokenizer.decode(req.output_ids)
outputs.append(RequestOutput(req.request_id, req.prompts, req.prompt_ids, output))
return outputs


def start_dynamic_batching(args, tp_engine, waiting_req_list):
try:
batch_manager = Async_DynamicBatchManager(
tp_engine=tp_engine,
max_total_token_num=args.max_total_token_num,
batch_max_tokens=args.batch_max_tokens,
eos_id=args.eos_id,
model=args.model,
log_stats=not args.disable_log_stats,
log_stats_interval=args.log_stats_interval,
waiting_req_list=waiting_req_list,
)

except Exception:
raise Exception

return batch_manager
Empty file.
40 changes: 40 additions & 0 deletions colossalai/inference/dynamic_batching/get_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""
Motivated by VllM (https://github.com/vllm-project/vllm), This module is trying to resolve the tokenizer issue.

license: MIT, see LICENSE for more details.
"""

from transformers import AutoTokenizer

_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"


def get_tokenizer(
tokenizer=None,
tokenizer_name: str = "",
trust_remote_code: bool = False,
use_fast: bool = True,
):
if tokenizer is not None:
tokenizer = tokenizer
else:
if "llama" in tokenizer_name.lower() and use_fast == True:
print(
"For some LLaMA-based models, initializing the fast tokenizer may "
"take a long time. To eliminate the initialization time, consider "
f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
"tokenizer. This is done automatically in Colossalai."
)

tokenizer_name = _FAST_LLAMA_TOKENIZER

try:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code
)
except TypeError:
use_fast = False
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code
)
return tokenizer
Loading