-
Notifications
You must be signed in to change notification settings - Fork 424
[Fix] Deprecate torch.cuda.amp API
#1675
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…radScaler("cuda", args...)`` instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR updates the codebase to use the new torch.amp.GradScaler API instead of the deprecated torch.cuda.amp.GradScaler API, following PyTorch's deprecation guidance.
Key Changes:
- Replaced deprecated
torch.cuda.amp.GradScalerimports withtorch.amp.GradScalerusing apartialfunction to set the device to 'cuda' - Applied this change to both production code (
amp_optimizer_wrapper.py) and test code (test_optimizer_wrapper.py)
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| mmengine/optim/optimizer/amp_optimizer_wrapper.py | Updated GradScaler import from deprecated torch.cuda.amp to torch.amp with device parameter |
| tests/test_optim/test_optimizer/test_optimizer_wrapper.py | Updated GradScaler import in tests to use the new torch.amp API |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| from functools import partial | ||
|
|
||
| import unittest | ||
| from unittest import TestCase | ||
| from unittest.mock import MagicMock |
Copilot
AI
Oct 26, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The blank line after from functools import partial creates inconsistent import grouping. Move the functools import to be with other standard library imports (os, unittest) before the blank line that separates standard library imports from third-party imports.
| from functools import partial | |
| import unittest | |
| from unittest import TestCase | |
| from unittest.mock import MagicMock | |
| import unittest | |
| from unittest import TestCase | |
| from unittest.mock import MagicMock | |
| from functools import partial |
| from parameterized import parameterized | ||
| from torch.cuda.amp import GradScaler | ||
| from torch.amp import GradScaler as amp_GradScaler | ||
| GradScaler = partial(amp_GradScaler, device='cuda') |
Copilot
AI
Oct 26, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Creating a module-level variable GradScaler through partial assignment makes the code less maintainable and harder to understand. Consider either: (1) using amp_GradScaler('cuda', ...) directly at call sites, or (2) creating a proper wrapper function with a docstring explaining the device binding.
| GradScaler = partial(amp_GradScaler, device='cuda') | |
| def get_cuda_grad_scaler(*args, **kwargs): | |
| """Return a torch.amp.GradScaler instance bound to the 'cuda' device. | |
| Args: | |
| *args: Positional arguments for torch.amp.GradScaler. | |
| **kwargs: Keyword arguments for torch.amp.GradScaler. | |
| Returns: | |
| amp_GradScaler: An instance of GradScaler with device='cuda'. | |
| """ | |
| return amp_GradScaler(*args, device='cuda', **kwargs) |
| GradScaler = partial(amp_GradScaler, device='cuda') | ||
|
|
||
|
|
Copilot
AI
Oct 26, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Creating a module-level variable GradScaler through partial assignment makes the code less maintainable and harder to understand. Consider either: (1) using amp_GradScaler('cuda', ...) directly at call sites, or (2) creating a proper wrapper function with a docstring explaining the device binding.
| GradScaler = partial(amp_GradScaler, device='cuda') | |
| def get_grad_scaler(*args, **kwargs): | |
| """Create a torch.amp.GradScaler instance bound to device='cuda'. | |
| Args: | |
| *args: Positional arguments passed to torch.amp.GradScaler. | |
| **kwargs: Keyword arguments passed to torch.amp.GradScaler. | |
| Returns: | |
| amp_GradScaler: An instance of torch.amp.GradScaler with device='cuda'. | |
| """ | |
| return amp_GradScaler(*args, device='cuda', **kwargs) |
|
@HAOCHENYE This one is ready to be reviewed. The copilot review suggested a different implementation which also LGTM. You may choose one to merge~~ |
…dScaler` to validate instead.
|
This workaround LGTM, I'll merge it after fixing the lint |
This is a sub-PR of #1665
Brief
According to PyTorch:
This includes two related replacement:
amp_optimizer_wrappertest_optimizer_wrapperPyTest Result
pytest tests/test_optim/test_optimizer/test_optimizer_wrapper.py