Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 83 additions & 15 deletions checkpoint_engine/ps.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,7 @@ def _register_checkpoint(
files: list[str],
named_tensors: dict[str, torch.Tensor],
rank: int | None = None,
shared_pin_memory: list[MemoryBuffer] | None = None,
) -> list[MemoryBuffer]:
logger.info(
f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors"
Expand Down Expand Up @@ -483,16 +484,33 @@ class MemoryBucket(BaseModel):
for bucket in buckets
]

def register_pin_memory(idx: int, size: int) -> tuple[int, torch.Tensor]:
buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True)
return idx, buffer
def register_pin_memory(
idx: int, size: int, shared_pin_memory: list[MemoryBuffer] | None = None
) -> tuple[int, torch.Tensor]:
if shared_pin_memory:
# Reusing pin memory only support fixed shape of checkpoints, which is registered the first time
assert idx < len(shared_pin_memory), (
f"idx {idx} should be less than shared_pin_memory length {len(shared_pin_memory)}"
)
assert shared_pin_memory[idx].size == size, (
f"shared_pin_memory[{idx}].size {shared_pin_memory[idx].size} should be equal to {size}"
)
return idx, shared_pin_memory[idx].buffer
else:
buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True)
return idx, buffer

def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor):
buffer[offset : offset + tensor.nbytes] = tensor.view(-1).view(dtype=torch.uint8)

with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
futures = [
executor.submit(register_pin_memory, idx, bucket.size)
executor.submit(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think when if shared_pin_memory, this executor thread may be unnecessary to submit

register_pin_memory,
idx,
bucket.size,
shared_pin_memory,
)
for idx, bucket in enumerate(buckets)
]
new_futures = []
Expand Down Expand Up @@ -781,7 +799,11 @@ def __init__(
self._zmq_ctx = zmq.Context()
self._zmq_addr_counter = 0

self.shared_memory_pool_name = "__shared_memory_pool__"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can place this into

class ParameterServer:
    _shared_memory_pool_name = "__shared_memory_pool__"

# stores the name of the checkpoint currently using the shared memory pool, or empty string if none
self._current_shared_memory_pool_user: str = ""
self._memory_pool: dict[str, list[MemoryBuffer]] = {}
self._memory_pool[self.shared_memory_pool_name] = []
# dict key is owner_rank, value is a bucket metas list in owner_rank
self._current_global_parameter_metas: dict[int, MemoryBufferMetaList] = {}
try:
Expand All @@ -795,6 +817,18 @@ def __init__(
self._device_uuid = _get_physical_gpu_id(self.device_manager, device_index)
self._rdma_device = None if self._p2p_store is None else self._p2p_store.device

def _get_memory_pool(self, checkpoint_name: str) -> list[MemoryBuffer]:
if checkpoint_name == self._current_shared_memory_pool_user:
if not self._memory_pool[self.shared_memory_pool_name]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can use assert?

raise RuntimeError(
f"shared memory pool is not initialized, but checkpoint {checkpoint_name} is using it"
)
return self._memory_pool[self.shared_memory_pool_name]
elif checkpoint_name in self._memory_pool:
return self._memory_pool[checkpoint_name]
else:
raise RuntimeError(f"checkpoint {checkpoint_name} is not registered")

def _logger_rank0(self, msg: str):
if self._local_rank == 0:
logger.info(msg)
Expand All @@ -818,6 +852,7 @@ def register_checkpoint(
*,
files: list[str] | None = None,
named_tensors: dict[str, torch.Tensor] | None = None,
use_shared_memory_pool: bool = False,
) -> None:
"""
Register a checkpoint to the parameter server. Both files and named_tensors will be registered together.
Expand All @@ -826,14 +861,34 @@ def register_checkpoint(
checkpoint_name: The name of the checkpoint.
files: The safetensors files to register.
named_tensors: The named tensors to register.
use_shared_memory_pool: If True, uses a reusable shared pin memory pool instead of allocating new memory.
Only one checkpoint can use the shared pool at a time. The pool's shape is fixed on first use and
cannot accommodate checkpoints with different memory requirements.
"""
try:
assert checkpoint_name not in self._memory_pool, (
f"checkpoint {checkpoint_name} already registered"
)
self._memory_pool[checkpoint_name] = _register_checkpoint(
files=files or [], named_tensors=named_tensors or {}, rank=self._rank
)
if use_shared_memory_pool:
logger.info(
f"[rank{self._rank}] checkpoint {checkpoint_name} use shared memory pool"
)
assert self._current_shared_memory_pool_user == "", (
f"cannot register checkpoint {checkpoint_name} to shared memory pool, "
f"since checkpoint {self._current_shared_memory_pool_user} is already using shared memory pool. "
f"This registration may cause unexpected conflicts."
)
self._memory_pool[self.shared_memory_pool_name] = _register_checkpoint(
files=files or [],
named_tensors=named_tensors or {},
rank=self._rank,
shared_pin_memory=self._memory_pool[self.shared_memory_pool_name],
)
self._current_shared_memory_pool_user = checkpoint_name
else:
assert checkpoint_name not in self._memory_pool, (
f"checkpoint {checkpoint_name} already registered"
)
self._memory_pool[checkpoint_name] = _register_checkpoint(
files=files or [], named_tensors=named_tensors or {}, rank=self._rank
)
if self._p2p_store is not None:
self._register_parameters_to_p2p_store(checkpoint_name)
except Exception:
Comment on lines +878 to 894
Copy link

Copilot AI Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a resource leak when registration of a shared memory pool checkpoint fails. If an exception occurs after setting self._current_shared_memory_pool_user = checkpoint_name (line 879) but before the checkpoint is fully registered, the exception handler at line 895 calls self.unregister_checkpoint(checkpoint_name). However, for shared memory pool users, unregister_checkpoint only clears _current_shared_memory_pool_user without cleaning up the shared memory pool itself, which may have been partially modified.

Consider resetting _current_shared_memory_pool_user to empty string in the exception handler before calling unregister_checkpoint, or handle shared memory pool cleanup differently in error cases.

Copilot uses AI. Check for mistakes.
Expand All @@ -850,13 +905,26 @@ def unregister_checkpoint(self, checkpoint_name: str):
Unregister a checkpoint from the parameter server. This function will also unregister the checkpoint
from p2p store if p2p store is initialized.
"""
if checkpoint_name not in self._memory_pool:
if (
checkpoint_name not in self._memory_pool
and checkpoint_name != self._current_shared_memory_pool_user
):
logger.warning(
f"[rank{self._rank}] unregister checkpoint failed, checkpoint name {checkpoint_name} not found"
)
return
if self._p2p_store is not None:
num_unregistered = self._unregister_parameters_from_p2p_store(checkpoint_name)
logger.info(
f"[rank{self._rank}] unregister {num_unregistered} parameters from p2p store for checkpoint {checkpoint_name}"
)
if checkpoint_name == self._current_shared_memory_pool_user:
logger.info(
f"[rank{self._rank}] unregister shared memory pool from p2p store, skip unregistering from memory pool"
)
self._current_shared_memory_pool_user = ""
return
Comment on lines 855 to +926
Copy link

Copilot AI Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When unregistering a checkpoint using the shared memory pool, the p2p_store is unregistered (lines 912-915) before checking if it's the shared memory pool user (line 916). This means that _unregister_parameters_from_p2p_store is called with the checkpoint name, which internally calls _get_memory_pool(checkpoint_name).

However, _get_memory_pool will correctly return the shared memory pool for the current user. The issue is that the p2p_store keys are generated using the pattern f"memory_pool_{checkpoint_name}_{idx}", so each checkpoint using the shared pool would have different p2p_store keys, even though they share the same underlying memory buffers. This could lead to p2p_store inconsistencies where multiple sets of keys point to the same memory, or failed unregistration attempts.

Suggested change
if self._p2p_store is not None:
num_unregistered = self._unregister_parameters_from_p2p_store(checkpoint_name)
logger.info(
f"[rank{self._rank}] unregister {num_unregistered} parameters from p2p store for checkpoint {checkpoint_name}"
)
if checkpoint_name == self._current_shared_memory_pool_user:
logger.info(
f"[rank{self._rank}] unregister shared memory pool from p2p store, skip unregistering from memory pool"
)
self._current_shared_memory_pool_user = ""
return
if checkpoint_name == self._current_shared_memory_pool_user:
logger.info(
f"[rank{self._rank}] unregister shared memory pool from p2p store, skip unregistering from memory pool"
)
self._current_shared_memory_pool_user = ""
return
if self._p2p_store is not None:
num_unregistered = self._unregister_parameters_from_p2p_store(checkpoint_name)
logger.info(
f"[rank{self._rank}] unregister {num_unregistered} parameters from p2p store for checkpoint {checkpoint_name}"
)

Copilot uses AI. Check for mistakes.

del self._memory_pool[checkpoint_name]
# see https://github.com/pytorch/pytorch/blob/31d5c675394705f8a6bc767f80ae14bf4f01246b/torch/csrc/cuda/Module.cpp#L2018
# this works by using torch>=2.5.0
Expand All @@ -879,7 +947,7 @@ def gather_metas(self, checkpoint_name: str):
ptr=x.buffer.data_ptr(),
size=x.size,
)
for x in self._memory_pool.get(checkpoint_name, [])
for x in (self._get_memory_pool(checkpoint_name) or [])
],
p2p_store_addr=None if self._p2p_store is None else self._p2p_store.addr,
host_ip=get_ip(),
Expand Down Expand Up @@ -1083,7 +1151,7 @@ def _copy_to_buffer(
remote_ptrs.append(ptrs[b.idx][0] + b.offset)
lens.append(b.size)
else:
pool = self._memory_pool[checkpoint_name][b.idx]
pool = self._get_memory_pool(checkpoint_name)[b.idx]
buffer[offset : offset + b.size].data.copy_(
pool.buffer[b.offset : b.offset + b.size],
non_blocking=True,
Expand Down Expand Up @@ -1142,7 +1210,7 @@ def _get_addr_ptrs(self, owner_rank: int) -> tuple[str, list[tuple[int, int]]]:

def _register_parameters_to_p2p_store(self, checkpoint_name: str):
assert self._p2p_store is not None, "p2p store is not initialized"
pool = self._memory_pool[checkpoint_name]
pool = self._get_memory_pool(checkpoint_name)
if len(pool) == 0:
return
named_tensors, tensor_ptrs = {}, []
Expand All @@ -1153,7 +1221,7 @@ def _register_parameters_to_p2p_store(self, checkpoint_name: str):

def _unregister_parameters_from_p2p_store(self, checkpoint_name: str) -> int:
assert self._p2p_store is not None, "p2p store is not initialized"
pool = self._memory_pool[checkpoint_name]
pool = self._get_memory_pool(checkpoint_name)
if len(pool) == 0:
return 0
return self._p2p_store.unregister_named_tensors(
Expand Down
62 changes: 62 additions & 0 deletions tests/test_pin_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import os

import pytest
import torch

from checkpoint_engine.ps import ParameterServer


def generate_dummy_checkpoint() -> dict[str, torch.Tensor]:
"""
Generate dummy checkpoint data
"""
named_tensors = {
"layer1.weight": torch.randn(1024, 1024),
"layer1.bias": torch.randn(1024),
"layer2.weight": torch.randn(2048, 1024),
"layer2.bias": torch.randn(2048),
}
return named_tensors


@pytest.mark.gpu
def test_register_pin_memory():
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
ps = ParameterServer()
checkpoint1 = generate_dummy_checkpoint()
checkpoint_shared1 = generate_dummy_checkpoint()
checkpoint2 = generate_dummy_checkpoint()
checkpoint_shared2 = generate_dummy_checkpoint()
ps.register_checkpoint("test_checkpoint1", named_tensors=checkpoint1)
ps.unregister_checkpoint("test_checkpoint1")
assert "test_checkpoint1" not in ps._memory_pool
ps.register_checkpoint(
"test_checkpoint_shared1", named_tensors=checkpoint_shared1, use_shared_memory_pool=True
)
ps.register_checkpoint("test_checkpoint2", named_tensors=checkpoint2)
assert "test_checkpoint_shared1" not in ps._memory_pool
assert "__shared_memory_pool__" in ps._memory_pool
assert ps._current_shared_memory_pool_user == "test_checkpoint_shared1"
assert "test_checkpoint2" in ps._memory_pool
ps.register_checkpoint(
"test_checkpoint_shared2", named_tensors=checkpoint_shared2, use_shared_memory_pool=True
) # this will fail
assert "test_checkpoint_shared2" not in ps._memory_pool
assert ps._current_shared_memory_pool_user == "test_checkpoint_shared1"
ps.unregister_checkpoint("test_checkpoint_shared1")
assert ps._current_shared_memory_pool_user == ""
assert "__shared_memory_pool__" in ps._memory_pool
ps.register_checkpoint(
"test_checkpoint_shared2", named_tensors=checkpoint_shared2, use_shared_memory_pool=True
)
assert "test_checkpoint_shared2" not in ps._memory_pool
assert "__shared_memory_pool__" in ps._memory_pool
assert ps._current_shared_memory_pool_user == "test_checkpoint_shared2"
ps.unregister_checkpoint("test_checkpoint1")
assert "test_checkpoint1" not in ps._memory_pool
ps.unregister_checkpoint("test_checkpoint2")
assert "test_checkpoint2" not in ps._memory_pool
ps.unregister_checkpoint("test_checkpoint_shared2")
assert ps._current_shared_memory_pool_user == ""
assert "__shared_memory_pool__" in ps._memory_pool