@@ -22,8 +22,11 @@ use crate::expressions::format_state_name;
2222use crate :: { AggregateExpr , PhysicalExpr } ;
2323use arrow:: array:: ArrayRef ;
2424use arrow:: datatypes:: { DataType , Field } ;
25+ use arrow_array:: Array ;
26+ use datafusion_common:: cast:: as_list_array;
27+ use datafusion_common:: utils:: wrap_into_list_array;
28+ use datafusion_common:: Result ;
2529use datafusion_common:: ScalarValue ;
26- use datafusion_common:: { internal_err, DataFusionError , Result } ;
2730use datafusion_expr:: Accumulator ;
2831use std:: any:: Any ;
2932use std:: sync:: Arc ;
@@ -102,7 +105,7 @@ impl PartialEq<dyn Any> for ArrayAgg {
102105
103106#[ derive( Debug ) ]
104107pub ( crate ) struct ArrayAggAccumulator {
105- values : Vec < ScalarValue > ,
108+ values : Vec < ArrayRef > ,
106109 datatype : DataType ,
107110}
108111
@@ -117,50 +120,60 @@ impl ArrayAggAccumulator {
117120}
118121
119122impl Accumulator for ArrayAggAccumulator {
123+ // Append value like Int64Array(1,2,3)
120124 fn update_batch ( & mut self , values : & [ ArrayRef ] ) -> Result < ( ) > {
121125 if values. is_empty ( ) {
122126 return Ok ( ( ) ) ;
123127 }
124128 assert ! ( values. len( ) == 1 , "array_agg can only take 1 param!" ) ;
125- let arr = & values[ 0 ] ;
126- ( 0 ..arr. len ( ) ) . try_for_each ( |index| {
127- let scalar = ScalarValue :: try_from_array ( arr, index) ?;
128- self . values . push ( scalar) ;
129- Ok ( ( ) )
130- } )
129+ let val = values[ 0 ] . clone ( ) ;
130+ self . values . push ( val) ;
131+ Ok ( ( ) )
131132 }
132133
134+ // Append value like ListArray(Int64Array(1,2,3), Int64Array(4,5,6))
133135 fn merge_batch ( & mut self , states : & [ ArrayRef ] ) -> Result < ( ) > {
134136 if states. is_empty ( ) {
135137 return Ok ( ( ) ) ;
136138 }
137139 assert ! ( states. len( ) == 1 , "array_agg states must be singleton!" ) ;
138- let arr = & states[ 0 ] ;
139- ( 0 ..arr. len ( ) ) . try_for_each ( |index| {
140- let scalar = ScalarValue :: try_from_array ( arr, index) ?;
141- if let ScalarValue :: List ( Some ( values) , _) = scalar {
142- self . values . extend ( values) ;
143- Ok ( ( ) )
144- } else {
145- internal_err ! ( "array_agg state must be list!" )
146- }
147- } )
140+
141+ let list_arr = as_list_array ( & states[ 0 ] ) ?;
142+ for arr in list_arr. iter ( ) . flatten ( ) {
143+ self . values . push ( arr) ;
144+ }
145+ Ok ( ( ) )
148146 }
149147
150148 fn state ( & self ) -> Result < Vec < ScalarValue > > {
151149 Ok ( vec ! [ self . evaluate( ) ?] )
152150 }
153151
154152 fn evaluate ( & self ) -> Result < ScalarValue > {
155- Ok ( ScalarValue :: new_list (
156- Some ( self . values . clone ( ) ) ,
157- self . datatype . clone ( ) ,
158- ) )
153+ // Transform Vec<ListArr> to ListArr
154+
155+ let element_arrays: Vec < & dyn Array > =
156+ self . values . iter ( ) . map ( |a| a. as_ref ( ) ) . collect ( ) ;
157+
158+ if element_arrays. is_empty ( ) {
159+ let arr = ScalarValue :: new_list ( & [ ] , & self . datatype ) ;
160+ return Ok ( ScalarValue :: List ( arr) ) ;
161+ }
162+
163+ let concated_array = arrow:: compute:: concat ( & element_arrays) ?;
164+ let list_array = wrap_into_list_array ( concated_array) ;
165+
166+ Ok ( ScalarValue :: List ( Arc :: new ( list_array) ) )
159167 }
160168
161169 fn size ( & self ) -> usize {
162- std:: mem:: size_of_val ( self ) + ScalarValue :: size_of_vec ( & self . values )
163- - std:: mem:: size_of_val ( & self . values )
170+ std:: mem:: size_of_val ( self )
171+ + ( std:: mem:: size_of :: < ArrayRef > ( ) * self . values . capacity ( ) )
172+ + self
173+ . values
174+ . iter ( )
175+ . map ( |arr| arr. get_array_memory_size ( ) )
176+ . sum :: < usize > ( )
164177 + self . datatype . size ( )
165178 - std:: mem:: size_of_val ( & self . datatype )
166179 }
@@ -176,72 +189,78 @@ mod tests {
176189 use arrow:: array:: Int32Array ;
177190 use arrow:: datatypes:: * ;
178191 use arrow:: record_batch:: RecordBatch ;
192+ use arrow_array:: Array ;
193+ use arrow_array:: ListArray ;
194+ use arrow_buffer:: OffsetBuffer ;
195+ use datafusion_common:: DataFusionError ;
179196 use datafusion_common:: Result ;
180197
181198 #[ test]
182199 fn array_agg_i32 ( ) -> Result < ( ) > {
183200 let a: ArrayRef = Arc :: new ( Int32Array :: from ( vec ! [ 1 , 2 , 3 , 4 , 5 ] ) ) ;
184201
185- let list = ScalarValue :: new_list (
186- Some ( vec ! [
187- ScalarValue :: Int32 ( Some ( 1 ) ) ,
188- ScalarValue :: Int32 ( Some ( 2 ) ) ,
189- ScalarValue :: Int32 ( Some ( 3 ) ) ,
190- ScalarValue :: Int32 ( Some ( 4 ) ) ,
191- ScalarValue :: Int32 ( Some ( 5 ) ) ,
192- ] ) ,
193- DataType :: Int32 ,
194- ) ;
202+ let list = ListArray :: from_iter_primitive :: < Int32Type , _ , _ > ( vec ! [ Some ( vec![
203+ Some ( 1 ) ,
204+ Some ( 2 ) ,
205+ Some ( 3 ) ,
206+ Some ( 4 ) ,
207+ Some ( 5 ) ,
208+ ] ) ] ) ;
209+ let list = ScalarValue :: List ( Arc :: new ( list) ) ;
195210
196211 generic_test_op ! ( a, DataType :: Int32 , ArrayAgg , list, DataType :: Int32 )
197212 }
198213
199214 #[ test]
200215 fn array_agg_nested ( ) -> Result < ( ) > {
201- let l1 = ScalarValue :: new_list (
202- Some ( vec ! [
203- ScalarValue :: new_list(
204- Some ( vec![
205- ScalarValue :: from( 1i32 ) ,
206- ScalarValue :: from( 2i32 ) ,
207- ScalarValue :: from( 3i32 ) ,
208- ] ) ,
209- DataType :: Int32 ,
210- ) ,
211- ScalarValue :: new_list(
212- Some ( vec![ ScalarValue :: from( 4i32 ) , ScalarValue :: from( 5i32 ) ] ) ,
213- DataType :: Int32 ,
214- ) ,
215- ] ) ,
216- DataType :: List ( Arc :: new ( Field :: new ( "item" , DataType :: Int32 , true ) ) ) ,
216+ let a1 = ListArray :: from_iter_primitive :: < Int32Type , _ , _ > ( vec ! [ Some ( vec![
217+ Some ( 1 ) ,
218+ Some ( 2 ) ,
219+ Some ( 3 ) ,
220+ ] ) ] ) ;
221+ let a2 = ListArray :: from_iter_primitive :: < Int32Type , _ , _ > ( vec ! [ Some ( vec![
222+ Some ( 4 ) ,
223+ Some ( 5 ) ,
224+ ] ) ] ) ;
225+ let l1 = ListArray :: new (
226+ Arc :: new ( Field :: new ( "item" , a1. data_type ( ) . to_owned ( ) , true ) ) ,
227+ OffsetBuffer :: from_lengths ( [ a1. len ( ) + a2. len ( ) ] ) ,
228+ arrow:: compute:: concat ( & [ & a1, & a2] ) ?,
229+ None ,
217230 ) ;
218231
219- let l2 = ScalarValue :: new_list (
220- Some ( vec ! [
221- ScalarValue :: new_list(
222- Some ( vec![ ScalarValue :: from( 6i32 ) ] ) ,
223- DataType :: Int32 ,
224- ) ,
225- ScalarValue :: new_list(
226- Some ( vec![ ScalarValue :: from( 7i32 ) , ScalarValue :: from( 8i32 ) ] ) ,
227- DataType :: Int32 ,
228- ) ,
229- ] ) ,
230- DataType :: List ( Arc :: new ( Field :: new ( "item" , DataType :: Int32 , true ) ) ) ,
232+ let a1 =
233+ ListArray :: from_iter_primitive :: < Int32Type , _ , _ > ( vec ! [ Some ( vec![ Some ( 6 ) ] ) ] ) ;
234+ let a2 = ListArray :: from_iter_primitive :: < Int32Type , _ , _ > ( vec ! [ Some ( vec![
235+ Some ( 7 ) ,
236+ Some ( 8 ) ,
237+ ] ) ] ) ;
238+ let l2 = ListArray :: new (
239+ Arc :: new ( Field :: new ( "item" , a1. data_type ( ) . to_owned ( ) , true ) ) ,
240+ OffsetBuffer :: from_lengths ( [ a1. len ( ) + a2. len ( ) ] ) ,
241+ arrow:: compute:: concat ( & [ & a1, & a2] ) ?,
242+ None ,
231243 ) ;
232244
233- let l3 = ScalarValue :: new_list (
234- Some ( vec ! [ ScalarValue :: new_list(
235- Some ( vec![ ScalarValue :: from( 9i32 ) ] ) ,
236- DataType :: Int32 ,
237- ) ] ) ,
238- DataType :: List ( Arc :: new ( Field :: new ( "item" , DataType :: Int32 , true ) ) ) ,
245+ let a1 =
246+ ListArray :: from_iter_primitive :: < Int32Type , _ , _ > ( vec ! [ Some ( vec![ Some ( 9 ) ] ) ] ) ;
247+ let l3 = ListArray :: new (
248+ Arc :: new ( Field :: new ( "item" , a1. data_type ( ) . to_owned ( ) , true ) ) ,
249+ OffsetBuffer :: from_lengths ( [ a1. len ( ) ] ) ,
250+ arrow:: compute:: concat ( & [ & a1] ) ?,
251+ None ,
239252 ) ;
240253
241- let list = ScalarValue :: new_list (
242- Some ( vec ! [ l1. clone( ) , l2. clone( ) , l3. clone( ) ] ) ,
243- DataType :: List ( Arc :: new ( Field :: new ( "item" , DataType :: Int32 , true ) ) ) ,
254+ let list = ListArray :: new (
255+ Arc :: new ( Field :: new ( "item" , l1. data_type ( ) . to_owned ( ) , true ) ) ,
256+ OffsetBuffer :: from_lengths ( [ l1. len ( ) + l2. len ( ) + l3. len ( ) ] ) ,
257+ arrow:: compute:: concat ( & [ & l1, & l2, & l3] ) ?,
258+ None ,
244259 ) ;
260+ let list = ScalarValue :: List ( Arc :: new ( list) ) ;
261+ let l1 = ScalarValue :: List ( Arc :: new ( l1) ) ;
262+ let l2 = ScalarValue :: List ( Arc :: new ( l2) ) ;
263+ let l3 = ScalarValue :: List ( Arc :: new ( l3) ) ;
245264
246265 let array = ScalarValue :: iter_to_array ( vec ! [ l1, l2, l3] ) . unwrap ( ) ;
247266
0 commit comments