-
Notifications
You must be signed in to change notification settings - Fork 1.7k
fix: Support Substrait's compound names also for window functions #11163
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
538760c
413d32c
a3a6867
0480fff
80c7bb3
a08e62b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,7 +20,8 @@ use datafusion::arrow::datatypes::{ | |
| DataType, Field, FieldRef, Fields, IntervalUnit, Schema, TimeUnit, | ||
| }; | ||
| use datafusion::common::{ | ||
| not_impl_err, substrait_datafusion_err, substrait_err, DFSchema, DFSchemaRef, | ||
| not_impl_datafusion_err, not_impl_err, plan_datafusion_err, plan_err, | ||
| substrait_datafusion_err, substrait_err, DFSchema, DFSchemaRef, | ||
| }; | ||
| use substrait::proto::expression::literal::IntervalDayToSecond; | ||
| use substrait::proto::read_rel::local_files::file_or_files::PathType::UriFile; | ||
|
|
@@ -30,8 +31,7 @@ use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; | |
| use datafusion::execution::FunctionRegistry; | ||
| use datafusion::logical_expr::{ | ||
| aggregate_function, expr::find_df_window_func, Aggregate, BinaryExpr, Case, | ||
| EmptyRelation, Expr, ExprSchemable, LogicalPlan, Operator, Projection, ScalarUDF, | ||
| Values, | ||
| EmptyRelation, Expr, ExprSchemable, LogicalPlan, Operator, Projection, Values, | ||
| }; | ||
|
|
||
| use datafusion::logical_expr::{ | ||
|
|
@@ -57,7 +57,7 @@ use substrait::proto::{ | |
| reference_segment::ReferenceType::StructField, | ||
| window_function::bound as SubstraitBound, | ||
| window_function::bound::Kind as BoundKind, window_function::Bound, | ||
| MaskExpression, RexType, | ||
| window_function::BoundsType, MaskExpression, RexType, | ||
| }, | ||
| extensions::simple_extension_declaration::MappingType, | ||
| function_argument::ArgType, | ||
|
|
@@ -71,7 +71,6 @@ use substrait::proto::{ | |
| use substrait::proto::{FunctionArgument, SortField}; | ||
|
|
||
| use datafusion::arrow::array::GenericListArray; | ||
| use datafusion::common::plan_err; | ||
| use datafusion::common::scalar::ScalarStructBuilder; | ||
| use datafusion::logical_expr::expr::{InList, InSubquery, Sort}; | ||
| use std::collections::HashMap; | ||
|
|
@@ -89,12 +88,6 @@ use crate::variation_const::{ | |
| UNSIGNED_INTEGER_TYPE_VARIATION_REF, | ||
| }; | ||
|
|
||
| enum ScalarFunctionType { | ||
| Op(Operator), | ||
| Expr(BuiltinExprBuilder), | ||
| Udf(Arc<ScalarUDF>), | ||
| } | ||
|
|
||
| pub fn name_to_op(name: &str) -> Result<Operator> { | ||
| match name { | ||
| "equal" => Ok(Operator::Eq), | ||
|
|
@@ -128,28 +121,6 @@ pub fn name_to_op(name: &str) -> Result<Operator> { | |
| } | ||
| } | ||
|
|
||
| fn scalar_function_type_from_str( | ||
| ctx: &SessionContext, | ||
| name: &str, | ||
| ) -> Result<ScalarFunctionType> { | ||
| let s = ctx.state(); | ||
| let name = substrait_fun_name(name); | ||
|
|
||
| if let Some(func) = s.scalar_functions().get(name) { | ||
| return Ok(ScalarFunctionType::Udf(func.to_owned())); | ||
| } | ||
|
|
||
| if let Ok(op) = name_to_op(name) { | ||
| return Ok(ScalarFunctionType::Op(op)); | ||
| } | ||
|
|
||
| if let Some(builder) = BuiltinExprBuilder::try_from_name(name) { | ||
| return Ok(ScalarFunctionType::Expr(builder)); | ||
| } | ||
|
|
||
| not_impl_err!("Unsupported function name: {name:?}") | ||
| } | ||
|
|
||
| pub fn substrait_fun_name(name: &str) -> &str { | ||
| let name = match name.rsplit_once(':') { | ||
| // Since 0.32.0, Substrait requires the function names to be in a compound format | ||
|
|
@@ -972,7 +943,7 @@ pub async fn from_substrait_rex_vec( | |
| } | ||
|
|
||
| /// Convert Substrait FunctionArguments to DataFusion Exprs | ||
| pub async fn from_substriat_func_args( | ||
| pub async fn from_substrait_func_args( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fix'd the spelling - this would break consumer if someone is using this function, but I dunno why anyone would be. Still it's not necessary change so I'm happy to revert if that'd be better There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it is a nice improvement |
||
| ctx: &SessionContext, | ||
| arguments: &Vec<FunctionArgument>, | ||
| input_schema: &DFSchema, | ||
|
|
@@ -984,9 +955,7 @@ pub async fn from_substriat_func_args( | |
| Some(ArgType::Value(e)) => { | ||
| from_substrait_rex(ctx, e, input_schema, extensions).await | ||
| } | ||
| _ => { | ||
| not_impl_err!("Aggregated function argument non-Value type not supported") | ||
| } | ||
| _ => not_impl_err!("Function argument non-Value type not supported"), | ||
| }; | ||
| args.push(arg_expr?.as_ref().clone()); | ||
| } | ||
|
|
@@ -1003,33 +972,25 @@ pub async fn from_substrait_agg_func( | |
| order_by: Option<Vec<Expr>>, | ||
| distinct: bool, | ||
| ) -> Result<Arc<Expr>> { | ||
| let mut args: Vec<Expr> = vec![]; | ||
| for arg in &f.arguments { | ||
| let arg_expr = match &arg.arg_type { | ||
| Some(ArgType::Value(e)) => { | ||
| from_substrait_rex(ctx, e, input_schema, extensions).await | ||
| } | ||
| _ => { | ||
| not_impl_err!("Aggregated function argument non-Value type not supported") | ||
| } | ||
| }; | ||
| args.push(arg_expr?.as_ref().clone()); | ||
| } | ||
| let args = | ||
| from_substrait_func_args(ctx, &f.arguments, input_schema, extensions).await?; | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the code of |
||
|
|
||
| let Some(function_name) = extensions.get(&f.function_reference) else { | ||
| return plan_err!( | ||
| "Aggregate function not registered: function anchor = {:?}", | ||
| f.function_reference | ||
| ); | ||
| }; | ||
| // function_name.split(':').next().unwrap_or(function_name); | ||
|
|
||
| let function_name = substrait_fun_name((**function_name).as_str()); | ||
| // try udaf first, then built-in aggr fn. | ||
| if let Ok(fun) = ctx.udaf(function_name) { | ||
| // deal with situation that count(*) got no arguments | ||
| if fun.name() == "count" && args.is_empty() { | ||
| args.push(Expr::Literal(ScalarValue::Int64(Some(1)))); | ||
| } | ||
| let args = if fun.name() == "count" && args.is_empty() { | ||
| vec![Expr::Literal(ScalarValue::Int64(Some(1)))] | ||
| } else { | ||
| args | ||
| }; | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this way |
||
|
|
||
| Ok(Arc::new(Expr::AggregateFunction( | ||
| expr::AggregateFunction::new_udf(fun, args, distinct, filter, order_by, None), | ||
|
|
@@ -1041,7 +1002,7 @@ pub async fn from_substrait_agg_func( | |
| ))) | ||
| } else { | ||
| not_impl_err!( | ||
| "Aggregated function {} is not supported: function anchor = {:?}", | ||
| "Aggregate function {} is not supported: function anchor = {:?}", | ||
| function_name, | ||
| f.function_reference | ||
| ) | ||
|
|
@@ -1145,84 +1106,40 @@ pub async fn from_substrait_rex( | |
| }))) | ||
| } | ||
| Some(RexType::ScalarFunction(f)) => { | ||
| let fn_name = extensions.get(&f.function_reference).ok_or_else(|| { | ||
| DataFusionError::NotImplemented(format!( | ||
| "Aggregated function not found: function reference = {:?}", | ||
| let Some(fn_name) = extensions.get(&f.function_reference) else { | ||
| return plan_err!( | ||
| "Scalar function not found: function reference = {:?}", | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just aligning this check for all three function types |
||
| f.function_reference | ||
| )) | ||
| })?; | ||
|
|
||
| // Convert function arguments from Substrait to DataFusion | ||
| async fn decode_arguments( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this was same as |
||
| ctx: &SessionContext, | ||
| input_schema: &DFSchema, | ||
| extensions: &HashMap<u32, &String>, | ||
| function_args: &[FunctionArgument], | ||
| ) -> Result<Vec<Expr>> { | ||
| let mut args = Vec::with_capacity(function_args.len()); | ||
| for arg in function_args { | ||
| let arg_expr = match &arg.arg_type { | ||
| Some(ArgType::Value(e)) => { | ||
| from_substrait_rex(ctx, e, input_schema, extensions).await | ||
| } | ||
| _ => not_impl_err!( | ||
| "Aggregated function argument non-Value type not supported" | ||
| ), | ||
| }?; | ||
| args.push(arg_expr.as_ref().clone()); | ||
| } | ||
| Ok(args) | ||
| } | ||
| ); | ||
| }; | ||
| let fn_name = substrait_fun_name(fn_name); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this and below is coming from |
||
|
|
||
| let fn_type = scalar_function_type_from_str(ctx, fn_name)?; | ||
| match fn_type { | ||
| ScalarFunctionType::Udf(fun) => { | ||
| let args = decode_arguments( | ||
| ctx, | ||
| input_schema, | ||
| extensions, | ||
| f.arguments.as_slice(), | ||
| ) | ||
| let args = | ||
| from_substrait_func_args(ctx, &f.arguments, input_schema, extensions) | ||
| .await?; | ||
| Ok(Arc::new(Expr::ScalarFunction( | ||
| expr::ScalarFunction::new_udf(fun, args), | ||
| ))) | ||
| } | ||
| ScalarFunctionType::Op(op) => { | ||
| if f.arguments.len() != 2 { | ||
| return not_impl_err!( | ||
| "Expect two arguments for binary operator {op:?}" | ||
| ); | ||
| } | ||
| let lhs = &f.arguments[0].arg_type; | ||
| let rhs = &f.arguments[1].arg_type; | ||
|
|
||
| match (lhs, rhs) { | ||
| (Some(ArgType::Value(l)), Some(ArgType::Value(r))) => { | ||
| Ok(Arc::new(Expr::BinaryExpr(BinaryExpr { | ||
| left: Box::new( | ||
| from_substrait_rex(ctx, l, input_schema, extensions) | ||
| .await? | ||
| .as_ref() | ||
| .clone(), | ||
| ), | ||
| op, | ||
| right: Box::new( | ||
| from_substrait_rex(ctx, r, input_schema, extensions) | ||
| .await? | ||
| .as_ref() | ||
| .clone(), | ||
| ), | ||
| }))) | ||
| } | ||
| (l, r) => not_impl_err!( | ||
| "Invalid arguments for binary expression: {l:?} and {r:?}" | ||
| ), | ||
| } | ||
| } | ||
| ScalarFunctionType::Expr(builder) => { | ||
| builder.build(ctx, f, input_schema, extensions).await | ||
|
|
||
| // try to first match the requested function into registered udfs, then built-in ops | ||
| // and finally built-in expressions | ||
| if let Some(func) = ctx.state().scalar_functions().get(fn_name) { | ||
| Ok(Arc::new(Expr::ScalarFunction( | ||
| expr::ScalarFunction::new_udf(func.to_owned(), args), | ||
| ))) | ||
| } else if let Ok(op) = name_to_op(fn_name) { | ||
| if args.len() != 2 { | ||
| return not_impl_err!( | ||
| "Expect two arguments for binary operator {op:?}" | ||
| ); | ||
| } | ||
|
|
||
| Ok(Arc::new(Expr::BinaryExpr(BinaryExpr { | ||
| left: Box::new(args[0].to_owned()), | ||
| op, | ||
| right: Box::new(args[1].to_owned()), | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there is a bit of a change where before we used to check the types of these args more explicitly, and now it only happens in |
||
| }))) | ||
| } else if let Some(builder) = BuiltinExprBuilder::try_from_name(fn_name) { | ||
| builder.build(ctx, f, input_schema, extensions).await | ||
| } else { | ||
| not_impl_err!("Unsupported function name: {fn_name:?}") | ||
| } | ||
| } | ||
| Some(RexType::Literal(lit)) => { | ||
|
|
@@ -1247,36 +1164,50 @@ pub async fn from_substrait_rex( | |
| None => substrait_err!("Cast expression without output type is not allowed"), | ||
| }, | ||
| Some(RexType::WindowFunction(window)) => { | ||
| let fun = match extensions.get(&window.function_reference) { | ||
| Some(function_name) => { | ||
| // check udaf | ||
| match ctx.udaf(function_name) { | ||
| Ok(udaf) => { | ||
| Ok(Some(WindowFunctionDefinition::AggregateUDF(udaf))) | ||
| } | ||
| Err(_) => Ok(find_df_window_func(function_name)), | ||
| } | ||
| } | ||
| None => not_impl_err!( | ||
| "Window function not found: function anchor = {:?}", | ||
| &window.function_reference | ||
| ), | ||
| let Some(fn_name) = extensions.get(&window.function_reference) else { | ||
| return plan_err!( | ||
| "Window function not found: function reference = {:?}", | ||
| window.function_reference | ||
| ); | ||
| }; | ||
| let fn_name = substrait_fun_name(fn_name); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this was the line I originally wanted to add to fix the issue |
||
|
|
||
| // check udaf first, then built-in functions | ||
| let fun = match ctx.udaf(fn_name) { | ||
| Ok(udaf) => Ok(WindowFunctionDefinition::AggregateUDF(udaf)), | ||
| Err(_) => find_df_window_func(fn_name).ok_or_else(|| { | ||
| not_impl_datafusion_err!( | ||
| "Window function {} is not supported: function anchor = {:?}", | ||
| fn_name, | ||
| window.function_reference | ||
| ) | ||
| }), | ||
| }?; | ||
|
|
||
| let order_by = | ||
| from_substrait_sorts(ctx, &window.sorts, input_schema, extensions) | ||
| .await?; | ||
| // Substrait does not encode WindowFrameUnits so we're using a simple logic to determine the units | ||
| // If there is no `ORDER BY`, then by default, the frame counts each row from the lower up to upper boundary | ||
| // If there is `ORDER BY`, then by default, each frame is a range starting from unbounded preceding to current row | ||
| // TODO: Consider the cases where window frame is specified in query and is different from default | ||
| let units = if order_by.is_empty() { | ||
| WindowFrameUnits::Rows | ||
| } else { | ||
| WindowFrameUnits::Range | ||
| }; | ||
|
|
||
| let bound_units = | ||
| match BoundsType::try_from(window.bounds_type).map_err(|e| { | ||
| plan_datafusion_err!("Invalid bound type {}: {e}", window.bounds_type) | ||
| })? { | ||
| BoundsType::Rows => WindowFrameUnits::Rows, | ||
| BoundsType::Range => WindowFrameUnits::Range, | ||
| BoundsType::Unspecified => { | ||
| // If the plan does not specify the bounds type, then we use a simple logic to determine the units | ||
| // If there is no `ORDER BY`, then by default, the frame counts each row from the lower up to upper boundary | ||
| // If there is `ORDER BY`, then by default, each frame is a range starting from unbounded preceding to current row | ||
| if order_by.is_empty() { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I dunno how well this logic works in reality, but I guess there is some logic to it - sorting is less necessary for a RANGE bound while a ROWS bound without sort is quite meaningless. Anyways we can easily keep it around for the unspecified case for backwards compatibility. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It might be better to simply return an error here. It looks to me like the substrait spec doesn't explicitly allow for unspecified -- I think the fact this field may not be set is because of how protobuf encodes the fields. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hm, the UNSPECIFIED is a proper option in Substrait spec, see https://github.com/substrait-io/substrait/blob/7dbbf0468083d932a61b9c720700bd6083558fa9/proto/substrait/algebra.proto#L1059. I don't have much opinion here (I don't think the plans we produce would ever have UNSPECIFIED), so I'm happy to leave it like this or to make it return a not_impl_err - do you have a preference? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would personally suggest a |
||
| WindowFrameUnits::Rows | ||
| } else { | ||
| WindowFrameUnits::Range | ||
| } | ||
| } | ||
| }; | ||
| Ok(Arc::new(Expr::WindowFunction(expr::WindowFunction { | ||
| fun: fun?.unwrap(), | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this made it hard to debug since we'd throw at the unwrap w/o context on why |
||
| args: from_substriat_func_args( | ||
| fun, | ||
| args: from_substrait_func_args( | ||
| ctx, | ||
| &window.arguments, | ||
| input_schema, | ||
|
|
@@ -1292,7 +1223,7 @@ pub async fn from_substrait_rex( | |
| .await?, | ||
| order_by, | ||
| window_frame: datafusion::logical_expr::WindowFrame::new_bounds( | ||
| units, | ||
| bound_units, | ||
| from_substrait_bound(&window.lower_bound, true)?, | ||
| from_substrait_bound(&window.upper_bound, false)?, | ||
| ), | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This felt like unnecessary indirection to go from name -> ScalarFunctionType -> actual function, we can just skip the ScalarFunctionType step
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree this is cleaner -- I think the old code is left over from the time when we had BuiltInScalarFunction