@@ -1219,6 +1219,53 @@ class GroupBy(_GroupBy):
12191219 """
12201220 _apply_whitelist = _common_apply_whitelist
12211221
1222+ def _bool_agg (self , val_test , skipna ):
1223+ """Shared func to call any / all Cython GroupBy implementations"""
1224+
1225+ def objs_to_bool (vals ):
1226+ try :
1227+ vals = vals .astype (np .bool )
1228+ except ValueError : # for objects
1229+ vals = np .array ([bool (x ) for x in vals ])
1230+
1231+ return vals .view (np .uint8 )
1232+
1233+ def result_to_bool (result ):
1234+ return result .astype (np .bool , copy = False )
1235+
1236+ return self ._get_cythonized_result ('group_any_all' , self .grouper ,
1237+ aggregate = True ,
1238+ cython_dtype = np .uint8 ,
1239+ needs_values = True ,
1240+ needs_mask = True ,
1241+ pre_processing = objs_to_bool ,
1242+ post_processing = result_to_bool ,
1243+ val_test = val_test , skipna = skipna )
1244+
1245+ @Substitution (name = 'groupby' )
1246+ @Appender (_doc_template )
1247+ def any (self , skipna = True ):
1248+ """Returns True if any value in the group is truthful, else False
1249+
1250+ Parameters
1251+ ----------
1252+ skipna : bool, default True
1253+ Flag to ignore nan values during truth testing
1254+ """
1255+ return self ._bool_agg ('any' , skipna )
1256+
1257+ @Substitution (name = 'groupby' )
1258+ @Appender (_doc_template )
1259+ def all (self , skipna = True ):
1260+ """Returns True if all values in the group are truthful, else False
1261+
1262+ Parameters
1263+ ----------
1264+ skipna : bool, default True
1265+ Flag to ignore nan values during truth testing
1266+ """
1267+ return self ._bool_agg ('all' , skipna )
1268+
12221269 @Substitution (name = 'groupby' )
12231270 @Appender (_doc_template )
12241271 def count (self ):
@@ -1485,6 +1532,8 @@ def _fill(self, direction, limit=None):
14851532
14861533 return self ._get_cythonized_result ('group_fillna_indexer' ,
14871534 self .grouper , needs_mask = True ,
1535+ cython_dtype = np .int64 ,
1536+ result_is_index = True ,
14881537 direction = direction , limit = limit )
14891538
14901539 @Substitution (name = 'groupby' )
@@ -1873,33 +1922,81 @@ def cummax(self, axis=0, **kwargs):
18731922
18741923 return self ._cython_transform ('cummax' , numeric_only = False )
18751924
1876- def _get_cythonized_result (self , how , grouper , needs_mask = False ,
1877- needs_ngroups = False , ** kwargs ):
1925+ def _get_cythonized_result (self , how , grouper , aggregate = False ,
1926+ cython_dtype = None , needs_values = False ,
1927+ needs_mask = False , needs_ngroups = False ,
1928+ result_is_index = False ,
1929+ pre_processing = None , post_processing = None ,
1930+ ** kwargs ):
18781931 """Get result for Cythonized functions
18791932
18801933 Parameters
18811934 ----------
18821935 how : str, Cythonized function name to be called
18831936 grouper : Grouper object containing pertinent group info
1937+ aggregate : bool, default False
1938+ Whether the result should be aggregated to match the number of
1939+ groups
1940+ cython_dtype : default None
1941+ Type of the array that will be modified by the Cython call. If
1942+ `None`, the type will be inferred from the values of each slice
1943+ needs_values : bool, default False
1944+ Whether the values should be a part of the Cython call
1945+ signature
18841946 needs_mask : bool, default False
1885- Whether boolean mask needs to be part of the Cython call signature
1947+ Whether boolean mask needs to be part of the Cython call
1948+ signature
18861949 needs_ngroups : bool, default False
1887- Whether number of groups part of the Cython call signature
1950+ Whether number of groups is part of the Cython call signature
1951+ result_is_index : bool, default False
1952+ Whether the result of the Cython operation is an index of
1953+ values to be retrieved, instead of the actual values themselves
1954+ pre_processing : function, default None
1955+ Function to be applied to `values` prior to passing to Cython
1956+ Raises if `needs_values` is False
1957+ post_processing : function, default None
1958+ Function to be applied to result of Cython function
18881959 **kwargs : dict
18891960 Extra arguments to be passed back to Cython funcs
18901961
18911962 Returns
18921963 -------
18931964 `Series` or `DataFrame` with filled values
18941965 """
1966+ if result_is_index and aggregate :
1967+ raise ValueError ("'result_is_index' and 'aggregate' cannot both "
1968+ "be True!" )
1969+ if post_processing :
1970+ if not callable (pre_processing ):
1971+ raise ValueError ("'post_processing' must be a callable!" )
1972+ if pre_processing :
1973+ if not callable (pre_processing ):
1974+ raise ValueError ("'pre_processing' must be a callable!" )
1975+ if not needs_values :
1976+ raise ValueError ("Cannot use 'pre_processing' without "
1977+ "specifying 'needs_values'!" )
18951978
18961979 labels , _ , ngroups = grouper .group_info
18971980 output = collections .OrderedDict ()
18981981 base_func = getattr (libgroupby , how )
18991982
19001983 for name , obj in self ._iterate_slices ():
1901- indexer = np .zeros_like (labels , dtype = np .int64 )
1902- func = partial (base_func , indexer , labels )
1984+ if aggregate :
1985+ result_sz = ngroups
1986+ else :
1987+ result_sz = len (obj .values )
1988+
1989+ if not cython_dtype :
1990+ cython_dtype = obj .values .dtype
1991+
1992+ result = np .zeros (result_sz , dtype = cython_dtype )
1993+ func = partial (base_func , result , labels )
1994+ if needs_values :
1995+ vals = obj .values
1996+ if pre_processing :
1997+ vals = pre_processing (vals )
1998+ func = partial (func , vals )
1999+
19032000 if needs_mask :
19042001 mask = isnull (obj .values ).view (np .uint8 )
19052002 func = partial (func , mask )
@@ -1908,9 +2005,19 @@ def _get_cythonized_result(self, how, grouper, needs_mask=False,
19082005 func = partial (func , ngroups )
19092006
19102007 func (** kwargs ) # Call func to modify indexer values in place
1911- output [name ] = algorithms .take_nd (obj .values , indexer )
19122008
1913- return self ._wrap_transformed_output (output )
2009+ if result_is_index :
2010+ result = algorithms .take_nd (obj .values , result )
2011+
2012+ if post_processing :
2013+ result = post_processing (result )
2014+
2015+ output [name ] = result
2016+
2017+ if aggregate :
2018+ return self ._wrap_aggregated_output (output )
2019+ else :
2020+ return self ._wrap_transformed_output (output )
19142021
19152022 @Substitution (name = 'groupby' )
19162023 @Appender (_doc_template )
@@ -1930,7 +2037,9 @@ def shift(self, periods=1, freq=None, axis=0):
19302037 return self .apply (lambda x : x .shift (periods , freq , axis ))
19312038
19322039 return self ._get_cythonized_result ('group_shift_indexer' ,
1933- self .grouper , needs_ngroups = True ,
2040+ self .grouper , cython_dtype = np .int64 ,
2041+ needs_ngroups = True ,
2042+ result_is_index = True ,
19342043 periods = periods )
19352044
19362045 @Substitution (name = 'groupby' )
0 commit comments