Skip to content

Commit aaac9cd

Browse files
authored
[Data] Improve state initialization for ActorPoolMapOperator (#34037)
ActorPoolMapOperator takes in a Callable class which initializes some state to be reused for every batch. In the current implementation, this state is initialized on the first batch, rather than during actor init. In this PR, we separate the state initialization and actually call it during Actor init. This allows state to be initialized for fixed size actor pools, even when tasks are not ready to be dispatched for better pipelining. It also supports using multithreaded actors, so state gets initialized once per actor instead of once per thread. --------- Signed-off-by: amogkam <[email protected]>
1 parent 372a1eb commit aaac9cd

File tree

5 files changed

+68
-21
lines changed

5 files changed

+68
-21
lines changed

python/ray/data/_internal/execution/legacy_compat.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -259,25 +259,24 @@ def _stage_to_operator(stage: Stage, input_op: PhysicalOperator) -> PhysicalOper
259259
fn_ = stage.fn
260260

261261
def fn(item: Any) -> Any:
262-
# Wrapper providing cached instantiation of stateful callable class
263-
# UDFs.
262+
assert ray.data._cached_fn is not None
263+
assert ray.data._cached_cls == fn_
264+
return ray.data._cached_fn(item)
265+
266+
def init_fn():
264267
if ray.data._cached_fn is None:
265268
ray.data._cached_cls = fn_
266269
ray.data._cached_fn = fn_(
267270
*fn_constructor_args, **fn_constructor_kwargs
268271
)
269-
else:
270-
# A worker is destroyed when its actor is killed, so we
271-
# shouldn't have any worker reuse across different UDF
272-
# applications (i.e. different map operators).
273-
assert ray.data._cached_cls == fn_
274-
return ray.data._cached_fn(item)
275272

276273
else:
277274
fn = stage.fn
275+
init_fn = None
278276
fn_args = (fn,)
279277
else:
280278
fn_args = ()
279+
init_fn = None
281280
if stage.fn_args:
282281
fn_args += stage.fn_args
283282
fn_kwargs = stage.fn_kwargs or {}
@@ -288,6 +287,7 @@ def do_map(blocks: Iterator[Block], ctx: TaskContext) -> Iterator[Block]:
288287
return MapOperator.create(
289288
do_map,
290289
input_op,
290+
init_fn=init_fn,
291291
name=stage.name,
292292
compute_strategy=compute,
293293
min_rows_per_bundle=stage.target_block_size,

python/ray/data/_internal/execution/operators/actor_pool_map_operator.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Dict, Any, Iterator, Callable, List, Tuple, Union, Optional
44

55
import ray
6-
from ray.data.block import Block, BlockMetadata
6+
from ray.data.block import Block, BlockMetadata, _CallableClassProtocol
77
from ray.data.context import DatasetContext, DEFAULT_SCHEDULING_STRATEGY
88
from ray.data._internal.compute import ActorPoolStrategy
99
from ray.data._internal.dataset_logger import DatasetLogger
@@ -37,6 +37,7 @@ class ActorPoolMapOperator(MapOperator):
3737
def __init__(
3838
self,
3939
transform_fn: Callable[[Iterator[Block]], Iterator[Block]],
40+
init_fn: Callable[[], None],
4041
input_op: PhysicalOperator,
4142
autoscaling_policy: "AutoscalingPolicy",
4243
name: str = "ActorPoolMap",
@@ -47,6 +48,7 @@ def __init__(
4748
4849
Args:
4950
transform_fn: The function to apply to each ref bundle input.
51+
init_fn: The callable class to instantiate on each actor.
5052
input_op: Operator generating input data for this op.
5153
autoscaling_policy: A policy controlling when the actor pool should be
5254
scaled up and scaled down.
@@ -60,6 +62,7 @@ def __init__(
6062
super().__init__(
6163
transform_fn, input_op, name, min_rows_per_bundle, ray_remote_args
6264
)
65+
self._init_fn = init_fn
6366
self._ray_remote_args = self._apply_default_remote_args(self._ray_remote_args)
6467

6568
# Create autoscaling policy from compute strategy.
@@ -105,7 +108,7 @@ def _start_actor(self):
105108
"""Start a new actor and add it to the actor pool as a pending actor."""
106109
assert self._cls is not None
107110
ctx = DatasetContext.get_current()
108-
actor = self._cls.remote(ctx, src_fn_name=self.name)
111+
actor = self._cls.remote(ctx, src_fn_name=self.name, init_fn=self._init_fn)
109112
self._actor_pool.add_pending_actor(actor, actor.get_location.remote())
110113

111114
def _add_bundled_input(self, bundle: RefBundle):
@@ -307,10 +310,15 @@ def _apply_default_remote_args(ray_remote_args: Dict[str, Any]) -> Dict[str, Any
307310
class _MapWorker:
308311
"""An actor worker for MapOperator."""
309312

310-
def __init__(self, ctx: DatasetContext, src_fn_name: str):
313+
def __init__(
314+
self, ctx: DatasetContext, src_fn_name: str, init_fn: _CallableClassProtocol
315+
):
311316
DatasetContext._set_current(ctx)
312317
self.src_fn_name: str = src_fn_name
313318

319+
# Initialize state for this actor.
320+
init_fn()
321+
314322
def get_location(self) -> NodeIdStr:
315323
return ray.get_runtime_context().get_node_id()
316324

python/ray/data/_internal/execution/operators/map_operator.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,15 @@
22
import copy
33
from dataclasses import dataclass
44
import itertools
5-
from typing import List, Iterator, Any, Dict, Optional, Union
5+
from typing import Callable, List, Iterator, Any, Dict, Optional, Union
66

77
import ray
8-
from ray.data.block import Block, BlockAccessor, BlockMetadata, BlockExecStats
8+
from ray.data.block import (
9+
Block,
10+
BlockAccessor,
11+
BlockMetadata,
12+
BlockExecStats,
13+
)
914
from ray.data._internal.compute import (
1015
ComputeStrategy,
1116
TaskPoolStrategy,
@@ -66,6 +71,7 @@ def create(
6671
cls,
6772
transform_fn: MapTransformFn,
6873
input_op: PhysicalOperator,
74+
init_fn: Optional[Callable[[], None]] = None,
6975
name: str = "Map",
7076
# TODO(ekl): slim down ComputeStrategy to only specify the compute
7177
# config and not contain implementation code.
@@ -83,6 +89,7 @@ def create(
8389
Args:
8490
transform_fn: The function to apply to each ref bundle input.
8591
input_op: Operator generating input data for this op.
92+
init_fn: The callable class to instantiate if using ActorPoolMapOperator.
8693
name: The name of this operator.
8794
compute_strategy: Customize the compute strategy for this op.
8895
min_rows_per_bundle: The number of rows to gather per batch passed to the
@@ -117,8 +124,15 @@ def create(
117124
compute_strategy
118125
)
119126
autoscaling_policy = AutoscalingPolicy(autoscaling_config)
127+
128+
if init_fn is None:
129+
130+
def init_fn():
131+
pass
132+
120133
return ActorPoolMapOperator(
121134
transform_fn,
135+
init_fn,
122136
input_op,
123137
autoscaling_policy=autoscaling_policy,
124138
name=name,

python/ray/data/_internal/planner/plan_udf_map_op.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,20 +61,18 @@ def _plan_udf_map_op(
6161
fn_ = op._fn
6262

6363
def fn(item: Any) -> Any:
64-
# Wrapper providing cached instantiation of stateful callable class
65-
# UDFs.
64+
assert ray.data._cached_fn is not None
65+
assert ray.data._cached_cls == fn_
66+
return ray.data._cached_fn(item)
67+
68+
def init_fn():
6669
if ray.data._cached_fn is None:
6770
ray.data._cached_cls = fn_
6871
ray.data._cached_fn = fn_(*fn_constructor_args, **fn_constructor_kwargs)
69-
else:
70-
# A worker is destroyed when its actor is killed, so we
71-
# shouldn't have any worker reuse across different UDF
72-
# applications (i.e. different map operators).
73-
assert ray.data._cached_cls == fn_
74-
return ray.data._cached_fn(item)
7572

7673
else:
7774
fn = op._fn
75+
init_fn = None
7876
fn_args = (fn,)
7977
if op._fn_args:
8078
fn_args += op._fn_args
@@ -86,6 +84,7 @@ def do_map(blocks: Iterator[Block], ctx: TaskContext) -> Iterator[Block]:
8684
return MapOperator.create(
8785
do_map,
8886
input_physical_dag,
87+
init_fn=init_fn,
8988
name=op.name,
9089
compute_strategy=compute,
9190
min_rows_per_bundle=op._target_block_size,

python/ray/data/tests/test_operators.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,32 @@ def _sleep(block_iter: Iterable[Block]) -> Iterable[Block]:
506506
wait_for_condition(lambda: (ray.available_resources().get("GPU", 0) == 1.0))
507507

508508

509+
def test_actor_pool_map_operator_init(ray_start_regular_shared):
510+
"""Tests that ActorPoolMapOperator runs init_fn on start."""
511+
512+
from ray.exceptions import RayActorError
513+
514+
def _sleep(block_iter: Iterable[Block]) -> Iterable[Block]:
515+
time.sleep(999)
516+
517+
def _fail():
518+
raise ValueError("init_failed")
519+
520+
input_op = InputDataBuffer(make_ref_bundles([[i] for i in range(10)]))
521+
compute_strategy = ActorPoolStrategy(min_size=1)
522+
523+
op = MapOperator.create(
524+
_sleep,
525+
input_op=input_op,
526+
init_fn=_fail,
527+
name="TestMapper",
528+
compute_strategy=compute_strategy,
529+
)
530+
531+
with pytest.raises(RayActorError, match=r"init_failed"):
532+
op.start(ExecutionOptions())
533+
534+
509535
@pytest.mark.parametrize(
510536
"compute,expected",
511537
[

0 commit comments

Comments
 (0)