diff --git a/colossalai/inference/async_engine.py b/colossalai/inference/async_engine.py new file mode 100644 index 000000000000..a58dde01d250 --- /dev/null +++ b/colossalai/inference/async_engine.py @@ -0,0 +1,133 @@ +import asyncio + +from colossalai.inference.dynamic_batching.ray_dist_init import Driver + +from .dynamic_batching.io_struct import RequestOutput +from .dynamic_batching.sampling_params import SamplingParams + + +class RequestTracker: + """ + A class for trace down all the requests, abstraction for async + """ + + def __init__(self) -> None: + self._requests: asyncio.Queue[str] = asyncio.Queue() + self._finished_requests: asyncio.Queue[RequestOutput] = asyncio.Queue() + self.new_requests_event = None + + def __contains__(self, item): + return item in self._requests + + def init_event(self): + self.new_requests_event = asyncio.Event() + + def add_request(self, request_id: str): + """Add a request to be sent to the engine on the next background + loop iteration.""" + self._requests.put_nowait(request_id) + self.new_requests_event.set() # NOTE: we may find a better way to clear this event + + def add_stop(self): + """ + Add a StopIteration flag to stop async generator. + """ + self._finished_requests.put_nowait(StopIteration) + self.new_requests_event.clear() + + def process_request_output(self, request_output: RequestOutput) -> None: + """Process a request output from the engine.""" + self._finished_requests.put_nowait(request_output) + + async def wait_for_new_requests(self): + await self.new_requests_event.wait() + + def __aiter__(self): + return self + + async def __anext__(self) -> RequestOutput: + result = await self._finished_requests.get() + # print("result of ", result) + if result is StopIteration: + raise StopAsyncIteration + return result + + +class Async_Engine: + + """ + Use an engine to launch RAY Driver --> RAY Worker --> Async_Manager + Background loop: inference reqs in waiting list (Listen) + Request Tracker: manage incoming requests and restore finished ones + Generate: exposed func for add new input and return finished ones + """ + + def __init__( + self, + router_config, + engine_config, + start_engine_loop: bool = True, + ) -> None: + self.driver = Driver(router_config=router_config, engine_config=engine_config) + self.background_loop = None + self.start_engine_loop = start_engine_loop + self._request_tracker = RequestTracker() + + def _step(self): + """ + Logic for handling requests + """ + request_outputs = self.driver.step() + if request_outputs is not None: + for request_output in request_outputs: + self._request_tracker.process_request_output(request_output) + self._request_tracker.add_stop() + + def abort(self, request_id: str): + self.driver.abort(request_id) + + def _has_requests_in_progress(self): + return self.driver.is_running() + + async def run_loop_fwd(self): + has_requests_in_progress = self._has_requests_in_progress() + while True: + if not has_requests_in_progress: + await self._request_tracker.wait_for_new_requests() + self._step() + await asyncio.sleep(0) + + @property + def is_running(self): + return self.background_loop is not None and not self.background_loop.done() + + def start_background_loop(self): + if self.is_running: + raise RuntimeError("Background loop is already running.") + + self._request_tracker.init_event() + + self.background_loop_unshielded = asyncio.get_event_loop().create_task(self.run_loop_fwd()) + self.background_loop = asyncio.shield(self.background_loop_unshielded) + + async def add_request(self, request_id: str, prompt: str, sampling_params: SamplingParams): + self.driver.add_input(request_id, prompt, sampling_params) + self._request_tracker.add_request(request_id) + + async def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams): + """ + The only exposed func, adding new request and return a async generator that yields the existing results. + """ + try: + if not self.is_running: + self.start_background_loop() + + await self.add_request(request_id, prompt, sampling_params) + + async for request_output in self._request_tracker: + yield request_output + + except (Exception, asyncio.CancelledError) as e: + # If there is an exception or coroutine is cancelled, abort the request. + self.abort_request(request_id) + raise e diff --git a/colossalai/inference/async_manager.py b/colossalai/inference/async_manager.py new file mode 100644 index 000000000000..78d11b1caa44 --- /dev/null +++ b/colossalai/inference/async_manager.py @@ -0,0 +1,150 @@ +from typing import List + +from .dynamic_batching.io_struct import Batch, Req, RequestOutput +from .manager import DynamicBatchManager +from .tensor_parallel import TPInferEngine + + +class Async_DynamicBatchManager(DynamicBatchManager): + def __init__( + self, + tp_engine: TPInferEngine, + max_total_token_num, + batch_max_tokens, + eos_id, + model, + log_stats=True, + log_stats_interval=10, + running_batch: Batch = None, + waiting_req_list: List = [], + ): + """ + Args: tp_engine : The tp engine that dynamic batch manager hold, defined before dynamic batch manager + max_total_token_num : max_total_token_num for memory manager, default to: max batch size * (max input len + max output len) + batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests + running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine + eos_id : The end token of a seq + model: the model weight dir path, the app will load config, weights and tokenizer from this dir + log_stats : whether to log stats + log_stats_interval : log stats interval + running_batch : running batch + waiting_req_list : list of waiting requests, initialized before dynamic batch manager + """ + super().__init__( + tp_engine, + max_total_token_num, + batch_max_tokens, + eos_id, + model, + log_stats, + log_stats_interval, + running_batch, + waiting_req_list, + ) + + def _step(self): + """ + Logic for handling requests + """ + has_new_finished = False + if self.running_batch is None: + new_batch = self.req_queue.generate_new_batch(self.running_batch) + if new_batch is not None: + self.stats_tool.count_prompt_tokens(new_batch) + self.running_batch = new_batch + has_new_finished, outputs = self._prefill_batch(self.running_batch) + self._filter_runing_batch() + self.has_wait_tokens = 0 + + else: + if self.has_wait_tokens < self.max_wait_tokens: + self.stats_tool.count_output_tokens(self.running_batch) + has_new_finished, outputs = self._decode_batch(self.running_batch) + self._filter_runing_batch() + self.has_wait_tokens += 1 + + else: + new_mini_batch = self.req_queue.generate_new_batch(self.running_batch) + if new_mini_batch is not None: + self.stats_tool.count_prompt_tokens(new_mini_batch) + has_new_finished, outputs = self._prefill_batch(new_mini_batch) + if not new_mini_batch.is_clear(): + self._merge_batch(self.running_batch, new_mini_batch) + self.running_batch.merge(new_mini_batch) + self.has_wait_tokens = 0 + + else: + self.stats_tool.count_output_tokens(self.running_batch) + has_new_finished, outputs = self._decode_batch(self.running_batch) + self._filter_runing_batch() + self.has_wait_tokens += 1 + + if has_new_finished: + return outputs + return None + + def _prefill_batch(self, batch): + """ + For all batches, no matter it is a new batch or a mini batch, we need to do prefill first. + """ + self._init_batch(batch) + + # TODO: figure out if cache and batch id is needed + ans = self.engine._prefill_batch(batch.batch_id) + req_to_out_token_id = ans + self._add_token_id_to_req(batch, req_to_out_token_id) + has_new_finished_req = batch.mark_finished_req(self.eos_id) + outputs = self._handle_finish_req(batch, has_new_finished_req) + return has_new_finished_req, outputs + # delete finished reqs + + def _decode_batch(self, batch: Batch): + """ + Decoding process + """ + ans = self.engine._decode_batch(batch.batch_id) + req_to_out_token_id = ans + self._add_token_id_to_req(batch, req_to_out_token_id) + has_new_finished_req = batch.mark_finished_req(self.eos_id) + outputs = self._handle_finish_req(batch, has_new_finished_req) + return has_new_finished_req, outputs + + def _handle_finish_req(self, batch: Batch, has_new_finished_req): + if has_new_finished_req: + finished_reqs = batch.filter_finished() + if batch.is_clear(): + self._remove_batch(batch) + else: + self._filter_batch(batch) + return self._output_process(finished_reqs) + return None + + def _output_process(self, finished_reqs: List[Req]): + """ + Process the output of a batch. + """ + outputs = [] + for req in finished_reqs: + output = self.tokenizer.decode(req.output_ids) + outputs.append(RequestOutput(req.request_id, req.prompts, req.prompt_ids, output)) + return outputs + + +def start_dynamic_batching(args, tp_engine, waiting_req_list): + try: + batch_manager = Async_DynamicBatchManager( + tp_engine=tp_engine, + max_total_token_num=args.max_total_token_num, + batch_max_tokens=args.batch_max_tokens, + eos_id=args.eos_id, + model=args.model, + log_stats=not args.disable_log_stats, + log_stats_interval=args.log_stats_interval, + waiting_req_list=waiting_req_list, + ) + + except Exception: + batch_manager.clean_up() + raise + + return batch_manager diff --git a/colossalai/inference/dynamic_batching/get_tokenizer.py b/colossalai/inference/dynamic_batching/get_tokenizer.py index af1f26848b3a..94aa3f24393f 100644 --- a/colossalai/inference/dynamic_batching/get_tokenizer.py +++ b/colossalai/inference/dynamic_batching/get_tokenizer.py @@ -1,6 +1,12 @@ +""" +Motivated by VllM (https://github.com/vllm-project/vllm), This module is trying to resolve the tokenizer issue. + +license: MIT, see LICENSE for more details. +""" + from transformers import AutoTokenizer -_FAST_LLAMA_TOKENIZER = "/home/lccd/share/llama-tokenizer" +_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer" def get_tokenizer( diff --git a/colossalai/inference/dynamic_batching/infer_batch.py b/colossalai/inference/dynamic_batching/infer_batch.py index 826272db3e11..112784c15f84 100644 --- a/colossalai/inference/dynamic_batching/infer_batch.py +++ b/colossalai/inference/dynamic_batching/infer_batch.py @@ -1,15 +1,16 @@ +# Adapted from https://github.com/ModelTC/lightllm + import collections from dataclasses import dataclass -from typing import Dict, List , Tuple +from typing import Dict, List, Tuple import numpy as np import torch from colossalai.inference.tensor_parallel import MemoryManager -# make batch infer state an attr of InferBatch - +# make batch infer state an attr of InferBatch class InferSamplingParams: def __init__( self, @@ -65,7 +66,7 @@ def init_batch( cache_manager: MemoryManager, vocab_size: int, max_total_len: int, - ) -> 'InferBatch': + ) -> "InferBatch": input_lengths = [] all_input_ids = [] requests_idx_mapping = {} @@ -76,7 +77,7 @@ def init_batch( nopad_total_token_num = 0 nopad_max_len_in_batch = 0 nopad_b_loc = torch.empty((len(requests), max_total_len + 12), dtype=torch.long, device="cuda") - # to avoid memory leak , we pre-allocate 12 more space for each batch. + # to avoid memory leak , we pre-allocate 12 more space for each batch. nopad_b_start_loc = torch.zeros(len(requests), dtype=torch.int32, device="cuda") for i, r in enumerate(requests): # request id -> idx in list mapping @@ -142,10 +143,9 @@ def free_self(self) -> None: ) remove_index = torch.cat(remove_index, dim=-1) self.cache_manager.free(remove_index) - @torch.no_grad() - def filter(self, request_ids: List[int]) -> 'InferBatch': + def filter(self, request_ids: List[int]) -> "InferBatch": """ Filter finished batch and return a new InferBatch with left ones. """ @@ -226,7 +226,7 @@ def filter(self, request_ids: List[int]) -> 'InferBatch': @classmethod @torch.no_grad() - def merge(cls, batch1, batch2) -> 'InferBatch': + def merge(cls, batch1, batch2) -> "InferBatch": """ Return megerd new InferBatch """ diff --git a/colossalai/inference/dynamic_batching/io_struct.py b/colossalai/inference/dynamic_batching/io_struct.py index 9faaad6f111e..a75eb8007a02 100644 --- a/colossalai/inference/dynamic_batching/io_struct.py +++ b/colossalai/inference/dynamic_batching/io_struct.py @@ -1,10 +1,12 @@ +# Adapted from https://github.com/ModelTC/lightllm + from typing import Dict, List, Tuple from .sampling_params import SamplingParams class Req: - def __init__(self, request_id, prompt_ids, sample_params: SamplingParams, prompts: str): + def __init__(self, request_id, prompt_ids, sample_params: SamplingParams, prompts: str = ""): self.request_id = request_id self.prompt_ids = prompt_ids self.input_len = len(prompt_ids) @@ -49,26 +51,6 @@ def __repr__(self): return f"request_id(n={self.request_id}, " f"prompt_ids={self.prompt_ids}, " -class ReqDetokenizationState: - def __init__( - self, - request_id: str, - prompt_ids: List[int], - max_output_len: int, - ignore_eos: bool, - ) -> None: - self.request_id = request_id - self.prompt_ids = prompt_ids - self.output_ids = [] - self.output_tokens = [] - self.output_str = "" - self.sub_texts = [] - self.current_sub_text = [] - self.max_output_len = max_output_len - self.ignore_eos = ignore_eos - self.gen_metadata = {} - - class Batch: def __init__(self, batch_id, reqs: List[Req]): self.batch_id = batch_id @@ -156,3 +138,34 @@ def __init__(self): class AbortReq: def __init__(self, req_id): self.req_id = req_id + + +class RequestOutput: + """The output data of a request to the LLM. + + Args: + request_id: The unique ID of the request. + prompt: The prompt string of the request. + prompt_token_ids: The token IDs of the prompt. + outputs: The output sequences of the request. + """ + + def __init__( + self, + request_id: str, + prompt: str, + prompt_token_ids: List[int], + outputs, + ) -> None: + self.request_id = request_id + self.prompt = prompt + self.prompt_token_ids = prompt_token_ids + self.outputs = outputs + + def __repr__(self) -> str: + return ( + f"RequestOutput(request_id={self.request_id}, " + f"prompt={self.prompt!r}, " + f"prompt_token_ids={self.prompt_token_ids}, " + f"outputs={self.outputs}, " + ) diff --git a/colossalai/inference/dynamic_batching/ray_dist_init.py b/colossalai/inference/dynamic_batching/ray_dist_init.py index 70cc21436456..7639633eaa79 100644 --- a/colossalai/inference/dynamic_batching/ray_dist_init.py +++ b/colossalai/inference/dynamic_batching/ray_dist_init.py @@ -1,4 +1,3 @@ -import asyncio import logging import os from typing import List @@ -9,10 +8,11 @@ from transformers import AutoModelForCausalLM import colossalai +from colossalai.inference.async_manager import start_dynamic_batching from colossalai.inference.dynamic_batching.get_tokenizer import get_tokenizer +from colossalai.inference.dynamic_batching.io_struct import RequestOutput from colossalai.inference.dynamic_batching.ray_init_config import EngineArgsClass, RooterArgsClass from colossalai.inference.dynamic_batching.sampling_params import SamplingParams -from colossalai.inference.manager import start_dynamic_batching from colossalai.inference.tensor_parallel.engine import TPInferEngine from colossalai.shardformer import ShardConfig from colossalai.testing import free_port @@ -76,31 +76,25 @@ def setup(self, world_size, rank, port): return True - def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams) -> str: - ray_serve_logger.info(f"text: {prompt}") + # def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams) -> List[str]: + # ray_serve_logger.info(f"text: {prompt}") - results_generator = self.start_dynamic_batching.generate(prompt, sampling_params, request_id) + # final_outputs = self.start_dynamic_batching.generate(prompt, sampling_params, request_id) - final_output = None - for request_output in results_generator: - final_output = request_output - - assert final_output is not None - ray_serve_logger.info(f"Generated text: {final_output}") - return final_output + # return final_outputs def add_input(self, request_id: str, prompt: str, sampling_params: SamplingParams): - self.start_dynamic_batching.add_input(request_id, sampling_params, prompt) + self.start_dynamic_batching.add_input(request_id, prompt, sampling_params) def abort(self, request_id: str): self.start_dynamic_batching.abort(request_id) - def step(self): - self.start_dynamic_batching._step() + def step(self) -> List[RequestOutput]: + return self.start_dynamic_batching._step() def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompt: str): self.start_dynamic_batching.add_req(prompt_ids, sampling_params, request_id, prompt) - + def is_running(self): return self.start_dynamic_batching.is_running() @@ -140,32 +134,20 @@ def __init__(self, router_config: RooterArgsClass, engine_config: EngineArgsClas collective.create_collective_group(self.workers, **_options) _ = ray.get(init_rets) - # set batch wait delay in seconds and maximum number of sequences in a batch - def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams): - results = ray.get([w.generate.remote(request_id, prompt, sampling_params) for w in self.workers]) - text_res = results[0] # get any one of the copies - return text_res - - async def async_generate(self, request_id: str, prompt: str, sampling_params: SamplingParams): - all_outputs = [] - for worker in self.workers: - all_outputs.append(worker.generate.remote(request_id, prompt, sampling_params)) - all_outputs = await asyncio.gather(*all_outputs) - text_res = all_outputs[0] # get any one of the copies - return text_res - def add_input(self, request_id: str, prompt: str, sampling_params: SamplingParams): - ray.get([w.add_input.remote(request_id, sampling_params, prompt) for w in self.workers]) + ray.get([w.add_input.remote(request_id, prompt, sampling_params) for w in self.workers]) def abort(self, request_id: str): ray.get([w.abort.remote(request_id) for w in self.workers]) def step(self): - ray.get([w._step.remote() for w in self.workers]) + results = ray.get([w.step.remote() for w in self.workers]) + outputs = results[0] # get any one of the copies + return outputs def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompt: str): ray.get([w.add_req.remote(prompt_ids, sampling_params, request_id, prompt) for w in self.workers]) - + def is_running(self): results = ray.get([w.is_running.remote() for w in self.workers]) return any(results) diff --git a/colossalai/inference/dynamic_batching/req_queue.py b/colossalai/inference/dynamic_batching/req_queue.py index d9e9b6269cc4..0de43bd1a21f 100644 --- a/colossalai/inference/dynamic_batching/req_queue.py +++ b/colossalai/inference/dynamic_batching/req_queue.py @@ -1,3 +1,5 @@ +# Adapted from https://github.com/ModelTC/lightllm + import uuid from typing import List @@ -41,7 +43,7 @@ def _can_add_new_req(self, req): need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max() # NOTE: change here < to <= return need_max_token_num <= self.max_total_tokens and len(self.cache_len_list) <= self.running_max_req_size - + def generate_new_batch(self, current_batch: Batch = None): if current_batch is not None and len(current_batch.reqs) >= self.running_max_req_size: return None diff --git a/colossalai/inference/dynamic_batching/sampling_params.py b/colossalai/inference/dynamic_batching/sampling_params.py index 9a0ace4111dd..2028da907259 100644 --- a/colossalai/inference/dynamic_batching/sampling_params.py +++ b/colossalai/inference/dynamic_batching/sampling_params.py @@ -1,3 +1,5 @@ +# Adapted from https://github.com/ModelTC/lightllm + """Sampling parameters for text generation.""" from typing import List, Optional, Union @@ -5,7 +7,6 @@ class SamplingParams: - def __init__( self, do_sample: bool = False, @@ -13,10 +14,10 @@ def __init__( frequency_penalty: float = 0.0, temperature: float = 1.0, top_p: float = 1.0, - top_k: int = -1, # -1 is for all + top_k: int = -1, # -1 is for all ignore_eos: bool = False, max_new_tokens: int = 16, - stop_sequences: Optional[Union[str, List[str]]] = None # conditions to stop generation + stop_sequences: Optional[Union[str, List[str]]] = None, # conditions to stop generation ) -> None: self.do_sample = do_sample self.presence_penalty = presence_penalty @@ -31,11 +32,13 @@ def __init__( self.temperature = 1.0 self.top_p = 1.0 self.top_k = 1 - if self.temperature >= 0.0 and self.temperature < _SAMPLING_EPS: # temperature is too slow, change to greedy search + if ( + self.temperature >= 0.0 and self.temperature < _SAMPLING_EPS + ): # temperature is too slow, change to greedy search self.temperature = 1.0 self.top_k = 1 return - + def verify(self): if self.presence_penalty < 0.0: raise ValueError(f"presence_penalty must >= 0.0, got {self.presence_penalty}") @@ -60,13 +63,13 @@ def stop_sentences_to_token_ids(self, tokenizer): new_stop_sequences = [] for stop_str in self.stop_sequences: stop_str_ids = tokenizer.encode(stop_str) - if stop_str_ids is not None and len(stop_str_ids) >= 1: # remove bos_token_id + if stop_str_ids is not None and len(stop_str_ids) >= 1: # remove bos_token_id stop_str_ids = stop_str_ids[1:] if len(stop_str_ids) > 0: new_stop_sequences.append(stop_str_ids) self.stop_sequences = new_stop_sequences return - + def to_dict(self): ret = {} ret["do_sample"] = self.do_sample diff --git a/colossalai/inference/dynamic_batching/stats.py b/colossalai/inference/dynamic_batching/stats.py index 6d34183f47c4..524072861a3f 100644 --- a/colossalai/inference/dynamic_batching/stats.py +++ b/colossalai/inference/dynamic_batching/stats.py @@ -1,3 +1,5 @@ +# Adapted from https://github.com/ModelTC/lightllm + import time diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py index bd33837dc451..42ff8bf1e9ef 100644 --- a/colossalai/inference/manager.py +++ b/colossalai/inference/manager.py @@ -1,3 +1,5 @@ +# Adapted from https://github.com/ModelTC/lightllm + import time from typing import List @@ -51,7 +53,7 @@ def __init__( self.mem_usage_interval = log_stats_interval * 2 self.tokenizer = get_tokenizer(tokenizer_name=self.model) - def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompts: str): + def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompts: str = ""): """ Add new request to req queue, during initialization all requests are held in waiting list. """ @@ -59,7 +61,7 @@ def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, reques self.req_queue.append(req) return - def add_input(self, request_id, sampling_params, prompts): + def add_input(self, request_id, prompts, sampling_params): """ Encode and Add new input to req queue. support one sequence input for now. """ @@ -257,9 +259,10 @@ def generate(self, prompts, sampling_params, request_id): """ self.add_input(request_id, sampling_params, prompts) return self.loop_for_fwd() - + def is_running(self): - return self.running_batch is not None or self.req_queue.waiting_req_list + return self.running_batch is not None or self.req_queue.waiting_req_list + def start_dynamic_batching(args, tp_engine, waiting_req_list): try: diff --git a/tests/test_infer/test_dynamic_batching/test_async_engine.py b/tests/test_infer/test_dynamic_batching/test_async_engine.py new file mode 100644 index 000000000000..148d325a1d9a --- /dev/null +++ b/tests/test_infer/test_dynamic_batching/test_async_engine.py @@ -0,0 +1,60 @@ +import asyncio +import os +import uuid + +import pytest + +import colossalai +from colossalai.inference.async_engine import Async_Engine +from colossalai.inference.dynamic_batching.ray_init_config import RayInitConfig +from colossalai.inference.dynamic_batching.sampling_params import SamplingParams +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn + +PATH = "config.yaml" + + +def run_async_engine(path: str): + if not os.path.exists(path): + return + + config = RayInitConfig.from_yaml_path(path) + engine_config = config.engine_config_data + model = engine_config.model + if model is None or not os.path.exists(model): + return + + prompt = "Introduce some landmarks in Beijing" + sampling_params = SamplingParams() + asyncio.run(asy_for_loop_test(config, prompt, sampling_params)) + + +async def get_result(engine, prompt, sampling_params): + request_id = str(uuid.uuid4().hex) + results = engine.generate(request_id, prompt, sampling_params) + async for result in results: + assert result is not None + + +async def asy_for_loop_test(config, prompt, sampling_params): + router_config = config.router_config_data + engine_config = config.engine_config_data + engine = Async_Engine(router_config=router_config, engine_config=engine_config) + for i in range(10): + print("in for loop", i) + await get_result(engine, prompt, sampling_params) + + +def check_async_engine(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run_async_engine(PATH) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_async_engine(): + spawn(check_async_engine, 1) + + +if __name__ == "__main__": + test_async_engine() diff --git a/tests/test_infer/test_dynamic_batching/test_dynamic_batching_manager.py b/tests/test_infer/test_dynamic_batching/test_dynamic_batching_manager.py index 124f1f478b00..588922b5a58f 100644 --- a/tests/test_infer/test_dynamic_batching/test_dynamic_batching_manager.py +++ b/tests/test_infer/test_dynamic_batching/test_dynamic_batching_manager.py @@ -45,6 +45,7 @@ def run(): log_stats=False, log_stats_interval=10, waiting_req_list=waiting_list, + model="llama", ) before_add = len(dynamic_batch_manager.req_queue) diff --git a/tests/test_infer/test_dynamic_batching/test_ray_dist.py b/tests/test_infer/test_dynamic_batching/test_ray_dist.py index 0eea9ef16345..5c84b39d8f8e 100644 --- a/tests/test_infer/test_dynamic_batching/test_ray_dist.py +++ b/tests/test_infer/test_dynamic_batching/test_ray_dist.py @@ -2,19 +2,21 @@ import os import uuid +import pytest + +import colossalai from colossalai.inference.dynamic_batching.ray_dist_init import Driver from colossalai.inference.dynamic_batching.ray_init_config import RayInitConfig from colossalai.inference.dynamic_batching.sampling_params import SamplingParams -import colossalai -import pytest from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn PATH = "config.yaml" + def run_ray_dist(path: str): print(f"Using yaml file {path}") if not os.path.exists(path): - raise FileNotFoundError(f"Invalid yaml file path {path}") + return config = RayInitConfig.from_yaml_path(path) router_config = config.router_config_data engine_config = config.engine_config_data @@ -25,8 +27,8 @@ def run_ray_dist(path: str): prompt = "Introduce some landmarks in Beijing" request_id = str(uuid.uuid4().hex) - sampling_params = SamplingParams() + print("sampling_params: ", sampling_params) async def get_result(request_id, prompt, sampling_params): return await driver.async_generate(request_id, prompt, sampling_params) @@ -35,19 +37,20 @@ async def get_result(request_id, prompt, sampling_params): if test_async: print("test_async: ", test_async) result = asyncio.run(get_result(request_id, prompt, sampling_params)) - assert result is not None + assert result is not None print("result: ", result) else: print("test_async: ", test_async) result = driver.generate(request_id, prompt, sampling_params) assert result is not None print("result: ", result) - + is_running = None is_running = driver.is_running() - assert is_running is not None + assert is_running is not None print("is_running: ", is_running) + def check_ray_dist(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_ray_dist(PATH) @@ -59,5 +62,6 @@ def check_ray_dist(rank, world_size, port): def test_ray_dist(): spawn(check_ray_dist, 1) + if __name__ == "__main__": test_ray_dist()