Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ v2025.10.2 (unreleased)
New Features
~~~~~~~~~~~~

- :py:func:`merge` and :py:func:`concat` now support :py:class:`DataTree`
objects (:issue:`9790`, :issue:`9778`).
- :py:func:`merge`, :py:func:`concat` and :py:func:`combine_nested` now
support :py:class:`DataTree` objects (:issue:`9790`, :issue:`9778`,
:pull:`10849`).
By `Stephan Hoyer <https://github.com/shoyer>`_.
- The ``h5netcdf`` engine has support for pseudo ``NETCDF4_CLASSIC`` files, meaning variables and attributes are cast to supported types. Note that the saved files won't be recognized as genuine ``NETCDF4_CLASSIC`` files until ``h5netcdf`` adds support with version 1.7.0. (:issue:`10676`, :pull:`10686`).
By `David Huard <https://github.com/huard>`_.
Expand Down
145 changes: 109 additions & 36 deletions xarray/structure/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

from collections import Counter, defaultdict
from collections.abc import Callable, Hashable, Iterable, Iterator, Sequence
from typing import TYPE_CHECKING, Literal, TypeVar, Union, cast
from typing import TYPE_CHECKING, Literal, TypeAlias, TypeVar, cast, overload

import pandas as pd

from xarray.core import dtypes
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.datatree import DataTree
from xarray.core.utils import iterate_nested
from xarray.structure.alignment import AlignmentError
from xarray.structure.concat import concat
Expand Down Expand Up @@ -96,27 +97,28 @@ def _ensure_same_types(series, dim):
raise TypeError(error_msg)


def _infer_concat_order_from_coords(datasets):
def _infer_concat_order_from_coords(datasets: list[Dataset] | list[DataTree]):
concat_dims = []
tile_ids = [() for ds in datasets]
tile_ids: list[tuple[int, ...]] = [() for ds in datasets]

# All datasets have same variables because they've been grouped as such
ds0 = datasets[0]
for dim in ds0.dims:
# Check if dim is a coordinate dimension
if dim in ds0:
# Need to read coordinate values to do ordering
indexes = [ds._indexes.get(dim) for ds in datasets]
if any(index is None for index in indexes):
error_msg = (
f"Every dimension requires a corresponding 1D coordinate "
f"and index for inferring concatenation order but the "
f"coordinate '{dim}' has no corresponding index"
)
raise ValueError(error_msg)

# TODO (benbovy, flexible indexes): support flexible indexes?
indexes = [index.to_pandas_index() for index in indexes]
indexes: list[pd.Index] = []
for ds in datasets:
index = ds._indexes.get(dim)
if index is None:
error_msg = (
f"Every dimension requires a corresponding 1D coordinate "
f"and index for inferring concatenation order but the "
f"coordinate '{dim}' has no corresponding index"
)
raise ValueError(error_msg)
# TODO (benbovy, flexible indexes): support flexible indexes?
indexes.append(index.to_pandas_index())

# If dimension coordinate values are same on every dataset then
# should be leaving this dimension alone (it's just a "bystander")
Expand Down Expand Up @@ -153,7 +155,7 @@ def _infer_concat_order_from_coords(datasets):
rank = series.rank(
method="dense", ascending=ascending, numeric_only=False
)
order = rank.astype(int).values - 1
order = (rank.astype(int).values - 1).tolist()

# Append positions along extra dimension to structure which
# encodes the multi-dimensional concatenation order
Expand All @@ -163,10 +165,16 @@ def _infer_concat_order_from_coords(datasets):
]

if len(datasets) > 1 and not concat_dims:
raise ValueError(
"Could not find any dimension coordinates to use to "
"order the datasets for concatenation"
)
if any(isinstance(data, DataTree) for data in datasets):
raise ValueError(
"Did not find any dimension coordinates at root nodes "
"to order the DataTree objects for concatenation"
)
else:
raise ValueError(
"Could not find any dimension coordinates to use to "
"order the Dataset objects for concatenation"
)

combined_ids = dict(zip(tile_ids, datasets, strict=True))

Expand Down Expand Up @@ -224,7 +232,7 @@ def _combine_nd(

Parameters
----------
combined_ids : Dict[Tuple[int, ...]], xarray.Dataset]
combined_ids : Dict[Tuple[int, ...]], xarray.Dataset | xarray.DataTree]
Structure containing all datasets to be concatenated with "tile_IDs" as
keys, which specify position within the desired final combined result.
concat_dims : sequence of str
Expand All @@ -235,7 +243,7 @@ def _combine_nd(

Returns
-------
combined_ds : xarray.Dataset
combined_ds : xarray.Dataset | xarray.DataTree
"""

example_tile_id = next(iter(combined_ids.keys()))
Expand Down Expand Up @@ -399,20 +407,74 @@ def _nested_combine(
return combined


# Define type for arbitrarily-nested list of lists recursively:
DATASET_HYPERCUBE = Union[Dataset, Iterable["DATASET_HYPERCUBE"]]
# Define types for arbitrarily-nested list of lists.
# Mypy doesn't seem to handle overloads properly with recursive types, so we
# explicitly expand the first handful of levels of recursion.
DatasetLike: TypeAlias = DataArray | Dataset
DatasetHyperCube: TypeAlias = (
DatasetLike
| Sequence[DatasetLike]
| Sequence[Sequence[DatasetLike]]
| Sequence[Sequence[Sequence[DatasetLike]]]
| Sequence[Sequence[Sequence[Sequence[DatasetLike]]]]
)
DataTreeHyperCube: TypeAlias = (
DataTree
| Sequence[DataTree]
| Sequence[Sequence[DataTree]]
| Sequence[Sequence[Sequence[DataTree]]]
| Sequence[Sequence[Sequence[Sequence[DataTree]]]]
)


@overload
def combine_nested(
datasets: DatasetHyperCube,
concat_dim: str
| DataArray
| list[str]
| Sequence[str | DataArray | pd.Index | None]
| None,
compat: str | CombineKwargDefault = ...,
data_vars: str | CombineKwargDefault = ...,
coords: str | CombineKwargDefault = ...,
fill_value: object = ...,
join: JoinOptions | CombineKwargDefault = ...,
combine_attrs: CombineAttrsOptions = ...,
) -> Dataset: ...


@overload
def combine_nested(
datasets: DataTreeHyperCube,
concat_dim: str
| DataArray
| list[str]
| Sequence[str | DataArray | pd.Index | None]
| None,
compat: str | CombineKwargDefault = ...,
data_vars: str | CombineKwargDefault = ...,
coords: str | CombineKwargDefault = ...,
fill_value: object = ...,
join: JoinOptions | CombineKwargDefault = ...,
combine_attrs: CombineAttrsOptions = ...,
) -> DataTree: ...


def combine_nested(
datasets: DATASET_HYPERCUBE,
concat_dim: str | DataArray | Sequence[str | DataArray | pd.Index | None] | None,
datasets: DatasetHyperCube | DataTreeHyperCube,
concat_dim: str
| DataArray
| list[str]
| Sequence[str | DataArray | pd.Index | None]
| None,
compat: str | CombineKwargDefault = _COMPAT_DEFAULT,
data_vars: str | CombineKwargDefault = _DATA_VARS_DEFAULT,
coords: str | CombineKwargDefault = _COORDS_DEFAULT,
fill_value: object = dtypes.NA,
join: JoinOptions | CombineKwargDefault = _JOIN_DEFAULT,
combine_attrs: CombineAttrsOptions = "drop",
) -> Dataset:
) -> Dataset | DataTree:
"""
Explicitly combine an N-dimensional grid of datasets into one by using a
succession of concat and merge operations along each dimension of the grid.
Expand All @@ -433,7 +495,7 @@ def combine_nested(

Parameters
----------
datasets : list or nested list of Dataset
datasets : list or nested list of Dataset, DataArray or DataTree
Dataset objects to combine.
If concatenation or merging along more than one dimension is desired,
then datasets must be supplied in a nested list-of-lists.
Expand Down Expand Up @@ -527,7 +589,7 @@ def combine_nested(

Returns
-------
combined : xarray.Dataset
combined : xarray.Dataset or xarray.DataTree

Examples
--------
Expand Down Expand Up @@ -621,22 +683,29 @@ def combine_nested(
concat
merge
"""
mixed_datasets_and_arrays = any(
isinstance(obj, Dataset) for obj in iterate_nested(datasets)
) and any(
any_datasets = any(isinstance(obj, Dataset) for obj in iterate_nested(datasets))
any_unnamed_arrays = any(
isinstance(obj, DataArray) and obj.name is None
for obj in iterate_nested(datasets)
)
if mixed_datasets_and_arrays:
if any_datasets and any_unnamed_arrays:
raise ValueError("Can't combine datasets with unnamed arrays.")

if isinstance(concat_dim, str | DataArray) or concat_dim is None:
concat_dim = [concat_dim]
any_datatrees = any(isinstance(obj, DataTree) for obj in iterate_nested(datasets))
all_datatrees = all(isinstance(obj, DataTree) for obj in iterate_nested(datasets))
if any_datatrees and not all_datatrees:
raise ValueError("Can't combine a mix of DataTree and non-DataTree objects.")

concat_dims = (
[concat_dim]
if isinstance(concat_dim, str | DataArray) or concat_dim is None
else concat_dim
)

# The IDs argument tells _nested_combine that datasets aren't yet sorted
return _nested_combine(
datasets,
concat_dims=concat_dim,
concat_dims=concat_dims,
compat=compat,
data_vars=data_vars,
coords=coords,
Expand Down Expand Up @@ -988,6 +1057,10 @@ def combine_by_coords(
Finally, if you attempt to combine a mix of unnamed DataArrays with either named
DataArrays or Datasets, a ValueError will be raised (as this is an ambiguous operation).
"""
if any(isinstance(data_object, DataTree) for data_object in data_objects):
raise NotImplementedError(
"combine_by_coords() does not yet support DataTree objects."
)

if not data_objects:
return Dataset()
Expand Down Expand Up @@ -1018,7 +1091,7 @@ def combine_by_coords(
# Must be a mix of unnamed dataarrays with either named dataarrays or with datasets
# Can't combine these as we wouldn't know whether to merge or concatenate the arrays
raise ValueError(
"Can't automatically combine unnamed DataArrays with either named DataArrays or Datasets."
"Can't automatically combine unnamed DataArrays with named DataArrays or Datasets."
)
else:
# Promote any named DataArrays to single-variable Datasets to simplify combining
Expand Down
32 changes: 29 additions & 3 deletions xarray/tests/test_combine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import re
from itertools import product

import numpy as np
Expand All @@ -8,6 +9,7 @@
from xarray import (
DataArray,
Dataset,
DataTree,
MergeError,
combine_by_coords,
combine_nested,
Expand Down Expand Up @@ -624,8 +626,8 @@ def test_auto_combine_2d_combine_attrs_kwarg(self):
datasets,
concat_dim=["dim1", "dim2"],
data_vars="all",
combine_attrs=combine_attrs, # type: ignore[arg-type]
)
combine_attrs=combine_attrs,
) # type: ignore[call-overload]
assert_identical(result, expected)

def test_combine_nested_missing_data_new_dim(self):
Expand Down Expand Up @@ -764,7 +766,21 @@ def test_nested_combine_mixed_datasets_arrays(self):
with pytest.raises(
ValueError, match=r"Can't combine datasets with unnamed arrays."
):
combine_nested(objs, "x")
combine_nested(objs, "x") # type: ignore[arg-type]

def test_nested_combine_mixed_datatrees_and_datasets(self):
objs = [DataTree.from_dict({"foo": 0}), Dataset({"foo": 1})]
with pytest.raises(
ValueError,
match=r"Can't combine a mix of DataTree and non-DataTree objects.",
):
combine_nested(objs, concat_dim="x") # type: ignore[arg-type]

def test_datatree(self):
objs = [DataTree.from_dict({"foo": 0}), DataTree.from_dict({"foo": 1})]
expected = DataTree.from_dict({"foo": ("x", [0, 1])})
actual = combine_nested(objs, concat_dim="x")
assert expected.identical(actual)


class TestCombineDatasetsbyCoords:
Expand Down Expand Up @@ -1210,6 +1226,16 @@ def test_combine_by_coords_all_dataarrays_with_the_same_name(self):
expected = merge([named_da1, named_da2], compat="no_conflicts", join="outer")
assert_identical(expected, actual)

def test_combine_by_coords_datatree(self):
tree = DataTree.from_dict({"/nested/foo": ("x", [10])}, coords={"x": [1]})
with pytest.raises(
NotImplementedError,
match=re.escape(
"combine_by_coords() does not yet support DataTree objects."
),
):
combine_by_coords([tree]) # type: ignore[list-item]


class TestNewDefaults:
def test_concat_along_existing_dim(self):
Expand Down
Loading