@@ -3159,11 +3159,11 @@ def setUp(self):
31593159 def test_custom_enumerate_basic (self ):
31603160 """Test that custom enumerate works correctly without state restoration."""
31613161 dataloader = DataLoader (self .dataset , batch_size = 2 , shuffle = False )
3162-
3162+
31633163 # Test custom enumerate produces correct indices
31643164 custom_results = list (dataloader .enumerate ())
31653165 builtin_results = list (enumerate (dataloader ))
3166-
3166+
31673167 # Both should produce the same results when no state is loaded
31683168 self .assertEqual (len (custom_results ), len (builtin_results ))
31693169 for (custom_idx , custom_data ), (builtin_idx , builtin_data ) in zip (custom_results , builtin_results ):
@@ -3173,44 +3173,44 @@ def test_custom_enumerate_basic(self):
31733173 def test_custom_enumerate_with_start_parameter (self ):
31743174 """Test that custom enumerate works correctly with start parameter."""
31753175 dataloader = DataLoader (self .dataset , batch_size = 2 , shuffle = False )
3176-
3176+
31773177 start_value = 100
31783178 results = list (dataloader .enumerate (start = start_value ))
3179-
3179+
31803180 expected_indices = list (range (start_value , start_value + len (dataloader )))
31813181 actual_indices = [idx for idx , _ in results ]
3182-
3182+
31833183 self .assertEqual (actual_indices , expected_indices )
31843184
31853185 def test_custom_enumerate_with_state_restoration (self ):
31863186 """Test that custom enumerate correctly handles state restoration."""
31873187 # Create initial dataloader and process some batches
31883188 dataloader1 = DataLoader (self .dataset , batch_size = 2 , shuffle = False )
3189-
3189+
31903190 # Process first 3 batches (indices 0, 1, 2) and save state
31913191 processed_count = 0
31923192 for i , (batch ,) in enumerate (dataloader1 ):
31933193 processed_count += 1
31943194 if i == 2 : # After processing batches 0, 1, 2
31953195 state = dataloader1 .state_dict ()
31963196 break
3197-
3197+
31983198 self .assertEqual (processed_count , 3 )
3199-
3199+
32003200 # Create new dataloader and restore state
32013201 dataloader2 = DataLoader (self .dataset , batch_size = 2 , shuffle = False )
32023202 dataloader2 .load_state_dict (state )
3203-
3203+
32043204 # Use custom enumerate to continue
32053205 remaining_results = list (dataloader2 .enumerate ())
3206-
3206+
32073207 # Should start from index 3 (since we processed 0, 1, 2)
32083208 expected_start_index = 3
32093209 expected_indices = list (range (expected_start_index , len (dataloader1 )))
32103210 actual_indices = [idx for idx , _ in remaining_results ]
3211-
3211+
32123212 self .assertEqual (actual_indices , expected_indices )
3213-
3213+
32143214 # Verify data correctness
32153215 expected_data_start = 6 # batch 3 should contain [6, 7]
32163216 first_batch_data = remaining_results [0 ][1 ][0 ]
@@ -3220,31 +3220,31 @@ def test_custom_enumerate_vs_builtin_after_restoration(self):
32203220 """Test that demonstrates the difference between custom and builtin enumerate after state restoration."""
32213221 # Create initial dataloader and process some batches
32223222 dataloader1 = DataLoader (self .dataset , batch_size = 2 , shuffle = False )
3223-
3223+
32243224 # Process first 2 batches and save state
32253225 for i , batch in enumerate (dataloader1 ):
32263226 if i == 1 : # After processing batches 0, 1
32273227 state = dataloader1 .state_dict ()
32283228 break
3229-
3229+
32303230 # Test builtin enumerate (demonstrates the problem)
32313231 dataloader2 = DataLoader (self .dataset , batch_size = 2 , shuffle = False )
32323232 dataloader2 .load_state_dict (state )
32333233 builtin_results = list (enumerate (dataloader2 ))
32343234 builtin_indices = [idx for idx , _ in builtin_results ]
3235-
3235+
32363236 # Test custom enumerate (shows the fix)
32373237 dataloader3 = DataLoader (self .dataset , batch_size = 2 , shuffle = False )
32383238 dataloader3 .load_state_dict (state )
32393239 custom_results = list (dataloader3 .enumerate ())
32403240 custom_indices = [idx for idx , _ in custom_results ]
3241-
3241+
32423242 # Builtin enumerate should start from 0 (the problem)
32433243 self .assertEqual (builtin_indices , [0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ])
3244-
3244+
32453245 # Custom enumerate should start from 2 (the fix)
32463246 self .assertEqual (custom_indices , [2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ])
3247-
3247+
32483248 # Data should be the same for both
32493249 for (_ , builtin_data ), (_ , custom_data ) in zip (builtin_results , custom_results ):
32503250 self .assertTrue (torch .equal (builtin_data [0 ], custom_data [0 ]))
@@ -3253,18 +3253,18 @@ def test_custom_enumerate_with_multiprocessing(self):
32533253 """Test that custom enumerate works correctly with multiprocessing."""
32543254 # Test with 2 workers
32553255 dataloader1 = DataLoader (self .dataset , batch_size = 2 , shuffle = False , num_workers = 2 )
3256-
3256+
32573257 # Process some batches and save state
32583258 for i , batch in enumerate (dataloader1 ):
32593259 if i == 2 :
32603260 state = dataloader1 .state_dict ()
32613261 break
3262-
3262+
32633263 # Restore state and use custom enumerate
32643264 dataloader2 = DataLoader (self .dataset , batch_size = 2 , shuffle = False , num_workers = 2 )
32653265 dataloader2 .load_state_dict (state )
32663266 results = list (dataloader2 .enumerate ())
3267-
3267+
32683268 # Should start from the correct index
32693269 expected_start_index = 3
32703270 actual_indices = [idx for idx , _ in results ]
@@ -3276,20 +3276,20 @@ def test_custom_enumerate_empty_after_restoration(self):
32763276 small_data = torch .arange (4 )
32773277 small_dataset = TensorDataset (small_data )
32783278 dataloader1 = DataLoader (small_dataset , batch_size = 2 , shuffle = False )
3279-
3279+
32803280 # Process all but the last batch
32813281 state = None
32823282 batches_processed = 0
32833283 for i , batch in enumerate (dataloader1 ):
32843284 batches_processed += 1
32853285 if i == 0 : # After first batch, only one batch remains
32863286 state = dataloader1 .state_dict ()
3287-
3287+
32883288 # Restore state and process remaining data
32893289 dataloader2 = DataLoader (small_dataset , batch_size = 2 , shuffle = False )
32903290 dataloader2 .load_state_dict (state )
32913291 remaining_results = list (dataloader2 .enumerate ())
3292-
3292+
32933293 # Should have exactly one batch remaining with correct index
32943294 self .assertEqual (len (remaining_results ), 1 )
32953295 self .assertEqual (remaining_results [0 ][0 ], 1 ) # Should be index 1
@@ -3300,47 +3300,48 @@ def test_custom_enumerate_single_batch(self):
33003300 single_batch_data = torch .arange (4 )
33013301 single_batch_dataset = TensorDataset (single_batch_data )
33023302 dataloader = DataLoader (single_batch_dataset , batch_size = 4 , shuffle = False )
3303-
3303+
33043304 # Should produce one result with index 0
33053305 results = list (dataloader .enumerate ())
33063306 self .assertEqual (len (results ), 1 )
33073307 self .assertEqual (results [0 ][0 ], 0 )
3308-
3308+
33093309 # Test with start parameter
33103310 results_with_start = list (dataloader .enumerate (start = 50 ))
33113311 self .assertEqual (len (results_with_start ), 1 )
33123312 self .assertEqual (results_with_start [0 ][0 ], 50 )
33133313
33143314 def test_custom_enumerate_iterable_dataset (self ):
33153315 """Test custom enumerate with IterableDataset."""
3316+
33163317 class SimpleIterableDataset (IterableDataset ):
33173318 def __init__ (self , data ):
33183319 self .data = data
3319-
3320+
33203321 def __iter__ (self ):
33213322 return iter (self .data )
3322-
3323+
33233324 def __len__ (self ):
33243325 return len (self .data )
3325-
3326+
33263327 iterable_dataset = SimpleIterableDataset (list (range (10 )))
33273328 dataloader = DataLoader (iterable_dataset , batch_size = 2 , shuffle = False )
3328-
3329+
33293330 # Test basic custom enumerate
33303331 results = list (dataloader .enumerate ())
33313332 expected_indices = list (range (5 )) # 10 items / 2 batch_size = 5 batches
33323333 actual_indices = [idx for idx , _ in results ]
3333-
3334+
33343335 self .assertEqual (actual_indices , expected_indices )
33353336
33363337 def test_custom_enumerate_consistency (self ):
33373338 """Test that multiple calls to custom enumerate produce consistent results."""
33383339 dataloader = DataLoader (self .dataset , batch_size = 3 , shuffle = False )
3339-
3340+
33403341 # Call enumerate multiple times
33413342 results1 = list (dataloader .enumerate ())
33423343 results2 = list (dataloader .enumerate (start = 0 ))
3343-
3344+
33443345 # Results should be identical
33453346 self .assertEqual (len (results1 ), len (results2 ))
33463347 for (idx1 , data1 ), (idx2 , data2 ) in zip (results1 , results2 ):
0 commit comments