diff --git a/datatree/datatree.py b/datatree/datatree.py index 6f78d8c8..ff7a417b 100644 --- a/datatree/datatree.py +++ b/datatree/datatree.py @@ -1014,6 +1014,66 @@ def map_over_subtree_inplace( if node.has_data: node.ds = func(node.ds, *args, **kwargs) + def pipe( + self, func: Callable | tuple[Callable, str], *args: Any, **kwargs: Any + ) -> Any: + """Apply ``func(self, *args, **kwargs)`` + + This method replicates the pandas method of the same name. + + Parameters + ---------- + func : callable + function to apply to this xarray object (Dataset/DataArray). + ``args``, and ``kwargs`` are passed into ``func``. + Alternatively a ``(callable, data_keyword)`` tuple where + ``data_keyword`` is a string indicating the keyword of + ``callable`` that expects the xarray object. + *args + positional arguments passed into ``func``. + **kwargs + a dictionary of keyword arguments passed into ``func``. + + Returns + ------- + object : Any + the return type of ``func``. + + Notes + ----- + Use ``.pipe`` when chaining together functions that expect + xarray or pandas objects, e.g., instead of writing + + .. code:: python + + f(g(h(dt), arg1=a), arg2=b, arg3=c) + + You can write + + .. code:: python + + (dt.pipe(h).pipe(g, arg1=a).pipe(f, arg2=b, arg3=c)) + + If you have a function that takes the data as (say) the second + argument, pass a tuple indicating which keyword expects the + data. For example, suppose ``f`` takes its data as ``arg2``: + + .. code:: python + + (dt.pipe(h).pipe(g, arg1=a).pipe((f, "arg2"), arg1=a, arg3=c)) + + """ + if isinstance(func, tuple): + func, target = func + if target in kwargs: + raise ValueError( + f"{target} is both the pipe target and a keyword argument" + ) + kwargs[target] = self + else: + args = (self,) + args + return func(*args, **kwargs) + def render(self): """Print tree structure, including any data stored at each node.""" for pre, fill, node in RenderTree(self): diff --git a/datatree/tests/test_datatree.py b/datatree/tests/test_datatree.py index ca9fae5f..dd08618d 100644 --- a/datatree/tests/test_datatree.py +++ b/datatree/tests/test_datatree.py @@ -448,3 +448,35 @@ def test_arithmetic(self, create_test_datatree): class TestRestructuring: ... + + +class TestPipe: + def test_noop(self, create_test_datatree): + dt = create_test_datatree() + + actual = dt.pipe(lambda tree: tree) + assert actual.identical(dt) + + def test_params(self, create_test_datatree): + dt = create_test_datatree() + + def f(tree, **attrs): + return tree.assign(arr_with_attrs=xr.Variable("dim0", [], attrs=attrs)) + + attrs = {"x": 1, "y": 2, "z": 3} + + actual = dt.pipe(f, **attrs) + assert actual["arr_with_attrs"].attrs == attrs + + def test_named_self(self, create_test_datatree): + dt = create_test_datatree() + + def f(x, tree, y): + tree.attrs.update({"x": x, "y": y}) + return tree + + attrs = {"x": 1, "y": 2} + + actual = dt.pipe((f, "tree"), **attrs) + + assert actual is dt and actual.attrs == attrs diff --git a/docs/source/api.rst b/docs/source/api.rst index 209d4ab9..49caaea8 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -96,6 +96,7 @@ For manipulating, traversing, navigating, or mapping over the tree structure. DataTree.iter_lineage DataTree.find_common_ancestor map_over_subtree + DataTree.pipe DataTree Contents ----------------- diff --git a/docs/source/whats-new.rst b/docs/source/whats-new.rst index 25bb8614..daee3fd8 100644 --- a/docs/source/whats-new.rst +++ b/docs/source/whats-new.rst @@ -25,6 +25,8 @@ New Features - Add the ability to register accessors on ``DataTree`` objects, by using ``register_datatree_accessor``. (:pull:`144`) By `Tom Nicholas `_. +- Allow method chaining with a new :py:meth:`DataTree.pipe` method (:issue:`151`, :pull:`156`). + By `Justus Magin `_. Breaking changes ~~~~~~~~~~~~~~~~