@@ -24,6 +24,7 @@ use crate::{OptimizerConfig, OptimizerRule};
2424
2525use datafusion_common:: { DFSchema , Result } ;
2626use datafusion_expr:: {
27+ aggregate_function:: AggregateFunction :: { Max , Min , Sum } ,
2728 col,
2829 expr:: AggregateFunction ,
2930 logical_plan:: { Aggregate , LogicalPlan , Projection } ,
@@ -35,17 +36,19 @@ use hashbrown::HashSet;
3536
3637/// single distinct to group by optimizer rule
3738/// ```text
38- /// SELECT F1(DISTINCT s),F2(DISTINCT s)
39- /// ...
40- /// GROUP BY k
39+ /// Before:
40+ /// SELECT a, COUNT(DINSTINCT b), SUM(c)
41+ /// FROM t
42+ /// GROUP BY a
4143///
42- /// Into
43- ///
44- /// SELECT F1(alias1),F2(alias1)
44+ /// After:
45+ /// SELECT a, COUNT(alias1), SUM(alias2)
4546/// FROM (
46- /// SELECT s as alias1, k ... GROUP BY s, k
47+ /// SELECT a, b as alias1, SUM(c) as alias2
48+ /// FROM t
49+ /// GROUP BY a, b
4750/// )
48- /// GROUP BY k
51+ /// GROUP BY a
4952/// ```
5053#[ derive( Default ) ]
5154pub struct SingleDistinctToGroupBy { }
@@ -64,22 +67,30 @@ fn is_single_distinct_agg(plan: &LogicalPlan) -> Result<bool> {
6467 match plan {
6568 LogicalPlan :: Aggregate ( Aggregate { aggr_expr, .. } ) => {
6669 let mut fields_set = HashSet :: new ( ) ;
67- let mut distinct_count = 0 ;
70+ let mut aggregate_count = 0 ;
6871 for expr in aggr_expr {
6972 if let Expr :: AggregateFunction ( AggregateFunction {
70- distinct, args, ..
73+ fun,
74+ distinct,
75+ args,
76+ filter,
77+ order_by,
7178 } ) = expr
7279 {
73- if * distinct {
74- distinct_count += 1 ;
80+ if filter . is_some ( ) || order_by . is_some ( ) {
81+ return Ok ( false ) ;
7582 }
76- for e in args {
77- fields_set. insert ( e. canonical_name ( ) ) ;
83+ aggregate_count += 1 ;
84+ if * distinct {
85+ for e in args {
86+ fields_set. insert ( e. canonical_name ( ) ) ;
87+ }
88+ } else if !matches ! ( fun, Sum | Min | Max ) {
89+ return Ok ( false ) ;
7890 }
7991 }
8092 }
81- let res = distinct_count == aggr_expr. len ( ) && fields_set. len ( ) == 1 ;
82- Ok ( res)
93+ Ok ( aggregate_count == aggr_expr. len ( ) && fields_set. len ( ) == 1 )
8394 }
8495 _ => Ok ( false ) ,
8596 }
@@ -152,30 +163,57 @@ impl OptimizerRule for SingleDistinctToGroupBy {
152163 . collect :: < Vec < _ > > ( ) ;
153164
154165 // replace the distinct arg with alias
166+ let mut index = 1 ;
155167 let mut group_fields_set = HashSet :: new ( ) ;
156- let new_aggr_exprs = aggr_expr
168+ let mut inner_aggr_exprs = vec ! [ ] ;
169+ let outer_aggr_exprs = aggr_expr
157170 . iter ( )
158171 . map ( |aggr_expr| match aggr_expr {
159172 Expr :: AggregateFunction ( AggregateFunction {
160173 fun,
161174 args,
162- filter,
163- order_by,
175+ distinct,
164176 ..
165177 } ) => {
166178 // is_single_distinct_agg ensure args.len=1
167- if group_fields_set. insert ( args[ 0 ] . display_name ( ) ?) {
179+ if * distinct
180+ && group_fields_set. insert ( args[ 0 ] . display_name ( ) ?)
181+ {
168182 inner_group_exprs. push (
169183 args[ 0 ] . clone ( ) . alias ( SINGLE_DISTINCT_ALIAS ) ,
170184 ) ;
171185 }
172- Ok ( Expr :: AggregateFunction ( AggregateFunction :: new (
173- fun. clone ( ) ,
174- vec ! [ col( SINGLE_DISTINCT_ALIAS ) ] ,
175- false , // intentional to remove distinct here
176- filter. clone ( ) ,
177- order_by. clone ( ) ,
178- ) ) )
186+
187+ // if the aggregate function is not distinct, we need to rewrite it like two phase aggregation
188+ if !( * distinct) {
189+ index += 1 ;
190+ let alias_str = format ! ( "alias{}" , index) ;
191+ inner_aggr_exprs. push (
192+ Expr :: AggregateFunction ( AggregateFunction :: new (
193+ fun. clone ( ) ,
194+ args. clone ( ) ,
195+ false ,
196+ None ,
197+ None ,
198+ ) )
199+ . alias ( & alias_str) ,
200+ ) ;
201+ Ok ( Expr :: AggregateFunction ( AggregateFunction :: new (
202+ fun. clone ( ) ,
203+ vec ! [ col( & alias_str) ] ,
204+ false ,
205+ None ,
206+ None ,
207+ ) ) )
208+ } else {
209+ Ok ( Expr :: AggregateFunction ( AggregateFunction :: new (
210+ fun. clone ( ) ,
211+ vec ! [ col( SINGLE_DISTINCT_ALIAS ) ] ,
212+ false , // intentional to remove distinct here
213+ None ,
214+ None ,
215+ ) ) )
216+ }
179217 }
180218 _ => Ok ( aggr_expr. clone ( ) ) ,
181219 } )
@@ -184,6 +222,7 @@ impl OptimizerRule for SingleDistinctToGroupBy {
184222 // construct the inner AggrPlan
185223 let inner_fields = inner_group_exprs
186224 . iter ( )
225+ . chain ( inner_aggr_exprs. iter ( ) )
187226 . map ( |expr| expr. to_field ( input. schema ( ) ) )
188227 . collect :: < Result < Vec < _ > > > ( ) ?;
189228 let inner_schema = DFSchema :: new_with_metadata (
@@ -193,12 +232,12 @@ impl OptimizerRule for SingleDistinctToGroupBy {
193232 let inner_agg = LogicalPlan :: Aggregate ( Aggregate :: try_new (
194233 input. clone ( ) ,
195234 inner_group_exprs,
196- Vec :: new ( ) ,
235+ inner_aggr_exprs ,
197236 ) ?) ;
198237
199238 let outer_fields = outer_group_exprs
200239 . iter ( )
201- . chain ( new_aggr_exprs . iter ( ) )
240+ . chain ( outer_aggr_exprs . iter ( ) )
202241 . map ( |expr| expr. to_field ( & inner_schema) )
203242 . collect :: < Result < Vec < _ > > > ( ) ?;
204243 let outer_aggr_schema = Arc :: new ( DFSchema :: new_with_metadata (
@@ -220,7 +259,7 @@ impl OptimizerRule for SingleDistinctToGroupBy {
220259 group_expr
221260 }
222261 } )
223- . chain ( new_aggr_exprs . iter ( ) . enumerate ( ) . map ( |( idx, expr) | {
262+ . chain ( outer_aggr_exprs . iter ( ) . enumerate ( ) . map ( |( idx, expr) | {
224263 let idx = idx + group_size;
225264 let name = fields[ idx] . qualified_name ( ) ;
226265 columnize_expr ( expr. clone ( ) . alias ( name) , & outer_aggr_schema)
@@ -230,7 +269,7 @@ impl OptimizerRule for SingleDistinctToGroupBy {
230269 let outer_aggr = LogicalPlan :: Aggregate ( Aggregate :: try_new (
231270 Arc :: new ( inner_agg) ,
232271 outer_group_exprs,
233- new_aggr_exprs ,
272+ outer_aggr_exprs ,
234273 ) ?) ;
235274
236275 Ok ( Some ( LogicalPlan :: Projection ( Projection :: try_new (
@@ -262,7 +301,7 @@ mod tests {
262301 use datafusion_expr:: expr:: GroupingSet ;
263302 use datafusion_expr:: {
264303 col, count, count_distinct, lit, logical_plan:: builder:: LogicalPlanBuilder , max,
265- AggregateFunction ,
304+ min , sum , AggregateFunction ,
266305 } ;
267306
268307 fn assert_optimized_plan_equal ( plan : & LogicalPlan , expected : & str ) -> Result < ( ) > {
@@ -478,4 +517,181 @@ mod tests {
478517
479518 assert_optimized_plan_equal ( & plan, expected)
480519 }
520+
521+ #[ test]
522+ fn two_distinct_and_one_common ( ) -> Result < ( ) > {
523+ let table_scan = test_table_scan ( ) ?;
524+
525+ let plan = LogicalPlanBuilder :: from ( table_scan)
526+ . aggregate (
527+ vec ! [ col( "a" ) ] ,
528+ vec ! [
529+ sum( col( "c" ) ) ,
530+ count_distinct( col( "b" ) ) ,
531+ Expr :: AggregateFunction ( expr:: AggregateFunction :: new(
532+ AggregateFunction :: Max ,
533+ vec![ col( "b" ) ] ,
534+ true ,
535+ None ,
536+ None ,
537+ ) ) ,
538+ ] ,
539+ ) ?
540+ . build ( ) ?;
541+ // Should work
542+ let expected = "Projection: test.a, SUM(alias2) AS SUM(test.c), COUNT(alias1) AS COUNT(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, SUM(test.c):UInt64;N, COUNT(DISTINCT test.b):Int64;N, MAX(DISTINCT test.b):UInt32;N]\
543+ \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(alias2), COUNT(alias1), MAX(alias1)]] [a:UInt32, SUM(alias2):UInt64;N, COUNT(alias1):Int64;N, MAX(alias1):UInt32;N]\
544+ \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[SUM(test.c) AS alias2]] [a:UInt32, alias1:UInt32, alias2:UInt64;N]\
545+ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
546+
547+ assert_optimized_plan_equal ( & plan, expected)
548+ }
549+
550+ #[ test]
551+ fn one_distinctand_and_two_common ( ) -> Result < ( ) > {
552+ let table_scan = test_table_scan ( ) ?;
553+
554+ let plan = LogicalPlanBuilder :: from ( table_scan)
555+ . aggregate (
556+ vec ! [ col( "a" ) ] ,
557+ vec ! [ sum( col( "c" ) ) , max( col( "c" ) ) , count_distinct( col( "b" ) ) ] ,
558+ ) ?
559+ . build ( ) ?;
560+ // Should work
561+ let expected = "Projection: test.a, SUM(alias2) AS SUM(test.c), MAX(alias3) AS MAX(test.c), COUNT(alias1) AS COUNT(DISTINCT test.b) [a:UInt32, SUM(test.c):UInt64;N, MAX(test.c):UInt32;N, COUNT(DISTINCT test.b):Int64;N]\
562+ \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(alias2), MAX(alias3), COUNT(alias1)]] [a:UInt32, SUM(alias2):UInt64;N, MAX(alias3):UInt32;N, COUNT(alias1):Int64;N]\
563+ \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[SUM(test.c) AS alias2, MAX(test.c) AS alias3]] [a:UInt32, alias1:UInt32, alias2:UInt64;N, alias3:UInt32;N]\
564+ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
565+
566+ assert_optimized_plan_equal ( & plan, expected)
567+ }
568+
569+ #[ test]
570+ fn one_distinct_and_one_common ( ) -> Result < ( ) > {
571+ let table_scan = test_table_scan ( ) ?;
572+
573+ let plan = LogicalPlanBuilder :: from ( table_scan)
574+ . aggregate (
575+ vec ! [ col( "c" ) ] ,
576+ vec ! [ min( col( "a" ) ) , count_distinct( col( "b" ) ) ] ,
577+ ) ?
578+ . build ( ) ?;
579+ // Should work
580+ let expected = "Projection: test.c, MIN(alias2) AS MIN(test.a), COUNT(alias1) AS COUNT(DISTINCT test.b) [c:UInt32, MIN(test.a):UInt32;N, COUNT(DISTINCT test.b):Int64;N]\
581+ \n Aggregate: groupBy=[[test.c]], aggr=[[MIN(alias2), COUNT(alias1)]] [c:UInt32, MIN(alias2):UInt32;N, COUNT(alias1):Int64;N]\
582+ \n Aggregate: groupBy=[[test.c, test.b AS alias1]], aggr=[[MIN(test.a) AS alias2]] [c:UInt32, alias1:UInt32, alias2:UInt32;N]\
583+ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
584+
585+ assert_optimized_plan_equal ( & plan, expected)
586+ }
587+
588+ #[ test]
589+ fn common_with_filter ( ) -> Result < ( ) > {
590+ let table_scan = test_table_scan ( ) ?;
591+
592+ // SUM(a) FILTER (WHERE a > 5)
593+ let expr = Expr :: AggregateFunction ( expr:: AggregateFunction :: new (
594+ AggregateFunction :: Sum ,
595+ vec ! [ col( "a" ) ] ,
596+ false ,
597+ Some ( Box :: new ( col ( "a" ) . gt ( lit ( 5 ) ) ) ) ,
598+ None ,
599+ ) ) ;
600+ let plan = LogicalPlanBuilder :: from ( table_scan)
601+ . aggregate ( vec ! [ col( "c" ) ] , vec ! [ expr, count_distinct( col( "b" ) ) ] ) ?
602+ . build ( ) ?;
603+ // Do nothing
604+ let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a) FILTER (WHERE test.a > Int32(5)), COUNT(DISTINCT test.b)]] [c:UInt32, SUM(test.a) FILTER (WHERE test.a > Int32(5)):UInt64;N, COUNT(DISTINCT test.b):Int64;N]\
605+ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
606+
607+ assert_optimized_plan_equal ( & plan, expected)
608+ }
609+
610+ #[ test]
611+ fn distinct_with_filter ( ) -> Result < ( ) > {
612+ let table_scan = test_table_scan ( ) ?;
613+
614+ // COUNT(DISTINCT a) FILTER (WHERE a > 5)
615+ let expr = Expr :: AggregateFunction ( expr:: AggregateFunction :: new (
616+ AggregateFunction :: Count ,
617+ vec ! [ col( "a" ) ] ,
618+ true ,
619+ Some ( Box :: new ( col ( "a" ) . gt ( lit ( 5 ) ) ) ) ,
620+ None ,
621+ ) ) ;
622+ let plan = LogicalPlanBuilder :: from ( table_scan)
623+ . aggregate ( vec ! [ col( "c" ) ] , vec ! [ sum( col( "a" ) ) , expr] ) ?
624+ . build ( ) ?;
625+ // Do nothing
626+ let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a), COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5))]] [c:UInt32, SUM(test.a):UInt64;N, COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)):Int64;N]\
627+ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
628+
629+ assert_optimized_plan_equal ( & plan, expected)
630+ }
631+
632+ #[ test]
633+ fn common_with_order_by ( ) -> Result < ( ) > {
634+ let table_scan = test_table_scan ( ) ?;
635+
636+ // SUM(a ORDER BY a)
637+ let expr = Expr :: AggregateFunction ( expr:: AggregateFunction :: new (
638+ AggregateFunction :: Sum ,
639+ vec ! [ col( "a" ) ] ,
640+ false ,
641+ None ,
642+ Some ( vec ! [ col( "a" ) ] ) ,
643+ ) ) ;
644+ let plan = LogicalPlanBuilder :: from ( table_scan)
645+ . aggregate ( vec ! [ col( "c" ) ] , vec ! [ expr, count_distinct( col( "b" ) ) ] ) ?
646+ . build ( ) ?;
647+ // Do nothing
648+ let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a) ORDER BY [test.a], COUNT(DISTINCT test.b)]] [c:UInt32, SUM(test.a) ORDER BY [test.a]:UInt64;N, COUNT(DISTINCT test.b):Int64;N]\
649+ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
650+
651+ assert_optimized_plan_equal ( & plan, expected)
652+ }
653+
654+ #[ test]
655+ fn distinct_with_order_by ( ) -> Result < ( ) > {
656+ let table_scan = test_table_scan ( ) ?;
657+
658+ // COUNT(DISTINCT a ORDER BY a)
659+ let expr = Expr :: AggregateFunction ( expr:: AggregateFunction :: new (
660+ AggregateFunction :: Count ,
661+ vec ! [ col( "a" ) ] ,
662+ true ,
663+ None ,
664+ Some ( vec ! [ col( "a" ) ] ) ,
665+ ) ) ;
666+ let plan = LogicalPlanBuilder :: from ( table_scan)
667+ . aggregate ( vec ! [ col( "c" ) ] , vec ! [ sum( col( "a" ) ) , expr] ) ?
668+ . build ( ) ?;
669+ // Do nothing
670+ let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a), COUNT(DISTINCT test.a) ORDER BY [test.a]]] [c:UInt32, SUM(test.a):UInt64;N, COUNT(DISTINCT test.a) ORDER BY [test.a]:Int64;N]\
671+ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
672+
673+ assert_optimized_plan_equal ( & plan, expected)
674+ }
675+
676+ #[ test]
677+ fn aggregate_with_filter_and_order_by ( ) -> Result < ( ) > {
678+ let table_scan = test_table_scan ( ) ?;
679+
680+ // COUNT(DISTINCT a ORDER BY a) FILTER (WHERE a > 5)
681+ let expr = Expr :: AggregateFunction ( expr:: AggregateFunction :: new (
682+ AggregateFunction :: Count ,
683+ vec ! [ col( "a" ) ] ,
684+ true ,
685+ Some ( Box :: new ( col ( "a" ) . gt ( lit ( 5 ) ) ) ) ,
686+ Some ( vec ! [ col( "a" ) ] ) ,
687+ ) ) ;
688+ let plan = LogicalPlanBuilder :: from ( table_scan)
689+ . aggregate ( vec ! [ col( "c" ) ] , vec ! [ sum( col( "a" ) ) , expr] ) ?
690+ . build ( ) ?;
691+ // Do nothing
692+ let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a), COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a]]] [c:UInt32, SUM(test.a):UInt64;N, COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a]:Int64;N]\
693+ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
694+
695+ assert_optimized_plan_equal ( & plan, expected)
696+ }
481697}
0 commit comments