diff --git a/datafusion-examples/examples/advanced_udaf.rs b/datafusion-examples/examples/advanced_udaf.rs index 10164a850bfb..06995863a245 100644 --- a/datafusion-examples/examples/advanced_udaf.rs +++ b/datafusion-examples/examples/advanced_udaf.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use arrow_schema::Schema; use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility}; use datafusion_physical_expr::NullState; use std::{any::Any, sync::Arc}; @@ -85,7 +86,12 @@ impl AggregateUDFImpl for GeoMeanUdaf { /// is supported, DataFusion will use this row oriented /// accumulator when the aggregate function is used as a window function /// or when there are only aggregates (no GROUP BY columns) in the plan. - fn accumulator(&self, _arg: &DataType) -> Result> { + fn accumulator( + &self, + _arg: &DataType, + _sort_exprs: &[Expr], + _schema: &Schema, + ) -> Result> { Ok(Box::new(GeometricMean::new())) } @@ -191,7 +197,7 @@ impl Accumulator for GeometricMean { // create local session context with an in-memory table fn create_context() -> Result { - use datafusion::arrow::datatypes::{Field, Schema}; + use datafusion::arrow::datatypes::Field; use datafusion::datasource::MemTable; // define a schema. let schema = Arc::new(Schema::new(vec![ diff --git a/datafusion-examples/examples/simple_udaf.rs b/datafusion-examples/examples/simple_udaf.rs index 0996a67245a8..3a62e28b0568 100644 --- a/datafusion-examples/examples/simple_udaf.rs +++ b/datafusion-examples/examples/simple_udaf.rs @@ -150,7 +150,7 @@ async fn main() -> Result<()> { Arc::new(DataType::Float64), Volatility::Immutable, // This is the accumulator factory; DataFusion uses it to create new accumulators. - Arc::new(|_| Ok(Box::new(GeometricMean::new()))), + Arc::new(|_, _, _| Ok(Box::new(GeometricMean::new()))), // This is the description of the state. `state()` must match the types here. Arc::new(vec![DataType::Float64, DataType::UInt32]), ); diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index ca708b05823e..15764d84bdd5 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -246,24 +246,20 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { distinct, args, filter, - order_by, + order_by: _, null_treatment: _, }) => match func_def { AggregateFunctionDefinition::BuiltIn(..) => { create_function_physical_name(func_def.name(), *distinct, args) } AggregateFunctionDefinition::UDF(fun) => { - // TODO: Add support for filter and order by in AggregateUDF + // TODO: Add support for filter by in AggregateUDF if filter.is_some() { return exec_err!( "aggregate expression with filter is not supported" ); } - if order_by.is_some() { - return exec_err!( - "aggregate expression with order_by is not supported" - ); - } + let names = args .iter() .map(|e| create_physical_name(e, false)) @@ -1682,6 +1678,8 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( )?), None => None, }; + + let sort_exprs = order_by.clone().unwrap_or(vec![]); let order_by = match order_by { Some(e) => Some( e.iter() @@ -1714,13 +1712,18 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( (agg_expr, filter, order_by) } AggregateFunctionDefinition::UDF(fun) => { + let ordering_reqs: Vec = + order_by.clone().unwrap_or(vec![]); + let agg_expr = udaf::create_aggregate_expr( fun, &args, + &sort_exprs, + &ordering_reqs, physical_input_schema, name, - ); - (agg_expr?, filter, order_by) + )?; + (agg_expr, filter, order_by) } AggregateFunctionDefinition::Name(_) => { return internal_err!( diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index a58a8cf51681..dfe17430147b 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -20,7 +20,7 @@ use arrow::{array::AsArray, datatypes::Fields}; use arrow_array::{types::UInt64Type, Int32Array, PrimitiveArray, StructArray}; -use arrow_schema::Schema; +use arrow_schema::{Schema, SortOptions}; use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, @@ -45,9 +45,11 @@ use datafusion::{ }; use datafusion_common::{assert_contains, cast::as_primitive_array, exec_err}; use datafusion_expr::{ - create_udaf, AggregateUDFImpl, GroupsAccumulator, SimpleAggregateUDF, + create_udaf, create_udaf_with_ordering, AggregateUDFImpl, Expr, GroupsAccumulator, + SimpleAggregateUDF, }; -use datafusion_physical_expr::expressions::AvgAccumulator; +use datafusion_physical_expr::expressions::{self, FirstValueAccumulator}; +use datafusion_physical_expr::{expressions::AvgAccumulator, PhysicalSortExpr}; /// Test to show the contents of the setup #[tokio::test] @@ -209,6 +211,102 @@ async fn execute(ctx: &SessionContext, sql: &str) -> Result> { ctx.sql(sql).await?.collect().await } +#[tokio::test] +async fn simple_udaf_order() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3, 4])), + Arc::new(Int32Array::from(vec![1, 1, 2, 2])), + ], + )?; + + let ctx = SessionContext::new(); + + let provider = MemTable::try_new(Arc::new(schema.clone()), vec![vec![batch]])?; + ctx.register_table("t", Arc::new(provider))?; + + fn create_accumulator( + data_type: &DataType, + order_by: &[Expr], + schema: &Schema, + ) -> Result> { + let mut all_sort_orders = vec![]; + + // Construct PhysicalSortExpr objects from Expr objects: + let mut sort_exprs = vec![]; + for expr in order_by { + if let Expr::Sort(sort) = expr { + if let Expr::Column(col) = sort.expr.as_ref() { + let name = &col.name; + let e = expressions::col(name, schema)?; + sort_exprs.push(PhysicalSortExpr { + expr: e, + options: SortOptions { + descending: !sort.asc, + nulls_first: sort.nulls_first, + }, + }); + } + } + } + if !sort_exprs.is_empty() { + all_sort_orders.extend(sort_exprs); + } + + let ordering_req = all_sort_orders; + + let ordering_dtypes = ordering_req + .iter() + .map(|e| e.expr.data_type(schema)) + .collect::>>()?; + + let acc = FirstValueAccumulator::try_new( + data_type, + &ordering_dtypes, + ordering_req, + false, + )?; + Ok(Box::new(acc)) + } + + // define a udaf, using a DataFusion's accumulator + let my_first = create_udaf_with_ordering( + "my_first", + vec![DataType::Int32], + Arc::new(DataType::Int32), + Volatility::Immutable, + Arc::new(create_accumulator), + Arc::new(vec![DataType::Int32, DataType::Int32, DataType::Boolean]), + ); + + ctx.register_udaf(my_first); + + // Should be the same as `SELECT FIRST_VALUE(a order by a) FROM t group by b order by b` + let result = ctx + .sql("SELECT MY_FIRST(a order by a desc) FROM t group by b order by b") + .await? + .collect() + .await?; + + let expected = [ + "+---------------+", + "| my_first(t.a) |", + "+---------------+", + "| 2 |", + "| 4 |", + "+---------------+", + ]; + assert_batches_eq!(expected, &result); + + Ok(()) +} + /// tests the creation, registration and usage of a UDAF #[tokio::test] async fn simple_udaf() -> Result<()> { @@ -234,7 +332,7 @@ async fn simple_udaf() -> Result<()> { vec![DataType::Float64], Arc::new(DataType::Float64), Volatility::Immutable, - Arc::new(|_| Ok(Box::::default())), + Arc::new(|_, _, _| Ok(Box::::default())), Arc::new(vec![DataType::UInt64, DataType::Float64]), ); @@ -262,7 +360,7 @@ async fn deregister_udaf() -> Result<()> { vec![DataType::Float64], Arc::new(DataType::Float64), Volatility::Immutable, - Arc::new(|_| Ok(Box::::default())), + Arc::new(|_, _, _| Ok(Box::::default())), Arc::new(vec![DataType::UInt64, DataType::Float64]), ); @@ -290,7 +388,7 @@ async fn case_sensitive_identifiers_user_defined_aggregates() -> Result<()> { vec![DataType::Float64], Arc::new(DataType::Float64), Volatility::Immutable, - Arc::new(|_| Ok(Box::::default())), + Arc::new(|_, _, _| Ok(Box::::default())), Arc::new(vec![DataType::UInt64, DataType::Float64]), ); @@ -333,7 +431,7 @@ async fn test_user_defined_functions_with_alias() -> Result<()> { vec![DataType::Float64], Arc::new(DataType::Float64), Volatility::Immutable, - Arc::new(|_| Ok(Box::::default())), + Arc::new(|_, _, _| Ok(Box::::default())), Arc::new(vec![DataType::UInt64, DataType::Float64]), ) .with_aliases(vec!["dummy_alias"]); @@ -497,7 +595,7 @@ impl TimeSum { let captured_state = Arc::clone(&test_state); let accumulator: AccumulatorFactoryFunction = - Arc::new(move |_| Ok(Box::new(Self::new(Arc::clone(&captured_state))))); + Arc::new(move |_, _, _| Ok(Box::new(Self::new(Arc::clone(&captured_state))))); let time_sum = AggregateUDF::from(SimpleAggregateUDF::new( name, @@ -596,7 +694,7 @@ impl FirstSelector { let signatures = vec![TypeSignature::Exact(Self::input_datatypes())]; let accumulator: AccumulatorFactoryFunction = - Arc::new(|_| Ok(Box::new(Self::new()))); + Arc::new(|_, _, _| Ok(Box::new(Self::new()))); let volatility = Volatility::Immutable; @@ -717,7 +815,12 @@ impl AggregateUDFImpl for TestGroupsAccumulator { Ok(DataType::UInt64) } - fn accumulator(&self, _arg: &DataType) -> Result> { + fn accumulator( + &self, + _arg: &DataType, + _sort_exprs: &[Expr], + _schema: &Schema, + ) -> Result> { // should use groups accumulator panic!("accumulator shouldn't invoke"); } diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index b525e4fc6341..df8976736572 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -291,7 +291,7 @@ async fn udaf_as_window_func() -> Result<()> { vec![DataType::Int32], Arc::new(DataType::Int32), Volatility::Immutable, - Arc::new(|_| Ok(Box::new(MyAccumulator))), + Arc::new(|_, _, _| Ok(Box::new(MyAccumulator))), Arc::new(vec![DataType::Int32]), ); diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 0ea946288e0f..1a021cedbabd 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -29,7 +29,7 @@ use crate::{ ScalarUDF, Signature, Volatility, }; use crate::{AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowUDF, WindowUDFImpl}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Schema}; use datafusion_common::{Column, Result}; use std::any::Any; use std::fmt::Debug; @@ -746,6 +746,29 @@ pub fn create_udaf( )) } +/// Creates a new UDAF with a specific signature, state type and return type. +/// The signature and state type must match the `Accumulator's implementation`. +pub fn create_udaf_with_ordering( + name: &str, + input_type: Vec, + return_type: Arc, + volatility: Volatility, + accumulator: AccumulatorFactoryFunction, + state_type: Arc>, +) -> AggregateUDF { + let return_type = Arc::try_unwrap(return_type).unwrap_or_else(|t| t.as_ref().clone()); + let state_type = Arc::try_unwrap(state_type).unwrap_or_else(|t| t.as_ref().clone()); + + AggregateUDF::from(SimpleOrderedAggregateUDF::new( + name, + input_type, + return_type, + volatility, + accumulator, + state_type, + )) +} + /// Implements [`AggregateUDFImpl`] for functions that have a single signature and /// return type. pub struct SimpleAggregateUDF { @@ -823,8 +846,87 @@ impl AggregateUDFImpl for SimpleAggregateUDF { Ok(self.return_type.clone()) } - fn accumulator(&self, arg: &DataType) -> Result> { - (self.accumulator)(arg) + fn accumulator( + &self, + arg: &DataType, + sort_exprs: &[Expr], + schema: &Schema, + ) -> Result> { + (self.accumulator)(arg, sort_exprs, schema) + } + + fn state_type(&self, _return_type: &DataType) -> Result> { + Ok(self.state_type.clone()) + } +} + +/// Implements [`AggregateUDFImpl`] for functions that have a single signature and +/// return type. +pub struct SimpleOrderedAggregateUDF { + name: String, + signature: Signature, + return_type: DataType, + accumulator: AccumulatorFactoryFunction, + state_type: Vec, +} + +impl Debug for SimpleOrderedAggregateUDF { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("AggregateUDF") + .field("name", &self.name) + .field("signature", &self.signature) + .field("fun", &"") + .finish() + } +} + +impl SimpleOrderedAggregateUDF { + /// Create a new `AggregateUDFImpl` from a name, input types, return type, state type and + /// implementation. Implementing [`AggregateUDFImpl`] allows more flexibility + pub fn new( + name: impl Into, + input_type: Vec, + return_type: DataType, + volatility: Volatility, + accumulator: AccumulatorFactoryFunction, + state_type: Vec, + ) -> Self { + let name = name.into(); + let signature = Signature::exact(input_type, volatility); + Self { + name, + signature, + return_type, + accumulator, + state_type, + } + } +} + +impl AggregateUDFImpl for SimpleOrderedAggregateUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(self.return_type.clone()) + } + + fn accumulator( + &self, + arg: &DataType, + sort_exprs: &[Expr], + schema: &Schema, + ) -> Result> { + (self.accumulator)(arg, sort_exprs, schema) } fn state_type(&self, _return_type: &DataType) -> Result> { diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index adf4dd3fef20..f08411550124 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -17,8 +17,9 @@ //! Function module contains typing and signature for built-in and user defined functions. -use crate::{Accumulator, ColumnarValue, PartitionEvaluator}; -use arrow::datatypes::DataType; +use crate::ColumnarValue; +use crate::{Accumulator, Expr, PartitionEvaluator}; +use arrow::datatypes::{DataType, Schema}; use datafusion_common::Result; use std::sync::Arc; @@ -38,9 +39,10 @@ pub type ReturnTypeFunction = Arc Result> + Send + Sync>; /// Factory that returns an accumulator for the given aggregate, given -/// its return datatype. -pub type AccumulatorFactoryFunction = - Arc Result> + Send + Sync>; +/// its return datatype, the sorting expressions and the schema for ordering. +pub type AccumulatorFactoryFunction = Arc< + dyn Fn(&DataType, &[Expr], &Schema) -> Result> + Send + Sync, +>; /// Factory that creates a PartitionEvaluator for the given window /// function diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index c46dd9cd3a6f..d232a2b09c7e 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -22,7 +22,7 @@ use crate::{Accumulator, Expr}; use crate::{ AccumulatorFactoryFunction, ReturnTypeFunction, Signature, StateTypeFunction, }; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Schema}; use datafusion_common::{not_impl_err, Result}; use std::any::Any; use std::fmt::{self, Debug, Formatter}; @@ -166,10 +166,14 @@ impl AggregateUDF { self.inner.return_type(args) } - /// Return an accumulator the given aggregate, given - /// its return datatype. - pub fn accumulator(&self, return_type: &DataType) -> Result> { - self.inner.accumulator(return_type) + /// Return an accumulator the given aggregate, given its return datatype + pub fn accumulator( + &self, + return_type: &DataType, + sort_exprs: &[Expr], + schema: &Schema, + ) -> Result> { + self.inner.accumulator(return_type, sort_exprs, schema) } /// Return the type of the intermediate state used by this aggregator, given @@ -213,8 +217,9 @@ where /// # use std::any::Any; /// # use arrow::datatypes::DataType; /// # use datafusion_common::{DataFusionError, plan_err, Result}; -/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility}; +/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility, Expr}; /// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator}; +/// # use arrow::datatypes::Schema; /// #[derive(Debug, Clone)] /// struct GeoMeanUdf { /// signature: Signature @@ -240,7 +245,7 @@ where /// Ok(DataType::Float64) /// } /// // This is the accumulator factory; DataFusion uses it to create new accumulators. -/// fn accumulator(&self, _arg: &DataType) -> Result> { unimplemented!() } +/// fn accumulator(&self, _arg: &DataType, _sort_exprs: &[Expr], _schema: &Schema) -> Result> { unimplemented!() } /// fn state_type(&self, _return_type: &DataType) -> Result> { /// Ok(vec![DataType::Float64, DataType::UInt32]) /// } @@ -269,7 +274,20 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// Return a new [`Accumulator`] that aggregates values for a specific /// group during query execution. - fn accumulator(&self, arg: &DataType) -> Result>; + /// + /// `arg`: the type of the argument to this accumulator + /// + /// `sort_exprs`: contains a list of `Expr::SortExpr`s if the + /// aggregate is called with an explicit `ORDER BY`. For example, + /// `ARRAY_AGG(x ORDER BY y ASC)`. In this case, `sort_exprs` would contain `[y ASC]` + /// + /// `schema` is the input schema to the udaf + fn accumulator( + &self, + arg: &DataType, + sort_exprs: &[Expr], + schema: &Schema, + ) -> Result>; /// Return the type used to serialize the [`Accumulator`]'s intermediate state. /// See [`Accumulator::state()`] for more details @@ -277,7 +295,7 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// If the aggregate expression has a specialized /// [`GroupsAccumulator`] implementation. If this returns true, - /// `[Self::create_groups_accumulator`] will be called. + /// `[Self::create_groups_accumulator]` will be called. fn groups_accumulator_supported(&self) -> bool { false } @@ -337,8 +355,13 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl { self.inner.return_type(arg_types) } - fn accumulator(&self, arg: &DataType) -> Result> { - self.inner.accumulator(arg) + fn accumulator( + &self, + arg: &DataType, + sort_exprs: &[Expr], + schema: &Schema, + ) -> Result> { + self.inner.accumulator(arg, sort_exprs, schema) } fn state_type(&self, return_type: &DataType) -> Result> { @@ -394,8 +417,13 @@ impl AggregateUDFImpl for AggregateUDFLegacyWrapper { Ok(res.as_ref().clone()) } - fn accumulator(&self, arg: &DataType) -> Result> { - (self.accumulator)(arg) + fn accumulator( + &self, + arg: &DataType, + sort_exprs: &[Expr], + schema: &Schema, + ) -> Result> { + (self.accumulator)(arg, sort_exprs, schema) } fn state_type(&self, return_type: &DataType) -> Result> { diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index c76c1c8a7bd0..4f861ffe9967 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -887,7 +887,7 @@ mod test { vec![DataType::Float64], Arc::new(DataType::Float64), Volatility::Immutable, - Arc::new(|_| Ok(Box::::default())), + Arc::new(|_, _, _| Ok(Box::::default())), Arc::new(vec![DataType::UInt64, DataType::Float64]), ); let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf( @@ -908,7 +908,7 @@ mod test { let return_type = DataType::Float64; let state_type = vec![DataType::UInt64, DataType::Float64]; let accumulator: AccumulatorFactoryFunction = - Arc::new(|_| Ok(Box::::default())); + Arc::new(|_, _, _| Ok(Box::::default())); let my_avg = AggregateUDF::from(SimpleAggregateUDF::new_with_signature( "MY_AVG", Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable), diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 0c9064d0641f..8b4d60aafd19 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -963,7 +963,8 @@ mod test { let table_scan = test_table_scan()?; let return_type = DataType::UInt32; - let accumulator: AccumulatorFactoryFunction = Arc::new(|_| unimplemented!()); + let accumulator: AccumulatorFactoryFunction = + Arc::new(|_, _, _| unimplemented!()); let state_type = vec![DataType::UInt32]; let udf_agg = |inner: Expr| { Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( diff --git a/datafusion/physical-expr/src/aggregate/first_last.rs b/datafusion/physical-expr/src/aggregate/first_last.rs index 17dd3ef1206d..264f52f93a3a 100644 --- a/datafusion/physical-expr/src/aggregate/first_last.rs +++ b/datafusion/physical-expr/src/aggregate/first_last.rs @@ -210,7 +210,7 @@ impl PartialEq for FirstValue { } #[derive(Debug)] -struct FirstValueAccumulator { +pub struct FirstValueAccumulator { first: ScalarValue, // At the beginning, `is_set` is false, which means `first` is not seen yet. // Once we see the first value, we set the `is_set` flag and do not update `first` anymore. diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 7c4ea07dfbcb..fcd656173355 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -53,7 +53,7 @@ pub use crate::aggregate::correlation::Correlation; pub use crate::aggregate::count::Count; pub use crate::aggregate::count_distinct::DistinctCount; pub use crate::aggregate::covariance::{Covariance, CovariancePop}; -pub use crate::aggregate::first_last::{FirstValue, LastValue}; +pub use crate::aggregate::first_last::{FirstValue, FirstValueAccumulator, LastValue}; pub use crate::aggregate::grouping::Grouping; pub use crate::aggregate::median::Median; pub use crate::aggregate::min_max::{Max, Min}; diff --git a/datafusion/physical-plan/src/udaf.rs b/datafusion/physical-plan/src/udaf.rs index fd9279dfd552..3d15e3d012c5 100644 --- a/datafusion/physical-plan/src/udaf.rs +++ b/datafusion/physical-plan/src/udaf.rs @@ -17,20 +17,17 @@ //! This module contains functions and structs supporting user-defined aggregate functions. -use datafusion_expr::GroupsAccumulator; +use datafusion_expr::{Expr, GroupsAccumulator}; use fmt::Debug; use std::any::Any; use std::fmt; -use arrow::{ - datatypes::Field, - datatypes::{DataType, Schema}, -}; +use arrow::datatypes::{DataType, Field, Schema}; use super::{expressions::format_state_name, Accumulator, AggregateExpr}; use datafusion_common::{not_impl_err, Result}; pub use datafusion_expr::AggregateUDF; -use datafusion_physical_expr::PhysicalExpr; +use datafusion_physical_expr::{LexOrdering, PhysicalExpr, PhysicalSortExpr}; use datafusion_physical_expr::aggregate::utils::down_cast_any_ref; use std::sync::Arc; @@ -40,12 +37,14 @@ use std::sync::Arc; pub fn create_aggregate_expr( fun: &AggregateUDF, input_phy_exprs: &[Arc], - input_schema: &Schema, + sort_exprs: &[Expr], + ordering_req: &[PhysicalSortExpr], + schema: &Schema, name: impl Into, ) -> Result> { let input_exprs_types = input_phy_exprs .iter() - .map(|arg| arg.data_type(input_schema)) + .map(|arg| arg.data_type(schema)) .collect::>>()?; Ok(Arc::new(AggregateFunctionExpr { @@ -53,6 +52,9 @@ pub fn create_aggregate_expr( args: input_phy_exprs.to_vec(), data_type: fun.return_type(&input_exprs_types)?, name: name.into(), + schema: schema.clone(), + sort_exprs: sort_exprs.to_vec(), + ordering_req: ordering_req.to_vec(), })) } @@ -64,6 +66,11 @@ pub struct AggregateFunctionExpr { /// Output / return type of this aggregate data_type: DataType, name: String, + schema: Schema, + // The logical order by expressions + sort_exprs: Vec, + // The physical order by expressions + ordering_req: LexOrdering, } impl AggregateFunctionExpr { @@ -106,11 +113,14 @@ impl AggregateExpr for AggregateFunctionExpr { } fn create_accumulator(&self) -> Result> { - self.fun.accumulator(&self.data_type) + self.fun + .accumulator(&self.data_type, self.sort_exprs.as_slice(), &self.schema) } fn create_sliding_accumulator(&self) -> Result> { - let accumulator = self.fun.accumulator(&self.data_type)?; + let accumulator = + self.fun + .accumulator(&self.data_type, &self.sort_exprs, &self.schema)?; // Accumulators that have window frame startings different // than `UNBOUNDED PRECEDING`, such as `1 PRECEEDING`, need to @@ -175,6 +185,10 @@ impl AggregateExpr for AggregateFunctionExpr { fn create_groups_accumulator(&self) -> Result> { self.fun.create_groups_accumulator() } + + fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { + (!self.ordering_req.is_empty()).then_some(&self.ordering_req) + } } impl PartialEq for AggregateFunctionExpr { diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 21f42f41fb5c..e0ad9363051f 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -92,8 +92,18 @@ pub fn create_window_expr( )) } WindowFunctionDefinition::AggregateUDF(fun) => { - let aggregate = - udaf::create_aggregate_expr(fun.as_ref(), args, input_schema, name)?; + // TODO: Ordering not supported for Window UDFs yet + let sort_exprs = &[]; + let ordering_req = &[]; + + let aggregate = udaf::create_aggregate_expr( + fun.as_ref(), + args, + sort_exprs, + ordering_req, + input_schema, + name, + )?; window_expr_from_aggregate_expr( partition_by, order_by, diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs index 610c533d574c..4c570d343574 100644 --- a/datafusion/proto/src/bytes/mod.rs +++ b/datafusion/proto/src/bytes/mod.rs @@ -127,7 +127,7 @@ impl Serializeable for Expr { vec![arrow::datatypes::DataType::Null], Arc::new(arrow::datatypes::DataType::Null), Volatility::Immutable, - Arc::new(|_| unimplemented!()), + Arc::new(|_, _, _| unimplemented!()), Arc::new(vec![]), ))) } diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index da31c5e762bc..a45b08c333da 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -517,7 +517,10 @@ impl AsExecutionPlan for PhysicalPlanNode { } AggregateFunction::UserDefinedAggrFunction(udaf_name) => { let agg_udf = registry.udaf(udaf_name)?; - udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, &physical_schema, name) + // TODO: `order by` is not supported for UDAF yet + let sort_exprs = &[]; + let ordering_req = &[]; + udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, sort_exprs, ordering_req, &physical_schema, name) } } }).transpose()?.ok_or_else(|| { diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 3c43f100750f..c40bfc97677c 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -1762,7 +1762,7 @@ fn roundtrip_aggregate_udf() { Arc::new(DataType::Float64), Volatility::Immutable, // This is the accumulator factory; DataFusion uses it to create new accumulators. - Arc::new(|_| Ok(Box::new(Dummy {}))), + Arc::new(|_, _, _| Ok(Box::new(Dummy {}))), // This is the description of the state. `state()` must match the types here. Arc::new(vec![DataType::Float64, DataType::UInt32]), ); @@ -1977,7 +1977,7 @@ fn roundtrip_window() { Arc::new(DataType::Float64), Volatility::Immutable, // This is the accumulator factory; DataFusion uses it to create new accumulators. - Arc::new(|_| Ok(Box::new(DummyAggr {}))), + Arc::new(|_, _, _| Ok(Box::new(DummyAggr {}))), // This is the description of the state. `state()` must match the types here. Arc::new(vec![DataType::Float64, DataType::UInt32]), ); diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 4924128ae190..33dbb92d8d3a 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -411,7 +411,8 @@ fn roundtrip_aggregate_udaf() -> Result<()> { } let return_type = DataType::Int64; - let accumulator: AccumulatorFactoryFunction = Arc::new(|_| Ok(Box::new(Example))); + let accumulator: AccumulatorFactoryFunction = + Arc::new(|_, _, _| Ok(Box::new(Example))); let state_type = vec![DataType::Int64]; let udaf = AggregateUDF::from(SimpleAggregateUDF::new_with_signature( @@ -431,6 +432,8 @@ fn roundtrip_aggregate_udaf() -> Result<()> { let aggregates: Vec> = vec![udaf::create_aggregate_expr( &udaf, &[col("b", &schema)?], + &[], + &[], &schema, "example_agg", )?]; diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 582404b29749..9b92067b9ec6 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -221,9 +221,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } else { // User defined aggregate functions (UDAF) have precedence in case it has the same name as a scalar built-in function if let Some(fm) = self.context_provider.get_aggregate_meta(&name) { + let order_by = + self.order_by_to_sort_expr(&order_by, schema, planner_context, true)?; + let order_by = (!order_by.is_empty()).then_some(order_by); let args = self.function_args_to_expr(args, schema, planner_context)?; return Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf( - fm, args, false, None, None, + fm, args, false, None, order_by, ))); } diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index bc9cc66b7626..a24a84ad76be 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -750,7 +750,7 @@ async fn roundtrip_aggregate_udf() -> Result<()> { Arc::new(DataType::Int64), Volatility::Immutable, // This is the accumulator factory; DataFusion uses it to create new accumulators. - Arc::new(|_| Ok(Box::new(Dummy {}))), + Arc::new(|_, _, _| Ok(Box::new(Dummy {}))), // This is the description of the state. `state()` must match the types here. Arc::new(vec![DataType::Float64, DataType::UInt32]), );