Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 37 additions & 52 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,26 +171,26 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
))))
}
Expr::Not(expr) => Ok(Transformed::yes(not(get_casted_expr_for_bool_op(
&expr,
*expr,
&self.schema,
)?))),
Expr::IsTrue(expr) => Ok(Transformed::yes(is_true(
get_casted_expr_for_bool_op(&expr, &self.schema)?,
get_casted_expr_for_bool_op(*expr, &self.schema)?,
))),
Expr::IsNotTrue(expr) => Ok(Transformed::yes(is_not_true(
get_casted_expr_for_bool_op(&expr, &self.schema)?,
get_casted_expr_for_bool_op(*expr, &self.schema)?,
))),
Expr::IsFalse(expr) => Ok(Transformed::yes(is_false(
get_casted_expr_for_bool_op(&expr, &self.schema)?,
get_casted_expr_for_bool_op(*expr, &self.schema)?,
))),
Expr::IsNotFalse(expr) => Ok(Transformed::yes(is_not_false(
get_casted_expr_for_bool_op(&expr, &self.schema)?,
get_casted_expr_for_bool_op(*expr, &self.schema)?,
))),
Expr::IsUnknown(expr) => Ok(Transformed::yes(is_unknown(
get_casted_expr_for_bool_op(&expr, &self.schema)?,
get_casted_expr_for_bool_op(*expr, &self.schema)?,
))),
Expr::IsNotUnknown(expr) => Ok(Transformed::yes(is_not_unknown(
get_casted_expr_for_bool_op(&expr, &self.schema)?,
get_casted_expr_for_bool_op(*expr, &self.schema)?,
))),
Expr::Like(Like {
negated,
Expand Down Expand Up @@ -308,15 +308,12 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def {
ScalarFunctionDefinition::UDF(fun) => {
let new_expr = coerce_arguments_for_signature(
args.as_slice(),
args,
&self.schema,
fun.signature(),
)?;
let new_expr = coerce_arguments_for_fun(
new_expr.as_slice(),
&self.schema,
&fun,
)?;
let new_expr =
coerce_arguments_for_fun(new_expr, &self.schema, &fun)?;
Ok(Transformed::yes(Expr::ScalarFunction(
ScalarFunction::new_udf(fun, new_expr),
)))
Expand All @@ -336,7 +333,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
AggregateFunctionDefinition::BuiltIn(fun) => {
let new_expr = coerce_agg_exprs_for_signature(
&fun,
&args,
args,
&self.schema,
&fun.signature(),
)?;
Expand All @@ -353,7 +350,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
}
AggregateFunctionDefinition::UDF(fun) => {
let new_expr = coerce_arguments_for_signature(
args.as_slice(),
args,
&self.schema,
fun.signature(),
)?;
Expand Down Expand Up @@ -387,7 +384,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
expr::WindowFunctionDefinition::AggregateFunction(fun) => {
coerce_agg_exprs_for_signature(
fun,
&args,
args,
&self.schema,
&fun.signature(),
)?
Expand Down Expand Up @@ -454,12 +451,12 @@ fn coerce_scalar(target_type: &DataType, value: &ScalarValue) -> Result<ScalarVa
/// Downstream code uses this signal to treat these values as *unbounded*.
fn coerce_scalar_range_aware(
target_type: &DataType,
value: &ScalarValue,
value: ScalarValue,
) -> Result<ScalarValue> {
coerce_scalar(target_type, value).or_else(|err| {
coerce_scalar(target_type, &value).or_else(|err| {
// If type coercion fails, check if the largest type in family works:
if let Some(largest_type) = get_widest_type_in_family(target_type) {
coerce_scalar(largest_type, value).map_or_else(
coerce_scalar(largest_type, &value).map_or_else(
|_| exec_err!("Cannot cast {value:?} to {target_type:?}"),
|_| ScalarValue::try_from(target_type),
)
Expand All @@ -484,7 +481,7 @@ fn get_widest_type_in_family(given_type: &DataType) -> Option<&DataType> {
/// Coerces the given (window frame) `bound` to `target_type`.
fn coerce_frame_bound(
target_type: &DataType,
bound: &WindowFrameBound,
bound: WindowFrameBound,
) -> Result<WindowFrameBound> {
match bound {
WindowFrameBound::Preceding(v) => {
Expand Down Expand Up @@ -530,31 +527,30 @@ fn coerce_window_frame(
}
WindowFrameUnits::Rows | WindowFrameUnits::Groups => &DataType::UInt64,
};
window_frame.start_bound =
coerce_frame_bound(target_type, &window_frame.start_bound)?;
window_frame.end_bound = coerce_frame_bound(target_type, &window_frame.end_bound)?;
window_frame.start_bound = coerce_frame_bound(target_type, window_frame.start_bound)?;
window_frame.end_bound = coerce_frame_bound(target_type, window_frame.end_bound)?;
Ok(window_frame)
}

// Support the `IsTrue` `IsNotTrue` `IsFalse` `IsNotFalse` type coercion.
// The above op will be rewrite to the binary op when creating the physical op.
fn get_casted_expr_for_bool_op(expr: &Expr, schema: &DFSchemaRef) -> Result<Expr> {
fn get_casted_expr_for_bool_op(expr: Expr, schema: &DFSchemaRef) -> Result<Expr> {
let left_type = expr.get_type(schema)?;
get_input_types(&left_type, &Operator::IsDistinctFrom, &DataType::Boolean)?;
cast_expr(expr, &DataType::Boolean, schema)
expr.cast_to(&DataType::Boolean, schema)
}

/// Returns `expressions` coerced to types compatible with
/// `signature`, if possible.
///
/// See the module level documentation for more detail on coercion.
fn coerce_arguments_for_signature(
expressions: &[Expr],
expressions: Vec<Expr>,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will stop all the arguments to function calls being copied, which seems good

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I always thought having less possible coercible type is preferred 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure what you mean

Perhaps you are referring to preferring &[Expr] over &Vec<Expr>? If so the difference here is that the Vec was owned (and thus we don't have to copy the contents via to_vec())

schema: &DFSchema,
signature: &Signature,
) -> Result<Vec<Expr>> {
if expressions.is_empty() {
return Ok(vec![]);
return Ok(expressions);
}

let current_types = expressions
Expand All @@ -565,58 +561,47 @@ fn coerce_arguments_for_signature(
let new_types = data_types(&current_types, signature)?;

expressions
.iter()
.into_iter()
.enumerate()
.map(|(i, expr)| cast_expr(expr, &new_types[i], schema))
.collect::<Result<Vec<_>>>()
.map(|(i, expr)| expr.cast_to(&new_types[i], schema))
.collect()
}

fn coerce_arguments_for_fun(
expressions: &[Expr],
expressions: Vec<Expr>,
schema: &DFSchema,
fun: &Arc<ScalarUDF>,
) -> Result<Vec<Expr>> {
if expressions.is_empty() {
return Ok(vec![]);
}
let mut expressions: Vec<Expr> = expressions.to_vec();

// Cast Fixedsizelist to List for array functions
if fun.name() == "make_array" {
expressions = expressions
expressions
.into_iter()
.map(|expr| {
let data_type = expr.get_type(schema).unwrap();
if let DataType::FixedSizeList(field, _) = data_type {
let field = field.as_ref().clone();
let to_type = DataType::List(Arc::new(field));
let to_type = DataType::List(field.clone());
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FieldRef is an Arc<Field> -- so this clone is just an Arc::clone (rather than copying the entire filed)

expr.cast_to(&to_type, schema)
} else {
Ok(expr)
}
})
.collect::<Result<Vec<_>>>()?;
.collect()
} else {
Ok(expressions)
}

Ok(expressions)
}

/// Cast `expr` to the specified type, if possible
fn cast_expr(expr: &Expr, to_type: &DataType, schema: &DFSchema) -> Result<Expr> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function just wrapped Expr::cast_to in an API that forced a clone, so I removed it

expr.clone().cast_to(to_type, schema)
}

/// Returns the coerced exprs for each `input_exprs`.
/// Get the coerced data type from `aggregate_rule::coerce_types` and add `try_cast` if the
/// data type of `input_exprs` need to be coerced.
fn coerce_agg_exprs_for_signature(
agg_fun: &AggregateFunction,
input_exprs: &[Expr],
input_exprs: Vec<Expr>,
schema: &DFSchema,
signature: &Signature,
) -> Result<Vec<Expr>> {
if input_exprs.is_empty() {
return Ok(vec![]);
return Ok(input_exprs);
}
let current_types = input_exprs
.iter()
Expand All @@ -627,10 +612,10 @@ fn coerce_agg_exprs_for_signature(
type_coercion::aggregates::coerce_types(agg_fun, &current_types, signature)?;

input_exprs
.iter()
.into_iter()
.enumerate()
.map(|(i, expr)| cast_expr(expr, &coerced_types[i], schema))
.collect::<Result<Vec<_>>>()
.map(|(i, expr)| expr.cast_to(&coerced_types[i], schema))
.collect()
}

fn coerce_case_expression(case: Case, schema: &DFSchemaRef) -> Result<Case> {
Expand Down