Skip to content

Commit d9ea729

Browse files
committed
Implement equals for stateful functions (apache#16781)
* Implement equals for stateful functions Default implementation of `ScalarUDFImpl::equals`, `AggregateUDFImpl::equals` and `WindowUDFImpl::equals` is correct for stateless functions and those which only state is the `Signature`, which is most of the functions. This implements `equals` and `hash_value` for functions that have state other than `Signature` object. This fixes correctness issues which could occur when such stateful functions are used together in one query. * downgrade for MSRV * Improve doc * Update default UDF:: equals to compare aliases too * Update default UDF:: equals to compare type too (‼️) * remove now-obsoleted UDF equals/hash customizations remove these which compare signature, aliases, as the default handles these now * remove equals impl which compares name, signature -- default does that * cleanup imports (cherry picked from commit afd8235)
1 parent 0935058 commit d9ea729

File tree

18 files changed

+572
-22
lines changed

18 files changed

+572
-22
lines changed

datafusion-examples/examples/function_factory.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ use datafusion::logical_expr::{
2828
ColumnarValue, CreateFunction, Expr, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl,
2929
Signature, Volatility,
3030
};
31+
use std::hash::{DefaultHasher, Hash, Hasher};
3132
use std::result::Result as RResult;
3233
use std::sync::Arc;
3334

@@ -157,6 +158,38 @@ impl ScalarUDFImpl for ScalarFunctionWrapper {
157158
fn output_ordering(&self, _input: &[ExprProperties]) -> Result<SortProperties> {
158159
Ok(SortProperties::Unordered)
159160
}
161+
162+
fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
163+
let Some(other) = other.as_any().downcast_ref::<Self>() else {
164+
return false;
165+
};
166+
let Self {
167+
name,
168+
expr,
169+
signature,
170+
return_type,
171+
} = self;
172+
name == &other.name
173+
&& expr == &other.expr
174+
&& signature == &other.signature
175+
&& return_type == &other.return_type
176+
}
177+
178+
fn hash_value(&self) -> u64 {
179+
let Self {
180+
name,
181+
expr,
182+
signature,
183+
return_type,
184+
} = self;
185+
let mut hasher = DefaultHasher::new();
186+
std::any::type_name::<Self>().hash(&mut hasher);
187+
name.hash(&mut hasher);
188+
expr.hash(&mut hasher);
189+
signature.hash(&mut hasher);
190+
return_type.hash(&mut hasher);
191+
hasher.finish()
192+
}
160193
}
161194

162195
impl ScalarFunctionWrapper {

datafusion/core/tests/user_defined/user_defined_scalar_functions.rs

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,34 @@ impl ScalarUDFImpl for Simple0ArgsScalarUDF {
215215
) -> Result<ColumnarValue> {
216216
Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(100))))
217217
}
218+
219+
fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
220+
let Some(other) = other.as_any().downcast_ref::<Self>() else {
221+
return false;
222+
};
223+
let Self {
224+
name,
225+
signature,
226+
return_type,
227+
} = self;
228+
name == &other.name
229+
&& signature == &other.signature
230+
&& return_type == &other.return_type
231+
}
232+
233+
fn hash_value(&self) -> u64 {
234+
let Self {
235+
name,
236+
signature,
237+
return_type,
238+
} = self;
239+
let mut hasher = DefaultHasher::new();
240+
std::any::type_name::<Self>().hash(&mut hasher);
241+
name.hash(&mut hasher);
242+
signature.hash(&mut hasher);
243+
return_type.hash(&mut hasher);
244+
hasher.finish()
245+
}
218246
}
219247

220248
#[tokio::test]
@@ -556,6 +584,34 @@ impl ScalarUDFImpl for AddIndexToStringVolatileScalarUDF {
556584
};
557585
Ok(ColumnarValue::Array(Arc::new(StringArray::from(answer))))
558586
}
587+
588+
fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
589+
let Some(other) = other.as_any().downcast_ref::<Self>() else {
590+
return false;
591+
};
592+
let Self {
593+
name,
594+
signature,
595+
return_type,
596+
} = self;
597+
name == &other.name
598+
&& signature == &other.signature
599+
&& return_type == &other.return_type
600+
}
601+
602+
fn hash_value(&self) -> u64 {
603+
let Self {
604+
name,
605+
signature,
606+
return_type,
607+
} = self;
608+
let mut hasher = DefaultHasher::new();
609+
std::any::type_name::<Self>().hash(&mut hasher);
610+
name.hash(&mut hasher);
611+
signature.hash(&mut hasher);
612+
return_type.hash(&mut hasher);
613+
hasher.finish()
614+
}
559615
}
560616

561617
#[tokio::test]
@@ -985,6 +1041,38 @@ impl ScalarUDFImpl for ScalarFunctionWrapper {
9851041
fn aliases(&self) -> &[String] {
9861042
&[]
9871043
}
1044+
1045+
fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
1046+
let Some(other) = other.as_any().downcast_ref::<Self>() else {
1047+
return false;
1048+
};
1049+
let Self {
1050+
name,
1051+
expr,
1052+
signature,
1053+
return_type,
1054+
} = self;
1055+
name == &other.name
1056+
&& expr == &other.expr
1057+
&& signature == &other.signature
1058+
&& return_type == &other.return_type
1059+
}
1060+
1061+
fn hash_value(&self) -> u64 {
1062+
let Self {
1063+
name,
1064+
expr,
1065+
signature,
1066+
return_type,
1067+
} = self;
1068+
let mut hasher = DefaultHasher::new();
1069+
std::any::type_name::<Self>().hash(&mut hasher);
1070+
name.hash(&mut hasher);
1071+
expr.hash(&mut hasher);
1072+
signature.hash(&mut hasher);
1073+
return_type.hash(&mut hasher);
1074+
hasher.finish()
1075+
}
9881076
}
9891077

9901078
impl ScalarFunctionWrapper {

datafusion/doc/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
/// thus all text should be in English.
3434
///
3535
/// [SQL function documentation]: https://datafusion.apache.org/user-guide/sql/index.html
36-
#[derive(Debug, Clone)]
36+
#[derive(Debug, Clone, PartialEq, Hash)]
3737
pub struct Documentation {
3838
/// The section in the documentation where the UDF will be documented
3939
pub doc_section: DocSection,
@@ -153,7 +153,7 @@ impl Documentation {
153153
}
154154
}
155155

156-
#[derive(Debug, Clone, PartialEq)]
156+
#[derive(Debug, Clone, PartialEq, Hash)]
157157
pub struct DocSection {
158158
/// True to include this doc section in the public
159159
/// documentation, false otherwise

datafusion/expr/src/expr_fn.rs

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
4343
use sqlparser::ast::NullTreatment;
4444
use std::any::Any;
4545
use std::fmt::Debug;
46+
use std::hash::{DefaultHasher, Hash, Hasher};
4647
use std::ops::Not;
4748
use std::sync::Arc;
4849

@@ -484,6 +485,38 @@ impl ScalarUDFImpl for SimpleScalarUDF {
484485
) -> Result<ColumnarValue> {
485486
(self.fun)(args)
486487
}
488+
489+
fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
490+
let Some(other) = other.as_any().downcast_ref::<Self>() else {
491+
return false;
492+
};
493+
let Self {
494+
name,
495+
signature,
496+
return_type,
497+
fun,
498+
} = self;
499+
name == &other.name
500+
&& signature == &other.signature
501+
&& return_type == &other.return_type
502+
&& Arc::ptr_eq(fun, &other.fun)
503+
}
504+
505+
fn hash_value(&self) -> u64 {
506+
let Self {
507+
name,
508+
signature,
509+
return_type,
510+
fun,
511+
} = self;
512+
let mut hasher = DefaultHasher::new();
513+
std::any::type_name::<Self>().hash(&mut hasher);
514+
name.hash(&mut hasher);
515+
signature.hash(&mut hasher);
516+
return_type.hash(&mut hasher);
517+
Arc::as_ptr(fun).hash(&mut hasher);
518+
hasher.finish()
519+
}
487520
}
488521

489522
/// Creates a new UDAF with a specific signature, state type and return type.
@@ -603,6 +636,42 @@ impl AggregateUDFImpl for SimpleAggregateUDF {
603636
fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
604637
Ok(self.state_fields.clone())
605638
}
639+
640+
fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
641+
let Some(other) = other.as_any().downcast_ref::<Self>() else {
642+
return false;
643+
};
644+
let Self {
645+
name,
646+
signature,
647+
return_type,
648+
accumulator,
649+
state_fields,
650+
} = self;
651+
name == &other.name
652+
&& signature == &other.signature
653+
&& return_type == &other.return_type
654+
&& Arc::ptr_eq(accumulator, &other.accumulator)
655+
&& state_fields == &other.state_fields
656+
}
657+
658+
fn hash_value(&self) -> u64 {
659+
let Self {
660+
name,
661+
signature,
662+
return_type,
663+
accumulator,
664+
state_fields,
665+
} = self;
666+
let mut hasher = DefaultHasher::new();
667+
std::any::type_name::<Self>().hash(&mut hasher);
668+
name.hash(&mut hasher);
669+
signature.hash(&mut hasher);
670+
return_type.hash(&mut hasher);
671+
Arc::as_ptr(accumulator).hash(&mut hasher);
672+
state_fields.hash(&mut hasher);
673+
hasher.finish()
674+
}
606675
}
607676

608677
/// Creates a new UDWF with a specific signature, state type and return type.
@@ -695,6 +764,41 @@ impl WindowUDFImpl for SimpleWindowUDF {
695764
true,
696765
))
697766
}
767+
768+
fn equals(&self, other: &dyn WindowUDFImpl) -> bool {
769+
let Some(other) = other.as_any().downcast_ref::<Self>() else {
770+
return false;
771+
};
772+
let Self {
773+
name,
774+
signature,
775+
return_type,
776+
partition_evaluator_factory,
777+
} = self;
778+
name == &other.name
779+
&& signature == &other.signature
780+
&& return_type == &other.return_type
781+
&& Arc::ptr_eq(
782+
partition_evaluator_factory,
783+
&other.partition_evaluator_factory,
784+
)
785+
}
786+
787+
fn hash_value(&self) -> u64 {
788+
let Self {
789+
name,
790+
signature,
791+
return_type,
792+
partition_evaluator_factory,
793+
} = self;
794+
let mut hasher = DefaultHasher::new();
795+
std::any::type_name::<Self>().hash(&mut hasher);
796+
name.hash(&mut hasher);
797+
signature.hash(&mut hasher);
798+
return_type.hash(&mut hasher);
799+
Arc::as_ptr(partition_evaluator_factory).hash(&mut hasher);
800+
hasher.finish()
801+
}
698802
}
699803

700804
pub fn interval_year_month_lit(value: &str) -> Expr {

datafusion/expr/src/udaf.rs

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -574,26 +574,35 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
574574
/// Return true if this aggregate UDF is equal to the other.
575575
///
576576
/// Allows customizing the equality of aggregate UDFs.
577+
/// *Must* be implemented explicitly if the UDF type has internal state.
577578
/// Must be consistent with [`Self::hash_value`] and follow the same rules as [`Eq`]:
578579
///
579580
/// - reflexive: `a.equals(a)`;
580581
/// - symmetric: `a.equals(b)` implies `b.equals(a)`;
581582
/// - transitive: `a.equals(b)` and `b.equals(c)` implies `a.equals(c)`.
582583
///
583-
/// By default, compares [`Self::name`] and [`Self::signature`].
584+
/// By default, compares type, [`Self::name`], [`Self::aliases`] and [`Self::signature`].
584585
fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
585-
self.name() == other.name() && self.signature() == other.signature()
586+
self.as_any().type_id() == other.as_any().type_id()
587+
&& self.name() == other.name()
588+
&& self.aliases() == other.aliases()
589+
&& self.signature() == other.signature()
586590
}
587591

588592
/// Returns a hash value for this aggregate UDF.
589593
///
590-
/// Allows customizing the hash code of aggregate UDFs. Similarly to [`Hash`] and [`Eq`],
591-
/// if [`Self::equals`] returns true for two UDFs, their `hash_value`s must be the same.
594+
/// Allows customizing the hash code of aggregate UDFs.
595+
/// *Must* be implemented explicitly whenever [`Self::equals`] is implemented.
592596
///
593-
/// By default, hashes [`Self::name`] and [`Self::signature`].
597+
/// Similarly to [`Hash`] and [`Eq`], if [`Self::equals`] returns true for two UDFs,
598+
/// their `hash_value`s must be the same.
599+
///
600+
/// By default, it is consistent with default implementation of [`Self::equals`].
594601
fn hash_value(&self) -> u64 {
595602
let hasher = &mut DefaultHasher::new();
603+
self.as_any().type_id().hash(hasher);
596604
self.name().hash(hasher);
605+
self.aliases().hash(hasher);
597606
self.signature().hash(hasher);
598607
hasher.finish()
599608
}

datafusion/expr/src/udf.rs

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -801,26 +801,35 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
801801
/// Return true if this scalar UDF is equal to the other.
802802
///
803803
/// Allows customizing the equality of scalar UDFs.
804+
/// *Must* be implemented explicitly if the UDF type has internal state.
804805
/// Must be consistent with [`Self::hash_value`] and follow the same rules as [`Eq`]:
805806
///
806807
/// - reflexive: `a.equals(a)`;
807808
/// - symmetric: `a.equals(b)` implies `b.equals(a)`;
808809
/// - transitive: `a.equals(b)` and `b.equals(c)` implies `a.equals(c)`.
809810
///
810-
/// By default, compares [`Self::name`] and [`Self::signature`].
811+
/// By default, compares type, [`Self::name`], [`Self::aliases`] and [`Self::signature`].
811812
fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
812-
self.name() == other.name() && self.signature() == other.signature()
813+
self.as_any().type_id() == other.as_any().type_id()
814+
&& self.name() == other.name()
815+
&& self.aliases() == other.aliases()
816+
&& self.signature() == other.signature()
813817
}
814818

815819
/// Returns a hash value for this scalar UDF.
816820
///
817-
/// Allows customizing the hash code of scalar UDFs. Similarly to [`Hash`] and [`Eq`],
818-
/// if [`Self::equals`] returns true for two UDFs, their `hash_value`s must be the same.
821+
/// Allows customizing the hash code of scalar UDFs.
822+
/// *Must* be implemented explicitly whenever [`Self::equals`] is implemented.
819823
///
820-
/// By default, hashes [`Self::name`] and [`Self::signature`].
824+
/// Similarly to [`Hash`] and [`Eq`], if [`Self::equals`] returns true for two UDFs,
825+
/// their `hash_value`s must be the same.
826+
///
827+
/// By default, it is consistent with default implementation of [`Self::equals`].
821828
fn hash_value(&self) -> u64 {
822829
let hasher = &mut DefaultHasher::new();
830+
self.as_any().type_id().hash(hasher);
823831
self.name().hash(hasher);
832+
self.aliases().hash(hasher);
824833
self.signature().hash(hasher);
825834
hasher.finish()
826835
}
@@ -951,6 +960,7 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl {
951960

952961
fn hash_value(&self) -> u64 {
953962
let hasher = &mut DefaultHasher::new();
963+
std::any::type_name::<Self>().hash(hasher);
954964
self.inner.hash_value().hash(hasher);
955965
self.aliases.hash(hasher);
956966
hasher.finish()

0 commit comments

Comments
 (0)