Skip to content
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- The `ModelCheckpoint.save_on_train_epoch_end` attribute is now computed dynamically every epoch, accounting for changes to the validation dataloaders ([#15300](https://github.com/Lightning-AI/lightning/pull/15300))

- The Trainer now raises an error if it is given multiple stateful callbacks of the same time with colliding state keys ([#15634](https://github.com/Lightning-AI/lightning/pull/15634))


### Fixed

- Enhanced `reduce_boolean_decision` to accommodate `any`-analogous semantics expected by the `EarlyStopping` callback ([#15253](https://github.com/Lightning-AI/lightning/pull/15253))
Expand Down
17 changes: 17 additions & 0 deletions src/pytorch_lightning/trainer/connectors/callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from pytorch_lightning.callbacks.timer import Timer
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_info

_log = logging.getLogger(__name__)
Expand Down Expand Up @@ -82,6 +83,7 @@ def on_trainer_init(
self._configure_fault_tolerance_callbacks()

self.trainer.callbacks.extend(_configure_external_callbacks())
_validate_callbacks_list(self.trainer.callbacks)

# push all model checkpoint callbacks to the end
# it is important that these are the last callbacks to run
Expand Down Expand Up @@ -290,3 +292,18 @@ def _configure_external_callbacks() -> List[Callback]:
)
external_callbacks.extend(callbacks_list)
return external_callbacks


def _validate_callbacks_list(callbacks: List[Callback]) -> None:
stateful_callbacks = [cb for cb in callbacks if is_overridden("state_dict", instance=cb)]
seen_callbacks = set()
for callback in stateful_callbacks:
if callback.state_key in seen_callbacks:
raise RuntimeError(
f"Found more than one stateful callback of type `{type(callback).__name__}`. In the current"
" configuration, this callback does not support being saved alongside other instances of the same type."
f" Please consult the documentation of `{type(callback).__name__}` regarding valid settings for"
" the callback state to be checkpointable."
" HINT: The `callback.state_key` must be unique among all callbacks in the Trainer."
)
seen_callbacks.add(callback.state_key)
4 changes: 2 additions & 2 deletions tests/tests_pytorch/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,8 +981,8 @@ def assert_checkpoint_log_dir(idx):
def test_configure_model_checkpoint(tmpdir):
"""Test all valid and invalid ways a checkpoint callback can be passed to the Trainer."""
kwargs = dict(default_root_dir=tmpdir)
callback1 = ModelCheckpoint()
callback2 = ModelCheckpoint()
callback1 = ModelCheckpoint(monitor="foo")
callback2 = ModelCheckpoint(monitor="bar")

# no callbacks
trainer = Trainer(enable_checkpointing=False, callbacks=[], **kwargs)
Expand Down
50 changes: 34 additions & 16 deletions tests/tests_pytorch/trainer/connectors/test_callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from unittest import mock
from unittest.mock import Mock

import pytest
import torch

from pytorch_lightning import Callback, LightningModule, Trainer
Expand All @@ -36,8 +37,8 @@

def test_checkpoint_callbacks_are_last(tmpdir):
"""Test that checkpoint callbacks always get moved to the end of the list, with preserved order."""
checkpoint1 = ModelCheckpoint(tmpdir)
checkpoint2 = ModelCheckpoint(tmpdir)
checkpoint1 = ModelCheckpoint(tmpdir, monitor="foo")
checkpoint2 = ModelCheckpoint(tmpdir, monitor="bar")
model_summary = ModelSummary()
early_stopping = EarlyStopping(monitor="foo")
lr_monitor = LearningRateMonitor()
Expand Down Expand Up @@ -179,7 +180,8 @@ def _attach_callbacks(trainer_callbacks, model_callbacks):
cb_connector._attach_model_callbacks()
return trainer

early_stopping = EarlyStopping(monitor="foo")
early_stopping1 = EarlyStopping(monitor="red")
early_stopping2 = EarlyStopping(monitor="blue")
progress_bar = TQDMProgressBar()
lr_monitor = LearningRateMonitor()
grad_accumulation = GradientAccumulationScheduler({1: 1})
Expand All @@ -189,40 +191,40 @@ def _attach_callbacks(trainer_callbacks, model_callbacks):
assert trainer.callbacks == [trainer.accumulation_scheduler]

# callbacks of different types
trainer = _attach_callbacks(trainer_callbacks=[early_stopping], model_callbacks=[progress_bar])
assert trainer.callbacks == [early_stopping, trainer.accumulation_scheduler, progress_bar]
trainer = _attach_callbacks(trainer_callbacks=[early_stopping1], model_callbacks=[progress_bar])
assert trainer.callbacks == [early_stopping1, trainer.accumulation_scheduler, progress_bar]

# same callback type twice, different instance
trainer = _attach_callbacks(
trainer_callbacks=[progress_bar, EarlyStopping(monitor="foo")],
model_callbacks=[early_stopping],
trainer_callbacks=[progress_bar, EarlyStopping(monitor="red")],
model_callbacks=[early_stopping1],
)
assert trainer.callbacks == [progress_bar, trainer.accumulation_scheduler, early_stopping]
assert trainer.callbacks == [progress_bar, trainer.accumulation_scheduler, early_stopping1]

# multiple callbacks of the same type in trainer
trainer = _attach_callbacks(
trainer_callbacks=[
LearningRateMonitor(),
EarlyStopping(monitor="foo"),
EarlyStopping(monitor="yellow"),
LearningRateMonitor(),
EarlyStopping(monitor="foo"),
EarlyStopping(monitor="black"),
],
model_callbacks=[early_stopping, lr_monitor],
model_callbacks=[early_stopping1, lr_monitor],
)
assert trainer.callbacks == [trainer.accumulation_scheduler, early_stopping, lr_monitor]
assert trainer.callbacks == [trainer.accumulation_scheduler, early_stopping1, lr_monitor]

# multiple callbacks of the same type, in both trainer and model
trainer = _attach_callbacks(
trainer_callbacks=[
LearningRateMonitor(),
progress_bar,
EarlyStopping(monitor="foo"),
EarlyStopping(monitor="yellow"),
LearningRateMonitor(),
EarlyStopping(monitor="foo"),
EarlyStopping(monitor="black"),
],
model_callbacks=[early_stopping, lr_monitor, grad_accumulation, early_stopping],
model_callbacks=[early_stopping1, lr_monitor, grad_accumulation, early_stopping2],
)
assert trainer.callbacks == [progress_bar, early_stopping, lr_monitor, grad_accumulation, early_stopping]
assert trainer.callbacks == [progress_bar, early_stopping1, lr_monitor, grad_accumulation, early_stopping2]


def test_attach_model_callbacks_override_info(caplog):
Expand Down Expand Up @@ -296,3 +298,19 @@ def _make_entry_point_query_mock(callback_factory):
import_path = "pkg_resources.iter_entry_points"
with mock.patch(import_path, query_mock):
yield


def test_validate_unique_callback_state_key():
"""Test that we raise an error if the state keys collide, leading to missing state in the checkpoint."""

class MockCallback(Callback):
@property
def state_key(self):
return "same_key"

def state_dict(self):
# pretend these callbacks are stateful by overriding the `state_dict` hook
return {"state": 1}

with pytest.raises(RuntimeError, match="Found more than one stateful callback of type `MockCallback`"):
Trainer(callbacks=[MockCallback(), MockCallback()])
4 changes: 2 additions & 2 deletions tests/tests_pytorch/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,9 +924,9 @@ def test_best_ckpt_evaluate_raises_warning_with_multiple_ckpt_callbacks():
"""Test that a warning is raised if best ckpt callback is used for evaluation configured with multiple
checkpoints."""

ckpt_callback1 = ModelCheckpoint()
ckpt_callback1 = ModelCheckpoint(monitor="foo")
ckpt_callback1.best_model_path = "foo_best_model.ckpt"
ckpt_callback2 = ModelCheckpoint()
ckpt_callback2 = ModelCheckpoint(monitor="bar")
ckpt_callback2.best_model_path = "bar_best_model.ckpt"
trainer = Trainer(callbacks=[ckpt_callback1, ckpt_callback2])
trainer.state.fn = TrainerFn.TESTING
Expand Down