1717
1818//! Aggregate without grouping columns
1919
20+ use super :: AggregateExec ;
2021use crate :: aggregates:: {
2122 aggregate_expressions, create_accumulators, finalize_aggregation, AccumulatorItem ,
2223 AggregateMode ,
2324} ;
25+ use crate :: filter:: batch_filter;
2426use crate :: metrics:: { BaselineMetrics , RecordOutput } ;
27+ use crate :: poll_budget:: PollBudget ;
2528use crate :: { RecordBatchStream , SendableRecordBatchStream } ;
2629use arrow:: datatypes:: SchemaRef ;
2730use arrow:: record_batch:: RecordBatch ;
2831use datafusion_common:: Result ;
32+ use datafusion_execution:: memory_pool:: { MemoryConsumer , MemoryReservation } ;
2933use datafusion_execution:: TaskContext ;
3034use datafusion_physical_expr:: PhysicalExpr ;
31- use futures:: stream:: BoxStream ;
35+ use futures:: stream:: { Stream , StreamExt } ;
36+ use futures:: FutureExt ;
3237use std:: borrow:: Cow ;
3338use std:: sync:: Arc ;
34- use std:: task:: { Context , Poll } ;
35-
36- use super :: AggregateExec ;
37- use crate :: filter:: batch_filter;
38- use crate :: poll_budget:: PollBudget ;
39- use datafusion_execution:: memory_pool:: { MemoryConsumer , MemoryReservation } ;
40- use futures:: stream:: { Stream , StreamExt } ;
39+ use std:: task:: { ready, Context , Poll } ;
4140
4241/// stream struct for aggregation without grouping columns
4342pub ( crate ) struct AggregateStream {
44- stream : BoxStream < ' static , Result < RecordBatch > > ,
45- schema : SchemaRef ,
46- }
47-
48- /// Actual implementation of [`AggregateStream`].
49- ///
50- /// This is wrapped into yet another struct because we need to interact with the async memory management subsystem
51- /// during poll. To have as little code "weirdness" as possible, we chose to just use [`BoxStream`] together with
52- /// [`futures::stream::unfold`].
53- ///
54- /// The latter requires a state object, which is [`AggregateStreamInner`].
55- struct AggregateStreamInner {
5643 schema : SchemaRef ,
5744 mode : AggregateMode ,
5845 input : SendableRecordBatchStream ,
@@ -62,6 +49,7 @@ struct AggregateStreamInner {
6249 accumulators : Vec < AccumulatorItem > ,
6350 reservation : MemoryReservation ,
6451 finished : bool ,
52+ poll_budget : PollBudget ,
6553}
6654
6755impl AggregateStream {
@@ -71,7 +59,6 @@ impl AggregateStream {
7159 context : Arc < TaskContext > ,
7260 partition : usize ,
7361 ) -> Result < Self > {
74- let agg_schema = Arc :: clone ( & agg. schema ) ;
7562 let agg_filter_expr = agg. filter_expr . clone ( ) ;
7663
7764 let baseline_metrics = BaselineMetrics :: new ( & agg. metrics , partition) ;
@@ -91,81 +78,17 @@ impl AggregateStream {
9178 let reservation = MemoryConsumer :: new ( format ! ( "AggregateStream[{partition}]" ) )
9279 . register ( context. memory_pool ( ) ) ;
9380
94- let inner = AggregateStreamInner {
81+ Ok ( AggregateStream {
9582 schema : Arc :: clone ( & agg. schema ) ,
9683 mode : agg. mode ,
97- input : PollBudget :: from ( context . as_ref ( ) ) . wrap_stream ( input ) ,
84+ input,
9885 baseline_metrics,
9986 aggregate_expressions,
10087 filter_expressions,
10188 accumulators,
10289 reservation,
10390 finished : false ,
104- } ;
105- let stream = futures:: stream:: unfold ( inner, |mut this| async move {
106- if this. finished {
107- return None ;
108- }
109-
110- let elapsed_compute = this. baseline_metrics . elapsed_compute ( ) ;
111-
112- loop {
113- let result = match this. input . next ( ) . await {
114- Some ( Ok ( batch) ) => {
115- let timer = elapsed_compute. timer ( ) ;
116- let result = aggregate_batch (
117- & this. mode ,
118- batch,
119- & mut this. accumulators ,
120- & this. aggregate_expressions ,
121- & this. filter_expressions ,
122- ) ;
123-
124- timer. done ( ) ;
125-
126- // allocate memory
127- // This happens AFTER we actually used the memory, but simplifies the whole accounting and we are OK with
128- // overshooting a bit. Also this means we either store the whole record batch or not.
129- match result
130- . and_then ( |allocated| this. reservation . try_grow ( allocated) )
131- {
132- Ok ( _) => continue ,
133- Err ( e) => Err ( e) ,
134- }
135- }
136- Some ( Err ( e) ) => Err ( e) ,
137- None => {
138- this. finished = true ;
139- let timer = this. baseline_metrics . elapsed_compute ( ) . timer ( ) ;
140- let result =
141- finalize_aggregation ( & mut this. accumulators , & this. mode )
142- . and_then ( |columns| {
143- RecordBatch :: try_new (
144- Arc :: clone ( & this. schema ) ,
145- columns,
146- )
147- . map_err ( Into :: into)
148- } )
149- . record_output ( & this. baseline_metrics ) ;
150-
151- timer. done ( ) ;
152-
153- result
154- }
155- } ;
156-
157- this. finished = true ;
158- return Some ( ( result, this) ) ;
159- }
160- } ) ;
161-
162- // seems like some consumers call this stream even after it returned `None`, so let's fuse the stream.
163- let stream = stream. fuse ( ) ;
164- let stream = Box :: pin ( stream) ;
165-
166- Ok ( Self {
167- schema : agg_schema,
168- stream,
91+ poll_budget : PollBudget :: from ( context. as_ref ( ) ) ,
16992 } )
17093 }
17194}
@@ -178,7 +101,61 @@ impl Stream for AggregateStream {
178101 cx : & mut Context < ' _ > ,
179102 ) -> Poll < Option < Self :: Item > > {
180103 let this = & mut * self ;
181- this. stream . poll_next_unpin ( cx)
104+
105+ if this. finished {
106+ return Poll :: Ready ( None ) ;
107+ }
108+
109+ let elapsed_compute = this. baseline_metrics . elapsed_compute ( ) ;
110+
111+ let mut consume_budget = this. poll_budget . consume_budget ( ) ;
112+
113+ loop {
114+ let result = match ready ! ( this. input. poll_next_unpin( cx) ) {
115+ Some ( Ok ( batch) ) => {
116+ let timer = elapsed_compute. timer ( ) ;
117+ let result = aggregate_batch (
118+ & this. mode ,
119+ batch,
120+ & mut this. accumulators ,
121+ & this. aggregate_expressions ,
122+ & this. filter_expressions ,
123+ ) ;
124+
125+ timer. done ( ) ;
126+
127+ // allocate memory
128+ // This happens AFTER we actually used the memory, but simplifies the whole accounting and we are OK with
129+ // overshooting a bit. Also this means we either store the whole record batch or not.
130+ match result
131+ . and_then ( |allocated| this. reservation . try_grow ( allocated) )
132+ {
133+ Ok ( _) => {
134+ ready ! ( consume_budget. poll_unpin( cx) ) ;
135+ continue ;
136+ }
137+ Err ( e) => Err ( e) ,
138+ }
139+ }
140+ Some ( Err ( e) ) => Err ( e) ,
141+ None => {
142+ let timer = this. baseline_metrics . elapsed_compute ( ) . timer ( ) ;
143+ let result = finalize_aggregation ( & mut this. accumulators , & this. mode )
144+ . and_then ( |columns| {
145+ RecordBatch :: try_new ( Arc :: clone ( & this. schema ) , columns)
146+ . map_err ( Into :: into)
147+ } )
148+ . record_output ( & this. baseline_metrics ) ;
149+
150+ timer. done ( ) ;
151+
152+ result
153+ }
154+ } ;
155+
156+ this. finished = true ;
157+ return Poll :: Ready ( Some ( result) ) ;
158+ }
182159 }
183160}
184161
0 commit comments