From b94f70fdcfb41b5661879df489298c6072335a51 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Fri, 16 Feb 2024 22:37:21 +0800 Subject: [PATCH 01/18] first draft Signed-off-by: jayzhan211 --- datafusion/core/src/physical_planner.rs | 14 +- .../user_defined/user_defined_aggregates.rs | 128 ++++++++++++++++- datafusion/expr/src/expr_fn.rs | 130 +++++++++++++++++- datafusion/expr/src/function.rs | 12 +- .../physical-expr/src/aggregate/first_last.rs | 2 +- .../physical-expr/src/expressions/mod.rs | 2 +- datafusion/physical-plan/src/udaf.rs | 9 +- datafusion/physical-plan/src/windows/mod.rs | 11 +- datafusion/proto/src/physical_plan/mod.rs | 4 +- datafusion/sql/src/expr/function.rs | 5 +- 10 files changed, 294 insertions(+), 23 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index d348e28ededa..c32443ce0098 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -246,7 +246,6 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { args, filter, order_by, - null_treatment: _, }) => match func_def { AggregateFunctionDefinition::BuiltIn(..) => { create_function_physical_name(func_def.name(), *distinct, args) @@ -258,11 +257,7 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { "aggregate expression with filter is not supported" ); } - if order_by.is_some() { - return exec_err!( - "aggregate expression with order_by is not supported" - ); - } + let names = args .iter() .map(|e| create_physical_name(e, false)) @@ -1709,13 +1704,16 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( (agg_expr, filter, order_by) } AggregateFunctionDefinition::UDF(fun) => { + let ordering_reqs: Vec = + order_by.clone().unwrap_or(vec![]); let agg_expr = udaf::create_aggregate_expr( fun, &args, + &ordering_reqs, physical_input_schema, name, - ); - (agg_expr?, filter, order_by) + )?; + (agg_expr, filter, order_by) } AggregateFunctionDefinition::Name(_) => { return internal_err!( diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 9e231d25f298..a36aec698438 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -20,7 +20,7 @@ use arrow::{array::AsArray, datatypes::Fields}; use arrow_array::{types::UInt64Type, Int32Array, PrimitiveArray, StructArray}; -use arrow_schema::Schema; +use arrow_schema::{Schema, SortOptions}; use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, @@ -42,11 +42,15 @@ use datafusion::{ prelude::SessionContext, scalar::ScalarValue, }; -use datafusion_common::{assert_contains, cast::as_primitive_array, exec_err}; +use datafusion_common::{ + assert_contains, cast::as_primitive_array, exec_err, Column, DataFusionError, +}; use datafusion_expr::{ - create_udaf, AggregateUDFImpl, GroupsAccumulator, SimpleAggregateUDF, + create_udaf, create_udaf_with_ordering, expr::Sort, AggregateUDFImpl, Expr, + GroupsAccumulator, SimpleAggregateUDF, }; -use datafusion_physical_expr::expressions::AvgAccumulator; +use datafusion_physical_expr::expressions::{self, FirstValueAccumulator}; +use datafusion_physical_expr::{expressions::AvgAccumulator, PhysicalSortExpr}; /// Test to show the contents of the setup #[tokio::test] @@ -208,6 +212,122 @@ async fn execute(ctx: &SessionContext, sql: &str) -> Result> { ctx.sql(sql).await?.collect().await } +#[tokio::test] +async fn simple_udaf_order() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3, 4])), + Arc::new(Int32Array::from(vec![1, 1, 2, 2])), + ], + )?; + + let ctx = SessionContext::new(); + + let provider = MemTable::try_new(Arc::new(schema.clone()), vec![vec![batch]])?; + ctx.register_table("t", Arc::new(provider))?; + + // let expected_result = ctx + // .sql("SELECT FIRST_VALUE(a order by a desc) FROM t group by b order by b") + // .await? + // .collect() + // .await?; + + fn create_accumulator( + data_type: &DataType, + order_by: Vec>, + schema: &Schema, + ) -> Result> { + let mut all_sort_orders = vec![]; + + assert_eq!(order_by.len(), 1); + + for exprs in order_by { + // Construct PhysicalSortExpr objects from Expr objects: + let mut sort_exprs = vec![]; + for expr in exprs { + if let Expr::Sort(sort) = expr { + if let Expr::Column(col) = sort.expr.as_ref() { + let name = &col.name; + let e = expressions::col(name, schema)?; + sort_exprs.push(PhysicalSortExpr { + expr: e, + options: SortOptions { + descending: !sort.asc, + nulls_first: sort.nulls_first, + }, + }); + } + } + } + if !sort_exprs.is_empty() { + all_sort_orders.extend(sort_exprs); + } + } + + let ordering_req = all_sort_orders; + + let ordering_types = ordering_req + .iter() + .map(|e| e.expr.data_type(schema)) + .collect::>>()?; + + let acc = FirstValueAccumulator::try_new( + data_type, + ordering_types.as_slice(), + ordering_req, + )?; + // let acc = FirstValueAccumulator::try_new(data_type, &[], vec![])?; + Ok(Box::new(acc)) + } + + let order_by = Expr::Sort(Sort { + expr: Box::new(Expr::Column(Column::new(Some("t"), "a"))), + asc: false, + nulls_first: false, + }); + let order_by = vec![vec![order_by]]; + + // define a udaf, using a DataFusion's accumulator + let my_first = create_udaf_with_ordering( + "my_first", + vec![DataType::Int32], + Arc::new(DataType::Int32), + Volatility::Immutable, + // Arc::new(|d| create_accumulator(d, None, &dfs, &p, &schema)), + Arc::new(|d, order_by, schema| create_accumulator(d, order_by, schema)), + Arc::new(vec![DataType::Int32, DataType::Int32, DataType::Boolean]), + order_by, + schema, + ); + + ctx.register_udaf(my_first); + + // Should be the same as `SELECT FIRST_VALUE(a order by a) FROM t group by b order by b` + let result = ctx + .sql("SELECT MY_FIRST(a order by a desc) FROM t group by b order by b") + .await? + .collect() + .await?; + + let expected = [ + "+---------------+", + "| my_first(t.a) |", + "+---------------+", + "| 2 |", + "| 4 |", + "+---------------+", + ]; + assert_batches_eq!(expected, &result); + + Ok(()) +} + /// tests the creation, registration and usage of a UDAF #[tokio::test] async fn simple_udaf() -> Result<()> { diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 99f44a73c1dd..201d4fca2814 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -21,7 +21,9 @@ use crate::expr::{ AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, InSubquery, Placeholder, ScalarFunction, TryCast, }; -use crate::function::PartitionEvaluatorFactory; +use crate::function::{ + AccumulatorFactoryFunctionWithOrdering, PartitionEvaluatorFactory, +}; use crate::{ aggregate_function, built_in_function, conditional_expressions::CaseBuilder, logical_plan::Subquery, AccumulatorFactoryFunction, AggregateUDF, @@ -29,7 +31,7 @@ use crate::{ ScalarUDF, Signature, Volatility, }; use crate::{AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowUDF, WindowUDFImpl}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Schema}; use datafusion_common::{Column, Result}; use std::any::Any; use std::fmt::Debug; @@ -1017,6 +1019,34 @@ pub fn create_udaf( )) } +// TODO: Merge with ordering +/// Creates a new UDAF with a specific signature, state type and return type. +/// The signature and state type must match the `Accumulator's implementation`. +pub fn create_udaf_with_ordering( + name: &str, + input_type: Vec, + return_type: Arc, + volatility: Volatility, + accumulator: AccumulatorFactoryFunctionWithOrdering, + state_type: Arc>, + ordering_req: Vec>, + schema: Schema, +) -> AggregateUDF { + let return_type = Arc::try_unwrap(return_type).unwrap_or_else(|t| t.as_ref().clone()); + let state_type = Arc::try_unwrap(state_type).unwrap_or_else(|t| t.as_ref().clone()); + + AggregateUDF::from(SimpleOrderedAggregateUDF::new( + name, + input_type, + return_type, + volatility, + accumulator, + state_type, + ordering_req, + schema, + )) +} + /// Implements [`AggregateUDFImpl`] for functions that have a single signature and /// return type. pub struct SimpleAggregateUDF { @@ -1103,6 +1133,102 @@ impl AggregateUDFImpl for SimpleAggregateUDF { } } +/// Implements [`AggregateUDFImpl`] for functions that have a single signature and +/// return type. +pub struct SimpleOrderedAggregateUDF { + name: String, + signature: Signature, + return_type: DataType, + accumulator: AccumulatorFactoryFunctionWithOrdering, + state_type: Vec, + ordering_req: Vec>, + schema: Schema, +} + +impl Debug for SimpleOrderedAggregateUDF { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("AggregateUDF") + .field("name", &self.name) + .field("signature", &self.signature) + .field("fun", &"") + .finish() + } +} + +impl SimpleOrderedAggregateUDF { + /// Create a new `AggregateUDFImpl` from a name, input types, return type, state type and + /// implementation. Implementing [`AggregateUDFImpl`] allows more flexibility + pub fn new( + name: impl Into, + input_type: Vec, + return_type: DataType, + volatility: Volatility, + accumulator: AccumulatorFactoryFunctionWithOrdering, + state_type: Vec, + ordering_req: Vec>, + schema: Schema, + ) -> Self { + let name = name.into(); + let signature = Signature::exact(input_type, volatility); + Self { + name, + signature, + return_type, + accumulator, + state_type, + ordering_req, + schema, + } + } + + pub fn new_with_signature( + name: impl Into, + signature: Signature, + return_type: DataType, + accumulator: AccumulatorFactoryFunctionWithOrdering, + state_type: Vec, + ordering_req: Vec>, + schema: Schema, + ) -> Self { + let name = name.into(); + Self { + name, + signature, + return_type, + accumulator, + state_type, + ordering_req, + schema, + } + } +} + +impl AggregateUDFImpl for SimpleOrderedAggregateUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(self.return_type.clone()) + } + + fn accumulator(&self, arg: &DataType) -> Result> { + (self.accumulator)(arg, self.ordering_req.clone(), &self.schema) + } + + fn state_type(&self, _return_type: &DataType) -> Result> { + Ok(self.state_type.clone()) + } +} + /// Creates a new UDWF with a specific signature, state type and return type. /// /// The signature and state type must match the [`PartitionEvaluator`]'s implementation`. diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 3e30a5574be0..1c00dcd665f5 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -17,9 +17,9 @@ //! Function module contains typing and signature for built-in and user defined functions. -use crate::{Accumulator, BuiltinScalarFunction, PartitionEvaluator, Signature}; +use crate::{Accumulator, BuiltinScalarFunction, Expr, PartitionEvaluator, Signature}; use crate::{AggregateFunction, BuiltInWindowFunction, ColumnarValue}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Schema}; use datafusion_common::utils::datafusion_strsim; use datafusion_common::Result; use std::sync::Arc; @@ -45,6 +45,14 @@ pub type ReturnTypeFunction = pub type AccumulatorFactoryFunction = Arc Result> + Send + Sync>; +/// Factory that returns an accumulator for the given aggregate, given +/// its return datatype, the ordering of the input arguments and the schema that are needed for ordering. +pub type AccumulatorFactoryFunctionWithOrdering = Arc< + dyn Fn(&DataType, Vec>, &Schema) -> Result> + + Send + + Sync, +>; + /// Factory that creates a PartitionEvaluator for the given window /// function pub type PartitionEvaluatorFactory = diff --git a/datafusion/physical-expr/src/aggregate/first_last.rs b/datafusion/physical-expr/src/aggregate/first_last.rs index 17dd3ef1206d..264f52f93a3a 100644 --- a/datafusion/physical-expr/src/aggregate/first_last.rs +++ b/datafusion/physical-expr/src/aggregate/first_last.rs @@ -210,7 +210,7 @@ impl PartialEq for FirstValue { } #[derive(Debug)] -struct FirstValueAccumulator { +pub struct FirstValueAccumulator { first: ScalarValue, // At the beginning, `is_set` is false, which means `first` is not seen yet. // Once we see the first value, we set the `is_set` flag and do not update `first` anymore. diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 26d649f57201..d900c00e9db0 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -54,7 +54,7 @@ pub use crate::aggregate::correlation::Correlation; pub use crate::aggregate::count::Count; pub use crate::aggregate::count_distinct::DistinctCount; pub use crate::aggregate::covariance::{Covariance, CovariancePop}; -pub use crate::aggregate::first_last::{FirstValue, LastValue}; +pub use crate::aggregate::first_last::{FirstValue, FirstValueAccumulator, LastValue}; pub use crate::aggregate::grouping::Grouping; pub use crate::aggregate::median::Median; pub use crate::aggregate::min_max::{Max, Min}; diff --git a/datafusion/physical-plan/src/udaf.rs b/datafusion/physical-plan/src/udaf.rs index fd9279dfd552..f30f95924a82 100644 --- a/datafusion/physical-plan/src/udaf.rs +++ b/datafusion/physical-plan/src/udaf.rs @@ -30,7 +30,7 @@ use arrow::{ use super::{expressions::format_state_name, Accumulator, AggregateExpr}; use datafusion_common::{not_impl_err, Result}; pub use datafusion_expr::AggregateUDF; -use datafusion_physical_expr::PhysicalExpr; +use datafusion_physical_expr::{LexOrdering, PhysicalExpr, PhysicalSortExpr}; use datafusion_physical_expr::aggregate::utils::down_cast_any_ref; use std::sync::Arc; @@ -40,6 +40,7 @@ use std::sync::Arc; pub fn create_aggregate_expr( fun: &AggregateUDF, input_phy_exprs: &[Arc], + ordering_req: &[PhysicalSortExpr], input_schema: &Schema, name: impl Into, ) -> Result> { @@ -53,6 +54,7 @@ pub fn create_aggregate_expr( args: input_phy_exprs.to_vec(), data_type: fun.return_type(&input_exprs_types)?, name: name.into(), + ordering_req: ordering_req.to_vec(), })) } @@ -64,6 +66,7 @@ pub struct AggregateFunctionExpr { /// Output / return type of this aggregate data_type: DataType, name: String, + ordering_req: LexOrdering, } impl AggregateFunctionExpr { @@ -175,6 +178,10 @@ impl AggregateExpr for AggregateFunctionExpr { fn create_groups_accumulator(&self) -> Result> { self.fun.create_groups_accumulator() } + + fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { + (!self.ordering_req.is_empty()).then_some(&self.ordering_req) + } } impl PartialEq for AggregateFunctionExpr { diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 54731f0d812b..4ddcfd38f1db 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -92,8 +92,15 @@ pub fn create_window_expr( )) } WindowFunctionDefinition::AggregateUDF(fun) => { - let aggregate = - udaf::create_aggregate_expr(fun.as_ref(), args, input_schema, name)?; + // TODO: Ordering not supported for Window UDFs + let ordering_req = &[]; + let aggregate = udaf::create_aggregate_expr( + fun.as_ref(), + args, + ordering_req, + input_schema, + name, + )?; window_expr_from_aggregate_expr( partition_by, order_by, diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index a4c08d76867d..f6ff2d3efa40 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -473,7 +473,9 @@ impl AsExecutionPlan for PhysicalPlanNode { } AggregateFunction::UserDefinedAggrFunction(udaf_name) => { let agg_udf = registry.udaf(udaf_name)?; - udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, &physical_schema, name) + // TODO: Ordering not supported for UDAF here + let ordering_req = &[]; + udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, ordering_req, &physical_schema, name) } } }).transpose()?.ok_or_else(|| { diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index bcf641e4b5a0..bb0a71a7ab5b 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -176,9 +176,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } else { // User defined aggregate functions (UDAF) have precedence in case it has the same name as a scalar built-in function if let Some(fm) = self.context_provider.get_aggregate_meta(&name) { + let order_by = + self.order_by_to_sort_expr(&order_by, schema, planner_context, true)?; + let order_by = (!order_by.is_empty()).then_some(order_by); let args = self.function_args_to_expr(args, schema, planner_context)?; return Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf( - fm, args, false, None, None, + fm, args, false, None, order_by, ))); } From c743d13febd81654ccde0fb8a7ec42ac9e93d45a Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sun, 18 Feb 2024 16:35:27 +0800 Subject: [PATCH 02/18] clippy fix Signed-off-by: jayzhan211 --- .../core/tests/user_defined/user_defined_aggregates.rs | 9 +-------- datafusion/expr/src/expr_fn.rs | 3 ++- datafusion/proto/tests/cases/roundtrip_physical_plan.rs | 1 + 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index a36aec698438..10f1494799a2 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -232,12 +232,6 @@ async fn simple_udaf_order() -> Result<()> { let provider = MemTable::try_new(Arc::new(schema.clone()), vec![vec![batch]])?; ctx.register_table("t", Arc::new(provider))?; - // let expected_result = ctx - // .sql("SELECT FIRST_VALUE(a order by a desc) FROM t group by b order by b") - // .await? - // .collect() - // .await?; - fn create_accumulator( data_type: &DataType, order_by: Vec>, @@ -299,8 +293,7 @@ async fn simple_udaf_order() -> Result<()> { vec![DataType::Int32], Arc::new(DataType::Int32), Volatility::Immutable, - // Arc::new(|d| create_accumulator(d, None, &dfs, &p, &schema)), - Arc::new(|d, order_by, schema| create_accumulator(d, order_by, schema)), + Arc::new(create_accumulator), Arc::new(vec![DataType::Int32, DataType::Int32, DataType::Boolean]), order_by, schema, diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 201d4fca2814..4a07bf7bfb43 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -1019,9 +1019,9 @@ pub fn create_udaf( )) } -// TODO: Merge with ordering /// Creates a new UDAF with a specific signature, state type and return type. /// The signature and state type must match the `Accumulator's implementation`. +#[allow(clippy::too_many_arguments)] pub fn create_udaf_with_ordering( name: &str, input_type: Vec, @@ -1158,6 +1158,7 @@ impl Debug for SimpleOrderedAggregateUDF { impl SimpleOrderedAggregateUDF { /// Create a new `AggregateUDFImpl` from a name, input types, return type, state type and /// implementation. Implementing [`AggregateUDFImpl`] allows more flexibility + #[allow(clippy::too_many_arguments)] pub fn new( name: impl Into, input_type: Vec, diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 7df22e01469b..243e3777db9d 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -425,6 +425,7 @@ fn roundtrip_aggregate_udaf() -> Result<()> { let aggregates: Vec> = vec![udaf::create_aggregate_expr( &udaf, &[col("b", &schema)?], + &[], &schema, "example_agg", )?]; From 3a7e9652af4ac039b4414963525bd87aff1cb90d Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sun, 18 Feb 2024 17:04:25 +0800 Subject: [PATCH 03/18] cleanup Signed-off-by: jayzhan211 --- .../user_defined/user_defined_aggregates.rs | 1 - datafusion/expr/src/expr_fn.rs | 21 ------------------- datafusion/physical-plan/src/windows/mod.rs | 2 +- datafusion/proto/src/physical_plan/mod.rs | 1 - 4 files changed, 1 insertion(+), 24 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 10f1494799a2..711f27443712 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -276,7 +276,6 @@ async fn simple_udaf_order() -> Result<()> { ordering_types.as_slice(), ordering_req, )?; - // let acc = FirstValueAccumulator::try_new(data_type, &[], vec![])?; Ok(Box::new(acc)) } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 4a07bf7bfb43..dae7269fa285 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -1181,27 +1181,6 @@ impl SimpleOrderedAggregateUDF { schema, } } - - pub fn new_with_signature( - name: impl Into, - signature: Signature, - return_type: DataType, - accumulator: AccumulatorFactoryFunctionWithOrdering, - state_type: Vec, - ordering_req: Vec>, - schema: Schema, - ) -> Self { - let name = name.into(); - Self { - name, - signature, - return_type, - accumulator, - state_type, - ordering_req, - schema, - } - } } impl AggregateUDFImpl for SimpleOrderedAggregateUDF { diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 4ddcfd38f1db..5143ccc49944 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -92,7 +92,7 @@ pub fn create_window_expr( )) } WindowFunctionDefinition::AggregateUDF(fun) => { - // TODO: Ordering not supported for Window UDFs + // TODO: Ordering not supported for Window UDFs yet let ordering_req = &[]; let aggregate = udaf::create_aggregate_expr( fun.as_ref(), diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index f6ff2d3efa40..43fb7e9e87d0 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -473,7 +473,6 @@ impl AsExecutionPlan for PhysicalPlanNode { } AggregateFunction::UserDefinedAggrFunction(udaf_name) => { let agg_udf = registry.udaf(udaf_name)?; - // TODO: Ordering not supported for UDAF here let ordering_req = &[]; udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, ordering_req, &physical_schema, name) } From 4917f56d33d8b6ff02c9797cacd99e17c3e6419b Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Wed, 21 Feb 2024 21:55:11 +0800 Subject: [PATCH 04/18] use one vector for ordering req Signed-off-by: jayzhan211 --- datafusion/core/src/physical_planner.rs | 2 +- .../user_defined/user_defined_aggregates.rs | 49 +++++++++---------- datafusion/expr/src/expr_fn.rs | 8 +-- datafusion/expr/src/function.rs | 2 +- datafusion/expr/src/udaf.rs | 2 +- 5 files changed, 31 insertions(+), 32 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index c32443ce0098..528f746ca9e6 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -251,7 +251,7 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { create_function_physical_name(func_def.name(), *distinct, args) } AggregateFunctionDefinition::UDF(fun) => { - // TODO: Add support for filter and order by in AggregateUDF + // TODO: Add support for filter by in AggregateUDF if filter.is_some() { return exec_err!( "aggregate expression with filter is not supported" diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 711f27443712..5f20a8bc0fc9 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -234,41 +234,40 @@ async fn simple_udaf_order() -> Result<()> { fn create_accumulator( data_type: &DataType, - order_by: Vec>, - schema: &Schema, + order_by: Vec, + schema: Option, ) -> Result> { + // test with ordering so schema is required + let schema = schema.unwrap(); + let mut all_sort_orders = vec![]; - assert_eq!(order_by.len(), 1); - - for exprs in order_by { - // Construct PhysicalSortExpr objects from Expr objects: - let mut sort_exprs = vec![]; - for expr in exprs { - if let Expr::Sort(sort) = expr { - if let Expr::Column(col) = sort.expr.as_ref() { - let name = &col.name; - let e = expressions::col(name, schema)?; - sort_exprs.push(PhysicalSortExpr { - expr: e, - options: SortOptions { - descending: !sort.asc, - nulls_first: sort.nulls_first, - }, - }); - } + // Construct PhysicalSortExpr objects from Expr objects: + let mut sort_exprs = vec![]; + for expr in order_by { + if let Expr::Sort(sort) = expr { + if let Expr::Column(col) = sort.expr.as_ref() { + let name = &col.name; + let e = expressions::col(name, &schema)?; + sort_exprs.push(PhysicalSortExpr { + expr: e, + options: SortOptions { + descending: !sort.asc, + nulls_first: sort.nulls_first, + }, + }); } } - if !sort_exprs.is_empty() { - all_sort_orders.extend(sort_exprs); - } + } + if !sort_exprs.is_empty() { + all_sort_orders.extend(sort_exprs); } let ordering_req = all_sort_orders; let ordering_types = ordering_req .iter() - .map(|e| e.expr.data_type(schema)) + .map(|e| e.expr.data_type(&schema)) .collect::>>()?; let acc = FirstValueAccumulator::try_new( @@ -284,7 +283,7 @@ async fn simple_udaf_order() -> Result<()> { asc: false, nulls_first: false, }); - let order_by = vec![vec![order_by]]; + let order_by = vec![order_by]; // define a udaf, using a DataFusion's accumulator let my_first = create_udaf_with_ordering( diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index dae7269fa285..fb8dd9cdf234 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -1029,7 +1029,7 @@ pub fn create_udaf_with_ordering( volatility: Volatility, accumulator: AccumulatorFactoryFunctionWithOrdering, state_type: Arc>, - ordering_req: Vec>, + ordering_req: Vec, schema: Schema, ) -> AggregateUDF { let return_type = Arc::try_unwrap(return_type).unwrap_or_else(|t| t.as_ref().clone()); @@ -1141,7 +1141,7 @@ pub struct SimpleOrderedAggregateUDF { return_type: DataType, accumulator: AccumulatorFactoryFunctionWithOrdering, state_type: Vec, - ordering_req: Vec>, + ordering_req: Vec, schema: Schema, } @@ -1166,7 +1166,7 @@ impl SimpleOrderedAggregateUDF { volatility: Volatility, accumulator: AccumulatorFactoryFunctionWithOrdering, state_type: Vec, - ordering_req: Vec>, + ordering_req: Vec, schema: Schema, ) -> Self { let name = name.into(); @@ -1201,7 +1201,7 @@ impl AggregateUDFImpl for SimpleOrderedAggregateUDF { } fn accumulator(&self, arg: &DataType) -> Result> { - (self.accumulator)(arg, self.ordering_req.clone(), &self.schema) + (self.accumulator)(arg, self.ordering_req.clone(), Some(self.schema.clone())) } fn state_type(&self, _return_type: &DataType) -> Result> { diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 1c00dcd665f5..f10a82ba6f12 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -48,7 +48,7 @@ pub type AccumulatorFactoryFunction = /// Factory that returns an accumulator for the given aggregate, given /// its return datatype, the ordering of the input arguments and the schema that are needed for ordering. pub type AccumulatorFactoryFunctionWithOrdering = Arc< - dyn Fn(&DataType, Vec>, &Schema) -> Result> + dyn Fn(&DataType, Vec, Option) -> Result> + Send + Sync, >; diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index e56723063e41..9b883359b5e8 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -264,7 +264,7 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// If the aggregate expression has a specialized /// [`GroupsAccumulator`] implementation. If this returns true, - /// `[Self::create_groups_accumulator`] will be called. + /// `[Self::create_groups_accumulator]` will be called. fn groups_accumulator_supported(&self) -> bool { false } From c9e8641bb5b0136ea28ecf024a9bd24e7bd2eee9 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Wed, 21 Feb 2024 22:27:59 +0800 Subject: [PATCH 05/18] add sort exprs to accumulator Signed-off-by: jayzhan211 --- .../user_defined/user_defined_aggregates.rs | 22 ++++++------ datafusion/expr/src/expr_fn.rs | 30 ++++++++++++---- datafusion/expr/src/udaf.rs | 35 ++++++++++++++++--- datafusion/physical-plan/src/udaf.rs | 5 +-- 4 files changed, 67 insertions(+), 25 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 5f20a8bc0fc9..0f7416c0a419 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -278,13 +278,6 @@ async fn simple_udaf_order() -> Result<()> { Ok(Box::new(acc)) } - let order_by = Expr::Sort(Sort { - expr: Box::new(Expr::Column(Column::new(Some("t"), "a"))), - asc: false, - nulls_first: false, - }); - let order_by = vec![order_by]; - // define a udaf, using a DataFusion's accumulator let my_first = create_udaf_with_ordering( "my_first", @@ -293,8 +286,12 @@ async fn simple_udaf_order() -> Result<()> { Volatility::Immutable, Arc::new(create_accumulator), Arc::new(vec![DataType::Int32, DataType::Int32, DataType::Boolean]), - order_by, - schema, + vec![Expr::Sort(Sort { + expr: Box::new(Expr::Column(Column::new(Some("t"), "a"))), + asc: false, + nulls_first: false, + })], + Some(schema), ); ctx.register_udaf(my_first); @@ -791,7 +788,12 @@ impl AggregateUDFImpl for TestGroupsAccumulator { Ok(DataType::UInt64) } - fn accumulator(&self, _arg: &DataType) -> Result> { + fn accumulator( + &self, + _arg: &DataType, + _sort_exprs: Vec, + _schmea: Option, + ) -> Result> { // should use groups accumulator panic!("accumulator shouldn't invoke"); } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index fb8dd9cdf234..debff818b35f 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -1030,7 +1030,7 @@ pub fn create_udaf_with_ordering( accumulator: AccumulatorFactoryFunctionWithOrdering, state_type: Arc>, ordering_req: Vec, - schema: Schema, + schema: Option, ) -> AggregateUDF { let return_type = Arc::try_unwrap(return_type).unwrap_or_else(|t| t.as_ref().clone()); let state_type = Arc::try_unwrap(state_type).unwrap_or_else(|t| t.as_ref().clone()); @@ -1124,7 +1124,12 @@ impl AggregateUDFImpl for SimpleAggregateUDF { Ok(self.return_type.clone()) } - fn accumulator(&self, arg: &DataType) -> Result> { + fn accumulator( + &self, + arg: &DataType, + sort_exprs: Vec, + schema: Option, + ) -> Result> { (self.accumulator)(arg) } @@ -1142,7 +1147,7 @@ pub struct SimpleOrderedAggregateUDF { accumulator: AccumulatorFactoryFunctionWithOrdering, state_type: Vec, ordering_req: Vec, - schema: Schema, + schema: Option, } impl Debug for SimpleOrderedAggregateUDF { @@ -1167,7 +1172,7 @@ impl SimpleOrderedAggregateUDF { accumulator: AccumulatorFactoryFunctionWithOrdering, state_type: Vec, ordering_req: Vec, - schema: Schema, + schema: Option, ) -> Self { let name = name.into(); let signature = Signature::exact(input_type, volatility); @@ -1200,13 +1205,26 @@ impl AggregateUDFImpl for SimpleOrderedAggregateUDF { Ok(self.return_type.clone()) } - fn accumulator(&self, arg: &DataType) -> Result> { - (self.accumulator)(arg, self.ordering_req.clone(), Some(self.schema.clone())) + fn accumulator( + &self, + arg: &DataType, + sort_exprs: Vec, + schema: Option, + ) -> Result> { + (self.accumulator)(arg, sort_exprs, schema) } fn state_type(&self, _return_type: &DataType) -> Result> { Ok(self.state_type.clone()) } + + fn sort_exprs(&self) -> Vec { + self.ordering_req.clone() + } + + fn schema(&self) -> Option { + self.schema.clone() + } } /// Creates a new UDWF with a specific signature, state type and return type. diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 9b883359b5e8..8707c803001b 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -22,8 +22,8 @@ use crate::{Accumulator, Expr}; use crate::{ AccumulatorFactoryFunction, ReturnTypeFunction, Signature, StateTypeFunction, }; -use arrow::datatypes::DataType; -use datafusion_common::{not_impl_err, Result}; +use arrow::datatypes::{DataType, Schema}; +use datafusion_common::{not_impl_err, DataFusionError, Result}; use std::any::Any; use std::fmt::{self, Debug, Formatter}; use std::sync::Arc; @@ -155,8 +155,11 @@ impl AggregateUDF { /// Return an accumulator the given aggregate, given /// its return datatype. + // pub fn accumulator(&self, return_type: &DataType, sort_exprs: Vec, schema: Option) -> Result> { pub fn accumulator(&self, return_type: &DataType) -> Result> { - self.inner.accumulator(return_type) + let sort_exprs = self.inner.sort_exprs(); + let schema = self.inner.schema(); + self.inner.accumulator(return_type, sort_exprs, schema) } /// Return the type of the intermediate state used by this aggregator, given @@ -174,6 +177,10 @@ impl AggregateUDF { pub fn create_groups_accumulator(&self) -> Result> { self.inner.create_groups_accumulator() } + + pub fn sort_exprs() -> Vec { + vec![] + } } impl From for AggregateUDF @@ -256,7 +263,12 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// Return a new [`Accumulator`] that aggregates values for a specific /// group during query execution. - fn accumulator(&self, arg: &DataType) -> Result>; + fn accumulator( + &self, + arg: &DataType, + sort_exprs: Vec, + schema: Option, + ) -> Result>; /// Return the type used to serialize the [`Accumulator`]'s intermediate state. /// See [`Accumulator::state()`] for more details @@ -277,6 +289,14 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { fn create_groups_accumulator(&self) -> Result> { not_impl_err!("GroupsAccumulator hasn't been implemented for {self:?} yet") } + + fn sort_exprs(&self) -> Vec { + vec![] + } + + fn schema(&self) -> Option { + None + } } /// Implementation of [`AggregateUDFImpl`] that wraps the function style pointers @@ -323,7 +343,12 @@ impl AggregateUDFImpl for AggregateUDFLegacyWrapper { Ok(res.as_ref().clone()) } - fn accumulator(&self, arg: &DataType) -> Result> { + fn accumulator( + &self, + arg: &DataType, + _sort_exprs: Vec, + _schema: Option, + ) -> Result> { (self.accumulator)(arg) } diff --git a/datafusion/physical-plan/src/udaf.rs b/datafusion/physical-plan/src/udaf.rs index f30f95924a82..3049f6cf642a 100644 --- a/datafusion/physical-plan/src/udaf.rs +++ b/datafusion/physical-plan/src/udaf.rs @@ -22,10 +22,7 @@ use fmt::Debug; use std::any::Any; use std::fmt; -use arrow::{ - datatypes::Field, - datatypes::{DataType, Schema}, -}; +use arrow::datatypes::{DataType, Field, Schema}; use super::{expressions::format_state_name, Accumulator, AggregateExpr}; use datafusion_common::{not_impl_err, Result}; From 3a5f0d1fe7bc6b32a5eb27285fe16a8aedfed733 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Wed, 21 Feb 2024 22:52:04 +0800 Subject: [PATCH 06/18] clippy Signed-off-by: jayzhan211 --- datafusion-examples/examples/advanced_udaf.rs | 10 ++++++++-- datafusion-examples/examples/simple_udaf.rs | 2 +- .../tests/user_defined/user_defined_aggregates.rs | 8 ++++---- .../user_defined/user_defined_scalar_functions.rs | 2 +- datafusion/expr/src/expr_fn.rs | 12 +++++------- datafusion/expr/src/function.rs | 9 ++------- datafusion/expr/src/udaf.rs | 6 +++--- datafusion/optimizer/src/analyzer/type_coercion.rs | 4 ++-- datafusion/optimizer/src/common_subexpr_eliminate.rs | 3 ++- datafusion/proto/src/bytes/mod.rs | 2 +- .../proto/tests/cases/roundtrip_logical_plan.rs | 4 ++-- .../proto/tests/cases/roundtrip_physical_plan.rs | 3 ++- .../substrait/tests/cases/roundtrip_logical_plan.rs | 2 +- 13 files changed, 34 insertions(+), 33 deletions(-) diff --git a/datafusion-examples/examples/advanced_udaf.rs b/datafusion-examples/examples/advanced_udaf.rs index 10164a850bfb..9fed6efe70f1 100644 --- a/datafusion-examples/examples/advanced_udaf.rs +++ b/datafusion-examples/examples/advanced_udaf.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use arrow_schema::Schema; use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility}; use datafusion_physical_expr::NullState; use std::{any::Any, sync::Arc}; @@ -85,7 +86,12 @@ impl AggregateUDFImpl for GeoMeanUdaf { /// is supported, DataFusion will use this row oriented /// accumulator when the aggregate function is used as a window function /// or when there are only aggregates (no GROUP BY columns) in the plan. - fn accumulator(&self, _arg: &DataType) -> Result> { + fn accumulator( + &self, + _arg: &DataType, + _sort_exprs: Vec, + _schema: Option, + ) -> Result> { Ok(Box::new(GeometricMean::new())) } @@ -191,7 +197,7 @@ impl Accumulator for GeometricMean { // create local session context with an in-memory table fn create_context() -> Result { - use datafusion::arrow::datatypes::{Field, Schema}; + use datafusion::arrow::datatypes::Field; use datafusion::datasource::MemTable; // define a schema. let schema = Arc::new(Schema::new(vec![ diff --git a/datafusion-examples/examples/simple_udaf.rs b/datafusion-examples/examples/simple_udaf.rs index 0996a67245a8..3a62e28b0568 100644 --- a/datafusion-examples/examples/simple_udaf.rs +++ b/datafusion-examples/examples/simple_udaf.rs @@ -150,7 +150,7 @@ async fn main() -> Result<()> { Arc::new(DataType::Float64), Volatility::Immutable, // This is the accumulator factory; DataFusion uses it to create new accumulators. - Arc::new(|_| Ok(Box::new(GeometricMean::new()))), + Arc::new(|_, _, _| Ok(Box::new(GeometricMean::new()))), // This is the description of the state. `state()` must match the types here. Arc::new(vec![DataType::Float64, DataType::UInt32]), ); diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 0f7416c0a419..720eef42831c 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -341,7 +341,7 @@ async fn simple_udaf() -> Result<()> { vec![DataType::Float64], Arc::new(DataType::Float64), Volatility::Immutable, - Arc::new(|_| Ok(Box::::default())), + Arc::new(|_, _, _| Ok(Box::::default())), Arc::new(vec![DataType::UInt64, DataType::Float64]), ); @@ -397,7 +397,7 @@ async fn case_sensitive_identifiers_user_defined_aggregates() -> Result<()> { vec![DataType::Float64], Arc::new(DataType::Float64), Volatility::Immutable, - Arc::new(|_| Ok(Box::::default())), + Arc::new(|_, _, _| Ok(Box::::default())), Arc::new(vec![DataType::UInt64, DataType::Float64]), ); @@ -568,7 +568,7 @@ impl TimeSum { let captured_state = Arc::clone(&test_state); let accumulator: AccumulatorFactoryFunction = - Arc::new(move |_| Ok(Box::new(Self::new(Arc::clone(&captured_state))))); + Arc::new(move |_, _, _| Ok(Box::new(Self::new(Arc::clone(&captured_state))))); let time_sum = AggregateUDF::from(SimpleAggregateUDF::new( name, @@ -667,7 +667,7 @@ impl FirstSelector { let signatures = vec![TypeSignature::Exact(Self::input_datatypes())]; let accumulator: AccumulatorFactoryFunction = - Arc::new(|_| Ok(Box::new(Self::new()))); + Arc::new(|_, _, _| Ok(Box::new(Self::new()))); let volatility = Volatility::Immutable; diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index d9b60134b3d9..4229220e697f 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -294,7 +294,7 @@ async fn udaf_as_window_func() -> Result<()> { vec![DataType::Int32], Arc::new(DataType::Int32), Volatility::Immutable, - Arc::new(|_| Ok(Box::new(MyAccumulator))), + Arc::new(|_, _, _| Ok(Box::new(MyAccumulator))), Arc::new(vec![DataType::Int32]), ); diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index debff818b35f..c77c1aa48375 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -21,9 +21,7 @@ use crate::expr::{ AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, InSubquery, Placeholder, ScalarFunction, TryCast, }; -use crate::function::{ - AccumulatorFactoryFunctionWithOrdering, PartitionEvaluatorFactory, -}; +use crate::function::PartitionEvaluatorFactory; use crate::{ aggregate_function, built_in_function, conditional_expressions::CaseBuilder, logical_plan::Subquery, AccumulatorFactoryFunction, AggregateUDF, @@ -1027,7 +1025,7 @@ pub fn create_udaf_with_ordering( input_type: Vec, return_type: Arc, volatility: Volatility, - accumulator: AccumulatorFactoryFunctionWithOrdering, + accumulator: AccumulatorFactoryFunction, state_type: Arc>, ordering_req: Vec, schema: Option, @@ -1130,7 +1128,7 @@ impl AggregateUDFImpl for SimpleAggregateUDF { sort_exprs: Vec, schema: Option, ) -> Result> { - (self.accumulator)(arg) + (self.accumulator)(arg, sort_exprs, schema) } fn state_type(&self, _return_type: &DataType) -> Result> { @@ -1144,7 +1142,7 @@ pub struct SimpleOrderedAggregateUDF { name: String, signature: Signature, return_type: DataType, - accumulator: AccumulatorFactoryFunctionWithOrdering, + accumulator: AccumulatorFactoryFunction, state_type: Vec, ordering_req: Vec, schema: Option, @@ -1169,7 +1167,7 @@ impl SimpleOrderedAggregateUDF { input_type: Vec, return_type: DataType, volatility: Volatility, - accumulator: AccumulatorFactoryFunctionWithOrdering, + accumulator: AccumulatorFactoryFunction, state_type: Vec, ordering_req: Vec, schema: Option, diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index f10a82ba6f12..c1e162286cb6 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -41,13 +41,8 @@ pub type ReturnTypeFunction = Arc Result> + Send + Sync>; /// Factory that returns an accumulator for the given aggregate, given -/// its return datatype. -pub type AccumulatorFactoryFunction = - Arc Result> + Send + Sync>; - -/// Factory that returns an accumulator for the given aggregate, given -/// its return datatype, the ordering of the input arguments and the schema that are needed for ordering. -pub type AccumulatorFactoryFunctionWithOrdering = Arc< +/// its return datatype, the sorting expressions and the schema for ordering. +pub type AccumulatorFactoryFunction = Arc< dyn Fn(&DataType, Vec, Option) -> Result> + Send + Sync, diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 8707c803001b..534e45e13521 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -346,10 +346,10 @@ impl AggregateUDFImpl for AggregateUDFLegacyWrapper { fn accumulator( &self, arg: &DataType, - _sort_exprs: Vec, - _schema: Option, + sort_exprs: Vec, + schema: Option, ) -> Result> { - (self.accumulator)(arg) + (self.accumulator)(arg, sort_exprs, schema) } fn state_type(&self, return_type: &DataType) -> Result> { diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 496def95e1bc..7281425a5721 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -895,7 +895,7 @@ mod test { vec![DataType::Float64], Arc::new(DataType::Float64), Volatility::Immutable, - Arc::new(|_| Ok(Box::::default())), + Arc::new(|_, _, _| Ok(Box::::default())), Arc::new(vec![DataType::UInt64, DataType::Float64]), ); let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf( @@ -916,7 +916,7 @@ mod test { let return_type = DataType::Float64; let state_type = vec![DataType::UInt64, DataType::Float64]; let accumulator: AccumulatorFactoryFunction = - Arc::new(|_| Ok(Box::::default())); + Arc::new(|_, _, _| Ok(Box::::default())); let my_avg = AggregateUDF::from(SimpleAggregateUDF::new_with_signature( "MY_AVG", Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable), diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 30c184a28e33..487e06674e7a 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -975,7 +975,8 @@ mod test { let table_scan = test_table_scan()?; let return_type = DataType::UInt32; - let accumulator: AccumulatorFactoryFunction = Arc::new(|_| unimplemented!()); + let accumulator: AccumulatorFactoryFunction = + Arc::new(|_, _, _| unimplemented!()); let state_type = vec![DataType::UInt32]; let udf_agg = |inner: Expr| { Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs index 610c533d574c..4c570d343574 100644 --- a/datafusion/proto/src/bytes/mod.rs +++ b/datafusion/proto/src/bytes/mod.rs @@ -127,7 +127,7 @@ impl Serializeable for Expr { vec![arrow::datatypes::DataType::Null], Arc::new(arrow::datatypes::DataType::Null), Volatility::Immutable, - Arc::new(|_| unimplemented!()), + Arc::new(|_, _, _| unimplemented!()), Arc::new(vec![]), ))) } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index fad50d3ecddc..42e8718ff097 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -1738,7 +1738,7 @@ fn roundtrip_aggregate_udf() { Arc::new(DataType::Float64), Volatility::Immutable, // This is the accumulator factory; DataFusion uses it to create new accumulators. - Arc::new(|_| Ok(Box::new(Dummy {}))), + Arc::new(|_, _, _| Ok(Box::new(Dummy {}))), // This is the description of the state. `state()` must match the types here. Arc::new(vec![DataType::Float64, DataType::UInt32]), ); @@ -1953,7 +1953,7 @@ fn roundtrip_window() { Arc::new(DataType::Float64), Volatility::Immutable, // This is the accumulator factory; DataFusion uses it to create new accumulators. - Arc::new(|_| Ok(Box::new(DummyAggr {}))), + Arc::new(|_, _, _| Ok(Box::new(DummyAggr {}))), // This is the description of the state. `state()` must match the types here. Arc::new(vec![DataType::Float64, DataType::UInt32]), ); diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 243e3777db9d..c7ff0cf782d0 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -405,7 +405,8 @@ fn roundtrip_aggregate_udaf() -> Result<()> { } let return_type = DataType::Int64; - let accumulator: AccumulatorFactoryFunction = Arc::new(|_| Ok(Box::new(Example))); + let accumulator: AccumulatorFactoryFunction = + Arc::new(|_, _, _| Ok(Box::new(Example))); let state_type = vec![DataType::Int64]; let udaf = AggregateUDF::from(SimpleAggregateUDF::new_with_signature( diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index bc9cc66b7626..a24a84ad76be 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -750,7 +750,7 @@ async fn roundtrip_aggregate_udf() -> Result<()> { Arc::new(DataType::Int64), Volatility::Immutable, // This is the accumulator factory; DataFusion uses it to create new accumulators. - Arc::new(|_| Ok(Box::new(Dummy {}))), + Arc::new(|_, _, _| Ok(Box::new(Dummy {}))), // This is the description of the state. `state()` must match the types here. Arc::new(vec![DataType::Float64, DataType::UInt32]), ); From a3ea00adf0595d6862ee6ce76fefd4564a3ff3b2 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Wed, 21 Feb 2024 22:55:08 +0800 Subject: [PATCH 07/18] cleanup Signed-off-by: jayzhan211 --- datafusion/expr/src/udaf.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 534e45e13521..1dc1cb12707c 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -153,9 +153,7 @@ impl AggregateUDF { self.inner.return_type(args) } - /// Return an accumulator the given aggregate, given - /// its return datatype. - // pub fn accumulator(&self, return_type: &DataType, sort_exprs: Vec, schema: Option) -> Result> { + /// Return an accumulator the given aggregate, given its return datatype pub fn accumulator(&self, return_type: &DataType) -> Result> { let sort_exprs = self.inner.sort_exprs(); let schema = self.inner.schema(); From f349f215a1e10b129c7a910378d91c32084168ab Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Wed, 21 Feb 2024 23:10:43 +0800 Subject: [PATCH 08/18] fix doc test Signed-off-by: jayzhan211 --- datafusion/expr/src/udaf.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 1dc1cb12707c..3409df62a0c6 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -205,8 +205,9 @@ where /// # use std::any::Any; /// # use arrow::datatypes::DataType; /// # use datafusion_common::{DataFusionError, plan_err, Result}; -/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility}; +/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility, Expr}; /// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator}; +/// # use arrow::datatypes::Schema; /// #[derive(Debug, Clone)] /// struct GeoMeanUdf { /// signature: Signature @@ -232,7 +233,7 @@ where /// Ok(DataType::Float64) /// } /// // This is the accumulator factory; DataFusion uses it to create new accumulators. -/// fn accumulator(&self, _arg: &DataType) -> Result> { unimplemented!() } +/// fn accumulator(&self, _arg: &DataType, _sort_exprs: Vec, _schema: Option) -> Result> { unimplemented!() } /// fn state_type(&self, _return_type: &DataType) -> Result> { /// Ok(vec![DataType::Float64, DataType::UInt32]) /// } From 6fcdaac8e484cef16df5d905c0178e17fa35c840 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Tue, 27 Feb 2024 22:46:24 +0800 Subject: [PATCH 09/18] change to ref Signed-off-by: jayzhan211 --- datafusion-examples/examples/advanced_udaf.rs | 2 +- .../tests/user_defined/user_defined_aggregates.rs | 10 +++++----- datafusion/expr/src/expr_fn.rs | 14 +++++++------- datafusion/expr/src/function.rs | 2 +- datafusion/expr/src/udaf.rs | 11 +++++++---- 5 files changed, 21 insertions(+), 18 deletions(-) diff --git a/datafusion-examples/examples/advanced_udaf.rs b/datafusion-examples/examples/advanced_udaf.rs index 9fed6efe70f1..ef43dc47f2ce 100644 --- a/datafusion-examples/examples/advanced_udaf.rs +++ b/datafusion-examples/examples/advanced_udaf.rs @@ -90,7 +90,7 @@ impl AggregateUDFImpl for GeoMeanUdaf { &self, _arg: &DataType, _sort_exprs: Vec, - _schema: Option, + _schema: Option<&Schema>, ) -> Result> { Ok(Box::new(GeometricMean::new())) } diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 720eef42831c..1f637c0d97a7 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -235,7 +235,7 @@ async fn simple_udaf_order() -> Result<()> { fn create_accumulator( data_type: &DataType, order_by: Vec, - schema: Option, + schema: Option<&Schema>, ) -> Result> { // test with ordering so schema is required let schema = schema.unwrap(); @@ -248,7 +248,7 @@ async fn simple_udaf_order() -> Result<()> { if let Expr::Sort(sort) = expr { if let Expr::Column(col) = sort.expr.as_ref() { let name = &col.name; - let e = expressions::col(name, &schema)?; + let e = expressions::col(name, schema)?; sort_exprs.push(PhysicalSortExpr { expr: e, options: SortOptions { @@ -267,7 +267,7 @@ async fn simple_udaf_order() -> Result<()> { let ordering_types = ordering_req .iter() - .map(|e| e.expr.data_type(&schema)) + .map(|e| e.expr.data_type(schema)) .collect::>>()?; let acc = FirstValueAccumulator::try_new( @@ -291,7 +291,7 @@ async fn simple_udaf_order() -> Result<()> { asc: false, nulls_first: false, })], - Some(schema), + Some(&schema), ); ctx.register_udaf(my_first); @@ -792,7 +792,7 @@ impl AggregateUDFImpl for TestGroupsAccumulator { &self, _arg: &DataType, _sort_exprs: Vec, - _schmea: Option, + _schmea: Option<&Schema>, ) -> Result> { // should use groups accumulator panic!("accumulator shouldn't invoke"); diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index c77c1aa48375..62856873c833 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -1028,7 +1028,7 @@ pub fn create_udaf_with_ordering( accumulator: AccumulatorFactoryFunction, state_type: Arc>, ordering_req: Vec, - schema: Option, + schema: Option<&Schema>, ) -> AggregateUDF { let return_type = Arc::try_unwrap(return_type).unwrap_or_else(|t| t.as_ref().clone()); let state_type = Arc::try_unwrap(state_type).unwrap_or_else(|t| t.as_ref().clone()); @@ -1126,7 +1126,7 @@ impl AggregateUDFImpl for SimpleAggregateUDF { &self, arg: &DataType, sort_exprs: Vec, - schema: Option, + schema: Option<&Schema>, ) -> Result> { (self.accumulator)(arg, sort_exprs, schema) } @@ -1170,7 +1170,7 @@ impl SimpleOrderedAggregateUDF { accumulator: AccumulatorFactoryFunction, state_type: Vec, ordering_req: Vec, - schema: Option, + schema: Option<&Schema>, ) -> Self { let name = name.into(); let signature = Signature::exact(input_type, volatility); @@ -1181,7 +1181,7 @@ impl SimpleOrderedAggregateUDF { accumulator, state_type, ordering_req, - schema, + schema: schema.cloned(), } } } @@ -1207,7 +1207,7 @@ impl AggregateUDFImpl for SimpleOrderedAggregateUDF { &self, arg: &DataType, sort_exprs: Vec, - schema: Option, + schema: Option<&Schema>, ) -> Result> { (self.accumulator)(arg, sort_exprs, schema) } @@ -1220,8 +1220,8 @@ impl AggregateUDFImpl for SimpleOrderedAggregateUDF { self.ordering_req.clone() } - fn schema(&self) -> Option { - self.schema.clone() + fn schema(&self) -> Option<&Schema> { + self.schema.as_ref() } } diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index c1e162286cb6..bbb631af8fd5 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -43,7 +43,7 @@ pub type ReturnTypeFunction = /// Factory that returns an accumulator for the given aggregate, given /// its return datatype, the sorting expressions and the schema for ordering. pub type AccumulatorFactoryFunction = Arc< - dyn Fn(&DataType, Vec, Option) -> Result> + dyn Fn(&DataType, Vec, Option<&Schema>) -> Result> + Send + Sync, >; diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 3409df62a0c6..0e4785b25673 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -261,12 +261,13 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { fn return_type(&self, arg_types: &[DataType]) -> Result; /// Return a new [`Accumulator`] that aggregates values for a specific - /// group during query execution. + /// group during query execution. sort_exprs is a list of ordering expressions, + /// and schema is used while ordering. fn accumulator( &self, arg: &DataType, sort_exprs: Vec, - schema: Option, + schema: Option<&Schema>, ) -> Result>; /// Return the type used to serialize the [`Accumulator`]'s intermediate state. @@ -289,11 +290,13 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { not_impl_err!("GroupsAccumulator hasn't been implemented for {self:?} yet") } + /// Return the ordering expressions for the accumulator fn sort_exprs(&self) -> Vec { vec![] } - fn schema(&self) -> Option { + /// Return the schema for the accumulator + fn schema(&self) -> Option<&Schema> { None } } @@ -346,7 +349,7 @@ impl AggregateUDFImpl for AggregateUDFLegacyWrapper { &self, arg: &DataType, sort_exprs: Vec, - schema: Option, + schema: Option<&Schema>, ) -> Result> { (self.accumulator)(arg, sort_exprs, schema) } From c3512a6c150914f288f7dc19b4dea3274f9cc295 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Tue, 27 Feb 2024 22:47:09 +0800 Subject: [PATCH 10/18] fix typo Signed-off-by: jayzhan211 --- datafusion/core/tests/user_defined/user_defined_aggregates.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 1f637c0d97a7..41c19c3ebb7e 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -792,7 +792,7 @@ impl AggregateUDFImpl for TestGroupsAccumulator { &self, _arg: &DataType, _sort_exprs: Vec, - _schmea: Option<&Schema>, + _schema: Option<&Schema>, ) -> Result> { // should use groups accumulator panic!("accumulator shouldn't invoke"); From 092d46e50666a422748dbf4755fc8ef260ba3e7f Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Tue, 27 Feb 2024 23:03:35 +0800 Subject: [PATCH 11/18] fix doc Signed-off-by: jayzhan211 --- datafusion/expr/src/udaf.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 0e4785b25673..2a55904e01b7 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -233,7 +233,7 @@ where /// Ok(DataType::Float64) /// } /// // This is the accumulator factory; DataFusion uses it to create new accumulators. -/// fn accumulator(&self, _arg: &DataType, _sort_exprs: Vec, _schema: Option) -> Result> { unimplemented!() } +/// fn accumulator(&self, _arg: &DataType, _sort_exprs: Vec, _schema: Option<&Schema>) -> Result> { unimplemented!() } /// fn state_type(&self, _return_type: &DataType) -> Result> { /// Ok(vec![DataType::Float64, DataType::UInt32]) /// } From 8592e6bccc64c45a7c120b94c02b1665bd81b170 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Fri, 1 Mar 2024 19:26:06 +0800 Subject: [PATCH 12/18] fmt Signed-off-by: jayzhan211 --- datafusion/expr/src/udaf.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 2a55904e01b7..9b038b7c38a9 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -23,7 +23,7 @@ use crate::{ AccumulatorFactoryFunction, ReturnTypeFunction, Signature, StateTypeFunction, }; use arrow::datatypes::{DataType, Schema}; -use datafusion_common::{not_impl_err, DataFusionError, Result}; +use datafusion_common::{not_impl_err, Result}; use std::any::Any; use std::fmt::{self, Debug, Formatter}; use std::sync::Arc; From 0f8fc2427abc405feff8b2c8781a3ad1fb043307 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Fri, 1 Mar 2024 20:34:23 +0800 Subject: [PATCH 13/18] move schema and logical ordering exprs Signed-off-by: jayzhan211 --- datafusion-examples/examples/advanced_udaf.rs | 4 +-- datafusion/core/src/physical_planner.rs | 16 +++++++----- .../user_defined/user_defined_aggregates.rs | 26 +++++++------------ datafusion/expr/src/expr_fn.rs | 8 +++--- datafusion/expr/src/function.rs | 4 +-- datafusion/expr/src/udaf.rs | 19 +++++++++----- datafusion/physical-plan/src/udaf.rs | 22 ++++++++++++---- datafusion/physical-plan/src/windows/mod.rs | 3 +++ datafusion/proto/src/physical_plan/mod.rs | 2 +- .../tests/cases/roundtrip_physical_plan.rs | 1 + 10 files changed, 60 insertions(+), 45 deletions(-) diff --git a/datafusion-examples/examples/advanced_udaf.rs b/datafusion-examples/examples/advanced_udaf.rs index ef43dc47f2ce..06995863a245 100644 --- a/datafusion-examples/examples/advanced_udaf.rs +++ b/datafusion-examples/examples/advanced_udaf.rs @@ -89,8 +89,8 @@ impl AggregateUDFImpl for GeoMeanUdaf { fn accumulator( &self, _arg: &DataType, - _sort_exprs: Vec, - _schema: Option<&Schema>, + _sort_exprs: &[Expr], + _schema: &Schema, ) -> Result> { Ok(Box::new(GeometricMean::new())) } diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 528f746ca9e6..4d1cf1b28d8f 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1672,7 +1672,9 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( )?), None => None, }; - let order_by = match order_by { + + let sort_exprs = order_by.clone().unwrap_or(vec![]); + let phy_order_by = match order_by { Some(e) => Some( e.iter() .map(|expr| { @@ -1691,7 +1693,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( == NullTreatment::IgnoreNulls; let (agg_expr, filter, order_by) = match func_def { AggregateFunctionDefinition::BuiltIn(fun) => { - let ordering_reqs = order_by.clone().unwrap_or(vec![]); + let ordering_reqs = phy_order_by.clone().unwrap_or(vec![]); let agg_expr = aggregates::create_aggregate_expr( fun, *distinct, @@ -1701,19 +1703,21 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( name, ignore_nulls, )?; - (agg_expr, filter, order_by) + (agg_expr, filter, phy_order_by) } AggregateFunctionDefinition::UDF(fun) => { let ordering_reqs: Vec = - order_by.clone().unwrap_or(vec![]); + phy_order_by.clone().unwrap_or(vec![]); + let agg_expr = udaf::create_aggregate_expr( fun, &args, + &sort_exprs, &ordering_reqs, physical_input_schema, name, )?; - (agg_expr, filter, order_by) + (agg_expr, filter, phy_order_by) } AggregateFunctionDefinition::Name(_) => { return internal_err!( @@ -1721,7 +1725,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( ) } }; - Ok((agg_expr, filter, order_by)) + Ok((agg_expr, filter, phy_order_by)) } other => internal_err!("Invalid aggregate expression '{other:?}'"), } diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 41c19c3ebb7e..85441838c19d 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -42,9 +42,7 @@ use datafusion::{ prelude::SessionContext, scalar::ScalarValue, }; -use datafusion_common::{ - assert_contains, cast::as_primitive_array, exec_err, Column, DataFusionError, -}; +use datafusion_common::{assert_contains, cast::as_primitive_array, exec_err, Column}; use datafusion_expr::{ create_udaf, create_udaf_with_ordering, expr::Sort, AggregateUDFImpl, Expr, GroupsAccumulator, SimpleAggregateUDF, @@ -234,12 +232,9 @@ async fn simple_udaf_order() -> Result<()> { fn create_accumulator( data_type: &DataType, - order_by: Vec, - schema: Option<&Schema>, + order_by: &[Expr], + schema: &Schema, ) -> Result> { - // test with ordering so schema is required - let schema = schema.unwrap(); - let mut all_sort_orders = vec![]; // Construct PhysicalSortExpr objects from Expr objects: @@ -265,16 +260,13 @@ async fn simple_udaf_order() -> Result<()> { let ordering_req = all_sort_orders; - let ordering_types = ordering_req + let ordering_dtypes = ordering_req .iter() .map(|e| e.expr.data_type(schema)) .collect::>>()?; - let acc = FirstValueAccumulator::try_new( - data_type, - ordering_types.as_slice(), - ordering_req, - )?; + let acc = + FirstValueAccumulator::try_new(data_type, &ordering_dtypes, ordering_req)?; Ok(Box::new(acc)) } @@ -369,7 +361,7 @@ async fn deregister_udaf() -> Result<()> { vec![DataType::Float64], Arc::new(DataType::Float64), Volatility::Immutable, - Arc::new(|_| Ok(Box::::default())), + Arc::new(|_, _, _| Ok(Box::::default())), Arc::new(vec![DataType::UInt64, DataType::Float64]), ); @@ -791,8 +783,8 @@ impl AggregateUDFImpl for TestGroupsAccumulator { fn accumulator( &self, _arg: &DataType, - _sort_exprs: Vec, - _schema: Option<&Schema>, + _sort_exprs: &[Expr], + _schema: &Schema, ) -> Result> { // should use groups accumulator panic!("accumulator shouldn't invoke"); diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 62856873c833..3ba4808b9694 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -1125,8 +1125,8 @@ impl AggregateUDFImpl for SimpleAggregateUDF { fn accumulator( &self, arg: &DataType, - sort_exprs: Vec, - schema: Option<&Schema>, + sort_exprs: &[Expr], + schema: &Schema, ) -> Result> { (self.accumulator)(arg, sort_exprs, schema) } @@ -1206,8 +1206,8 @@ impl AggregateUDFImpl for SimpleOrderedAggregateUDF { fn accumulator( &self, arg: &DataType, - sort_exprs: Vec, - schema: Option<&Schema>, + sort_exprs: &[Expr], + schema: &Schema, ) -> Result> { (self.accumulator)(arg, sort_exprs, schema) } diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index bbb631af8fd5..7a92724227e9 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -43,9 +43,7 @@ pub type ReturnTypeFunction = /// Factory that returns an accumulator for the given aggregate, given /// its return datatype, the sorting expressions and the schema for ordering. pub type AccumulatorFactoryFunction = Arc< - dyn Fn(&DataType, Vec, Option<&Schema>) -> Result> - + Send - + Sync, + dyn Fn(&DataType, &[Expr], &Schema) -> Result> + Send + Sync, >; /// Factory that creates a PartitionEvaluator for the given window diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 9b038b7c38a9..121cd6834306 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -154,9 +154,14 @@ impl AggregateUDF { } /// Return an accumulator the given aggregate, given its return datatype - pub fn accumulator(&self, return_type: &DataType) -> Result> { - let sort_exprs = self.inner.sort_exprs(); - let schema = self.inner.schema(); + pub fn accumulator( + &self, + return_type: &DataType, + sort_exprs: &[Expr], + schema: &Schema, + ) -> Result> { + // let sort_exprs = self.inner.sort_exprs(); + // let schema = self.inner.schema(); self.inner.accumulator(return_type, sort_exprs, schema) } @@ -266,8 +271,8 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { fn accumulator( &self, arg: &DataType, - sort_exprs: Vec, - schema: Option<&Schema>, + sort_exprs: &[Expr], + schema: &Schema, ) -> Result>; /// Return the type used to serialize the [`Accumulator`]'s intermediate state. @@ -348,8 +353,8 @@ impl AggregateUDFImpl for AggregateUDFLegacyWrapper { fn accumulator( &self, arg: &DataType, - sort_exprs: Vec, - schema: Option<&Schema>, + sort_exprs: &[Expr], + schema: &Schema, ) -> Result> { (self.accumulator)(arg, sort_exprs, schema) } diff --git a/datafusion/physical-plan/src/udaf.rs b/datafusion/physical-plan/src/udaf.rs index 3049f6cf642a..b74a2d971d36 100644 --- a/datafusion/physical-plan/src/udaf.rs +++ b/datafusion/physical-plan/src/udaf.rs @@ -17,7 +17,7 @@ //! This module contains functions and structs supporting user-defined aggregate functions. -use datafusion_expr::GroupsAccumulator; +use datafusion_expr::{Expr, GroupsAccumulator}; use fmt::Debug; use std::any::Any; use std::fmt; @@ -37,13 +37,14 @@ use std::sync::Arc; pub fn create_aggregate_expr( fun: &AggregateUDF, input_phy_exprs: &[Arc], + sort_exprs: &[Expr], ordering_req: &[PhysicalSortExpr], - input_schema: &Schema, + schema: &Schema, name: impl Into, ) -> Result> { let input_exprs_types = input_phy_exprs .iter() - .map(|arg| arg.data_type(input_schema)) + .map(|arg| arg.data_type(schema)) .collect::>>()?; Ok(Arc::new(AggregateFunctionExpr { @@ -51,6 +52,8 @@ pub fn create_aggregate_expr( args: input_phy_exprs.to_vec(), data_type: fun.return_type(&input_exprs_types)?, name: name.into(), + schema: schema.clone(), + sort_exprs: sort_exprs.to_vec(), ordering_req: ordering_req.to_vec(), })) } @@ -63,6 +66,10 @@ pub struct AggregateFunctionExpr { /// Output / return type of this aggregate data_type: DataType, name: String, + schema: Schema, + // The logical order by expressions + sort_exprs: Vec, + // The physical order by expressions ordering_req: LexOrdering, } @@ -106,11 +113,16 @@ impl AggregateExpr for AggregateFunctionExpr { } fn create_accumulator(&self) -> Result> { - self.fun.accumulator(&self.data_type) + self.fun + .accumulator(&self.data_type, self.sort_exprs.as_slice(), &self.schema) } fn create_sliding_accumulator(&self) -> Result> { - let accumulator = self.fun.accumulator(&self.data_type)?; + let accumulator = self.fun.accumulator( + &self.data_type, + self.sort_exprs.as_slice(), + &self.schema, + )?; // Accumulators that have window frame startings different // than `UNBOUNDED PRECEDING`, such as `1 PRECEEDING`, need to diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 5143ccc49944..3082bec9134d 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -93,10 +93,13 @@ pub fn create_window_expr( } WindowFunctionDefinition::AggregateUDF(fun) => { // TODO: Ordering not supported for Window UDFs yet + let sort_exprs = &[]; let ordering_req = &[]; + let aggregate = udaf::create_aggregate_expr( fun.as_ref(), args, + sort_exprs, ordering_req, input_schema, name, diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 43fb7e9e87d0..b8da5ea7a092 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -474,7 +474,7 @@ impl AsExecutionPlan for PhysicalPlanNode { AggregateFunction::UserDefinedAggrFunction(udaf_name) => { let agg_udf = registry.udaf(udaf_name)?; let ordering_req = &[]; - udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, ordering_req, &physical_schema, name) + udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, &[], ordering_req, &physical_schema, name) } } }).transpose()?.ok_or_else(|| { diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index c7ff0cf782d0..a51da2ead738 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -427,6 +427,7 @@ fn roundtrip_aggregate_udaf() -> Result<()> { &udaf, &[col("b", &schema)?], &[], + &[], &schema, "example_agg", )?]; From 3185f9f77280c8a3c2fc2d65016b0a7fce045f63 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Fri, 1 Mar 2024 20:40:59 +0800 Subject: [PATCH 14/18] remove redudant info Signed-off-by: jayzhan211 --- .../user_defined/user_defined_aggregates.rs | 12 +++--------- datafusion/expr/src/expr_fn.rs | 18 ------------------ datafusion/expr/src/udaf.rs | 10 ---------- 3 files changed, 3 insertions(+), 37 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 85441838c19d..e9583fff35c3 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -42,10 +42,10 @@ use datafusion::{ prelude::SessionContext, scalar::ScalarValue, }; -use datafusion_common::{assert_contains, cast::as_primitive_array, exec_err, Column}; +use datafusion_common::{assert_contains, cast::as_primitive_array, exec_err}; use datafusion_expr::{ - create_udaf, create_udaf_with_ordering, expr::Sort, AggregateUDFImpl, Expr, - GroupsAccumulator, SimpleAggregateUDF, + create_udaf, create_udaf_with_ordering, AggregateUDFImpl, Expr, GroupsAccumulator, + SimpleAggregateUDF, }; use datafusion_physical_expr::expressions::{self, FirstValueAccumulator}; use datafusion_physical_expr::{expressions::AvgAccumulator, PhysicalSortExpr}; @@ -278,12 +278,6 @@ async fn simple_udaf_order() -> Result<()> { Volatility::Immutable, Arc::new(create_accumulator), Arc::new(vec![DataType::Int32, DataType::Int32, DataType::Boolean]), - vec![Expr::Sort(Sort { - expr: Box::new(Expr::Column(Column::new(Some("t"), "a"))), - asc: false, - nulls_first: false, - })], - Some(&schema), ); ctx.register_udaf(my_first); diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 3ba4808b9694..1a84a49d0949 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -1027,8 +1027,6 @@ pub fn create_udaf_with_ordering( volatility: Volatility, accumulator: AccumulatorFactoryFunction, state_type: Arc>, - ordering_req: Vec, - schema: Option<&Schema>, ) -> AggregateUDF { let return_type = Arc::try_unwrap(return_type).unwrap_or_else(|t| t.as_ref().clone()); let state_type = Arc::try_unwrap(state_type).unwrap_or_else(|t| t.as_ref().clone()); @@ -1040,8 +1038,6 @@ pub fn create_udaf_with_ordering( volatility, accumulator, state_type, - ordering_req, - schema, )) } @@ -1144,8 +1140,6 @@ pub struct SimpleOrderedAggregateUDF { return_type: DataType, accumulator: AccumulatorFactoryFunction, state_type: Vec, - ordering_req: Vec, - schema: Option, } impl Debug for SimpleOrderedAggregateUDF { @@ -1169,8 +1163,6 @@ impl SimpleOrderedAggregateUDF { volatility: Volatility, accumulator: AccumulatorFactoryFunction, state_type: Vec, - ordering_req: Vec, - schema: Option<&Schema>, ) -> Self { let name = name.into(); let signature = Signature::exact(input_type, volatility); @@ -1180,8 +1172,6 @@ impl SimpleOrderedAggregateUDF { return_type, accumulator, state_type, - ordering_req, - schema: schema.cloned(), } } } @@ -1215,14 +1205,6 @@ impl AggregateUDFImpl for SimpleOrderedAggregateUDF { fn state_type(&self, _return_type: &DataType) -> Result> { Ok(self.state_type.clone()) } - - fn sort_exprs(&self) -> Vec { - self.ordering_req.clone() - } - - fn schema(&self) -> Option<&Schema> { - self.schema.as_ref() - } } /// Creates a new UDWF with a specific signature, state type and return type. diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 121cd6834306..b5caf860163d 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -294,16 +294,6 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { fn create_groups_accumulator(&self) -> Result> { not_impl_err!("GroupsAccumulator hasn't been implemented for {self:?} yet") } - - /// Return the ordering expressions for the accumulator - fn sort_exprs(&self) -> Vec { - vec![] - } - - /// Return the schema for the accumulator - fn schema(&self) -> Option<&Schema> { - None - } } /// Implementation of [`AggregateUDFImpl`] that wraps the function style pointers From 3ecc772c2a99b185741868d562257a2e84d15dcb Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Fri, 1 Mar 2024 20:44:24 +0800 Subject: [PATCH 15/18] rename Signed-off-by: jayzhan211 --- datafusion/core/src/physical_planner.rs | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 4d1cf1b28d8f..e065a0d791e4 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -246,6 +246,7 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { args, filter, order_by, + null_treatment: _, }) => match func_def { AggregateFunctionDefinition::BuiltIn(..) => { create_function_physical_name(func_def.name(), *distinct, args) @@ -1674,7 +1675,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( }; let sort_exprs = order_by.clone().unwrap_or(vec![]); - let phy_order_by = match order_by { + let order_by = match order_by { Some(e) => Some( e.iter() .map(|expr| { @@ -1693,7 +1694,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( == NullTreatment::IgnoreNulls; let (agg_expr, filter, order_by) = match func_def { AggregateFunctionDefinition::BuiltIn(fun) => { - let ordering_reqs = phy_order_by.clone().unwrap_or(vec![]); + let ordering_reqs = order_by.clone().unwrap_or(vec![]); let agg_expr = aggregates::create_aggregate_expr( fun, *distinct, @@ -1703,11 +1704,11 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( name, ignore_nulls, )?; - (agg_expr, filter, phy_order_by) + (agg_expr, filter, order_by) } AggregateFunctionDefinition::UDF(fun) => { let ordering_reqs: Vec = - phy_order_by.clone().unwrap_or(vec![]); + order_by.clone().unwrap_or(vec![]); let agg_expr = udaf::create_aggregate_expr( fun, @@ -1717,7 +1718,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( physical_input_schema, name, )?; - (agg_expr, filter, phy_order_by) + (agg_expr, filter, order_by) } AggregateFunctionDefinition::Name(_) => { return internal_err!( @@ -1725,7 +1726,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( ) } }; - Ok((agg_expr, filter, phy_order_by)) + Ok((agg_expr, filter, order_by)) } other => internal_err!("Invalid aggregate expression '{other:?}'"), } From faadc63ce4607090cdc7b3cfa8860bc18fc28ac1 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Fri, 1 Mar 2024 20:55:15 +0800 Subject: [PATCH 16/18] cleanup Signed-off-by: jayzhan211 --- datafusion/expr/src/expr_fn.rs | 2 -- datafusion/expr/src/udaf.rs | 19 ++++++++++--------- datafusion/physical-plan/src/udaf.rs | 8 +++----- datafusion/proto/src/physical_plan/mod.rs | 4 +++- 4 files changed, 16 insertions(+), 17 deletions(-) diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 1a84a49d0949..4808c7197cfa 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -1019,7 +1019,6 @@ pub fn create_udaf( /// Creates a new UDAF with a specific signature, state type and return type. /// The signature and state type must match the `Accumulator's implementation`. -#[allow(clippy::too_many_arguments)] pub fn create_udaf_with_ordering( name: &str, input_type: Vec, @@ -1155,7 +1154,6 @@ impl Debug for SimpleOrderedAggregateUDF { impl SimpleOrderedAggregateUDF { /// Create a new `AggregateUDFImpl` from a name, input types, return type, state type and /// implementation. Implementing [`AggregateUDFImpl`] allows more flexibility - #[allow(clippy::too_many_arguments)] pub fn new( name: impl Into, input_type: Vec, diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index b5caf860163d..8119a45e9f0c 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -160,8 +160,6 @@ impl AggregateUDF { sort_exprs: &[Expr], schema: &Schema, ) -> Result> { - // let sort_exprs = self.inner.sort_exprs(); - // let schema = self.inner.schema(); self.inner.accumulator(return_type, sort_exprs, schema) } @@ -180,10 +178,6 @@ impl AggregateUDF { pub fn create_groups_accumulator(&self) -> Result> { self.inner.create_groups_accumulator() } - - pub fn sort_exprs() -> Vec { - vec![] - } } impl From for AggregateUDF @@ -238,7 +232,7 @@ where /// Ok(DataType::Float64) /// } /// // This is the accumulator factory; DataFusion uses it to create new accumulators. -/// fn accumulator(&self, _arg: &DataType, _sort_exprs: Vec, _schema: Option<&Schema>) -> Result> { unimplemented!() } +/// fn accumulator(&self, _arg: &DataType, _sort_exprs: &[Expr], _schema: &Schema) -> Result> { unimplemented!() } /// fn state_type(&self, _return_type: &DataType) -> Result> { /// Ok(vec![DataType::Float64, DataType::UInt32]) /// } @@ -266,8 +260,15 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { fn return_type(&self, arg_types: &[DataType]) -> Result; /// Return a new [`Accumulator`] that aggregates values for a specific - /// group during query execution. sort_exprs is a list of ordering expressions, - /// and schema is used while ordering. + /// group during query execution. + /// + /// `arg`: the type of the argument to this accumulator + /// + /// `sort_exprs`: contains a list of `Expr::SortExpr`s if the + /// aggregate is called with an explicit `ORDER BY`. For example, + /// `ARRAY_AGG(x ORDER BY y ASC)`. In this case, `sort_exprs` would contain `[y ASC]` + /// + /// `schema` is the input schema to the udaf fn accumulator( &self, arg: &DataType, diff --git a/datafusion/physical-plan/src/udaf.rs b/datafusion/physical-plan/src/udaf.rs index b74a2d971d36..3d15e3d012c5 100644 --- a/datafusion/physical-plan/src/udaf.rs +++ b/datafusion/physical-plan/src/udaf.rs @@ -118,11 +118,9 @@ impl AggregateExpr for AggregateFunctionExpr { } fn create_sliding_accumulator(&self) -> Result> { - let accumulator = self.fun.accumulator( - &self.data_type, - self.sort_exprs.as_slice(), - &self.schema, - )?; + let accumulator = + self.fun + .accumulator(&self.data_type, &self.sort_exprs, &self.schema)?; // Accumulators that have window frame startings different // than `UNBOUNDED PRECEDING`, such as `1 PRECEEDING`, need to diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index b8da5ea7a092..4c098a10737d 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -473,8 +473,10 @@ impl AsExecutionPlan for PhysicalPlanNode { } AggregateFunction::UserDefinedAggrFunction(udaf_name) => { let agg_udf = registry.udaf(udaf_name)?; + // TODO: `order by` is not supported for UDAF yet + let sort_exprs = &[]; let ordering_req = &[]; - udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, &[], ordering_req, &physical_schema, name) + udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, sort_exprs, ordering_req, &physical_schema, name) } } }).transpose()?.ok_or_else(|| { From 7e339101f76acb2fa19ee4d689ad9535b43e9e45 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Fri, 8 Mar 2024 06:22:54 +0800 Subject: [PATCH 17/18] add ignore nulls Signed-off-by: jayzhan211 --- datafusion/core/src/physical_planner.rs | 2 +- .../core/tests/user_defined/user_defined_aggregates.rs | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index e065a0d791e4..cf56e09cb32e 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -245,7 +245,7 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { distinct, args, filter, - order_by, + order_by: _, null_treatment: _, }) => match func_def { AggregateFunctionDefinition::BuiltIn(..) => { diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index e9583fff35c3..523fb697af9f 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -265,8 +265,12 @@ async fn simple_udaf_order() -> Result<()> { .map(|e| e.expr.data_type(schema)) .collect::>>()?; - let acc = - FirstValueAccumulator::try_new(data_type, &ordering_dtypes, ordering_req)?; + let acc = FirstValueAccumulator::try_new( + data_type, + &ordering_dtypes, + ordering_req, + false, + )?; Ok(Box::new(acc)) } From 6aaa15ca70f14f088d548e28eb762ccf87eee254 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Mon, 25 Mar 2024 20:18:10 +0800 Subject: [PATCH 18/18] fix conflict Signed-off-by: jayzhan211 --- .../core/tests/user_defined/user_defined_aggregates.rs | 2 +- datafusion/expr/src/function.rs | 5 ++--- datafusion/expr/src/udaf.rs | 9 +++++++-- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 3b4c4a69f3b8..dfe17430147b 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -431,7 +431,7 @@ async fn test_user_defined_functions_with_alias() -> Result<()> { vec![DataType::Float64], Arc::new(DataType::Float64), Volatility::Immutable, - Arc::new(|_| Ok(Box::::default())), + Arc::new(|_, _, _| Ok(Box::::default())), Arc::new(vec![DataType::UInt64, DataType::Float64]), ) .with_aliases(vec!["dummy_alias"]); diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 2cf89a4fd39c..f08411550124 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -17,10 +17,9 @@ //! Function module contains typing and signature for built-in and user defined functions. -use crate::{Accumulator, BuiltinScalarFunction, Expr, PartitionEvaluator, Signature}; -use crate::{AggregateFunction, BuiltInWindowFunction, ColumnarValue}; +use crate::ColumnarValue; +use crate::{Accumulator, Expr, PartitionEvaluator}; use arrow::datatypes::{DataType, Schema}; -use datafusion_common::utils::datafusion_strsim; use datafusion_common::Result; use std::sync::Arc; diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index d54152485890..d232a2b09c7e 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -355,8 +355,13 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl { self.inner.return_type(arg_types) } - fn accumulator(&self, arg: &DataType) -> Result> { - self.inner.accumulator(arg) + fn accumulator( + &self, + arg: &DataType, + sort_exprs: &[Expr], + schema: &Schema, + ) -> Result> { + self.inner.accumulator(arg, sort_exprs, schema) } fn state_type(&self, return_type: &DataType) -> Result> {