diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 3daf347ae4ff..1bbb7b0206a5 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -1839,7 +1839,7 @@ impl ScalarValue { /// ScalarValue::Int32(Some(2)) /// ]; /// - /// let result = ScalarValue::new_list(&scalars, &DataType::Int32); + /// let result = ScalarValue::new_list(&scalars, &DataType::Int32, true); /// /// let expected = ListArray::from_iter_primitive::( /// vec![ @@ -1848,13 +1848,25 @@ impl ScalarValue { /// /// assert_eq!(*result, expected); /// ``` - pub fn new_list(values: &[ScalarValue], data_type: &DataType) -> Arc { + pub fn new_list( + values: &[ScalarValue], + data_type: &DataType, + nullable: bool, + ) -> Arc { let values = if values.is_empty() { new_empty_array(data_type) } else { Self::iter_to_array(values.iter().cloned()).unwrap() }; - Arc::new(array_into_list_array(values)) + Arc::new(array_into_list_array(values, nullable)) + } + + /// Same as [`ScalarValue::new_list`] but with nullable set to true. + pub fn new_list_nullable( + values: &[ScalarValue], + data_type: &DataType, + ) -> Arc { + Self::new_list(values, data_type, true) } /// Converts `IntoIterator` where each element has type corresponding to @@ -1873,7 +1885,7 @@ impl ScalarValue { /// ScalarValue::Int32(Some(2)) /// ]; /// - /// let result = ScalarValue::new_list_from_iter(scalars.into_iter(), &DataType::Int32); + /// let result = ScalarValue::new_list_from_iter(scalars.into_iter(), &DataType::Int32, true); /// /// let expected = ListArray::from_iter_primitive::( /// vec![ @@ -1885,13 +1897,14 @@ impl ScalarValue { pub fn new_list_from_iter( values: impl IntoIterator + ExactSizeIterator, data_type: &DataType, + nullable: bool, ) -> Arc { let values = if values.len() == 0 { new_empty_array(data_type) } else { Self::iter_to_array(values).unwrap() }; - Arc::new(array_into_list_array(values)) + Arc::new(array_into_list_array(values, nullable)) } /// Converts `Vec` where each element has type corresponding to @@ -2305,7 +2318,7 @@ impl ScalarValue { /// use datafusion_common::ScalarValue; /// use arrow::array::ListArray; /// use arrow::datatypes::{DataType, Int32Type}; - /// use datafusion_common::utils::array_into_list_array; + /// use datafusion_common::utils::array_into_list_array_nullable; /// use std::sync::Arc; /// /// let list_arr = ListArray::from_iter_primitive::(vec![ @@ -2314,7 +2327,7 @@ impl ScalarValue { /// ]); /// /// // Wrap into another layer of list, we got nested array as [ [[1,2,3], [4,5]] ] - /// let list_arr = array_into_list_array(Arc::new(list_arr)); + /// let list_arr = array_into_list_array_nullable(Arc::new(list_arr)); /// /// // Convert the array into Scalar Values for each row, we got 1D arrays in this example /// let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&list_arr).unwrap(); @@ -2400,11 +2413,12 @@ impl ScalarValue { typed_cast!(array, index, LargeStringArray, LargeUtf8)? } DataType::Utf8View => typed_cast!(array, index, StringViewArray, Utf8View)?, - DataType::List(_) => { + DataType::List(field) => { let list_array = array.as_list::(); let nested_array = list_array.value(index); // Produces a single element `ListArray` with the value at `index`. - let arr = Arc::new(array_into_list_array(nested_array)); + let arr = + Arc::new(array_into_list_array(nested_array, field.is_nullable())); ScalarValue::List(arr) } @@ -3499,6 +3513,7 @@ mod tests { }; use crate::assert_batches_eq; + use crate::utils::array_into_list_array_nullable; use arrow::buffer::OffsetBuffer; use arrow::compute::{is_null, kernels}; use arrow::util::pretty::pretty_format_columns; @@ -3646,9 +3661,9 @@ mod tests { ScalarValue::from("data-fusion"), ]; - let result = ScalarValue::new_list(scalars.as_slice(), &DataType::Utf8); + let result = ScalarValue::new_list_nullable(scalars.as_slice(), &DataType::Utf8); - let expected = array_into_list_array(Arc::new(StringArray::from(vec![ + let expected = array_into_list_array_nullable(Arc::new(StringArray::from(vec![ "rust", "arrow", "data-fusion", @@ -3860,10 +3875,12 @@ mod tests { #[test] fn iter_to_array_string_test() { - let arr1 = - array_into_list_array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"]))); - let arr2 = - array_into_list_array(Arc::new(StringArray::from(vec!["rust", "world"]))); + let arr1 = array_into_list_array_nullable(Arc::new(StringArray::from(vec![ + "foo", "bar", "baz", + ]))); + let arr2 = array_into_list_array_nullable(Arc::new(StringArray::from(vec![ + "rust", "world", + ]))); let scalars = vec![ ScalarValue::List(Arc::new(arr1)), @@ -4270,7 +4287,7 @@ mod tests { #[test] fn scalar_list_null_to_array() { - let list_array = ScalarValue::new_list(&[], &DataType::UInt64); + let list_array = ScalarValue::new_list_nullable(&[], &DataType::UInt64); assert_eq!(list_array.len(), 1); assert_eq!(list_array.values().len(), 0); @@ -4291,7 +4308,7 @@ mod tests { ScalarValue::UInt64(None), ScalarValue::UInt64(Some(101)), ]; - let list_array = ScalarValue::new_list(&values, &DataType::UInt64); + let list_array = ScalarValue::new_list_nullable(&values, &DataType::UInt64); assert_eq!(list_array.len(), 1); assert_eq!(list_array.values().len(), 3); @@ -5216,13 +5233,13 @@ mod tests { // Define list-of-structs scalars let nl0_array = ScalarValue::iter_to_array(vec![s0.clone(), s1.clone()]).unwrap(); - let nl0 = ScalarValue::List(Arc::new(array_into_list_array(nl0_array))); + let nl0 = ScalarValue::List(Arc::new(array_into_list_array_nullable(nl0_array))); let nl1_array = ScalarValue::iter_to_array(vec![s2.clone()]).unwrap(); - let nl1 = ScalarValue::List(Arc::new(array_into_list_array(nl1_array))); + let nl1 = ScalarValue::List(Arc::new(array_into_list_array_nullable(nl1_array))); let nl2_array = ScalarValue::iter_to_array(vec![s1.clone()]).unwrap(); - let nl2 = ScalarValue::List(Arc::new(array_into_list_array(nl2_array))); + let nl2 = ScalarValue::List(Arc::new(array_into_list_array_nullable(nl2_array))); // iter_to_array for list-of-struct let array = ScalarValue::iter_to_array(vec![nl0, nl1, nl2]).unwrap(); @@ -6008,7 +6025,7 @@ mod tests { #[test] fn test_build_timestamp_millisecond_list() { let values = vec![ScalarValue::TimestampMillisecond(Some(1), None)]; - let arr = ScalarValue::new_list( + let arr = ScalarValue::new_list_nullable( &values, &DataType::Timestamp(TimeUnit::Millisecond, None), ); @@ -6019,7 +6036,7 @@ mod tests { fn test_newlist_timestamp_zone() { let s: &'static str = "UTC"; let values = vec![ScalarValue::TimestampMillisecond(Some(1), Some(s.into()))]; - let arr = ScalarValue::new_list( + let arr = ScalarValue::new_list_nullable( &values, &DataType::Timestamp(TimeUnit::Millisecond, Some(s.into())), ); diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index a0e4d1a76c03..dd7b80333cf8 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -351,10 +351,19 @@ pub fn longest_consecutive_prefix>( /// Wrap an array into a single element `ListArray`. /// For example `[1, 2, 3]` would be converted into `[[1, 2, 3]]` -pub fn array_into_list_array(arr: ArrayRef) -> ListArray { +/// The field in the list array is nullable. +pub fn array_into_list_array_nullable(arr: ArrayRef) -> ListArray { + array_into_list_array(arr, true) +} + +/// Array Utils + +/// Wrap an array into a single element `ListArray`. +/// For example `[1, 2, 3]` would be converted into `[[1, 2, 3]]` +pub fn array_into_list_array(arr: ArrayRef, nullable: bool) -> ListArray { let offsets = OffsetBuffer::from_lengths([arr.len()]); ListArray::new( - Arc::new(Field::new_list_field(arr.data_type().to_owned(), true)), + Arc::new(Field::new_list_field(arr.data_type().to_owned(), nullable)), offsets, arr, None, diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 700bc53af72e..c3bc2fcca2b5 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -1386,7 +1386,7 @@ async fn unnest_with_redundant_columns() -> Result<()> { let expected = vec![ "Projection: shapes.shape_id [shape_id:UInt32]", " Unnest: lists[shape_id2] structs[] [shape_id:UInt32, shape_id2:UInt32;N]", - " Aggregate: groupBy=[[shapes.shape_id]], aggr=[[ARRAY_AGG(shapes.shape_id) AS shape_id2]] [shape_id:UInt32, shape_id2:List(Field { name: \"item\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} });N]", + " Aggregate: groupBy=[[shapes.shape_id]], aggr=[[ARRAY_AGG(shapes.shape_id) AS shape_id2]] [shape_id:UInt32, shape_id2:List(Field { name: \"item\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} })]", " TableScan: shapes projection=[shape_id] [shape_id:UInt32]", ]; diff --git a/datafusion/core/tests/sql/aggregates.rs b/datafusion/core/tests/sql/aggregates.rs index c3139f6fcdfb..e503b74992c3 100644 --- a/datafusion/core/tests/sql/aggregates.rs +++ b/datafusion/core/tests/sql/aggregates.rs @@ -36,7 +36,7 @@ async fn csv_query_array_agg_distinct() -> Result<()> { *actual[0].schema(), Schema::new(vec![Field::new_list( "ARRAY_AGG(DISTINCT aggregate_test_100.c2)", - Field::new("item", DataType::UInt32, true), + Field::new("item", DataType::UInt32, false), false ),]) ); diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index 2997aef58e55..b17e4294a1ef 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -86,7 +86,11 @@ impl AggregateFunction { /// Returns the datatype of the aggregate function given its argument types /// /// This is used to get the returned data type for aggregate expr. - pub fn return_type(&self, input_expr_types: &[DataType]) -> Result { + pub fn return_type( + &self, + input_expr_types: &[DataType], + input_expr_nullable: &[bool], + ) -> Result { // Note that this function *must* return the same type that the respective physical expression returns // or the execution panics. @@ -113,12 +117,23 @@ impl AggregateFunction { AggregateFunction::ArrayAgg => Ok(DataType::List(Arc::new(Field::new( "item", coerced_data_types[0].clone(), - true, + input_expr_nullable[0], )))), AggregateFunction::Grouping => Ok(DataType::Int32), AggregateFunction::NthValue => Ok(coerced_data_types[0].clone()), } } + + /// Returns if the return type of the aggregate function is nullable given its argument + /// nullability + pub fn nullable(&self) -> Result { + match self { + AggregateFunction::Max | AggregateFunction::Min => Ok(true), + AggregateFunction::ArrayAgg => Ok(false), + AggregateFunction::Grouping => Ok(true), + AggregateFunction::NthValue => Ok(true), + } + } } impl AggregateFunction { diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 409ccf7b47c7..6d0202ec3548 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -708,10 +708,14 @@ pub enum WindowFunctionDefinition { impl WindowFunctionDefinition { /// Returns the datatype of the window function - pub fn return_type(&self, input_expr_types: &[DataType]) -> Result { + pub fn return_type( + &self, + input_expr_types: &[DataType], + input_expr_nullable: &[bool], + ) -> Result { match self { WindowFunctionDefinition::AggregateFunction(fun) => { - fun.return_type(input_expr_types) + fun.return_type(input_expr_types, input_expr_nullable) } WindowFunctionDefinition::BuiltInWindowFunction(fun) => { fun.return_type(input_expr_types) @@ -2180,10 +2184,10 @@ mod test { #[test] fn test_first_value_return_type() -> Result<()> { let fun = find_df_window_func("first_value").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; + let observed = fun.return_type(&[DataType::Utf8], &[true])?; assert_eq!(DataType::Utf8, observed); - let observed = fun.return_type(&[DataType::UInt64])?; + let observed = fun.return_type(&[DataType::UInt64], &[true])?; assert_eq!(DataType::UInt64, observed); Ok(()) @@ -2192,10 +2196,10 @@ mod test { #[test] fn test_last_value_return_type() -> Result<()> { let fun = find_df_window_func("last_value").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; + let observed = fun.return_type(&[DataType::Utf8], &[true])?; assert_eq!(DataType::Utf8, observed); - let observed = fun.return_type(&[DataType::Float64])?; + let observed = fun.return_type(&[DataType::Float64], &[true])?; assert_eq!(DataType::Float64, observed); Ok(()) @@ -2204,10 +2208,10 @@ mod test { #[test] fn test_lead_return_type() -> Result<()> { let fun = find_df_window_func("lead").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; + let observed = fun.return_type(&[DataType::Utf8], &[true])?; assert_eq!(DataType::Utf8, observed); - let observed = fun.return_type(&[DataType::Float64])?; + let observed = fun.return_type(&[DataType::Float64], &[true])?; assert_eq!(DataType::Float64, observed); Ok(()) @@ -2216,10 +2220,10 @@ mod test { #[test] fn test_lag_return_type() -> Result<()> { let fun = find_df_window_func("lag").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; + let observed = fun.return_type(&[DataType::Utf8], &[true])?; assert_eq!(DataType::Utf8, observed); - let observed = fun.return_type(&[DataType::Float64])?; + let observed = fun.return_type(&[DataType::Float64], &[true])?; assert_eq!(DataType::Float64, observed); Ok(()) @@ -2228,10 +2232,12 @@ mod test { #[test] fn test_nth_value_return_type() -> Result<()> { let fun = find_df_window_func("nth_value").unwrap(); - let observed = fun.return_type(&[DataType::Utf8, DataType::UInt64])?; + let observed = + fun.return_type(&[DataType::Utf8, DataType::UInt64], &[true, true])?; assert_eq!(DataType::Utf8, observed); - let observed = fun.return_type(&[DataType::Float64, DataType::UInt64])?; + let observed = + fun.return_type(&[DataType::Float64, DataType::UInt64], &[true, true])?; assert_eq!(DataType::Float64, observed); Ok(()) @@ -2240,7 +2246,7 @@ mod test { #[test] fn test_percent_rank_return_type() -> Result<()> { let fun = find_df_window_func("percent_rank").unwrap(); - let observed = fun.return_type(&[])?; + let observed = fun.return_type(&[], &[])?; assert_eq!(DataType::Float64, observed); Ok(()) @@ -2249,7 +2255,7 @@ mod test { #[test] fn test_cume_dist_return_type() -> Result<()> { let fun = find_df_window_func("cume_dist").unwrap(); - let observed = fun.return_type(&[])?; + let observed = fun.return_type(&[], &[])?; assert_eq!(DataType::Float64, observed); Ok(()) @@ -2258,7 +2264,7 @@ mod test { #[test] fn test_ntile_return_type() -> Result<()> { let fun = find_df_window_func("ntile").unwrap(); - let observed = fun.return_type(&[DataType::Int16])?; + let observed = fun.return_type(&[DataType::Int16], &[true])?; assert_eq!(DataType::UInt64, observed); Ok(()) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 986f85adebaa..d5a04ad4ae1f 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -160,6 +160,10 @@ impl ExprSchemable for Expr { .iter() .map(|e| e.get_type(schema)) .collect::>>()?; + let nullability = args + .iter() + .map(|e| e.nullable(schema)) + .collect::>>()?; match fun { WindowFunctionDefinition::AggregateUDF(udf) => { let new_types = data_types_with_aggregate_udf(&data_types, udf).map_err(|err| { @@ -173,10 +177,10 @@ impl ExprSchemable for Expr { ) ) })?; - Ok(fun.return_type(&new_types)?) + Ok(fun.return_type(&new_types, &nullability)?) } _ => { - fun.return_type(&data_types) + fun.return_type(&data_types, &nullability) } } } @@ -185,9 +189,13 @@ impl ExprSchemable for Expr { .iter() .map(|e| e.get_type(schema)) .collect::>>()?; + let nullability = args + .iter() + .map(|e| e.nullable(schema)) + .collect::>>()?; match func_def { AggregateFunctionDefinition::BuiltIn(fun) => { - fun.return_type(&data_types) + fun.return_type(&data_types, &nullability) } AggregateFunctionDefinition::UDF(fun) => { let new_types = data_types_with_aggregate_udf(&data_types, fun).map_err(|err| { @@ -314,11 +322,17 @@ impl ExprSchemable for Expr { } } Expr::Cast(Cast { expr, .. }) => expr.nullable(input_schema), + Expr::AggregateFunction(AggregateFunction { func_def, .. }) => { + match func_def { + AggregateFunctionDefinition::BuiltIn(fun) => fun.nullable(), + // TODO: UDF should be able to customize nullability + AggregateFunctionDefinition::UDF(_) => Ok(true), + } + } Expr::ScalarVariable(_, _) | Expr::TryCast { .. } | Expr::ScalarFunction(..) | Expr::WindowFunction { .. } - | Expr::AggregateFunction { .. } | Expr::Unnest(_) | Expr::Placeholder(_) => Ok(true), Expr::IsNull(_) @@ -343,9 +357,12 @@ impl ExprSchemable for Expr { | Expr::SimilarTo(Like { expr, pattern, .. }) => { Ok(expr.nullable(input_schema)? || pattern.nullable(input_schema)?) } - Expr::Wildcard { .. } => internal_err!( - "Wildcard expressions are not valid in a logical query plan" - ), + Expr::Wildcard { qualifier } => match qualifier { + Some(_) => internal_err!( + "QualifiedWildcard expressions are not valid in a logical query plan" + ), + None => Ok(false), + }, Expr::GroupingSet(_) => { // grouping sets do not really have the concept of nullable and do not appear // in projections diff --git a/datafusion/functions-aggregate/src/bit_and_or_xor.rs b/datafusion/functions-aggregate/src/bit_and_or_xor.rs index 19e24f547d8a..ba9964270443 100644 --- a/datafusion/functions-aggregate/src/bit_and_or_xor.rs +++ b/datafusion/functions-aggregate/src/bit_and_or_xor.rs @@ -440,7 +440,7 @@ where .map(|x| ScalarValue::new_primitive::(Some(*x), &T::DATA_TYPE)) .collect::>>()?; - let arr = ScalarValue::new_list(&values, &T::DATA_TYPE); + let arr = ScalarValue::new_list_nullable(&values, &T::DATA_TYPE); vec![ScalarValue::List(arr)] }; Ok(state_out) diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 84abc0d73098..0fc8e32d7240 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -500,7 +500,8 @@ impl Accumulator for DistinctCountAccumulator { /// Returns the distinct values seen so far as (one element) ListArray. fn state(&mut self) -> Result> { let scalars = self.values.iter().cloned().collect::>(); - let arr = ScalarValue::new_list(scalars.as_slice(), &self.state_data_type); + let arr = + ScalarValue::new_list_nullable(scalars.as_slice(), &self.state_data_type); Ok(vec![ScalarValue::List(arr)]) } diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs index c8bc78ac2dcd..bb926b8da271 100644 --- a/datafusion/functions-aggregate/src/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -180,7 +180,7 @@ impl Accumulator for MedianAccumulator { .map(|x| ScalarValue::new_primitive::(Some(*x), &self.data_type)) .collect::>>()?; - let arr = ScalarValue::new_list(&all_values, &self.data_type); + let arr = ScalarValue::new_list_nullable(&all_values, &self.data_type); Ok(vec![ScalarValue::List(arr)]) } @@ -237,7 +237,7 @@ impl Accumulator for DistinctMedianAccumulator { .map(|x| ScalarValue::new_primitive::(Some(x.0), &self.data_type)) .collect::>>()?; - let arr = ScalarValue::new_list(&all_values, &self.data_type); + let arr = ScalarValue::new_list_nullable(&all_values, &self.data_type); Ok(vec![ScalarValue::List(arr)]) } diff --git a/datafusion/functions-aggregate/src/sum.rs b/datafusion/functions-aggregate/src/sum.rs index b9293bc2ca28..a9f31dc05be9 100644 --- a/datafusion/functions-aggregate/src/sum.rs +++ b/datafusion/functions-aggregate/src/sum.rs @@ -384,7 +384,7 @@ impl Accumulator for DistinctSumAccumulator { }) .collect::>>()?; - vec![ScalarValue::List(ScalarValue::new_list( + vec![ScalarValue::List(ScalarValue::new_list_nullable( &distinct_values, &self.data_type, ))] diff --git a/datafusion/functions-array/src/make_array.rs b/datafusion/functions-array/src/make_array.rs index 0159d4ac0829..79858041d3ca 100644 --- a/datafusion/functions-array/src/make_array.rs +++ b/datafusion/functions-array/src/make_array.rs @@ -27,7 +27,7 @@ use arrow_buffer::OffsetBuffer; use arrow_schema::DataType::{LargeList, List, Null}; use arrow_schema::{DataType, Field}; use datafusion_common::internal_err; -use datafusion_common::{plan_err, utils::array_into_list_array, Result}; +use datafusion_common::{plan_err, utils::array_into_list_array_nullable, Result}; use datafusion_expr::type_coercion::binary::comparison_coercion; use datafusion_expr::TypeSignature; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; @@ -155,7 +155,7 @@ pub(crate) fn make_array_inner(arrays: &[ArrayRef]) -> Result { let length = arrays.iter().map(|a| a.len()).sum(); // By default Int64 let array = new_null_array(&DataType::Int64, length); - Ok(Arc::new(array_into_list_array(array))) + Ok(Arc::new(array_into_list_array_nullable(array))) } LargeList(..) => array_array::(arrays, data_type), _ => array_array::(arrays, data_type), diff --git a/datafusion/functions-array/src/utils.rs b/datafusion/functions-array/src/utils.rs index 00a6a68f7aac..3ecccf3c8713 100644 --- a/datafusion/functions-array/src/utils.rs +++ b/datafusion/functions-array/src/utils.rs @@ -262,7 +262,7 @@ pub(super) fn get_arg_name(args: &[Expr], i: usize) -> String { mod tests { use super::*; use arrow::datatypes::Int64Type; - use datafusion_common::utils::array_into_list_array; + use datafusion_common::utils::array_into_list_array_nullable; /// Only test internal functions, array-related sql functions will be tested in sqllogictest `array.slt` #[test] @@ -277,8 +277,10 @@ mod tests { Some(vec![Some(6), Some(7), Some(8)]), ])); - let array2d_1 = Arc::new(array_into_list_array(array1d_1.clone())) as ArrayRef; - let array2d_2 = Arc::new(array_into_list_array(array1d_2.clone())) as ArrayRef; + let array2d_1 = + Arc::new(array_into_list_array_nullable(array1d_1.clone())) as ArrayRef; + let array2d_2 = + Arc::new(array_into_list_array_nullable(array1d_2.clone())) as ArrayRef; let res = align_array_dimensions::(vec![ array1d_1.to_owned(), @@ -294,8 +296,8 @@ mod tests { expected_dim ); - let array3d_1 = Arc::new(array_into_list_array(array2d_1)) as ArrayRef; - let array3d_2 = array_into_list_array(array2d_2.to_owned()); + let array3d_1 = Arc::new(array_into_list_array_nullable(array2d_1)) as ArrayRef; + let array3d_2 = array_into_list_array_nullable(array2d_2.to_owned()); let res = align_array_dimensions::(vec![array1d_1, Arc::new(array3d_2.clone())]) .unwrap(); diff --git a/datafusion/physical-expr-common/src/aggregate/count_distinct/bytes.rs b/datafusion/physical-expr-common/src/aggregate/count_distinct/bytes.rs index 5c888ca66caa..27094b0c819a 100644 --- a/datafusion/physical-expr-common/src/aggregate/count_distinct/bytes.rs +++ b/datafusion/physical-expr-common/src/aggregate/count_distinct/bytes.rs @@ -20,7 +20,7 @@ use crate::binary_map::{ArrowBytesSet, OutputType}; use arrow::array::{ArrayRef, OffsetSizeTrait}; use datafusion_common::cast::as_list_array; -use datafusion_common::utils::array_into_list_array; +use datafusion_common::utils::array_into_list_array_nullable; use datafusion_common::ScalarValue; use datafusion_expr::Accumulator; use std::fmt::Debug; @@ -47,7 +47,7 @@ impl Accumulator for BytesDistinctCountAccumulator { fn state(&mut self) -> datafusion_common::Result> { let set = self.0.take(); let arr = set.into_state(); - let list = Arc::new(array_into_list_array(arr)); + let list = Arc::new(array_into_list_array_nullable(arr)); Ok(vec![ScalarValue::List(list)]) } diff --git a/datafusion/physical-expr-common/src/aggregate/count_distinct/native.rs b/datafusion/physical-expr-common/src/aggregate/count_distinct/native.rs index 72b83676e81d..e525118b9a17 100644 --- a/datafusion/physical-expr-common/src/aggregate/count_distinct/native.rs +++ b/datafusion/physical-expr-common/src/aggregate/count_distinct/native.rs @@ -32,7 +32,7 @@ use arrow::array::PrimitiveArray; use arrow::datatypes::DataType; use datafusion_common::cast::{as_list_array, as_primitive_array}; -use datafusion_common::utils::array_into_list_array; +use datafusion_common::utils::array_into_list_array_nullable; use datafusion_common::utils::memory::estimate_memory_size; use datafusion_common::ScalarValue; use datafusion_expr::Accumulator; @@ -72,7 +72,7 @@ where PrimitiveArray::::from_iter_values(self.values.iter().cloned()) .with_data_type(self.data_type.clone()), ); - let list = Arc::new(array_into_list_array(arr)); + let list = Arc::new(array_into_list_array_nullable(arr)); Ok(vec![ScalarValue::List(list)]) } @@ -160,7 +160,7 @@ where let arr = Arc::new(PrimitiveArray::::from_iter_values( self.values.iter().map(|v| v.0), )) as ArrayRef; - let list = Arc::new(array_into_list_array(arr)); + let list = Arc::new(array_into_list_array_nullable(arr)); Ok(vec![ScalarValue::List(list)]) } diff --git a/datafusion/physical-expr-common/src/aggregate/tdigest.rs b/datafusion/physical-expr-common/src/aggregate/tdigest.rs index 5107d0ab8e52..1da3d7180d84 100644 --- a/datafusion/physical-expr-common/src/aggregate/tdigest.rs +++ b/datafusion/physical-expr-common/src/aggregate/tdigest.rs @@ -576,7 +576,7 @@ impl TDigest { .map(|v| ScalarValue::Float64(Some(v))) .collect(); - let arr = ScalarValue::new_list(¢roids, &DataType::Float64); + let arr = ScalarValue::new_list_nullable(¢roids, &DataType::Float64); vec![ ScalarValue::UInt64(Some(self.max_size as u64)), diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs b/datafusion/physical-expr/src/aggregate/array_agg.rs index a23ba07de44a..c5a0662a2283 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg.rs @@ -70,22 +70,23 @@ impl AggregateExpr for ArrayAgg { Ok(Field::new_list( &self.name, // This should be the same as return type of AggregateFunction::ArrayAgg - Field::new("item", self.input_data_type.clone(), true), - self.nullable, + Field::new("item", self.input_data_type.clone(), self.nullable), + false, )) } fn create_accumulator(&self) -> Result> { Ok(Box::new(ArrayAggAccumulator::try_new( &self.input_data_type, + self.nullable, )?)) } fn state_fields(&self) -> Result> { Ok(vec![Field::new_list( format_state_name(&self.name, "array_agg"), - Field::new("item", self.input_data_type.clone(), true), - self.nullable, + Field::new("item", self.input_data_type.clone(), self.nullable), + false, )]) } @@ -115,14 +116,16 @@ impl PartialEq for ArrayAgg { pub(crate) struct ArrayAggAccumulator { values: Vec, datatype: DataType, + nullable: bool, } impl ArrayAggAccumulator { /// new array_agg accumulator based on given item data type - pub fn try_new(datatype: &DataType) -> Result { + pub fn try_new(datatype: &DataType, nullable: bool) -> Result { Ok(Self { values: vec![], datatype: datatype.clone(), + nullable, }) } } @@ -164,12 +167,12 @@ impl Accumulator for ArrayAggAccumulator { self.values.iter().map(|a| a.as_ref()).collect(); if element_arrays.is_empty() { - let arr = ScalarValue::new_list(&[], &self.datatype); + let arr = ScalarValue::new_list(&[], &self.datatype, self.nullable); return Ok(ScalarValue::List(arr)); } let concated_array = arrow::compute::concat(&element_arrays)?; - let list_array = array_into_list_array(concated_array); + let list_array = array_into_list_array(concated_array, self.nullable); Ok(ScalarValue::List(Arc::new(list_array))) } diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs index 244a44acdcb5..fc838196de20 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs @@ -74,22 +74,23 @@ impl AggregateExpr for DistinctArrayAgg { Ok(Field::new_list( &self.name, // This should be the same as return type of AggregateFunction::ArrayAgg - Field::new("item", self.input_data_type.clone(), true), - self.nullable, + Field::new("item", self.input_data_type.clone(), self.nullable), + false, )) } fn create_accumulator(&self) -> Result> { Ok(Box::new(DistinctArrayAggAccumulator::try_new( &self.input_data_type, + self.nullable, )?)) } fn state_fields(&self) -> Result> { Ok(vec![Field::new_list( format_state_name(&self.name, "distinct_array_agg"), - Field::new("item", self.input_data_type.clone(), true), - self.nullable, + Field::new("item", self.input_data_type.clone(), self.nullable), + false, )]) } @@ -119,13 +120,15 @@ impl PartialEq for DistinctArrayAgg { struct DistinctArrayAggAccumulator { values: HashSet, datatype: DataType, + nullable: bool, } impl DistinctArrayAggAccumulator { - pub fn try_new(datatype: &DataType) -> Result { + pub fn try_new(datatype: &DataType, nullable: bool) -> Result { Ok(Self { values: HashSet::new(), datatype: datatype.clone(), + nullable, }) } } @@ -162,7 +165,7 @@ impl Accumulator for DistinctArrayAggAccumulator { fn evaluate(&mut self) -> Result { let values: Vec = self.values.iter().cloned().collect(); - let arr = ScalarValue::new_list(&values, &self.datatype); + let arr = ScalarValue::new_list(&values, &self.datatype, self.nullable); Ok(ScalarValue::List(arr)) } diff --git a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs index 837a9d551153..1234ab40c188 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs @@ -91,8 +91,8 @@ impl AggregateExpr for OrderSensitiveArrayAgg { Ok(Field::new_list( &self.name, // This should be the same as return type of AggregateFunction::ArrayAgg - Field::new("item", self.input_data_type.clone(), true), - self.nullable, + Field::new("item", self.input_data_type.clone(), self.nullable), + false, )) } @@ -102,6 +102,7 @@ impl AggregateExpr for OrderSensitiveArrayAgg { &self.order_by_data_types, self.ordering_req.clone(), self.reverse, + self.nullable, ) .map(|acc| Box::new(acc) as _) } @@ -109,14 +110,18 @@ impl AggregateExpr for OrderSensitiveArrayAgg { fn state_fields(&self) -> Result> { let mut fields = vec![Field::new_list( format_state_name(&self.name, "array_agg"), - Field::new("item", self.input_data_type.clone(), true), - self.nullable, // This should be the same as field() + Field::new("item", self.input_data_type.clone(), self.nullable), + false, // This should be the same as field() )]; let orderings = ordering_fields(&self.ordering_req, &self.order_by_data_types); fields.push(Field::new_list( format_state_name(&self.name, "array_agg_orderings"), - Field::new("item", DataType::Struct(Fields::from(orderings)), true), - self.nullable, + Field::new( + "item", + DataType::Struct(Fields::from(orderings)), + self.nullable, + ), + false, )); Ok(fields) } @@ -181,6 +186,8 @@ pub(crate) struct OrderSensitiveArrayAggAccumulator { ordering_req: LexOrdering, /// Whether the aggregation is running in reverse. reverse: bool, + /// Whether the input expr is nullable + nullable: bool, } impl OrderSensitiveArrayAggAccumulator { @@ -191,6 +198,7 @@ impl OrderSensitiveArrayAggAccumulator { ordering_dtypes: &[DataType], ordering_req: LexOrdering, reverse: bool, + nullable: bool, ) -> Result { let mut datatypes = vec![datatype.clone()]; datatypes.extend(ordering_dtypes.iter().cloned()); @@ -200,6 +208,7 @@ impl OrderSensitiveArrayAggAccumulator { datatypes, ordering_req, reverse, + nullable, }) } } @@ -302,9 +311,17 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { fn evaluate(&mut self) -> Result { let values = self.values.clone(); let array = if self.reverse { - ScalarValue::new_list_from_iter(values.into_iter().rev(), &self.datatypes[0]) + ScalarValue::new_list_from_iter( + values.into_iter().rev(), + &self.datatypes[0], + self.nullable, + ) } else { - ScalarValue::new_list_from_iter(values.into_iter(), &self.datatypes[0]) + ScalarValue::new_list_from_iter( + values.into_iter(), + &self.datatypes[0], + self.nullable, + ) }; Ok(ScalarValue::List(array)) } @@ -362,6 +379,7 @@ impl OrderSensitiveArrayAggAccumulator { )?; Ok(ScalarValue::List(Arc::new(array_into_list_array( Arc::new(ordering_array), + self.nullable, )))) } } diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 19b54daa2ef9..169418d2daa0 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -172,7 +172,7 @@ mod tests { Field::new_list( "c1", Field::new("item", data_type.clone(), true), - true, + false, ), result_agg_phy_exprs.field().unwrap() ); @@ -192,7 +192,7 @@ mod tests { Field::new_list( "c1", Field::new("item", data_type.clone(), true), - true, + false, ), result_agg_phy_exprs.field().unwrap() ); @@ -253,20 +253,20 @@ mod tests { #[test] fn test_min_max() -> Result<()> { - let observed = AggregateFunction::Min.return_type(&[DataType::Utf8])?; + let observed = AggregateFunction::Min.return_type(&[DataType::Utf8], &[true])?; assert_eq!(DataType::Utf8, observed); - let observed = AggregateFunction::Max.return_type(&[DataType::Int32])?; + let observed = AggregateFunction::Max.return_type(&[DataType::Int32], &[true])?; assert_eq!(DataType::Int32, observed); // test decimal for min - let observed = - AggregateFunction::Min.return_type(&[DataType::Decimal128(10, 6)])?; + let observed = AggregateFunction::Min + .return_type(&[DataType::Decimal128(10, 6)], &[true])?; assert_eq!(DataType::Decimal128(10, 6), observed); // test decimal for max - let observed = - AggregateFunction::Max.return_type(&[DataType::Decimal128(28, 13)])?; + let observed = AggregateFunction::Max + .return_type(&[DataType::Decimal128(28, 13)], &[true])?; assert_eq!(DataType::Decimal128(28, 13), observed); Ok(()) diff --git a/datafusion/physical-expr/src/aggregate/nth_value.rs b/datafusion/physical-expr/src/aggregate/nth_value.rs index ee7426a897b3..f6d25348f222 100644 --- a/datafusion/physical-expr/src/aggregate/nth_value.rs +++ b/datafusion/physical-expr/src/aggregate/nth_value.rs @@ -32,7 +32,7 @@ use crate::{ use arrow_array::cast::AsArray; use arrow_array::{new_empty_array, ArrayRef, StructArray}; use arrow_schema::{DataType, Field, Fields}; -use datafusion_common::utils::{array_into_list_array, get_row_at_idx}; +use datafusion_common::utils::{array_into_list_array_nullable, get_row_at_idx}; use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; use datafusion_expr::utils::AggregateOrderSensitivity; use datafusion_expr::Accumulator; @@ -393,7 +393,7 @@ impl NthValueAccumulator { None, )?; - Ok(ScalarValue::List(Arc::new(array_into_list_array( + Ok(ScalarValue::List(Arc::new(array_into_list_array_nullable( Arc::new(ordering_array), )))) } @@ -401,7 +401,10 @@ impl NthValueAccumulator { fn evaluate_values(&self) -> ScalarValue { let mut values_cloned = self.values.clone(); let values_slice = values_cloned.make_contiguous(); - ScalarValue::List(ScalarValue::new_list(values_slice, &self.datatypes[0])) + ScalarValue::List(ScalarValue::new_list_nullable( + values_slice, + &self.datatypes[0], + )) } /// Updates state, with the `values`. Fetch contains missing number of entries for state to be complete diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index ecfe123a43af..181c30800434 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -65,7 +65,11 @@ pub fn schema_add_window_field( .iter() .map(|e| e.clone().as_ref().data_type(schema)) .collect::>>()?; - let window_expr_return_type = window_fn.return_type(&data_types)?; + let nullability = args + .iter() + .map(|e| e.clone().as_ref().nullable(schema)) + .collect::>>()?; + let window_expr_return_type = window_fn.return_type(&data_types, &nullability)?; let mut window_fields = schema .fields() .iter() diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 486a911f92c4..510ebe9a9801 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -984,7 +984,7 @@ fn round_trip_scalar_values() { ScalarValue::UInt64(None), ScalarValue::Utf8(None), ScalarValue::LargeUtf8(None), - ScalarValue::List(ScalarValue::new_list(&[], &DataType::Boolean)), + ScalarValue::List(ScalarValue::new_list_nullable(&[], &DataType::Boolean)), ScalarValue::LargeList(ScalarValue::new_large_list(&[], &DataType::Boolean)), ScalarValue::Date32(None), ScalarValue::Boolean(Some(true)), @@ -1076,7 +1076,7 @@ fn round_trip_scalar_values() { i64::MAX, ))), ScalarValue::IntervalMonthDayNano(None), - ScalarValue::List(ScalarValue::new_list( + ScalarValue::List(ScalarValue::new_list_nullable( &[ ScalarValue::Float32(Some(-213.1)), ScalarValue::Float32(None), @@ -1096,10 +1096,13 @@ fn round_trip_scalar_values() { ], &DataType::Float32, )), - ScalarValue::List(ScalarValue::new_list( + ScalarValue::List(ScalarValue::new_list_nullable( &[ - ScalarValue::List(ScalarValue::new_list(&[], &DataType::Float32)), - ScalarValue::List(ScalarValue::new_list( + ScalarValue::List(ScalarValue::new_list_nullable( + &[], + &DataType::Float32, + )), + ScalarValue::List(ScalarValue::new_list_nullable( &[ ScalarValue::Float32(Some(-213.1)), ScalarValue::Float32(None), diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index a48da7b9a1a8..5d430285538d 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -1686,7 +1686,7 @@ fn from_substrait_literal( let element_type = elements[0].data_type(); match lit.type_variation_reference { DEFAULT_CONTAINER_TYPE_VARIATION_REF => ScalarValue::List( - ScalarValue::new_list(elements.as_slice(), &element_type), + ScalarValue::new_list_nullable(elements.as_slice(), &element_type), ), LARGE_CONTAINER_TYPE_VARIATION_REF => ScalarValue::LargeList( ScalarValue::new_large_list(elements.as_slice(), &element_type), @@ -1704,7 +1704,7 @@ fn from_substrait_literal( )?; match lit.type_variation_reference { DEFAULT_CONTAINER_TYPE_VARIATION_REF => { - ScalarValue::List(ScalarValue::new_list(&[], &element_type)) + ScalarValue::List(ScalarValue::new_list_nullable(&[], &element_type)) } LARGE_CONTAINER_TYPE_VARIATION_REF => ScalarValue::LargeList( ScalarValue::new_large_list(&[], &element_type), diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index c0469d333164..b59843ac468d 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -2231,11 +2231,11 @@ mod test { round_trip_literal(ScalarValue::UInt64(Some(u64::MIN)))?; round_trip_literal(ScalarValue::UInt64(Some(u64::MAX)))?; - round_trip_literal(ScalarValue::List(ScalarValue::new_list( + round_trip_literal(ScalarValue::List(ScalarValue::new_list_nullable( &[ScalarValue::Float32(Some(1.0))], &DataType::Float32, )))?; - round_trip_literal(ScalarValue::List(ScalarValue::new_list( + round_trip_literal(ScalarValue::List(ScalarValue::new_list_nullable( &[], &DataType::Float32, )))?;