diff --git a/src/transformers/integrations/deepspeed.py b/src/transformers/integrations/deepspeed.py index 47d7a7ffcb5f..dd53f3cb6b81 100644 --- a/src/transformers/integrations/deepspeed.py +++ b/src/transformers/integrations/deepspeed.py @@ -526,7 +526,27 @@ def deepspeed_init(trainer, num_training_steps, inference=False): return optimizer, lr_scheduler -def deepspeed_load_checkpoint(deepspeed_engine, checkpoint_path, load_module_strict=True): +def convert_zero_checkpoint_to_universal_checkpoint(input_path, output_path, num_workers): + import argparse + + from deepspeed.checkpoint.ds_to_universal import main as ds_to_universal_main + + param_dict = { + "input_folder": input_path, + "output_folder": output_path, + "num_extract_workers": num_workers, + "num_merge_workers": num_workers // 2, + "keep_temp_folder": False, + "strict": True, + "inject_missing_state": True, + } + args = argparse.Namespace(**param_dict) + ds_to_universal_main(args) + + +def deepspeed_load_checkpoint( + deepspeed_engine, checkpoint_path, load_module_strict=True, convert_deepspeed_universal_checkpoint=False +): # it's possible that the user is trying to resume from model_path, which doesn't necessarily # contain a deepspeed checkpoint. e.g. examples just check if the dir exists and assume it's # a resume from a checkpoint and not just a local pretrained weight. So we check here if the @@ -537,6 +557,37 @@ def deepspeed_load_checkpoint(deepspeed_engine, checkpoint_path, load_module_str if len(deepspeed_checkpoint_dirs) > 0: logger.info(f"Attempting to resume from {checkpoint_path}") + + if convert_deepspeed_universal_checkpoint: + assert len(deepspeed_checkpoint_dirs) == 1 + import os + + deepspeed_engine._config.load_universal_checkpoint = True + ckpt_list = deepspeed_engine._get_all_ckpt_names( + checkpoint_path, os.path.basename(deepspeed_checkpoint_dirs[0]) + ) + # We can get loaded_checkpoint_dp_world_size from any model file. + sd = deepspeed_engine.checkpoint_engine.load(ckpt_list[0], map_location="cpu") + loaded_checkpoint_dp_world_size = sd["dp_world_size"] + + if loaded_checkpoint_dp_world_size != deepspeed_engine.dp_world_size: + deepspeed_engine._config.load_universal_checkpoint = True + if deepspeed_engine.global_rank == 0: + convert_zero_checkpoint_to_universal_checkpoint( + deepspeed_checkpoint_dirs[0], + os.path.join(checkpoint_path, "universal_" + os.path.basename(deepspeed_checkpoint_dirs[0])), + loaded_checkpoint_dp_world_size, + ) + logger.info( + f"Converted deepspeed checkpoint at {checkpoint_path} to universal format for " + f"current world size {deepspeed_engine.dp_world_size}" + ) + from deepspeed import comm as dist + + dist.barrier() + else: + deepspeed_engine._config.load_universal_checkpoint = False + # this magically updates self.optimizer and self.lr_scheduler load_path, _ = deepspeed_engine.load_checkpoint( checkpoint_path, diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 0cd8fcf8cd14..1ac65e18ebee 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2501,7 +2501,10 @@ def _inner_training_loop( if resume_from_checkpoint is not None: if self.is_deepspeed_enabled: deepspeed_load_checkpoint( - self.model_wrapped, resume_from_checkpoint, load_module_strict=not _is_peft_model(self.model) + self.model_wrapped, + resume_from_checkpoint, + load_module_strict=not _is_peft_model(self.model), + convert_deepspeed_universal_checkpoint=args.convert_deepspeed_universal_checkpoint, ) elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled: self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped) @@ -3050,6 +3053,7 @@ def _load_best_model(self): self.model_wrapped, self.state.best_model_checkpoint, load_module_strict=not _is_peft_model(self.model), + convert_deepspeed_universal_checkpoint=self.args.convert_deepspeed_universal_checkpoint, ) elif self.is_fsdp_enabled: load_result = load_fsdp_model( diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 5e71f2a30a6d..15902127e04b 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1469,6 +1469,16 @@ class TrainingArguments: }, ) + convert_deepspeed_universal_checkpoint: Optional[bool] = field( + default=False, + metadata={ + "help": ( + "Whether or not to convert deepspeed zero checkpoint to universal checkpoint when " + "loaded world size is changed." + ) + }, + ) + def __post_init__(self): # Set default output_dir if not provided if self.output_dir is None: