Skip to content

Commit b52365a

Browse files
author
Frederik Diehl
committed
Added atomic checkpoint creation
1 parent de2ccc0 commit b52365a

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

pytorch_lightning/trainer/training_io.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -255,17 +255,22 @@ def term_handler(self, signum, frame):
255255
# --------------------
256256
# MODEL SAVE CHECKPOINT
257257
# --------------------
258+
def _atomic_save(self, checkpoint, filepath):
259+
tmp_path = str(filepath) + ".part"
260+
torch.save(checkpoint, tmp_path)
261+
os.replace(tmp_path, filepath)
262+
258263
def save_checkpoint(self, filepath):
259264
checkpoint = self.dump_checkpoint()
260265

261266
# do the actual save
262267
try:
263-
torch.save(checkpoint, filepath)
268+
self._atomic_save(checkpoint, filepath)
264269
except AttributeError:
265270
if 'hparams' in checkpoint:
266271
del checkpoint['hparams']
267272

268-
torch.save(checkpoint, filepath)
273+
self._atomic_save(checkpoint, filepath)
269274

270275
def restore(self, checkpoint_path, on_gpu):
271276
# if on_gpu:
@@ -412,12 +417,12 @@ def hpc_save(self, folderpath, logger):
412417
# do the actual save
413418
# TODO: fix for anything with multiprocess DP, DDP, DDP2
414419
try:
415-
torch.save(checkpoint, filepath)
420+
self._atomic_save(checkpoint, filepath)
416421
except AttributeError:
417422
if 'hparams' in checkpoint:
418423
del checkpoint['hparams']
419424

420-
torch.save(checkpoint, filepath)
425+
self._atomic_save(checkpoint, filepath)
421426

422427
return filepath
423428

0 commit comments

Comments
 (0)