|  | 
| 18 | 18 | //! `ARRAY_AGG` aggregate implementation: [`ArrayAgg`] | 
| 19 | 19 | 
 | 
| 20 | 20 | use arrow::array::{ | 
| 21 |  | -    new_empty_array, Array, ArrayRef, AsArray, BooleanArray, ListArray, StructArray, | 
|  | 21 | +    make_array, new_empty_array, Array, ArrayRef, AsArray, BooleanArray, ListArray, | 
|  | 22 | +    StructArray, | 
| 22 | 23 | }; | 
| 23 | 24 | use arrow::compute::{filter, SortOptions}; | 
| 24 | 25 | use arrow::datatypes::{DataType, Field, FieldRef, Fields}; | 
| 25 | 26 | 
 | 
| 26 | 27 | use datafusion_common::cast::as_list_array; | 
|  | 28 | +use datafusion_common::scalar::copy_array_data; | 
| 27 | 29 | use datafusion_common::utils::{get_row_at_idx, SingleRowListArrayBuilder}; | 
| 28 | 30 | use datafusion_common::{exec_err, ScalarValue}; | 
| 29 | 31 | use datafusion_common::{internal_err, Result}; | 
| @@ -319,7 +321,11 @@ impl Accumulator for ArrayAggAccumulator { | 
| 319 | 321 |         }; | 
| 320 | 322 | 
 | 
| 321 | 323 |         if !val.is_empty() { | 
| 322 |  | -            self.values.push(val); | 
|  | 324 | +            // The ArrayRef might be holding a reference to its original input buffer, so | 
|  | 325 | +            // storing it here directly copied/compacted avoids over accounting memory | 
|  | 326 | +            // not used here. | 
|  | 327 | +            self.values | 
|  | 328 | +                .push(make_array(copy_array_data(&val.to_data()))); | 
| 323 | 329 |         } | 
| 324 | 330 | 
 | 
| 325 | 331 |         Ok(()) | 
| @@ -429,7 +435,8 @@ impl Accumulator for DistinctArrayAggAccumulator { | 
| 429 | 435 |         if nulls.is_none_or(|nulls| nulls.null_count() < val.len()) { | 
| 430 | 436 |             for i in 0..val.len() { | 
| 431 | 437 |                 if nulls.is_none_or(|nulls| nulls.is_valid(i)) { | 
| 432 |  | -                    self.values.insert(ScalarValue::try_from_array(val, i)?); | 
|  | 438 | +                    self.values | 
|  | 439 | +                        .insert(ScalarValue::try_from_array(val, i)?.compacted()); | 
| 433 | 440 |                 } | 
| 434 | 441 |             } | 
| 435 | 442 |         } | 
| @@ -558,8 +565,14 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { | 
| 558 | 565 |         if nulls.is_none_or(|nulls| nulls.null_count() < val.len()) { | 
| 559 | 566 |             for i in 0..val.len() { | 
| 560 | 567 |                 if nulls.is_none_or(|nulls| nulls.is_valid(i)) { | 
| 561 |  | -                    self.values.push(ScalarValue::try_from_array(val, i)?); | 
| 562 |  | -                    self.ordering_values.push(get_row_at_idx(ord, i)?) | 
|  | 568 | +                    self.values | 
|  | 569 | +                        .push(ScalarValue::try_from_array(val, i)?.compacted()); | 
|  | 570 | +                    self.ordering_values.push( | 
|  | 571 | +                        get_row_at_idx(ord, i)? | 
|  | 572 | +                            .into_iter() | 
|  | 573 | +                            .map(|v| v.compacted()) | 
|  | 574 | +                            .collect(), | 
|  | 575 | +                    ) | 
| 563 | 576 |                 } | 
| 564 | 577 |             } | 
| 565 | 578 |         } | 
| @@ -722,6 +735,7 @@ impl OrderSensitiveArrayAggAccumulator { | 
| 722 | 735 | #[cfg(test)] | 
| 723 | 736 | mod tests { | 
| 724 | 737 |     use super::*; | 
|  | 738 | +    use arrow::array::{ListBuilder, StringBuilder}; | 
| 725 | 739 |     use arrow::datatypes::{FieldRef, Schema}; | 
| 726 | 740 |     use datafusion_common::cast::as_generic_string_array; | 
| 727 | 741 |     use datafusion_common::internal_err; | 
| @@ -988,6 +1002,56 @@ mod tests { | 
| 988 | 1002 |         Ok(()) | 
| 989 | 1003 |     } | 
| 990 | 1004 | 
 | 
|  | 1005 | +    #[test] | 
|  | 1006 | +    fn does_not_over_account_memory() -> Result<()> { | 
|  | 1007 | +        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string().build_two()?; | 
|  | 1008 | + | 
|  | 1009 | +        acc1.update_batch(&[data(["a", "c", "b"])])?; | 
|  | 1010 | +        acc2.update_batch(&[data(["b", "c", "a"])])?; | 
|  | 1011 | +        acc1 = merge(acc1, acc2)?; | 
|  | 1012 | + | 
|  | 1013 | +        // without compaction, the size is 2652. | 
|  | 1014 | +        assert_eq!(acc1.size(), 732); | 
|  | 1015 | + | 
|  | 1016 | +        Ok(()) | 
|  | 1017 | +    } | 
|  | 1018 | +    #[test] | 
|  | 1019 | +    fn does_not_over_account_memory_distinct() -> Result<()> { | 
|  | 1020 | +        let (mut acc1, mut acc2) = ArrayAggAccumulatorBuilder::string() | 
|  | 1021 | +            .distinct() | 
|  | 1022 | +            .build_two()?; | 
|  | 1023 | + | 
|  | 1024 | +        acc1.update_batch(&[string_list_data([ | 
|  | 1025 | +            vec!["a", "b", "c"], | 
|  | 1026 | +            vec!["d", "e", "f"], | 
|  | 1027 | +        ])])?; | 
|  | 1028 | +        acc2.update_batch(&[string_list_data([vec!["e", "f", "g"]])])?; | 
|  | 1029 | +        acc1 = merge(acc1, acc2)?; | 
|  | 1030 | + | 
|  | 1031 | +        // without compaction, the size is 16660 | 
|  | 1032 | +        assert_eq!(acc1.size(), 1660); | 
|  | 1033 | + | 
|  | 1034 | +        Ok(()) | 
|  | 1035 | +    } | 
|  | 1036 | + | 
|  | 1037 | +    #[test] | 
|  | 1038 | +    fn does_not_over_account_memory_ordered() -> Result<()> { | 
|  | 1039 | +        let mut acc = ArrayAggAccumulatorBuilder::string() | 
|  | 1040 | +            .order_by_col("col", SortOptions::new(false, false)) | 
|  | 1041 | +            .build()?; | 
|  | 1042 | + | 
|  | 1043 | +        acc.update_batch(&[string_list_data([ | 
|  | 1044 | +            vec!["a", "b", "c"], | 
|  | 1045 | +            vec!["c", "d", "e"], | 
|  | 1046 | +            vec!["b", "c", "d"], | 
|  | 1047 | +        ])])?; | 
|  | 1048 | + | 
|  | 1049 | +        // without compaction, the size is 17112 | 
|  | 1050 | +        assert_eq!(acc.size(), 2080); | 
|  | 1051 | + | 
|  | 1052 | +        Ok(()) | 
|  | 1053 | +    } | 
|  | 1054 | + | 
| 991 | 1055 |     struct ArrayAggAccumulatorBuilder { | 
| 992 | 1056 |         return_field: FieldRef, | 
| 993 | 1057 |         distinct: bool, | 
| @@ -1066,6 +1130,15 @@ mod tests { | 
| 1066 | 1130 |             .collect() | 
| 1067 | 1131 |     } | 
| 1068 | 1132 | 
 | 
|  | 1133 | +    fn string_list_data<'a>(data: impl IntoIterator<Item = Vec<&'a str>>) -> ArrayRef { | 
|  | 1134 | +        let mut builder = ListBuilder::new(StringBuilder::new()); | 
|  | 1135 | +        for string_list in data.into_iter() { | 
|  | 1136 | +            builder.append_value(string_list.iter().map(Some).collect::<Vec<_>>()); | 
|  | 1137 | +        } | 
|  | 1138 | + | 
|  | 1139 | +        Arc::new(builder.finish()) | 
|  | 1140 | +    } | 
|  | 1141 | + | 
| 1069 | 1142 |     fn data<T, const N: usize>(list: [T; N]) -> ArrayRef | 
| 1070 | 1143 |     where | 
| 1071 | 1144 |         ScalarValue: From<T>, | 
|  | 
0 commit comments