Skip to content

Commit 303ac3b

Browse files
authored
[Datasets] Streaming executor fixes #5 (ray-project#32951)
1 parent 5ab758b commit 303ac3b

File tree

4 files changed

+56
-15
lines changed

4 files changed

+56
-15
lines changed

python/ray/data/_internal/lazy_block_list.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,13 @@
22
import uuid
33
from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union
44

5-
import numpy as np
6-
75
import ray
86
from ray.data._internal.block_list import BlockList
97
from ray.data._internal.progress_bar import ProgressBar
108
from ray.data._internal.remote_fn import cached_remote_fn
119
from ray.data._internal.memory_tracing import trace_allocation
1210
from ray.data._internal.stats import DatasetStats, _get_or_create_stats_actor
11+
from ray.data._internal.util import _split_list
1312
from ray.data.block import (
1413
Block,
1514
BlockAccessor,
@@ -162,22 +161,22 @@ def _check_if_cleared(self):
162161
# Note: does not force execution prior to splitting.
163162
def split(self, split_size: int) -> List["LazyBlockList"]:
164163
num_splits = math.ceil(len(self._tasks) / split_size)
165-
tasks = np.array_split(self._tasks, num_splits)
166-
block_partition_refs = np.array_split(self._block_partition_refs, num_splits)
167-
block_partition_meta_refs = np.array_split(
164+
tasks = _split_list(self._tasks, num_splits)
165+
block_partition_refs = _split_list(self._block_partition_refs, num_splits)
166+
block_partition_meta_refs = _split_list(
168167
self._block_partition_meta_refs, num_splits
169168
)
170-
cached_metadata = np.array_split(self._cached_metadata, num_splits)
169+
cached_metadata = _split_list(self._cached_metadata, num_splits)
171170
output = []
172171
for t, b, m, c in zip(
173172
tasks, block_partition_refs, block_partition_meta_refs, cached_metadata
174173
):
175174
output.append(
176175
LazyBlockList(
177-
t.tolist(),
178-
b.tolist(),
179-
m.tolist(),
180-
c.tolist(),
176+
t,
177+
b,
178+
m,
179+
c,
181180
owned_by_consumer=self._owned_by_consumer,
182181
)
183182
)

python/ray/data/_internal/util.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import importlib
22
import logging
33
import os
4-
from typing import List, Union, Optional, TYPE_CHECKING
4+
from typing import Any, List, Union, Optional, TYPE_CHECKING
55
from types import ModuleType
66
import sys
77

@@ -380,3 +380,20 @@ def ConsumptionAPI(*args, **kwargs):
380380
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
381381
return _consumption_api()(args[0])
382382
return _consumption_api(*args, **kwargs)
383+
384+
385+
def _split_list(arr: List[Any], num_splits: int) -> List[List[Any]]:
386+
"""Split the list into `num_splits` lists.
387+
388+
The splits will be even if the `num_splits` divides the length of list, otherwise
389+
the remainder (suppose it's R) will be allocated to the first R splits (one for
390+
each).
391+
This is the same as numpy.array_split(). The reason we make this a separate
392+
implementation is to allow the heterogeneity in the elements in the list.
393+
"""
394+
assert num_splits > 0
395+
q, r = divmod(len(arr), num_splits)
396+
splits = [
397+
arr[i * q + min(i, r) : (i + 1) * q + min(i + 1, r)] for i in range(num_splits)
398+
]
399+
return splits

python/ray/data/tests/test_dataset.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4581,9 +4581,19 @@ def test_warning_execute_with_no_cpu(ray_start_cluster):
45814581
ds = ray.data.range(10)
45824582
ds = ds.map_batches(lambda x: x)
45834583
ds.take()
4584-
except LoggerWarningCalled:
4585-
logger_args, logger_kwargs = mock_logger.call_args
4586-
assert "Warning: The Ray cluster currently does not have " in logger_args[0]
4584+
except Exception as e:
4585+
if ray.data.context.DatasetContext.get_current().use_streaming_executor:
4586+
assert isinstance(e, ValueError)
4587+
assert "exceeds the execution limits ExecutionResources(cpu=0.0" in str(
4588+
e
4589+
)
4590+
else:
4591+
assert isinstance(e, LoggerWarningCalled)
4592+
logger_args, logger_kwargs = mock_logger.call_args
4593+
assert (
4594+
"Warning: The Ray cluster currently does not have "
4595+
in logger_args[0]
4596+
)
45874597

45884598

45894599
def test_nowarning_execute_with_cpu(ray_start_cluster_init):

python/ray/data/tests/test_util.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import ray
33
import numpy as np
44

5-
from ray.data._internal.util import _check_pyarrow_version
5+
from ray.data._internal.util import _check_pyarrow_version, _split_list
66
from ray.data._internal.memory_tracing import (
77
trace_allocation,
88
trace_deallocation,
@@ -72,6 +72,21 @@ def test_memory_tracing(enabled):
7272
assert "test5" not in report, report
7373

7474

75+
def test_list_splits():
76+
with pytest.raises(AssertionError):
77+
_split_list(list(range(5)), 0)
78+
79+
with pytest.raises(AssertionError):
80+
_split_list(list(range(5)), -1)
81+
82+
assert _split_list(list(range(5)), 7) == [[0], [1], [2], [3], [4], [], []]
83+
assert _split_list(list(range(5)), 2) == [[0, 1, 2], [3, 4]]
84+
assert _split_list(list(range(6)), 2) == [[0, 1, 2], [3, 4, 5]]
85+
assert _split_list(list(range(5)), 1) == [[0, 1, 2, 3, 4]]
86+
assert _split_list(["foo", 1, [0], None], 2) == [["foo", 1], [[0], None]]
87+
assert _split_list(["foo", 1, [0], None], 3) == [["foo", 1], [[0]], [None]]
88+
89+
7590
if __name__ == "__main__":
7691
import sys
7792

0 commit comments

Comments
 (0)