1919
2020use std:: any:: type_name;
2121use std:: collections:: HashSet ;
22+ use std:: fmt:: { Display , Formatter } ;
2223use std:: sync:: Arc ;
2324
2425use arrow:: array:: * ;
@@ -1535,11 +1536,25 @@ macro_rules! to_string {
15351536 } } ;
15361537}
15371538
1538- /// general function for array_union and array_intersect
1539- fn general_set_lists < OffsetSize : OffsetSizeTrait > (
1539+ #[ derive( Debug , PartialEq ) ]
1540+ enum SetOp {
1541+ Union ,
1542+ Intersect ,
1543+ }
1544+
1545+ impl Display for SetOp {
1546+ fn fmt ( & self , f : & mut Formatter < ' _ > ) -> std:: fmt:: Result {
1547+ match self {
1548+ SetOp :: Union => write ! ( f, "array_union" ) ,
1549+ SetOp :: Intersect => write ! ( f, "array_intersect" ) ,
1550+ }
1551+ }
1552+ }
1553+
1554+ fn generic_set_lists < OffsetSize : OffsetSizeTrait > (
15401555 l : & GenericListArray < OffsetSize > ,
15411556 r : & GenericListArray < OffsetSize > ,
1542- is_union : bool ,
1557+ set_op : SetOp ,
15431558) -> Result < ArrayRef > {
15441559 if matches ! ( l. value_type( ) , DataType :: Null ) {
15451560 let field = Arc :: new ( Field :: new ( "item" , r. value_type ( ) , true ) ) ;
@@ -1550,12 +1565,7 @@ fn general_set_lists<OffsetSize: OffsetSizeTrait>(
15501565 }
15511566
15521567 if l. value_type ( ) != r. value_type ( ) {
1553- let operation = if is_union {
1554- "array_union"
1555- } else {
1556- "array_intersect"
1557- } ;
1558- return internal_err ! ( "{operation} is not implemented for '{l:?}' and '{r:?}'" ) ;
1568+ return internal_err ! ( "{set_op} is not implemented for '{l:?}' and '{r:?}'" ) ;
15591569 }
15601570
15611571 let dt = l. value_type ( ) ;
@@ -1571,14 +1581,23 @@ fn general_set_lists<OffsetSize: OffsetSizeTrait>(
15711581
15721582 let l_iter = l_values. iter ( ) . sorted ( ) . dedup ( ) ;
15731583 let values_set: HashSet < _ > = l_iter. clone ( ) . collect ( ) ;
1574- let mut rows = if is_union {
1584+ let mut rows = if set_op == SetOp :: Union {
15751585 l_iter. collect :: < Vec < _ > > ( )
15761586 } else {
15771587 vec ! [ ]
15781588 } ;
15791589 for r_val in r_values. iter ( ) . sorted ( ) . dedup ( ) {
1580- if !values_set. contains ( & r_val) == is_union {
1581- rows. push ( r_val) ;
1590+ match set_op {
1591+ SetOp :: Union => {
1592+ if !values_set. contains ( & r_val) {
1593+ rows. push ( r_val) ;
1594+ }
1595+ }
1596+ SetOp :: Intersect => {
1597+ if values_set. contains ( & r_val) {
1598+ rows. push ( r_val) ;
1599+ }
1600+ }
15821601 }
15831602 }
15841603
@@ -1591,12 +1610,7 @@ fn general_set_lists<OffsetSize: OffsetSizeTrait>(
15911610 let array = match arrays. first ( ) {
15921611 Some ( array) => array. clone ( ) ,
15931612 None => {
1594- let operation = if is_union {
1595- "array_union"
1596- } else {
1597- "array_intersect"
1598- } ;
1599- return internal_err ! ( "{operation}: failed to get array from rows" ) ;
1613+ return internal_err ! ( "{set_op}: failed to get array from rows" ) ;
16001614 }
16011615 } ;
16021616 new_arrays. push ( array) ;
@@ -1611,15 +1625,13 @@ fn general_set_lists<OffsetSize: OffsetSizeTrait>(
16111625 Ok ( Arc :: new ( arr) )
16121626}
16131627
1614- /// Array_union SQL function
1615- pub fn array_union ( args : & [ ArrayRef ] ) -> Result < ArrayRef > {
1616- if args. len ( ) != 2 {
1617- return exec_err ! ( "array_union needs two arguments" ) ;
1618- }
1619- let array1 = & args[ 0 ] ;
1620- let array2 = & args[ 1 ] ;
1621-
1628+ fn general_set_op (
1629+ array1 : & ArrayRef ,
1630+ array2 : & ArrayRef ,
1631+ set_op : SetOp ,
1632+ ) -> Result < ArrayRef > {
16221633 match ( array1. data_type ( ) , array2. data_type ( ) ) {
1634+ // Null type
16231635 ( DataType :: Null , DataType :: List ( field) )
16241636 | ( DataType :: List ( field) , DataType :: Null ) => {
16251637 let array = match array1. data_type ( ) {
@@ -1637,24 +1649,36 @@ pub fn array_union(args: &[ArrayRef]) -> Result<ArrayRef> {
16371649 general_array_distinct :: < i64 > ( array, field)
16381650 }
16391651 ( DataType :: Null , DataType :: Null ) => Ok ( array1. clone ( ) ) ,
1652+
16401653 ( DataType :: List ( _) , DataType :: List ( _) ) => {
16411654 let array1 = as_list_array ( & array1) ?;
16421655 let array2 = as_list_array ( & array2) ?;
1643- general_set_lists :: < i32 > ( array1, array2, true )
1656+ generic_set_lists :: < i32 > ( array1, array2, set_op )
16441657 }
16451658 ( DataType :: LargeList ( _) , DataType :: LargeList ( _) ) => {
16461659 let array1 = as_large_list_array ( & array1) ?;
16471660 let array2 = as_large_list_array ( & array2) ?;
1648- general_set_lists :: < i64 > ( array1, array2, true )
1661+ generic_set_lists :: < i64 > ( array1, array2, set_op )
16491662 }
16501663 ( data_type1, data_type2) => {
16511664 internal_err ! (
1652- "array_union does not support types '{data_type1:?}' and '{data_type2:?}'"
1665+ "{set_op} does not support types '{data_type1:?}' and '{data_type2:?}'"
16531666 )
16541667 }
16551668 }
16561669}
16571670
1671+ /// Array_union SQL function
1672+ pub fn array_union ( args : & [ ArrayRef ] ) -> Result < ArrayRef > {
1673+ if args. len ( ) != 2 {
1674+ return exec_err ! ( "array_union needs two arguments" ) ;
1675+ }
1676+ let array1 = & args[ 0 ] ;
1677+ let array2 = & args[ 1 ] ;
1678+
1679+ general_set_op ( array1, array2, SetOp :: Union )
1680+ }
1681+
16581682/// array_intersect SQL function
16591683pub fn array_intersect ( args : & [ ArrayRef ] ) -> Result < ArrayRef > {
16601684 if args. len ( ) != 2 {
@@ -1664,40 +1688,7 @@ pub fn array_intersect(args: &[ArrayRef]) -> Result<ArrayRef> {
16641688 let array1 = & args[ 0 ] ;
16651689 let array2 = & args[ 1 ] ;
16661690
1667- match ( array1. data_type ( ) , array2. data_type ( ) ) {
1668- ( DataType :: Null , DataType :: List ( field) )
1669- | ( DataType :: List ( field) , DataType :: Null ) => {
1670- let array = match array1. data_type ( ) {
1671- DataType :: Null => as_list_array ( & array2) ?,
1672- _ => as_list_array ( & array1) ?,
1673- } ;
1674- general_array_distinct :: < i32 > ( array, field)
1675- }
1676- ( DataType :: Null , DataType :: LargeList ( field) )
1677- | ( DataType :: LargeList ( field) , DataType :: Null ) => {
1678- let array = match array1. data_type ( ) {
1679- DataType :: Null => as_large_list_array ( & array2) ?,
1680- _ => as_large_list_array ( & array1) ?,
1681- } ;
1682- general_array_distinct :: < i64 > ( array, field)
1683- }
1684- ( DataType :: Null , DataType :: Null ) => Ok ( array1. clone ( ) ) ,
1685- ( DataType :: List ( _) , DataType :: List ( _) ) => {
1686- let array1 = as_list_array ( & array1) ?;
1687- let array2 = as_list_array ( & array2) ?;
1688- general_set_lists :: < i32 > ( array1, array2, false )
1689- }
1690- ( DataType :: LargeList ( _) , DataType :: LargeList ( _) ) => {
1691- let array1 = as_large_list_array ( & array1) ?;
1692- let array2 = as_large_list_array ( & array2) ?;
1693- general_set_lists :: < i64 > ( array1, array2, false )
1694- }
1695- ( data_type1, data_type2) => {
1696- internal_err ! (
1697- "array_intersect does not support types '{data_type1:?}' and '{data_type2:?}'"
1698- )
1699- }
1700- }
1691+ general_set_op ( array1, array2, SetOp :: Intersect )
17011692}
17021693
17031694/// Array_to_string SQL function
0 commit comments