diff --git a/cacheflow/master/scheduler.py b/cacheflow/master/scheduler.py index 0d1b8f9c36ab..3931d92684f3 100644 --- a/cacheflow/master/scheduler.py +++ b/cacheflow/master/scheduler.py @@ -158,8 +158,8 @@ def step(self) -> None: # 3. Join new sequences if possible. # NOTE: Here we implicitly assume FCFS scheduling. # TODO(woosuk): Add a batching policy to control the batch size. + self._fetch_inputs() if not self.swapped: - self._fetch_inputs() for i, seq_group in enumerate(self.pending): num_prompt_tokens = seq_group.seqs[0].get_len() if self.block_manager.can_allocate(seq_group): @@ -211,12 +211,13 @@ def step(self) -> None: input_seq_groups.append(input_seq_group) # 5. Execute the first stage of the pipeline. - self.controllers[0].execute_stage( - input_seq_groups, - blocks_to_swap_in, - blocks_to_swap_out, - blocks_to_copy, - ) + if (input_seq_groups or blocks_to_swap_in or blocks_to_swap_out): + self.controllers[0].execute_stage( + input_seq_groups, + blocks_to_swap_in, + blocks_to_swap_out, + blocks_to_copy, + ) def post_step( self, diff --git a/cacheflow/models/attention.py b/cacheflow/models/attention.py index 7c77db5a819b..fe1e194d94a8 100644 --- a/cacheflow/models/attention.py +++ b/cacheflow/models/attention.py @@ -12,7 +12,7 @@ class OPTCacheFlowAttention(nn.Module): def __init__(self, scale: float) -> None: - super().__init__() + super(OPTCacheFlowAttention, self).__init__() self.scale = float(scale) self.flash_attn = FlashAttention(softmax_scale=self.scale) @@ -106,8 +106,8 @@ def forward( output = output.view(-1, num_heads, head_size) # Compute the attention op for prompts. - if input_metadata.num_prompts > 0: - num_prompt_tokens = sum(input_metadata.prompt_lens) + num_prompt_tokens = input_metadata.num_prompt_tokens + if num_prompt_tokens > 0: self.multi_query_kv_attention( output[:num_prompt_tokens], query[:num_prompt_tokens], @@ -126,10 +126,9 @@ def forward( if input_metadata.num_generation_tokens > 0: # Compute the attention op for generation tokens. - start_idx = sum(input_metadata.prompt_lens) self.single_query_cached_kv_attention( - output[start_idx:], - query[start_idx:], + output[num_prompt_tokens:], + query[num_prompt_tokens:], key_cache, value_cache, input_metadata) diff --git a/cacheflow/models/memory_analyzer.py b/cacheflow/models/memory_analyzer.py index 6af7b25f60b3..11668b588da5 100644 --- a/cacheflow/models/memory_analyzer.py +++ b/cacheflow/models/memory_analyzer.py @@ -5,7 +5,7 @@ from cacheflow.models.utils import get_dtype_size from cacheflow.models.utils import get_gpu_memory -_GiB = 1 << 30 +_GiB = 1 << 30 class CacheFlowMemoryAnalyzer: @@ -117,9 +117,19 @@ def get_max_num_gpu_blocks( def get_max_num_cpu_blocks( self, - memory_utilization: float = 0.25, + swap_space: int, ) -> int: + swap_space = swap_space * _GiB cpu_memory = get_cpu_memory() - usable_memory = int(memory_utilization * cpu_memory) - max_num_blocks = usable_memory // self._get_cache_block_size() + if swap_space > 0.8 * cpu_memory: + raise ValueError(f'The swap space ({swap_space / _GiB:.2f} GiB) ' + 'takes more than 80% of the available memory ' + f'({cpu_memory / _GiB:.2f} GiB).' + 'Please check the swap space size.') + if swap_space > 0.5 * cpu_memory: + print(f'WARNING: The swap space ({swap_space / _GiB:.2f} GiB) ' + 'takes more than 50% of the available memory ' + f'({cpu_memory / _GiB:.2f} GiB).' + 'This may slow the system performance.') + max_num_blocks = swap_space // self._get_cache_block_size() return max_num_blocks diff --git a/cacheflow/models/sample.py b/cacheflow/models/sample.py index 7bdc6a42771a..a8b60c145208 100644 --- a/cacheflow/models/sample.py +++ b/cacheflow/models/sample.py @@ -11,7 +11,7 @@ class Sampler(nn.Module): def __init__(self) -> None: - super().__init__() + super(Sampler, self).__init__() def forward( self, diff --git a/cacheflow/worker/worker.py b/cacheflow/worker/worker.py index 84d8cdd5390a..3f4b451934c9 100644 --- a/cacheflow/worker/worker.py +++ b/cacheflow/worker/worker.py @@ -191,6 +191,13 @@ def execute_stage( else: cache_events = None + # If there is no input, we don't need to execute the model. + if not input_seq_groups: + if cache_events is not None: + for event in cache_events: + event.wait() + return {} + # Prepare input tensors. input_tokens, input_positions, input_metadata = self.prepare_inputs( input_seq_groups) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 0366b922c28d..8d69e4dc6459 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -1,5 +1,4 @@ #include - #include #include @@ -73,6 +72,8 @@ void copy_blocks( } } +namespace cacheflow { + template __global__ void reshape_and_cache_kernel( const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] @@ -112,6 +113,8 @@ __global__ void reshape_and_cache_kernel( } } +} // namespace cacheflow + void reshape_and_cache( torch::Tensor& key, torch::Tensor& value, @@ -131,7 +134,7 @@ void reshape_and_cache( key.scalar_type(), "reshape_and_cache_kernel", [&] { - reshape_and_cache_kernel<<>>( + cacheflow::reshape_and_cache_kernel<<>>( key.data_ptr(), value.data_ptr(), key_cache.data_ptr(), diff --git a/server.py b/server.py index b740724c373f..ccd6f8f6e3f8 100644 --- a/server.py +++ b/server.py @@ -15,7 +15,8 @@ parser.add_argument('--dtype', type=str, default='half', choices=['half', 'float'], help='data type') # TODO(woosuk): Support fine-grained seeds (e.g., seed per request). parser.add_argument('--seed', type=int, default=0, help='random seed') -parser.add_argument('--max-batch-size', type=int, default=2048, help='maximum number of batched tokens') +parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU') +parser.add_argument('--max-batch-size', type=int, default=2560, help='maximum number of batched tokens') args = parser.parse_args() @@ -27,7 +28,8 @@ def main(): ) num_gpu_blocks = memory_analyzer.get_max_num_gpu_blocks( max_num_batched_tokens=args.max_batch_size) - num_cpu_blocks = memory_analyzer.get_max_num_cpu_blocks() + num_cpu_blocks = memory_analyzer.get_max_num_cpu_blocks( + swap_space=args.swap_space) print(f'# GPU blocks: {num_gpu_blocks}, # CPU blocks: {num_cpu_blocks}') # Create a controller for each node.