2424from transformers import is_torch_available
2525from transformers .testing_utils import (
2626 TestCasePlus ,
27+ backend_device_count ,
2728 execute_subprocess_async ,
28- get_gpu_count ,
2929 mockenv_context ,
3030 require_accelerate ,
3131 require_fsdp ,
32- require_torch_gpu ,
33- require_torch_multi_gpu ,
32+ require_torch_accelerator ,
33+ require_torch_multi_accelerator ,
3434 slow ,
35+ torch_device ,
3536)
3637from transformers .trainer_callback import TrainerState
3738from transformers .trainer_utils import FSDPOption , set_seed
38- from transformers .utils import is_accelerate_available , is_torch_bf16_gpu_available
39+ from transformers .utils import is_accelerate_available , is_torch_bf16_available_on_device
3940
4041
4142if is_torch_available ():
4647# default torch.distributed port
4748DEFAULT_MASTER_PORT = "10999"
4849dtypes = ["fp16" ]
49- if is_torch_bf16_gpu_available ( ):
50+ if is_torch_bf16_available_on_device ( torch_device ):
5051 dtypes += ["bf16" ]
5152sharding_strategies = ["full_shard" , "shard_grad_op" ]
5253state_dict_types = ["FULL_STATE_DICT" , "SHARDED_STATE_DICT" ]
@@ -100,7 +101,7 @@ def get_launcher(distributed=False, use_accelerate=False):
100101 # - it won't be able to handle that
101102 # 2. for now testing with just 2 gpus max (since some quality tests may give different
102103 # results with mode gpus because we use very little data)
103- num_gpus = min (2 , get_gpu_count ( )) if distributed else 1
104+ num_gpus = min (2 , backend_device_count ( torch_device )) if distributed else 1
104105 master_port = get_master_port (real_launcher = True )
105106 if use_accelerate :
106107 return f"""accelerate launch
@@ -121,7 +122,7 @@ def _parameterized_custom_name_func(func, param_num, param):
121122
122123
123124@require_accelerate
124- @require_torch_gpu
125+ @require_torch_accelerator
125126@require_fsdp_version
126127class TrainerIntegrationFSDP (TestCasePlus , TrainerIntegrationCommon ):
127128 def setUp (self ):
@@ -170,7 +171,7 @@ def test_fsdp_config(self, sharding_strategy, dtype):
170171 self .assertEqual (os .environ .get ("ACCELERATE_USE_FSDP" , "false" ), "true" )
171172
172173 @parameterized .expand (params , name_func = _parameterized_custom_name_func )
173- @require_torch_multi_gpu
174+ @require_torch_multi_accelerator
174175 @slow
175176 def test_basic_run (self , sharding_strategy , dtype ):
176177 launcher = get_launcher (distributed = True , use_accelerate = False )
@@ -182,7 +183,7 @@ def test_basic_run(self, sharding_strategy, dtype):
182183 execute_subprocess_async (cmd , env = self .get_env ())
183184
184185 @parameterized .expand (dtypes )
185- @require_torch_multi_gpu
186+ @require_torch_multi_accelerator
186187 @slow
187188 @unittest .skipIf (not is_torch_greater_or_equal_than_2_1 , reason = "This test on pytorch 2.0 takes 4 hours." )
188189 def test_basic_run_with_cpu_offload (self , dtype ):
@@ -195,7 +196,7 @@ def test_basic_run_with_cpu_offload(self, dtype):
195196 execute_subprocess_async (cmd , env = self .get_env ())
196197
197198 @parameterized .expand (state_dict_types , name_func = _parameterized_custom_name_func )
198- @require_torch_multi_gpu
199+ @require_torch_multi_accelerator
199200 @slow
200201 def test_training_and_can_resume_normally (self , state_dict_type ):
201202 output_dir = self .get_auto_remove_tmp_dir ("./xxx" , after = False )
0 commit comments