@@ -342,6 +342,29 @@ macro_rules! boolean_op {
342342 } } ;
343343}
344344
345+ /// Invoke a boolean kernel with a scalar on an array
346+ macro_rules! boolean_op_scalar {
347+ ( $LEFT: expr, $RIGHT: expr, $OP: ident) => { {
348+ let ll = $LEFT
349+ . as_any( )
350+ . downcast_ref:: <BooleanArray >( )
351+ . expect( "boolean_op_scalar failed to downcast array" ) ;
352+
353+ let result = if let ScalarValue :: Boolean ( scalar) = $RIGHT {
354+ Ok (
355+ Arc :: new( paste:: expr! { [ <$OP _bool_scalar>] } ( & ll, scalar. as_ref( ) ) ?)
356+ as ArrayRef ,
357+ )
358+ } else {
359+ Err ( DataFusionError :: Internal ( format!(
360+ "boolean_op_scalar failed to cast literal value {}" ,
361+ $RIGHT
362+ ) ) )
363+ } ;
364+ Some ( result)
365+ } } ;
366+ }
367+
345368macro_rules! binary_string_array_flag_op {
346369 ( $LEFT: expr, $RIGHT: expr, $OP: ident, $NOT: expr, $FLAG: expr) => { {
347370 match $LEFT. data_type( ) {
@@ -592,9 +615,19 @@ impl BinaryExpr {
592615 Operator :: GtEq => {
593616 binary_array_op_scalar ! ( array, scalar. clone( ) , gt_eq)
594617 }
595- Operator :: Eq => binary_array_op_scalar ! ( array, scalar. clone( ) , eq) ,
618+ Operator :: Eq => {
619+ if array. data_type ( ) == & DataType :: Boolean {
620+ boolean_op_scalar ! ( array, scalar. clone( ) , eq)
621+ } else {
622+ binary_array_op_scalar ! ( array, scalar. clone( ) , eq)
623+ }
624+ }
596625 Operator :: NotEq => {
597- binary_array_op_scalar ! ( array, scalar. clone( ) , neq)
626+ if array. data_type ( ) == & DataType :: Boolean {
627+ boolean_op_scalar ! ( array, scalar. clone( ) , neq)
628+ } else {
629+ binary_array_op_scalar ! ( array, scalar. clone( ) , neq)
630+ }
598631 }
599632 Operator :: Like => {
600633 binary_string_array_op_scalar ! ( array, scalar. clone( ) , like)
@@ -659,9 +692,19 @@ impl BinaryExpr {
659692 Operator :: GtEq => {
660693 binary_array_op_scalar ! ( array, scalar. clone( ) , lt_eq)
661694 }
662- Operator :: Eq => binary_array_op_scalar ! ( array, scalar. clone( ) , eq) ,
695+ Operator :: Eq => {
696+ if array. data_type ( ) == & DataType :: Boolean {
697+ boolean_op_scalar ! ( array, scalar. clone( ) , eq)
698+ } else {
699+ binary_array_op_scalar ! ( array, scalar. clone( ) , eq)
700+ }
701+ }
663702 Operator :: NotEq => {
664- binary_array_op_scalar ! ( array, scalar. clone( ) , neq)
703+ if array. data_type ( ) == & DataType :: Boolean {
704+ boolean_op_scalar ! ( array, scalar. clone( ) , neq)
705+ } else {
706+ binary_array_op_scalar ! ( array, scalar. clone( ) , neq)
707+ }
665708 }
666709 // if scalar operation is not supported - fallback to array implementation
667710 _ => None ,
@@ -683,8 +726,21 @@ impl BinaryExpr {
683726 Operator :: LtEq => binary_array_op ! ( left, right, lt_eq) ,
684727 Operator :: Gt => binary_array_op ! ( left, right, gt) ,
685728 Operator :: GtEq => binary_array_op ! ( left, right, gt_eq) ,
686- Operator :: Eq => binary_array_op ! ( left, right, eq) ,
687- Operator :: NotEq => binary_array_op ! ( left, right, neq) ,
729+ Operator :: Eq => {
730+ if left_data_type == & DataType :: Boolean {
731+ boolean_op ! ( left, right, eq_bool)
732+ } else {
733+ binary_array_op ! ( left, right, eq)
734+ }
735+ }
736+ Operator :: NotEq => {
737+ if left_data_type == & DataType :: Boolean {
738+ boolean_op ! ( left, right, neq_bool)
739+ } else {
740+ binary_array_op ! ( left, right, neq)
741+ }
742+ }
743+
688744 Operator :: IsDistinctFrom => binary_array_op ! ( left, right, is_distinct_from) ,
689745 Operator :: IsNotDistinctFrom => {
690746 binary_array_op ! ( left, right, is_not_distinct_from)
@@ -814,14 +870,68 @@ pub fn binary(
814870 Ok ( Arc :: new ( BinaryExpr :: new ( l, op, r) ) )
815871}
816872
873+ // TODO file a ticket with arrow-rs to include these kernels
874+
875+ fn eq_bool ( lhs : & BooleanArray , rhs : & BooleanArray ) -> Result < BooleanArray > {
876+ let arr: BooleanArray = lhs
877+ . iter ( )
878+ . zip ( rhs. iter ( ) )
879+ . map ( |v| match v {
880+ // both lhs and rhs were non null
881+ ( Some ( lhs) , Some ( rhs) ) => Some ( lhs == rhs) ,
882+ _ => None ,
883+ } )
884+ . collect ( ) ;
885+
886+ Ok ( arr)
887+ }
888+
889+ fn eq_bool_scalar ( lhs : & BooleanArray , rhs : Option < & bool > ) -> Result < BooleanArray > {
890+ let arr: BooleanArray = lhs
891+ . iter ( )
892+ . map ( |v| match ( v, rhs) {
893+ // both lhs and rhs were non null
894+ ( Some ( lhs) , Some ( rhs) ) => Some ( lhs == * rhs) ,
895+ _ => None ,
896+ } )
897+ . collect ( ) ;
898+ Ok ( arr)
899+ }
900+
901+ fn neq_bool ( lhs : & BooleanArray , rhs : & BooleanArray ) -> Result < BooleanArray > {
902+ let arr: BooleanArray = lhs
903+ . iter ( )
904+ . zip ( rhs. iter ( ) )
905+ . map ( |v| match v {
906+ // both lhs and rhs were non null
907+ ( Some ( lhs) , Some ( rhs) ) => Some ( lhs != rhs) ,
908+ _ => None ,
909+ } )
910+ . collect ( ) ;
911+
912+ Ok ( arr)
913+ }
914+
915+ fn neq_bool_scalar ( lhs : & BooleanArray , rhs : Option < & bool > ) -> Result < BooleanArray > {
916+ let arr: BooleanArray = lhs
917+ . iter ( )
918+ . map ( |v| match ( v, rhs) {
919+ // both lhs and rhs were non null
920+ ( Some ( lhs) , Some ( rhs) ) => Some ( lhs != * rhs) ,
921+ _ => None ,
922+ } )
923+ . collect ( ) ;
924+ Ok ( arr)
925+ }
926+
817927#[ cfg( test) ]
818928mod tests {
819929 use arrow:: datatypes:: { ArrowNumericType , Field , Int32Type , SchemaRef } ;
820930 use arrow:: util:: display:: array_value_to_string;
821931
822932 use super :: * ;
823933 use crate :: error:: Result ;
824- use crate :: physical_plan:: expressions:: col;
934+ use crate :: physical_plan:: expressions:: { col, lit } ;
825935
826936 // Create a binary expression without coercion. Used here when we do not want to coerce the expressions
827937 // to valid types. Usage can result in an execution (after plan) error.
@@ -1371,6 +1481,42 @@ mod tests {
13711481 Ok ( ( ) )
13721482 }
13731483
1484+ // Test `scalar <op> arr` produces expected
1485+ fn apply_logic_op_scalar_arr (
1486+ schema : & SchemaRef ,
1487+ scalar : bool ,
1488+ arr : & ArrayRef ,
1489+ op : Operator ,
1490+ expected : & BooleanArray ,
1491+ ) -> Result < ( ) > {
1492+ let scalar = lit ( scalar. into ( ) ) ;
1493+
1494+ let arithmetic_op = binary_simple ( scalar, op, col ( "a" , schema) ?) ;
1495+ let batch = RecordBatch :: try_new ( Arc :: clone ( schema) , vec ! [ Arc :: clone( arr) ] ) ?;
1496+ let result = arithmetic_op. evaluate ( & batch) ?. into_array ( batch. num_rows ( ) ) ;
1497+ assert_eq ! ( result. as_ref( ) , expected) ;
1498+
1499+ Ok ( ( ) )
1500+ }
1501+
1502+ // Test `arr <op> scalar` produces expected
1503+ fn apply_logic_op_arr_scalar (
1504+ schema : & SchemaRef ,
1505+ arr : & ArrayRef ,
1506+ scalar : bool ,
1507+ op : Operator ,
1508+ expected : & BooleanArray ,
1509+ ) -> Result < ( ) > {
1510+ let scalar = lit ( scalar. into ( ) ) ;
1511+
1512+ let arithmetic_op = binary_simple ( col ( "a" , schema) ?, op, scalar) ;
1513+ let batch = RecordBatch :: try_new ( Arc :: clone ( schema) , vec ! [ Arc :: clone( arr) ] ) ?;
1514+ let result = arithmetic_op. evaluate ( & batch) ?. into_array ( batch. num_rows ( ) ) ;
1515+ assert_eq ! ( result. as_ref( ) , expected) ;
1516+
1517+ Ok ( ( ) )
1518+ }
1519+
13741520 #[ test]
13751521 fn and_with_nulls_op ( ) -> Result < ( ) > {
13761522 let schema = Schema :: new ( vec ! [
@@ -1461,6 +1607,58 @@ mod tests {
14611607 Ok ( ( ) )
14621608 }
14631609
1610+ #[ test]
1611+ fn eq_op_bool ( ) {
1612+ let schema = Schema :: new ( vec ! [
1613+ Field :: new( "a" , DataType :: Boolean , false ) ,
1614+ Field :: new( "b" , DataType :: Boolean , false ) ,
1615+ ] ) ;
1616+ let a = BooleanArray :: from ( vec ! [ Some ( true ) , None , Some ( false ) , None ] ) ;
1617+ let b =
1618+ BooleanArray :: from ( vec ! [ Some ( true ) , Some ( false ) , Some ( true ) , Some ( false ) ] ) ;
1619+
1620+ let expected = BooleanArray :: from ( vec ! [ Some ( true ) , None , Some ( false ) , None ] ) ;
1621+ apply_logic_op ( Arc :: new ( schema) , a, b, Operator :: Eq , expected) . unwrap ( ) ;
1622+ }
1623+
1624+ #[ test]
1625+ fn eq_op_bool_scalar ( ) {
1626+ let schema = Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Boolean , false ) ] ) ;
1627+ let schema = Arc :: new ( schema) ;
1628+ let a: ArrayRef =
1629+ Arc :: new ( BooleanArray :: from ( vec ! [ Some ( true ) , None , Some ( false ) ] ) ) ;
1630+
1631+ let expected = BooleanArray :: from ( vec ! [ Some ( true ) , None , Some ( false ) ] ) ;
1632+ apply_logic_op_scalar_arr ( & schema, true , & a, Operator :: Eq , & expected) . unwrap ( ) ;
1633+ apply_logic_op_arr_scalar ( & schema, & a, true , Operator :: Eq , & expected) . unwrap ( ) ;
1634+ }
1635+
1636+ #[ test]
1637+ fn neq_op_bool ( ) {
1638+ let schema = Schema :: new ( vec ! [
1639+ Field :: new( "a" , DataType :: Boolean , false ) ,
1640+ Field :: new( "b" , DataType :: Boolean , false ) ,
1641+ ] ) ;
1642+ let a = BooleanArray :: from ( vec ! [ Some ( true ) , None , Some ( false ) , None ] ) ;
1643+ let b =
1644+ BooleanArray :: from ( vec ! [ Some ( true ) , Some ( false ) , Some ( true ) , Some ( false ) ] ) ;
1645+
1646+ let expected = BooleanArray :: from ( vec ! [ Some ( false ) , None , Some ( true ) , None ] ) ;
1647+ apply_logic_op ( Arc :: new ( schema) , a, b, Operator :: NotEq , expected) . unwrap ( ) ;
1648+ }
1649+
1650+ #[ test]
1651+ fn neq_op_bool_scalar ( ) {
1652+ let schema = Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Boolean , false ) ] ) ;
1653+ let schema = Arc :: new ( schema) ;
1654+ let a: ArrayRef =
1655+ Arc :: new ( BooleanArray :: from ( vec ! [ Some ( true ) , None , Some ( false ) ] ) ) ;
1656+
1657+ let expected = BooleanArray :: from ( vec ! [ Some ( false ) , None , Some ( true ) ] ) ;
1658+ apply_logic_op_scalar_arr ( & schema, true , & a, Operator :: NotEq , & expected) . unwrap ( ) ;
1659+ apply_logic_op_arr_scalar ( & schema, & a, true , Operator :: NotEq , & expected) . unwrap ( ) ;
1660+ }
1661+
14641662 #[ test]
14651663 fn test_coersion_error ( ) -> Result < ( ) > {
14661664 let expr =
0 commit comments