3737from torchao .optim .subclass_fp8 import OptimStateFp8
3838from torchao .testing .utils import skip_if_rocm
3939from torchao .utils import (
40- TORCH_VERSION_AT_LEAST_2_4 ,
4140 TORCH_VERSION_AT_LEAST_2_5 ,
4241 TORCH_VERSION_AT_LEAST_2_7 ,
4342 get_available_devices ,
@@ -128,8 +127,6 @@ class TestOptim(TestCase):
128127 @skip_if_rocm ("ROCm enablement in progress" )
129128 def test_optim_smoke (self , optim_name , dtype , device ):
130129 if optim_name .endswith ("Fp8" ) and device == "cuda" :
131- if not TORCH_VERSION_AT_LEAST_2_4 :
132- pytest .skip ("FP8 CUDA requires PyTorch >= 2.4" )
133130 if torch .cuda .get_device_capability () < (8 , 9 ):
134131 pytest .skip ("FP8 CUDA requires compute capability >= 8.9" )
135132
@@ -166,6 +163,30 @@ def test_optim_smoke(self, optim_name, dtype, device):
166163 for p1 , p2 in zip (model .parameters (), model2 .parameters ()):
167164 torch .testing .assert_close (p2 , p1 )
168165
166+ @parametrize ("optim_name" , ["Adam8bit" , "Adam4bit" , "AdamFp8" ])
167+ @parametrize ("device" , _DEVICES )
168+ def test_optim_default_dtype_bf16 (self , optim_name , device ):
169+ if optim_name .endswith ("Fp8" ) and device == "cuda" :
170+ if torch .cuda .get_device_capability () < (8 , 9 ):
171+ pytest .skip ("FP8 CUDA requires compute capability >= 8.9" )
172+
173+ old_dtype = torch .get_default_dtype ()
174+ torch .set_default_dtype (torch .bfloat16 )
175+
176+ try :
177+ model = nn .Sequential (nn .Linear (32 , 256 ), nn .ReLU (), nn .Linear (256 , 32 ))
178+ model .to (device = device )
179+ optimizer = getattr (optim , optim_name )(model .parameters ())
180+
181+ x = torch .randn (4 , 32 , device = device )
182+ loss = model (x ).sum ()
183+ loss .backward ()
184+ optimizer .step ()
185+ optimizer .zero_grad ()
186+
187+ finally :
188+ torch .set_default_dtype (old_dtype )
189+
169190 # aten.slice is required for dcp.load() when world size changes i.e. re-sharding
170191 # however, it's cumbersome to test it directly, since we would need to run distributed
171192 # test 2 times with different world size, and persist checkpoint across the 2 runs.
@@ -178,8 +199,6 @@ def test_subclass_slice(self, subclass, shape, device):
178199 if subclass == OptimStateFp8 :
179200 if device == "cpu" and len (shape ) > 1 and not TORCH_VERSION_AT_LEAST_2_5 :
180201 pytest .skip ("fill_cpu not implemented for Float8_e4m3fn for torch<2.5" )
181- if device == "cuda" and not TORCH_VERSION_AT_LEAST_2_4 :
182- pytest .skip ("FP8 CUDA requires PyTorch >= 2.4" )
183202 if device == "cuda" and torch .cuda .get_device_capability () < (8 , 9 ):
184203 pytest .skip ("FP8 CUDA requires compute capability >= 8.9" )
185204
0 commit comments