Skip to content

Commit 31d540f

Browse files
mesejomax-sixty
andauthored
Closes #4647 DataArray transpose inconsistent with Dataset Ellipsis usage (#4767)
- Add missing_dims parameter to transpose to mimic isel behavior - Add missing_dims to infix_dims to make function consistent across different methods. Co-authored-by: Maximilian Roos <[email protected]>
1 parent 7298df0 commit 31d540f

File tree

8 files changed

+90
-26
lines changed

8 files changed

+90
-26
lines changed

doc/internals.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,4 +230,4 @@ re-open it directly with Zarr:
230230
231231
zgroup = zarr.open("rasm.zarr")
232232
print(zgroup.tree())
233-
dict(zgroup["Tair"].attrs)
233+
dict(zgroup["Tair"].attrs)

doc/plotting.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -955,4 +955,4 @@ One can also make line plots with multidimensional coordinates. In this case, ``
955955
f, ax = plt.subplots(2, 1)
956956
da.plot.line(x="lon", hue="y", ax=ax[0])
957957
@savefig plotting_example_2d_hue_xy.png
958-
da.plot.line(x="lon", hue="x", ax=ax[1])
958+
da.plot.line(x="lon", hue="x", ax=ax[1])

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ Bug fixes
5555
- Fix a crash in orthogonal indexing on geographic coordinates with ``engine='cfgrib'`` (:issue:`4733` :pull:`4737`).
5656
By `Alessandro Amici <https://github.com/alexamici>`_
5757
- Limit number of data rows when printing large datasets. (:issue:`4736`, :pull:`4750`). By `Jimmy Westling <https://github.com/illviljan>`_.
58+
- Add ``missing_dims`` parameter to transpose (:issue:`4647`, :pull:`4767`). By `Daniel Mesejo <https://github.com/mesejo>`_.
5859

5960
Documentation
6061
~~~~~~~~~~~~~
@@ -76,6 +77,7 @@ Internal Changes
7677
- Run the tests in parallel using pytest-xdist (:pull:`4694`).
7778

7879
By `Justus Magin <https://github.com/keewis>`_ and `Mathias Hauser <https://github.com/mathause>`_.
80+
7981
- Replace all usages of ``assert x.identical(y)`` with ``assert_identical(x, y)``
8082
for clearer error messages.
8183
(:pull:`4752`);

xarray/core/dataarray.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2120,7 +2120,12 @@ def to_unstacked_dataset(self, dim, level=0):
21202120
# unstacked dataset
21212121
return Dataset(data_dict)
21222122

2123-
def transpose(self, *dims: Hashable, transpose_coords: bool = True) -> "DataArray":
2123+
def transpose(
2124+
self,
2125+
*dims: Hashable,
2126+
transpose_coords: bool = True,
2127+
missing_dims: str = "raise",
2128+
) -> "DataArray":
21242129
"""Return a new DataArray object with transposed dimensions.
21252130
21262131
Parameters
@@ -2130,6 +2135,12 @@ def transpose(self, *dims: Hashable, transpose_coords: bool = True) -> "DataArra
21302135
dimensions to this order.
21312136
transpose_coords : bool, default: True
21322137
If True, also transpose the coordinates of this DataArray.
2138+
missing_dims : {"raise", "warn", "ignore"}, default: "raise"
2139+
What to do if dimensions that should be selected from are not present in the
2140+
DataArray:
2141+
- "raise": raise an exception
2142+
- "warning": raise a warning, and ignore the missing dimensions
2143+
- "ignore": ignore the missing dimensions
21332144
21342145
Returns
21352146
-------
@@ -2148,7 +2159,7 @@ def transpose(self, *dims: Hashable, transpose_coords: bool = True) -> "DataArra
21482159
Dataset.transpose
21492160
"""
21502161
if dims:
2151-
dims = tuple(utils.infix_dims(dims, self.dims))
2162+
dims = tuple(utils.infix_dims(dims, self.dims, missing_dims))
21522163
variable = self.variable.transpose(*dims)
21532164
if transpose_coords:
21542165
coords: Dict[Hashable, Variable] = {}

xarray/core/utils.py

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -744,28 +744,32 @@ def __len__(self) -> int:
744744
return len(self._data) - num_hidden
745745

746746

747-
def infix_dims(dims_supplied: Collection, dims_all: Collection) -> Iterator:
747+
def infix_dims(
748+
dims_supplied: Collection, dims_all: Collection, missing_dims: str = "raise"
749+
) -> Iterator:
748750
"""
749-
Resolves a supplied list containing an ellispsis representing other items, to
751+
Resolves a supplied list containing an ellipsis representing other items, to
750752
a generator with the 'realized' list of all items
751753
"""
752754
if ... in dims_supplied:
753755
if len(set(dims_all)) != len(dims_all):
754756
raise ValueError("Cannot use ellipsis with repeated dims")
755-
if len([d for d in dims_supplied if d == ...]) > 1:
757+
if list(dims_supplied).count(...) > 1:
756758
raise ValueError("More than one ellipsis supplied")
757759
other_dims = [d for d in dims_all if d not in dims_supplied]
758-
for d in dims_supplied:
759-
if d == ...:
760+
existing_dims = drop_missing_dims(dims_supplied, dims_all, missing_dims)
761+
for d in existing_dims:
762+
if d is ...:
760763
yield from other_dims
761764
else:
762765
yield d
763766
else:
764-
if set(dims_supplied) ^ set(dims_all):
767+
existing_dims = drop_missing_dims(dims_supplied, dims_all, missing_dims)
768+
if set(existing_dims) ^ set(dims_all):
765769
raise ValueError(
766770
f"{dims_supplied} must be a permuted list of {dims_all}, unless `...` is included"
767771
)
768-
yield from dims_supplied
772+
yield from existing_dims
769773

770774

771775
def get_temp_dimname(dims: Container[Hashable], new_dim: Hashable) -> Hashable:
@@ -805,7 +809,7 @@ def drop_dims_from_indexers(
805809
invalid = indexers.keys() - set(dims)
806810
if invalid:
807811
raise ValueError(
808-
f"dimensions {invalid} do not exist. Expected one or more of {dims}"
812+
f"Dimensions {invalid} do not exist. Expected one or more of {dims}"
809813
)
810814

811815
return indexers
@@ -818,7 +822,7 @@ def drop_dims_from_indexers(
818822
invalid = indexers.keys() - set(dims)
819823
if invalid:
820824
warnings.warn(
821-
f"dimensions {invalid} do not exist. Expected one or more of {dims}"
825+
f"Dimensions {invalid} do not exist. Expected one or more of {dims}"
822826
)
823827
for key in invalid:
824828
indexers.pop(key)
@@ -834,6 +838,48 @@ def drop_dims_from_indexers(
834838
)
835839

836840

841+
def drop_missing_dims(
842+
supplied_dims: Collection, dims: Collection, missing_dims: str
843+
) -> Collection:
844+
"""Depending on the setting of missing_dims, drop any dimensions from supplied_dims that
845+
are not present in dims.
846+
847+
Parameters
848+
----------
849+
supplied_dims : dict
850+
dims : sequence
851+
missing_dims : {"raise", "warn", "ignore"}
852+
"""
853+
854+
if missing_dims == "raise":
855+
supplied_dims_set = set(val for val in supplied_dims if val is not ...)
856+
invalid = supplied_dims_set - set(dims)
857+
if invalid:
858+
raise ValueError(
859+
f"Dimensions {invalid} do not exist. Expected one or more of {dims}"
860+
)
861+
862+
return supplied_dims
863+
864+
elif missing_dims == "warn":
865+
866+
invalid = set(supplied_dims) - set(dims)
867+
if invalid:
868+
warnings.warn(
869+
f"Dimensions {invalid} do not exist. Expected one or more of {dims}"
870+
)
871+
872+
return [val for val in supplied_dims if val in dims or val is ...]
873+
874+
elif missing_dims == "ignore":
875+
return [val for val in supplied_dims if val in dims or val is ...]
876+
877+
else:
878+
raise ValueError(
879+
f"Unrecognised option {missing_dims} for missing_dims argument"
880+
)
881+
882+
837883
class UncachedAccessor:
838884
"""Acts like a property, but on both classes and class instances
839885

xarray/tests/test_dataarray.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -797,13 +797,13 @@ def test_isel(self):
797797
assert_identical(self.dv[:3, :5], self.dv.isel(x=slice(3), y=slice(5)))
798798
with raises_regex(
799799
ValueError,
800-
r"dimensions {'not_a_dim'} do not exist. Expected "
800+
r"Dimensions {'not_a_dim'} do not exist. Expected "
801801
r"one or more of \('x', 'y'\)",
802802
):
803803
self.dv.isel(not_a_dim=0)
804804
with pytest.warns(
805805
UserWarning,
806-
match=r"dimensions {'not_a_dim'} do not exist. "
806+
match=r"Dimensions {'not_a_dim'} do not exist. "
807807
r"Expected one or more of \('x', 'y'\)",
808808
):
809809
self.dv.isel(not_a_dim=0, missing_dims="warn")
@@ -2231,9 +2231,21 @@ def test_transpose(self):
22312231
actual = da.transpose("z", ..., "x", transpose_coords=True)
22322232
assert_equal(expected, actual)
22332233

2234+
# same as previous but with a missing dimension
2235+
actual = da.transpose(
2236+
"z", "y", "x", "not_a_dim", transpose_coords=True, missing_dims="ignore"
2237+
)
2238+
assert_equal(expected, actual)
2239+
22342240
with pytest.raises(ValueError):
22352241
da.transpose("x", "y")
22362242

2243+
with pytest.raises(ValueError):
2244+
da.transpose("not_a_dim", "z", "x", ...)
2245+
2246+
with pytest.warns(UserWarning):
2247+
da.transpose("not_a_dim", "y", "x", ..., missing_dims="warn")
2248+
22372249
def test_squeeze(self):
22382250
assert_equal(self.dv.variable.squeeze(), self.dv.squeeze().variable)
22392251

@@ -6227,7 +6239,6 @@ def da_dask(seed=123):
62276239

62286240
@pytest.mark.parametrize("da", ("repeating_ints",), indirect=True)
62296241
def test_isin(da):
6230-
62316242
expected = DataArray(
62326243
np.asarray([[0, 0, 0], [1, 0, 0]]),
62336244
dims=list("yx"),
@@ -6277,7 +6288,6 @@ def test_coarsen_keep_attrs():
62776288

62786289
@pytest.mark.parametrize("da", (1, 2), indirect=True)
62796290
def test_rolling_iter(da):
6280-
62816291
rolling_obj = da.rolling(time=7)
62826292
rolling_obj_mean = rolling_obj.mean()
62836293

@@ -6452,7 +6462,6 @@ def test_rolling_construct(center, window):
64526462
@pytest.mark.parametrize("window", (1, 2, 3, 4))
64536463
@pytest.mark.parametrize("name", ("sum", "mean", "std", "max"))
64546464
def test_rolling_reduce(da, center, min_periods, window, name):
6455-
64566465
if min_periods is not None and window < min_periods:
64576466
min_periods = window
64586467

@@ -6491,7 +6500,6 @@ def test_rolling_reduce_nonnumeric(center, min_periods, window, name):
64916500

64926501

64936502
def test_rolling_count_correct():
6494-
64956503
da = DataArray([0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], dims="time")
64966504

64976505
kwargs = [
@@ -6579,7 +6587,6 @@ def test_ndrolling_construct(center, fill_value):
65796587
],
65806588
)
65816589
def test_rolling_keep_attrs(funcname, argument):
6582-
65836590
attrs_da = {"da_attr": "test"}
65846591

65856592
data = np.linspace(10, 15, 100)
@@ -6623,7 +6630,6 @@ def test_rolling_keep_attrs(funcname, argument):
66236630

66246631

66256632
def test_rolling_keep_attrs_deprecated():
6626-
66276633
attrs_da = {"da_attr": "test"}
66286634

66296635
data = np.linspace(10, 15, 100)
@@ -6957,7 +6963,6 @@ def test_rolling_exp(da, dim, window_type, window):
69576963

69586964
@requires_numbagg
69596965
def test_rolling_exp_keep_attrs(da):
6960-
69616966
attrs = {"attrs": "da"}
69626967
da.attrs = attrs
69636968

xarray/tests/test_dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,14 +1024,14 @@ def test_isel(self):
10241024
data.isel(not_a_dim=slice(0, 2))
10251025
with raises_regex(
10261026
ValueError,
1027-
r"dimensions {'not_a_dim'} do not exist. Expected "
1027+
r"Dimensions {'not_a_dim'} do not exist. Expected "
10281028
r"one or more of "
10291029
r"[\w\W]*'time'[\w\W]*'dim\d'[\w\W]*'dim\d'[\w\W]*'dim\d'[\w\W]*",
10301030
):
10311031
data.isel(not_a_dim=slice(0, 2))
10321032
with pytest.warns(
10331033
UserWarning,
1034-
match=r"dimensions {'not_a_dim'} do not exist. "
1034+
match=r"Dimensions {'not_a_dim'} do not exist. "
10351035
r"Expected one or more of "
10361036
r"[\w\W]*'time'[\w\W]*'dim\d'[\w\W]*'dim\d'[\w\W]*'dim\d'[\w\W]*",
10371037
):

xarray/tests/test_variable.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1270,13 +1270,13 @@ def test_isel(self):
12701270
assert_identical(v.isel(time=[]), v[[]])
12711271
with raises_regex(
12721272
ValueError,
1273-
r"dimensions {'not_a_dim'} do not exist. Expected one or more of "
1273+
r"Dimensions {'not_a_dim'} do not exist. Expected one or more of "
12741274
r"\('time', 'x'\)",
12751275
):
12761276
v.isel(not_a_dim=0)
12771277
with pytest.warns(
12781278
UserWarning,
1279-
match=r"dimensions {'not_a_dim'} do not exist. Expected one or more of "
1279+
match=r"Dimensions {'not_a_dim'} do not exist. Expected one or more of "
12801280
r"\('time', 'x'\)",
12811281
):
12821282
v.isel(not_a_dim=0, missing_dims="warn")

0 commit comments

Comments
 (0)