3030 GiB_bytes , LayerBlockType , LazyLoader , cdiv ,
3131 check_use_alibi , is_pin_memory_available )
3232from vllm .v1 .attention .backends .flash_attn import FlashAttentionMetadata
33+ from vllm .v1 .attention .backends .utils import CommonAttentionMetadata
3334from vllm .v1 .core .encoder_cache_manager import compute_encoder_budget
3435from vllm .v1 .kv_cache_interface import (AttentionSpec , FullAttentionSpec ,
3536 KVCacheConfig , KVCacheSpec ,
@@ -157,9 +158,12 @@ def __init__(
157158 # Sampler
158159 self .sampler = Sampler ()
159160
160- # Lazy initialization
161+ # Lazy initializations
161162 # self.model: nn.Module # Set after load_model
163+ # Initialize in initialize_kv_cache
162164 self .kv_caches : list [torch .Tensor ] = []
165+ # self.kv_cache_config: KVCacheConfig
166+
163167 # req_id -> (input_id -> encoder_output)
164168 self .encoder_cache : dict [str , dict [int , torch .Tensor ]] = {}
165169
@@ -488,7 +492,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
488492 def _prepare_inputs (
489493 self ,
490494 scheduler_output : "SchedulerOutput" ,
491- ) -> tuple [FlashAttentionMetadata , torch .Tensor ,
495+ ) -> tuple [dict [ str , FlashAttentionMetadata ] , torch .Tensor ,
492496 Optional [SpecDecodeMetadata ]]:
493497 total_num_scheduled_tokens = scheduler_output .total_num_scheduled_tokens
494498 assert total_num_scheduled_tokens > 0
@@ -585,20 +589,39 @@ def _prepare_inputs(
585589 self .positions_cpu [:total_num_scheduled_tokens ],
586590 non_blocking = True )
587591
588- # Prepare for cascade attention if enabled & beneficial.
589- common_prefix_len = 0
590- if self .cascade_attn_enabled :
591- common_prefix_len = self ._compute_cascade_attn_prefix_len (
592- num_scheduled_tokens ,
593- scheduler_output .num_common_prefix_blocks ,
594- )
592+ query_start_loc = self .query_start_loc_cpu [:num_reqs + 1 ].to (
593+ self .device , non_blocking = True )
594+ seq_lens = self .seq_lens_cpu [:num_reqs ].to (self .device ,
595+ non_blocking = True )
596+ common_attn_metadata = CommonAttentionMetadata (
597+ query_start_loc = query_start_loc , seq_lens = seq_lens )
598+
599+ attn_metadata : dict [str , FlashAttentionMetadata ] = {}
600+ # Prepare the attention metadata for each KV cache group and make layers
601+ # in the same group share the same metadata.
602+ # NOTE(Chen): there is exactly one KV cache group that contains all
603+ # attetnion layers in the model for now, so the current logic for
604+ # getting attn_metadata is not related to kv_cache_group information.
605+ # Will extend this part to support multiple KV cache groups later.
606+ for kv_cache_group_id , kv_cache_group_spec in enumerate (
607+ self .kv_cache_config .kv_cache_groups ):
608+
609+ # Prepare for cascade attention if enabled & beneficial.
610+ common_prefix_len = 0
611+ if self .cascade_attn_enabled :
612+ common_prefix_len = self ._compute_cascade_attn_prefix_len (
613+ num_scheduled_tokens ,
614+ scheduler_output .num_common_prefix_blocks ,
615+ )
595616
596- attn_metadata = self .attn_metadata_builder .build (
597- num_reqs = num_reqs ,
598- num_actual_tokens = total_num_scheduled_tokens ,
599- max_query_len = max_num_scheduled_tokens ,
600- common_prefix_len = common_prefix_len ,
601- )
617+ attn_metadata_i = self .attn_metadata_builder .build (
618+ num_reqs = num_reqs ,
619+ num_actual_tokens = total_num_scheduled_tokens ,
620+ max_query_len = max_num_scheduled_tokens ,
621+ common_prefix_len = common_prefix_len ,
622+ common_attn_metadata = common_attn_metadata )
623+ for layer_name in kv_cache_group_spec .layer_names :
624+ attn_metadata [layer_name ] = attn_metadata_i
602625
603626 use_spec_decode = len (
604627 scheduler_output .scheduled_spec_decode_tokens ) > 0
@@ -608,7 +631,7 @@ def _prepare_inputs(
608631 # from these partial requests, we do so for simplicity.
609632 # We will ignore the sampled tokens from the partial requests.
610633 # TODO: Support prompt logprobs.
611- logits_indices = attn_metadata . query_start_loc [1 :] - 1
634+ logits_indices = query_start_loc [1 :] - 1
612635 spec_decode_metadata = None
613636 else :
614637 # Get the number of draft tokens for each request.
@@ -1230,6 +1253,7 @@ def execute_model(
12301253 next_token_ids = torch .tensor (next_token_ids ,
12311254 dtype = torch .int32 ,
12321255 device = self .device )
1256+ eagle_attn_metadata = attn_metadata [self .drafter .attn_layer_name ]
12331257
12341258 if spec_decode_metadata is None :
12351259 # input_ids can be None for multimodal models.
@@ -1241,8 +1265,8 @@ def execute_model(
12411265 dim = - 1 )
12421266 else :
12431267 target_hidden_states = hidden_states [:num_scheduled_tokens ]
1244- target_slot_mapping = attn_metadata .slot_mapping
1245- cu_num_tokens = attn_metadata .query_start_loc
1268+ target_slot_mapping = eagle_attn_metadata .slot_mapping
1269+ cu_num_tokens = eagle_attn_metadata .query_start_loc
12461270 else :
12471271 # TODO(woosuk): Refactor this.
12481272 num_draft_tokens = spec_decode_metadata .num_draft_tokens
@@ -1256,7 +1280,7 @@ def execute_model(
12561280 device = self .device ,
12571281 )
12581282 cu_num_tokens , token_indices = self .drafter .prepare_inputs (
1259- attn_metadata .query_start_loc ,
1283+ eagle_attn_metadata .query_start_loc ,
12601284 num_rejected_tokens ,
12611285 )
12621286 target_token_ids = self .input_ids [token_indices ]
@@ -1266,7 +1290,8 @@ def execute_model(
12661290 [h [token_indices ] for h in aux_hidden_states ], dim = - 1 )
12671291 else :
12681292 target_hidden_states = hidden_states [token_indices ]
1269- target_slot_mapping = attn_metadata .slot_mapping [token_indices ]
1293+ target_slot_mapping = eagle_attn_metadata .slot_mapping [
1294+ token_indices ]
12701295
12711296 draft_token_ids = self .drafter .propose (
12721297 target_token_ids = target_token_ids ,
@@ -1275,7 +1300,7 @@ def execute_model(
12751300 target_slot_mapping = target_slot_mapping ,
12761301 next_token_ids = next_token_ids ,
12771302 cu_num_tokens = cu_num_tokens ,
1278- block_table = attn_metadata .block_table ,
1303+ block_table = eagle_attn_metadata .block_table ,
12791304 sampling_metadata = sampling_metadata ,
12801305 )
12811306 spec_token_ids = draft_token_ids .tolist ()
@@ -1708,6 +1733,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
17081733 raise NotImplementedError (
17091734 "Hybrid models with more than one KV cache type are not "
17101735 "supported yet." )
1736+ self .kv_cache_config = kv_cache_config
17111737
17121738 kv_caches : dict [str , torch .Tensor ] = {}
17131739
0 commit comments