Skip to content

Commit cb1bf57

Browse files
committed
add custom iterator
1 parent a05a54f commit cb1bf57

File tree

2 files changed

+402
-53
lines changed

2 files changed

+402
-53
lines changed

test/stateful_dataloader/test_dataloader.py

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3145,5 +3145,208 @@ def test_out_of_order_iterable_ds(self):
31453145
instantiate_device_type_tests(TestDataLoaderDeviceType, globals())
31463146

31473147

3148+
@unittest.skipIf(
3149+
TEST_WITH_TSAN,
3150+
"Fails with TSAN with the following error: starting new threads after multi-threaded "
3151+
"fork is not supported. Dying (set die_after_fork=0 to override)",
3152+
)
3153+
class TestStatefulDataLoaderEnumerate(TestCase):
3154+
def setUp(self):
3155+
super().setUp()
3156+
self.data = torch.arange(20)
3157+
self.dataset = TensorDataset(self.data)
3158+
3159+
def test_custom_enumerate_basic(self):
3160+
"""Test that custom enumerate works correctly without state restoration."""
3161+
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=False)
3162+
3163+
# Test custom enumerate produces correct indices
3164+
custom_results = list(dataloader.enumerate())
3165+
builtin_results = list(enumerate(dataloader))
3166+
3167+
# Both should produce the same results when no state is loaded
3168+
self.assertEqual(len(custom_results), len(builtin_results))
3169+
for (custom_idx, custom_data), (builtin_idx, builtin_data) in zip(custom_results, builtin_results):
3170+
self.assertEqual(custom_idx, builtin_idx)
3171+
self.assertTrue(torch.equal(custom_data[0], builtin_data[0]))
3172+
3173+
def test_custom_enumerate_with_start_parameter(self):
3174+
"""Test that custom enumerate works correctly with start parameter."""
3175+
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=False)
3176+
3177+
start_value = 100
3178+
results = list(dataloader.enumerate(start=start_value))
3179+
3180+
expected_indices = list(range(start_value, start_value + len(dataloader)))
3181+
actual_indices = [idx for idx, _ in results]
3182+
3183+
self.assertEqual(actual_indices, expected_indices)
3184+
3185+
def test_custom_enumerate_with_state_restoration(self):
3186+
"""Test that custom enumerate correctly handles state restoration."""
3187+
# Create initial dataloader and process some batches
3188+
dataloader1 = DataLoader(self.dataset, batch_size=2, shuffle=False)
3189+
3190+
# Process first 3 batches (indices 0, 1, 2) and save state
3191+
processed_count = 0
3192+
for i, (batch,) in enumerate(dataloader1):
3193+
processed_count += 1
3194+
if i == 2: # After processing batches 0, 1, 2
3195+
state = dataloader1.state_dict()
3196+
break
3197+
3198+
self.assertEqual(processed_count, 3)
3199+
3200+
# Create new dataloader and restore state
3201+
dataloader2 = DataLoader(self.dataset, batch_size=2, shuffle=False)
3202+
dataloader2.load_state_dict(state)
3203+
3204+
# Use custom enumerate to continue
3205+
remaining_results = list(dataloader2.enumerate())
3206+
3207+
# Should start from index 3 (since we processed 0, 1, 2)
3208+
expected_start_index = 3
3209+
expected_indices = list(range(expected_start_index, len(dataloader1)))
3210+
actual_indices = [idx for idx, _ in remaining_results]
3211+
3212+
self.assertEqual(actual_indices, expected_indices)
3213+
3214+
# Verify data correctness
3215+
expected_data_start = 6 # batch 3 should contain [6, 7]
3216+
first_batch_data = remaining_results[0][1][0]
3217+
self.assertTrue(torch.equal(first_batch_data, torch.tensor([expected_data_start, expected_data_start + 1])))
3218+
3219+
def test_custom_enumerate_vs_builtin_after_restoration(self):
3220+
"""Test that demonstrates the difference between custom and builtin enumerate after state restoration."""
3221+
# Create initial dataloader and process some batches
3222+
dataloader1 = DataLoader(self.dataset, batch_size=2, shuffle=False)
3223+
3224+
# Process first 2 batches and save state
3225+
for i, batch in enumerate(dataloader1):
3226+
if i == 1: # After processing batches 0, 1
3227+
state = dataloader1.state_dict()
3228+
break
3229+
3230+
# Test builtin enumerate (demonstrates the problem)
3231+
dataloader2 = DataLoader(self.dataset, batch_size=2, shuffle=False)
3232+
dataloader2.load_state_dict(state)
3233+
builtin_results = list(enumerate(dataloader2))
3234+
builtin_indices = [idx for idx, _ in builtin_results]
3235+
3236+
# Test custom enumerate (shows the fix)
3237+
dataloader3 = DataLoader(self.dataset, batch_size=2, shuffle=False)
3238+
dataloader3.load_state_dict(state)
3239+
custom_results = list(dataloader3.enumerate())
3240+
custom_indices = [idx for idx, _ in custom_results]
3241+
3242+
# Builtin enumerate should start from 0 (the problem)
3243+
self.assertEqual(builtin_indices, [0, 1, 2, 3, 4, 5, 6, 7])
3244+
3245+
# Custom enumerate should start from 2 (the fix)
3246+
self.assertEqual(custom_indices, [2, 3, 4, 5, 6, 7, 8, 9])
3247+
3248+
# Data should be the same for both
3249+
for (_, builtin_data), (_, custom_data) in zip(builtin_results, custom_results):
3250+
self.assertTrue(torch.equal(builtin_data[0], custom_data[0]))
3251+
3252+
def test_custom_enumerate_with_multiprocessing(self):
3253+
"""Test that custom enumerate works correctly with multiprocessing."""
3254+
# Test with 2 workers
3255+
dataloader1 = DataLoader(self.dataset, batch_size=2, shuffle=False, num_workers=2)
3256+
3257+
# Process some batches and save state
3258+
for i, batch in enumerate(dataloader1):
3259+
if i == 2:
3260+
state = dataloader1.state_dict()
3261+
break
3262+
3263+
# Restore state and use custom enumerate
3264+
dataloader2 = DataLoader(self.dataset, batch_size=2, shuffle=False, num_workers=2)
3265+
dataloader2.load_state_dict(state)
3266+
results = list(dataloader2.enumerate())
3267+
3268+
# Should start from the correct index
3269+
expected_start_index = 3
3270+
actual_indices = [idx for idx, _ in results]
3271+
self.assertEqual(actual_indices[0], expected_start_index)
3272+
3273+
def test_custom_enumerate_empty_after_restoration(self):
3274+
"""Test custom enumerate when no data remains after state restoration."""
3275+
# Use a small dataset
3276+
small_data = torch.arange(4)
3277+
small_dataset = TensorDataset(small_data)
3278+
dataloader1 = DataLoader(small_dataset, batch_size=2, shuffle=False)
3279+
3280+
# Process all but the last batch
3281+
state = None
3282+
batches_processed = 0
3283+
for i, batch in enumerate(dataloader1):
3284+
batches_processed += 1
3285+
if i == 0: # After first batch, only one batch remains
3286+
state = dataloader1.state_dict()
3287+
3288+
# Restore state and process remaining data
3289+
dataloader2 = DataLoader(small_dataset, batch_size=2, shuffle=False)
3290+
dataloader2.load_state_dict(state)
3291+
remaining_results = list(dataloader2.enumerate())
3292+
3293+
# Should have exactly one batch remaining with correct index
3294+
self.assertEqual(len(remaining_results), 1)
3295+
self.assertEqual(remaining_results[0][0], 1) # Should be index 1
3296+
3297+
def test_custom_enumerate_single_batch(self):
3298+
"""Test custom enumerate with single batch scenarios."""
3299+
# Create dataset with exactly one batch
3300+
single_batch_data = torch.arange(4)
3301+
single_batch_dataset = TensorDataset(single_batch_data)
3302+
dataloader = DataLoader(single_batch_dataset, batch_size=4, shuffle=False)
3303+
3304+
# Should produce one result with index 0
3305+
results = list(dataloader.enumerate())
3306+
self.assertEqual(len(results), 1)
3307+
self.assertEqual(results[0][0], 0)
3308+
3309+
# Test with start parameter
3310+
results_with_start = list(dataloader.enumerate(start=50))
3311+
self.assertEqual(len(results_with_start), 1)
3312+
self.assertEqual(results_with_start[0][0], 50)
3313+
3314+
def test_custom_enumerate_iterable_dataset(self):
3315+
"""Test custom enumerate with IterableDataset."""
3316+
class SimpleIterableDataset(IterableDataset):
3317+
def __init__(self, data):
3318+
self.data = data
3319+
3320+
def __iter__(self):
3321+
return iter(self.data)
3322+
3323+
def __len__(self):
3324+
return len(self.data)
3325+
3326+
iterable_dataset = SimpleIterableDataset(list(range(10)))
3327+
dataloader = DataLoader(iterable_dataset, batch_size=2, shuffle=False)
3328+
3329+
# Test basic custom enumerate
3330+
results = list(dataloader.enumerate())
3331+
expected_indices = list(range(5)) # 10 items / 2 batch_size = 5 batches
3332+
actual_indices = [idx for idx, _ in results]
3333+
3334+
self.assertEqual(actual_indices, expected_indices)
3335+
3336+
def test_custom_enumerate_consistency(self):
3337+
"""Test that multiple calls to custom enumerate produce consistent results."""
3338+
dataloader = DataLoader(self.dataset, batch_size=3, shuffle=False)
3339+
3340+
# Call enumerate multiple times
3341+
results1 = list(dataloader.enumerate())
3342+
results2 = list(dataloader.enumerate(start=0))
3343+
3344+
# Results should be identical
3345+
self.assertEqual(len(results1), len(results2))
3346+
for (idx1, data1), (idx2, data2) in zip(results1, results2):
3347+
self.assertEqual(idx1, idx2)
3348+
self.assertTrue(torch.equal(data1[0], data2[0]))
3349+
3350+
31483351
if __name__ == "__main__":
31493352
run_tests()

0 commit comments

Comments
 (0)