@@ -25,9 +25,7 @@ use datafusion_expr::expr::{
2525 AggregateFunction , AggregateFunctionDefinition , WindowFunction ,
2626} ;
2727use datafusion_expr:: utils:: COUNT_STAR_EXPANSION ;
28- use datafusion_expr:: {
29- aggregate_function, lit, Expr , LogicalPlan , WindowFunctionDefinition ,
30- } ;
28+ use datafusion_expr:: { lit, Expr , LogicalPlan , WindowFunctionDefinition } ;
3129
3230/// Rewrite `Count(Expr:Wildcard)` to `Count(Expr:Literal)`.
3331///
@@ -56,37 +54,19 @@ fn is_wildcard(expr: &Expr) -> bool {
5654}
5755
5856fn is_count_star_aggregate ( aggregate_function : & AggregateFunction ) -> bool {
59- match aggregate_function {
57+ matches ! ( aggregate_function,
6058 AggregateFunction {
6159 func_def: AggregateFunctionDefinition :: UDF ( udf) ,
6260 args,
6361 ..
64- } if udf. name ( ) == "COUNT" && args. len ( ) == 1 && is_wildcard ( & args[ 0 ] ) => true ,
65- AggregateFunction {
66- func_def :
67- AggregateFunctionDefinition :: BuiltIn (
68- datafusion_expr:: aggregate_function:: AggregateFunction :: Count ,
69- ) ,
70- args,
71- ..
72- } if args. len ( ) == 1 && is_wildcard ( & args[ 0 ] ) => true ,
73- _ => false ,
74- }
62+ } if udf. name( ) == "COUNT" && args. len( ) == 1 && is_wildcard( & args[ 0 ] ) )
7563}
7664
7765fn is_count_star_window_aggregate ( window_function : & WindowFunction ) -> bool {
7866 let args = & window_function. args ;
79- match window_function. fun {
80- WindowFunctionDefinition :: AggregateFunction (
81- aggregate_function:: AggregateFunction :: Count ,
82- ) if args. len ( ) == 1 && is_wildcard ( & args[ 0 ] ) => true ,
67+ matches ! ( window_function. fun,
8368 WindowFunctionDefinition :: AggregateUDF ( ref udaf)
84- if udaf. name ( ) == "COUNT" && args. len ( ) == 1 && is_wildcard ( & args[ 0 ] ) =>
85- {
86- true
87- }
88- _ => false ,
89- }
69+ if udaf. name( ) == "COUNT" && args. len( ) == 1 && is_wildcard( & args[ 0 ] ) )
9070}
9171
9272fn analyze_internal ( plan : LogicalPlan ) -> Result < Transformed < LogicalPlan > > {
@@ -121,14 +101,16 @@ mod tests {
121101 use arrow:: datatypes:: DataType ;
122102 use datafusion_common:: ScalarValue ;
123103 use datafusion_expr:: expr:: Sort ;
124- use datafusion_expr:: test:: function_stub:: sum;
125104 use datafusion_expr:: {
126- col, count , exists, expr, in_subquery, logical_plan:: LogicalPlanBuilder , max,
127- out_ref_col, scalar_subquery, wildcard, AggregateFunction , WindowFrame ,
128- WindowFrameBound , WindowFrameUnits ,
105+ col, exists, expr, in_subquery, logical_plan:: LogicalPlanBuilder , max,
106+ out_ref_col, scalar_subquery, wildcard, WindowFrame , WindowFrameBound ,
107+ WindowFrameUnits ,
129108 } ;
109+ use datafusion_functions_aggregate:: count:: count_udaf;
130110 use std:: sync:: Arc ;
131111
112+ use datafusion_functions_aggregate:: expr_fn:: { count, sum} ;
113+
132114 fn assert_plan_eq ( plan : LogicalPlan , expected : & str ) -> Result < ( ) > {
133115 assert_analyzed_plan_eq_display_indent (
134116 Arc :: new ( CountWildcardRule :: new ( ) ) ,
@@ -239,7 +221,7 @@ mod tests {
239221
240222 let plan = LogicalPlanBuilder :: from ( table_scan)
241223 . window ( vec ! [ Expr :: WindowFunction ( expr:: WindowFunction :: new(
242- WindowFunctionDefinition :: AggregateFunction ( AggregateFunction :: Count ) ,
224+ WindowFunctionDefinition :: AggregateUDF ( count_udaf ( ) ) ,
243225 vec![ wildcard( ) ] ,
244226 vec![ ] ,
245227 vec![ Expr :: Sort ( Sort :: new( Box :: new( col( "a" ) ) , false , true ) ) ] ,
0 commit comments