@@ -3145,5 +3145,208 @@ def test_out_of_order_iterable_ds(self):
31453145instantiate_device_type_tests (TestDataLoaderDeviceType , globals ())
31463146
31473147
3148+ @unittest .skipIf (
3149+ TEST_WITH_TSAN ,
3150+ "Fails with TSAN with the following error: starting new threads after multi-threaded "
3151+ "fork is not supported. Dying (set die_after_fork=0 to override)" ,
3152+ )
3153+ class TestStatefulDataLoaderEnumerate (TestCase ):
3154+ def setUp (self ):
3155+ super ().setUp ()
3156+ self .data = torch .arange (20 )
3157+ self .dataset = TensorDataset (self .data )
3158+
3159+ def test_custom_enumerate_basic (self ):
3160+ """Test that custom enumerate works correctly without state restoration."""
3161+ dataloader = DataLoader (self .dataset , batch_size = 2 , shuffle = False )
3162+
3163+ # Test custom enumerate produces correct indices
3164+ custom_results = list (dataloader .enumerate ())
3165+ builtin_results = list (enumerate (dataloader ))
3166+
3167+ # Both should produce the same results when no state is loaded
3168+ self .assertEqual (len (custom_results ), len (builtin_results ))
3169+ for (custom_idx , custom_data ), (builtin_idx , builtin_data ) in zip (custom_results , builtin_results ):
3170+ self .assertEqual (custom_idx , builtin_idx )
3171+ self .assertTrue (torch .equal (custom_data [0 ], builtin_data [0 ]))
3172+
3173+ def test_custom_enumerate_with_start_parameter (self ):
3174+ """Test that custom enumerate works correctly with start parameter."""
3175+ dataloader = DataLoader (self .dataset , batch_size = 2 , shuffle = False )
3176+
3177+ start_value = 100
3178+ results = list (dataloader .enumerate (start = start_value ))
3179+
3180+ expected_indices = list (range (start_value , start_value + len (dataloader )))
3181+ actual_indices = [idx for idx , _ in results ]
3182+
3183+ self .assertEqual (actual_indices , expected_indices )
3184+
3185+ def test_custom_enumerate_with_state_restoration (self ):
3186+ """Test that custom enumerate correctly handles state restoration."""
3187+ # Create initial dataloader and process some batches
3188+ dataloader1 = DataLoader (self .dataset , batch_size = 2 , shuffle = False )
3189+
3190+ # Process first 3 batches (indices 0, 1, 2) and save state
3191+ processed_count = 0
3192+ for i , (batch ,) in enumerate (dataloader1 ):
3193+ processed_count += 1
3194+ if i == 2 : # After processing batches 0, 1, 2
3195+ state = dataloader1 .state_dict ()
3196+ break
3197+
3198+ self .assertEqual (processed_count , 3 )
3199+
3200+ # Create new dataloader and restore state
3201+ dataloader2 = DataLoader (self .dataset , batch_size = 2 , shuffle = False )
3202+ dataloader2 .load_state_dict (state )
3203+
3204+ # Use custom enumerate to continue
3205+ remaining_results = list (dataloader2 .enumerate ())
3206+
3207+ # Should start from index 3 (since we processed 0, 1, 2)
3208+ expected_start_index = 3
3209+ expected_indices = list (range (expected_start_index , len (dataloader1 )))
3210+ actual_indices = [idx for idx , _ in remaining_results ]
3211+
3212+ self .assertEqual (actual_indices , expected_indices )
3213+
3214+ # Verify data correctness
3215+ expected_data_start = 6 # batch 3 should contain [6, 7]
3216+ first_batch_data = remaining_results [0 ][1 ][0 ]
3217+ self .assertTrue (torch .equal (first_batch_data , torch .tensor ([expected_data_start , expected_data_start + 1 ])))
3218+
3219+ def test_custom_enumerate_vs_builtin_after_restoration (self ):
3220+ """Test that demonstrates the difference between custom and builtin enumerate after state restoration."""
3221+ # Create initial dataloader and process some batches
3222+ dataloader1 = DataLoader (self .dataset , batch_size = 2 , shuffle = False )
3223+
3224+ # Process first 2 batches and save state
3225+ for i , batch in enumerate (dataloader1 ):
3226+ if i == 1 : # After processing batches 0, 1
3227+ state = dataloader1 .state_dict ()
3228+ break
3229+
3230+ # Test builtin enumerate (demonstrates the problem)
3231+ dataloader2 = DataLoader (self .dataset , batch_size = 2 , shuffle = False )
3232+ dataloader2 .load_state_dict (state )
3233+ builtin_results = list (enumerate (dataloader2 ))
3234+ builtin_indices = [idx for idx , _ in builtin_results ]
3235+
3236+ # Test custom enumerate (shows the fix)
3237+ dataloader3 = DataLoader (self .dataset , batch_size = 2 , shuffle = False )
3238+ dataloader3 .load_state_dict (state )
3239+ custom_results = list (dataloader3 .enumerate ())
3240+ custom_indices = [idx for idx , _ in custom_results ]
3241+
3242+ # Builtin enumerate should start from 0 (the problem)
3243+ self .assertEqual (builtin_indices , [0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ])
3244+
3245+ # Custom enumerate should start from 2 (the fix)
3246+ self .assertEqual (custom_indices , [2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ])
3247+
3248+ # Data should be the same for both
3249+ for (_ , builtin_data ), (_ , custom_data ) in zip (builtin_results , custom_results ):
3250+ self .assertTrue (torch .equal (builtin_data [0 ], custom_data [0 ]))
3251+
3252+ def test_custom_enumerate_with_multiprocessing (self ):
3253+ """Test that custom enumerate works correctly with multiprocessing."""
3254+ # Test with 2 workers
3255+ dataloader1 = DataLoader (self .dataset , batch_size = 2 , shuffle = False , num_workers = 2 )
3256+
3257+ # Process some batches and save state
3258+ for i , batch in enumerate (dataloader1 ):
3259+ if i == 2 :
3260+ state = dataloader1 .state_dict ()
3261+ break
3262+
3263+ # Restore state and use custom enumerate
3264+ dataloader2 = DataLoader (self .dataset , batch_size = 2 , shuffle = False , num_workers = 2 )
3265+ dataloader2 .load_state_dict (state )
3266+ results = list (dataloader2 .enumerate ())
3267+
3268+ # Should start from the correct index
3269+ expected_start_index = 3
3270+ actual_indices = [idx for idx , _ in results ]
3271+ self .assertEqual (actual_indices [0 ], expected_start_index )
3272+
3273+ def test_custom_enumerate_empty_after_restoration (self ):
3274+ """Test custom enumerate when no data remains after state restoration."""
3275+ # Use a small dataset
3276+ small_data = torch .arange (4 )
3277+ small_dataset = TensorDataset (small_data )
3278+ dataloader1 = DataLoader (small_dataset , batch_size = 2 , shuffle = False )
3279+
3280+ # Process all but the last batch
3281+ state = None
3282+ batches_processed = 0
3283+ for i , batch in enumerate (dataloader1 ):
3284+ batches_processed += 1
3285+ if i == 0 : # After first batch, only one batch remains
3286+ state = dataloader1 .state_dict ()
3287+
3288+ # Restore state and process remaining data
3289+ dataloader2 = DataLoader (small_dataset , batch_size = 2 , shuffle = False )
3290+ dataloader2 .load_state_dict (state )
3291+ remaining_results = list (dataloader2 .enumerate ())
3292+
3293+ # Should have exactly one batch remaining with correct index
3294+ self .assertEqual (len (remaining_results ), 1 )
3295+ self .assertEqual (remaining_results [0 ][0 ], 1 ) # Should be index 1
3296+
3297+ def test_custom_enumerate_single_batch (self ):
3298+ """Test custom enumerate with single batch scenarios."""
3299+ # Create dataset with exactly one batch
3300+ single_batch_data = torch .arange (4 )
3301+ single_batch_dataset = TensorDataset (single_batch_data )
3302+ dataloader = DataLoader (single_batch_dataset , batch_size = 4 , shuffle = False )
3303+
3304+ # Should produce one result with index 0
3305+ results = list (dataloader .enumerate ())
3306+ self .assertEqual (len (results ), 1 )
3307+ self .assertEqual (results [0 ][0 ], 0 )
3308+
3309+ # Test with start parameter
3310+ results_with_start = list (dataloader .enumerate (start = 50 ))
3311+ self .assertEqual (len (results_with_start ), 1 )
3312+ self .assertEqual (results_with_start [0 ][0 ], 50 )
3313+
3314+ def test_custom_enumerate_iterable_dataset (self ):
3315+ """Test custom enumerate with IterableDataset."""
3316+ class SimpleIterableDataset (IterableDataset ):
3317+ def __init__ (self , data ):
3318+ self .data = data
3319+
3320+ def __iter__ (self ):
3321+ return iter (self .data )
3322+
3323+ def __len__ (self ):
3324+ return len (self .data )
3325+
3326+ iterable_dataset = SimpleIterableDataset (list (range (10 )))
3327+ dataloader = DataLoader (iterable_dataset , batch_size = 2 , shuffle = False )
3328+
3329+ # Test basic custom enumerate
3330+ results = list (dataloader .enumerate ())
3331+ expected_indices = list (range (5 )) # 10 items / 2 batch_size = 5 batches
3332+ actual_indices = [idx for idx , _ in results ]
3333+
3334+ self .assertEqual (actual_indices , expected_indices )
3335+
3336+ def test_custom_enumerate_consistency (self ):
3337+ """Test that multiple calls to custom enumerate produce consistent results."""
3338+ dataloader = DataLoader (self .dataset , batch_size = 3 , shuffle = False )
3339+
3340+ # Call enumerate multiple times
3341+ results1 = list (dataloader .enumerate ())
3342+ results2 = list (dataloader .enumerate (start = 0 ))
3343+
3344+ # Results should be identical
3345+ self .assertEqual (len (results1 ), len (results2 ))
3346+ for (idx1 , data1 ), (idx2 , data2 ) in zip (results1 , results2 ):
3347+ self .assertEqual (idx1 , idx2 )
3348+ self .assertTrue (torch .equal (data1 [0 ], data2 [0 ]))
3349+
3350+
31483351if __name__ == "__main__" :
31493352 run_tests ()
0 commit comments