88from __future__ import annotations
99
1010import collections
11+ import functools
1112from typing import (
1213 Dict ,
1314 Generic ,
9596 get_indexer_dict ,
9697)
9798
99+ _CYTHON_FUNCTIONS = {
100+ "aggregate" : {
101+ "add" : "group_add" ,
102+ "prod" : "group_prod" ,
103+ "min" : "group_min" ,
104+ "max" : "group_max" ,
105+ "mean" : "group_mean" ,
106+ "median" : "group_median" ,
107+ "var" : "group_var" ,
108+ "first" : "group_nth" ,
109+ "last" : "group_last" ,
110+ "ohlc" : "group_ohlc" ,
111+ },
112+ "transform" : {
113+ "cumprod" : "group_cumprod" ,
114+ "cumsum" : "group_cumsum" ,
115+ "cummin" : "group_cummin" ,
116+ "cummax" : "group_cummax" ,
117+ "rank" : "group_rank" ,
118+ },
119+ }
120+
121+
122+ @functools .lru_cache (maxsize = None )
123+ def _get_cython_function (kind : str , how : str , dtype : np .dtype , is_numeric : bool ):
124+
125+ dtype_str = dtype .name
126+ ftype = _CYTHON_FUNCTIONS [kind ][how ]
127+
128+ # see if there is a fused-type version of function
129+ # only valid for numeric
130+ f = getattr (libgroupby , ftype , None )
131+ if f is not None and is_numeric :
132+ return f
133+
134+ # otherwise find dtype-specific version, falling back to object
135+ for dt in [dtype_str , "object" ]:
136+ f2 = getattr (libgroupby , f"{ ftype } _{ dt } " , None )
137+ if f2 is not None :
138+ return f2
139+
140+ if hasattr (f , "__signatures__" ):
141+ # inspect what fused types are implemented
142+ if dtype_str == "object" and "object" not in f .__signatures__ :
143+ # disallow this function so we get a NotImplementedError below
144+ # instead of a TypeError at runtime
145+ f = None
146+
147+ func = f
148+
149+ if func is None :
150+ raise NotImplementedError (
151+ f"function is not implemented for this dtype: "
152+ f"[how->{ how } ,dtype->{ dtype_str } ]"
153+ )
154+
155+ return func
156+
98157
99158class BaseGrouper :
100159 """
@@ -385,28 +444,6 @@ def get_group_levels(self) -> List[Index]:
385444 # ------------------------------------------------------------
386445 # Aggregation functions
387446
388- _cython_functions = {
389- "aggregate" : {
390- "add" : "group_add" ,
391- "prod" : "group_prod" ,
392- "min" : "group_min" ,
393- "max" : "group_max" ,
394- "mean" : "group_mean" ,
395- "median" : "group_median" ,
396- "var" : "group_var" ,
397- "first" : "group_nth" ,
398- "last" : "group_last" ,
399- "ohlc" : "group_ohlc" ,
400- },
401- "transform" : {
402- "cumprod" : "group_cumprod" ,
403- "cumsum" : "group_cumsum" ,
404- "cummin" : "group_cummin" ,
405- "cummax" : "group_cummax" ,
406- "rank" : "group_rank" ,
407- },
408- }
409-
410447 _cython_arity = {"ohlc" : 4 } # OHLC
411448
412449 @final
@@ -417,43 +454,6 @@ def _is_builtin_func(self, arg):
417454 """
418455 return SelectionMixin ._builtin_table .get (arg , arg )
419456
420- @final
421- def _get_cython_function (
422- self , kind : str , how : str , values : np .ndarray , is_numeric : bool
423- ):
424-
425- dtype_str = values .dtype .name
426- ftype = self ._cython_functions [kind ][how ]
427-
428- # see if there is a fused-type version of function
429- # only valid for numeric
430- f = getattr (libgroupby , ftype , None )
431- if f is not None and is_numeric :
432- return f
433-
434- # otherwise find dtype-specific version, falling back to object
435- for dt in [dtype_str , "object" ]:
436- f2 = getattr (libgroupby , f"{ ftype } _{ dt } " , None )
437- if f2 is not None :
438- return f2
439-
440- if hasattr (f , "__signatures__" ):
441- # inspect what fused types are implemented
442- if dtype_str == "object" and "object" not in f .__signatures__ :
443- # disallow this function so we get a NotImplementedError below
444- # instead of a TypeError at runtime
445- f = None
446-
447- func = f
448-
449- if func is None :
450- raise NotImplementedError (
451- f"function is not implemented for this dtype: "
452- f"[how->{ how } ,dtype->{ dtype_str } ]"
453- )
454-
455- return func
456-
457457 @final
458458 def _get_cython_func_and_vals (
459459 self , kind : str , how : str , values : np .ndarray , is_numeric : bool
@@ -474,7 +474,7 @@ def _get_cython_func_and_vals(
474474 values : np.ndarray
475475 """
476476 try :
477- func = self . _get_cython_function (kind , how , values , is_numeric )
477+ func = _get_cython_function (kind , how , values . dtype , is_numeric )
478478 except NotImplementedError :
479479 if is_numeric :
480480 try :
@@ -484,7 +484,7 @@ def _get_cython_func_and_vals(
484484 values = values .astype (complex )
485485 else :
486486 raise
487- func = self . _get_cython_function (kind , how , values , is_numeric )
487+ func = _get_cython_function (kind , how , values . dtype , is_numeric )
488488 else :
489489 raise
490490 return func , values
0 commit comments