Skip to content
194 changes: 167 additions & 27 deletions datafusion/optimizer/src/push_down_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@
//! Push Down Filter optimizer rule ensures that filters are applied as early as possible in the plan

use crate::optimizer::ApplyOrder;
use crate::utils::{conjunction, split_conjunction};
use crate::utils::{conjunction, split_conjunction, split_conjunction_owned};
use crate::{utils, OptimizerConfig, OptimizerRule};
use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion};
use datafusion_common::{
internal_err, plan_datafusion_err, Column, DFSchema, DataFusionError, Result,
};
use datafusion_expr::expr::Alias;
use datafusion_expr::Volatility;
use datafusion_expr::{
and,
expr_rewriter::replace_col,
Expand Down Expand Up @@ -652,32 +653,60 @@ impl OptimizerRule for PushDownFilter {
child_plan.with_new_inputs(&[new_filter])?
}
LogicalPlan::Projection(projection) => {
// A projection is filter-commutable, but re-writes all predicate expressions
// A projection is filter-commutable if it do not contain volatile predicates or contain volatile
// predicates that are not used in the filter. However, we should re-writes all predicate expressions.
// collect projection.
let replace_map = projection
.schema
.fields()
.iter()
.enumerate()
.map(|(i, field)| {
// strip alias, as they should not be part of filters
let expr = match &projection.expr[i] {
Expr::Alias(Alias { expr, .. }) => expr.as_ref().clone(),
expr => expr.clone(),
};

(field.qualified_name(), expr)
})
.collect::<HashMap<_, _>>();
let (volatile_map, non_volatile_map): (HashMap<_, _>, HashMap<_, _>) =
projection
.schema
.fields()
.iter()
.enumerate()
.map(|(i, field)| {
// strip alias, as they should not be part of filters
let expr = match &projection.expr[i] {
Expr::Alias(Alias { expr, .. }) => expr.as_ref().clone(),
expr => expr.clone(),
};

(field.qualified_name(), expr)
})
.partition(|(_, value)| is_volatile_expression(value));

// re-write all filters based on this projection
// E.g. in `Filter: b\n Projection: a > 1 as b`, we can swap them, but the filter must be "a > 1"
let new_filter = LogicalPlan::Filter(Filter::try_new(
replace_cols_by_name(filter.predicate.clone(), &replace_map)?,
projection.input.clone(),
)?);
let mut push_predicates = vec![];
let mut keep_predicates = vec![];
for expr in split_conjunction_owned(filter.predicate.clone()).into_iter()
{
if contain(&expr, &volatile_map) {
keep_predicates.push(expr);
} else {
push_predicates.push(expr);
}
}

child_plan.with_new_inputs(&[new_filter])?
match conjunction(push_predicates) {
Some(expr) => {
// re-write all filters based on this projection
// E.g. in `Filter: b\n Projection: a > 1 as b`, we can swap them, but the filter must be "a > 1"
let new_filter = LogicalPlan::Filter(Filter::try_new(
replace_cols_by_name(expr, &non_volatile_map)?,
projection.input.clone(),
)?);

match conjunction(keep_predicates) {
None => child_plan.with_new_inputs(&[new_filter])?,
Some(keep_predicate) => {
let child_plan =
child_plan.with_new_inputs(&[new_filter])?;
LogicalPlan::Filter(Filter::try_new(
keep_predicate,
Arc::new(child_plan),
)?)
}
}
}
None => return Ok(None),
}
}
LogicalPlan::Union(union) => {
let mut inputs = Vec::with_capacity(union.inputs.len());
Expand Down Expand Up @@ -881,6 +910,42 @@ pub fn replace_cols_by_name(
})
}

/// check whether the expression is volatile predicates
fn is_volatile_expression(e: &Expr) -> bool {
let mut is_volatile = false;
e.apply(&mut |expr| {
Ok(match expr {
Expr::ScalarFunction(f) if f.fun.volatility() == Volatility::Volatile => {
is_volatile = true;
VisitRecursion::Stop
}
_ => VisitRecursion::Continue,
})
})
.unwrap();
is_volatile
}

/// check whether the expression uses the columns in `check_map`.
fn contain(e: &Expr, check_map: &HashMap<String, Expr>) -> bool {
let mut is_contain = false;
e.apply(&mut |expr| {
Ok(if let Expr::Column(c) = &expr {
match check_map.get(&c.flat_name()) {
Some(_) => {
is_contain = true;
VisitRecursion::Stop
}
None => VisitRecursion::Continue,
}
} else {
VisitRecursion::Continue
})
})
.unwrap();
is_contain
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -893,9 +958,9 @@ mod tests {
use datafusion_common::{DFSchema, DFSchemaRef};
use datafusion_expr::logical_plan::table_scan;
use datafusion_expr::{
and, col, in_list, in_subquery, lit, logical_plan::JoinType, or, sum, BinaryExpr,
Expr, Extension, LogicalPlanBuilder, Operator, TableSource, TableType,
UserDefinedLogicalNodeCore,
and, col, in_list, in_subquery, lit, logical_plan::JoinType, or, random, sum,
BinaryExpr, Expr, Extension, LogicalPlanBuilder, Operator, TableSource,
TableType, UserDefinedLogicalNodeCore,
};
use std::fmt::{Debug, Formatter};
use std::sync::Arc;
Expand Down Expand Up @@ -2712,4 +2777,79 @@ Projection: a, b
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected)
}

#[test]
fn test_push_down_volatile_function_in_aggregate() -> Result<()> {
// SELECT t.a, t.r FROM (SELECT a, SUM(b), random()+1 AS r FROM test1 GROUP BY a) AS t WHERE t.a > 5 AND t.r > 0.5;
let table_scan = test_table_scan_with_name("test1")?;
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("a")], vec![sum(col("b"))])?
.project(vec![
col("a"),
sum(col("b")),
add(random(), lit(1)).alias("r"),
])?
.alias("t")?
.filter(col("t.a").gt(lit(5)).and(col("t.r").gt(lit(0.5))))?
.project(vec![col("t.a"), col("t.r")])?
.build()?;

let expected_before = "Projection: t.a, t.r\
\n Filter: t.a > Int32(5) AND t.r > Float64(0.5)\
\n SubqueryAlias: t\
\n Projection: test1.a, SUM(test1.b), random() + Int32(1) AS r\
\n Aggregate: groupBy=[[test1.a]], aggr=[[SUM(test1.b)]]\
\n TableScan: test1";
assert_eq!(format!("{plan:?}"), expected_before);

let expected_after = "Projection: t.a, t.r\
\n SubqueryAlias: t\
\n Filter: r > Float64(0.5)\
\n Projection: test1.a, SUM(test1.b), random() + Int32(1) AS r\
\n Aggregate: groupBy=[[test1.a]], aggr=[[SUM(test1.b)]]\
\n TableScan: test1, full_filters=[test1.a > Int32(5)]";
assert_optimized_plan_eq(&plan, expected_after)
}

#[test]
fn test_push_down_volatile_function_in_join() -> Result<()> {
// SELECT t.a, t.r FROM (SELECT test1.a AS a, random() AS r FROM test1 join test2 ON test1.a = test2.a) AS t WHERE t.r > 0.5;
let table_scan = test_table_scan_with_name("test1")?;
let left = LogicalPlanBuilder::from(table_scan).build()?;
let right_table_scan = test_table_scan_with_name("test2")?;
let right = LogicalPlanBuilder::from(right_table_scan).build()?;
let plan = LogicalPlanBuilder::from(left)
.join(
right,
JoinType::Inner,
(
vec![Column::from_qualified_name("test1.a")],
vec![Column::from_qualified_name("test2.a")],
),
None,
)?
.project(vec![col("test1.a").alias("a"), random().alias("r")])?
.alias("t")?
.filter(col("t.r").gt(lit(0.8)))?
.project(vec![col("t.a"), col("t.r")])?
.build()?;

let expected_before = "Projection: t.a, t.r\
\n Filter: t.r > Float64(0.8)\
\n SubqueryAlias: t\
\n Projection: test1.a AS a, random() AS r\
\n Inner Join: test1.a = test2.a\
\n TableScan: test1\
\n TableScan: test2";
assert_eq!(format!("{plan:?}"), expected_before);

let expected = "Projection: t.a, t.r\
\n SubqueryAlias: t\
\n Filter: r > Float64(0.8)\
\n Projection: test1.a AS a, random() AS r\
\n Inner Join: test1.a = test2.a\
\n TableScan: test1\
\n TableScan: test2";
assert_optimized_plan_eq(&plan, expected)
}
}