Skip to content

Commit eb26833

Browse files
committed
run precommit
1 parent cb1bf57 commit eb26833

File tree

2 files changed

+87
-192
lines changed

2 files changed

+87
-192
lines changed

test/stateful_dataloader/test_dataloader.py

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3159,11 +3159,11 @@ def setUp(self):
31593159
def test_custom_enumerate_basic(self):
31603160
"""Test that custom enumerate works correctly without state restoration."""
31613161
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=False)
3162-
3162+
31633163
# Test custom enumerate produces correct indices
31643164
custom_results = list(dataloader.enumerate())
31653165
builtin_results = list(enumerate(dataloader))
3166-
3166+
31673167
# Both should produce the same results when no state is loaded
31683168
self.assertEqual(len(custom_results), len(builtin_results))
31693169
for (custom_idx, custom_data), (builtin_idx, builtin_data) in zip(custom_results, builtin_results):
@@ -3173,44 +3173,44 @@ def test_custom_enumerate_basic(self):
31733173
def test_custom_enumerate_with_start_parameter(self):
31743174
"""Test that custom enumerate works correctly with start parameter."""
31753175
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=False)
3176-
3176+
31773177
start_value = 100
31783178
results = list(dataloader.enumerate(start=start_value))
3179-
3179+
31803180
expected_indices = list(range(start_value, start_value + len(dataloader)))
31813181
actual_indices = [idx for idx, _ in results]
3182-
3182+
31833183
self.assertEqual(actual_indices, expected_indices)
31843184

31853185
def test_custom_enumerate_with_state_restoration(self):
31863186
"""Test that custom enumerate correctly handles state restoration."""
31873187
# Create initial dataloader and process some batches
31883188
dataloader1 = DataLoader(self.dataset, batch_size=2, shuffle=False)
3189-
3189+
31903190
# Process first 3 batches (indices 0, 1, 2) and save state
31913191
processed_count = 0
31923192
for i, (batch,) in enumerate(dataloader1):
31933193
processed_count += 1
31943194
if i == 2: # After processing batches 0, 1, 2
31953195
state = dataloader1.state_dict()
31963196
break
3197-
3197+
31983198
self.assertEqual(processed_count, 3)
3199-
3199+
32003200
# Create new dataloader and restore state
32013201
dataloader2 = DataLoader(self.dataset, batch_size=2, shuffle=False)
32023202
dataloader2.load_state_dict(state)
3203-
3203+
32043204
# Use custom enumerate to continue
32053205
remaining_results = list(dataloader2.enumerate())
3206-
3206+
32073207
# Should start from index 3 (since we processed 0, 1, 2)
32083208
expected_start_index = 3
32093209
expected_indices = list(range(expected_start_index, len(dataloader1)))
32103210
actual_indices = [idx for idx, _ in remaining_results]
3211-
3211+
32123212
self.assertEqual(actual_indices, expected_indices)
3213-
3213+
32143214
# Verify data correctness
32153215
expected_data_start = 6 # batch 3 should contain [6, 7]
32163216
first_batch_data = remaining_results[0][1][0]
@@ -3220,31 +3220,31 @@ def test_custom_enumerate_vs_builtin_after_restoration(self):
32203220
"""Test that demonstrates the difference between custom and builtin enumerate after state restoration."""
32213221
# Create initial dataloader and process some batches
32223222
dataloader1 = DataLoader(self.dataset, batch_size=2, shuffle=False)
3223-
3223+
32243224
# Process first 2 batches and save state
32253225
for i, batch in enumerate(dataloader1):
32263226
if i == 1: # After processing batches 0, 1
32273227
state = dataloader1.state_dict()
32283228
break
3229-
3229+
32303230
# Test builtin enumerate (demonstrates the problem)
32313231
dataloader2 = DataLoader(self.dataset, batch_size=2, shuffle=False)
32323232
dataloader2.load_state_dict(state)
32333233
builtin_results = list(enumerate(dataloader2))
32343234
builtin_indices = [idx for idx, _ in builtin_results]
3235-
3235+
32363236
# Test custom enumerate (shows the fix)
32373237
dataloader3 = DataLoader(self.dataset, batch_size=2, shuffle=False)
32383238
dataloader3.load_state_dict(state)
32393239
custom_results = list(dataloader3.enumerate())
32403240
custom_indices = [idx for idx, _ in custom_results]
3241-
3241+
32423242
# Builtin enumerate should start from 0 (the problem)
32433243
self.assertEqual(builtin_indices, [0, 1, 2, 3, 4, 5, 6, 7])
3244-
3244+
32453245
# Custom enumerate should start from 2 (the fix)
32463246
self.assertEqual(custom_indices, [2, 3, 4, 5, 6, 7, 8, 9])
3247-
3247+
32483248
# Data should be the same for both
32493249
for (_, builtin_data), (_, custom_data) in zip(builtin_results, custom_results):
32503250
self.assertTrue(torch.equal(builtin_data[0], custom_data[0]))
@@ -3253,18 +3253,18 @@ def test_custom_enumerate_with_multiprocessing(self):
32533253
"""Test that custom enumerate works correctly with multiprocessing."""
32543254
# Test with 2 workers
32553255
dataloader1 = DataLoader(self.dataset, batch_size=2, shuffle=False, num_workers=2)
3256-
3256+
32573257
# Process some batches and save state
32583258
for i, batch in enumerate(dataloader1):
32593259
if i == 2:
32603260
state = dataloader1.state_dict()
32613261
break
3262-
3262+
32633263
# Restore state and use custom enumerate
32643264
dataloader2 = DataLoader(self.dataset, batch_size=2, shuffle=False, num_workers=2)
32653265
dataloader2.load_state_dict(state)
32663266
results = list(dataloader2.enumerate())
3267-
3267+
32683268
# Should start from the correct index
32693269
expected_start_index = 3
32703270
actual_indices = [idx for idx, _ in results]
@@ -3276,20 +3276,20 @@ def test_custom_enumerate_empty_after_restoration(self):
32763276
small_data = torch.arange(4)
32773277
small_dataset = TensorDataset(small_data)
32783278
dataloader1 = DataLoader(small_dataset, batch_size=2, shuffle=False)
3279-
3279+
32803280
# Process all but the last batch
32813281
state = None
32823282
batches_processed = 0
32833283
for i, batch in enumerate(dataloader1):
32843284
batches_processed += 1
32853285
if i == 0: # After first batch, only one batch remains
32863286
state = dataloader1.state_dict()
3287-
3287+
32883288
# Restore state and process remaining data
32893289
dataloader2 = DataLoader(small_dataset, batch_size=2, shuffle=False)
32903290
dataloader2.load_state_dict(state)
32913291
remaining_results = list(dataloader2.enumerate())
3292-
3292+
32933293
# Should have exactly one batch remaining with correct index
32943294
self.assertEqual(len(remaining_results), 1)
32953295
self.assertEqual(remaining_results[0][0], 1) # Should be index 1
@@ -3300,47 +3300,48 @@ def test_custom_enumerate_single_batch(self):
33003300
single_batch_data = torch.arange(4)
33013301
single_batch_dataset = TensorDataset(single_batch_data)
33023302
dataloader = DataLoader(single_batch_dataset, batch_size=4, shuffle=False)
3303-
3303+
33043304
# Should produce one result with index 0
33053305
results = list(dataloader.enumerate())
33063306
self.assertEqual(len(results), 1)
33073307
self.assertEqual(results[0][0], 0)
3308-
3308+
33093309
# Test with start parameter
33103310
results_with_start = list(dataloader.enumerate(start=50))
33113311
self.assertEqual(len(results_with_start), 1)
33123312
self.assertEqual(results_with_start[0][0], 50)
33133313

33143314
def test_custom_enumerate_iterable_dataset(self):
33153315
"""Test custom enumerate with IterableDataset."""
3316+
33163317
class SimpleIterableDataset(IterableDataset):
33173318
def __init__(self, data):
33183319
self.data = data
3319-
3320+
33203321
def __iter__(self):
33213322
return iter(self.data)
3322-
3323+
33233324
def __len__(self):
33243325
return len(self.data)
3325-
3326+
33263327
iterable_dataset = SimpleIterableDataset(list(range(10)))
33273328
dataloader = DataLoader(iterable_dataset, batch_size=2, shuffle=False)
3328-
3329+
33293330
# Test basic custom enumerate
33303331
results = list(dataloader.enumerate())
33313332
expected_indices = list(range(5)) # 10 items / 2 batch_size = 5 batches
33323333
actual_indices = [idx for idx, _ in results]
3333-
3334+
33343335
self.assertEqual(actual_indices, expected_indices)
33353336

33363337
def test_custom_enumerate_consistency(self):
33373338
"""Test that multiple calls to custom enumerate produce consistent results."""
33383339
dataloader = DataLoader(self.dataset, batch_size=3, shuffle=False)
3339-
3340+
33403341
# Call enumerate multiple times
33413342
results1 = list(dataloader.enumerate())
33423343
results2 = list(dataloader.enumerate(start=0))
3343-
3344+
33443345
# Results should be identical
33453346
self.assertEqual(len(results1), len(results2))
33463347
for (idx1, data1), (idx2, data2) in zip(results1, results2):

0 commit comments

Comments
 (0)