Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 34 additions & 65 deletions datafusion/optimizer/src/analyzer/function_rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@

use super::AnalyzerRule;
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TreeNodeRewriter};
use datafusion_common::tree_node::{Transformed, TreeNode};
use datafusion_common::{DFSchema, Result};
use datafusion_expr::expr_rewriter::{rewrite_preserving_name, FunctionRewrite};

use crate::utils::NamePreserver;
use datafusion_expr::expr_rewriter::FunctionRewrite;
use datafusion_expr::utils::merge_schema;
use datafusion_expr::{Expr, LogicalPlan};
use datafusion_expr::LogicalPlan;
use std::sync::Arc;

/// Analyzer rule that invokes [`FunctionRewrite`]s on expressions
Expand All @@ -37,86 +39,53 @@ impl ApplyFunctionRewrites {
pub fn new(function_rewrites: Vec<Arc<dyn FunctionRewrite + Send + Sync>>) -> Self {
Self { function_rewrites }
}
}

impl AnalyzerRule for ApplyFunctionRewrites {
fn name(&self) -> &str {
"apply_function_rewrites"
}

fn analyze(&self, plan: LogicalPlan, options: &ConfigOptions) -> Result<LogicalPlan> {
self.analyze_internal(&plan, options)
}
}

impl ApplyFunctionRewrites {
fn analyze_internal(
/// Rewrite a single plan, and all its expressions using the provided rewriters
fn rewrite_plan(
&self,
plan: &LogicalPlan,
plan: LogicalPlan,
options: &ConfigOptions,
) -> Result<LogicalPlan> {
// optimize child plans first
let new_inputs = plan
.inputs()
.iter()
.map(|p| self.analyze_internal(p, options))
.collect::<Result<Vec<_>>>()?;

) -> Result<Transformed<LogicalPlan>> {
// get schema representing all available input fields. This is used for data type
// resolution only, so order does not matter here
let mut schema = merge_schema(new_inputs.iter().collect());
let mut schema = merge_schema(plan.inputs());

if let LogicalPlan::TableScan(ts) = plan {
if let LogicalPlan::TableScan(ts) = &plan {
let source_schema = DFSchema::try_from_qualified_schema(
ts.table_name.clone(),
&ts.source.schema(),
)?;
schema.merge(&source_schema);
}

let mut expr_rewrite = OperatorToFunctionRewriter {
function_rewrites: &self.function_rewrites,
options,
schema: &schema,
};
let name_preserver = NamePreserver::new(&plan);

plan.map_expressions(|expr| {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using the new map_expressions API it is quite straightforward to rewrite these expressions (and it doesn't copy them!)

let original_name = name_preserver.save(&expr)?;

let new_expr = plan
.expressions()
.into_iter()
.map(|expr| {
// ensure names don't change:
// https://github.com/apache/arrow-datafusion/issues/3555
rewrite_preserving_name(expr, &mut expr_rewrite)
})
.collect::<Result<Vec<_>>>()?;
// recursively transform the expression, applying the rewrites at each step
let result = expr.transform_up(&|expr| {
let mut result = Transformed::no(expr);
for rewriter in self.function_rewrites.iter() {
result = result.transform_data(|expr| {
rewriter.rewrite(expr, &schema, options)
})?;
}
Ok(result)
})?;

plan.with_new_exprs(new_expr, new_inputs)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this copies the plan + expressions

result.map_data(|expr| original_name.restore(expr))
})
}
}
struct OperatorToFunctionRewriter<'a> {
function_rewrites: &'a [Arc<dyn FunctionRewrite + Send + Sync>],
options: &'a ConfigOptions,
schema: &'a DFSchema,
}

impl<'a> TreeNodeRewriter for OperatorToFunctionRewriter<'a> {
type Node = Expr;

fn f_up(&mut self, mut expr: Expr) -> Result<Transformed<Expr>> {
// apply transforms one by one
let mut transformed = false;
for rewriter in self.function_rewrites.iter() {
let result = rewriter.rewrite(expr, self.schema, self.options)?;
if result.transformed {
transformed = true;
}
expr = result.data
}
impl AnalyzerRule for ApplyFunctionRewrites {
fn name(&self) -> &str {
"apply_function_rewrites"
}

Ok(if transformed {
Transformed::yes(expr)
} else {
Transformed::no(expr)
})
fn analyze(&self, plan: LogicalPlan, options: &ConfigOptions) -> Result<LogicalPlan> {
plan.transform_up_with_subqueries(&|plan| self.rewrite_plan(plan, options))
.map(|res| res.data)
}
}
44 changes: 44 additions & 0 deletions datafusion/optimizer/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -288,3 +288,47 @@ pub fn only_or_err<T>(slice: &[T]) -> Result<&T> {
pub fn merge_schema(inputs: Vec<&LogicalPlan>) -> DFSchema {
expr_utils::merge_schema(inputs)
}

/// Handles ensuring the name of rewritten expressions is not changed.
///
/// For example, if an expression `1 + 2` is rewritten to `3`, the name of the
/// expression should be preserved: `3 as "1 + 2"`
///
/// See <https://github.com/apache/arrow-datafusion/issues/3555> for details
pub struct NamePreserver {
use_alias: bool,
}

/// If the name of an expression is remembered, it will be preserved when
/// rewriting the expression
pub struct SavedName(Option<String>);

impl NamePreserver {
/// Create a new NamePreserver for rewriting the `expr` that is part of the specified plan
pub fn new(plan: &LogicalPlan) -> Self {
Self {
use_alias: !matches!(plan, LogicalPlan::Filter(_) | LogicalPlan::Join(_)),
}
}

pub fn save(&self, expr: &Expr) -> Result<SavedName> {
let original_name = if self.use_alias {
Some(expr.name_for_alias()?)
} else {
None
};

Ok(SavedName(original_name))
}
}

impl SavedName {
/// Ensures the name of the rewritten expression is preserved
pub fn restore(self, expr: Expr) -> Result<Expr> {
let Self(original_name) = self;
match original_name {
Some(name) => expr.alias_if_changed(name),
None => Ok(expr),
}
}
}
55 changes: 55 additions & 0 deletions datafusion/sqllogictest/test_files/subquery.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1060,3 +1060,58 @@ logical_plan
Projection: t.a / Int64(2)Int64(2)t.a AS t.a / Int64(2), t.a / Int64(2)Int64(2)t.a AS t.a / Int64(2) + Int64(1)
--Projection: t.a / Int64(2) AS t.a / Int64(2)Int64(2)t.a
----TableScan: t projection=[a]

###
## Ensure that operators are rewritten in subqueries
###

statement ok
create table foo(x int) as values (1);

# Show input data
query ?
select struct(1, 'b')
----
{c0: 1, c1: b}


query T
select (select struct(1, 'b')['c1']);
----
b

query T
select 'foo' || (select struct(1, 'b')['c1']);
----
foob

query I
SELECT * FROM (VALUES (1), (2))
WHERE column1 IN (SELECT struct(1, 'b')['c0']);
----
1

# also add an expression so the subquery is the output expr
query I
SELECT * FROM (VALUES (1), (2))
WHERE 1+2 = 3 AND column1 IN (SELECT struct(1, 'b')['c0']);
----
1


query I
SELECT * FROM foo
WHERE EXISTS (SELECT * FROM (values (1)) WHERE column1 = foo.x AND struct(1, 'b')['c0'] = 1);
----
1

# also add an expression so the subquery is the output expr
query I
SELECT * FROM foo
WHERE 1+2 = 3 AND EXISTS (SELECT * FROM (values (1)) WHERE column1 = foo.x AND struct(1, 'b')['c0'] = 1);
----
1


statement ok
drop table foo;