Skip to content

Commit 3e7c014

Browse files
erhoo82carmocca
andauthored
Relax on_train_batch_* hook check with dataloader_iter to a warning (#16062)
Co-authored-by: Carlos Mocholí <[email protected]>
1 parent bbc52ec commit 3e7c014

File tree

3 files changed

+15
-18
lines changed

3 files changed

+15
-18
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
7373
- `MLFlowLogger` now logs hyperparameters and metrics in batched API calls ([#15915](https://github.com/Lightning-AI/lightning/pull/15915))
7474

7575

76+
- Overriding the `on_train_batch_{start,end}` hooks in conjunction with taking a `dataloader_iter` in the `training_step` no longer errors out and instead shows a warning ([#16062](https://github.com/Lightning-AI/lightning/pull/16062))
77+
78+
7679
### Deprecated
7780

7881
- Deprecated `description`, `env_prefix` and `env_parse` parameters in `LightningCLI.__init__` in favour of giving them through `parser_kwargs` ([#15651](https://github.com/Lightning-AI/lightning/pull/15651))

src/pytorch_lightning/trainer/configuration_validator.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -169,20 +169,14 @@ def __verify_manual_optimization_support(trainer: "pl.Trainer", model: "pl.Light
169169

170170
def __check_training_step_requires_dataloader_iter(model: "pl.LightningModule") -> None:
171171
"""Check if the current `training_step` is requesting `dataloader_iter`."""
172-
training_step_fx = model.training_step
173-
if is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True):
174-
175-
if is_overridden("on_train_batch_start", model):
176-
raise MisconfigurationException(
177-
"The model hook `on_train_batch_start` is not compatible with "
178-
"taking a `dataloader_iter` argument in your `training_step`."
179-
)
180-
181-
if is_overridden("on_train_batch_end", model):
182-
raise MisconfigurationException(
183-
"The model hook `on_train_batch_end` is not compatible with "
184-
"taking a `dataloader_iter` argument in your `training_step`."
185-
)
172+
if is_param_in_hook_signature(model.training_step, "dataloader_iter", explicit=True):
173+
for hook in ("on_train_batch_start", "on_train_batch_end"):
174+
if is_overridden(hook, model):
175+
rank_zero_warn(
176+
f"The `batch_idx` argument in `{type(model).__name__}.{hook}` hook may"
177+
" not match with the actual batch index when using a `dataloader_iter`"
178+
" argument in your `training_step`."
179+
)
186180

187181
if model.truncated_bptt_steps > 0:
188182
raise MisconfigurationException(

tests/tests_pytorch/utilities/test_fetching.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -429,9 +429,9 @@ class InvalidModel(AsyncBoringModel):
429429
def on_train_batch_start(self, batch, batch_idx):
430430
pass
431431

432-
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
432+
trainer = Trainer(fast_dev_run=1, default_root_dir=tmpdir)
433433
m = InvalidModel()
434-
with pytest.raises(MisconfigurationException, match="The model hook `on_train_batch_start` is not compatible with"):
434+
with pytest.warns(match="InvalidModel.on_train_batch_start` hook may not match"):
435435
trainer.fit(m)
436436

437437

@@ -443,9 +443,9 @@ class InvalidModel(AsyncBoringModel):
443443
def on_train_batch_end(self, outputs, batch, batch_idx):
444444
pass
445445

446-
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
446+
trainer = Trainer(fast_dev_run=1, default_root_dir=tmpdir)
447447
m = InvalidModel()
448-
with pytest.raises(MisconfigurationException, match="The model hook `on_train_batch_end` is not compatible with"):
448+
with pytest.warns(match="InvalidModel.on_train_batch_end` hook may not match"):
449449
trainer.fit(m)
450450

451451

0 commit comments

Comments
 (0)