Skip to content

Commit ba04bb3

Browse files
Disable training with zero num_training_batches when insufficient limit_train_batches (#5703)
* disable training when zero num_train_batches with limit_train_batches * refactor train skip condition * fix formatting issues * fix formatting issues * ref: test error msg * fix tests for data loader calls * fix train dataloader condition * update limit_train_batches upper range in test comment * remove model state check test Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
1 parent 793fe73 commit ba04bb3

File tree

2 files changed

+65
-15
lines changed

2 files changed

+65
-15
lines changed

pytorch_lightning/trainer/training_loop.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,7 @@ def num_optimizers(self):
9191
return num_optimizers
9292

9393
def should_skip_training(self):
94-
if self.trainer.current_epoch >= self.trainer.max_epochs:
95-
return True
96-
97-
if self.trainer.limit_train_batches == 0:
98-
return True
99-
100-
return False
94+
return self.trainer.current_epoch >= self.trainer.max_epochs or self.trainer.num_training_batches == 0
10195

10296
def on_train_start(self):
10397
# clear cache before training
@@ -203,7 +197,7 @@ def on_train_epoch_start(self, epoch):
203197
model = self.trainer.get_model()
204198

205199
# reset train dataloader
206-
if self.trainer.reload_dataloaders_every_epoch:
200+
if epoch != 0 and self.trainer.reload_dataloaders_every_epoch:
207201
self.trainer.reset_train_dataloader(model)
208202

209203
# set seed for distributed sampler (enables shuffling for each epoch)
@@ -238,7 +232,7 @@ def on_train_batch_end(self, epoch_output, batch_end_outputs, batch, batch_idx,
238232
self.trainer.logger_connector.on_train_batch_end()
239233

240234
def reset_train_val_dataloaders(self, model):
241-
if not self.trainer.reload_dataloaders_every_epoch:
235+
if self.trainer.train_dataloader is None or not self.trainer.reload_dataloaders_every_epoch:
242236
self.trainer.reset_train_dataloader(model)
243237

244238
if self.trainer.val_dataloaders is None and not self.trainer.reload_dataloaders_every_epoch:

tests/trainer/test_trainer.py

Lines changed: 62 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,22 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from argparse import Namespace
15-
from copy import deepcopy
1614
import math
1715
import os
18-
from pathlib import Path
1916
import pickle
2017
import sys
18+
from argparse import Namespace
19+
from copy import deepcopy
20+
from pathlib import Path
2121
from unittest.mock import ANY, call, patch
2222

2323
import cloudpickle
24-
from omegaconf import OmegaConf
2524
import pytest
2625
import torch
26+
from omegaconf import OmegaConf
27+
from torch.utils.data import DataLoader
2728

29+
import tests.base.develop_utils as tutils
2830
from pytorch_lightning import Callback, LightningModule, Trainer
2931
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
3032
from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv
@@ -35,8 +37,7 @@
3537
from pytorch_lightning.utilities import NATIVE_AMP_AVAILABLE
3638
from pytorch_lightning.utilities.cloud_io import load as pl_load
3739
from pytorch_lightning.utilities.exceptions import MisconfigurationException
38-
from tests.base import BoringModel, EvalModelTemplate
39-
import tests.base.develop_utils as tutils
40+
from tests.base import BoringModel, EvalModelTemplate, RandomDataset
4041

4142

4243
@pytest.mark.parametrize("url_ckpt", [True, False])
@@ -1444,3 +1445,58 @@ def test_trainer_profiler_incorrect_arg_type(profiler):
14441445
match=r"Only None, bool, str and subclasses of `BaseProfiler`"
14451446
r" are valid values for `Trainer`'s `profiler` parameter. *"):
14461447
Trainer(profiler=profiler)
1448+
1449+
1450+
@pytest.mark.parametrize(
1451+
["limit_train_batches", "global_step", "num_training_batches", "current_epoch", "should_train"],
1452+
[(0.2, 0, 0, 0, False), (0.5, 10, 2, 4, True)],
1453+
)
1454+
def test_disabled_training_for_insufficient_limit_train_batches(tmpdir, limit_train_batches, global_step,
1455+
num_training_batches, current_epoch, should_train):
1456+
"""
1457+
Verify when `limit_train_batches` is float & between [0.0, 1.0] and
1458+
`int(self.num_training_batches * self.limit_train_batches) == 0`, the training loop is disabled.
1459+
"""
1460+
class CurrentModel(BoringModel):
1461+
1462+
training_step_invoked = False
1463+
training_epoch_end_invoked = False
1464+
1465+
def training_step(self, *args, **kwargs):
1466+
self.training_step_invoked = True
1467+
return super().training_step(*args, **kwargs)
1468+
1469+
def training_epoch_end(self, *args, **kwargs):
1470+
self.training_epoch_end_invoked = True
1471+
return super().training_epoch_end(*args, **kwargs)
1472+
1473+
dataset_len = 100
1474+
batch_size = 25
1475+
1476+
train = RandomDataset(32, length=dataset_len)
1477+
train_loader = DataLoader(train, batch_size=batch_size)
1478+
1479+
model = CurrentModel()
1480+
1481+
trainer = Trainer(
1482+
default_root_dir=tmpdir,
1483+
max_epochs=5,
1484+
limit_train_batches=limit_train_batches,
1485+
)
1486+
result = trainer.fit(model, train_loader)
1487+
1488+
params_string = f"""`limit_train_batches={limit_train_batches}`, `dataset_len={dataset_len}`
1489+
& `batch_size={batch_size}` as
1490+
`num_training_batches={num_training_batches}`"""
1491+
if should_train:
1492+
error_string = f"should run with {params_string}"
1493+
else:
1494+
error_string = f"should not run with {params_string}"
1495+
1496+
assert result == 1, "training failed to complete"
1497+
assert trainer.state == TrainerState.FINISHED
1498+
assert trainer.global_step == global_step
1499+
assert trainer.num_training_batches == num_training_batches
1500+
assert trainer.current_epoch == current_epoch
1501+
assert model.training_step_invoked == should_train, f"`training_step` {error_string}"
1502+
assert model.training_epoch_end_invoked == should_train, f"`training_epoch_end` {error_string}"

0 commit comments

Comments
 (0)