Skip to content

Commit 0b7f5a8

Browse files
ananthsubtchatonrohitgr7mergify[bot]
authored
Fix toggle optimizer (#5775)
* Update lightning.py * update changelog * add a 3 optimizer test * resolve flake8 * remove extra code * typo * resolve typo * Apply suggestions from code review Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: tchaton <[email protected]> Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
1 parent e8c1755 commit 0b7f5a8

File tree

3 files changed

+174
-15
lines changed

3 files changed

+174
-15
lines changed

CHANGELOG.md

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,19 @@ All notable changes to this project will be documented in this file.
44

55
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

7+
## [unreleased] - YYYY-MM-DD
8+
9+
### Added
10+
11+
### Changed
12+
13+
### Deprecated
14+
15+
### Removed
16+
17+
### Fixed
18+
19+
- Fixed `toggle_optimizers` not handling all optimizer parameters ([#5775](https://github.com/PyTorchLightning/pytorch-lightning/pull/5775))
720

821
## [1.1.7] - 2021-02-03
922

@@ -32,7 +45,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3245
- Fixed FileNotFoundError for best checkpoint when using DDP with Hydra ([#5629](https://github.com/PyTorchLightning/pytorch-lightning/pull/5629))
3346
- Fixed an error when logging a progress bar metric with a reserved name ([#5620](https://github.com/PyTorchLightning/pytorch-lightning/pull/5620))
3447
- Fixed `Metric`'s `state_dict` not included when child modules ([#5614](https://github.com/PyTorchLightning/pytorch-lightning/pull/5614))
35-
- Fixed Neptune logger creating multiple experiments when GPUs > 1 ([#3256](https://github.com/PyTorchLightning/pytorch-lightning/pull/3256))
48+
- Fixed Neptune logger creating multiple experiments when GPUs > 1 ([#3256](https://github.com/PyTorchLightning/pytorch-lightning/pull/3256))
3649
- Fixed duplicate logs appearing in console when using the python logging module ([#5509](https://github.com/PyTorchLightning/pytorch-lightning/pull/5509))
3750
- Fixed tensor printing in `trainer.test()` ([#5138](https://github.com/PyTorchLightning/pytorch-lightning/pull/5138))
3851
- Fixed not using dataloader when `hparams` present ([#4559](https://github.com/PyTorchLightning/pytorch-lightning/pull/4559))

pytorch_lightning/core/lightning.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1176,22 +1176,24 @@ def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int):
11761176
optimizer: Current optimizer used in training_loop
11771177
optimizer_idx: Current optimizer idx in training_loop
11781178
"""
1179+
1180+
# Iterate over all optimizer parameters to preserve their `requires_grad` information
1181+
# in case these are pre-defined during `configure_optimizers`
11791182
param_requires_grad_state = {}
1180-
# make sure current optimizer is latest to be iterated over.
1181-
optimizers = [opt for opt in self.optimizers(use_pl_optimizer=False) if opt != optimizer] + [optimizer]
1182-
num_optimizers = len(optimizers) - 1
1183-
for opt_idx, opt in enumerate(optimizers):
1183+
for opt in self.optimizers(use_pl_optimizer=False):
11841184
for group in opt.param_groups:
11851185
for param in group['params']:
1186-
if num_optimizers == opt_idx:
1187-
# If a param appears in 2 optimizers, revert `requires_grad` to before toggle.
1188-
if param in param_requires_grad_state:
1189-
param.requires_grad = param_requires_grad_state[param]
1190-
else:
1191-
# save requires_grad for later restoration
1192-
param_requires_grad_state[param] = param.requires_grad
1193-
param.requires_grad = False
1194-
1186+
# If a param already appear in param_requires_grad_state, continue
1187+
if param in param_requires_grad_state:
1188+
continue
1189+
param_requires_grad_state[param] = param.requires_grad
1190+
param.requires_grad = False
1191+
1192+
# Then iterate over the current optimizer's parameters and set its `requires_grad`
1193+
# properties accordingly
1194+
for group in optimizer.param_groups:
1195+
for param in group['params']:
1196+
param.requires_grad = param_requires_grad_state[param]
11951197
self._param_requires_grad_state = param_requires_grad_state
11961198

11971199
def untoggle_optimizer(self, optimizer_idx: int):

tests/core/test_lightning_module.py

Lines changed: 145 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, clos
142142
trainer.fit(model)
143143

144144

145-
def test_toggle_untoggle(tmpdir):
145+
def test_toggle_untoggle_2_optimizers_no_shared_parameters(tmpdir):
146146

147147
class TestModel(BoringModel):
148148

@@ -198,8 +198,152 @@ def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, clos
198198
assert self.layer_2[1].weight.requires_grad is False
199199
assert self.layer_2[3].weight.requires_grad is False
200200
assert self.layer_2[5].weight.requires_grad is True
201+
202+
optimizer.step(closure=closure)
203+
204+
model = TestModel()
205+
model.training_epoch_end = None
206+
207+
trainer = Trainer(
208+
max_epochs=1,
209+
default_root_dir=tmpdir,
210+
limit_train_batches=8,
211+
accumulate_grad_batches=1,
212+
limit_val_batches=0,
213+
)
214+
215+
results = trainer.fit(model)
216+
assert results
217+
218+
219+
def test_toggle_untoggle_3_optimizers_shared_parameters(tmpdir):
220+
221+
class TestModel(BoringModel):
222+
223+
def __init__(self):
224+
super().__init__()
225+
self.layer_1 = nn.Sequential(
226+
nn.Linear(32, 32),
227+
nn.ReLU(),
228+
nn.Linear(32, 32),
229+
nn.ReLU(),
230+
nn.Linear(32, 32),
231+
)
232+
233+
self.layer_2 = nn.Sequential(
234+
nn.ReLU(),
235+
nn.Linear(32, 32),
236+
nn.ReLU(),
237+
nn.Linear(32, 32),
238+
nn.ReLU(),
239+
nn.Linear(32, 2)
240+
)
241+
242+
self.layer_3 = nn.Sequential(
243+
nn.ReLU(),
244+
nn.Linear(32, 32),
245+
nn.ReLU(),
246+
nn.Linear(32, 32),
247+
nn.ReLU(),
248+
nn.Linear(32, 2)
249+
)
250+
251+
# set some weights to False to check untoggle works as expected.
252+
self.layer_1[2].weight.requires_grad = False
253+
self.layer_1[4].weight.requires_grad = False
254+
255+
self.layer_2[1].weight.requires_grad = False
256+
self.layer_2[3].weight.requires_grad = False
257+
258+
self.layer_3[1].weight.requires_grad = False
259+
self.layer_3[5].weight.requires_grad = False
260+
261+
def optimizer_step(
262+
self,
263+
current_epoch,
264+
batch_nb,
265+
optimizer,
266+
optimizer_idx,
267+
closure,
268+
on_tpu=False,
269+
using_native_amp=False,
270+
using_lbfgs=False
271+
):
272+
if optimizer_idx == 0:
273+
assert self.layer_1[0].weight.requires_grad is True
274+
assert self.layer_1[2].weight.requires_grad is False
275+
assert self.layer_1[4].weight.requires_grad is False
276+
277+
assert self.layer_2[1].weight.requires_grad is False
278+
assert self.layer_2[3].weight.requires_grad is False
279+
assert self.layer_2[5].weight.requires_grad is True
280+
281+
assert self.layer_3[1].weight.requires_grad is False
282+
assert self.layer_3[3].weight.requires_grad is False
283+
assert self.layer_3[5].weight.requires_grad is False
284+
285+
if optimizer_idx == 1:
286+
assert self.layer_1[0].weight.requires_grad is False
287+
assert self.layer_1[2].weight.requires_grad is False
288+
assert self.layer_1[4].weight.requires_grad is False
289+
290+
assert self.layer_2[1].weight.requires_grad is False
291+
assert self.layer_2[3].weight.requires_grad is False
292+
assert self.layer_2[5].weight.requires_grad is True
293+
294+
assert self.layer_3[1].weight.requires_grad is False
295+
assert self.layer_3[3].weight.requires_grad is True
296+
assert self.layer_3[5].weight.requires_grad is False
297+
298+
if optimizer_idx == 2:
299+
assert self.layer_1[0].weight.requires_grad is True
300+
assert self.layer_1[2].weight.requires_grad is False
301+
assert self.layer_1[4].weight.requires_grad is False
302+
303+
assert self.layer_2[1].weight.requires_grad is False
304+
assert self.layer_2[3].weight.requires_grad is False
305+
assert self.layer_2[5].weight.requires_grad is False
306+
307+
assert self.layer_3[1].weight.requires_grad is False
308+
assert self.layer_3[3].weight.requires_grad is True
309+
assert self.layer_3[5].weight.requires_grad is False
310+
201311
optimizer.step(closure=closure)
202312

313+
def training_step(self, batch, batch_idx, optimizer_idx=None):
314+
return super().training_step(batch, batch_idx)
315+
316+
@staticmethod
317+
def combine_generators(gen_1, gen_2):
318+
for p in gen_1:
319+
yield p
320+
for p in gen_2:
321+
yield p
322+
323+
def configure_optimizers(self):
324+
optimizer_1 = SGD(
325+
self.combine_generators(
326+
self.layer_1.parameters(),
327+
self.layer_2.parameters()
328+
),
329+
lr=0.1
330+
)
331+
optimizer_2 = Adam(
332+
self.combine_generators(
333+
self.layer_2.parameters(),
334+
self.layer_3.parameters()
335+
),
336+
lr=0.1
337+
)
338+
optimizer_3 = SGD(
339+
self.combine_generators(
340+
self.layer_3.parameters(),
341+
self.layer_1.parameters()
342+
),
343+
lr=0.1
344+
)
345+
return [optimizer_1, optimizer_2, optimizer_3]
346+
203347
model = TestModel()
204348
model.training_epoch_end = None
205349

0 commit comments

Comments
 (0)