@@ -255,17 +255,22 @@ def term_handler(self, signum, frame):
255255 # --------------------
256256 # MODEL SAVE CHECKPOINT
257257 # --------------------
258+ def _atomic_save (self , checkpoint , filepath ):
259+ tmp_path = str (filepath ) + ".part"
260+ torch .save (checkpoint , tmp_path )
261+ os .replace (tmp_path , filepath )
262+
258263 def save_checkpoint (self , filepath ):
259264 checkpoint = self .dump_checkpoint ()
260265
261266 # do the actual save
262267 try :
263- torch . save (checkpoint , filepath )
268+ self . _atomic_save (checkpoint , filepath )
264269 except AttributeError :
265270 if 'hparams' in checkpoint :
266271 del checkpoint ['hparams' ]
267272
268- torch . save (checkpoint , filepath )
273+ self . _atomic_save (checkpoint , filepath )
269274
270275 def restore (self , checkpoint_path , on_gpu ):
271276 # if on_gpu:
@@ -412,12 +417,12 @@ def hpc_save(self, folderpath, logger):
412417 # do the actual save
413418 # TODO: fix for anything with multiprocess DP, DDP, DDP2
414419 try :
415- torch . save (checkpoint , filepath )
420+ self . _atomic_save (checkpoint , filepath )
416421 except AttributeError :
417422 if 'hparams' in checkpoint :
418423 del checkpoint ['hparams' ]
419424
420- torch . save (checkpoint , filepath )
425+ self . _atomic_save (checkpoint , filepath )
421426
422427 return filepath
423428
0 commit comments