diff --git a/checkpoint_engine/ps.py b/checkpoint_engine/ps.py index e3f9c56..4da6436 100644 --- a/checkpoint_engine/ps.py +++ b/checkpoint_engine/ps.py @@ -454,6 +454,7 @@ def _register_checkpoint( files: list[str], named_tensors: dict[str, torch.Tensor], rank: int | None = None, + shared_pin_memory: list[MemoryBuffer] | None = None, ) -> list[MemoryBuffer]: logger.info( f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors" @@ -483,16 +484,34 @@ class MemoryBucket(BaseModel): for bucket in buckets ] - def register_pin_memory(idx: int, size: int) -> tuple[int, torch.Tensor]: - buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True) - return idx, buffer + def register_pin_memory( + idx: int, size: int, shared_pin_memory: list[MemoryBuffer] | None = None + ) -> tuple[int, torch.Tensor]: + if shared_pin_memory: + # If shared_pin_memory is provided, reuse the pin memory buffer, do not allocate new one + # Reusing pin memory only support fixed shape of checkpoints, which is registered the first time + assert idx < len(shared_pin_memory), ( + f"idx {idx} should be less than shared_pin_memory length {len(shared_pin_memory)}" + ) + assert shared_pin_memory[idx].size == size, ( + f"shared_pin_memory[{idx}].size {shared_pin_memory[idx].size} should be equal to {size}" + ) + return idx, shared_pin_memory[idx].buffer + else: + buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True) + return idx, buffer def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor): buffer[offset : offset + tensor.nbytes] = tensor.view(-1).view(dtype=torch.uint8) with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: futures = [ - executor.submit(register_pin_memory, idx, bucket.size) + executor.submit( + register_pin_memory, + idx, + bucket.size, + shared_pin_memory, + ) for idx, bucket in enumerate(buckets) ] new_futures = [] @@ -738,6 +757,8 @@ def batch_transfer_sync_read( class ParameterServer: + shared_memory_pool_name = "__shared_memory_pool__" + def __init__( self, *, @@ -781,7 +802,10 @@ def __init__( self._zmq_ctx = zmq.Context() self._zmq_addr_counter = 0 + # stores the name of the checkpoint currently using the shared memory pool, or empty string if none + self._current_shared_memory_pool_user: str = "" self._memory_pool: dict[str, list[MemoryBuffer]] = {} + self._memory_pool[self.shared_memory_pool_name] = [] # dict key is owner_rank, value is a bucket metas list in owner_rank self._current_global_parameter_metas: dict[int, MemoryBufferMetaList] = {} try: @@ -795,6 +819,17 @@ def __init__( self._device_uuid = _get_physical_gpu_id(self.device_manager, device_index) self._rdma_device = None if self._p2p_store is None else self._p2p_store.device + def _get_memory_pool(self, checkpoint_name: str) -> list[MemoryBuffer]: + if checkpoint_name == self._current_shared_memory_pool_user: + assert self._memory_pool[self.shared_memory_pool_name], ( + f"shared memory pool is not initialized, but checkpoint {checkpoint_name} is using it" + ) + return self._memory_pool[self.shared_memory_pool_name] + elif checkpoint_name in self._memory_pool: + return self._memory_pool[checkpoint_name] + else: + raise RuntimeError(f"checkpoint {checkpoint_name} is not registered") + def _logger_rank0(self, msg: str): if self._local_rank == 0: logger.info(msg) @@ -818,6 +853,7 @@ def register_checkpoint( *, files: list[str] | None = None, named_tensors: dict[str, torch.Tensor] | None = None, + use_shared_memory_pool: bool = False, ) -> None: """ Register a checkpoint to the parameter server. Both files and named_tensors will be registered together. @@ -826,21 +862,46 @@ def register_checkpoint( checkpoint_name: The name of the checkpoint. files: The safetensors files to register. named_tensors: The named tensors to register. + use_shared_memory_pool: If True, uses a reusable shared pin memory pool instead of allocating new memory. + Only one checkpoint can use the shared pool at a time. The pool's shape is fixed on first use and + cannot accommodate checkpoints with different memory requirements. """ try: - assert checkpoint_name not in self._memory_pool, ( - f"checkpoint {checkpoint_name} already registered" - ) - self._memory_pool[checkpoint_name] = _register_checkpoint( - files=files or [], named_tensors=named_tensors or {}, rank=self._rank - ) - if self._p2p_store is not None: - self._register_parameters_to_p2p_store(checkpoint_name) + if use_shared_memory_pool: + logger.info( + f"[rank{self._rank}] checkpoint {checkpoint_name} use shared memory pool" + ) + assert self._current_shared_memory_pool_user == "", ( + f"cannot register checkpoint {checkpoint_name} to shared memory pool, " + f"since checkpoint {self._current_shared_memory_pool_user} is already using shared memory pool. " + f"This registration may cause unexpected conflicts." + ) + # Since we set the uninitialized shared memory pool to empty list, + # we can check whether this is the first time to use shared memory pool + _is_first_time = not self._memory_pool[self.shared_memory_pool_name] + self._memory_pool[self.shared_memory_pool_name] = _register_checkpoint( + files=files or [], + named_tensors=named_tensors or {}, + rank=self._rank, + shared_pin_memory=self._memory_pool[self.shared_memory_pool_name], + ) + self._current_shared_memory_pool_user = checkpoint_name + if self._p2p_store is not None and _is_first_time: + self._register_parameters_to_p2p_store(checkpoint_name) + else: + assert checkpoint_name not in self._memory_pool, ( + f"checkpoint {checkpoint_name} already registered" + ) + self._memory_pool[checkpoint_name] = _register_checkpoint( + files=files or [], named_tensors=named_tensors or {}, rank=self._rank + ) + if self._p2p_store is not None: + self._register_parameters_to_p2p_store(checkpoint_name) except Exception: logger.exception( f"[rank{self._rank}] fail to register checkpoint {checkpoint_name} with files {files}" ) - if self._p2p_store is not None: + if self._p2p_store is not None and not use_shared_memory_pool: self._unregister_parameters_from_p2p_store(checkpoint_name) self.unregister_checkpoint(checkpoint_name) raise @@ -850,13 +911,25 @@ def unregister_checkpoint(self, checkpoint_name: str): Unregister a checkpoint from the parameter server. This function will also unregister the checkpoint from p2p store if p2p store is initialized. """ - if checkpoint_name not in self._memory_pool: + if ( + checkpoint_name not in self._memory_pool + and checkpoint_name != self._current_shared_memory_pool_user + ): + logger.warning( + f"[rank{self._rank}] unregister checkpoint name {checkpoint_name} not found" + ) + return + + if checkpoint_name == self._current_shared_memory_pool_user: + self._current_shared_memory_pool_user = "" return + if self._p2p_store is not None: num_unregistered = self._unregister_parameters_from_p2p_store(checkpoint_name) logger.info( f"[rank{self._rank}] unregister {num_unregistered} parameters from p2p store for checkpoint {checkpoint_name}" ) + del self._memory_pool[checkpoint_name] # see https://github.com/pytorch/pytorch/blob/31d5c675394705f8a6bc767f80ae14bf4f01246b/torch/csrc/cuda/Module.cpp#L2018 # this works by using torch>=2.5.0 @@ -879,7 +952,7 @@ def gather_metas(self, checkpoint_name: str): ptr=x.buffer.data_ptr(), size=x.size, ) - for x in self._memory_pool.get(checkpoint_name, []) + for x in (self._get_memory_pool(checkpoint_name) or []) ], p2p_store_addr=None if self._p2p_store is None else self._p2p_store.addr, host_ip=get_ip(), @@ -1083,7 +1156,7 @@ def _copy_to_buffer( remote_ptrs.append(ptrs[b.idx][0] + b.offset) lens.append(b.size) else: - pool = self._memory_pool[checkpoint_name][b.idx] + pool = self._get_memory_pool(checkpoint_name)[b.idx] buffer[offset : offset + b.size].data.copy_( pool.buffer[b.offset : b.offset + b.size], non_blocking=True, @@ -1142,7 +1215,7 @@ def _get_addr_ptrs(self, owner_rank: int) -> tuple[str, list[tuple[int, int]]]: def _register_parameters_to_p2p_store(self, checkpoint_name: str): assert self._p2p_store is not None, "p2p store is not initialized" - pool = self._memory_pool[checkpoint_name] + pool = self._get_memory_pool(checkpoint_name) if len(pool) == 0: return named_tensors, tensor_ptrs = {}, [] @@ -1153,7 +1226,7 @@ def _register_parameters_to_p2p_store(self, checkpoint_name: str): def _unregister_parameters_from_p2p_store(self, checkpoint_name: str) -> int: assert self._p2p_store is not None, "p2p store is not initialized" - pool = self._memory_pool[checkpoint_name] + pool = self._get_memory_pool(checkpoint_name) if len(pool) == 0: return 0 return self._p2p_store.unregister_named_tensors( diff --git a/tests/test_pin_memory.py b/tests/test_pin_memory.py new file mode 100644 index 0000000..8daa4ac --- /dev/null +++ b/tests/test_pin_memory.py @@ -0,0 +1,65 @@ +import os + +import pytest +import torch + +from checkpoint_engine.ps import ParameterServer + + +def generate_dummy_checkpoint() -> dict[str, torch.Tensor]: + """ + Generate dummy checkpoint data + """ + named_tensors = { + "layer1.weight": torch.randn(1024, 1024), + "layer1.bias": torch.randn(1024), + "layer2.weight": torch.randn(2048, 1024), + "layer2.bias": torch.randn(2048), + } + return named_tensors + + +@pytest.mark.gpu +def test_register_pin_memory(): + os.environ["RANK"] = "0" + os.environ["WORLD_SIZE"] = "1" + ps = ParameterServer() + checkpoint1 = generate_dummy_checkpoint() + checkpoint_shared1 = generate_dummy_checkpoint() + checkpoint2 = generate_dummy_checkpoint() + checkpoint_shared2 = generate_dummy_checkpoint() + ps.register_checkpoint("test_checkpoint1", named_tensors=checkpoint1) + ps.unregister_checkpoint("test_checkpoint1") + assert "test_checkpoint1" not in ps._memory_pool + ps.register_checkpoint( + "test_checkpoint_shared1", named_tensors=checkpoint_shared1, use_shared_memory_pool=True + ) + ps.register_checkpoint("test_checkpoint2", named_tensors=checkpoint2) + assert "test_checkpoint_shared1" not in ps._memory_pool + assert "__shared_memory_pool__" in ps._memory_pool + assert ps._current_shared_memory_pool_user == "test_checkpoint_shared1" + assert "test_checkpoint2" in ps._memory_pool + try: + ps.register_checkpoint( + "test_checkpoint_shared2", named_tensors=checkpoint_shared2, use_shared_memory_pool=True + ) # this will fail + except AssertionError: + print("Caught expected AssertionError when registering second shared memory pool user") + assert "test_checkpoint_shared2" not in ps._memory_pool + assert ps._current_shared_memory_pool_user == "test_checkpoint_shared1" + ps.unregister_checkpoint("test_checkpoint_shared1") + assert ps._current_shared_memory_pool_user == "" + assert "__shared_memory_pool__" in ps._memory_pool + ps.register_checkpoint( + "test_checkpoint_shared2", named_tensors=checkpoint_shared2, use_shared_memory_pool=True + ) + assert "test_checkpoint_shared2" not in ps._memory_pool + assert "__shared_memory_pool__" in ps._memory_pool + assert ps._current_shared_memory_pool_user == "test_checkpoint_shared2" + ps.unregister_checkpoint("test_checkpoint1") # this will trigger an warning + assert "test_checkpoint1" not in ps._memory_pool + ps.unregister_checkpoint("test_checkpoint2") + assert "test_checkpoint2" not in ps._memory_pool + ps.unregister_checkpoint("test_checkpoint_shared2") + assert ps._current_shared_memory_pool_user == "" + assert "__shared_memory_pool__" in ps._memory_pool