Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 52 additions & 1 deletion src/transformers/integrations/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,27 @@ def deepspeed_init(trainer, num_training_steps, inference=False):
return optimizer, lr_scheduler


def deepspeed_load_checkpoint(deepspeed_engine, checkpoint_path, load_module_strict=True):
def convert_zero_checkpoint_to_universal_checkpoint(input_path, output_path, num_workers):
import argparse

from deepspeed.checkpoint.ds_to_universal import main as ds_to_universal_main

param_dict = {
"input_folder": input_path,
"output_folder": output_path,
"num_extract_workers": num_workers,
"num_merge_workers": num_workers // 2,
"keep_temp_folder": False,
"strict": True,
"inject_missing_state": True,
}
args = argparse.Namespace(**param_dict)
ds_to_universal_main(args)


def deepspeed_load_checkpoint(
deepspeed_engine, checkpoint_path, load_module_strict=True, convert_deepspeed_universal_checkpoint=False
):
# it's possible that the user is trying to resume from model_path, which doesn't necessarily
# contain a deepspeed checkpoint. e.g. examples just check if the dir exists and assume it's
# a resume from a checkpoint and not just a local pretrained weight. So we check here if the
Expand All @@ -537,6 +557,37 @@ def deepspeed_load_checkpoint(deepspeed_engine, checkpoint_path, load_module_str

if len(deepspeed_checkpoint_dirs) > 0:
logger.info(f"Attempting to resume from {checkpoint_path}")

if convert_deepspeed_universal_checkpoint:
assert len(deepspeed_checkpoint_dirs) == 1
import os

deepspeed_engine._config.load_universal_checkpoint = True
ckpt_list = deepspeed_engine._get_all_ckpt_names(
checkpoint_path, os.path.basename(deepspeed_checkpoint_dirs[0])
)
# We can get loaded_checkpoint_dp_world_size from any model file.
sd = deepspeed_engine.checkpoint_engine.load(ckpt_list[0], map_location="cpu")
loaded_checkpoint_dp_world_size = sd["dp_world_size"]

if loaded_checkpoint_dp_world_size != deepspeed_engine.dp_world_size:
deepspeed_engine._config.load_universal_checkpoint = True
if deepspeed_engine.global_rank == 0:
convert_zero_checkpoint_to_universal_checkpoint(
deepspeed_checkpoint_dirs[0],
os.path.join(checkpoint_path, "universal_" + os.path.basename(deepspeed_checkpoint_dirs[0])),
loaded_checkpoint_dp_world_size,
)
logger.info(
f"Converted deepspeed checkpoint at {checkpoint_path} to universal format for "
f"current world size {deepspeed_engine.dp_world_size}"
)
from deepspeed import comm as dist

dist.barrier()
else:
deepspeed_engine._config.load_universal_checkpoint = False

# this magically updates self.optimizer and self.lr_scheduler
load_path, _ = deepspeed_engine.load_checkpoint(
checkpoint_path,
Expand Down
6 changes: 5 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2501,7 +2501,10 @@ def _inner_training_loop(
if resume_from_checkpoint is not None:
if self.is_deepspeed_enabled:
deepspeed_load_checkpoint(
self.model_wrapped, resume_from_checkpoint, load_module_strict=not _is_peft_model(self.model)
self.model_wrapped,
resume_from_checkpoint,
load_module_strict=not _is_peft_model(self.model),
convert_deepspeed_universal_checkpoint=args.convert_deepspeed_universal_checkpoint,
)
elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled:
self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped)
Expand Down Expand Up @@ -3050,6 +3053,7 @@ def _load_best_model(self):
self.model_wrapped,
self.state.best_model_checkpoint,
load_module_strict=not _is_peft_model(self.model),
convert_deepspeed_universal_checkpoint=self.args.convert_deepspeed_universal_checkpoint,
)
elif self.is_fsdp_enabled:
load_result = load_fsdp_model(
Expand Down
10 changes: 10 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1469,6 +1469,16 @@ class TrainingArguments:
},
)

convert_deepspeed_universal_checkpoint: Optional[bool] = field(
default=False,
metadata={
"help": (
"Whether or not to convert deepspeed zero checkpoint to universal checkpoint when "
"loaded world size is changed."
)
},
)

def __post_init__(self):
# Set default output_dir if not provided
if self.output_dir is None:
Expand Down