@@ -31,6 +31,7 @@ use arrow::datatypes::{DataType, SchemaRef};
3131use arrow:: record_batch:: RecordBatch ;
3232use arrow_array:: * ;
3333use datafusion_common:: Result ;
34+ use datafusion_execution:: memory_pool:: MemoryReservation ;
3435use futures:: Stream ;
3536use std:: pin:: Pin ;
3637use std:: task:: { ready, Context , Poll } ;
@@ -42,14 +43,15 @@ macro_rules! primitive_merge_helper {
4243}
4344
4445macro_rules! merge_helper {
45- ( $t: ty, $sort: ident, $streams: ident, $schema: ident, $tracking_metrics: ident, $batch_size: ident, $fetch: ident) => { {
46+ ( $t: ty, $sort: ident, $streams: ident, $schema: ident, $tracking_metrics: ident, $batch_size: ident, $fetch: ident, $reservation : ident ) => { {
4647 let streams = FieldCursorStream :: <$t>:: new( $sort, $streams) ;
4748 return Ok ( Box :: pin( SortPreservingMergeStream :: new(
4849 Box :: new( streams) ,
4950 $schema,
5051 $tracking_metrics,
5152 $batch_size,
5253 $fetch,
54+ $reservation,
5355 ) ) ) ;
5456 } } ;
5557}
@@ -63,28 +65,36 @@ pub fn streaming_merge(
6365 metrics : BaselineMetrics ,
6466 batch_size : usize ,
6567 fetch : Option < usize > ,
68+ reservation : MemoryReservation ,
6669) -> Result < SendableRecordBatchStream > {
6770 // Special case single column comparisons with optimized cursor implementations
6871 if expressions. len ( ) == 1 {
6972 let sort = expressions[ 0 ] . clone ( ) ;
7073 let data_type = sort. expr . data_type ( schema. as_ref ( ) ) ?;
7174 downcast_primitive ! {
72- data_type => ( primitive_merge_helper, sort, streams, schema, metrics, batch_size, fetch) ,
73- DataType :: Utf8 => merge_helper!( StringArray , sort, streams, schema, metrics, batch_size, fetch)
74- DataType :: LargeUtf8 => merge_helper!( LargeStringArray , sort, streams, schema, metrics, batch_size, fetch)
75- DataType :: Binary => merge_helper!( BinaryArray , sort, streams, schema, metrics, batch_size, fetch)
76- DataType :: LargeBinary => merge_helper!( LargeBinaryArray , sort, streams, schema, metrics, batch_size, fetch)
75+ data_type => ( primitive_merge_helper, sort, streams, schema, metrics, batch_size, fetch, reservation ) ,
76+ DataType :: Utf8 => merge_helper!( StringArray , sort, streams, schema, metrics, batch_size, fetch, reservation )
77+ DataType :: LargeUtf8 => merge_helper!( LargeStringArray , sort, streams, schema, metrics, batch_size, fetch, reservation )
78+ DataType :: Binary => merge_helper!( BinaryArray , sort, streams, schema, metrics, batch_size, fetch, reservation )
79+ DataType :: LargeBinary => merge_helper!( LargeBinaryArray , sort, streams, schema, metrics, batch_size, fetch, reservation )
7780 _ => { }
7881 }
7982 }
8083
81- let streams = RowCursorStream :: try_new ( schema. as_ref ( ) , expressions, streams) ?;
84+ let streams = RowCursorStream :: try_new (
85+ schema. as_ref ( ) ,
86+ expressions,
87+ streams,
88+ reservation. new_empty ( ) ,
89+ ) ?;
90+
8291 Ok ( Box :: pin ( SortPreservingMergeStream :: new (
8392 Box :: new ( streams) ,
8493 schema,
8594 metrics,
8695 batch_size,
8796 fetch,
97+ reservation,
8898 ) ) )
8999}
90100
@@ -162,11 +172,12 @@ impl<C: Cursor> SortPreservingMergeStream<C> {
162172 metrics : BaselineMetrics ,
163173 batch_size : usize ,
164174 fetch : Option < usize > ,
175+ reservation : MemoryReservation ,
165176 ) -> Self {
166177 let stream_count = streams. partitions ( ) ;
167178
168179 Self {
169- in_progress : BatchBuilder :: new ( schema, stream_count, batch_size) ,
180+ in_progress : BatchBuilder :: new ( schema, stream_count, batch_size, reservation ) ,
170181 streams,
171182 metrics,
172183 aborted : false ,
@@ -197,8 +208,7 @@ impl<C: Cursor> SortPreservingMergeStream<C> {
197208 Some ( Err ( e) ) => Poll :: Ready ( Err ( e) ) ,
198209 Some ( Ok ( ( cursor, batch) ) ) => {
199210 self . cursors [ idx] = Some ( cursor) ;
200- self . in_progress . push_batch ( idx, batch) ;
201- Poll :: Ready ( Ok ( ( ) ) )
211+ Poll :: Ready ( self . in_progress . push_batch ( idx, batch) )
202212 }
203213 }
204214 }
0 commit comments