Skip to content

Commit 8729398

Browse files
squash
1 parent 3bda91a commit 8729398

File tree

13 files changed

+594
-32
lines changed

13 files changed

+594
-32
lines changed
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! An optimizer rule that detects aggregate operations that could use a limited bucket count
19+
20+
use crate::physical_optimizer::PhysicalOptimizerRule;
21+
use crate::physical_plan::aggregates::AggregateExec;
22+
use crate::physical_plan::sorts::sort::SortExec;
23+
use crate::physical_plan::ExecutionPlan;
24+
use datafusion_common::config::ConfigOptions;
25+
use datafusion_common::{DataFusionError, Result};
26+
use std::sync::Arc;
27+
28+
/// An optimizer rule that passes a `limit` hint to aggregations if the whole result is not needed
29+
pub struct LimitAggregation {}
30+
31+
impl LimitAggregation {
32+
/// Create a new `LimitAggregation`
33+
pub fn new() -> Self {
34+
Self {}
35+
}
36+
37+
fn recurse(plan: Arc<dyn ExecutionPlan>) -> Result<Arc<dyn ExecutionPlan>> {
38+
// Not a sort
39+
let sort = if let Some(sort) = plan.as_any().downcast_ref::<SortExec>() {
40+
sort
41+
} else {
42+
return Ok(plan);
43+
};
44+
45+
// Error if sorting with two inputs
46+
let children = sort.children();
47+
let child = match children.as_slice() {
48+
[] => Err(DataFusionError::Execution(
49+
"Sorts should have children".to_string(),
50+
))?,
51+
[child] => child,
52+
_ => Err(DataFusionError::Execution(
53+
"Sorts should have 1 child".to_string(),
54+
))?,
55+
};
56+
57+
// Sort doesn't have an aggregate before it
58+
let binding = (*child).as_any();
59+
let aggr = if let Some(aggr) = binding.downcast_ref::<AggregateExec>() {
60+
aggr
61+
} else {
62+
return Ok(plan);
63+
};
64+
65+
// We found what we want: clone, copy the limit down, and return modified node
66+
let mut new_aggr = AggregateExec::try_new(
67+
aggr.mode,
68+
aggr.group_by.clone(),
69+
aggr.aggr_expr.clone(),
70+
aggr.filter_expr.clone(),
71+
aggr.order_by_expr.clone(),
72+
aggr.input.clone(),
73+
aggr.input_schema.clone(),
74+
)?;
75+
new_aggr.limit = sort.fetch();
76+
let plan = Arc::new(SortExec::new(sort.expr().to_vec(), Arc::new(new_aggr)));
77+
Ok(plan)
78+
}
79+
}
80+
81+
impl Default for LimitAggregation {
82+
fn default() -> Self {
83+
Self::new()
84+
}
85+
}
86+
87+
impl PhysicalOptimizerRule for LimitAggregation {
88+
fn optimize(
89+
&self,
90+
plan: Arc<dyn ExecutionPlan>,
91+
_config: &ConfigOptions,
92+
) -> Result<Arc<dyn ExecutionPlan>> {
93+
LimitAggregation::recurse(plan.clone())
94+
}
95+
96+
fn name(&self) -> &str {
97+
"limit aggregation"
98+
}
99+
100+
fn schema_check(&self) -> bool {
101+
true
102+
}
103+
}

datafusion/core/src/physical_optimizer/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ pub mod coalesce_batches;
2626
pub mod combine_partial_final_agg;
2727
pub mod dist_enforcement;
2828
pub mod join_selection;
29+
pub mod limit_aggregation;
2930
pub mod optimizer;
3031
pub mod pipeline_checker;
3132
pub mod pruning;

datafusion/core/src/physical_optimizer/optimizer.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ use crate::physical_optimizer::coalesce_batches::CoalesceBatches;
2525
use crate::physical_optimizer::combine_partial_final_agg::CombinePartialFinalAggregate;
2626
use crate::physical_optimizer::dist_enforcement::EnforceDistribution;
2727
use crate::physical_optimizer::join_selection::JoinSelection;
28+
use crate::physical_optimizer::limit_aggregation::LimitAggregation;
2829
use crate::physical_optimizer::pipeline_checker::PipelineChecker;
2930
use crate::physical_optimizer::repartition::Repartition;
3031
use crate::physical_optimizer::sort_enforcement::EnforceSorting;
@@ -101,6 +102,7 @@ impl PhysicalOptimizer {
101102
// diagnostic error message when this happens. It makes no changes to the
102103
// given query plan; i.e. it only acts as a final gatekeeping rule.
103104
Arc::new(PipelineChecker::new()),
105+
Arc::new(LimitAggregation::new()),
104106
];
105107

106108
Self::with_rules(rules)

datafusion/core/src/physical_plan/aggregates/mod.rs

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,10 @@ use std::sync::Arc;
4949
mod group_values;
5050
mod no_grouping;
5151
mod order;
52+
mod priority_queue;
5253
mod row_hash;
5354

55+
use crate::physical_plan::aggregates::priority_queue::GroupedPriorityQueueAggregateStream;
5456
pub use datafusion_expr::AggregateFunction;
5557
use datafusion_physical_expr::aggregate::is_order_sensitive;
5658
pub use datafusion_physical_expr::expressions::create_aggregate_expr;
@@ -228,14 +230,16 @@ impl PartialEq for PhysicalGroupBy {
228230

229231
enum StreamType {
230232
AggregateStream(AggregateStream),
231-
GroupedHashAggregateStream(GroupedHashAggregateStream),
233+
GroupedHash(GroupedHashAggregateStream),
234+
GroupedPriorityQueue(GroupedPriorityQueueAggregateStream),
232235
}
233236

234237
impl From<StreamType> for SendableRecordBatchStream {
235238
fn from(stream: StreamType) -> Self {
236239
match stream {
237240
StreamType::AggregateStream(stream) => Box::pin(stream),
238-
StreamType::GroupedHashAggregateStream(stream) => Box::pin(stream),
241+
StreamType::GroupedHash(stream) => Box::pin(stream),
242+
StreamType::GroupedPriorityQueue(stream) => Box::pin(stream),
239243
}
240244
}
241245
}
@@ -265,6 +269,8 @@ pub struct AggregateExec {
265269
pub(crate) filter_expr: Vec<Option<Arc<dyn PhysicalExpr>>>,
266270
/// (ORDER BY clause) expression for each aggregate expression
267271
pub(crate) order_by_expr: Vec<Option<LexOrdering>>,
272+
/// Set if the output of this aggregation is truncated by a upstream sort/limit clause
273+
pub(crate) limit: Option<usize>,
268274
/// Input plan, could be a partial aggregate or the input to the aggregate
269275
pub(crate) input: Arc<dyn ExecutionPlan>,
270276
/// Schema after the aggregate is applied
@@ -669,6 +675,7 @@ impl AggregateExec {
669675
metrics: ExecutionPlanMetricsSet::new(),
670676
aggregation_ordering,
671677
required_input_ordering,
678+
limit: None,
672679
})
673680
}
674681

@@ -717,15 +724,29 @@ impl AggregateExec {
717724
partition: usize,
718725
context: Arc<TaskContext>,
719726
) -> Result<StreamType> {
727+
// no group by at all
720728
if self.group_by.expr.is_empty() {
721-
Ok(StreamType::AggregateStream(AggregateStream::new(
729+
return Ok(StreamType::AggregateStream(AggregateStream::new(
722730
self, context, partition,
723-
)?))
724-
} else {
725-
Ok(StreamType::GroupedHashAggregateStream(
726-
GroupedHashAggregateStream::new(self, context, partition)?,
727-
))
731+
)?));
732+
}
733+
734+
// grouping by an expression that has a sort/limit upstream
735+
let is_minmax =
736+
GroupedPriorityQueueAggregateStream::get_minmax_desc(self).is_some();
737+
if self.limit.is_some() && is_minmax {
738+
println!("Using limited priority queue aggregation");
739+
return Ok(StreamType::GroupedPriorityQueue(
740+
GroupedPriorityQueueAggregateStream::new(
741+
self, context, partition, self.limit,
742+
)?,
743+
));
728744
}
745+
746+
// grouping by something else and we need to just materialize all results
747+
Ok(StreamType::GroupedHash(GroupedHashAggregateStream::new(
748+
self, context, partition,
749+
)?))
729750
}
730751
}
731752

@@ -1148,7 +1169,7 @@ fn evaluate(
11481169
}
11491170

11501171
/// Evaluates expressions against a record batch.
1151-
fn evaluate_many(
1172+
pub fn evaluate_many(
11521173
expr: &[Vec<Arc<dyn PhysicalExpr>>],
11531174
batch: &RecordBatch,
11541175
) -> Result<Vec<Vec<ArrayRef>>> {
@@ -1171,7 +1192,17 @@ fn evaluate_optional(
11711192
.collect::<Result<Vec<_>>>()
11721193
}
11731194

1174-
fn evaluate_group_by(
1195+
/// Evaluate a group by expression against a `RecordBatch`
1196+
///
1197+
/// Arguments:
1198+
/// `group_by`: the expression to evaluate
1199+
/// `batch`: the `RecordBatch` to evaluate against
1200+
///
1201+
/// Returns: A Vec of Vecs of Array of results
1202+
/// The outer Vect appears to be for grouping sets
1203+
/// The inner Vect contains the results per expression
1204+
/// The inner-inner Array contains the results per row
1205+
pub fn evaluate_group_by(
11751206
group_by: &PhysicalGroupBy,
11761207
batch: &RecordBatch,
11771208
) -> Result<Vec<Vec<ArrayRef>>> {
@@ -1840,10 +1871,10 @@ mod tests {
18401871
assert!(matches!(stream, StreamType::AggregateStream(_)));
18411872
}
18421873
1 => {
1843-
assert!(matches!(stream, StreamType::GroupedHashAggregateStream(_)));
1874+
assert!(matches!(stream, StreamType::GroupedHash(_)));
18441875
}
18451876
2 => {
1846-
assert!(matches!(stream, StreamType::GroupedHashAggregateStream(_)));
1877+
assert!(matches!(stream, StreamType::GroupedHash(_)));
18471878
}
18481879
_ => panic!("Unknown version: {version}"),
18491880
}

0 commit comments

Comments
 (0)