@@ -3023,40 +3023,33 @@ def _reshape_kv_cache_tensors(
30233023 raise NotImplementedError
30243024
30253025 if has_attn and has_mamba :
3026- self ._verify_hybrid_attention_mamba_layout (kv_cache_config ,
3027- kv_cache_raw_tensors )
3026+ self ._update_hybrid_attention_mamba_layout (kv_caches )
30283027
30293028 return kv_caches
30303029
3031- def _verify_hybrid_attention_mamba_layout (
3032- self , kv_cache_config : KVCacheConfig ,
3033- kv_cache_raw_tensors : dict [str , torch .Tensor ]) -> None :
3030+ def _update_hybrid_attention_mamba_layout (
3031+ self , kv_caches : dict [str , torch .Tensor ]) -> None :
30343032 """
3035- Verify that the KV cache memory layout is compatible for
3036- models with both attention and mamba KV cache groups .
3033+ Update the layout of attention layers from (2, num_blocks, ...) to
3034+ (num_blocks, 2, ...) .
30373035
30383036 Args:
3039- kv_cache_config: The KV cache config
3040- kv_cache_raw_tensors: The KV cache buffer of each layer.
3037+ kv_caches: The KV cache buffer of each layer.
30413038 """
30423039
30433040 for kv_cache_spec , group in self ._kv_cache_spec_attn_group_iterator ():
30443041 for layer_name in group .layer_names :
3045- raw_tensor = kv_cache_raw_tensors [layer_name ]
3046- num_blocks = (raw_tensor .numel () //
3047- kv_cache_spec .page_size_bytes )
3048- if isinstance (kv_cache_spec , AttentionSpec ):
3049-
3050- kv_cache_shape = group .backend .get_kv_cache_shape (
3051- num_blocks , kv_cache_spec .block_size ,
3052- kv_cache_spec .num_kv_heads , kv_cache_spec .head_size )
3053- if kv_cache_shape [0 ] != num_blocks or kv_cache_shape [
3054- 1 ] != 2 :
3055- raise ValueError (
3056- "Hybrid models in V1 require an attention "
3057- "backend with kv_cache_shape="
3058- "(num_blocks, 2, ...). Please try setting "
3059- "VLLM_ATTENTION_BACKEND=FLASHINFER" )
3042+ kv_cache = kv_caches [layer_name ]
3043+ if (isinstance (kv_cache_spec , AttentionSpec )
3044+ and kv_cache .shape [0 ] == 2 ):
3045+ assert kv_cache .shape [1 ] != 2 , \
3046+ "Fail to determine whether the layout is " \
3047+ "(2, num_blocks, ...) or (num_blocks, 2, ...) for " \
3048+ f"a tensor of shape { kv_cache .shape } "
3049+ hidden_size = kv_cache .shape [2 :].numel ()
3050+ kv_cache .as_strided_ (size = kv_cache .shape ,
3051+ stride = (hidden_size , 2 * hidden_size ,
3052+ * kv_cache .stride ()[2 :]))
30603053
30613054 def initialize_kv_cache_tensors (
30623055 self , kv_cache_config : KVCacheConfig ) -> dict [str , torch .Tensor ]:
0 commit comments