Skip to content

Commit 4cde998

Browse files
authored
fix: don't extract common sub expr in CASE WHEN clause (#8833)
* fix: don't extract common sub expr in CASE WHEN clause * fix ci * fix
1 parent e966a10 commit 4cde998

File tree

6 files changed

+58
-48
lines changed

6 files changed

+58
-48
lines changed

datafusion/optimizer/src/common_subexpr_eliminate.rs

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
use std::collections::{BTreeSet, HashMap};
2121
use std::sync::Arc;
2222

23+
use crate::utils::is_volatile_expression;
2324
use crate::{utils, OptimizerConfig, OptimizerRule};
2425

2526
use arrow::datatypes::DataType;
@@ -29,7 +30,7 @@ use datafusion_common::tree_node::{
2930
use datafusion_common::{
3031
internal_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result,
3132
};
32-
use datafusion_expr::expr::{is_volatile, Alias};
33+
use datafusion_expr::expr::Alias;
3334
use datafusion_expr::logical_plan::{
3435
Aggregate, Filter, LogicalPlan, Projection, Sort, Window,
3536
};
@@ -518,7 +519,7 @@ enum ExprMask {
518519
}
519520

520521
impl ExprMask {
521-
fn ignores(&self, expr: &Expr) -> Result<bool> {
522+
fn ignores(&self, expr: &Expr) -> bool {
522523
let is_normal_minus_aggregates = matches!(
523524
expr,
524525
Expr::Literal(..)
@@ -529,14 +530,12 @@ impl ExprMask {
529530
| Expr::Wildcard { .. }
530531
);
531532

532-
let is_volatile = is_volatile(expr)?;
533-
534533
let is_aggr = matches!(expr, Expr::AggregateFunction(..));
535534

536-
Ok(match self {
537-
Self::Normal => is_volatile || is_normal_minus_aggregates || is_aggr,
538-
Self::NormalAndAggregates => is_volatile || is_normal_minus_aggregates,
539-
})
535+
match self {
536+
Self::Normal => is_normal_minus_aggregates || is_aggr,
537+
Self::NormalAndAggregates => is_normal_minus_aggregates,
538+
}
540539
}
541540
}
542541

@@ -614,7 +613,12 @@ impl ExprIdentifierVisitor<'_> {
614613
impl TreeNodeVisitor for ExprIdentifierVisitor<'_> {
615614
type N = Expr;
616615

617-
fn pre_visit(&mut self, _expr: &Expr) -> Result<VisitRecursion> {
616+
fn pre_visit(&mut self, expr: &Expr) -> Result<VisitRecursion> {
617+
// related to https://github.com/apache/arrow-datafusion/issues/8814
618+
// If the expr contain volatile expression or is a case expression, skip it.
619+
if matches!(expr, Expr::Case(..)) || is_volatile_expression(expr)? {
620+
return Ok(VisitRecursion::Skip);
621+
}
618622
self.visit_stack
619623
.push(VisitRecord::EnterMark(self.node_count));
620624
self.node_count += 1;
@@ -628,7 +632,7 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> {
628632

629633
let (idx, sub_expr_desc) = self.pop_enter_mark();
630634
// skip exprs should not be recognize.
631-
if self.expr_mask.ignores(expr)? {
635+
if self.expr_mask.ignores(expr) {
632636
self.id_array[idx].0 = self.series_number;
633637
let desc = Self::desc_expr(expr);
634638
self.visit_stack.push(VisitRecord::ExprItem(desc));

datafusion/optimizer/src/push_down_filter.rs

Lines changed: 5 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use std::collections::{HashMap, HashSet};
1919
use std::sync::Arc;
2020

2121
use crate::optimizer::ApplyOrder;
22+
use crate::utils::is_volatile_expression;
2223
use crate::{OptimizerConfig, OptimizerRule};
2324

2425
use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion};
@@ -34,7 +35,7 @@ use datafusion_expr::logical_plan::{
3435
use datafusion_expr::utils::{conjunction, split_conjunction, split_conjunction_owned};
3536
use datafusion_expr::{
3637
and, build_join_schema, or, BinaryExpr, Expr, Filter, LogicalPlanBuilder, Operator,
37-
ScalarFunctionDefinition, TableProviderFilterPushDown, Volatility,
38+
ScalarFunctionDefinition, TableProviderFilterPushDown,
3839
};
3940

4041
use itertools::Itertools;
@@ -739,7 +740,9 @@ impl OptimizerRule for PushDownFilter {
739740

740741
(field.qualified_name(), expr)
741742
})
742-
.partition(|(_, value)| is_volatile_expression(value));
743+
.partition(|(_, value)| {
744+
is_volatile_expression(value).unwrap_or(true)
745+
});
743746

744747
let mut push_predicates = vec![];
745748
let mut keep_predicates = vec![];
@@ -1028,38 +1031,6 @@ pub fn replace_cols_by_name(
10281031
})
10291032
}
10301033

1031-
/// check whether the expression is volatile predicates
1032-
fn is_volatile_expression(e: &Expr) -> bool {
1033-
let mut is_volatile = false;
1034-
e.apply(&mut |expr| {
1035-
Ok(match expr {
1036-
Expr::ScalarFunction(f) => match &f.func_def {
1037-
ScalarFunctionDefinition::BuiltIn(fun)
1038-
if fun.volatility() == Volatility::Volatile =>
1039-
{
1040-
is_volatile = true;
1041-
VisitRecursion::Stop
1042-
}
1043-
ScalarFunctionDefinition::UDF(fun)
1044-
if fun.signature().volatility == Volatility::Volatile =>
1045-
{
1046-
is_volatile = true;
1047-
VisitRecursion::Stop
1048-
}
1049-
ScalarFunctionDefinition::Name(_) => {
1050-
return internal_err!(
1051-
"Function `Expr` with name should be resolved."
1052-
);
1053-
}
1054-
_ => VisitRecursion::Continue,
1055-
},
1056-
_ => VisitRecursion::Continue,
1057-
})
1058-
})
1059-
.unwrap();
1060-
is_volatile
1061-
}
1062-
10631034
/// check whether the expression uses the columns in `check_map`.
10641035
fn contain(e: &Expr, check_map: &HashMap<String, Expr>) -> bool {
10651036
let mut is_contain = false;

datafusion/optimizer/src/utils.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818
//! Collection of utility functions that are leveraged by the query optimizer rules
1919
2020
use crate::{OptimizerConfig, OptimizerRule};
21+
use datafusion_common::tree_node::{TreeNode, VisitRecursion};
2122
use datafusion_common::{Column, DFSchemaRef};
2223
use datafusion_common::{DFSchema, Result};
24+
use datafusion_expr::expr::is_volatile;
2325
use datafusion_expr::expr_rewriter::replace_col;
2426
use datafusion_expr::utils as expr_utils;
2527
use datafusion_expr::{logical_plan::LogicalPlan, Expr, Operator};
@@ -92,6 +94,20 @@ pub fn log_plan(description: &str, plan: &LogicalPlan) {
9294
trace!("{description}::\n{}\n", plan.display_indent_schema());
9395
}
9496

97+
/// check whether the expression is volatile predicates
98+
pub(crate) fn is_volatile_expression(e: &Expr) -> Result<bool> {
99+
let mut is_volatile_expr = false;
100+
e.apply(&mut |expr| {
101+
Ok(if is_volatile(expr)? {
102+
is_volatile_expr = true;
103+
VisitRecursion::Stop
104+
} else {
105+
VisitRecursion::Continue
106+
})
107+
})?;
108+
Ok(is_volatile_expr)
109+
}
110+
95111
/// Splits a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]`
96112
///
97113
/// See [`split_conjunction_owned`] for more details and an example.

datafusion/sqllogictest/test_files/functions.slt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -998,6 +998,6 @@ NULL
998998

999999
# Verify that multiple calls to volatile functions like `random()` are not combined / optimized away
10001000
query B
1001-
SELECT r FROM (SELECT r1 == r2 r, r1, r2 FROM (SELECT random() r1, random() r2) WHERE r1 > 0 AND r2 > 0)
1001+
SELECT r FROM (SELECT r1 == r2 r, r1, r2 FROM (SELECT random()+1 r1, random()+1 r2) WHERE r1 > 0 AND r2 > 0)
10021002
----
10031003
false

datafusion/sqllogictest/test_files/select.slt

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1112,3 +1112,22 @@ SELECT abs(x), abs(x) + abs(y) FROM t;
11121112

11131113
statement ok
11141114
DROP TABLE t;
1115+
1116+
# related to https://github.com/apache/arrow-datafusion/issues/8814
1117+
statement ok
1118+
create table t(x int, y int) as values (1,1), (2,2), (3,3), (0,0), (4,0);
1119+
1120+
query II
1121+
SELECT
1122+
CASE WHEN B.x > 0 THEN A.x / B.x ELSE 0 END AS value1,
1123+
CASE WHEN B.x > 0 AND B.y > 0 THEN A.x / B.x ELSE 0 END AS value3
1124+
FROM t AS A, (SELECT * FROM t WHERE x = 0) AS B;
1125+
----
1126+
0 0
1127+
0 0
1128+
0 0
1129+
0 0
1130+
0 0
1131+
1132+
statement ok
1133+
DROP TABLE t;

datafusion/sqllogictest/test_files/tpch/q14.slt.part

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ where
3333
----
3434
logical_plan
3535
Projection: Float64(100) * CAST(SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END) AS Float64) / CAST(SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS Float64) AS promo_revenue
36-
--Aggregate: groupBy=[[]], aggr=[[SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)Decimal128(Some(1),20,0) - lineitem.l_discountlineitem.l_discountDecimal128(Some(1),20,0)lineitem.l_extendedprice AS lineitem.l_extendedprice * Decimal128(Some(1),20,0) - lineitem.l_discount ELSE Decimal128(Some(0),38,4) END) AS SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END), SUM(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)Decimal128(Some(1),20,0) - lineitem.l_discountlineitem.l_discountDecimal128(Some(1),20,0)lineitem.l_extendedprice AS lineitem.l_extendedprice * Decimal128(Some(1),20,0) - lineitem.l_discount) AS SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]]
37-
----Projection: lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount) AS lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)Decimal128(Some(1),20,0) - lineitem.l_discountlineitem.l_discountDecimal128(Some(1),20,0)lineitem.l_extendedprice, part.p_type
36+
--Aggregate: groupBy=[[]], aggr=[[SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount) ELSE Decimal128(Some(0),38,4) END) AS SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END), SUM(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)) AS SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]]
37+
----Projection: lineitem.l_extendedprice, lineitem.l_discount, part.p_type
3838
------Inner Join: lineitem.l_partkey = part.p_partkey
3939
--------Projection: lineitem.l_partkey, lineitem.l_extendedprice, lineitem.l_discount
4040
----------Filter: lineitem.l_shipdate >= Date32("9374") AND lineitem.l_shipdate < Date32("9404")
@@ -45,7 +45,7 @@ ProjectionExec: expr=[100 * CAST(SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%")
4545
--AggregateExec: mode=Final, gby=[], aggr=[SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END), SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]
4646
----CoalescePartitionsExec
4747
------AggregateExec: mode=Partial, gby=[], aggr=[SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END), SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]
48-
--------ProjectionExec: expr=[l_extendedprice@1 * (Some(1),20,0 - l_discount@2) as lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)Decimal128(Some(1),20,0) - lineitem.l_discountlineitem.l_discountDecimal128(Some(1),20,0)lineitem.l_extendedprice, p_type@4 as p_type]
48+
--------ProjectionExec: expr=[l_extendedprice@1 as l_extendedprice, l_discount@2 as l_discount, p_type@4 as p_type]
4949
----------CoalesceBatchesExec: target_batch_size=8192
5050
------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_partkey@0, p_partkey@0)]
5151
--------------CoalesceBatchesExec: target_batch_size=8192

0 commit comments

Comments
 (0)