2222//! `RUSTFLAGS="-C target-feature=+avx2"` for example. See the documentation
2323//! [here](https://doc.rust-lang.org/stable/core/arch/) for more information.
2424
25- use regex:: Regex ;
26- use std:: collections:: HashMap ;
27-
2825use crate :: array:: * ;
2926use crate :: buffer:: { bitwise_bin_op_helper, buffer_unary_not, Buffer , MutableBuffer } ;
3027use crate :: compute:: binary_boolean_kernel;
3128use crate :: compute:: util:: combine_option_bitmap;
32- use crate :: datatypes:: { ArrowNumericType , DataType } ;
29+ use crate :: datatypes:: {
30+ ArrowNumericType , DataType , Float32Type , Float64Type , Int16Type , Int32Type ,
31+ Int64Type , Int8Type , UInt16Type , UInt32Type , UInt64Type , UInt8Type ,
32+ } ;
3333use crate :: error:: { ArrowError , Result } ;
3434use crate :: util:: bit_util;
35+ use regex:: Regex ;
36+ use std:: any:: type_name;
37+ use std:: collections:: HashMap ;
3538
3639/// Helper function to perform boolean lambda function on values from two arrays, this
3740/// version does not attempt to use SIMD.
@@ -974,7 +977,142 @@ where
974977 Ok ( BooleanArray :: from ( data) )
975978}
976979
977- /// Perform `left == right` operation on two arrays.
980+ macro_rules! typed_cmp {
981+ ( $LEFT: expr, $RIGHT: expr, $T: ident, $OP: ident) => { {
982+ let left = $LEFT. as_any( ) . downcast_ref:: <$T>( ) . ok_or_else( || {
983+ ArrowError :: CastError ( format!(
984+ "Left array cannot be cast to {}" ,
985+ type_name:: <$T>( )
986+ ) )
987+ } ) ?;
988+ let right = $RIGHT. as_any( ) . downcast_ref:: <$T>( ) . ok_or_else( || {
989+ ArrowError :: CastError ( format!(
990+ "Right array cannot be cast to {}" ,
991+ type_name:: <$T>( ) ,
992+ ) )
993+ } ) ?;
994+ $OP( left, right)
995+ } } ;
996+ ( $LEFT: expr, $RIGHT: expr, $T: ident, $OP: ident, $TT: tt) => { {
997+ let left = $LEFT. as_any( ) . downcast_ref:: <$T>( ) . ok_or_else( || {
998+ ArrowError :: CastError ( format!(
999+ "Left array cannot be cast to {}" ,
1000+ type_name:: <$T>( )
1001+ ) )
1002+ } ) ?;
1003+ let right = $RIGHT. as_any( ) . downcast_ref:: <$T>( ) . ok_or_else( || {
1004+ ArrowError :: CastError ( format!(
1005+ "Right array cannot be cast to {}" ,
1006+ type_name:: <$T>( ) ,
1007+ ) )
1008+ } ) ?;
1009+ $OP:: <$TT>( left, right)
1010+ } } ;
1011+ }
1012+
1013+ macro_rules! typed_compares {
1014+ ( $LEFT: expr, $RIGHT: expr, $OP_BOOL: ident, $OP_PRIM: ident, $OP_STR: ident) => { {
1015+ match ( $LEFT. data_type( ) , $RIGHT. data_type( ) ) {
1016+ ( DataType :: Boolean , DataType :: Boolean ) => {
1017+ typed_cmp!( $LEFT, $RIGHT, BooleanArray , $OP_BOOL)
1018+ }
1019+ ( DataType :: Int8 , DataType :: Int8 ) => {
1020+ typed_cmp!( $LEFT, $RIGHT, Int8Array , $OP_PRIM, Int8Type )
1021+ }
1022+ ( DataType :: Int16 , DataType :: Int16 ) => {
1023+ typed_cmp!( $LEFT, $RIGHT, Int16Array , $OP_PRIM, Int16Type )
1024+ }
1025+ ( DataType :: Int32 , DataType :: Int32 ) => {
1026+ typed_cmp!( $LEFT, $RIGHT, Int32Array , $OP_PRIM, Int32Type )
1027+ }
1028+ ( DataType :: Int64 , DataType :: Int64 ) => {
1029+ typed_cmp!( $LEFT, $RIGHT, Int64Array , $OP_PRIM, Int64Type )
1030+ }
1031+ ( DataType :: UInt8 , DataType :: UInt8 ) => {
1032+ typed_cmp!( $LEFT, $RIGHT, UInt8Array , $OP_PRIM, UInt8Type )
1033+ }
1034+ ( DataType :: UInt16 , DataType :: UInt16 ) => {
1035+ typed_cmp!( $LEFT, $RIGHT, UInt16Array , $OP_PRIM, UInt16Type )
1036+ }
1037+ ( DataType :: UInt32 , DataType :: UInt32 ) => {
1038+ typed_cmp!( $LEFT, $RIGHT, UInt32Array , $OP_PRIM, UInt32Type )
1039+ }
1040+ ( DataType :: UInt64 , DataType :: UInt64 ) => {
1041+ typed_cmp!( $LEFT, $RIGHT, UInt64Array , $OP_PRIM, UInt64Type )
1042+ }
1043+ ( DataType :: Float32 , DataType :: Float32 ) => {
1044+ typed_cmp!( $LEFT, $RIGHT, Float32Array , $OP_PRIM, Float32Type )
1045+ }
1046+ ( DataType :: Float64 , DataType :: Float64 ) => {
1047+ typed_cmp!( $LEFT, $RIGHT, Float64Array , $OP_PRIM, Float64Type )
1048+ }
1049+ ( DataType :: Utf8 , DataType :: Utf8 ) => {
1050+ typed_cmp!( $LEFT, $RIGHT, StringArray , $OP_STR, i32 )
1051+ }
1052+ ( DataType :: LargeUtf8 , DataType :: LargeUtf8 ) => {
1053+ typed_cmp!( $LEFT, $RIGHT, LargeStringArray , $OP_STR, i64 )
1054+ }
1055+ ( t1, t2) if t1 == t2 => Err ( ArrowError :: NotYetImplemented ( format!(
1056+ "Comparing arrays of type {} is not yet implemented" ,
1057+ t1
1058+ ) ) ) ,
1059+ ( t1, t2) => Err ( ArrowError :: CastError ( format!(
1060+ "Cannot compare two arrays of different types ({} and {})" ,
1061+ t1, t2
1062+ ) ) ) ,
1063+ }
1064+ } } ;
1065+ }
1066+
1067+ /// Perform `left == right` operation on two (dynamic) [`Array`]s.
1068+ ///
1069+ /// Only when two arrays are of the same type the comparison will happen otherwise it will err
1070+ /// with a casting error.
1071+ pub fn eq_dyn ( left : & dyn Array , right : & dyn Array ) -> Result < BooleanArray > {
1072+ typed_compares ! ( left, right, eq_bool, eq, eq_utf8)
1073+ }
1074+
1075+ /// Perform `left != right` operation on two (dynamic) [`Array`]s.
1076+ ///
1077+ /// Only when two arrays are of the same type the comparison will happen otherwise it will err
1078+ /// with a casting error.
1079+ pub fn neq_dyn ( left : & dyn Array , right : & dyn Array ) -> Result < BooleanArray > {
1080+ typed_compares ! ( left, right, neq_bool, neq, neq_utf8)
1081+ }
1082+
1083+ /// Perform `left < right` operation on two (dynamic) [`Array`]s.
1084+ ///
1085+ /// Only when two arrays are of the same type the comparison will happen otherwise it will err
1086+ /// with a casting error.
1087+ pub fn lt_dyn ( left : & dyn Array , right : & dyn Array ) -> Result < BooleanArray > {
1088+ typed_compares ! ( left, right, lt_bool, lt, lt_utf8)
1089+ }
1090+
1091+ /// Perform `left <= right` operation on two (dynamic) [`Array`]s.
1092+ ///
1093+ /// Only when two arrays are of the same type the comparison will happen otherwise it will err
1094+ /// with a casting error.
1095+ pub fn lt_eq_dyn ( left : & dyn Array , right : & dyn Array ) -> Result < BooleanArray > {
1096+ typed_compares ! ( left, right, lt_eq_bool, lt_eq, lt_eq_utf8)
1097+ }
1098+
1099+ /// Perform `left > right` operation on two (dynamic) [`Array`]s.
1100+ ///
1101+ /// Only when two arrays are of the same type the comparison will happen otherwise it will err
1102+ /// with a casting error.
1103+ pub fn gt_dyn ( left : & dyn Array , right : & dyn Array ) -> Result < BooleanArray > {
1104+ typed_compares ! ( left, right, gt_bool, gt, gt_utf8)
1105+ }
1106+
1107+ /// Perform `left >= right` operation on two (dynamic) [`Array`]s.
1108+ ///
1109+ /// Only when two arrays are of the same type the comparison will happen otherwise it will err
1110+ /// with a casting error.
1111+ pub fn gt_eq_dyn ( left : & dyn Array , right : & dyn Array ) -> Result < BooleanArray > {
1112+ typed_compares ! ( left, right, gt_eq_bool, gt_eq, gt_eq_utf8)
1113+ }
1114+
1115+ /// Perform `left == right` operation on two [`PrimitiveArray`]s.
9781116pub fn eq < T > ( left : & PrimitiveArray < T > , right : & PrimitiveArray < T > ) -> Result < BooleanArray >
9791117where
9801118 T : ArrowNumericType ,
@@ -985,7 +1123,7 @@ where
9851123 return compare_op ! ( left, right, |a, b| a == b) ;
9861124}
9871125
988- /// Perform `left == right` operation on an array and a scalar value.
1126+ /// Perform `left == right` operation on a [`PrimitiveArray`] and a scalar value.
9891127pub fn eq_scalar < T > ( left : & PrimitiveArray < T > , right : T :: Native ) -> Result < BooleanArray >
9901128where
9911129 T : ArrowNumericType ,
@@ -996,7 +1134,7 @@ where
9961134 return compare_op_scalar ! ( left, right, |a, b| a == b) ;
9971135}
9981136
999- /// Perform `left != right` operation on two arrays .
1137+ /// Perform `left != right` operation on two [`PrimitiveArray`]s .
10001138pub fn neq < T > ( left : & PrimitiveArray < T > , right : & PrimitiveArray < T > ) -> Result < BooleanArray >
10011139where
10021140 T : ArrowNumericType ,
@@ -1007,7 +1145,7 @@ where
10071145 return compare_op ! ( left, right, |a, b| a != b) ;
10081146}
10091147
1010- /// Perform `left != right` operation on an array and a scalar value.
1148+ /// Perform `left != right` operation on a [`PrimitiveArray`] and a scalar value.
10111149pub fn neq_scalar < T > ( left : & PrimitiveArray < T > , right : T :: Native ) -> Result < BooleanArray >
10121150where
10131151 T : ArrowNumericType ,
@@ -1018,7 +1156,7 @@ where
10181156 return compare_op_scalar ! ( left, right, |a, b| a != b) ;
10191157}
10201158
1021- /// Perform `left < right` operation on two arrays . Null values are less than non-null
1159+ /// Perform `left < right` operation on two [`PrimitiveArray`]s . Null values are less than non-null
10221160/// values.
10231161pub fn lt < T > ( left : & PrimitiveArray < T > , right : & PrimitiveArray < T > ) -> Result < BooleanArray >
10241162where
@@ -1030,7 +1168,7 @@ where
10301168 return compare_op ! ( left, right, |a, b| a < b) ;
10311169}
10321170
1033- /// Perform `left < right` operation on an array and a scalar value.
1171+ /// Perform `left < right` operation on a [`PrimitiveArray`] and a scalar value.
10341172/// Null values are less than non-null values.
10351173pub fn lt_scalar < T > ( left : & PrimitiveArray < T > , right : T :: Native ) -> Result < BooleanArray >
10361174where
@@ -1042,7 +1180,7 @@ where
10421180 return compare_op_scalar ! ( left, right, |a, b| a < b) ;
10431181}
10441182
1045- /// Perform `left <= right` operation on two arrays . Null values are less than non-null
1183+ /// Perform `left <= right` operation on two [`PrimitiveArray`]s . Null values are less than non-null
10461184/// values.
10471185pub fn lt_eq < T > (
10481186 left : & PrimitiveArray < T > ,
@@ -1057,7 +1195,7 @@ where
10571195 return compare_op ! ( left, right, |a, b| a <= b) ;
10581196}
10591197
1060- /// Perform `left <= right` operation on an array and a scalar value.
1198+ /// Perform `left <= right` operation on a [`PrimitiveArray`] and a scalar value.
10611199/// Null values are less than non-null values.
10621200pub fn lt_eq_scalar < T > ( left : & PrimitiveArray < T > , right : T :: Native ) -> Result < BooleanArray >
10631201where
@@ -1069,7 +1207,7 @@ where
10691207 return compare_op_scalar ! ( left, right, |a, b| a <= b) ;
10701208}
10711209
1072- /// Perform `left > right` operation on two arrays . Non-null values are greater than null
1210+ /// Perform `left > right` operation on two [`PrimitiveArray`]s . Non-null values are greater than null
10731211/// values.
10741212pub fn gt < T > ( left : & PrimitiveArray < T > , right : & PrimitiveArray < T > ) -> Result < BooleanArray >
10751213where
@@ -1081,7 +1219,7 @@ where
10811219 return compare_op ! ( left, right, |a, b| a > b) ;
10821220}
10831221
1084- /// Perform `left > right` operation on an array and a scalar value.
1222+ /// Perform `left > right` operation on a [`PrimitiveArray`] and a scalar value.
10851223/// Non-null values are greater than null values.
10861224pub fn gt_scalar < T > ( left : & PrimitiveArray < T > , right : T :: Native ) -> Result < BooleanArray >
10871225where
@@ -1093,7 +1231,7 @@ where
10931231 return compare_op_scalar ! ( left, right, |a, b| a > b) ;
10941232}
10951233
1096- /// Perform `left >= right` operation on two arrays . Non-null values are greater than null
1234+ /// Perform `left >= right` operation on two [`PrimitiveArray`]s . Non-null values are greater than null
10971235/// values.
10981236pub fn gt_eq < T > (
10991237 left : & PrimitiveArray < T > ,
@@ -1108,7 +1246,7 @@ where
11081246 return compare_op ! ( left, right, |a, b| a >= b) ;
11091247}
11101248
1111- /// Perform `left >= right` operation on an array and a scalar value.
1249+ /// Perform `left >= right` operation on a [`PrimitiveArray`] and a scalar value.
11121250/// Non-null values are greater than null values.
11131251pub fn gt_eq_scalar < T > ( left : & PrimitiveArray < T > , right : T :: Native ) -> Result < BooleanArray >
11141252where
@@ -1260,11 +1398,17 @@ mod tests {
12601398 /// `EXPECTED` can be either `Vec<bool>` or `Vec<Option<bool>>`.
12611399 /// The main reason for this macro is that inputs and outputs align nicely after `cargo fmt`.
12621400 macro_rules! cmp_i64 {
1263- ( $KERNEL: ident, $A_VEC: expr, $B_VEC: expr, $EXPECTED: expr) => {
1401+ ( $KERNEL: ident, $DYN_KERNEL : ident , $ A_VEC: expr, $B_VEC: expr, $EXPECTED: expr) => {
12641402 let a = Int64Array :: from( $A_VEC) ;
12651403 let b = Int64Array :: from( $B_VEC) ;
12661404 let c = $KERNEL( & a, & b) . unwrap( ) ;
12671405 assert_eq!( BooleanArray :: from( $EXPECTED) , c) ;
1406+
1407+ // slice and test if the dynamic array works
1408+ let a = a. slice( 0 , a. len( ) ) ;
1409+ let b = b. slice( 0 , b. len( ) ) ;
1410+ let c = $DYN_KERNEL( a. as_ref( ) , b. as_ref( ) ) . unwrap( ) ;
1411+ assert_eq!( BooleanArray :: from( $EXPECTED) , c) ;
12681412 } ;
12691413 }
12701414
@@ -1284,6 +1428,7 @@ mod tests {
12841428 fn test_primitive_array_eq ( ) {
12851429 cmp_i64 ! (
12861430 eq,
1431+ eq_dyn,
12871432 vec![ 8 , 8 , 8 , 8 , 8 , 8 , 8 , 8 , 8 , 8 ] ,
12881433 vec![ 6 , 7 , 8 , 9 , 10 , 6 , 7 , 8 , 9 , 10 ] ,
12891434 vec![ false , false , true , false , false , false , false , true , false , false ]
@@ -1330,6 +1475,7 @@ mod tests {
13301475 fn test_primitive_array_neq ( ) {
13311476 cmp_i64 ! (
13321477 neq,
1478+ neq_dyn,
13331479 vec![ 8 , 8 , 8 , 8 , 8 , 8 , 8 , 8 , 8 , 8 ] ,
13341480 vec![ 6 , 7 , 8 , 9 , 10 , 6 , 7 , 8 , 9 , 10 ] ,
13351481 vec![ true , true , false , true , true , true , true , false , true , true ]
@@ -1479,6 +1625,7 @@ mod tests {
14791625 fn test_primitive_array_lt ( ) {
14801626 cmp_i64 ! (
14811627 lt,
1628+ lt_dyn,
14821629 vec![ 8 , 8 , 8 , 8 , 8 , 8 , 8 , 8 , 8 , 8 ] ,
14831630 vec![ 6 , 7 , 8 , 9 , 10 , 6 , 7 , 8 , 9 , 10 ] ,
14841631 vec![ false , false , false , true , true , false , false , false , true , true ]
@@ -1499,6 +1646,7 @@ mod tests {
14991646 fn test_primitive_array_lt_nulls ( ) {
15001647 cmp_i64 ! (
15011648 lt,
1649+ lt_dyn,
15021650 vec![ None , None , Some ( 1 ) , Some ( 1 ) , None , None , Some ( 2 ) , Some ( 2 ) , ] ,
15031651 vec![ None , Some ( 1 ) , None , Some ( 1 ) , None , Some ( 3 ) , None , Some ( 3 ) , ] ,
15041652 vec![ None , None , None , Some ( false ) , None , None , None , Some ( true ) ]
@@ -1519,6 +1667,7 @@ mod tests {
15191667 fn test_primitive_array_lt_eq ( ) {
15201668 cmp_i64 ! (
15211669 lt_eq,
1670+ lt_eq_dyn,
15221671 vec![ 8 , 8 , 8 , 8 , 8 , 8 , 8 , 8 , 8 , 8 ] ,
15231672 vec![ 6 , 7 , 8 , 9 , 10 , 6 , 7 , 8 , 9 , 10 ] ,
15241673 vec![ false , false , true , true , true , false , false , true , true , true ]
@@ -1539,6 +1688,7 @@ mod tests {
15391688 fn test_primitive_array_lt_eq_nulls ( ) {
15401689 cmp_i64 ! (
15411690 lt_eq,
1691+ lt_eq_dyn,
15421692 vec![ None , None , Some ( 1 ) , None , None , Some ( 1 ) , None , None , Some ( 1 ) ] ,
15431693 vec![ None , Some ( 1 ) , Some ( 0 ) , None , Some ( 1 ) , Some ( 2 ) , None , None , Some ( 3 ) ] ,
15441694 vec![ None , None , Some ( false ) , None , None , Some ( true ) , None , None , Some ( true ) ]
@@ -1559,6 +1709,7 @@ mod tests {
15591709 fn test_primitive_array_gt ( ) {
15601710 cmp_i64 ! (
15611711 gt,
1712+ gt_dyn,
15621713 vec![ 8 , 8 , 8 , 8 , 8 , 8 , 8 , 8 , 8 , 8 ] ,
15631714 vec![ 6 , 7 , 8 , 9 , 10 , 6 , 7 , 8 , 9 , 10 ] ,
15641715 vec![ true , true , false , false , false , true , true , false , false , false ]
@@ -1579,6 +1730,7 @@ mod tests {
15791730 fn test_primitive_array_gt_nulls ( ) {
15801731 cmp_i64 ! (
15811732 gt,
1733+ gt_dyn,
15821734 vec![ None , None , Some ( 1 ) , None , None , Some ( 2 ) , None , None , Some ( 3 ) ] ,
15831735 vec![ None , Some ( 1 ) , Some ( 1 ) , None , Some ( 1 ) , Some ( 1 ) , None , Some ( 1 ) , Some ( 1 ) ] ,
15841736 vec![ None , None , Some ( false ) , None , None , Some ( true ) , None , None , Some ( true ) ]
@@ -1599,6 +1751,7 @@ mod tests {
15991751 fn test_primitive_array_gt_eq ( ) {
16001752 cmp_i64 ! (
16011753 gt_eq,
1754+ gt_eq_dyn,
16021755 vec![ 8 , 8 , 8 , 8 , 8 , 8 , 8 , 8 , 8 , 8 ] ,
16031756 vec![ 6 , 7 , 8 , 9 , 10 , 6 , 7 , 8 , 9 , 10 ] ,
16041757 vec![ true , true , true , false , false , true , true , true , false , false ]
@@ -1619,6 +1772,7 @@ mod tests {
16191772 fn test_primitive_array_gt_eq_nulls ( ) {
16201773 cmp_i64 ! (
16211774 gt_eq,
1775+ gt_eq_dyn,
16221776 vec![ None , None , Some ( 1 ) , None , Some ( 1 ) , Some ( 2 ) , None , None , Some ( 1 ) ] ,
16231777 vec![ None , Some ( 1 ) , None , None , Some ( 1 ) , Some ( 1 ) , None , Some ( 2 ) , Some ( 2 ) ] ,
16241778 vec![ None , None , None , None , Some ( true ) , Some ( true ) , None , None , Some ( false ) ]
0 commit comments