diff --git a/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py index 57a2b0393ba4..ec46d4045447 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py +++ b/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py @@ -2,13 +2,14 @@ import json import os -import pickle from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from typing import Optional, Union import torch import zmq +from safetensors.torch import load as safetensors_load +from safetensors.torch import save as safetensors_save from vllm.config import KVTransferConfig from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase @@ -237,14 +238,13 @@ def tensor_hash(self, tensor: torch.Tensor) -> int: return hash(tensor.data_ptr()) def _send_impl(self, tensor: torch.Tensor) -> None: - """Implement the tensor sending logic.""" - value_bytes = pickle.dumps(tensor) - self.transfer_engine.send_bytes(value_bytes) + """Implement the tensor sending logic using safetensors.""" + self.transfer_engine.send_bytes(safetensors_save({"tensor": tensor})) def _recv_impl(self) -> torch.Tensor: - """Implement the tensor receiving logic.""" + """Implement the tensor receiving logic using safetensors.""" data = self.transfer_engine.recv_bytes() - return pickle.loads(data) + return safetensors_load(data)["tensor"].to(self.device) def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: """Send tensor to the target process."""