@@ -19,7 +19,6 @@ use std::{
1919 any:: Any ,
2020 cmp:: min,
2121 fmt:: { Debug , Write } ,
22- str:: FromStr ,
2322 sync:: Arc ,
2423} ;
2524
@@ -35,17 +34,15 @@ use arrow_array::{Array, ArrowNativeTypeOp, Decimal128Array, StringArray};
3534use arrow_schema:: DataType ;
3635use datafusion:: {
3736 execution:: FunctionRegistry ,
38- logical_expr:: {
39- BuiltinScalarFunction , ScalarFunctionDefinition , ScalarFunctionImplementation ,
40- ScalarUDFImpl , Signature , Volatility ,
41- } ,
37+ functions:: math:: round:: round,
38+ logical_expr:: { ScalarFunctionImplementation , ScalarUDFImpl , Signature , Volatility } ,
4239 physical_plan:: ColumnarValue ,
4340} ;
4441use datafusion_common:: {
4542 cast:: { as_binary_array, as_generic_string_array} ,
4643 exec_err, internal_err, DataFusionError , Result as DataFusionResult , ScalarValue ,
4744} ;
48- use datafusion_physical_expr :: { math_expressions , udf :: ScalarUDF } ;
45+ use datafusion_expr :: ScalarUDF ;
4946use num:: {
5047 integer:: { div_ceil, div_floor} ,
5148 BigInt , Signed , ToPrimitive ,
@@ -63,9 +60,7 @@ macro_rules! make_comet_scalar_udf {
6360 $data_type. clone( ) ,
6461 Arc :: new( move |args| $func( args, & $data_type) ) ,
6562 ) ;
66- Ok ( ScalarFunctionDefinition :: UDF ( Arc :: new(
67- ScalarUDF :: new_from_impl( scalar_func) ,
68- ) ) )
63+ Ok ( Arc :: new( ScalarUDF :: new_from_impl( scalar_func) ) )
6964 } } ;
7065 ( $name: expr, $func: expr, without $data_type: ident) => { {
7166 let scalar_func = CometScalarFunction :: new(
@@ -74,9 +69,7 @@ macro_rules! make_comet_scalar_udf {
7469 $data_type,
7570 $func,
7671 ) ;
77- Ok ( ScalarFunctionDefinition :: UDF ( Arc :: new(
78- ScalarUDF :: new_from_impl( scalar_func) ,
79- ) ) )
72+ Ok ( Arc :: new( ScalarUDF :: new_from_impl( scalar_func) ) )
8073 } } ;
8174}
8275
@@ -85,7 +78,7 @@ pub fn create_comet_physical_fun(
8578 fun_name : & str ,
8679 data_type : DataType ,
8780 registry : & dyn FunctionRegistry ,
88- ) -> Result < ScalarFunctionDefinition , DataFusionError > {
81+ ) -> Result < Arc < ScalarUDF > , DataFusionError > {
8982 let sha2_functions = [ "sha224" , "sha256" , "sha384" , "sha512" ] ;
9083 match fun_name {
9184 "ceil" => {
@@ -129,13 +122,11 @@ pub fn create_comet_physical_fun(
129122 let spark_func_name = "spark" . to_owned ( ) + sha;
130123 make_comet_scalar_udf ! ( spark_func_name, wrapped_func, without data_type)
131124 }
132- _ => {
133- if let Ok ( fun) = BuiltinScalarFunction :: from_str ( fun_name) {
134- Ok ( ScalarFunctionDefinition :: BuiltIn ( fun) )
135- } else {
136- Ok ( ScalarFunctionDefinition :: UDF ( registry. udf ( fun_name) ?) )
137- }
138- }
125+ _ => registry. udf ( fun_name) . map_err ( |e| {
126+ DataFusionError :: Execution ( format ! (
127+ "Function {fun_name} not found in the registry: {e}" ,
128+ ) )
129+ } ) ,
139130 }
140131}
141132
@@ -498,9 +489,7 @@ fn spark_round(
498489 make_decimal_array ( array, precision, scale, & f)
499490 }
500491 DataType :: Float32 | DataType :: Float64 => {
501- Ok ( ColumnarValue :: Array ( math_expressions:: round ( & [
502- array. clone ( )
503- ] ) ?) )
492+ Ok ( ColumnarValue :: Array ( round ( & [ array. clone ( ) ] ) ?) )
504493 }
505494 dt => exec_err ! ( "Not supported datatype for ROUND: {dt}" ) ,
506495 } ,
@@ -523,7 +512,7 @@ fn spark_round(
523512 make_decimal_scalar ( a, precision, scale, & f)
524513 }
525514 ScalarValue :: Float32 ( _) | ScalarValue :: Float64 ( _) => Ok ( ColumnarValue :: Scalar (
526- ScalarValue :: try_from_array ( & math_expressions :: round ( & [ a. to_array ( ) ?] ) ?, 0 ) ?,
515+ ScalarValue :: try_from_array ( & round ( & [ a. to_array ( ) ?] ) ?, 0 ) ?,
527516 ) ) ,
528517 dt => exec_err ! ( "Not supported datatype for ROUND: {dt}" ) ,
529518 } ,
0 commit comments