diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 2694cff70afe..1c8ca74a6f09 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -93,6 +93,7 @@ nested_numpify, nested_xla_mesh_reduce, reissue_pt_warnings, + remove_dummy_checkpoint, ) from .trainer_utils import ( PREFIX_CHECKPOINT_DIR, @@ -2780,12 +2781,8 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa if self.args.should_save: self._save(output_dir, state_dict=state_dict) if self.is_fsdp_enabled: - # remove the dummy state_dict saved above - if self.args.should_save: - for filename in [WEIGHTS_NAME, SAFE_WEIGHTS_NAME]: - file = os.path.join(output_dir, filename) - if os.path.isfile(file): - os.remove(file) + # remove the dummy state_dict + remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME]) save_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir) elif self.is_deepspeed_enabled: @@ -2801,6 +2798,9 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use" " zero_to_fp32.py to recover weights" ) + self._save(output_dir, state_dict={}) + # remove the dummy state_dict + remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME]) self.model_wrapped.save_checkpoint(output_dir) elif self.args.should_save: diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index 88e27e3c4dc7..b8c4080c2d54 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -1089,6 +1089,14 @@ def get_module_class_from_name(module, name): return module_class +def remove_dummy_checkpoint(is_main_process, output_dir, filenames): + if is_main_process: + for filename in filenames: + file = os.path.join(output_dir, filename) + if os.path.isfile(file): + os.remove(file) + + if is_sagemaker_mp_enabled(): import smdistributed.modelparallel.torch as smp