1212from xarray .core .variable import Variable
1313from xarray .core .combine import merge
1414from xarray .core import dtypes , utils
15- from xarray .core ._typed_ops import DatasetOpsMixin
1615
1716from .treenode import TreeNode , PathType , _init_single_treenode
1817
@@ -188,7 +187,7 @@ def imag(self):
188187 else :
189188 raise AttributeError ("property is not defined for a node with no data" )
190189
191- # TODO .loc
190+ # TODO .loc, __contains__, __iter__, __array__, '__len__',
192191
193192 dims .__doc__ = Dataset .dims .__doc__
194193 variables .__doc__ = Dataset .variables .__doc__
@@ -207,68 +206,71 @@ def imag(self):
207206 "See the `map_over_subtree` decorator for more details." , width = 117 )
208207
209208
210- def _expose_methods_wrapped_to_map_over_subtree ( obj , method_name , method ):
209+ def _wrap_then_attach_to_cls ( cls_dict , methods_to_expose , wrap_func = None ):
211210 """
212- Expose given method on node object, but wrapped to map over whole subtree, not just that node object.
213-
214- Result is like having written this in obj's class definition:
211+ Attach given methods on a class, and optionally wrap each method first. (i.e. with map_over_subtree)
215212
213+ Result is like having written this in the classes' definition:
216214 ```
217- @map_over_subtree
215+ @wrap_func
218216 def method_name(self, *args, **kwargs):
219217 return self.method(*args, **kwargs)
220218 ```
221- """
222-
223- # Expose Dataset method, but wrapped to map over whole subtree when called
224- # TODO should we be using functools.partialmethod here instead?
225- mapped_over_tree = functools .partial (map_over_subtree (method ), obj )
226- setattr (obj , method_name , mapped_over_tree )
227-
228- # TODO do we really need this for ops like __add__?
229- # Add a line to the method's docstring explaining how it's been mapped
230- method_docstring = method .__doc__
231- if method_docstring is not None :
232- updated_method_docstring = method_docstring .replace ('\n ' , _MAPPED_DOCSTRING_ADDENDUM , 1 )
233- obj_method = getattr (obj , method_name )
234- setattr (obj_method , '__doc__' , updated_method_docstring )
235219
220+ Parameters
221+ ----------
222+ cls_dict
223+ The __dict__ attribute of a class, which can also be accessed by calling vars() from within that classes'
224+ definition.
225+ methods_to_expose : Iterable[Tuple[str, callable]]
226+ The method names and definitions supplied as a list of (method_name_string, method) pairs.\
227+ This format matches the output of inspect.getmembers().
228+ """
229+ for method_name , method in methods_to_expose :
230+ wrapped_method = wrap_func (method ) if wrap_func is not None else method
231+ cls_dict [method_name ] = wrapped_method
236232
237- # TODO equals, broadcast_equals etc.
238- # TODO do dask-related private methods need to be exposed?
239- _DATASET_DASK_METHODS_TO_EXPOSE = ['load' , 'compute' , 'persist' , 'unify_chunks' , 'chunk' , 'map_blocks' ]
240- _DATASET_METHODS_TO_EXPOSE = ['copy' , 'as_numpy' , '__copy__' , '__deepcopy__' , '__contains__' , '__len__' ,
241- '__bool__' , '__iter__' , '__array__' , 'set_coords' , 'reset_coords' , 'info' ,
242- 'isel' , 'sel' , 'head' , 'tail' , 'thin' , 'broadcast_like' , 'reindex_like' ,
243- 'reindex' , 'interp' , 'interp_like' , 'rename' , 'rename_dims' , 'rename_vars' ,
244- 'swap_dims' , 'expand_dims' , 'set_index' , 'reset_index' , 'reorder_levels' , 'stack' ,
245- 'unstack' , 'update' , 'merge' , 'drop_vars' , 'drop_sel' , 'drop_isel' , 'drop_dims' ,
246- 'transpose' , 'dropna' , 'fillna' , 'interpolate_na' , 'ffill' , 'bfill' , 'combine_first' ,
247- 'reduce' , 'map' , 'assign' , 'diff' , 'shift' , 'roll' , 'sortby' , 'quantile' , 'rank' ,
248- 'differentiate' , 'integrate' , 'cumulative_integrate' , 'filter_by_attrs' , 'polyfit' ,
249- 'pad' , 'idxmin' , 'idxmax' , 'argmin' , 'argmax' , 'query' , 'curvefit' ]
250- _DATASET_OPS_TO_EXPOSE = ['_unary_op' , '_binary_op' , '_inplace_binary_op' ]
251- _ALL_DATASET_METHODS_TO_EXPOSE = _DATASET_DASK_METHODS_TO_EXPOSE + _DATASET_METHODS_TO_EXPOSE + _DATASET_OPS_TO_EXPOSE
252-
253- # TODO methods which should not or cannot act over the whole tree, such as .to_array
254-
233+ # TODO do we really need this for ops like __add__?
234+ # Add a line to the method's docstring explaining how it's been mapped
235+ method_docstring = method .__doc__
236+ if method_docstring is not None :
237+ updated_method_docstring = method_docstring .replace ('\n ' , _MAPPED_DOCSTRING_ADDENDUM , 1 )
238+ setattr (cls_dict [method_name ], '__doc__' , updated_method_docstring )
255239
256- class DatasetMethodsMixin :
257- """Mixin to add Dataset methods like .mean(), but wrapped to map over all nodes in the subtree."""
258240
259- # TODO is there a way to put this code in the class definition so we don't have to specifically call this method?
260- def _add_dataset_methods (self ):
261- methods_to_expose = [(method_name , getattr (Dataset , method_name ))
262- for method_name in _ALL_DATASET_METHODS_TO_EXPOSE ]
241+ class MappedDatasetMethodsMixin :
242+ """
243+ Mixin to add Dataset methods like .mean(), but wrapped to map over all nodes in the subtree.
263244
264- for method_name , method in methods_to_expose :
265- _expose_methods_wrapped_to_map_over_subtree (self , method_name , method )
245+ Every method wrapped here needs to have a return value of Dataset or DataArray in order to construct a new tree.
246+ """
247+ # TODO equals, broadcast_equals etc.
248+ # TODO do dask-related private methods need to be exposed?
249+ _DATASET_DASK_METHODS_TO_EXPOSE = ['load' , 'compute' , 'persist' , 'unify_chunks' , 'chunk' , 'map_blocks' ]
250+ _DATASET_METHODS_TO_EXPOSE = ['copy' , 'as_numpy' , '__copy__' , '__deepcopy__' , 'set_coords' , 'reset_coords' , 'info' ,
251+ 'isel' , 'sel' , 'head' , 'tail' , 'thin' , 'broadcast_like' , 'reindex_like' ,
252+ 'reindex' , 'interp' , 'interp_like' , 'rename' , 'rename_dims' , 'rename_vars' ,
253+ 'swap_dims' , 'expand_dims' , 'set_index' , 'reset_index' , 'reorder_levels' , 'stack' ,
254+ 'unstack' , 'update' , 'merge' , 'drop_vars' , 'drop_sel' , 'drop_isel' , 'drop_dims' ,
255+ 'transpose' , 'dropna' , 'fillna' , 'interpolate_na' , 'ffill' , 'bfill' , 'combine_first' ,
256+ 'reduce' , 'map' , 'assign' , 'diff' , 'shift' , 'roll' , 'sortby' , 'quantile' , 'rank' ,
257+ 'differentiate' , 'integrate' , 'cumulative_integrate' , 'filter_by_attrs' , 'polyfit' ,
258+ 'pad' , 'idxmin' , 'idxmax' , 'argmin' , 'argmax' , 'query' , 'curvefit' ]
259+ # TODO unsure if these are called by external functions or not?
260+ _DATASET_OPS_TO_EXPOSE = ['_unary_op' , '_binary_op' , '_inplace_binary_op' ]
261+ _ALL_DATASET_METHODS_TO_EXPOSE = _DATASET_DASK_METHODS_TO_EXPOSE + _DATASET_METHODS_TO_EXPOSE + _DATASET_OPS_TO_EXPOSE
262+
263+ # TODO methods which should not or cannot act over the whole tree, such as .to_array
264+
265+ methods_to_expose = [(method_name , getattr (Dataset , method_name ))
266+ for method_name in _ALL_DATASET_METHODS_TO_EXPOSE ]
267+ _wrap_then_attach_to_cls (vars (), methods_to_expose , wrap_func = map_over_subtree )
266268
267269
268270# TODO implement ArrayReduce type methods
269271
270272
271- class DataTree (TreeNode , DatasetPropertiesMixin , DatasetMethodsMixin ):
273+ class DataTree (TreeNode , DatasetPropertiesMixin , MappedDatasetMethodsMixin ):
272274 """
273275 A tree-like hierarchical collection of xarray objects.
274276
@@ -339,15 +341,6 @@ def __init__(
339341 new_node = self .get_node (path )
340342 new_node [path ] = data
341343
342- # TODO this has to be
343- self ._add_all_dataset_api ()
344-
345- def _add_all_dataset_api (self ):
346- # Add methods like .isel(), but wrapped to map over subtrees
347- self ._add_dataset_methods ()
348-
349- # TODO add dataset ops here
350-
351344 @property
352345 def ds (self ) -> Dataset :
353346 return self ._ds
@@ -396,9 +389,6 @@ def _init_single_datatree_node(
396389 obj = object .__new__ (cls )
397390 obj = _init_single_treenode (obj , name = name , parent = parent , children = children )
398391 obj .ds = data
399-
400- obj ._add_all_dataset_api ()
401-
402392 return obj
403393
404394 def __str__ (self ):
@@ -435,7 +425,7 @@ def _single_node_repr(self):
435425 def __repr__ (self ):
436426 """Information about this node, including its relationships to other nodes."""
437427 # TODO redo this to look like the Dataset repr, but just with child and parent info
438- parent = self .parent .name if self .parent else "None"
428+ parent = self .parent .name if self .parent is not None else "None"
439429 node_str = f"DataNode(name='{ self .name } ', parent='{ parent } ', children={ [c .name for c in self .children ]} ,"
440430
441431 if self .has_data :
@@ -554,7 +544,7 @@ def __setitem__(
554544 except anytree .resolver .ResolverError :
555545 existing_node = None
556546
557- if existing_node :
547+ if existing_node is not None :
558548 if isinstance (value , Dataset ):
559549 # replace whole dataset
560550 existing_node .ds = Dataset
0 commit comments