From 55b6c1ed8838fc678c9b2fddc6bdadb46cd90d84 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 18 Oct 2021 16:07:10 -0400 Subject: [PATCH 1/4] Generic constant expression evaluation --- datafusion/src/optimizer/constant_folding.rs | 279 +++++++------- datafusion/src/optimizer/utils.rs | 381 ++++++++++++++++++- datafusion/src/test_util.rs | 46 +++ datafusion/tests/sql.rs | 50 +-- 4 files changed, 572 insertions(+), 184 deletions(-) diff --git a/datafusion/src/optimizer/constant_folding.rs b/datafusion/src/optimizer/constant_folding.rs index 4d8f06fb2844..eab936d438db 100644 --- a/datafusion/src/optimizer/constant_folding.rs +++ b/datafusion/src/optimizer/constant_folding.rs @@ -15,12 +15,10 @@ // specific language governing permissions and limitations // under the License. -//! Boolean comparison rule rewrites redundant comparison expression involving boolean literal into -//! unary expression. +//! Constant folding and algebraic simplification use std::sync::Arc; -use arrow::compute::kernels::cast_utils::string_to_timestamp_nanos; use arrow::datatypes::DataType; use crate::error::Result; @@ -30,11 +28,11 @@ use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; use crate::physical_plan::functions::BuiltinScalarFunction; use crate::scalar::ScalarValue; -use arrow::compute::{kernels, DEFAULT_CAST_OPTIONS}; -/// Optimizer that simplifies comparison expressions involving boolean literals. +/// Simplifies plans by rewriting [`Expr`]`s evaluating constants +/// and applying algebraic simplifications /// -/// Recursively go through all expressions and simplify the following cases: +/// Example transformations that are applied: /// * `expr = true` and `expr != false` to `expr` when `expr` is of boolean type /// * `expr = false` and `expr != true` to `!expr` when `expr` is of boolean type /// * `true = true` and `false = false` to `true` @@ -61,14 +59,16 @@ impl OptimizerRule for ConstantFolding { // projected columns. With just the projected schema, it's not possible to infer types for // expressions that references non-projected columns within the same project plan or its // children plans. - let mut rewriter = ConstantRewriter { + let mut simplifier = Simplifier { schemas: plan.all_schemas(), execution_props, }; + let mut const_evaluator = utils::ConstEvaluator::new(); + match plan { LogicalPlan::Filter { predicate, input } => Ok(LogicalPlan::Filter { - predicate: predicate.clone().rewrite(&mut rewriter)?, + predicate: predicate.clone().rewrite(&mut simplifier)?, input: Arc::new(self.optimize(input, execution_props)?), }), // Rest: recurse into plan, apply optimization where possible @@ -95,7 +95,18 @@ impl OptimizerRule for ConstantFolding { let expr = plan .expressions() .into_iter() - .map(|e| e.rewrite(&mut rewriter)) + .map(|e| { + // TODO iterate until no changes are made + // during rewrite (evaluating constants can + // enable new simplifications and + // simplifications can enable new constant + // evaluation) + let new_e = e + // fold constants and then simplify + .rewrite(&mut const_evaluator)? + .rewrite(&mut simplifier)?; + Ok(new_e) + }) .collect::>>()?; utils::from_plan(plan, &expr, &new_inputs) @@ -111,13 +122,17 @@ impl OptimizerRule for ConstantFolding { } } -struct ConstantRewriter<'a> { +/// Simplifies [`Expr`]s by applying algebraic transformation rules +/// +/// For example +/// `false && col` --> `col` where `col` is a boolean types +struct Simplifier<'a> { /// input schemas schemas: Vec<&'a DFSchemaRef>, execution_props: &'a ExecutionProps, } -impl<'a> ConstantRewriter<'a> { +impl<'a> Simplifier<'a> { fn is_boolean_type(&self, expr: &Expr) -> bool { for schema in &self.schemas { if let Ok(DataType::Boolean) = expr.get_type(schema) { @@ -129,7 +144,7 @@ impl<'a> ConstantRewriter<'a> { } } -impl<'a> ExprRewriter for ConstantRewriter<'a> { +impl<'a> ExprRewriter for Simplifier<'a> { /// rewrite the expression simplifying any constant expressions fn mutate(&mut self, expr: Expr) -> Result { let new_expr = match expr { @@ -204,14 +219,15 @@ impl<'a> ExprRewriter for ConstantRewriter<'a> { }, _ => Expr::BinaryExpr { left, op, right }, }, + // Not(Not(expr)) --> expr Expr::Not(inner) => { - // Not(Not(expr)) --> expr if let Expr::Not(negated_inner) = *inner { *negated_inner } else { Expr::Not(inner) } } + // convert now() --> the time in `ExecutionProps` Expr::ScalarFunction { fun: BuiltinScalarFunction::Now, .. @@ -220,56 +236,8 @@ impl<'a> ExprRewriter for ConstantRewriter<'a> { .query_execution_start_time .timestamp_nanos(), ))), - Expr::ScalarFunction { - fun: BuiltinScalarFunction::ToTimestamp, - args, - } => { - if !args.is_empty() { - match &args[0] { - Expr::Literal(ScalarValue::Utf8(Some(val))) => { - match string_to_timestamp_nanos(val) { - Ok(timestamp) => Expr::Literal( - ScalarValue::TimestampNanosecond(Some(timestamp)), - ), - _ => Expr::ScalarFunction { - fun: BuiltinScalarFunction::ToTimestamp, - args, - }, - } - } - _ => Expr::ScalarFunction { - fun: BuiltinScalarFunction::ToTimestamp, - args, - }, - } - } else { - Expr::ScalarFunction { - fun: BuiltinScalarFunction::ToTimestamp, - args, - } - } - } - Expr::Cast { - expr: inner, - data_type, - } => match inner.as_ref() { - Expr::Literal(val) => { - let scalar_array = val.to_array(); - let cast_array = kernels::cast::cast_with_options( - &scalar_array, - &data_type, - &DEFAULT_CAST_OPTIONS, - )?; - let cast_scalar = ScalarValue::try_from_array(&cast_array, 0)?; - Expr::Literal(cast_scalar) - } - _ => Expr::Cast { - expr: inner, - data_type, - }, - }, expr => { - // no rewrite possible + // no additional rewrites possible expr } }; @@ -280,12 +248,13 @@ impl<'a> ExprRewriter for ConstantRewriter<'a> { #[cfg(test)] mod tests { use super::*; - use crate::logical_plan::{ - col, lit, max, min, DFField, DFSchema, LogicalPlanBuilder, + use crate::{ + assert_contains, + logical_plan::{col, lit, max, min, DFField, DFSchema, LogicalPlanBuilder}, }; use arrow::datatypes::*; - use chrono::{DateTime, Utc}; + use chrono::{DateTime, TimeZone, Utc}; fn test_table_scan() -> Result { let schema = Schema::new(vec![ @@ -310,7 +279,7 @@ mod tests { #[test] fn optimize_expr_not_not() -> Result<()> { let schema = expr_test_schema(); - let mut rewriter = ConstantRewriter { + let mut rewriter = Simplifier { schemas: vec![&schema], execution_props: &ExecutionProps::new(), }; @@ -326,7 +295,7 @@ mod tests { #[test] fn optimize_expr_null_comparison() -> Result<()> { let schema = expr_test_schema(); - let mut rewriter = ConstantRewriter { + let mut rewriter = Simplifier { schemas: vec![&schema], execution_props: &ExecutionProps::new(), }; @@ -362,7 +331,7 @@ mod tests { #[test] fn optimize_expr_eq() -> Result<()> { let schema = expr_test_schema(); - let mut rewriter = ConstantRewriter { + let mut rewriter = Simplifier { schemas: vec![&schema], execution_props: &ExecutionProps::new(), }; @@ -393,7 +362,7 @@ mod tests { #[test] fn optimize_expr_eq_skip_nonboolean_type() -> Result<()> { let schema = expr_test_schema(); - let mut rewriter = ConstantRewriter { + let mut rewriter = Simplifier { schemas: vec![&schema], execution_props: &ExecutionProps::new(), }; @@ -433,7 +402,7 @@ mod tests { #[test] fn optimize_expr_not_eq() -> Result<()> { let schema = expr_test_schema(); - let mut rewriter = ConstantRewriter { + let mut rewriter = Simplifier { schemas: vec![&schema], execution_props: &ExecutionProps::new(), }; @@ -469,7 +438,7 @@ mod tests { #[test] fn optimize_expr_not_eq_skip_nonboolean_type() -> Result<()> { let schema = expr_test_schema(); - let mut rewriter = ConstantRewriter { + let mut rewriter = Simplifier { schemas: vec![&schema], execution_props: &ExecutionProps::new(), }; @@ -505,7 +474,7 @@ mod tests { #[test] fn optimize_expr_case_when_then_else() -> Result<()> { let schema = expr_test_schema(); - let mut rewriter = ConstantRewriter { + let mut rewriter = Simplifier { schemas: vec![&schema], execution_props: &ExecutionProps::new(), }; @@ -668,6 +637,20 @@ mod tests { Ok(()) } + // expect optimizing will result in an error, returning the error string + fn get_optimized_plan_err(plan: &LogicalPlan, date_time: &DateTime) -> String { + let rule = ConstantFolding::new(); + let execution_props = ExecutionProps { + query_execution_start_time: *date_time, + }; + + let err = rule + .optimize(plan, &execution_props) + .expect_err("expected optimization to fail"); + + err.to_string() + } + fn get_optimized_plan_formatted( plan: &LogicalPlan, date_time: &DateTime, @@ -683,15 +666,19 @@ mod tests { return format!("{:?}", optimized_plan); } + /// Create a to_timestamp expr + fn to_timestamp_expr(arg: impl Into) -> Expr { + Expr::ScalarFunction { + args: vec![lit(arg.into())], + fun: BuiltinScalarFunction::ToTimestamp, + } + } + #[test] - fn to_timestamp_expr() { + fn to_timestamp_expr_folded() { let table_scan = test_table_scan().unwrap(); - let proj = vec![Expr::ScalarFunction { - args: vec![Expr::Literal(ScalarValue::Utf8(Some( - "2020-09-08T12:00:00+00:00".to_string(), - )))], - fun: BuiltinScalarFunction::ToTimestamp, - }]; + let proj = vec![to_timestamp_expr("2020-09-08T12:00:00+00:00")]; + let plan = LogicalPlanBuilder::from(table_scan) .project(proj) .unwrap() @@ -701,55 +688,30 @@ mod tests { let expected = "Projection: TimestampNanosecond(1599566400000000000)\ \n TableScan: test projection=None" .to_string(); - let actual = get_optimized_plan_formatted(&plan, &chrono::Utc::now()); + let actual = get_optimized_plan_formatted(&plan, &Utc::now()); assert_eq!(expected, actual); } #[test] fn to_timestamp_expr_wrong_arg() { let table_scan = test_table_scan().unwrap(); - let proj = vec![Expr::ScalarFunction { - args: vec![Expr::Literal(ScalarValue::Utf8(Some( - "I'M NOT A TIMESTAMP".to_string(), - )))], - fun: BuiltinScalarFunction::ToTimestamp, - }]; - let plan = LogicalPlanBuilder::from(table_scan) - .project(proj) - .unwrap() - .build() - .unwrap(); - - let expected = "Projection: totimestamp(Utf8(\"I\'M NOT A TIMESTAMP\"))\ - \n TableScan: test projection=None"; - let actual = get_optimized_plan_formatted(&plan, &chrono::Utc::now()); - assert_eq!(expected, actual); - } - - #[test] - fn to_timestamp_expr_no_arg() { - let table_scan = test_table_scan().unwrap(); - let proj = vec![Expr::ScalarFunction { - args: vec![], - fun: BuiltinScalarFunction::ToTimestamp, - }]; + let proj = vec![to_timestamp_expr("I'M NOT A TIMESTAMP")]; let plan = LogicalPlanBuilder::from(table_scan) .project(proj) .unwrap() .build() .unwrap(); - let expected = "Projection: totimestamp()\ - \n TableScan: test projection=None"; - let actual = get_optimized_plan_formatted(&plan, &chrono::Utc::now()); - assert_eq!(expected, actual); + let expected = "Error parsing 'I'M NOT A TIMESTAMP' as timestamp"; + let actual = get_optimized_plan_err(&plan, &Utc::now()); + assert_contains!(actual, expected); } #[test] fn cast_expr() { let table_scan = test_table_scan().unwrap(); let proj = vec![Expr::Cast { - expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some("0".to_string())))), + expr: Box::new(lit("0")), data_type: DataType::Int32, }]; let plan = LogicalPlanBuilder::from(table_scan) @@ -760,7 +722,7 @@ mod tests { let expected = "Projection: Int32(0)\ \n TableScan: test projection=None"; - let actual = get_optimized_plan_formatted(&plan, &chrono::Utc::now()); + let actual = get_optimized_plan_formatted(&plan, &Utc::now()); assert_eq!(expected, actual); } @@ -768,7 +730,7 @@ mod tests { fn cast_expr_wrong_arg() { let table_scan = test_table_scan().unwrap(); let proj = vec![Expr::Cast { - expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some("".to_string())))), + expr: Box::new(lit("")), data_type: DataType::Int32, }]; let plan = LogicalPlanBuilder::from(table_scan) @@ -777,20 +739,24 @@ mod tests { .build() .unwrap(); - let expected = "Projection: Int32(NULL)\ - \n TableScan: test projection=None"; - let actual = get_optimized_plan_formatted(&plan, &chrono::Utc::now()); - assert_eq!(expected, actual); + let expected = + "Cannot cast string '' to value of arrow::datatypes::types::Int32Type type"; + let actual = get_optimized_plan_err(&plan, &Utc::now()); + assert_contains!(actual, expected); + } + + fn now_expr() -> Expr { + Expr::ScalarFunction { + args: vec![], + fun: BuiltinScalarFunction::Now, + } } #[test] fn single_now_expr() { let table_scan = test_table_scan().unwrap(); - let proj = vec![Expr::ScalarFunction { - args: vec![], - fun: BuiltinScalarFunction::Now, - }]; - let time = chrono::Utc::now(); + let proj = vec![now_expr()]; + let time = Utc::now(); let plan = LogicalPlanBuilder::from(table_scan) .project(proj) .unwrap() @@ -810,19 +776,10 @@ mod tests { #[test] fn multiple_now_expr() { let table_scan = test_table_scan().unwrap(); - let time = chrono::Utc::now(); + let time = Utc::now(); let proj = vec![ - Expr::ScalarFunction { - args: vec![], - fun: BuiltinScalarFunction::Now, - }, - Expr::Alias( - Box::new(Expr::ScalarFunction { - args: vec![], - fun: BuiltinScalarFunction::Now, - }), - "t2".to_string(), - ), + now_expr(), + Expr::Alias(Box::new(now_expr()), "t2".to_string()), ]; let plan = LogicalPlanBuilder::from(table_scan) .project(proj) @@ -830,6 +787,7 @@ mod tests { .build() .unwrap(); + // expect the same timestamp appears in both exprs let actual = get_optimized_plan_formatted(&plan, &time); let expected = format!( "Projection: TimestampNanosecond({}), TimestampNanosecond({}) AS t2\ @@ -840,4 +798,59 @@ mod tests { assert_eq!(actual, expected); } + + #[test] + fn simplify_and_eval() { + // demonstrate a case where the evaluation needs to run prior + // to the simplifier for it to work + let table_scan = test_table_scan().unwrap(); + let time = Utc::now(); + // (true or false) != col --> !col + let proj = vec![lit(true).or(lit(false)).not_eq(col("a"))]; + let plan = LogicalPlanBuilder::from(table_scan) + .project(proj) + .unwrap() + .build() + .unwrap(); + + let actual = get_optimized_plan_formatted(&plan, &time); + let expected = "Projection: NOT #test.a\ + \n TableScan: test projection=None"; + + assert_eq!(actual, expected); + } + + fn cast_to_int64_expr(expr: Expr) -> Expr { + Expr::Cast { + expr: expr.into(), + data_type: DataType::Int64, + } + } + + #[test] + fn now_less_than_timestamp() { + let table_scan = test_table_scan().unwrap(); + + let ts_string = "2020-09-08T12:05:00+00:00"; + let time = chrono::Utc.timestamp_nanos(1599566400000000000i64); + + // now() < cast(to_timestamp(...) as int) + 5000000000 + let plan = LogicalPlanBuilder::from(table_scan) + .filter( + now_expr() + .lt(cast_to_int64_expr(to_timestamp_expr(ts_string)) + lit(50000)), + ) + .unwrap() + .build() + .unwrap(); + + // TODO constant folder hould be able to run again and fold + // this whole thing down + // TODO add ticket + let expected = "Filter: TimestampNanosecond(1599566400000000000) < CAST(totimestamp(Utf8(\"2020-09-08T12:05:00+00:00\")) AS Int64) + Int32(50000)\ + \n TableScan: test projection=None"; + let actual = get_optimized_plan_formatted(&plan, &time); + + assert_eq!(expected, actual); + } } diff --git a/datafusion/src/optimizer/utils.rs b/datafusion/src/optimizer/utils.rs index 6e64bf39b2e2..47900fb0c98e 100644 --- a/datafusion/src/optimizer/utils.rs +++ b/datafusion/src/optimizer/utils.rs @@ -17,12 +17,18 @@ //! Collection of utility functions that are leveraged by the query optimizer rules +use arrow::array::new_null_array; +use arrow::datatypes::{DataType, Field, Schema}; +use arrow::record_batch::RecordBatch; + use super::optimizer::OptimizerRule; -use crate::execution::context::ExecutionProps; +use crate::execution::context::{ExecutionContextState, ExecutionProps}; use crate::logical_plan::{ - build_join_schema, Column, DFSchemaRef, Expr, LogicalPlan, LogicalPlanBuilder, - Operator, Partitioning, Recursion, + build_join_schema, Column, DFSchema, DFSchemaRef, Expr, ExprRewriter, LogicalPlan, + LogicalPlanBuilder, Operator, Partitioning, Recursion, RewriteRecursion, }; +use crate::physical_plan::functions::Volatility; +use crate::physical_plan::planner::DefaultPhysicalPlanner; use crate::prelude::lit; use crate::scalar::ScalarValue; use crate::{ @@ -468,11 +474,180 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result { } } +/// Partially evaluate `Expr`s so constant subtrees are evaluated at plan time. +/// +/// Note it does not handle other algebriac rewrites such as `(a and false)` --> `a` +/// +/// ``` +/// # use datafusion::prelude::*; +/// # use datafusion::optimizer::utils::ConstEvaluator; +/// let mut const_evaluator = ConstEvaluator::new(); +/// +/// // (1 + 2) + a +/// let expr = (lit(1) + lit(2)) + col("a"); +/// +/// // is rewritten to (3 + a); +/// let rewritten = expr.rewrite(&mut const_evaluator).unwrap(); +/// assert_eq!(rewritten, lit(3) + col("a")); +/// ``` +pub struct ConstEvaluator { + /// can_evaluate is used during the depth-first-search of the + /// Expr tree to track if any siblings (or their descendants) were + /// non evaluatable (e.g. had a column reference or volatile + /// function) + /// + /// Specifically, can_evaluate[N] represents the state of + /// traversal when we are N levels deep in the tree, one entry for + /// this Expr and each of its parents. + /// + /// After visiting all siblings if can_evauate.top() is true, that + /// means there were no non evaluatable siblings (or their + /// descendants) so this Expr can be evaluated + can_evaluate: Vec, + + ctx_state: ExecutionContextState, + planner: DefaultPhysicalPlanner, + input_schema: DFSchema, + input_batch: RecordBatch, +} + +impl ExprRewriter for ConstEvaluator { + fn pre_visit(&mut self, expr: &Expr) -> Result { + // Default to being able to evaluate this node + self.can_evaluate.push(true); + + // if this expr is not ok to evaluate, mark entire parent + // stack as not ok (as all parents have at least one child or + // descendant that is non evaluateable + + if !Self::can_evaluate(expr) { + // walk back up stack, marking first parent that is not mutable + let parent_iter = self.can_evaluate.iter_mut().rev(); + for p in parent_iter { + if !*p { + // optimization: if we find an element on the + // stack already marked, know all elements above are also marked + break; + } + *p = false; + } + } + + // NB: do not short circuit recursion even if we find a non + // evaluatable node (so we can fold other children, args to + // functions, etc) + Ok(RewriteRecursion::Continue) + } + + fn mutate(&mut self, expr: Expr) -> Result { + if self.can_evaluate.pop().unwrap() { + let scalar = self.evaluate_to_scalar(expr)?; + Ok(Expr::Literal(scalar)) + } else { + Ok(expr) + } + } +} + +impl ConstEvaluator { + /// Create a new `ConstantEvaluator`. + pub fn new() -> Self { + let planner = DefaultPhysicalPlanner::default(); + let ctx_state = ExecutionContextState::new(); + let input_schema = DFSchema::empty(); + + // The dummy column name uis used doesn't matter as only scalar + // expressions will be evaluated + static DUMMY_COL_NAME: &str = "."; + let schema = + Schema::new(vec![Field::new(DUMMY_COL_NAME, DataType::Float64, true)]); + + let col = new_null_array(&DataType::Float64, 1); + + let input_batch = + RecordBatch::try_new(std::sync::Arc::new(schema), vec![col]).unwrap(); + + Self { + can_evaluate: vec![], + ctx_state, + planner, + input_schema, + input_batch, + } + } + + /// Can a function of the specified volatility be evaluated? + fn volatility_ok(volatility: Volatility) -> bool { + match volatility { + Volatility::Immutable => true, + // To evaluate stable functions, need ExecutionProps, see + // Simplifier for code that does that. + Volatility::Stable => false, + Volatility::Volatile => false, + } + } + + /// Can the expression be evaluated (assuming all of its children + /// can also be evaluated)? + fn can_evaluate(expr: &Expr) -> bool { + // check for reasons we can't evaluate this node + match expr { + // Has no runtime cost, but needed during planning + Expr::Alias(..) => false, + Expr::AggregateFunction { .. } => false, + Expr::AggregateUDF { .. } => false, + // TODO handle in constantant propagator pass + Expr::ScalarVariable(_) => false, + Expr::Column(_) => false, + Expr::ScalarFunction { fun, .. } => Self::volatility_ok(fun.volatility()), + Expr::ScalarUDF { fun, .. } => Self::volatility_ok(fun.signature.volatility), + _ => true, + } + } + + /// Internal helper to evaluates an Expr + fn evaluate_to_scalar(&self, expr: Expr) -> Result { + if let Expr::Literal(s) = expr { + return Ok(s); + } + + let phys_expr = self.planner.create_physical_expr( + &expr, + &self.input_schema, + &self.input_batch.schema(), + &self.ctx_state, + )?; + let col_val = phys_expr.evaluate(&self.input_batch)?; + match col_val { + crate::physical_plan::ColumnarValue::Array(a) => { + if a.len() != 1 { + Err(DataFusionError::Execution(format!( + "Could not evaluate the expressison, found a result of length {}", + a.len() + ))) + } else { + Ok(ScalarValue::try_from_array(&a, 0)?) + } + } + crate::physical_plan::ColumnarValue::Scalar(s) => Ok(s), + } + } +} + #[cfg(test)] mod tests { use super::*; - use crate::logical_plan::col; - use arrow::datatypes::DataType; + use crate::{ + logical_plan::{col, create_udf, lit_timestamp_nano}, + physical_plan::{ + functions::{make_scalar_function, BuiltinScalarFunction}, + udf::ScalarUDF, + }, + }; + use arrow::{ + array::{ArrayRef, Int32Array}, + datatypes::DataType, + }; use std::collections::HashSet; #[test] @@ -496,4 +671,200 @@ mod tests { assert!(accum.contains(&Column::from_name("a"))); Ok(()) } + + #[test] + fn test_const_evaluator() { + // true --> true + test_evaluate(lit(true), lit(true)); + // true or true --> true + test_evaluate(lit(true).or(lit(true)), lit(true)); + // true or false --> true + test_evaluate(lit(true).or(lit(false)), lit(true)); + + // "foo" == "foo" --> true + test_evaluate(lit("foo").eq(lit("foo")), lit(true)); + // "foo" != "foo" --> false + test_evaluate(lit("foo").not_eq(lit("foo")), lit(false)); + + // c = 1 --> c = 1 + test_evaluate(col("c").eq(lit(1)), col("c").eq(lit(1))); + // c = 1 + 2 --> c + 3 + test_evaluate(col("c").eq(lit(1) + lit(2)), col("c").eq(lit(3))); + // (foo != foo) OR (c = 1) --> false OR (c = 1) + test_evaluate( + (lit("foo").not_eq(lit("foo"))).or(col("c").eq(lit(1))), + lit(false).or(col("c").eq(lit(1))), + ); + } + + #[test] + fn test_const_evaluator_scalar_functions() { + // concat("foo", "bar") --> "foobar" + let expr = Expr::ScalarFunction { + args: vec![lit("foo"), lit("bar")], + fun: BuiltinScalarFunction::Concat, + }; + test_evaluate(expr, lit("foobar")); + + // ensure arguments are also constant folded + // concat("foo", concat("bar", "baz")) --> "foobarbaz" + let concat1 = Expr::ScalarFunction { + args: vec![lit("bar"), lit("baz")], + fun: BuiltinScalarFunction::Concat, + }; + let expr = Expr::ScalarFunction { + args: vec![lit("foo"), concat1], + fun: BuiltinScalarFunction::Concat, + }; + test_evaluate(expr, lit("foobarbaz")); + + // Check non string arguments + // to_timestamp("2020-09-08T12:00:00+00:00") --> timestamp(1599566400000000000i64) + let expr = Expr::ScalarFunction { + args: vec![lit("2020-09-08T12:00:00+00:00")], + fun: BuiltinScalarFunction::ToTimestamp, + }; + test_evaluate(expr, lit_timestamp_nano(1599566400000000000i64)); + + // check that non foldable arguments are folded + // to_timestamp(a) --> to_timestamp(a) [no rewrite possible] + let expr = Expr::ScalarFunction { + args: vec![col("a")], + fun: BuiltinScalarFunction::ToTimestamp, + }; + test_evaluate(expr.clone(), expr); + + // check that non foldable arguments are folded + // to_timestamp(a) --> to_timestamp(a) [no rewrite possible] + let expr = Expr::ScalarFunction { + args: vec![col("a")], + fun: BuiltinScalarFunction::ToTimestamp, + }; + test_evaluate(expr.clone(), expr); + + // volatile / stable functions should not be evaluated + // rand() + (1 + 2) --> rand() + 3 + let fun = BuiltinScalarFunction::Random; + assert_eq!(fun.volatility(), Volatility::Volatile); + let rand = Expr::ScalarFunction { args: vec![], fun }; + let expr = rand.clone() + (lit(1) + lit(2)); + let expected = rand + lit(3); + test_evaluate(expr, expected); + + // parenthesization matters: can't rewrite + // (rand() + 1) + 2 --> (rand() + 1) + 2) + let fun = BuiltinScalarFunction::Random; + assert_eq!(fun.volatility(), Volatility::Volatile); + let rand = Expr::ScalarFunction { args: vec![], fun }; + let expr = (rand + lit(1)) + lit(2); + test_evaluate(expr.clone(), expr); + + // volatile / stable functions should not be evaluated + // now() + (1 + 2) --> now() + 3 + let fun = BuiltinScalarFunction::Now; + assert_eq!(fun.volatility(), Volatility::Stable); + let now = Expr::ScalarFunction { args: vec![], fun }; + let expr = now.clone() + (lit(1) + lit(2)); + let expected = now + lit(3); + test_evaluate(expr, expected); + } + + #[test] + fn test_const_evaluator_udfs() { + let args = vec![lit(1) + lit(2), lit(30) + lit(40)]; + let folded_args = vec![lit(3), lit(70)]; + + // immutable UDF should get folded + // udf_add(1+2, 30+40) --> 70 + let expr = Expr::ScalarUDF { + args: args.clone(), + fun: make_udf_add(Volatility::Immutable), + }; + test_evaluate(expr, lit(73)); + + // stable UDF should have args folded + // udf_add(1+2, 30+40) --> udf_add(3, 70) + let fun = make_udf_add(Volatility::Stable); + let expr = Expr::ScalarUDF { + args: args.clone(), + fun: Arc::clone(&fun), + }; + let expected_expr = Expr::ScalarUDF { + args: folded_args.clone(), + fun: Arc::clone(&fun), + }; + test_evaluate(expr, expected_expr); + + // volatile UDF should have args folded + // udf_add(1+2, 30+40) --> udf_add(3, 70) + let fun = make_udf_add(Volatility::Volatile); + let expr = Expr::ScalarUDF { + args, + fun: Arc::clone(&fun), + }; + let expected_expr = Expr::ScalarUDF { + args: folded_args, + fun: Arc::clone(&fun), + }; + test_evaluate(expr, expected_expr); + } + + // Make a UDF that adds its two values together, with the specified volatility + fn make_udf_add(volatility: Volatility) -> Arc { + let input_types = vec![DataType::Int32, DataType::Int32]; + let return_type = Arc::new(DataType::Int32); + + let fun = |args: &[ArrayRef]| { + let arg0 = &args[0] + .as_any() + .downcast_ref::() + .expect("cast failed"); + let arg1 = &args[1] + .as_any() + .downcast_ref::() + .expect("cast failed"); + + // 2. perform the computation + let array = arg0 + .iter() + .zip(arg1.iter()) + .map(|args| { + if let (Some(arg0), Some(arg1)) = args { + Some(arg0 + arg1) + } else { + // one or both args were Null + None + } + }) + .collect::(); + + Ok(Arc::new(array) as ArrayRef) + }; + + let fun = make_scalar_function(fun); + Arc::new(create_udf( + "udf_add", + input_types, + return_type, + volatility, + fun, + )) + } + + // udfs + // validate that even a volatile function's arguments will be evaluated + + fn test_evaluate(input_expr: Expr, expected_expr: Expr) { + let mut const_evaluator = ConstEvaluator::new(); + let evaluated_expr = input_expr + .clone() + .rewrite(&mut const_evaluator) + .expect("successfully evaluated"); + + assert_eq!( + evaluated_expr, expected_expr, + "Mismatch evaluating {}\n Expected:{}\n Got:{}", + input_expr, expected_expr, evaluated_expr + ); + } } diff --git a/datafusion/src/test_util.rs b/datafusion/src/test_util.rs index 0c9498acf920..03e0054b1570 100644 --- a/datafusion/src/test_util.rs +++ b/datafusion/src/test_util.rs @@ -88,6 +88,52 @@ macro_rules! assert_batches_sorted_eq { }; } +/// A macro to assert that one string is contained within another with +/// a nice error message if they are not. +/// +/// Usage: `assert_contains!(actual, expected)` +/// +/// Is a macro so test error +/// messages are on the same line as the failure; +/// +/// Both arguments must be convertable into Strings (Into) +#[macro_export] +macro_rules! assert_contains { + ($ACTUAL: expr, $EXPECTED: expr) => { + let actual_value: String = $ACTUAL.into(); + let expected_value: String = $EXPECTED.into(); + assert!( + actual_value.contains(&expected_value), + "Can not find expected in actual.\n\nExpected:\n{}\n\nActual:\n{}", + expected_value, + actual_value + ); + }; +} + +/// A macro to assert that one string is NOT contained within another with +/// a nice error message if they are are. +/// +/// Usage: `assert_not_contains!(actual, unexpected)` +/// +/// Is a macro so test error +/// messages are on the same line as the failure; +/// +/// Both arguments must be convertable into Strings (Into) +#[macro_export] +macro_rules! assert_not_contains { + ($ACTUAL: expr, $UNEXPECTED: expr) => { + let actual_value: String = $ACTUAL.into(); + let unexpected_value: String = $UNEXPECTED.into(); + assert!( + !actual_value.contains(&unexpected_value), + "Found unexpected in actual.\n\nUnexpected:\n{}\n\nActual:\n{}", + unexpected_value, + actual_value + ); + }; +} + /// Returns the arrow test data directory, which is by default stored /// in a git submodule rooted at `testing/data`. /// diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index 283033bcde4e..f7c0f82d3160 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -34,6 +34,8 @@ use arrow::{ use datafusion::assert_batches_eq; use datafusion::assert_batches_sorted_eq; +use datafusion::assert_contains; +use datafusion::assert_not_contains; use datafusion::logical_plan::LogicalPlan; use datafusion::physical_plan::functions::Volatility; use datafusion::physical_plan::metrics::MetricValue; @@ -47,50 +49,6 @@ use datafusion::{ }; use datafusion::{execution::context::ExecutionContext, physical_plan::displayable}; -/// A macro to assert that one string is contained within another with -/// a nice error message if they are not. -/// -/// Usage: `assert_contains!(actual, expected)` -/// -/// Is a macro so test error -/// messages are on the same line as the failure; -/// -/// Both arguments must be convertable into Strings (Into) -macro_rules! assert_contains { - ($ACTUAL: expr, $EXPECTED: expr) => { - let actual_value: String = $ACTUAL.into(); - let expected_value: String = $EXPECTED.into(); - assert!( - actual_value.contains(&expected_value), - "Can not find expected in actual.\n\nExpected:\n{}\n\nActual:\n{}", - expected_value, - actual_value - ); - }; -} - -/// A macro to assert that one string is NOT contained within another with -/// a nice error message if they are are. -/// -/// Usage: `assert_not_contains!(actual, unexpected)` -/// -/// Is a macro so test error -/// messages are on the same line as the failure; -/// -/// Both arguments must be convertable into Strings (Into) -macro_rules! assert_not_contains { - ($ACTUAL: expr, $UNEXPECTED: expr) => { - let actual_value: String = $ACTUAL.into(); - let unexpected_value: String = $UNEXPECTED.into(); - assert!( - !actual_value.contains(&unexpected_value), - "Found unexpected in actual.\n\nUnexpected:\n{}\n\nActual:\n{}", - unexpected_value, - actual_value - ); - }; -} - #[tokio::test] async fn nyc() -> Result<()> { // schema for nyxtaxi csv files @@ -598,7 +556,7 @@ async fn select_distinct_simple_4() { async fn select_distinct_from() { let mut ctx = ExecutionContext::new(); - let sql = "select + let sql = "select 1 IS DISTINCT FROM CAST(NULL as INT) as a, 1 IS DISTINCT FROM 1 as b, 1 IS NOT DISTINCT FROM CAST(NULL as INT) as c, @@ -621,7 +579,7 @@ async fn select_distinct_from() { async fn select_distinct_from_utf8() { let mut ctx = ExecutionContext::new(); - let sql = "select + let sql = "select 'x' IS DISTINCT FROM NULL as a, 'x' IS DISTINCT FROM 'x' as b, 'x' IS NOT DISTINCT FROM NULL as c, From da9f3d81037df619a949887a34918cfcd8665041 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 21 Oct 2021 10:39:56 -0400 Subject: [PATCH 2/4] Better list of evaluatable expressions --- datafusion/src/optimizer/utils.rs | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/datafusion/src/optimizer/utils.rs b/datafusion/src/optimizer/utils.rs index 47900fb0c98e..ac60a3679591 100644 --- a/datafusion/src/optimizer/utils.rs +++ b/datafusion/src/optimizer/utils.rs @@ -587,21 +587,38 @@ impl ConstEvaluator { } } - /// Can the expression be evaluated (assuming all of its children - /// can also be evaluated)? + /// Can the expression be evaluated at plan time, (assuming all of + /// its children can also be evaluated)? fn can_evaluate(expr: &Expr) -> bool { // check for reasons we can't evaluate this node + // + // NOTE all expr types are listed here so when new ones are + // added they can be checked for their ability to be evaluated + // at plan time match expr { // Has no runtime cost, but needed during planning Expr::Alias(..) => false, Expr::AggregateFunction { .. } => false, Expr::AggregateUDF { .. } => false, - // TODO handle in constantant propagator pass Expr::ScalarVariable(_) => false, Expr::Column(_) => false, Expr::ScalarFunction { fun, .. } => Self::volatility_ok(fun.volatility()), Expr::ScalarUDF { fun, .. } => Self::volatility_ok(fun.signature.volatility), - _ => true, + Expr::WindowFunction { .. } => false, + Expr::Sort { .. } => false, + Expr::Wildcard => false, + + Expr::Literal(_) => true, + Expr::BinaryExpr { .. } => true, + Expr::Not(_) => true, + Expr::IsNotNull(_) => true, + Expr::IsNull(_) => true, + Expr::Negative(_) => true, + Expr::Between { .. } => true, + Expr::Case { .. } => true, + Expr::Cast { .. } => true, + Expr::TryCast { .. } => true, + Expr::InList { .. } => true, } } From da4f1adf51e749b7228013e24dd2e387523364b5 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 25 Oct 2021 14:48:03 -0400 Subject: [PATCH 3/4] Fixup comments --- datafusion/src/optimizer/constant_folding.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/datafusion/src/optimizer/constant_folding.rs b/datafusion/src/optimizer/constant_folding.rs index 3adbf2601969..74fdc729eb69 100644 --- a/datafusion/src/optimizer/constant_folding.rs +++ b/datafusion/src/optimizer/constant_folding.rs @@ -126,7 +126,7 @@ impl OptimizerRule for ConstantFolding { /// Simplifies [`Expr`]s by applying algebraic transformation rules /// /// For example -/// `false && col` --> `col` where `col` is a boolean types +/// `true && col` --> `col` where `col` is a boolean types struct Simplifier<'a> { /// input schemas schemas: Vec<&'a DFSchemaRef>, @@ -845,9 +845,9 @@ mod tests { .build() .unwrap(); - // TODO constant folder hould be able to run again and fold - // this whole thing down - // TODO add ticket + // Note that constant folder should be able to run again and fold + // this whole expression down to a single constant; + // https://github.com/apache/arrow-datafusion/issues/1160 let expected = "Filter: TimestampNanosecond(1599566400000000000) < CAST(totimestamp(Utf8(\"2020-09-08T12:05:00+00:00\")) AS Int64) + Int32(50000)\ \n TableScan: test projection=None"; let actual = get_optimized_plan_formatted(&plan, &time); From b2d046555fc059a33a5bb188fe43b80d607cba03 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 25 Oct 2021 14:54:10 -0400 Subject: [PATCH 4/4] Use Null type --- datafusion/src/optimizer/utils.rs | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/datafusion/src/optimizer/utils.rs b/datafusion/src/optimizer/utils.rs index 8a2b83b1b3ba..fdc9a173ed5e 100644 --- a/datafusion/src/optimizer/utils.rs +++ b/datafusion/src/optimizer/utils.rs @@ -581,14 +581,13 @@ impl ConstEvaluator { let ctx_state = ExecutionContextState::new(); let input_schema = DFSchema::empty(); - // The dummy column name uis used doesn't matter as only scalar - // expressions will be evaluated + // The dummy column name is unused and doesn't matter as only + // expressions without column references can be evaluated static DUMMY_COL_NAME: &str = "."; - let schema = - Schema::new(vec![Field::new(DUMMY_COL_NAME, DataType::Float64, true)]); - - let col = new_null_array(&DataType::Float64, 1); + let schema = Schema::new(vec![Field::new(DUMMY_COL_NAME, DataType::Null, true)]); + // Need a single "input" row to produce a single output row + let col = new_null_array(&DataType::Null, 1); let input_batch = RecordBatch::try_new(std::sync::Arc::new(schema), vec![col]).unwrap();