From d35146f678765a0054c9facc214be79d4d79467c Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 25 Apr 2025 07:33:19 -0700 Subject: [PATCH 01/15] remove num_input_tokens from attn_metadata Signed-off-by: Chen Zhang --- vllm/forward_context.py | 14 ++++++-------- vllm/v1/attention/backends/flash_attn.py | 3 --- vllm/v1/attention/backends/flashinfer.py | 3 --- vllm/v1/attention/backends/mla/common.py | 3 --- vllm/v1/worker/gpu_model_runner.py | 5 +++-- vllm/v1/worker/tpu_model_runner.py | 5 ++++- 6 files changed, 13 insertions(+), 20 deletions(-) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 06790d8ee2f8..d7c43b56827d 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -74,15 +74,13 @@ def set_forward_context(attn_metadata: Any, if vllm_config.parallel_config.data_parallel_size > 1: dp_size = vllm_config.parallel_config.data_parallel_size dp_rank = vllm_config.parallel_config.data_parallel_rank - if attn_metadata is not None: - if hasattr(attn_metadata, "num_prefill_tokens"): - # for v0 attention backends - batchsize = attn_metadata.num_prefill_tokens + \ - attn_metadata.num_decode_tokens - else: - # for v1 attention backends - batchsize = attn_metadata.num_input_tokens + if attn_metadata is not None and hasattr(attn_metadata, + "num_prefill_tokens"): + # for v0 attention backends + batchsize = attn_metadata.num_prefill_tokens + \ + attn_metadata.num_decode_tokens else: + # for v1 attention backends or no attn_metadata batchsize = num_tokens num_tokens_across_dp = [0] * dp_size num_tokens_across_dp[dp_rank] = batchsize diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 51ae386d3389..7c20f94b915b 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -94,9 +94,6 @@ class FlashAttentionMetadata: scheduler_metadata: Optional[torch.Tensor] = None prefix_scheduler_metadata: Optional[torch.Tensor] = None - # For logging. - num_input_tokens: int = 0 # Number of tokens including padding. - # for local attention @dataclass class LocalAttentionMetadata: diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 17341ecfa4fe..aae170984ab2 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -184,9 +184,6 @@ class FlashInferMetadata: decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None cascade_wrapper: Optional[MultiLevelCascadeAttentionWrapper] = None - # For logging. - num_input_tokens: int = 0 # Number of tokens including padding. - @property def query_start_loc(self): # The GPUModelRunner expects to be able to access this property. diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index f826f8a21789..75a11bd46920 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -312,9 +312,6 @@ class MLACommonMetadata(Generic[D]): num_decode_tokens: int num_prefills: int - # For logging. - num_input_tokens: int = 0 # Number of tokens including padding. - # The dimension of the attention heads head_dim: Optional[int] = None diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 86f6a301fbb6..83c6aaa9168d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1026,7 +1026,6 @@ def execute_model( else: # Eager mode. num_input_tokens = num_scheduled_tokens - attn_metadata.num_input_tokens = num_input_tokens # _prepare_inputs may reorder the batch, so we must gather multi # modal outputs after that to ensure the correct order @@ -1078,7 +1077,9 @@ def execute_model( # Run the decoder. # Use persistent buffers for CUDA graphs. - with set_forward_context(attn_metadata, self.vllm_config): + with set_forward_context(attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens): hidden_states = self.model( input_ids=input_ids, positions=positions, diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index e9cb0dbe8b5e..0728efb168dc 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -771,7 +771,10 @@ def execute_model( xm.mark_step() num_reqs = self.input_batch.num_reqs # Run the decoder - with set_forward_context(attn_metadata, self.vllm_config): + with set_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=scheduler_output.total_num_scheduled_tokens): hidden_states = self.model( input_ids=input_ids, positions=self.position_ids, From 20d930be75bcfee9b1fe1f4d37af78f57732f864 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 25 Apr 2025 07:42:33 -0700 Subject: [PATCH 02/15] fix Signed-off-by: Chen Zhang --- vllm/forward_context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index d7c43b56827d..c75d8f088c5b 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -122,7 +122,7 @@ def set_forward_context(attn_metadata: Any, attn_metadata.num_decode_tokens else: # for v1 attention backends - batchsize = attn_metadata.num_input_tokens + batchsize = num_tokens # we use synchronous scheduling right now, # adding a sync point here should not affect # scheduling of the next batch From f0636dfca47c357a3f4eff2fc208f30bb3d7c837 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 29 Apr 2025 07:17:30 -0700 Subject: [PATCH 03/15] per_layer_attn_metadata Signed-off-by: Chen Zhang --- vllm/attention/layer.py | 15 +++++- vllm/forward_context.py | 11 +++-- vllm/v1/attention/backends/flash_attn.py | 11 ++--- vllm/v1/attention/backends/flashinfer.py | 10 ++-- vllm/v1/attention/backends/mla/common.py | 10 ++-- vllm/v1/attention/backends/utils.py | 13 ++++++ vllm/v1/spec_decode/eagle.py | 11 ++++- vllm/v1/worker/gpu_model_runner.py | 59 ++++++++++++++++-------- 8 files changed, 100 insertions(+), 40 deletions(-) create mode 100644 vllm/v1/attention/backends/utils.py diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index aa218cc37af9..358c055e78d9 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -210,6 +210,8 @@ def forward( if self.use_direct_call: forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[self.layer_name] self_kv_cache = self.kv_cache[forward_context.virtual_engine] self.impl.forward(self, query, @@ -226,6 +228,8 @@ def forward( if self.use_direct_call: forward_context = get_forward_context() attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[self.layer_name] self_kv_cache = self.kv_cache[forward_context.virtual_engine] return self.impl.forward(self, query, key, value, self_kv_cache, attn_metadata) @@ -340,7 +344,9 @@ def wait_for_kv_layer_from_connector(layer_name: str): connector = get_kv_transfer_group() forward_context: ForwardContext = get_forward_context() - attn_metadata = forward_context.attn_metadata + + assert isinstance(forward_context.attn_metadata, dict) + attn_metadata = forward_context.attn_metadata[layer_name] if attn_metadata is None: return @@ -357,7 +363,8 @@ def maybe_save_kv_layer_to_connector( connector = get_kv_transfer_group() forward_context: ForwardContext = get_forward_context() - attn_metadata = forward_context.attn_metadata + assert isinstance(forward_context.attn_metadata, dict) + attn_metadata = forward_context.attn_metadata[layer_name] if attn_metadata is None: return @@ -374,6 +381,8 @@ def unified_attention( forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[layer_name] self = forward_context.no_compile_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] output = self.impl.forward(self, query, key, value, kv_cache, @@ -411,6 +420,8 @@ def unified_attention_with_output( wait_for_kv_layer_from_connector(layer_name) forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[layer_name] self = forward_context.no_compile_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] self.impl.forward(self, diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 06790d8ee2f8..a7ab1272f187 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -4,7 +4,7 @@ from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, Union import torch import torch.distributed as dist @@ -38,8 +38,13 @@ class DPMetadata: class ForwardContext: # copy from vllm_config.compilation_config.static_forward_context no_compile_layers: dict[str, Any] - # TODO: extend to support per-layer dynamic forward context - attn_metadata: "AttentionMetadata" # set dynamically for each forward pass + """ + Type AttentionMetadata for v0, + Type Dict[str, AttentionMetadata] for v1, map from layer_name of each + attention layer to its attention metadata + set dynamically for each forward pass + """ + attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"]] # TODO: remove after making all virtual_engines share the same kv cache virtual_engine: int # set dynamically for each forward pass # set dynamically for each forward pass diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 41bb9aba2995..e1b452f85c8a 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -16,6 +16,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import cdiv +from vllm.v1.attention.backends.utils import CommonAttentionMetadata if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -295,13 +296,11 @@ def reorder_batch(self, input_batch: "InputBatch", return False def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, - common_prefix_len: int): + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata): max_seq_len = self.runner.seq_lens_np[:num_reqs].max() - query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1] - query_start_loc = query_start_loc_cpu.to(self.runner.device, - non_blocking=True) - seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs] - seq_lens = seq_lens_cpu.to(self.runner.device, non_blocking=True) + query_start_loc = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens block_table = ( self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index bce446bd2b82..04317925320e 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -18,6 +18,7 @@ get_layers_from_vllm_config) from vllm.logger import init_logger from vllm.v1.attention.backends.flash_attn import use_cascade_attention +from vllm.v1.attention.backends.utils import CommonAttentionMetadata if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -397,16 +398,15 @@ def _plan(self, attn_metadata: FlashInferMetadata): ) def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, - common_prefix_len: int): + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata): assert self._num_decodes + self._num_prefills == num_reqs assert (self._num_decode_tokens + self._num_prefill_tokens == num_actual_tokens) page_size = self.runner.block_size device = self.runner.device - qo_indptr = self.runner.query_start_loc_cpu[:num_reqs + 1].to( - self.runner.device, non_blocking=True) - seq_lens = self.runner.seq_lens_cpu[:num_reqs].to(self.runner.device, - non_blocking=True) + qo_indptr = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens block_table = ( self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index b032006d1ad1..fa25ba0b3bc6 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -205,6 +205,7 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.platforms import current_platform from vllm.utils import cdiv, round_down +from vllm.v1.attention.backends.utils import CommonAttentionMetadata try: from vllm.vllm_flash_attn import flash_attn_varlen_func @@ -452,7 +453,8 @@ def _build_decode(self, input_positions: torch.Tensor, ) def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, - common_prefix_len: int) -> M: + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata) -> M: assert self._num_decodes + self._num_prefills == num_reqs # Note(simon): be careful about the CPU <> GPU memory movement in this @@ -461,15 +463,13 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, device = self.runner.device block_table = ( self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) - query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to( - device, non_blocking=True) slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( device, non_blocking=True).long() input_positions = self.runner.positions_cpu[:num_actual_tokens].to( device, non_blocking=True).long() - seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs] - seq_lens = seq_lens_cpu.to(device, non_blocking=True) + query_start_loc = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens prefill_metadata = None if self._num_prefills > 0: diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py new file mode 100644 index 000000000000..b59062132810 --- /dev/null +++ b/vllm/v1/attention/backends/utils.py @@ -0,0 +1,13 @@ +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass + +import torch + + +@dataclass +class CommonAttentionMetadata: + """ + Attention Metadata that are same for different layer types. + """ + query_start_loc: torch.Tensor + seq_lens: torch.Tensor diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 8c45ca9a319f..bc7034ba1b9e 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -4,7 +4,9 @@ import triton import triton.language as tl -from vllm.config import VllmConfig, set_current_vllm_config +from vllm.attention.layer import Attention +from vllm.config import (VllmConfig, get_layers_from_vllm_config, + set_current_vllm_config) from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader.loader import get_model_loader @@ -214,6 +216,8 @@ def load_model(self, target_model: nn.Module) -> None: loader = get_model_loader(self.vllm_config.load_config) target_layer_num = self.vllm_config.model_config.get_num_layers( self.vllm_config.parallel_config) + target_attn_layer_names = set( + get_layers_from_vllm_config(self.vllm_config, Attention).keys()) draft_model_config = \ self.vllm_config.speculative_config.draft_model_config @@ -230,6 +234,11 @@ def load_model(self, target_model: nn.Module) -> None: model_config=draft_model_config, start_layer_id=target_layer_num).to(target_device) + draft_attn_layer_names = ( + get_layers_from_vllm_config(self.vllm_config, Attention).keys() - + target_attn_layer_names) + assert len(draft_attn_layer_names) == 1 + self.attn_layer_name = iter(draft_attn_layer_names).__next__() loaded_weights = self.model.load_weights( loader.get_all_weights( self.vllm_config.speculative_config.draft_model_config, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e3d8b94fe9d7..a573db570895 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3,7 +3,7 @@ import gc import time import weakref -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Optional, Union, cast import numpy as np import torch @@ -30,6 +30,7 @@ GiB_bytes, LayerBlockType, LazyLoader, cdiv, check_use_alibi, is_pin_memory_available) from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata +from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, KVCacheConfig, KVCacheSpec, @@ -157,9 +158,12 @@ def __init__( # Sampler self.sampler = Sampler() - # Lazy initialization + # Lazy initializations # self.model: nn.Module # Set after load_model + # Initialize in initialize_kv_cache self.kv_caches: list[torch.Tensor] = [] + self.kv_cache_config = cast(KVCacheConfig, None) + # req_id -> (input_id -> encoder_output) self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} @@ -488,7 +492,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: def _prepare_inputs( self, scheduler_output: "SchedulerOutput", - ) -> tuple[FlashAttentionMetadata, torch.Tensor, + ) -> tuple[dict[str, FlashAttentionMetadata], torch.Tensor, Optional[SpecDecodeMetadata]]: total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 @@ -585,20 +589,39 @@ def _prepare_inputs( self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True) - # Prepare for cascade attention if enabled & beneficial. - common_prefix_len = 0 - if self.cascade_attn_enabled: - common_prefix_len = self._compute_cascade_attn_prefix_len( - num_scheduled_tokens, - scheduler_output.num_common_prefix_blocks, - ) + query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to( + self.device, non_blocking=True) + seq_lens = self.seq_lens_cpu[:num_reqs].to(self.device, + non_blocking=True) + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=query_start_loc, seq_lens=seq_lens) + + attn_metadata: dict[str, FlashAttentionMetadata] = {} + # Prepare the attention metadata for each KV cache group and make layers + # in the same group share the same metadata. + # NOTE(Chen): there is exactly one KV cache group that contains all + # attetnion layers in the model for now, so the current logic for + # getting attn_metadata is not related to kv_cache_group information. + # Will extend this part to support multiple KV cache groups later. + for kv_cache_group_id, kv_cache_group_spec in enumerate( + self.kv_cache_config.kv_cache_groups): + + # Prepare for cascade attention if enabled & beneficial. + common_prefix_len = 0 + if self.cascade_attn_enabled: + common_prefix_len = self._compute_cascade_attn_prefix_len( + num_scheduled_tokens, + scheduler_output.num_common_prefix_blocks, + ) - attn_metadata = self.attn_metadata_builder.build( - num_reqs=num_reqs, - num_actual_tokens=total_num_scheduled_tokens, - max_query_len=max_num_scheduled_tokens, - common_prefix_len=common_prefix_len, - ) + attn_metadata_i = self.attn_metadata_builder.build( + num_reqs=num_reqs, + num_actual_tokens=total_num_scheduled_tokens, + max_query_len=max_num_scheduled_tokens, + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata) + for layer_name in kv_cache_group_spec.layer_names: + attn_metadata[layer_name] = attn_metadata_i use_spec_decode = len( scheduler_output.scheduled_spec_decode_tokens) > 0 @@ -608,7 +631,7 @@ def _prepare_inputs( # from these partial requests, we do so for simplicity. # We will ignore the sampled tokens from the partial requests. # TODO: Support prompt logprobs. - logits_indices = attn_metadata.query_start_loc[1:] - 1 + logits_indices = query_start_loc[1:] - 1 spec_decode_metadata = None else: # Get the number of draft tokens for each request. @@ -1036,7 +1059,6 @@ def execute_model( num_input_tokens = round_up(num_scheduled_tokens, tp_size) else: num_input_tokens = num_scheduled_tokens - attn_metadata.num_input_tokens = num_input_tokens # _prepare_inputs may reorder the batch, so we must gather multi # modal outputs after that to ensure the correct order @@ -1697,6 +1719,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: raise NotImplementedError( "Hybrid models with more than one KV cache type are not " "supported yet.") + self.kv_cache_config = kv_cache_config kv_caches: dict[str, torch.Tensor] = {} From dd08b5be05b742470894177d812461ba7bd60ba8 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 29 Apr 2025 07:25:19 -0700 Subject: [PATCH 04/15] updaet comment Signed-off-by: Chen Zhang --- vllm/v1/attention/backends/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index b59062132810..a71af00ee891 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -7,7 +7,8 @@ @dataclass class CommonAttentionMetadata: """ - Attention Metadata that are same for different layer types. + Attention metadata attributes that can be shared by layers in different KV + cache groups and thus having different block table. """ query_start_loc: torch.Tensor seq_lens: torch.Tensor From ab4389e0ade8bd2a7de29fd82e42cf79366d84ce Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 29 Apr 2025 07:45:28 -0700 Subject: [PATCH 05/15] update tpu code Signed-off-by: Chen Zhang --- vllm/v1/worker/tpu_model_runner.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index d716542f7898..692156d1e5ea 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -168,6 +168,8 @@ def __init__( # Lazy initialization # self.model: nn.Module # Set after load_model self.kv_caches: list[torch.Tensor] = [] + self.kv_cache_config = cast(KVCacheConfig, None) + # req_id -> (input_id -> encoder_output) self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} @@ -588,7 +590,13 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # Padded to avoid recompiling when `num_reqs` varies. logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1 logits_indices = logits_indices.to(self.device) - return attn_metadata, logits_indices, padded_num_reqs + + per_layer_attn_metadata = { + layer_name: attn_metadata + for layer_name in + self.kv_cache_config.kv_cache_groups[0].layer_names + } + return per_layer_attn_metadata, logits_indices, padded_num_reqs def _scatter_placeholders( self, @@ -1202,6 +1210,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: raise NotImplementedError( "Hybrid models with more than one KV cache type are not " "supported yet.") + self.kv_cache_config = kv_cache_config kv_caches: dict[str, torch.Tensor] = {} From 20a1d2279555c9f53850dc36e437812a82a7282b Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 29 Apr 2025 08:11:59 -0700 Subject: [PATCH 06/15] fix kv connector Signed-off-by: Chen Zhang --- vllm/attention/layer.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 358c055e78d9..07bb113b3b30 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -345,12 +345,12 @@ def wait_for_kv_layer_from_connector(layer_name: str): forward_context: ForwardContext = get_forward_context() - assert isinstance(forward_context.attn_metadata, dict) - attn_metadata = forward_context.attn_metadata[layer_name] + attn_metadata = forward_context.attn_metadata if attn_metadata is None: return + assert isinstance(attn_metadata, dict) - connector.wait_for_layer_load(layer_name) + connector.wait_for_layer_load(attn_metadata[layer_name]) def maybe_save_kv_layer_to_connector( @@ -363,12 +363,13 @@ def maybe_save_kv_layer_to_connector( connector = get_kv_transfer_group() forward_context: ForwardContext = get_forward_context() - assert isinstance(forward_context.attn_metadata, dict) - attn_metadata = forward_context.attn_metadata[layer_name] + attn_metadata = forward_context.attn_metadata if attn_metadata is None: return + assert isinstance(attn_metadata, dict) - connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata) + connector.save_kv_layer(layer_name, kv_cache_layer, + attn_metadata[layer_name]) def unified_attention( From 4679b4cefa9a60fbc2ee67976fde766c101e4ec4 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 30 Apr 2025 04:41:45 -0700 Subject: [PATCH 07/15] fix eagle Signed-off-by: Chen Zhang --- vllm/v1/worker/gpu_model_runner.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9938c8af8014..19ccb899b360 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1253,6 +1253,7 @@ def execute_model( next_token_ids = torch.tensor(next_token_ids, dtype=torch.int32, device=self.device) + eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name] if spec_decode_metadata is None: # input_ids can be None for multimodal models. @@ -1264,8 +1265,8 @@ def execute_model( dim=-1) else: target_hidden_states = hidden_states[:num_scheduled_tokens] - target_slot_mapping = attn_metadata.slot_mapping - cu_num_tokens = attn_metadata.query_start_loc + target_slot_mapping = eagle_attn_metadata.slot_mapping + cu_num_tokens = eagle_attn_metadata.query_start_loc else: # TODO(woosuk): Refactor this. num_draft_tokens = spec_decode_metadata.num_draft_tokens @@ -1279,7 +1280,7 @@ def execute_model( device=self.device, ) cu_num_tokens, token_indices = self.drafter.prepare_inputs( - attn_metadata.query_start_loc, + eagle_attn_metadata.query_start_loc, num_rejected_tokens, ) target_token_ids = self.input_ids[token_indices] @@ -1289,7 +1290,8 @@ def execute_model( [h[token_indices] for h in aux_hidden_states], dim=-1) else: target_hidden_states = hidden_states[token_indices] - target_slot_mapping = attn_metadata.slot_mapping[token_indices] + target_slot_mapping = eagle_attn_metadata.slot_mapping[ + token_indices] draft_token_ids = self.drafter.propose( target_token_ids=target_token_ids, @@ -1298,7 +1300,7 @@ def execute_model( target_slot_mapping=target_slot_mapping, next_token_ids=next_token_ids, cu_num_tokens=cu_num_tokens, - block_table=attn_metadata.block_table, + block_table=eagle_attn_metadata.block_table, sampling_metadata=sampling_metadata, ) spec_token_ids = draft_token_ids.tolist() From 1fbb06aeb459757327b4441b67a0705d7b816d06 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Thu, 1 May 2025 04:56:28 -0700 Subject: [PATCH 08/15] fix bug Signed-off-by: Chen Zhang --- vllm/attention/layer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 07bb113b3b30..66edf9f7af5a 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -344,13 +344,11 @@ def wait_for_kv_layer_from_connector(layer_name: str): connector = get_kv_transfer_group() forward_context: ForwardContext = get_forward_context() - attn_metadata = forward_context.attn_metadata if attn_metadata is None: return assert isinstance(attn_metadata, dict) - - connector.wait_for_layer_load(attn_metadata[layer_name]) + connector.wait_for_layer_load(layer_name) def maybe_save_kv_layer_to_connector( From bb68034b996be7d3453c98f5fd767e5a8cd07ad0 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Thu, 1 May 2025 05:00:09 -0700 Subject: [PATCH 09/15] fix Signed-off-by: Chen Zhang --- vllm/attention/layer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 66edf9f7af5a..9e4fbe0b4c6c 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -365,7 +365,6 @@ def maybe_save_kv_layer_to_connector( if attn_metadata is None: return assert isinstance(attn_metadata, dict) - connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata[layer_name]) From b68952379d8b749c948ae7a0d341fbe05e2dfd7a Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 5 May 2025 23:09:07 -0700 Subject: [PATCH 10/15] address review comments Signed-off-by: Chen Zhang --- vllm/v1/spec_decode/eagle.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 3e93f79da0c4..31c363762216 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -304,7 +304,7 @@ def load_model(self, target_model: nn.Module) -> None: get_layers_from_vllm_config(self.vllm_config, Attention).keys() - target_attn_layer_names) assert len(draft_attn_layer_names) == 1 - self.attn_layer_name = iter(draft_attn_layer_names).__next__() + self.attn_layer_name = next(iter(draft_attn_layer_names)) loaded_weights = self.model.load_weights( loader.get_all_weights(draft_model_config, self.model)) if self.vllm_config.speculative_config.method == "eagle3": diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 19ccb899b360..3a9b0a9816a8 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3,7 +3,7 @@ import gc import time import weakref -from typing import TYPE_CHECKING, Optional, Union, cast +from typing import TYPE_CHECKING, Optional, Union import numpy as np import torch @@ -162,7 +162,7 @@ def __init__( # self.model: nn.Module # Set after load_model # Initialize in initialize_kv_cache self.kv_caches: list[torch.Tensor] = [] - self.kv_cache_config = cast(KVCacheConfig, None) + # self.kv_cache_config: KVCacheConfig # req_id -> (input_id -> encoder_output) self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} From e3021c63011ee7b3917acc255e16e60faf4935ac Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 5 May 2025 23:24:50 -0700 Subject: [PATCH 11/15] add docstring to CommonAttentionMetadata Signed-off-by: Chen Zhang --- vllm/v1/attention/backends/utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index a71af00ee891..10a771e830b6 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -10,5 +10,9 @@ class CommonAttentionMetadata: Attention metadata attributes that can be shared by layers in different KV cache groups and thus having different block table. """ + query_start_loc: torch.Tensor + """(batch_size + 1,), the start location of each request in query Tensor""" seq_lens: torch.Tensor + """(batch_size,), the length of each request including both computed tokens + and newly scheduled tokens""" From 5b55ca216f6edc8178958e70bfa4ac6967891d37 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 5 May 2025 23:29:01 -0700 Subject: [PATCH 12/15] fix tpu Signed-off-by: Chen Zhang --- vllm/v1/worker/tpu_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 692156d1e5ea..6f959a561892 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -168,7 +168,7 @@ def __init__( # Lazy initialization # self.model: nn.Module # Set after load_model self.kv_caches: list[torch.Tensor] = [] - self.kv_cache_config = cast(KVCacheConfig, None) + # self.kv_cache_config: KVCacheConfig # req_id -> (input_id -> encoder_output) self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} From 4acfaff3618f1a241818652e49cca7818844e1eb Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 6 May 2025 07:12:21 +0000 Subject: [PATCH 13/15] fix tpu Signed-off-by: Chen Zhang --- vllm/v1/worker/tpu_model_runner.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 6f959a561892..68dd628ab4d4 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -591,10 +591,12 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1 logits_indices = logits_indices.to(self.device) + layer_names = ( + get_layers_from_vllm_config(self.vllm_config, Attention).keys() + ) per_layer_attn_metadata = { layer_name: attn_metadata - for layer_name in - self.kv_cache_config.kv_cache_groups[0].layer_names + for layer_name in layer_names } return per_layer_attn_metadata, logits_indices, padded_num_reqs @@ -954,7 +956,15 @@ def _dummy_run(self, num_tokens: int) -> None: torch._dynamo.mark_dynamic(position_ids, 0) torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) - with set_forward_context(attn_metadata, self.vllm_config, 0): + layer_names = ( + get_layers_from_vllm_config(self.vllm_config, Attention).keys() + ) + per_layer_attn_metadata = { + layer_name: attn_metadata + for layer_name in layer_names + } + + with set_forward_context(per_layer_attn_metadata, self.vllm_config, 0): out = self.model(input_ids=input_ids, positions=position_ids, inputs_embeds=inputs_embeds) From dd1ec7dba2b8a65b78b68c8289796f52e4f11521 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 6 May 2025 07:16:36 +0000 Subject: [PATCH 14/15] fix tpu Signed-off-by: Chen Zhang --- vllm/v1/worker/tpu_model_runner.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 123c20b2f5a3..59aa8a8745ff 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -168,8 +168,6 @@ def __init__( # Lazy initialization # self.model: nn.Module # Set after load_model self.kv_caches: list[torch.Tensor] = [] - # self.kv_cache_config: KVCacheConfig - # req_id -> (input_id -> encoder_output) self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} @@ -1247,7 +1245,6 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: raise NotImplementedError( "Hybrid models with more than one KV cache type are not " "supported yet.") - self.kv_cache_config = kv_cache_config kv_caches: dict[str, torch.Tensor] = {} From 55692ac5633d41d52ffd7ab6b008db771987af3f Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 6 May 2025 02:59:55 -0700 Subject: [PATCH 15/15] fix precommit Signed-off-by: Chen Zhang --- vllm/v1/worker/tpu_model_runner.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 123c20b2f5a3..33a2116c2b49 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -591,9 +591,8 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1 logits_indices = logits_indices.to(self.device) - layer_names = ( - get_layers_from_vllm_config(self.vllm_config, Attention).keys() - ) + layer_names = get_layers_from_vllm_config(self.vllm_config, + Attention).keys() per_layer_attn_metadata = { layer_name: attn_metadata for layer_name in layer_names @@ -966,9 +965,8 @@ def _dummy_run(self, num_tokens: int) -> None: torch._dynamo.mark_dynamic(position_ids, 0) torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) - layer_names = ( - get_layers_from_vllm_config(self.vllm_config, Attention).keys() - ) + layer_names = get_layers_from_vllm_config(self.vllm_config, + Attention).keys() per_layer_attn_metadata = { layer_name: attn_metadata for layer_name in layer_names