From 0e1ea4ca46bbf3aaf992e0d6e56afb602494d3c4 Mon Sep 17 00:00:00 2001 From: Schwidola0607 Date: Wed, 26 Mar 2025 01:25:53 -0500 Subject: [PATCH 01/12] add support for HF2UCP feature Signed-off-by: Schwidola0607 --- deepspeed/checkpoint/hf_to_universal.py | 227 ++++++++++++++++++++++++ deepspeed/runtime/base_optimizer.py | 23 ++- deepspeed/runtime/engine.py | 33 ++-- deepspeed/runtime/state_dict_factory.py | 1 - deepspeed/runtime/zero/config.py | 7 +- deepspeed/runtime/zero/stage3.py | 39 ++-- deepspeed/runtime/zero/stage_1_and_2.py | 9 +- 7 files changed, 297 insertions(+), 42 deletions(-) create mode 100644 deepspeed/checkpoint/hf_to_universal.py diff --git a/deepspeed/checkpoint/hf_to_universal.py b/deepspeed/checkpoint/hf_to_universal.py new file mode 100644 index 000000000000..9867bd442a02 --- /dev/null +++ b/deepspeed/checkpoint/hf_to_universal.py @@ -0,0 +1,227 @@ +import torch +import os +import shutil +import logging +from concurrent.futures import ProcessPoolExecutor, as_completed +from tqdm import tqdm +from typing import List + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Constants for parameter patterns +VOCAB_PARAMETER_PATTERNS = [ + 'word_embeddings', + 'embed_tokens', + 'embedding', + 'wte', # GPT style embeddings + 'lm_head' # Often tied with embeddings +] + +ROW_PARALLEL_PATTERNS = [ + 'dense_h_to_4h', + 'fc1', + 'k_proj', + 'v_proj', + 'q_proj', + 'gate_proj', + 'up_proj' +] + +def get_parameter_type(name: str) -> dict: + """Determine parameter type and required fields based on name.""" + param_info = { + 'cat_dim': 0 # Default concatenation dimension + } + + # Check for vocabulary tensors (embeddings, etc.) + if any(pattern in name.lower() for pattern in VOCAB_PARAMETER_PATTERNS): + param_info['vocab_tensor'] = True + + # TODO: figure out if this is needed + # # Check for row-parallel parameters + # if any(pattern in name.lower() for pattern in ROW_PARALLEL_PATTERNS): + # param_info['cat_dim'] = 1 # Use dimension 1 for row-parallel parameters + + return param_info + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser(description='Load a HuggingFace model') + parser.add_argument('--hf_checkpoint_dir', type=str, help='Path to the HuggingFace checkpoint directory') + parser.add_argument('--safe_serialization', action='store_true', default=False, help='Use safetensors for serialization') + parser.add_argument('--num_workers', type=int, default=4, help='Number of workers to use for saving checkpoints') + parser.add_argument('--save_dir', type=str, required=True, help='Directory to save checkpoints') + args = parser.parse_args() + + # Create a temporary directory for atomic operations + temp_save_dir = args.save_dir + '.tmp' + + def save_parameter(name: str, param: torch.Tensor, save_dir: str): + """Save a parameter and its optimizer states in universal format.""" + # Create parameter directory under zero/ + param_dir = os.path.join(save_dir, name) + os.makedirs(param_dir, exist_ok=True) + + # Get parameter type and required fields + param_info = get_parameter_type(name) + + # Save parameter in fp32 with proper dictionary structure + param_path = os.path.join(param_dir, "fp32.pt") + param_dict = { + 'param': param.to(torch.float32), # Main tensor goes in 'param' field + **param_info # Include all determined parameter info + } + torch.save(param_dict, param_path) + + # Initialize optimizer states with zeros + for state in ("exp_avg", "exp_avg_sq"): + state_path = os.path.join(param_dir, f"{state}.pt") + state_dict = { + 'param': torch.zeros_like(param, dtype=torch.float32), + **param_info # Include same parameter info in optimizer states + } + torch.save(state_dict, state_path) + + def process_shard(shard_file, checkpoint_dir, save_dir, safe_serialization): + """Process a single shard file.""" + try: + shard_path = os.path.join(checkpoint_dir, shard_file) + logger.info(f"Loading shard from: {shard_path}") + + if safe_serialization: + from safetensors.torch import load_file + shard_dict = load_file(shard_path) + else: + shard_dict = torch.load(shard_path, map_location='cpu') + + # Create progress bar for parameters within this shard + pbar = tqdm(total=len(shard_dict), + desc=f"Processing {os.path.basename(shard_file)}", + position=1, + leave=False) + + for key, param in shard_dict.items(): + save_parameter(key, param, save_dir) + del param + pbar.update(1) + pbar.set_postfix({'key': key[:20] + '...' if len(key) > 20 else key}) + + pbar.close() + del shard_dict + torch.cuda.empty_cache() + logger.info(f"Completed processing shard: {shard_file}") + + except Exception as e: + logger.error(f"Error processing shard {shard_file}: {str(e)}") + raise + + def get_shard_list(checkpoint_dir): + """Get list of shards from index file.""" + if args.safe_serialization: + index_file = os.path.join(checkpoint_dir, "model.safetensors.index.json") + else: + index_file = os.path.join(checkpoint_dir, "pytorch_model.bin.index.json") + + if os.path.exists(index_file): + import json + with open(index_file, 'r') as f: + index = json.load(f) + return list(set(index['weight_map'].values())) + else: + # Handle single file case + if args.safe_serialization: + return ["model.safetensors"] + else: + return ["pytorch_model.bin"] + + def process_shard_batch(shard_files: List[str], checkpoint_dir: str, save_dir: str, safe_serialization: bool): + """Process a batch of shards in parallel.""" + with ProcessPoolExecutor(max_workers=args.num_workers) as executor: + futures = [] + for shard_file in shard_files: + future = executor.submit(process_shard, + shard_file, + checkpoint_dir, + save_dir, + safe_serialization) + futures.append((shard_file, future)) + + # Create progress bar for this batch + batch_pbar = tqdm(total=len(futures), + desc=f"Processing shard batch", + position=0, + leave=True) + + # Wait for all futures to complete + for shard_file, future in futures: + try: + future.result() # This will raise any exceptions that occurred + batch_pbar.update(1) + batch_pbar.set_postfix({'last_completed': os.path.basename(shard_file)}) + except Exception as e: + logger.error(f"Failed processing shard {shard_file}: {str(e)}") + raise + + batch_pbar.close() + + try: + # Create zero subdirectory in temp directory + temp_zero_dir = os.path.join(temp_save_dir, 'zero') + if os.path.exists(temp_zero_dir): + logger.info(f"Removing existing temp directory: {temp_zero_dir}") + shutil.rmtree(temp_zero_dir) + + shard_files = get_shard_list(args.hf_checkpoint_dir) + total_shards = len(shard_files) + logger.info(f"Found {total_shards} shards to process") + + # Process shards in batches equal to number of workers + batch_size = args.num_workers + for i in range(0, total_shards, batch_size): + batch_shards = shard_files[i:i + batch_size] + logger.info(f"Processing batch of {len(batch_shards)} shards ({i+1}-{min(i+batch_size, total_shards)} of {total_shards})") + process_shard_batch(batch_shards, + args.hf_checkpoint_dir, + temp_zero_dir, # Changed from temp_save_dir to temp_zero_dir + args.safe_serialization) + + # Force garbage collection after each batch + torch.cuda.empty_cache() + + logger.info("All shard batches processed successfully") + + final_save_dir = os.path.join(args.save_dir, 'zero') + if os.path.exists(final_save_dir): + shutil.rmtree(final_save_dir) + + # Create the parent directory if it doesn't exist + os.makedirs(os.path.dirname(final_save_dir), exist_ok=True) + # Move the zero directory to its final location + os.rename(temp_zero_dir, final_save_dir) + + # Clean up the temporary directory + if os.path.exists(temp_save_dir): + shutil.rmtree(temp_save_dir) + + # Write identifier file + with open(os.path.join(args.save_dir, 'source.txt'), 'w') as f: + f.write("Huggingface checkpoint") + + logger.info(f"Successfully saved checkpoint to {final_save_dir}") + + # Update latest file + checkpoint_root_folder = os.path.dirname(args.save_dir) + step_folder = os.path.basename(args.save_dir) + latest_file = os.path.join(checkpoint_root_folder, 'latest_universal') + with open(latest_file, 'w') as f: + f.write(step_folder) + + logger.info(f"Checkpoint conversion completed successfully. Latest file updated at {latest_file}") + + except Exception as e: + logger.error(f"Failed to process checkpoint: {str(e)}") + if os.path.exists(temp_save_dir): + shutil.rmtree(temp_save_dir) + raise diff --git a/deepspeed/runtime/base_optimizer.py b/deepspeed/runtime/base_optimizer.py index d2c54155da89..18295125055e 100644 --- a/deepspeed/runtime/base_optimizer.py +++ b/deepspeed/runtime/base_optimizer.py @@ -17,14 +17,17 @@ class DeepSpeedOptimizer(object): class ZeROOptimizer(DeepSpeedOptimizer): - def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, checkpoint_dir: str) -> None: + def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, checkpoint_dir: str, ignore_missing_optim_state: bool = False) -> None: checkpoint_dir = os.path.join(checkpoint_dir, "zero") - optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt") - assert os.path.isfile( - optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.' - optim_sd = torch.load(optim_state_path, weights_only=False) - - self._load_global_state(optim_sd) + optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt") + if not ignore_missing_optim_state: + assert os.path.isfile( + optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.' + + optim_sd = torch.load(optim_state_path, weights_only=False) + self._load_global_state(optim_sd) + else: + optim_sd = {} tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu) if self.mpu is None: @@ -34,8 +37,7 @@ def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, chec tp_world_size = self.mpu.get_slice_parallel_world_size() if hasattr(self.mpu, "get_slice_parallel_world_size") \ else self.mpu.get_tensor_model_parallel_world_size() - for i, (param_group, - loaded_param_group) in enumerate(zip(self.optimizer.param_groups, optim_sd['param_groups'])): + for i, param_group in enumerate(self.optimizer.param_groups): # We have an assumption that all params in the same param_group have the same keys opt_keys = set() steps = [] @@ -57,6 +59,9 @@ def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, chec map_to_flat_opt_states(hp_param, lp_groups[i], self.optimizer.state, opt_keys) + if ignore_missing_optim_state: + continue + loaded_param_group = optim_sd['param_groups'][i] for key, value in loaded_param_group.items(): if key == 'params': continue diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 6c9577054f5f..a2274eaaa763 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2952,16 +2952,24 @@ def load_checkpoint(self, if self._optimizer_has_ckpt_event_prologue(): # Prepare for checkpoint load by ensuring all parameters are partitioned self.optimizer.checkpoint_event_prologue() - - load_path, client_states = self._load_checkpoint(load_dir, - tag, - load_module_strict=load_module_strict, - load_optimizer_states=load_optimizer_states, - load_lr_scheduler_states=load_lr_scheduler_states, - load_module_only=load_module_only, - custom_load_fn=custom_load_fn) - - load_zero_checkpoint = load_path is not None and (self.zero_optimization() or self.bfloat16_enabled()) + + if not self.zero_ignore_missing_optim_state(): + # Temporary skip this path for HF-based UCP + load_path, client_states = self._load_checkpoint(load_dir, + tag, + load_module_strict=load_module_strict, + load_optimizer_states=load_optimizer_states, + load_lr_scheduler_states=load_lr_scheduler_states, + load_module_only=load_module_only, + custom_load_fn=custom_load_fn) + + load_zero_checkpoint = load_path is not None and (self.zero_optimization() or self.bfloat16_enabled()) + + else: + # What should load_path and client_states be? + load_path, client_states = None, {} + load_zero_checkpoint = (self.zero_optimization() or self.bfloat16_enabled()) + if load_zero_checkpoint: if (load_optimizer_states and not load_module_only) or self.load_universal_checkpoint(): success = self._load_zero_checkpoint(load_dir, tag, load_optimizer_states=load_optimizer_states) @@ -3001,7 +3009,7 @@ def _load_checkpoint(self, custom_load_fn=None): from deepspeed.runtime.state_dict_factory import SDLoaderFactory - + logger.info(f"Loading checkpoint from {load_dir} with tag {tag}") ckpt_list = self._get_all_ckpt_names(load_dir, tag) sd_loader = SDLoaderFactory.get_sd_loader(ckpt_list, checkpoint_engine=self.checkpoint_engine) @@ -3159,7 +3167,8 @@ def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True): load_from_fp32_weights=self.zero_load_from_fp32_weights(), checkpoint_folder=checkpoint_folder, load_serial=load_serial, - param_shapes=param_shapes) + param_shapes=param_shapes, + ignore_missing_optim_state=self.zero_ignore_missing_optim_state()) if self.load_universal_checkpoint(): logger.info(f'loaded universal zero checkpoints from {checkpoint_folder} for rank {self.global_rank}') diff --git a/deepspeed/runtime/state_dict_factory.py b/deepspeed/runtime/state_dict_factory.py index c2db85d1ba58..a3d0a3f25391 100755 --- a/deepspeed/runtime/state_dict_factory.py +++ b/deepspeed/runtime/state_dict_factory.py @@ -164,7 +164,6 @@ def set_module(self, sd, module): return sd def check_ckpt_list(self): - #logger.info(f'checkpoint file list: {self.ckpt_list}') assert len(self.ckpt_list) > 0 sd = self.checkpoint_engine.load(self.ckpt_list[0], map_location=lambda storage, loc: storage) diff --git a/deepspeed/runtime/zero/config.py b/deepspeed/runtime/zero/config.py index 19ee9b51702e..74880a1d5a26 100644 --- a/deepspeed/runtime/zero/config.py +++ b/deepspeed/runtime/zero/config.py @@ -272,7 +272,7 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel): ignore_unused_parameters: bool = True """ Unused parameters in modules may be unexpected in static networks, but - could be normal in dynamic networks. This controls whether or not training + could be normal in dynamic networks. This controls whether or not training should terminate with an error message when unused parameters are detected. This is set to ``True`` by default, which means unused parameters are ignored and training continues. Now is just used in stage 2. @@ -345,6 +345,11 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel): """ Whether to log warnings from trace cache, such as invalidation events. """ + + ignore_missing_optim_state: bool = False + """ + Ignore missing optimizer states when loading checkpoint + """ # Validators @model_validator(mode="after") diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index ec0cd92b3174..3202f1a1639e 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -2692,7 +2692,8 @@ def load_state_dict(self, load_from_fp32_weights=False, checkpoint_folder=None, load_serial=None, - param_shapes=None): + param_shapes=None, + ignore_missing_optim_state: bool = False): r"""Loading a ZeRO checkpoint Arguments: state_dict_list: List of all saved ZeRO checkpoints, one for each saved partition. @@ -2723,7 +2724,7 @@ def load_state_dict(self, if checkpoint_folder: self._load_universal_checkpoint(checkpoint_folder, load_optimizer_states, load_from_fp32_weights, - param_shapes) + param_shapes, ignore_missing_optim_state=ignore_missing_optim_state) else: self._rigid_load_state_dict(state_dict_list[dist.get_rank(group=self.dp_process_group)], load_optimizer_states=load_optimizer_states) @@ -2745,18 +2746,19 @@ def load_state_dict(self, # self.persistent_parameters[0].all_gather(self.persistent_parameters) # this will be done in checkpoint_event_epilogue() so remove it to prevent double all_gather def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights, - param_shapes): - self.load_hp_checkpoint_state_from_checkpoint_dir_stage3(checkpoint_folder, param_shapes) + param_shapes, ignore_missing_optim_state): + self.load_hp_checkpoint_state_from_checkpoint_dir_stage3(checkpoint_folder, param_shapes, ignore_missing_optim_state) - def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, param_shapes): + def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, param_shapes, ignore_missing_optim_state): """ Load optimizer and model states from the checkpoint directory. """ checkpoint_dir = os.path.join(checkpoint_dir, "zero") optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt") - assert os.path.isfile( - optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.' + if not ignore_missing_optim_state: + assert os.path.isfile( + optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.' - optim_sd = torch.load(optim_state_path, weights_only=False) - self._load_global_state_stage3(optim_sd) + optim_sd = torch.load(optim_state_path, weights_only=False) + self._load_global_state_stage3(optim_sd) key_list = ["fp32", "exp_avg", "exp_avg_sq"] @@ -2768,14 +2770,13 @@ def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, pa if key == "fp32": self.fp32_partitioned_groups_flat[0].data.copy_(key_tensor) self.optimizer.param_groups[0]['params'].append(self.fp32_partitioned_groups_flat[0]) - else: + elif not ignore_missing_optim_state: optim_sd[OPTIMIZER_STATE_DICT]['state'][0][key] = key_tensor if self.swap_optimizer: # Purge the swapped optimizer state, it was initialized to the freshly created model and not the checkpoint self.optimizer_swapper.purge_state() - if self.swap_optimizer: # Touch all parameters to synchronize all buffers timer_names = set() self._partition_all_parameters() @@ -2785,9 +2786,10 @@ def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, pa self._release_sub_group(sub_group_id, timer_names) self._post_step(timer_names) - self.optimizer.load_state_dict(optim_sd[OPTIMIZER_STATE_DICT]) - for param_group in self.optimizer.param_groups: - param_group['params'] = [] + if not ignore_missing_optim_state: + self.optimizer.load_state_dict(optim_sd[OPTIMIZER_STATE_DICT]) + for param_group in self.optimizer.param_groups: + param_group['params'] = [] for sub_group_id in range(len(self.fp32_partitioned_groups_flat)): fp32_param = self.fp32_partitioned_groups_flat[sub_group_id] @@ -2811,7 +2813,14 @@ def load_hp_checkpoint_state(self, folder, key): local_rank = dist.get_local_rank() # Load tensors from files and reshape them to flat vectors - loaded_checkpoint_state = torch.load(os.path.join(folder, f"{key}.pt"), weights_only=False).view(-1) + + loaded_state = torch.load(os.path.join(folder, f"{key}.pt"), weights_only=False) + if isinstance(loaded_state, dict): + loaded_checkpoint_state = loaded_state['param'].view(-1) + elif isinstance(loaded_state, torch.Tensor): + loaded_checkpoint_state = loaded_state.view(-1) + else: + raise ValueError(f"Unknown type {type(loaded_state)} for loaded state") # Partition the loaded data according to the local rank world_size = dist.get_world_size(group=self.dp_process_group) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 861f7d23c9c2..38f4e98bc5dc 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -2309,14 +2309,15 @@ def load_state_dict(self, load_from_fp32_weights=False, checkpoint_folder=None, load_serial=None, - param_shapes=None): + param_shapes=None, + ignore_missing_optim_state: bool = False): if checkpoint_folder: - self._load_universal_checkpoint(checkpoint_folder, load_optimizer_states, load_from_fp32_weights) + self._load_universal_checkpoint(checkpoint_folder, load_optimizer_states, load_from_fp32_weights, ignore_missing_optim_state=ignore_missing_optim_state) else: self._load_legacy_checkpoint(state_dict_list, load_optimizer_states, load_from_fp32_weights) - def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights): - self.load_hp_checkpoint_state_from_checkpoint_dir("bit16_groups", checkpoint_folder) + def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights, ignore_missing_optim_state: bool = False): + self.load_hp_checkpoint_state_from_checkpoint_dir("bit16_groups", checkpoint_folder, ignore_missing_optim_state=ignore_missing_optim_state) def _load_global_state(self, sd): self.loss_scaler = sd.get(LOSS_SCALER, self.loss_scaler) From 727206e1534dcdeb22244bbad01b889dcc75c985 Mon Sep 17 00:00:00 2001 From: Schwidola0607 Date: Wed, 26 Mar 2025 02:42:46 -0500 Subject: [PATCH 02/12] add user guide Signed-off-by: Schwidola0607 --- docs/_tutorials/hugging-face-to-ucp.md | 40 ++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 docs/_tutorials/hugging-face-to-ucp.md diff --git a/docs/_tutorials/hugging-face-to-ucp.md b/docs/_tutorials/hugging-face-to-ucp.md new file mode 100644 index 000000000000..9c892e7e234e --- /dev/null +++ b/docs/_tutorials/hugging-face-to-ucp.md @@ -0,0 +1,40 @@ +--- +title: "Converting a Hugging Face checkpoint to Universal Checkpointing format" +tags: checkpointing, training, deepspeed, huggingface +--- + +## Introduction to Universal Checkpointing + +Universal Checkpointing in DeepSpeed abstracts away the complexities of saving and loading model states, optimizer states, and training scheduler states. This feature is designed to work out of the box with minimal configuration, supporting a wide range of model sizes and types, from small-scale models to large, distributed models with different parallelism topologies trained across multiple GPUs and other accelerators. + +See more: https://www.deepspeed.ai/tutorials/universal-checkpointing/ + +## Converting a Hugging Face checkpoint to Universal Checkpointing format + +### Step 1: Download a Hugging Face checkpoint + +You can download a Hugging Face checkpoint from the Hugging Face Hub. For example, you can download the `openai-community/gpt2` checkpoint using the following script + +```python +from huggingface_hub import snapshot_download +local_dir = snapshot_download(repo_id="openai-community/gpt2") +``` + +### Step 2: Convert Hugging Face checkpoint to Universal Checkpointing format + +To convert a Hugging Face checkpoint to Universal Checkpointing format, you can use the `hf_to_universal.py` script provided in the DeepSpeed repository. This script will take a Hugging Face checkpoint directory and convert it to a Universal Checkpointing format. + +```bash +python hf_to_universal.py --hf_checkpoint_dir /path/to/huggingface/checkpoint --save_dir /path/to/universal/checkpoint +``` + +This script will process the Hugging Face checkpoint and generate a new checkpoint in the Universal Checkpointing format. Note that `hf_to_universal.py` script supports both safetensors and pytorch.bin checkpoint format. + +### Step 3: Resume Training with Universal Checkpoint +With the Universal checkpoint ready, you can now resume training on potentially with different parallelism topologies or training configurations. To do this add `--universal-checkpoint` to your DeepSpeed config (json) file + + +## Conclusion +DeepSpeed Universal Checkpointing simplifies the management of model states, making it easier to save, load, and transfer model states across different training sessions and parallelism techniques. By converting a Hugging Face checkpoint to Universal Checkpointing format, you can load pretrained weights of any model in the Hugging Face Hub and resume training with DeepSpeed under any parallelism topologies. + +For more detailed examples and advanced configurations, please refer to the [Megatron-DeepSpeed examples](https://github.com/deepspeedai/Megatron-DeepSpeed/tree/main/examples_deepspeed/universal_checkpointing). From 49588a85272424bf4765cfcd06b671e0f9da4c81 Mon Sep 17 00:00:00 2001 From: Schwidola0607 Date: Wed, 26 Mar 2025 15:22:06 -0500 Subject: [PATCH 03/12] edit user guide Signed-off-by: Schwidola0607 --- docs/_tutorials/hugging-face-to-ucp.md | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/docs/_tutorials/hugging-face-to-ucp.md b/docs/_tutorials/hugging-face-to-ucp.md index 9c892e7e234e..636c96d29510 100644 --- a/docs/_tutorials/hugging-face-to-ucp.md +++ b/docs/_tutorials/hugging-face-to-ucp.md @@ -9,23 +9,19 @@ Universal Checkpointing in DeepSpeed abstracts away the complexities of saving a See more: https://www.deepspeed.ai/tutorials/universal-checkpointing/ -## Converting a Hugging Face checkpoint to Universal Checkpointing format +## Converting a pretrained Hugging Face checkpoint to Universal Checkpointing format -### Step 1: Download a Hugging Face checkpoint +### Step 1: Download a pretrained Hugging Face checkpoint +Download a pretrained Hugging Face checkpoint from the Hugging Face Hub using [snapshot_download](https://huggingface.co/docs/huggingface_hub/en/guides/download) -You can download a Hugging Face checkpoint from the Hugging Face Hub. For example, you can download the `openai-community/gpt2` checkpoint using the following script - -```python -from huggingface_hub import snapshot_download -local_dir = snapshot_download(repo_id="openai-community/gpt2") -``` +Hugging Face checkpoints are one or many files in the pytorch_model.bin or safetensors format. ### Step 2: Convert Hugging Face checkpoint to Universal Checkpointing format To convert a Hugging Face checkpoint to Universal Checkpointing format, you can use the `hf_to_universal.py` script provided in the DeepSpeed repository. This script will take a Hugging Face checkpoint directory and convert it to a Universal Checkpointing format. ```bash -python hf_to_universal.py --hf_checkpoint_dir /path/to/huggingface/checkpoint --save_dir /path/to/universal/checkpoint +python deepspeed/checkpoint/hf_to_universal.py --hf_checkpoint_dir /path/to/huggingface/checkpoint --save_dir /path/to/universal/checkpoint ``` This script will process the Hugging Face checkpoint and generate a new checkpoint in the Universal Checkpointing format. Note that `hf_to_universal.py` script supports both safetensors and pytorch.bin checkpoint format. From 7bef517c5c03f5172e2913073d27ae58103403bd Mon Sep 17 00:00:00 2001 From: Schwidola0607 Date: Wed, 26 Mar 2025 15:27:52 -0500 Subject: [PATCH 04/12] cleaning up Signed-off-by: Schwidola0607 --- deepspeed/checkpoint/hf_to_universal.py | 20 ++++---------------- docs/_tutorials/hugging-face-to-ucp.md | 4 ++-- 2 files changed, 6 insertions(+), 18 deletions(-) diff --git a/deepspeed/checkpoint/hf_to_universal.py b/deepspeed/checkpoint/hf_to_universal.py index 9867bd442a02..7cc6e2e3ea4a 100644 --- a/deepspeed/checkpoint/hf_to_universal.py +++ b/deepspeed/checkpoint/hf_to_universal.py @@ -9,7 +9,7 @@ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -# Constants for parameter patterns +# Hard-coded constants for parameter patterns VOCAB_PARAMETER_PATTERNS = [ 'word_embeddings', 'embed_tokens', @@ -18,15 +18,6 @@ 'lm_head' # Often tied with embeddings ] -ROW_PARALLEL_PATTERNS = [ - 'dense_h_to_4h', - 'fc1', - 'k_proj', - 'v_proj', - 'q_proj', - 'gate_proj', - 'up_proj' -] def get_parameter_type(name: str) -> dict: """Determine parameter type and required fields based on name.""" @@ -38,11 +29,7 @@ def get_parameter_type(name: str) -> dict: if any(pattern in name.lower() for pattern in VOCAB_PARAMETER_PATTERNS): param_info['vocab_tensor'] = True - # TODO: figure out if this is needed - # # Check for row-parallel parameters - # if any(pattern in name.lower() for pattern in ROW_PARALLEL_PATTERNS): - # param_info['cat_dim'] = 1 # Use dimension 1 for row-parallel parameters - + # TODO: figure out if we need to check for row-parallel parameters return param_info if __name__ == '__main__': @@ -75,7 +62,8 @@ def save_parameter(name: str, param: torch.Tensor, save_dir: str): } torch.save(param_dict, param_path) - # Initialize optimizer states with zeros + # Since HuggingFace checkpoints do not have optimizer states, + # we initialize them with zeros for state in ("exp_avg", "exp_avg_sq"): state_path = os.path.join(param_dir, f"{state}.pt") state_dict = { diff --git a/docs/_tutorials/hugging-face-to-ucp.md b/docs/_tutorials/hugging-face-to-ucp.md index 636c96d29510..31ef903f2a89 100644 --- a/docs/_tutorials/hugging-face-to-ucp.md +++ b/docs/_tutorials/hugging-face-to-ucp.md @@ -14,7 +14,7 @@ See more: https://www.deepspeed.ai/tutorials/universal-checkpointing/ ### Step 1: Download a pretrained Hugging Face checkpoint Download a pretrained Hugging Face checkpoint from the Hugging Face Hub using [snapshot_download](https://huggingface.co/docs/huggingface_hub/en/guides/download) -Hugging Face checkpoints are one or many files in the pytorch_model.bin or safetensors format. +Hugging Face checkpoints are one or many files in the `pytorch_model.bin` or `safetensors format`. ### Step 2: Convert Hugging Face checkpoint to Universal Checkpointing format @@ -24,7 +24,7 @@ To convert a Hugging Face checkpoint to Universal Checkpointing format, you can python deepspeed/checkpoint/hf_to_universal.py --hf_checkpoint_dir /path/to/huggingface/checkpoint --save_dir /path/to/universal/checkpoint ``` -This script will process the Hugging Face checkpoint and generate a new checkpoint in the Universal Checkpointing format. Note that `hf_to_universal.py` script supports both safetensors and pytorch.bin checkpoint format. +This script will process the Hugging Face checkpoint and generate a new checkpoint in the Universal Checkpointing format. Note that `hf_to_universal.py` script supports both `.safetensors` and `pytorch.bin` checkpoint format. Use `--safe_serialization` flag to convert from `.safetensors` format. ### Step 3: Resume Training with Universal Checkpoint With the Universal checkpoint ready, you can now resume training on potentially with different parallelism topologies or training configurations. To do this add `--universal-checkpoint` to your DeepSpeed config (json) file From 2930f2a0e9c4e69e4fa6de71309f489a1b6d46de Mon Sep 17 00:00:00 2001 From: Xinyu Lian Date: Tue, 1 Apr 2025 14:48:29 -0500 Subject: [PATCH 05/12] nits Signed-off-by: Schwidola0607 --- deepspeed/checkpoint/hf_to_universal.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/deepspeed/checkpoint/hf_to_universal.py b/deepspeed/checkpoint/hf_to_universal.py index 7cc6e2e3ea4a..bc1d953beb05 100644 --- a/deepspeed/checkpoint/hf_to_universal.py +++ b/deepspeed/checkpoint/hf_to_universal.py @@ -14,8 +14,8 @@ 'word_embeddings', 'embed_tokens', 'embedding', - 'wte', # GPT style embeddings - 'lm_head' # Often tied with embeddings + 'wte', # GPT style embeddings + 'lm_head' # Language model head, often tied with embeddings ] @@ -35,8 +35,8 @@ def get_parameter_type(name: str) -> dict: if __name__ == '__main__': import argparse - parser = argparse.ArgumentParser(description='Load a HuggingFace model') - parser.add_argument('--hf_checkpoint_dir', type=str, help='Path to the HuggingFace checkpoint directory') + parser = argparse.ArgumentParser(description='Convert HuggingFace checkpoint to Universal Checkpoint format') + parser.add_argument('--hf_checkpoint_dir', type=str, required=True, help='Path to the HuggingFace checkpoint directory') parser.add_argument('--safe_serialization', action='store_true', default=False, help='Use safetensors for serialization') parser.add_argument('--num_workers', type=int, default=4, help='Number of workers to use for saving checkpoints') parser.add_argument('--save_dir', type=str, required=True, help='Directory to save checkpoints') @@ -119,10 +119,12 @@ def get_shard_list(checkpoint_dir): return list(set(index['weight_map'].values())) else: # Handle single file case - if args.safe_serialization: + if args.safe_serialization and os.path.exists(os.path.join(checkpoint_dir, "model.safetensors")): return ["model.safetensors"] - else: + elif os.path.exists(os.path.join(checkpoint_dir, "pytorch_model.bin")): return ["pytorch_model.bin"] + else: + raise FileNotFoundError(f"No checkpoint files found in {checkpoint_dir}") def process_shard_batch(shard_files: List[str], checkpoint_dir: str, save_dir: str, safe_serialization: bool): """Process a batch of shards in parallel.""" From 9207df99dac5efa13017bea1057fa34480cce0fc Mon Sep 17 00:00:00 2001 From: Schwidola0607 Date: Sat, 5 Apr 2025 06:55:29 -0500 Subject: [PATCH 06/12] remove ignore_missing_optim config from zero ds_config Signed-off-by: Schwidola0607 --- deepspeed/runtime/base_optimizer.py | 12 ++++----- deepspeed/runtime/engine.py | 33 +++++++++---------------- deepspeed/runtime/state_dict_factory.py | 1 + deepspeed/runtime/zero/config.py | 7 +----- deepspeed/runtime/zero/stage3.py | 22 ++++++++--------- deepspeed/runtime/zero/stage_1_and_2.py | 9 +++---- 6 files changed, 35 insertions(+), 49 deletions(-) diff --git a/deepspeed/runtime/base_optimizer.py b/deepspeed/runtime/base_optimizer.py index 18295125055e..eddfae829624 100644 --- a/deepspeed/runtime/base_optimizer.py +++ b/deepspeed/runtime/base_optimizer.py @@ -17,16 +17,16 @@ class DeepSpeedOptimizer(object): class ZeROOptimizer(DeepSpeedOptimizer): - def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, checkpoint_dir: str, ignore_missing_optim_state: bool = False) -> None: + def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, checkpoint_dir: str) -> None: checkpoint_dir = os.path.join(checkpoint_dir, "zero") - optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt") - if not ignore_missing_optim_state: - assert os.path.isfile( - optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.' - + optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt") + if os.path.isfile(optim_state_path): + ignore_missing_optim_state = False optim_sd = torch.load(optim_state_path, weights_only=False) self._load_global_state(optim_sd) else: + logger.warning(f'{optim_state_path} containing optimizer global state is missing!') + ignore_missing_optim_state = True optim_sd = {} tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index a2274eaaa763..6c9577054f5f 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2952,24 +2952,16 @@ def load_checkpoint(self, if self._optimizer_has_ckpt_event_prologue(): # Prepare for checkpoint load by ensuring all parameters are partitioned self.optimizer.checkpoint_event_prologue() - - if not self.zero_ignore_missing_optim_state(): - # Temporary skip this path for HF-based UCP - load_path, client_states = self._load_checkpoint(load_dir, - tag, - load_module_strict=load_module_strict, - load_optimizer_states=load_optimizer_states, - load_lr_scheduler_states=load_lr_scheduler_states, - load_module_only=load_module_only, - custom_load_fn=custom_load_fn) - - load_zero_checkpoint = load_path is not None and (self.zero_optimization() or self.bfloat16_enabled()) - - else: - # What should load_path and client_states be? - load_path, client_states = None, {} - load_zero_checkpoint = (self.zero_optimization() or self.bfloat16_enabled()) - + + load_path, client_states = self._load_checkpoint(load_dir, + tag, + load_module_strict=load_module_strict, + load_optimizer_states=load_optimizer_states, + load_lr_scheduler_states=load_lr_scheduler_states, + load_module_only=load_module_only, + custom_load_fn=custom_load_fn) + + load_zero_checkpoint = load_path is not None and (self.zero_optimization() or self.bfloat16_enabled()) if load_zero_checkpoint: if (load_optimizer_states and not load_module_only) or self.load_universal_checkpoint(): success = self._load_zero_checkpoint(load_dir, tag, load_optimizer_states=load_optimizer_states) @@ -3009,7 +3001,7 @@ def _load_checkpoint(self, custom_load_fn=None): from deepspeed.runtime.state_dict_factory import SDLoaderFactory - logger.info(f"Loading checkpoint from {load_dir} with tag {tag}") + ckpt_list = self._get_all_ckpt_names(load_dir, tag) sd_loader = SDLoaderFactory.get_sd_loader(ckpt_list, checkpoint_engine=self.checkpoint_engine) @@ -3167,8 +3159,7 @@ def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True): load_from_fp32_weights=self.zero_load_from_fp32_weights(), checkpoint_folder=checkpoint_folder, load_serial=load_serial, - param_shapes=param_shapes, - ignore_missing_optim_state=self.zero_ignore_missing_optim_state()) + param_shapes=param_shapes) if self.load_universal_checkpoint(): logger.info(f'loaded universal zero checkpoints from {checkpoint_folder} for rank {self.global_rank}') diff --git a/deepspeed/runtime/state_dict_factory.py b/deepspeed/runtime/state_dict_factory.py index a3d0a3f25391..c2db85d1ba58 100755 --- a/deepspeed/runtime/state_dict_factory.py +++ b/deepspeed/runtime/state_dict_factory.py @@ -164,6 +164,7 @@ def set_module(self, sd, module): return sd def check_ckpt_list(self): + #logger.info(f'checkpoint file list: {self.ckpt_list}') assert len(self.ckpt_list) > 0 sd = self.checkpoint_engine.load(self.ckpt_list[0], map_location=lambda storage, loc: storage) diff --git a/deepspeed/runtime/zero/config.py b/deepspeed/runtime/zero/config.py index 74880a1d5a26..19ee9b51702e 100644 --- a/deepspeed/runtime/zero/config.py +++ b/deepspeed/runtime/zero/config.py @@ -272,7 +272,7 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel): ignore_unused_parameters: bool = True """ Unused parameters in modules may be unexpected in static networks, but - could be normal in dynamic networks. This controls whether or not training + could be normal in dynamic networks. This controls whether or not training should terminate with an error message when unused parameters are detected. This is set to ``True`` by default, which means unused parameters are ignored and training continues. Now is just used in stage 2. @@ -345,11 +345,6 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel): """ Whether to log warnings from trace cache, such as invalidation events. """ - - ignore_missing_optim_state: bool = False - """ - Ignore missing optimizer states when loading checkpoint - """ # Validators @model_validator(mode="after") diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 3202f1a1639e..a9f33c275aa6 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -2692,8 +2692,7 @@ def load_state_dict(self, load_from_fp32_weights=False, checkpoint_folder=None, load_serial=None, - param_shapes=None, - ignore_missing_optim_state: bool = False): + param_shapes=None): r"""Loading a ZeRO checkpoint Arguments: state_dict_list: List of all saved ZeRO checkpoints, one for each saved partition. @@ -2724,7 +2723,7 @@ def load_state_dict(self, if checkpoint_folder: self._load_universal_checkpoint(checkpoint_folder, load_optimizer_states, load_from_fp32_weights, - param_shapes, ignore_missing_optim_state=ignore_missing_optim_state) + param_shapes) else: self._rigid_load_state_dict(state_dict_list[dist.get_rank(group=self.dp_process_group)], load_optimizer_states=load_optimizer_states) @@ -2746,19 +2745,20 @@ def load_state_dict(self, # self.persistent_parameters[0].all_gather(self.persistent_parameters) # this will be done in checkpoint_event_epilogue() so remove it to prevent double all_gather def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights, - param_shapes, ignore_missing_optim_state): - self.load_hp_checkpoint_state_from_checkpoint_dir_stage3(checkpoint_folder, param_shapes, ignore_missing_optim_state) + param_shapes): + self.load_hp_checkpoint_state_from_checkpoint_dir_stage3(checkpoint_folder, param_shapes) - def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, param_shapes, ignore_missing_optim_state): + def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, param_shapes): """ Load optimizer and model states from the checkpoint directory. """ checkpoint_dir = os.path.join(checkpoint_dir, "zero") optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt") - if not ignore_missing_optim_state: - assert os.path.isfile( - optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.' - + if os.path.isfile(optim_state_path): + ignore_missing_optim_state = False optim_sd = torch.load(optim_state_path, weights_only=False) self._load_global_state_stage3(optim_sd) + else: + logger.warning(f'{optim_state_path} containing optimizer global state is missing!') + ignore_missing_optim_state = True key_list = ["fp32", "exp_avg", "exp_avg_sq"] @@ -2777,6 +2777,7 @@ def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, pa # Purge the swapped optimizer state, it was initialized to the freshly created model and not the checkpoint self.optimizer_swapper.purge_state() + if self.swap_optimizer: # Touch all parameters to synchronize all buffers timer_names = set() self._partition_all_parameters() @@ -2813,7 +2814,6 @@ def load_hp_checkpoint_state(self, folder, key): local_rank = dist.get_local_rank() # Load tensors from files and reshape them to flat vectors - loaded_state = torch.load(os.path.join(folder, f"{key}.pt"), weights_only=False) if isinstance(loaded_state, dict): loaded_checkpoint_state = loaded_state['param'].view(-1) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 38f4e98bc5dc..861f7d23c9c2 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -2309,15 +2309,14 @@ def load_state_dict(self, load_from_fp32_weights=False, checkpoint_folder=None, load_serial=None, - param_shapes=None, - ignore_missing_optim_state: bool = False): + param_shapes=None): if checkpoint_folder: - self._load_universal_checkpoint(checkpoint_folder, load_optimizer_states, load_from_fp32_weights, ignore_missing_optim_state=ignore_missing_optim_state) + self._load_universal_checkpoint(checkpoint_folder, load_optimizer_states, load_from_fp32_weights) else: self._load_legacy_checkpoint(state_dict_list, load_optimizer_states, load_from_fp32_weights) - def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights, ignore_missing_optim_state: bool = False): - self.load_hp_checkpoint_state_from_checkpoint_dir("bit16_groups", checkpoint_folder, ignore_missing_optim_state=ignore_missing_optim_state) + def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights): + self.load_hp_checkpoint_state_from_checkpoint_dir("bit16_groups", checkpoint_folder) def _load_global_state(self, sd): self.loss_scaler = sd.get(LOSS_SCALER, self.loss_scaler) From 809036991d43a974bc355e7513c26d0e448eed6d Mon Sep 17 00:00:00 2001 From: Schwidola0607 Date: Thu, 10 Apr 2025 04:10:55 -0500 Subject: [PATCH 07/12] fix to make ucp load more lenient Signed-off-by: Schwidola0607 --- deepspeed/runtime/engine.py | 16 +++++++++++++--- deepspeed/runtime/zero/stage3.py | 3 +-- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 6c9577054f5f..4c000df8e4b7 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2899,7 +2899,7 @@ def _get_all_ckpt_names(self, checkpoints_path, tag): ckpt_files = glob.glob(ckpt_file_pattern) ckpt_files.sort() - return ckpt_files + return ckpt_files, ckpt_file_pattern def load_checkpoint(self, load_dir, @@ -2923,7 +2923,7 @@ def load_checkpoint(self, Returns: A tuple of ``load_path`` and ``client_state``. - *``load_path``: Path of the loaded checkpoint. ``None`` if loading the checkpoint failed. + *``load_path``: Path of the loaded checkpoint. ``None`` if loading the checkpoint failed or loading a HF based UCP *``client_state``: State dictionary used for loading required training states in the client code. Important: under ZeRO3, one cannot load checkpoint with ``engine.load_checkpoint()`` right @@ -2962,6 +2962,12 @@ def load_checkpoint(self, custom_load_fn=custom_load_fn) load_zero_checkpoint = load_path is not None and (self.zero_optimization() or self.bfloat16_enabled()) + # import pdb; pdb.set_trace() + if self.load_universal_checkpoint(): + ucp_ckpt_folder = os.path.join(load_dir, tag) + # UCP load can ignore '*mp' files or '*model_states.pt' but ucp_ckpt_folder must exist + load_zero_checkpoint = os.path.isdir(ucp_ckpt_folder) + if load_zero_checkpoint: if (load_optimizer_states and not load_module_only) or self.load_universal_checkpoint(): success = self._load_zero_checkpoint(load_dir, tag, load_optimizer_states=load_optimizer_states) @@ -3002,7 +3008,11 @@ def _load_checkpoint(self, from deepspeed.runtime.state_dict_factory import SDLoaderFactory - ckpt_list = self._get_all_ckpt_names(load_dir, tag) + ckpt_list, ckpt_file_pattern = self._get_all_ckpt_names(load_dir, tag) + if self.load_universal_checkpoint() and len(ckpt_list) == 0: + logger.warning(f"Unable to find {ckpt_file_pattern} files in UCP folder {load_dir}") + return None, {} + sd_loader = SDLoaderFactory.get_sd_loader(ckpt_list, checkpoint_engine=self.checkpoint_engine) is_pipe_parallel = isinstance(self.module, PipelineModule) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index a9f33c275aa6..b98f275b20c3 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -2777,7 +2777,6 @@ def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, pa # Purge the swapped optimizer state, it was initialized to the freshly created model and not the checkpoint self.optimizer_swapper.purge_state() - if self.swap_optimizer: # Touch all parameters to synchronize all buffers timer_names = set() self._partition_all_parameters() @@ -2812,7 +2811,7 @@ def _load_global_state_stage3(self, sd): def load_hp_checkpoint_state(self, folder, key): local_rank = dist.get_local_rank() - + # Load tensors from files and reshape them to flat vectors loaded_state = torch.load(os.path.join(folder, f"{key}.pt"), weights_only=False) if isinstance(loaded_state, dict): From 7b8962a6144a1d033899aaa4f966dc88e59a35e2 Mon Sep 17 00:00:00 2001 From: Xinyu Lian Date: Sat, 12 Apr 2025 22:26:13 -0500 Subject: [PATCH 08/12] nits Signed-off-by: Schwidola0607 --- deepspeed/checkpoint/hf_to_universal.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/deepspeed/checkpoint/hf_to_universal.py b/deepspeed/checkpoint/hf_to_universal.py index bc1d953beb05..8de5ba793a8a 100644 --- a/deepspeed/checkpoint/hf_to_universal.py +++ b/deepspeed/checkpoint/hf_to_universal.py @@ -166,18 +166,17 @@ def process_shard_batch(shard_files: List[str], checkpoint_dir: str, save_dir: s shard_files = get_shard_list(args.hf_checkpoint_dir) total_shards = len(shard_files) logger.info(f"Found {total_shards} shards to process") - - # Process shards in batches equal to number of workers + # Process shards in batches equal to the number of workers batch_size = args.num_workers for i in range(0, total_shards, batch_size): batch_shards = shard_files[i:i + batch_size] logger.info(f"Processing batch of {len(batch_shards)} shards ({i+1}-{min(i+batch_size, total_shards)} of {total_shards})") process_shard_batch(batch_shards, - args.hf_checkpoint_dir, - temp_zero_dir, # Changed from temp_save_dir to temp_zero_dir - args.safe_serialization) + args.hf_checkpoint_dir, + temp_zero_dir, # Changed from temp_save_dir to temp_zero_dir + args.safe_serialization) - # Force garbage collection after each batch + # Clear CUDA cache after each batch to free up memory torch.cuda.empty_cache() logger.info("All shard batches processed successfully") From 23895674c662b97fc6e4d1e4a79e5bb104ce32be Mon Sep 17 00:00:00 2001 From: Xinyu Lian Date: Sat, 12 Apr 2025 22:49:19 -0500 Subject: [PATCH 09/12] nits Signed-off-by: Schwidola0607 --- deepspeed/runtime/engine.py | 1 - 1 file changed, 1 deletion(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 4c000df8e4b7..e7a91e9b90ea 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2962,7 +2962,6 @@ def load_checkpoint(self, custom_load_fn=custom_load_fn) load_zero_checkpoint = load_path is not None and (self.zero_optimization() or self.bfloat16_enabled()) - # import pdb; pdb.set_trace() if self.load_universal_checkpoint(): ucp_ckpt_folder = os.path.join(load_dir, tag) # UCP load can ignore '*mp' files or '*model_states.pt' but ucp_ckpt_folder must exist From 2fa08897b81fd4140b52291b3046725bac0306f0 Mon Sep 17 00:00:00 2001 From: Schwidola0607 Date: Thu, 1 May 2025 18:30:58 -0500 Subject: [PATCH 10/12] formatting and license Signed-off-by: Schwidola0607 --- deepspeed/checkpoint/hf_to_universal.py | 107 +++++++++++++----------- deepspeed/runtime/engine.py | 2 +- deepspeed/runtime/zero/stage3.py | 2 +- 3 files changed, 60 insertions(+), 51 deletions(-) diff --git a/deepspeed/checkpoint/hf_to_universal.py b/deepspeed/checkpoint/hf_to_universal.py index 8de5ba793a8a..e68be7d8780e 100644 --- a/deepspeed/checkpoint/hf_to_universal.py +++ b/deepspeed/checkpoint/hf_to_universal.py @@ -1,8 +1,14 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + import torch import os import shutil import logging -from concurrent.futures import ProcessPoolExecutor, as_completed +from concurrent.futures import ProcessPoolExecutor +from deepspeed.accelerator import get_accelerator from tqdm import tqdm from typing import List @@ -14,8 +20,8 @@ 'word_embeddings', 'embed_tokens', 'embedding', - 'wte', # GPT style embeddings - 'lm_head' # Language model head, often tied with embeddings + 'wte', # GPT style embeddings + 'lm_head' # Language model head, often tied with embeddings ] @@ -24,20 +30,27 @@ def get_parameter_type(name: str) -> dict: param_info = { 'cat_dim': 0 # Default concatenation dimension } - + # Check for vocabulary tensors (embeddings, etc.) if any(pattern in name.lower() for pattern in VOCAB_PARAMETER_PATTERNS): param_info['vocab_tensor'] = True - + # TODO: figure out if we need to check for row-parallel parameters return param_info + if __name__ == '__main__': import argparse - + parser = argparse.ArgumentParser(description='Convert HuggingFace checkpoint to Universal Checkpoint format') - parser.add_argument('--hf_checkpoint_dir', type=str, required=True, help='Path to the HuggingFace checkpoint directory') - parser.add_argument('--safe_serialization', action='store_true', default=False, help='Use safetensors for serialization') + parser.add_argument('--hf_checkpoint_dir', + type=str, + required=True, + help='Path to the HuggingFace checkpoint directory') + parser.add_argument('--safe_serialization', + action='store_true', + default=False, + help='Use safetensors for serialization') parser.add_argument('--num_workers', type=int, default=4, help='Number of workers to use for saving checkpoints') parser.add_argument('--save_dir', type=str, required=True, help='Directory to save checkpoints') args = parser.parse_args() @@ -50,10 +63,10 @@ def save_parameter(name: str, param: torch.Tensor, save_dir: str): # Create parameter directory under zero/ param_dir = os.path.join(save_dir, name) os.makedirs(param_dir, exist_ok=True) - + # Get parameter type and required fields param_info = get_parameter_type(name) - + # Save parameter in fp32 with proper dictionary structure param_path = os.path.join(param_dir, "fp32.pt") param_dict = { @@ -61,8 +74,8 @@ def save_parameter(name: str, param: torch.Tensor, save_dir: str): **param_info # Include all determined parameter info } torch.save(param_dict, param_path) - - # Since HuggingFace checkpoints do not have optimizer states, + + # Since HuggingFace checkpoints do not have optimizer states, # we initialize them with zeros for state in ("exp_avg", "exp_avg_sq"): state_path = os.path.join(param_dir, f"{state}.pt") @@ -77,30 +90,30 @@ def process_shard(shard_file, checkpoint_dir, save_dir, safe_serialization): try: shard_path = os.path.join(checkpoint_dir, shard_file) logger.info(f"Loading shard from: {shard_path}") - + if safe_serialization: from safetensors.torch import load_file shard_dict = load_file(shard_path) else: shard_dict = torch.load(shard_path, map_location='cpu') - + # Create progress bar for parameters within this shard - pbar = tqdm(total=len(shard_dict), - desc=f"Processing {os.path.basename(shard_file)}", - position=1, - leave=False) - + pbar = tqdm(total=len(shard_dict), + desc=f"Processing {os.path.basename(shard_file)}", + position=1, + leave=False) + for key, param in shard_dict.items(): save_parameter(key, param, save_dir) del param pbar.update(1) pbar.set_postfix({'key': key[:20] + '...' if len(key) > 20 else key}) - + pbar.close() del shard_dict - torch.cuda.empty_cache() + get_accelerator().empty_cache() logger.info(f"Completed processing shard: {shard_file}") - + except Exception as e: logger.error(f"Error processing shard {shard_file}: {str(e)}") raise @@ -111,7 +124,7 @@ def get_shard_list(checkpoint_dir): index_file = os.path.join(checkpoint_dir, "model.safetensors.index.json") else: index_file = os.path.join(checkpoint_dir, "pytorch_model.bin.index.json") - + if os.path.exists(index_file): import json with open(index_file, 'r') as f: @@ -131,18 +144,11 @@ def process_shard_batch(shard_files: List[str], checkpoint_dir: str, save_dir: s with ProcessPoolExecutor(max_workers=args.num_workers) as executor: futures = [] for shard_file in shard_files: - future = executor.submit(process_shard, - shard_file, - checkpoint_dir, - save_dir, - safe_serialization) + future = executor.submit(process_shard, shard_file, checkpoint_dir, save_dir, safe_serialization) futures.append((shard_file, future)) - + # Create progress bar for this batch - batch_pbar = tqdm(total=len(futures), - desc=f"Processing shard batch", - position=0, - leave=True) + batch_pbar = tqdm(total=len(futures), desc=f"Processing shard batch", position=0, leave=True) # Wait for all futures to complete for shard_file, future in futures: @@ -153,7 +159,7 @@ def process_shard_batch(shard_files: List[str], checkpoint_dir: str, save_dir: s except Exception as e: logger.error(f"Failed processing shard {shard_file}: {str(e)}") raise - + batch_pbar.close() try: @@ -162,7 +168,7 @@ def process_shard_batch(shard_files: List[str], checkpoint_dir: str, save_dir: s if os.path.exists(temp_zero_dir): logger.info(f"Removing existing temp directory: {temp_zero_dir}") shutil.rmtree(temp_zero_dir) - + shard_files = get_shard_list(args.hf_checkpoint_dir) total_shards = len(shard_files) logger.info(f"Found {total_shards} shards to process") @@ -170,34 +176,37 @@ def process_shard_batch(shard_files: List[str], checkpoint_dir: str, save_dir: s batch_size = args.num_workers for i in range(0, total_shards, batch_size): batch_shards = shard_files[i:i + batch_size] - logger.info(f"Processing batch of {len(batch_shards)} shards ({i+1}-{min(i+batch_size, total_shards)} of {total_shards})") - process_shard_batch(batch_shards, - args.hf_checkpoint_dir, - temp_zero_dir, # Changed from temp_save_dir to temp_zero_dir - args.safe_serialization) - + logger.info( + f"Processing batch of {len(batch_shards)} shards ({i+1}-{min(i+batch_size, total_shards)} of {total_shards})" + ) + process_shard_batch( + batch_shards, + args.hf_checkpoint_dir, + temp_zero_dir, # Changed from temp_save_dir to temp_zero_dir + args.safe_serialization) + # Clear CUDA cache after each batch to free up memory - torch.cuda.empty_cache() - + get_accelerator().empty_cache() + logger.info("All shard batches processed successfully") - + final_save_dir = os.path.join(args.save_dir, 'zero') if os.path.exists(final_save_dir): shutil.rmtree(final_save_dir) - + # Create the parent directory if it doesn't exist os.makedirs(os.path.dirname(final_save_dir), exist_ok=True) # Move the zero directory to its final location os.rename(temp_zero_dir, final_save_dir) - + # Clean up the temporary directory if os.path.exists(temp_save_dir): shutil.rmtree(temp_save_dir) - + # Write identifier file with open(os.path.join(args.save_dir, 'source.txt'), 'w') as f: f.write("Huggingface checkpoint") - + logger.info(f"Successfully saved checkpoint to {final_save_dir}") # Update latest file @@ -206,7 +215,7 @@ def process_shard_batch(shard_files: List[str], checkpoint_dir: str, save_dir: s latest_file = os.path.join(checkpoint_root_folder, 'latest_universal') with open(latest_file, 'w') as f: f.write(step_folder) - + logger.info(f"Checkpoint conversion completed successfully. Latest file updated at {latest_file}") except Exception as e: diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index dc387d6a9cff..cf38b601c915 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -3035,7 +3035,7 @@ def _load_checkpoint(self, if self.load_universal_checkpoint() and len(ckpt_list) == 0: logger.warning(f"Unable to find {ckpt_file_pattern} files in UCP folder {load_dir}") return None, {} - + sd_loader = SDLoaderFactory.get_sd_loader(ckpt_list, checkpoint_engine=self.checkpoint_engine) is_pipe_parallel = isinstance(self.module, PipelineModule) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index dcfd9523a5ed..5662ff056ca1 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -2812,7 +2812,7 @@ def _load_global_state_stage3(self, sd): def load_hp_checkpoint_state(self, folder, key): local_rank = dist.get_local_rank() - + # Load tensors from files and reshape them to flat vectors loaded_state = torch.load(os.path.join(folder, f"{key}.pt"), weights_only=False) if isinstance(loaded_state, dict): From 724a48066beec875f596398d1ab7dade593c9744 Mon Sep 17 00:00:00 2001 From: Schwidola0607 Date: Fri, 27 Jun 2025 01:30:21 -0500 Subject: [PATCH 11/12] Minor comment fix Signed-off-by: Schwidola0607 --- deepspeed/runtime/engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index cf38b601c915..a0b212f524ea 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2947,7 +2947,7 @@ def load_checkpoint(self, Returns: A tuple of ``load_path`` and ``client_state``. - *``load_path``: Path of the loaded checkpoint. ``None`` if loading the checkpoint failed or loading a HF based UCP + *``load_path``: Path of the loaded checkpoint. ``None`` if loading the checkpoint failed or loading a HF based UCP. *``client_state``: State dictionary used for loading required training states in the client code. Important: under ZeRO3, one cannot load checkpoint with ``engine.load_checkpoint()`` right From 64f5bc3f5a184871684a4a2fe175d96c5475c187 Mon Sep 17 00:00:00 2001 From: Schwidola0607 Date: Fri, 27 Jun 2025 01:32:36 -0500 Subject: [PATCH 12/12] remove document Signed-off-by: Schwidola0607 --- docs/_tutorials/hugging-face-to-ucp.md | 36 -------------------------- 1 file changed, 36 deletions(-) delete mode 100644 docs/_tutorials/hugging-face-to-ucp.md diff --git a/docs/_tutorials/hugging-face-to-ucp.md b/docs/_tutorials/hugging-face-to-ucp.md deleted file mode 100644 index 31ef903f2a89..000000000000 --- a/docs/_tutorials/hugging-face-to-ucp.md +++ /dev/null @@ -1,36 +0,0 @@ ---- -title: "Converting a Hugging Face checkpoint to Universal Checkpointing format" -tags: checkpointing, training, deepspeed, huggingface ---- - -## Introduction to Universal Checkpointing - -Universal Checkpointing in DeepSpeed abstracts away the complexities of saving and loading model states, optimizer states, and training scheduler states. This feature is designed to work out of the box with minimal configuration, supporting a wide range of model sizes and types, from small-scale models to large, distributed models with different parallelism topologies trained across multiple GPUs and other accelerators. - -See more: https://www.deepspeed.ai/tutorials/universal-checkpointing/ - -## Converting a pretrained Hugging Face checkpoint to Universal Checkpointing format - -### Step 1: Download a pretrained Hugging Face checkpoint -Download a pretrained Hugging Face checkpoint from the Hugging Face Hub using [snapshot_download](https://huggingface.co/docs/huggingface_hub/en/guides/download) - -Hugging Face checkpoints are one or many files in the `pytorch_model.bin` or `safetensors format`. - -### Step 2: Convert Hugging Face checkpoint to Universal Checkpointing format - -To convert a Hugging Face checkpoint to Universal Checkpointing format, you can use the `hf_to_universal.py` script provided in the DeepSpeed repository. This script will take a Hugging Face checkpoint directory and convert it to a Universal Checkpointing format. - -```bash -python deepspeed/checkpoint/hf_to_universal.py --hf_checkpoint_dir /path/to/huggingface/checkpoint --save_dir /path/to/universal/checkpoint -``` - -This script will process the Hugging Face checkpoint and generate a new checkpoint in the Universal Checkpointing format. Note that `hf_to_universal.py` script supports both `.safetensors` and `pytorch.bin` checkpoint format. Use `--safe_serialization` flag to convert from `.safetensors` format. - -### Step 3: Resume Training with Universal Checkpoint -With the Universal checkpoint ready, you can now resume training on potentially with different parallelism topologies or training configurations. To do this add `--universal-checkpoint` to your DeepSpeed config (json) file - - -## Conclusion -DeepSpeed Universal Checkpointing simplifies the management of model states, making it easier to save, load, and transfer model states across different training sessions and parallelism techniques. By converting a Hugging Face checkpoint to Universal Checkpointing format, you can load pretrained weights of any model in the Hugging Face Hub and resume training with DeepSpeed under any parallelism topologies. - -For more detailed examples and advanced configurations, please refer to the [Megatron-DeepSpeed examples](https://github.com/deepspeedai/Megatron-DeepSpeed/tree/main/examples_deepspeed/universal_checkpointing).