Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions vllm/distributed/kv_transfer/kv_connector/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,13 @@ def create_connector_v1(

KVConnectorFactory.register_connector(
"MooncakeStoreConnector",
"vllm.distributed.kv_transfer.kv_connector.mooncake_store_connector",
"MooncakeStoreConnector")
"vllm.distributed.kv_transfer.kv_connector.kv_store_connector",
"KVStoreConnector")

KVConnectorFactory.register_connector(
"FileStoreConnector",
"vllm.distributed.kv_transfer.kv_connector.kv_store_connector",
"KVStoreConnector")

KVConnectorFactory.register_connector(
"SharedStorageConnector",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
"""
MooncakeStore Connector for Distributed Machine Learning Inference
The MooncakeStoreConnector transfers KV caches between prefill vLLM workers
KVStore Connector for Distributed Machine Learning Inference
The KVStoreConnector transfers KV caches between prefill vLLM workers
(KV cache producer) and decode vLLM workers (KV cache consumer) using a
database-style KVStore.
"""
Expand All @@ -14,6 +14,8 @@
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
from vllm.distributed.kv_transfer.kv_connector.utils import (
model_aware_kv_ops_helper as kv_helper)
from vllm.distributed.kv_transfer.kv_lookup_buffer.base import (
KVStoreBufferBase)
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors

Expand All @@ -23,7 +25,7 @@
logger = init_logger(__name__)


class MooncakeStoreConnector(KVConnectorBase):
class KVStoreConnector(KVConnectorBase):

def __init__(
self,
Expand All @@ -49,9 +51,18 @@ def __init__(
from vllm.distributed.kv_transfer.kv_lookup_buffer.mooncake_store import ( # noqa: E501
MooncakeStore)
logger.info(
"Initializing KVStoreConnector under kv_transfer_config %s",
self.kv_transfer_config)
self.kv_store = MooncakeStore(config)
"Initializing MooncakeStoreConnector "
"under kv_transfer_config %s", self.kv_transfer_config)
self.kv_store: KVStoreBufferBase = MooncakeStore(config)
elif self.kv_transfer_config.kv_connector == "FileStoreConnector":
from vllm.distributed.kv_transfer.kv_lookup_buffer.file_store import ( # noqa: E501
FileStore)

# Init kv_store
self.kv_store = FileStore(config)
logger.info(
"Initializing FileStoreConnector under kv_transfer_config %s",
self.kv_transfer_config)
else:
logger.error("Can not find %s",
self.kv_transfer_config.kv_connector)
Expand Down
63 changes: 63 additions & 0 deletions vllm/distributed/kv_transfer/kv_lookup_buffer/file_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# SPDX-License-Identifier: Apache-2.0
"""Local file system based KV store implementation."""
import os
from typing import Optional

import torch
from safetensors.torch import load_file as safetensors_load
from safetensors.torch import save_file as safetensors_save

from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_lookup_buffer.base import (
KVStoreBufferBase)
from vllm.logger import init_logger

logger = init_logger(__name__)


class FileStore(KVStoreBufferBase):
"""KV store implementation using local filesystem with safetensors."""

def __init__(
self,
config: VllmConfig,
):
self.storage_path = config.kv_transfer_config.get_from_extra_config(
"fs_storage_path", "/tmp/vllm_kv_cache")
os.makedirs(self.storage_path, exist_ok=True)

def close(self):
"""No resources to clean up for file storage"""
pass

def put(self, key: str, value: Optional[torch.Tensor]) -> None:
"""Save tensor to file with key as filename."""
if value is None:
return

file_path = os.path.join(self.storage_path, f"{key}.safetensors")
device_id = value.device.index if value.device.type == 'cuda' else -1
device_tensor = torch.tensor(device_id, dtype=torch.int32)

safetensors_save({
"tensor": value.cpu(),
"device_id": device_tensor
}, file_path)

def get(self, key: str) -> Optional[torch.Tensor]:
"""Load tensor from file with key as filename."""
file_path = os.path.join(self.storage_path, f"{key}.safetensors")
if not os.path.exists(file_path):
return None

try:
data = safetensors_load(file_path)
tensor = data["tensor"]
device_id = int(data["device_id"].item())

device = torch.device(
'cuda', device_id) if device_id >= 0 else torch.device('cpu')
return tensor.to(device)
except Exception as e:
logger.error("Error loading tensor %s: %s", key, str(e))
return None