Skip to content

Commit d0ad8d0

Browse files
heheda12345lulmer
authored andcommitted
[v1] Refactor KVCacheConfig (vllm-project#14079)
Signed-off-by: Chen Zhang <[email protected]> Signed-off-by: Louis Ulmer <[email protected]>
1 parent b23436d commit d0ad8d0

File tree

10 files changed

+320
-112
lines changed

10 files changed

+320
-112
lines changed

tests/v1/core/test_kv_cache_utils.py

Lines changed: 109 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import pytest
4+
import torch
45

56
from vllm.multimodal.inputs import MultiModalKwargs
67
from vllm.sampling_params import SamplingParams
78
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
89
KVCacheBlock, PrefixCachingMetrics,
910
generate_block_hash_extra_keys,
1011
hash_block_tokens,
11-
hash_request_tokens)
12+
hash_request_tokens,
13+
unify_kv_cache_configs)
14+
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
15+
KVCacheGroupSpec, KVCacheTensor)
1216
from vllm.v1.metrics.stats import PrefixCacheStats
1317
from vllm.v1.request import Request
1418

@@ -314,3 +318,107 @@ def stats(requests, queries, hits):
314318
assert metrics.aggregated_query_total == 0
315319
assert metrics.aggregated_query_hit == 0
316320
assert not metrics.query_queue
321+
322+
323+
def test_unify_kv_cache_configs():
324+
325+
def new_kv_cache_spec(block_size=16,
326+
num_kv_heads=2,
327+
head_size=64,
328+
dtype=torch.float32,
329+
use_mla=False):
330+
return FullAttentionSpec(block_size=block_size,
331+
num_kv_heads=num_kv_heads,
332+
head_size=head_size,
333+
dtype=dtype,
334+
use_mla=use_mla)
335+
336+
same_kv_cache_config = [
337+
KVCacheConfig(
338+
num_blocks=10,
339+
tensors={
340+
"layer1": KVCacheTensor(100),
341+
"layer2": KVCacheTensor(100),
342+
},
343+
kv_cache_groups=[
344+
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
345+
KVCacheGroupSpec(["layer2"],
346+
new_kv_cache_spec(num_kv_heads=4)),
347+
],
348+
),
349+
KVCacheConfig(
350+
num_blocks=20,
351+
tensors={
352+
"layer1": KVCacheTensor(100),
353+
"layer2": KVCacheTensor(100),
354+
},
355+
kv_cache_groups=[
356+
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
357+
KVCacheGroupSpec(["layer2"],
358+
new_kv_cache_spec(num_kv_heads=4)),
359+
],
360+
),
361+
]
362+
unify_kv_cache_configs(same_kv_cache_config)
363+
assert same_kv_cache_config[0].num_blocks == 10
364+
assert same_kv_cache_config[1].num_blocks == 10
365+
366+
need_sort_kv_cache_config = [
367+
KVCacheConfig(
368+
num_blocks=10,
369+
tensors={
370+
"layer1": KVCacheTensor(100),
371+
"layer2": KVCacheTensor(100),
372+
},
373+
kv_cache_groups=[
374+
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
375+
KVCacheGroupSpec(["layer2"],
376+
new_kv_cache_spec(num_kv_heads=4)),
377+
],
378+
),
379+
KVCacheConfig(
380+
num_blocks=20,
381+
tensors={
382+
"layer1": KVCacheTensor(100),
383+
"layer2": KVCacheTensor(100),
384+
},
385+
kv_cache_groups=[
386+
KVCacheGroupSpec(["layer2"],
387+
new_kv_cache_spec(num_kv_heads=4)),
388+
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
389+
],
390+
),
391+
]
392+
393+
unify_kv_cache_configs(need_sort_kv_cache_config)
394+
assert need_sort_kv_cache_config[0].num_blocks == 10
395+
assert need_sort_kv_cache_config[1].num_blocks == 10
396+
397+
diff_kv_cache_config = [
398+
KVCacheConfig(
399+
num_blocks=10,
400+
tensors={
401+
"layer1": KVCacheTensor(100),
402+
"layer2": KVCacheTensor(100),
403+
},
404+
kv_cache_groups=[
405+
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
406+
KVCacheGroupSpec(["layer2"],
407+
new_kv_cache_spec(num_kv_heads=4)),
408+
],
409+
),
410+
KVCacheConfig(
411+
num_blocks=20,
412+
tensors={
413+
"layer1": KVCacheTensor(100),
414+
"layer2": KVCacheTensor(100),
415+
},
416+
kv_cache_groups=[
417+
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
418+
KVCacheGroupSpec(["layer2"],
419+
new_kv_cache_spec(num_kv_heads=8)),
420+
],
421+
),
422+
]
423+
with pytest.raises(AssertionError):
424+
unify_kv_cache_configs(diff_kv_cache_config)

vllm/v1/core/kv_cache_utils.py

Lines changed: 95 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77

88
from vllm.config import VllmConfig
99
from vllm.logger import init_logger
10-
from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheSpec,
11-
KVCacheTensor)
10+
from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheGroupSpec,
11+
KVCacheSpec, KVCacheTensor)
1212
from vllm.v1.metrics.stats import PrefixCacheStats
1313
from vllm.v1.request import Request
1414

@@ -449,15 +449,15 @@ def hash_request_tokens(block_size: int,
449449

450450

451451
def check_enough_kv_cache_memory(vllm_config: VllmConfig,
452-
kv_cache_spec: KVCacheSpec,
452+
kv_cache_spec: dict[str, KVCacheSpec],
453453
available_memory: int):
454454
"""
455455
Checks whether `available_memory` is enough for the KV cache to hold at
456456
least one request with the model's max_model_len.
457457
458458
Args:
459459
vllm_config: The global VllmConfig
460-
kv_cache_spec: The kv cache spec of the model
460+
kv_cache_spec: The kv cache spec of each attention layer in the model
461461
available_memory: Memory available for KV cache in bytes.
462462
463463
Raises:
@@ -484,12 +484,43 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig,
484484
f"`max_model_len` when initializing the engine.")
485485

486486

487-
def is_kv_cache_type_uniform(kv_cache_spec: KVCacheSpec) -> bool:
487+
def create_kv_cache_group_specs(
488+
kv_cache_spec: dict[str, KVCacheSpec],
489+
grouped_layer_names: list[list[str]]) -> list[KVCacheGroupSpec]:
490+
"""
491+
Create KVCacheGroupSpec object for each kv cache group layer.
492+
The layers in the same group should share the same
493+
KVCacheSpec.
494+
495+
Args:
496+
kv_cache_spec:
497+
A mapping from each layer name to its corresponding KVCacheSpec.
498+
grouped_layer_names:
499+
A list of kv cache groups, where each element is a list of layer
500+
names that belong to the same group and should share the same
501+
KVCacheSpec.
502+
Returns:
503+
A list of KVCacheGroupSpec objects, one for each group.
504+
"""
505+
kv_cache_groups = []
506+
for layer_names_one_group in grouped_layer_names:
507+
layer_spec = kv_cache_spec[layer_names_one_group[0]]
508+
assert all(
509+
kv_cache_spec[layer_name] == layer_spec
510+
for layer_name in layer_names_one_group[1:]), (
511+
"All layers in the same KV cache group must share the same "
512+
"KVCacheSpec.")
513+
kv_cache_groups.append(
514+
KVCacheGroupSpec(layer_names_one_group, layer_spec))
515+
return kv_cache_groups
516+
517+
518+
def is_kv_cache_type_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
488519
"""
489520
Whether all layers in the given KVCacheSpec have the same type of KV cache.
490521
491522
Args:
492-
kv_cache_spec: The KVCacheSpec of the model
523+
kv_cache_spec: The kv cache spec of each attention layer in the model
493524
494525
Returns:
495526
True if all layers have the same type, False otherwise.
@@ -500,18 +531,16 @@ def is_kv_cache_type_uniform(kv_cache_spec: KVCacheSpec) -> bool:
500531

501532

502533
def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
503-
kv_cache_spec: KVCacheSpec,
504-
available_memory: int,
505-
num_layers: int) -> KVCacheConfig:
534+
kv_cache_spec: dict[str, KVCacheSpec],
535+
available_memory: int) -> KVCacheConfig:
506536
"""
507537
Generates the KV cache configuration for a model with one type of KV cache.
508538
Divide the available memory equally among all layers.
509539
510540
Args:
511541
vllm_config: The global VllmConfig
512-
kv_cache_spec: The kv cache spec of the model
542+
kv_cache_spec: The kv cache spec of each attention layer in the model
513543
available_memory: Memory available for KV cache in bytes.
514-
num_layers: The number of layers in the model.
515544
516545
Returns:
517546
The generated KVCacheConfig
@@ -521,7 +550,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
521550
assert len(page_sizes) == 1
522551
page_size = page_sizes.pop()
523552

524-
num_blocks = int(available_memory // page_size // num_layers)
553+
num_blocks = int(available_memory // page_size // len(kv_cache_spec))
525554
num_blocks = max(num_blocks, 0)
526555

527556
if vllm_config.cache_config.num_gpu_blocks_override is not None:
@@ -541,48 +570,79 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
541570
max_model_len_str, max_concurrency)
542571

543572
per_layer_size = page_size * num_blocks
573+
# All layers have the same KV cache spec, so we create one kv cache group
574+
# for all layers.
575+
grouped_layer_names = [list(kv_cache_spec.keys())]
544576

545577
kv_cache_config = KVCacheConfig(
546578
num_blocks=num_blocks,
547579
tensors={
548580
layer_name: KVCacheTensor(size=per_layer_size)
549581
for layer_name in kv_cache_spec
550582
},
551-
groups=[[layer_name for layer_name in kv_cache_spec]],
552-
kv_cache_spec=kv_cache_spec)
583+
kv_cache_groups=create_kv_cache_group_specs(kv_cache_spec,
584+
grouped_layer_names),
585+
)
553586
return kv_cache_config
554587

555588

556-
def get_kv_cache_configs(vllm_config: VllmConfig,
557-
kv_cache_specs: list[KVCacheSpec],
558-
available_memory: int) -> list[KVCacheConfig]:
589+
def get_kv_cache_config(vllm_config: VllmConfig,
590+
kv_cache_spec: dict[str, KVCacheSpec],
591+
available_memory: int) -> KVCacheConfig:
559592
"""
560593
Generates the KV cache configuration for a model
561594
TODO: support hybrid models with more than one type of KV cache.
562595
563596
Args:
564597
vllm_config: The global VllmConfig
565-
kv_cache_specs: The kv cache specs of the model
598+
kv_cache_spec: The kv cache spec of each attention layer in the model
566599
available_memory: Memory available for KV cache in bytes.
567600
568601
Returns:
569602
The generated KVCacheConfigs
570603
"""
571-
# Use the max number of layers to conservatively determine
572-
# the number of blocks.
573-
num_layers = max(len(kv_cache_spec) for kv_cache_spec in kv_cache_specs)
574-
kv_cache_configs = []
575-
for kv_cache_spec in kv_cache_specs:
576-
check_enough_kv_cache_memory(vllm_config, kv_cache_spec,
577-
available_memory)
578-
if is_kv_cache_type_uniform(kv_cache_spec):
579-
# KV cache of all layers are the same, which is true for
580-
# most models. Allocate the same amount of memory for
581-
# each layer.
582-
kv_cache_configs.append(
583-
_get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec,
584-
available_memory,
585-
num_layers))
586-
else:
587-
raise NotImplementedError
604+
check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory)
605+
if is_kv_cache_type_uniform(kv_cache_spec):
606+
# KV cache of all layers are the same, which is true for
607+
# most models. Allocate the same amount of memory for
608+
# each layer.
609+
return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec,
610+
available_memory)
611+
612+
raise NotImplementedError
613+
614+
615+
def unify_kv_cache_configs(kv_cache_configs: list[KVCacheConfig]):
616+
"""
617+
Make the KV cache configurations for each worker consistent, so that all
618+
workers can be controlled by the same KVCacheManager.
619+
This function verifies that the layer group of each worker are the same,
620+
and changes the num_blocks of each worker to the smallest among all workers.
621+
622+
Args:
623+
kv_cache_configs: The KV cache configurations for each worker. Will be
624+
in-place modified to make them consistent.
625+
"""
626+
627+
# Sort the kv cache groups by the type_id of their KV cache spec.
628+
# This can avoid the inconsistency caused by the order of groups.
629+
for kv_cache_config in kv_cache_configs:
630+
kv_cache_config.kv_cache_groups.sort(
631+
key=lambda x: x.kv_cache_spec.type_id)
632+
633+
# Verify that the groups of each rank are the same.
634+
for kv_cache_config in kv_cache_configs[1:]:
635+
for group_rank_0, group_rank_i in zip(
636+
kv_cache_configs[0].kv_cache_groups,
637+
kv_cache_config.kv_cache_groups):
638+
assert group_rank_0.kv_cache_spec == group_rank_i.kv_cache_spec
639+
640+
# Change the num_blocks of each rank to the smallest among all ranks. We
641+
# do not need to shrink the tensor size because it is valid to only use the
642+
# first `num_blocks` blocks of the tensor.
643+
min_num_blocks = min(kv_cache_config.num_blocks
644+
for kv_cache_config in kv_cache_configs)
645+
for kv_cache_config in kv_cache_configs:
646+
kv_cache_config.num_blocks = min_num_blocks
647+
588648
return kv_cache_configs

vllm/v1/engine/core.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
maybe_register_config_serialize_by_value)
2222
from vllm.utils import (get_exception_traceback, resolve_obj_by_qualname,
2323
zmq_socket_ctx)
24-
from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
24+
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
25+
unify_kv_cache_configs)
2526
from vllm.v1.core.sched.output import SchedulerOutput
2627
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
2728
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
@@ -120,15 +121,27 @@ def _initialize_kv_caches(self,
120121
# memory can be allocated for kv cache.
121122
available_gpu_memory = self.model_executor.determine_available_memory()
122123

124+
assert len(kv_cache_specs) == len(available_gpu_memory)
123125
# Get the kv cache tensor size
124-
kv_cache_configs = get_kv_cache_configs(vllm_config, kv_cache_specs,
125-
available_gpu_memory)
126-
num_gpu_blocks_set = set(config.num_blocks
127-
for config in kv_cache_configs)
128-
assert len(num_gpu_blocks_set) == 1, (
129-
f"num_gpu_blocks need to be the same across workers, "
130-
f"but they are different: {num_gpu_blocks_set}")
131-
num_gpu_blocks = num_gpu_blocks_set.pop()
126+
kv_cache_configs = [
127+
get_kv_cache_config(vllm_config, kv_cache_spec_one_worker,
128+
available_gpu_memory_one_worker)
129+
for kv_cache_spec_one_worker, available_gpu_memory_one_worker in
130+
zip(kv_cache_specs, available_gpu_memory)
131+
]
132+
133+
# Since we use a shared centralized controller, we need the
134+
# `kv_cache_config` to be consistent across all workers to make sure
135+
# all the memory operators can be applied to all workers.
136+
unify_kv_cache_configs(kv_cache_configs)
137+
138+
# All workers have the same kv_cache_config except layer names, so use
139+
# an arbitrary one to get the number of blocks.
140+
assert all([
141+
cfg.num_blocks == kv_cache_configs[0].num_blocks
142+
for cfg in kv_cache_configs
143+
])
144+
num_gpu_blocks = kv_cache_configs[0].num_blocks
132145
num_cpu_blocks = 0
133146

134147
# Initialize kv cache and warmup the execution

0 commit comments

Comments
 (0)