From 5906fcac26d3aabfbf3c1d04cc6550c191f8f0a7 Mon Sep 17 00:00:00 2001 From: hhj Date: Sun, 12 Nov 2023 09:42:42 +0800 Subject: [PATCH 01/14] init impl --- datafusion/physical-expr/src/aggregate/mod.rs | 1 + .../physical-expr/src/aggregate/string_agg.rs | 219 ++++++++++++++++++ 2 files changed, 220 insertions(+) create mode 100644 datafusion/physical-expr/src/aggregate/string_agg.rs diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index 442d018b87d5..329bb1e6415e 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -43,6 +43,7 @@ pub(crate) mod covariance; pub(crate) mod first_last; pub(crate) mod grouping; pub(crate) mod median; +pub(crate) mod string_agg; #[macro_use] pub(crate) mod min_max; pub mod build_in; diff --git a/datafusion/physical-expr/src/aggregate/string_agg.rs b/datafusion/physical-expr/src/aggregate/string_agg.rs new file mode 100644 index 000000000000..5fe74fc64df4 --- /dev/null +++ b/datafusion/physical-expr/src/aggregate/string_agg.rs @@ -0,0 +1,219 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines physical expressions that can evaluated at runtime during query execution + +use crate::aggregate::utils::down_cast_any_ref; +use crate::expressions::format_state_name; +use crate::{AggregateExpr, PhysicalExpr}; +use arrow::array::ArrayRef; +use arrow::datatypes::{DataType, Field}; +use arrow_array::Array; +use datafusion_common::cast::as_generic_string_array; +use datafusion_common::Result; +use datafusion_common::ScalarValue; +use datafusion_expr::Accumulator; +use std::any::Any; +use std::sync::Arc; +// use arrow::array::OffsetSizeTrait; + +/// STRING_AGG aggregate expression +#[derive(Debug)] +pub struct StringAgg { + /// Column name + name: String, + /// The DataType for the input expression + input_data_type: DataType, + /// The input expression + expr: Arc, + /// If the input expression can have NULLs + nullable: bool, +} + +impl StringAgg { + /// Create a new StringAgg aggregate function + pub fn new( + expr: Arc, + name: impl Into, + data_type: DataType, + nullable: bool, + ) -> Self { + Self { + name: name.into(), + input_data_type: data_type, + expr, + nullable, + } + } +} + +impl AggregateExpr for StringAgg { + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new( + &self.name, + self.input_data_type.clone(), + self.nullable, + )) + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(StringAggAccumulator::new( + self.expr.clone(), + &self.input_data_type, + ))) + } + + fn state_fields(&self) -> Result> { + Ok(vec![Field::new( + format_state_name(&self.name, "string_agg"), + self.input_data_type.clone(), + self.nullable, + )]) + } + + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn name(&self) -> &str { + &self.name + } +} + +impl PartialEq for StringAgg { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| { + self.name == x.name + && self.input_data_type == x.input_data_type + && self.expr.eq(&x.expr) + }) + .unwrap_or(false) + } +} + +#[derive(Debug)] +pub(crate) struct StringAggAccumulator { + values: String, + datatype: DataType, +} + +impl StringAggAccumulator { + pub fn new( + expr: Arc, + datatype: &DataType, + ) -> Self { + Self { + values: String::new(), + datatype: datatype.clone(), + } + } +} + +impl Accumulator for StringAggAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + // use match to select utf8 or largeutf8 + let string_aggay = as_generic_string_array::(&values[0])?; + for i in 0..string_aggay.len() { + self.values.push_str(string_aggay.value(i)); + } + Ok(()) + } + + fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let string_aggay = as_generic_string_array::(&values[0])?; + for i in 0..string_aggay.len() { + self.values.push_str(string_aggay.value(i)); + } + Ok(()) + } + + fn state(&self) -> Result> { + Ok(vec![self.evaluate()?]) + } + + fn evaluate(&self) -> Result { + Ok(ScalarValue::Utf8(Some(self.values.clone()))) + } + + fn size(&self) -> usize { + std::mem::size_of::() + } +} + + +#[cfg(test)] +mod tests { + use super::*; + use crate::expressions::col; + use crate::expressions::tests::aggregate; + use arrow::array::ArrayRef; + use arrow::array::Int32Array; + use arrow::datatypes::*; + use arrow::record_batch::RecordBatch; + use arrow_array::Array; + use arrow_array::ListArray; + use arrow_buffer::OffsetBuffer; + use datafusion_common::DataFusionError; + use datafusion_common::Result; + + macro_rules! test_op { + ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr) => { + test_op!($ARRAY, $DATATYPE, $OP, $EXPECTED, $EXPECTED.data_type()) + }; + ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr, $EXPECTED_DATATYPE:expr) => {{ + let schema = Schema::new(vec![Field::new("a", $DATATYPE, true)]); + + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![$ARRAY])?; + + let agg = Arc::new(<$OP>::new( + col("a", &schema)?, + "bla".to_string(), + $EXPECTED_DATATYPE, + true, + )); + let actual = aggregate(&batch, agg)?; + let expected = ScalarValue::from($EXPECTED); + + assert_eq!(expected, actual); + + Ok(()) as Result<(), DataFusionError> + }}; + } + + #[test] + fn array_agg_i32() -> Result<()> { + let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); + + let list = ListArray::from_iter_primitive::(vec![Some(vec![ + Some(1), + Some(2), + Some(3), + Some(4), + Some(5), + ])]); + let list = ScalarValue::List(Arc::new(list)); + + test_op!(a, DataType::Int32, StringAgg, list, DataType::Int32) + } +} From a2b2e088db3f52d4ebe5f1b17f07570de1a72d0a Mon Sep 17 00:00:00 2001 From: hhj Date: Sun, 12 Nov 2023 21:09:04 +0800 Subject: [PATCH 02/14] add support for larget utf8 --- datafusion/expr/src/aggregate_function.rs | 9 ++ .../expr/src/type_coercion/aggregates.rs | 1 + .../physical-expr/src/aggregate/build_in.rs | 16 ++ .../physical-expr/src/aggregate/string_agg.rs | 139 ++++++++++-------- .../physical-expr/src/expressions/mod.rs | 1 + datafusion/proto/proto/datafusion.proto | 1 + datafusion/proto/src/generated/pbjson.rs | 3 + datafusion/proto/src/generated/prost.rs | 3 + .../proto/src/logical_plan/from_proto.rs | 1 + datafusion/proto/src/logical_plan/to_proto.rs | 4 + 10 files changed, 113 insertions(+), 65 deletions(-) diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index ea0b01825170..f0f31705d4a3 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -100,6 +100,8 @@ pub enum AggregateFunction { BoolAnd, /// Bool Or BoolOr, + /// string_agg + StringAgg, } impl AggregateFunction { @@ -141,6 +143,7 @@ impl AggregateFunction { BitXor => "BIT_XOR", BoolAnd => "BOOL_AND", BoolOr => "BOOL_OR", + StringAgg => "STRING_AGG", } } } @@ -171,6 +174,7 @@ impl FromStr for AggregateFunction { "array_agg" => AggregateFunction::ArrayAgg, "first_value" => AggregateFunction::FirstValue, "last_value" => AggregateFunction::LastValue, + "string_agg" => AggregateFunction::StringAgg, // statistical "corr" => AggregateFunction::Correlation, "covar" => AggregateFunction::Covariance, @@ -299,6 +303,8 @@ impl AggregateFunction { AggregateFunction::FirstValue | AggregateFunction::LastValue => { Ok(coerced_data_types[0].clone()) } + // TODO + AggregateFunction::StringAgg => Ok(DataType::Utf8), } } } @@ -408,6 +414,9 @@ impl AggregateFunction { .collect(), Volatility::Immutable, ), + AggregateFunction::StringAgg => { + Signature::uniform(2, STRINGS.to_vec(), Volatility::Immutable) + } } } } diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index 261c406d5d5e..8dc8351df41c 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -298,6 +298,7 @@ pub fn coerce_types( | AggregateFunction::FirstValue | AggregateFunction::LastValue => Ok(input_types.to_vec()), AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]), + AggregateFunction::StringAgg => Ok(vec![LargeUtf8, input_types[0].clone()]), } } diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 596197b4eebe..c40f0db19405 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -369,6 +369,22 @@ pub fn create_aggregate_expr( ordering_req.to_vec(), ordering_types, )), + (AggregateFunction::StringAgg, false) => { + if !ordering_req.is_empty() { + return not_impl_err!( + "STRING_AGG(ORDER BY a ASC) order-sensitive aggregations are not available" + ); + } + Arc::new(expressions::StringAgg::new( + input_phy_exprs[0].clone(), + input_phy_exprs[1].clone(), + name, + data_type, + )) + } + (AggregateFunction::StringAgg, true) => { + return not_impl_err!("STRING_AGG(DISTINCT) aggregations are not available"); + } }) } diff --git a/datafusion/physical-expr/src/aggregate/string_agg.rs b/datafusion/physical-expr/src/aggregate/string_agg.rs index 5fe74fc64df4..321ec65abbaf 100644 --- a/datafusion/physical-expr/src/aggregate/string_agg.rs +++ b/datafusion/physical-expr/src/aggregate/string_agg.rs @@ -18,29 +18,25 @@ //! Defines physical expressions that can evaluated at runtime during query execution use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::format_state_name; +use crate::expressions::{format_state_name, Literal}; use crate::{AggregateExpr, PhysicalExpr}; use arrow::array::ArrayRef; use arrow::datatypes::{DataType, Field}; -use arrow_array::Array; +use arrow_array::OffsetSizeTrait; use datafusion_common::cast::as_generic_string_array; -use datafusion_common::Result; -use datafusion_common::ScalarValue; +use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::Accumulator; use std::any::Any; +use std::marker::PhantomData; use std::sync::Arc; -// use arrow::array::OffsetSizeTrait; /// STRING_AGG aggregate expression #[derive(Debug)] pub struct StringAgg { - /// Column name name: String, - /// The DataType for the input expression - input_data_type: DataType, - /// The input expression + data_type: DataType, expr: Arc, - /// If the input expression can have NULLs + delimiter: Arc, nullable: bool, } @@ -48,15 +44,16 @@ impl StringAgg { /// Create a new StringAgg aggregate function pub fn new( expr: Arc, + delimiter: Arc, name: impl Into, data_type: DataType, - nullable: bool, ) -> Self { Self { name: name.into(), - input_data_type: data_type, + data_type, + delimiter, expr, - nullable, + nullable: true, } } } @@ -69,29 +66,42 @@ impl AggregateExpr for StringAgg { fn field(&self) -> Result { Ok(Field::new( &self.name, - self.input_data_type.clone(), + self.data_type.clone(), self.nullable, )) } fn create_accumulator(&self) -> Result> { - Ok(Box::new(StringAggAccumulator::new( - self.expr.clone(), - &self.input_data_type, - ))) + if let Some(delimiter) = self.delimiter.as_any().downcast_ref::() { + if let ScalarValue::Utf8(Some(delimiter)) = delimiter.value() { + match self.data_type { + DataType::Utf8 => { + return Ok(Box::new(StringAggAccumulator::::new(delimiter))); + } + DataType::LargeUtf8 => { + return Ok(Box::new(StringAggAccumulator::::new(delimiter))); + } + _ => { + return not_impl_err!( + "StringAgg separator only support literal string" + ) + } + }; + } + } + not_impl_err!("StringAgg separator only support literal string") } fn state_fields(&self) -> Result> { Ok(vec![Field::new( format_state_name(&self.name, "string_agg"), - self.input_data_type.clone(), + self.data_type.clone(), self.nullable, )]) } - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] + vec![self.expr.clone(), self.delimiter.clone()] } fn name(&self) -> &str { @@ -105,46 +115,44 @@ impl PartialEq for StringAgg { .downcast_ref::() .map(|x| { self.name == x.name - && self.input_data_type == x.input_data_type + && self.data_type == x.data_type && self.expr.eq(&x.expr) + && self.delimiter.eq(&x.delimiter) }) .unwrap_or(false) } } #[derive(Debug)] -pub(crate) struct StringAggAccumulator { +pub(crate) struct StringAggAccumulator { values: String, - datatype: DataType, + sep: String, + phantom: PhantomData, } -impl StringAggAccumulator { - pub fn new( - expr: Arc, - datatype: &DataType, - ) -> Self { +impl StringAggAccumulator { + pub fn new(sep: &str) -> Self { Self { values: String::new(), - datatype: datatype.clone(), + sep: sep.to_string(), + phantom: PhantomData, } } } -impl Accumulator for StringAggAccumulator { +impl Accumulator for StringAggAccumulator { fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - // use match to select utf8 or largeutf8 - let string_aggay = as_generic_string_array::(&values[0])?; - for i in 0..string_aggay.len() { - self.values.push_str(string_aggay.value(i)); - } + let string_array: Vec<_> = as_generic_string_array::(&values[0])? + .iter() + .filter_map(|v| v.as_ref().map(ToString::to_string)) + .collect(); + self.values + .push_str(string_array.join(self.sep.as_str()).as_str()); Ok(()) } fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let string_aggay = as_generic_string_array::(&values[0])?; - for i in 0..string_aggay.len() { - self.values.push_str(string_aggay.value(i)); - } + self.update_batch(values)?; Ok(()) } @@ -153,7 +161,7 @@ impl Accumulator for StringAggAccumulator { } fn evaluate(&self) -> Result { - Ok(ScalarValue::Utf8(Some(self.values.clone()))) + Ok(ScalarValue::LargeUtf8(Some(self.values.clone()))) } fn size(&self) -> usize { @@ -161,36 +169,30 @@ impl Accumulator for StringAggAccumulator { } } - #[cfg(test)] mod tests { use super::*; use crate::expressions::col; use crate::expressions::tests::aggregate; use arrow::array::ArrayRef; - use arrow::array::Int32Array; use arrow::datatypes::*; use arrow::record_batch::RecordBatch; - use arrow_array::Array; - use arrow_array::ListArray; - use arrow_buffer::OffsetBuffer; + use arrow_array::LargeStringArray; + use arrow_array::StringArray; use datafusion_common::DataFusionError; use datafusion_common::Result; macro_rules! test_op { - ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr) => { - test_op!($ARRAY, $DATATYPE, $OP, $EXPECTED, $EXPECTED.data_type()) - }; - ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr, $EXPECTED_DATATYPE:expr) => {{ + ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr, $EXPECTED_DATATYPE:expr, $SEP:expr) => {{ let schema = Schema::new(vec![Field::new("a", $DATATYPE, true)]); let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![$ARRAY])?; let agg = Arc::new(<$OP>::new( col("a", &schema)?, - "bla".to_string(), + $SEP, + "str".to_string(), $EXPECTED_DATATYPE, - true, )); let actual = aggregate(&batch, agg)?; let expected = ScalarValue::from($EXPECTED); @@ -202,18 +204,25 @@ mod tests { } #[test] - fn array_agg_i32() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); - - let list = ListArray::from_iter_primitive::(vec![Some(vec![ - Some(1), - Some(2), - Some(3), - Some(4), - Some(5), - ])]); - let list = ScalarValue::List(Arc::new(list)); - - test_op!(a, DataType::Int32, StringAgg, list, DataType::Int32) + fn string_agg_utf8() -> Result<()> { + let a: ArrayRef = Arc::new(StringArray::from(vec!["h", "e", "l", "l", "o"])); + let list = ScalarValue::LargeUtf8(Some("h,e,l,l,o".to_owned())); + let sep = Arc::new(Literal::new(ScalarValue::Utf8(Some(",".to_owned())))); + test_op!(a, DataType::Utf8, StringAgg, list, DataType::Utf8, sep) + } + + #[test] + fn string_agg_largeutf8() -> Result<()> { + let a: ArrayRef = Arc::new(LargeStringArray::from(vec!["h", "e", "l", "l", "o"])); + let list = ScalarValue::LargeUtf8(Some("h,e,l,l,o".to_owned())); + let sep = Arc::new(Literal::new(ScalarValue::Utf8(Some(",".to_owned())))); + test_op!( + a, + DataType::LargeUtf8, + StringAgg, + list, + DataType::LargeUtf8, + sep + ) } } diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index c44b3cf01d36..f2ecfef297ff 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -63,6 +63,7 @@ pub use crate::aggregate::min_max::{MaxAccumulator, MinAccumulator}; pub use crate::aggregate::regr::{Regr, RegrType}; pub use crate::aggregate::stats::StatsType; pub use crate::aggregate::stddev::{Stddev, StddevPop}; +pub use crate::aggregate::string_agg::StringAgg; pub use crate::aggregate::sum::Sum; pub use crate::aggregate::sum_distinct::DistinctSum; pub use crate::aggregate::variance::{Variance, VariancePop}; diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index f9deca2f1e52..9f7de18ff558 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -667,6 +667,7 @@ enum AggregateFunction { REGR_SXX = 32; REGR_SYY = 33; REGR_SXY = 34; + STRING_AGG = 35; } message AggregateExprNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 81f260c28bed..c357b3cf8db2 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -474,6 +474,7 @@ impl serde::Serialize for AggregateFunction { Self::RegrSxx => "REGR_SXX", Self::RegrSyy => "REGR_SYY", Self::RegrSxy => "REGR_SXY", + Self::StringAgg => "STRING_AGG", }; serializer.serialize_str(variant) } @@ -520,6 +521,7 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "REGR_SXX", "REGR_SYY", "REGR_SXY", + "STRING_AGG", ]; struct GeneratedVisitor; @@ -595,6 +597,7 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "REGR_SXX" => Ok(AggregateFunction::RegrSxx), "REGR_SYY" => Ok(AggregateFunction::RegrSyy), "REGR_SXY" => Ok(AggregateFunction::RegrSxy), + "STRING_AGG" => Ok(AggregateFunction::StringAgg), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index ae64c11b3b74..b18d2f1f9c42 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2837,6 +2837,7 @@ pub enum AggregateFunction { RegrSxx = 32, RegrSyy = 33, RegrSxy = 34, + StringAgg = 35, } impl AggregateFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2882,6 +2883,7 @@ impl AggregateFunction { AggregateFunction::RegrSxx => "REGR_SXX", AggregateFunction::RegrSyy => "REGR_SYY", AggregateFunction::RegrSxy => "REGR_SXY", + AggregateFunction::StringAgg => "STRING_AGG", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2924,6 +2926,7 @@ impl AggregateFunction { "REGR_SXX" => Some(Self::RegrSxx), "REGR_SYY" => Some(Self::RegrSyy), "REGR_SXY" => Some(Self::RegrSxy), + "STRING_AGG" => Some(Self::StringAgg), _ => None, } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 31fffca3bbed..5f755e8eb40e 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -588,6 +588,7 @@ impl From for AggregateFunction { protobuf::AggregateFunction::Median => Self::Median, protobuf::AggregateFunction::FirstValueAgg => Self::FirstValue, protobuf::AggregateFunction::LastValueAgg => Self::LastValue, + protobuf::AggregateFunction::StringAgg => Self::StringAgg, } } } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 803becbcaece..e7adfcd89547 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -398,6 +398,7 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::Median => Self::Median, AggregateFunction::FirstValue => Self::FirstValueAgg, AggregateFunction::LastValue => Self::LastValueAgg, + AggregateFunction::StringAgg => Self::StringAgg, } } } @@ -706,6 +707,9 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { AggregateFunction::LastValue => { protobuf::AggregateFunction::LastValueAgg } + AggregateFunction::StringAgg => { + protobuf::AggregateFunction::StringAgg + } }; let aggregate_expr = protobuf::AggregateExprNode { From 75b033e0fdd4aed673493431af0a34515d28bb09 Mon Sep 17 00:00:00 2001 From: hhj Date: Sun, 12 Nov 2023 23:12:07 +0800 Subject: [PATCH 03/14] add some test --- .../physical-expr/src/aggregate/string_agg.rs | 3 + .../sqllogictest/test_files/aggregate.slt | 70 +++++++++++++++++++ 2 files changed, 73 insertions(+) diff --git a/datafusion/physical-expr/src/aggregate/string_agg.rs b/datafusion/physical-expr/src/aggregate/string_agg.rs index 321ec65abbaf..cd8bc2f25811 100644 --- a/datafusion/physical-expr/src/aggregate/string_agg.rs +++ b/datafusion/physical-expr/src/aggregate/string_agg.rs @@ -146,6 +146,9 @@ impl Accumulator for StringAggAccumulator { .iter() .filter_map(|v| v.as_ref().map(ToString::to_string)) .collect(); + if self.values.len() > 0 { + self.values.push_str(self.sep.as_str()); + } self.values .push_str(string_array.join(self.sep.as_str()).as_str()); Ok(()) diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index a1bb93ed53c4..b59e09fba637 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -2987,3 +2987,73 @@ NULL NULL 1 NULL 3 6 0 0 0 NULL NULL 1 NULL 5 15 0 0 0 3 0 2 1 5.5 16.5 0.5 4.5 1.5 3 0 3 1 6 18 2 18 6 + +statement error +SELECT STRING_AGG() + +statement error +SELECT STRING_AGG(1,2,3) + +statement error +SELECT STRING_AGG(STRING_AGG('a', ',')) + +query T +SELECT STRING_AGG('a', ',') +---- +a + +# TODO: deal with NULL +statement error +SELECT STRING_AGG(NULL,',') + +statement error +STRING_AGG('a', NULL) + +statement error +STRING_AGG(NULL,NULL) + +statement ok +CREATE TABLE strings(g INTEGER, x VARCHAR, y VARCHAR) + +query ITT +INSERT INTO strings VALUES (1,'a','/'), (1,'b','-'), (2,'i','/'), (2,NULL,'-'), (2,'j','+'), (3,'p','/'), (4,'x','/'), (4,'y','-'), (4,'z','+') +---- +9 + +query IT +SELECT g, STRING_AGG(x,'|') FROM strings GROUP BY g ORDER BY g +---- +1 a|b +2 i|j +3 p +4 x|y|z + +query T +SELECT STRING_AGG(x,',') FROM strings WHERE g > 100 +---- +(empty) + + +query T +WITH my_data as ( +SELECT 'text1'::varchar(1000) as my_column union all +SELECT 'text1'::varchar(1000) as my_column union all +SELECT 'text1'::varchar(1000) as my_column +) +SELECT string_agg(my_column,', ') as my_string_agg +FROM my_data +---- +text1, text1, text1 + + +query T +WITH my_data as ( +SELECT 1 as dummy, 'text1'::varchar(1000) as my_column union all +SELECT 1 as dummy, 'text1'::varchar(1000) as my_column union all +SELECT 1 as dummy, 'text1'::varchar(1000) as my_column +) +SELECT string_agg(my_column,', ') as my_string_agg +FROM my_data +GROUP BY dummy +---- +text1, text1, text1 From 15ff644b21fd82bcab152a40c2ee5893905bbd73 Mon Sep 17 00:00:00 2001 From: hhj Date: Mon, 13 Nov 2023 13:52:00 +0800 Subject: [PATCH 04/14] support null --- datafusion/expr/src/aggregate_function.rs | 3 +- .../expr/src/type_coercion/aggregates.rs | 2 +- .../physical-expr/src/aggregate/string_agg.rs | 60 ++++++++++++------- .../sqllogictest/test_files/aggregate.slt | 17 ++---- 4 files changed, 45 insertions(+), 37 deletions(-) diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index f0f31705d4a3..4611c7fb10d7 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -303,8 +303,7 @@ impl AggregateFunction { AggregateFunction::FirstValue | AggregateFunction::LastValue => { Ok(coerced_data_types[0].clone()) } - // TODO - AggregateFunction::StringAgg => Ok(DataType::Utf8), + AggregateFunction::StringAgg => Ok(DataType::LargeUtf8), } } } diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index 8dc8351df41c..3d540b1164f6 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -298,7 +298,7 @@ pub fn coerce_types( | AggregateFunction::FirstValue | AggregateFunction::LastValue => Ok(input_types.to_vec()), AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]), - AggregateFunction::StringAgg => Ok(vec![LargeUtf8, input_types[0].clone()]), + AggregateFunction::StringAgg => Ok(vec![LargeUtf8, input_types[1].clone()]), } } diff --git a/datafusion/physical-expr/src/aggregate/string_agg.rs b/datafusion/physical-expr/src/aggregate/string_agg.rs index cd8bc2f25811..47b4a5ac10ee 100644 --- a/datafusion/physical-expr/src/aggregate/string_agg.rs +++ b/datafusion/physical-expr/src/aggregate/string_agg.rs @@ -73,23 +73,36 @@ impl AggregateExpr for StringAgg { fn create_accumulator(&self) -> Result> { if let Some(delimiter) = self.delimiter.as_any().downcast_ref::() { - if let ScalarValue::Utf8(Some(delimiter)) = delimiter.value() { - match self.data_type { - DataType::Utf8 => { - return Ok(Box::new(StringAggAccumulator::::new(delimiter))); - } - DataType::LargeUtf8 => { - return Ok(Box::new(StringAggAccumulator::::new(delimiter))); - } - _ => { - return not_impl_err!( - "StringAgg separator only support literal string" - ) - } - }; + match (self.data_type.clone(), delimiter.value()) { + (DataType::Utf8, ScalarValue::Utf8(Some(delimiter))) + | (DataType::Utf8, ScalarValue::LargeUtf8(Some(delimiter))) => { + return Ok(Box::new(StringAggAccumulator::::new(delimiter))); + } + (DataType::LargeUtf8, ScalarValue::Utf8(Some(delimiter))) + | (DataType::LargeUtf8, ScalarValue::LargeUtf8(Some(delimiter))) => { + return Ok(Box::new(StringAggAccumulator::::new(delimiter))); + } + (DataType::Utf8, ScalarValue::Null) => { + return Ok(Box::new(StringAggAccumulator::::new(""))) + } + (DataType::LargeUtf8, ScalarValue::Null) => { + return Ok(Box::new(StringAggAccumulator::::new(""))) + } + (_, _) => { + return not_impl_err!( + "StringAgg not support for {}: {} with delimiter {}", + self.name, + self.data_type, + delimiter.value() + ) + } } } - not_impl_err!("StringAgg separator only support literal string") + not_impl_err!( + "StringAgg not support for {}: {} with no Literal delimiter", + self.name, + self.data_type + ) } fn state_fields(&self) -> Result> { @@ -125,7 +138,7 @@ impl PartialEq for StringAgg { #[derive(Debug)] pub(crate) struct StringAggAccumulator { - values: String, + values: Option, sep: String, phantom: PhantomData, } @@ -133,7 +146,7 @@ pub(crate) struct StringAggAccumulator { impl StringAggAccumulator { pub fn new(sep: &str) -> Self { Self { - values: String::new(), + values: None, sep: sep.to_string(), phantom: PhantomData, } @@ -146,11 +159,14 @@ impl Accumulator for StringAggAccumulator { .iter() .filter_map(|v| v.as_ref().map(ToString::to_string)) .collect(); - if self.values.len() > 0 { - self.values.push_str(self.sep.as_str()); + let s = string_array.join(self.sep.as_str()); + if s.len() > 0 { + let v = self.values.get_or_insert("".to_string()); + if v.len() > 0 { + v.push_str(self.sep.as_str()); + } + v.push_str(s.as_str()); } - self.values - .push_str(string_array.join(self.sep.as_str()).as_str()); Ok(()) } @@ -164,7 +180,7 @@ impl Accumulator for StringAggAccumulator { } fn evaluate(&self) -> Result { - Ok(ScalarValue::LargeUtf8(Some(self.values.clone()))) + Ok(ScalarValue::LargeUtf8(self.values.clone())) } fn size(&self) -> usize { diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index b59e09fba637..3526e280ebe2 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -3002,15 +3002,10 @@ SELECT STRING_AGG('a', ',') ---- a -# TODO: deal with NULL -statement error -SELECT STRING_AGG(NULL,',') - -statement error -STRING_AGG('a', NULL) - -statement error -STRING_AGG(NULL,NULL) +query TTTT +SELECT STRING_AGG('a',','), STRING_AGG('a', NULL), STRING_AGG(NULL, ','), STRING_AGG(NULL, NULL) +---- +a a NULL NULL statement ok CREATE TABLE strings(g INTEGER, x VARCHAR, y VARCHAR) @@ -3031,8 +3026,7 @@ SELECT g, STRING_AGG(x,'|') FROM strings GROUP BY g ORDER BY g query T SELECT STRING_AGG(x,',') FROM strings WHERE g > 100 ---- -(empty) - +NULL query T WITH my_data as ( @@ -3045,7 +3039,6 @@ FROM my_data ---- text1, text1, text1 - query T WITH my_data as ( SELECT 1 as dummy, 'text1'::varchar(1000) as my_column union all From 19acf200ccfb9642fe78962e3fc23ddf6ccf43f4 Mon Sep 17 00:00:00 2001 From: hhj Date: Mon, 13 Nov 2023 15:16:10 +0800 Subject: [PATCH 05/14] remove redundance code --- .../physical-expr/src/aggregate/string_agg.rs | 106 ++++++++++-------- 1 file changed, 59 insertions(+), 47 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/string_agg.rs b/datafusion/physical-expr/src/aggregate/string_agg.rs index 47b4a5ac10ee..23c582220b12 100644 --- a/datafusion/physical-expr/src/aggregate/string_agg.rs +++ b/datafusion/physical-expr/src/aggregate/string_agg.rs @@ -73,22 +73,15 @@ impl AggregateExpr for StringAgg { fn create_accumulator(&self) -> Result> { if let Some(delimiter) = self.delimiter.as_any().downcast_ref::() { - match (self.data_type.clone(), delimiter.value()) { - (DataType::Utf8, ScalarValue::Utf8(Some(delimiter))) - | (DataType::Utf8, ScalarValue::LargeUtf8(Some(delimiter))) => { - return Ok(Box::new(StringAggAccumulator::::new(delimiter))); - } - (DataType::LargeUtf8, ScalarValue::Utf8(Some(delimiter))) - | (DataType::LargeUtf8, ScalarValue::LargeUtf8(Some(delimiter))) => { + match delimiter.value() { + ScalarValue::Utf8(Some(delimiter)) + | ScalarValue::LargeUtf8(Some(delimiter)) => { return Ok(Box::new(StringAggAccumulator::::new(delimiter))); } - (DataType::Utf8, ScalarValue::Null) => { - return Ok(Box::new(StringAggAccumulator::::new(""))) - } - (DataType::LargeUtf8, ScalarValue::Null) => { - return Ok(Box::new(StringAggAccumulator::::new(""))) + ScalarValue::Null => { + return Ok(Box::new(StringAggAccumulator::::new(""))); } - (_, _) => { + _ => { return not_impl_err!( "StringAgg not support for {}: {} with delimiter {}", self.name, @@ -160,9 +153,9 @@ impl Accumulator for StringAggAccumulator { .filter_map(|v| v.as_ref().map(ToString::to_string)) .collect(); let s = string_array.join(self.sep.as_str()); - if s.len() > 0 { + if !s.is_empty() { let v = self.values.get_or_insert("".to_string()); - if v.len() > 0 { + if !v.is_empty() { v.push_str(self.sep.as_str()); } v.push_str(s.as_str()); @@ -191,57 +184,76 @@ impl Accumulator for StringAggAccumulator { #[cfg(test)] mod tests { use super::*; - use crate::expressions::col; use crate::expressions::tests::aggregate; + use crate::expressions::{col, create_aggregate_expr, try_cast}; use arrow::array::ArrayRef; use arrow::datatypes::*; use arrow::record_batch::RecordBatch; use arrow_array::LargeStringArray; use arrow_array::StringArray; - use datafusion_common::DataFusionError; - use datafusion_common::Result; + use datafusion_expr::type_coercion::aggregates::coerce_types; + use datafusion_expr::AggregateFunction; - macro_rules! test_op { - ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr, $EXPECTED_DATATYPE:expr, $SEP:expr) => {{ - let schema = Schema::new(vec![Field::new("a", $DATATYPE, true)]); + fn assert_string_aggregate( + array: ArrayRef, + function: AggregateFunction, + distinct: bool, + expected: ScalarValue, + delimiter: String, + ) { + let data_type = array.data_type(); + let sig = function.signature(); + let coerced = + coerce_types(&function, &[data_type.clone(), DataType::Utf8], &sig).unwrap(); - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![$ARRAY])?; + let input_schema = Schema::new(vec![Field::new("a", data_type.clone(), true)]); + let batch = + RecordBatch::try_new(Arc::new(input_schema.clone()), vec![array]).unwrap(); - let agg = Arc::new(<$OP>::new( - col("a", &schema)?, - $SEP, - "str".to_string(), - $EXPECTED_DATATYPE, - )); - let actual = aggregate(&batch, agg)?; - let expected = ScalarValue::from($EXPECTED); + let input = try_cast( + col("a", &input_schema).unwrap(), + &input_schema, + coerced[0].clone(), + ) + .unwrap(); - assert_eq!(expected, actual); + let delimiter = Arc::new(Literal::new(ScalarValue::Utf8(Some(delimiter)))); + let schema = Schema::new(vec![Field::new("a", coerced[0].clone(), true)]); + let agg = create_aggregate_expr( + &function, + distinct, + &[input, delimiter], + &[], + &schema, + "agg", + ) + .unwrap(); - Ok(()) as Result<(), DataFusionError> - }}; + let result = aggregate(&batch, agg).unwrap(); + assert_eq!(expected, result); } #[test] - fn string_agg_utf8() -> Result<()> { + fn string_agg_utf8() { let a: ArrayRef = Arc::new(StringArray::from(vec!["h", "e", "l", "l", "o"])); - let list = ScalarValue::LargeUtf8(Some("h,e,l,l,o".to_owned())); - let sep = Arc::new(Literal::new(ScalarValue::Utf8(Some(",".to_owned())))); - test_op!(a, DataType::Utf8, StringAgg, list, DataType::Utf8, sep) + assert_string_aggregate( + a, + AggregateFunction::StringAgg, + false, + ScalarValue::LargeUtf8(Some("h,e,l,l,o".to_owned())), + ",".to_owned(), + ); } #[test] - fn string_agg_largeutf8() -> Result<()> { + fn string_agg_largeutf8() { let a: ArrayRef = Arc::new(LargeStringArray::from(vec!["h", "e", "l", "l", "o"])); - let list = ScalarValue::LargeUtf8(Some("h,e,l,l,o".to_owned())); - let sep = Arc::new(Literal::new(ScalarValue::Utf8(Some(",".to_owned())))); - test_op!( + assert_string_aggregate( a, - DataType::LargeUtf8, - StringAgg, - list, - DataType::LargeUtf8, - sep - ) + AggregateFunction::StringAgg, + false, + ScalarValue::LargeUtf8(Some("h|e|l|l|o".to_owned())), + "|".to_owned(), + ); } } From 3175e63c4b91e043768c6bab6160c7bd220f243d Mon Sep 17 00:00:00 2001 From: hhj Date: Mon, 13 Nov 2023 15:43:33 +0800 Subject: [PATCH 06/14] remove redundance code --- .../physical-expr/src/aggregate/string_agg.rs | 30 +++++++++---------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/string_agg.rs b/datafusion/physical-expr/src/aggregate/string_agg.rs index 23c582220b12..3af15968cd64 100644 --- a/datafusion/physical-expr/src/aggregate/string_agg.rs +++ b/datafusion/physical-expr/src/aggregate/string_agg.rs @@ -22,12 +22,10 @@ use crate::expressions::{format_state_name, Literal}; use crate::{AggregateExpr, PhysicalExpr}; use arrow::array::ArrayRef; use arrow::datatypes::{DataType, Field}; -use arrow_array::OffsetSizeTrait; use datafusion_common::cast::as_generic_string_array; use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::Accumulator; use std::any::Any; -use std::marker::PhantomData; use std::sync::Arc; /// STRING_AGG aggregate expression @@ -76,10 +74,10 @@ impl AggregateExpr for StringAgg { match delimiter.value() { ScalarValue::Utf8(Some(delimiter)) | ScalarValue::LargeUtf8(Some(delimiter)) => { - return Ok(Box::new(StringAggAccumulator::::new(delimiter))); + return Ok(Box::new(StringAggAccumulator::new(delimiter))); } ScalarValue::Null => { - return Ok(Box::new(StringAggAccumulator::::new(""))); + return Ok(Box::new(StringAggAccumulator::new(""))); } _ => { return not_impl_err!( @@ -130,33 +128,31 @@ impl PartialEq for StringAgg { } #[derive(Debug)] -pub(crate) struct StringAggAccumulator { +pub(crate) struct StringAggAccumulator { values: Option, - sep: String, - phantom: PhantomData, + delimiter: String, } -impl StringAggAccumulator { - pub fn new(sep: &str) -> Self { +impl StringAggAccumulator { + pub fn new(delimiter: &str) -> Self { Self { values: None, - sep: sep.to_string(), - phantom: PhantomData, + delimiter: delimiter.to_string(), } } } -impl Accumulator for StringAggAccumulator { +impl Accumulator for StringAggAccumulator { fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let string_array: Vec<_> = as_generic_string_array::(&values[0])? + let string_array: Vec<_> = as_generic_string_array::(&values[0])? .iter() .filter_map(|v| v.as_ref().map(ToString::to_string)) .collect(); - let s = string_array.join(self.sep.as_str()); + let s = string_array.join(self.delimiter.as_str()); if !s.is_empty() { let v = self.values.get_or_insert("".to_string()); if !v.is_empty() { - v.push_str(self.sep.as_str()); + v.push_str(self.delimiter.as_str()); } v.push_str(s.as_str()); } @@ -177,7 +173,9 @@ impl Accumulator for StringAggAccumulator { } fn size(&self) -> usize { - std::mem::size_of::() + std::mem::size_of_val(self) + + self.values.as_ref().map(|v| v.capacity()).unwrap_or(0) + + self.delimiter.capacity() } } From 6107aeacce1d340ec1d396f40dad2de26b45d99a Mon Sep 17 00:00:00 2001 From: hhj Date: Wed, 15 Nov 2023 16:49:48 +0800 Subject: [PATCH 07/14] add more test --- .../physical-expr/src/aggregate/string_agg.rs | 4 ++-- datafusion/sqllogictest/test_files/aggregate.slt | 13 +++++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/string_agg.rs b/datafusion/physical-expr/src/aggregate/string_agg.rs index 3af15968cd64..47953463cd5a 100644 --- a/datafusion/physical-expr/src/aggregate/string_agg.rs +++ b/datafusion/physical-expr/src/aggregate/string_agg.rs @@ -148,8 +148,8 @@ impl Accumulator for StringAggAccumulator { .iter() .filter_map(|v| v.as_ref().map(ToString::to_string)) .collect(); - let s = string_array.join(self.delimiter.as_str()); - if !s.is_empty() { + if !string_array.is_empty() { + let s = string_array.join(self.delimiter.as_str()); let v = self.values.get_or_insert("".to_string()); if !v.is_empty() { v.push_str(self.delimiter.as_str()); diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 3526e280ebe2..0a495dd2b0c9 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -3007,6 +3007,16 @@ SELECT STRING_AGG('a',','), STRING_AGG('a', NULL), STRING_AGG(NULL, ','), STRING ---- a a NULL NULL +query TT +select string_agg('', '|'), string_agg('a', ''); +---- +(empty) a + +query T +SELECT STRING_AGG(column1, '|') FROM (values (''), (null), ('')); +---- +| + statement ok CREATE TABLE strings(g INTEGER, x VARCHAR, y VARCHAR) @@ -3028,6 +3038,9 @@ SELECT STRING_AGG(x,',') FROM strings WHERE g > 100 ---- NULL +statement ok +drop table strings + query T WITH my_data as ( SELECT 'text1'::varchar(1000) as my_column union all From e1939f9731e533d5fbdd5da32cce603b412aa7f1 Mon Sep 17 00:00:00 2001 From: Huaijin Date: Thu, 16 Nov 2023 08:55:39 +0800 Subject: [PATCH 08/14] Update datafusion/physical-expr/src/aggregate/string_agg.rs Co-authored-by: universalmind303 --- datafusion/physical-expr/src/aggregate/string_agg.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/physical-expr/src/aggregate/string_agg.rs b/datafusion/physical-expr/src/aggregate/string_agg.rs index 47953463cd5a..ff787c01ea01 100644 --- a/datafusion/physical-expr/src/aggregate/string_agg.rs +++ b/datafusion/physical-expr/src/aggregate/string_agg.rs @@ -81,7 +81,7 @@ impl AggregateExpr for StringAgg { } _ => { return not_impl_err!( - "StringAgg not support for {}: {} with delimiter {}", + "StringAgg not supported for {}: {} with delimiter {}", self.name, self.data_type, delimiter.value() From 3936978f1ea54d0670866990ac437f4f62b7de07 Mon Sep 17 00:00:00 2001 From: Huaijin Date: Thu, 16 Nov 2023 08:55:49 +0800 Subject: [PATCH 09/14] Update datafusion/physical-expr/src/aggregate/string_agg.rs Co-authored-by: universalmind303 --- datafusion/physical-expr/src/aggregate/string_agg.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/physical-expr/src/aggregate/string_agg.rs b/datafusion/physical-expr/src/aggregate/string_agg.rs index ff787c01ea01..0d94fcf7baa5 100644 --- a/datafusion/physical-expr/src/aggregate/string_agg.rs +++ b/datafusion/physical-expr/src/aggregate/string_agg.rs @@ -90,7 +90,7 @@ impl AggregateExpr for StringAgg { } } not_impl_err!( - "StringAgg not support for {}: {} with no Literal delimiter", + "StringAgg not supported for {}: {} with no Literal delimiter", self.name, self.data_type ) From 225cd2a45a63497caef8b7234722bd6d573d77d3 Mon Sep 17 00:00:00 2001 From: hhj Date: Thu, 16 Nov 2023 09:31:46 +0800 Subject: [PATCH 10/14] add suggest --- .../expr/src/type_coercion/aggregates.rs | 27 ++++++++++++++++++- .../physical-expr/src/aggregate/string_agg.rs | 15 ++--------- 2 files changed, 28 insertions(+), 14 deletions(-) diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index 3d540b1164f6..7128b575978a 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -298,7 +298,23 @@ pub fn coerce_types( | AggregateFunction::FirstValue | AggregateFunction::LastValue => Ok(input_types.to_vec()), AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]), - AggregateFunction::StringAgg => Ok(vec![LargeUtf8, input_types[1].clone()]), + AggregateFunction::StringAgg => { + if !is_string_agg_supported_arg_type(&input_types[0]) { + return plan_err!( + "The function {:?} does not support inputs of type {:?}", + agg_fun, + input_types[0] + ); + } + if !is_string_agg_supported_arg_type(&input_types[1]) { + return plan_err!( + "The function {:?} does not support inputs of type {:?}", + agg_fun, + input_types[1] + ); + } + Ok(vec![LargeUtf8, input_types[1].clone()]) + } } } @@ -566,6 +582,15 @@ pub fn is_approx_percentile_cont_supported_arg_type(arg_type: &DataType) -> bool ) } +/// Return `true` if `arg_type` is of a [`DataType`] that the +/// [`AggregateFunction::StringAgg`] aggregation can operate on. +pub fn is_string_agg_supported_arg_type(arg_type: &DataType) -> bool { + matches!( + arg_type, + DataType::Utf8 | DataType::LargeUtf8 | DataType::Null + ) +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/physical-expr/src/aggregate/string_agg.rs b/datafusion/physical-expr/src/aggregate/string_agg.rs index 0d94fcf7baa5..7666331c53ff 100644 --- a/datafusion/physical-expr/src/aggregate/string_agg.rs +++ b/datafusion/physical-expr/src/aggregate/string_agg.rs @@ -79,21 +79,10 @@ impl AggregateExpr for StringAgg { ScalarValue::Null => { return Ok(Box::new(StringAggAccumulator::new(""))); } - _ => { - return not_impl_err!( - "StringAgg not supported for {}: {} with delimiter {}", - self.name, - self.data_type, - delimiter.value() - ) - } + _ => return not_impl_err!("StringAgg not supported for {}", self.name), } } - not_impl_err!( - "StringAgg not supported for {}: {} with no Literal delimiter", - self.name, - self.data_type - ) + not_impl_err!("StringAgg not supported for {}", self.name) } fn state_fields(&self) -> Result> { From b8e6efbab07a60fbb5d70140bb9e586c7fea4e1a Mon Sep 17 00:00:00 2001 From: Huaijin Date: Sat, 18 Nov 2023 10:32:17 +0800 Subject: [PATCH 11/14] Update datafusion/physical-expr/src/aggregate/string_agg.rs Co-authored-by: Andrew Lamb --- datafusion/physical-expr/src/aggregate/string_agg.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/physical-expr/src/aggregate/string_agg.rs b/datafusion/physical-expr/src/aggregate/string_agg.rs index 7666331c53ff..74c083959ed8 100644 --- a/datafusion/physical-expr/src/aggregate/string_agg.rs +++ b/datafusion/physical-expr/src/aggregate/string_agg.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Defines physical expressions that can evaluated at runtime during query execution +//! [`StringAgg`] and [`StringAggAccumulator`] accumulator for the `string_agg` function use crate::aggregate::utils::down_cast_any_ref; use crate::expressions::{format_state_name, Literal}; From 35ff4be63d4f87a89687687e5feaf476ab805bee Mon Sep 17 00:00:00 2001 From: Huaijin Date: Sat, 18 Nov 2023 10:32:27 +0800 Subject: [PATCH 12/14] Update datafusion/sqllogictest/test_files/aggregate.slt Co-authored-by: Andrew Lamb --- datafusion/sqllogictest/test_files/aggregate.slt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 0a495dd2b0c9..80adc7fb506c 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -3044,8 +3044,8 @@ drop table strings query T WITH my_data as ( SELECT 'text1'::varchar(1000) as my_column union all -SELECT 'text1'::varchar(1000) as my_column union all -SELECT 'text1'::varchar(1000) as my_column +SELECT 'text2'::varchar(1000) as my_column union all +SELECT 'text3'::varchar(1000) as my_column ) SELECT string_agg(my_column,', ') as my_string_agg FROM my_data From 2ea3c264d2b04cba8cd788575f607aecf3b8067a Mon Sep 17 00:00:00 2001 From: Huaijin Date: Sat, 18 Nov 2023 10:32:53 +0800 Subject: [PATCH 13/14] Update datafusion/sqllogictest/test_files/aggregate.slt Co-authored-by: Andrew Lamb --- datafusion/sqllogictest/test_files/aggregate.slt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 80adc7fb506c..a40286577211 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -3055,8 +3055,8 @@ text1, text1, text1 query T WITH my_data as ( SELECT 1 as dummy, 'text1'::varchar(1000) as my_column union all -SELECT 1 as dummy, 'text1'::varchar(1000) as my_column union all -SELECT 1 as dummy, 'text1'::varchar(1000) as my_column +SELECT 1 as dummy, 'text2'::varchar(1000) as my_column union all +SELECT 1 as dummy, 'text3'::varchar(1000) as my_column ) SELECT string_agg(my_column,', ') as my_string_agg FROM my_data From bb5c21956cb556abebaa1c65d47ded28f96577ed Mon Sep 17 00:00:00 2001 From: hhj Date: Sat, 18 Nov 2023 10:53:24 +0800 Subject: [PATCH 14/14] fix ci --- datafusion/sqllogictest/test_files/aggregate.slt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index a40286577211..0a495dd2b0c9 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -3044,8 +3044,8 @@ drop table strings query T WITH my_data as ( SELECT 'text1'::varchar(1000) as my_column union all -SELECT 'text2'::varchar(1000) as my_column union all -SELECT 'text3'::varchar(1000) as my_column +SELECT 'text1'::varchar(1000) as my_column union all +SELECT 'text1'::varchar(1000) as my_column ) SELECT string_agg(my_column,', ') as my_string_agg FROM my_data @@ -3055,8 +3055,8 @@ text1, text1, text1 query T WITH my_data as ( SELECT 1 as dummy, 'text1'::varchar(1000) as my_column union all -SELECT 1 as dummy, 'text2'::varchar(1000) as my_column union all -SELECT 1 as dummy, 'text3'::varchar(1000) as my_column +SELECT 1 as dummy, 'text1'::varchar(1000) as my_column union all +SELECT 1 as dummy, 'text1'::varchar(1000) as my_column ) SELECT string_agg(my_column,', ') as my_string_agg FROM my_data