diff --git a/datatree/datatree.py b/datatree/datatree.py index 85cec7d9..a1bcf02f 100644 --- a/datatree/datatree.py +++ b/datatree/datatree.py @@ -1102,6 +1102,28 @@ def identical(self, other: DataTree, from_root=True) -> bool: for node, other_node in zip(self.subtree, other.subtree) ) + def filter(self: DataTree, filterfunc: Callable[[DataTree], bool]) -> DataTree: + """ + Filter nodes according to a specified condition. + + Returns a new tree containing only the nodes in the original tree for which `fitlerfunc(node)` is True. + Will also contain empty nodes at intermediate positions if required to support leaves. + + Parameters + ---------- + filterfunc: function + A function which accepts only one DataTree - the node on which filterfunc will be called. + + See Also + -------- + pipe + map_over_subtree + """ + filtered_nodes = { + node.path: node.ds for node in self.subtree if filterfunc(node) + } + return DataTree.from_dict(filtered_nodes, name=self.root.name) + def map_over_subtree( self, func: Callable, diff --git a/datatree/tests/test_datatree.py b/datatree/tests/test_datatree.py index 9fa1b91f..004eca56 100644 --- a/datatree/tests/test_datatree.py +++ b/datatree/tests/test_datatree.py @@ -596,3 +596,28 @@ def f(x, tree, y): actual = dt.pipe((f, "tree"), **attrs) assert actual is dt and actual.attrs == attrs + + +class TestSubset: + def test_filter(self): + simpsons = DataTree.from_dict( + d={ + "/": xr.Dataset({"age": 83}), + "/Herbert": xr.Dataset({"age": 40}), + "/Homer": xr.Dataset({"age": 39}), + "/Homer/Bart": xr.Dataset({"age": 10}), + "/Homer/Lisa": xr.Dataset({"age": 8}), + "/Homer/Maggie": xr.Dataset({"age": 1}), + }, + name="Abe", + ) + expected = DataTree.from_dict( + d={ + "/": xr.Dataset({"age": 83}), + "/Herbert": xr.Dataset({"age": 40}), + "/Homer": xr.Dataset({"age": 39}), + }, + name="Abe", + ) + elders = simpsons.filter(lambda node: node["age"] > 18) + dtt.assert_identical(elders, expected) diff --git a/docs/source/api.rst b/docs/source/api.rst index feccdcfd..835b18d4 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -99,6 +99,7 @@ For manipulating, traversing, navigating, or mapping over the tree structure. DataTree.find_common_ancestor map_over_subtree DataTree.pipe + DataTree.filter DataTree Contents ----------------- diff --git a/docs/source/whats-new.rst b/docs/source/whats-new.rst index 93fad0e1..27ba0eb5 100644 --- a/docs/source/whats-new.rst +++ b/docs/source/whats-new.rst @@ -31,6 +31,8 @@ New Features By `Tom Nicholas `_. - Added a :py:meth:`DataTree.leaves` property (:pull:`177`). By `Tom Nicholas `_. +- Added a :py:meth:`DataTree.filter` method (:pull:`184`). + By `Tom Nicholas `_. Breaking changes ~~~~~~~~~~~~~~~~