Skip to content

Commit 37dc663

Browse files
committed
Add grouping_id to the logical plan
1 parent e82e069 commit 37dc663

File tree

6 files changed

+99
-97
lines changed

6 files changed

+99
-97
lines changed

datafusion/expr/src/logical_plan/plan.rs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use std::cmp::Ordering;
2121
use std::collections::{HashMap, HashSet};
2222
use std::fmt::{self, Debug, Display, Formatter};
2323
use std::hash::{Hash, Hasher};
24-
use std::sync::Arc;
24+
use std::sync::{Arc, OnceLock};
2525

2626
use super::dml::CopyTo;
2727
use super::DdlStatement;
@@ -2964,6 +2964,10 @@ impl Aggregate {
29642964
.into_iter()
29652965
.map(|(q, f)| (q, f.as_ref().clone().with_nullable(true).into()))
29662966
.collect::<Vec<_>>();
2967+
qualified_fields.push((
2968+
None,
2969+
Field::new(Self::INTERNAL_GROUPING_ID, DataType::UInt8, false).into(),
2970+
));
29672971
}
29682972

29692973
qualified_fields.extend(exprlist_to_fields(aggr_expr.as_slice(), &input)?);
@@ -3015,9 +3019,19 @@ impl Aggregate {
30153019
})
30163020
}
30173021

3022+
fn is_grouping_set(&self) -> bool {
3023+
matches!(self.group_expr.as_slice(), [Expr::GroupingSet(_)])
3024+
}
3025+
30183026
/// Get the output expressions.
30193027
fn output_expressions(&self) -> Result<Vec<&Expr>> {
3028+
static INTERNAL_ID_EXPR: OnceLock<Expr> = OnceLock::new();
30203029
let mut exprs = grouping_set_to_exprlist(self.group_expr.as_slice())?;
3030+
if self.is_grouping_set() {
3031+
exprs.push(INTERNAL_ID_EXPR.get_or_init(|| {
3032+
Expr::Column(Column::from_name(Self::INTERNAL_GROUPING_ID))
3033+
}));
3034+
}
30213035
exprs.extend(self.aggr_expr.iter());
30223036
debug_assert!(exprs.len() == self.schema.fields().len());
30233037
Ok(exprs)
@@ -3029,6 +3043,8 @@ impl Aggregate {
30293043
pub fn group_expr_len(&self) -> Result<usize> {
30303044
grouping_set_expr_count(&self.group_expr)
30313045
}
3046+
3047+
pub const INTERNAL_GROUPING_ID: &str = "__grouping_id";
30323048
}
30333049

30343050
// Manual implementation needed because of `schema` field. Comparison excludes this field.

datafusion/expr/src/utils.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result<usize> {
6565
"Invalid group by expressions, GroupingSet must be the only expression"
6666
);
6767
}
68-
Ok(grouping_set.distinct_expr().len())
68+
// Groupings sets have an additional interal column for the grouping id
69+
Ok(grouping_set.distinct_expr().len() + 1)
6970
} else {
7071
Ok(group_expr.len())
7172
}

datafusion/physical-optimizer/src/combine_partial_final_agg.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ fn can_combine(final_agg: GroupExprsRef, partial_agg: GroupExprsRef) -> bool {
135135

136136
// Compare output expressions of the partial, and input expressions of the final operator.
137137
physical_exprs_equal(
138-
&input_group_by.output_exprs(&AggregateMode::Partial),
138+
&input_group_by.output_exprs(),
139139
&final_group_by.input_exprs(),
140140
) && input_group_by.groups() == final_group_by.groups()
141141
&& input_group_by.null_expr().len() == final_group_by.null_expr().len()

datafusion/physical-plan/src/aggregates/mod.rs

Lines changed: 66 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ use arrow_schema::DataType;
4040
use datafusion_common::stats::Precision;
4141
use datafusion_common::{internal_err, not_impl_err, Result};
4242
use datafusion_execution::TaskContext;
43-
use datafusion_expr::Accumulator;
43+
use datafusion_expr::{Accumulator, Aggregate};
4444
use datafusion_physical_expr::{
4545
equivalence::{collapse_lex_req, ProjectionMapping},
4646
expressions::Column,
@@ -110,8 +110,6 @@ impl AggregateMode {
110110
}
111111
}
112112

113-
const INTERNAL_GROUPING_ID: &str = "grouping_id";
114-
115113
/// Represents `GROUP BY` clause in the plan (including the more general GROUPING SET)
116114
/// In the case of a simple `GROUP BY a, b` clause, this will contain the expression [a, b]
117115
/// and a single group [false, false].
@@ -141,10 +139,6 @@ pub struct PhysicalGroupBy {
141139
/// expression in null_expr. If `groups[i][j]` is true, then the
142140
/// j-th expression in the i-th group is NULL, otherwise it is `expr[j]`.
143141
groups: Vec<Vec<bool>>,
144-
// The number of internal expressions that are used to implement grouping
145-
// sets. These output are removed from the final output and not in `expr`
146-
// as they are generated based on the value in `groups`
147-
num_internal_exprs: usize,
148142
}
149143

150144
impl PhysicalGroupBy {
@@ -154,12 +148,10 @@ impl PhysicalGroupBy {
154148
null_expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
155149
groups: Vec<Vec<bool>>,
156150
) -> Self {
157-
let num_internal_exprs = if !null_expr.is_empty() { 1 } else { 0 };
158151
Self {
159152
expr,
160153
null_expr,
161154
groups,
162-
num_internal_exprs,
163155
}
164156
}
165157

@@ -171,7 +163,6 @@ impl PhysicalGroupBy {
171163
expr,
172164
null_expr: vec![],
173165
groups: vec![vec![false; num_exprs]],
174-
num_internal_exprs: 0,
175166
}
176167
}
177168

@@ -222,20 +213,17 @@ impl PhysicalGroupBy {
222213
}
223214

224215
/// The number of expressions in the output schema.
225-
fn num_output_exprs(&self, mode: &AggregateMode) -> usize {
216+
fn num_output_exprs(&self) -> usize {
226217
let mut num_exprs = self.expr.len();
227218
if !self.is_single() {
228-
num_exprs += self.num_internal_exprs;
229-
}
230-
if *mode != AggregateMode::Partial {
231-
num_exprs -= self.num_internal_exprs;
219+
num_exprs += 1
232220
}
233221
num_exprs
234222
}
235223

236224
/// Return grouping expressions as they occur in the output schema.
237-
pub fn output_exprs(&self, mode: &AggregateMode) -> Vec<Arc<dyn PhysicalExpr>> {
238-
let num_output_exprs = self.num_output_exprs(mode);
225+
pub fn output_exprs(&self) -> Vec<Arc<dyn PhysicalExpr>> {
226+
let num_output_exprs = self.num_output_exprs();
239227
let mut output_exprs = Vec::with_capacity(num_output_exprs);
240228
output_exprs.extend(
241229
self.expr
@@ -244,9 +232,11 @@ impl PhysicalGroupBy {
244232
.take(num_output_exprs)
245233
.map(|(index, (_, name))| Arc::new(Column::new(name, index)) as _),
246234
);
247-
if !self.is_single() && *mode == AggregateMode::Partial {
248-
output_exprs
249-
.push(Arc::new(Column::new(INTERNAL_GROUPING_ID, self.expr.len())) as _);
235+
if !self.is_single() {
236+
output_exprs.push(Arc::new(Column::new(
237+
Aggregate::INTERNAL_GROUPING_ID,
238+
self.expr.len(),
239+
)) as _);
250240
}
251241
output_exprs
252242
}
@@ -256,7 +246,7 @@ impl PhysicalGroupBy {
256246
if self.is_single() {
257247
self.expr.len()
258248
} else {
259-
self.expr.len() + self.num_internal_exprs
249+
self.expr.len() + 1
260250
}
261251
}
262252

@@ -290,7 +280,7 @@ impl PhysicalGroupBy {
290280
}
291281
if !self.is_single() {
292282
fields.push(Field::new(
293-
INTERNAL_GROUPING_ID,
283+
Aggregate::INTERNAL_GROUPING_ID,
294284
self.grouping_id_type(),
295285
false,
296286
));
@@ -302,35 +292,29 @@ impl PhysicalGroupBy {
302292
///
303293
/// This might be different from the `group_fields` that might contain internal expressions that
304294
/// should not be part of the output schema.
305-
fn output_fields(
306-
&self,
307-
input_schema: &Schema,
308-
mode: &AggregateMode,
309-
) -> Result<Vec<Field>> {
295+
fn output_fields(&self, input_schema: &Schema) -> Result<Vec<Field>> {
310296
let mut fields = self.group_fields(input_schema)?;
311-
fields.truncate(self.num_output_exprs(mode));
297+
fields.truncate(self.num_output_exprs());
312298
Ok(fields)
313299
}
314300

315301
/// Returns the `PhysicalGroupBy` for a final aggregation if `self` is used for a partial
316302
/// aggregation.
317303
pub fn as_final(&self) -> PhysicalGroupBy {
318-
let expr: Vec<_> = self
319-
.output_exprs(&AggregateMode::Partial)
320-
.into_iter()
321-
.zip(
322-
self.expr
323-
.iter()
324-
.map(|t| t.1.clone())
325-
.chain(std::iter::once(INTERNAL_GROUPING_ID.to_owned())),
326-
)
327-
.collect();
304+
let expr: Vec<_> =
305+
self.output_exprs()
306+
.into_iter()
307+
.zip(
308+
self.expr.iter().map(|t| t.1.clone()).chain(std::iter::once(
309+
Aggregate::INTERNAL_GROUPING_ID.to_owned(),
310+
)),
311+
)
312+
.collect();
328313
let num_exprs = expr.len();
329314
Self {
330315
expr,
331316
null_expr: vec![],
332317
groups: vec![vec![false; num_exprs]],
333-
num_internal_exprs: self.num_internal_exprs,
334318
}
335319
}
336320
}
@@ -567,7 +551,7 @@ impl AggregateExec {
567551

568552
/// Grouping expressions as they occur in the output schema
569553
pub fn output_group_expr(&self) -> Vec<Arc<dyn PhysicalExpr>> {
570-
self.group_by.output_exprs(&AggregateMode::Partial)
554+
self.group_by.output_exprs()
571555
}
572556

573557
/// Aggregate expressions
@@ -901,9 +885,8 @@ fn create_schema(
901885
aggr_expr: &[AggregateFunctionExpr],
902886
mode: AggregateMode,
903887
) -> Result<Schema> {
904-
let mut fields =
905-
Vec::with_capacity(group_by.num_output_exprs(&mode) + aggr_expr.len());
906-
fields.extend(group_by.output_fields(input_schema, &mode)?);
888+
let mut fields = Vec::with_capacity(group_by.num_output_exprs() + aggr_expr.len());
889+
fields.extend(group_by.output_fields(input_schema)?);
907890

908891
match mode {
909892
AggregateMode::Partial => {
@@ -1506,49 +1489,49 @@ mod tests {
15061489
// In spill mode, we test with the limited memory, if the mem usage exceeds,
15071490
// we trigger the early emit rule, which turns out the partial aggregate result.
15081491
vec![
1509-
"+---+-----+-------------+-----------------+",
1510-
"| a | b | grouping_id | COUNT(1)[count] |",
1511-
"+---+-----+-------------+-----------------+",
1512-
"| | 1.0 | 2 | 1 |",
1513-
"| | 1.0 | 2 | 1 |",
1514-
"| | 2.0 | 2 | 1 |",
1515-
"| | 2.0 | 2 | 1 |",
1516-
"| | 3.0 | 2 | 1 |",
1517-
"| | 3.0 | 2 | 1 |",
1518-
"| | 4.0 | 2 | 1 |",
1519-
"| | 4.0 | 2 | 1 |",
1520-
"| 2 | | 1 | 1 |",
1521-
"| 2 | | 1 | 1 |",
1522-
"| 2 | 1.0 | 0 | 1 |",
1523-
"| 2 | 1.0 | 0 | 1 |",
1524-
"| 3 | | 1 | 1 |",
1525-
"| 3 | | 1 | 2 |",
1526-
"| 3 | 2.0 | 0 | 2 |",
1527-
"| 3 | 3.0 | 0 | 1 |",
1528-
"| 4 | | 1 | 1 |",
1529-
"| 4 | | 1 | 2 |",
1530-
"| 4 | 3.0 | 0 | 1 |",
1531-
"| 4 | 4.0 | 0 | 2 |",
1532-
"+---+-----+-------------+-----------------+",
1492+
"+---+-----+---------------+-----------------+",
1493+
"| a | b | __grouping_id | COUNT(1)[count] |",
1494+
"+---+-----+---------------+-----------------+",
1495+
"| | 1.0 | 2 | 1 |",
1496+
"| | 1.0 | 2 | 1 |",
1497+
"| | 2.0 | 2 | 1 |",
1498+
"| | 2.0 | 2 | 1 |",
1499+
"| | 3.0 | 2 | 1 |",
1500+
"| | 3.0 | 2 | 1 |",
1501+
"| | 4.0 | 2 | 1 |",
1502+
"| | 4.0 | 2 | 1 |",
1503+
"| 2 | | 1 | 1 |",
1504+
"| 2 | | 1 | 1 |",
1505+
"| 2 | 1.0 | 0 | 1 |",
1506+
"| 2 | 1.0 | 0 | 1 |",
1507+
"| 3 | | 1 | 1 |",
1508+
"| 3 | | 1 | 2 |",
1509+
"| 3 | 2.0 | 0 | 2 |",
1510+
"| 3 | 3.0 | 0 | 1 |",
1511+
"| 4 | | 1 | 1 |",
1512+
"| 4 | | 1 | 2 |",
1513+
"| 4 | 3.0 | 0 | 1 |",
1514+
"| 4 | 4.0 | 0 | 2 |",
1515+
"+---+-----+---------------+-----------------+",
15331516
]
15341517
} else {
15351518
vec![
1536-
"+---+-----+-------------+-----------------+",
1537-
"| a | b | grouping_id | COUNT(1)[count] |",
1538-
"+---+-----+-------------+-----------------+",
1539-
"| | 1.0 | 2 | 2 |",
1540-
"| | 2.0 | 2 | 2 |",
1541-
"| | 3.0 | 2 | 2 |",
1542-
"| | 4.0 | 2 | 2 |",
1543-
"| 2 | | 1 | 2 |",
1544-
"| 2 | 1.0 | 0 | 2 |",
1545-
"| 3 | | 1 | 3 |",
1546-
"| 3 | 2.0 | 0 | 2 |",
1547-
"| 3 | 3.0 | 0 | 1 |",
1548-
"| 4 | | 1 | 3 |",
1549-
"| 4 | 3.0 | 0 | 1 |",
1550-
"| 4 | 4.0 | 0 | 2 |",
1551-
"+---+-----+-------------+-----------------+",
1519+
"+---+-----+---------------+-----------------+",
1520+
"| a | b | __grouping_id | COUNT(1)[count] |",
1521+
"+---+-----+---------------+-----------------+",
1522+
"| | 1.0 | 2 | 2 |",
1523+
"| | 2.0 | 2 | 2 |",
1524+
"| | 3.0 | 2 | 2 |",
1525+
"| | 4.0 | 2 | 2 |",
1526+
"| 2 | | 1 | 2 |",
1527+
"| 2 | 1.0 | 0 | 2 |",
1528+
"| 3 | | 1 | 3 |",
1529+
"| 3 | 2.0 | 0 | 2 |",
1530+
"| 3 | 3.0 | 0 | 1 |",
1531+
"| 4 | | 1 | 3 |",
1532+
"| 4 | 3.0 | 0 | 1 |",
1533+
"| 4 | 4.0 | 0 | 2 |",
1534+
"+---+-----+---------------+-----------------+",
15521535
]
15531536
};
15541537
assert_batches_sorted_eq!(expected, &result);

datafusion/physical-plan/src/aggregates/row_hash.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,7 @@ impl GroupedHashAggregateStream {
491491
let (ordering, _) = agg
492492
.properties()
493493
.equivalence_properties()
494-
.find_longest_permutation(&agg_group_by.output_exprs(&agg.mode));
494+
.find_longest_permutation(&agg_group_by.output_exprs());
495495
let group_ordering = GroupOrdering::try_new(
496496
&group_schema,
497497
&agg.input_order_mode,
@@ -845,7 +845,7 @@ impl GroupedHashAggregateStream {
845845

846846
let mut output = self.group_values.emit(emit_to)?;
847847
if !spilling {
848-
output.truncate(self.group_by.num_output_exprs(&self.mode));
848+
output.truncate(self.group_by.num_output_exprs());
849849
}
850850
if let EmitTo::First(n) = emit_to {
851851
self.group_ordering.remove_groups(n);

datafusion/sqllogictest/test_files/aggregate.slt

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4883,16 +4883,18 @@ query TT
48834883
EXPLAIN SELECT c2, c3 FROM aggregate_test_100 group by rollup(c2, c3) limit 3;
48844884
----
48854885
logical_plan
4886-
01)Limit: skip=0, fetch=3
4887-
02)--Aggregate: groupBy=[[ROLLUP (aggregate_test_100.c2, aggregate_test_100.c3)]], aggr=[[]]
4888-
03)----TableScan: aggregate_test_100 projection=[c2, c3]
4886+
01)Projection: aggregate_test_100.c2, aggregate_test_100.c3
4887+
02)--Limit: skip=0, fetch=3
4888+
03)----Aggregate: groupBy=[[ROLLUP (aggregate_test_100.c2, aggregate_test_100.c3)]], aggr=[[]]
4889+
04)------TableScan: aggregate_test_100 projection=[c2, c3]
48894890
physical_plan
4890-
01)GlobalLimitExec: skip=0, fetch=3
4891-
02)--AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3, grouping_id@2 as grouping_id], aggr=[], lim=[3]
4892-
03)----CoalescePartitionsExec
4893-
04)------AggregateExec: mode=Partial, gby=[(NULL as c2, NULL as c3), (c2@0 as c2, NULL as c3), (c2@0 as c2, c3@1 as c3)], aggr=[]
4894-
05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
4895-
06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c3], has_header=true
4891+
01)ProjectionExec: expr=[c2@0 as c2, c3@1 as c3]
4892+
02)--GlobalLimitExec: skip=0, fetch=3
4893+
03)----AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3, __grouping_id@2 as __grouping_id], aggr=[], lim=[3]
4894+
04)------CoalescePartitionsExec
4895+
05)--------AggregateExec: mode=Partial, gby=[(NULL as c2, NULL as c3), (c2@0 as c2, NULL as c3), (c2@0 as c2, c3@1 as c3)], aggr=[]
4896+
06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
4897+
07)------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c3], has_header=true
48964898

48974899
query II
48984900
SELECT c2, c3 FROM aggregate_test_100 group by rollup(c2, c3) limit 3;

0 commit comments

Comments
 (0)