@@ -1219,6 +1219,29 @@ class GroupBy(_GroupBy):
12191219 """
12201220 _apply_whitelist = _common_apply_whitelist
12211221
1222+ def _bool_agg (self , how , skipna ):
1223+ """Shared func to call any / all Cython GroupBy implementations"""
1224+
1225+ def objs_to_bool (vals ):
1226+ try :
1227+ vals = vals .astype (np .bool )
1228+ except ValueError : # for objects
1229+ vals = np .array ([bool (x ) for x in vals ])
1230+
1231+ return vals .view (np .uint8 )
1232+
1233+ def result_to_bool (result ):
1234+ return result .astype (np .bool , copy = False )
1235+
1236+ return self ._get_cythonized_result (how , self .grouper ,
1237+ aggregate = True ,
1238+ cython_dtype = np .uint8 ,
1239+ needs_values = True ,
1240+ needs_mask = True ,
1241+ pre_processing = objs_to_bool ,
1242+ post_processing = result_to_bool ,
1243+ skipna = skipna )
1244+
12221245 @Substitution (name = 'groupby' )
12231246 @Appender (_doc_template )
12241247 def any (self , skipna = True ):
@@ -1229,15 +1252,19 @@ def any(self, skipna=True):
12291252 skipna : bool, default True
12301253 Flag to ignore nan values during truth testing
12311254 """
1232- labels , _ , _ = self .grouper .group_info
1233- output = collections .OrderedDict ()
1255+ return self ._bool_agg ('group_any' , skipna )
12341256
1235- for name , obj in self . _iterate_slices ():
1236- result = np . zeros ( self . ngroups , dtype = np . int64 )
1237- libgroupby . group_any ( result , obj . values , labels , skipna )
1238- output [ name ] = result . astype ( np . bool )
1257+ @ Substitution ( name = 'groupby' )
1258+ @ Appender ( _doc_template )
1259+ def all ( self , skipna = True ):
1260+ """Returns True if all values in the group are truthful, else False
12391261
1240- return self ._wrap_aggregated_output (output )
1262+ Parameters
1263+ ----------
1264+ skipna : bool, default True
1265+ Flag to ignore nan values during truth testing
1266+ """
1267+ return self ._bool_agg ('group_all' , skipna )
12411268
12421269 @Substitution (name = 'groupby' )
12431270 @Appender (_doc_template )
@@ -1505,6 +1532,8 @@ def _fill(self, direction, limit=None):
15051532
15061533 return self ._get_cythonized_result ('group_fillna_indexer' ,
15071534 self .grouper , needs_mask = True ,
1535+ cython_dtype = np .int64 ,
1536+ result_is_index = True ,
15081537 direction = direction , limit = limit )
15091538
15101539 @Substitution (name = 'groupby' )
@@ -1893,33 +1922,81 @@ def cummax(self, axis=0, **kwargs):
18931922
18941923 return self ._cython_transform ('cummax' , numeric_only = False )
18951924
1896- def _get_cythonized_result (self , how , grouper , needs_mask = False ,
1897- needs_ngroups = False , ** kwargs ):
1925+ def _get_cythonized_result (self , how , grouper , aggregate = False ,
1926+ cython_dtype = None , needs_values = False ,
1927+ needs_mask = False , needs_ngroups = False ,
1928+ result_is_index = False ,
1929+ pre_processing = None , post_processing = None ,
1930+ ** kwargs ):
18981931 """Get result for Cythonized functions
18991932
19001933 Parameters
19011934 ----------
19021935 how : str, Cythonized function name to be called
19031936 grouper : Grouper object containing pertinent group info
1937+ aggregate : bool, default False
1938+ Whether the result should be aggregated to match the number of
1939+ groups
1940+ cython_dtype : default None
1941+ Type of the array that will be modified by the Cython call. If
1942+ `None`, the type will be inferred from the values of each slice
1943+ needs_values : bool, default False
1944+ Whether the values should be a part of the Cython call
1945+ signature
19041946 needs_mask : bool, default False
1905- Whether boolean mask needs to be part of the Cython call signature
1947+ Whether boolean mask needs to be part of the Cython call
1948+ signature
19061949 needs_ngroups : bool, default False
1907- Whether number of groups part of the Cython call signature
1950+ Whether number of groups is part of the Cython call signature
1951+ result_is_index : bool, default False
1952+ Whether the result of the Cython operation is an index of
1953+ values to be retrieved, instead of the actual values themselves
1954+ pre_processing : function, default None
1955+ Function to be applied to `values` prior to passing to Cython
1956+ Raises if `needs_values` is False
1957+ post_processing : function, default None
1958+ Function to be applied to result of Cython function
19081959 **kwargs : dict
19091960 Extra arguments to be passed back to Cython funcs
19101961
19111962 Returns
19121963 -------
19131964 `Series` or `DataFrame` with filled values
19141965 """
1966+ if result_is_index and aggregate :
1967+ raise ValueError ("'result_is_index' and 'aggregate' cannot both "
1968+ "be True!" )
1969+ if post_processing :
1970+ if not callable (pre_processing ):
1971+ raise ValueError ("'post_processing' must be a callable!" )
1972+ if pre_processing :
1973+ if not callable (pre_processing ):
1974+ raise ValueError ("'pre_processing' must be a callable!" )
1975+ if not needs_values :
1976+ raise ValueError ("Cannot use 'pre_processing' without "
1977+ "specifying 'needs_values'!" )
19151978
19161979 labels , _ , ngroups = grouper .group_info
19171980 output = collections .OrderedDict ()
19181981 base_func = getattr (libgroupby , how )
19191982
19201983 for name , obj in self ._iterate_slices ():
1921- indexer = np .zeros_like (labels , dtype = np .int64 )
1922- func = partial (base_func , indexer , labels )
1984+ if aggregate :
1985+ result_sz = ngroups
1986+ else :
1987+ result_sz = len (obj .values )
1988+
1989+ if not cython_dtype :
1990+ cython_dtype = obj .values .dtype
1991+
1992+ result = np .zeros (result_sz , dtype = cython_dtype )
1993+ func = partial (base_func , result , labels )
1994+ if needs_values :
1995+ vals = obj .values
1996+ if pre_processing :
1997+ vals = pre_processing (vals )
1998+ func = partial (func , vals )
1999+
19232000 if needs_mask :
19242001 mask = isnull (obj .values ).view (np .uint8 )
19252002 func = partial (func , mask )
@@ -1928,9 +2005,19 @@ def _get_cythonized_result(self, how, grouper, needs_mask=False,
19282005 func = partial (func , ngroups )
19292006
19302007 func (** kwargs ) # Call func to modify indexer values in place
1931- output [name ] = algorithms .take_nd (obj .values , indexer )
19322008
1933- return self ._wrap_transformed_output (output )
2009+ if result_is_index :
2010+ result = algorithms .take_nd (obj .values , result )
2011+
2012+ if post_processing :
2013+ result = post_processing (result )
2014+
2015+ output [name ] = result
2016+
2017+ if aggregate :
2018+ return self ._wrap_aggregated_output (output )
2019+ else :
2020+ return self ._wrap_transformed_output (output )
19342021
19352022 @Substitution (name = 'groupby' )
19362023 @Appender (_doc_template )
@@ -1950,7 +2037,9 @@ def shift(self, periods=1, freq=None, axis=0):
19502037 return self .apply (lambda x : x .shift (periods , freq , axis ))
19512038
19522039 return self ._get_cythonized_result ('group_shift_indexer' ,
1953- self .grouper , needs_ngroups = True ,
2040+ self .grouper , cython_dtype = np .int64 ,
2041+ needs_ngroups = True ,
2042+ result_is_index = True ,
19542043 periods = periods )
19552044
19562045 @Substitution (name = 'groupby' )
0 commit comments