@@ -2293,6 +2293,57 @@ def test_compute(self):
22932293 assert actual .chunksizes == expected_chunksizes , "mismatching chunksizes"
22942294 assert tree .chunksizes == original_chunksizes , "original tree was modified"
22952295
2296+ def test_persist (self ):
2297+ ds1 = xr .Dataset ({"a" : ("x" , np .arange (10 ))})
2298+ ds2 = xr .Dataset ({"b" : ("y" , np .arange (5 ))})
2299+ ds3 = xr .Dataset ({"c" : ("z" , np .arange (4 ))})
2300+ ds4 = xr .Dataset ({"d" : ("x" , np .arange (- 5 , 5 ))})
2301+
2302+ def fn (x ):
2303+ return 2 * x
2304+
2305+ expected = xr .DataTree .from_dict (
2306+ {
2307+ "/" : fn (ds1 ).chunk ({"x" : 5 }),
2308+ "/group1" : fn (ds2 ).chunk ({"y" : 3 }),
2309+ "/group2" : fn (ds3 ).chunk ({"z" : 2 }),
2310+ "/group1/subgroup1" : fn (ds4 ).chunk ({"x" : 5 }),
2311+ }
2312+ )
2313+ # Add trivial second layer to the task graph, persist should reduce to one
2314+ tree = xr .DataTree .from_dict (
2315+ {
2316+ "/" : fn (ds1 .chunk ({"x" : 5 })),
2317+ "/group1" : fn (ds2 .chunk ({"y" : 3 })),
2318+ "/group2" : fn (ds3 .chunk ({"z" : 2 })),
2319+ "/group1/subgroup1" : fn (ds4 .chunk ({"x" : 5 })),
2320+ }
2321+ )
2322+ original_chunksizes = tree .chunksizes
2323+ original_hlg_depths = {
2324+ node .path : len (node .dataset .__dask_graph__ ().layers )
2325+ for node in tree .subtree
2326+ }
2327+
2328+ actual = tree .persist ()
2329+ actual_hlg_depths = {
2330+ node .path : len (node .dataset .__dask_graph__ ().layers )
2331+ for node in actual .subtree
2332+ }
2333+
2334+ assert_identical (actual , expected )
2335+
2336+ assert actual .chunksizes == original_chunksizes , "chunksizes were modified"
2337+ assert (
2338+ tree .chunksizes == original_chunksizes
2339+ ), "original chunksizes were modified"
2340+ assert all (
2341+ d == 1 for d in actual_hlg_depths .values ()
2342+ ), "unexpected dask graph depth"
2343+ assert all (
2344+ d == 2 for d in original_hlg_depths .values ()
2345+ ), "original dask graph was modified"
2346+
22962347 def test_chunk (self ):
22972348 ds1 = xr .Dataset ({"a" : ("x" , np .arange (10 ))})
22982349 ds2 = xr .Dataset ({"b" : ("y" , np .arange (5 ))})
0 commit comments