@@ -510,6 +510,29 @@ def test_rolling_exp_keep_attrs(self, da, func) -> None:
510510 ):
511511 da .rolling_exp (time = 10 , keep_attrs = True )
512512
513+ @pytest .mark .parametrize ("func" , ["mean" , "sum" ])
514+ @pytest .mark .parametrize ("min_periods" , [1 , 20 ])
515+ def test_cumulative (self , da , func , min_periods ) -> None :
516+ # One dim
517+ result = getattr (da .cumulative ("time" , min_periods = min_periods ), func )()
518+ expected = getattr (
519+ da .rolling (time = da .time .size , min_periods = min_periods ), func
520+ )()
521+ assert_identical (result , expected )
522+
523+ # Multiple dim
524+ result = getattr (da .cumulative (["time" , "a" ], min_periods = min_periods ), func )()
525+ expected = getattr (
526+ da .rolling (time = da .time .size , a = da .a .size , min_periods = min_periods ),
527+ func ,
528+ )()
529+ assert_identical (result , expected )
530+
531+ def test_cumulative_vs_cum (self , da ) -> None :
532+ result = da .cumulative ("time" ).sum ()
533+ expected = da .cumsum ("time" )
534+ assert_identical (result , expected )
535+
513536
514537class TestDatasetRolling :
515538 @pytest .mark .parametrize (
@@ -834,6 +857,25 @@ def test_raise_no_warning_dask_rolling_assert_close(self, ds, name) -> None:
834857 expected = getattr (getattr (ds .rolling (time = 4 ), name )().rolling (x = 3 ), name )()
835858 assert_allclose (actual , expected )
836859
860+ @pytest .mark .parametrize ("func" , ["mean" , "sum" ])
861+ @pytest .mark .parametrize ("ds" , (2 ,), indirect = True )
862+ @pytest .mark .parametrize ("min_periods" , [1 , 10 ])
863+ def test_cumulative (self , ds , func , min_periods ) -> None :
864+ # One dim
865+ result = getattr (ds .cumulative ("time" , min_periods = min_periods ), func )()
866+ expected = getattr (
867+ ds .rolling (time = ds .time .size , min_periods = min_periods ), func
868+ )()
869+ assert_identical (result , expected )
870+
871+ # Multiple dim
872+ result = getattr (ds .cumulative (["time" , "x" ], min_periods = min_periods ), func )()
873+ expected = getattr (
874+ ds .rolling (time = ds .time .size , x = ds .x .size , min_periods = min_periods ),
875+ func ,
876+ )()
877+ assert_identical (result , expected )
878+
837879
838880@requires_numbagg
839881class TestDatasetRollingExp :
0 commit comments