Skip to content

Commit db6c861

Browse files
Weijun-Happletreeisyellow
authored andcommitted
Refactor array_union and array_intersect functions to one general function (apache#8516)
* Refactor array_union and array_intersect functions * fix cli * fix ci * add tests for null * modify the return type * update tests * fix clippy * fix clippy * add tests for largelist * fix clippy * Add field parameter to generic_set_lists() function * Add large array drop statements * fix clippy
1 parent 5da2de4 commit db6c861

File tree

3 files changed

+446
-144
lines changed

3 files changed

+446
-144
lines changed

datafusion/expr/src/built_in_function.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -618,7 +618,18 @@ impl BuiltinScalarFunction {
618618
BuiltinScalarFunction::ArrayReplaceAll => Ok(input_expr_types[0].clone()),
619619
BuiltinScalarFunction::ArraySlice => Ok(input_expr_types[0].clone()),
620620
BuiltinScalarFunction::ArrayToString => Ok(Utf8),
621-
BuiltinScalarFunction::ArrayUnion | BuiltinScalarFunction::ArrayIntersect => {
621+
BuiltinScalarFunction::ArrayIntersect => {
622+
match (input_expr_types[0].clone(), input_expr_types[1].clone()) {
623+
(DataType::Null, DataType::Null) | (DataType::Null, _) => {
624+
Ok(DataType::Null)
625+
}
626+
(_, DataType::Null) => {
627+
Ok(List(Arc::new(Field::new("item", Null, true))))
628+
}
629+
(dt, _) => Ok(dt),
630+
}
631+
}
632+
BuiltinScalarFunction::ArrayUnion => {
622633
match (input_expr_types[0].clone(), input_expr_types[1].clone()) {
623634
(DataType::Null, dt) => Ok(dt),
624635
(dt, DataType::Null) => Ok(dt),

datafusion/physical-expr/src/array_expressions.rs

Lines changed: 146 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
2020
use std::any::type_name;
2121
use std::collections::HashSet;
22+
use std::fmt::{Display, Formatter};
2223
use std::sync::Arc;
2324

2425
use arrow::array::*;
@@ -1777,97 +1778,173 @@ macro_rules! to_string {
17771778
}};
17781779
}
17791780

1780-
fn union_generic_lists<OffsetSize: OffsetSizeTrait>(
1781+
#[derive(Debug, PartialEq)]
1782+
enum SetOp {
1783+
Union,
1784+
Intersect,
1785+
}
1786+
1787+
impl Display for SetOp {
1788+
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1789+
match self {
1790+
SetOp::Union => write!(f, "array_union"),
1791+
SetOp::Intersect => write!(f, "array_intersect"),
1792+
}
1793+
}
1794+
}
1795+
1796+
fn generic_set_lists<OffsetSize: OffsetSizeTrait>(
17811797
l: &GenericListArray<OffsetSize>,
17821798
r: &GenericListArray<OffsetSize>,
1783-
field: &FieldRef,
1784-
) -> Result<GenericListArray<OffsetSize>> {
1785-
let converter = RowConverter::new(vec![SortField::new(l.value_type())])?;
1799+
field: Arc<Field>,
1800+
set_op: SetOp,
1801+
) -> Result<ArrayRef> {
1802+
if matches!(l.value_type(), DataType::Null) {
1803+
let field = Arc::new(Field::new("item", r.value_type(), true));
1804+
return general_array_distinct::<OffsetSize>(r, &field);
1805+
} else if matches!(r.value_type(), DataType::Null) {
1806+
let field = Arc::new(Field::new("item", l.value_type(), true));
1807+
return general_array_distinct::<OffsetSize>(l, &field);
1808+
}
17861809

1787-
let nulls = NullBuffer::union(l.nulls(), r.nulls());
1788-
let l_values = l.values().clone();
1789-
let r_values = r.values().clone();
1790-
let l_values = converter.convert_columns(&[l_values])?;
1791-
let r_values = converter.convert_columns(&[r_values])?;
1810+
if l.value_type() != r.value_type() {
1811+
return internal_err!("{set_op:?} is not implemented for '{l:?}' and '{r:?}'");
1812+
}
17921813

1793-
// Might be worth adding an upstream OffsetBufferBuilder
1794-
let mut offsets = Vec::<OffsetSize>::with_capacity(l.len() + 1);
1795-
offsets.push(OffsetSize::usize_as(0));
1796-
let mut rows = Vec::with_capacity(l_values.num_rows() + r_values.num_rows());
1797-
let mut dedup = HashSet::new();
1798-
for (l_w, r_w) in l.offsets().windows(2).zip(r.offsets().windows(2)) {
1799-
let l_slice = l_w[0].as_usize()..l_w[1].as_usize();
1800-
let r_slice = r_w[0].as_usize()..r_w[1].as_usize();
1801-
for i in l_slice {
1802-
let left_row = l_values.row(i);
1803-
if dedup.insert(left_row) {
1804-
rows.push(left_row);
1805-
}
1806-
}
1807-
for i in r_slice {
1808-
let right_row = r_values.row(i);
1809-
if dedup.insert(right_row) {
1810-
rows.push(right_row);
1814+
let dt = l.value_type();
1815+
1816+
let mut offsets = vec![OffsetSize::usize_as(0)];
1817+
let mut new_arrays = vec![];
1818+
1819+
let converter = RowConverter::new(vec![SortField::new(dt)])?;
1820+
for (first_arr, second_arr) in l.iter().zip(r.iter()) {
1821+
if let (Some(first_arr), Some(second_arr)) = (first_arr, second_arr) {
1822+
let l_values = converter.convert_columns(&[first_arr])?;
1823+
let r_values = converter.convert_columns(&[second_arr])?;
1824+
1825+
let l_iter = l_values.iter().sorted().dedup();
1826+
let values_set: HashSet<_> = l_iter.clone().collect();
1827+
let mut rows = if set_op == SetOp::Union {
1828+
l_iter.collect::<Vec<_>>()
1829+
} else {
1830+
vec![]
1831+
};
1832+
for r_val in r_values.iter().sorted().dedup() {
1833+
match set_op {
1834+
SetOp::Union => {
1835+
if !values_set.contains(&r_val) {
1836+
rows.push(r_val);
1837+
}
1838+
}
1839+
SetOp::Intersect => {
1840+
if values_set.contains(&r_val) {
1841+
rows.push(r_val);
1842+
}
1843+
}
1844+
}
18111845
}
1846+
1847+
let last_offset = match offsets.last().copied() {
1848+
Some(offset) => offset,
1849+
None => return internal_err!("offsets should not be empty"),
1850+
};
1851+
offsets.push(last_offset + OffsetSize::usize_as(rows.len()));
1852+
let arrays = converter.convert_rows(rows)?;
1853+
let array = match arrays.first() {
1854+
Some(array) => array.clone(),
1855+
None => {
1856+
return internal_err!("{set_op}: failed to get array from rows");
1857+
}
1858+
};
1859+
new_arrays.push(array);
18121860
}
1813-
offsets.push(OffsetSize::usize_as(rows.len()));
1814-
dedup.clear();
18151861
}
18161862

1817-
let values = converter.convert_rows(rows)?;
18181863
let offsets = OffsetBuffer::new(offsets.into());
1819-
let result = values[0].clone();
1820-
Ok(GenericListArray::<OffsetSize>::new(
1821-
field.clone(),
1822-
offsets,
1823-
result,
1824-
nulls,
1825-
))
1864+
let new_arrays_ref = new_arrays.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
1865+
let values = compute::concat(&new_arrays_ref)?;
1866+
let arr = GenericListArray::<OffsetSize>::try_new(field, offsets, values, None)?;
1867+
Ok(Arc::new(arr))
18261868
}
18271869

1828-
/// Array_union SQL function
1829-
pub fn array_union(args: &[ArrayRef]) -> Result<ArrayRef> {
1830-
if args.len() != 2 {
1831-
return exec_err!("array_union needs 2 arguments");
1832-
}
1833-
let array1 = &args[0];
1834-
let array2 = &args[1];
1870+
fn general_set_op(
1871+
array1: &ArrayRef,
1872+
array2: &ArrayRef,
1873+
set_op: SetOp,
1874+
) -> Result<ArrayRef> {
1875+
match (array1.data_type(), array2.data_type()) {
1876+
(DataType::Null, DataType::List(field)) => {
1877+
if set_op == SetOp::Intersect {
1878+
return Ok(new_empty_array(&DataType::Null));
1879+
}
1880+
let array = as_list_array(&array2)?;
1881+
general_array_distinct::<i32>(array, field)
1882+
}
18351883

1836-
fn union_arrays<O: OffsetSizeTrait>(
1837-
array1: &ArrayRef,
1838-
array2: &ArrayRef,
1839-
l_field_ref: &Arc<Field>,
1840-
r_field_ref: &Arc<Field>,
1841-
) -> Result<ArrayRef> {
1842-
match (l_field_ref.data_type(), r_field_ref.data_type()) {
1843-
(DataType::Null, _) => Ok(array2.clone()),
1844-
(_, DataType::Null) => Ok(array1.clone()),
1845-
(_, _) => {
1846-
let list1 = array1.as_list::<O>();
1847-
let list2 = array2.as_list::<O>();
1848-
let result = union_generic_lists::<O>(list1, list2, l_field_ref)?;
1849-
Ok(Arc::new(result))
1884+
(DataType::List(field), DataType::Null) => {
1885+
if set_op == SetOp::Intersect {
1886+
return make_array(&[]);
18501887
}
1888+
let array = as_list_array(&array1)?;
1889+
general_array_distinct::<i32>(array, field)
18511890
}
1852-
}
1891+
(DataType::Null, DataType::LargeList(field)) => {
1892+
if set_op == SetOp::Intersect {
1893+
return Ok(new_empty_array(&DataType::Null));
1894+
}
1895+
let array = as_large_list_array(&array2)?;
1896+
general_array_distinct::<i64>(array, field)
1897+
}
1898+
(DataType::LargeList(field), DataType::Null) => {
1899+
if set_op == SetOp::Intersect {
1900+
return make_array(&[]);
1901+
}
1902+
let array = as_large_list_array(&array1)?;
1903+
general_array_distinct::<i64>(array, field)
1904+
}
1905+
(DataType::Null, DataType::Null) => Ok(new_empty_array(&DataType::Null)),
18531906

1854-
match (array1.data_type(), array2.data_type()) {
1855-
(DataType::Null, _) => Ok(array2.clone()),
1856-
(_, DataType::Null) => Ok(array1.clone()),
1857-
(DataType::List(l_field_ref), DataType::List(r_field_ref)) => {
1858-
union_arrays::<i32>(array1, array2, l_field_ref, r_field_ref)
1907+
(DataType::List(field), DataType::List(_)) => {
1908+
let array1 = as_list_array(&array1)?;
1909+
let array2 = as_list_array(&array2)?;
1910+
generic_set_lists::<i32>(array1, array2, field.clone(), set_op)
18591911
}
1860-
(DataType::LargeList(l_field_ref), DataType::LargeList(r_field_ref)) => {
1861-
union_arrays::<i64>(array1, array2, l_field_ref, r_field_ref)
1912+
(DataType::LargeList(field), DataType::LargeList(_)) => {
1913+
let array1 = as_large_list_array(&array1)?;
1914+
let array2 = as_large_list_array(&array2)?;
1915+
generic_set_lists::<i64>(array1, array2, field.clone(), set_op)
18621916
}
1863-
_ => {
1917+
(data_type1, data_type2) => {
18641918
internal_err!(
1865-
"array_union only support list with offsets of type int32 and int64"
1919+
"{set_op} does not support types '{data_type1:?}' and '{data_type2:?}'"
18661920
)
18671921
}
18681922
}
18691923
}
18701924

1925+
/// Array_union SQL function
1926+
pub fn array_union(args: &[ArrayRef]) -> Result<ArrayRef> {
1927+
if args.len() != 2 {
1928+
return exec_err!("array_union needs two arguments");
1929+
}
1930+
let array1 = &args[0];
1931+
let array2 = &args[1];
1932+
1933+
general_set_op(array1, array2, SetOp::Union)
1934+
}
1935+
1936+
/// array_intersect SQL function
1937+
pub fn array_intersect(args: &[ArrayRef]) -> Result<ArrayRef> {
1938+
if args.len() != 2 {
1939+
return exec_err!("array_intersect needs two arguments");
1940+
}
1941+
1942+
let array1 = &args[0];
1943+
let array2 = &args[1];
1944+
1945+
general_set_op(array1, array2, SetOp::Intersect)
1946+
}
1947+
18711948
/// Array_to_string SQL function
18721949
pub fn array_to_string(args: &[ArrayRef]) -> Result<ArrayRef> {
18731950
if args.len() < 2 || args.len() > 3 {
@@ -2228,7 +2305,7 @@ pub fn array_has(args: &[ArrayRef]) -> Result<ArrayRef> {
22282305
DataType::LargeList(_) => {
22292306
general_array_has_dispatch::<i64>(&args[0], &args[1], ComparisonType::Single)
22302307
}
2231-
_ => internal_err!("array_has does not support type '{array_type:?}'."),
2308+
_ => exec_err!("array_has does not support type '{array_type:?}'."),
22322309
}
22332310
}
22342311

@@ -2359,74 +2436,6 @@ pub fn string_to_array<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef
23592436
Ok(Arc::new(list_array) as ArrayRef)
23602437
}
23612438

2362-
/// array_intersect SQL function
2363-
pub fn array_intersect(args: &[ArrayRef]) -> Result<ArrayRef> {
2364-
if args.len() != 2 {
2365-
return exec_err!("array_intersect needs two arguments");
2366-
}
2367-
2368-
let first_array = &args[0];
2369-
let second_array = &args[1];
2370-
2371-
match (first_array.data_type(), second_array.data_type()) {
2372-
(DataType::Null, _) => Ok(second_array.clone()),
2373-
(_, DataType::Null) => Ok(first_array.clone()),
2374-
_ => {
2375-
let first_array = as_list_array(&first_array)?;
2376-
let second_array = as_list_array(&second_array)?;
2377-
2378-
if first_array.value_type() != second_array.value_type() {
2379-
return internal_err!("array_intersect is not implemented for '{first_array:?}' and '{second_array:?}'");
2380-
}
2381-
2382-
let dt = first_array.value_type();
2383-
2384-
let mut offsets = vec![0];
2385-
let mut new_arrays = vec![];
2386-
2387-
let converter = RowConverter::new(vec![SortField::new(dt.clone())])?;
2388-
for (first_arr, second_arr) in first_array.iter().zip(second_array.iter()) {
2389-
if let (Some(first_arr), Some(second_arr)) = (first_arr, second_arr) {
2390-
let l_values = converter.convert_columns(&[first_arr])?;
2391-
let r_values = converter.convert_columns(&[second_arr])?;
2392-
2393-
let values_set: HashSet<_> = l_values.iter().collect();
2394-
let mut rows = Vec::with_capacity(r_values.num_rows());
2395-
for r_val in r_values.iter().sorted().dedup() {
2396-
if values_set.contains(&r_val) {
2397-
rows.push(r_val);
2398-
}
2399-
}
2400-
2401-
let last_offset: i32 = match offsets.last().copied() {
2402-
Some(offset) => offset,
2403-
None => return internal_err!("offsets should not be empty"),
2404-
};
2405-
offsets.push(last_offset + rows.len() as i32);
2406-
let arrays = converter.convert_rows(rows)?;
2407-
let array = match arrays.first() {
2408-
Some(array) => array.clone(),
2409-
None => {
2410-
return internal_err!(
2411-
"array_intersect: failed to get array from rows"
2412-
)
2413-
}
2414-
};
2415-
new_arrays.push(array);
2416-
}
2417-
}
2418-
2419-
let field = Arc::new(Field::new("item", dt, true));
2420-
let offsets = OffsetBuffer::new(offsets.into());
2421-
let new_arrays_ref =
2422-
new_arrays.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
2423-
let values = compute::concat(&new_arrays_ref)?;
2424-
let arr = Arc::new(ListArray::try_new(field, offsets, values, None)?);
2425-
Ok(arr)
2426-
}
2427-
}
2428-
}
2429-
24302439
pub fn general_array_distinct<OffsetSize: OffsetSizeTrait>(
24312440
array: &GenericListArray<OffsetSize>,
24322441
field: &FieldRef,

0 commit comments

Comments
 (0)