-
-
Notifications
You must be signed in to change notification settings - Fork 11.5k
Description
Introduction and Background
In AIBrix, we have implemented a Distributed KV Cache to support high-capacity, cross-engine KV reuse. Our integration with vLLM aligns with the KV transfer connector framework introduced by PR #10502. Given these similarities and the recent vLLM v1 refactors, we have started refactoring our Distributed KV Cache feature to contribute common functionalities back to the upstream vLLM project. Our goal is to facilitate seamless support for KV cache offloading and cross-engine KV reuse across different cache backends.
This RFC aims to enhance vLLM's capability to offload KV cache to external KV cache services by extending the KV transfer connector framework, enabling more memory-efficient and scalable inference workloads.
Motivation
Although the current KV transfer connector framework enables a general way to offloading KV cache to different cache backends, several common functionalities are not provided for KV cache offloading for cross-engine KV reuse use cases:
Tensor Parallelism Aware Management: When vLLM uses tensor parallelism, each participating vLLM instance fetches KV tensors independently from the cache backend. In case of cache misses, before proceeding with prefill computation, participants must align the potentially different number of KV tensors fetched from the external KV cache service to ensure a consistent view .
Embedded Cache w/ CPU Memory: To meet performance requirements, it's common to have a small CPU memory-based cache embedded in the engine to avoid frequently accessing remote cache backends.
Selective KV Cache Offloading: Enables fine-grained control over offloading strategies and thus is crucial in optimizing performance across diverse deployment environments:
- Many cloud providers and companies deploy lower-end GPU instances without high-speed interconnects like RDMA, suited for tasks related to 7B/8B models running on 24/32GiB GPU cards. In these setups, GPUs within the same instance (typically 8-16 GPUs) share a single VPC NIC, leading to significant network bandwidth contention. Selective KV cache offloading (e.g., only offloading KV tensors identified by the employed eviction policy as hot rather than offloading all KV tensors) helps mitigate this issue by reducing unnecessary data transfers and conserving limited network bandwidth.
- Even in high-performance environments with RDMA-equipped GPUs, selective KV cache offloading can enhance efficiency by limiting the PCIe bandwidth consumed by remote data movement. While RDMA enables low-latency, high-bandwidth communication, remote data access still incurs higher latency than local memory access. By leveraging selective KV offloading, the framework reduces the frequency of remote data transfers, preserving PCIe bandwidth and ensuring that local memory access remains the preferred data pathway.
To achieve selective KV cache offloading, we introduce an eviction policy layer that can be extended and customized with advanced offloading strategies to determine which KV tensors should be offloaded. Within this layer, multiple callbacks are available to support different offloading modes, including offloading all KV tensors, only hot KV tensors, or only cold KV tensors, with the definition of "hot" and "cold" being determined by the specific eviction policy in use. In this initial PR, the framework will provide built-in support for LRU, FIFO, and S3FIFO eviction policies.
Implementation Details
KV Transfer Module Enhancement
- We will enhance the KV transfer module by introducing
KVTransferMetadata, which containscontext_tokens,prompt_len, andrecv_lenfor each sequence, enabling theOffloadingConnector(described below) to build keys used by the cache backend based oncontext_tokensandinput_tokens(frommodel_input) and to offload newly computed tokens based onrecv_len. - A new KV Transfer Connector named
OffloadingConnectorwill be introduced to interact with external cache backends. This connector will:
- Use
KVTransferMetadatato build cache keys. - Perform KV cache store and load by reusing
send_kv_caches_and_hidden_statesandrecv_kv_caches_and_hidden_statesAPIs of KV transfer connector
KV Cache Offloading Module
The KV cache offloading module consists of three key components:
GroupAwareKVCacheManager
- Coordinates tensor parallelism participants to always get the same amount of KV tensors
- Manages both
L1Cache(i.e., local cache embedded in the engine) andL2Cache(i.e., remote cache backend). L1CacheandL2Cachecan be enabled independently. Users can choose to use one or both of them based on their resource availability and performance requirements.- Supports configurable ingestion modes:
ALL: Store all KV tensors toL2Cache.HOT: Store only hot KV tensors toL2Cache.EVICTED: Store only evicted (i.e., cold) KV tensors toL2Cache.
L1Cache
- Has an allocator that supports GPU and CPU memory allocations
- Supports several eviction policies, especially scan-resistant eviction policies that can improve memory space efficiency and reduce contention on network bandwidth. LRU, FIFO, S3FIFO (scan-resistant) will be supported in the initial PR.
L2Cache
ConnectorAPI abstraction for external cache backends.- Already have RocksDB connector implementation, which is mainly used in unit tests right now and can also be used in scenarios that just use local CPU memory and SSDs.
MarshallerAPI abstraction for serializers and compressors.- Could be extended to support LLM workload-oriented cache placement.
Proposed APIs
Spec
@dataclass
class KVCacheLayerSpec:
"""The specification of the kv cache tensor for each layer.
Args:
size: The size of the kv cache layer in bytes. For FullAttention,
size = num_heads * head_dim * dtype_size
"""
size: int
@dataclass
class KVCacheTensorSpec:
"""The specification of the kv cache tensor.
Args:
heads: head ids. To support tensor parallelism.
layers: layer ids. To support pipeline parallellism.
layer_spec: layer specs.
"""
heads: List[int]
layers: List[int]
layer_specs: List[KVCacheLayerSpec]
class KVCacheBlockLayout(enum.Enum):
"""The layout of the kv cache block.
Args:
NCLD:
This layout signifies that the shape would be [num_tokens, 2 (k & v), num_layers, layer_dim].
|<-------------------------------- Block i ---------------------------------->|
|...|<------------------------------ Token j ---------------------------->|...|
|...|<---------------- K ------------->|<---------------- V ------------->|...|
|...|<-- Layer 0 ->||<-- Layer 1 ->|...|<-- Layer 0 ->||<-- Layer 1 ->|...|...|
For a heterogeneous tensor, its shape will be [num_tokens, 2, num_layers, [layer0_dim,
layer1_dim, ...]].
LCND:
This layout signifies that the shape would be [num_layers, 2 (k & v), num_tokens, layer_dim].
|<-------------------------------- Block i ---------------------------------->|
|...|<------------------------------ Layer j ---------------------------->|...|
|...|<---------------- K ------------->|<---------------- V ------------->|...|
|...|<-- Token 0 ->||<-- Token 1 ->|...|<-- Token 0 ->||<-- Token 1 ->|...|...|
For a heterogeneous tensor, its shape will be [num_layers, 2, num_tokens, [layer0_dim,
layer1_dim, ...]].
"""
NCLD = enum.auto()
LCND = enum.auto()
@dataclass
class KVCacheBlockSpec:
"""The specification of the kv cache block.
Args:
block_ntokens: The number of tokens in each block.
block_dtype: The dtype of the kv cache block.
block_layout: The layout of the kv cache block.
tensor_spec: The specification of the kv cache tensor.
"""
block_ntokens: int
block_dtype: torch.dtype
block_layout: KVCacheBlockLayout
tensor_spec: KVCacheTensorSpec
def __post_init__(self):
if self.block_ntokens <= 0:
raise ValueError("block_ntokens must be greater than 0.")
self.block_nbytes: int = (2 * self.block_ntokens *
self.block_dtype.itemsize *
sum(s.size
for s in self.tensor_spec.layer_specs))
self.block_shape: Tuple[int, ...] = self._get_block_shape()
self.block_shape_token_dim: int = 0 if self.block_layout == KVCacheBlockLayout.NCLD else 3
self.is_homogeneous: Callable[
[], bool] = lambda: isinstance(self.block_shape[-1], int)KVCacheManager
@dataclass
class KVCacheFeature:
"""The features of the kv cache.
Args:
zero_copy: Whether the kv cache supports zero-copy.
non_blocking_put: Whether the kv cache uses non-blocking put.
"""
zero_copy: bool = False
non_blocking_put: bool = False
class KVCacheHandle(ABC):
"""Cache handle to support zero-copy APIs.
"""
@abstractmethod
def to_tensors(self) -> Iterable[torch.Tensor]:
raise NotImplementedError
@abstractmethod
def release(self) -> None:
raise NotImplementedError
@abstractmethod
def __len__(self) -> int:
raise NotImplementedError
class KVCacheManager(ABC):
"""The KV cache manager.
Args:
config: The KV cache manager configuration.
"""
def __init__(self, config: KVCacheConfig) -> None:
...
@property
@abstractmethod
def feature(self) -> KVCacheFeature:
"""Get the feature of the kv cache.
Returns:
The feature of the kv cache.
"""
raise NotImplementedError
@property
@abstractmethod
def chunk_size(self) -> int:
"""Get the chunk size of the kv cache.
Returns:
The chunk size of the kv cache.
"""
raise NotImplementedError
@classmethod
def prefetch(self, prefix: Iterable[int] | None,
tokens: Iterable[int]) -> None:
"""(Optional) Prefetch the kv cache for the given prefix and tokens.
Args:
prefix: The prefix of the kv cache. E.g., [1, 2, 3]
tokens: The tokens of the kv cache. E.g., [4, 5, 6, 7]
"""
pass
@classmethod
def allocate(
self,
nblocks: int,
) -> Status[KVCacheHandle]:
"""(Optional) Allocate a cache handle that points to buffers owned by the kv
cache service.
Only the kv cache services supporting zero-copy need to implement this method.
Args:
nblocks: The number of blocks to allocate.
Returns:
The cache handle.
"""
raise NotImplementedError
@classmethod
def acquire(
self,
prefix: Iterable[int] | None,
tokens: Iterable[int],
) -> Status[Tuple[int, Iterable[KVCacheHandle]]]:
"""(Optional) Acquire cache handle of the kv tensors for the given prefix
and tokens. Only the kv cache services supporting zero-copy need to implement
this method.
The returned cache handle pointing to buffers owned by the kv cache service.
We can use "KVCacheHandle.to_tensors()" to get tensors sharing the same storage.
After the kv tensors are used, we need to explicitly `ref_down()` the cache handle
to let the kv cache service know that these buffers are not referenced anymore.
Args:
prefix: The prefix of the kv cache. E.g., [1, 2, 3]
tokens: The tokens of the kv cache. E.g., [4, 5, 6, 7]
Returns:
Number of tokens have been fetched from the kv cache service.
The cache handles corresponding to the given tokens.
"""
raise NotImplementedError
@abstractmethod
def get(
self,
prefix: Iterable[int] | None,
tokens: Iterable[int],
) -> Status[Tuple[int, torch.Tensor]]:
"""Get kv tensors from the kv cache service.
Args:
prefix: The prefix of the kv cache. E.g., [1, 2, 3]
tokens: The tokens of the kv cache. E.g., [4, 5, 6, 7]
Returns:
Number of tokens have been fetched from the kv cache service.
The kv tensors corresponding to the tokens:
Its layout matches the layout of the kv cache service.
For example, if the layout is NCLD, then:
The k, v tensors for i-th token at the j-th layer are kv_tensors[i][0[j]
and kv_tensors[i][1[j], respectively.
"""
raise NotImplementedError
@abstractmethod
def put(
self,
prefix: Iterable[int] | None,
tokens: Iterable[int],
kv_tensors: torch.Tensor | KVCacheHandle,
) -> Status[int]:
"""Put kv tensors to the kv cache service.
Args:
prefix: The prefix of the kv cache. E.g., [1, 2, 3]
tokens: The tokens of the kv cache. E.g., [4, 5, 6, 7]
kv_tensors:
The kv tensors to put into the kv cache.
The layout of kv_tensors must match the layout of the kv cache service.
For example, if the layout is NCLD, then:
The k, v tensors for i-th token at the j-th layer are kv_tensors[i][0[j]
and kv_tensors[i][1[j], respectively.
Returns:
The status of the put operation and the number of tokens have been put or
scheduled to put into the kv cache service.
"""
raise NotImplementedError
@abstractmethod
def delete(
self,
prefix: Iterable[int] | None,
tokens: Iterable[int],
) -> Status:
"""Delete kv tensors from the kv cache service.
Args:
prefix: The prefix of the kv cache. E.g., [1, 2, 3]
tokens: The tokens of the kv cache. E.g., [4, 5, 6, 7]
Returns:
The status of the delete operation.
"""
raise NotImplementedError
def flush(self) -> Status:
"""Flush the kv cache service.
Returns:
The status of the flush operation.
"""
return Status(StatusCodes.OK)
@abstractmethod
def cache_chunk_keys(
self, prefix: Iterable[int] | None, tokens: Iterable[int]
) -> Iterator[Iterator[Tuple[Iterable[int], Iterable[int],
Iterable[int]]]]:
"""Get the cache chunk keys.
Args:
prefix (Iterable[int] | None): The prefix tokens of the kv tensors.
tokens (Iterable[int]): The tokens of the kv tensors.
Returns:
chunk prefix tokens, chunk tokens, next chunk tokens
"""
raise NotImplementedError
@abstractmethod
def close(self) -> None:
"""Close the kv cache service."""
raise NotImplementedErrorBaseEvictionPolicy (in L1Cache)
class BaseEvictionPolicy(Generic[N, V]):
"""Base class for eviction policies."""
def __init__(
self,
name: str,
capacity: int,
evict_size: int = 1,
on_put: Functor | None = None,
on_evict: Functor | None = None,
on_hot_access: Functor | None = None,
) -> None:
"""Initialize the eviction policy.
Args:
name (str): The name of the eviction policy.
capacity(int): The capacity of the eviction policy in terms of number of items.
evict_size (int, optional): The number of items to evict at a time. Defaults to 1.
on_put (Functor, optional): The put function to call when putting new items. Defaults to None.
on_evict (Functor, optional): The evict function to call when evicting items. Defaults to None.
on_hot_access (Functor, optional): The callback function to call when a cache item becomes hot.
Defaults to None.
"""
self._name: str = name
self._capacity: int = capacity
self._evict_size: int = evict_size
self._on_put: Functor = on_put
self._on_evict: Functor = on_evict
self._on_hot_access: Functor = on_hot_access
self._hashmap: Dict[Hashable, N[V]] = {}
@property
def name(self) -> str:
"""Return the name of the eviction policy."""
return self._name
@property
def evict_size(self) -> int:
"""Return the number of items to evict at a time."""
return self._evict_size
@property
def capacity(self) -> int:
"""Return the capacity of the eviction policy in terms of number of items."""
return self._capacity
def __len__(self) -> int:
"""Return the number of items in the eviction policy."""
return len(self._hashmap)
def __contains__(self, key: Hashable) -> bool:
"""Return True if the key is in the eviction policy."""
return key in self._hashmap
def __getitem__(self, key: Hashable) -> V:
"""Return the value of the key."""
value = self.get(key, None)
if value is None:
raise KeyError(key)
return value
def __setitem__(self, key: Hashable, value: V) -> None:
"""Set the value of the key."""
self.put(key, value)
def __delitem__(self, key: Hashable) -> None:
"""Delete the key."""
self.delete(key)
def __iter__(self) -> Iterator[Hashable]:
"""Return an iterator over the keys in the eviction policy."""
return iter(self._hashmap.keys())
def set_on_put_callback(self, functor: Functor) -> None:
"""Set the callback function to call when putting new items."""
self._on_put = functor
def set_on_evict_callback(self, functor: Functor) -> None:
"""Set the callback function to call when evicting items."""
self._on_evict = functor
def set_on_hot_access_callback(self, functor: Functor) -> None:
"""Set the callback function to call when a cache item becomes hot."""
self._on_hot_access = functor
def items(self) -> Iterator[Tuple[Hashable, V]]:
"""Return an iterator over the key-value pairs in the eviction policy."""
return iter({(key, node.value) for key, node in self._hashmap.items()})
def keys(self) -> Iterator[Hashable]:
"""Return an iterator over the keys in the eviction policy."""
return iter(self._hashmap.keys())
def values(self) -> Iterator[V]:
"""Return an iterator over the values in the eviction policy."""
return iter({node.value for node in self._hashmap.values()})
def __repr__(self) -> str:
return f"{self._name}(capacity={self._capacity}, size={len(self)})"
def __str__(self) -> str:
return self.__repr__()
@abstractmethod
def put(self, key: Hashable, value: V) -> Status:
"""Put a key into the eviction policy.
Args:
key (Hashable): The key of the item.
value: The value of the item.
Returns:
Status: The status of the operation.
"""
raise NotImplementedError
@abstractmethod
def get(
self,
key: Hashable,
) -> Status[V]:
"""Get the value of key from the eviction policy.
Args:
key (Hashable): The key of the item.
Returns:
Status: The status of the operation.
"""
raise NotImplementedError
@abstractmethod
def peak(
self,
key: Hashable,
) -> Status[V]:
"""Peak the value of key from the eviction policy. Peak does not update the eviction policy.
Args:
key (Hashable): The key of the item.
Returns:
Status: The status of the operation.
"""
if key in self._hashmap:
node = self._hashmap[key]
return Status(value=node.value)
return Status(StatusCodes.NOT_FOUND)
@abstractmethod
def delete(self, key: Hashable) -> Status:
"""Delete a key-value pair from the eviction policy.
Args:
key (Hashable): The key of the item.
Returns:
Status: The status of the operation.
"""
raise NotImplementedError
@abstractmethod
def evict(self, size: int = 1) -> Status:
"""Evict a key-value pair from the eviction policy.
Args:
size (int, optional): The number of items to evict. Defaults to 1.
"""
raise NotImplementedError
@abstractmethod
def assert_consistency(self) -> None:
"""Check the consistency of the eviction policy. Only for test purpose."""
raise NotImplementedErrorConnector (in L2Cache)
@dataclass
class ConnectorFeature:
"""The features of the kv cache connector.
Args:
mput_mget: Whether the kv cache connector supports mput/mget
prefetch: Whether the kv cache connector supports prefetch.
zero_copy: Whether the kv cache connector supports zero-copy.
"""
mput_mget: bool = False
prefetch: bool = False
zero_copy: bool = False
class Connector(Generic[K, V]):
"""Connector interface."""
@classmethod
@abstractmethod
def from_envs(cls, conn_id: str):
"""Create a connector from environment variables."""
raise NotImplementedError
@property
@abstractmethod
def name(self) -> str:
raise NotImplementedError
@property
@abstractmethod
def feature(self) -> ConnectorFeature:
"""Get the feature of the connector.
Returns:
The feature of the kv cache service.
"""
raise NotImplementedError
@abstractmethod
def open(self) -> Status:
"""Open a connection."""
raise NotImplementedError
@abstractmethod
def close(self) -> Status:
"""Close a connection."""
raise NotImplementedError
async def prefetch(self, keys: Iterable[K]) -> None:
"""Prefetch a list of keys.
Args:
keys: The keys of the kv tensors.
"""
pass
@abstractmethod
async def get(self, key: K) -> Status[V]:
"""Get a value.
Args:
key: The key of the kv tensor.
Returns:
The value of the kv tensor.
"""
raise NotImplementedError
@abstractmethod
async def put(self, key: K, value: V) -> Status:
"""Put a key value pair.
Args:
key: The key of the kv cache.
value: The value of the kv cache.
Returns:
The status of the put operation.
"""
raise NotImplementedError
async def mget(self, keys: Iterable[K]) -> Iterable[Status[V]]:
"""MGet a list of values. This function is optional and only connectors
have mput_mget feature enabled can implement this function.
Args:
keys: The keys of the kv tensors.
Returns:
List of values.
"""
raise NotImplementedError
async def mput(self, keys: Iterable[K],
values: Iterable[V]) -> Iterable[Status]:
"""MPut a list of key value pairs. This function is optional and only connectors
have mput_mget feature enabled can implement this function.
Args:
keys: The keys of the kv tensors.
values: The values of the kv tensors.
Returns:
List of statuses.
"""
raise NotImplementedError
async def acquire(self, key: K) -> Status[KVCacheHandle]:
"""Acquire a kv cache handle pointing to the kv tensors. This function is
optional and only connectors have zero_copy feature enabled can implement
this function.
Args:
key: The key of the kv cache.
Returns:
The kv cache handle.
"""
raise NotImplementedError
@abstractmethod
async def delete(self, key: K) -> Status:
"""Delete a key.
Args:
key: The key of the kv cache.
Returns:
The status of the delete operation.
"""
raise NotImplementedErrorIntegration
class OffloadingConnector(KVConnectorBase):
"""OffloadingConnector is a KVConnector that offloads KV caches and hidden
states to the kv cache offloading service.
"""
def __init__(
self,
rank: int,
local_rank: int,
config: 'VllmConfig',
):
...
# init block spec
block_spec = KVCacheBlockSpec(
block_ntokens=block_ntokens,
block_dtype=block_dtype,
# NCLD layout is used for now
block_layout=KVCacheBlockLayout.NCLD,
tensor_spec=KVCacheTensorSpec(
heads=kv_head_ids,
layers=layer_ids,
# use the same size for all layers
layer_specs=[
KVCacheLayerSpec(size=num_kv_heads * head_size) for _ in layer_ids
],
),
)
config = KVCacheConfig(block_spec=block_spec)
# init cache
if parallel_config.tensor_parallel_size == 1:
self.cache = BaseKVCacheManager(config=config)
else:
pg = get_tp_group().cpu_group
self.cache = GroupAwareKVCacheManager(config=config, process_group=pg)
def close(self) -> None:
if self.cache:
self.cache.close()
self.cache = None
def send_kv_caches_and_hidden_states(
self,
model_executable: torch.nn.Module,
model_input: 'ModelInputForGPUWithSamplingMetadata',
kv_caches: List[torch.Tensor],
hidden_or_intermediate_states: Union[torch.Tensor,
'IntermediateTensors'],
) -> None:
...
# query_lens contains new KV caches that need to be offloaded
for seq_idx, query_len in enumerate(query_lens):
# prepare prefix and tokens
prefix = ...
tokens = ...
for chunk_prefix, chunk_tokens in self.cache.cache_chunk_keys(prefix, tokens):
# prepare kv tensors
chunk_kv_tensors = ...
# use allocate if cache supports zero copy
if self.cache_feature.zero_copy:
status = self.cache.allocate(len(chunk_tokens) // self.block_ntokens)
handles = status.value
tensors = handles.to_tensors()
# copy chunk_kv_tensors to tensors
...
status = self.cache.put(chunk_prefix, chunk_tokens, handles)
else:
# if non_blocking_put is True, we need to make a copy of chunk_kv_tensors
# to ensure it can be reused after the put operation
if self.cache_feature.non_blocking_put:
chunk_kv_tensors = make_copy(chunk_kv_tensors)
# put KV caches to offloading service
status = self.cache.put(chunk_prefix, chunk_tokens, chunk_kv_tensors)
...
def recv_kv_caches_and_hidden_states(
self, model_executable: torch.nn.Module,
model_input: 'ModelInputForGPUWithSamplingMetadata',
kv_caches: List[torch.Tensor]
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
'ModelInputForGPUWithSamplingMetadata']:
...
# query_lens contains new KV caches to be received
for seq_idx, query_len in enumerate(query_lens):
# prepare prefix and tokens
prefix = ...
tokens = ...
for chunk_prefix, chunk_tokens, next_tokens in self.cache.cache_chunk_keys(prefix, tokens):
if next_tokens and len(next_tokens) >= 0:
# prefetch
self.cahce.prefetch(chunk_prefix + chunk_tokens, next_tokens)
# get KV caches from offloading service
status = self.cache.acquire(chunk_prefix, chunk_tokens, chunk_kv_tensors)
if not status.is_ok():
# error handling
num_fetched_tokens, handle = status.value
if num_fetched_tokens == 0:
handle.release()
return
tensors = handle.to_tensors()
# load tensors to kv cache
...Feedback Period.
Feedback can be provided directly on PR. Based on the comments, we can update the RFC to elaborate.
CC List.
cc: @KuntaiDu @youkaichao @robertgshaw2-redhat @simon-mo @Jeffwan
Any Other Things.
Initial PR: [TBD]

