Skip to content

Commit ee96d00

Browse files
jianoaixelliottower
authored andcommitted
[data] Streaming executor fixes ray-project#2 (ray-project#32759)
Signed-off-by: elliottower <[email protected]>
1 parent d4e017c commit ee96d00

File tree

3 files changed

+41
-11
lines changed

3 files changed

+41
-11
lines changed

python/ray/data/_internal/execution/legacy_compat.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,11 +266,13 @@ def bulk_fn(
266266

267267
def _bundles_to_block_list(bundles: Iterator[RefBundle]) -> BlockList:
268268
blocks, metadata = [], []
269+
owns_blocks = True
269270
for ref_bundle in bundles:
271+
if not ref_bundle.owns_blocks:
272+
owns_blocks = False
270273
for block, meta in ref_bundle.blocks:
271274
blocks.append(block)
272275
metadata.append(meta)
273-
owns_blocks = all(b.owns_blocks for b in bundles)
274276
return BlockList(blocks, metadata, owned_by_consumer=owns_blocks)
275277

276278

python/ray/data/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2825,7 +2825,7 @@ def iter_rows(self, *, prefetch_blocks: int = 0) -> Iterator[Union[T, TableRow]]
28252825
for batch in self.iter_batches(
28262826
batch_size=None, prefetch_blocks=prefetch_blocks, batch_format=batch_format
28272827
):
2828-
batch = BlockAccessor.for_block(batch)
2828+
batch = BlockAccessor.for_block(BlockAccessor.batch_to_block(batch))
28292829
for row in batch.iter_rows():
28302830
yield row
28312831

python/ray/data/tests/test_dataset.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1369,17 +1369,26 @@ def test_count_lazy(ray_start_regular_shared):
13691369

13701370
def test_lazy_loading_exponential_rampup(ray_start_regular_shared):
13711371
ds = ray.data.range(100, parallelism=20)
1372-
assert ds._plan.execute()._num_computed() == 0
1372+
1373+
def check_num_computed(expected):
1374+
if ray.data.context.DatasetContext.get_current().use_streaming_executor:
1375+
# In streaing executor, ds.take() will not invoke partial execution
1376+
# in LazyBlocklist.
1377+
assert ds._plan.execute()._num_computed() == 0
1378+
else:
1379+
assert ds._plan.execute()._num_computed() == expected
1380+
1381+
check_num_computed(0)
13731382
assert ds.take(10) == list(range(10))
1374-
assert ds._plan.execute()._num_computed() == 2
1383+
check_num_computed(2)
13751384
assert ds.take(20) == list(range(20))
1376-
assert ds._plan.execute()._num_computed() == 4
1385+
check_num_computed(4)
13771386
assert ds.take(30) == list(range(30))
1378-
assert ds._plan.execute()._num_computed() == 8
1387+
check_num_computed(8)
13791388
assert ds.take(50) == list(range(50))
1380-
assert ds._plan.execute()._num_computed() == 16
1389+
check_num_computed(16)
13811390
assert ds.take(100) == list(range(100))
1382-
assert ds._plan.execute()._num_computed() == 20
1391+
check_num_computed(20)
13831392

13841393

13851394
def test_dataset_repr(ray_start_regular_shared):
@@ -1696,7 +1705,14 @@ def to_pylist(table):
16961705
# Default ArrowRows.
16971706
for row, t_row in zip(ds.iter_rows(), to_pylist(t)):
16981707
assert isinstance(row, TableRow)
1699-
assert isinstance(row, ArrowRow)
1708+
# In streaming, we set batch_format to "default" because calling
1709+
# ds.dataset_format() will still invoke bulk execution and we want
1710+
# to avoid that. As a result, it's receiving PandasRow (the defaut
1711+
# batch format).
1712+
if ray.data.context.DatasetContext.get_current().use_streaming_executor:
1713+
assert isinstance(row, PandasRow)
1714+
else:
1715+
assert isinstance(row, ArrowRow)
17001716
assert row == t_row
17011717

17021718
# PandasRows after conversion.
@@ -1710,7 +1726,14 @@ def to_pylist(table):
17101726
# Prefetch.
17111727
for row, t_row in zip(ds.iter_rows(prefetch_blocks=1), to_pylist(t)):
17121728
assert isinstance(row, TableRow)
1713-
assert isinstance(row, ArrowRow)
1729+
# In streaming, we set batch_format to "default" because calling
1730+
# ds.dataset_format() will still invoke bulk execution and we want
1731+
# to avoid that. As a result, it's receiving PandasRow (the defaut
1732+
# batch format).
1733+
if ray.data.context.DatasetContext.get_current().use_streaming_executor:
1734+
assert isinstance(row, PandasRow)
1735+
else:
1736+
assert isinstance(row, ArrowRow)
17141737
assert row == t_row
17151738

17161739

@@ -2181,7 +2204,12 @@ def test_lazy_loading_iter_batches_exponential_rampup(ray_start_regular_shared):
21812204
ds = ray.data.range(32, parallelism=8)
21822205
expected_num_blocks = [1, 2, 4, 4, 8, 8, 8, 8]
21832206
for _, expected in zip(ds.iter_batches(batch_size=None), expected_num_blocks):
2184-
assert ds._plan.execute()._num_computed() == expected
2207+
if ray.data.context.DatasetContext.get_current().use_streaming_executor:
2208+
# In streaming execution of ds.iter_batches(), there is no partial
2209+
# execution so _num_computed() in LazyBlocklist is 0.
2210+
assert ds._plan.execute()._num_computed() == 0
2211+
else:
2212+
assert ds._plan.execute()._num_computed() == expected
21852213

21862214

21872215
def test_add_column(ray_start_regular_shared):

0 commit comments

Comments
 (0)