From 102b5804c623b0e9b0d852a082d77d9c353523b1 Mon Sep 17 00:00:00 2001 From: Ryan Smith Date: Mon, 11 Dec 2023 00:33:07 +0000 Subject: [PATCH 1/3] trainer flag should_stop resets when fit is called --- src/lightning/pytorch/trainer/trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index ae0fc7756fbf2..93bbd265936d5 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -540,6 +540,7 @@ def fit( self.state.fn = TrainerFn.FITTING self.state.status = TrainerStatus.RUNNING self.training = True + self.should_stop = False call._call_and_handle_interrupt( self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path ) From 82edacb53d5ae70eea84ef5014a6b237fbd2dba3 Mon Sep 17 00:00:00 2001 From: Ryan Smith Date: Mon, 18 Dec 2023 16:57:12 +0000 Subject: [PATCH 2/3] introduce early stopping condition to trigger should_stop --- .../loops/test_training_epoch_loop.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/tests/tests_pytorch/loops/test_training_epoch_loop.py b/tests/tests_pytorch/loops/test_training_epoch_loop.py index 06f27ab322530..cb48c56b9e195 100644 --- a/tests/tests_pytorch/loops/test_training_epoch_loop.py +++ b/tests/tests_pytorch/loops/test_training_epoch_loop.py @@ -17,7 +17,7 @@ import pytest from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.trainer.trainer import Trainer - +from lightning.pytorch.callbacks import EarlyStopping def test_no_val_on_train_epoch_loop_restart(tmpdir): """Test that training validation loop doesn't get triggered at the beginning of a restart.""" @@ -86,7 +86,17 @@ def test_should_stop_triggers_validation_once(min_epochs, min_steps, val_count, (min_epochs/steps is satisfied). """ - model = BoringModel() + class NewBoring(BoringModel): + def training_step(self, batch, batch_idx): + self.log('loss', self.step(batch)) + return {"loss": self.step(batch)} + + model = NewBoring() + # create a stopping condition with a high threshold so it triggers immediately + # check the condition before validation so the count is unaffected + stopping = EarlyStopping(monitor='loss', + check_on_train_epoch_end=True, + stopping_threshold=100) trainer = Trainer( default_root_dir=tmp_path, num_sanity_val_steps=0, @@ -97,8 +107,8 @@ def test_should_stop_triggers_validation_once(min_epochs, min_steps, val_count, min_steps=min_steps, enable_model_summary=False, enable_checkpointing=False, + callbacks=[stopping] ) - trainer.should_stop = True # Request to stop before min_epochs/min_steps are reached trainer.fit_loop.epoch_loop.val_loop.run = Mock() trainer.fit(model) assert trainer.fit_loop.epoch_loop.val_loop.run.call_count == val_count From c85b8fde117a9cf4e22a9abb8cd4e776be201ca3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 18 Dec 2023 16:58:36 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../tests_pytorch/loops/test_training_epoch_loop.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/tests_pytorch/loops/test_training_epoch_loop.py b/tests/tests_pytorch/loops/test_training_epoch_loop.py index cb48c56b9e195..f40cf75076d36 100644 --- a/tests/tests_pytorch/loops/test_training_epoch_loop.py +++ b/tests/tests_pytorch/loops/test_training_epoch_loop.py @@ -15,9 +15,10 @@ from unittest.mock import Mock, patch import pytest +from lightning.pytorch.callbacks import EarlyStopping from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.trainer.trainer import Trainer -from lightning.pytorch.callbacks import EarlyStopping + def test_no_val_on_train_epoch_loop_restart(tmpdir): """Test that training validation loop doesn't get triggered at the beginning of a restart.""" @@ -86,17 +87,16 @@ def test_should_stop_triggers_validation_once(min_epochs, min_steps, val_count, (min_epochs/steps is satisfied). """ + class NewBoring(BoringModel): def training_step(self, batch, batch_idx): - self.log('loss', self.step(batch)) + self.log("loss", self.step(batch)) return {"loss": self.step(batch)} model = NewBoring() # create a stopping condition with a high threshold so it triggers immediately # check the condition before validation so the count is unaffected - stopping = EarlyStopping(monitor='loss', - check_on_train_epoch_end=True, - stopping_threshold=100) + stopping = EarlyStopping(monitor="loss", check_on_train_epoch_end=True, stopping_threshold=100) trainer = Trainer( default_root_dir=tmp_path, num_sanity_val_steps=0, @@ -107,7 +107,7 @@ def training_step(self, batch, batch_idx): min_steps=min_steps, enable_model_summary=False, enable_checkpointing=False, - callbacks=[stopping] + callbacks=[stopping], ) trainer.fit_loop.epoch_loop.val_loop.run = Mock() trainer.fit(model)