200200 save_fsdp_model ,
201201 save_fsdp_optimizer ,
202202 )
203+ DATA_SAMPLERS = [RandomSampler ]
204+ if version .parse (accelerate_version ) > version .parse ("0.23.0" ):
205+ from accelerate .data_loader import SeedableRandomSampler
206+
207+ DATA_SAMPLERS += [SeedableRandomSampler ]
203208
204209 if is_deepspeed_available ():
205210 from accelerate .utils import DeepSpeedSchedulerWrapper
@@ -1738,7 +1743,10 @@ def _inner_training_loop(
17381743 if not args .ignore_data_skip :
17391744 for epoch in range (epochs_trained ):
17401745 sampler = get_dataloader_sampler (train_dataloader )
1741- is_random_sampler = isinstance (sampler , RandomSampler )
1746+ sampler_kinds = [RandomSampler ]
1747+ if version .parse (accelerate_version ) > version .parse ("0.23.0" ):
1748+ sampler_kinds .append (SeedableRandomSampler )
1749+ is_random_sampler = isinstance (sampler , tuple (sampler_kinds ))
17421750 if is_torch_less_than_1_11 or not is_random_sampler :
17431751 # We just need to begin an iteration to create the randomization of the sampler.
17441752 for _ in train_dataloader :
@@ -1752,6 +1760,8 @@ def _inner_training_loop(
17521760 total_batched_samples = 0
17531761 for epoch in range (epochs_trained , num_train_epochs ):
17541762 epoch_iterator = train_dataloader
1763+ if hasattr (epoch_iterator , "set_epoch" ):
1764+ epoch_iterator .set_epoch (epoch )
17551765
17561766 # Reset the past mems state at the beginning of each epoch if necessary.
17571767 if args .past_index >= 0 :
0 commit comments