Skip to content

Commit 9041240

Browse files
Bring back set_epoch for Accelerate-based dataloaders (#26850)
* Working tests! * Fix sampler * Fix * Update src/transformers/trainer.py Co-authored-by: Arthur <[email protected]> * Fix check * Clean --------- Co-authored-by: Arthur <[email protected]>
1 parent 3c26924 commit 9041240

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

src/transformers/trainer.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,11 @@
200200
save_fsdp_model,
201201
save_fsdp_optimizer,
202202
)
203+
DATA_SAMPLERS = [RandomSampler]
204+
if version.parse(accelerate_version) > version.parse("0.23.0"):
205+
from accelerate.data_loader import SeedableRandomSampler
206+
207+
DATA_SAMPLERS += [SeedableRandomSampler]
203208

204209
if is_deepspeed_available():
205210
from accelerate.utils import DeepSpeedSchedulerWrapper
@@ -1738,7 +1743,10 @@ def _inner_training_loop(
17381743
if not args.ignore_data_skip:
17391744
for epoch in range(epochs_trained):
17401745
sampler = get_dataloader_sampler(train_dataloader)
1741-
is_random_sampler = isinstance(sampler, RandomSampler)
1746+
sampler_kinds = [RandomSampler]
1747+
if version.parse(accelerate_version) > version.parse("0.23.0"):
1748+
sampler_kinds.append(SeedableRandomSampler)
1749+
is_random_sampler = isinstance(sampler, tuple(sampler_kinds))
17421750
if is_torch_less_than_1_11 or not is_random_sampler:
17431751
# We just need to begin an iteration to create the randomization of the sampler.
17441752
for _ in train_dataloader:
@@ -1752,6 +1760,8 @@ def _inner_training_loop(
17521760
total_batched_samples = 0
17531761
for epoch in range(epochs_trained, num_train_epochs):
17541762
epoch_iterator = train_dataloader
1763+
if hasattr(epoch_iterator, "set_epoch"):
1764+
epoch_iterator.set_epoch(epoch)
17551765

17561766
# Reset the past mems state at the beginning of each epoch if necessary.
17571767
if args.past_index >= 0:

0 commit comments

Comments
 (0)