|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | | -from argparse import Namespace |
15 | | -from copy import deepcopy |
16 | 14 | import math |
17 | 15 | import os |
18 | | -from pathlib import Path |
19 | 16 | import pickle |
20 | 17 | import sys |
| 18 | +from argparse import Namespace |
| 19 | +from copy import deepcopy |
| 20 | +from pathlib import Path |
21 | 21 | from unittest.mock import ANY, call, patch |
22 | 22 |
|
23 | 23 | import cloudpickle |
24 | | -from omegaconf import OmegaConf |
25 | 24 | import pytest |
26 | 25 | import torch |
| 26 | +from omegaconf import OmegaConf |
| 27 | +from torch.utils.data import DataLoader |
27 | 28 |
|
| 29 | +import tests.base.develop_utils as tutils |
28 | 30 | from pytorch_lightning import Callback, LightningModule, Trainer |
29 | 31 | from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint |
30 | 32 | from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv |
|
35 | 37 | from pytorch_lightning.utilities import NATIVE_AMP_AVAILABLE |
36 | 38 | from pytorch_lightning.utilities.cloud_io import load as pl_load |
37 | 39 | from pytorch_lightning.utilities.exceptions import MisconfigurationException |
38 | | -from tests.base import BoringModel, EvalModelTemplate |
39 | | -import tests.base.develop_utils as tutils |
| 40 | +from tests.base import BoringModel, EvalModelTemplate, RandomDataset |
40 | 41 |
|
41 | 42 |
|
42 | 43 | @pytest.mark.parametrize("url_ckpt", [True, False]) |
@@ -1444,3 +1445,58 @@ def test_trainer_profiler_incorrect_arg_type(profiler): |
1444 | 1445 | match=r"Only None, bool, str and subclasses of `BaseProfiler`" |
1445 | 1446 | r" are valid values for `Trainer`'s `profiler` parameter. *"): |
1446 | 1447 | Trainer(profiler=profiler) |
| 1448 | + |
| 1449 | + |
| 1450 | +@pytest.mark.parametrize( |
| 1451 | + ["limit_train_batches", "global_step", "num_training_batches", "current_epoch", "should_train"], |
| 1452 | + [(0.2, 0, 0, 0, False), (0.5, 10, 2, 4, True)], |
| 1453 | +) |
| 1454 | +def test_disabled_training_for_insufficient_limit_train_batches(tmpdir, limit_train_batches, global_step, |
| 1455 | + num_training_batches, current_epoch, should_train): |
| 1456 | + """ |
| 1457 | + Verify when `limit_train_batches` is float & between [0.0, 1.0] and |
| 1458 | + `int(self.num_training_batches * self.limit_train_batches) == 0`, the training loop is disabled. |
| 1459 | + """ |
| 1460 | + class CurrentModel(BoringModel): |
| 1461 | + |
| 1462 | + training_step_invoked = False |
| 1463 | + training_epoch_end_invoked = False |
| 1464 | + |
| 1465 | + def training_step(self, *args, **kwargs): |
| 1466 | + self.training_step_invoked = True |
| 1467 | + return super().training_step(*args, **kwargs) |
| 1468 | + |
| 1469 | + def training_epoch_end(self, *args, **kwargs): |
| 1470 | + self.training_epoch_end_invoked = True |
| 1471 | + return super().training_epoch_end(*args, **kwargs) |
| 1472 | + |
| 1473 | + dataset_len = 100 |
| 1474 | + batch_size = 25 |
| 1475 | + |
| 1476 | + train = RandomDataset(32, length=dataset_len) |
| 1477 | + train_loader = DataLoader(train, batch_size=batch_size) |
| 1478 | + |
| 1479 | + model = CurrentModel() |
| 1480 | + |
| 1481 | + trainer = Trainer( |
| 1482 | + default_root_dir=tmpdir, |
| 1483 | + max_epochs=5, |
| 1484 | + limit_train_batches=limit_train_batches, |
| 1485 | + ) |
| 1486 | + result = trainer.fit(model, train_loader) |
| 1487 | + |
| 1488 | + params_string = f"""`limit_train_batches={limit_train_batches}`, `dataset_len={dataset_len}` |
| 1489 | + & `batch_size={batch_size}` as |
| 1490 | + `num_training_batches={num_training_batches}`""" |
| 1491 | + if should_train: |
| 1492 | + error_string = f"should run with {params_string}" |
| 1493 | + else: |
| 1494 | + error_string = f"should not run with {params_string}" |
| 1495 | + |
| 1496 | + assert result == 1, "training failed to complete" |
| 1497 | + assert trainer.state == TrainerState.FINISHED |
| 1498 | + assert trainer.global_step == global_step |
| 1499 | + assert trainer.num_training_batches == num_training_batches |
| 1500 | + assert trainer.current_epoch == current_epoch |
| 1501 | + assert model.training_step_invoked == should_train, f"`training_step` {error_string}" |
| 1502 | + assert model.training_epoch_end_invoked == should_train, f"`training_epoch_end` {error_string}" |
0 commit comments