Skip to content

Commit 522f698

Browse files
committed
Initial Implementation of array_intersect
Signed-off-by: veeupup <[email protected]>
1 parent 01d7dba commit 522f698

File tree

10 files changed

+142
-0
lines changed

10 files changed

+142
-0
lines changed

datafusion/core/tests/sqllogictests/test_files/array.slt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1695,6 +1695,15 @@ select array_has_all(make_array(1,2,3), make_array(1,3)),
16951695
----
16961696
true false true false false false true true false false true false true
16971697

1698+
query ????
1699+
SELECT array_intersect(make_array(1,2,3), make_array(2,3,4)),
1700+
array_intersect(make_array(1,3,5), make_array(2,4,6)),
1701+
array_intersect(make_array('aa','bb','cc'), make_array('cc','aa','dd')),
1702+
array_intersect(make_array(true, false), make_array(true))
1703+
;
1704+
----
1705+
[2, 3] [] [cc, aa] [true]
1706+
16981707
query BBBB
16991708
select list_has_all(make_array(1,2,3), make_array(4,5,6)),
17001709
list_has_all(make_array(1,2,3), make_array(1,2)),

datafusion/expr/src/built_in_function.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,8 @@ pub enum BuiltinScalarFunction {
153153
ArrayReplaceAll,
154154
/// array_to_string
155155
ArrayToString,
156+
/// array_intersect
157+
ArrayIntersect,
156158
/// cardinality
157159
Cardinality,
158160
/// construct an array from columns
@@ -359,6 +361,7 @@ impl BuiltinScalarFunction {
359361
BuiltinScalarFunction::ArrayReplaceN => Volatility::Immutable,
360362
BuiltinScalarFunction::ArrayReplaceAll => Volatility::Immutable,
361363
BuiltinScalarFunction::ArrayToString => Volatility::Immutable,
364+
BuiltinScalarFunction::ArrayIntersect => Volatility::Immutable,
362365
BuiltinScalarFunction::Cardinality => Volatility::Immutable,
363366
BuiltinScalarFunction::MakeArray => Volatility::Immutable,
364367
BuiltinScalarFunction::TrimArray => Volatility::Immutable,
@@ -543,6 +546,34 @@ impl BuiltinScalarFunction {
543546
BuiltinScalarFunction::ArrayReplaceN => Ok(input_expr_types[0].clone()),
544547
BuiltinScalarFunction::ArrayReplaceAll => Ok(input_expr_types[0].clone()),
545548
BuiltinScalarFunction::ArrayToString => Ok(Utf8),
549+
BuiltinScalarFunction::ArrayIntersect => {
550+
if input_expr_types.len() < 2 || input_expr_types.len() > 2 {
551+
Err(DataFusionError::Internal(format!(
552+
"The {self} function must have two arrays as parameters"
553+
)))
554+
} else {
555+
match (&input_expr_types[0], &input_expr_types[1]) {
556+
(List(l_field), List(r_field)) => {
557+
if !l_field.data_type().equals_datatype(r_field.data_type()) {
558+
Err(DataFusionError::Internal(format!(
559+
"The {self} function array data type not equal, [0]: {:?}, [1]: {:?}",
560+
l_field.data_type(), r_field.data_type()
561+
)))
562+
} else {
563+
Ok(List(Arc::new(Field::new(
564+
"item",
565+
l_field.data_type().clone(),
566+
true,
567+
))))
568+
}
569+
}
570+
_ => Err(DataFusionError::Internal(format!(
571+
"The {} parameters should be array, [0]: {:?}, [1]: {:?}",
572+
self, input_expr_types[0], input_expr_types[1]
573+
))),
574+
}
575+
}
576+
}
546577
BuiltinScalarFunction::Cardinality => Ok(UInt64),
547578
BuiltinScalarFunction::MakeArray => match input_expr_types.len() {
548579
0 => Ok(List(Arc::new(Field::new("item", Null, true)))),
@@ -834,6 +865,7 @@ impl BuiltinScalarFunction {
834865
BuiltinScalarFunction::ArrayToString => {
835866
Signature::variadic_any(self.volatility())
836867
}
868+
BuiltinScalarFunction::ArrayIntersect => Signature::any(2, self.volatility()),
837869
BuiltinScalarFunction::Cardinality => Signature::any(1, self.volatility()),
838870
BuiltinScalarFunction::MakeArray => {
839871
Signature::variadic_any(self.volatility())
@@ -1324,6 +1356,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] {
13241356
BuiltinScalarFunction::Cardinality => &["cardinality"],
13251357
BuiltinScalarFunction::MakeArray => &["make_array", "make_list"],
13261358
BuiltinScalarFunction::TrimArray => &["trim_array"],
1359+
BuiltinScalarFunction::ArrayIntersect => &["array_intersect", "list_interact"],
13271360
}
13281361
}
13291362

datafusion/expr/src/expr_fn.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,12 @@ scalar_expr!(
654654
array n,
655655
"removes the last n elements from the array."
656656
);
657+
scalar_expr!(
658+
ArrayIntersect,
659+
array_intersect,
660+
first_array second_array,
661+
"Returns an array of the elements in the intersection of array1 and array2."
662+
);
657663

658664
// string functions
659665
scalar_expr!(Ascii, ascii, chr, "ASCII code value of the character");

datafusion/physical-expr/src/array_expressions.rs

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ use datafusion_common::cast::{as_generic_string_array, as_int64_array, as_list_a
2727
use datafusion_common::ScalarValue;
2828
use datafusion_common::{DataFusionError, Result};
2929
use datafusion_expr::ColumnarValue;
30+
use hashbrown::{HashMap, HashSet};
3031
use itertools::Itertools;
3132
use std::sync::Arc;
3233

@@ -1820,6 +1821,90 @@ pub fn array_has_all(args: &[ArrayRef]) -> Result<ArrayRef> {
18201821
Ok(Arc::new(boolean_builder.finish()))
18211822
}
18221823

1824+
macro_rules! array_intersect_normal {
1825+
($FIRST_ARRAY:expr, $SECOND_ARRAY:expr, $DATA_TYPE:expr, $ARRAY_TYPE:ident, $BUILDER:ident) => {{
1826+
let mut offsets: Vec<i32> = vec![0];
1827+
let mut values =
1828+
downcast_arg!(new_empty_array(&$DATA_TYPE), $ARRAY_TYPE).clone();
1829+
1830+
for (first_arr, second_arr) in $FIRST_ARRAY.iter().zip($SECOND_ARRAY.iter()) {
1831+
let last_offset: i32 = offsets.last().copied().ok_or_else(|| {
1832+
DataFusionError::Internal(format!("offsets should not be empty"))
1833+
})?;
1834+
match (first_arr, second_arr) {
1835+
(Some(first_arr), Some(second_arr)) => {
1836+
let first_arr = downcast_arg!(first_arr, $ARRAY_TYPE);
1837+
// TODO(veeupup): maybe use stack-implemented map to avoid heap memory allocation
1838+
let first_set = first_arr.iter().dedup().flatten().collect::<HashSet<_>>();
1839+
let second_arr = downcast_arg!(second_arr, $ARRAY_TYPE);
1840+
1841+
let mut builder = $BUILDER::new();
1842+
for elem in second_arr.iter().dedup().flatten() {
1843+
if first_set.contains(&elem) {
1844+
builder.append_value(elem);
1845+
}
1846+
}
1847+
1848+
let arr = builder.finish();
1849+
values = downcast_arg!(
1850+
compute::concat(&[
1851+
&values,
1852+
&arr
1853+
])?
1854+
.clone(),
1855+
$ARRAY_TYPE
1856+
)
1857+
.clone();
1858+
offsets.push(last_offset + arr.len() as i32);
1859+
},
1860+
_ => {
1861+
offsets.push(last_offset);
1862+
}
1863+
}
1864+
}
1865+
let field = Arc::new(Field::new("item", $DATA_TYPE, true));
1866+
1867+
Ok(Arc::new(ListArray::try_new(
1868+
field,
1869+
OffsetBuffer::new(offsets.into()),
1870+
Arc::new(values),
1871+
None,
1872+
)?))
1873+
1874+
}};
1875+
}
1876+
1877+
/// array_intersect SQL function
1878+
pub fn array_intersect(args: &[ArrayRef]) -> Result<ArrayRef> {
1879+
assert_eq!(args.len(), 2);
1880+
1881+
let first_array = as_list_array(&args[0])?;
1882+
let second_array = as_list_array(&args[1])?;
1883+
1884+
match (first_array.value_type(), second_array.value_type()) {
1885+
// (DataType::List(_), DataType::List(_)) => concat_internal(args)?,
1886+
(DataType::Utf8, DataType::Utf8) => array_intersect_normal!(first_array, second_array, DataType::Utf8, StringArray, StringBuilder),
1887+
(DataType::LargeUtf8, DataType::LargeUtf8) => array_intersect_normal!(first_array, second_array, DataType::LargeUtf8, LargeStringArray, LargeStringBuilder),
1888+
(DataType::Boolean, DataType::Boolean) => array_intersect_normal!(first_array, second_array, DataType::Boolean, BooleanArray, BooleanBuilder),
1889+
// (DataType::Float32, DataType::Float32) => array_intersect_normal!(arr, element, Float32Array),
1890+
// (DataType::Float64, DataType::Float64) => array_intersect_normal!(arr, element, Float64Array),
1891+
(DataType::Int8, DataType::Int8) => array_intersect_normal!(first_array, second_array, DataType::Int8, Int8Array, Int8Builder),
1892+
(DataType::Int16, DataType::Int16) => array_intersect_normal!(first_array, second_array, DataType::Int16, Int16Array, Int16Builder),
1893+
(DataType::Int32, DataType::Int32) => array_intersect_normal!(first_array, second_array, DataType::Int32, Int32Array, Int32Builder),
1894+
(DataType::Int64, DataType::Int64) => array_intersect_normal!(first_array, second_array, DataType::Int64, Int64Array, Int64Builder),
1895+
(DataType::UInt8, DataType::UInt8) => array_intersect_normal!(first_array, second_array, DataType::UInt8, UInt8Array, UInt8Builder),
1896+
(DataType::UInt16, DataType::UInt16) => array_intersect_normal!(first_array, second_array, DataType::UInt16, UInt16Array, UInt16Builder),
1897+
(DataType::UInt32, DataType::UInt32) => array_intersect_normal!(first_array, second_array, DataType::UInt32, UInt32Array, UInt32Builder),
1898+
(DataType::UInt64, DataType::UInt64) => array_intersect_normal!(first_array, second_array, DataType::UInt64, UInt64Array, UInt64Builder),
1899+
// (DataType::Null, _) => return Ok(array(&[ColumnarValue::Array(args[1].clone())])?.into_array(1)),
1900+
(first_value_dt, second_value_dt) => {
1901+
Err(DataFusionError::NotImplemented(format!(
1902+
"array_intersect is not implemented for '{first_value_dt:?}' and '{second_value_dt:?}'",
1903+
)))
1904+
}
1905+
}
1906+
}
1907+
18231908
#[cfg(test)]
18241909
mod tests {
18251910
use super::*;

datafusion/physical-expr/src/functions.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,9 @@ pub fn create_physical_fun(
465465
BuiltinScalarFunction::ArrayToString => Arc::new(|args| {
466466
make_scalar_function(array_expressions::array_to_string)(args)
467467
}),
468+
BuiltinScalarFunction::ArrayIntersect => Arc::new(|args| {
469+
make_scalar_function(array_expressions::array_intersect)(args)
470+
}),
468471
BuiltinScalarFunction::Cardinality => {
469472
Arc::new(|args| make_scalar_function(array_expressions::cardinality)(args))
470473
}

datafusion/proto/proto/datafusion.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,7 @@ enum ScalarFunction {
577577
ArrayReplaceN = 108;
578578
ArrayRemoveAll = 109;
579579
ArrayReplaceAll = 110;
580+
ArrayIntersect = 111;
580581
}
581582

582583
message ScalarFunctionNode {

datafusion/proto/src/generated/prost.rs

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/proto/src/logical_plan/from_proto.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction {
469469
ScalarFunction::ArrayReplaceN => Self::ArrayReplaceN,
470470
ScalarFunction::ArrayReplaceAll => Self::ArrayReplaceAll,
471471
ScalarFunction::ArrayToString => Self::ArrayToString,
472+
ScalarFunction::ArrayIntersect => Self::ArrayIntersect,
472473
ScalarFunction::Cardinality => Self::Cardinality,
473474
ScalarFunction::Array => Self::MakeArray,
474475
ScalarFunction::TrimArray => Self::TrimArray,

datafusion/proto/src/logical_plan/to_proto.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1417,6 +1417,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction {
14171417
BuiltinScalarFunction::ArrayReplaceN => Self::ArrayReplaceN,
14181418
BuiltinScalarFunction::ArrayReplaceAll => Self::ArrayReplaceAll,
14191419
BuiltinScalarFunction::ArrayToString => Self::ArrayToString,
1420+
BuiltinScalarFunction::ArrayIntersect => Self::ArrayIntersect,
14201421
BuiltinScalarFunction::Cardinality => Self::Cardinality,
14211422
BuiltinScalarFunction::MakeArray => Self::Array,
14221423
BuiltinScalarFunction::TrimArray => Self::TrimArray,

docs/source/user-guide/expressions.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ Unlike to some databases the math functions in Datafusion works the same way as
200200
| array_replace_n(array, from, to, max) | Replaces the first `max` occurrences of the specified element with another specified element. `array_replace_n([1, 2, 2, 3, 2, 1, 4], 2, 5, 2) -> [1, 5, 5, 3, 2, 1, 4]` |
201201
| array_replace_all(array, from, to) | Replaces all occurrences of the specified element with another specified element. `array_replace_all([1, 2, 2, 3, 2, 1, 4], 2, 5) -> [1, 5, 5, 3, 5, 1, 4]` |
202202
| array_to_string(array, delimeter) | Converts each element to its text representation. `array_to_string([1, 2, 3, 4], ',') -> 1,2,3,4` |
203+
| array_intersect(array1, array2) | Returns an array of the elements in the intersection of array1 and array2. `array_intersect([1, 2, 3, 4], [5, 6, 3, 4]) -> [3, 4]` |
203204
| cardinality(array) | Returns the total number of elements in the array. `cardinality([[1, 2, 3], [4, 5, 6]]) -> 6` |
204205
| make_array(value1, [value2 [, ...]]) | Returns an Arrow array using the specified input expressions. `make_array(1, 2, 3) -> [1, 2, 3]` |
205206
| trim_array(array, n) | Removes the last n elements from the array. |

0 commit comments

Comments
 (0)