Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3004,15 +3004,20 @@ def _get_zero_ckpt_name(self, checkpoints_path, tag):
bf16_mode = self.bfloat16_enabled()
return self._get_rank_zero_ckpt_name(checkpoints_path, tag, mp_rank, pp_rank, bf16_mode)

def _get_ckpt_name(self, checkpoints_path, tag, mp_placeholder=None):
def _get_ckpt_name(self, checkpoints_path, tag, mp_placeholder=None, pp_placeholder=None):
if mp_placeholder is not None:
mp_rank_str = mp_placeholder
else:
mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()
mp_rank_str = f"{mp_rank:02d}"

if self.zero_optimization_partition_weights():
filename = "zero_pp_rank_{}".format(dist.get_rank(group=self.optimizer.dp_process_group))
if pp_placeholder is not None:
pp_rank = pp_placeholder
else:
pp_rank = dist.get_rank(group=self.optimizer.dp_process_group)

filename = "zero_pp_rank_{}".format(pp_rank)
ckpt_name = os.path.join(
checkpoints_path,
str(tag),
Expand Down Expand Up @@ -3047,15 +3052,15 @@ def _get_expert_ckpt_name(checkpoints_path, layer_id, expert_id, tag, mpu=None):

def _get_all_ckpt_names(self, checkpoints_path, tag):
# It is required that (checkpoints_path, tag) are consistent among all ranks.
ckpt_file_pattern = self._get_ckpt_name(checkpoints_path, tag, mp_placeholder="*")
ckpt_file_pattern = self._get_ckpt_name(checkpoints_path,
tag,
mp_placeholder="*",
pp_placeholder="0" if self.load_universal_checkpoint() else None)
import glob

ckpt_files = glob.glob(ckpt_file_pattern)
ckpt_files.sort()
if self.load_universal_checkpoint():
return [ckpt_files[0]]
else:
return ckpt_files
return ckpt_files

def load_checkpoint(self,
load_dir,
Expand Down