diff --git a/datafusion/common/src/utils.rs b/datafusion/common/src/utils.rs index fecab8835e50..2d38ca21829b 100644 --- a/datafusion/common/src/utils.rs +++ b/datafusion/common/src/utils.rs @@ -342,6 +342,8 @@ pub fn longest_consecutive_prefix>( count } +/// Array Utils + /// Wrap an array into a single element `ListArray`. /// For example `[1, 2, 3]` would be converted into `[[1, 2, 3]]` pub fn array_into_list_array(arr: ArrayRef) -> ListArray { @@ -429,6 +431,42 @@ pub fn base_type(data_type: &DataType) -> DataType { } } +/// A helper function to coerce base type in List. +/// +/// Example +/// ``` +/// use arrow::datatypes::{DataType, Field}; +/// use datafusion_common::utils::coerced_type_with_base_type_only; +/// use std::sync::Arc; +/// +/// let data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); +/// let base_type = DataType::Float64; +/// let coerced_type = coerced_type_with_base_type_only(&data_type, &base_type); +/// assert_eq!(coerced_type, DataType::List(Arc::new(Field::new("item", DataType::Float64, true)))); +pub fn coerced_type_with_base_type_only( + data_type: &DataType, + base_type: &DataType, +) -> DataType { + match data_type { + DataType::List(field) => { + let data_type = match field.data_type() { + DataType::List(_) => { + coerced_type_with_base_type_only(field.data_type(), base_type) + } + _ => base_type.to_owned(), + }; + + DataType::List(Arc::new(Field::new( + field.name(), + data_type, + field.is_nullable(), + ))) + } + + _ => base_type.clone(), + } +} + /// Compute the number of dimensions in a list data type. pub fn list_ndims(data_type: &DataType) -> u64 { if let DataType::List(field) = data_type { diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index fd899289ac82..289704ed98f8 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -915,10 +915,17 @@ impl BuiltinScalarFunction { // for now, the list is small, as we do not have many built-in functions. match self { - BuiltinScalarFunction::ArrayAppend => Signature::any(2, self.volatility()), BuiltinScalarFunction::ArraySort => { Signature::variadic_any(self.volatility()) } + BuiltinScalarFunction::ArrayAppend => Signature { + type_signature: ArrayAndElement, + volatility: self.volatility(), + }, + BuiltinScalarFunction::MakeArray => { + // 0 or more arguments of arbitrary type + Signature::one_of(vec![VariadicEqual, Any(0)], self.volatility()) + } BuiltinScalarFunction::ArrayPopFront => Signature::any(1, self.volatility()), BuiltinScalarFunction::ArrayPopBack => Signature::any(1, self.volatility()), BuiltinScalarFunction::ArrayConcat => { @@ -958,10 +965,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayIntersect => Signature::any(2, self.volatility()), BuiltinScalarFunction::ArrayUnion => Signature::any(2, self.volatility()), BuiltinScalarFunction::Cardinality => Signature::any(1, self.volatility()), - BuiltinScalarFunction::MakeArray => { - // 0 or more arguments of arbitrary type - Signature::one_of(vec![VariadicAny, Any(0)], self.volatility()) - } BuiltinScalarFunction::Range => Signature::one_of( vec![ Exact(vec![Int64]), diff --git a/datafusion/expr/src/signature.rs b/datafusion/expr/src/signature.rs index 685601523f9b..3f07c300e196 100644 --- a/datafusion/expr/src/signature.rs +++ b/datafusion/expr/src/signature.rs @@ -91,11 +91,14 @@ pub enum TypeSignature { /// DataFusion attempts to coerce all argument types to match the first argument's type /// /// # Examples - /// A function such as `array` is `VariadicEqual` + /// Given types in signature should be coericible to the same final type. + /// A function such as `make_array` is `VariadicEqual`. + /// + /// `make_array(i32, i64) -> make_array(i64, i64)` VariadicEqual, /// One or more arguments with arbitrary types VariadicAny, - /// fixed number of arguments of an arbitrary but equal type out of a list of valid types. + /// Fixed number of arguments of an arbitrary but equal type out of a list of valid types. /// /// # Examples /// 1. A function of one argument of f64 is `Uniform(1, vec![DataType::Float64])` @@ -113,6 +116,12 @@ pub enum TypeSignature { /// Function `make_array` takes 0 or more arguments with arbitrary types, its `TypeSignature` /// is `OneOf(vec![Any(0), VariadicAny])`. OneOf(Vec), + /// Specialized Signature for ArrayAppend and similar functions + /// The first argument should be List/LargeList, and the second argument should be non-list or list. + /// The second argument's list dimension should be one dimension less than the first argument's list dimension. + /// List dimension of the List/LargeList is equivalent to the number of List. + /// List dimension of the non-list is 0. + ArrayAndElement, } impl TypeSignature { @@ -136,11 +145,16 @@ impl TypeSignature { .collect::>() .join(", ")] } - TypeSignature::VariadicEqual => vec!["T, .., T".to_string()], + TypeSignature::VariadicEqual => { + vec!["CoercibleT, .., CoercibleT".to_string()] + } TypeSignature::VariadicAny => vec!["Any, .., Any".to_string()], TypeSignature::OneOf(sigs) => { sigs.iter().flat_map(|s| s.to_string_repr()).collect() } + TypeSignature::ArrayAndElement => { + vec!["ArrayAndElement(List, T)".to_string()] + } } } diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 79b574238495..f95a30e025b4 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -21,7 +21,10 @@ use arrow::{ compute::can_cast_types, datatypes::{DataType, TimeUnit}, }; -use datafusion_common::{plan_err, DataFusionError, Result}; +use datafusion_common::utils::list_ndims; +use datafusion_common::{internal_err, plan_err, DataFusionError, Result}; + +use super::binary::comparison_coercion; /// Performs type coercion for function arguments. /// @@ -86,16 +89,66 @@ fn get_valid_types( .map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect()) .collect(), TypeSignature::VariadicEqual => { - // one entry with the same len as current_types, whose type is `current_types[0]`. - vec![current_types - .iter() - .map(|_| current_types[0].clone()) - .collect()] + let new_type = current_types.iter().skip(1).try_fold( + current_types.first().unwrap().clone(), + |acc, x| { + let coerced_type = comparison_coercion(&acc, x); + if let Some(coerced_type) = coerced_type { + Ok(coerced_type) + } else { + internal_err!("Coercion from {acc:?} to {x:?} failed.") + } + }, + ); + + match new_type { + Ok(new_type) => vec![vec![new_type; current_types.len()]], + Err(e) => return Err(e), + } } TypeSignature::VariadicAny => { vec![current_types.to_vec()] } + TypeSignature::Exact(valid_types) => vec![valid_types.clone()], + TypeSignature::ArrayAndElement => { + if current_types.len() != 2 { + return Ok(vec![vec![]]); + } + + let array_type = ¤t_types[0]; + let elem_type = ¤t_types[1]; + + // We follow Postgres on `array_append(Null, T)`, which is not valid. + if array_type.eq(&DataType::Null) { + return Ok(vec![vec![]]); + } + + // We need to find the coerced base type, mainly for cases like: + // `array_append(List(null), i64)` -> `List(i64)` + let array_base_type = datafusion_common::utils::base_type(array_type); + let elem_base_type = datafusion_common::utils::base_type(elem_type); + let new_base_type = comparison_coercion(&array_base_type, &elem_base_type); + + if new_base_type.is_none() { + return internal_err!( + "Coercion from {array_base_type:?} to {elem_base_type:?} not supported." + ); + } + let new_base_type = new_base_type.unwrap(); + + let array_type = datafusion_common::utils::coerced_type_with_base_type_only( + array_type, + &new_base_type, + ); + + if let DataType::List(ref field) = array_type { + let elem_type = field.data_type(); + return Ok(vec![vec![array_type.clone(), elem_type.to_owned()]]); + } else { + return Ok(vec![vec![]]); + } + } TypeSignature::Any(number) => { if current_types.len() != *number { return plan_err!( @@ -241,6 +294,15 @@ fn coerced_from<'a>( Utf8 | LargeUtf8 => Some(type_into.clone()), Null if can_cast_types(type_from, type_into) => Some(type_into.clone()), + // Only accept list with the same number of dimensions unless the type is Null. + // List with different dimensions should be handled in TypeSignature or other places before this. + List(_) + if datafusion_common::utils::base_type(type_from).eq(&Null) + || list_ndims(type_from) == list_ndims(type_into) => + { + Some(type_into.clone()) + } + Timestamp(unit, Some(tz)) if tz.as_ref() == TIMEZONE_WILDCARD => { match type_from { Timestamp(_, Some(from_tz)) => { diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 91611251d9dd..c5e1180b9f97 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -590,26 +590,6 @@ fn coerce_arguments_for_fun( .collect::>>()?; } - if *fun == BuiltinScalarFunction::MakeArray { - // Find the final data type for the function arguments - let current_types = expressions - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - - let new_type = current_types - .iter() - .skip(1) - .fold(current_types.first().unwrap().clone(), |acc, x| { - comparison_coercion(&acc, x).unwrap_or(acc) - }); - - return expressions - .iter() - .zip(current_types) - .map(|(expr, from_type)| cast_array_expr(expr, &from_type, &new_type, schema)) - .collect(); - } Ok(expressions) } @@ -618,20 +598,6 @@ fn cast_expr(expr: &Expr, to_type: &DataType, schema: &DFSchema) -> Result expr.clone().cast_to(to_type, schema) } -/// Cast array `expr` to the specified type, if possible -fn cast_array_expr( - expr: &Expr, - from_type: &DataType, - to_type: &DataType, - schema: &DFSchema, -) -> Result { - if from_type.equals_datatype(&DataType::Null) { - Ok(expr.clone()) - } else { - cast_expr(expr, to_type, schema) - } -} - /// Returns the coerced exprs for each `input_exprs`. /// Get the coerced data type from `aggregate_rule::coerce_types` and add `try_cast` if the /// data type of `input_exprs` need to be coerced. diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 7ccf58af832d..98c9aee8940f 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -361,7 +361,8 @@ pub fn make_array(arrays: &[ArrayRef]) -> Result { match data_type { // Either an empty array or all nulls: DataType::Null => { - let array = new_null_array(&DataType::Null, arrays.len()); + let array = + new_null_array(&DataType::Null, arrays.iter().map(|a| a.len()).sum()); Ok(Arc::new(array_into_list_array(array))) } DataType::LargeList(..) => array_array::(arrays, data_type), @@ -827,10 +828,14 @@ pub fn array_append(args: &[ArrayRef]) -> Result { let list_array = as_list_array(&args[0])?; let element_array = &args[1]; - check_datatypes("array_append", &[list_array.values(), element_array])?; let res = match list_array.value_type() { DataType::List(_) => concat_internal(args)?, - DataType::Null => return make_array(&[element_array.to_owned()]), + DataType::Null => { + return make_array(&[ + list_array.values().to_owned(), + element_array.to_owned(), + ]); + } data_type => { return general_append_and_prepend( list_array, @@ -2284,18 +2289,4 @@ mod tests { expected_dim ); } - - #[test] - fn test_check_invalid_datatypes() { - let data = vec![Some(vec![Some(1), Some(2), Some(3)])]; - let list_array = - Arc::new(ListArray::from_iter_primitive::(data)) as ArrayRef; - let int64_array = Arc::new(StringArray::from(vec![Some("string")])) as ArrayRef; - - let args = [list_array.clone(), int64_array.clone()]; - - let array = array_append(&args); - - assert_eq!(array.unwrap_err().strip_backtrace(), "Error during planning: array_append received incompatible types: '[Int64, Utf8]'."); - } } diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 210739aa51da..640f5064eae6 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -297,10 +297,8 @@ AS VALUES (make_array([28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]), [28, 29, 30], [37, 38, 39], 10) ; -query ? +query error select [1, true, null] ----- -[1, 1, ] query error DataFusion error: This feature is not implemented: ScalarFunctions without MakeArray are not supported: now() SELECT [now()] @@ -1253,18 +1251,43 @@ select list_sort(make_array(1, 3, null, 5, NULL, -5)), list_sort(make_array(1, 3 ## array_append (aliases: `list_append`, `array_push_back`, `list_push_back`) -# TODO: array_append with NULLs -# array_append scalar function #1 -# query ? -# select array_append(make_array(), 4); -# ---- -# [4] +# array_append with NULLs -# array_append scalar function #2 -# query ?? -# select array_append(make_array(), make_array()), array_append(make_array(), make_array(4)); -# ---- -# [[]] [[4]] +query error +select array_append(null, 1); + +query error +select array_append(null, [2, 3]); + +query error +select array_append(null, [[4]]); + +query ???? +select + array_append(make_array(), 4), + array_append(make_array(), null), + array_append(make_array(1, null, 3), 4), + array_append(make_array(null, null), 1) +; +---- +[4] [] [1, , 3, 4] [, , 1] + +# test invalid (non-null) +query error +select array_append(1, 2); + +query error +select array_append(1, [2]); + +query error +select array_append([1], [2]); + +query ?? +select + array_append(make_array(make_array(1, null, 3)), make_array(null)), + array_append(make_array(make_array(1, null, 3)), null); +---- +[[1, , 3], []] [[1, , 3], ] # array_append scalar function #3 query ???