Skip to content

Commit fda35ac

Browse files
committed
Support <bool col> = <bool col> and <bool col> != <bool col>
1 parent 2b002e4 commit fda35ac

File tree

2 files changed

+243
-9
lines changed

2 files changed

+243
-9
lines changed

datafusion/src/physical_plan/expressions/binary.rs

Lines changed: 205 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,29 @@ macro_rules! boolean_op {
342342
}};
343343
}
344344

345+
/// Invoke a boolean kernel with a scalar on an array
346+
macro_rules! boolean_op_scalar {
347+
($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
348+
let ll = $LEFT
349+
.as_any()
350+
.downcast_ref::<BooleanArray>()
351+
.expect("boolean_op_scalar failed to downcast array");
352+
353+
let result = if let ScalarValue::Boolean(scalar) = $RIGHT {
354+
Ok(
355+
Arc::new(paste::expr! {[<$OP _bool_scalar>]}(&ll, scalar.as_ref())?)
356+
as ArrayRef,
357+
)
358+
} else {
359+
Err(DataFusionError::Internal(format!(
360+
"boolean_op_scalar failed to cast literal value {}",
361+
$RIGHT
362+
)))
363+
};
364+
Some(result)
365+
}};
366+
}
367+
345368
macro_rules! binary_string_array_flag_op {
346369
($LEFT:expr, $RIGHT:expr, $OP:ident, $NOT:expr, $FLAG:expr) => {{
347370
match $LEFT.data_type() {
@@ -592,9 +615,19 @@ impl BinaryExpr {
592615
Operator::GtEq => {
593616
binary_array_op_scalar!(array, scalar.clone(), gt_eq)
594617
}
595-
Operator::Eq => binary_array_op_scalar!(array, scalar.clone(), eq),
618+
Operator::Eq => {
619+
if array.data_type() == &DataType::Boolean {
620+
boolean_op_scalar!(array, scalar.clone(), eq)
621+
} else {
622+
binary_array_op_scalar!(array, scalar.clone(), eq)
623+
}
624+
}
596625
Operator::NotEq => {
597-
binary_array_op_scalar!(array, scalar.clone(), neq)
626+
if array.data_type() == &DataType::Boolean {
627+
boolean_op_scalar!(array, scalar.clone(), neq)
628+
} else {
629+
binary_array_op_scalar!(array, scalar.clone(), neq)
630+
}
598631
}
599632
Operator::Like => {
600633
binary_string_array_op_scalar!(array, scalar.clone(), like)
@@ -659,9 +692,19 @@ impl BinaryExpr {
659692
Operator::GtEq => {
660693
binary_array_op_scalar!(array, scalar.clone(), lt_eq)
661694
}
662-
Operator::Eq => binary_array_op_scalar!(array, scalar.clone(), eq),
695+
Operator::Eq => {
696+
if array.data_type() == &DataType::Boolean {
697+
boolean_op_scalar!(array, scalar.clone(), eq)
698+
} else {
699+
binary_array_op_scalar!(array, scalar.clone(), eq)
700+
}
701+
}
663702
Operator::NotEq => {
664-
binary_array_op_scalar!(array, scalar.clone(), neq)
703+
if array.data_type() == &DataType::Boolean {
704+
boolean_op_scalar!(array, scalar.clone(), neq)
705+
} else {
706+
binary_array_op_scalar!(array, scalar.clone(), neq)
707+
}
665708
}
666709
// if scalar operation is not supported - fallback to array implementation
667710
_ => None,
@@ -683,8 +726,21 @@ impl BinaryExpr {
683726
Operator::LtEq => binary_array_op!(left, right, lt_eq),
684727
Operator::Gt => binary_array_op!(left, right, gt),
685728
Operator::GtEq => binary_array_op!(left, right, gt_eq),
686-
Operator::Eq => binary_array_op!(left, right, eq),
687-
Operator::NotEq => binary_array_op!(left, right, neq),
729+
Operator::Eq => {
730+
if left_data_type == &DataType::Boolean {
731+
boolean_op!(left, right, eq_bool)
732+
} else {
733+
binary_array_op!(left, right, eq)
734+
}
735+
}
736+
Operator::NotEq => {
737+
if left_data_type == &DataType::Boolean {
738+
boolean_op!(left, right, neq_bool)
739+
} else {
740+
binary_array_op!(left, right, neq)
741+
}
742+
}
743+
688744
Operator::IsDistinctFrom => binary_array_op!(left, right, is_distinct_from),
689745
Operator::IsNotDistinctFrom => {
690746
binary_array_op!(left, right, is_not_distinct_from)
@@ -814,14 +870,68 @@ pub fn binary(
814870
Ok(Arc::new(BinaryExpr::new(l, op, r)))
815871
}
816872

873+
// TODO file a ticket with arrow-rs to include these kernels
874+
875+
fn eq_bool(lhs: &BooleanArray, rhs: &BooleanArray) -> Result<BooleanArray> {
876+
let arr: BooleanArray = lhs
877+
.iter()
878+
.zip(rhs.iter())
879+
.map(|v| match v {
880+
// both lhs and rhs were non null
881+
(Some(lhs), Some(rhs)) => Some(lhs == rhs),
882+
_ => None,
883+
})
884+
.collect();
885+
886+
Ok(arr)
887+
}
888+
889+
fn eq_bool_scalar(lhs: &BooleanArray, rhs: Option<&bool>) -> Result<BooleanArray> {
890+
let arr: BooleanArray = lhs
891+
.iter()
892+
.map(|v| match (v, rhs) {
893+
// both lhs and rhs were non null
894+
(Some(lhs), Some(rhs)) => Some(lhs == *rhs),
895+
_ => None,
896+
})
897+
.collect();
898+
Ok(arr)
899+
}
900+
901+
fn neq_bool(lhs: &BooleanArray, rhs: &BooleanArray) -> Result<BooleanArray> {
902+
let arr: BooleanArray = lhs
903+
.iter()
904+
.zip(rhs.iter())
905+
.map(|v| match v {
906+
// both lhs and rhs were non null
907+
(Some(lhs), Some(rhs)) => Some(lhs != rhs),
908+
_ => None,
909+
})
910+
.collect();
911+
912+
Ok(arr)
913+
}
914+
915+
fn neq_bool_scalar(lhs: &BooleanArray, rhs: Option<&bool>) -> Result<BooleanArray> {
916+
let arr: BooleanArray = lhs
917+
.iter()
918+
.map(|v| match (v, rhs) {
919+
// both lhs and rhs were non null
920+
(Some(lhs), Some(rhs)) => Some(lhs != *rhs),
921+
_ => None,
922+
})
923+
.collect();
924+
Ok(arr)
925+
}
926+
817927
#[cfg(test)]
818928
mod tests {
819929
use arrow::datatypes::{ArrowNumericType, Field, Int32Type, SchemaRef};
820930
use arrow::util::display::array_value_to_string;
821931

822932
use super::*;
823933
use crate::error::Result;
824-
use crate::physical_plan::expressions::col;
934+
use crate::physical_plan::expressions::{col, lit};
825935

826936
// Create a binary expression without coercion. Used here when we do not want to coerce the expressions
827937
// to valid types. Usage can result in an execution (after plan) error.
@@ -1371,6 +1481,42 @@ mod tests {
13711481
Ok(())
13721482
}
13731483

1484+
// Test `scalar <op> arr` produces expected
1485+
fn apply_logic_op_scalar_arr(
1486+
schema: &SchemaRef,
1487+
scalar: bool,
1488+
arr: &ArrayRef,
1489+
op: Operator,
1490+
expected: &BooleanArray,
1491+
) -> Result<()> {
1492+
let scalar = lit(scalar.into());
1493+
1494+
let arithmetic_op = binary_simple(scalar, op, col("a", schema)?);
1495+
let batch = RecordBatch::try_new(Arc::clone(schema), vec![Arc::clone(arr)])?;
1496+
let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows());
1497+
assert_eq!(result.as_ref(), expected);
1498+
1499+
Ok(())
1500+
}
1501+
1502+
// Test `arr <op> scalar` produces expected
1503+
fn apply_logic_op_arr_scalar(
1504+
schema: &SchemaRef,
1505+
arr: &ArrayRef,
1506+
scalar: bool,
1507+
op: Operator,
1508+
expected: &BooleanArray,
1509+
) -> Result<()> {
1510+
let scalar = lit(scalar.into());
1511+
1512+
let arithmetic_op = binary_simple(col("a", schema)?, op, scalar);
1513+
let batch = RecordBatch::try_new(Arc::clone(schema), vec![Arc::clone(arr)])?;
1514+
let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows());
1515+
assert_eq!(result.as_ref(), expected);
1516+
1517+
Ok(())
1518+
}
1519+
13741520
#[test]
13751521
fn and_with_nulls_op() -> Result<()> {
13761522
let schema = Schema::new(vec![
@@ -1461,6 +1607,58 @@ mod tests {
14611607
Ok(())
14621608
}
14631609

1610+
#[test]
1611+
fn eq_op_bool() {
1612+
let schema = Schema::new(vec![
1613+
Field::new("a", DataType::Boolean, false),
1614+
Field::new("b", DataType::Boolean, false),
1615+
]);
1616+
let a = BooleanArray::from(vec![Some(true), None, Some(false), None]);
1617+
let b =
1618+
BooleanArray::from(vec![Some(true), Some(false), Some(true), Some(false)]);
1619+
1620+
let expected = BooleanArray::from(vec![Some(true), None, Some(false), None]);
1621+
apply_logic_op(Arc::new(schema), a, b, Operator::Eq, expected).unwrap();
1622+
}
1623+
1624+
#[test]
1625+
fn eq_op_bool_scalar() {
1626+
let schema = Schema::new(vec![Field::new("a", DataType::Boolean, false)]);
1627+
let schema = Arc::new(schema);
1628+
let a: ArrayRef =
1629+
Arc::new(BooleanArray::from(vec![Some(true), None, Some(false)]));
1630+
1631+
let expected = BooleanArray::from(vec![Some(true), None, Some(false)]);
1632+
apply_logic_op_scalar_arr(&schema, true, &a, Operator::Eq, &expected).unwrap();
1633+
apply_logic_op_arr_scalar(&schema, &a, true, Operator::Eq, &expected).unwrap();
1634+
}
1635+
1636+
#[test]
1637+
fn neq_op_bool() {
1638+
let schema = Schema::new(vec![
1639+
Field::new("a", DataType::Boolean, false),
1640+
Field::new("b", DataType::Boolean, false),
1641+
]);
1642+
let a = BooleanArray::from(vec![Some(true), None, Some(false), None]);
1643+
let b =
1644+
BooleanArray::from(vec![Some(true), Some(false), Some(true), Some(false)]);
1645+
1646+
let expected = BooleanArray::from(vec![Some(false), None, Some(true), None]);
1647+
apply_logic_op(Arc::new(schema), a, b, Operator::NotEq, expected).unwrap();
1648+
}
1649+
1650+
#[test]
1651+
fn neq_op_bool_scalar() {
1652+
let schema = Schema::new(vec![Field::new("a", DataType::Boolean, false)]);
1653+
let schema = Arc::new(schema);
1654+
let a: ArrayRef =
1655+
Arc::new(BooleanArray::from(vec![Some(true), None, Some(false)]));
1656+
1657+
let expected = BooleanArray::from(vec![Some(false), None, Some(true)]);
1658+
apply_logic_op_scalar_arr(&schema, true, &a, Operator::NotEq, &expected).unwrap();
1659+
apply_logic_op_arr_scalar(&schema, &a, true, Operator::NotEq, &expected).unwrap();
1660+
}
1661+
14641662
#[test]
14651663
fn test_coersion_error() -> Result<()> {
14661664
let expr =

datafusion/tests/sql.rs

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,7 @@ async fn select_distinct_simple_4() {
598598
async fn select_distinct_from() {
599599
let mut ctx = ExecutionContext::new();
600600

601-
let sql = "select
601+
let sql = "select
602602
1 IS DISTINCT FROM CAST(NULL as INT) as a,
603603
1 IS DISTINCT FROM 1 as b,
604604
1 IS NOT DISTINCT FROM CAST(NULL as INT) as c,
@@ -621,7 +621,7 @@ async fn select_distinct_from() {
621621
async fn select_distinct_from_utf8() {
622622
let mut ctx = ExecutionContext::new();
623623

624-
let sql = "select
624+
let sql = "select
625625
'x' IS DISTINCT FROM NULL as a,
626626
'x' IS DISTINCT FROM 'x' as b,
627627
'x' IS NOT DISTINCT FROM NULL as c,
@@ -812,6 +812,40 @@ async fn csv_query_having_without_group_by() -> Result<()> {
812812
Ok(())
813813
}
814814

815+
#[tokio::test]
816+
async fn csv_query_boolean_eq() -> Result<()> {
817+
let mut ctx = ExecutionContext::new();
818+
register_aggregate_simple_csv(&mut ctx).await?;
819+
820+
let sql = "SELECT c3, c3 = c3 as eq, c3 != c3 as neq FROM aggregate_simple";
821+
let actual = execute_to_batches(&mut ctx, sql).await;
822+
823+
let expected = vec![
824+
"+-------+------+-------+",
825+
"| c3 | eq | neq |",
826+
"+-------+------+-------+",
827+
"| true | true | false |",
828+
"| false | true | false |",
829+
"| false | true | false |",
830+
"| true | true | false |",
831+
"| true | true | false |",
832+
"| true | true | false |",
833+
"| false | true | false |",
834+
"| false | true | false |",
835+
"| false | true | false |",
836+
"| false | true | false |",
837+
"| true | true | false |",
838+
"| true | true | false |",
839+
"| true | true | false |",
840+
"| true | true | false |",
841+
"| true | true | false |",
842+
"+-------+------+-------+",
843+
];
844+
assert_batches_eq!(expected, &actual);
845+
846+
Ok(())
847+
}
848+
815849
#[tokio::test]
816850
async fn csv_query_avg_sqrt() -> Result<()> {
817851
let mut ctx = create_ctx()?;
@@ -4054,6 +4088,8 @@ macro_rules! test_expression {
40544088
async fn test_boolean_expressions() -> Result<()> {
40554089
test_expression!("true", "true");
40564090
test_expression!("false", "false");
4091+
test_expression!("false = false", "true");
4092+
test_expression!("true = false", "false");
40574093
Ok(())
40584094
}
40594095

0 commit comments

Comments
 (0)