-
Notifications
You must be signed in to change notification settings - Fork 67
feat: reuse pin_memory when registering checkpoint #56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -454,6 +454,7 @@ def _register_checkpoint( | |||||||||||||||||||||||||||||||||||||||||||||
| files: list[str], | ||||||||||||||||||||||||||||||||||||||||||||||
| named_tensors: dict[str, torch.Tensor], | ||||||||||||||||||||||||||||||||||||||||||||||
| rank: int | None = None, | ||||||||||||||||||||||||||||||||||||||||||||||
| shared_pin_memory: list[MemoryBuffer] | None = None, | ||||||||||||||||||||||||||||||||||||||||||||||
| ) -> list[MemoryBuffer]: | ||||||||||||||||||||||||||||||||||||||||||||||
| logger.info( | ||||||||||||||||||||||||||||||||||||||||||||||
| f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors" | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -483,16 +484,33 @@ class MemoryBucket(BaseModel): | |||||||||||||||||||||||||||||||||||||||||||||
| for bucket in buckets | ||||||||||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| def register_pin_memory(idx: int, size: int) -> tuple[int, torch.Tensor]: | ||||||||||||||||||||||||||||||||||||||||||||||
| buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True) | ||||||||||||||||||||||||||||||||||||||||||||||
| return idx, buffer | ||||||||||||||||||||||||||||||||||||||||||||||
| def register_pin_memory( | ||||||||||||||||||||||||||||||||||||||||||||||
| idx: int, size: int, shared_pin_memory: list[MemoryBuffer] | None = None | ||||||||||||||||||||||||||||||||||||||||||||||
| ) -> tuple[int, torch.Tensor]: | ||||||||||||||||||||||||||||||||||||||||||||||
| if shared_pin_memory: | ||||||||||||||||||||||||||||||||||||||||||||||
| # Reusing pin memory only support fixed shape of checkpoints, which is registered the first time | ||||||||||||||||||||||||||||||||||||||||||||||
| assert idx < len(shared_pin_memory), ( | ||||||||||||||||||||||||||||||||||||||||||||||
| f"idx {idx} should be less than shared_pin_memory length {len(shared_pin_memory)}" | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
| assert shared_pin_memory[idx].size == size, ( | ||||||||||||||||||||||||||||||||||||||||||||||
| f"shared_pin_memory[{idx}].size {shared_pin_memory[idx].size} should be equal to {size}" | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
| return idx, shared_pin_memory[idx].buffer | ||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||
| buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True) | ||||||||||||||||||||||||||||||||||||||||||||||
| return idx, buffer | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor): | ||||||||||||||||||||||||||||||||||||||||||||||
| buffer[offset : offset + tensor.nbytes] = tensor.view(-1).view(dtype=torch.uint8) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: | ||||||||||||||||||||||||||||||||||||||||||||||
| futures = [ | ||||||||||||||||||||||||||||||||||||||||||||||
| executor.submit(register_pin_memory, idx, bucket.size) | ||||||||||||||||||||||||||||||||||||||||||||||
| executor.submit( | ||||||||||||||||||||||||||||||||||||||||||||||
| register_pin_memory, | ||||||||||||||||||||||||||||||||||||||||||||||
| idx, | ||||||||||||||||||||||||||||||||||||||||||||||
| bucket.size, | ||||||||||||||||||||||||||||||||||||||||||||||
| shared_pin_memory, | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
| for idx, bucket in enumerate(buckets) | ||||||||||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||||||||||
| new_futures = [] | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -781,7 +799,11 @@ def __init__( | |||||||||||||||||||||||||||||||||||||||||||||
| self._zmq_ctx = zmq.Context() | ||||||||||||||||||||||||||||||||||||||||||||||
| self._zmq_addr_counter = 0 | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| self.shared_memory_pool_name = "__shared_memory_pool__" | ||||||||||||||||||||||||||||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we can place this into class ParameterServer:
_shared_memory_pool_name = "__shared_memory_pool__" |
||||||||||||||||||||||||||||||||||||||||||||||
| # stores the name of the checkpoint currently using the shared memory pool, or empty string if none | ||||||||||||||||||||||||||||||||||||||||||||||
| self._current_shared_memory_pool_user: str = "" | ||||||||||||||||||||||||||||||||||||||||||||||
| self._memory_pool: dict[str, list[MemoryBuffer]] = {} | ||||||||||||||||||||||||||||||||||||||||||||||
| self._memory_pool[self.shared_memory_pool_name] = [] | ||||||||||||||||||||||||||||||||||||||||||||||
| # dict key is owner_rank, value is a bucket metas list in owner_rank | ||||||||||||||||||||||||||||||||||||||||||||||
| self._current_global_parameter_metas: dict[int, MemoryBufferMetaList] = {} | ||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -795,6 +817,18 @@ def __init__( | |||||||||||||||||||||||||||||||||||||||||||||
| self._device_uuid = _get_physical_gpu_id(self.device_manager, device_index) | ||||||||||||||||||||||||||||||||||||||||||||||
| self._rdma_device = None if self._p2p_store is None else self._p2p_store.device | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| def _get_memory_pool(self, checkpoint_name: str) -> list[MemoryBuffer]: | ||||||||||||||||||||||||||||||||||||||||||||||
| if checkpoint_name == self._current_shared_memory_pool_user: | ||||||||||||||||||||||||||||||||||||||||||||||
| if not self._memory_pool[self.shared_memory_pool_name]: | ||||||||||||||||||||||||||||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we can use assert? |
||||||||||||||||||||||||||||||||||||||||||||||
| raise RuntimeError( | ||||||||||||||||||||||||||||||||||||||||||||||
| f"shared memory pool is not initialized, but checkpoint {checkpoint_name} is using it" | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
| return self._memory_pool[self.shared_memory_pool_name] | ||||||||||||||||||||||||||||||||||||||||||||||
| elif checkpoint_name in self._memory_pool: | ||||||||||||||||||||||||||||||||||||||||||||||
| return self._memory_pool[checkpoint_name] | ||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||
| raise RuntimeError(f"checkpoint {checkpoint_name} is not registered") | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| def _logger_rank0(self, msg: str): | ||||||||||||||||||||||||||||||||||||||||||||||
| if self._local_rank == 0: | ||||||||||||||||||||||||||||||||||||||||||||||
| logger.info(msg) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -818,6 +852,7 @@ def register_checkpoint( | |||||||||||||||||||||||||||||||||||||||||||||
| *, | ||||||||||||||||||||||||||||||||||||||||||||||
| files: list[str] | None = None, | ||||||||||||||||||||||||||||||||||||||||||||||
| named_tensors: dict[str, torch.Tensor] | None = None, | ||||||||||||||||||||||||||||||||||||||||||||||
| use_shared_memory_pool: bool = False, | ||||||||||||||||||||||||||||||||||||||||||||||
| ) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||
| Register a checkpoint to the parameter server. Both files and named_tensors will be registered together. | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -826,14 +861,34 @@ def register_checkpoint( | |||||||||||||||||||||||||||||||||||||||||||||
| checkpoint_name: The name of the checkpoint. | ||||||||||||||||||||||||||||||||||||||||||||||
| files: The safetensors files to register. | ||||||||||||||||||||||||||||||||||||||||||||||
| named_tensors: The named tensors to register. | ||||||||||||||||||||||||||||||||||||||||||||||
specture724 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||
| use_shared_memory_pool: If True, uses a reusable shared pin memory pool instead of allocating new memory. | ||||||||||||||||||||||||||||||||||||||||||||||
| Only one checkpoint can use the shared pool at a time. The pool's shape is fixed on first use and | ||||||||||||||||||||||||||||||||||||||||||||||
| cannot accommodate checkpoints with different memory requirements. | ||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||
| assert checkpoint_name not in self._memory_pool, ( | ||||||||||||||||||||||||||||||||||||||||||||||
| f"checkpoint {checkpoint_name} already registered" | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
| self._memory_pool[checkpoint_name] = _register_checkpoint( | ||||||||||||||||||||||||||||||||||||||||||||||
| files=files or [], named_tensors=named_tensors or {}, rank=self._rank | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
| if use_shared_memory_pool: | ||||||||||||||||||||||||||||||||||||||||||||||
| logger.info( | ||||||||||||||||||||||||||||||||||||||||||||||
| f"[rank{self._rank}] checkpoint {checkpoint_name} use shared memory pool" | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
| assert self._current_shared_memory_pool_user == "", ( | ||||||||||||||||||||||||||||||||||||||||||||||
| f"cannot register checkpoint {checkpoint_name} to shared memory pool, " | ||||||||||||||||||||||||||||||||||||||||||||||
| f"since checkpoint {self._current_shared_memory_pool_user} is already using shared memory pool. " | ||||||||||||||||||||||||||||||||||||||||||||||
| f"This registration may cause unexpected conflicts." | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
| self._memory_pool[self.shared_memory_pool_name] = _register_checkpoint( | ||||||||||||||||||||||||||||||||||||||||||||||
| files=files or [], | ||||||||||||||||||||||||||||||||||||||||||||||
| named_tensors=named_tensors or {}, | ||||||||||||||||||||||||||||||||||||||||||||||
| rank=self._rank, | ||||||||||||||||||||||||||||||||||||||||||||||
| shared_pin_memory=self._memory_pool[self.shared_memory_pool_name], | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
| self._current_shared_memory_pool_user = checkpoint_name | ||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||
| assert checkpoint_name not in self._memory_pool, ( | ||||||||||||||||||||||||||||||||||||||||||||||
| f"checkpoint {checkpoint_name} already registered" | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
| self._memory_pool[checkpoint_name] = _register_checkpoint( | ||||||||||||||||||||||||||||||||||||||||||||||
| files=files or [], named_tensors=named_tensors or {}, rank=self._rank | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
| if self._p2p_store is not None: | ||||||||||||||||||||||||||||||||||||||||||||||
| self._register_parameters_to_p2p_store(checkpoint_name) | ||||||||||||||||||||||||||||||||||||||||||||||
| except Exception: | ||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+878
to
894
|
||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -850,13 +905,26 @@ def unregister_checkpoint(self, checkpoint_name: str): | |||||||||||||||||||||||||||||||||||||||||||||
| Unregister a checkpoint from the parameter server. This function will also unregister the checkpoint | ||||||||||||||||||||||||||||||||||||||||||||||
| from p2p store if p2p store is initialized. | ||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||
| if checkpoint_name not in self._memory_pool: | ||||||||||||||||||||||||||||||||||||||||||||||
| if ( | ||||||||||||||||||||||||||||||||||||||||||||||
| checkpoint_name not in self._memory_pool | ||||||||||||||||||||||||||||||||||||||||||||||
| and checkpoint_name != self._current_shared_memory_pool_user | ||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||
| logger.warning( | ||||||||||||||||||||||||||||||||||||||||||||||
| f"[rank{self._rank}] unregister checkpoint failed, checkpoint name {checkpoint_name} not found" | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||||||||||
| if self._p2p_store is not None: | ||||||||||||||||||||||||||||||||||||||||||||||
| num_unregistered = self._unregister_parameters_from_p2p_store(checkpoint_name) | ||||||||||||||||||||||||||||||||||||||||||||||
| logger.info( | ||||||||||||||||||||||||||||||||||||||||||||||
| f"[rank{self._rank}] unregister {num_unregistered} parameters from p2p store for checkpoint {checkpoint_name}" | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
| if checkpoint_name == self._current_shared_memory_pool_user: | ||||||||||||||||||||||||||||||||||||||||||||||
| logger.info( | ||||||||||||||||||||||||||||||||||||||||||||||
| f"[rank{self._rank}] unregister shared memory pool from p2p store, skip unregistering from memory pool" | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
| self._current_shared_memory_pool_user = "" | ||||||||||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
855
to
+926
|
||||||||||||||||||||||||||||||||||||||||||||||
| if self._p2p_store is not None: | |
| num_unregistered = self._unregister_parameters_from_p2p_store(checkpoint_name) | |
| logger.info( | |
| f"[rank{self._rank}] unregister {num_unregistered} parameters from p2p store for checkpoint {checkpoint_name}" | |
| ) | |
| if checkpoint_name == self._current_shared_memory_pool_user: | |
| logger.info( | |
| f"[rank{self._rank}] unregister shared memory pool from p2p store, skip unregistering from memory pool" | |
| ) | |
| self._current_shared_memory_pool_user = "" | |
| return | |
| if checkpoint_name == self._current_shared_memory_pool_user: | |
| logger.info( | |
| f"[rank{self._rank}] unregister shared memory pool from p2p store, skip unregistering from memory pool" | |
| ) | |
| self._current_shared_memory_pool_user = "" | |
| return | |
| if self._p2p_store is not None: | |
| num_unregistered = self._unregister_parameters_from_p2p_store(checkpoint_name) | |
| logger.info( | |
| f"[rank{self._rank}] unregister {num_unregistered} parameters from p2p store for checkpoint {checkpoint_name}" | |
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,62 @@ | ||
| import os | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| from checkpoint_engine.ps import ParameterServer | ||
|
|
||
|
|
||
| def generate_dummy_checkpoint() -> dict[str, torch.Tensor]: | ||
| """ | ||
| Generate dummy checkpoint data | ||
| """ | ||
| named_tensors = { | ||
| "layer1.weight": torch.randn(1024, 1024), | ||
| "layer1.bias": torch.randn(1024), | ||
| "layer2.weight": torch.randn(2048, 1024), | ||
| "layer2.bias": torch.randn(2048), | ||
| } | ||
| return named_tensors | ||
|
|
||
|
|
||
| @pytest.mark.gpu | ||
| def test_register_pin_memory(): | ||
| os.environ["RANK"] = "0" | ||
| os.environ["WORLD_SIZE"] = "1" | ||
| ps = ParameterServer() | ||
| checkpoint1 = generate_dummy_checkpoint() | ||
| checkpoint_shared1 = generate_dummy_checkpoint() | ||
| checkpoint2 = generate_dummy_checkpoint() | ||
| checkpoint_shared2 = generate_dummy_checkpoint() | ||
| ps.register_checkpoint("test_checkpoint1", named_tensors=checkpoint1) | ||
| ps.unregister_checkpoint("test_checkpoint1") | ||
| assert "test_checkpoint1" not in ps._memory_pool | ||
| ps.register_checkpoint( | ||
| "test_checkpoint_shared1", named_tensors=checkpoint_shared1, use_shared_memory_pool=True | ||
| ) | ||
| ps.register_checkpoint("test_checkpoint2", named_tensors=checkpoint2) | ||
| assert "test_checkpoint_shared1" not in ps._memory_pool | ||
| assert "__shared_memory_pool__" in ps._memory_pool | ||
| assert ps._current_shared_memory_pool_user == "test_checkpoint_shared1" | ||
| assert "test_checkpoint2" in ps._memory_pool | ||
| ps.register_checkpoint( | ||
| "test_checkpoint_shared2", named_tensors=checkpoint_shared2, use_shared_memory_pool=True | ||
| ) # this will fail | ||
| assert "test_checkpoint_shared2" not in ps._memory_pool | ||
| assert ps._current_shared_memory_pool_user == "test_checkpoint_shared1" | ||
| ps.unregister_checkpoint("test_checkpoint_shared1") | ||
| assert ps._current_shared_memory_pool_user == "" | ||
| assert "__shared_memory_pool__" in ps._memory_pool | ||
| ps.register_checkpoint( | ||
| "test_checkpoint_shared2", named_tensors=checkpoint_shared2, use_shared_memory_pool=True | ||
| ) | ||
| assert "test_checkpoint_shared2" not in ps._memory_pool | ||
| assert "__shared_memory_pool__" in ps._memory_pool | ||
| assert ps._current_shared_memory_pool_user == "test_checkpoint_shared2" | ||
| ps.unregister_checkpoint("test_checkpoint1") | ||
| assert "test_checkpoint1" not in ps._memory_pool | ||
| ps.unregister_checkpoint("test_checkpoint2") | ||
| assert "test_checkpoint2" not in ps._memory_pool | ||
| ps.unregister_checkpoint("test_checkpoint_shared2") | ||
| assert ps._current_shared_memory_pool_user == "" | ||
| assert "__shared_memory_pool__" in ps._memory_pool |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think when
if shared_pin_memory, this executor thread may be unnecessary to submit