- 
                Notifications
    You must be signed in to change notification settings 
- Fork 4.5k
[Inference] Dynamic Batching Inference, online and offline #4953
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
          
     Merged
      
      
    
  
     Merged
                    Changes from all commits
      Commits
    
    
            Show all changes
          
          
            36 commits
          
        
        Select commit
          Hold shift + click to select a range
      
      e0757c3
              
                [inference] Dynamic Batching for Single and Multiple GPUs (#4831)
              
              
                CjhHa1 fced140
              
                [inference] Async dynamic batching  (#4894)
              
              
                CjhHa1 fbf3c09
              
                [inference]Re push async dynamic batching (#4901)
              
              
                CjhHa1 d509e79
              
                Revert "[inference]Re push async dynamic batching (#4901)" (#4905)
              
              
                CjhHa1 ec004fe
              
                Revert "[inference] Async dynamic batching  (#4894)"
              
              
                isky-cd 78cd937
              
                Revert "[inference] Async dynamic batching  (#4894)" (#4909)
              
              
                tiandiao123 d97290a
              
                Add Ray Distributed Environment Init Scripts
              
              
                isky-cd 8483393
              
                fix conflict
              
              
                isky-cd f589e97
              
                support DynamicBatchManager base function
              
              
                isky-cd c070050
              
                revert _set_tokenizer version
              
              
                isky-cd 5deb95c
              
                add driver async generate
              
              
                isky-cd 306ef77
              
                add async test
              
              
                isky-cd 632f0e1
              
                fix bugs in test_ray_dist.py
              
              
                isky-cd 0b2fe51
              
                add get_tokenizer.py
              
              
                isky-cd cd843ac
              
                fix code style
              
              
                isky-cd 8c9ad51
              
                fix bugs about No module named 'pydantic' in ci test
              
              
                isky-cd 8d0cc6b
              
                fix bugs in ci test
              
              
                isky-cd acdd751
              
                fix bugs in ci test
              
              
                isky-cd 8a761bd
              
                fix bugs in ci test
              
              
                isky-cd 56f75c4
              
                [infer]Add Ray Distributed Environment Init Scripts (#4911)
              
              
                isky-cd c76fd68
              
                support dynamic batch for bloom model and is_running function
              
              
                isky-cd f41ccdd
              
                fix conflict
              
              
                isky-cd fca12b8
              
                Merge pull request #4933 from yuehuayingxueluo/ray_dist_init_branch
              
              
                isky-cd 4ea9fbe
              
                [Inference]Test for new Async engine (#4935)
              
              
                CjhHa1 3f6af12
              
                add assertion for config (#4947)
              
              
                CjhHa1 4867561
              
                [Inference] Finish dynamic batching offline test (#4948)
              
              
                CjhHa1 285fc30
              
                fix bugs
              
              
                CjhHa1 d5d2c94
              
                fix quant
              
              
                CjhHa1 ed86584
              
                add default
              
              
                CjhHa1 4bffb8b
              
                fix
              
              
                CjhHa1 77adc2e
              
                fix some bugs
              
              
                CjhHa1 dcb51b4
              
                fix some bugs
              
              
                CjhHa1 afae53b
              
                fix
              
              
                 c477266
              
                fix bug
              
              
                 f99eba2
              
                fix bugs
              
              
                 4c3ea40
              
                reset param
              
              
                 File filter
Filter by extension
Conversations
          Failed to load comments.   
        
        
          
      Loading
        
  Jump to
        
          Jump to file
        
      
      
          Failed to load files.   
        
        
          
      Loading
        
  Diff view
Diff view
There are no files selected for viewing
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,133 @@ | ||
| import asyncio | ||
|  | ||
| from colossalai.inference.dynamic_batching.ray_dist_init import Driver | ||
|  | ||
| from .dynamic_batching.io_struct import RequestOutput | ||
| from .dynamic_batching.sampling_params import SamplingParams | ||
|  | ||
|  | ||
| class RequestTracker: | ||
| """ | ||
| A class for trace down all the requests, abstraction for async | ||
| """ | ||
|  | ||
| def __init__(self) -> None: | ||
| self._requests: asyncio.Queue[str] = asyncio.Queue() | ||
| self._finished_requests: asyncio.Queue[RequestOutput] = asyncio.Queue() | ||
| self.new_requests_event = None | ||
|  | ||
| def __contains__(self, item): | ||
| return item in self._requests | ||
|  | ||
| def init_event(self): | ||
| self.new_requests_event = asyncio.Event() | ||
|  | ||
| def add_request(self, request_id: str): | ||
| """Add a request to be sent to the engine on the next background | ||
| loop iteration.""" | ||
| self._requests.put_nowait(request_id) | ||
| self.new_requests_event.set() # NOTE: we may find a better way to clear this event | ||
|  | ||
| def add_stop(self): | ||
| """ | ||
| Add a StopIteration flag to stop async generator. | ||
| """ | ||
| self._finished_requests.put_nowait(StopIteration) | ||
| self.new_requests_event.clear() | ||
|  | ||
| def process_request_output(self, request_output: RequestOutput) -> None: | ||
| """Process a request output from the engine.""" | ||
| self._finished_requests.put_nowait(request_output) | ||
|  | ||
| async def wait_for_new_requests(self): | ||
| await self.new_requests_event.wait() | ||
|  | ||
| def __aiter__(self): | ||
| return self | ||
|  | ||
| async def __anext__(self) -> RequestOutput: | ||
| result = await self._finished_requests.get() | ||
| # print("result of ", result) | ||
| if result is StopIteration: | ||
| raise StopAsyncIteration | ||
| return result | ||
|  | ||
|  | ||
| class Async_Engine: | ||
|  | ||
| """ | ||
| Use an engine to launch RAY Driver --> RAY Worker --> Async_Manager | ||
| Background loop: inference reqs in waiting list (Listen) | ||
| Request Tracker: manage incoming requests and restore finished ones | ||
| Generate: exposed func for add new input and return finished ones | ||
| """ | ||
|  | ||
| def __init__( | ||
| self, | ||
| router_config, | ||
| engine_config, | ||
| start_engine_loop: bool = True, | ||
| ) -> None: | ||
| self.driver = Driver(router_config=router_config, engine_config=engine_config) | ||
| self.background_loop = None | ||
| self.start_engine_loop = start_engine_loop | ||
| self._request_tracker = RequestTracker() | ||
|  | ||
| def _step(self): | ||
| """ | ||
| Logic for handling requests | ||
| """ | ||
| request_outputs = self.driver.step() | ||
| if request_outputs is not None: | ||
| for request_output in request_outputs: | ||
| self._request_tracker.process_request_output(request_output) | ||
| self._request_tracker.add_stop() | ||
|  | ||
| def abort_request(self, request_id: str): | ||
| self.driver.abort(request_id) | ||
|  | ||
| def _has_requests_in_progress(self): | ||
| return self.driver.is_running() | ||
|  | ||
| async def run_loop_fwd(self): | ||
| has_requests_in_progress = self._has_requests_in_progress() | ||
| while True: | ||
| if not has_requests_in_progress: | ||
| await self._request_tracker.wait_for_new_requests() | ||
| self._step() | ||
| await asyncio.sleep(0) | ||
|  | ||
| @property | ||
| def is_running(self): | ||
| return self.background_loop is not None and not self.background_loop.done() | ||
|  | ||
| def start_background_loop(self): | ||
| if self.is_running: | ||
| raise RuntimeError("Background loop is already running.") | ||
|  | ||
| self._request_tracker.init_event() | ||
|  | ||
| self.background_loop_unshielded = asyncio.get_event_loop().create_task(self.run_loop_fwd()) | ||
| self.background_loop = asyncio.shield(self.background_loop_unshielded) | ||
|  | ||
| async def add_request(self, request_id: str, prompt: str, sampling_params: SamplingParams): | ||
| self.driver.add_input(request_id, prompt, sampling_params) | ||
| self._request_tracker.add_request(request_id) | ||
|  | ||
| async def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams): | ||
| """ | ||
| The only exposed func, adding new request and return a async generator that yields the existing results. | ||
| """ | ||
| try: | ||
| if not self.is_running: | ||
| self.start_background_loop() | ||
|  | ||
| await self.add_request(request_id, prompt, sampling_params) | ||
|  | ||
| async for request_output in self._request_tracker: | ||
| yield request_output | ||
|  | ||
| except (Exception, asyncio.CancelledError) as e: | ||
| # If there is an exception or coroutine is cancelled, abort the request. | ||
| self.abort_request(request_id) | ||
| raise e | ||
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,151 @@ | ||
| from typing import List | ||
|  | ||
| from .dynamic_batching.io_struct import Batch, Req, RequestOutput | ||
| from .manager import DynamicBatchManager | ||
| from .tensor_parallel import TPInferEngine | ||
|  | ||
|  | ||
| class Async_DynamicBatchManager(DynamicBatchManager): | ||
| def __init__( | ||
| self, | ||
| tp_engine: TPInferEngine, | ||
| max_total_token_num: int, | ||
| batch_max_tokens: int, | ||
| model: str, | ||
| tokenizer=None, | ||
| eos_id=None, | ||
| log_stats=True, | ||
| log_stats_interval=10, | ||
| running_batch: Batch = None, | ||
| waiting_req_list: List = [], | ||
| ): | ||
| """ | ||
| Args: tp_engine : The tp engine that dynamic batch manager hold, defined before dynamic batch manager | ||
| max_total_token_num : max_total_token_num for memory manager, default to: max batch size * (max input len + max output len) | ||
| batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests | ||
| running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine | ||
| eos_id : The end token of a seq | ||
| model: the model weight dir path, the app will load config, weights and tokenizer from this dir | ||
| log_stats : whether to log stats | ||
| log_stats_interval : log stats interval | ||
| running_batch : running batch | ||
| waiting_req_list : list of waiting requests, initialized before dynamic batch manager | ||
| """ | ||
| super().__init__( | ||
| tp_engine, | ||
| max_total_token_num, | ||
| batch_max_tokens, | ||
| model, | ||
| tokenizer, | ||
| eos_id, | ||
| log_stats, | ||
| log_stats_interval, | ||
| running_batch, | ||
| waiting_req_list, | ||
| ) | ||
|  | ||
| def _step(self): | ||
| """ | ||
| Logic for handling requests | ||
| """ | ||
| has_new_finished = False | ||
| if self.running_batch is None: | ||
| new_batch = self.req_queue.generate_new_batch(self.running_batch) | ||
| if new_batch is not None: | ||
| self.stats_tool.count_prompt_tokens(new_batch) | ||
| self.running_batch = new_batch | ||
| has_new_finished, outputs = self._prefill_batch(self.running_batch) | ||
| self._filter_runing_batch() | ||
| self.has_wait_tokens = 0 | ||
|  | ||
| else: | ||
| if self.has_wait_tokens < self.max_wait_tokens: | ||
| self.stats_tool.count_output_tokens(self.running_batch) | ||
| has_new_finished, outputs = self._decode_batch(self.running_batch) | ||
| self._filter_runing_batch() | ||
| self.has_wait_tokens += 1 | ||
|  | ||
| else: | ||
| new_mini_batch = self.req_queue.generate_new_batch(self.running_batch) | ||
| if new_mini_batch is not None: | ||
| self.stats_tool.count_prompt_tokens(new_mini_batch) | ||
| has_new_finished, outputs = self._prefill_batch(new_mini_batch) | ||
| if not new_mini_batch.is_clear(): | ||
| self._merge_batch(self.running_batch, new_mini_batch) | ||
| self.running_batch.merge(new_mini_batch) | ||
| self.has_wait_tokens = 0 | ||
|  | ||
| else: | ||
| self.stats_tool.count_output_tokens(self.running_batch) | ||
| has_new_finished, outputs = self._decode_batch(self.running_batch) | ||
| self._filter_runing_batch() | ||
| self.has_wait_tokens += 1 | ||
|  | ||
| if has_new_finished: | ||
| return outputs | ||
| return None | ||
|  | ||
| def _prefill_batch(self, batch): | ||
| """ | ||
| For all batches, no matter it is a new batch or a mini batch, we need to do prefill first. | ||
| """ | ||
| self._init_batch(batch) | ||
|  | ||
| # TODO: figure out if cache and batch id is needed | ||
| ans = self.engine._prefill_batch(batch.batch_id) | ||
| req_to_out_token_id = ans | ||
| self._add_token_id_to_req(batch, req_to_out_token_id) | ||
| has_new_finished_req = batch.mark_finished_req(self.eos_id, self.engine.max_output_len) | ||
| outputs = self._handle_finish_req(batch, has_new_finished_req) | ||
| return has_new_finished_req, outputs | ||
| # delete finished reqs | ||
|  | ||
| def _decode_batch(self, batch: Batch): | ||
| """ | ||
| Decoding process | ||
| """ | ||
| ans = self.engine._decode_batch(batch.batch_id) | ||
| req_to_out_token_id = ans | ||
| self._add_token_id_to_req(batch, req_to_out_token_id) | ||
| has_new_finished_req = batch.mark_finished_req(self.eos_id, self.engine.max_output_len) | ||
| outputs = self._handle_finish_req(batch, has_new_finished_req) | ||
| return has_new_finished_req, outputs | ||
|  | ||
| def _handle_finish_req(self, batch: Batch, has_new_finished_req): | ||
| if has_new_finished_req: | ||
| finished_reqs = batch.filter_finished() | ||
| if batch.is_clear(): | ||
| self._remove_batch(batch) | ||
| else: | ||
| self._filter_batch(batch) | ||
| return self._output_process(finished_reqs) | ||
| return None | ||
|  | ||
| def _output_process(self, finished_reqs: List[Req]): | ||
| """ | ||
| Process the output of a batch. | ||
| """ | ||
| outputs = [] | ||
| for req in finished_reqs: | ||
| output = self.tokenizer.decode(req.output_ids) | ||
| outputs.append(RequestOutput(req.request_id, req.prompts, req.prompt_ids, output)) | ||
| return outputs | ||
|  | ||
|  | ||
| def start_dynamic_batching(args, tp_engine, waiting_req_list): | ||
| try: | ||
| batch_manager = Async_DynamicBatchManager( | ||
| tp_engine=tp_engine, | ||
| max_total_token_num=args.max_total_token_num, | ||
| batch_max_tokens=args.batch_max_tokens, | ||
| eos_id=args.eos_id, | ||
| model=args.model, | ||
| log_stats=not args.disable_log_stats, | ||
| log_stats_interval=args.log_stats_interval, | ||
| waiting_req_list=waiting_req_list, | ||
| ) | ||
|  | ||
| except Exception: | ||
| raise Exception | ||
|  | ||
| return batch_manager | 
              Empty file.
          
    
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,40 @@ | ||
| """ | ||
| Motivated by VllM (https://github.com/vllm-project/vllm), This module is trying to resolve the tokenizer issue. | ||
|  | ||
| license: MIT, see LICENSE for more details. | ||
| """ | ||
|  | ||
| from transformers import AutoTokenizer | ||
|  | ||
| _FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer" | ||
|  | ||
|  | ||
| def get_tokenizer( | ||
| tokenizer=None, | ||
| tokenizer_name: str = "", | ||
| trust_remote_code: bool = False, | ||
| use_fast: bool = True, | ||
| ): | ||
| if tokenizer is not None: | ||
| tokenizer = tokenizer | ||
| else: | ||
| if "llama" in tokenizer_name.lower() and use_fast == True: | ||
| print( | ||
| "For some LLaMA-based models, initializing the fast tokenizer may " | ||
| "take a long time. To eliminate the initialization time, consider " | ||
| f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original " | ||
| "tokenizer. This is done automatically in Colossalai." | ||
| ) | ||
|  | ||
| tokenizer_name = _FAST_LLAMA_TOKENIZER | ||
|  | ||
| try: | ||
| tokenizer = AutoTokenizer.from_pretrained( | ||
| tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code | ||
| ) | ||
| except TypeError: | ||
| use_fast = False | ||
| tokenizer = AutoTokenizer.from_pretrained( | ||
| tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code | ||
| ) | ||
| return tokenizer | 
      
      Oops, something went wrong.
        
    
  
  Add this suggestion to a batch that can be applied as a single commit.
  This suggestion is invalid because no changes were made to the code.
  Suggestions cannot be applied while the pull request is closed.
  Suggestions cannot be applied while viewing a subset of changes.
  Only one suggestion per line can be applied in a batch.
  Add this suggestion to a batch that can be applied as a single commit.
  Applying suggestions on deleted lines is not supported.
  You must change the existing code in this line in order to create a valid suggestion.
  Outdated suggestions cannot be applied.
  This suggestion has been applied or marked resolved.
  Suggestions cannot be applied from pending reviews.
  Suggestions cannot be applied on multi-line comments.
  Suggestions cannot be applied while the pull request is queued to merge.
  Suggestion cannot be applied right now. Please check back later.
  
    
  
    
Uh oh!
There was an error while loading. Please reload this page.