diff --git a/README.md b/README.md index 1ae1ff6aa665..7cce45b9efbc 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ ## Installation ```bash -pip install cmake torch transformers +pip install psutil numpy torch transformers pip install flash-attn # This may take up to 10 mins. pip install -e . ``` diff --git a/cacheflow/master/scheduler.py b/cacheflow/master/scheduler.py index 7f2ca1455fc4..0d1b8f9c36ab 100644 --- a/cacheflow/master/scheduler.py +++ b/cacheflow/master/scheduler.py @@ -9,8 +9,6 @@ from cacheflow.sequence import SequenceOutputs from cacheflow.sequence import SequenceStatus -_MAX_NUM_BATCHED_TOKENS = 2048 - class Scheduler: @@ -21,12 +19,14 @@ def __init__( block_size: int, num_gpu_blocks: int, num_cpu_blocks: int, + max_num_batched_tokens: int, ) -> None: self.frontend = frontend self.controllers = controllers self.block_size = block_size self.num_gpu_blocks = num_gpu_blocks self.num_cpu_blocks = num_cpu_blocks + self.max_num_batched_tokens = max_num_batched_tokens # Create the block space manager. self.block_manager = BlockSpaceManager( @@ -164,7 +164,7 @@ def step(self) -> None: num_prompt_tokens = seq_group.seqs[0].get_len() if self.block_manager.can_allocate(seq_group): if (num_batched_tokens + num_prompt_tokens - <= _MAX_NUM_BATCHED_TOKENS): + <= self.max_num_batched_tokens): self._allocate(seq_group) num_batched_tokens += num_prompt_tokens continue diff --git a/cacheflow/models/__init__.py b/cacheflow/models/__init__.py index 67dbd5627cbb..cd8f134a5a74 100644 --- a/cacheflow/models/__init__.py +++ b/cacheflow/models/__init__.py @@ -1,10 +1,12 @@ from cacheflow.models.input_metadata import InputMetadata +from cacheflow.models.model_utils import get_memory_analyzer from cacheflow.models.model_utils import get_model -from cacheflow.models.model_utils import set_seed +from cacheflow.models.utils import set_seed __all__ = [ 'InputMetadata', + 'get_memory_analyzer', 'get_model', - 'set_seed' + 'set_seed', ] diff --git a/cacheflow/models/memory_analyzer.py b/cacheflow/models/memory_analyzer.py new file mode 100644 index 000000000000..6af7b25f60b3 --- /dev/null +++ b/cacheflow/models/memory_analyzer.py @@ -0,0 +1,125 @@ +import torch +from transformers import AutoConfig + +from cacheflow.models.utils import get_cpu_memory +from cacheflow.models.utils import get_dtype_size +from cacheflow.models.utils import get_gpu_memory + +_GiB = 1 << 30 + + +class CacheFlowMemoryAnalyzer: + + def get_max_num_gpu_blocks( + self, + max_num_batched_tokens: int, + memory_utilization: float, + ) -> int: + raise NotImplementedError() + + def get_max_num_cpu_blocks( + self, + memory_utilization: float, + ) -> int: + raise NotImplementedError() + + +class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer): + + def __init__( + self, + model_name: str, + block_size: int, + dtype: torch.dtype, + ) -> None: + self.model_name = model_name + self.block_size = block_size + self.dtype = dtype + + # TODO(woosuk): Support tensor parallelism. + config = AutoConfig.from_pretrained(model_name) + self.num_layers = config.num_hidden_layers + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_size = config.hidden_size // self.num_heads + self.ffn_size = config.ffn_dim + self.embedding_size = config.word_embed_proj_dim + self.vocab_size = config.vocab_size + self.max_position = config.max_position_embeddings + + def _get_param_size(self) -> int: + # TODO(woosuk): Support tensor parallelism. + word_embedding = self.vocab_size * self.embedding_size + if self.embedding_size != self.vocab_size: + # Project in/out. + word_embedding += 2 * self.embedding_size * self.vocab_size + position_embedding = self.max_position * self.hidden_size + + ln1 = 2 * self.hidden_size + q = self.hidden_size * self.hidden_size + self.hidden_size + k = self.hidden_size * self.hidden_size + self.hidden_size + v = self.hidden_size * self.hidden_size + self.hidden_size + out = self.hidden_size * self.hidden_size + self.hidden_size + mha = ln1 + q + k + v + out + + ln2 = 2 * self.hidden_size + ffn1 = self.hidden_size * self.ffn_size + self.ffn_size + ffn2 = self.ffn_size * self.hidden_size + self.hidden_size + ffn = ln2 + ffn1 + ffn2 + + total = (word_embedding + position_embedding + + self.num_layers * (mha + ffn)) + dtype_size = get_dtype_size(self.dtype) + return dtype_size * total + + def _get_max_act_size( + self, + max_num_batched_tokens: int, + ) -> int: + # TODO(woosuk): Support tensor parallelism. + # NOTE: We approxmiately calculate the maximum activation size by + # 1) estimating the maximum activation tensor size during inference, and + # 2) multiplying it by 4. + # Here, we assume that FlashAttention is used and + # thus the attention maps are never materialized in GPU DRAM. + qkv = 3 * (max_num_batched_tokens * self.hidden_size) + ffn = max_num_batched_tokens * self.ffn_size + max_act = 4 * max(qkv, ffn) + dtype_size = get_dtype_size(self.dtype) + return dtype_size * max_act + + def _get_workspace_size(self) -> int: + return 1 * _GiB + + def _get_cache_block_size(self) -> int: + key_cache_block = self.block_size * self.num_heads * self.head_size + value_cache_block = self.block_size * self.num_heads * self.head_size + total = self.num_layers * (key_cache_block + value_cache_block) + dtype_size = get_dtype_size(self.dtype) + return dtype_size * total + + def get_max_num_gpu_blocks( + self, + max_num_batched_tokens: int, + memory_utilization: float = 0.95, + ) -> int: + # NOTE(woosuk): This assumes that the machine has homogeneous GPUs. + gpu_memory = get_gpu_memory() + usable_memory = int(memory_utilization * gpu_memory) + + param_size = self._get_param_size() + act_size = self._get_max_act_size(max_num_batched_tokens) + workspace_size = self._get_workspace_size() + + max_cache_size = usable_memory - (param_size + act_size + workspace_size) + max_num_blocks = max_cache_size // self._get_cache_block_size() + return max_num_blocks + + def get_max_num_cpu_blocks( + self, + memory_utilization: float = 0.25, + ) -> int: + cpu_memory = get_cpu_memory() + usable_memory = int(memory_utilization * cpu_memory) + max_num_blocks = usable_memory // self._get_cache_block_size() + return max_num_blocks diff --git a/cacheflow/models/model_utils.py b/cacheflow/models/model_utils.py index d26fd8c46a1d..98ff6d44ebb0 100644 --- a/cacheflow/models/model_utils.py +++ b/cacheflow/models/model_utils.py @@ -1,21 +1,20 @@ -import random from typing import Union -import numpy as np import torch import torch.nn as nn +from cacheflow.models.memory_analyzer import CacheFlowMemoryAnalyzer +from cacheflow.models.memory_analyzer import OPTMemoryAnalyzer from cacheflow.models.opt import OPTForCausalLM +from cacheflow.models.utils import get_torch_dtype -MODEL_CLASSES = { + +_MODELS = { 'opt': OPTForCausalLM, } -STR_DTYPE_TO_TORCH_DTYPE = { - 'half': torch.half, - 'float': torch.float, - 'float16': torch.float16, - 'float32': torch.float32, +_MEMORY_ANALYZERS = { + 'opt': OPTMemoryAnalyzer, } @@ -23,20 +22,23 @@ def get_model( model_name: str, dtype: Union[torch.dtype, str], ) -> nn.Module: - if isinstance(dtype, str): - torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype.lower()] - else: - torch_dtype = dtype - for model_class, hf_model in MODEL_CLASSES.items(): + torch_dtype = get_torch_dtype(dtype) + for model_class, hf_model in _MODELS.items(): if model_class in model_name: - model = hf_model.from_pretrained(model_name, torch_dtype=torch_dtype) + model = hf_model.from_pretrained( + model_name, torch_dtype=torch_dtype) return model.eval() - raise ValueError(f'Invalid model name: {model_name}') + raise ValueError(f'Unsupported model name: {model_name}') -def set_seed(seed: int) -> None: - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(seed) +def get_memory_analyzer( + model_name: str, + block_size: int, + dtype: Union[torch.dtype, str], +) -> CacheFlowMemoryAnalyzer: + torch_dtype = get_torch_dtype(dtype) + for model_class, memory_analyzer in _MEMORY_ANALYZERS.items(): + if model_class in model_name: + return memory_analyzer( + model_name, block_size, torch_dtype) + raise ValueError(f'Unsupported model name: {model_name}') diff --git a/cacheflow/models/utils.py b/cacheflow/models/utils.py new file mode 100644 index 000000000000..4b705bf7d969 --- /dev/null +++ b/cacheflow/models/utils.py @@ -0,0 +1,43 @@ +from typing import Union + +import random + +import numpy as np +import psutil +import torch + +_STR_DTYPE_TO_TORCH_DTYPE = { + 'half': torch.half, + 'float': torch.float, + 'float16': torch.float16, + 'float32': torch.float32, +} + + +def get_torch_dtype(dtype: Union[torch.dtype, str]) -> torch.dtype: + if isinstance(dtype, str): + torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype.lower()] + else: + torch_dtype = dtype + return torch_dtype + + +def get_dtype_size(dtype: Union[torch.dtype, str]) -> int: + torch_dtype = get_torch_dtype(dtype) + return torch.tensor([], dtype=torch_dtype).element_size() + + +def set_seed(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + +def get_gpu_memory(gpu: int = 0) -> int: + return torch.cuda.get_device_properties(gpu).total_memory + + +def get_cpu_memory() -> int: + return psutil.virtual_memory().total diff --git a/server.py b/server.py index c873caf15d5d..b740724c373f 100644 --- a/server.py +++ b/server.py @@ -3,6 +3,7 @@ from cacheflow.master.frontend import Frontend from cacheflow.master.scheduler import Scheduler +from cacheflow.models import get_memory_analyzer from cacheflow.worker.controller import Controller parser = argparse.ArgumentParser(description='CacheFlow server') @@ -10,17 +11,25 @@ parser.add_argument('--num-nodes', type=int, default=1, help='number of nodes') parser.add_argument('--num-workers', type=int, default=1, help='number of workers per node') parser.add_argument('--block-size', type=int, default=8, choices=[8, 16], help='token block size') -# TODO(woosuk): Add an analytical model to determine the maximum number of GPU/CPU blocks. -parser.add_argument('--num-gpu-blocks', type=int, default=1024, help='number of GPU blocks (per GPU)') -parser.add_argument('--num-cpu-blocks', type=int, default=32, help='number of CPU blocks (per GPU)') # NOTE(woosuk): If FlashAttention is used, the float data type is not supported. parser.add_argument('--dtype', type=str, default='half', choices=['half', 'float'], help='data type') # TODO(woosuk): Support fine-grained seeds (e.g., seed per request). parser.add_argument('--seed', type=int, default=0, help='random seed') +parser.add_argument('--max-batch-size', type=int, default=2048, help='maximum number of batched tokens') args = parser.parse_args() def main(): + memory_analyzer = get_memory_analyzer( + model_name=args.model, + block_size=args.block_size, + dtype=args.dtype, + ) + num_gpu_blocks = memory_analyzer.get_max_num_gpu_blocks( + max_num_batched_tokens=args.max_batch_size) + num_cpu_blocks = memory_analyzer.get_max_num_cpu_blocks() + print(f'# GPU blocks: {num_gpu_blocks}, # CPU blocks: {num_cpu_blocks}') + # Create a controller for each node. controllers: List[Controller] = [] for i in range(args.num_nodes): @@ -29,8 +38,8 @@ def main(): num_workers=args.num_workers, model_name=args.model, block_size=args.block_size, - num_gpu_blocks=args.num_gpu_blocks, - num_cpu_blocks=args.num_cpu_blocks, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks, dtype=args.dtype, seed=args.seed, ) @@ -47,8 +56,9 @@ def main(): frontend=frontend, controllers=controllers, block_size=args.block_size, - num_gpu_blocks=args.num_gpu_blocks, - num_cpu_blocks=args.num_cpu_blocks, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks, + max_num_batched_tokens=args.max_batch_size, ) # Connect the controllers. for i in range(len(controllers) - 1):