Skip to content

Commit 070517a

Browse files
authored
Derive UDF equality from PartialEq, Hash (#16842)
* Derive UDF equality from PartialEq, Hash Reduce boilerplate in cases where implementation of `{ScalarUDFImpl,AggregateUDFImpl,WindowUDFImpl}::{equals,hash_code}` can be derived using standard `PartialEq` and `Hash` traits. This is code complexity reduction. While valuable on its own, this also prepares for more automatic derivation of UDF equals/hash and/or removal of default implementations (which currently are error-prone). * udf_equals_hash example * test udf_equals_hash * empty: roll the dice 🎲
1 parent bb1b55c commit 070517a

File tree

8 files changed

+335
-292
lines changed

8 files changed

+335
-292
lines changed

datafusion/core/tests/user_defined/user_defined_scalar_functions.rs

Lines changed: 40 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
use std::any::Any;
1919
use std::collections::HashMap;
20-
use std::hash::{DefaultHasher, Hash, Hasher};
20+
use std::hash::{Hash, Hasher};
2121
use std::sync::Arc;
2222

2323
use arrow::array::{as_string_array, create_array, record_batch, Int8Array, UInt64Array};
@@ -43,9 +43,9 @@ use datafusion_common::{
4343
use datafusion_expr::expr::FieldMetadata;
4444
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
4545
use datafusion_expr::{
46-
lit_with_metadata, Accumulator, ColumnarValue, CreateFunction, CreateFunctionBody,
47-
LogicalPlanBuilder, OperateFunctionArg, ReturnFieldArgs, ScalarFunctionArgs,
48-
ScalarUDF, ScalarUDFImpl, Signature, Volatility,
46+
lit_with_metadata, udf_equals_hash, Accumulator, ColumnarValue, CreateFunction,
47+
CreateFunctionBody, LogicalPlanBuilder, OperateFunctionArg, ReturnFieldArgs,
48+
ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
4949
};
5050
use datafusion_functions_nested::range::range_udf;
5151
use parking_lot::Mutex;
@@ -181,6 +181,7 @@ async fn scalar_udf() -> Result<()> {
181181
Ok(())
182182
}
183183

184+
#[derive(PartialEq, Hash)]
184185
struct Simple0ArgsScalarUDF {
185186
name: String,
186187
signature: Signature,
@@ -218,33 +219,7 @@ impl ScalarUDFImpl for Simple0ArgsScalarUDF {
218219
Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(100))))
219220
}
220221

221-
fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
222-
let Some(other) = other.as_any().downcast_ref::<Self>() else {
223-
return false;
224-
};
225-
let Self {
226-
name,
227-
signature,
228-
return_type,
229-
} = self;
230-
name == &other.name
231-
&& signature == &other.signature
232-
&& return_type == &other.return_type
233-
}
234-
235-
fn hash_value(&self) -> u64 {
236-
let Self {
237-
name,
238-
signature,
239-
return_type,
240-
} = self;
241-
let mut hasher = DefaultHasher::new();
242-
std::any::type_name::<Self>().hash(&mut hasher);
243-
name.hash(&mut hasher);
244-
signature.hash(&mut hasher);
245-
return_type.hash(&mut hasher);
246-
hasher.finish()
247-
}
222+
udf_equals_hash!(ScalarUDFImpl);
248223
}
249224

250225
#[tokio::test]
@@ -517,7 +492,7 @@ async fn test_user_defined_functions_with_alias() -> Result<()> {
517492
}
518493

519494
/// Volatile UDF that should append a different value to each row
520-
#[derive(Debug)]
495+
#[derive(Debug, PartialEq, Hash)]
521496
struct AddIndexToStringVolatileScalarUDF {
522497
name: String,
523498
signature: Signature,
@@ -586,33 +561,7 @@ impl ScalarUDFImpl for AddIndexToStringVolatileScalarUDF {
586561
Ok(ColumnarValue::Array(Arc::new(StringArray::from(answer))))
587562
}
588563

589-
fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
590-
let Some(other) = other.as_any().downcast_ref::<Self>() else {
591-
return false;
592-
};
593-
let Self {
594-
name,
595-
signature,
596-
return_type,
597-
} = self;
598-
name == &other.name
599-
&& signature == &other.signature
600-
&& return_type == &other.return_type
601-
}
602-
603-
fn hash_value(&self) -> u64 {
604-
let Self {
605-
name,
606-
signature,
607-
return_type,
608-
} = self;
609-
let mut hasher = DefaultHasher::new();
610-
std::any::type_name::<Self>().hash(&mut hasher);
611-
name.hash(&mut hasher);
612-
signature.hash(&mut hasher);
613-
return_type.hash(&mut hasher);
614-
hasher.finish()
615-
}
564+
udf_equals_hash!(ScalarUDFImpl);
616565
}
617566

618567
#[tokio::test]
@@ -992,7 +941,7 @@ impl FunctionFactory for CustomFunctionFactory {
992941
//
993942
// it also defines custom [ScalarUDFImpl::simplify()]
994943
// to replace ScalarUDF expression with one instance contains.
995-
#[derive(Debug)]
944+
#[derive(Debug, PartialEq, Hash)]
996945
struct ScalarFunctionWrapper {
997946
name: String,
998947
expr: Expr,
@@ -1031,37 +980,7 @@ impl ScalarUDFImpl for ScalarFunctionWrapper {
1031980
Ok(ExprSimplifyResult::Simplified(replacement))
1032981
}
1033982

1034-
fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
1035-
let Some(other) = other.as_any().downcast_ref::<Self>() else {
1036-
return false;
1037-
};
1038-
let Self {
1039-
name,
1040-
expr,
1041-
signature,
1042-
return_type,
1043-
} = self;
1044-
name == &other.name
1045-
&& expr == &other.expr
1046-
&& signature == &other.signature
1047-
&& return_type == &other.return_type
1048-
}
1049-
1050-
fn hash_value(&self) -> u64 {
1051-
let Self {
1052-
name,
1053-
expr,
1054-
signature,
1055-
return_type,
1056-
} = self;
1057-
let mut hasher = DefaultHasher::new();
1058-
std::any::type_name::<Self>().hash(&mut hasher);
1059-
name.hash(&mut hasher);
1060-
expr.hash(&mut hasher);
1061-
signature.hash(&mut hasher);
1062-
return_type.hash(&mut hasher);
1063-
hasher.finish()
1064-
}
983+
udf_equals_hash!(ScalarUDFImpl);
1065984
}
1066985

1067986
impl ScalarFunctionWrapper {
@@ -1296,6 +1215,21 @@ struct MyRegexUdf {
12961215
regex: Regex,
12971216
}
12981217

1218+
impl PartialEq for MyRegexUdf {
1219+
fn eq(&self, other: &Self) -> bool {
1220+
let Self { signature, regex } = self;
1221+
signature == &other.signature && regex.as_str() == other.regex.as_str()
1222+
}
1223+
}
1224+
1225+
impl Hash for MyRegexUdf {
1226+
fn hash<H: Hasher>(&self, state: &mut H) {
1227+
let Self { signature, regex } = self;
1228+
signature.hash(state);
1229+
regex.as_str().hash(state);
1230+
}
1231+
}
1232+
12991233
impl MyRegexUdf {
13001234
fn new(pattern: &str) -> Self {
13011235
Self {
@@ -1348,19 +1282,7 @@ impl ScalarUDFImpl for MyRegexUdf {
13481282
}
13491283
}
13501284

1351-
fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
1352-
if let Some(other) = other.as_any().downcast_ref::<MyRegexUdf>() {
1353-
self.regex.as_str() == other.regex.as_str()
1354-
} else {
1355-
false
1356-
}
1357-
}
1358-
1359-
fn hash_value(&self) -> u64 {
1360-
let hasher = &mut DefaultHasher::new();
1361-
self.regex.as_str().hash(hasher);
1362-
hasher.finish()
1363-
}
1285+
udf_equals_hash!(ScalarUDFImpl);
13641286
}
13651287

13661288
#[tokio::test]
@@ -1458,13 +1380,25 @@ async fn plan_and_collect(ctx: &SessionContext, sql: &str) -> Result<Vec<RecordB
14581380
ctx.sql(sql).await?.collect().await
14591381
}
14601382

1461-
#[derive(Debug)]
1383+
#[derive(Debug, PartialEq)]
14621384
struct MetadataBasedUdf {
14631385
name: String,
14641386
signature: Signature,
14651387
metadata: HashMap<String, String>,
14661388
}
14671389

1390+
impl Hash for MetadataBasedUdf {
1391+
fn hash<H: Hasher>(&self, state: &mut H) {
1392+
let Self {
1393+
name,
1394+
signature,
1395+
metadata: _, // unhashable
1396+
} = self;
1397+
name.hash(state);
1398+
signature.hash(state);
1399+
}
1400+
}
1401+
14681402
impl MetadataBasedUdf {
14691403
fn new(metadata: HashMap<String, String>) -> Self {
14701404
// The name we return must be unique. Otherwise we will not call distinct
@@ -1537,32 +1471,7 @@ impl ScalarUDFImpl for MetadataBasedUdf {
15371471
}
15381472
}
15391473

1540-
fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
1541-
let Some(other) = other.as_any().downcast_ref::<Self>() else {
1542-
return false;
1543-
};
1544-
let Self {
1545-
name,
1546-
signature,
1547-
metadata,
1548-
} = self;
1549-
name == &other.name
1550-
&& signature == &other.signature
1551-
&& metadata == &other.metadata
1552-
}
1553-
1554-
fn hash_value(&self) -> u64 {
1555-
let Self {
1556-
name,
1557-
signature,
1558-
metadata: _, // unhashable
1559-
} = self;
1560-
let mut hasher = DefaultHasher::new();
1561-
std::any::type_name::<Self>().hash(&mut hasher);
1562-
name.hash(&mut hasher);
1563-
signature.hash(&mut hasher);
1564-
hasher.finish()
1565-
}
1474+
udf_equals_hash!(ScalarUDFImpl);
15661475
}
15671476

15681477
#[tokio::test]

datafusion/expr/src/async_udf.rs

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use crate::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl};
18+
use crate::utils::{arc_ptr_eq, arc_ptr_hash};
19+
use crate::{
20+
udf_equals_hash, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl,
21+
};
1922
use arrow::array::ArrayRef;
2023
use arrow::datatypes::{DataType, FieldRef};
2124
use async_trait::async_trait;
@@ -26,7 +29,7 @@ use datafusion_expr_common::columnar_value::ColumnarValue;
2629
use datafusion_expr_common::signature::Signature;
2730
use std::any::Any;
2831
use std::fmt::{Debug, Display};
29-
use std::hash::{DefaultHasher, Hash, Hasher};
32+
use std::hash::{Hash, Hasher};
3033
use std::sync::Arc;
3134

3235
/// A scalar UDF that can invoke using async methods
@@ -62,6 +65,21 @@ pub struct AsyncScalarUDF {
6265
inner: Arc<dyn AsyncScalarUDFImpl>,
6366
}
6467

68+
impl PartialEq for AsyncScalarUDF {
69+
fn eq(&self, other: &Self) -> bool {
70+
let Self { inner } = self;
71+
// TODO when MSRV >= 1.86.0, switch to `inner.equals(other.inner.as_ref())` leveraging trait upcasting.
72+
arc_ptr_eq(inner, &other.inner)
73+
}
74+
}
75+
76+
impl Hash for AsyncScalarUDF {
77+
fn hash<H: Hasher>(&self, state: &mut H) {
78+
let Self { inner } = self;
79+
arc_ptr_hash(inner, state);
80+
}
81+
}
82+
6583
impl AsyncScalarUDF {
6684
pub fn new(inner: Arc<dyn AsyncScalarUDFImpl>) -> Self {
6785
Self { inner }
@@ -113,21 +131,7 @@ impl ScalarUDFImpl for AsyncScalarUDF {
113131
internal_err!("async functions should not be called directly")
114132
}
115133

116-
fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
117-
let Some(other) = other.as_any().downcast_ref::<Self>() else {
118-
return false;
119-
};
120-
let Self { inner } = self;
121-
// TODO when MSRV >= 1.86.0, switch to `inner.equals(other.inner.as_ref())` leveraging trait upcasting
122-
Arc::ptr_eq(inner, &other.inner)
123-
}
124-
125-
fn hash_value(&self) -> u64 {
126-
let Self { inner } = self;
127-
let mut hasher = DefaultHasher::new();
128-
Arc::as_ptr(inner).hash(&mut hasher);
129-
hasher.finish()
130-
}
134+
udf_equals_hash!(ScalarUDFImpl);
131135
}
132136

133137
impl Display for AsyncScalarUDF {

0 commit comments

Comments
 (0)