From cf889630d04d61acfc0f9a42cb1518bf61f9080e Mon Sep 17 00:00:00 2001 From: mgam <312065559@qq.com> Date: Sun, 26 Oct 2025 15:47:31 +0800 Subject: [PATCH 1/2] torch.cuda.amp.GradScaler(args...)`` is deprecated, use ``torch.amp.GradScaler("cuda", args...)`` instead. --- mmengine/optim/optimizer/amp_optimizer_wrapper.py | 4 +++- tests/test_optim/test_optimizer/test_optimizer_wrapper.py | 5 ++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/mmengine/optim/optimizer/amp_optimizer_wrapper.py b/mmengine/optim/optimizer/amp_optimizer_wrapper.py index 4f3323f2cc..60200924b5 100644 --- a/mmengine/optim/optimizer/amp_optimizer_wrapper.py +++ b/mmengine/optim/optimizer/amp_optimizer_wrapper.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from contextlib import contextmanager +from functools import partial from typing import Union import torch @@ -17,7 +18,8 @@ elif is_mlu_available(): from torch.mlu.amp import GradScaler else: - from torch.cuda.amp import GradScaler + from torch.amp import GradScaler as amp_GradScaler + GradScaler = partial(amp_GradScaler, device='cuda') @OPTIM_WRAPPERS.register_module() diff --git a/tests/test_optim/test_optimizer/test_optimizer_wrapper.py b/tests/test_optim/test_optimizer/test_optimizer_wrapper.py index ef1db241dd..3fbb588678 100644 --- a/tests/test_optim/test_optimizer/test_optimizer_wrapper.py +++ b/tests/test_optim/test_optimizer/test_optimizer_wrapper.py @@ -1,5 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import os +from functools import partial + import unittest from unittest import TestCase from unittest.mock import MagicMock @@ -8,7 +10,8 @@ import torch.distributed as torch_dist import torch.nn as nn from parameterized import parameterized -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler as amp_GradScaler +GradScaler = partial(amp_GradScaler, device='cuda') from torch.nn.parallel.distributed import DistributedDataParallel from torch.optim import SGD, Adam, Optimizer From 74a0c8b91b3120545df9ae29d72d533f31bcf46b Mon Sep 17 00:00:00 2001 From: mgam <312065559@qq.com> Date: Sun, 26 Oct 2025 16:39:48 +0800 Subject: [PATCH 2/2] [Fix] unittest cannot recognize partialled func, use the vanilla `GradScaler` to validate instead. --- tests/test_optim/test_optimizer/test_optimizer_wrapper.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_optim/test_optimizer/test_optimizer_wrapper.py b/tests/test_optim/test_optimizer/test_optimizer_wrapper.py index 3fbb588678..d8a0c03812 100644 --- a/tests/test_optim/test_optimizer/test_optimizer_wrapper.py +++ b/tests/test_optim/test_optimizer/test_optimizer_wrapper.py @@ -426,13 +426,13 @@ def setUp(self) -> None: def test_init(self): # Test with default arguments. amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer) - self.assertIsInstance(amp_optim_wrapper.loss_scaler, GradScaler) + self.assertIsInstance(amp_optim_wrapper.loss_scaler, amp_GradScaler) # Test with dynamic. amp_optim_wrapper = AmpOptimWrapper( 'dynamic', optimizer=self.optimizer) self.assertIsNone(amp_optim_wrapper._scale_update_param) - self.assertIsInstance(amp_optim_wrapper.loss_scaler, GradScaler) + self.assertIsInstance(amp_optim_wrapper.loss_scaler, amp_GradScaler) # Test with dtype float16 amp_optim_wrapper = AmpOptimWrapper( @@ -447,7 +447,7 @@ def test_init(self): # Test with dict loss_scale. amp_optim_wrapper = AmpOptimWrapper( dict(init_scale=1, growth_factor=2), optimizer=self.optimizer) - self.assertIsInstance(amp_optim_wrapper.loss_scaler, GradScaler) + self.assertIsInstance(amp_optim_wrapper.loss_scaler, amp_GradScaler) self.assertIsNone(amp_optim_wrapper._scale_update_param) with self.assertRaisesRegex(TypeError, 'loss_scale must be of type float'):