@@ -237,12 +237,24 @@ def random(**kwargs):
237237 test ()
238238
239239
240- @gen_cluster (
241- client = True ,
242- nthreads = [("127.0.0.1" , 1 )] * 4 ,
243- config = {"distributed.scheduler.work-stealing" : False },
244- )
240+ @gen_cluster (client = True , nthreads = [("127.0.0.1" , 1 )] * 4 )
245241async def test_decide_worker_common_dep_ignored (client , s , * workers ):
242+ r"""
243+ When we have basic linear chains, but all the downstream tasks also share a common dependency, ignore that dependency.
244+
245+ i j k l m n o p
246+ \__\__\__\___/__/__/__/
247+ | | | | | | | | |
248+ | | | | X | | | |
249+ a b c d e f g h
250+
251+ ^ Ignore the location of X when picking a worker for i..p.
252+ It will end up being copied to all workers anyway.
253+
254+ If a dependency will end up on every worker regardless, because many things depend on it,
255+ we should ignore it when selecting our candidate workers. Otherwise, we'll end up considering
256+ every worker as a candidate, which is 1) slow and 2) often leads to poor choices.
257+ """
246258 roots = [
247259 delayed (slowinc )(1 , 0.1 / (i + 1 ), dask_key_name = f"root-{ i } " ) for i in range (16 )
248260 ]
@@ -261,15 +273,15 @@ async def test_decide_worker_common_dep_ignored(client, s, *workers):
261273 root_keys = sorted (
262274 [int (k .split ("-" )[1 ]) for k in worker .data if k .startswith ("root" )]
263275 ),
264- dep_keys = sorted (
276+ deps_of_root = sorted (
265277 [int (k .split ("-" )[1 ]) for k in worker .data if k .startswith ("dep" )]
266278 ),
267279 )
268280 for worker in workers
269281 }
270282
271283 for k in keys .values ():
272- assert k ["root_keys" ] == k ["dep_keys " ]
284+ assert k ["root_keys" ] == k ["deps_of_root " ]
273285
274286 for worker in workers :
275287 log = worker .incoming_transfer_log
@@ -278,6 +290,100 @@ async def test_decide_worker_common_dep_ignored(client, s, *workers):
278290 assert list (log [0 ]["keys" ]) == ["everywhere" ]
279291
280292
293+ @gen_cluster (client = True , nthreads = [("127.0.0.1" , 1 )] * 4 )
294+ async def test_decide_worker_large_subtrees_colocated (client , s , * workers ):
295+ r"""
296+ Ensure that the above "ignore common dependencies" logic doesn't affect wide (but isolated) subtrees.
297+
298+ ........ ........ ........ ........
299+ \\\\//// \\\\//// \\\\//// \\\\////
300+ a b c d
301+
302+ Each one of a, b, etc. has more dependents than there are workers. But just because a has
303+ lots of dependents doesn't necessarily mean it will end up copied to every worker.
304+ Because a also has a few siblings, a's dependents shouldn't spread out over the whole cluster.
305+ """
306+ roots = [delayed (inc )(i , dask_key_name = f"root-{ i } " ) for i in range (len (workers ))]
307+ deps = [
308+ delayed (inc )(r , dask_key_name = f"dep-{ i } -{ j } " )
309+ for i , r in enumerate (roots )
310+ for j in range (len (workers ) * 2 )
311+ ]
312+
313+ rs , ds = dask .persist (roots , deps )
314+ await wait (ds )
315+
316+ keys = {
317+ worker .name : dict (
318+ root_keys = set (
319+ int (k .split ("-" )[1 ]) for k in worker .data if k .startswith ("root" )
320+ ),
321+ deps_of_root = set (
322+ int (k .split ("-" )[1 ]) for k in worker .data if k .startswith ("dep" )
323+ ),
324+ )
325+ for worker in workers
326+ }
327+
328+ for k in keys .values ():
329+ assert k ["root_keys" ] == k ["deps_of_root" ]
330+ assert len (k ["root_keys" ]) == len (roots ) / len (workers )
331+
332+ for worker in workers :
333+ assert not worker .incoming_transfer_log
334+
335+
336+ @gen_cluster (
337+ client = True ,
338+ nthreads = [("127.0.0.1" , 1 )] * 4 ,
339+ config = {"distributed.scheduler.work-stealing" : False },
340+ )
341+ async def test_decide_worker_large_multiroot_subtrees_colocated (client , s , * workers ):
342+ r"""
343+ Same as the above test, but also check isolated trees with multiple roots.
344+
345+ ........ ........ ........ ........
346+ \\\\//// \\\\//// \\\\//// \\\\////
347+ a b c d e f g h
348+ """
349+ roots = [
350+ delayed (inc )(i , dask_key_name = f"root-{ i } " ) for i in range (len (workers ) * 2 )
351+ ]
352+ deps = [
353+ delayed (lambda x , y : None )(
354+ r , roots [i * 2 + 1 ], dask_key_name = f"dep-{ i * 2 } -{ j } "
355+ )
356+ for i , r in enumerate (roots [::2 ])
357+ for j in range (len (workers ) * 2 )
358+ ]
359+
360+ rs , ds = dask .persist (roots , deps )
361+ await wait (ds )
362+
363+ keys = {
364+ worker .name : dict (
365+ root_keys = set (
366+ int (k .split ("-" )[1 ]) for k in worker .data if k .startswith ("root" )
367+ ),
368+ deps_of_root = set ().union (
369+ * (
370+ (int (k .split ("-" )[1 ]), int (k .split ("-" )[1 ]) + 1 )
371+ for k in worker .data
372+ if k .startswith ("dep" )
373+ )
374+ ),
375+ )
376+ for worker in workers
377+ }
378+
379+ for k in keys .values ():
380+ assert k ["root_keys" ] == k ["deps_of_root" ]
381+ assert len (k ["root_keys" ]) == len (roots ) / len (workers )
382+
383+ for worker in workers :
384+ assert not worker .incoming_transfer_log
385+
386+
281387@gen_cluster (client = True , nthreads = [("127.0.0.1" , 1 )] * 3 )
282388async def test_move_data_over_break_restrictions (client , s , a , b , c ):
283389 [x ] = await client .scatter ([1 ], workers = b .address )
0 commit comments