Skip to content

Commit 82c7e87

Browse files
authored
device agnostic fsdp testing (#27120)
* make fsdp test cases device agnostic * make style
1 parent 7d8ff36 commit 82c7e87

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

tests/fsdp/test_fsdp.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,18 +24,19 @@
2424
from transformers import is_torch_available
2525
from transformers.testing_utils import (
2626
TestCasePlus,
27+
backend_device_count,
2728
execute_subprocess_async,
28-
get_gpu_count,
2929
mockenv_context,
3030
require_accelerate,
3131
require_fsdp,
32-
require_torch_gpu,
33-
require_torch_multi_gpu,
32+
require_torch_accelerator,
33+
require_torch_multi_accelerator,
3434
slow,
35+
torch_device,
3536
)
3637
from transformers.trainer_callback import TrainerState
3738
from transformers.trainer_utils import FSDPOption, set_seed
38-
from transformers.utils import is_accelerate_available, is_torch_bf16_gpu_available
39+
from transformers.utils import is_accelerate_available, is_torch_bf16_available_on_device
3940

4041

4142
if is_torch_available():
@@ -46,7 +47,7 @@
4647
# default torch.distributed port
4748
DEFAULT_MASTER_PORT = "10999"
4849
dtypes = ["fp16"]
49-
if is_torch_bf16_gpu_available():
50+
if is_torch_bf16_available_on_device(torch_device):
5051
dtypes += ["bf16"]
5152
sharding_strategies = ["full_shard", "shard_grad_op"]
5253
state_dict_types = ["FULL_STATE_DICT", "SHARDED_STATE_DICT"]
@@ -100,7 +101,7 @@ def get_launcher(distributed=False, use_accelerate=False):
100101
# - it won't be able to handle that
101102
# 2. for now testing with just 2 gpus max (since some quality tests may give different
102103
# results with mode gpus because we use very little data)
103-
num_gpus = min(2, get_gpu_count()) if distributed else 1
104+
num_gpus = min(2, backend_device_count(torch_device)) if distributed else 1
104105
master_port = get_master_port(real_launcher=True)
105106
if use_accelerate:
106107
return f"""accelerate launch
@@ -121,7 +122,7 @@ def _parameterized_custom_name_func(func, param_num, param):
121122

122123

123124
@require_accelerate
124-
@require_torch_gpu
125+
@require_torch_accelerator
125126
@require_fsdp_version
126127
class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
127128
def setUp(self):
@@ -170,7 +171,7 @@ def test_fsdp_config(self, sharding_strategy, dtype):
170171
self.assertEqual(os.environ.get("ACCELERATE_USE_FSDP", "false"), "true")
171172

172173
@parameterized.expand(params, name_func=_parameterized_custom_name_func)
173-
@require_torch_multi_gpu
174+
@require_torch_multi_accelerator
174175
@slow
175176
def test_basic_run(self, sharding_strategy, dtype):
176177
launcher = get_launcher(distributed=True, use_accelerate=False)
@@ -182,7 +183,7 @@ def test_basic_run(self, sharding_strategy, dtype):
182183
execute_subprocess_async(cmd, env=self.get_env())
183184

184185
@parameterized.expand(dtypes)
185-
@require_torch_multi_gpu
186+
@require_torch_multi_accelerator
186187
@slow
187188
@unittest.skipIf(not is_torch_greater_or_equal_than_2_1, reason="This test on pytorch 2.0 takes 4 hours.")
188189
def test_basic_run_with_cpu_offload(self, dtype):
@@ -195,7 +196,7 @@ def test_basic_run_with_cpu_offload(self, dtype):
195196
execute_subprocess_async(cmd, env=self.get_env())
196197

197198
@parameterized.expand(state_dict_types, name_func=_parameterized_custom_name_func)
198-
@require_torch_multi_gpu
199+
@require_torch_multi_accelerator
199200
@slow
200201
def test_training_and_can_resume_normally(self, state_dict_type):
201202
output_dir = self.get_auto_remove_tmp_dir("./xxx", after=False)

0 commit comments

Comments
 (0)