diff --git a/cacheflow/config.py b/cacheflow/config.py index cf779723a969..759a2cafbfeb 100644 --- a/cacheflow/config.py +++ b/cacheflow/config.py @@ -12,6 +12,20 @@ class ModelConfig: + """Configuration for the model. + + Args: + model: Name or path of the huggingface model to use. + download_dir: Directory to download and load the weights, default to the + default cache directory of huggingface. + use_np_weights: Save a numpy copy of model weights for faster loading. + This can increase the disk usage by up to 2x. + use_dummy_weights: Use dummy values for model weights (for profiling). + dtype: Data type for model weights and activations. The "auto" option + will use FP16 precision for FP32 and FP16 models, and BF16 precision + for BF16 models. + seed: Random seed for reproducibility. + """ def __init__( self, @@ -68,7 +82,14 @@ def get_num_layers(self, parallel_config: "ParallelConfig") -> int: class CacheConfig: - + """Configuration for the KV cache. + + Args: + block_size: Size of a cache block in number of tokens. + gpu_memory_utilization: Fraction of GPU memory to use for the + CacheFlow execution. + swap_space: Size of the CPU swap space per GPU (in GiB). + """ def __init__( self, block_size: int, @@ -111,7 +132,15 @@ def verify_with_parallel_config( class ParallelConfig: - + """Configuration for the distributed execution. + + Args: + pipeline_parallel_size: Number of pipeline parallel groups. + tensor_parallel_size: Number of tensor parallel groups. + worker_use_ray: Whether to use Ray for model workers. Will be set to + True if either pipeline_parallel_size or tensor_parallel_size is + greater than 1. + """ def __init__( self, pipeline_parallel_size: int, @@ -134,7 +163,14 @@ def _verify_args(self) -> None: class SchedulerConfig: - + """Scheduler configuration. + + Args: + max_num_batched_tokens: Maximum number of tokens to be processed in + a single iteration. + max_num_seqs: Maximum number of sequences to be processed in a single + iteration. + """ def __init__( self, max_num_batched_tokens: int, diff --git a/cacheflow/entrypoints/openai/openai_frontend.py b/cacheflow/entrypoints/openai/openai_frontend.py index 8f00db863e9b..acf09b210c1e 100644 --- a/cacheflow/entrypoints/openai/openai_frontend.py +++ b/cacheflow/entrypoints/openai/openai_frontend.py @@ -96,6 +96,18 @@ def create_logprobs(token_ids: List[int], @app.post("/v1/completions") async def create_completion(raw_request: Request): + """Completion API similar to OpenAI's API. + + See https://platform.openai.com/docs/api-reference/completions/create + for the API specification. This API mimics the OpenAI Completion API. + + NOTE: Currently we do not support the following features: + - echo (since the cacheflow server does not currently support + getting the logprobs of prompt tokens) + - suffix (the language models we currently support do not support + suffix) + - logit_bias (to be supported in cacheflow server) + """ request = CompletionRequest(**await raw_request.json()) logger.info(f"Received completion request: {request}") diff --git a/cacheflow/entrypoints/simple_fastapi_frontend.py b/cacheflow/entrypoints/simple_fastapi_frontend.py index 1fce4cf3bf8e..1438851c5cbd 100644 --- a/cacheflow/entrypoints/simple_fastapi_frontend.py +++ b/cacheflow/entrypoints/simple_fastapi_frontend.py @@ -18,6 +18,12 @@ @app.post("/generate") async def generate_stream(request: Request) -> StreamingResponse: + """ Stream the results of the generation request. + + The request should be a JSON object with the following fields: + - prompt: the prompt to use for the generation. + - other fields: the sampling parameters (See `SamplingParams` for details). + """ request_dict = await request.json() prompt = request_dict.pop("prompt") sampling_params = SamplingParams(**request_dict) diff --git a/cacheflow/server/arg_utils.py b/cacheflow/server/arg_utils.py index 63f32c80fab9..5c8dc9f154ea 100644 --- a/cacheflow/server/arg_utils.py +++ b/cacheflow/server/arg_utils.py @@ -9,6 +9,7 @@ @dataclass class ServerArgs: + """Arguments for CacheFlow servers.""" model: str download_dir: Optional[str] = None use_np_weights: bool = False @@ -117,6 +118,7 @@ def create_server_configs( @dataclass class AsyncServerArgs(ServerArgs): + """Arguments for asynchronous CacheFlow servers.""" server_use_ray: bool = False @staticmethod diff --git a/cacheflow/server/async_llm_server.py b/cacheflow/server/async_llm_server.py index 409af2f240ed..e4467e9bfe9e 100644 --- a/cacheflow/server/async_llm_server.py +++ b/cacheflow/server/async_llm_server.py @@ -1,6 +1,6 @@ import asyncio import time -from typing import Dict, Optional +from typing import Dict, List, Optional from cacheflow.logger import init_logger from cacheflow.outputs import RequestOutput @@ -15,7 +15,25 @@ class AsyncLLMServer: - + """An asynchronous wrapper for LLMServer. + + This class is used to wrap the LLMServer class to make it asynchronous. It + uses asyncio to create a background loop that keeps processing incoming + requests. The LLMServer is kicked by the generate method when there + are requests in the waiting queue. The generate method yields the outputs + from the LLMServer to the caller. + + NOTE: For the comprehensive list of arguments, see `LLMServer`. + + Args: + worker_use_ray: Whether to use Ray for model workers. Required for + distributed execution. Should be the same as + `parallel_config.worker_use_ray`. + server_use_ray: Whether to make LLMServer a Ray actor. If so, the + async frontend will be executed in a separate process as the + model workers. + *args, *kwargs: Arguments for LLMServer. + """ def __init__(self, worker_use_ray: bool, server_use_ray: bool, *args, **kwargs) -> None: self.worker_use_ray = worker_use_ray @@ -35,6 +53,7 @@ def __init__(self, worker_use_ray: bool, server_use_ray: bool, self.kicking_request_id: Optional[str] = None async def server_step(self, kicking_request_id: Optional[str] = None): + """Kick the server to process the waiting requests.""" self.is_server_running = True self.kicking_request_id = kicking_request_id if self.server_use_ray: @@ -54,8 +73,31 @@ async def server_step(self, kicking_request_id: Optional[str] = None): self.request_outputs[request_id] = request_output self.request_events[request_id].set() - async def generate(self, prompt: str, sampling_params: SamplingParams, - request_id: str) -> RequestOutput: + async def generate( + self, + prompt: Optional[str], + sampling_params: SamplingParams, + request_id: str, + prompt_token_ids: Optional[List[int]] = None + ) -> RequestOutput: + """Generate outputs for a request. + + Generate outputs for a request. This method is a coroutine. It adds the + request into the waiting queue of the LLMServer and streams the outputs + from the LLMServer to the caller. + + Args: + prompt: The prompt string. Can be None if prompt_token_ids is + provided. + sampling_params: The sampling parameters of the request. + request_id: The unique id of the request. + prompt_token_ids: The token IDs of the prompt. If None, we + use the tokenizer to convert the prompts to token IDs. + + Yields: + The output `RequestOutput` objects from the LLMServer for the + request. + """ # Preprocess the request. arrival_time = time.time() @@ -66,20 +108,29 @@ async def generate(self, prompt: str, sampling_params: SamplingParams, logger.info(f"Received request {request_id}: " f"prompt: {prompt!r}, " - f"sampling params: {sampling_params}.") + f"sampling params: {sampling_params}, " + f"prompt token ids: {prompt_token_ids}.") # Add the request into the cacheflow server's waiting queue. if self.server_use_ray: await self.server.add_request.remote( - request_id, prompt, sampling_params, arrival_time=arrival_time) + request_id, prompt, sampling_params, + prompt_token_ids=prompt_token_ids, + arrival_time=arrival_time) else: self.server.add_request( - request_id, prompt, sampling_params, arrival_time=arrival_time) + request_id, prompt, sampling_params, + prompt_token_ids=prompt_token_ids, + arrival_time=arrival_time) # The cacheflow server does not have a background loop that keeps # processing incoming requests. Therefore, we need to keep kicking # the server to process the requests. while True: + if request_id not in self.request_events: + # The request has been aborted. + return + # Kick the server if the server is not running. if not self.is_server_running: await self.server_step(request_id) @@ -113,6 +164,14 @@ async def generate(self, prompt: str, sampling_params: SamplingParams, break async def abort(self, request_id: str) -> None: + """Abort a request. + + Abort a submitted request. If the request is finished or not found, + this method will be a no-op. + + Args: + request_id: The unique id of the request. + """ if request_id not in self.request_events: # The request has already finished or been aborted. return @@ -137,6 +196,7 @@ async def abort(self, request_id: str) -> None: @classmethod def from_server_args(cls, server_args: AsyncServerArgs) -> "AsyncLLMServer": + """Creates an async LLM server from the server arguments.""" # Create the server configs. server_configs = server_args.create_server_configs() parallel_config = server_configs[2] diff --git a/cacheflow/server/llm_server.py b/cacheflow/server/llm_server.py index 54ab622359b7..6a9107f82b3a 100644 --- a/cacheflow/server/llm_server.py +++ b/cacheflow/server/llm_server.py @@ -8,7 +8,7 @@ from cacheflow.outputs import RequestOutput from cacheflow.sampling_params import SamplingParams from cacheflow.server.arg_utils import ServerArgs -from cacheflow.server.ray_utils import ray, initialize_cluster +from cacheflow.server.ray_utils import DeviceID, initialize_cluster, ray from cacheflow.server.tokenizer_utils import (get_tokenizer, detokenize_incrementally) from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus @@ -19,6 +19,33 @@ class LLMServer: + """An LLM server that receives requests and generates texts. + + This is the main class for the CacheFlow LLM server. It receives requests + from clients and generates texts from the LLM. It includes a tokenizer, a + language model (possibly distributed across multiple GPUs), and GPU memory + space allocated for intermediate states (aka KV cache). This class utilizes + iteration-level scheduling and efficient memory management to maximize the + serving throughput. + + The `LLM` class wraps this class for offline batched inference and the + `AsyncLLMServer` class wraps this class for online serving. + + NOTE: The config arguments are derived from the `ServerArgs` class. For the + comprehensive list of arguments, see `ServerArgs`. + + Args: + model_config: The configuration related to the LLM model. + cache_config: The configuration related to the KV cache memory + management. + parallel_config: The configuration related to distributed execution. + scheduler_config: The configuration related to the request scheduler. + distributed_init_method: The initialization method for distributed + execution. See `torch.distributed.init_process_group` for details. + stage_devices: The list of devices for each stage. Each stage is a list + of (rank, node_resource, device) tuples. + log_stats: Whether to log statistics. + """ def __init__( self, @@ -27,7 +54,7 @@ def __init__( parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, distributed_init_method: str, - stage_devices: List[List[Any]], + stage_devices: List[List[DeviceID]], log_stats: bool, ) -> None: logger.info( @@ -83,6 +110,7 @@ def _verify_args(self) -> None: self.cache_config.verify_with_parallel_config(self.parallel_config) def _init_cache(self) -> None: + """Profiles the memory usage and initializes the KV cache.""" # Get the maximum number of blocks that can be allocated on GPU and CPU. num_blocks = self._run_workers( "profile_num_available_blocks", @@ -108,6 +136,7 @@ def _init_cache(self) -> None: @classmethod def from_server_args(cls, server_args: ServerArgs) -> "LLMServer": + """Creates an LLM server from the server arguments.""" # Create the server configs. server_configs = server_args.create_server_configs() parallel_config = server_configs[2] @@ -126,6 +155,22 @@ def add_request( prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, ) -> None: + """Add a request to the server's request pool. + + The request is added to the request pool and will be processed by the + scheduler as `server.step()` is called. The exact scheduling policy is + determined by the scheduler. + + Args: + request_id: The unique ID of the request. + prompt: The prompt string. Can be None if prompt_token_ids is + provided. + sampling_params: The sampling parameters for text generation. + prompt_token_ids: The token IDs of the prompt. If None, we + use the tokenizer to convert the prompts to token IDs. + arrival_time: The arrival time of the request. If None, we use + the current time. + """ if arrival_time is None: arrival_time = time.time() if prompt_token_ids is None: @@ -148,15 +193,30 @@ def add_request( self.scheduler.add_seq_group(seq_group) def abort_request(self, request_id: str) -> None: + """Aborts a request with the given ID. + + Args: + request_id: The ID of the request to abort. + """ self.scheduler.abort_seq_group(request_id) def get_num_unfinished_requests(self) -> int: + """Gets the number of unfinished requests.""" return self.scheduler.get_num_unfinished_seq_groups() def has_unfinished_requests(self) -> bool: + """Returns True if there are unfinished requests.""" return self.scheduler.has_unfinished_seqs() def step(self) -> List[RequestOutput]: + """Performs one decoding iteration and returns newly generated results. + + This function performs one decoding iteration for the server. It first + schedules the sequences to be executed in the next iteration and the + token blocks to be swapped in/out/copy. Then, it executes the model + and updates the scheduler with the model outputs. Finally, it decodes + the sequences and returns the newly generated results. + """ seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() if (not seq_group_metadata_list) and scheduler_outputs.is_empty(): # Nothing to do. @@ -188,7 +248,7 @@ def step(self) -> List[RequestOutput]: return request_outputs def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None: - # Decode the sequence outputs. + """Decodes the sequence outputs.""" for seq_group in seq_groups: for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): new_token, new_output_text = detokenize_incrementally( @@ -201,7 +261,7 @@ def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None: seq.output_text = new_output_text def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None: - # Stop the sequences. + """Stop the finished sequences.""" for seq_group in seq_groups: sampling_params = seq_group.sampling_params for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): @@ -238,6 +298,7 @@ def _run_workers( *args, **kwargs, ) -> Any: + """Runs the given method on all workers.""" all_outputs = [] for worker in self.workers: executor = getattr(worker, method) diff --git a/cacheflow/server/ray_utils.py b/cacheflow/server/ray_utils.py index 4d533bddee0b..e701d00fd9a6 100644 --- a/cacheflow/server/ray_utils.py +++ b/cacheflow/server/ray_utils.py @@ -14,15 +14,30 @@ def initialize_cluster( parallel_config: ParallelConfig, server_use_ray: bool = False, - address: Optional[str] = None, + ray_server_address: Optional[str] = None, ) -> Tuple[str, List[List[DeviceID]]]: + """Initialize the distributed cluster probably with Ray. + + Args: + parallel_config: The configurations for parallel execution. + server_use_ray: Whether to use Ray for async server. + ray_server_address: The address of the Ray cluster. If None, uses + the default Ray cluster address. + + Returns: + A tuple of (`distributed_init_method`, `all_stage_devices`). The + `distributed_init_method` is the address for initializing the + distributed backend. `all_stage_devices` includes device IDs for + each worker in each pipeline stage. Each device ID is a tuple of + (rank, node resource, device id). + """ if parallel_config.worker_use_ray or server_use_ray: if ray is None: raise ImportError( "Ray is not installed. Please install Ray to use distributed " "serving.") # Connect to a ray cluster. - ray.init(address=address) + ray.init(address=ray_server_address) if not parallel_config.worker_use_ray: # Initialize cluster locally. diff --git a/cacheflow/server/tokenizer_utils.py b/cacheflow/server/tokenizer_utils.py index 8aede295d245..3443412c4021 100644 --- a/cacheflow/server/tokenizer_utils.py +++ b/cacheflow/server/tokenizer_utils.py @@ -15,6 +15,7 @@ def get_tokenizer( *args, **kwargs, ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + """Gets a tokenizer for the given model name via Huggingface.""" config = AutoConfig.from_pretrained(model_name) if config.model_type == "llama" and getattr(kwargs, "use_fast", True): # LLaMA fast tokenizer causes protobuf errors in some environments. diff --git a/examples/openai_client.py b/examples/openai_client.py index 9e711a8a0899..3994b8e4ebae 100644 --- a/examples/openai_client.py +++ b/examples/openai_client.py @@ -1,14 +1,15 @@ import openai + +# Modify OpenAI's API key and API base to use CacheFlow's API server. openai.api_key = "EMPTY" openai.api_base = "http://localhost:8000/v1" model = "facebook/opt-125m" -# list models +# Test list models API models = openai.Model.list() -print(models) - -# create a completion +print("Models:", models) +# Test completion API stream = True completion = openai.Completion.create( model=model, prompt="A robot may not injure a human being", echo=False, n=2, @@ -19,4 +20,4 @@ for c in completion: print(c) else: - print("completion:", completion) + print("Completion result:", completion) diff --git a/examples/simple_server.py b/examples/simple_server.py index afca0e3ca7d2..d43e1bc85e90 100644 --- a/examples/simple_server.py +++ b/examples/simple_server.py @@ -19,7 +19,7 @@ def main(args: argparse.Namespace): SamplingParams(n=3, best_of=3, use_beam_search=True, temperature=0.0)), ] - # Run the server. + # Run the server by calling `server.step()` manually. request_id = 0 while True: # To test iteration-level scheduling, we add one request at each step.