Skip to content

Commit f0f838c

Browse files
committed
Merge branch 'main' into custom-groupers
* main: fix cf decoding of grid_mapping (pydata#9765) Allow wrapping `np.ndarray` subclasses (pydata#9760) Optimize polyfit (pydata#9766) Use `map_overlap` for rolling reductions with Dask (pydata#9770) fix html repr indexes section (pydata#9768)
2 parents 2512d53 + e674286 commit f0f838c

File tree

15 files changed

+309
-95
lines changed

15 files changed

+309
-95
lines changed

doc/whats-new.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@ New Features
2929
- Support lazy grouping by dask arrays, and allow specifying ordered groups with ``UniqueGrouper(labels=["a", "b", "c"])``
3030
(:issue:`2852`, :issue:`757`).
3131
By `Deepak Cherian <https://github.com/dcherian>`_.
32+
- Allow wrapping ``np.ndarray`` subclasses, e.g. ``astropy.units.Quantity`` (:issue:`9704`, :pull:`9760`).
33+
By `Sam Levang <https://github.com/slevang>`_ and `Tien Vo <https://github.com/tien-vo>`_.
34+
- Optimize :py:meth:`DataArray.polyfit` and :py:meth:`Dataset.polyfit` with dask, when used with
35+
arrays with more than two dimensions.
36+
(:issue:`5629`). By `Deepak Cherian <https://github.com/dcherian>`_.
3237

3338
Breaking changes
3439
~~~~~~~~~~~~~~~~
@@ -53,6 +58,8 @@ Bug fixes
5358
By `Stephan Hoyer <https://github.com/shoyer>`_.
5459
- Fix regression in the interoperability of :py:meth:`DataArray.polyfit` and :py:meth:`xr.polyval` for date-time coordinates. (:pull:`9691`).
5560
By `Pascal Bourgault <https://github.com/aulemahal>`_.
61+
- Fix CF decoding of ``grid_mapping`` to allow all possible formats, add tests (:issue:`9761`, :pull:`9765`).
62+
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.
5663

5764
Documentation
5865
~~~~~~~~~~~~~

xarray/conventions.py

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import itertools
34
from collections import defaultdict
45
from collections.abc import Hashable, Iterable, Mapping, MutableMapping
56
from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union
@@ -31,6 +32,7 @@
3132
"formula_terms",
3233
)
3334
CF_RELATED_DATA_NEEDS_PARSING = (
35+
"grid_mapping",
3436
"cell_measures",
3537
"formula_terms",
3638
)
@@ -476,18 +478,41 @@ def stackable(dim: Hashable) -> bool:
476478
if decode_coords == "all":
477479
for attr_name in CF_RELATED_DATA:
478480
if attr_name in var_attrs:
479-
attr_val = var_attrs[attr_name]
480-
if attr_name not in CF_RELATED_DATA_NEEDS_PARSING:
481-
var_names = attr_val.split()
482-
else:
483-
roles_and_names = [
484-
role_or_name
485-
for part in attr_val.split(":")
486-
for role_or_name in part.split()
487-
]
488-
if len(roles_and_names) % 2 == 1:
489-
emit_user_level_warning(f"Attribute {attr_name} malformed")
490-
var_names = roles_and_names[1::2]
481+
# fixes stray colon
482+
attr_val = var_attrs[attr_name].replace(" :", ":")
483+
var_names = attr_val.split()
484+
# if grid_mapping is a single string, do not enter here
485+
if (
486+
attr_name in CF_RELATED_DATA_NEEDS_PARSING
487+
and len(var_names) > 1
488+
):
489+
# map the keys to list of strings
490+
# "A: b c d E: f g" returns
491+
# {"A": ["b", "c", "d"], "E": ["f", "g"]}
492+
roles_and_names = defaultdict(list)
493+
key = None
494+
for vname in var_names:
495+
if ":" in vname:
496+
key = vname.strip(":")
497+
else:
498+
if key is None:
499+
raise ValueError(
500+
f"First element {vname!r} of [{attr_val!r}] misses ':', "
501+
f"cannot decode {attr_name!r}."
502+
)
503+
roles_and_names[key].append(vname)
504+
# for grid_mapping keys are var_names
505+
if attr_name == "grid_mapping":
506+
var_names = list(roles_and_names.keys())
507+
else:
508+
# for cell_measures and formula_terms values are var names
509+
var_names = list(itertools.chain(*roles_and_names.values()))
510+
# consistency check (one element per key)
511+
if len(var_names) != len(roles_and_names.keys()):
512+
emit_user_level_warning(
513+
f"Attribute {attr_name!r} has malformed content [{attr_val!r}], "
514+
f"decoding {var_names!r} to coordinates."
515+
)
491516
if all(var_name in variables for var_name in var_names):
492517
new_vars[k].encoding[attr_name] = attr_val
493518
coord_names.update(var_names)
@@ -732,7 +757,7 @@ def _encode_coordinates(
732757
# the dataset faithfully. Because this serialization goes beyond CF
733758
# conventions, only do it if necessary.
734759
# Reference discussion:
735-
# http://mailman.cgd.ucar.edu/pipermail/cf-metadata/2014/007571.html
760+
# https://cfconventions.org/mailing-list-archive/Data/7400.html
736761
global_coordinates.difference_update(written_coords)
737762
if global_coordinates:
738763
attributes = dict(attributes)

xarray/core/dask_array_compat.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from typing import Any
2+
3+
from xarray.namedarray.utils import module_available
4+
5+
6+
def reshape_blockwise(
7+
x: Any,
8+
shape: int | tuple[int, ...],
9+
chunks: tuple[tuple[int, ...], ...] | None = None,
10+
):
11+
if module_available("dask", "2024.08.2"):
12+
from dask.array import reshape_blockwise
13+
14+
return reshape_blockwise(x, shape=shape, chunks=chunks)
15+
else:
16+
return x.reshape(shape)

xarray/core/dask_array_ops.py

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,43 @@
11
from __future__ import annotations
22

3+
import math
4+
35
from xarray.core import dtypes, nputils
46

57

68
def dask_rolling_wrapper(moving_func, a, window, min_count=None, axis=-1):
79
"""Wrapper to apply bottleneck moving window funcs on dask arrays"""
8-
import dask.array as da
9-
10-
dtype, fill_value = dtypes.maybe_promote(a.dtype)
11-
a = a.astype(dtype)
12-
# inputs for overlap
13-
if axis < 0:
14-
axis = a.ndim + axis
15-
depth = {d: 0 for d in range(a.ndim)}
16-
depth[axis] = (window + 1) // 2
17-
boundary = {d: fill_value for d in range(a.ndim)}
18-
# Create overlap array.
19-
ag = da.overlap.overlap(a, depth=depth, boundary=boundary)
20-
# apply rolling func
21-
out = da.map_blocks(
22-
moving_func, ag, window, min_count=min_count, axis=axis, dtype=a.dtype
10+
dtype, _ = dtypes.maybe_promote(a.dtype)
11+
return a.data.map_overlap(
12+
moving_func,
13+
depth={axis: (window - 1, 0)},
14+
axis=axis,
15+
dtype=dtype,
16+
window=window,
17+
min_count=min_count,
2318
)
24-
# trim array
25-
result = da.overlap.trim_internal(out, depth)
26-
return result
2719

2820

2921
def least_squares(lhs, rhs, rcond=None, skipna=False):
3022
import dask.array as da
3123

24+
from xarray.core.dask_array_compat import reshape_blockwise
25+
26+
# The trick here is that the core dimension is axis 0.
27+
# All other dimensions need to be reshaped down to one axis for `lstsq`
28+
# (which only accepts 2D input)
29+
# and this needs to be undone after running `lstsq`
30+
# The order of values in the reshaped axes is irrelevant.
31+
# There are big gains to be had by simply reshaping the blocks on a blockwise
32+
# basis, and then undoing that transform.
33+
# We use a specific `reshape_blockwise` method in dask for this optimization
34+
if rhs.ndim > 2:
35+
out_shape = rhs.shape
36+
reshape_chunks = rhs.chunks
37+
rhs = reshape_blockwise(rhs, (rhs.shape[0], math.prod(rhs.shape[1:])))
38+
else:
39+
out_shape = None
40+
3241
lhs_da = da.from_array(lhs, chunks=(rhs.chunks[0], lhs.shape[1]))
3342
if skipna:
3443
added_dim = rhs.ndim == 1
@@ -52,6 +61,17 @@ def least_squares(lhs, rhs, rcond=None, skipna=False):
5261
# Residuals here are (1, 1) but should be (K,) as rhs is (N, K)
5362
# See issue dask/dask#6516
5463
coeffs, residuals, _, _ = da.linalg.lstsq(lhs_da, rhs)
64+
65+
if out_shape is not None:
66+
coeffs = reshape_blockwise(
67+
coeffs,
68+
shape=(coeffs.shape[0], *out_shape[1:]),
69+
chunks=((coeffs.shape[0],), *reshape_chunks[1:]),
70+
)
71+
residuals = reshape_blockwise(
72+
residuals, shape=out_shape[1:], chunks=reshape_chunks[1:]
73+
)
74+
5575
return coeffs, residuals
5676

5777

xarray/core/dataset.py

Lines changed: 50 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -9086,15 +9086,14 @@ def polyfit(
90869086
numpy.polyval
90879087
xarray.polyval
90889088
"""
9089-
from xarray.core.dataarray import DataArray
9090-
9091-
variables = {}
9089+
variables: dict[Hashable, Variable] = {}
90929090
skipna_da = skipna
90939091

90949092
x = np.asarray(_ensure_numeric(self.coords[dim]).astype(np.float64))
90959093

90969094
xname = f"{self[dim].name}_"
90979095
order = int(deg) + 1
9096+
degree_coord_values = np.arange(order)[::-1]
90989097
lhs = np.vander(x, order)
90999098

91009099
if rcond is None:
@@ -9120,46 +9119,48 @@ def polyfit(
91209119
rank = np.linalg.matrix_rank(lhs)
91219120

91229121
if full:
9123-
rank = DataArray(rank, name=xname + "matrix_rank")
9124-
variables[rank.name] = rank
9122+
rank = Variable(dims=(), data=rank)
9123+
variables[xname + "matrix_rank"] = rank
91259124
_sing = np.linalg.svd(lhs, compute_uv=False)
9126-
sing = DataArray(
9127-
_sing,
9125+
variables[xname + "singular_values"] = Variable(
91289126
dims=(degree_dim,),
9129-
coords={degree_dim: np.arange(rank - 1, -1, -1)},
9130-
name=xname + "singular_values",
9127+
data=np.concatenate([np.full((order - rank.data,), np.nan), _sing]),
91319128
)
9132-
variables[sing.name] = sing
91339129

91349130
# If we have a coordinate get its underlying dimension.
9135-
true_dim = self.coords[dim].dims[0]
9131+
(true_dim,) = self.coords[dim].dims
91369132

9137-
for name, da in self.data_vars.items():
9138-
if true_dim not in da.dims:
9133+
other_coords = {
9134+
dim: self._variables[dim]
9135+
for dim in set(self.dims) - {true_dim}
9136+
if dim in self._variables
9137+
}
9138+
present_dims: set[Hashable] = set()
9139+
for name, var in self._variables.items():
9140+
if name in self._coord_names or name in self.dims:
9141+
continue
9142+
if true_dim not in var.dims:
91399143
continue
91409144

9141-
if is_duck_dask_array(da.data) and (
9145+
if is_duck_dask_array(var._data) and (
91429146
rank != order or full or skipna is None
91439147
):
91449148
# Current algorithm with dask and skipna=False neither supports
91459149
# deficient ranks nor does it output the "full" info (issue dask/dask#6516)
91469150
skipna_da = True
91479151
elif skipna is None:
9148-
skipna_da = bool(np.any(da.isnull()))
9149-
9150-
dims_to_stack = [dimname for dimname in da.dims if dimname != true_dim]
9151-
stacked_coords: dict[Hashable, DataArray] = {}
9152-
if dims_to_stack:
9153-
stacked_dim = utils.get_temp_dimname(dims_to_stack, "stacked")
9154-
rhs = da.transpose(true_dim, *dims_to_stack).stack(
9155-
{stacked_dim: dims_to_stack}
9156-
)
9157-
stacked_coords = {stacked_dim: rhs[stacked_dim]}
9158-
scale_da = scale[:, np.newaxis]
9152+
skipna_da = bool(np.any(var.isnull()))
9153+
9154+
if var.ndim > 1:
9155+
rhs = var.transpose(true_dim, ...)
9156+
other_dims = rhs.dims[1:]
9157+
scale_da = scale.reshape(-1, *((1,) * len(other_dims)))
91599158
else:
9160-
rhs = da
9159+
rhs = var
91619160
scale_da = scale
9161+
other_dims = ()
91629162

9163+
present_dims.update(other_dims)
91639164
if w is not None:
91649165
rhs = rhs * w[:, np.newaxis]
91659166

@@ -9179,42 +9180,45 @@ def polyfit(
91799180
# Thus a ReprObject => polyfit was called on a DataArray
91809181
name = ""
91819182

9182-
coeffs = DataArray(
9183-
coeffs / scale_da,
9184-
dims=[degree_dim] + list(stacked_coords.keys()),
9185-
coords={degree_dim: np.arange(order)[::-1], **stacked_coords},
9186-
name=name + "polyfit_coefficients",
9183+
variables[name + "polyfit_coefficients"] = Variable(
9184+
data=coeffs / scale_da, dims=(degree_dim,) + other_dims
91879185
)
9188-
if dims_to_stack:
9189-
coeffs = coeffs.unstack(stacked_dim)
9190-
variables[coeffs.name] = coeffs
91919186

91929187
if full or (cov is True):
9193-
residuals = DataArray(
9194-
residuals if dims_to_stack else residuals.squeeze(),
9195-
dims=list(stacked_coords.keys()),
9196-
coords=stacked_coords,
9197-
name=name + "polyfit_residuals",
9188+
variables[name + "polyfit_residuals"] = Variable(
9189+
data=residuals if var.ndim > 1 else residuals.squeeze(),
9190+
dims=other_dims,
91989191
)
9199-
if dims_to_stack:
9200-
residuals = residuals.unstack(stacked_dim)
9201-
variables[residuals.name] = residuals
92029192

92039193
if cov:
92049194
Vbase = np.linalg.inv(np.dot(lhs.T, lhs))
92059195
Vbase /= np.outer(scale, scale)
9196+
if TYPE_CHECKING:
9197+
fac: int | Variable
92069198
if cov == "unscaled":
92079199
fac = 1
92089200
else:
92099201
if x.shape[0] <= order:
92109202
raise ValueError(
92119203
"The number of data points must exceed order to scale the covariance matrix."
92129204
)
9213-
fac = residuals / (x.shape[0] - order)
9214-
covariance = DataArray(Vbase, dims=("cov_i", "cov_j")) * fac
9215-
variables[name + "polyfit_covariance"] = covariance
9205+
fac = variables[name + "polyfit_residuals"] / (x.shape[0] - order)
9206+
variables[name + "polyfit_covariance"] = (
9207+
Variable(data=Vbase, dims=("cov_i", "cov_j")) * fac
9208+
)
92169209

9217-
return type(self)(data_vars=variables, attrs=self.attrs.copy())
9210+
return type(self)(
9211+
data_vars=variables,
9212+
coords={
9213+
degree_dim: degree_coord_values,
9214+
**{
9215+
name: coord
9216+
for name, coord in other_coords.items()
9217+
if name in present_dims
9218+
},
9219+
},
9220+
attrs=self.attrs.copy(),
9221+
)
92189222

92199223
def pad(
92209224
self,

xarray/core/formatting_html.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,9 @@ def summarize_index(coord_names, index) -> str:
155155
return (
156156
f"<div class='xr-index-name'><div>{name}</div></div>"
157157
f"<div class='xr-index-preview'>{preview}</div>"
158-
f"<div></div>"
158+
# need empty input + label here to conform to the fixed CSS grid layout
159+
f"<input type='checkbox' disabled/>"
160+
f"<label></label>"
159161
f"<input id='{index_id}' class='xr-index-data-in' type='checkbox'/>"
160162
f"<label for='{index_id}' title='Show/Hide index repr'>{data_icon}</label>"
161163
f"<div class='xr-index-data'>{details}</div>"

xarray/core/nputils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,12 @@ def warn_on_deficient_rank(rank, order):
255255

256256

257257
def least_squares(lhs, rhs, rcond=None, skipna=False):
258+
if rhs.ndim > 2:
259+
out_shape = rhs.shape
260+
rhs = rhs.reshape(rhs.shape[0], -1)
261+
else:
262+
out_shape = None
263+
258264
if skipna:
259265
added_dim = rhs.ndim == 1
260266
if added_dim:
@@ -281,6 +287,10 @@ def least_squares(lhs, rhs, rcond=None, skipna=False):
281287
if residuals.size == 0:
282288
residuals = coeffs[0] * np.nan
283289
warn_on_deficient_rank(rank, lhs.shape[1])
290+
291+
if out_shape is not None:
292+
coeffs = coeffs.reshape(-1, *out_shape[1:])
293+
residuals = residuals.reshape(*out_shape[1:])
284294
return coeffs, residuals
285295

286296

0 commit comments

Comments
 (0)