Skip to content
Open
225 changes: 225 additions & 0 deletions deepspeed/checkpoint/hf_to_universal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import torch
import os
import shutil
import logging
from concurrent.futures import ProcessPoolExecutor
from deepspeed.accelerator import get_accelerator
from tqdm import tqdm
from typing import List

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Hard-coded constants for parameter patterns
VOCAB_PARAMETER_PATTERNS = [
'word_embeddings',
'embed_tokens',
'embedding',
'wte', # GPT style embeddings
'lm_head' # Language model head, often tied with embeddings
]


def get_parameter_type(name: str) -> dict:
"""Determine parameter type and required fields based on name."""
param_info = {
'cat_dim': 0 # Default concatenation dimension
}

# Check for vocabulary tensors (embeddings, etc.)
if any(pattern in name.lower() for pattern in VOCAB_PARAMETER_PATTERNS):
param_info['vocab_tensor'] = True

# TODO: figure out if we need to check for row-parallel parameters
return param_info


if __name__ == '__main__':
import argparse

parser = argparse.ArgumentParser(description='Convert HuggingFace checkpoint to Universal Checkpoint format')
parser.add_argument('--hf_checkpoint_dir',
type=str,
required=True,
help='Path to the HuggingFace checkpoint directory')
parser.add_argument('--safe_serialization',
action='store_true',
default=False,
help='Use safetensors for serialization')
parser.add_argument('--num_workers', type=int, default=4, help='Number of workers to use for saving checkpoints')
parser.add_argument('--save_dir', type=str, required=True, help='Directory to save checkpoints')
args = parser.parse_args()

# Create a temporary directory for atomic operations
temp_save_dir = args.save_dir + '.tmp'

def save_parameter(name: str, param: torch.Tensor, save_dir: str):
"""Save a parameter and its optimizer states in universal format."""
# Create parameter directory under zero/
param_dir = os.path.join(save_dir, name)
os.makedirs(param_dir, exist_ok=True)

# Get parameter type and required fields
param_info = get_parameter_type(name)

# Save parameter in fp32 with proper dictionary structure
param_path = os.path.join(param_dir, "fp32.pt")
param_dict = {
'param': param.to(torch.float32), # Main tensor goes in 'param' field
**param_info # Include all determined parameter info
}
torch.save(param_dict, param_path)

# Since HuggingFace checkpoints do not have optimizer states,
# we initialize them with zeros
for state in ("exp_avg", "exp_avg_sq"):
state_path = os.path.join(param_dir, f"{state}.pt")
state_dict = {
'param': torch.zeros_like(param, dtype=torch.float32),
**param_info # Include same parameter info in optimizer states
}
torch.save(state_dict, state_path)

def process_shard(shard_file, checkpoint_dir, save_dir, safe_serialization):
"""Process a single shard file."""
try:
shard_path = os.path.join(checkpoint_dir, shard_file)
logger.info(f"Loading shard from: {shard_path}")

if safe_serialization:
from safetensors.torch import load_file
shard_dict = load_file(shard_path)
else:
shard_dict = torch.load(shard_path, map_location='cpu')

# Create progress bar for parameters within this shard
pbar = tqdm(total=len(shard_dict),
desc=f"Processing {os.path.basename(shard_file)}",
position=1,
leave=False)

for key, param in shard_dict.items():
save_parameter(key, param, save_dir)
del param
pbar.update(1)
pbar.set_postfix({'key': key[:20] + '...' if len(key) > 20 else key})

pbar.close()
del shard_dict
get_accelerator().empty_cache()
logger.info(f"Completed processing shard: {shard_file}")

except Exception as e:
logger.error(f"Error processing shard {shard_file}: {str(e)}")
raise

def get_shard_list(checkpoint_dir):
"""Get list of shards from index file."""
if args.safe_serialization:
index_file = os.path.join(checkpoint_dir, "model.safetensors.index.json")
else:
index_file = os.path.join(checkpoint_dir, "pytorch_model.bin.index.json")

if os.path.exists(index_file):
import json
with open(index_file, 'r') as f:
index = json.load(f)
return list(set(index['weight_map'].values()))
else:
# Handle single file case
if args.safe_serialization and os.path.exists(os.path.join(checkpoint_dir, "model.safetensors")):
return ["model.safetensors"]
elif os.path.exists(os.path.join(checkpoint_dir, "pytorch_model.bin")):
return ["pytorch_model.bin"]
else:
raise FileNotFoundError(f"No checkpoint files found in {checkpoint_dir}")

def process_shard_batch(shard_files: List[str], checkpoint_dir: str, save_dir: str, safe_serialization: bool):
"""Process a batch of shards in parallel."""
with ProcessPoolExecutor(max_workers=args.num_workers) as executor:
futures = []
for shard_file in shard_files:
future = executor.submit(process_shard, shard_file, checkpoint_dir, save_dir, safe_serialization)
futures.append((shard_file, future))

# Create progress bar for this batch
batch_pbar = tqdm(total=len(futures), desc=f"Processing shard batch", position=0, leave=True)

# Wait for all futures to complete
for shard_file, future in futures:
try:
future.result() # This will raise any exceptions that occurred
batch_pbar.update(1)
batch_pbar.set_postfix({'last_completed': os.path.basename(shard_file)})
except Exception as e:
logger.error(f"Failed processing shard {shard_file}: {str(e)}")
raise

batch_pbar.close()

try:
# Create zero subdirectory in temp directory
temp_zero_dir = os.path.join(temp_save_dir, 'zero')
if os.path.exists(temp_zero_dir):
logger.info(f"Removing existing temp directory: {temp_zero_dir}")
shutil.rmtree(temp_zero_dir)

shard_files = get_shard_list(args.hf_checkpoint_dir)
total_shards = len(shard_files)
logger.info(f"Found {total_shards} shards to process")
# Process shards in batches equal to the number of workers
batch_size = args.num_workers
for i in range(0, total_shards, batch_size):
batch_shards = shard_files[i:i + batch_size]
logger.info(
f"Processing batch of {len(batch_shards)} shards ({i+1}-{min(i+batch_size, total_shards)} of {total_shards})"
)
process_shard_batch(
batch_shards,
args.hf_checkpoint_dir,
temp_zero_dir, # Changed from temp_save_dir to temp_zero_dir
args.safe_serialization)

# Clear CUDA cache after each batch to free up memory
get_accelerator().empty_cache()

logger.info("All shard batches processed successfully")

final_save_dir = os.path.join(args.save_dir, 'zero')
if os.path.exists(final_save_dir):
shutil.rmtree(final_save_dir)

# Create the parent directory if it doesn't exist
os.makedirs(os.path.dirname(final_save_dir), exist_ok=True)
# Move the zero directory to its final location
os.rename(temp_zero_dir, final_save_dir)

# Clean up the temporary directory
if os.path.exists(temp_save_dir):
shutil.rmtree(temp_save_dir)

# Write identifier file
with open(os.path.join(args.save_dir, 'source.txt'), 'w') as f:
f.write("Huggingface checkpoint")

logger.info(f"Successfully saved checkpoint to {final_save_dir}")

# Update latest file
checkpoint_root_folder = os.path.dirname(args.save_dir)
step_folder = os.path.basename(args.save_dir)
latest_file = os.path.join(checkpoint_root_folder, 'latest_universal')
with open(latest_file, 'w') as f:
f.write(step_folder)

logger.info(f"Checkpoint conversion completed successfully. Latest file updated at {latest_file}")

except Exception as e:
logger.error(f"Failed to process checkpoint: {str(e)}")
if os.path.exists(temp_save_dir):
shutil.rmtree(temp_save_dir)
raise
19 changes: 12 additions & 7 deletions deepspeed/runtime/base_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@ class ZeROOptimizer(DeepSpeedOptimizer):
def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, checkpoint_dir: str) -> None:
checkpoint_dir = os.path.join(checkpoint_dir, "zero")
optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt")
assert os.path.isfile(
optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.'
optim_sd = torch.load(optim_state_path, weights_only=False)

self._load_global_state(optim_sd)
if os.path.isfile(optim_state_path):
ignore_missing_optim_state = False
optim_sd = torch.load(optim_state_path, weights_only=False)
self._load_global_state(optim_sd)
else:
logger.warning(f'{optim_state_path} containing optimizer global state is missing!')
ignore_missing_optim_state = True
optim_sd = {}

tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu)
if self.mpu is None:
Expand All @@ -35,8 +38,7 @@ def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, chec
tp_world_size = self.mpu.get_slice_parallel_world_size() if hasattr(self.mpu, "get_slice_parallel_world_size") \
else self.mpu.get_tensor_model_parallel_world_size()

for i, (param_group,
loaded_param_group) in enumerate(zip(self.optimizer.param_groups, optim_sd['param_groups'])):
for i, param_group in enumerate(self.optimizer.param_groups):
# We have an assumption that all params in the same param_group have the same keys
opt_keys = set()
steps = []
Expand All @@ -58,6 +60,9 @@ def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, chec

map_to_flat_opt_states(hp_param, lp_groups[i], self.optimizer.state, opt_keys)

if ignore_missing_optim_state:
continue
loaded_param_group = optim_sd['param_groups'][i]
for key, value in loaded_param_group.items():
if key == 'params':
continue
Expand Down
15 changes: 12 additions & 3 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2977,7 +2977,7 @@ def _get_all_ckpt_names(self, checkpoints_path, tag):

ckpt_files = glob.glob(ckpt_file_pattern)
ckpt_files.sort()
return ckpt_files
return ckpt_files, ckpt_file_pattern

def load_checkpoint(self,
load_dir,
Expand All @@ -3001,7 +3001,7 @@ def load_checkpoint(self,

Returns:
A tuple of ``load_path`` and ``client_state``.
*``load_path``: Path of the loaded checkpoint. ``None`` if loading the checkpoint failed.
*``load_path``: Path of the loaded checkpoint. ``None`` if loading the checkpoint failed or loading a HF based UCP.
*``client_state``: State dictionary used for loading required training states in the client code.

Important: under ZeRO3, one cannot load checkpoint with ``engine.load_checkpoint()`` right
Expand Down Expand Up @@ -3040,6 +3040,11 @@ def load_checkpoint(self,
custom_load_fn=custom_load_fn)

load_zero_checkpoint = load_path is not None and (self.zero_optimization() or self.bfloat16_enabled())
if self.load_universal_checkpoint():
ucp_ckpt_folder = os.path.join(load_dir, tag)
# UCP load can ignore '*mp' files or '*model_states.pt' but ucp_ckpt_folder must exist
load_zero_checkpoint = os.path.isdir(ucp_ckpt_folder)

if load_zero_checkpoint:
if (load_optimizer_states and not load_module_only) or self.load_universal_checkpoint():
success = self._load_zero_checkpoint(load_dir, tag, load_optimizer_states=load_optimizer_states)
Expand Down Expand Up @@ -3080,7 +3085,11 @@ def _load_checkpoint(self,

from deepspeed.runtime.state_dict_factory import SDLoaderFactory

ckpt_list = self._get_all_ckpt_names(load_dir, tag)
ckpt_list, ckpt_file_pattern = self._get_all_ckpt_names(load_dir, tag)
if self.load_universal_checkpoint() and len(ckpt_list) == 0:
logger.warning(f"Unable to find {ckpt_file_pattern} files in UCP folder {load_dir}")
return None, {}

sd_loader = SDLoaderFactory.get_sd_loader(ckpt_list, checkpoint_engine=self.checkpoint_engine)

is_pipe_parallel = isinstance(self.module, PipelineModule)
Expand Down
30 changes: 19 additions & 11 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2865,11 +2865,13 @@ def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, pa
""" Load optimizer and model states from the checkpoint directory. """
checkpoint_dir = os.path.join(checkpoint_dir, "zero")
optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt")
assert os.path.isfile(
optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.'

optim_sd = torch.load(optim_state_path, weights_only=False)
self._load_global_state_stage3(optim_sd)
if os.path.isfile(optim_state_path):
ignore_missing_optim_state = False
optim_sd = torch.load(optim_state_path, weights_only=False)
self._load_global_state_stage3(optim_sd)
else:
logger.warning(f'{optim_state_path} containing optimizer global state is missing!')
ignore_missing_optim_state = True

key_list = ["fp32", "exp_avg", "exp_avg_sq"]

Expand All @@ -2881,14 +2883,13 @@ def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, pa
if key == "fp32":
self.fp32_partitioned_groups_flat[0].data.copy_(key_tensor)
self.optimizer.param_groups[0]['params'].append(self.fp32_partitioned_groups_flat[0])
else:
elif not ignore_missing_optim_state:
optim_sd[OPTIMIZER_STATE_DICT]['state'][0][key] = key_tensor

if self.swap_optimizer:
# Purge the swapped optimizer state, it was initialized to the freshly created model and not the checkpoint
self.optimizer_swapper.purge_state()

if self.swap_optimizer:
# Touch all parameters to synchronize all buffers
timer_names = set()
self._partition_all_parameters()
Expand All @@ -2898,9 +2899,10 @@ def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, pa
self._release_sub_group(sub_group_id, timer_names)
self._post_step(timer_names)

self.optimizer.load_state_dict(optim_sd[OPTIMIZER_STATE_DICT])
for param_group in self.optimizer.param_groups:
param_group['params'] = []
if not ignore_missing_optim_state:
self.optimizer.load_state_dict(optim_sd[OPTIMIZER_STATE_DICT])
for param_group in self.optimizer.param_groups:
param_group['params'] = []

for sub_group_id in range(len(self.fp32_partitioned_groups_flat)):
fp32_param = self.fp32_partitioned_groups_flat[sub_group_id]
Expand All @@ -2924,7 +2926,13 @@ def load_hp_checkpoint_state(self, folder, key):
local_rank = dist.get_local_rank()

# Load tensors from files and reshape them to flat vectors
loaded_checkpoint_state = torch.load(os.path.join(folder, f"{key}.pt"), weights_only=False).view(-1)
loaded_state = torch.load(os.path.join(folder, f"{key}.pt"), weights_only=False)
if isinstance(loaded_state, dict):
loaded_checkpoint_state = loaded_state['param'].view(-1)
elif isinstance(loaded_state, torch.Tensor):
loaded_checkpoint_state = loaded_state.view(-1)
else:
raise ValueError(f"Unknown type {type(loaded_state)} for loaded state")

# Partition the loaded data according to the local rank
world_size = dist.get_world_size(group=self.dp_process_group)
Expand Down
Loading