|
21 | 21 | from torch.optim import Optimizer |
22 | 22 |
|
23 | 23 | import pytorch_lightning as pl |
24 | | -from lightning_fabric.utilities.types import _Stateful, ReduceLROnPlateau |
| 24 | +from lightning_fabric.utilities.types import _Stateful, Optimizable, ReduceLROnPlateau |
25 | 25 | from pytorch_lightning.utilities.exceptions import MisconfigurationException |
26 | 26 | from pytorch_lightning.utilities.model_helpers import is_overridden |
27 | 27 | from pytorch_lightning.utilities.rank_zero import rank_zero_warn |
@@ -194,14 +194,14 @@ def _configure_optimizers( |
194 | 194 | monitor = None |
195 | 195 |
|
196 | 196 | # single output, single optimizer |
197 | | - if isinstance(optim_conf, Optimizer): |
| 197 | + if isinstance(optim_conf, Optimizable): |
198 | 198 | optimizers = [optim_conf] |
199 | 199 | # two lists, optimizer + lr schedulers |
200 | 200 | elif ( |
201 | 201 | isinstance(optim_conf, (list, tuple)) |
202 | 202 | and len(optim_conf) == 2 |
203 | 203 | and isinstance(optim_conf[0], list) |
204 | | - and all(isinstance(opt, Optimizer) for opt in optim_conf[0]) |
| 204 | + and all(isinstance(opt, Optimizable) for opt in optim_conf[0]) |
205 | 205 | ): |
206 | 206 | opt, sch = optim_conf |
207 | 207 | optimizers = opt |
@@ -235,7 +235,7 @@ def _configure_optimizers( |
235 | 235 | if optimizer_frequencies and len(optimizer_frequencies) != len(optimizers): |
236 | 236 | raise ValueError("A frequency must be given to each optimizer.") |
237 | 237 | # single list or tuple, multiple optimizer |
238 | | - elif isinstance(optim_conf, (list, tuple)) and all(isinstance(opt, Optimizer) for opt in optim_conf): |
| 238 | + elif isinstance(optim_conf, (list, tuple)) and all(isinstance(opt, Optimizable) for opt in optim_conf): |
239 | 239 | optimizers = list(optim_conf) |
240 | 240 | # unknown configuration |
241 | 241 | else: |
|
0 commit comments