77
88from vllm .config import VllmConfig
99from vllm .logger import init_logger
10- from vllm .v1 .kv_cache_interface import (KVCacheConfig , KVCacheSpec ,
11- KVCacheTensor )
10+ from vllm .v1 .kv_cache_interface import (KVCacheConfig , KVCacheGroupSpec ,
11+ KVCacheSpec , KVCacheTensor )
1212from vllm .v1 .metrics .stats import PrefixCacheStats
1313from vllm .v1 .request import Request
1414
@@ -449,15 +449,15 @@ def hash_request_tokens(block_size: int,
449449
450450
451451def check_enough_kv_cache_memory (vllm_config : VllmConfig ,
452- kv_cache_spec : KVCacheSpec ,
452+ kv_cache_spec : dict [ str , KVCacheSpec ] ,
453453 available_memory : int ):
454454 """
455455 Checks whether `available_memory` is enough for the KV cache to hold at
456456 least one request with the model's max_model_len.
457457
458458 Args:
459459 vllm_config: The global VllmConfig
460- kv_cache_spec: The kv cache spec of the model
460+ kv_cache_spec: The kv cache spec of each attention layer in the model
461461 available_memory: Memory available for KV cache in bytes.
462462
463463 Raises:
@@ -484,12 +484,43 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig,
484484 f"`max_model_len` when initializing the engine." )
485485
486486
487- def is_kv_cache_type_uniform (kv_cache_spec : KVCacheSpec ) -> bool :
487+ def create_kv_cache_group_specs (
488+ kv_cache_spec : dict [str , KVCacheSpec ],
489+ grouped_layer_names : list [list [str ]]) -> list [KVCacheGroupSpec ]:
490+ """
491+ Create KVCacheGroupSpec object for each kv cache group layer.
492+ The layers in the same group should share the same
493+ KVCacheSpec.
494+
495+ Args:
496+ kv_cache_spec:
497+ A mapping from each layer name to its corresponding KVCacheSpec.
498+ grouped_layer_names:
499+ A list of kv cache groups, where each element is a list of layer
500+ names that belong to the same group and should share the same
501+ KVCacheSpec.
502+ Returns:
503+ A list of KVCacheGroupSpec objects, one for each group.
504+ """
505+ kv_cache_groups = []
506+ for layer_names_one_group in grouped_layer_names :
507+ layer_spec = kv_cache_spec [layer_names_one_group [0 ]]
508+ assert all (
509+ kv_cache_spec [layer_name ] == layer_spec
510+ for layer_name in layer_names_one_group [1 :]), (
511+ "All layers in the same KV cache group must share the same "
512+ "KVCacheSpec." )
513+ kv_cache_groups .append (
514+ KVCacheGroupSpec (layer_names_one_group , layer_spec ))
515+ return kv_cache_groups
516+
517+
518+ def is_kv_cache_type_uniform (kv_cache_spec : dict [str , KVCacheSpec ]) -> bool :
488519 """
489520 Whether all layers in the given KVCacheSpec have the same type of KV cache.
490521
491522 Args:
492- kv_cache_spec: The KVCacheSpec of the model
523+ kv_cache_spec: The kv cache spec of each attention layer in the model
493524
494525 Returns:
495526 True if all layers have the same type, False otherwise.
@@ -500,18 +531,16 @@ def is_kv_cache_type_uniform(kv_cache_spec: KVCacheSpec) -> bool:
500531
501532
502533def _get_kv_cache_config_uniform_type (vllm_config : VllmConfig ,
503- kv_cache_spec : KVCacheSpec ,
504- available_memory : int ,
505- num_layers : int ) -> KVCacheConfig :
534+ kv_cache_spec : dict [str , KVCacheSpec ],
535+ available_memory : int ) -> KVCacheConfig :
506536 """
507537 Generates the KV cache configuration for a model with one type of KV cache.
508538 Divide the available memory equally among all layers.
509539
510540 Args:
511541 vllm_config: The global VllmConfig
512- kv_cache_spec: The kv cache spec of the model
542+ kv_cache_spec: The kv cache spec of each attention layer in the model
513543 available_memory: Memory available for KV cache in bytes.
514- num_layers: The number of layers in the model.
515544
516545 Returns:
517546 The generated KVCacheConfig
@@ -521,7 +550,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
521550 assert len (page_sizes ) == 1
522551 page_size = page_sizes .pop ()
523552
524- num_blocks = int (available_memory // page_size // num_layers )
553+ num_blocks = int (available_memory // page_size // len ( kv_cache_spec ) )
525554 num_blocks = max (num_blocks , 0 )
526555
527556 if vllm_config .cache_config .num_gpu_blocks_override is not None :
@@ -541,48 +570,79 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
541570 max_model_len_str , max_concurrency )
542571
543572 per_layer_size = page_size * num_blocks
573+ # All layers have the same KV cache spec, so we create one kv cache group
574+ # for all layers.
575+ grouped_layer_names = [list (kv_cache_spec .keys ())]
544576
545577 kv_cache_config = KVCacheConfig (
546578 num_blocks = num_blocks ,
547579 tensors = {
548580 layer_name : KVCacheTensor (size = per_layer_size )
549581 for layer_name in kv_cache_spec
550582 },
551- groups = [[layer_name for layer_name in kv_cache_spec ]],
552- kv_cache_spec = kv_cache_spec )
583+ kv_cache_groups = create_kv_cache_group_specs (kv_cache_spec ,
584+ grouped_layer_names ),
585+ )
553586 return kv_cache_config
554587
555588
556- def get_kv_cache_configs (vllm_config : VllmConfig ,
557- kv_cache_specs : list [ KVCacheSpec ],
558- available_memory : int ) -> list [ KVCacheConfig ] :
589+ def get_kv_cache_config (vllm_config : VllmConfig ,
590+ kv_cache_spec : dict [ str , KVCacheSpec ],
591+ available_memory : int ) -> KVCacheConfig :
559592 """
560593 Generates the KV cache configuration for a model
561594 TODO: support hybrid models with more than one type of KV cache.
562595
563596 Args:
564597 vllm_config: The global VllmConfig
565- kv_cache_specs : The kv cache specs of the model
598+ kv_cache_spec : The kv cache spec of each attention layer in the model
566599 available_memory: Memory available for KV cache in bytes.
567600
568601 Returns:
569602 The generated KVCacheConfigs
570603 """
571- # Use the max number of layers to conservatively determine
572- # the number of blocks.
573- num_layers = max (len (kv_cache_spec ) for kv_cache_spec in kv_cache_specs )
574- kv_cache_configs = []
575- for kv_cache_spec in kv_cache_specs :
576- check_enough_kv_cache_memory (vllm_config , kv_cache_spec ,
577- available_memory )
578- if is_kv_cache_type_uniform (kv_cache_spec ):
579- # KV cache of all layers are the same, which is true for
580- # most models. Allocate the same amount of memory for
581- # each layer.
582- kv_cache_configs .append (
583- _get_kv_cache_config_uniform_type (vllm_config , kv_cache_spec ,
584- available_memory ,
585- num_layers ))
586- else :
587- raise NotImplementedError
604+ check_enough_kv_cache_memory (vllm_config , kv_cache_spec , available_memory )
605+ if is_kv_cache_type_uniform (kv_cache_spec ):
606+ # KV cache of all layers are the same, which is true for
607+ # most models. Allocate the same amount of memory for
608+ # each layer.
609+ return _get_kv_cache_config_uniform_type (vllm_config , kv_cache_spec ,
610+ available_memory )
611+
612+ raise NotImplementedError
613+
614+
615+ def unify_kv_cache_configs (kv_cache_configs : list [KVCacheConfig ]):
616+ """
617+ Make the KV cache configurations for each worker consistent, so that all
618+ workers can be controlled by the same KVCacheManager.
619+ This function verifies that the layer group of each worker are the same,
620+ and changes the num_blocks of each worker to the smallest among all workers.
621+
622+ Args:
623+ kv_cache_configs: The KV cache configurations for each worker. Will be
624+ in-place modified to make them consistent.
625+ """
626+
627+ # Sort the kv cache groups by the type_id of their KV cache spec.
628+ # This can avoid the inconsistency caused by the order of groups.
629+ for kv_cache_config in kv_cache_configs :
630+ kv_cache_config .kv_cache_groups .sort (
631+ key = lambda x : x .kv_cache_spec .type_id )
632+
633+ # Verify that the groups of each rank are the same.
634+ for kv_cache_config in kv_cache_configs [1 :]:
635+ for group_rank_0 , group_rank_i in zip (
636+ kv_cache_configs [0 ].kv_cache_groups ,
637+ kv_cache_config .kv_cache_groups ):
638+ assert group_rank_0 .kv_cache_spec == group_rank_i .kv_cache_spec
639+
640+ # Change the num_blocks of each rank to the smallest among all ranks. We
641+ # do not need to shrink the tensor size because it is valid to only use the
642+ # first `num_blocks` blocks of the tensor.
643+ min_num_blocks = min (kv_cache_config .num_blocks
644+ for kv_cache_config in kv_cache_configs )
645+ for kv_cache_config in kv_cache_configs :
646+ kv_cache_config .num_blocks = min_num_blocks
647+
588648 return kv_cache_configs
0 commit comments