Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 85 additions & 154 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ use datafusion::arrow::datatypes::{
DataType, Field, FieldRef, Fields, IntervalUnit, Schema, TimeUnit,
};
use datafusion::common::{
not_impl_err, substrait_datafusion_err, substrait_err, DFSchema, DFSchemaRef,
not_impl_datafusion_err, not_impl_err, plan_datafusion_err, plan_err,
substrait_datafusion_err, substrait_err, DFSchema, DFSchemaRef,
};
use substrait::proto::expression::literal::IntervalDayToSecond;
use substrait::proto::read_rel::local_files::file_or_files::PathType::UriFile;
Expand All @@ -30,8 +31,7 @@ use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
use datafusion::execution::FunctionRegistry;
use datafusion::logical_expr::{
aggregate_function, expr::find_df_window_func, Aggregate, BinaryExpr, Case,
EmptyRelation, Expr, ExprSchemable, LogicalPlan, Operator, Projection, ScalarUDF,
Values,
EmptyRelation, Expr, ExprSchemable, LogicalPlan, Operator, Projection, Values,
};

use datafusion::logical_expr::{
Expand All @@ -57,7 +57,7 @@ use substrait::proto::{
reference_segment::ReferenceType::StructField,
window_function::bound as SubstraitBound,
window_function::bound::Kind as BoundKind, window_function::Bound,
MaskExpression, RexType,
window_function::BoundsType, MaskExpression, RexType,
},
extensions::simple_extension_declaration::MappingType,
function_argument::ArgType,
Expand All @@ -71,7 +71,6 @@ use substrait::proto::{
use substrait::proto::{FunctionArgument, SortField};

use datafusion::arrow::array::GenericListArray;
use datafusion::common::plan_err;
use datafusion::common::scalar::ScalarStructBuilder;
use datafusion::logical_expr::expr::{InList, InSubquery, Sort};
use std::collections::HashMap;
Expand All @@ -89,12 +88,6 @@ use crate::variation_const::{
UNSIGNED_INTEGER_TYPE_VARIATION_REF,
};

enum ScalarFunctionType {
Op(Operator),
Expr(BuiltinExprBuilder),
Udf(Arc<ScalarUDF>),
}

pub fn name_to_op(name: &str) -> Result<Operator> {
match name {
"equal" => Ok(Operator::Eq),
Expand Down Expand Up @@ -128,28 +121,6 @@ pub fn name_to_op(name: &str) -> Result<Operator> {
}
}

fn scalar_function_type_from_str(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This felt like unnecessary indirection to go from name -> ScalarFunctionType -> actual function, we can just skip the ScalarFunctionType step

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree this is cleaner -- I think the old code is left over from the time when we had BuiltInScalarFunction

ctx: &SessionContext,
name: &str,
) -> Result<ScalarFunctionType> {
let s = ctx.state();
let name = substrait_fun_name(name);

if let Some(func) = s.scalar_functions().get(name) {
return Ok(ScalarFunctionType::Udf(func.to_owned()));
}

if let Ok(op) = name_to_op(name) {
return Ok(ScalarFunctionType::Op(op));
}

if let Some(builder) = BuiltinExprBuilder::try_from_name(name) {
return Ok(ScalarFunctionType::Expr(builder));
}

not_impl_err!("Unsupported function name: {name:?}")
}

pub fn substrait_fun_name(name: &str) -> &str {
let name = match name.rsplit_once(':') {
// Since 0.32.0, Substrait requires the function names to be in a compound format
Expand Down Expand Up @@ -972,7 +943,7 @@ pub async fn from_substrait_rex_vec(
}

/// Convert Substrait FunctionArguments to DataFusion Exprs
pub async fn from_substriat_func_args(
pub async fn from_substrait_func_args(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix'd the spelling - this would break consumer if someone is using this function, but I dunno why anyone would be. Still it's not necessary change so I'm happy to revert if that'd be better

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is a nice improvement

ctx: &SessionContext,
arguments: &Vec<FunctionArgument>,
input_schema: &DFSchema,
Expand All @@ -984,9 +955,7 @@ pub async fn from_substriat_func_args(
Some(ArgType::Value(e)) => {
from_substrait_rex(ctx, e, input_schema, extensions).await
}
_ => {
not_impl_err!("Aggregated function argument non-Value type not supported")
}
_ => not_impl_err!("Function argument non-Value type not supported"),
};
args.push(arg_expr?.as_ref().clone());
}
Expand All @@ -1003,33 +972,25 @@ pub async fn from_substrait_agg_func(
order_by: Option<Vec<Expr>>,
distinct: bool,
) -> Result<Arc<Expr>> {
let mut args: Vec<Expr> = vec![];
for arg in &f.arguments {
let arg_expr = match &arg.arg_type {
Some(ArgType::Value(e)) => {
from_substrait_rex(ctx, e, input_schema, extensions).await
}
_ => {
not_impl_err!("Aggregated function argument non-Value type not supported")
}
};
args.push(arg_expr?.as_ref().clone());
}
let args =
from_substrait_func_args(ctx, &f.arguments, input_schema, extensions).await?;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the code of from_substrait_func_args was duplicated inline here and for scalar functions


let Some(function_name) = extensions.get(&f.function_reference) else {
return plan_err!(
"Aggregate function not registered: function anchor = {:?}",
f.function_reference
);
};
// function_name.split(':').next().unwrap_or(function_name);

let function_name = substrait_fun_name((**function_name).as_str());
// try udaf first, then built-in aggr fn.
if let Ok(fun) = ctx.udaf(function_name) {
// deal with situation that count(*) got no arguments
if fun.name() == "count" && args.is_empty() {
args.push(Expr::Literal(ScalarValue::Int64(Some(1))));
}
let args = if fun.name() == "count" && args.is_empty() {
vec![Expr::Literal(ScalarValue::Int64(Some(1)))]
} else {
args
};
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this way args doesn't need to be mutable


Ok(Arc::new(Expr::AggregateFunction(
expr::AggregateFunction::new_udf(fun, args, distinct, filter, order_by, None),
Expand All @@ -1041,7 +1002,7 @@ pub async fn from_substrait_agg_func(
)))
} else {
not_impl_err!(
"Aggregated function {} is not supported: function anchor = {:?}",
"Aggregate function {} is not supported: function anchor = {:?}",
function_name,
f.function_reference
)
Expand Down Expand Up @@ -1145,84 +1106,40 @@ pub async fn from_substrait_rex(
})))
}
Some(RexType::ScalarFunction(f)) => {
let fn_name = extensions.get(&f.function_reference).ok_or_else(|| {
DataFusionError::NotImplemented(format!(
"Aggregated function not found: function reference = {:?}",
let Some(fn_name) = extensions.get(&f.function_reference) else {
return plan_err!(
"Scalar function not found: function reference = {:?}",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just aligning this check for all three function types

f.function_reference
))
})?;

// Convert function arguments from Substrait to DataFusion
async fn decode_arguments(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was same as from_substrait_func_args

ctx: &SessionContext,
input_schema: &DFSchema,
extensions: &HashMap<u32, &String>,
function_args: &[FunctionArgument],
) -> Result<Vec<Expr>> {
let mut args = Vec::with_capacity(function_args.len());
for arg in function_args {
let arg_expr = match &arg.arg_type {
Some(ArgType::Value(e)) => {
from_substrait_rex(ctx, e, input_schema, extensions).await
}
_ => not_impl_err!(
"Aggregated function argument non-Value type not supported"
),
}?;
args.push(arg_expr.as_ref().clone());
}
Ok(args)
}
);
};
let fn_name = substrait_fun_name(fn_name);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this and below is coming from scalar_function_type_from_str


let fn_type = scalar_function_type_from_str(ctx, fn_name)?;
match fn_type {
ScalarFunctionType::Udf(fun) => {
let args = decode_arguments(
ctx,
input_schema,
extensions,
f.arguments.as_slice(),
)
let args =
from_substrait_func_args(ctx, &f.arguments, input_schema, extensions)
.await?;
Ok(Arc::new(Expr::ScalarFunction(
expr::ScalarFunction::new_udf(fun, args),
)))
}
ScalarFunctionType::Op(op) => {
if f.arguments.len() != 2 {
return not_impl_err!(
"Expect two arguments for binary operator {op:?}"
);
}
let lhs = &f.arguments[0].arg_type;
let rhs = &f.arguments[1].arg_type;

match (lhs, rhs) {
(Some(ArgType::Value(l)), Some(ArgType::Value(r))) => {
Ok(Arc::new(Expr::BinaryExpr(BinaryExpr {
left: Box::new(
from_substrait_rex(ctx, l, input_schema, extensions)
.await?
.as_ref()
.clone(),
),
op,
right: Box::new(
from_substrait_rex(ctx, r, input_schema, extensions)
.await?
.as_ref()
.clone(),
),
})))
}
(l, r) => not_impl_err!(
"Invalid arguments for binary expression: {l:?} and {r:?}"
),
}
}
ScalarFunctionType::Expr(builder) => {
builder.build(ctx, f, input_schema, extensions).await

// try to first match the requested function into registered udfs, then built-in ops
// and finally built-in expressions
if let Some(func) = ctx.state().scalar_functions().get(fn_name) {
Ok(Arc::new(Expr::ScalarFunction(
expr::ScalarFunction::new_udf(func.to_owned(), args),
)))
} else if let Ok(op) = name_to_op(fn_name) {
if args.len() != 2 {
return not_impl_err!(
"Expect two arguments for binary operator {op:?}"
);
}

Ok(Arc::new(Expr::BinaryExpr(BinaryExpr {
left: Box::new(args[0].to_owned()),
op,
right: Box::new(args[1].to_owned()),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is a bit of a change where before we used to check the types of these args more explicitly, and now it only happens in from_substrait_func_args. I think that should be fine

})))
} else if let Some(builder) = BuiltinExprBuilder::try_from_name(fn_name) {
builder.build(ctx, f, input_schema, extensions).await
} else {
not_impl_err!("Unsupported function name: {fn_name:?}")
}
}
Some(RexType::Literal(lit)) => {
Expand All @@ -1247,36 +1164,50 @@ pub async fn from_substrait_rex(
None => substrait_err!("Cast expression without output type is not allowed"),
},
Some(RexType::WindowFunction(window)) => {
let fun = match extensions.get(&window.function_reference) {
Some(function_name) => {
// check udaf
match ctx.udaf(function_name) {
Ok(udaf) => {
Ok(Some(WindowFunctionDefinition::AggregateUDF(udaf)))
}
Err(_) => Ok(find_df_window_func(function_name)),
}
}
None => not_impl_err!(
"Window function not found: function anchor = {:?}",
&window.function_reference
),
let Some(fn_name) = extensions.get(&window.function_reference) else {
return plan_err!(
"Window function not found: function reference = {:?}",
window.function_reference
);
};
let fn_name = substrait_fun_name(fn_name);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was the line I originally wanted to add to fix the issue


// check udaf first, then built-in functions
let fun = match ctx.udaf(fn_name) {
Ok(udaf) => Ok(WindowFunctionDefinition::AggregateUDF(udaf)),
Err(_) => find_df_window_func(fn_name).ok_or_else(|| {
not_impl_datafusion_err!(
"Window function {} is not supported: function anchor = {:?}",
fn_name,
window.function_reference
)
}),
}?;

let order_by =
from_substrait_sorts(ctx, &window.sorts, input_schema, extensions)
.await?;
// Substrait does not encode WindowFrameUnits so we're using a simple logic to determine the units
// If there is no `ORDER BY`, then by default, the frame counts each row from the lower up to upper boundary
// If there is `ORDER BY`, then by default, each frame is a range starting from unbounded preceding to current row
// TODO: Consider the cases where window frame is specified in query and is different from default
let units = if order_by.is_empty() {
WindowFrameUnits::Rows
} else {
WindowFrameUnits::Range
};

let bound_units =
match BoundsType::try_from(window.bounds_type).map_err(|e| {
plan_datafusion_err!("Invalid bound type {}: {e}", window.bounds_type)
})? {
BoundsType::Rows => WindowFrameUnits::Rows,
BoundsType::Range => WindowFrameUnits::Range,
BoundsType::Unspecified => {
// If the plan does not specify the bounds type, then we use a simple logic to determine the units
// If there is no `ORDER BY`, then by default, the frame counts each row from the lower up to upper boundary
// If there is `ORDER BY`, then by default, each frame is a range starting from unbounded preceding to current row
if order_by.is_empty() {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dunno how well this logic works in reality, but I guess there is some logic to it - sorting is less necessary for a RANGE bound while a ROWS bound without sort is quite meaningless. Anyways we can easily keep it around for the unspecified case for backwards compatibility.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be better to simply return an error here. It looks to me like the substrait spec doesn't explicitly allow for unspecified -- I think the fact this field may not be set is because of how protobuf encodes the fields.

https://github.com/substrait-io/substrait/blob/7dbbf0468083d932a61b9c720700bd6083558fa9/proto/substrait/algebra.proto#L1038-L1039

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, the UNSPECIFIED is a proper option in Substrait spec, see https://github.com/substrait-io/substrait/blob/7dbbf0468083d932a61b9c720700bd6083558fa9/proto/substrait/algebra.proto#L1059. I don't have much opinion here (I don't think the plans we produce would ever have UNSPECIFIED), so I'm happy to leave it like this or to make it return a not_impl_err - do you have a preference?

Copy link
Contributor

@alamb alamb Jul 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would personally suggest a not_impl_err to avoid silently ignored errors, but we can do it as a follow on PR (or never)

WindowFrameUnits::Rows
} else {
WindowFrameUnits::Range
}
}
};
Ok(Arc::new(Expr::WindowFunction(expr::WindowFunction {
fun: fun?.unwrap(),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this made it hard to debug since we'd throw at the unwrap w/o context on why

args: from_substriat_func_args(
fun,
args: from_substrait_func_args(
ctx,
&window.arguments,
input_schema,
Expand All @@ -1292,7 +1223,7 @@ pub async fn from_substrait_rex(
.await?,
order_by,
window_frame: datafusion::logical_expr::WindowFrame::new_bounds(
units,
bound_units,
from_substrait_bound(&window.lower_bound, true)?,
from_substrait_bound(&window.upper_bound, false)?,
),
Expand Down
Loading