Skip to content

Commit e1fd58b

Browse files
committed
Handle multiple large subtrees
This addresses the issue in #5325 (comment). It feels a little hacky since it can still be wrong (what if there are multiple root groups that have large subtrees?). We're trying to infer global graph structure (how mnay sibling tasks are there) using TaskGroups, which don't necessarily reflect graph structure. It's also hard to explain the intuition for why this is right-ish (besides "well we need the `len(dts._dependents)` number to be smaller if it has siblings".)
1 parent aaa12a3 commit e1fd58b

File tree

2 files changed

+115
-9
lines changed

2 files changed

+115
-9
lines changed

distributed/scheduler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7985,10 +7985,10 @@ def decide_worker(
79857985
candidates = {
79867986
wws
79877987
for dts in deps
7988-
for wws in dts._who_has
79897988
# Ignore dependencies that will need to be, or already are, copied to all workers
7990-
if max(len(dts._who_has), len(dts._dependents))
7989+
if max(len(dts._dependents) / len(dts._group), len(dts._who_has))
79917990
< len(valid_workers if valid_workers is not None else all_workers)
7991+
for wws in dts._who_has
79927992
}
79937993
if valid_workers is None:
79947994
if not candidates:

distributed/tests/test_scheduler.py

Lines changed: 113 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -237,12 +237,24 @@ def random(**kwargs):
237237
test()
238238

239239

240-
@gen_cluster(
241-
client=True,
242-
nthreads=[("127.0.0.1", 1)] * 4,
243-
config={"distributed.scheduler.work-stealing": False},
244-
)
240+
@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 4)
245241
async def test_decide_worker_common_dep_ignored(client, s, *workers):
242+
r"""
243+
When we have basic linear chains, but all the downstream tasks also share a common dependency, ignore that dependency.
244+
245+
i j k l m n o p
246+
\__\__\__\___/__/__/__/
247+
| | | | | | | | |
248+
| | | | X | | | |
249+
a b c d e f g h
250+
251+
^ Ignore the location of X when picking a worker for i..p.
252+
It will end up being copied to all workers anyway.
253+
254+
If a dependency will end up on every worker regardless, because many things depend on it,
255+
we should ignore it when selecting our candidate workers. Otherwise, we'll end up considering
256+
every worker as a candidate, which is 1) slow and 2) often leads to poor choices.
257+
"""
246258
roots = [
247259
delayed(slowinc)(1, 0.1 / (i + 1), dask_key_name=f"root-{i}") for i in range(16)
248260
]
@@ -261,15 +273,15 @@ async def test_decide_worker_common_dep_ignored(client, s, *workers):
261273
root_keys=sorted(
262274
[int(k.split("-")[1]) for k in worker.data if k.startswith("root")]
263275
),
264-
dep_keys=sorted(
276+
deps_of_root=sorted(
265277
[int(k.split("-")[1]) for k in worker.data if k.startswith("dep")]
266278
),
267279
)
268280
for worker in workers
269281
}
270282

271283
for k in keys.values():
272-
assert k["root_keys"] == k["dep_keys"]
284+
assert k["root_keys"] == k["deps_of_root"]
273285

274286
for worker in workers:
275287
log = worker.incoming_transfer_log
@@ -278,6 +290,100 @@ async def test_decide_worker_common_dep_ignored(client, s, *workers):
278290
assert list(log[0]["keys"]) == ["everywhere"]
279291

280292

293+
@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 4)
294+
async def test_decide_worker_large_subtrees_colocated(client, s, *workers):
295+
r"""
296+
Ensure that the above "ignore common dependencies" logic doesn't affect wide (but isolated) subtrees.
297+
298+
........ ........ ........ ........
299+
\\\\//// \\\\//// \\\\//// \\\\////
300+
a b c d
301+
302+
Each one of a, b, etc. has more dependents than there are workers. But just because a has
303+
lots of dependents doesn't necessarily mean it will end up copied to every worker.
304+
Because a also has a few siblings, a's dependents shouldn't spread out over the whole cluster.
305+
"""
306+
roots = [delayed(inc)(i, dask_key_name=f"root-{i}") for i in range(len(workers))]
307+
deps = [
308+
delayed(inc)(r, dask_key_name=f"dep-{i}-{j}")
309+
for i, r in enumerate(roots)
310+
for j in range(len(workers) * 2)
311+
]
312+
313+
rs, ds = dask.persist(roots, deps)
314+
await wait(ds)
315+
316+
keys = {
317+
worker.name: dict(
318+
root_keys=set(
319+
int(k.split("-")[1]) for k in worker.data if k.startswith("root")
320+
),
321+
deps_of_root=set(
322+
int(k.split("-")[1]) for k in worker.data if k.startswith("dep")
323+
),
324+
)
325+
for worker in workers
326+
}
327+
328+
for k in keys.values():
329+
assert k["root_keys"] == k["deps_of_root"]
330+
assert len(k["root_keys"]) == len(roots) / len(workers)
331+
332+
for worker in workers:
333+
assert not worker.incoming_transfer_log
334+
335+
336+
@gen_cluster(
337+
client=True,
338+
nthreads=[("127.0.0.1", 1)] * 4,
339+
config={"distributed.scheduler.work-stealing": False},
340+
)
341+
async def test_decide_worker_large_multiroot_subtrees_colocated(client, s, *workers):
342+
r"""
343+
Same as the above test, but also check isolated trees with multiple roots.
344+
345+
........ ........ ........ ........
346+
\\\\//// \\\\//// \\\\//// \\\\////
347+
a b c d e f g h
348+
"""
349+
roots = [
350+
delayed(inc)(i, dask_key_name=f"root-{i}") for i in range(len(workers) * 2)
351+
]
352+
deps = [
353+
delayed(lambda x, y: None)(
354+
r, roots[i * 2 + 1], dask_key_name=f"dep-{i * 2}-{j}"
355+
)
356+
for i, r in enumerate(roots[::2])
357+
for j in range(len(workers) * 2)
358+
]
359+
360+
rs, ds = dask.persist(roots, deps)
361+
await wait(ds)
362+
363+
keys = {
364+
worker.name: dict(
365+
root_keys=set(
366+
int(k.split("-")[1]) for k in worker.data if k.startswith("root")
367+
),
368+
deps_of_root=set().union(
369+
*(
370+
(int(k.split("-")[1]), int(k.split("-")[1]) + 1)
371+
for k in worker.data
372+
if k.startswith("dep")
373+
)
374+
),
375+
)
376+
for worker in workers
377+
}
378+
379+
for k in keys.values():
380+
assert k["root_keys"] == k["deps_of_root"]
381+
assert len(k["root_keys"]) == len(roots) / len(workers)
382+
383+
for worker in workers:
384+
assert not worker.incoming_transfer_log
385+
386+
281387
@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3)
282388
async def test_move_data_over_break_restrictions(client, s, a, b, c):
283389
[x] = await client.scatter([1], workers=b.address)

0 commit comments

Comments
 (0)