Skip to content

Commit 4b5c87b

Browse files
Use ddof in numbagg>=0.7.0 for aggregations (#8624)
* Use numbagg ddof for aggregations (not yet for rolling) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent f07e895 commit 4b5c87b

File tree

4 files changed

+27
-18
lines changed

4 files changed

+27
-18
lines changed

xarray/core/nputils.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,12 @@ def f(values, axis=None, **kwargs):
185185
and pycompat.mod_version("numbagg") >= Version("0.5.0")
186186
and OPTIONS["use_numbagg"]
187187
and isinstance(values, np.ndarray)
188-
# numbagg uses ddof=1 only, but numpy uses ddof=0 by default
189-
and (("var" in name or "std" in name) and kwargs.get("ddof", 0) == 1)
188+
# numbagg<0.7.0 uses ddof=1 only, but numpy uses ddof=0 by default
189+
and (
190+
pycompat.mod_version("numbagg") >= Version("0.7.0")
191+
or ("var" not in name and "std" not in name)
192+
or kwargs.get("ddof", 0) == 1
193+
)
190194
# TODO: bool?
191195
and values.dtype.kind in "uifc"
192196
# and values.dtype.isnative
@@ -196,9 +200,12 @@ def f(values, axis=None, **kwargs):
196200

197201
nba_func = getattr(numbagg, name, None)
198202
if nba_func is not None:
199-
# numbagg does not take care dtype, ddof
203+
# numbagg does not use dtype
200204
kwargs.pop("dtype", None)
201-
kwargs.pop("ddof", None)
205+
# prior to 0.7.0, numbagg did not support ddof; we ensure it's limited
206+
# to ddof=1 above.
207+
if pycompat.mod_version("numbagg") < Version("0.7.0"):
208+
kwargs.pop("ddof", None)
202209
return nba_func(values, axis=axis, **kwargs)
203210
if (
204211
_BOTTLENECK_AVAILABLE

xarray/tests/conftest.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pandas as pd
55
import pytest
66

7+
import xarray as xr
78
from xarray import DataArray, Dataset
89
from xarray.tests import create_test_data, requires_dask
910

@@ -13,6 +14,19 @@ def backend(request):
1314
return request.param
1415

1516

17+
@pytest.fixture(params=["numbagg", "bottleneck"])
18+
def compute_backend(request):
19+
if request.param == "bottleneck":
20+
options = dict(use_bottleneck=True, use_numbagg=False)
21+
elif request.param == "numbagg":
22+
options = dict(use_bottleneck=False, use_numbagg=True)
23+
else:
24+
raise ValueError
25+
26+
with xr.set_options(**options):
27+
yield request.param
28+
29+
1630
@pytest.fixture(params=[1])
1731
def ds(request, backend):
1832
if request.param == 1:

xarray/tests/test_rolling.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,6 @@
2323
]
2424

2525

26-
@pytest.fixture(params=["numbagg", "bottleneck"])
27-
def compute_backend(request):
28-
if request.param == "bottleneck":
29-
options = dict(use_bottleneck=True, use_numbagg=False)
30-
elif request.param == "numbagg":
31-
options = dict(use_bottleneck=False, use_numbagg=True)
32-
else:
33-
raise ValueError
34-
35-
with xr.set_options(**options):
36-
yield request.param
37-
38-
3926
@pytest.mark.parametrize("func", ["mean", "sum"])
4027
@pytest.mark.parametrize("min_periods", [1, 10])
4128
def test_cumulative(d, func, min_periods) -> None:

xarray/tests/test_variable.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1754,7 +1754,8 @@ def test_reduce(self):
17541754
v.mean(dim="x", axis=0)
17551755

17561756
@requires_bottleneck
1757-
def test_reduce_use_bottleneck(self, monkeypatch):
1757+
@pytest.mark.parametrize("compute_backend", ["bottleneck"], indirect=True)
1758+
def test_reduce_use_bottleneck(self, monkeypatch, compute_backend):
17581759
def raise_if_called(*args, **kwargs):
17591760
raise RuntimeError("should not have been called")
17601761

0 commit comments

Comments
 (0)