Skip to content

Commit f29bcf3

Browse files
authored
Support no distinct aggregate sum/min/max in single_distinct_to_group_by rule (#8266)
* init impl * add some tests * add filter tests * minor * add more tests * update test
1 parent f8dcc64 commit f29bcf3

File tree

2 files changed

+330
-32
lines changed

2 files changed

+330
-32
lines changed

datafusion/optimizer/src/single_distinct_to_groupby.rs

Lines changed: 248 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use crate::{OptimizerConfig, OptimizerRule};
2424

2525
use datafusion_common::{DFSchema, Result};
2626
use datafusion_expr::{
27+
aggregate_function::AggregateFunction::{Max, Min, Sum},
2728
col,
2829
expr::AggregateFunction,
2930
logical_plan::{Aggregate, LogicalPlan, Projection},
@@ -35,17 +36,19 @@ use hashbrown::HashSet;
3536

3637
/// single distinct to group by optimizer rule
3738
/// ```text
38-
/// SELECT F1(DISTINCT s),F2(DISTINCT s)
39-
/// ...
40-
/// GROUP BY k
39+
/// Before:
40+
/// SELECT a, COUNT(DINSTINCT b), SUM(c)
41+
/// FROM t
42+
/// GROUP BY a
4143
///
42-
/// Into
43-
///
44-
/// SELECT F1(alias1),F2(alias1)
44+
/// After:
45+
/// SELECT a, COUNT(alias1), SUM(alias2)
4546
/// FROM (
46-
/// SELECT s as alias1, k ... GROUP BY s, k
47+
/// SELECT a, b as alias1, SUM(c) as alias2
48+
/// FROM t
49+
/// GROUP BY a, b
4750
/// )
48-
/// GROUP BY k
51+
/// GROUP BY a
4952
/// ```
5053
#[derive(Default)]
5154
pub struct SingleDistinctToGroupBy {}
@@ -64,22 +67,30 @@ fn is_single_distinct_agg(plan: &LogicalPlan) -> Result<bool> {
6467
match plan {
6568
LogicalPlan::Aggregate(Aggregate { aggr_expr, .. }) => {
6669
let mut fields_set = HashSet::new();
67-
let mut distinct_count = 0;
70+
let mut aggregate_count = 0;
6871
for expr in aggr_expr {
6972
if let Expr::AggregateFunction(AggregateFunction {
70-
distinct, args, ..
73+
fun,
74+
distinct,
75+
args,
76+
filter,
77+
order_by,
7178
}) = expr
7279
{
73-
if *distinct {
74-
distinct_count += 1;
80+
if filter.is_some() || order_by.is_some() {
81+
return Ok(false);
7582
}
76-
for e in args {
77-
fields_set.insert(e.canonical_name());
83+
aggregate_count += 1;
84+
if *distinct {
85+
for e in args {
86+
fields_set.insert(e.canonical_name());
87+
}
88+
} else if !matches!(fun, Sum | Min | Max) {
89+
return Ok(false);
7890
}
7991
}
8092
}
81-
let res = distinct_count == aggr_expr.len() && fields_set.len() == 1;
82-
Ok(res)
93+
Ok(aggregate_count == aggr_expr.len() && fields_set.len() == 1)
8394
}
8495
_ => Ok(false),
8596
}
@@ -152,30 +163,57 @@ impl OptimizerRule for SingleDistinctToGroupBy {
152163
.collect::<Vec<_>>();
153164

154165
// replace the distinct arg with alias
166+
let mut index = 1;
155167
let mut group_fields_set = HashSet::new();
156-
let new_aggr_exprs = aggr_expr
168+
let mut inner_aggr_exprs = vec![];
169+
let outer_aggr_exprs = aggr_expr
157170
.iter()
158171
.map(|aggr_expr| match aggr_expr {
159172
Expr::AggregateFunction(AggregateFunction {
160173
fun,
161174
args,
162-
filter,
163-
order_by,
175+
distinct,
164176
..
165177
}) => {
166178
// is_single_distinct_agg ensure args.len=1
167-
if group_fields_set.insert(args[0].display_name()?) {
179+
if *distinct
180+
&& group_fields_set.insert(args[0].display_name()?)
181+
{
168182
inner_group_exprs.push(
169183
args[0].clone().alias(SINGLE_DISTINCT_ALIAS),
170184
);
171185
}
172-
Ok(Expr::AggregateFunction(AggregateFunction::new(
173-
fun.clone(),
174-
vec![col(SINGLE_DISTINCT_ALIAS)],
175-
false, // intentional to remove distinct here
176-
filter.clone(),
177-
order_by.clone(),
178-
)))
186+
187+
// if the aggregate function is not distinct, we need to rewrite it like two phase aggregation
188+
if !(*distinct) {
189+
index += 1;
190+
let alias_str = format!("alias{}", index);
191+
inner_aggr_exprs.push(
192+
Expr::AggregateFunction(AggregateFunction::new(
193+
fun.clone(),
194+
args.clone(),
195+
false,
196+
None,
197+
None,
198+
))
199+
.alias(&alias_str),
200+
);
201+
Ok(Expr::AggregateFunction(AggregateFunction::new(
202+
fun.clone(),
203+
vec![col(&alias_str)],
204+
false,
205+
None,
206+
None,
207+
)))
208+
} else {
209+
Ok(Expr::AggregateFunction(AggregateFunction::new(
210+
fun.clone(),
211+
vec![col(SINGLE_DISTINCT_ALIAS)],
212+
false, // intentional to remove distinct here
213+
None,
214+
None,
215+
)))
216+
}
179217
}
180218
_ => Ok(aggr_expr.clone()),
181219
})
@@ -184,6 +222,7 @@ impl OptimizerRule for SingleDistinctToGroupBy {
184222
// construct the inner AggrPlan
185223
let inner_fields = inner_group_exprs
186224
.iter()
225+
.chain(inner_aggr_exprs.iter())
187226
.map(|expr| expr.to_field(input.schema()))
188227
.collect::<Result<Vec<_>>>()?;
189228
let inner_schema = DFSchema::new_with_metadata(
@@ -193,12 +232,12 @@ impl OptimizerRule for SingleDistinctToGroupBy {
193232
let inner_agg = LogicalPlan::Aggregate(Aggregate::try_new(
194233
input.clone(),
195234
inner_group_exprs,
196-
Vec::new(),
235+
inner_aggr_exprs,
197236
)?);
198237

199238
let outer_fields = outer_group_exprs
200239
.iter()
201-
.chain(new_aggr_exprs.iter())
240+
.chain(outer_aggr_exprs.iter())
202241
.map(|expr| expr.to_field(&inner_schema))
203242
.collect::<Result<Vec<_>>>()?;
204243
let outer_aggr_schema = Arc::new(DFSchema::new_with_metadata(
@@ -220,7 +259,7 @@ impl OptimizerRule for SingleDistinctToGroupBy {
220259
group_expr
221260
}
222261
})
223-
.chain(new_aggr_exprs.iter().enumerate().map(|(idx, expr)| {
262+
.chain(outer_aggr_exprs.iter().enumerate().map(|(idx, expr)| {
224263
let idx = idx + group_size;
225264
let name = fields[idx].qualified_name();
226265
columnize_expr(expr.clone().alias(name), &outer_aggr_schema)
@@ -230,7 +269,7 @@ impl OptimizerRule for SingleDistinctToGroupBy {
230269
let outer_aggr = LogicalPlan::Aggregate(Aggregate::try_new(
231270
Arc::new(inner_agg),
232271
outer_group_exprs,
233-
new_aggr_exprs,
272+
outer_aggr_exprs,
234273
)?);
235274

236275
Ok(Some(LogicalPlan::Projection(Projection::try_new(
@@ -262,7 +301,7 @@ mod tests {
262301
use datafusion_expr::expr::GroupingSet;
263302
use datafusion_expr::{
264303
col, count, count_distinct, lit, logical_plan::builder::LogicalPlanBuilder, max,
265-
AggregateFunction,
304+
min, sum, AggregateFunction,
266305
};
267306

268307
fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> {
@@ -478,4 +517,181 @@ mod tests {
478517

479518
assert_optimized_plan_equal(&plan, expected)
480519
}
520+
521+
#[test]
522+
fn two_distinct_and_one_common() -> Result<()> {
523+
let table_scan = test_table_scan()?;
524+
525+
let plan = LogicalPlanBuilder::from(table_scan)
526+
.aggregate(
527+
vec![col("a")],
528+
vec![
529+
sum(col("c")),
530+
count_distinct(col("b")),
531+
Expr::AggregateFunction(expr::AggregateFunction::new(
532+
AggregateFunction::Max,
533+
vec![col("b")],
534+
true,
535+
None,
536+
None,
537+
)),
538+
],
539+
)?
540+
.build()?;
541+
// Should work
542+
let expected = "Projection: test.a, SUM(alias2) AS SUM(test.c), COUNT(alias1) AS COUNT(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, SUM(test.c):UInt64;N, COUNT(DISTINCT test.b):Int64;N, MAX(DISTINCT test.b):UInt32;N]\
543+
\n Aggregate: groupBy=[[test.a]], aggr=[[SUM(alias2), COUNT(alias1), MAX(alias1)]] [a:UInt32, SUM(alias2):UInt64;N, COUNT(alias1):Int64;N, MAX(alias1):UInt32;N]\
544+
\n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[SUM(test.c) AS alias2]] [a:UInt32, alias1:UInt32, alias2:UInt64;N]\
545+
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
546+
547+
assert_optimized_plan_equal(&plan, expected)
548+
}
549+
550+
#[test]
551+
fn one_distinctand_and_two_common() -> Result<()> {
552+
let table_scan = test_table_scan()?;
553+
554+
let plan = LogicalPlanBuilder::from(table_scan)
555+
.aggregate(
556+
vec![col("a")],
557+
vec![sum(col("c")), max(col("c")), count_distinct(col("b"))],
558+
)?
559+
.build()?;
560+
// Should work
561+
let expected = "Projection: test.a, SUM(alias2) AS SUM(test.c), MAX(alias3) AS MAX(test.c), COUNT(alias1) AS COUNT(DISTINCT test.b) [a:UInt32, SUM(test.c):UInt64;N, MAX(test.c):UInt32;N, COUNT(DISTINCT test.b):Int64;N]\
562+
\n Aggregate: groupBy=[[test.a]], aggr=[[SUM(alias2), MAX(alias3), COUNT(alias1)]] [a:UInt32, SUM(alias2):UInt64;N, MAX(alias3):UInt32;N, COUNT(alias1):Int64;N]\
563+
\n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[SUM(test.c) AS alias2, MAX(test.c) AS alias3]] [a:UInt32, alias1:UInt32, alias2:UInt64;N, alias3:UInt32;N]\
564+
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
565+
566+
assert_optimized_plan_equal(&plan, expected)
567+
}
568+
569+
#[test]
570+
fn one_distinct_and_one_common() -> Result<()> {
571+
let table_scan = test_table_scan()?;
572+
573+
let plan = LogicalPlanBuilder::from(table_scan)
574+
.aggregate(
575+
vec![col("c")],
576+
vec![min(col("a")), count_distinct(col("b"))],
577+
)?
578+
.build()?;
579+
// Should work
580+
let expected = "Projection: test.c, MIN(alias2) AS MIN(test.a), COUNT(alias1) AS COUNT(DISTINCT test.b) [c:UInt32, MIN(test.a):UInt32;N, COUNT(DISTINCT test.b):Int64;N]\
581+
\n Aggregate: groupBy=[[test.c]], aggr=[[MIN(alias2), COUNT(alias1)]] [c:UInt32, MIN(alias2):UInt32;N, COUNT(alias1):Int64;N]\
582+
\n Aggregate: groupBy=[[test.c, test.b AS alias1]], aggr=[[MIN(test.a) AS alias2]] [c:UInt32, alias1:UInt32, alias2:UInt32;N]\
583+
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
584+
585+
assert_optimized_plan_equal(&plan, expected)
586+
}
587+
588+
#[test]
589+
fn common_with_filter() -> Result<()> {
590+
let table_scan = test_table_scan()?;
591+
592+
// SUM(a) FILTER (WHERE a > 5)
593+
let expr = Expr::AggregateFunction(expr::AggregateFunction::new(
594+
AggregateFunction::Sum,
595+
vec![col("a")],
596+
false,
597+
Some(Box::new(col("a").gt(lit(5)))),
598+
None,
599+
));
600+
let plan = LogicalPlanBuilder::from(table_scan)
601+
.aggregate(vec![col("c")], vec![expr, count_distinct(col("b"))])?
602+
.build()?;
603+
// Do nothing
604+
let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a) FILTER (WHERE test.a > Int32(5)), COUNT(DISTINCT test.b)]] [c:UInt32, SUM(test.a) FILTER (WHERE test.a > Int32(5)):UInt64;N, COUNT(DISTINCT test.b):Int64;N]\
605+
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
606+
607+
assert_optimized_plan_equal(&plan, expected)
608+
}
609+
610+
#[test]
611+
fn distinct_with_filter() -> Result<()> {
612+
let table_scan = test_table_scan()?;
613+
614+
// COUNT(DISTINCT a) FILTER (WHERE a > 5)
615+
let expr = Expr::AggregateFunction(expr::AggregateFunction::new(
616+
AggregateFunction::Count,
617+
vec![col("a")],
618+
true,
619+
Some(Box::new(col("a").gt(lit(5)))),
620+
None,
621+
));
622+
let plan = LogicalPlanBuilder::from(table_scan)
623+
.aggregate(vec![col("c")], vec![sum(col("a")), expr])?
624+
.build()?;
625+
// Do nothing
626+
let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a), COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5))]] [c:UInt32, SUM(test.a):UInt64;N, COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)):Int64;N]\
627+
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
628+
629+
assert_optimized_plan_equal(&plan, expected)
630+
}
631+
632+
#[test]
633+
fn common_with_order_by() -> Result<()> {
634+
let table_scan = test_table_scan()?;
635+
636+
// SUM(a ORDER BY a)
637+
let expr = Expr::AggregateFunction(expr::AggregateFunction::new(
638+
AggregateFunction::Sum,
639+
vec![col("a")],
640+
false,
641+
None,
642+
Some(vec![col("a")]),
643+
));
644+
let plan = LogicalPlanBuilder::from(table_scan)
645+
.aggregate(vec![col("c")], vec![expr, count_distinct(col("b"))])?
646+
.build()?;
647+
// Do nothing
648+
let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a) ORDER BY [test.a], COUNT(DISTINCT test.b)]] [c:UInt32, SUM(test.a) ORDER BY [test.a]:UInt64;N, COUNT(DISTINCT test.b):Int64;N]\
649+
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
650+
651+
assert_optimized_plan_equal(&plan, expected)
652+
}
653+
654+
#[test]
655+
fn distinct_with_order_by() -> Result<()> {
656+
let table_scan = test_table_scan()?;
657+
658+
// COUNT(DISTINCT a ORDER BY a)
659+
let expr = Expr::AggregateFunction(expr::AggregateFunction::new(
660+
AggregateFunction::Count,
661+
vec![col("a")],
662+
true,
663+
None,
664+
Some(vec![col("a")]),
665+
));
666+
let plan = LogicalPlanBuilder::from(table_scan)
667+
.aggregate(vec![col("c")], vec![sum(col("a")), expr])?
668+
.build()?;
669+
// Do nothing
670+
let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a), COUNT(DISTINCT test.a) ORDER BY [test.a]]] [c:UInt32, SUM(test.a):UInt64;N, COUNT(DISTINCT test.a) ORDER BY [test.a]:Int64;N]\
671+
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
672+
673+
assert_optimized_plan_equal(&plan, expected)
674+
}
675+
676+
#[test]
677+
fn aggregate_with_filter_and_order_by() -> Result<()> {
678+
let table_scan = test_table_scan()?;
679+
680+
// COUNT(DISTINCT a ORDER BY a) FILTER (WHERE a > 5)
681+
let expr = Expr::AggregateFunction(expr::AggregateFunction::new(
682+
AggregateFunction::Count,
683+
vec![col("a")],
684+
true,
685+
Some(Box::new(col("a").gt(lit(5)))),
686+
Some(vec![col("a")]),
687+
));
688+
let plan = LogicalPlanBuilder::from(table_scan)
689+
.aggregate(vec![col("c")], vec![sum(col("a")), expr])?
690+
.build()?;
691+
// Do nothing
692+
let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a), COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a]]] [c:UInt32, SUM(test.a):UInt64;N, COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a]:Int64;N]\
693+
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
694+
695+
assert_optimized_plan_equal(&plan, expected)
696+
}
481697
}

0 commit comments

Comments
 (0)