diff --git a/examples/offline_inference/basic/basic.py b/examples/offline_inference/basic/basic.py index ae5ae7cb4834..5afe91783c55 100644 --- a/examples/offline_inference/basic/basic.py +++ b/examples/offline_inference/basic/basic.py @@ -15,7 +15,8 @@ def main(): # Create an LLM. - llm = LLM(model="facebook/opt-125m") + # llm = LLM(model="facebook/opt-125m") + llm = LLM(model="google/gemma-3-1b-it", enforce_eager=True) # Generate texts from the prompts. # The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. diff --git a/tests/v1/e2e/test_correctness_sliding_window.py b/tests/v1/e2e/test_correctness_sliding_window.py index a125d3fb7975..a2ec9793d2f8 100644 --- a/tests/v1/e2e/test_correctness_sliding_window.py +++ b/tests/v1/e2e/test_correctness_sliding_window.py @@ -17,15 +17,15 @@ class TestConfig: model_config = { "bigcode/starcoder2-3b": TestConfig(4096, (800, 1100)), - "google/gemma-2-2b-it": TestConfig(4096, (400, 800)), + "google/gemma-3-1b-it": TestConfig(4096, (400, 800)), # TODO: swa 1024 } @pytest.mark.parametrize( "model", [ - "bigcode/starcoder2-3b", # sliding window only - "google/gemma-2-2b-it", # sliding window + full attention + # "bigcode/starcoder2-3b", # sliding window only + "google/gemma-3-1b-it", # sliding window + full attention ]) @pytest.mark.parametrize("batch_size", [5]) @pytest.mark.parametrize("seed", [1]) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 718b15e58785..1d78295a9781 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -37,7 +37,7 @@ is_block_tables_empty) from vllm.attention.layer import Attention from vllm.attention.ops.paged_attn import PagedAttention -from vllm.config import VllmConfig +from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype, make_tensor_with_pad) @@ -128,12 +128,10 @@ def get_per_layer_parameters( to use during `plan`. """ - layers = vllm_config.compilation_config.static_forward_context + layers = get_layers_from_vllm_config(vllm_config, Attention) per_layer_params: Dict[str, PerLayerParameters] = {} for key, layer in layers.items(): - assert isinstance(layer, Attention) - impl = layer.impl assert isinstance(impl, FlashInferImpl) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index aa218cc37af9..5fdeb1709b18 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -210,6 +210,8 @@ def forward( if self.use_direct_call: forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[self.layer_name] self_kv_cache = self.kv_cache[forward_context.virtual_engine] self.impl.forward(self, query, @@ -226,6 +228,8 @@ def forward( if self.use_direct_call: forward_context = get_forward_context() attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[self.layer_name] self_kv_cache = self.kv_cache[forward_context.virtual_engine] return self.impl.forward(self, query, key, value, self_kv_cache, attn_metadata) @@ -374,6 +378,8 @@ def unified_attention( forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[layer_name] self = forward_context.no_compile_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] output = self.impl.forward(self, query, key, value, kv_cache, @@ -411,6 +417,8 @@ def unified_attention_with_output( wait_for_kv_layer_from_connector(layer_name) forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[layer_name] self = forward_context.no_compile_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] self.impl.forward(self, diff --git a/vllm/config.py b/vllm/config.py index 0ac3cc46b063..7c30a8267ecb 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1320,6 +1320,8 @@ class CacheConfig: """The number of blocks to allocate for GPU memory.""" num_cpu_blocks: Optional[int] = field(default=None, init=False) """The number of blocks to allocate for CPU memory.""" + disable_hybrid_allocator: bool = False + """Whether to disable the hybrid allocator (Only affects v1).""" def compute_hash(self) -> str: """ @@ -4075,3 +4077,16 @@ def assert_hashable(text): f"vLLM tried to hash some configs that may have Python objects ids " f"in them. This is a bug, please file an issue. " f"Text being hashed: {text}") + + +T = TypeVar("T") + + +def get_layers_from_vllm_config(vllm_config: VllmConfig, + layer_type: type[T]) -> dict[str, T]: + return { + layer_name: layer + for layer_name, layer in + vllm_config.compilation_config.static_forward_context.items() + if isinstance(layer, layer_type) + } diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index d5b87a2ce2aa..a5408872a753 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -254,6 +254,7 @@ class EngineArgs: model_impl: str = "auto" calculate_kv_scales: bool = CacheConfig.calculate_kv_scales + disable_hybrid_allocator: bool = False additional_config: Optional[Dict[str, Any]] = None enable_reasoning: Optional[bool] = None @@ -948,6 +949,12 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: help="Enable sleep mode for the engine. " "(only cuda platform is supported)") + parser.add_argument( + "--disable-hybrid-allocator", + action="store_true", + default=False, + help="Disable the hybrid allocator. This only affects v1.") + parser.add_argument( "--additional-config", type=json.loads, @@ -1148,6 +1155,7 @@ def create_engine_config( prefix_caching_hash_algo=self.prefix_caching_hash_algo, cpu_offload_gb=self.cpu_offload_gb, calculate_kv_scales=self.calculate_kv_scales, + disable_hybrid_allocator=self.disable_hybrid_allocator, ) # Get the current placement group if Ray is initialized and diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 06790d8ee2f8..1bd2db8fcb4f 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -4,7 +4,7 @@ from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, Union import torch import torch.distributed as dist @@ -38,8 +38,13 @@ class DPMetadata: class ForwardContext: # copy from vllm_config.compilation_config.static_forward_context no_compile_layers: dict[str, Any] - # TODO: extend to support per-layer dynamic forward context - attn_metadata: "AttentionMetadata" # set dynamically for each forward pass + """ + Type AttentionMetadata for v0, + Type Dict[str, AttentionMetadata] for v1, mapping from layer_name to + AttentionMetadata of that layer + set dynamically for each forward pass + """ + attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"]] # TODO: remove after making all virtual_engines share the same kv cache virtual_engine: int # set dynamically for each forward pass # set dynamically for each forward pass diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 51ae386d3389..718fb3b06662 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -14,6 +14,9 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import cdiv +from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.worker.block_table import BlockTable from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8, get_flash_attn_version) @@ -278,7 +281,8 @@ def make_local_attention_virtual_batches( class FlashAttentionMetadataBuilder: - def __init__(self, runner: "GPUModelRunner"): + def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, + block_table: BlockTable): model_config = runner.model_config self.runner = runner @@ -288,23 +292,23 @@ def __init__(self, runner: "GPUModelRunner"): self.num_heads_kv = model_config.get_num_kv_heads( runner.parallel_config) self.headdim = model_config.get_head_size() - self.page_size = self.runner.block_size + self.page_size = kv_cache_spec.block_size + self.kv_cache_spec = kv_cache_spec + self.block_table = block_table def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: return False def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, - common_prefix_len: int): + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata): max_seq_len = self.runner.seq_lens_np[:num_reqs].max() - query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1] - query_start_loc = query_start_loc_cpu.to(self.runner.device, - non_blocking=True) - seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs] - seq_lens = seq_lens_cpu.to(self.runner.device, non_blocking=True) - block_table = ( - self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) - slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( + query_start_loc = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens + block_table = self.block_table + block_table_tensor = block_table.get_device_tensor()[:num_reqs] + slot_mapping = block_table.slot_mapping_cpu[:num_actual_tokens].to( self.runner.device, non_blocking=True).long() def schedule(batch_size, cu_query_lens, max_query_len, seqlens, @@ -328,12 +332,12 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, local_attn_metadata = None if self.runner.attention_chunk_size is not None: seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \ - virt_block_table = make_local_attention_virtual_batches( + virt_block_table_tensor = make_local_attention_virtual_batches( self.runner.attention_chunk_size, self.runner.query_start_loc_np[:num_reqs + 1], self.runner.seq_lens_np[:num_reqs], - block_table, - self.runner.block_size, + block_table_tensor, + self.kv_cache_spec.block_size, ) local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to( self.runner.device, non_blocking=True) @@ -352,7 +356,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata( local_query_start_loc=local_query_start_loc, local_seqused_k=local_seqused_k, - local_block_table=virt_block_table, + local_block_table=virt_block_table_tensor, local_max_query_len=local_max_query_len, local_max_seq_len=local_max_seq_len, local_scheduler_metadata=local_scheduler_metadata, @@ -403,7 +407,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, query_start_loc=query_start_loc, max_seq_len=max_seq_len, seq_lens=seq_lens, - block_table=block_table, + block_table=block_table_tensor, slot_mapping=slot_mapping, use_cascade=use_cascade, common_prefix_len=common_prefix_len, diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 17341ecfa4fe..9f24317371af 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -14,9 +14,12 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionType) from vllm.attention.layer import Attention -from vllm.config import VllmConfig, get_current_vllm_config +from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger from vllm.v1.attention.backends.flash_attn import use_cascade_attention +from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.worker.block_table import BlockTable if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -81,12 +84,10 @@ def get_per_layer_parameters( to use during `plan`. """ - layers = vllm_config.compilation_config.static_forward_context + layers = get_layers_from_vllm_config(vllm_config, Attention) per_layer_params: dict[str, PerLayerParameters] = {} for key, layer in layers.items(): - assert isinstance(layer, Attention) - impl = layer.impl assert isinstance(impl, FlashInferImpl) @@ -205,7 +206,8 @@ def __post_init__(self): class FlashInferMetadataBuilder: - def __init__(self, runner: GPUModelRunner): + def __init__(self, runner: GPUModelRunner, kv_cache_spec: AttentionSpec, + block_table: BlockTable): self.runner = runner self._workspace_buffer = None self._prefill_wrapper = None # Wrapper for prefill/append @@ -215,7 +217,9 @@ def __init__(self, runner: GPUModelRunner): # Global hyperparameters shared by all attention layers self.global_hyperparameters: Optional[PerLayerParameters] = None - self.vllm_config = get_current_vllm_config() + self.vllm_config = runner.vllm_config + self.kv_cache_spec = kv_cache_spec + self.block_table = block_table def reorder_batch(self, input_batch: InputBatch, scheduler_output: SchedulerOutput) -> bool: @@ -398,19 +402,17 @@ def _plan(self, attn_metadata: FlashInferMetadata): ) def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, - common_prefix_len: int): + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata): assert self._num_decodes + self._num_prefills == num_reqs assert (self._num_decode_tokens + self._num_prefill_tokens == num_actual_tokens) - page_size = self.runner.block_size + page_size = self.kv_cache_spec.block_size device = self.runner.device - qo_indptr = self.runner.query_start_loc_cpu[:num_reqs + 1].to( - self.runner.device, non_blocking=True) - seq_lens = self.runner.seq_lens_cpu[:num_reqs].to(self.runner.device, - non_blocking=True) - block_table = ( - self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) - slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( + qo_indptr = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens + block_table_tensor = (self.block_table.get_device_tensor()[:num_reqs]) + slot_mapping = self.block_table.slot_mapping_cpu[:num_actual_tokens].to( self.runner.device, non_blocking=True).long() block_table_bounds = (seq_lens + page_size - 1) // page_size @@ -426,12 +428,13 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, shared_kv_page_indptr = torch.tensor([0, num_common_kv_blocks], dtype=torch.int32, device=device) - shared_kv_page_indices = block_table[0, :num_common_kv_blocks] + shared_kv_page_indices = block_table_tensor[ + 0, :num_common_kv_blocks] shared_kv_last_page_len = torch.tensor([page_size], dtype=torch.int32, device=device) # Remove the blocks of the shared prefix from all requests. - block_table = block_table[:, num_common_kv_blocks:] + block_table_tensor = block_table_tensor[:, num_common_kv_blocks:] block_table_bounds -= num_common_kv_blocks else: shared_qo_indptr = None @@ -439,11 +442,11 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, shared_kv_page_indices = None shared_kv_last_page_len = None - mask = (torch.arange(block_table.size(1), - dtype=block_table.dtype, - device=block_table.device).unsqueeze(0) + mask = (torch.arange(block_table_tensor.size(1), + dtype=block_table_tensor.dtype, + device=block_table_tensor.device).unsqueeze(0) < block_table_bounds.unsqueeze(1)) - paged_kv_indices = block_table[mask] + paged_kv_indices = block_table_tensor[mask] paged_kv_indptr = torch.cat([ torch.zeros(1, @@ -462,9 +465,9 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, paged_kv_indptr=paged_kv_indptr, paged_kv_indices=paged_kv_indices, paged_kv_last_page_len=paged_kv_last_page_len, - num_qo_heads=self.runner.num_query_heads, - num_kv_heads=self.runner.num_kv_heads, - head_dim=self.runner.head_size, + num_qo_heads=self.kv_cache_spec.num_query_heads, + num_kv_heads=self.kv_cache_spec.num_kv_heads, + head_dim=self.kv_cache_spec.head_size, page_size=page_size, data_type=self.runner.kv_cache_dtype, q_data_type=self.runner.dtype, diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index f826f8a21789..90464b9073a3 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -204,6 +204,9 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.platforms import current_platform from vllm.utils import cdiv, round_down +from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.worker.block_table import BlockTable from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version try: @@ -341,6 +344,8 @@ class MLACommonMetadataBuilder(Generic[M]): def __init__(self, runner: "GPUModelRunner", + kv_cache_spec: AttentionSpec, + block_table: BlockTable, metadata_cls: Optional[type[M]] = None): self.metadata_cls = metadata_cls \ if metadata_cls is not None else MLACommonMetadata @@ -353,10 +358,11 @@ def __init__(self, runner.parallel_config) self.mla_dims = get_mla_dims(model_config) self.aot_schedule = is_vllm_fa and (get_flash_attn_version() == 3) + self.kv_cache_spec = kv_cache_spec # Dont try to access the runner on AMD if self.aot_schedule: - self.page_size = self.runner.block_size + self.page_size = self.kv_cache_spec.block_size if self.chunked_prefill_enabled: self.chunked_prefill_workspace_size = min( @@ -382,6 +388,8 @@ def __init__(self, dtype=model_config.dtype, device=runner.device, ) + self.page_size = kv_cache_spec.block_size + self.block_table = block_table def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: @@ -444,32 +452,32 @@ def reorder_batch(self, input_batch: "InputBatch", return modified_batch def _build_decode(self, input_positions: torch.Tensor, - block_table: torch.Tensor, seq_lens: torch.Tensor): + block_table_tensor: torch.Tensor, + seq_lens: torch.Tensor): return MLACommonDecodeMetadata( input_positions=input_positions, - block_table=block_table, + block_table=block_table_tensor, seq_lens=seq_lens, ) def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, - common_prefix_len: int) -> M: + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata) -> M: assert self._num_decodes + self._num_prefills == num_reqs # Note(simon): be careful about the CPU <> GPU memory movement in this # function. We should avoid GPU -> CPU sync as much as possible because # it blocks on all previous kernels. device = self.runner.device - block_table = ( - self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) - query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to( - device, non_blocking=True) - slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( - device, non_blocking=True).long() + block_table_tensor = (self.block_table.get_device_tensor()[:num_reqs]) + query_start_loc = common_attn_metadata.query_start_loc + slot_mapping = ( + self.block_table.slot_mapping_cpu[:num_actual_tokens].to( + device, non_blocking=True).long()) input_positions = self.runner.positions_cpu[:num_actual_tokens].to( device, non_blocking=True).long() - seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs] - seq_lens = seq_lens_cpu.to(device, non_blocking=True) + seq_lens = common_attn_metadata.seq_lens prefill_metadata = None if self._num_prefills > 0: @@ -543,7 +551,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, prefill_metadata = MLACommonPrefillMetadata( input_positions=input_positions[tokens_start:], - block_table=block_table[reqs_start:, ...], + block_table=block_table_tensor[reqs_start:, ...], query_start_loc=prefill_query_start_loc, max_query_len=max_query_len, chunked_context=chunked_context_metadata, @@ -553,7 +561,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, if self._num_decodes > 0: decode_metadata = self._build_decode( input_positions=input_positions[:self._num_decode_tokens], - block_table=block_table[:self._num_decodes, ...], + block_table_tensor=block_table_tensor[:self._num_decodes, ...], seq_lens=seq_lens[:self._num_decodes], ) diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 143bfe35bb5e..e072f74b4978 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -1,5 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 - from dataclasses import dataclass from typing import Any, Optional @@ -16,6 +15,8 @@ MLACommonImpl, MLACommonMetadata, MLACommonMetadataBuilder) +from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.worker.block_table import BlockTable logger = init_logger(__name__) @@ -52,14 +53,15 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]): class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): - def __init__(self, runner): - super().__init__(runner) + def __init__(self, runner, kv_cache_spec: AttentionSpec, + block_table: BlockTable): + super().__init__(runner, kv_cache_spec, block_table) self.num_q_heads = self.runner.model_config.get_num_attention_heads( self.runner.parallel_config) def _build_decode(self, input_positions: torch.Tensor, - block_table: torch.Tensor, + block_table_tensor: torch.Tensor, seq_lens: torch.Tensor) -> FlashMLADecodeMetadata: tile_scheduler_metadata, num_splits = \ get_mla_metadata( @@ -70,7 +72,7 @@ def _build_decode(self, input_positions: torch.Tensor, return FlashMLADecodeMetadata( input_positions=input_positions, - block_table=block_table, + block_table=block_table_tensor, seq_lens=seq_lens, tile_scheduler_metadata=tile_scheduler_metadata, num_splits=num_splits, diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py new file mode 100644 index 000000000000..b9d153da0a86 --- /dev/null +++ b/vllm/v1/attention/backends/utils.py @@ -0,0 +1,13 @@ +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass + +import torch + + +@dataclass +class CommonAttentionMetadata: + """ + Metadata that are same for different layer types. + """ + query_start_loc: torch.Tensor + seq_lens: torch.Tensor diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 74f3f7852c9a..796ff98554c0 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -5,7 +5,7 @@ from vllm.logger import init_logger from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue, - KVCacheBlock, + GroupedKVCacheBlock, KVCacheBlock, generate_block_hash_extra_keys, hash_block_tokens) from vllm.v1.request import Request @@ -26,10 +26,12 @@ class BlockPool: enable_caching: Whether to enable prefix caching. """ - def __init__(self, num_gpu_blocks: int, enable_caching: bool): + def __init__(self, num_gpu_blocks: int, enable_caching: bool, + num_specialized_managers: int, caching_hash_fn: Callable): assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0 self.num_gpu_blocks = num_gpu_blocks self.enable_caching = enable_caching + self.caching_hash_fn = caching_hash_fn # All kv-cache blocks. self.blocks: list[KVCacheBlock] = [ KVCacheBlock(idx) for idx in range(num_gpu_blocks) @@ -39,7 +41,7 @@ def __init__(self, num_gpu_blocks: int, enable_caching: bool): # enabled). self.free_block_queue = FreeKVCacheBlockQueue(self.blocks) - # {block_hash: {block ID: block}}. A cached block is + # {manager_id: {block_hash: {block ID: GroupedKVCacheBlock}}}. A cached block is # a full block with a block hash that can be used for prefix caching. # The cached block may be used by running requests or in the # free_block_queue that could potentially be evicted. @@ -48,16 +50,23 @@ def __init__(self, num_gpu_blocks: int, enable_caching: bool): # if there is already an identical block in the cache. This is because # we want to make sure the allocated block IDs won't change so that # block tables are append-only. - self.cached_block_hash_to_block: dict[BlockHashType, dict[ - int, KVCacheBlock]] = defaultdict(dict) + self.cached_block_hash_to_block: list[dict[BlockHashType, dict[ + int, GroupedKVCacheBlock]]] = [ + defaultdict(dict) for _ in range(num_specialized_managers) + ] # To represent a placeholder block with block_id=0. # The ref_cnt of null_block is not maintained, needs special care to # avoid freeing it. self.null_block = self.free_block_queue.popleft() - def get_cached_block(self, - block_hash: BlockHashType) -> Optional[KVCacheBlock]: + self.num_specialized_managers = num_specialized_managers + + def get_cached_block( + self, + block_hash: BlockHashType, + manager_id: int, + ) -> Optional[GroupedKVCacheBlock]: """Get a cached block by the block hash, or None if cache miss. If there are duplicated blocks, we return the first block in the cache. @@ -67,7 +76,8 @@ def get_cached_block(self, Returns: The cached block if it exists, or None. """ - cached_blocks = self.cached_block_hash_to_block.get(block_hash) + cached_blocks = self.cached_block_hash_to_block[manager_id].get( + block_hash) if not cached_blocks: return None first_block_id = next(iter(cached_blocks)) @@ -76,12 +86,12 @@ def get_cached_block(self, def cache_full_blocks( self, request: Request, - blocks: list[KVCacheBlock], + blocks: list[GroupedKVCacheBlock], block_hashes: list[BlockHashType], num_cached_blocks: int, num_full_blocks: int, block_size: int, - hash_fn: Callable, + manager_id: int, ) -> None: """Cache a list of full blocks for prefix caching. This function takes a list of blocks that will have their block hash @@ -100,7 +110,7 @@ def cache_full_blocks( num_full_blocks: The number of blocks that are full and should be cached after this function. block_size: Number of tokens in each block. - hash_fn: The hash function to use for block hashes. + manager_id: The id of the kv cache manager. """ if num_cached_blocks == num_full_blocks: return @@ -117,11 +127,13 @@ def cache_full_blocks( prev_block_hash_value = prev_block.block_hash.hash_value for i, blk in enumerate(new_full_blocks): + assert all(b.block_hash is None for b in blk.blocks) assert blk.block_hash is None if i < len(new_block_hashes): # The block hash may already be computed in - # "get_computed_blocks" if the tokens are not generated by + # "get_computed_blocks" or other groups with the same block_size + # if the tokens are not generated by # this request (either the prompt tokens or the previously # generated tokens with preemption). In this case we simply # reuse the block hash. @@ -146,13 +158,18 @@ def cache_full_blocks( request, start_token_idx, end_token_idx, -1) # Compute the hash of the current block. - block_hash = hash_block_tokens(hash_fn, prev_block_hash_value, + block_hash = hash_block_tokens(self.caching_hash_fn, + prev_block_hash_value, block_tokens, extra_keys) block_hashes.append(block_hash) # Update and added the full block to the cache. + for b in blk.blocks: + b.block_hash = block_hash + b.manager_id = manager_id blk.block_hash = block_hash - self.cached_block_hash_to_block[block_hash][blk.block_id] = blk + self.cached_block_hash_to_block[manager_id][block_hash][ + blk.block_id] = blk prev_block_hash_value = block_hash.hash_value def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]: @@ -199,17 +216,20 @@ def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: True if the block is evicted, False otherwise. """ block_hash = block.block_hash - if block_hash and block_hash in self.cached_block_hash_to_block: - block.reset_hash() - del self.cached_block_hash_to_block[block_hash][block.block_id] - - if len(self.cached_block_hash_to_block[block_hash]) == 0: - del self.cached_block_hash_to_block[block_hash] - + manager_id = block.manager_id + if block_hash and block_hash in self.cached_block_hash_to_block[ + manager_id]: + cached_blocks = ( + self.cached_block_hash_to_block[manager_id][block_hash]) + assert block.block_id in cached_blocks + cached_blocks[block.block_id].reset_hash() + del cached_blocks[block.block_id] + if len(cached_blocks) == 0: + del self.cached_block_hash_to_block[manager_id][block_hash] return True return False - def touch(self, blocks: list[KVCacheBlock]) -> None: + def touch(self, blocks: list[list[GroupedKVCacheBlock]]) -> None: """Touch a block increases its reference count by 1, and may remove the block from the free queue. This is used when a block is hit by another request with the same prefix. @@ -217,14 +237,17 @@ def touch(self, blocks: list[KVCacheBlock]) -> None: Args: blocks: A list of blocks to touch. """ - for block in blocks: - # ref_cnt=0 means this block is in the free list (i.e. eviction - # candidate), so remove it. - if block.ref_cnt == 0 and block != self.null_block: - self.free_block_queue.remove(block) - block.incr_ref() - - def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None: + for block_one_layer in blocks: + for block_two_layer in block_one_layer: + for block in block_two_layer.blocks: + # ref_cnt=0 means this block is in the free list (i.e. eviction + # candidate), so remove it. + if block.ref_cnt == 0 and block != self.null_block: + self.free_block_queue.remove(block) + block.incr_ref() + + def free_blocks(self, + ordered_blocks: Iterable[GroupedKVCacheBlock]) -> None: """Free a list of blocks. The blocks should be ordered by their eviction priority, where the first block will be evicted first. @@ -232,11 +255,13 @@ def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None: ordered_blocks: A list of blocks to free ordered by their eviction priority. """ - for block in ordered_blocks: - block.decr_ref() - # null_block should not be added to the free list. - if block.ref_cnt == 0 and block != self.null_block: - self.free_block_queue.append(block) + # TODO: make sure blocks in the first group are evicted first + for blk in ordered_blocks: + for block in blk.blocks: + block.decr_ref() + # null_block should not be added to the free list. + if block.ref_cnt == 0 and block != self.null_block: + self.free_block_queue.append(block) def reset_prefix_cache(self) -> bool: """Reset prefix cache. This function may be used in RLHF @@ -255,7 +280,9 @@ def reset_prefix_cache(self) -> bool: return False # Remove all hashes so that no new blocks will hit. - self.cached_block_hash_to_block = defaultdict(dict) + self.cached_block_hash_to_block = [ + defaultdict(dict) for _ in range(self.num_specialized_managers) + ] # Remove all hashes from all blocks. for block in self.blocks: diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 354300d3c2fe..a7842a556a6c 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -1,23 +1,40 @@ # SPDX-License-Identifier: Apache-2.0 from collections import defaultdict -from collections.abc import Iterable +from dataclasses import dataclass from typing import Optional from vllm.logger import init_logger from vllm.utils import cdiv, sha256 from vllm.v1.core.block_pool import BlockPool -from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, - hash_request_tokens) -from vllm.v1.core.specialized_manager import get_specialized_manager -from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.core.kv_cache_utils import ( + BlockHashType, GroupedKVCacheBlock, KVCacheBlock, hash_request_tokens, + remove_last_block_hash_for_divisible_prompt_length) +from vllm.v1.core.specialized_manager import SpecializedManager, get_specialized_manager +from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request, RequestStatus logger = init_logger(__name__) +@dataclass +class KVCacheBlocks: + blocks: list[list[GroupedKVCacheBlock]] + + def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks": + return KVCacheBlocks([ + self_blocks_i + other_blocks_i + for self_blocks_i, other_blocks_i in zip(self.blocks, other.blocks) + ]) + + class KVCacheManager: + """ + The KVCacheManager for models with multiple KV cache types + (e.g., Gemma-2) and thus multiple kv cache groups (Refer to class + `KVCacheConfig` for the meaning of kv cache groups). + """ def __init__( self, @@ -27,44 +44,50 @@ def __init__( caching_hash_algo: str = "builtin", log_stats: bool = False, ) -> None: - assert len(kv_cache_config.kv_cache_groups) == 1, ( - "KVCacheManager does not support hybrid models with more than 1 " - "kv cache group") - kv_cache_spec = kv_cache_config.kv_cache_groups[0].kv_cache_spec - self.block_size = kv_cache_spec.block_size + self.kv_cache_config = kv_cache_config self.num_gpu_blocks = kv_cache_config.num_blocks self.max_model_len = max_model_len - self.max_num_blocks_per_req = cdiv(max_model_len, self.block_size) - + self.max_num_blocks_per_req = [ + cdiv(max_model_len, g.kv_cache_spec.block_size) + for g in kv_cache_config.kv_cache_groups + ] self.enable_caching = enable_caching self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash self.log_stats = log_stats # FIXME: make prefix cache stats conditional on log_stats self.prefix_cache_stats = PrefixCacheStats() if log_stats else None - self.block_pool = BlockPool(self.num_gpu_blocks, enable_caching) - self.specialized_manager = get_specialized_manager( - kv_cache_spec=kv_cache_spec, - block_pool=self.block_pool, - ) - - # Mapping from request ID to blocks to track the blocks allocated - # for each request, so that we can free the blocks when the request - # is finished. - self.req_to_blocks: defaultdict[str, - list[KVCacheBlock]] = defaultdict(list) - + self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups) + + # the kv cache groups managed by the each manager + # manager_id -> list[kv_cache_group_id] + self.manager_to_group = self.generate_group_manager_map() + self.num_specialized_managers = len(self.manager_to_group) + + self.block_pool = BlockPool(self.num_gpu_blocks, enable_caching, + self.num_specialized_managers, + self.caching_hash_fn) + + self.specialized_managers: list[SpecializedManager] = [] + for i in range(len(self.manager_to_group)): + group_ids = self.manager_to_group[i] + kv_cache_spec = kv_cache_config.kv_cache_groups[ + group_ids[0]].kv_cache_spec + self.specialized_managers.append( + get_specialized_manager(kv_cache_spec=kv_cache_spec, + block_pool=self.block_pool, + kv_cache_manager_id=i, + num_kv_cache_groups=len(group_ids))) # Mapping from request ID to kv block hashes. # This is to avoid recomputing the block hashes for each call of # `get_computed_blocks` or `allocate_slots`. - self.req_to_block_hashes: defaultdict[ - str, list[BlockHashType]] = defaultdict(list) + # block_size -> list[BlockHashType]; TODO update comment + self.req_to_block_hashes: defaultdict[str, dict[ + int, list[BlockHashType]]] = defaultdict(dict) - # {req_id: The number of cached blocks for this given request} - # This is used to track the number of cached blocks for each request. - # This is only used to track the RUNNING requests, we do not track the - # data for reempted ones. - self.num_cached_block: dict[str, int] = {} + self.all_block_sizes = set( + g.kv_cache_spec.block_size + for g in self.kv_cache_config.kv_cache_groups) @property def usage(self) -> float: @@ -87,8 +110,12 @@ def make_prefix_cache_stats(self) -> Optional[PrefixCacheStats]: self.prefix_cache_stats = PrefixCacheStats() return stats - def get_computed_blocks( - self, request: Request) -> tuple[list[KVCacheBlock], int]: + def empty_kv_cache_blocks(self) -> KVCacheBlocks: + return KVCacheBlocks([[] + for _ in range(self.num_specialized_managers)]) + + def get_computed_blocks(self, + request: Request) -> tuple[KVCacheBlocks, int]: """Get the computed (cached) blocks for the request. Note that the computed blocks must be full. @@ -97,19 +124,22 @@ def get_computed_blocks( Returns: A tuple containing: - - A list of blocks that are computed for the request. + - A list of blocks that are computed for each kv cache group. - The number of computed tokens. """ if not self.enable_caching: # Prefix caching is disabled. - return [], 0 + return self.empty_kv_cache_blocks(), 0 # The block hashes for the request may already be computed # if the scheduler has tried to schedule the request before. block_hashes = self.req_to_block_hashes[request.request_id] - if not block_hashes: - block_hashes = hash_request_tokens(self.caching_hash_fn, - self.block_size, request) + if len(block_hashes) == 0: + block_hashes = { + block_size: + hash_request_tokens(self.caching_hash_fn, block_size, request) + for block_size in self.all_block_sizes + } self.req_to_block_hashes[request.request_id] = block_hashes if self.log_stats: @@ -117,47 +147,26 @@ def get_computed_blocks( self.prefix_cache_stats.requests += 1 # When the request requires prompt logprobs, we skip prefix caching. if request.sampling_params.prompt_logprobs is not None: - return [], 0 - - if len(block_hashes) * self.block_size == request.num_tokens: - # When prompt length is divisible by the block size and all - # blocks are cached, we need to recompute the last token. This - # have to be achieved by re-computing an entire block because - # allocate_slots() assumes num_computed_tokens is always a - # multiple of the block size. To achieve this, remove the last - # block hash from the block_hashes for find_longest_cache_hit - # This limitation can potentially be removed in the future to - # slightly improve the performance. - last_block_hash = block_hashes.pop() - else: - last_block_hash = None + return self.empty_kv_cache_blocks(), 0 + + computed_blocks, num_computed_tokens = self.find_longest_cache_hit( + request, block_hashes) - computed_blocks = ( - self.specialized_manager.find_longest_cache_hit(block_hashes)) if self.log_stats: assert self.prefix_cache_stats is not None + self.prefix_cache_stats.queries += len(block_hashes) self.prefix_cache_stats.hits += len(computed_blocks) - - if last_block_hash is not None: - # Add back the last block hash if it was removed. - # NOTE: Because block_hashes is cached in req_to_block_hashes, - # we shouldn't modify it directly. - block_hashes.append(last_block_hash) - - # NOTE(woosuk): Since incomplete blocks are not eligible for - # sharing, `num_computed_tokens` is always a multiple of - # `block_size`. - num_computed_tokens = len(computed_blocks) * self.block_size - return computed_blocks, num_computed_tokens + return KVCacheBlocks(computed_blocks), num_computed_tokens def allocate_slots( self, request: Request, num_tokens: int, - new_computed_blocks: Optional[list[KVCacheBlock]] = None, + wrapped_new_computed_blocks: Optional[KVCacheBlocks] = None, + num_new_computed_tokens: int = 0, num_lookahead_tokens: int = 0, - ) -> Optional[list[KVCacheBlock]]: + ) -> Optional[KVCacheBlocks]: """Add slots for a request with new tokens to append. Args: @@ -170,6 +179,8 @@ def allocate_slots( num_lookahead_tokens: The number of speculative tokens to allocate. This is used by spec decode proposers with kv-cache such as eagle. + num_new_computed_tokens: The number of new computed tokens in the + new_computed_blocks. Blocks layout: ----------------------------------------------------------------------- @@ -189,99 +200,69 @@ def allocate_slots( if num_tokens == 0: raise ValueError("num_tokens must be greater than 0") - new_computed_blocks = new_computed_blocks or [] - - req_blocks = self.req_to_blocks[request.request_id] - + if wrapped_new_computed_blocks is not None: + new_computed_blocks = wrapped_new_computed_blocks.blocks + else: + new_computed_blocks = [ + [] for _ in range(self.num_specialized_managers) + ] # Free the blocks that are skipped during the attention computation # (e.g., tokens outside the sliding window). # We can do this even if we cannot schedule this request due to # insufficient free blocks. # Should call this function before allocating new blocks to reduce # the number of evicted blocks. - removed_blocks = self.specialized_manager.remove_skipped_blocks( - req_blocks, request.num_computed_tokens) - self.block_pool.free_blocks(removed_blocks) + for i, manager in enumerate(self.specialized_managers): + manager.remove_skipped_blocks(request.request_id, + request.num_computed_tokens) # The number of computed tokens is the number of computed tokens plus # the new prefix caching hits num_computed_tokens = (request.num_computed_tokens + - len(new_computed_blocks) * self.block_size) - num_required_blocks = cdiv( - num_computed_tokens + num_tokens + num_lookahead_tokens, - self.block_size) - num_new_blocks = (num_required_blocks - len(req_blocks) - - len(new_computed_blocks)) - - # If a computed block of a request is an eviction candidate (in the - # free queue and ref_cnt == 0), it cannot be counted as a free block - # when allocating this request. - num_evictable_computed_blocks = sum(1 for blk in new_computed_blocks - if blk.ref_cnt == 0) - if (num_new_blocks > self.block_pool.get_num_free_blocks() - - num_evictable_computed_blocks): - # Cannot allocate new blocks + num_new_computed_tokens) + num_tokens_need_slot = (num_computed_tokens + num_tokens + + num_lookahead_tokens) + + num_needed_blocks: list[int] = [ + manager.get_num_needed_blocks(request.request_id, + num_tokens_need_slot, + new_computed_blocks[i]) + for manager in self.specialized_managers + ] + if (sum(num_needed_blocks) > self.block_pool.get_num_free_blocks()): return None # Touch the computed blocks to make sure they won't be evicted. if self.enable_caching: self.block_pool.touch(new_computed_blocks) else: - assert not new_computed_blocks, ( + assert all(len(blks) == 0 for blks in new_computed_blocks), ( "Computed blocks should be empty when " "prefix caching is disabled") # Append the new computed blocks to the request blocks until now to # avoid the case where the new blocks cannot be allocated. - req_blocks.extend(new_computed_blocks) + for i in range(self.num_specialized_managers): + self.specialized_managers[i].req_to_blocks[ + request.request_id].extend(new_computed_blocks[i]) + new_blocks: list[list[GroupedKVCacheBlock]] = [] # Start to handle new blocks - - if num_new_blocks <= 0: - # No new block is needed. - new_blocks = [] - else: - # Get new blocks from the free block pool. - num_new_blocks = min( - num_new_blocks, - self.block_pool.get_num_free_blocks(), - # Should not exceed the maximum number of blocks per request. - # This is especially because the block table has the shape - # [..., max_num_blocks_per_req]. - self.max_num_blocks_per_req - len(req_blocks), - ) - assert num_new_blocks > 0 - - # Concatenate the computed block IDs and the new block IDs. - new_blocks = self.block_pool.get_new_blocks(num_new_blocks) - req_blocks.extend(new_blocks) + for i in range(self.num_specialized_managers): + new_blocks_i = self.specialized_managers[i].allocate_new_blocks( + request.request_id, num_tokens_need_slot) + new_blocks.append(new_blocks_i) if not self.enable_caching: - return new_blocks - - # Use `new_computed_blocks` for a new request, and `num_cached_block` - # for a running request. - num_cached_blocks = self.num_cached_block.get(request.request_id, - len(new_computed_blocks)) - # Speculated tokens might be rejected in the future, so we does - # not cache any speculated tokens. We only cache blocks with - # generated (accepted) tokens. - num_full_blocks_after_append = (num_computed_tokens + num_tokens - len( - request.spec_token_ids)) // self.block_size - - self.block_pool.cache_full_blocks( - request=request, - blocks=req_blocks, - block_hashes=self.req_to_block_hashes[request.request_id], - num_cached_blocks=num_cached_blocks, - num_full_blocks=num_full_blocks_after_append, - block_size=self.block_size, - hash_fn=self.caching_hash_fn, - ) - - self.num_cached_block[ - request.request_id] = num_full_blocks_after_append - return new_blocks + return KVCacheBlocks(new_blocks) + + for i, manager in enumerate(self.specialized_managers): + manager.cache_blocks( + request, new_computed_blocks[i], self.req_to_block_hashes[ + request.request_id][manager.block_size], + num_computed_tokens, num_tokens) + + return KVCacheBlocks(new_blocks) def free(self, request: Request) -> None: """Free the blocks allocated for the request. @@ -291,16 +272,8 @@ def free(self, request: Request) -> None: Args: request: The request to free the blocks. """ - # Default to [] in case a request is freed (aborted) before alloc. - blocks = self.req_to_blocks.pop(request.request_id, []) - ordered_blocks: Iterable[KVCacheBlock] = blocks - if self.enable_caching: - # Free blocks in reverse order so that the tail blocks are - # freed first. - ordered_blocks = reversed(blocks) - - self.block_pool.free_blocks(ordered_blocks) - self.num_cached_block.pop(request.request_id, None) + for manager in self.specialized_managers: + manager.free(request.request_id) def reset_prefix_cache(self) -> bool: """Reset prefix cache. This function may be used in RLHF @@ -322,7 +295,7 @@ def get_num_common_prefix_blocks( self, request: Request, num_running_requests: int, - ) -> int: + ) -> list[int]: """Calculate the number of common prefix blocks shared by all requests in the RUNNING state. @@ -354,17 +327,23 @@ def get_num_common_prefix_blocks( requests in the current step. Returns: - int: The number of common prefix blocks. + list[int]: The number of common prefix blocks for each kv cache + group. """ - assert request.status == RequestStatus.RUNNING - blocks = self.req_to_blocks[request.request_id] - num_common_blocks = 0 - for block in blocks: - if block.ref_cnt == num_running_requests: - num_common_blocks += 1 - else: - break - return num_common_blocks + # TODO: implement this + return [0] * self.num_kv_cache_groups + # assert request.status == RequestStatus.RUNNING + # blocks = self.req_to_blocks[request.request_id] + # num_common_blocks = [] + # for i in range(self.num_kv_cache_groups): + # num_common_blocks_i = 0 + # for block in blocks[i]: + # if block.ref_cnt == num_running_requests: + # num_common_blocks_i += 1 + # else: + # break + # num_common_blocks.append(num_common_blocks_i) + # return num_common_blocks def free_block_hashes(self, request: Request) -> None: """Discard the block hashes for the request. @@ -373,3 +352,99 @@ def free_block_hashes(self, request: Request) -> None: is finished, not when it is preempted. """ self.req_to_block_hashes.pop(request.request_id, None) + + def find_longest_cache_hit( + self, request: Request, block_hashes_dict: dict[int, + list[BlockHashType]] + ) -> tuple[list[list[GroupedKVCacheBlock]], int]: + """Find the longest cache hit for each kv cache group. + TODO: add more notes + """ + + # When prompt length is divisible by the block size and all + # blocks are cached, we need to recompute the last token. This + # have to be achieved by re-computing an entire block because + # allocate_slots() assumes num_computed_tokens is always a + # multiple of the block size. To achieve this, remove the last + # block hash from the block_hashes for find_longest_cache_hit + # This limitation can potentially be removed in the future to + # slightly improve the performance. + with remove_last_block_hash_for_divisible_prompt_length( + block_hashes_dict, request.num_tokens): + if self.num_specialized_managers == 1: + block_size = self.kv_cache_config.kv_cache_groups[ + 0].kv_cache_spec.block_size + hit_blocks = self.specialized_managers[ + 0].find_longest_cache_hit(block_hashes_dict[block_size]) + + return [hit_blocks], len(hit_blocks) * block_size + + # TODO: add note for the two magic number + num_computed_tokens = [self.max_model_len + 100] * len( + self.specialized_managers) + min_computed_tokens = self.max_model_len + + # Use copy to avoid modifying the original block_hashes + block_hashes = [ + block_hashes_dict[g.kv_cache_spec.block_size].copy() + for g in self.kv_cache_config.kv_cache_groups + ] + + computed_blocks: list[Optional[list[GroupedKVCacheBlock]]] = [ + None for _ in range(self.num_specialized_managers) + ] + + def shrink_length(block_hashes, length): + del block_hashes[length:] + + while max(num_computed_tokens) != min_computed_tokens: + for i, manager in enumerate(self.specialized_managers): + if num_computed_tokens[i] > min_computed_tokens: + shrink_length( + block_hashes[i], + min_computed_tokens // manager.block_size) + computed_blocks_i = (manager.find_longest_cache_hit( + block_hashes[i], computed_blocks[i])) + + num_computed_tokens[i] = len(computed_blocks_i) * \ + manager.block_size + min_computed_tokens = min(min_computed_tokens, + num_computed_tokens[i]) + computed_blocks[i] = computed_blocks_i + shrink_length( + block_hashes[i], + num_computed_tokens[i] // manager.block_size) + + assert all(block is not None and len(block) * + manager.block_size == min_computed_tokens + for block, manager in zip(computed_blocks, + self.specialized_managers)) + return computed_blocks, min_computed_tokens + + def generate_group_manager_map(self) -> list[list[int]]: + type_ids = [ + g.kv_cache_spec.type_id + for g in self.kv_cache_config.kv_cache_groups + ] + assert sorted(type_ids) == type_ids, "type_ids must be sorted" + manager_to_group: list[list[int]] = [] + for i, type_id in enumerate(type_ids): + if i == 0: + manager_to_group.append([i]) + else: + if type_id == type_ids[i - 1]: + manager_to_group[-1].append(i) + else: + manager_to_group.append([i]) + print("manager_to_group", manager_to_group) + return manager_to_group + + def to_block_ids(self, kv_cache_blocks: KVCacheBlocks) -> list[list[int]]: + block_ids: list[list[int]] = [[] + for _ in range(self.num_kv_cache_groups)] + for blocks_one_manager, group_ids in zip(kv_cache_blocks.blocks, + self.manager_to_group): + for blocks in blocks_one_manager: + for blk, group_id in zip(blocks.blocks, group_ids): + block_ids[group_id].append(blk.block_id) + return block_ids diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index bd0e01d045d1..4254f66945b3 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1,8 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 """KV-Cache Utilities.""" import os -from collections import deque +from collections import defaultdict, deque from collections.abc import Sequence +from contextlib import contextmanager from dataclasses import dataclass from typing import Any, Callable, NamedTuple, Optional @@ -10,8 +11,9 @@ from vllm.logger import init_logger from vllm.utils import GiB_bytes, sha256 from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec, KVCacheSpec, - KVCacheTensor, SlidingWindowSpec) + KVCacheGroupSpec, KVCacheNewTensor, + KVCacheReuseTensor, KVCacheSpec, + SlidingWindowSpec) from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request @@ -124,6 +126,8 @@ class KVCacheBlock: prev_free_block: Optional["KVCacheBlock"] = None next_free_block: Optional["KVCacheBlock"] = None + manager_id: int = -1 + def incr_ref(self): self.ref_cnt += 1 @@ -143,6 +147,7 @@ def block_hash(self, block_hash: BlockHashType): def reset_hash(self): """Reset the block hash when the block is evicted.""" self._block_hash = None + self.manager_id = -1 def __repr__(self) -> str: # Use block_id instead of KVCacheBlock object to avoid calling __repr__ @@ -372,7 +377,6 @@ def generate_block_hash_extra_keys( start_token_idx: The start token index of the block. end_token_idx: The end token index of the block. start_mm_idx: The start multi-modal index of the block. - Returns: A tuple of extra keys and the next multi-modal index. """ @@ -400,11 +404,11 @@ def hash_block_tokens( hash values for the same block contents. Args: + hash_function: The function used for hash parent_block_hash: The hash of the parent block. None if this is the first block. curr_block_token_ids: A list of token ids in the current block. The current block is assumed to be full. - extra_keys: Extra keys for the block. Returns: The hash value of the block and the token ids in the block. @@ -552,6 +556,26 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig, f"`max_model_len` when initializing the engine.") +def merge_layer_specs(layer_specs: list[KVCacheSpec]) -> KVCacheSpec: + """ + Merge a list of KVCacheSpec objects into a single KVCacheSpec object. + """ + assert all(layer_spec.type_id == layer_specs[0].type_id + for layer_spec in layer_specs[1:]), ( + "All layers in the same KV cache group must share the same " + "KVCacheSpec.") + layer_spec = layer_specs[0] + if isinstance(layer_spec, FullAttentionSpec): + for spec in layer_specs[1:]: + assert isinstance(spec, FullAttentionSpec) + if spec.sliding_window is not None: + if layer_spec.sliding_window is None: + layer_spec.sliding_window = spec.sliding_window + else: + assert layer_spec.sliding_window == spec.sliding_window + return layer_spec + + def create_kv_cache_group_specs( kv_cache_spec: dict[str, KVCacheSpec], grouped_layer_names: list[list[str]]) -> list[KVCacheGroupSpec]: @@ -572,12 +596,9 @@ def create_kv_cache_group_specs( """ kv_cache_groups = [] for layer_names_one_group in grouped_layer_names: - layer_spec = kv_cache_spec[layer_names_one_group[0]] - assert all( - kv_cache_spec[layer_name] == layer_spec - for layer_name in layer_names_one_group[1:]), ( - "All layers in the same KV cache group must share the same " - "KVCacheSpec.") + layer_spec = merge_layer_specs([ + kv_cache_spec[layer_name] for layer_name in layer_names_one_group + ]) kv_cache_groups.append( KVCacheGroupSpec(layer_names_one_group, layer_spec)) return kv_cache_groups @@ -645,7 +666,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, kv_cache_config = KVCacheConfig( num_blocks=num_blocks, tensors={ - layer_name: KVCacheTensor(size=per_layer_size) + layer_name: KVCacheNewTensor(size=per_layer_size) for layer_name in kv_cache_spec }, kv_cache_groups=create_kv_cache_group_specs(kv_cache_spec, @@ -654,6 +675,87 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, return kv_cache_config +def is_kv_cache_page_size_uniform( + kv_cache_spec: dict[str, KVCacheSpec]) -> bool: + """ + Whether all layers in the given KVCacheSpec have the same page size. + Args: + kv_cache_spec: The KVCacheSpec of each attention layer in the model + + Returns: + True if all layers have the same page size, False otherwise. + """ + + page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()} + return len(page_sizes) == 1 + + +def _get_kv_cache_config_uniform_page_size( + vllm_config: VllmConfig, kv_cache_spec: dict[str, KVCacheSpec], + available_memory: int) -> KVCacheConfig: + """ + Generates the KV cache configuration for a model with one page size. + + Args: + vllm_config: The global VllmConfig + kv_cache_spec: The KVCacheSpec of each attention layer in the model + available_memory: Memory available for KV cache in bytes. + + Returns: + The generated KVCacheConfig + """ + # Group all layers by type_id. + # E.g., 2 full attention layers and 3 sliding window attention layers, + # -> (full.0, full.1), (sw.0, sw.1, sw.2). + same_type_layers: dict[str, list[str]] = defaultdict(list) + for layer_name, layer_spec in kv_cache_spec.items(): + same_type_layers[layer_spec.type_id].append(layer_name) + + # Split each group into smaller groups, to make the number of layers in each + # group identical. Add padding to the last group of each type if necessary. + # E.g., (full.0, full.1), (sw.0, sw.1, sw.2) + # split to 3 groups with 2 layers each: + # (full.0, full.1), (sw.0, sw.1), (sw.2, padding). + group_size = min([len(layers) for layers in same_type_layers.values()]) + grouped_layers = [] + for layers in same_type_layers.values(): + num_padding_layers = len(layers) % group_size + if num_padding_layers > 0: + logger.warning( + "Add %d padding layers, may waste at most %.2f%% KV cache memory", # noqa + num_padding_layers, + num_padding_layers / len(layers) * 100) + for i in range(0, len(layers), group_size): + grouped_layers.append(layers[i:i + group_size]) + + # Divide the available memory equally among all layers in the first group. + # The memory layout in the example will be: + # full.0: Tensor with size=available_memory//2 + # full.1: Tensor with size=available_memory//2 + kv_cache_spec_first_group = { + layer_name: kv_cache_spec[layer_name] + for layer_name in grouped_layers[0] + } + kv_cache_config = _get_kv_cache_config_uniform_type( + vllm_config, kv_cache_spec_first_group, available_memory) + + # Reuse the KV cache tensors of the first group for the other groups. + # The memory layout in the example will be: + # full.0, sw.0, sw.2: share a Tensor with size=available_memory//2 + # full.1, sw.1: share another Tensor with size=available_memory//2 + # Layers of different groups have different block table, so they will + # use different parts of the shared Tensor. + for layers in grouped_layers[1:]: + for layer_name, layer_name_first_group in zip( + layers, grouped_layers[0][:len(layers)]): + kv_cache_config.tensors[layer_name] = KVCacheReuseTensor( + reused_layer_name=layer_name_first_group) + + kv_cache_config.kv_cache_groups = create_kv_cache_group_specs( + kv_cache_spec, grouped_layers) + return kv_cache_config + + def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): """ Only models with one type of KV cache are supported yet. This function tries @@ -674,10 +776,12 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): if isinstance(spec, SlidingWindowSpec): kv_cache_spec[layer_name] = FullAttentionSpec( block_size=spec.block_size, + num_query_heads=spec.num_query_heads, num_kv_heads=spec.num_kv_heads, head_size=spec.head_size, dtype=spec.dtype, use_mla=spec.use_mla, + sliding_window=spec.sliding_window, ) @@ -697,14 +801,20 @@ def get_kv_cache_config(vllm_config: VllmConfig, The generated KVCacheConfigs """ check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory) - unify_hybrid_kv_cache_specs(kv_cache_spec) + if (vllm_config.cache_config.disable_hybrid_allocator + or vllm_config.device_config.device.type != "cuda"): + unify_hybrid_kv_cache_specs(kv_cache_spec) if is_kv_cache_type_uniform(kv_cache_spec): # KV cache of all layers are the same, which is true for # most models. Allocate the same amount of memory for # each layer. return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec, available_memory) - + elif is_kv_cache_page_size_uniform(kv_cache_spec): + # KV cache of all layers have the same page size. TODO more notes + return _get_kv_cache_config_uniform_page_size(vllm_config, + kv_cache_spec, + available_memory) raise NotImplementedError @@ -742,3 +852,39 @@ def unify_kv_cache_configs(kv_cache_configs: list[KVCacheConfig]): kv_cache_config.num_blocks = min_num_blocks return kv_cache_configs + + +@contextmanager +def remove_last_block_hash_for_divisible_prompt_length( + block_hashes: dict[int, list[BlockHashType]], num_tokens: int): + """ + Remove the last block hash for the case where the prompt length is divisible + by the block size and all blocks are cached. + """ + last_block_hashs: dict[int, BlockHashType] = {} + for block_size in block_hashes: + if len(block_hashes[block_size]) * block_size == num_tokens: + last_block_hashs[block_size] = block_hashes[block_size].pop() + yield + for block_size, block_hash in last_block_hashs.items(): + block_hashes[block_size].append(block_hash) + + +# KVCacheBlocks for the same block of all kv cache groups with the same kv cache +# spec (and belongs to the same manager) +@dataclass +class GroupedKVCacheBlock: + blocks: tuple[KVCacheBlock, ...] + block_hash: Optional[BlockHashType] = None + block_id: int = -1 + + @staticmethod + def from_kv_cache_blocks(blocks: tuple[KVCacheBlock, ...]): + return GroupedKVCacheBlock(blocks=blocks, + block_hash=blocks[0].block_hash, + block_id=blocks[0].block_id) + + def reset_hash(self): + for block in self.blocks: + block.reset_hash() + self.block_hash = None diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index 1d3f1f41f8fb..6d7a7f0e0b0a 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -27,7 +27,7 @@ class NewRequestData: mm_hashes: list[str] mm_positions: list[PlaceholderRange] sampling_params: SamplingParams - block_ids: list[int] + block_ids: list[list[int]] num_computed_tokens: int lora_request: Optional[LoRARequest] @@ -35,7 +35,7 @@ class NewRequestData: def from_request( cls, request: Request, - block_ids: list[int], + block_ids: list[list[int]], ) -> NewRequestData: return cls( req_id=request.request_id, @@ -60,7 +60,7 @@ class CachedRequestData: # request's block IDs instead of appending to the existing block IDs. resumed_from_preemption: bool new_token_ids: list[int] - new_block_ids: list[int] + new_block_ids: list[list[int]] num_computed_tokens: int @classmethod @@ -69,7 +69,7 @@ def from_request( request: Request, resumed_from_preemption: bool, new_token_ids: list[int], - new_block_ids: list[int], + new_block_ids: list[list[int]], ) -> CachedRequestData: return cls( req_id=request.request_id, @@ -108,7 +108,7 @@ class SchedulerOutput: scheduled_encoder_inputs: dict[str, list[int]] # Number of common prefix blocks for all requests. # This can be used for cascade attention. - num_common_prefix_blocks: int + num_common_prefix_blocks: list[int] # Request IDs that are finished in between the previous and the current # steps. This is used to notify the workers about the finished requests diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index adec4462963c..486b359708cb 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -154,7 +154,7 @@ def schedule(self) -> SchedulerOutput: # uses structured decoding. structured_output_request_ids: dict[str, int] = {} - req_to_new_block_ids: dict[str, list[int]] = {} + req_to_new_block_ids: dict[str, list[list[int]]] = {} num_scheduled_tokens: dict[str, int] = {} token_budget = self.max_num_scheduled_tokens # Encoder-related. @@ -246,9 +246,8 @@ def schedule(self) -> SchedulerOutput: # Therefore, we might introduce some additional # cycle to fill in the bitmask, which could be a big no-op. structured_output_request_ids[request.request_id] = req_index - req_to_new_block_ids[request.request_id] = [ - b.block_id for b in new_blocks - ] + req_to_new_block_ids[request.request_id] = ( + self.kv_cache_manager.to_block_ids(new_blocks)) num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens req_index += 1 @@ -318,6 +317,7 @@ def schedule(self) -> SchedulerOutput: # Get already-cached tokens. computed_blocks, num_computed_tokens = \ self.kv_cache_manager.get_computed_blocks(request) + print("num_computed_tokens", num_computed_tokens) # Get externally-cached tokens if using a KVConnector. num_external_tokens = ( @@ -357,6 +357,7 @@ def schedule(self) -> SchedulerOutput: request, num_new_tokens + num_external_tokens, computed_blocks, + num_computed_tokens, num_lookahead_tokens=self.num_lookahead_tokens, ) if new_blocks is None: @@ -391,9 +392,9 @@ def schedule(self) -> SchedulerOutput: if self.lora_config and request.lora_request: scheduled_loras.add(request.lora_request.lora_int_id) - req_to_new_block_ids[request.request_id] = [ - b.block_id for b in computed_blocks + new_blocks - ] + req_to_new_block_ids[request.request_id] = ( + self.kv_cache_manager.to_block_ids(computed_blocks + + new_blocks)) num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens request.status = RequestStatus.RUNNING @@ -425,7 +426,9 @@ def schedule(self) -> SchedulerOutput: # Get the longest common prefix among all requests in the running queue. # This can be potentially used for cascade attention. - num_common_prefix_blocks = 0 + num_common_prefix_blocks: list[int] = [0] * len( + self.kv_cache_config.kv_cache_groups) + if self.running: any_request = self.running[0] num_common_prefix_blocks = ( @@ -507,7 +510,7 @@ def _make_cached_request_data( request: Request, num_scheduled_tokens: int, num_scheduled_spec_tokens: int, - new_block_ids: list[int], + new_block_ids: list[list[int]], resumed_from_preemption: bool, ) -> CachedRequestData: # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating diff --git a/vllm/v1/core/specialized_manager.py b/vllm/v1/core/specialized_manager.py index 7a8a98361c7e..e5e180b3a77f 100644 --- a/vllm/v1/core/specialized_manager.py +++ b/vllm/v1/core/specialized_manager.py @@ -1,11 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod +from collections import defaultdict +from typing import Optional from vllm.utils import cdiv from vllm.v1.core.block_pool import BlockPool -from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock +from vllm.v1.core.kv_cache_utils import BlockHashType, GroupedKVCacheBlock, KVCacheBlock from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec, SlidingWindowSpec) +from vllm.v1.request import Request +from vllm.v1.utils import ConstantList class SpecializedManager(ABC): @@ -18,27 +22,119 @@ def __init__( self, kv_cache_spec: KVCacheSpec, block_pool: BlockPool, + kv_cache_manager_id: int, + num_kv_cache_groups: int, ) -> None: """ Initializes the SpecializedManager. Args: kv_cache_spec: The kv_cache_spec for this manager. block_pool: The block pool. + kv_cache_manager_id: The id of the kv cache manager. """ self.block_size = kv_cache_spec.block_size self.kv_cache_spec = kv_cache_spec self.block_pool = block_pool + self.kv_cache_manager_id = kv_cache_manager_id + self.num_kv_cache_groups = num_kv_cache_groups + # Mapping from request ID to blocks to track the blocks allocated + # for each request, so that we can free the blocks when the request + # is finished. + self.req_to_blocks: defaultdict[ + str, list[GroupedKVCacheBlock]] = defaultdict(list) + + # {req_id: The number of cached blocks for each kv cache group} + # This is used to track the number of cached blocks for each request. + # This is only used to track the RUNNING requests, we do not track the + # data for reempted ones. + self.num_cached_block: dict[str, int] = {} + + def get_num_needed_blocks( + self, request_id: str, num_tokens: int, + new_computed_block_list: list[GroupedKVCacheBlock]) -> int: + num_required_blocks = cdiv(num_tokens, self.block_size) + num_new_blocks = max( + num_required_blocks - len(new_computed_block_list) - + len(self.req_to_blocks[request_id]), 0) + # If a computed block of a request is an eviction candidate (in the + # free queue and ref_cnt == 0), it cannot be counted as a free block + # when allocating this request. # TODO: update comment + num_evictable_computed_blocks = sum( + blks.blocks[0].ref_cnt == 0 for blks in new_computed_block_list) + return ((num_new_blocks + num_evictable_computed_blocks) * + self.num_kv_cache_groups) + + def allocate_new_blocks(self, request_id: str, + num_tokens: int) -> list[GroupedKVCacheBlock]: + num_required_blocks = cdiv(num_tokens, self.block_size) + num_new_blocks = max( + num_required_blocks - len(self.req_to_blocks[request_id]), 0) + if num_new_blocks <= 0: + return [] + else: + flat_new_blocks = self.block_pool.get_new_blocks( + num_new_blocks * self.num_kv_cache_groups) + new_blocks = [] + for i in range(num_new_blocks): + blocks = flat_new_blocks[i * self.num_kv_cache_groups:(i + 1) * + self.num_kv_cache_groups] + grouped_block = GroupedKVCacheBlock.from_kv_cache_blocks( + tuple(blocks)) + new_blocks.append(grouped_block) + self.req_to_blocks[request_id].extend(new_blocks) + return new_blocks + + def cache_blocks(self, request: Request, + new_computed_blocks: list[GroupedKVCacheBlock], + block_hashes: list[BlockHashType], + num_computed_tokens: int, num_tokens: int) -> None: + # Use `new_computed_blocks` for a new request, and + # `num_cached_block` for a running request. + num_cached_blocks = self.num_cached_block.get(request.request_id, + len(new_computed_blocks)) + # Speculated tokens might be rejected in the future, so we does + # not cache any speculated tokens. We only cache blocks with + # generated (accepted) tokens. + num_full_blocks_after_append = (num_computed_tokens + num_tokens - len( + request.spec_token_ids)) // self.block_size + + self.block_pool.cache_full_blocks( + request=request, + blocks=self.req_to_blocks[request.request_id], + block_hashes=block_hashes, + num_cached_blocks=num_cached_blocks, + num_full_blocks=num_full_blocks_after_append, + block_size=self.block_size, + manager_id=self.kv_cache_manager_id, + ) + + self.num_cached_block[ + request.request_id] = num_full_blocks_after_append + + def free(self, request_id: str) -> None: + # Default to [] in case a request is freed (aborted) before alloc. + blocks = self.req_to_blocks.pop(request_id, None) + if blocks is not None: + self.block_pool.free_blocks(reversed(blocks)) + + self.num_cached_block.pop(request_id, None) @abstractmethod def find_longest_cache_hit( - self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]: + self, + block_hashes: list[BlockHashType], + computed_blocks: Optional[list[GroupedKVCacheBlock]] = None, + ) -> list[GroupedKVCacheBlock]: """ Get the longest cache hit prefix of the blocks. If no cache hit is - found, return an empty list. + found, return an empty list. # TODO: add notes for computed_blocks + will not be longer than block_hashes. Args: block_hashes: The block hashes of the request. + computed_blocks: The cached blocks for the request returned from + the previous call of this function. Returns: A list of cached blocks with skipped blocks replaced by null block. For example, sliding window manager should return a list like @@ -49,8 +145,8 @@ def find_longest_cache_hit( raise NotImplementedError @abstractmethod - def remove_skipped_blocks(self, blocks: list[KVCacheBlock], - num_computed_tokens: int) -> list[KVCacheBlock]: + def remove_skipped_blocks(self, request_id: str, + num_computed_tokens: int) -> None: """ Remove the blocks that are no longer needed from `blocks`. The removed blocks should be replaced by null_block. Return the removed blocks in @@ -69,29 +165,38 @@ def remove_skipped_blocks(self, blocks: list[KVCacheBlock], class FullAttentionManager(SpecializedManager): def find_longest_cache_hit( - self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]: - computed_blocks: list[KVCacheBlock] = [] - for block_hash in block_hashes: - # block_hashes is a chain of block hashes. If a block hash is not - # in the cached_block_hash_to_id, the following block hashes are - # not computed yet for sure. - if cached_block := self.block_pool.get_cached_block(block_hash): - computed_blocks.append(cached_block) - else: - break + self, + block_hashes: list[BlockHashType], + computed_blocks: Optional[list[GroupedKVCacheBlock]] = None + ) -> list[GroupedKVCacheBlock]: + if computed_blocks is None: + computed_blocks = [] + for block_hash in block_hashes: + # block_hashes is a chain of block hashes. If a block hash is + # not in the cached_block_hash_to_id, the following block hashes + # are not computed yet for sure. + if cached_block := self.block_pool.get_cached_block( + block_hash, self.kv_cache_manager_id): + computed_blocks.append(cached_block) + else: + break + else: + assert len(computed_blocks) >= len(block_hashes) + del computed_blocks[len(block_hashes):] return computed_blocks - def remove_skipped_blocks(self, blocks: list[KVCacheBlock], - num_computed_tokens: int) -> list[KVCacheBlock]: + def remove_skipped_blocks(self, request_id: str, + num_computed_tokens: int) -> None: # No need to remove blocks for full attention. - return [] + pass class SlidingWindowManager(SpecializedManager): - def __init__(self, kv_cache_spec: SlidingWindowSpec, - block_pool: BlockPool): - super().__init__(kv_cache_spec, block_pool) + def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool, + kv_cache_manager_id: int, num_kv_cache_groups: int): + super().__init__(kv_cache_spec, block_pool, kv_cache_manager_id, + num_kv_cache_groups) self.sliding_window = kv_cache_spec.sliding_window # The number of contiguous blocks needed for prefix cache hit. # -1 since the input token itself is also included in the window @@ -100,19 +205,36 @@ def __init__(self, kv_cache_spec: SlidingWindowSpec, self._null_block = block_pool.null_block def find_longest_cache_hit( - self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]: + self, + block_hashes: list[BlockHashType], + computed_blocks: Optional[list[GroupedKVCacheBlock]] = None + ) -> list[GroupedKVCacheBlock]: # TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to # optimize the time complexity from O(len(block_hashes)) to # O(len(block_hashes) / sliding_window_contiguous_blocks + # sliding_window_contiguous_blocks), # which is good for low cache hit rate scenarios. - computed_blocks = [self._null_block] * len(block_hashes) - num_contiguous_blocks = 0 + if computed_blocks is None: + num_contiguous_blocks = 0 + computed_blocks = [ + GroupedKVCacheBlock.from_kv_cache_blocks( + tuple([self._null_block] * self.num_kv_cache_groups)) + for _ in range(len(block_hashes)) + ] + else: + if len(computed_blocks) == len(block_hashes): + return computed_blocks + # We are sure the last num_contiguous_blocks are not NULL and do + # not need to check again. + num_contiguous_blocks = max( + self.sliding_window_contiguous_blocks - + (len(computed_blocks) - len(block_hashes)), 0) + del computed_blocks[len(block_hashes):] # Search from right to left and early stop when a match is found. - for i in range(len(block_hashes) - 1, -1, -1): + for i in range(len(block_hashes) - num_contiguous_blocks - 1, -1, -1): if cached_block := self.block_pool.get_cached_block( - block_hashes[i]): + block_hashes[i], self.kv_cache_manager_id): computed_blocks[i] = cached_block num_contiguous_blocks += 1 if (num_contiguous_blocks @@ -129,23 +251,25 @@ def find_longest_cache_hit( del computed_blocks[num_contiguous_blocks:] return computed_blocks - def remove_skipped_blocks(self, blocks: list[KVCacheBlock], - num_computed_tokens: int) -> list[KVCacheBlock]: + def remove_skipped_blocks(self, request_id: str, + num_computed_tokens: int) -> None: # Remove the blocks that are no longer be in the sliding window and # skipped during the attention computation. last_useful_token = num_computed_tokens - self.sliding_window + 1 last_useful_block = last_useful_token // self.block_size + blocks = self.req_to_blocks[request_id] - removed_blocks: list[KVCacheBlock] = [] + removed_blocks: list[GroupedKVCacheBlock] = [] for i in range(last_useful_block - 1, -1, -1): - if blocks[i] == self._null_block: + if blocks[i].blocks[0] == self._null_block: # If the block is already a null block, the blocks before it # should also have been set to null blocks by the previous calls # to this function. break removed_blocks.append(blocks[i]) - blocks[i] = self._null_block - return removed_blocks + blocks[i] = GroupedKVCacheBlock.from_kv_cache_blocks( + tuple([self._null_block] * self.num_kv_cache_groups)) + self.block_pool.free_blocks(removed_blocks) spec_manager_map: dict[type[KVCacheSpec], type[SpecializedManager]] = { @@ -154,8 +278,10 @@ def remove_skipped_blocks(self, blocks: list[KVCacheBlock], } -def get_specialized_manager(kv_cache_spec: KVCacheSpec, - block_pool: BlockPool) -> SpecializedManager: +def get_specialized_manager(kv_cache_spec: KVCacheSpec, block_pool: BlockPool, + kv_cache_manager_id: int, + num_kv_cache_groups: int) -> SpecializedManager: manager_class = spec_manager_map[type(kv_cache_spec)] - manager = manager_class(kv_cache_spec, block_pool) + manager = manager_class(kv_cache_spec, block_pool, kv_cache_manager_id, + num_kv_cache_groups) return manager diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 9590a9aadbec..31af54a47b05 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -150,6 +150,7 @@ def _initialize_kv_caches( num_gpu_blocks = kv_cache_configs[0].num_blocks num_cpu_blocks = 0 scheduler_kv_cache_config = kv_cache_configs[0] + print("kv_cache_config", scheduler_kv_cache_config) # Initialize kv cache and warmup the execution self.model_executor.initialize_from_config(kv_cache_configs) diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 4fc0844cd1f4..449afdacde00 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass +from typing import Optional import torch @@ -56,8 +57,9 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: @dataclass class AttentionSpec(KVCacheSpec): - num_kv_heads: int head_size: int + num_query_heads: int + num_kv_heads: int dtype: torch.dtype use_mla: bool @@ -71,6 +73,11 @@ def page_size_bytes(self) -> int: @dataclass class FullAttentionSpec(AttentionSpec): + # Some layers may be regarded as full attention layers in KV cache manager ( + # blocks are allocated for all tokens), while computed as sliding window + # attention. In this case, we use FullAttentionSpec and record the + # sliding window size. + sliding_window: Optional[int] = None @property def type_id(self) -> str: @@ -112,15 +119,30 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: @dataclass -class KVCacheTensor: +class KVCacheTensorBase: """ A dataclass for specifying how the workers should initialize the KV cache - for a layer. Only contains the size of KV cache for that layer for now. Will - be extended to support multiple layers sharing the same memory pool. + for a layer. + """ + pass + + +@dataclass +class KVCacheNewTensor(KVCacheTensorBase): + """ + Initialize the KV cache with a tensor of `size` bytes. """ size: int # The size of KV cache Tensor in bytes +@dataclass +class KVCacheReuseTensor(KVCacheTensorBase): + """ + Reuse the KV cache tensor of `layer_name` for the current layer. + """ + reused_layer_name: str + + @dataclass class KVCacheGroupSpec: """ @@ -141,7 +163,7 @@ class KVCacheConfig: """The number of KV cache blocks""" num_blocks: int """layer_name -> how to initialize KV cache for that layer""" - tensors: dict[str, KVCacheTensor] + tensors: dict[str, KVCacheTensorBase] """ The kv cache groups of the model. The layers in the models are repeated with some patterns, e.g., a model diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 1de14584d396..a84627b847e4 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -4,7 +4,9 @@ import triton import triton.language as tl -from vllm.config import VllmConfig, set_current_vllm_config +from vllm.attention.layer import Attention +from vllm.config import (VllmConfig, get_layers_from_vllm_config, + set_current_vllm_config) from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader.loader import get_model_loader @@ -215,6 +217,8 @@ def load_model(self, target_model: nn.Module) -> None: loader = get_model_loader(self.vllm_config.load_config) target_layer_num = self.vllm_config.model_config.get_num_layers( self.vllm_config.parallel_config) + target_attn_layer_names = set( + get_layers_from_vllm_config(self.vllm_config, Attention).keys()) draft_model_config = \ self.vllm_config.speculative_config.draft_model_config @@ -235,6 +239,12 @@ def load_model(self, target_model: nn.Module) -> None: model_config=draft_model_config, start_layer_id=target_layer_num).to(target_device) + draft_attn_layer_names = ( + get_layers_from_vllm_config(self.vllm_config, Attention).keys() - + target_attn_layer_names) + assert len(draft_attn_layer_names) == 1 + self.attn_layer_name = iter(draft_attn_layer_names).__next__() + loaded_weights = self.model.load_weights( loader.get_all_weights( self.vllm_config.speculative_config.draft_model_config, diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 7d4082b73992..bbd536f15a3d 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -1,9 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 +from typing import Callable, Concatenate, ParamSpec + import numpy as np import torch from vllm.logger import init_logger +from vllm.utils import cdiv +from vllm.v1.kv_cache_interface import KVCacheConfig logger = init_logger(__name__) @@ -14,11 +18,13 @@ def __init__( self, max_num_reqs: int, max_num_blocks_per_req: int, + max_num_batched_tokens: int, pin_memory: bool, device: torch.device, ): self.max_num_reqs = max_num_reqs self.max_num_blocks_per_req = max_num_blocks_per_req + self.max_num_batched_tokens = max_num_batched_tokens self.pin_memory = pin_memory self.device = device @@ -36,6 +42,12 @@ def __init__( self.block_table_np = self.block_table_cpu.numpy() self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32) + self.slot_mapping_cpu = torch.zeros(self.max_num_batched_tokens, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + self.slot_mapping_np = self.slot_mapping_cpu.numpy() + def append_row( self, block_ids: list[int], @@ -85,3 +97,57 @@ def get_cpu_tensor(self) -> torch.Tensor: def get_numpy_array(self) -> np.ndarray: """Returns the numpy array of the block table.""" return self.block_table_np + + +P = ParamSpec("P") + + +class MultiGroupBlockTable: + move_row: Callable[P, None] + swap_row: Callable[P, None] + commit: Callable[P, None] + clear: Callable[P, None] + + append_row: Callable[Concatenate[list[list[int]], P], None] + add_row: Callable[Concatenate[list[list[int]], P], None] + + def __init__(self, max_num_reqs: int, max_model_len: int, + max_num_batched_tokens: int, pin_memory: bool, + device: torch.device, kv_cache_config: KVCacheConfig) -> None: + max_num_blocks_per_req = [ + cdiv(max_model_len, g.kv_cache_spec.block_size) + for g in kv_cache_config.kv_cache_groups + ] + self.block_tables = [ + BlockTable(max_num_reqs, max_num_blocks_per_req[i], + max_num_batched_tokens, pin_memory, device) + for i in range(len(kv_cache_config.kv_cache_groups)) + ] + # For methods that just pass the arguments to each BlockTable. + for f_name in ("move_row", "swap_row", "commit", "clear"): + setattr(self, f_name, self._make_broadcast_func(f_name)) + # For methods that require a block_ids as the first argument. + for f_name in ("append_row", "add_row"): + setattr(self, f_name, + self._make_broadcast_func_with_block_ids(f_name)) + + def _make_broadcast_func(self, f_name: str) -> Callable[P, None]: + + def broadcast_func(*args: P.args, **kwargs: P.kwargs) -> None: + for block_table in self.block_tables: + getattr(block_table, f_name)(*args, **kwargs) + + return broadcast_func + + def _make_broadcast_func_with_block_ids( + self, f_name: str) -> Callable[Concatenate[list[int], P], None]: + + def broadcast_func(block_ids: list[int], *args: P.args, + **kwargs: P.kwargs) -> None: + for i, block_table in enumerate(self.block_tables): + getattr(block_table, f_name)(block_ids[i], *args, **kwargs) + + return broadcast_func + + def __getitem__(self, idx: int) -> "BlockTable": + return self.block_tables[idx] diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index a64cb97e0123..b706a47f6ad4 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -11,10 +11,11 @@ from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams, SamplingType from vllm.utils import swap_dict_values +from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.outputs import LogprobsTensors from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.utils import copy_slice -from vllm.v1.worker.block_table import BlockTable +from vllm.v1.worker.block_table import MultiGroupBlockTable _SAMPLING_EPS = 1e-5 @@ -30,7 +31,7 @@ class CachedRequestState: sampling_params: SamplingParams generator: Optional[torch.Generator] - block_ids: list[int] + block_ids: list[list[int]] num_computed_tokens: int output_token_ids: list[int] @@ -59,14 +60,14 @@ def __init__( self, max_num_reqs: int, max_model_len: int, - max_num_blocks_per_req: int, + max_num_batched_tokens: int, device: torch.device, pin_memory: bool, vocab_size: int, + kv_cache_config: KVCacheConfig, ): self.max_num_reqs = max_num_reqs self.max_model_len = max_model_len - self.max_num_blocks_per_req = max_num_blocks_per_req self.device = device self.pin_memory = pin_memory self.vocab_size = vocab_size @@ -98,11 +99,13 @@ def __init__( self.num_computed_tokens_cpu_tensor.numpy() # Block table. - self.block_table = BlockTable( + self.block_table = MultiGroupBlockTable( max_num_reqs=max_num_reqs, - max_num_blocks_per_req=max_num_blocks_per_req, + max_model_len=max_model_len, + max_num_batched_tokens=max_num_batched_tokens, pin_memory=pin_memory, device=device, + kv_cache_config=kv_cache_config, ) # Sampling-related. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 7910481762ef..6bfc1d1edc3a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3,7 +3,7 @@ import gc import time import weakref -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Optional, Union, cast import numpy as np import torch @@ -11,14 +11,16 @@ import torch.nn as nn from vllm.attention import AttentionType, get_attn_backend +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadataBuilder) from vllm.attention.layer import Attention -from vllm.config import CompilationLevel, VllmConfig +from vllm.config import (CompilationLevel, VllmConfig, + get_layers_from_vllm_config) from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) from vllm.distributed.parallel_state import get_pp_group, graph_capture from vllm.forward_context import set_forward_context from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import get_model from vllm.multimodal import MULTIMODAL_REGISTRY @@ -27,12 +29,14 @@ from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, - GiB_bytes, LayerBlockType, LazyLoader, cdiv, - check_use_alibi, is_pin_memory_available) + GiB_bytes, LazyLoader, cdiv, check_use_alibi, + is_pin_memory_available) from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata +from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, - KVCacheConfig, KVCacheSpec, + KVCacheConfig, KVCacheNewTensor, + KVCacheReuseTensor, KVCacheSpec, SlidingWindowSpec) from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, ModelRunnerOutput) @@ -44,6 +48,7 @@ from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.spec_decode.utils import is_spec_decode_supported from vllm.v1.utils import bind_kv_cache +from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin @@ -85,7 +90,7 @@ def __init__( model_config = self.model_config cache_config = self.cache_config scheduler_config = self.scheduler_config - parallel_config = self.parallel_config + self.device = device self.pin_memory = is_pin_memory_available() self.dtype = self.model_config.dtype @@ -95,51 +100,15 @@ def __init__( self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ cache_config.cache_dtype] - # NOTE(woosuk): sliding_window is None for models with interleaved - # attention. Use interleaved_sliding_window instead. - self.sliding_window = model_config.get_sliding_window() - self.interleaved_sliding_window = getattr( - model_config.hf_text_config, "interleaved_sliding_window", None) - self.window_size = (self.sliding_window - or self.interleaved_sliding_window) - self.is_multimodal_model = model_config.is_multimodal_model - self.block_size = cache_config.block_size self.max_model_len = model_config.max_model_len - self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_reqs = scheduler_config.max_num_seqs # Model-related. - self.num_attn_layers = model_config.get_num_layers_by_block_type( - parallel_config, LayerBlockType.attention) - self.num_query_heads = model_config.get_num_attention_heads( - parallel_config) - self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) - self.head_size = model_config.get_head_size() self.hidden_size = model_config.get_hidden_size() self.attention_chunk_size = model_config.attention_chunk_size - self.attn_backend = get_attn_backend( - self.head_size, - self.dtype, - self.kv_cache_dtype, - self.block_size, - self.model_config.is_attention_free, - use_mla=self.model_config.use_mla, - ) - if self.attn_backend is None: - error_msg = ( - f"Error with get_att_backend: {self.head_size=}, " - f"{self.dtype=}, {self.kv_cache_dtype=}, {self.block_size=}, " - f"{self.model_config.is_attention_free=}, " - f"{self.model_config.use_mla=}") - logger.error(error_msg) - raise NotImplementedError( - "Non-Attention backend is not supported by V1 GPUModelRunner.") - - self.attn_metadata_builder = self.attn_backend.get_builder_cls()( - weakref.proxy(self)) self.cascade_attn_enabled = not self.model_config.disable_cascade_attn # Multi-modal data support @@ -157,9 +126,15 @@ def __init__( # Sampler self.sampler = Sampler() - # Lazy initialization + # Lazy initializations # self.model: nn.Module # Set after load_model - self.kv_caches: list[torch.Tensor] = [] + # Initialized in initialize_kv_cache + self.kv_cache_config = cast(KVCacheConfig, None) + self.attn_backends: list[type[AttentionBackend]] = [] + self.attn_metadata_builders: list[type[AttentionMetadataBuilder]] = [] + # Persistent batch + self.input_batch = cast(InputBatch, None) + # req_id -> (input_id -> encoder_output) self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} @@ -184,15 +159,6 @@ def __init__( # Request states. self.requests: dict[str, CachedRequestState] = {} - # Persistent batch. - self.input_batch = InputBatch( - max_num_reqs=self.max_num_reqs, - max_model_len=self.max_model_len, - max_num_blocks_per_req=self.max_num_blocks_per_req, - device=self.device, - pin_memory=self.pin_memory, - vocab_size=model_config.get_vocab_size(), - ) self.use_cuda_graph = (self.vllm_config.compilation_config.level == CompilationLevel.PIECEWISE @@ -267,11 +233,6 @@ def __init__( device="cpu", pin_memory=self.pin_memory) self.positions_np = self.positions_cpu.numpy() - self.slot_mapping_cpu = torch.zeros(self.max_num_tokens, - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) - self.slot_mapping_np = self.slot_mapping_cpu.numpy() self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1, dtype=torch.int32, device="cpu", @@ -283,6 +244,32 @@ def __init__( pin_memory=self.pin_memory) self.seq_lens_np = self.seq_lens_cpu.numpy() + def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> bool: + """ + Update the order of requests in the batch based on the attention + backend's needs. For example, some attention backends (namely MLA) may + want to separate requests based on if the attention computation will be + compute-bound or memory-bound. + + Args: + scheduler_output: The scheduler output. + + Returns: + True if the batch was reordered, False otherwise. + """ + batch_reordered = self.attn_metadata_builders[0].reorder_batch( + self.input_batch, scheduler_output) + + # For models with multiple KV cache groups, the groups should agree on + # the same order of requests. We ensure this by only allowing the first + # group to reorder the batch. + for kv_cache_group_id in range( + 1, len(self.kv_cache_config.kv_cache_groups)): + assert not self.attn_metadata_builders[ + kv_cache_group_id].reorder_batch(self.input_batch, + scheduler_output) + return batch_reordered + def _update_states(self, scheduler_output: "SchedulerOutput") -> None: """Update the cached states and the persistent batch with the scheduler output. @@ -420,7 +407,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Update the block IDs. if not req_data.resumed_from_preemption: # Append the new blocks to the existing block IDs. - req_state.block_ids.extend(req_data.new_block_ids) + for i in range(len(self.kv_cache_config.kv_cache_groups)): + req_state.block_ids[i].extend(req_data.new_block_ids[i]) else: # The request is resumed from preemption. # Replace the existing block IDs with the new ones. @@ -478,11 +466,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: if removed_req_indices: self.input_batch.condense(removed_req_indices) - # Some attention backends (namely MLA) may want to separate requests - # based on if the attention computation will be compute-bound or - # memory-bound. This gives them a hook to do that. - batch_reordered = self.attn_metadata_builder.reorder_batch( - self.input_batch, scheduler_output) + batch_reordered = self._may_reorder_batch(scheduler_output) if batch_changed or batch_reordered: self.input_batch.refresh_sampling_metadata() @@ -490,7 +474,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: def _prepare_inputs( self, scheduler_output: "SchedulerOutput", - ) -> tuple[FlashAttentionMetadata, torch.Tensor, + ) -> tuple[dict[str, FlashAttentionMetadata], torch.Tensor, Optional[SpecDecodeMetadata]]: total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 @@ -550,20 +534,33 @@ def _prepare_inputs( torch.from_numpy(token_indices), out=self.input_ids_cpu[:total_num_scheduled_tokens]) - # Calculate the slot mapping. - # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] - # where K is the max_num_blocks_per_req and the block size is 2. - # NOTE(woosuk): We can't simply use `token_indices // block_size` here - # because M (max_model_len) is not necessarily divisible by block_size. - block_table_indices = (req_indices * self.max_num_blocks_per_req + - positions_np // self.block_size) - block_table_cpu = self.input_batch.block_table.get_cpu_tensor() - block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() - block_offsets = positions_np % self.block_size - np.add(block_numbers * self.block_size, - block_offsets, - out=self.slot_mapping_np[:total_num_scheduled_tokens]) + # Calculate the slot mapping for each KV cache group. + for kv_cache_group_id, kv_cache_group_spec in enumerate( + self.kv_cache_config.kv_cache_groups): + block_size = kv_cache_group_spec.kv_cache_spec.block_size + block_table: BlockTable = self.input_batch.block_table[ + kv_cache_group_id] + # Calculate the slot mapping. + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] + # where K is the max_num_blocks_per_req and the block size is 2. + # NOTE(woosuk): We can't simply use `token_indices // block_size` + # here because M (max_model_len) is not necessarily divisible by + # block_size. + block_table_indices = ( + req_indices * block_table.max_num_blocks_per_req + + positions_np // block_size) + # NOTE(woosuk): We use torch.index_select instead of np.take here + # because torch.index_select is much faster than np.take for large + # tensors. + block_table_cpu = block_table.get_cpu_tensor() + block_numbers = ( + block_table_cpu.flatten()[block_table_indices].numpy()) + block_offsets = positions_np % block_size + np.add( + block_numbers * block_size, + block_offsets, + out=block_table.slot_mapping_np[:total_num_scheduled_tokens]) # Prepare the attention metadata. self.query_start_loc_np[0] = 0 @@ -576,6 +573,12 @@ def _prepare_inputs( # Copy the tensors to the GPU. self.input_ids[:total_num_scheduled_tokens].copy_( self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) + query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to( + self.device, non_blocking=True) + seq_lens = self.seq_lens_cpu[:num_reqs].to(self.device, + non_blocking=True) + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=query_start_loc, seq_lens=seq_lens) if self.uses_mrope: # Only relevant for models using M-RoPE (e.g, Qwen2-VL) self.mrope_positions[:, :total_num_scheduled_tokens].copy_( @@ -587,20 +590,33 @@ def _prepare_inputs( self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True) - # Prepare for cascade attention if enabled & beneficial. - common_prefix_len = 0 - if self.cascade_attn_enabled: - common_prefix_len = self._compute_cascade_attn_prefix_len( - num_scheduled_tokens, - scheduler_output.num_common_prefix_blocks, - ) + attn_metadata: dict[str, FlashAttentionMetadata] = {} + # Prepare the attention metadata for each KV cache group and make layers + # in the same group share the same metadata. + for kv_cache_group_id, kv_cache_group_spec in enumerate( + self.kv_cache_config.kv_cache_groups): + + # Prepare for cascade attention if enabled & beneficial. + common_prefix_len = 0 + if self.cascade_attn_enabled: + common_prefix_len = self._compute_cascade_attn_prefix_len( + num_scheduled_tokens, + scheduler_output. + num_common_prefix_blocks[kv_cache_group_id], + kv_cache_group_spec.kv_cache_spec, + self.attn_metadata_builders[kv_cache_group_id], + ) - attn_metadata = self.attn_metadata_builder.build( - num_reqs=num_reqs, - num_actual_tokens=total_num_scheduled_tokens, - max_query_len=max_num_scheduled_tokens, - common_prefix_len=common_prefix_len, - ) + block_table = self.input_batch.block_table[kv_cache_group_id] + attn_metadata_i = ( + self.attn_metadata_builders[kv_cache_group_id].build( + num_reqs=num_reqs, + num_actual_tokens=total_num_scheduled_tokens, + max_query_len=max_num_scheduled_tokens, + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata)) + for layer_name in kv_cache_group_spec.layer_names: + attn_metadata[layer_name] = attn_metadata_i use_spec_decode = len( scheduler_output.scheduled_spec_decode_tokens) > 0 @@ -610,7 +626,7 @@ def _prepare_inputs( # from these partial requests, we do so for simplicity. # We will ignore the sampled tokens from the partial requests. # TODO: Support prompt logprobs. - logits_indices = attn_metadata.query_start_loc[1:] - 1 + logits_indices = query_start_loc[1:] - 1 spec_decode_metadata = None else: # Get the number of draft tokens for each request. @@ -636,6 +652,8 @@ def _compute_cascade_attn_prefix_len( self, num_scheduled_tokens: np.ndarray, num_common_prefix_blocks: int, + kv_cache_spec: KVCacheSpec, + attn_metadata_builder: AttentionMetadataBuilder, ) -> int: """Compute the length of the common prefix for cascade attention. @@ -654,7 +672,7 @@ def _compute_cascade_attn_prefix_len( Returns: int: Length of common prefix in tokens. """ - common_prefix_len = num_common_prefix_blocks * self.block_size + common_prefix_len = num_common_prefix_blocks * kv_cache_spec.block_size if common_prefix_len == 0: # Common case. return 0 @@ -703,15 +721,19 @@ def _compute_cascade_attn_prefix_len( common_prefix_len, self.input_batch.num_computed_tokens_cpu[:num_reqs].min()) # common_prefix_len should be a multiple of the block size. - common_prefix_len = (common_prefix_len // self.block_size * - self.block_size) - use_cascade = self.attn_metadata_builder.use_cascade_attention( + common_prefix_len = (common_prefix_len // kv_cache_spec.block_size * + kv_cache_spec.block_size) + use_sliding_window = (isinstance(kv_cache_spec, SlidingWindowSpec) or + (isinstance(kv_cache_spec, FullAttentionSpec) + and kv_cache_spec.sliding_window is not None)) + assert isinstance(kv_cache_spec, AttentionSpec) + use_cascade = attn_metadata_builder.use_cascade_attention( common_prefix_len=common_prefix_len, query_lens=num_scheduled_tokens, - num_query_heads=self.num_query_heads, - num_kv_heads=self.num_kv_heads, + num_query_heads=kv_cache_spec.num_query_heads, + num_kv_heads=kv_cache_spec.num_kv_heads, use_alibi=self.use_alibi, - use_sliding_window=self.window_size is not None, + use_sliding_window=use_sliding_window, num_sms=self.num_sms, ) return common_prefix_len if use_cascade else 0 @@ -1030,7 +1052,11 @@ def execute_model( else: # Eager mode. num_input_tokens = num_scheduled_tokens - attn_metadata.num_input_tokens = num_input_tokens + + for kv_cache_group_spec in self.kv_cache_config.kv_cache_groups: + # TODO: merge https://github.com/vllm-project/vllm/pull/17193 + layer_name = kv_cache_group_spec.layer_names[0] + attn_metadata[layer_name].num_input_tokens = num_input_tokens # _prepare_inputs may reorder the batch, so we must gather multi # modal outputs after that to ensure the correct order @@ -1213,7 +1239,7 @@ def execute_model( next_token_ids = torch.tensor(next_token_ids, dtype=torch.int32, device=self.device) - + eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name] if spec_decode_metadata is None: # input_ids can be None for multimodal models. # We need to slice token_ids, positions, and hidden_states @@ -1227,8 +1253,8 @@ def execute_model( ] else: target_hidden_states = hidden_states[:num_scheduled_tokens] - target_slot_mapping = attn_metadata.slot_mapping - cu_num_tokens = attn_metadata.query_start_loc + target_slot_mapping = eagle_attn_metadata.slot_mapping + cu_num_tokens = eagle_attn_metadata.query_start_loc else: # TODO(woosuk): Refactor this. num_draft_tokens = spec_decode_metadata.num_draft_tokens @@ -1242,7 +1268,7 @@ def execute_model( device=self.device, ) cu_num_tokens, token_indices = self.drafter.prepare_inputs( - attn_metadata.query_start_loc, + eagle_attn_metadata.query_start_loc, num_rejected_tokens, ) target_token_ids = self.input_ids[token_indices] @@ -1253,7 +1279,8 @@ def execute_model( ] else: target_hidden_states = hidden_states[token_indices] - target_slot_mapping = attn_metadata.slot_mapping[token_indices] + target_slot_mapping = eagle_attn_metadata.slot_mapping[ + token_indices] if self.use_aux_hidden_state_outputs: target_hidden_states = torch.cat(target_hidden_states, dim=-1) @@ -1264,7 +1291,7 @@ def execute_model( target_slot_mapping=target_slot_mapping, next_token_ids=next_token_ids, cu_num_tokens=cu_num_tokens, - block_table=attn_metadata.block_table, + block_table=eagle_attn_metadata.block_table, sampling_metadata=sampling_metadata, ) spec_token_ids = draft_token_ids.tolist() @@ -1681,51 +1708,138 @@ def capture_model(self) -> None: logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", elapsed_time, cuda_graph_size / (1 << 30)) - def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: + def _initialize_kv_cache_buffer( + self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: """ - Initialize KV cache based on `kv_cache_config`. + Initializes the KV cache buffer with the correct size. The buffer needs + to be reshaped to the desired shape before being used by the models. Args: - kv_cache_config: Configuration for the KV cache, including the KV - cache size of each layer + kv_cache_config: The KV cache config + Returns: + dict[str, torch.Tensor]: A map between layer names to their + corresponding memory buffer for KV cache. + """ + kv_cache_raw_tensors: dict[str, torch.Tensor] = {} + for layer_name, tensor_config in kv_cache_config.tensors.items(): + if isinstance(tensor_config, KVCacheNewTensor): + # A new tensor with `tensor_config.size` bytes + kv_cache_raw_tensors[layer_name] = torch.zeros( + tensor_config.size, dtype=torch.int8, device=self.device) + for layer_name, tensor_config in kv_cache_config.tensors.items(): + if isinstance(tensor_config, KVCacheReuseTensor): + # Reuse a tensor from `kv_cache_raw_tensors` + kv_cache_raw_tensors[layer_name] = kv_cache_raw_tensors[ + tensor_config.reused_layer_name] + assert len(kv_cache_raw_tensors) == len( + kv_cache_config.tensors), "Some layers are not initialized" + return kv_cache_raw_tensors + + def _setup_kv_cache_shapes( + self, + kv_cache_config: KVCacheConfig, + kv_cache_raw_tensors: dict[str, torch.Tensor], + ) -> dict[str, torch.Tensor]: + """ + Reshape the KV cache tensors to the desired shape. + Args: + kv_cache_config: The KV cache config + kv_cache_raw_tensors: The KV cache buffer of each layer, with + correct size but uninitialized shape. + Returns: + Dict[str, torch.Tensor]: A map between layer names to their + corresponding memory buffer for KV cache. """ - if len(kv_cache_config.kv_cache_groups) > 1: - raise NotImplementedError( - "Hybrid models with more than one KV cache type are not " - "supported yet.") - kv_caches: dict[str, torch.Tensor] = {} - - for kv_cache_group in kv_cache_config.kv_cache_groups: - kv_cache_spec = kv_cache_group.kv_cache_spec - for layer_name in kv_cache_group.layer_names: - tensor_config = kv_cache_config.tensors[layer_name] - assert tensor_config.size % kv_cache_spec.page_size_bytes == 0 - num_blocks = tensor_config.size // kv_cache_spec.page_size_bytes - # `num_blocks` is the number of blocks the model runner can use. - # `kv_cache_config.num_blocks` is the number of blocks that - # KVCacheManager may allocate. - # Since different GPUs may have different number of layers and - # different memory capacities, `num_blocks` can be different on - # different GPUs, and `kv_cache_config.num_blocks` is set to - # the min of all `num_blocks`. Verify it here. - assert num_blocks >= kv_cache_config.num_blocks + for i, kv_cache_group_spec in enumerate( + kv_cache_config.kv_cache_groups): + kv_cache_spec = kv_cache_group_spec.kv_cache_spec + for layer_name in kv_cache_group_spec.layer_names: + raw_tensor = kv_cache_raw_tensors[layer_name] + assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 + num_blocks = (raw_tensor.numel() // + kv_cache_spec.page_size_bytes) if isinstance(kv_cache_spec, AttentionSpec): - kv_cache_shape = self.attn_backend.get_kv_cache_shape( + kv_cache_shape = self.attn_backends[i].get_kv_cache_shape( num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) dtype = kv_cache_spec.dtype - kv_caches[layer_name] = torch.zeros(kv_cache_shape, - dtype=dtype, - device=self.device) + kv_caches[layer_name] = kv_cache_raw_tensors[ + layer_name].view(dtype).view(kv_cache_shape) else: - # TODO: add new branches when introducing more types of - # KV cache specs. - raise ValueError("Unknown KV cache spec type.") - + raise NotImplementedError + return kv_caches + + def initialize_kv_cache_tensors(self, + kv_cache_config: KVCacheConfig) -> None: + # TODO: docstring + # Initialize the memory buffer for KV cache + kv_cache_raw_tensors = self._initialize_kv_cache_buffer( + kv_cache_config) + # Change the memory buffer to the desired shape + kv_caches = self._setup_kv_cache_shapes(kv_cache_config, + kv_cache_raw_tensors) bind_kv_cache( kv_caches, - self.vllm_config.compilation_config.static_forward_context, - self.kv_caches) + self.vllm_config.compilation_config.static_forward_context, []) + + def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: + # TODO: docstring + assert len(self.attn_backends) == 0 and len( + self.attn_metadata_builders) == 0, "already initialized" + for i, kv_cache_group_spec in enumerate( + kv_cache_config.kv_cache_groups): + kv_cache_spec = kv_cache_group_spec.kv_cache_spec + if not isinstance(kv_cache_spec, AttentionSpec): + raise NotImplementedError( + "Only AttentionSpec is supported for now.") + attn_backend_i = get_attn_backend( + kv_cache_spec.head_size, + self.dtype, + kv_cache_spec.dtype, + kv_cache_spec.block_size, + self.model_config.is_attention_free, + use_mla=kv_cache_spec.use_mla, + ) + if attn_backend_i is None: + error_msg = ( + f"Error with get_attn_backend: {kv_cache_spec.head_size=}, " + f"{self.dtype=}, {kv_cache_spec.dtype=}, " + f"{kv_cache_spec.block_size=}, " + f"{self.model_config.is_attention_free=}, " + f"{kv_cache_spec.use_mla=}") + logger.error(error_msg) + raise NotImplementedError( + "Non-Attention backend is not supported by V1 " + "GPUModelRunner.") + block_table_i = self.input_batch.block_table[i] + attn_metadata_builder_i = attn_backend_i.get_builder_cls()( + weakref.proxy(self), kv_cache_spec, block_table_i) + self.attn_backends.append(attn_backend_i) + self.attn_metadata_builders.append(attn_metadata_builder_i) + + assert all(builder.reorder_batch.__func__ is + self.attn_metadata_builders[0].reorder_batch.__func__ + for builder in self.attn_metadata_builders), "TODO" + + def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: + """ + Initialize KV cache based on `kv_cache_config`. + Args: + kv_cache_config: Configuration for the KV cache, including the KV + cache size of each layer + """ + self.kv_cache_config = kv_cache_config + self.input_batch = InputBatch( + max_num_reqs=self.max_num_reqs, + max_model_len=self.max_model_len, + max_num_batched_tokens=self.max_num_tokens, + device=self.device, + pin_memory=self.pin_memory, + vocab_size=self.model_config.get_vocab_size(), + kv_cache_config=kv_cache_config, + ) + self.initialize_attn_backend(kv_cache_config) + self.initialize_kv_cache_tensors(kv_cache_config) def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: """ @@ -1736,31 +1850,28 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: format. Layers that do not need KV cache are not included. """ - forward_ctx = self.vllm_config.compilation_config.static_forward_context + layers = get_layers_from_vllm_config(self.vllm_config, Attention) block_size = self.vllm_config.cache_config.block_size use_mla = self.vllm_config.model_config.use_mla kv_cache_spec: dict[str, KVCacheSpec] = {} - for layer_name, attn_module in forward_ctx.items(): - if isinstance(attn_module, FusedMoE): - continue - - # TODO: Support other attention modules, e.g., sliding window, - # cross-attention - assert isinstance(attn_module, Attention) + for layer_name, attn_module in layers.items(): + # TODO: Support other attention modules, e.g., cross-attention if attn_module.attn_type == AttentionType.DECODER: if attn_module.sliding_window is not None: kv_cache_spec[layer_name] = SlidingWindowSpec( block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, + num_query_heads=attn_module.num_heads, + num_kv_heads=attn_module.num_kv_heads, dtype=self.kv_cache_dtype, sliding_window=attn_module.sliding_window, use_mla=use_mla) else: kv_cache_spec[layer_name] = FullAttentionSpec( block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, + num_query_heads=attn_module.num_heads, + num_kv_heads=attn_module.num_kv_heads, dtype=self.kv_cache_dtype, use_mla=use_mla) elif attn_module.attn_type in (AttentionType.ENCODER, diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index e9cb0dbe8b5e..3fb8563f8c9b 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# type: ignore import bisect import gc import time @@ -17,7 +18,7 @@ from vllm.attention.backends.abstract import AttentionType from vllm.attention.layer import Attention from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher -from vllm.config import VllmConfig +from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model @@ -430,11 +431,10 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: format. Layers that do not need KV cache are not included. """ - forward_ctx = self.vllm_config.compilation_config.static_forward_context + layers = get_layers_from_vllm_config(self.vllm_config, Attention) block_size = self.vllm_config.cache_config.block_size kv_cache_spec: dict[str, KVCacheSpec] = {} - for layer_name, attn_module in forward_ctx.items(): - assert isinstance(attn_module, Attention) + for layer_name, attn_module in layers.items(): if attn_module.attn_type == AttentionType.DECODER: if attn_module.sliding_window is not None: kv_cache_spec[layer_name] = SlidingWindowSpec( diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index de676541effa..728a8f1b9051 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -22,7 +22,7 @@ KVCacheSpec) from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.utils import bind_kv_cache, report_usage_stats -from vllm.v1.worker.tpu_model_runner import TPUModelRunner +from vllm.v1.worker.tpu_model_runner import TPUModelRunner # type: ignore logger = init_logger(__name__)