Skip to content
133 changes: 133 additions & 0 deletions colossalai/inference/async_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import asyncio

from colossalai.inference.dynamic_batching.ray_dist_init import Driver

from .dynamic_batching.io_struct import RequestOutput
from .dynamic_batching.sampling_params import SamplingParams


class RequestTracker:
"""
A class for trace down all the requests, abstraction for async
"""

def __init__(self) -> None:
self._requests: asyncio.Queue[str] = asyncio.Queue()
self._finished_requests: asyncio.Queue[RequestOutput] = asyncio.Queue()
self.new_requests_event = None

def __contains__(self, item):
return item in self._requests

def init_event(self):
self.new_requests_event = asyncio.Event()

def add_request(self, request_id: str):
"""Add a request to be sent to the engine on the next background
loop iteration."""
self._requests.put_nowait(request_id)
self.new_requests_event.set() # NOTE: we may find a better way to clear this event

def add_stop(self):
"""
Add a StopIteration flag to stop async generator.
"""
self._finished_requests.put_nowait(StopIteration)
self.new_requests_event.clear()

def process_request_output(self, request_output: RequestOutput) -> None:
"""Process a request output from the engine."""
self._finished_requests.put_nowait(request_output)

async def wait_for_new_requests(self):
await self.new_requests_event.wait()

def __aiter__(self):
return self

async def __anext__(self) -> RequestOutput:
result = await self._finished_requests.get()
# print("result of ", result)
if result is StopIteration:
raise StopAsyncIteration
return result


class Async_Engine:

"""
Use an engine to launch RAY Driver --> RAY Worker --> Async_Manager
Background loop: inference reqs in waiting list (Listen)
Request Tracker: manage incoming requests and restore finished ones
Generate: exposed func for add new input and return finished ones
"""

def __init__(
self,
router_config,
engine_config,
start_engine_loop: bool = True,
) -> None:
self.driver = Driver(router_config=router_config, engine_config=engine_config)
self.background_loop = None
self.start_engine_loop = start_engine_loop
self._request_tracker = RequestTracker()

def _step(self):
"""
Logic for handling requests
"""
request_outputs = self.driver.step()
if request_outputs is not None:
for request_output in request_outputs:
self._request_tracker.process_request_output(request_output)
self._request_tracker.add_stop()

def abort(self, request_id: str):
self.driver.abort(request_id)

def _has_requests_in_progress(self):
return self.driver.is_running()

async def run_loop_fwd(self):
has_requests_in_progress = self._has_requests_in_progress()
while True:
if not has_requests_in_progress:
await self._request_tracker.wait_for_new_requests()
self._step()
await asyncio.sleep(0)

@property
def is_running(self):
return self.background_loop is not None and not self.background_loop.done()

def start_background_loop(self):
if self.is_running:
raise RuntimeError("Background loop is already running.")

self._request_tracker.init_event()

self.background_loop_unshielded = asyncio.get_event_loop().create_task(self.run_loop_fwd())
self.background_loop = asyncio.shield(self.background_loop_unshielded)

async def add_request(self, request_id: str, prompt: str, sampling_params: SamplingParams):
self.driver.add_input(request_id, prompt, sampling_params)
self._request_tracker.add_request(request_id)

async def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams):
"""
The only exposed func, adding new request and return a async generator that yields the existing results.
"""
try:
if not self.is_running:
self.start_background_loop()

await self.add_request(request_id, prompt, sampling_params)

async for request_output in self._request_tracker:
yield request_output

except (Exception, asyncio.CancelledError) as e:
# If there is an exception or coroutine is cancelled, abort the request.
self.abort_request(request_id)
raise e
150 changes: 150 additions & 0 deletions colossalai/inference/async_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
from typing import List

from .dynamic_batching.io_struct import Batch, Req, RequestOutput
from .manager import DynamicBatchManager
from .tensor_parallel import TPInferEngine


class Async_DynamicBatchManager(DynamicBatchManager):
def __init__(
self,
tp_engine: TPInferEngine,
max_total_token_num,
batch_max_tokens,
eos_id,
model,
log_stats=True,
log_stats_interval=10,
running_batch: Batch = None,
waiting_req_list: List = [],
):
"""
Args: tp_engine : The tp engine that dynamic batch manager hold, defined before dynamic batch manager
max_total_token_num : max_total_token_num for memory manager, default to: max batch size * (max input len + max output len)
batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests
running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine
eos_id : The end token of a seq
model: the model weight dir path, the app will load config, weights and tokenizer from this dir
log_stats : whether to log stats
log_stats_interval : log stats interval
running_batch : running batch
waiting_req_list : list of waiting requests, initialized before dynamic batch manager
"""
super().__init__(
tp_engine,
max_total_token_num,
batch_max_tokens,
eos_id,
model,
log_stats,
log_stats_interval,
running_batch,
waiting_req_list,
)

def _step(self):
"""
Logic for handling requests
"""
has_new_finished = False
if self.running_batch is None:
new_batch = self.req_queue.generate_new_batch(self.running_batch)
if new_batch is not None:
self.stats_tool.count_prompt_tokens(new_batch)
self.running_batch = new_batch
has_new_finished, outputs = self._prefill_batch(self.running_batch)
self._filter_runing_batch()
self.has_wait_tokens = 0

else:
if self.has_wait_tokens < self.max_wait_tokens:
self.stats_tool.count_output_tokens(self.running_batch)
has_new_finished, outputs = self._decode_batch(self.running_batch)
self._filter_runing_batch()
self.has_wait_tokens += 1

else:
new_mini_batch = self.req_queue.generate_new_batch(self.running_batch)
if new_mini_batch is not None:
self.stats_tool.count_prompt_tokens(new_mini_batch)
has_new_finished, outputs = self._prefill_batch(new_mini_batch)
if not new_mini_batch.is_clear():
self._merge_batch(self.running_batch, new_mini_batch)
self.running_batch.merge(new_mini_batch)
self.has_wait_tokens = 0

else:
self.stats_tool.count_output_tokens(self.running_batch)
has_new_finished, outputs = self._decode_batch(self.running_batch)
self._filter_runing_batch()
self.has_wait_tokens += 1

if has_new_finished:
return outputs
return None

def _prefill_batch(self, batch):
"""
For all batches, no matter it is a new batch or a mini batch, we need to do prefill first.
"""
self._init_batch(batch)

# TODO: figure out if cache and batch id is needed
ans = self.engine._prefill_batch(batch.batch_id)
req_to_out_token_id = ans
self._add_token_id_to_req(batch, req_to_out_token_id)
has_new_finished_req = batch.mark_finished_req(self.eos_id)
outputs = self._handle_finish_req(batch, has_new_finished_req)
return has_new_finished_req, outputs
# delete finished reqs

def _decode_batch(self, batch: Batch):
"""
Decoding process
"""
ans = self.engine._decode_batch(batch.batch_id)
req_to_out_token_id = ans
self._add_token_id_to_req(batch, req_to_out_token_id)
has_new_finished_req = batch.mark_finished_req(self.eos_id)
outputs = self._handle_finish_req(batch, has_new_finished_req)
return has_new_finished_req, outputs

def _handle_finish_req(self, batch: Batch, has_new_finished_req):
if has_new_finished_req:
finished_reqs = batch.filter_finished()
if batch.is_clear():
self._remove_batch(batch)
else:
self._filter_batch(batch)
return self._output_process(finished_reqs)
return None

def _output_process(self, finished_reqs: List[Req]):
"""
Process the output of a batch.
"""
outputs = []
for req in finished_reqs:
output = self.tokenizer.decode(req.output_ids)
outputs.append(RequestOutput(req.request_id, req.prompts, req.prompt_ids, output))
return outputs


def start_dynamic_batching(args, tp_engine, waiting_req_list):
try:
batch_manager = Async_DynamicBatchManager(
tp_engine=tp_engine,
max_total_token_num=args.max_total_token_num,
batch_max_tokens=args.batch_max_tokens,
eos_id=args.eos_id,
model=args.model,
log_stats=not args.disable_log_stats,
log_stats_interval=args.log_stats_interval,
waiting_req_list=waiting_req_list,
)

except Exception:
batch_manager.clean_up()
raise

return batch_manager
8 changes: 7 additions & 1 deletion colossalai/inference/dynamic_batching/get_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
"""
Motivated by VllM (https://github.com/vllm-project/vllm), This module is trying to resolve the tokenizer issue.

license: MIT, see LICENSE for more details.
"""

from transformers import AutoTokenizer

_FAST_LLAMA_TOKENIZER = "/home/lccd/share/llama-tokenizer"
_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"


def get_tokenizer(
Expand Down
16 changes: 8 additions & 8 deletions colossalai/inference/dynamic_batching/infer_batch.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
# Adapted from https://github.com/ModelTC/lightllm

import collections
from dataclasses import dataclass
from typing import Dict, List , Tuple
from typing import Dict, List, Tuple

import numpy as np
import torch

from colossalai.inference.tensor_parallel import MemoryManager

# make batch infer state an attr of InferBatch


# make batch infer state an attr of InferBatch
class InferSamplingParams:
def __init__(
self,
Expand Down Expand Up @@ -65,7 +66,7 @@ def init_batch(
cache_manager: MemoryManager,
vocab_size: int,
max_total_len: int,
) -> 'InferBatch':
) -> "InferBatch":
input_lengths = []
all_input_ids = []
requests_idx_mapping = {}
Expand All @@ -76,7 +77,7 @@ def init_batch(
nopad_total_token_num = 0
nopad_max_len_in_batch = 0
nopad_b_loc = torch.empty((len(requests), max_total_len + 12), dtype=torch.long, device="cuda")
# to avoid memory leak , we pre-allocate 12 more space for each batch.
# to avoid memory leak , we pre-allocate 12 more space for each batch.
nopad_b_start_loc = torch.zeros(len(requests), dtype=torch.int32, device="cuda")
for i, r in enumerate(requests):
# request id -> idx in list mapping
Expand Down Expand Up @@ -142,10 +143,9 @@ def free_self(self) -> None:
)
remove_index = torch.cat(remove_index, dim=-1)
self.cache_manager.free(remove_index)


@torch.no_grad()
def filter(self, request_ids: List[int]) -> 'InferBatch':
def filter(self, request_ids: List[int]) -> "InferBatch":
"""
Filter finished batch and return a new InferBatch with left ones.
"""
Expand Down Expand Up @@ -226,7 +226,7 @@ def filter(self, request_ids: List[int]) -> 'InferBatch':

@classmethod
@torch.no_grad()
def merge(cls, batch1, batch2) -> 'InferBatch':
def merge(cls, batch1, batch2) -> "InferBatch":
"""
Return megerd new InferBatch
"""
Expand Down
Loading