Skip to content

Commit f88c974

Browse files
heheda12345zhewenl
authored andcommitted
Support FlashAttention Backend for Hybrid SSM Models (vllm-project#23299)
Signed-off-by: Chen Zhang <[email protected]>
1 parent f664cd9 commit f88c974

File tree

2 files changed

+17
-27
lines changed

2 files changed

+17
-27
lines changed

tests/models/language/generation/test_hybrid.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,6 @@ def test_models(
110110
if model in V1_SUPPORTED_MODELS:
111111
with monkeypatch.context() as m:
112112
m.setenv("VLLM_USE_V1", "1")
113-
if model in HYBRID_MODELS:
114-
# required due to reorder_batch behaviour
115-
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
116113
with vllm_runner(model,
117114
max_num_seqs=MAX_NUM_SEQS,
118115
enable_prefix_caching=False) as vllm_model:

vllm/v1/worker/gpu_model_runner.py

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3023,40 +3023,33 @@ def _reshape_kv_cache_tensors(
30233023
raise NotImplementedError
30243024

30253025
if has_attn and has_mamba:
3026-
self._verify_hybrid_attention_mamba_layout(kv_cache_config,
3027-
kv_cache_raw_tensors)
3026+
self._update_hybrid_attention_mamba_layout(kv_caches)
30283027

30293028
return kv_caches
30303029

3031-
def _verify_hybrid_attention_mamba_layout(
3032-
self, kv_cache_config: KVCacheConfig,
3033-
kv_cache_raw_tensors: dict[str, torch.Tensor]) -> None:
3030+
def _update_hybrid_attention_mamba_layout(
3031+
self, kv_caches: dict[str, torch.Tensor]) -> None:
30343032
"""
3035-
Verify that the KV cache memory layout is compatible for
3036-
models with both attention and mamba KV cache groups.
3033+
Update the layout of attention layers from (2, num_blocks, ...) to
3034+
(num_blocks, 2, ...).
30373035
30383036
Args:
3039-
kv_cache_config: The KV cache config
3040-
kv_cache_raw_tensors: The KV cache buffer of each layer.
3037+
kv_caches: The KV cache buffer of each layer.
30413038
"""
30423039

30433040
for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator():
30443041
for layer_name in group.layer_names:
3045-
raw_tensor = kv_cache_raw_tensors[layer_name]
3046-
num_blocks = (raw_tensor.numel() //
3047-
kv_cache_spec.page_size_bytes)
3048-
if isinstance(kv_cache_spec, AttentionSpec):
3049-
3050-
kv_cache_shape = group.backend.get_kv_cache_shape(
3051-
num_blocks, kv_cache_spec.block_size,
3052-
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
3053-
if kv_cache_shape[0] != num_blocks or kv_cache_shape[
3054-
1] != 2:
3055-
raise ValueError(
3056-
"Hybrid models in V1 require an attention "
3057-
"backend with kv_cache_shape="
3058-
"(num_blocks, 2, ...). Please try setting "
3059-
"VLLM_ATTENTION_BACKEND=FLASHINFER")
3042+
kv_cache = kv_caches[layer_name]
3043+
if (isinstance(kv_cache_spec, AttentionSpec)
3044+
and kv_cache.shape[0] == 2):
3045+
assert kv_cache.shape[1] != 2, \
3046+
"Fail to determine whether the layout is " \
3047+
"(2, num_blocks, ...) or (num_blocks, 2, ...) for " \
3048+
f"a tensor of shape {kv_cache.shape}"
3049+
hidden_size = kv_cache.shape[2:].numel()
3050+
kv_cache.as_strided_(size=kv_cache.shape,
3051+
stride=(hidden_size, 2 * hidden_size,
3052+
*kv_cache.stride()[2:]))
30603053

30613054
def initialize_kv_cache_tensors(
30623055
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:

0 commit comments

Comments
 (0)