Skip to content

LearningRateFinder not working with CLI optimizers #16787

@rusmux

Description

@rusmux

Bug description

LearningRateFinder does not update the optimizer if it is defined from the CLI or yaml config file.

For example, I define in train.yaml:

...
optimizer:
  class_path: torch.optim.AdamW
  init_args:
    lr: 1.5e-3
...

And I set the callback:

LearningRateFinder(update_attr=True)

At the start, It finds the best learning rate:

Screenshot 82

But after that, it still uses the learning rate I provided:

Screenshot 83

I also tried to do it manually like that:

Screenshot 84

But I had the same result.

How to reproduce the bug

Define an optimizer in a yaml config file. Add the `LearningRateFinder` callback.

Error messages and logs

# Error messages and logs here please

Environment

Current environment
* CUDA:
	- GPU:
		- NVIDIA RTX A4000
	- available:         True
	- version:           11.7
* Lightning:
	- lightning-utilities: 0.6.0.post0
	- pytorch-lightning: 1.9.1
	- torch:             1.13.1
	- torchmetrics:      0.11.1
	- torchvision:       0.14.1
* Packages:
	- aiobotocore:       2.4.2
	- aiofiles:          22.1.0
	- aiohttp:           3.8.4
	- aiohttp-retry:     2.8.3
	- aioitertools:      0.11.0
	- aiosignal:         1.3.1
	- aiosqlite:         0.18.0
	- albumentations:    1.3.0
	- amqp:              5.1.1
	- antlr4-python3-runtime: 4.9.3
	- anyio:             3.6.2
	- appdirs:           1.4.4
	- argcomplete:       2.0.0
	- argon2-cffi:       21.3.0
	- argon2-cffi-bindings: 21.2.0
	- arrow:             1.2.3
	- astor:             0.8.1
	- asttokens:         2.2.1
	- async-timeout:     4.0.2
	- asyncssh:          2.13.0
	- atpublic:          3.1.1
	- attrs:             22.2.0
	- babel:             2.11.0
	- backcall:          0.2.0
	- bandit:            1.7.4
	- beautifulsoup4:    4.11.2
	- billiard:          3.6.4.0
	- bleach:            6.0.0
	- boto3:             1.24.59
	- botocore:          1.27.59
	- celery:            5.2.7
	- certifi:           2022.12.7
	- cffi:              1.15.1
	- cfgv:              3.3.1
	- charset-normalizer: 3.0.1
	- clearml:           1.9.1
	- click:             8.1.3
	- click-didyoumean:  0.3.0
	- click-plugins:     1.1.1
	- click-repl:        0.2.0
	- colorama:          0.4.6
	- comm:              0.1.2
	- configobj:         5.0.8
	- contourpy:         1.0.7
	- cryptography:      39.0.1
	- cycler:            0.11.0
	- dacite:            1.8.0
	- darglint:          1.8.1
	- debugpy:           1.6.6
	- decorator:         5.1.1
	- defusedxml:        0.7.1
	- deprecated:        1.2.13
	- dictdiffer:        0.9.0
	- dill:              0.3.6
	- diskcache:         5.4.0
	- distlib:           0.3.6
	- distro:            1.8.0
	- dnspython:         2.3.0
	- docstring-parser:  0.15
	- docutils:          0.19
	- dpath:             2.1.4
	- dulwich:           0.21.2
	- dvc:               2.45.0
	- dvc-data:          0.40.1
	- dvc-http:          2.30.2
	- dvc-objects:       0.19.3
	- dvc-render:        0.1.2
	- dvc-s3:            2.21.0
	- dvc-studio-client: 0.4.0
	- dvc-task:          0.1.11
	- dvclive:           2.0.2
	- eradicate:         2.1.0
	- eventlet:          0.33.3
	- exceptiongroup:    1.1.0
	- executing:         1.2.0
	- fastjsonschema:    2.16.2
	- fiftyone:          0.18.0
	- fiftyone-brain:    0.9.2
	- fiftyone-db:       0.4.0
	- filelock:          3.9.0
	- flake8:            4.0.1
	- flake8-bandit:     3.0.0
	- flake8-broken-line: 0.5.0
	- flake8-bugbear:    22.12.6
	- flake8-commas:     2.1.0
	- flake8-comprehensions: 3.10.1
	- flake8-debugger:   4.1.2
	- flake8-docstrings: 1.7.0
	- flake8-eradicate:  1.4.0
	- flake8-isort:      4.2.0
	- flake8-polyfill:   1.0.2
	- flake8-quotes:     3.3.2
	- flake8-rst-docstrings: 0.2.7
	- flake8-string-format: 0.3.0
	- flatten-dict:      0.4.2
	- flufl.lock:        7.1.1
	- fonttools:         4.38.0
	- fqdn:              1.5.1
	- frozenlist:        1.3.3
	- fsspec:            2023.1.0
	- funcy:             1.18
	- furl:              2.1.3
	- future:            0.18.3
	- gitdb:             4.0.10
	- gitpython:         3.1.30
	- glob2:             0.7
	- grandalf:          0.8
	- graphql-core:      3.2.3
	- greenlet:          2.0.2
	- h11:               0.14.0
	- h2:                4.1.0
	- hpack:             4.0.0
	- httpcore:          0.16.3
	- httpx:             0.23.3
	- huggingface-hub:   0.12.0
	- hydra-core:        1.3.1
	- hypercorn:         0.14.3
	- hyperframe:        6.0.1
	- identify:          2.5.18
	- idna:              3.4
	- imageio:           2.25.1
	- importlib-resources: 5.10.2
	- iniconfig:         2.0.0
	- ipykernel:         6.21.2
	- ipython:           8.10.0
	- ipython-genutils:  0.2.0
	- ipywidgets:        8.0.4
	- isoduration:       20.11.0
	- isort:             5.12.0
	- iterative-telemetry: 0.0.7
	- jedi:              0.18.2
	- jinja2:            3.1.2
	- jmespath:          1.0.1
	- joblib:            1.2.0
	- json5:             0.9.11
	- jsonargparse:      4.19.0
	- jsonpointer:       2.3
	- jsonschema:        4.17.3
	- jupyter-client:    8.0.2
	- jupyter-contrib-core: 0.4.2
	- jupyter-contrib-nbextensions: 0.7.0
	- jupyter-core:      5.2.0
	- jupyter-events:    0.5.0
	- jupyter-highlight-selected-word: 0.2.0
	- jupyter-nbextensions-configurator: 0.6.1
	- jupyter-server:    2.2.1
	- jupyter-server-fileid: 0.6.0
	- jupyter-server-terminals: 0.4.4
	- jupyter-server-ydoc: 0.6.1
	- jupyter-ydoc:      0.2.2
	- jupyterlab:        3.6.1
	- jupyterlab-pygments: 0.2.2
	- jupyterlab-server: 2.19.0
	- jupyterlab-widgets: 3.0.5
	- kaleido:           0.2.1
	- kiwisolver:        1.4.4
	- kombu:             5.2.4
	- lightning-utilities: 0.6.0.post0
	- lxml:              4.9.2
	- markdown-it-py:    2.1.0
	- markupsafe:        2.1.2
	- matplotlib:        3.7.0
	- matplotlib-inline: 0.1.6
	- mccabe:            0.6.1
	- mdurl:             0.1.2
	- mistune:           2.0.5
	- mongoengine:       0.24.2
	- motor:             3.1.1
	- multidict:         6.0.4
	- nanotime:          0.5.2
	- nbclassic:         0.5.1
	- nbclient:          0.7.2
	- nbconvert:         7.2.9
	- nbformat:          5.7.3
	- ndjson:            0.3.1
	- nest-asyncio:      1.5.6
	- networkx:          3.0
	- nodeenv:           1.7.0
	- notebook:          6.5.2
	- notebook-shim:     0.2.2
	- numpy:             1.24.2
	- nvidia-cublas-cu11: 11.10.3.66
	- nvidia-cuda-nvrtc-cu11: 11.7.99
	- nvidia-cuda-runtime-cu11: 11.7.99
	- nvidia-cudnn-cu11: 8.5.0.96
	- omegaconf:         2.3.0
	- onnx:              1.13.0
	- opencv-python-headless: 4.7.0.68
	- orderedmultidict:  1.0.1
	- orjson:            3.8.6
	- packaging:         23.0
	- pandas:            1.5.3
	- pandocfilters:     1.5.0
	- parso:             0.8.3
	- pathlib2:          2.3.7.post1
	- pathspec:          0.11.0
	- patool:            1.12
	- pbr:               5.11.1
	- pep8-naming:       0.13.2
	- pexpect:           4.8.0
	- pickleshare:       0.7.5
	- pillow:            9.4.0
	- pip:               23.0
	- platformdirs:      3.0.0
	- plotly:            5.13.0
	- pluggy:            1.0.0
	- pprintpp:          0.4.0
	- pre-commit:        2.21.0
	- priority:          2.0.0
	- prometheus-client: 0.16.0
	- prompt-toolkit:    3.0.36
	- protobuf:          3.20.3
	- psutil:            5.9.4
	- ptyprocess:        0.7.0
	- pure-eval:         0.2.2
	- pycodestyle:       2.8.0
	- pycparser:         2.21
	- pydocstyle:        6.3.0
	- pydot:             1.4.2
	- pyflakes:          2.4.0
	- pygit2:            1.11.1
	- pygments:          2.14.0
	- pygtrie:           2.5.0
	- pyjwt:             2.4.0
	- pymongo:           4.3.3
	- pyparsing:         3.0.9
	- pyrsistent:        0.19.3
	- pytest:            7.2.1
	- python-dateutil:   2.8.2
	- python-json-logger: 2.0.6
	- pytorch-lightning: 1.9.1
	- pytz:              2022.7.1
	- pytz-deprecation-shim: 0.1.0.post0
	- pywavelets:        1.4.1
	- pyyaml:            6.0
	- pyzmq:             25.0.0
	- qudida:            0.0.4
	- requests:          2.28.2
	- restructuredtext-lint: 1.4.0
	- retrying:          1.3.4
	- rfc3339-validator: 0.1.4
	- rfc3986:           1.5.0
	- rfc3986-validator: 0.1.1
	- rich:              13.3.1
	- ruamel.yaml:       0.17.21
	- ruamel.yaml.clib:  0.2.7
	- s3fs:              2023.1.0
	- s3transfer:        0.6.0
	- scikit-image:      0.19.3
	- scikit-learn:      1.2.1
	- scipy:             1.10.0
	- scmrepo:           0.1.9
	- send2trash:        1.8.0
	- setuptools:        67.3.1
	- shortuuid:         1.0.11
	- shtab:             1.5.8
	- six:               1.16.0
	- smmap:             5.0.0
	- sniffio:           1.3.0
	- snowballstemmer:   2.2.0
	- sortedcontainers:  2.4.0
	- soupsieve:         2.4
	- sqltrie:           0.0.28
	- sse-starlette:     0.10.3
	- sseclient-py:      1.7.2
	- stack-data:        0.6.2
	- starlette:         0.20.4
	- stevedore:         5.0.0
	- strawberry-graphql: 0.138.1
	- tabulate:          0.9.0
	- tenacity:          8.2.1
	- tensorboardx:      2.6
	- terminado:         0.17.1
	- threadpoolctl:     3.1.0
	- tifffile:          2023.2.3
	- timm:              0.6.12
	- tinycss2:          1.2.1
	- toml:              0.10.2
	- tomli:             2.0.1
	- tomlkit:           0.11.6
	- torch:             1.13.1
	- torchmetrics:      0.11.1
	- torchvision:       0.14.1
	- tornado:           6.2
	- tqdm:              4.64.1
	- traitlets:         5.9.0
	- typeshed-client:   2.2.0
	- typing-extensions: 4.5.0
	- tzdata:            2022.7
	- tzlocal:           4.2
	- universal-analytics-python3: 1.1.1
	- uri-template:      1.2.0
	- urllib3:           1.26.14
	- vine:              5.0.0
	- virtualenv:        20.19.0
	- voluptuous:        0.13.1
	- voxel51-eta:       0.8.3
	- wcwidth:           0.2.6
	- webcolors:         1.12
	- webencodings:      0.5.1
	- websocket-client:  1.5.1
	- wemake-python-styleguide: 0.17.0
	- wheel:             0.38.4
	- widgetsnbextension: 4.0.5
	- wrapt:             1.14.1
	- wsproto:           1.2.0
	- xmltodict:         0.13.0
	- y-py:              0.5.5
	- yarl:              1.8.2
	- ypy-websocket:     0.8.2
	- zc.lockfile:       2.0
* System:
	- OS:                Linux
	- architecture:
		- 64bit
		-
	- processor:
	- python:            3.10.10
	- version:           #152-Ubuntu SMP Wed Nov 23 20:19:22 UTC 2022

More info

I think, the problem is specific in how and when optimizers and schedulers are instantiated. Because I run the above code, but only for batch size, and it worked as expected:

Screenshot 85

It used the found batch size in training.

For now, as I understand, the way to use LearningRateFinder is to manually define configure_optimizers() in LightningModule. But this way I can't change the optimizer from the yaml config file.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingtuner

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions