66import numpy as np
77from pandas .api .types import is_extension_array_dtype
88
9- from xarray .core import npcompat , utils
9+ from xarray .core import array_api_compat , npcompat , utils
1010
1111# Use as a sentinel value to indicate a dtype appropriate NA value.
1212NA = utils .ReprObject ("<NA>" )
@@ -131,7 +131,10 @@ def get_pos_infinity(dtype, max_for_int=False):
131131 if isdtype (dtype , "complex floating" ):
132132 return np .inf + 1j * np .inf
133133
134- return INF
134+ if isdtype (dtype , "bool" ):
135+ return True
136+
137+ return np .array (INF , dtype = object )
135138
136139
137140def get_neg_infinity (dtype , min_for_int = False ):
@@ -159,7 +162,10 @@ def get_neg_infinity(dtype, min_for_int=False):
159162 if isdtype (dtype , "complex floating" ):
160163 return - np .inf - 1j * np .inf
161164
162- return NINF
165+ if isdtype (dtype , "bool" ):
166+ return False
167+
168+ return np .array (NINF , dtype = object )
163169
164170
165171def is_datetime_like (dtype ) -> bool :
@@ -209,8 +215,16 @@ def isdtype(dtype, kind: str | tuple[str, ...], xp=None) -> bool:
209215 return xp .isdtype (dtype , kind )
210216
211217
218+ def preprocess_scalar_types (t ):
219+ if isinstance (t , (str , bytes )):
220+ return type (t )
221+ else :
222+ return t
223+
224+
212225def result_type (
213226 * arrays_and_dtypes : np .typing .ArrayLike | np .typing .DTypeLike ,
227+ xp = None ,
214228) -> np .dtype :
215229 """Like np.result_type, but with type promotion rules matching pandas.
216230
@@ -227,26 +241,26 @@ def result_type(
227241 -------
228242 numpy.dtype for the result.
229243 """
244+ # TODO (keewis): replace `array_api_compat.result_type` with `xp.result_type` once we
245+ # can require a version of the Array API that supports passing scalars to it.
230246 from xarray .core .duck_array_ops import get_array_namespace
231247
232- # TODO(shoyer): consider moving this logic into get_array_namespace()
233- # or another helper function.
234- namespaces = {get_array_namespace (t ) for t in arrays_and_dtypes }
235- non_numpy = namespaces - {np }
236- if non_numpy :
237- [xp ] = non_numpy
238- else :
239- xp = np
240-
241- types = {xp .result_type (t ) for t in arrays_and_dtypes }
248+ if xp is None :
249+ xp = get_array_namespace (arrays_and_dtypes )
242250
251+ types = {
252+ array_api_compat .result_type (preprocess_scalar_types (t ), xp = xp )
253+ for t in arrays_and_dtypes
254+ }
243255 if any (isinstance (t , np .dtype ) for t in types ):
244256 # only check if there's numpy dtypes – the array API does not
245257 # define the types we're checking for
246258 for left , right in PROMOTE_TO_OBJECT :
247259 if any (np .issubdtype (t , left ) for t in types ) and any (
248260 np .issubdtype (t , right ) for t in types
249261 ):
250- return xp .dtype (object )
262+ return np .dtype (object )
251263
252- return xp .result_type (* arrays_and_dtypes )
264+ return array_api_compat .result_type (
265+ * map (preprocess_scalar_types , arrays_and_dtypes ), xp = xp
266+ )
0 commit comments