Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,11 @@
save_fsdp_model,
save_fsdp_optimizer,
)
DATA_SAMPLERS = [RandomSampler]
if version.parse(accelerate_version) > version.parse("0.23.0"):
from accelerate.data_loader import SeedableRandomSampler

DATA_SAMPLERS += [SeedableRandomSampler]

if is_deepspeed_available():
from accelerate.utils import DeepSpeedSchedulerWrapper
Expand Down Expand Up @@ -1738,7 +1743,10 @@ def _inner_training_loop(
if not args.ignore_data_skip:
for epoch in range(epochs_trained):
sampler = get_dataloader_sampler(train_dataloader)
is_random_sampler = isinstance(sampler, RandomSampler)
sampler_kinds = [RandomSampler]
if version.parse(accelerate_version) > version.parse("0.23.0"):
sampler_kinds.append(SeedableRandomSampler)
is_random_sampler = isinstance(sampler, tuple(sampler_kinds))
if is_torch_less_than_1_11 or not is_random_sampler:
# We just need to begin an iteration to create the randomization of the sampler.
for _ in train_dataloader:
Expand All @@ -1752,6 +1760,8 @@ def _inner_training_loop(
total_batched_samples = 0
for epoch in range(epochs_trained, num_train_epochs):
epoch_iterator = train_dataloader
if hasattr(epoch_iterator, "set_epoch"):
epoch_iterator.set_epoch(epoch)

# Reset the past mems state at the beginning of each epoch if necessary.
if args.past_index >= 0:
Expand Down