1111
1212import numpy as np
1313import pandas as pd
14+ from numpy import all as array_all # noqa
15+ from numpy import any as array_any # noqa
16+ from numpy import zeros_like # noqa
17+ from numpy import around , broadcast_to # noqa
18+ from numpy import concatenate as _concatenate
19+ from numpy import einsum , isclose , isin , isnan , isnat , pad # noqa
20+ from numpy import stack as _stack
21+ from numpy import take , tensordot , transpose , unravel_index # noqa
22+ from numpy import where as _where
1423
1524from . import dask_array_compat , dask_array_ops , dtypes , npcompat , nputils
1625from .nputils import nanfirst , nanlast
@@ -34,31 +43,15 @@ def _dask_or_eager_func(
3443 name ,
3544 eager_module = np ,
3645 dask_module = dask_array ,
37- list_of_args = False ,
38- array_args = slice (1 ),
39- requires_dask = None ,
4046):
4147 """Create a function that dispatches to dask for dask array inputs."""
42- if dask_module is not None :
43-
44- def f (* args , ** kwargs ):
45- if list_of_args :
46- dispatch_args = args [0 ]
47- else :
48- dispatch_args = args [array_args ]
49- if any (is_duck_dask_array (a ) for a in dispatch_args ):
50- try :
51- wrapped = getattr (dask_module , name )
52- except AttributeError as e :
53- raise AttributeError (f"{ e } : requires dask >={ requires_dask } " )
54- else :
55- wrapped = getattr (eager_module , name )
56- return wrapped (* args , ** kwargs )
5748
58- else :
59-
60- def f (* args , ** kwargs ):
61- return getattr (eager_module , name )(* args , ** kwargs )
49+ def f (* args , ** kwargs ):
50+ if any (is_duck_dask_array (a ) for a in args ):
51+ wrapped = getattr (dask_module , name )
52+ else :
53+ wrapped = getattr (eager_module , name )
54+ return wrapped (* args , ** kwargs )
6255
6356 return f
6457
@@ -72,16 +65,40 @@ def fail_on_dask_array_input(values, msg=None, func_name=None):
7265 raise NotImplementedError (msg % func_name )
7366
7467
75- around = _dask_or_eager_func ("around" )
76- isclose = _dask_or_eager_func ("isclose" )
77-
68+ # Requires special-casing because pandas won't automatically dispatch to dask.isnull via NEP-18
69+ pandas_isnull = _dask_or_eager_func ("isnull" , eager_module = pd , dask_module = dask_array )
7870
79- isnat = np .isnat
80- isnan = _dask_or_eager_func ("isnan" )
81- zeros_like = _dask_or_eager_func ("zeros_like" )
82-
83-
84- pandas_isnull = _dask_or_eager_func ("isnull" , eager_module = pd )
71+ # np.around has failing doctests, overwrite it so they pass:
72+ # https://github.com/numpy/numpy/issues/19759
73+ around .__doc__ = str .replace (
74+ around .__doc__ or "" ,
75+ "array([0., 2.])" ,
76+ "array([0., 2.])" ,
77+ )
78+ around .__doc__ = str .replace (
79+ around .__doc__ or "" ,
80+ "array([0., 2.])" ,
81+ "array([0., 2.])" ,
82+ )
83+ around .__doc__ = str .replace (
84+ around .__doc__ or "" ,
85+ "array([0.4, 1.6])" ,
86+ "array([0.4, 1.6])" ,
87+ )
88+ around .__doc__ = str .replace (
89+ around .__doc__ or "" ,
90+ "array([0., 2., 2., 4., 4.])" ,
91+ "array([0., 2., 2., 4., 4.])" ,
92+ )
93+ around .__doc__ = str .replace (
94+ around .__doc__ or "" ,
95+ (
96+ ' .. [2] "How Futile are Mindless Assessments of\n '
97+ ' Roundoff in Floating-Point Computation?", William Kahan,\n '
98+ " https://people.eecs.berkeley.edu/~wkahan/Mindless.pdf\n "
99+ ),
100+ "" ,
101+ )
85102
86103
87104def isnull (data ):
@@ -114,21 +131,10 @@ def notnull(data):
114131 return ~ isnull (data )
115132
116133
117- transpose = _dask_or_eager_func ("transpose" )
118- _where = _dask_or_eager_func ("where" , array_args = slice (3 ))
119- isin = _dask_or_eager_func ("isin" , array_args = slice (2 ))
120- take = _dask_or_eager_func ("take" )
121- broadcast_to = _dask_or_eager_func ("broadcast_to" )
122- pad = _dask_or_eager_func ("pad" , dask_module = dask_array_compat )
123-
124- _concatenate = _dask_or_eager_func ("concatenate" , list_of_args = True )
125- _stack = _dask_or_eager_func ("stack" , list_of_args = True )
126-
127- array_all = _dask_or_eager_func ("all" )
128- array_any = _dask_or_eager_func ("any" )
129-
130- tensordot = _dask_or_eager_func ("tensordot" , array_args = slice (2 ))
131- einsum = _dask_or_eager_func ("einsum" , array_args = slice (1 , None ))
134+ # TODO replace with simply np.ma.masked_invalid once numpy/numpy#16022 is fixed
135+ masked_invalid = _dask_or_eager_func (
136+ "masked_invalid" , eager_module = np .ma , dask_module = getattr (dask_array , "ma" , None )
137+ )
132138
133139
134140def gradient (x , coord , axis , edge_order ):
@@ -166,11 +172,6 @@ def cumulative_trapezoid(y, x, axis):
166172 return cumsum (integrand , axis = axis , skipna = False )
167173
168174
169- masked_invalid = _dask_or_eager_func (
170- "masked_invalid" , eager_module = np .ma , dask_module = getattr (dask_array , "ma" , None )
171- )
172-
173-
174175def astype (data , dtype , ** kwargs ):
175176 if (
176177 isinstance (data , sparse_array_type )
@@ -317,9 +318,7 @@ def _ignore_warnings_if(condition):
317318 yield
318319
319320
320- def _create_nan_agg_method (
321- name , dask_module = dask_array , coerce_strings = False , invariant_0d = False
322- ):
321+ def _create_nan_agg_method (name , coerce_strings = False , invariant_0d = False ):
323322 from . import nanops
324323
325324 def f (values , axis = None , skipna = None , ** kwargs ):
@@ -344,7 +343,8 @@ def f(values, axis=None, skipna=None, **kwargs):
344343 else :
345344 if name in ["sum" , "prod" ]:
346345 kwargs .pop ("min_count" , None )
347- func = _dask_or_eager_func (name , dask_module = dask_module )
346+
347+ func = getattr (np , name )
348348
349349 try :
350350 with warnings .catch_warnings ():
@@ -378,9 +378,7 @@ def f(values, axis=None, skipna=None, **kwargs):
378378std .numeric_only = True
379379var = _create_nan_agg_method ("var" )
380380var .numeric_only = True
381- median = _create_nan_agg_method (
382- "median" , dask_module = dask_array_compat , invariant_0d = True
383- )
381+ median = _create_nan_agg_method ("median" , invariant_0d = True )
384382median .numeric_only = True
385383prod = _create_nan_agg_method ("prod" , invariant_0d = True )
386384prod .numeric_only = True
@@ -389,7 +387,6 @@ def f(values, axis=None, skipna=None, **kwargs):
389387cumprod_1d .numeric_only = True
390388cumsum_1d = _create_nan_agg_method ("cumsum" , invariant_0d = True )
391389cumsum_1d .numeric_only = True
392- unravel_index = _dask_or_eager_func ("unravel_index" )
393390
394391
395392_mean = _create_nan_agg_method ("mean" , invariant_0d = True )
0 commit comments