Skip to content
57 changes: 35 additions & 22 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -311,33 +311,43 @@ fn coerced_from<'a>(
type_from: &'a DataType,
) -> Option<DataType> {
use self::DataType::*;

match type_into {
// match Dictionary first
match (type_into, type_from) {
// coerced dictionary first
(cur_type, Dictionary(_, value_type)) | (Dictionary(_, value_type), cur_type)
if coerced_from(cur_type, value_type).is_some() =>
{
Some(type_into.clone())
}
// coerced into type_into
Int8 if matches!(type_from, Null | Int8) => Some(type_into.clone()),
Int16 if matches!(type_from, Null | Int8 | Int16 | UInt8) => {
(Int8, _) if matches!(type_from, Null | Int8) => Some(type_into.clone()),
(Int16, _) if matches!(type_from, Null | Int8 | Int16 | UInt8) => {
Some(type_into.clone())
}
Int32 if matches!(type_from, Null | Int8 | Int16 | Int32 | UInt8 | UInt16) => {
(Int32, _)
if matches!(type_from, Null | Int8 | Int16 | Int32 | UInt8 | UInt16) =>
{
Some(type_into.clone())
}
Int64
(Int64, _)
if matches!(
type_from,
Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32
) =>
{
Some(type_into.clone())
}
UInt8 if matches!(type_from, Null | UInt8) => Some(type_into.clone()),
UInt16 if matches!(type_from, Null | UInt8 | UInt16) => Some(type_into.clone()),
UInt32 if matches!(type_from, Null | UInt8 | UInt16 | UInt32) => {
(UInt8, _) if matches!(type_from, Null | UInt8) => Some(type_into.clone()),
(UInt16, _) if matches!(type_from, Null | UInt8 | UInt16) => {
Some(type_into.clone())
}
(UInt32, _) if matches!(type_from, Null | UInt8 | UInt16 | UInt32) => {
Some(type_into.clone())
}
UInt64 if matches!(type_from, Null | UInt8 | UInt16 | UInt32 | UInt64) => {
(UInt64, _) if matches!(type_from, Null | UInt8 | UInt16 | UInt32 | UInt64) => {
Some(type_into.clone())
}
Float32
(Float32, _)
if matches!(
type_from,
Null | Int8
Expand All @@ -353,7 +363,7 @@ fn coerced_from<'a>(
{
Some(type_into.clone())
}
Float64
(Float64, _)
if matches!(
type_from,
Null | Int8
Expand All @@ -371,31 +381,35 @@ fn coerced_from<'a>(
{
Some(type_into.clone())
}
Timestamp(TimeUnit::Nanosecond, None)
(Timestamp(TimeUnit::Nanosecond, None), _)
if matches!(
type_from,
Null | Timestamp(_, None) | Date32 | Utf8 | LargeUtf8
) =>
{
Some(type_into.clone())
}
Interval(_) if matches!(type_from, Utf8 | LargeUtf8) => Some(type_into.clone()),
(Interval(_), _) if matches!(type_from, Utf8 | LargeUtf8) => {
Some(type_into.clone())
}
// Any type can be coerced into strings
Utf8 | LargeUtf8 => Some(type_into.clone()),
Null if can_cast_types(type_from, type_into) => Some(type_into.clone()),
(Utf8 | LargeUtf8, _) => Some(type_into.clone()),
(Null, _) if can_cast_types(type_from, type_into) => Some(type_into.clone()),

List(_) if matches!(type_from, FixedSizeList(_, _)) => Some(type_into.clone()),
(List(_), _) if matches!(type_from, FixedSizeList(_, _)) => {
Some(type_into.clone())
}

// Only accept list and largelist with the same number of dimensions unless the type is Null.
// List or LargeList with different dimensions should be handled in TypeSignature or other places before this
List(_) | LargeList(_)
(List(_) | LargeList(_), _)
if datafusion_common::utils::base_type(type_from).eq(&Null)
|| list_ndims(type_from) == list_ndims(type_into) =>
{
Some(type_into.clone())
}
// should be able to coerce wildcard fixed size list to non wildcard fixed size list
FixedSizeList(f_into, FIXED_SIZE_LIST_WILDCARD) => match type_from {
(FixedSizeList(f_into, FIXED_SIZE_LIST_WILDCARD), _) => match type_from {
FixedSizeList(f_from, size_from) => {
match coerced_from(f_into.data_type(), f_from.data_type()) {
Some(data_type) if &data_type != f_into.data_type() => {
Expand All @@ -410,7 +424,7 @@ fn coerced_from<'a>(
_ => None,
},

Timestamp(unit, Some(tz)) if tz.as_ref() == TIMEZONE_WILDCARD => {
(Timestamp(unit, Some(tz)), _) if tz.as_ref() == TIMEZONE_WILDCARD => {
match type_from {
Timestamp(_, Some(from_tz)) => {
Some(Timestamp(unit.clone(), Some(from_tz.clone())))
Expand All @@ -422,15 +436,14 @@ fn coerced_from<'a>(
_ => None,
}
}
Timestamp(_, Some(_))
(Timestamp(_, Some(_)), _)
if matches!(
type_from,
Null | Timestamp(_, _) | Date32 | Utf8 | LargeUtf8
) =>
{
Some(type_into.clone())
}

// More coerce rules.
// Note that not all rules in `comparison_coercion` can be reused here.
// For example, all numeric types can be coerced into Utf8 for comparison,
Expand Down
129 changes: 89 additions & 40 deletions datafusion/optimizer/src/analyzer/function_rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ use super::AnalyzerRule;
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TreeNodeRewriter};
use datafusion_common::{DFSchema, Result};
use datafusion_expr::expr::{Exists, InSubquery};
use datafusion_expr::expr_rewriter::{rewrite_preserving_name, FunctionRewrite};
use datafusion_expr::utils::merge_schema;
use datafusion_expr::{Expr, LogicalPlan};
use datafusion_expr::{Expr, LogicalPlan, Subquery};
use std::sync::Arc;

/// Analyzer rule that invokes [`FunctionRewrite`]s on expressions
Expand All @@ -45,52 +46,66 @@ impl AnalyzerRule for ApplyFunctionRewrites {
}

fn analyze(&self, plan: LogicalPlan, options: &ConfigOptions) -> Result<LogicalPlan> {
self.analyze_internal(&plan, options)
analyze_internal(&plan, &self.function_rewrites, options)
}
}

impl ApplyFunctionRewrites {
fn analyze_internal(
&self,
plan: &LogicalPlan,
options: &ConfigOptions,
) -> Result<LogicalPlan> {
// optimize child plans first
let new_inputs = plan
.inputs()
.iter()
.map(|p| self.analyze_internal(p, options))
.collect::<Result<Vec<_>>>()?;

// get schema representing all available input fields. This is used for data type
// resolution only, so order does not matter here
let mut schema = merge_schema(new_inputs.iter().collect());

if let LogicalPlan::TableScan(ts) = plan {
let source_schema =
DFSchema::try_from_qualified_schema(&ts.table_name, &ts.source.schema())?;
schema.merge(&source_schema);
}
fn analyze_internal(
plan: &LogicalPlan,
function_rewrites: &[Arc<dyn FunctionRewrite + Send + Sync>],
options: &ConfigOptions,
) -> Result<LogicalPlan> {
// optimize child plans first
let new_inputs = plan
.inputs()
.iter()
.map(|p| analyze_internal(p, function_rewrites, options))
.collect::<Result<Vec<_>>>()?;

let mut expr_rewrite = OperatorToFunctionRewriter {
function_rewrites: &self.function_rewrites,
options,
schema: &schema,
};
// get schema representing all available input fields. This is used for data type
// resolution only, so order does not matter here
let mut schema = merge_schema(new_inputs.iter().collect());

let new_expr = plan
.expressions()
.into_iter()
.map(|expr| {
// ensure names don't change:
// https://github.com/apache/arrow-datafusion/issues/3555
rewrite_preserving_name(expr, &mut expr_rewrite)
})
.collect::<Result<Vec<_>>>()?;

plan.with_new_exprs(new_expr, new_inputs)
if let LogicalPlan::TableScan(ts) = plan {
let source_schema = DFSchema::try_from_qualified_schema(
ts.table_name.clone(),
&ts.source.schema(),
)?;
schema.merge(&source_schema);
}

let mut expr_rewrite = OperatorToFunctionRewriter {
function_rewrites,
options,
schema: &schema,
};

let new_expr = plan
.expressions()
.into_iter()
.map(|expr| {
// ensure names don't change:
// https://github.com/apache/arrow-datafusion/issues/3555
rewrite_preserving_name(expr, &mut expr_rewrite)
})
.collect::<Result<Vec<_>>>()?;

plan.with_new_exprs(new_expr, new_inputs)
}

fn rewrite_subquery(
mut subquery: Subquery,
function_rewrites: &[Arc<dyn FunctionRewrite + Send + Sync>],
options: &ConfigOptions,
) -> Result<Subquery> {
subquery.subquery = Arc::new(analyze_internal(
&subquery.subquery,
function_rewrites,
options,
)?);
Ok(subquery)
}

struct OperatorToFunctionRewriter<'a> {
function_rewrites: &'a [Arc<dyn FunctionRewrite + Send + Sync>],
options: &'a ConfigOptions,
Expand All @@ -111,6 +126,40 @@ impl<'a> TreeNodeRewriter for OperatorToFunctionRewriter<'a> {
expr = result.data
}

// recurse into subqueries if needed
let expr = match expr {
Expr::ScalarSubquery(subquery) => Expr::ScalarSubquery(rewrite_subquery(
subquery,
self.function_rewrites,
self.options,
)?),

Expr::Exists(Exists { subquery, negated }) => Expr::Exists(Exists {
subquery: rewrite_subquery(
subquery,
self.function_rewrites,
self.options,
)?,
negated,
}),

Expr::InSubquery(InSubquery {
expr,
subquery,
negated,
}) => Expr::InSubquery(InSubquery {
expr,
subquery: rewrite_subquery(
subquery,
self.function_rewrites,
self.options,
)?,
negated,
}),

expr => expr,
};

Ok(if transformed {
Transformed::yes(expr)
} else {
Expand Down
42 changes: 40 additions & 2 deletions datafusion/sqllogictest/test_files/scalar.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1779,6 +1779,46 @@ SELECT COALESCE(NULL, 'test')
----
test


statement ok
create table test1 as values (arrow_cast('foo', 'Dictionary(Int32, Utf8)')), (null);

# test coercion string
query ?
select coalesce(column1, 'none_set') from test1;
----
foo
none_set

# test coercion Int
query I
select coalesce(34, arrow_cast(123, 'Dictionary(Int32, Int8)'));
----
34

# test with Int
query I
select coalesce(arrow_cast(123, 'Dictionary(Int32, Int8)'),34);
----
123

# test with null
query I
select coalesce(null, 34, arrow_cast(123, 'Dictionary(Int32, Int8)'));
----
34

# test with null
query T
select coalesce(null, column1, 'none_set') from test1;
----
foo
none_set

statement ok
drop table test1


statement ok
CREATE TABLE test(
c1 INT,
Expand Down Expand Up @@ -2162,5 +2202,3 @@ query I
select strpos('joséésoj', arrow_cast(null, 'Utf8'));
----
NULL


Loading