Skip to content

Commit afd8235

Browse files
authored
Implement equals for stateful functions (apache#16781)
* Implement equals for stateful functions Default implementation of `ScalarUDFImpl::equals`, `AggregateUDFImpl::equals` and `WindowUDFImpl::equals` is correct for stateless functions and those which only state is the `Signature`, which is most of the functions. This implements `equals` and `hash_value` for functions that have state other than `Signature` object. This fixes correctness issues which could occur when such stateful functions are used together in one query. * downgrade for MSRV * Improve doc * Update default UDF:: equals to compare aliases too * Update default UDF:: equals to compare type too (‼️) * remove now-obsoleted UDF equals/hash customizations remove these which compare signature, aliases, as the default handles these now * remove equals impl which compares name, signature -- default does that * cleanup imports
1 parent d4d5bfd commit afd8235

File tree

25 files changed

+862
-31
lines changed

25 files changed

+862
-31
lines changed

datafusion-examples/examples/function_factory.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ use datafusion::logical_expr::{
2828
ColumnarValue, CreateFunction, Expr, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl,
2929
Signature, Volatility,
3030
};
31+
use std::hash::{DefaultHasher, Hash, Hasher};
3132
use std::result::Result as RResult;
3233
use std::sync::Arc;
3334

@@ -153,6 +154,38 @@ impl ScalarUDFImpl for ScalarFunctionWrapper {
153154
fn output_ordering(&self, _input: &[ExprProperties]) -> Result<SortProperties> {
154155
Ok(SortProperties::Unordered)
155156
}
157+
158+
fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
159+
let Some(other) = other.as_any().downcast_ref::<Self>() else {
160+
return false;
161+
};
162+
let Self {
163+
name,
164+
expr,
165+
signature,
166+
return_type,
167+
} = self;
168+
name == &other.name
169+
&& expr == &other.expr
170+
&& signature == &other.signature
171+
&& return_type == &other.return_type
172+
}
173+
174+
fn hash_value(&self) -> u64 {
175+
let Self {
176+
name,
177+
expr,
178+
signature,
179+
return_type,
180+
} = self;
181+
let mut hasher = DefaultHasher::new();
182+
std::any::type_name::<Self>().hash(&mut hasher);
183+
name.hash(&mut hasher);
184+
expr.hash(&mut hasher);
185+
signature.hash(&mut hasher);
186+
return_type.hash(&mut hasher);
187+
hasher.finish()
188+
}
156189
}
157190

158191
impl ScalarFunctionWrapper {

datafusion/core/tests/user_defined/user_defined_aggregates.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -957,6 +957,33 @@ impl AggregateUDFImpl for MetadataBasedAggregateUdf {
957957
curr_sum: 0,
958958
}))
959959
}
960+
961+
fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
962+
let Some(other) = other.as_any().downcast_ref::<Self>() else {
963+
return false;
964+
};
965+
let Self {
966+
name,
967+
signature,
968+
metadata,
969+
} = self;
970+
name == &other.name
971+
&& signature == &other.signature
972+
&& metadata == &other.metadata
973+
}
974+
975+
fn hash_value(&self) -> u64 {
976+
let Self {
977+
name,
978+
signature,
979+
metadata: _, // unhashable
980+
} = self;
981+
let mut hasher = DefaultHasher::new();
982+
std::any::type_name::<Self>().hash(&mut hasher);
983+
name.hash(&mut hasher);
984+
signature.hash(&mut hasher);
985+
hasher.finish()
986+
}
960987
}
961988

962989
#[derive(Debug)]

datafusion/core/tests/user_defined/user_defined_scalar_functions.rs

Lines changed: 112 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,34 @@ impl ScalarUDFImpl for Simple0ArgsScalarUDF {
217217
fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
218218
Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(100))))
219219
}
220+
221+
fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
222+
let Some(other) = other.as_any().downcast_ref::<Self>() else {
223+
return false;
224+
};
225+
let Self {
226+
name,
227+
signature,
228+
return_type,
229+
} = self;
230+
name == &other.name
231+
&& signature == &other.signature
232+
&& return_type == &other.return_type
233+
}
234+
235+
fn hash_value(&self) -> u64 {
236+
let Self {
237+
name,
238+
signature,
239+
return_type,
240+
} = self;
241+
let mut hasher = DefaultHasher::new();
242+
std::any::type_name::<Self>().hash(&mut hasher);
243+
name.hash(&mut hasher);
244+
signature.hash(&mut hasher);
245+
return_type.hash(&mut hasher);
246+
hasher.finish()
247+
}
220248
}
221249

222250
#[tokio::test]
@@ -557,6 +585,34 @@ impl ScalarUDFImpl for AddIndexToStringVolatileScalarUDF {
557585
};
558586
Ok(ColumnarValue::Array(Arc::new(StringArray::from(answer))))
559587
}
588+
589+
fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
590+
let Some(other) = other.as_any().downcast_ref::<Self>() else {
591+
return false;
592+
};
593+
let Self {
594+
name,
595+
signature,
596+
return_type,
597+
} = self;
598+
name == &other.name
599+
&& signature == &other.signature
600+
&& return_type == &other.return_type
601+
}
602+
603+
fn hash_value(&self) -> u64 {
604+
let Self {
605+
name,
606+
signature,
607+
return_type,
608+
} = self;
609+
let mut hasher = DefaultHasher::new();
610+
std::any::type_name::<Self>().hash(&mut hasher);
611+
name.hash(&mut hasher);
612+
signature.hash(&mut hasher);
613+
return_type.hash(&mut hasher);
614+
hasher.finish()
615+
}
560616
}
561617

562618
#[tokio::test]
@@ -974,6 +1030,38 @@ impl ScalarUDFImpl for ScalarFunctionWrapper {
9741030

9751031
Ok(ExprSimplifyResult::Simplified(replacement))
9761032
}
1033+
1034+
fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
1035+
let Some(other) = other.as_any().downcast_ref::<Self>() else {
1036+
return false;
1037+
};
1038+
let Self {
1039+
name,
1040+
expr,
1041+
signature,
1042+
return_type,
1043+
} = self;
1044+
name == &other.name
1045+
&& expr == &other.expr
1046+
&& signature == &other.signature
1047+
&& return_type == &other.return_type
1048+
}
1049+
1050+
fn hash_value(&self) -> u64 {
1051+
let Self {
1052+
name,
1053+
expr,
1054+
signature,
1055+
return_type,
1056+
} = self;
1057+
let mut hasher = DefaultHasher::new();
1058+
std::any::type_name::<Self>().hash(&mut hasher);
1059+
name.hash(&mut hasher);
1060+
expr.hash(&mut hasher);
1061+
signature.hash(&mut hasher);
1062+
return_type.hash(&mut hasher);
1063+
hasher.finish()
1064+
}
9771065
}
9781066

9791067
impl ScalarFunctionWrapper {
@@ -1450,7 +1538,30 @@ impl ScalarUDFImpl for MetadataBasedUdf {
14501538
}
14511539

14521540
fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
1453-
self.name == other.name()
1541+
let Some(other) = other.as_any().downcast_ref::<Self>() else {
1542+
return false;
1543+
};
1544+
let Self {
1545+
name,
1546+
signature,
1547+
metadata,
1548+
} = self;
1549+
name == &other.name
1550+
&& signature == &other.signature
1551+
&& metadata == &other.metadata
1552+
}
1553+
1554+
fn hash_value(&self) -> u64 {
1555+
let Self {
1556+
name,
1557+
signature,
1558+
metadata: _, // unhashable
1559+
} = self;
1560+
let mut hasher = DefaultHasher::new();
1561+
std::any::type_name::<Self>().hash(&mut hasher);
1562+
name.hash(&mut hasher);
1563+
signature.hash(&mut hasher);
1564+
hasher.finish()
14541565
}
14551566
}
14561567

@@ -1669,10 +1780,6 @@ impl ScalarUDFImpl for ExtensionBasedUdf {
16691780
}
16701781
}
16711782
}
1672-
1673-
fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
1674-
self.name == other.name()
1675-
}
16761783
}
16771784

16781785
struct MyUserExtentionType {}

datafusion/core/tests/user_defined/user_defined_window_functions.rs

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ use datafusion_physical_expr::{
4040
PhysicalExpr,
4141
};
4242
use std::collections::HashMap;
43+
use std::hash::{DefaultHasher, Hash, Hasher};
4344
use std::{
4445
any::Any,
4546
ops::Range,
@@ -568,6 +569,33 @@ impl OddCounter {
568569
fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
569570
Ok(Field::new(field_args.name(), DataType::Int64, true).into())
570571
}
572+
573+
fn equals(&self, other: &dyn WindowUDFImpl) -> bool {
574+
let Some(other) = other.as_any().downcast_ref::<Self>() else {
575+
return false;
576+
};
577+
let Self {
578+
signature,
579+
test_state,
580+
aliases,
581+
} = self;
582+
signature == &other.signature
583+
&& Arc::ptr_eq(test_state, &other.test_state)
584+
&& aliases == &other.aliases
585+
}
586+
587+
fn hash_value(&self) -> u64 {
588+
let Self {
589+
signature,
590+
test_state,
591+
aliases,
592+
} = self;
593+
let mut hasher = DefaultHasher::new();
594+
signature.hash(&mut hasher);
595+
Arc::as_ptr(test_state).hash(&mut hasher);
596+
aliases.hash(&mut hasher);
597+
hasher.finish()
598+
}
571599
}
572600

573601
ctx.register_udwf(WindowUDF::from(SimpleWindowUDF::new(test_state)))
@@ -815,6 +843,33 @@ impl WindowUDFImpl for MetadataBasedWindowUdf {
815843
.with_metadata(self.metadata.clone())
816844
.into())
817845
}
846+
847+
fn equals(&self, other: &dyn WindowUDFImpl) -> bool {
848+
let Some(other) = other.as_any().downcast_ref::<Self>() else {
849+
return false;
850+
};
851+
let Self {
852+
name,
853+
signature,
854+
metadata,
855+
} = self;
856+
name == &other.name
857+
&& signature == &other.signature
858+
&& metadata == &other.metadata
859+
}
860+
861+
fn hash_value(&self) -> u64 {
862+
let Self {
863+
name,
864+
signature,
865+
metadata: _, // unhashable
866+
} = self;
867+
let mut hasher = DefaultHasher::new();
868+
std::any::type_name::<Self>().hash(&mut hasher);
869+
name.hash(&mut hasher);
870+
signature.hash(&mut hasher);
871+
hasher.finish()
872+
}
818873
}
819874

820875
#[derive(Debug)]

datafusion/doc/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
/// thus all text should be in English.
4040
///
4141
/// [SQL function documentation]: https://datafusion.apache.org/user-guide/sql/index.html
42-
#[derive(Debug, Clone)]
42+
#[derive(Debug, Clone, PartialEq, Hash)]
4343
pub struct Documentation {
4444
/// The section in the documentation where the UDF will be documented
4545
pub doc_section: DocSection,
@@ -158,7 +158,7 @@ impl Documentation {
158158
}
159159
}
160160

161-
#[derive(Debug, Clone, PartialEq)]
161+
#[derive(Debug, Clone, PartialEq, Hash)]
162162
pub struct DocSection {
163163
/// True to include this doc section in the public
164164
/// documentation, false otherwise

datafusion/expr/src/async_udf.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ use datafusion_expr_common::columnar_value::ColumnarValue;
2626
use datafusion_expr_common::signature::Signature;
2727
use std::any::Any;
2828
use std::fmt::{Debug, Display};
29+
use std::hash::{DefaultHasher, Hash, Hasher};
2930
use std::sync::Arc;
3031

3132
/// A scalar UDF that can invoke using async methods
@@ -111,6 +112,22 @@ impl ScalarUDFImpl for AsyncScalarUDF {
111112
fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
112113
internal_err!("async functions should not be called directly")
113114
}
115+
116+
fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
117+
let Some(other) = other.as_any().downcast_ref::<Self>() else {
118+
return false;
119+
};
120+
let Self { inner } = self;
121+
// TODO when MSRV >= 1.86.0, switch to `inner.equals(other.inner.as_ref())` leveraging trait upcasting
122+
Arc::ptr_eq(inner, &other.inner)
123+
}
124+
125+
fn hash_value(&self) -> u64 {
126+
let Self { inner } = self;
127+
let mut hasher = DefaultHasher::new();
128+
Arc::as_ptr(inner).hash(&mut hasher);
129+
hasher.finish()
130+
}
114131
}
115132

116133
impl Display for AsyncScalarUDF {

0 commit comments

Comments
 (0)