Skip to content

Commit e403c35

Browse files
committed
Support arbitrary Optimizables as optimizers (#16189)
1 parent 83a2c9f commit e403c35

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
8181
- Added info message for Ampere CUDA GPU users to enable tf32 matmul precision ([#16037](https://github.com/Lightning-AI/lightning/pull/16037))
8282

8383

84+
- Added support for returning optimizer-like classes in `LightningModule.configure_optimizers` ([#16189](https://github.com/Lightning-AI/lightning/pull/16189))
85+
86+
8487
### Changed
8588

8689
- Drop PyTorch 1.9 support ([#15347](https://github.com/Lightning-AI/lightning/pull/15347))

src/pytorch_lightning/core/optimizer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from torch.optim import Optimizer
2222

2323
import pytorch_lightning as pl
24-
from lightning_fabric.utilities.types import _Stateful, ReduceLROnPlateau
24+
from lightning_fabric.utilities.types import _Stateful, Optimizable, ReduceLROnPlateau
2525
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2626
from pytorch_lightning.utilities.model_helpers import is_overridden
2727
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
@@ -194,14 +194,14 @@ def _configure_optimizers(
194194
monitor = None
195195

196196
# single output, single optimizer
197-
if isinstance(optim_conf, Optimizer):
197+
if isinstance(optim_conf, Optimizable):
198198
optimizers = [optim_conf]
199199
# two lists, optimizer + lr schedulers
200200
elif (
201201
isinstance(optim_conf, (list, tuple))
202202
and len(optim_conf) == 2
203203
and isinstance(optim_conf[0], list)
204-
and all(isinstance(opt, Optimizer) for opt in optim_conf[0])
204+
and all(isinstance(opt, Optimizable) for opt in optim_conf[0])
205205
):
206206
opt, sch = optim_conf
207207
optimizers = opt
@@ -235,7 +235,7 @@ def _configure_optimizers(
235235
if optimizer_frequencies and len(optimizer_frequencies) != len(optimizers):
236236
raise ValueError("A frequency must be given to each optimizer.")
237237
# single list or tuple, multiple optimizer
238-
elif isinstance(optim_conf, (list, tuple)) and all(isinstance(opt, Optimizer) for opt in optim_conf):
238+
elif isinstance(optim_conf, (list, tuple)) and all(isinstance(opt, Optimizable) for opt in optim_conf):
239239
optimizers = list(optim_conf)
240240
# unknown configuration
241241
else:

0 commit comments

Comments
 (0)