Skip to content

Commit 64bcf77

Browse files
pacman100amyeroberts
authored andcommitted
fix resuming from ckpt when using FSDP with FULL_STATE_DICT (#27891)
* fix resuming from ckpt when suing FSDP with FULL_STATE_DICT * update tests * fix tests
1 parent 780376f commit 64bcf77

File tree

2 files changed

+23
-4
lines changed

2 files changed

+23
-4
lines changed

src/transformers/trainer.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2030,10 +2030,15 @@ def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
20302030
weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
20312031
safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME)
20322032
safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME)
2033-
is_fsdp_ckpt = os.path.isdir(resume_from_checkpoint) and any(
2034-
FSDP_MODEL_NAME in folder_name
2035-
for folder_name in os.listdir(resume_from_checkpoint)
2036-
if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name))
2033+
is_fsdp_ckpt = os.path.isdir(resume_from_checkpoint) and (
2034+
# this checks the FSDP state dict when `SHARDED_STATE_DICT` is used
2035+
any(
2036+
FSDP_MODEL_NAME in folder_name
2037+
for folder_name in os.listdir(resume_from_checkpoint)
2038+
if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name))
2039+
)
2040+
# this checks the FSDP state dict when `FULL_STATE_DICT` is used
2041+
or os.path.isfile(os.path.join(resume_from_checkpoint, f"{FSDP_MODEL_NAME}.bin"))
20372042
)
20382043

20392044
if is_fsdp_ckpt and not self.is_fsdp_enabled:

tests/fsdp/test_fsdp.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141

4242
if is_torch_available():
4343
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_1
44+
from transformers.trainer import FSDP_MODEL_NAME
4445
else:
4546
is_torch_greater_or_equal_than_2_1 = False
4647

@@ -211,6 +212,19 @@ def test_training_and_can_resume_normally(self, state_dict_type):
211212
# resume from ckpt
212213
checkpoint = os.path.join(output_dir, "checkpoint-115")
213214
resume_args = args + f"--resume_from_checkpoint {checkpoint}".split()
215+
216+
is_fsdp_ckpt = os.path.isdir(checkpoint) and (
217+
# this checks the FSDP state dict when `SHARDED_STATE_DICT` is used
218+
any(
219+
FSDP_MODEL_NAME in folder_name
220+
for folder_name in os.listdir(checkpoint)
221+
if os.path.isdir(os.path.join(checkpoint, folder_name))
222+
)
223+
# this checks the FSDP state dict when `FULL_STATE_DICT` is used
224+
or os.path.isfile(os.path.join(checkpoint, f"{FSDP_MODEL_NAME}.bin"))
225+
)
226+
self.assertTrue(is_fsdp_ckpt)
227+
214228
logs_resume = self.run_cmd_and_get_logs(
215229
use_accelerate, sharding_strategy, launcher, script, resume_args, output_dir
216230
)

0 commit comments

Comments
 (0)