diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index ea0b01825170..4611c7fb10d7 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -100,6 +100,8 @@ pub enum AggregateFunction { BoolAnd, /// Bool Or BoolOr, + /// string_agg + StringAgg, } impl AggregateFunction { @@ -141,6 +143,7 @@ impl AggregateFunction { BitXor => "BIT_XOR", BoolAnd => "BOOL_AND", BoolOr => "BOOL_OR", + StringAgg => "STRING_AGG", } } } @@ -171,6 +174,7 @@ impl FromStr for AggregateFunction { "array_agg" => AggregateFunction::ArrayAgg, "first_value" => AggregateFunction::FirstValue, "last_value" => AggregateFunction::LastValue, + "string_agg" => AggregateFunction::StringAgg, // statistical "corr" => AggregateFunction::Correlation, "covar" => AggregateFunction::Covariance, @@ -299,6 +303,7 @@ impl AggregateFunction { AggregateFunction::FirstValue | AggregateFunction::LastValue => { Ok(coerced_data_types[0].clone()) } + AggregateFunction::StringAgg => Ok(DataType::LargeUtf8), } } } @@ -408,6 +413,9 @@ impl AggregateFunction { .collect(), Volatility::Immutable, ), + AggregateFunction::StringAgg => { + Signature::uniform(2, STRINGS.to_vec(), Volatility::Immutable) + } } } } diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index 261c406d5d5e..7128b575978a 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -298,6 +298,23 @@ pub fn coerce_types( | AggregateFunction::FirstValue | AggregateFunction::LastValue => Ok(input_types.to_vec()), AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]), + AggregateFunction::StringAgg => { + if !is_string_agg_supported_arg_type(&input_types[0]) { + return plan_err!( + "The function {:?} does not support inputs of type {:?}", + agg_fun, + input_types[0] + ); + } + if !is_string_agg_supported_arg_type(&input_types[1]) { + return plan_err!( + "The function {:?} does not support inputs of type {:?}", + agg_fun, + input_types[1] + ); + } + Ok(vec![LargeUtf8, input_types[1].clone()]) + } } } @@ -565,6 +582,15 @@ pub fn is_approx_percentile_cont_supported_arg_type(arg_type: &DataType) -> bool ) } +/// Return `true` if `arg_type` is of a [`DataType`] that the +/// [`AggregateFunction::StringAgg`] aggregation can operate on. +pub fn is_string_agg_supported_arg_type(arg_type: &DataType) -> bool { + matches!( + arg_type, + DataType::Utf8 | DataType::LargeUtf8 | DataType::Null + ) +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 596197b4eebe..c40f0db19405 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -369,6 +369,22 @@ pub fn create_aggregate_expr( ordering_req.to_vec(), ordering_types, )), + (AggregateFunction::StringAgg, false) => { + if !ordering_req.is_empty() { + return not_impl_err!( + "STRING_AGG(ORDER BY a ASC) order-sensitive aggregations are not available" + ); + } + Arc::new(expressions::StringAgg::new( + input_phy_exprs[0].clone(), + input_phy_exprs[1].clone(), + name, + data_type, + )) + } + (AggregateFunction::StringAgg, true) => { + return not_impl_err!("STRING_AGG(DISTINCT) aggregations are not available"); + } }) } diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index 442d018b87d5..329bb1e6415e 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -43,6 +43,7 @@ pub(crate) mod covariance; pub(crate) mod first_last; pub(crate) mod grouping; pub(crate) mod median; +pub(crate) mod string_agg; #[macro_use] pub(crate) mod min_max; pub mod build_in; diff --git a/datafusion/physical-expr/src/aggregate/string_agg.rs b/datafusion/physical-expr/src/aggregate/string_agg.rs new file mode 100644 index 000000000000..74c083959ed8 --- /dev/null +++ b/datafusion/physical-expr/src/aggregate/string_agg.rs @@ -0,0 +1,246 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`StringAgg`] and [`StringAggAccumulator`] accumulator for the `string_agg` function + +use crate::aggregate::utils::down_cast_any_ref; +use crate::expressions::{format_state_name, Literal}; +use crate::{AggregateExpr, PhysicalExpr}; +use arrow::array::ArrayRef; +use arrow::datatypes::{DataType, Field}; +use datafusion_common::cast::as_generic_string_array; +use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::Accumulator; +use std::any::Any; +use std::sync::Arc; + +/// STRING_AGG aggregate expression +#[derive(Debug)] +pub struct StringAgg { + name: String, + data_type: DataType, + expr: Arc, + delimiter: Arc, + nullable: bool, +} + +impl StringAgg { + /// Create a new StringAgg aggregate function + pub fn new( + expr: Arc, + delimiter: Arc, + name: impl Into, + data_type: DataType, + ) -> Self { + Self { + name: name.into(), + data_type, + delimiter, + expr, + nullable: true, + } + } +} + +impl AggregateExpr for StringAgg { + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new( + &self.name, + self.data_type.clone(), + self.nullable, + )) + } + + fn create_accumulator(&self) -> Result> { + if let Some(delimiter) = self.delimiter.as_any().downcast_ref::() { + match delimiter.value() { + ScalarValue::Utf8(Some(delimiter)) + | ScalarValue::LargeUtf8(Some(delimiter)) => { + return Ok(Box::new(StringAggAccumulator::new(delimiter))); + } + ScalarValue::Null => { + return Ok(Box::new(StringAggAccumulator::new(""))); + } + _ => return not_impl_err!("StringAgg not supported for {}", self.name), + } + } + not_impl_err!("StringAgg not supported for {}", self.name) + } + + fn state_fields(&self) -> Result> { + Ok(vec![Field::new( + format_state_name(&self.name, "string_agg"), + self.data_type.clone(), + self.nullable, + )]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone(), self.delimiter.clone()] + } + + fn name(&self) -> &str { + &self.name + } +} + +impl PartialEq for StringAgg { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| { + self.name == x.name + && self.data_type == x.data_type + && self.expr.eq(&x.expr) + && self.delimiter.eq(&x.delimiter) + }) + .unwrap_or(false) + } +} + +#[derive(Debug)] +pub(crate) struct StringAggAccumulator { + values: Option, + delimiter: String, +} + +impl StringAggAccumulator { + pub fn new(delimiter: &str) -> Self { + Self { + values: None, + delimiter: delimiter.to_string(), + } + } +} + +impl Accumulator for StringAggAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let string_array: Vec<_> = as_generic_string_array::(&values[0])? + .iter() + .filter_map(|v| v.as_ref().map(ToString::to_string)) + .collect(); + if !string_array.is_empty() { + let s = string_array.join(self.delimiter.as_str()); + let v = self.values.get_or_insert("".to_string()); + if !v.is_empty() { + v.push_str(self.delimiter.as_str()); + } + v.push_str(s.as_str()); + } + Ok(()) + } + + fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + self.update_batch(values)?; + Ok(()) + } + + fn state(&self) -> Result> { + Ok(vec![self.evaluate()?]) + } + + fn evaluate(&self) -> Result { + Ok(ScalarValue::LargeUtf8(self.values.clone())) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + + self.values.as_ref().map(|v| v.capacity()).unwrap_or(0) + + self.delimiter.capacity() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::expressions::tests::aggregate; + use crate::expressions::{col, create_aggregate_expr, try_cast}; + use arrow::array::ArrayRef; + use arrow::datatypes::*; + use arrow::record_batch::RecordBatch; + use arrow_array::LargeStringArray; + use arrow_array::StringArray; + use datafusion_expr::type_coercion::aggregates::coerce_types; + use datafusion_expr::AggregateFunction; + + fn assert_string_aggregate( + array: ArrayRef, + function: AggregateFunction, + distinct: bool, + expected: ScalarValue, + delimiter: String, + ) { + let data_type = array.data_type(); + let sig = function.signature(); + let coerced = + coerce_types(&function, &[data_type.clone(), DataType::Utf8], &sig).unwrap(); + + let input_schema = Schema::new(vec![Field::new("a", data_type.clone(), true)]); + let batch = + RecordBatch::try_new(Arc::new(input_schema.clone()), vec![array]).unwrap(); + + let input = try_cast( + col("a", &input_schema).unwrap(), + &input_schema, + coerced[0].clone(), + ) + .unwrap(); + + let delimiter = Arc::new(Literal::new(ScalarValue::Utf8(Some(delimiter)))); + let schema = Schema::new(vec![Field::new("a", coerced[0].clone(), true)]); + let agg = create_aggregate_expr( + &function, + distinct, + &[input, delimiter], + &[], + &schema, + "agg", + ) + .unwrap(); + + let result = aggregate(&batch, agg).unwrap(); + assert_eq!(expected, result); + } + + #[test] + fn string_agg_utf8() { + let a: ArrayRef = Arc::new(StringArray::from(vec!["h", "e", "l", "l", "o"])); + assert_string_aggregate( + a, + AggregateFunction::StringAgg, + false, + ScalarValue::LargeUtf8(Some("h,e,l,l,o".to_owned())), + ",".to_owned(), + ); + } + + #[test] + fn string_agg_largeutf8() { + let a: ArrayRef = Arc::new(LargeStringArray::from(vec!["h", "e", "l", "l", "o"])); + assert_string_aggregate( + a, + AggregateFunction::StringAgg, + false, + ScalarValue::LargeUtf8(Some("h|e|l|l|o".to_owned())), + "|".to_owned(), + ); + } +} diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 1919cac97986..b6d0ad5b9104 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -63,6 +63,7 @@ pub use crate::aggregate::min_max::{MaxAccumulator, MinAccumulator}; pub use crate::aggregate::regr::{Regr, RegrType}; pub use crate::aggregate::stats::StatsType; pub use crate::aggregate::stddev::{Stddev, StddevPop}; +pub use crate::aggregate::string_agg::StringAgg; pub use crate::aggregate::sum::Sum; pub use crate::aggregate::sum_distinct::DistinctSum; pub use crate::aggregate::variance::{Variance, VariancePop}; diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index fa080518d50c..c2e02e6e281f 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -683,6 +683,7 @@ enum AggregateFunction { REGR_SXX = 32; REGR_SYY = 33; REGR_SXY = 34; + STRING_AGG = 35; } message AggregateExprNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 08e7413102e8..661f4b0a2e5d 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -474,6 +474,7 @@ impl serde::Serialize for AggregateFunction { Self::RegrSxx => "REGR_SXX", Self::RegrSyy => "REGR_SYY", Self::RegrSxy => "REGR_SXY", + Self::StringAgg => "STRING_AGG", }; serializer.serialize_str(variant) } @@ -520,6 +521,7 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "REGR_SXX", "REGR_SYY", "REGR_SXY", + "STRING_AGG", ]; struct GeneratedVisitor; @@ -595,6 +597,7 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "REGR_SXX" => Ok(AggregateFunction::RegrSxx), "REGR_SYY" => Ok(AggregateFunction::RegrSyy), "REGR_SXY" => Ok(AggregateFunction::RegrSxy), + "STRING_AGG" => Ok(AggregateFunction::StringAgg), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 15606488b33a..986b1710d165 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2872,6 +2872,7 @@ pub enum AggregateFunction { RegrSxx = 32, RegrSyy = 33, RegrSxy = 34, + StringAgg = 35, } impl AggregateFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2917,6 +2918,7 @@ impl AggregateFunction { AggregateFunction::RegrSxx => "REGR_SXX", AggregateFunction::RegrSyy => "REGR_SYY", AggregateFunction::RegrSxy => "REGR_SXY", + AggregateFunction::StringAgg => "STRING_AGG", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2959,6 +2961,7 @@ impl AggregateFunction { "REGR_SXX" => Some(Self::RegrSxx), "REGR_SYY" => Some(Self::RegrSyy), "REGR_SXY" => Some(Self::RegrSxy), + "STRING_AGG" => Some(Self::StringAgg), _ => None, } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index b3d68570038c..96141892b813 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -594,6 +594,7 @@ impl From for AggregateFunction { protobuf::AggregateFunction::Median => Self::Median, protobuf::AggregateFunction::FirstValueAgg => Self::FirstValue, protobuf::AggregateFunction::LastValueAgg => Self::LastValue, + protobuf::AggregateFunction::StringAgg => Self::StringAgg, } } } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 144f28531041..4f239ad05a70 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -405,6 +405,7 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::Median => Self::Median, AggregateFunction::FirstValue => Self::FirstValueAgg, AggregateFunction::LastValue => Self::LastValueAgg, + AggregateFunction::StringAgg => Self::StringAgg, } } } @@ -721,6 +722,9 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { AggregateFunction::LastValue => { protobuf::AggregateFunction::LastValueAgg } + AggregateFunction::StringAgg => { + protobuf::AggregateFunction::StringAgg + } }; let aggregate_expr = protobuf::AggregateExprNode { diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index a1bb93ed53c4..0a495dd2b0c9 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -2987,3 +2987,79 @@ NULL NULL 1 NULL 3 6 0 0 0 NULL NULL 1 NULL 5 15 0 0 0 3 0 2 1 5.5 16.5 0.5 4.5 1.5 3 0 3 1 6 18 2 18 6 + +statement error +SELECT STRING_AGG() + +statement error +SELECT STRING_AGG(1,2,3) + +statement error +SELECT STRING_AGG(STRING_AGG('a', ',')) + +query T +SELECT STRING_AGG('a', ',') +---- +a + +query TTTT +SELECT STRING_AGG('a',','), STRING_AGG('a', NULL), STRING_AGG(NULL, ','), STRING_AGG(NULL, NULL) +---- +a a NULL NULL + +query TT +select string_agg('', '|'), string_agg('a', ''); +---- +(empty) a + +query T +SELECT STRING_AGG(column1, '|') FROM (values (''), (null), ('')); +---- +| + +statement ok +CREATE TABLE strings(g INTEGER, x VARCHAR, y VARCHAR) + +query ITT +INSERT INTO strings VALUES (1,'a','/'), (1,'b','-'), (2,'i','/'), (2,NULL,'-'), (2,'j','+'), (3,'p','/'), (4,'x','/'), (4,'y','-'), (4,'z','+') +---- +9 + +query IT +SELECT g, STRING_AGG(x,'|') FROM strings GROUP BY g ORDER BY g +---- +1 a|b +2 i|j +3 p +4 x|y|z + +query T +SELECT STRING_AGG(x,',') FROM strings WHERE g > 100 +---- +NULL + +statement ok +drop table strings + +query T +WITH my_data as ( +SELECT 'text1'::varchar(1000) as my_column union all +SELECT 'text1'::varchar(1000) as my_column union all +SELECT 'text1'::varchar(1000) as my_column +) +SELECT string_agg(my_column,', ') as my_string_agg +FROM my_data +---- +text1, text1, text1 + +query T +WITH my_data as ( +SELECT 1 as dummy, 'text1'::varchar(1000) as my_column union all +SELECT 1 as dummy, 'text1'::varchar(1000) as my_column union all +SELECT 1 as dummy, 'text1'::varchar(1000) as my_column +) +SELECT string_agg(my_column,', ') as my_string_agg +FROM my_data +GROUP BY dummy +---- +text1, text1, text1