@@ -439,7 +439,7 @@ def map( # type: ignore[override]
439439
440440
441441class DataTree (
442- NamedNode [ "DataTree" ] ,
442+ NamedNode ,
443443 DataTreeAggregations ,
444444 DataTreeOpsMixin ,
445445 TreeAttrAccessMixin ,
@@ -559,9 +559,12 @@ def _node_coord_variables_with_index(self) -> Mapping[Hashable, Variable]:
559559
560560 @property
561561 def _coord_variables (self ) -> ChainMap [Hashable , Variable ]:
562+ # ChainMap is incorrected typed in typeshed (only the first argument
563+ # needs to be mutable)
564+ # https://github.com/python/typeshed/issues/8430
562565 return ChainMap (
563566 self ._node_coord_variables ,
564- * (p ._node_coord_variables_with_index for p in self .parents ),
567+ * (p ._node_coord_variables_with_index for p in self .parents ), # type: ignore[arg-type]
565568 )
566569
567570 @property
@@ -1340,7 +1343,7 @@ def equals(self, other: DataTree) -> bool:
13401343 )
13411344
13421345 def _inherited_coords_set (self ) -> set [str ]:
1343- return set (self .parent .coords if self .parent else [])
1346+ return set (self .parent .coords if self .parent else []) # type: ignore[arg-type]
13441347
13451348 def identical (self , other : DataTree ) -> bool :
13461349 """
@@ -1563,9 +1566,33 @@ def match(self, pattern: str) -> DataTree:
15631566 }
15641567 return DataTree .from_dict (matching_nodes , name = self .name )
15651568
1569+ @overload
15661570 def map_over_datasets (
15671571 self ,
1568- func : Callable ,
1572+ func : Callable [..., Dataset | None ],
1573+ * args : Any ,
1574+ kwargs : Mapping [str , Any ] | None = None ,
1575+ ) -> DataTree : ...
1576+
1577+ @overload
1578+ def map_over_datasets (
1579+ self ,
1580+ func : Callable [..., tuple [Dataset | None , Dataset | None ]],
1581+ * args : Any ,
1582+ kwargs : Mapping [str , Any ] | None = None ,
1583+ ) -> tuple [DataTree , DataTree ]: ...
1584+
1585+ @overload
1586+ def map_over_datasets (
1587+ self ,
1588+ func : Callable [..., tuple [Dataset | None , ...]],
1589+ * args : Any ,
1590+ kwargs : Mapping [str , Any ] | None = None ,
1591+ ) -> tuple [DataTree , ...]: ...
1592+
1593+ def map_over_datasets (
1594+ self ,
1595+ func : Callable [..., Dataset | None | tuple [Dataset | None , ...]],
15691596 * args : Any ,
15701597 kwargs : Mapping [str , Any ] | None = None ,
15711598 ) -> DataTree | tuple [DataTree , ...]:
@@ -1600,8 +1627,7 @@ def map_over_datasets(
16001627 map_over_datasets
16011628 """
16021629 # TODO this signature means that func has no way to know which node it is being called upon - change?
1603- # TODO fix this typing error
1604- return map_over_datasets (func , self , * args , kwargs = kwargs )
1630+ return map_over_datasets (func , self , * args , kwargs = kwargs ) # type: ignore[arg-type]
16051631
16061632 @overload
16071633 def pipe (
@@ -1695,7 +1721,7 @@ def groups(self):
16951721
16961722 def _unary_op (self , f , * args , ** kwargs ) -> DataTree :
16971723 # TODO do we need to any additional work to avoid duplication etc.? (Similar to aggregations)
1698- return self .map_over_datasets (functools .partial (f , ** kwargs ), * args ) # type: ignore[return-value]
1724+ return self .map_over_datasets (functools .partial (f , ** kwargs ), * args )
16991725
17001726 def _binary_op (self , other , f , reflexive = False , join = None ) -> DataTree :
17011727 from xarray .core .groupby import GroupBy
@@ -1911,7 +1937,7 @@ def to_zarr(
19111937 )
19121938
19131939 def _get_all_dims (self ) -> set :
1914- all_dims = set ()
1940+ all_dims : set [ Any ] = set ()
19151941 for node in self .subtree :
19161942 all_dims .update (node ._node_dims )
19171943 return all_dims
0 commit comments