Skip to content

Commit 0480fff

Browse files
author
Arttu Voutilainen
committed
simplify scalar function handling
1 parent a3a6867 commit 0480fff

File tree

1 file changed

+21
-45
lines changed

1 file changed

+21
-45
lines changed

datafusion/substrait/src/logical_plan/consumer.rs

Lines changed: 21 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,6 @@ use crate::variation_const::{
8989
UNSIGNED_INTEGER_TYPE_VARIATION_REF,
9090
};
9191

92-
enum ScalarFunctionType {
93-
Op(Operator),
94-
Expr(BuiltinExprBuilder),
95-
Udf(Arc<ScalarUDF>),
96-
}
97-
9892
pub fn name_to_op(name: &str) -> Result<Operator> {
9993
match name {
10094
"equal" => Ok(Operator::Eq),
@@ -128,25 +122,6 @@ pub fn name_to_op(name: &str) -> Result<Operator> {
128122
}
129123
}
130124

131-
fn scalar_function_type_from_str(
132-
ctx: &SessionContext,
133-
name: &str,
134-
) -> Result<ScalarFunctionType> {
135-
if let Some(func) = ctx.state().scalar_functions().get(name) {
136-
return Ok(ScalarFunctionType::Udf(func.to_owned()));
137-
}
138-
139-
if let Ok(op) = name_to_op(name) {
140-
return Ok(ScalarFunctionType::Op(op));
141-
}
142-
143-
if let Some(builder) = BuiltinExprBuilder::try_from_name(name) {
144-
return Ok(ScalarFunctionType::Expr(builder));
145-
}
146-
147-
not_impl_err!("Unsupported function name: {name:?}")
148-
}
149-
150125
pub fn substrait_fun_name(name: &str) -> &str {
151126
let name = match name.rsplit_once(':') {
152127
// Since 0.32.0, Substrait requires the function names to be in a compound format
@@ -1144,27 +1119,28 @@ pub async fn from_substrait_rex(
11441119
from_substrait_func_args(ctx, &f.arguments, input_schema, extensions)
11451120
.await?;
11461121

1147-
let fn_type = scalar_function_type_from_str(ctx, fn_name)?;
1148-
match fn_type {
1149-
ScalarFunctionType::Udf(fun) => Ok(Arc::new(Expr::ScalarFunction(
1150-
expr::ScalarFunction::new_udf(fun, args),
1151-
))),
1152-
ScalarFunctionType::Op(op) => {
1153-
if args.len() != 2 {
1154-
return not_impl_err!(
1155-
"Expect two arguments for binary operator {op:?}"
1156-
);
1157-
}
1158-
1159-
Ok(Arc::new(Expr::BinaryExpr(BinaryExpr {
1160-
left: Box::new(args[0].to_owned()),
1161-
op,
1162-
right: Box::new(args[1].to_owned()),
1163-
})))
1164-
}
1165-
ScalarFunctionType::Expr(builder) => {
1166-
builder.build(ctx, f, input_schema, extensions).await
1122+
// try to first match the requested function into registered udfs, then built-in ops
1123+
// and finally built-in expressions
1124+
if let Some(func) = ctx.state().scalar_functions().get(fn_name) {
1125+
Ok(Arc::new(Expr::ScalarFunction(
1126+
expr::ScalarFunction::new_udf(func.to_owned(), args),
1127+
)))
1128+
} else if let Ok(op) = name_to_op(fn_name) {
1129+
if args.len() != 2 {
1130+
return not_impl_err!(
1131+
"Expect two arguments for binary operator {op:?}"
1132+
);
11671133
}
1134+
1135+
Ok(Arc::new(Expr::BinaryExpr(BinaryExpr {
1136+
left: Box::new(args[0].to_owned()),
1137+
op,
1138+
right: Box::new(args[1].to_owned()),
1139+
})))
1140+
} else if let Some(builder) = BuiltinExprBuilder::try_from_name(fn_name) {
1141+
builder.build(ctx, f, input_schema, extensions).await
1142+
} else {
1143+
not_impl_err!("Unsupported function name: {name:?}")
11681144
}
11691145
}
11701146
Some(RexType::Literal(lit)) => {

0 commit comments

Comments
 (0)