diff --git a/tests/fsdp/test_fsdp.py b/tests/fsdp/test_fsdp.py index 69103dcd8c3f..2a9473c862ff 100644 --- a/tests/fsdp/test_fsdp.py +++ b/tests/fsdp/test_fsdp.py @@ -24,18 +24,19 @@ from transformers import is_torch_available from transformers.testing_utils import ( TestCasePlus, + backend_device_count, execute_subprocess_async, - get_gpu_count, mockenv_context, require_accelerate, require_fsdp, - require_torch_gpu, - require_torch_multi_gpu, + require_torch_accelerator, + require_torch_multi_accelerator, slow, + torch_device, ) from transformers.trainer_callback import TrainerState from transformers.trainer_utils import FSDPOption, set_seed -from transformers.utils import is_accelerate_available, is_torch_bf16_gpu_available +from transformers.utils import is_accelerate_available, is_torch_bf16_available_on_device if is_torch_available(): @@ -46,7 +47,7 @@ # default torch.distributed port DEFAULT_MASTER_PORT = "10999" dtypes = ["fp16"] -if is_torch_bf16_gpu_available(): +if is_torch_bf16_available_on_device(torch_device): dtypes += ["bf16"] sharding_strategies = ["full_shard", "shard_grad_op"] state_dict_types = ["FULL_STATE_DICT", "SHARDED_STATE_DICT"] @@ -100,7 +101,7 @@ def get_launcher(distributed=False, use_accelerate=False): # - it won't be able to handle that # 2. for now testing with just 2 gpus max (since some quality tests may give different # results with mode gpus because we use very little data) - num_gpus = min(2, get_gpu_count()) if distributed else 1 + num_gpus = min(2, backend_device_count(torch_device)) if distributed else 1 master_port = get_master_port(real_launcher=True) if use_accelerate: return f"""accelerate launch @@ -121,7 +122,7 @@ def _parameterized_custom_name_func(func, param_num, param): @require_accelerate -@require_torch_gpu +@require_torch_accelerator @require_fsdp_version class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon): def setUp(self): @@ -170,7 +171,7 @@ def test_fsdp_config(self, sharding_strategy, dtype): self.assertEqual(os.environ.get("ACCELERATE_USE_FSDP", "false"), "true") @parameterized.expand(params, name_func=_parameterized_custom_name_func) - @require_torch_multi_gpu + @require_torch_multi_accelerator @slow def test_basic_run(self, sharding_strategy, dtype): launcher = get_launcher(distributed=True, use_accelerate=False) @@ -182,7 +183,7 @@ def test_basic_run(self, sharding_strategy, dtype): execute_subprocess_async(cmd, env=self.get_env()) @parameterized.expand(dtypes) - @require_torch_multi_gpu + @require_torch_multi_accelerator @slow @unittest.skipIf(not is_torch_greater_or_equal_than_2_1, reason="This test on pytorch 2.0 takes 4 hours.") def test_basic_run_with_cpu_offload(self, dtype): @@ -195,7 +196,7 @@ def test_basic_run_with_cpu_offload(self, dtype): execute_subprocess_async(cmd, env=self.get_env()) @parameterized.expand(state_dict_types, name_func=_parameterized_custom_name_func) - @require_torch_multi_gpu + @require_torch_multi_accelerator @slow def test_training_and_can_resume_normally(self, state_dict_type): output_dir = self.get_auto_remove_tmp_dir("./xxx", after=False)