@@ -49,8 +49,10 @@ use std::sync::Arc;
4949mod group_values;
5050mod no_grouping;
5151mod order;
52+ mod priority_queue;
5253mod row_hash;
5354
55+ use crate :: physical_plan:: aggregates:: priority_queue:: GroupedPriorityQueueAggregateStream ;
5456pub use datafusion_expr:: AggregateFunction ;
5557use datafusion_physical_expr:: aggregate:: is_order_sensitive;
5658pub use datafusion_physical_expr:: expressions:: create_aggregate_expr;
@@ -228,14 +230,16 @@ impl PartialEq for PhysicalGroupBy {
228230
229231enum StreamType {
230232 AggregateStream ( AggregateStream ) ,
231- GroupedHashAggregateStream ( GroupedHashAggregateStream ) ,
233+ GroupedHash ( GroupedHashAggregateStream ) ,
234+ GroupedPriorityQueue ( GroupedPriorityQueueAggregateStream ) ,
232235}
233236
234237impl From < StreamType > for SendableRecordBatchStream {
235238 fn from ( stream : StreamType ) -> Self {
236239 match stream {
237240 StreamType :: AggregateStream ( stream) => Box :: pin ( stream) ,
238- StreamType :: GroupedHashAggregateStream ( stream) => Box :: pin ( stream) ,
241+ StreamType :: GroupedHash ( stream) => Box :: pin ( stream) ,
242+ StreamType :: GroupedPriorityQueue ( stream) => Box :: pin ( stream) ,
239243 }
240244 }
241245}
@@ -265,6 +269,8 @@ pub struct AggregateExec {
265269 pub ( crate ) filter_expr : Vec < Option < Arc < dyn PhysicalExpr > > > ,
266270 /// (ORDER BY clause) expression for each aggregate expression
267271 pub ( crate ) order_by_expr : Vec < Option < LexOrdering > > ,
272+ /// Set if the output of this aggregation is truncated by a upstream sort/limit clause
273+ pub ( crate ) limit : Option < usize > ,
268274 /// Input plan, could be a partial aggregate or the input to the aggregate
269275 pub ( crate ) input : Arc < dyn ExecutionPlan > ,
270276 /// Schema after the aggregate is applied
@@ -669,6 +675,7 @@ impl AggregateExec {
669675 metrics : ExecutionPlanMetricsSet :: new ( ) ,
670676 aggregation_ordering,
671677 required_input_ordering,
678+ limit : None ,
672679 } )
673680 }
674681
@@ -717,15 +724,29 @@ impl AggregateExec {
717724 partition : usize ,
718725 context : Arc < TaskContext > ,
719726 ) -> Result < StreamType > {
727+ // no group by at all
720728 if self . group_by . expr . is_empty ( ) {
721- Ok ( StreamType :: AggregateStream ( AggregateStream :: new (
729+ return Ok ( StreamType :: AggregateStream ( AggregateStream :: new (
722730 self , context, partition,
723- ) ?) )
724- } else {
725- Ok ( StreamType :: GroupedHashAggregateStream (
726- GroupedHashAggregateStream :: new ( self , context, partition) ?,
727- ) )
731+ ) ?) ) ;
732+ }
733+
734+ // grouping by an expression that has a sort/limit upstream
735+ let is_minmax =
736+ GroupedPriorityQueueAggregateStream :: get_minmax_desc ( self ) . is_some ( ) ;
737+ if self . limit . is_some ( ) && is_minmax {
738+ println ! ( "Using limited priority queue aggregation" ) ;
739+ return Ok ( StreamType :: GroupedPriorityQueue (
740+ GroupedPriorityQueueAggregateStream :: new (
741+ self , context, partition, self . limit ,
742+ ) ?,
743+ ) ) ;
728744 }
745+
746+ // grouping by something else and we need to just materialize all results
747+ Ok ( StreamType :: GroupedHash ( GroupedHashAggregateStream :: new (
748+ self , context, partition,
749+ ) ?) )
729750 }
730751}
731752
@@ -1148,7 +1169,7 @@ fn evaluate(
11481169}
11491170
11501171/// Evaluates expressions against a record batch.
1151- fn evaluate_many (
1172+ pub fn evaluate_many (
11521173 expr : & [ Vec < Arc < dyn PhysicalExpr > > ] ,
11531174 batch : & RecordBatch ,
11541175) -> Result < Vec < Vec < ArrayRef > > > {
@@ -1171,7 +1192,17 @@ fn evaluate_optional(
11711192 . collect :: < Result < Vec < _ > > > ( )
11721193}
11731194
1174- fn evaluate_group_by (
1195+ /// Evaluate a group by expression against a `RecordBatch`
1196+ ///
1197+ /// Arguments:
1198+ /// `group_by`: the expression to evaluate
1199+ /// `batch`: the `RecordBatch` to evaluate against
1200+ ///
1201+ /// Returns: A Vec of Vecs of Array of results
1202+ /// The outer Vect appears to be for grouping sets
1203+ /// The inner Vect contains the results per expression
1204+ /// The inner-inner Array contains the results per row
1205+ pub fn evaluate_group_by (
11751206 group_by : & PhysicalGroupBy ,
11761207 batch : & RecordBatch ,
11771208) -> Result < Vec < Vec < ArrayRef > > > {
@@ -1840,10 +1871,10 @@ mod tests {
18401871 assert ! ( matches!( stream, StreamType :: AggregateStream ( _) ) ) ;
18411872 }
18421873 1 => {
1843- assert ! ( matches!( stream, StreamType :: GroupedHashAggregateStream ( _) ) ) ;
1874+ assert ! ( matches!( stream, StreamType :: GroupedHash ( _) ) ) ;
18441875 }
18451876 2 => {
1846- assert ! ( matches!( stream, StreamType :: GroupedHashAggregateStream ( _) ) ) ;
1877+ assert ! ( matches!( stream, StreamType :: GroupedHash ( _) ) ) ;
18471878 }
18481879 _ => panic ! ( "Unknown version: {version}" ) ,
18491880 }
0 commit comments