Skip to content

Commit 39038f3

Browse files
authored
Merge branch 'main' into test-unification
2 parents 4c8abcb + 9acc411 commit 39038f3

File tree

1 file changed

+42
-0
lines changed

1 file changed

+42
-0
lines changed

xarray/tests/test_rolling.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,29 @@ def test_rolling_exp_keep_attrs(self, da, func) -> None:
510510
):
511511
da.rolling_exp(time=10, keep_attrs=True)
512512

513+
@pytest.mark.parametrize("func", ["mean", "sum"])
514+
@pytest.mark.parametrize("min_periods", [1, 20])
515+
def test_cumulative(self, da, func, min_periods) -> None:
516+
# One dim
517+
result = getattr(da.cumulative("time", min_periods=min_periods), func)()
518+
expected = getattr(
519+
da.rolling(time=da.time.size, min_periods=min_periods), func
520+
)()
521+
assert_identical(result, expected)
522+
523+
# Multiple dim
524+
result = getattr(da.cumulative(["time", "a"], min_periods=min_periods), func)()
525+
expected = getattr(
526+
da.rolling(time=da.time.size, a=da.a.size, min_periods=min_periods),
527+
func,
528+
)()
529+
assert_identical(result, expected)
530+
531+
def test_cumulative_vs_cum(self, da) -> None:
532+
result = da.cumulative("time").sum()
533+
expected = da.cumsum("time")
534+
assert_identical(result, expected)
535+
513536

514537
class TestDatasetRolling:
515538
@pytest.mark.parametrize(
@@ -834,6 +857,25 @@ def test_raise_no_warning_dask_rolling_assert_close(self, ds, name) -> None:
834857
expected = getattr(getattr(ds.rolling(time=4), name)().rolling(x=3), name)()
835858
assert_allclose(actual, expected)
836859

860+
@pytest.mark.parametrize("func", ["mean", "sum"])
861+
@pytest.mark.parametrize("ds", (2,), indirect=True)
862+
@pytest.mark.parametrize("min_periods", [1, 10])
863+
def test_cumulative(self, ds, func, min_periods) -> None:
864+
# One dim
865+
result = getattr(ds.cumulative("time", min_periods=min_periods), func)()
866+
expected = getattr(
867+
ds.rolling(time=ds.time.size, min_periods=min_periods), func
868+
)()
869+
assert_identical(result, expected)
870+
871+
# Multiple dim
872+
result = getattr(ds.cumulative(["time", "x"], min_periods=min_periods), func)()
873+
expected = getattr(
874+
ds.rolling(time=ds.time.size, x=ds.x.size, min_periods=min_periods),
875+
func,
876+
)()
877+
assert_identical(result, expected)
878+
837879

838880
@requires_numbagg
839881
class TestDatasetRollingExp:

0 commit comments

Comments
 (0)