Skip to content

Commit 8966dc0

Browse files
authored
Implementation of array_intersect (#8081)
* Initial Implementation of array_intersect Signed-off-by: veeupup <[email protected]> * fix comments Signed-off-by: veeupup <[email protected]> x --------- Signed-off-by: veeupup <[email protected]>
1 parent 4e8777d commit 8966dc0

File tree

11 files changed

+238
-6
lines changed

11 files changed

+238
-6
lines changed

datafusion/expr/src/built_in_function.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,8 @@ pub enum BuiltinScalarFunction {
174174
ArraySlice,
175175
/// array_to_string
176176
ArrayToString,
177+
/// array_intersect
178+
ArrayIntersect,
177179
/// cardinality
178180
Cardinality,
179181
/// construct an array from columns
@@ -398,6 +400,7 @@ impl BuiltinScalarFunction {
398400
BuiltinScalarFunction::Flatten => Volatility::Immutable,
399401
BuiltinScalarFunction::ArraySlice => Volatility::Immutable,
400402
BuiltinScalarFunction::ArrayToString => Volatility::Immutable,
403+
BuiltinScalarFunction::ArrayIntersect => Volatility::Immutable,
401404
BuiltinScalarFunction::Cardinality => Volatility::Immutable,
402405
BuiltinScalarFunction::MakeArray => Volatility::Immutable,
403406
BuiltinScalarFunction::Ascii => Volatility::Immutable,
@@ -577,6 +580,7 @@ impl BuiltinScalarFunction {
577580
BuiltinScalarFunction::ArrayReplaceAll => Ok(input_expr_types[0].clone()),
578581
BuiltinScalarFunction::ArraySlice => Ok(input_expr_types[0].clone()),
579582
BuiltinScalarFunction::ArrayToString => Ok(Utf8),
583+
BuiltinScalarFunction::ArrayIntersect => Ok(input_expr_types[0].clone()),
580584
BuiltinScalarFunction::Cardinality => Ok(UInt64),
581585
BuiltinScalarFunction::MakeArray => match input_expr_types.len() {
582586
0 => Ok(List(Arc::new(Field::new("item", Null, true)))),
@@ -880,6 +884,7 @@ impl BuiltinScalarFunction {
880884
BuiltinScalarFunction::ArrayToString => {
881885
Signature::variadic_any(self.volatility())
882886
}
887+
BuiltinScalarFunction::ArrayIntersect => Signature::any(2, self.volatility()),
883888
BuiltinScalarFunction::Cardinality => Signature::any(1, self.volatility()),
884889
BuiltinScalarFunction::MakeArray => {
885890
// 0 or more arguments of arbitrary type
@@ -1505,6 +1510,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] {
15051510
],
15061511
BuiltinScalarFunction::Cardinality => &["cardinality"],
15071512
BuiltinScalarFunction::MakeArray => &["make_array", "make_list"],
1513+
BuiltinScalarFunction::ArrayIntersect => &["array_intersect", "list_intersect"],
15081514

15091515
// struct functions
15101516
BuiltinScalarFunction::Struct => &["struct"],

datafusion/expr/src/expr_fn.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -715,6 +715,12 @@ nary_scalar_expr!(
715715
array,
716716
"returns an Arrow array using the specified input expressions."
717717
);
718+
scalar_expr!(
719+
ArrayIntersect,
720+
array_intersect,
721+
first_array second_array,
722+
"Returns an array of the elements in the intersection of array1 and array2."
723+
);
718724

719725
// string functions
720726
scalar_expr!(Ascii, ascii, chr, "ASCII code value of the character");

datafusion/physical-expr/src/array_expressions.rs

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use arrow::array::*;
2424
use arrow::buffer::OffsetBuffer;
2525
use arrow::compute;
2626
use arrow::datatypes::{DataType, Field, UInt64Type};
27+
use arrow::row::{RowConverter, SortField};
2728
use arrow_buffer::NullBuffer;
2829

2930
use datafusion_common::cast::{
@@ -35,6 +36,7 @@ use datafusion_common::{
3536
DataFusionError, Result,
3637
};
3738

39+
use hashbrown::HashSet;
3840
use itertools::Itertools;
3941

4042
macro_rules! downcast_arg {
@@ -347,7 +349,7 @@ fn array_array(args: &[ArrayRef], data_type: DataType) -> Result<ArrayRef> {
347349
let data_type = arrays[0].data_type();
348350
let field = Arc::new(Field::new("item", data_type.to_owned(), true));
349351
let elements = arrays.iter().map(|x| x.as_ref()).collect::<Vec<_>>();
350-
let values = arrow::compute::concat(elements.as_slice())?;
352+
let values = compute::concat(elements.as_slice())?;
351353
let list_arr = ListArray::new(
352354
field,
353355
OffsetBuffer::from_lengths(array_lengths),
@@ -368,7 +370,7 @@ fn array_array(args: &[ArrayRef], data_type: DataType) -> Result<ArrayRef> {
368370
.iter()
369371
.map(|x| x as &dyn Array)
370372
.collect::<Vec<_>>();
371-
let values = arrow::compute::concat(elements.as_slice())?;
373+
let values = compute::concat(elements.as_slice())?;
372374
let list_arr = ListArray::new(
373375
field,
374376
OffsetBuffer::from_lengths(list_array_lengths),
@@ -767,7 +769,7 @@ fn concat_internal(args: &[ArrayRef]) -> Result<ArrayRef> {
767769
.collect::<Vec<&dyn Array>>();
768770

769771
// Concatenated array on i-th row
770-
let concated_array = arrow::compute::concat(elements.as_slice())?;
772+
let concated_array = compute::concat(elements.as_slice())?;
771773
array_lengths.push(concated_array.len());
772774
arrays.push(concated_array);
773775
valid.append(true);
@@ -785,7 +787,7 @@ fn concat_internal(args: &[ArrayRef]) -> Result<ArrayRef> {
785787
let list_arr = ListArray::new(
786788
Arc::new(Field::new("item", data_type, true)),
787789
OffsetBuffer::from_lengths(array_lengths),
788-
Arc::new(arrow::compute::concat(elements.as_slice())?),
790+
Arc::new(compute::concat(elements.as_slice())?),
789791
Some(NullBuffer::new(buffer)),
790792
);
791793
Ok(Arc::new(list_arr))
@@ -879,7 +881,7 @@ fn general_repeat(array: &ArrayRef, count_array: &Int64Array) -> Result<ArrayRef
879881
}
880882

881883
let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect();
882-
let values = arrow::compute::concat(&new_values)?;
884+
let values = compute::concat(&new_values)?;
883885

884886
Ok(Arc::new(ListArray::try_new(
885887
Arc::new(Field::new("item", data_type.to_owned(), true)),
@@ -947,7 +949,7 @@ fn general_list_repeat(
947949

948950
let lengths = new_values.iter().map(|a| a.len()).collect::<Vec<_>>();
949951
let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect();
950-
let values = arrow::compute::concat(&new_values)?;
952+
let values = compute::concat(&new_values)?;
951953

952954
Ok(Arc::new(ListArray::try_new(
953955
Arc::new(Field::new("item", data_type.to_owned(), true)),
@@ -1798,6 +1800,61 @@ pub fn string_to_array<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef
17981800
Ok(Arc::new(list_array) as ArrayRef)
17991801
}
18001802

1803+
/// array_intersect SQL function
1804+
pub fn array_intersect(args: &[ArrayRef]) -> Result<ArrayRef> {
1805+
assert_eq!(args.len(), 2);
1806+
1807+
let first_array = as_list_array(&args[0])?;
1808+
let second_array = as_list_array(&args[1])?;
1809+
1810+
if first_array.value_type() != second_array.value_type() {
1811+
return internal_err!("array_intersect is not implemented for '{first_array:?}' and '{second_array:?}'");
1812+
}
1813+
let dt = first_array.value_type().clone();
1814+
1815+
let mut offsets = vec![0];
1816+
let mut new_arrays = vec![];
1817+
1818+
let converter = RowConverter::new(vec![SortField::new(dt.clone())])?;
1819+
for (first_arr, second_arr) in first_array.iter().zip(second_array.iter()) {
1820+
if let (Some(first_arr), Some(second_arr)) = (first_arr, second_arr) {
1821+
let l_values = converter.convert_columns(&[first_arr])?;
1822+
let r_values = converter.convert_columns(&[second_arr])?;
1823+
1824+
let values_set: HashSet<_> = l_values.iter().collect();
1825+
let mut rows = Vec::with_capacity(r_values.num_rows());
1826+
for r_val in r_values.iter().sorted().dedup() {
1827+
if values_set.contains(&r_val) {
1828+
rows.push(r_val);
1829+
}
1830+
}
1831+
1832+
let last_offset: i32 = match offsets.last().copied() {
1833+
Some(offset) => offset,
1834+
None => return internal_err!("offsets should not be empty"),
1835+
};
1836+
offsets.push(last_offset + rows.len() as i32);
1837+
let arrays = converter.convert_rows(rows)?;
1838+
let array = match arrays.get(0) {
1839+
Some(array) => array.clone(),
1840+
None => {
1841+
return internal_err!(
1842+
"array_intersect: failed to get array from rows"
1843+
)
1844+
}
1845+
};
1846+
new_arrays.push(array);
1847+
}
1848+
}
1849+
1850+
let field = Arc::new(Field::new("item", dt, true));
1851+
let offsets = OffsetBuffer::new(offsets.into());
1852+
let new_arrays_ref = new_arrays.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
1853+
let values = compute::concat(&new_arrays_ref)?;
1854+
let arr = Arc::new(ListArray::try_new(field, offsets, values, None)?);
1855+
Ok(arr)
1856+
}
1857+
18011858
#[cfg(test)]
18021859
mod tests {
18031860
use super::*;

datafusion/physical-expr/src/functions.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,9 @@ pub fn create_physical_fun(
398398
BuiltinScalarFunction::ArrayToString => Arc::new(|args| {
399399
make_scalar_function(array_expressions::array_to_string)(args)
400400
}),
401+
BuiltinScalarFunction::ArrayIntersect => Arc::new(|args| {
402+
make_scalar_function(array_expressions::array_intersect)(args)
403+
}),
401404
BuiltinScalarFunction::Cardinality => {
402405
Arc::new(|args| make_scalar_function(array_expressions::cardinality)(args))
403406
}

datafusion/proto/proto/datafusion.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -621,6 +621,7 @@ enum ScalarFunction {
621621
ArrayPopBack = 116;
622622
StringToArray = 117;
623623
ToTimestampNanos = 118;
624+
ArrayIntersect = 119;
624625
}
625626

626627
message ScalarFunctionNode {

datafusion/proto/src/generated/pbjson.rs

Lines changed: 3 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/proto/src/generated/prost.rs

Lines changed: 3 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/proto/src/logical_plan/from_proto.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction {
483483
ScalarFunction::ArrayReplaceAll => Self::ArrayReplaceAll,
484484
ScalarFunction::ArraySlice => Self::ArraySlice,
485485
ScalarFunction::ArrayToString => Self::ArrayToString,
486+
ScalarFunction::ArrayIntersect => Self::ArrayIntersect,
486487
ScalarFunction::Cardinality => Self::Cardinality,
487488
ScalarFunction::Array => Self::MakeArray,
488489
ScalarFunction::NullIf => Self::NullIf,

datafusion/proto/src/logical_plan/to_proto.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1481,6 +1481,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction {
14811481
BuiltinScalarFunction::ArrayReplaceAll => Self::ArrayReplaceAll,
14821482
BuiltinScalarFunction::ArraySlice => Self::ArraySlice,
14831483
BuiltinScalarFunction::ArrayToString => Self::ArrayToString,
1484+
BuiltinScalarFunction::ArrayIntersect => Self::ArrayIntersect,
14841485
BuiltinScalarFunction::Cardinality => Self::Cardinality,
14851486
BuiltinScalarFunction::MakeArray => Self::Array,
14861487
BuiltinScalarFunction::NullIf => Self::NullIf,

0 commit comments

Comments
 (0)