Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 19 additions & 19 deletions deepspeed/checkpoint/ds_to_universal.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,21 +152,21 @@ def extract_zero_shards(dir, ds_checkpoint, indices_3D):
def extract_zero_shards_stage3(optim_files, param_shapes, dp_degree, temp_dir, dp_index):
state_dict = torch.load(optim_files[dp_index], map_location='cpu', weights_only=False)

flat_state = dict(
exp_avg=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][0]["exp_avg"],
exp_avg_sq=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][0]["exp_avg_sq"],
fp32=state_dict[OPTIMIZER_STATE_DICT]['fp32_flat_groups'][0],
)

offset = 0
for name, shape in param_shapes.items():
unpartitioned_numel = shape.numel()
partitioned_numel, _ = _zero_partitioned_param_info(unpartitioned_numel, dp_degree)
padding_free_numel = min(partitioned_numel, abs(unpartitioned_numel - dp_index * partitioned_numel))
for state_key in flat_state.keys():
dump_param_fragment(temp_dir, 0, dp_index, state_key, flat_state[state_key], name, offset,
padding_free_numel)
offset += partitioned_numel
for idx, sub_group_shape in enumerate(param_shapes):
flat_state = dict(
exp_avg=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][idx]["exp_avg"],
exp_avg_sq=state_dict[OPTIMIZER_STATE_DICT]['optimizer_state_dict']['state'][idx]["exp_avg_sq"],
fp32=state_dict[OPTIMIZER_STATE_DICT]['fp32_flat_groups'][idx],
)
offset = 0
for name, shape in sub_group_shape.items():
unpartitioned_numel = shape.numel()
partitioned_numel, _ = _zero_partitioned_param_info(unpartitioned_numel, dp_degree)
padding_free_numel = min(partitioned_numel, abs(unpartitioned_numel - dp_index * partitioned_numel))
for state_key in flat_state.keys():
dump_param_fragment(temp_dir, 0, dp_index, state_key, flat_state[state_key], name, offset,
padding_free_numel)
offset += partitioned_numel


cnt = 0
Expand Down Expand Up @@ -390,10 +390,10 @@ def _merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir):
print(f'Warning: Unused patterns={unmatched_patterns} while merging tp slices')


def _merge_zero3_slice_files(args, param_shapes, dp_degree, temp_dir):
def _merge_zero3_slice_files(args, param_keys, dp_degree, temp_dir):
zero_output_folder = os.path.join(args.output_folder, "zero")
do_work = partial(merge_zero3_slices, dp_degree, zero_output_folder, temp_dir)
_do_parallel_work(do_work, param_shapes.keys(), args.num_merge_workers)
_do_parallel_work(do_work, param_keys, args.num_merge_workers)


def _zero_partitioned_param_info(unpartitioned_numel, world_size):
Expand Down Expand Up @@ -514,7 +514,6 @@ def main(args):
else:
model_files = _get_model_state_files(args.input_folder)
param_shapes = _parse_model_states_stage3(model_files)
param_shapes = {k: v for d in param_shapes for k, v in d.items()}
dp_degree = len(model_files)

temp_dir = os.path.join(args.output_folder, 'tmp')
Expand All @@ -523,7 +522,8 @@ def main(args):
_extract_zero_shard_files_stage3(args, optim_files, param_shapes, dp_degree, temp_dir)

print('*** 2. Merging slices .....')
_merge_zero3_slice_files(args, param_shapes, dp_degree, temp_dir)
param_keys = {key for sub_group_shapes in param_shapes for key in sub_group_shapes.keys()}
_merge_zero3_slice_files(args, param_keys, dp_degree, temp_dir)

print('*** 3. Saving common optimizer states')
_save_optimizer_state_stage3(args, optim_files)
Expand Down
12 changes: 7 additions & 5 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1838,6 +1838,7 @@ def _configure_zero_optimizer(self, optimizer):
optimizer = Stage3ZeroOptimizer(
self.module,
optimizer,
self.param_names,
timers=timers,
ds_config=self.config,
static_loss_scale=self.loss_scale(),
Expand Down Expand Up @@ -1886,6 +1887,7 @@ def _return_mics_optimizer(self, basic_optimizer, timers):
model_dtype, gradient_accumulation_dtype = self.get_data_types()
optimizer = MiCS_Optimizer(self.module,
basic_optimizer,
self.param_names,
timers=timers,
ds_config=self.config,
static_loss_scale=self.loss_scale(),
Expand Down Expand Up @@ -3010,10 +3012,7 @@ def _get_ckpt_name(self, checkpoints_path, tag, mp_placeholder=None):
mp_rank_str = f"{mp_rank:02d}"

if self.zero_optimization_partition_weights():
if self.load_universal_checkpoint():
filename = "zero_pp_rank_0"
else:
filename = "zero_pp_rank_{}".format(dist.get_rank(group=self.optimizer.dp_process_group))
filename = "zero_pp_rank_{}".format(dist.get_rank(group=self.optimizer.dp_process_group))
ckpt_name = os.path.join(
checkpoints_path,
str(tag),
Expand Down Expand Up @@ -3053,7 +3052,10 @@ def _get_all_ckpt_names(self, checkpoints_path, tag):

ckpt_files = glob.glob(ckpt_file_pattern)
ckpt_files.sort()
return ckpt_files
if self.load_universal_checkpoint():
return [ckpt_files[0]]
else:
return ckpt_files

def load_checkpoint(self,
load_dir,
Expand Down
3 changes: 2 additions & 1 deletion deepspeed/runtime/zero/mics.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ class MiCS_Optimizer(DeepSpeedZeroOptimizer_Stage3):
def __init__(self,
module,
init_optimizer,
param_names,
timers,
ds_config,
static_loss_scale=1,
Expand Down Expand Up @@ -398,7 +399,7 @@ def __init__(self,
aio_config=None):

log_dist("Init MiCS optimizer", ranks=[0])
super().__init__(module, init_optimizer, timers, ds_config, static_loss_scale, dynamic_loss_scale,
super().__init__(module, init_optimizer, param_names, timers, ds_config, static_loss_scale, dynamic_loss_scale,
dynamic_loss_args, verbose, contiguous_gradients, reduce_bucket_size, prefetch_bucket_size,
max_reuse_distance, max_live_parameters, param_persistence_threshold,
model_persistence_threshold, dp_process_group, reduce_scatter, overlap_comm,
Expand Down
54 changes: 33 additions & 21 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def __init__(
self,
module,
init_optimizer,
param_names,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights,
param_shapes):
self.load_hp_checkpoint_state_from_checkpoint_dir_stage3(checkpoint_folder, param_shapes)
def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, param_shapes):

In earlier versions we passed param_shapes to supply parameter names for each subgroup. Could we reuse param_shapes here instead of introducing a new param_names argument, which seems redundant? Or is there a strong reason to replace param_shapes with param_names?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a critical question!

But No, we can not use param_shapes! Because the param_shapes is from checkpoint, param_shapes comes from the checkpoint, that is from the last distributed environment, that is the sub group (or fp16_group) is used in last distributed environment.

If the world size is changed, the sub group (or fp16_group) will changed. Here we set optimizer parameter for current sub group (or fp16_group).

In fact, I also used param_shapes at first. But I found this problem when my new unit test (which sub_group_size is set to 100) failed to run.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhengchenyu Make sense to me. Could you remove the param_shapes to clean up the codebase a bit? Since we’re using param_names, param_shapes is no longer needed..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I have remove param_shapes.

timers,
ds_config,
static_loss_scale=1.0,
Expand Down Expand Up @@ -200,6 +201,7 @@ def __init__(
raise SystemError("Cannot use fp16 without accelerator.")

self.optimizer = init_optimizer
self.param_names = param_names

# Use torch (un)flatten ops
self.flatten = _flatten_dense_tensors
Expand Down Expand Up @@ -2806,8 +2808,7 @@ def load_state_dict(self,
raise NotImplementedError("ZeRO-3 does not yet support elastic checkpointing, please disable for now.")

if checkpoint_folder:
self._load_universal_checkpoint(checkpoint_folder, load_optimizer_states, load_from_fp32_weights,
param_shapes)
self._load_universal_checkpoint(checkpoint_folder, load_optimizer_states, load_from_fp32_weights)
else:
self._rigid_load_state_dict(state_dict_list[dist.get_rank(group=self.dp_process_group)],
load_optimizer_states=load_optimizer_states)
Expand All @@ -2828,11 +2829,10 @@ def load_state_dict(self,
self.persistent_parameters[0].partition(self.persistent_parameters)
# self.persistent_parameters[0].all_gather(self.persistent_parameters) # this will be done in checkpoint_event_epilogue() so remove it to prevent double all_gather

def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights,
param_shapes):
self.load_hp_checkpoint_state_from_checkpoint_dir_stage3(checkpoint_folder, param_shapes)
def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights):
self.load_hp_checkpoint_state_from_checkpoint_dir_stage3(checkpoint_folder)

def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, param_shapes):
def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir):
""" Load optimizer and model states from the checkpoint directory. """
checkpoint_dir = os.path.join(checkpoint_dir, "zero")
optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt")
Expand All @@ -2842,18 +2842,34 @@ def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, pa
optim_sd = torch.load(optim_state_path, weights_only=False)
self._load_global_state_stage3(optim_sd)

key_list = ["fp32", "exp_avg", "exp_avg_sq"]
# Generally the step of each optimizer file should be the same, we can obtain from any parameter.
state_step = optim_sd[OPTIMIZER_STATE_DICT]['state'][0]['step']
for key in ["fp32", "exp_avg", "exp_avg_sq"]:
for sub_group_id, fp16_group in enumerate(self.fp16_groups):
fp32_param = self.fp32_partitioned_groups_flat[sub_group_id]
key_tensor = torch.zeros_like(fp32_param)
offset = 0
for param in fp16_group:
if param not in self.param_names:
raise ValueError(f"failed to find optimizer param in named params")
param_name = self.param_names[param]
key_layer_state_partition = self.load_hp_checkpoint_state(os.path.join(checkpoint_dir, param_name),
key)
key_tensor.narrow(0, offset, key_layer_state_partition.numel()).copy_(key_layer_state_partition)
offset += key_layer_state_partition.numel()
if key == "fp32":
self.fp32_partitioned_groups_flat[sub_group_id].data.copy_(key_tensor)
self.optimizer.state[fp32_param]['step'] = state_step
else:
self.optimizer.state[fp32_param][key] = key_tensor

for key in key_list:
key_tensor = torch.empty(0)
for layer in param_shapes[0].keys():
key_layer_state_partition = self.load_hp_checkpoint_state(os.path.join(checkpoint_dir, layer), key)
key_tensor = torch.cat((key_tensor, key_layer_state_partition))
if key == "fp32":
self.fp32_partitioned_groups_flat[0].data.copy_(key_tensor)
self.optimizer.param_groups[0]['params'].append(self.fp32_partitioned_groups_flat[0])
else:
optim_sd[OPTIMIZER_STATE_DICT]['state'][0][key] = key_tensor
for param_group in self.optimizer.param_groups:
# Generally, the hyperparameters of each parameter should be the same, we can obtain from any parameter.
for key, value in optim_sd[OPTIMIZER_STATE_DICT]["param_groups"][0].items():
if key == 'params':
param_group['params'] = []
else:
param_group[key] = value

if self.swap_optimizer:
# Purge the swapped optimizer state, it was initialized to the freshly created model and not the checkpoint
Expand All @@ -2869,10 +2885,6 @@ def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, pa
self._release_sub_group(sub_group_id, timer_names)
self._post_step(timer_names)

self.optimizer.load_state_dict(optim_sd[OPTIMIZER_STATE_DICT])
for param_group in self.optimizer.param_groups:
param_group['params'] = []

for sub_group_id in range(len(self.fp32_partitioned_groups_flat)):
fp32_param = self.fp32_partitioned_groups_flat[sub_group_id]
if sum(fp32_param.size()) > 0:
Expand Down
67 changes: 56 additions & 11 deletions tests/unit/checkpoint/test_universal_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

# DeepSpeed Team

import os
import math

import deepspeed
from types import SimpleNamespace
from torch.utils._pytree import tree_map
Expand Down Expand Up @@ -72,13 +75,13 @@ def init_ds_engine(model, ds_config, use_torch_adam):
return model


def train_save_convert(ds_config, hidden_dim, load_optim, use_torch_adam, dtype, tmpdir):
def train_save_convert(ds_config, hidden_dim, load_optim, use_torch_adam, dtype, tmpdir, world_size):
if dtype == torch.bfloat16 and not bf16_required_version_check():
return

test_step = 8

model = SimpleModel(hidden_dim)
model = SimpleModel(hidden_dim, nlayers=2)
model = init_ds_engine(model, ds_config, use_torch_adam)
data_loader = random_dataloader(model=model,
total_samples=test_step,
Expand Down Expand Up @@ -124,6 +127,7 @@ def train_save_convert(ds_config, hidden_dim, load_optim, use_torch_adam, dtype,
model.optimizer._set_fp32_optimizer_param_groups()
optimizer_state = gather_opt_state(model.optimizer.optimizer.state_dict())
model.optimizer._clear_fp32_optimizer_param_groups()
update_gathered_stage3_optimizer(optimizer_state, model._get_zero_param_shapes(), world_size)
else:
optimizer_state = gather_opt_state(model.optimizer.optimizer.state_dict())

Expand All @@ -135,7 +139,7 @@ def train_save_convert(ds_config, hidden_dim, load_optim, use_torch_adam, dtype,


@pytest.fixture
def ds_config(zero_stage, dtype):
def ds_config(zero_stage, dtype, sub_group_size):
ds_config = {
"train_batch_size": 8,
"optimizer": {
Expand All @@ -149,6 +153,8 @@ def ds_config(zero_stage, dtype):
ds_config["fp16"] = {"enabled": True, "initial_scale_power": 8}
elif dtype == torch.bfloat16:
ds_config["bf16"] = {"enabled": True}
if sub_group_size > 0:
ds_config["zero_optimization"]["sub_group_size"] = sub_group_size
return ds_config


Expand All @@ -157,7 +163,7 @@ class _baseline(DistributedFixture):

def run(self, tmpdir, ds_config, zero_stage, dtype, load_optim, use_torch_adam):
hidden_dim = 10
train_save_convert(ds_config, hidden_dim, load_optim, use_torch_adam, dtype, tmpdir)
train_save_convert(ds_config, hidden_dim, load_optim, use_torch_adam, dtype, tmpdir, self.world_size)


class baseline_ws2(_baseline):
Expand All @@ -168,13 +174,46 @@ class baseline_ws4(_baseline):
world_size = 4


# Stage3 use shard parameter, need to reorganize the optimizer parameters.
def update_gathered_stage3_optimizer(optimizer_state, param_shapes, world_size):
for sub_group_id, group in enumerate(optimizer_state["param_groups"]):
group["params"] = None

new_state = {}
for sub_group_id, sub_group_param_shape in enumerate(param_shapes):
total_numel = optimizer_state['state'][sub_group_id]['exp_avg'].numel()
assert total_numel % world_size == 0
numel_per_rank = total_numel // world_size
param_offset_in_current_rank = 0
for param_name, param_shape in sub_group_param_shape.items():
param_numel = param_shape.numel()
param_partition_numel = math.ceil(param_numel / world_size)
param_optimizer_tensor = {
"exp_avg": torch.zeros(param_numel),
"exp_avg_sq": torch.zeros(param_numel),
"step": optimizer_state['state'][sub_group_id]['step'],
}
for key in ["exp_avg", "exp_avg_sq"]:
write_offset = 0
for rank in range(world_size):
offset = param_offset_in_current_rank + rank * numel_per_rank
length = min(param_partition_numel, param_numel - rank * param_partition_numel)
tmp = optimizer_state['state'][sub_group_id][key].narrow(0, offset, length)
param_optimizer_tensor[key].narrow(0, write_offset, length).copy_(tmp)
write_offset += length
param_offset_in_current_rank += param_partition_numel
new_state[param_name] = param_optimizer_tensor
optimizer_state["state"] = new_state


@pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16, torch.float32])
@pytest.mark.parametrize("zero_stage", [1, 3])
@pytest.mark.parametrize("use_torch_adam", [False, True])
@pytest.mark.parametrize("load_optim", [False, True])
@pytest.mark.parametrize("sub_group_size", [-1, 100])
class TestZeROUniversalCheckpointDP(DistributedTest):

def _run_test(self, tmpdir, dtype, ds_config, load_optim, use_torch_adam):
def _run_test(self, tmpdir, dtype, ds_config, load_optim, use_torch_adam, world_size):
if dtype == torch.bfloat16 and not bf16_required_version_check():
pytest.skip(
" DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly"
Expand All @@ -184,15 +223,21 @@ def _run_test(self, tmpdir, dtype, ds_config, load_optim, use_torch_adam):
loaded_model_state, loaded_optimizer_state = torch.load(f"{tmpdir}/baseline_state.pt", weights_only=False)

ds_config["checkpoint"] = {"load_universal": True}
univ_model = SimpleModel(hidden_dim)
univ_model = SimpleModel(hidden_dim, nlayers=2)
univ_model = init_ds_engine(univ_model, ds_config, use_torch_adam)
univ_model.load_checkpoint(tmpdir, tag=f"{CP_TAG}_universal", load_optimizer_states=load_optim)

model_state = univ_model.state_dict()
compare_state_dicts(model_state, loaded_model_state)

if load_optim and ds_config["zero_optimization"]["stage"] != 3:
optimizer_state = gather_opt_state(univ_model.optimizer.optimizer.state_dict())
if load_optim:
if ds_config["zero_optimization"]["stage"] == 3:
univ_model.optimizer._set_fp32_optimizer_param_groups()
optimizer_state = gather_opt_state(univ_model.optimizer.optimizer.state_dict())
univ_model.optimizer._clear_fp32_optimizer_param_groups()
update_gathered_stage3_optimizer(optimizer_state, univ_model._get_zero_param_shapes(), world_size)
else:
optimizer_state = gather_opt_state(univ_model.optimizer.optimizer.state_dict())
# padding sizes may differ when dp sizes are different
param_count = sum(p.numel() for p in univ_model.parameters())
optimizer_state = remove_pad_in_opt_state(optimizer_state, param_count)
Expand All @@ -216,12 +261,12 @@ def _run_test(self, tmpdir, dtype, ds_config, load_optim, use_torch_adam):

@pytest.mark.world_size(2)
def test_dp_world_size_2to2(self, baseline_ws2, tmpdir, dtype, ds_config, load_optim, use_torch_adam):
self._run_test(tmpdir, dtype, ds_config, load_optim, use_torch_adam)
self._run_test(tmpdir, dtype, ds_config, load_optim, use_torch_adam, 2)

@pytest.mark.world_size(2)
def test_dp_world_size_4to2(self, baseline_ws4, tmpdir, dtype, ds_config, load_optim, use_torch_adam):
self._run_test(tmpdir, dtype, ds_config, load_optim, use_torch_adam)
self._run_test(tmpdir, dtype, ds_config, load_optim, use_torch_adam, 2)

@pytest.mark.world_size(4)
def test_dp_world_size_2to4(self, baseline_ws2, tmpdir, dtype, ds_config, load_optim, use_torch_adam):
self._run_test(tmpdir, dtype, ds_config, load_optim, use_torch_adam)
self._run_test(tmpdir, dtype, ds_config, load_optim, use_torch_adam, 4)