From 01fa4859e476fc7c190537096924533529910066 Mon Sep 17 00:00:00 2001 From: YuNing Chen Date: Tue, 25 Mar 2025 00:08:51 +0800 Subject: [PATCH 01/14] chore: mv `DistinctSumAccumulator` to common --- .../src/aggregate.rs | 1 + .../src/aggregate/sum_distinct/mod.rs | 22 ++++ .../src/aggregate/sum_distinct/numeric.rs | 123 ++++++++++++++++++ datafusion/functions-aggregate/src/sum.rs | 88 +------------ 4 files changed, 148 insertions(+), 86 deletions(-) create mode 100644 datafusion/functions-aggregate-common/src/aggregate/sum_distinct/mod.rs create mode 100644 datafusion/functions-aggregate-common/src/aggregate/sum_distinct/numeric.rs diff --git a/datafusion/functions-aggregate-common/src/aggregate.rs b/datafusion/functions-aggregate-common/src/aggregate.rs index c9cbaa8396fc..56dc58570ac7 100644 --- a/datafusion/functions-aggregate-common/src/aggregate.rs +++ b/datafusion/functions-aggregate-common/src/aggregate.rs @@ -17,3 +17,4 @@ pub mod count_distinct; pub mod groups_accumulator; +pub mod sum_distinct; diff --git a/datafusion/functions-aggregate-common/src/aggregate/sum_distinct/mod.rs b/datafusion/functions-aggregate-common/src/aggregate/sum_distinct/mod.rs new file mode 100644 index 000000000000..3a645b3d6ef4 --- /dev/null +++ b/datafusion/functions-aggregate-common/src/aggregate/sum_distinct/mod.rs @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Sum distinct accumulator implementations + +pub mod numeric; + +pub use numeric::DistinctSumAccumulator; \ No newline at end of file diff --git a/datafusion/functions-aggregate-common/src/aggregate/sum_distinct/numeric.rs b/datafusion/functions-aggregate-common/src/aggregate/sum_distinct/numeric.rs new file mode 100644 index 000000000000..2c56e681c683 --- /dev/null +++ b/datafusion/functions-aggregate-common/src/aggregate/sum_distinct/numeric.rs @@ -0,0 +1,123 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines the accumulator for `SUM DISTINCT` for primitive numeric types + +use std::collections::HashSet; +use std::fmt::Debug; +use std::mem::{size_of, size_of_val}; + +use ahash::RandomState; +use arrow::array::Array; +use arrow::array::ArrowNativeTypeOp; +use arrow::array::ArrowPrimitiveType; +use arrow::array::ArrayRef; +use arrow::array::AsArray; +use arrow::datatypes::ArrowNativeType; +use arrow::datatypes::DataType; + +use datafusion_common::Result; +use datafusion_common::ScalarValue; +use datafusion_expr_common::accumulator::Accumulator; + +use crate::utils::Hashable; + +/// Accumulator for computing SUM(DISTINCT expr) +pub struct DistinctSumAccumulator { + values: HashSet, RandomState>, + data_type: DataType, +} + +impl Debug for DistinctSumAccumulator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "DistinctSumAccumulator({})", self.data_type) + } +} + +impl DistinctSumAccumulator { + pub fn try_new(data_type: &DataType) -> Result { + Ok(Self { + values: HashSet::default(), + data_type: data_type.clone(), + }) + } + + pub fn distinct_count(&self) -> usize { + self.values.len() + } +} + +impl Accumulator for DistinctSumAccumulator { + fn state(&mut self) -> Result> { + // 1. Stores aggregate state in `ScalarValue::List` + // 2. Constructs `ScalarValue::List` state from distinct numeric stored in hash set + let state_out = { + let distinct_values = self + .values + .iter() + .map(|value| { + ScalarValue::new_primitive::(Some(value.0), &self.data_type) + }) + .collect::>>()?; + + vec![ScalarValue::List(ScalarValue::new_list_nullable( + &distinct_values, + &self.data_type, + ))] + }; + Ok(state_out) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + + let array = values[0].as_primitive::(); + match array.nulls().filter(|x| x.null_count() > 0) { + Some(n) => { + for idx in n.valid_indices() { + self.values.insert(Hashable(array.value(idx))); + } + } + None => array.values().iter().for_each(|x| { + self.values.insert(Hashable(*x)); + }), + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + for x in states[0].as_list::().iter().flatten() { + self.update_batch(&[x])? + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let mut acc = T::Native::usize_as(0); + for distinct_value in self.values.iter() { + acc = acc.add_wrapping(distinct_value.0) + } + let v = (!self.values.is_empty()).then_some(acc); + ScalarValue::new_primitive::(v, &self.data_type) + } + + fn size(&self) -> usize { + size_of_val(self) + self.values.capacity() * size_of::() + } +} \ No newline at end of file diff --git a/datafusion/functions-aggregate/src/sum.rs b/datafusion/functions-aggregate/src/sum.rs index 76a1315c2d88..6539ca920ebc 100644 --- a/datafusion/functions-aggregate/src/sum.rs +++ b/datafusion/functions-aggregate/src/sum.rs @@ -17,17 +17,14 @@ //! Defines `SUM` and `SUM DISTINCT` aggregate accumulators -use ahash::RandomState; use datafusion_expr::utils::AggregateOrderSensitivity; use std::any::Any; -use std::collections::HashSet; -use std::mem::{size_of, size_of_val}; +use std::mem::size_of_val; use arrow::array::Array; use arrow::array::ArrowNativeTypeOp; use arrow::array::{ArrowNumericType, AsArray}; use arrow::datatypes::ArrowNativeType; -use arrow::datatypes::ArrowPrimitiveType; use arrow::datatypes::{ DataType, Decimal128Type, Decimal256Type, Float64Type, Int64Type, UInt64Type, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, @@ -44,7 +41,7 @@ use datafusion_expr::{ SetMonotonicity, Signature, Volatility, }; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; -use datafusion_functions_aggregate_common::utils::Hashable; +use datafusion_functions_aggregate_common::aggregate::sum_distinct::DistinctSumAccumulator; use datafusion_macros::user_doc; make_udaf_expr_and_func!( @@ -388,84 +385,3 @@ impl Accumulator for SlidingSumAccumulator { true } } - -struct DistinctSumAccumulator { - values: HashSet, RandomState>, - data_type: DataType, -} - -impl std::fmt::Debug for DistinctSumAccumulator { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "DistinctSumAccumulator({})", self.data_type) - } -} - -impl DistinctSumAccumulator { - pub fn try_new(data_type: &DataType) -> Result { - Ok(Self { - values: HashSet::default(), - data_type: data_type.clone(), - }) - } -} - -impl Accumulator for DistinctSumAccumulator { - fn state(&mut self) -> Result> { - // 1. Stores aggregate state in `ScalarValue::List` - // 2. Constructs `ScalarValue::List` state from distinct numeric stored in hash set - let state_out = { - let distinct_values = self - .values - .iter() - .map(|value| { - ScalarValue::new_primitive::(Some(value.0), &self.data_type) - }) - .collect::>>()?; - - vec![ScalarValue::List(ScalarValue::new_list_nullable( - &distinct_values, - &self.data_type, - ))] - }; - Ok(state_out) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - if values.is_empty() { - return Ok(()); - } - - let array = values[0].as_primitive::(); - match array.nulls().filter(|x| x.null_count() > 0) { - Some(n) => { - for idx in n.valid_indices() { - self.values.insert(Hashable(array.value(idx))); - } - } - None => array.values().iter().for_each(|x| { - self.values.insert(Hashable(*x)); - }), - } - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - for x in states[0].as_list::().iter().flatten() { - self.update_batch(&[x])? - } - Ok(()) - } - - fn evaluate(&mut self) -> Result { - let mut acc = T::Native::usize_as(0); - for distinct_value in self.values.iter() { - acc = acc.add_wrapping(distinct_value.0) - } - let v = (!self.values.is_empty()).then_some(acc); - ScalarValue::new_primitive::(v, &self.data_type) - } - - fn size(&self) -> usize { - size_of_val(self) + self.values.capacity() * size_of::() - } -} From 3eb166b198b7c4c8fa9ff964fc839a2fcfaa6f76 Mon Sep 17 00:00:00 2001 From: YuNing Chen Date: Tue, 25 Mar 2025 17:04:05 +0800 Subject: [PATCH 02/14] feat: add avg distinct support for float64 type --- .../src/aggregate.rs | 1 + .../src/aggregate/avg_distinct.rs | 20 +++++ .../src/aggregate/avg_distinct/numeric.rs | 78 +++++++++++++++++++ datafusion/functions-aggregate/src/average.rs | 77 +++++++++--------- .../sqllogictest/test_files/aggregate.slt | 36 ++++++++- 5 files changed, 176 insertions(+), 36 deletions(-) create mode 100644 datafusion/functions-aggregate-common/src/aggregate/avg_distinct.rs create mode 100644 datafusion/functions-aggregate-common/src/aggregate/avg_distinct/numeric.rs diff --git a/datafusion/functions-aggregate-common/src/aggregate.rs b/datafusion/functions-aggregate-common/src/aggregate.rs index 56dc58570ac7..8072900a12bc 100644 --- a/datafusion/functions-aggregate-common/src/aggregate.rs +++ b/datafusion/functions-aggregate-common/src/aggregate.rs @@ -18,3 +18,4 @@ pub mod count_distinct; pub mod groups_accumulator; pub mod sum_distinct; +pub mod avg_distinct; diff --git a/datafusion/functions-aggregate-common/src/aggregate/avg_distinct.rs b/datafusion/functions-aggregate-common/src/aggregate/avg_distinct.rs new file mode 100644 index 000000000000..3d6889431d61 --- /dev/null +++ b/datafusion/functions-aggregate-common/src/aggregate/avg_distinct.rs @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod numeric; + +pub use numeric::Float64DistinctAvgAccumulator; diff --git a/datafusion/functions-aggregate-common/src/aggregate/avg_distinct/numeric.rs b/datafusion/functions-aggregate-common/src/aggregate/avg_distinct/numeric.rs new file mode 100644 index 000000000000..e4ff58a07284 --- /dev/null +++ b/datafusion/functions-aggregate-common/src/aggregate/avg_distinct/numeric.rs @@ -0,0 +1,78 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::fmt::Debug; + +use arrow::array::ArrayRef; +use arrow::datatypes::Float64Type; +use datafusion_common::ScalarValue; +use datafusion_expr_common::accumulator::Accumulator; + +use crate::aggregate::sum_distinct::DistinctSumAccumulator; + +/// Specialized implementation of `AVG DISTINCT` for Float64 values, leveraging +/// the existing DistinctSumAccumulator implementation. +#[derive(Debug)] +pub struct Float64DistinctAvgAccumulator { + // We use the DistinctSumAccumulator to handle the set of distinct values + sum_accumulator: DistinctSumAccumulator, +} + +impl Float64DistinctAvgAccumulator { + pub fn new() -> datafusion_common::Result { + Ok(Self { + sum_accumulator: DistinctSumAccumulator::::try_new( + &arrow::datatypes::DataType::Float64, + )?, + }) + } +} + +impl Accumulator for Float64DistinctAvgAccumulator { + fn state(&mut self) -> datafusion_common::Result> { + self.sum_accumulator.state() + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion_common::Result<()> { + self.sum_accumulator.update_batch(values) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> datafusion_common::Result<()> { + self.sum_accumulator.merge_batch(states) + } + + fn evaluate(&mut self) -> datafusion_common::Result { + // Get the sum from the DistinctSumAccumulator + let sum_result = self.sum_accumulator.evaluate()?; + + // Extract the sum value + if let ScalarValue::Float64(Some(sum)) = sum_result { + // Get the count of distinct values + let count = self.sum_accumulator.distinct_count() as f64; + // Calculate average + let avg = sum / count; + Ok(ScalarValue::Float64(Some(avg))) + } else { + // If sum is None, return None (null) + Ok(ScalarValue::Float64(None)) + } + } + + fn size(&self) -> usize { + self.sum_accumulator.size() + } +} diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index 141771b0412f..758a775e06d7 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -39,6 +39,7 @@ use datafusion_expr::{ ReversedUDAF, Signature, }; +use datafusion_functions_aggregate_common::aggregate::avg_distinct::Float64DistinctAvgAccumulator; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::NullState; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::{ filtered_null_mask, set_nulls, @@ -113,43 +114,49 @@ impl AggregateUDFImpl for Avg { } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - if acc_args.is_distinct { - return exec_err!("avg(DISTINCT) aggregations are not available"); - } - use DataType::*; - let data_type = acc_args.exprs[0].data_type(acc_args.schema)?; - // instantiate specialized accumulator based for the type - match (&data_type, acc_args.return_type) { - (Float64, Float64) => Ok(Box::::default()), - ( - Decimal128(sum_precision, sum_scale), - Decimal128(target_precision, target_scale), - ) => Ok(Box::new(DecimalAvgAccumulator:: { - sum: None, - count: 0, - sum_scale: *sum_scale, - sum_precision: *sum_precision, - target_precision: *target_precision, - target_scale: *target_scale, - })), + use DataType::*; - ( - Decimal256(sum_precision, sum_scale), - Decimal256(target_precision, target_scale), - ) => Ok(Box::new(DecimalAvgAccumulator:: { - sum: None, - count: 0, - sum_scale: *sum_scale, - sum_precision: *sum_precision, - target_precision: *target_precision, - target_scale: *target_scale, - })), - _ => exec_err!( - "AvgAccumulator for ({} --> {})", - &data_type, - acc_args.return_type - ), + if acc_args.is_distinct { + // instantiate specialized accumulator based for the type + match &data_type { + // Numeric types are converted to Float64 via `coerce_avg_type` during logical plan creation + Float64 => Ok(Box::new(Float64DistinctAvgAccumulator::new()?)), + _ => exec_err!("AVG(DISTINCT) for {} not supported", data_type), + } + } else { + // instantiate specialized accumulator based for the type + match (&data_type, acc_args.return_type) { + (Float64, Float64) => Ok(Box::::default()), + ( + Decimal128(sum_precision, sum_scale), + Decimal128(target_precision, target_scale), + ) => Ok(Box::new(DecimalAvgAccumulator:: { + sum: None, + count: 0, + sum_scale: *sum_scale, + sum_precision: *sum_precision, + target_precision: *target_precision, + target_scale: *target_scale, + })), + + ( + Decimal256(sum_precision, sum_scale), + Decimal256(target_precision, target_scale), + ) => Ok(Box::new(DecimalAvgAccumulator:: { + sum: None, + count: 0, + sum_scale: *sum_scale, + sum_precision: *sum_precision, + target_precision: *target_precision, + target_scale: *target_scale, + })), + _ => exec_err!( + "AvgAccumulator for ({} --> {})", + &data_type, + acc_args.return_type + ), + } } } diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 9d8620b100f3..93a532548626 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -4962,8 +4962,10 @@ select avg(distinct x_dict) from value_dict; ---- 3 -query error +query RR select avg(x_dict), avg(distinct x_dict) from value_dict; +---- +2.625 3 query I select min(x_dict) from value_dict; @@ -6686,3 +6688,35 @@ SELECT a, median(b), arrow_typeof(median(b)) FROM group_median_all_nulls GROUP B ---- group0 NULL Int32 group1 NULL Int32 + +statement ok +create table t_decimal (c decimal(10, 4)) as values (100.00), (125.00), (175.00), (200.00), (200.00), (300.00), (null), (null); + +# Test avg_distinct for Decimal128 +query RT +select avg(distinct c), arrow_typeof(avg(distinct c)) from t_decimal; +---- +180 Decimal128(14, 8) + +statement ok +drop table t_decimal; + +# Test avg_distinct for Decimal256 +statement ok +create table t_decimal256 (c decimal(50, 2)) as values + (100.00), + (125.00), + (175.00), + (200.00), + (200.00), + (300.00), + (null), + (null); + +query RT +select avg(distinct c), arrow_typeof(avg(distinct c)) from t_decimal256; +---- +180 Decimal256(54, 6) + +statement ok +drop table t_decimal256; From 6ae50ab2e86e0c5b6ce064e6bd39ba9bd9552399 Mon Sep 17 00:00:00 2001 From: YuNing Chen Date: Tue, 25 Mar 2025 17:29:37 +0800 Subject: [PATCH 03/14] chore: fmt --- datafusion/functions-aggregate-common/src/aggregate.rs | 2 +- .../src/aggregate/sum_distinct/mod.rs | 2 +- .../src/aggregate/sum_distinct/numeric.rs | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/datafusion/functions-aggregate-common/src/aggregate.rs b/datafusion/functions-aggregate-common/src/aggregate.rs index 8072900a12bc..aadce907e7cc 100644 --- a/datafusion/functions-aggregate-common/src/aggregate.rs +++ b/datafusion/functions-aggregate-common/src/aggregate.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. +pub mod avg_distinct; pub mod count_distinct; pub mod groups_accumulator; pub mod sum_distinct; -pub mod avg_distinct; diff --git a/datafusion/functions-aggregate-common/src/aggregate/sum_distinct/mod.rs b/datafusion/functions-aggregate-common/src/aggregate/sum_distinct/mod.rs index 3a645b3d6ef4..932bfba0bf0d 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/sum_distinct/mod.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/sum_distinct/mod.rs @@ -19,4 +19,4 @@ pub mod numeric; -pub use numeric::DistinctSumAccumulator; \ No newline at end of file +pub use numeric::DistinctSumAccumulator; diff --git a/datafusion/functions-aggregate-common/src/aggregate/sum_distinct/numeric.rs b/datafusion/functions-aggregate-common/src/aggregate/sum_distinct/numeric.rs index 2c56e681c683..859c82d95660 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/sum_distinct/numeric.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/sum_distinct/numeric.rs @@ -23,9 +23,9 @@ use std::mem::{size_of, size_of_val}; use ahash::RandomState; use arrow::array::Array; +use arrow::array::ArrayRef; use arrow::array::ArrowNativeTypeOp; use arrow::array::ArrowPrimitiveType; -use arrow::array::ArrayRef; use arrow::array::AsArray; use arrow::datatypes::ArrowNativeType; use arrow::datatypes::DataType; @@ -120,4 +120,4 @@ impl Accumulator for DistinctSumAccumulator { fn size(&self) -> usize { size_of_val(self) + self.values.capacity() * size_of::() } -} \ No newline at end of file +} From 4a8868dae394288c589b12e5785ec27b368caaac Mon Sep 17 00:00:00 2001 From: YuNing Chen Date: Tue, 25 Mar 2025 21:52:52 +0800 Subject: [PATCH 04/14] refactor: update import for DataType in Float64DistinctAvgAccumulator and remove unused sum_distinct module --- .../src/aggregate/avg_distinct/numeric.rs | 4 ++-- .../src/aggregate/{sum_distinct/mod.rs => sum_distinct.rs} | 0 2 files changed, 2 insertions(+), 2 deletions(-) rename datafusion/functions-aggregate-common/src/aggregate/{sum_distinct/mod.rs => sum_distinct.rs} (100%) diff --git a/datafusion/functions-aggregate-common/src/aggregate/avg_distinct/numeric.rs b/datafusion/functions-aggregate-common/src/aggregate/avg_distinct/numeric.rs index e4ff58a07284..c9fb14fb1069 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/avg_distinct/numeric.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/avg_distinct/numeric.rs @@ -18,7 +18,7 @@ use std::fmt::Debug; use arrow::array::ArrayRef; -use arrow::datatypes::Float64Type; +use arrow::datatypes::{DataType, Float64Type}; use datafusion_common::ScalarValue; use datafusion_expr_common::accumulator::Accumulator; @@ -36,7 +36,7 @@ impl Float64DistinctAvgAccumulator { pub fn new() -> datafusion_common::Result { Ok(Self { sum_accumulator: DistinctSumAccumulator::::try_new( - &arrow::datatypes::DataType::Float64, + &DataType::Float64, )?, }) } diff --git a/datafusion/functions-aggregate-common/src/aggregate/sum_distinct/mod.rs b/datafusion/functions-aggregate-common/src/aggregate/sum_distinct.rs similarity index 100% rename from datafusion/functions-aggregate-common/src/aggregate/sum_distinct/mod.rs rename to datafusion/functions-aggregate-common/src/aggregate/sum_distinct.rs From 3bb414da12e85eb0ea51d0faf05370e900bbc720 Mon Sep 17 00:00:00 2001 From: YuNing Chen Date: Tue, 25 Mar 2025 17:04:05 +0800 Subject: [PATCH 05/14] feat: add avg distinct support for float64 type --- datafusion/functions-aggregate-common/src/aggregate.rs | 1 + datafusion/functions-aggregate/src/average.rs | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/datafusion/functions-aggregate-common/src/aggregate.rs b/datafusion/functions-aggregate-common/src/aggregate.rs index aadce907e7cc..5c2a41c774b6 100644 --- a/datafusion/functions-aggregate-common/src/aggregate.rs +++ b/datafusion/functions-aggregate-common/src/aggregate.rs @@ -19,3 +19,4 @@ pub mod avg_distinct; pub mod count_distinct; pub mod groups_accumulator; pub mod sum_distinct; +pub mod avg_distinct; diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index 758a775e06d7..f0c98b7d9c97 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -39,7 +39,9 @@ use datafusion_expr::{ ReversedUDAF, Signature, }; -use datafusion_functions_aggregate_common::aggregate::avg_distinct::Float64DistinctAvgAccumulator; +use datafusion_functions_aggregate_common::aggregate::avg_distinct::{ + DecimalDistinctAvgAccumulator, Float64DistinctAvgAccumulator, +}; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::NullState; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::{ filtered_null_mask, set_nulls, From 0b9d74998525b12d20b1b8cb47faab7b06afb473 Mon Sep 17 00:00:00 2001 From: YuNing Chen Date: Tue, 25 Mar 2025 17:07:24 +0800 Subject: [PATCH 06/14] feat: add avg distinct support for decimal --- .../src/aggregate/avg_distinct.rs | 2 + .../src/aggregate/avg_distinct/decimal.rs | 230 ++++++++++++++++++ datafusion/functions-aggregate/src/average.rs | 24 ++ 3 files changed, 256 insertions(+) create mode 100644 datafusion/functions-aggregate-common/src/aggregate/avg_distinct/decimal.rs diff --git a/datafusion/functions-aggregate-common/src/aggregate/avg_distinct.rs b/datafusion/functions-aggregate-common/src/aggregate/avg_distinct.rs index 3d6889431d61..56cdaf6618de 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/avg_distinct.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/avg_distinct.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +mod decimal; mod numeric; +pub use decimal::DecimalDistinctAvgAccumulator; pub use numeric::Float64DistinctAvgAccumulator; diff --git a/datafusion/functions-aggregate-common/src/aggregate/avg_distinct/decimal.rs b/datafusion/functions-aggregate-common/src/aggregate/avg_distinct/decimal.rs new file mode 100644 index 000000000000..cf30db5789ed --- /dev/null +++ b/datafusion/functions-aggregate-common/src/aggregate/avg_distinct/decimal.rs @@ -0,0 +1,230 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::{ + array::{ArrayRef, ArrowNumericType}, + datatypes::{i256, Decimal128Type, Decimal256Type, DecimalType}, +}; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr_common::accumulator::Accumulator; +use std::fmt::Debug; +use std::mem::size_of_val; + +use crate::aggregate::sum_distinct::DistinctSumAccumulator; +use crate::utils::DecimalAverager; + +/// Generic implementation of `AVG DISTINCT` for Decimal types. +/// Handles both Decimal128Type and Decimal256Type. +#[derive(Debug)] +pub struct DecimalDistinctAvgAccumulator { + sum_accumulator: DistinctSumAccumulator, + sum_scale: i8, + target_precision: u8, + target_scale: i8, +} + +impl DecimalDistinctAvgAccumulator { + pub fn with_decimal_params( + sum_scale: i8, + target_precision: u8, + target_scale: i8, + ) -> Self { + let data_type = T::TYPE_CONSTRUCTOR(T::MAX_PRECISION, sum_scale); + + Self { + sum_accumulator: DistinctSumAccumulator::try_new(&data_type).unwrap(), + sum_scale, + target_precision, + target_scale, + } + } +} + +impl Accumulator + for DecimalDistinctAvgAccumulator +{ + fn state(&mut self) -> Result> { + self.sum_accumulator.state() + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + self.sum_accumulator.update_batch(values) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.sum_accumulator.merge_batch(states) + } + + fn evaluate(&mut self) -> Result { + if self.sum_accumulator.distinct_count() == 0 { + return ScalarValue::new_primitive::( + None, + &T::TYPE_CONSTRUCTOR(self.target_precision, self.target_scale), + ); + } + + let sum_scalar = self.sum_accumulator.evaluate()?; + + match sum_scalar { + ScalarValue::Decimal128(Some(sum), _, _) => { + let decimal_averager = DecimalAverager::::try_new( + self.sum_scale, + self.target_precision, + self.target_scale, + )?; + let avg = decimal_averager + .avg(sum, self.sum_accumulator.distinct_count() as i128)?; + Ok(ScalarValue::Decimal128( + Some(avg), + self.target_precision, + self.target_scale, + )) + } + ScalarValue::Decimal256(Some(sum), _, _) => { + let decimal_averager = DecimalAverager::::try_new( + self.sum_scale, + self.target_precision, + self.target_scale, + )?; + // `distinct_count` returns `u64`, but `avg` expects `i256` + // first convert `u64` to `i128`, then convert `i128` to `i256` to avoid overflow + let distinct_cnt: i128 = self.sum_accumulator.distinct_count() as i128; + let count: i256 = i256::from_i128(distinct_cnt); + let avg = decimal_averager.avg(sum, count)?; + Ok(ScalarValue::Decimal256( + Some(avg), + self.target_precision, + self.target_scale, + )) + } + + _ => unreachable!("Unsupported decimal type: {:?}", sum_scalar), + } + } + + fn size(&self) -> usize { + let fixed_size = size_of_val(self); + + // Account for the size of the sum_accumulator with its contained values + fixed_size + self.sum_accumulator.size() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Decimal128Array, Decimal256Array}; + use std::sync::Arc; + + #[test] + fn test_decimal128_distinct_avg_accumulator() -> Result<()> { + // (100.00), (125.00), (175.00), (200.00), (200.00), (300.00), (null), (null) + // with precision 10, scale 4 + // As `single_distinct_to_groupby` will convert the input to a `GroupBy` plan, + // we need to test it with rust api + // See also `aggregate.slt` + let precision = 10_u8; + let scale = 4_i8; + let array = Decimal128Array::from(vec![ + Some(100_0000), // 100.0000 + Some(125_0000), // 125.0000 + Some(175_0000), // 175.0000 + Some(200_0000), // 200.0000 + Some(200_0000), // 200.0000 (duplicate) + Some(300_0000), // 300.0000 + None, // null + None, // null + ]) + .with_precision_and_scale(precision, scale)?; + + // Expected result for avg(distinct) should be 180.0000 with precision 14, scale 8 + let expected_result = ScalarValue::Decimal128( + Some(180_00000000), // 180.00000000 + 14, // target precision + 8, // target scale + ); + + let arrays: Vec = vec![Arc::new(array)]; + + // Create accumulator with appropriate parameters + let mut accumulator = + DecimalDistinctAvgAccumulator::::with_decimal_params( + scale, // input scale + 14, // target precision + 8, // target scale + ); + + // Update the accumulator with input values + accumulator.update_batch(&arrays)?; + + // Evaluate the result + let result = accumulator.evaluate()?; + + // Assert that the result matches the expected value + assert_eq!(result, expected_result); + + Ok(()) + } + + #[test] + fn test_decimal256_distinct_avg_accumulator() -> Result<()> { + // (100.00), (125.00), (175.00), (200.00), (200.00), (300.00), (null), (null) + // with precision 50, scale 2 + let precision = 50_u8; + let scale = 2_i8; + + let array = Decimal256Array::from(vec![ + Some(i256::from_i128(100_00)), // 100.00 + Some(i256::from_i128(125_00)), // 125.00 + Some(i256::from_i128(175_00)), // 175.00 + Some(i256::from_i128(200_00)), // 200.00 + Some(i256::from_i128(200_00)), // 200.00 (duplicate) + Some(i256::from_i128(300_00)), // 300.00 + None, // null + None, // null + ]) + .with_precision_and_scale(precision, scale)?; + + // Expected result for avg(distinct) should be 180.000000 with precision 54, scale 6 + let expected_result = ScalarValue::Decimal256( + Some(i256::from_i128(180_000000)), // 180.000000 + 54, // target precision + 6, // target scale + ); + + let arrays: Vec = vec![Arc::new(array)]; + + // Create accumulator with appropriate parameters + let mut accumulator = + DecimalDistinctAvgAccumulator::::with_decimal_params( + scale, // input scale + 54, // target precision + 6, // target scale + ); + + // Update the accumulator with input values + accumulator.update_batch(&arrays)?; + + // Evaluate the result + let result = accumulator.evaluate()?; + + // Assert that the result matches the expected value + assert_eq!(result, expected_result); + + Ok(()) + } +} diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index f0c98b7d9c97..ef736ff06e47 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -124,6 +124,30 @@ impl AggregateUDFImpl for Avg { match &data_type { // Numeric types are converted to Float64 via `coerce_avg_type` during logical plan creation Float64 => Ok(Box::new(Float64DistinctAvgAccumulator::new()?)), + Decimal128(_, scale) => { + let target_type = &acc_args.return_type; + if let Decimal128(target_precision, target_scale) = target_type { + Ok(Box::new(DecimalDistinctAvgAccumulator::::with_decimal_params( + *scale, + *target_precision, + *target_scale, + ))) + } else { + exec_err!("AVG(DISTINCT) for Decimal128 expected Decimal128 return type, got {}", target_type) + } + } + Decimal256(_, scale) => { + let target_type = &acc_args.return_type; + if let Decimal256(target_precision, target_scale) = target_type { + Ok(Box::new(DecimalDistinctAvgAccumulator::::with_decimal_params( + *scale, + *target_precision, + *target_scale, + ))) + } else { + exec_err!("AVG(DISTINCT) for Decimal256 expected Decimal256 return type, got {}", target_type) + } + } _ => exec_err!("AVG(DISTINCT) for {} not supported", data_type), } } else { From 2ee004d1f479dd759674b3e547aceb7a80b470a8 Mon Sep 17 00:00:00 2001 From: YuNing Chen Date: Tue, 25 Mar 2025 18:02:21 +0800 Subject: [PATCH 07/14] feat: more test for avg distinct in rust api --- datafusion/core/tests/dataframe/mod.rs | 110 +++++++++--------- .../tests/cases/roundtrip_logical_plan.rs | 1 + 2 files changed, 59 insertions(+), 52 deletions(-) diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index b19c0b978605..b4b503d1deed 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -496,6 +496,7 @@ async fn aggregate() -> Result<()> { // build plan using DataFrame API let df = test_table().await?; let group_expr = vec![col("c1")]; + let avg_distinct = avg(col("c12")).distinct().build().unwrap(); let aggr_expr = vec![ min(col("c12")), max(col("c12")), @@ -503,23 +504,24 @@ async fn aggregate() -> Result<()> { sum(col("c12")), count(col("c12")), count_distinct(col("c12")), + avg_distinct, ]; let df: Vec = df.aggregate(group_expr, aggr_expr)?.collect().await?; assert_snapshot!( batches_to_sort_string(&df), - @r###" - +----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+ - | c1 | min(aggregate_test_100.c12) | max(aggregate_test_100.c12) | avg(aggregate_test_100.c12) | sum(aggregate_test_100.c12) | count(aggregate_test_100.c12) | count(DISTINCT aggregate_test_100.c12) | - +----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+ - | a | 0.02182578039211991 | 0.9800193410444061 | 0.48754517466109415 | 10.238448667882977 | 21 | 21 | - | b | 0.04893135681998029 | 0.9185813970744787 | 0.41040709263815384 | 7.797734760124923 | 19 | 19 | - | c | 0.0494924465469434 | 0.991517828651004 | 0.6600456536439784 | 13.860958726523545 | 21 | 21 | - | d | 0.061029375346466685 | 0.9748360509016578 | 0.48855379387549824 | 8.793968289758968 | 18 | 18 | - | e | 0.01479305307777301 | 0.9965400387585364 | 0.48600669271341534 | 10.206140546981722 | 21 | 21 | - +----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+ - "### + @r" + +----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+--------------------------------------+ + | c1 | min(aggregate_test_100.c12) | max(aggregate_test_100.c12) | avg(aggregate_test_100.c12) | sum(aggregate_test_100.c12) | count(aggregate_test_100.c12) | count(DISTINCT aggregate_test_100.c12) | avg(DISTINCT aggregate_test_100.c12) | + +----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+--------------------------------------+ + | a | 0.02182578039211991 | 0.9800193410444061 | 0.48754517466109415 | 10.238448667882977 | 21 | 21 | 0.48754517466109415 | + | b | 0.04893135681998029 | 0.9185813970744787 | 0.41040709263815384 | 7.797734760124923 | 19 | 19 | 0.41040709263815384 | + | c | 0.0494924465469434 | 0.991517828651004 | 0.6600456536439784 | 13.860958726523545 | 21 | 21 | 0.6600456536439784 | + | d | 0.061029375346466685 | 0.9748360509016578 | 0.48855379387549824 | 8.793968289758968 | 18 | 18 | 0.48855379387549824 | + | e | 0.01479305307777301 | 0.9965400387585364 | 0.48600669271341534 | 10.206140546981722 | 21 | 21 | 0.48600669271341534 | + +----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+--------------------------------------+ + " ); Ok(()) @@ -530,6 +532,7 @@ async fn aggregate_assert_no_empty_batches() -> Result<()> { // build plan using DataFrame API let df = test_table().await?; let group_expr = vec![col("c1")]; + let avg_distinct = avg(col("c12")).distinct().build().unwrap(); let aggr_expr = vec![ min(col("c12")), max(col("c12")), @@ -538,6 +541,7 @@ async fn aggregate_assert_no_empty_batches() -> Result<()> { count(col("c12")), count_distinct(col("c12")), median(col("c12")), + avg_distinct, ]; let df: Vec = df.aggregate(group_expr, aggr_expr)?.collect().await?; @@ -3354,6 +3358,7 @@ async fn test_grouping_set_array_agg_with_overflow() -> Result<()> { vec![col("c1"), col("c2")], ])); + let avg_distinct = avg(col("c3")).distinct().build().unwrap(); let df = aggregates_table(&ctx) .await? .aggregate( @@ -3361,6 +3366,7 @@ async fn test_grouping_set_array_agg_with_overflow() -> Result<()> { vec![ sum(col("c3")).alias("sum_c3"), avg(col("c3")).alias("avg_c3"), + avg_distinct.alias("avg_distinct_c3"), ], )? .sort(vec![ @@ -3372,47 +3378,47 @@ async fn test_grouping_set_array_agg_with_overflow() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r###" - +----+----+--------+---------------------+ - | c1 | c2 | sum_c3 | avg_c3 | - +----+----+--------+---------------------+ - | | 5 | -194 | -13.857142857142858 | - | | 4 | 29 | 1.2608695652173914 | - | | 3 | 395 | 20.789473684210527 | - | | 2 | 184 | 8.363636363636363 | - | | 1 | 367 | 16.681818181818183 | - | e | | 847 | 40.333333333333336 | - | e | 5 | -22 | -11.0 | - | e | 4 | 261 | 37.285714285714285 | - | e | 3 | 192 | 48.0 | - | e | 2 | 189 | 37.8 | - | e | 1 | 227 | 75.66666666666667 | - | d | | 458 | 25.444444444444443 | - | d | 5 | -99 | -49.5 | - | d | 4 | 162 | 54.0 | - | d | 3 | 124 | 41.333333333333336 | - | d | 2 | 328 | 109.33333333333333 | - | d | 1 | -57 | -8.142857142857142 | - | c | | -28 | -1.3333333333333333 | - | c | 5 | 24 | 12.0 | - | c | 4 | -43 | -10.75 | - | c | 3 | 190 | 47.5 | - | c | 2 | -389 | -55.57142857142857 | - | c | 1 | 190 | 47.5 | - | b | | -111 | -5.842105263157895 | - | b | 5 | -1 | -0.2 | - | b | 4 | -223 | -44.6 | - | b | 3 | -84 | -42.0 | - | b | 2 | 102 | 25.5 | - | b | 1 | 95 | 31.666666666666668 | - | a | | -385 | -18.333333333333332 | - | a | 5 | -96 | -32.0 | - | a | 4 | -128 | -32.0 | - | a | 3 | -27 | -4.5 | - | a | 2 | -46 | -15.333333333333334 | - | a | 1 | -88 | -17.6 | - +----+----+--------+---------------------+ - "### + @r" + +----+----+--------+---------------------+---------------------+ + | c1 | c2 | sum_c3 | avg_c3 | avg_distinct_c3 | + +----+----+--------+---------------------+---------------------+ + | | 5 | -194 | -13.857142857142858 | -13.857142857142858 | + | | 4 | 29 | 1.2608695652173914 | 1.2608695652173914 | + | | 3 | 395 | 20.789473684210527 | 20.789473684210527 | + | | 2 | 184 | 8.363636363636363 | 8.363636363636363 | + | | 1 | 367 | 16.681818181818183 | 16.681818181818183 | + | e | | 847 | 40.333333333333336 | 40.333333333333336 | + | e | 5 | -22 | -11.0 | -11.0 | + | e | 4 | 261 | 37.285714285714285 | 37.285714285714285 | + | e | 3 | 192 | 48.0 | 48.0 | + | e | 2 | 189 | 37.8 | 37.8 | + | e | 1 | 227 | 75.66666666666667 | 75.66666666666667 | + | d | | 458 | 25.444444444444443 | 25.444444444444443 | + | d | 5 | -99 | -49.5 | -49.5 | + | d | 4 | 162 | 54.0 | 54.0 | + | d | 3 | 124 | 41.333333333333336 | 41.333333333333336 | + | d | 2 | 328 | 109.33333333333333 | 109.33333333333333 | + | d | 1 | -57 | -8.142857142857142 | -8.142857142857142 | + | c | | -28 | -1.3333333333333333 | -1.3333333333333333 | + | c | 5 | 24 | 12.0 | 12.0 | + | c | 4 | -43 | -10.75 | -10.75 | + | c | 3 | 190 | 47.5 | 47.5 | + | c | 2 | -389 | -55.57142857142857 | -55.57142857142857 | + | c | 1 | 190 | 47.5 | 47.5 | + | b | | -111 | -5.842105263157895 | -5.842105263157895 | + | b | 5 | -1 | -0.2 | -0.2 | + | b | 4 | -223 | -44.6 | -44.6 | + | b | 3 | -84 | -42.0 | -42.0 | + | b | 2 | 102 | 25.5 | 25.5 | + | b | 1 | 95 | 31.666666666666668 | 31.666666666666668 | + | a | | -385 | -18.333333333333332 | -18.333333333333332 | + | a | 5 | -96 | -32.0 | -32.0 | + | a | 4 | -128 | -32.0 | -32.0 | + | a | 3 | -27 | -4.5 | -4.5 | + | a | 2 | -46 | -15.333333333333334 | -15.333333333333334 | + | a | 1 | -88 | -17.6 | -17.6 | + +----+----+--------+---------------------+---------------------+ + " ); Ok(()) diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 9fa1f74ae188..1cdca2df8bde 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -957,6 +957,7 @@ async fn roundtrip_expr_api() -> Result<()> { functions_window::nth_value::last_value(lit(1)), functions_window::nth_value::nth_value(lit(1), 1), avg(lit(1.5)), + avg(lit(1.5)).distinct().build().unwrap(), covar_samp(lit(1.5), lit(2.2)), covar_pop(lit(1.5), lit(2.2)), corr(lit(1.5), lit(2.2)), From bc500e5992c3b9ac531546189e20284021a6cf56 Mon Sep 17 00:00:00 2001 From: Jefffrey Date: Sun, 24 Aug 2025 10:32:03 +0900 Subject: [PATCH 08/14] Remove DataFrame API tests for avg(distinct) --- datafusion/core/tests/dataframe/mod.rs | 110 ++++++++++++------------- 1 file changed, 52 insertions(+), 58 deletions(-) diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 2e84d742fd09..38dc0dc73569 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -498,7 +498,6 @@ async fn aggregate() -> Result<()> { // build plan using DataFrame API let df = test_table().await?; let group_expr = vec![col("c1")]; - let avg_distinct = avg(col("c12")).distinct().build().unwrap(); let aggr_expr = vec![ min(col("c12")), max(col("c12")), @@ -506,24 +505,23 @@ async fn aggregate() -> Result<()> { sum(col("c12")), count(col("c12")), count_distinct(col("c12")), - avg_distinct, ]; let df: Vec = df.aggregate(group_expr, aggr_expr)?.collect().await?; assert_snapshot!( batches_to_sort_string(&df), - @r" - +----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+--------------------------------------+ - | c1 | min(aggregate_test_100.c12) | max(aggregate_test_100.c12) | avg(aggregate_test_100.c12) | sum(aggregate_test_100.c12) | count(aggregate_test_100.c12) | count(DISTINCT aggregate_test_100.c12) | avg(DISTINCT aggregate_test_100.c12) | - +----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+--------------------------------------+ - | a | 0.02182578039211991 | 0.9800193410444061 | 0.48754517466109415 | 10.238448667882977 | 21 | 21 | 0.48754517466109415 | - | b | 0.04893135681998029 | 0.9185813970744787 | 0.41040709263815384 | 7.797734760124923 | 19 | 19 | 0.41040709263815384 | - | c | 0.0494924465469434 | 0.991517828651004 | 0.6600456536439784 | 13.860958726523545 | 21 | 21 | 0.6600456536439784 | - | d | 0.061029375346466685 | 0.9748360509016578 | 0.48855379387549824 | 8.793968289758968 | 18 | 18 | 0.48855379387549824 | - | e | 0.01479305307777301 | 0.9965400387585364 | 0.48600669271341534 | 10.206140546981722 | 21 | 21 | 0.48600669271341534 | - +----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+--------------------------------------+ - " + @r###" + +----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+ + | c1 | min(aggregate_test_100.c12) | max(aggregate_test_100.c12) | avg(aggregate_test_100.c12) | sum(aggregate_test_100.c12) | count(aggregate_test_100.c12) | count(DISTINCT aggregate_test_100.c12) | + +----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+ + | a | 0.02182578039211991 | 0.9800193410444061 | 0.48754517466109415 | 10.238448667882977 | 21 | 21 | + | b | 0.04893135681998029 | 0.9185813970744787 | 0.41040709263815384 | 7.797734760124923 | 19 | 19 | + | c | 0.0494924465469434 | 0.991517828651004 | 0.6600456536439784 | 13.860958726523545 | 21 | 21 | + | d | 0.061029375346466685 | 0.9748360509016578 | 0.48855379387549824 | 8.793968289758968 | 18 | 18 | + | e | 0.01479305307777301 | 0.9965400387585364 | 0.48600669271341534 | 10.206140546981722 | 21 | 21 | + +----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+ + "### ); Ok(()) @@ -534,7 +532,6 @@ async fn aggregate_assert_no_empty_batches() -> Result<()> { // build plan using DataFrame API let df = test_table().await?; let group_expr = vec![col("c1")]; - let avg_distinct = avg(col("c12")).distinct().build().unwrap(); let aggr_expr = vec![ min(col("c12")), max(col("c12")), @@ -543,7 +540,6 @@ async fn aggregate_assert_no_empty_batches() -> Result<()> { count(col("c12")), count_distinct(col("c12")), median(col("c12")), - avg_distinct, ]; let df: Vec = df.aggregate(group_expr, aggr_expr)?.collect().await?; @@ -3437,7 +3433,6 @@ async fn test_grouping_set_array_agg_with_overflow() -> Result<()> { vec![col("c1"), col("c2")], ])); - let avg_distinct = avg(col("c3")).distinct().build().unwrap(); let df = aggregates_table(&ctx) .await? .aggregate( @@ -3445,7 +3440,6 @@ async fn test_grouping_set_array_agg_with_overflow() -> Result<()> { vec![ sum(col("c3")).alias("sum_c3"), avg(col("c3")).alias("avg_c3"), - avg_distinct.alias("avg_distinct_c3"), ], )? .sort(vec![ @@ -3457,47 +3451,47 @@ async fn test_grouping_set_array_agg_with_overflow() -> Result<()> { assert_snapshot!( batches_to_string(&results), - @r" - +----+----+--------+---------------------+---------------------+ - | c1 | c2 | sum_c3 | avg_c3 | avg_distinct_c3 | - +----+----+--------+---------------------+---------------------+ - | | 5 | -194 | -13.857142857142858 | -13.857142857142858 | - | | 4 | 29 | 1.2608695652173914 | 1.2608695652173914 | - | | 3 | 395 | 20.789473684210527 | 20.789473684210527 | - | | 2 | 184 | 8.363636363636363 | 8.363636363636363 | - | | 1 | 367 | 16.681818181818183 | 16.681818181818183 | - | e | | 847 | 40.333333333333336 | 40.333333333333336 | - | e | 5 | -22 | -11.0 | -11.0 | - | e | 4 | 261 | 37.285714285714285 | 37.285714285714285 | - | e | 3 | 192 | 48.0 | 48.0 | - | e | 2 | 189 | 37.8 | 37.8 | - | e | 1 | 227 | 75.66666666666667 | 75.66666666666667 | - | d | | 458 | 25.444444444444443 | 25.444444444444443 | - | d | 5 | -99 | -49.5 | -49.5 | - | d | 4 | 162 | 54.0 | 54.0 | - | d | 3 | 124 | 41.333333333333336 | 41.333333333333336 | - | d | 2 | 328 | 109.33333333333333 | 109.33333333333333 | - | d | 1 | -57 | -8.142857142857142 | -8.142857142857142 | - | c | | -28 | -1.3333333333333333 | -1.3333333333333333 | - | c | 5 | 24 | 12.0 | 12.0 | - | c | 4 | -43 | -10.75 | -10.75 | - | c | 3 | 190 | 47.5 | 47.5 | - | c | 2 | -389 | -55.57142857142857 | -55.57142857142857 | - | c | 1 | 190 | 47.5 | 47.5 | - | b | | -111 | -5.842105263157895 | -5.842105263157895 | - | b | 5 | -1 | -0.2 | -0.2 | - | b | 4 | -223 | -44.6 | -44.6 | - | b | 3 | -84 | -42.0 | -42.0 | - | b | 2 | 102 | 25.5 | 25.5 | - | b | 1 | 95 | 31.666666666666668 | 31.666666666666668 | - | a | | -385 | -18.333333333333332 | -18.333333333333332 | - | a | 5 | -96 | -32.0 | -32.0 | - | a | 4 | -128 | -32.0 | -32.0 | - | a | 3 | -27 | -4.5 | -4.5 | - | a | 2 | -46 | -15.333333333333334 | -15.333333333333334 | - | a | 1 | -88 | -17.6 | -17.6 | - +----+----+--------+---------------------+---------------------+ - " + @r###" + +----+----+--------+---------------------+ + | c1 | c2 | sum_c3 | avg_c3 | + +----+----+--------+---------------------+ + | | 5 | -194 | -13.857142857142858 | + | | 4 | 29 | 1.2608695652173914 | + | | 3 | 395 | 20.789473684210527 | + | | 2 | 184 | 8.363636363636363 | + | | 1 | 367 | 16.681818181818183 | + | e | | 847 | 40.333333333333336 | + | e | 5 | -22 | -11.0 | + | e | 4 | 261 | 37.285714285714285 | + | e | 3 | 192 | 48.0 | + | e | 2 | 189 | 37.8 | + | e | 1 | 227 | 75.66666666666667 | + | d | | 458 | 25.444444444444443 | + | d | 5 | -99 | -49.5 | + | d | 4 | 162 | 54.0 | + | d | 3 | 124 | 41.333333333333336 | + | d | 2 | 328 | 109.33333333333333 | + | d | 1 | -57 | -8.142857142857142 | + | c | | -28 | -1.3333333333333333 | + | c | 5 | 24 | 12.0 | + | c | 4 | -43 | -10.75 | + | c | 3 | 190 | 47.5 | + | c | 2 | -389 | -55.57142857142857 | + | c | 1 | 190 | 47.5 | + | b | | -111 | -5.842105263157895 | + | b | 5 | -1 | -0.2 | + | b | 4 | -223 | -44.6 | + | b | 3 | -84 | -42.0 | + | b | 2 | 102 | 25.5 | + | b | 1 | 95 | 31.666666666666668 | + | a | | -385 | -18.333333333333332 | + | a | 5 | -96 | -32.0 | + | a | 4 | -128 | -32.0 | + | a | 3 | -27 | -4.5 | + | a | 2 | -46 | -15.333333333333334 | + | a | 1 | -88 | -17.6 | + +----+----+--------+---------------------+ + "### ); Ok(()) From 9af76cf4d324aac636bdc389e6572d3fd87695c3 Mon Sep 17 00:00:00 2001 From: Jefffrey Date: Sun, 24 Aug 2025 10:33:10 +0900 Subject: [PATCH 09/14] Remove proto test --- datafusion/proto/tests/cases/roundtrip_logical_plan.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 78dca4cc280d..c76036a4344f 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -967,7 +967,6 @@ async fn roundtrip_expr_api() -> Result<()> { functions_window::nth_value::last_value(lit(1)), functions_window::nth_value::nth_value(lit(1), 1), avg(lit(1.5)), - avg(lit(1.5)).distinct().build().unwrap(), covar_samp(lit(1.5), lit(2.2)), covar_pop(lit(1.5), lit(2.2)), corr(lit(1.5), lit(2.2)), From 7daad506b61080d2ff26edd04b57226776a4d8ca Mon Sep 17 00:00:00 2001 From: Jefffrey Date: Sun, 24 Aug 2025 10:34:08 +0900 Subject: [PATCH 10/14] Fix merge errors --- datafusion/functions-aggregate/src/average.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index 70a874544633..8969b0529415 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -152,7 +152,7 @@ impl AggregateUDFImpl for Avg { _ => exec_err!("AVG(DISTINCT) for {} not supported", data_type), } } else { - match (&data_type, acc_args.return_type()) { + match (&data_type, acc_args.return_field.data_type()) { (Float64, Float64) => Ok(Box::::default()), ( Decimal128(sum_precision, sum_scale), @@ -177,6 +177,7 @@ impl AggregateUDFImpl for Avg { target_precision: *target_precision, target_scale: *target_scale, })), + (Duration(time_unit), Duration(result_unit)) => { Ok(Box::new(DurationAvgAccumulator { sum: None, @@ -185,6 +186,7 @@ impl AggregateUDFImpl for Avg { result_unit: *result_unit, })) } + _ => exec_err!( "AvgAccumulator for ({} --> {})", &data_type, From 5b7a8f05736d3def09c13ebc581940a60b1fc53d Mon Sep 17 00:00:00 2001 From: Jefffrey Date: Sun, 24 Aug 2025 11:03:20 +0900 Subject: [PATCH 11/14] Refactoring --- .../src/aggregate/avg_distinct/decimal.rs | 12 ++-- datafusion/functions-aggregate/src/average.rs | 63 +++++++++---------- 2 files changed, 36 insertions(+), 39 deletions(-) diff --git a/datafusion/functions-aggregate-common/src/aggregate/avg_distinct/decimal.rs b/datafusion/functions-aggregate-common/src/aggregate/avg_distinct/decimal.rs index 4b37b6368764..347109addf5b 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/avg_distinct/decimal.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/avg_distinct/decimal.rs @@ -188,12 +188,12 @@ mod tests { let scale = 2_i8; let array = Decimal256Array::from(vec![ - Some(i256::from_i128(100_00)), // 100.00 - Some(i256::from_i128(125_00)), // 125.00 - Some(i256::from_i128(175_00)), // 175.00 - Some(i256::from_i128(200_00)), // 200.00 - Some(i256::from_i128(200_00)), // 200.00 (duplicate) - Some(i256::from_i128(300_00)), // 300.00 + Some(i256::from_i128(10_000)), // 100.00 + Some(i256::from_i128(12_500)), // 125.00 + Some(i256::from_i128(17_500)), // 175.00 + Some(i256::from_i128(20_000)), // 200.00 + Some(i256::from_i128(20_000)), // 200.00 (duplicate) + Some(i256::from_i128(30_000)), // 300.00 None, // null None, // null ]) diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index 8969b0529415..4cf27d451e89 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -122,37 +122,36 @@ impl AggregateUDFImpl for Avg { // instantiate specialized accumulator based for the type if acc_args.is_distinct { - match &data_type { + match (&data_type, acc_args.return_type()) { // Numeric types are converted to Float64 via `coerce_avg_type` during logical plan creation - Float64 => Ok(Box::new(Float64DistinctAvgAccumulator::default())), - Decimal128(_, scale) => { - let target_type = &acc_args.return_type(); - if let Decimal128(target_precision, target_scale) = target_type { - Ok(Box::new(DecimalDistinctAvgAccumulator::::with_decimal_params( - *scale, - *target_precision, - *target_scale, - ))) - } else { - exec_err!("AVG(DISTINCT) for Decimal128 expected Decimal128 return type, got {}", target_type) - } - } - Decimal256(_, scale) => { - let target_type = &acc_args.return_type(); - if let Decimal256(target_precision, target_scale) = target_type { - Ok(Box::new(DecimalDistinctAvgAccumulator::::with_decimal_params( - *scale, - *target_precision, - *target_scale, - ))) - } else { - exec_err!("AVG(DISTINCT) for Decimal256 expected Decimal256 return type, got {}", target_type) - } - } - _ => exec_err!("AVG(DISTINCT) for {} not supported", data_type), + (Float64, _) => Ok(Box::new(Float64DistinctAvgAccumulator::default())), + + ( + Decimal128(_, scale), + Decimal128(target_precision, target_scale), + ) => Ok(Box::new(DecimalDistinctAvgAccumulator::::with_decimal_params( + *scale, + *target_precision, + *target_scale, + ))), + + ( + Decimal256(_, scale), + Decimal256(target_precision, target_scale), + ) => Ok(Box::new(DecimalDistinctAvgAccumulator::::with_decimal_params( + *scale, + *target_precision, + *target_scale, + ))), + + (dt, return_type) => exec_err!( + "AVG(DISTINCT) for ({} --> {}) not supported", + dt, + return_type + ), } } else { - match (&data_type, acc_args.return_field.data_type()) { + match (&data_type, acc_args.return_type()) { (Float64, Float64) => Ok(Box::::default()), ( Decimal128(sum_precision, sum_scale), @@ -187,11 +186,9 @@ impl AggregateUDFImpl for Avg { })) } - _ => exec_err!( - "AvgAccumulator for ({} --> {})", - &data_type, - acc_args.return_field.data_type() - ), + (dt, return_type) => { + exec_err!("AvgAccumulator for ({} --> {})", dt, return_type) + } } } } From 54a32ffc81195df392200136c03ce3399163bf7b Mon Sep 17 00:00:00 2001 From: Jefffrey Date: Sun, 14 Sep 2025 21:47:11 +0900 Subject: [PATCH 12/14] Minor cleanup --- .../src/aggregate/avg_distinct/decimal.rs | 84 +++++-------------- 1 file changed, 23 insertions(+), 61 deletions(-) diff --git a/datafusion/functions-aggregate-common/src/aggregate/avg_distinct/decimal.rs b/datafusion/functions-aggregate-common/src/aggregate/avg_distinct/decimal.rs index 347109addf5b..a71871b9b41e 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/avg_distinct/decimal.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/avg_distinct/decimal.rs @@ -132,49 +132,28 @@ mod tests { #[test] fn test_decimal128_distinct_avg_accumulator() -> Result<()> { - // (100.00), (125.00), (175.00), (200.00), (200.00), (300.00), (null), (null) - // with precision 10, scale 4 - // As `single_distinct_to_groupby` will convert the input to a `GroupBy` plan, - // we need to test it with rust api - // See also `aggregate.slt` let precision = 10_u8; let scale = 4_i8; let array = Decimal128Array::from(vec![ - Some(100_0000), // 100.0000 - Some(125_0000), // 125.0000 - Some(175_0000), // 175.0000 - Some(200_0000), // 200.0000 - Some(200_0000), // 200.0000 (duplicate) - Some(300_0000), // 300.0000 - None, // null - None, // null + Some(100_0000), + Some(125_0000), + Some(175_0000), + Some(200_0000), + Some(200_0000), + Some(300_0000), + None, + None, ]) .with_precision_and_scale(precision, scale)?; - // Expected result for avg(distinct) should be 180.0000 with precision 14, scale 8 - let expected_result = ScalarValue::Decimal128( - Some(180_00000000), // 180.00000000 - 14, // target precision - 8, // target scale - ); - - let arrays: Vec = vec![Arc::new(array)]; - - // Create accumulator with appropriate parameters let mut accumulator = DecimalDistinctAvgAccumulator::::with_decimal_params( - scale, // input scale - 14, // target precision - 8, // target scale + scale, 14, 8, ); + accumulator.update_batch(&[Arc::new(array)])?; - // Update the accumulator with input values - accumulator.update_batch(&arrays)?; - - // Evaluate the result let result = accumulator.evaluate()?; - - // Assert that the result matches the expected value + let expected_result = ScalarValue::Decimal128(Some(180_00000000), 14, 8); assert_eq!(result, expected_result); Ok(()) @@ -182,47 +161,30 @@ mod tests { #[test] fn test_decimal256_distinct_avg_accumulator() -> Result<()> { - // (100.00), (125.00), (175.00), (200.00), (200.00), (300.00), (null), (null) - // with precision 50, scale 2 let precision = 50_u8; let scale = 2_i8; let array = Decimal256Array::from(vec![ - Some(i256::from_i128(10_000)), // 100.00 - Some(i256::from_i128(12_500)), // 125.00 - Some(i256::from_i128(17_500)), // 175.00 - Some(i256::from_i128(20_000)), // 200.00 - Some(i256::from_i128(20_000)), // 200.00 (duplicate) - Some(i256::from_i128(30_000)), // 300.00 - None, // null - None, // null + Some(i256::from_i128(10_000)), + Some(i256::from_i128(12_500)), + Some(i256::from_i128(17_500)), + Some(i256::from_i128(20_000)), + Some(i256::from_i128(20_000)), + Some(i256::from_i128(30_000)), + None, + None, ]) .with_precision_and_scale(precision, scale)?; - // Expected result for avg(distinct) should be 180.000000 with precision 54, scale 6 - let expected_result = ScalarValue::Decimal256( - Some(i256::from_i128(180_000000)), // 180.000000 - 54, // target precision - 6, // target scale - ); - - let arrays: Vec = vec![Arc::new(array)]; - - // Create accumulator with appropriate parameters let mut accumulator = DecimalDistinctAvgAccumulator::::with_decimal_params( - scale, // input scale - 54, // target precision - 6, // target scale + scale, 54, 6, ); + accumulator.update_batch(&[Arc::new(array)])?; - // Update the accumulator with input values - accumulator.update_batch(&arrays)?; - - // Evaluate the result let result = accumulator.evaluate()?; - - // Assert that the result matches the expected value + let expected_result = + ScalarValue::Decimal256(Some(i256::from_i128(180_000000)), 54, 6); assert_eq!(result, expected_result); Ok(()) From 8099d27865d8ba85e5c51aa868a5cc8fdaa79b61 Mon Sep 17 00:00:00 2001 From: Jefffrey Date: Sun, 14 Sep 2025 21:47:37 +0900 Subject: [PATCH 13/14] Decimal slt tests for avg(distinct) --- .../sqllogictest/test_files/aggregate.slt | 64 +++++++++++-------- 1 file changed, 37 insertions(+), 27 deletions(-) diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 4101e2503120..eed3721078c7 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -7476,55 +7476,65 @@ FROM (VALUES ('a'), ('d'), ('c'), ('a')) t(a_varchar); # distinct average statement ok -create table distinct_avg (a int, b double) as values - (3, null), - (2, null), - (5, 100.5), - (5, 1.0), - (5, 44.112), - (null, 1.0), - (5, 100.5), - (1, 4.09), - (5, 100.5), - (5, 100.5), - (4, null), - (null, null) +create table distinct_avg (a int, b double, c decimal(10, 4), d decimal(50, 2)) as values + (3, null, 100.2562, 90251.21), + (2, null, 100.2562, null), + (5, 100.5, null, 10000000.11), + (5, 1.0, 100.2563, -1.0), + (5, 44.112, -132.12, null), + (null, 1.0, 100.2562, 90251.21), + (5, 100.5, -100.2562, -10000000.11), + (1, 4.09, 4222.124, 0.0), + (5, 100.5, null, 10000000.11), + (5, 100.5, 1.1, 1.0), + (4, null, 4222.124, null), + (null, null, null, null) ; # Need two columns to ensure single_distinct_to_group_by rule doesn't kick in, so we know our actual avg(distinct) code is being tested -query RTRTRR +query RTRTRTRTRRRR select avg(distinct a), arrow_typeof(avg(distinct a)), avg(distinct b), arrow_typeof(avg(distinct b)), + avg(distinct c), + arrow_typeof(avg(distinct c)), + avg(distinct d), + arrow_typeof(avg(distinct d)), avg(a), - avg(b) + avg(b), + avg(c), + avg(d) from distinct_avg; ---- -3 Float64 37.4255 Float64 4 56.52525 +3 Float64 37.4255 Float64 698.56005 Decimal128(14, 8) 15041.868333 Decimal256(54, 6) 4 56.52525 957.11074444 1272562.81625 -query RR rowsort +query RRRR rowsort select avg(distinct a), - avg(distinct b) + avg(distinct b), + avg(distinct c), + avg(distinct d) from distinct_avg group by b; ---- -1 4.09 -3 NULL -5 1 -5 100.5 -5 44.112 +1 4.09 4222.124 0 +3 NULL 2161.1901 90251.21 +5 1 100.25625 45125.105 +5 100.5 -49.5781 0.333333 +5 44.112 -132.12 NULL -query RR +query RRRR select avg(distinct a), - avg(distinct b) + avg(distinct b), + avg(distinct c), + avg(distinct d) from distinct_avg -where a is null and b is null; +where a is null and b is null and c is null and d is null; ---- -NULL NULL +NULL NULL NULL NULL statement ok drop table distinct_avg; From 1a98eb1f2f6ff4b6746838d91d7f844e49e96a82 Mon Sep 17 00:00:00 2001 From: Jefffrey Date: Sun, 14 Sep 2025 21:48:23 +0900 Subject: [PATCH 14/14] Fix state_fields for decimal distinct avg --- datafusion/functions-aggregate/src/average.rs | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index 4cf27d451e89..5223ef533603 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -27,6 +27,7 @@ use arrow::datatypes::{ i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, DecimalType, DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType, DurationSecondType, Field, FieldRef, Float64Type, TimeUnit, UInt64Type, + DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, }; use datafusion_common::{ exec_err, not_impl_err, utils::take_function_args, Result, ScalarValue, @@ -195,11 +196,22 @@ impl AggregateUDFImpl for Avg { fn state_fields(&self, args: StateFieldsArgs) -> Result> { if args.is_distinct { - // Copied from datafusion_functions_aggregate::sum::Sum::state_fields + // Decimal accumulator actually uses a different precision during accumulation, + // see DecimalDistinctAvgAccumulator::with_decimal_params + let dt = match args.input_fields[0].data_type() { + DataType::Decimal128(_, scale) => { + DataType::Decimal128(DECIMAL128_MAX_PRECISION, *scale) + } + DataType::Decimal256(_, scale) => { + DataType::Decimal256(DECIMAL256_MAX_PRECISION, *scale) + } + _ => args.return_type().clone(), + }; + // Similar to datafusion_functions_aggregate::sum::Sum::state_fields // since the accumulator uses DistinctSumAccumulator internally. Ok(vec![Field::new_list( format_state_name(args.name, "avg distinct"), - Field::new_list_field(args.return_type().clone(), true), + Field::new_list_field(dt, true), false, ) .into()])