diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index d14146a20d8b..2db599047bcd 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -205,10 +205,15 @@ impl CaseExpr { let mut current_value = new_null_array(&return_type, batch.num_rows()); // We only consider non-null values while comparing with whens let mut remainder = not(&base_nulls)?; + let mut non_null_remainder_count = remainder.true_count(); for i in 0..self.when_then_expr.len() { - let when_value = self.when_then_expr[i] - .0 - .evaluate_selection(batch, &remainder)?; + // If there are no rows left to process, break out of the loop early + if non_null_remainder_count == 0 { + break; + } + + let when_predicate = &self.when_then_expr[i].0; + let when_value = when_predicate.evaluate_selection(batch, &remainder)?; let when_value = when_value.into_array(batch.num_rows())?; // build boolean array representing which rows match the "when" value let when_match = compare_with_eq( @@ -224,41 +229,46 @@ impl CaseExpr { _ => Cow::Owned(prep_null_mask_filter(&when_match)), }; // Make sure we only consider rows that have not been matched yet - let when_match = and(&when_match, &remainder)?; + let when_value = and(&when_match, &remainder)?; - // When no rows available for when clause, skip then clause - if when_match.true_count() == 0 { + // If the predicate did not match any rows, continue to the next branch immediately + let when_match_count = when_value.true_count(); + if when_match_count == 0 { continue; } - let then_value = self.when_then_expr[i] - .1 - .evaluate_selection(batch, &when_match)?; + let then_expression = &self.when_then_expr[i].1; + let then_value = then_expression.evaluate_selection(batch, &when_value)?; current_value = match then_value { ColumnarValue::Scalar(ScalarValue::Null) => { - nullif(current_value.as_ref(), &when_match)? + nullif(current_value.as_ref(), &when_value)? } ColumnarValue::Scalar(then_value) => { - zip(&when_match, &then_value.to_scalar()?, ¤t_value)? + zip(&when_value, &then_value.to_scalar()?, ¤t_value)? } ColumnarValue::Array(then_value) => { - zip(&when_match, &then_value, ¤t_value)? + zip(&when_value, &then_value, ¤t_value)? } }; - remainder = and_not(&remainder, &when_match)?; + remainder = and_not(&remainder, &when_value)?; + non_null_remainder_count -= when_match_count; } if let Some(e) = self.else_expr() { - // keep `else_expr`'s data type and return type consistent - let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; // null and unmatched tuples should be assigned else value remainder = or(&base_nulls, &remainder)?; - let else_ = expr - .evaluate_selection(batch, &remainder)? - .into_array(batch.num_rows())?; - current_value = zip(&remainder, &else_, ¤t_value)?; + + if remainder.true_count() > 0 { + // keep `else_expr`'s data type and return type consistent + let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; + + let else_ = expr + .evaluate_selection(batch, &remainder)? + .into_array(batch.num_rows())?; + current_value = zip(&remainder, &else_, ¤t_value)?; + } } Ok(ColumnarValue::Array(current_value)) @@ -277,10 +287,15 @@ impl CaseExpr { // start with nulls as default output let mut current_value = new_null_array(&return_type, batch.num_rows()); let mut remainder = BooleanArray::from(vec![true; batch.num_rows()]); + let mut remainder_count = batch.num_rows(); for i in 0..self.when_then_expr.len() { - let when_value = self.when_then_expr[i] - .0 - .evaluate_selection(batch, &remainder)?; + // If there are no rows left to process, break out of the loop early + if remainder_count == 0 { + break; + } + + let when_predicate = &self.when_then_expr[i].0; + let when_value = when_predicate.evaluate_selection(batch, &remainder)?; let when_value = when_value.into_array(batch.num_rows())?; let when_value = as_boolean_array(&when_value).map_err(|_| { internal_datafusion_err!("WHEN expression did not return a BooleanArray") @@ -293,14 +308,14 @@ impl CaseExpr { // Make sure we only consider rows that have not been matched yet let when_value = and(&when_value, &remainder)?; - // When no rows available for when clause, skip then clause - if when_value.true_count() == 0 { + // If the predicate did not match any rows, continue to the next branch immediately + let when_match_count = when_value.true_count(); + if when_match_count == 0 { continue; } - let then_value = self.when_then_expr[i] - .1 - .evaluate_selection(batch, &when_value)?; + let then_expression = &self.when_then_expr[i].1; + let then_value = then_expression.evaluate_selection(batch, &when_value)?; current_value = match then_value { ColumnarValue::Scalar(ScalarValue::Null) => { @@ -317,10 +332,11 @@ impl CaseExpr { // Succeed tuples should be filtered out for short-circuit evaluation, // null values for the current when expr should be kept remainder = and_not(&remainder, &when_value)?; + remainder_count -= when_match_count; } if let Some(e) = self.else_expr() { - if remainder.true_count() > 0 { + if remainder_count > 0 { // keep `else_expr`'s data type and return type consistent let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; let else_ = expr diff --git a/datafusion/sqllogictest/test_files/case.slt b/datafusion/sqllogictest/test_files/case.slt index 9bc1f83ed119..2f9173d2dcbd 100644 --- a/datafusion/sqllogictest/test_files/case.slt +++ b/datafusion/sqllogictest/test_files/case.slt @@ -519,3 +519,79 @@ query I SELECT case when false then 1 / 0 else 1 / 1 end; ---- 1 + +# Else branch evaluation with case expression, 1 when branch, null input +query I +SELECT CASE a WHEN 'a' THEN 0 ELSE 1 END FROM (VALUES (NULL)) t(a) +---- +1 + +# Else branch evaluation with case expression, 2 when branches, null input +query I +SELECT CASE a WHEN 'a' THEN 0 WHEN 'b' THEN 1 ELSE 2 END FROM (VALUES (NULL)) t(a) +---- +2 + +# Else branch evaluation without case expression, 1 when branch, null input +query I +SELECT CASE WHEN a = 'a' THEN 0 ELSE 1 END FROM (VALUES (NULL)) t(a) +---- +1 + +# Else branch evaluation without case expression, 2 when branches, null input +query I +SELECT CASE WHEN a = 'a' THEN 0 WHEN a = 'b' THEN 1 ELSE 2 END FROM (VALUES (NULL)) t(a) +---- +2 + +# Else branch evaluation with case expression, 1 when branch, non-null input +query I +SELECT CASE a WHEN 'a' THEN 0 ELSE 1 END FROM (VALUES ('z')) t(a) +---- +1 + +# Else branch evaluation with case expression, 2 when branches, non-null input +query I +SELECT CASE a WHEN 'a' THEN 0 WHEN 'b' THEN 1 ELSE 2 END FROM (VALUES ('z')) t(a) +---- +2 + +# Else branch evaluation without case expression, 1 when branch, non-null input +query I +SELECT CASE WHEN a = 'a' THEN 0 ELSE 1 END FROM (VALUES ('z')) t(a) +---- +1 + +# Else branch evaluation without case expression, 2 when branches, non-null input +query I +SELECT CASE WHEN a = 'a' THEN 0 WHEN a = 'b' THEN 1 ELSE 2 END FROM (VALUES ('z')) t(a) +---- +2 + +# Else branch evaluation with case expression, 1 when branch, mixed input +query I +SELECT CASE a WHEN 'a' THEN 0 ELSE 1 END FROM (VALUES (NULL), ('z')) t(a) +---- +1 +1 + +# Else branch evaluation with case expression, 2 when branches, mixed input +query I +SELECT CASE a WHEN 'a' THEN 0 WHEN 'b' THEN 1 ELSE 2 END FROM (VALUES (NULL), ('z')) t(a) +---- +2 +2 + +# Else branch evaluation without case expression, 1 when branch, mixed input +query I +SELECT CASE WHEN a = 'a' THEN 0 ELSE 1 END FROM (VALUES (NULL), ('z')) t(a) +---- +1 +1 + +# Else branch evaluation without case expression, 2 when branches, mixed input +query I +SELECT CASE WHEN a = 'a' THEN 0 WHEN a = 'b' THEN 1 ELSE 2 END FROM (VALUES (NULL), ('z')) t(a) +---- +2 +2 \ No newline at end of file