Skip to content

Commit 39c6ec9

Browse files
awaelchlirohitgr7
andauthored
Only load global step when fitting (#15532)
Co-authored-by: Rohit Gupta <[email protected]>
1 parent f392180 commit 39c6ec9

File tree

3 files changed

+19
-16
lines changed

3 files changed

+19
-16
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2121

2222
- From now on, Lightning Trainer and `LightningModule.load_from_checkpoint` automatically upgrade the loaded checkpoint if it was produced in an old version of Lightning ([#15237](https://github.com/Lightning-AI/lightning/pull/15237))
2323

24-
-
24+
- `Trainer.{validate,test,predict}(ckpt_path=...)` no longer restores the `Trainer.global_step` and `trainer.current_epoch` value from the checkpoints - From now on, only `Trainer.fit` will restore this value ([#15532](https://github.com/Lightning-AI/lightning/pull/15532))
2525

2626
-
2727

src/pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -341,19 +341,20 @@ def restore_loops(self) -> None:
341341
pl_module = self.trainer.lightning_module
342342
assert pl_module is not None
343343

344-
# set the `global_step` value for checkpoints before v1.6 without the progress tracking state.
345-
# it will be overwritten by the loop's state if it was also saved
346-
batch_loop = fit_loop.epoch_loop.batch_loop
347-
if pl_module.automatic_optimization:
348-
batch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = self._loaded_checkpoint[
349-
"global_step"
350-
]
351-
else:
352-
batch_loop.manual_loop.optim_step_progress.total.completed = self._loaded_checkpoint["global_step"]
353-
354-
# set the `current_epoch` value for checkpoints before v1.6 without the progress tracking state.
355-
# it will be overwritten by the loop's state if it was also saved
356-
fit_loop.epoch_progress.current.completed = self._loaded_checkpoint["epoch"]
344+
if self.trainer.state.fn == TrainerFn.FITTING:
345+
# set the `global_step` value for checkpoints before v1.6 without the progress tracking state.
346+
# it will be overwritten by the loop's state if it was also saved
347+
batch_loop = fit_loop.epoch_loop.batch_loop
348+
if pl_module.automatic_optimization:
349+
batch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = self._loaded_checkpoint[
350+
"global_step"
351+
]
352+
else:
353+
batch_loop.manual_loop.optim_step_progress.total.completed = self._loaded_checkpoint["global_step"]
354+
355+
# set the `current_epoch` value for checkpoints before v1.6 without the progress tracking state.
356+
# it will be overwritten by the loop's state if it was also saved
357+
fit_loop.epoch_progress.current.completed = self._loaded_checkpoint["epoch"]
357358

358359
assert self.trainer.state.fn is not None
359360
state_dict = self._loaded_checkpoint.get("loops")

tests/tests_pytorch/models/test_restore.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def _check_model_state_dict(self):
187187

188188
def _test_on_val_test_predict_start(self):
189189
assert self.trainer.current_epoch == state_dict["epoch"]
190-
assert self.trainer.global_step == state_dict["global_step"]
190+
assert self.trainer.global_step == 0
191191
assert self._check_model_state_dict()
192192

193193
def on_train_start(self):
@@ -626,8 +626,10 @@ def __init__(self):
626626
super().__init__()
627627
self.on_train_start_called = False
628628

629-
def on_validation_start(self):
629+
def on_train_start(self):
630630
assert self.trainer.current_epoch == real_global_epoch and self.trainer.current_epoch > 0
631+
632+
def on_validation_start(self):
631633
dataloader = dm.val_dataloader()
632634
tpipes.run_model_prediction(self.trainer.lightning_module, dataloader=dataloader)
633635

0 commit comments

Comments
 (0)