diff --git a/asv_bench/asv.conf.json b/asv_bench/asv.conf.json index d35a2a223a2..f118aef565c 100644 --- a/asv_bench/asv.conf.json +++ b/asv_bench/asv.conf.json @@ -29,7 +29,7 @@ // If missing or the empty string, the tool will be automatically // determined by looking for tools on the PATH environment // variable. - "environment_type": "conda", + // "environment_type": "conda", // timeout in seconds for installing any dependencies in environment // defaults to 10 min diff --git a/asv_bench/benchmarks/unstacking.py b/asv_bench/benchmarks/unstacking.py index 8d0c3932870..008bea2a56a 100644 --- a/asv_bench/benchmarks/unstacking.py +++ b/asv_bench/benchmarks/unstacking.py @@ -7,23 +7,30 @@ class Unstacking: def setup(self): - data = np.random.RandomState(0).randn(500, 1000) - self.da_full = xr.DataArray(data, dims=list("ab")).stack(flat_dim=[...]) - self.da_missing = self.da_full[:-1] - self.df_missing = self.da_missing.to_pandas() + # data = np.random.RandomState(0).randn(1, 1000, 500) + data = np.random.RandomState(0).randn(1000, 500) + self.da = xr.DataArray(data, dims=list("ab")).stack(c=[...]) + self.da_slow = self.da[::-1] + self.df = self.da.to_pandas() + self.df_slow = self.da_slow.to_pandas() def time_unstack_fast(self): - self.da_full.unstack("flat_dim") + self.da.unstack("c") def time_unstack_slow(self): - self.da_missing.unstack("flat_dim") + self.da_slow.unstack("c") + + def time_unstack_pandas_fast(self): + # As comparison + self.df.unstack() def time_unstack_pandas_slow(self): - self.df_missing.unstack() + # As comparison + self.df_slow.unstack() class UnstackingDask(Unstacking): def setup(self, *args, **kwargs): requires_dask() super().setup(**kwargs) - self.da_full = self.da_full.chunk({"flat_dim": 50}) + self.da = self.da.chunk({"c": 50}) diff --git a/xarray/core/common.py b/xarray/core/common.py index a69ba03a7a4..4e7377449e1 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -3,6 +3,7 @@ from html import escape from textwrap import dedent from typing import ( + TYPE_CHECKING, Any, Callable, Dict, @@ -28,6 +29,12 @@ from .rolling_exp import RollingExp from .utils import Frozen, either_dict_or_kwargs, is_scalar +if TYPE_CHECKING: + from xarray.core.variable import IndexVariable + + from .dataarray import DataArray + + # Used as a sentinel value to indicate a all dimensions ALL_DIMS = ... @@ -580,12 +587,15 @@ def pipe( >>> def adder(data, arg): ... return data + arg ... + >>> def div(data, arg): ... return data / arg ... + >>> def sub_mult(data, sub_arg, mult_arg): ... return (data * mult_arg) - sub_arg ... + >>> x.pipe(adder, 2) Dimensions: (lat: 2, lon: 2) @@ -635,7 +645,12 @@ def pipe( else: return func(self, *args, **kwargs) - def groupby(self, group, squeeze: bool = True, restore_coord_dims: bool = None): + def groupby( + self, + group: Union[Hashable, "DataArray", "IndexVariable"], + squeeze: bool = True, + restore_coord_dims: bool = None, + ): """Returns a GroupBy object for performing grouped operations. Parameters diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index a73e299e27a..1a59892b567 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -3528,7 +3528,7 @@ def reorder_levels( return self._replace(variables, indexes=indexes) - def _stack_once(self, dims, new_dim): + def _stack_once(self, dims: Sequence[Hashable], new_dim: Hashable) -> "Dataset": if ... in dims: dims = list(infix_dims(dims, self.dims)) variables = {} @@ -6486,7 +6486,9 @@ def polyfit( name=name + "polyfit_coefficients", ) if dims_to_stack: - coeffs = coeffs.unstack(stacked_dim) + # Pylance can't see this that `stacked_dim` always exists when + # `dims_to_stack` is true + coeffs = coeffs.unstack(stacked_dim) # type: ignore variables[coeffs.name] = coeffs if full or (cov is True): @@ -6497,7 +6499,7 @@ def polyfit( name=name + "polyfit_residuals", ) if dims_to_stack: - residuals = residuals.unstack(stacked_dim) + residuals = residuals.unstack(stacked_dim) # type: ignore variables[residuals.name] = residuals if cov: diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index e1e5a0fabe8..2214116207b 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -1,6 +1,7 @@ import datetime import functools import warnings +from typing import TYPE_CHECKING, Hashable, Iterator, Union import numpy as np import pandas as pd @@ -11,7 +12,7 @@ from .concat import concat from .formatting import format_array_flat from .indexes import propagate_indexes -from .options import _get_keep_attrs +from .options import OPTIONS, _get_keep_attrs from .pycompat import integer_types from .utils import ( either_dict_or_kwargs, @@ -23,6 +24,9 @@ ) from .variable import IndexVariable, Variable, as_variable +if TYPE_CHECKING: + from .dataarray import DataArray + def check_reduce_dims(reduce_dims, dimensions): @@ -36,6 +40,15 @@ def check_reduce_dims(reduce_dims, dimensions): ) +# Notes +# - Check out squeeze — it will remove dimensions of size one when squeeze = True +# - Only call unique value groups once if possible +# - Probably don't use infer-concat-args +# - Always aggregated +# - We're always going to be on the second path +# - Old dims maybe shouldn't be public + + def unique_value_groups(ar, sort=True): """Group an array by its unique values. @@ -55,12 +68,12 @@ def unique_value_groups(ar, sort=True): the corresponding value in `unique_values`. """ inverse, values = pd.factorize(ar, sort=sort) - groups = [[] for _ in range(len(values))] + indices = [[] for _ in range(len(values))] for n, g in enumerate(inverse): if g >= 0: # pandas uses -1 to mark NaN, but doesn't include them in values - groups[g].append(n) - return values, groups + indices[g].append(n) + return values, indices def _dummy_copy(xarray_obj): @@ -267,7 +280,7 @@ class GroupBy(SupportsArithmetic): def __init__( self, obj, - group, + group_or_name: Union[Hashable, "DataArray", IndexVariable], squeeze=False, grouper=None, bins=None, @@ -280,7 +293,7 @@ def __init__( ---------- obj : Dataset or DataArray Object to group. - group : DataArray + group_or_name : str, DataArray, IndexVariable Array with the group values. squeeze : bool, optional If "group" is a coordinate of object, `squeeze` controls whether @@ -305,20 +318,22 @@ def __init__( if grouper is not None and bins is not None: raise TypeError("can't specify both `grouper` and `bins`") - if not isinstance(group, (DataArray, IndexVariable)): - if not hashable(group): + if not isinstance(group_or_name, (DataArray, IndexVariable)): + if not hashable(group_or_name): raise TypeError( "`group` must be an xarray.DataArray or the " "name of an xarray variable or dimension." - f"Received {group!r} instead." + f"Received {group_or_name!r} instead." ) - group = obj[group] + group = obj[group_or_name] if len(group) == 0: raise ValueError(f"{group.name} must not be empty") if group.name not in obj.coords and group.name in obj.dims: # DummyGroups should not appear on groupby results group = _DummyGroup(obj, group.name, group.coords) + else: + group = group_or_name if getattr(group, "name", None) is None: group.name = "group" @@ -361,7 +376,8 @@ def __init__( if not squeeze: # use slices to do views instead of fancy indexing # equivalent to: group_indices = group_indices.reshape(-1, 1) - group_indices = [slice(i, i + 1) for i in group_indices] + # TODO: fix type error `Unsupported operand types for + ("slice" and "int")` + group_indices = [slice(i, i + 1) for i in group_indices] # type: ignore unique_coord = group else: if group.isnull().any(): @@ -392,8 +408,15 @@ def __init__( # specification for the groupby operation self._obj = obj self._group = group + # The dimension over which to group over. Where the group is a + # non-index coord, this wil differ from the group name self._group_dim = group_dim + # TODO: reword! + # + # A list containing a list for each group. each list contains the + # indices of points the respective group contains self._group_indices = group_indices + # IndexVariable of unique values (labels?) self._unique_coord = unique_coord self._stacked_dim = stacked_dim self._inserted_dims = inserted_dims @@ -404,6 +427,8 @@ def __init__( self._groups = None self._dims = None + # TODO: is this correct? Should we be returning the dims of the result? This + # will use the original dim where we're grouping by a coord. @property def dims(self): if self._dims is None: @@ -882,9 +907,105 @@ def reduce_array(ar): return self.map(reduce_array, shortcut=shortcut) + def dims_(self) -> Iterator[Hashable]: + """ The dims of the resulting object (before any further reduction) """ + + for d in self.dims: + if d != self._group_dim: + yield d + # Grouping on a dimension + elif self._group_dim == self._group.name: + yield d + # Grouping on a coord that isn't a dimension + else: + yield self._group.name + + def _npg_groupby(self, func: str): + + # `.values` seems to be required for datetimes + indices, _ = pd.factorize(self._group.values) + + array = npg_aggregate(self._obj, self._group_dim, func, indices) + + # FIXME: Currently we're trying to use as much of the existing + # infrastructure as possible, but I'm struggling to fit it in — it may + # be easier to start from scratch. + # + # The existing model checks a single result `applied_example`; which in + # this case we don't have — we generate the whole array at once. + + applied = applied_example = type(self._obj)( + data=array, + dims=tuple(self.dims_()), + ) + + # The remainder is mostly copied from `_combine` + + # FIXME: this part seems broken at the moment — the `_infer_concat_args` + # doesn't return the correct result when the group isn't a dimensioned + # coordinate + + coord = self._unique_coord + coord, dim, positions = self._infer_concat_args(applied_example) + # NB: These are commented out for simplicity. + # if shortcut: + # combined = self._concat_shortcut(applied, dim, positions) + # else: + combined = concat(applied, dim) + combined = _maybe_reorder(combined, dim, positions) + + if isinstance(combined, type(self._obj)): + # only restore dimension order for arrays + combined = self._restore_dim_order(combined) + # assign coord when the applied function does not return that coord + if coord is not None: # and dim not in applied_example.dims: + # if shortcut: + # coord_var = as_variable(coord) + # combined._coords[coord.name] = coord_var + # else: + combined.coords[coord.name] = coord + combined = self._maybe_restore_empty_groups(combined) + combined = self._maybe_unstack(combined) + + return combined + + if OPTIONS["numpy_groupies"]: + + def sum(self, dim=None): + grouped = self._npg_groupby(func="sum") + if dim: + return grouped.sum(dim) + else: + return grouped + + def count(self, dim=None): + grouped = self._npg_groupby(func="count") + if dim: + return grouped.count(dim) + else: + return grouped + + def mean(self, dim=None): + grouped = self._npg_groupby(func="mean") + if dim: + return grouped.mean(dim) + else: + return grouped + + +def npg_aggregate( + da: "DataArray", dim: str, func: str, group_idx: IndexVariable +) -> np.array: + from numpy_groupies.aggregate_numba import aggregate + + axis = da.get_axis_num(dim) + return aggregate(group_idx=group_idx, a=da, func=func, axis=axis) + + +if not OPTIONS["numpy_groupies"]: -ops.inject_reduce_methods(DataArrayGroupBy) -ops.inject_binary_ops(DataArrayGroupBy) + ops.inject_reduce_methods(DataArrayGroupBy) + ops.inject_binary_ops(DataArrayGroupBy) class DatasetGroupBy(GroupBy, ImplementsDatasetReduce): diff --git a/xarray/core/options.py b/xarray/core/options.py index d421b4c4f17..9b822a9caaa 100644 --- a/xarray/core/options.py +++ b/xarray/core/options.py @@ -9,6 +9,8 @@ ENABLE_CFTIMEINDEX = "enable_cftimeindex" FILE_CACHE_MAXSIZE = "file_cache_maxsize" KEEP_ATTRS = "keep_attrs" +DISPLAY_STYLE = "display_style" +NUMPY_GROUPIES = "numpy_groupies" WARN_FOR_UNCLOSED_FILES = "warn_for_unclosed_files" @@ -22,6 +24,8 @@ ENABLE_CFTIMEINDEX: True, FILE_CACHE_MAXSIZE: 128, KEEP_ATTRS: "default", + DISPLAY_STYLE: "html", + NUMPY_GROUPIES: True, WARN_FOR_UNCLOSED_FILES: False, } @@ -107,6 +111,7 @@ class set_options: Default: ``'default'``. - ``display_style``: display style to use in jupyter for xarray objects. Default: ``'text'``. Other options are ``'html'``. + - ``numpy_groupies``: use numpy groupies. You can use ``set_options`` either as a context manager: diff --git a/xarray/core/utils.py b/xarray/core/utils.py index ced688f32dd..9d4426be63f 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -117,7 +117,7 @@ def safe_cast_to_index(array: Any) -> pd.Index: def multiindex_from_product_levels( - levels: Sequence[pd.Index], names: Sequence[str] = None + levels: Sequence[pd.Index], names: Sequence[Hashable] = None ) -> pd.MultiIndex: """Creating a MultiIndex from a product without refactorizing levels.