Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 39 additions & 22 deletions datafusion/common/src/scalar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1839,7 +1839,7 @@ impl ScalarValue {
/// ScalarValue::Int32(Some(2))
/// ];
///
/// let result = ScalarValue::new_list(&scalars, &DataType::Int32);
/// let result = ScalarValue::new_list(&scalars, &DataType::Int32, true);
///
/// let expected = ListArray::from_iter_primitive::<Int32Type, _, _>(
/// vec![
Expand All @@ -1848,13 +1848,25 @@ impl ScalarValue {
///
/// assert_eq!(*result, expected);
/// ```
pub fn new_list(values: &[ScalarValue], data_type: &DataType) -> Arc<ListArray> {
pub fn new_list(
values: &[ScalarValue],
data_type: &DataType,
nullable: bool,
) -> Arc<ListArray> {
let values = if values.is_empty() {
new_empty_array(data_type)
} else {
Self::iter_to_array(values.iter().cloned()).unwrap()
};
Arc::new(array_into_list_array(values))
Arc::new(array_into_list_array(values, nullable))
}

/// Same as [`ScalarValue::new_list`] but with nullable set to true.
pub fn new_list_nullable(
values: &[ScalarValue],
data_type: &DataType,
) -> Arc<ListArray> {
Self::new_list(values, data_type, true)
}

/// Converts `IntoIterator<Item = ScalarValue>` where each element has type corresponding to
Expand All @@ -1873,7 +1885,7 @@ impl ScalarValue {
/// ScalarValue::Int32(Some(2))
/// ];
///
/// let result = ScalarValue::new_list_from_iter(scalars.into_iter(), &DataType::Int32);
/// let result = ScalarValue::new_list_from_iter(scalars.into_iter(), &DataType::Int32, true);
///
/// let expected = ListArray::from_iter_primitive::<Int32Type, _, _>(
/// vec![
Expand All @@ -1885,13 +1897,14 @@ impl ScalarValue {
pub fn new_list_from_iter(
values: impl IntoIterator<Item = ScalarValue> + ExactSizeIterator,
data_type: &DataType,
nullable: bool,
) -> Arc<ListArray> {
let values = if values.len() == 0 {
new_empty_array(data_type)
} else {
Self::iter_to_array(values).unwrap()
};
Arc::new(array_into_list_array(values))
Arc::new(array_into_list_array(values, nullable))
}

/// Converts `Vec<ScalarValue>` where each element has type corresponding to
Expand Down Expand Up @@ -2305,7 +2318,7 @@ impl ScalarValue {
/// use datafusion_common::ScalarValue;
/// use arrow::array::ListArray;
/// use arrow::datatypes::{DataType, Int32Type};
/// use datafusion_common::utils::array_into_list_array;
/// use datafusion_common::utils::array_into_list_array_nullable;
/// use std::sync::Arc;
///
/// let list_arr = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
Expand All @@ -2314,7 +2327,7 @@ impl ScalarValue {
/// ]);
///
/// // Wrap into another layer of list, we got nested array as [ [[1,2,3], [4,5]] ]
/// let list_arr = array_into_list_array(Arc::new(list_arr));
/// let list_arr = array_into_list_array_nullable(Arc::new(list_arr));
///
/// // Convert the array into Scalar Values for each row, we got 1D arrays in this example
/// let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&list_arr).unwrap();
Expand Down Expand Up @@ -2400,11 +2413,12 @@ impl ScalarValue {
typed_cast!(array, index, LargeStringArray, LargeUtf8)?
}
DataType::Utf8View => typed_cast!(array, index, StringViewArray, Utf8View)?,
DataType::List(_) => {
DataType::List(field) => {
let list_array = array.as_list::<i32>();
let nested_array = list_array.value(index);
// Produces a single element `ListArray` with the value at `index`.
let arr = Arc::new(array_into_list_array(nested_array));
let arr =
Arc::new(array_into_list_array(nested_array, field.is_nullable()));

ScalarValue::List(arr)
}
Expand Down Expand Up @@ -3499,6 +3513,7 @@ mod tests {
};

use crate::assert_batches_eq;
use crate::utils::array_into_list_array_nullable;
use arrow::buffer::OffsetBuffer;
use arrow::compute::{is_null, kernels};
use arrow::util::pretty::pretty_format_columns;
Expand Down Expand Up @@ -3646,9 +3661,9 @@ mod tests {
ScalarValue::from("data-fusion"),
];

let result = ScalarValue::new_list(scalars.as_slice(), &DataType::Utf8);
let result = ScalarValue::new_list_nullable(scalars.as_slice(), &DataType::Utf8);

let expected = array_into_list_array(Arc::new(StringArray::from(vec![
let expected = array_into_list_array_nullable(Arc::new(StringArray::from(vec![
"rust",
"arrow",
"data-fusion",
Expand Down Expand Up @@ -3860,10 +3875,12 @@ mod tests {

#[test]
fn iter_to_array_string_test() {
let arr1 =
array_into_list_array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"])));
let arr2 =
array_into_list_array(Arc::new(StringArray::from(vec!["rust", "world"])));
let arr1 = array_into_list_array_nullable(Arc::new(StringArray::from(vec![
"foo", "bar", "baz",
])));
let arr2 = array_into_list_array_nullable(Arc::new(StringArray::from(vec![
"rust", "world",
])));

let scalars = vec![
ScalarValue::List(Arc::new(arr1)),
Expand Down Expand Up @@ -4270,7 +4287,7 @@ mod tests {

#[test]
fn scalar_list_null_to_array() {
let list_array = ScalarValue::new_list(&[], &DataType::UInt64);
let list_array = ScalarValue::new_list_nullable(&[], &DataType::UInt64);

assert_eq!(list_array.len(), 1);
assert_eq!(list_array.values().len(), 0);
Expand All @@ -4291,7 +4308,7 @@ mod tests {
ScalarValue::UInt64(None),
ScalarValue::UInt64(Some(101)),
];
let list_array = ScalarValue::new_list(&values, &DataType::UInt64);
let list_array = ScalarValue::new_list_nullable(&values, &DataType::UInt64);
assert_eq!(list_array.len(), 1);
assert_eq!(list_array.values().len(), 3);

Expand Down Expand Up @@ -5216,13 +5233,13 @@ mod tests {
// Define list-of-structs scalars

let nl0_array = ScalarValue::iter_to_array(vec![s0.clone(), s1.clone()]).unwrap();
let nl0 = ScalarValue::List(Arc::new(array_into_list_array(nl0_array)));
let nl0 = ScalarValue::List(Arc::new(array_into_list_array_nullable(nl0_array)));

let nl1_array = ScalarValue::iter_to_array(vec![s2.clone()]).unwrap();
let nl1 = ScalarValue::List(Arc::new(array_into_list_array(nl1_array)));
let nl1 = ScalarValue::List(Arc::new(array_into_list_array_nullable(nl1_array)));

let nl2_array = ScalarValue::iter_to_array(vec![s1.clone()]).unwrap();
let nl2 = ScalarValue::List(Arc::new(array_into_list_array(nl2_array)));
let nl2 = ScalarValue::List(Arc::new(array_into_list_array_nullable(nl2_array)));

// iter_to_array for list-of-struct
let array = ScalarValue::iter_to_array(vec![nl0, nl1, nl2]).unwrap();
Expand Down Expand Up @@ -6008,7 +6025,7 @@ mod tests {
#[test]
fn test_build_timestamp_millisecond_list() {
let values = vec![ScalarValue::TimestampMillisecond(Some(1), None)];
let arr = ScalarValue::new_list(
let arr = ScalarValue::new_list_nullable(
&values,
&DataType::Timestamp(TimeUnit::Millisecond, None),
);
Expand All @@ -6019,7 +6036,7 @@ mod tests {
fn test_newlist_timestamp_zone() {
let s: &'static str = "UTC";
let values = vec![ScalarValue::TimestampMillisecond(Some(1), Some(s.into()))];
let arr = ScalarValue::new_list(
let arr = ScalarValue::new_list_nullable(
&values,
&DataType::Timestamp(TimeUnit::Millisecond, Some(s.into())),
);
Expand Down
13 changes: 11 additions & 2 deletions datafusion/common/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,10 +351,19 @@ pub fn longest_consecutive_prefix<T: Borrow<usize>>(

/// Wrap an array into a single element `ListArray`.
/// For example `[1, 2, 3]` would be converted into `[[1, 2, 3]]`
pub fn array_into_list_array(arr: ArrayRef) -> ListArray {
/// The field in the list array is nullable.
pub fn array_into_list_array_nullable(arr: ArrayRef) -> ListArray {
array_into_list_array(arr, true)
}

/// Array Utils

/// Wrap an array into a single element `ListArray`.
/// For example `[1, 2, 3]` would be converted into `[[1, 2, 3]]`
pub fn array_into_list_array(arr: ArrayRef, nullable: bool) -> ListArray {
let offsets = OffsetBuffer::from_lengths([arr.len()]);
ListArray::new(
Arc::new(Field::new_list_field(arr.data_type().to_owned(), true)),
Arc::new(Field::new_list_field(arr.data_type().to_owned(), nullable)),
offsets,
arr,
None,
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1386,7 +1386,7 @@ async fn unnest_with_redundant_columns() -> Result<()> {
let expected = vec![
"Projection: shapes.shape_id [shape_id:UInt32]",
" Unnest: lists[shape_id2] structs[] [shape_id:UInt32, shape_id2:UInt32;N]",
" Aggregate: groupBy=[[shapes.shape_id]], aggr=[[ARRAY_AGG(shapes.shape_id) AS shape_id2]] [shape_id:UInt32, shape_id2:List(Field { name: \"item\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} });N]",
" Aggregate: groupBy=[[shapes.shape_id]], aggr=[[ARRAY_AGG(shapes.shape_id) AS shape_id2]] [shape_id:UInt32, shape_id2:List(Field { name: \"item\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} })]",
" TableScan: shapes projection=[shape_id] [shape_id:UInt32]",
];

Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/tests/sql/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async fn csv_query_array_agg_distinct() -> Result<()> {
*actual[0].schema(),
Schema::new(vec![Field::new_list(
"ARRAY_AGG(DISTINCT aggregate_test_100.c2)",
Field::new("item", DataType::UInt32, true),
Field::new("item", DataType::UInt32, false),
false
),])
);
Expand Down
19 changes: 17 additions & 2 deletions datafusion/expr/src/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,11 @@ impl AggregateFunction {
/// Returns the datatype of the aggregate function given its argument types
///
/// This is used to get the returned data type for aggregate expr.
pub fn return_type(&self, input_expr_types: &[DataType]) -> Result<DataType> {
pub fn return_type(
&self,
input_expr_types: &[DataType],
input_expr_nullable: &[bool],
) -> Result<DataType> {
// Note that this function *must* return the same type that the respective physical expression returns
// or the execution panics.

Expand All @@ -113,12 +117,23 @@ impl AggregateFunction {
AggregateFunction::ArrayAgg => Ok(DataType::List(Arc::new(Field::new(
"item",
coerced_data_types[0].clone(),
true,
input_expr_nullable[0],
)))),
AggregateFunction::Grouping => Ok(DataType::Int32),
AggregateFunction::NthValue => Ok(coerced_data_types[0].clone()),
}
}

/// Returns if the return type of the aggregate function is nullable given its argument
/// nullability
pub fn nullable(&self) -> Result<bool> {
match self {
AggregateFunction::Max | AggregateFunction::Min => Ok(true),
AggregateFunction::ArrayAgg => Ok(false),
AggregateFunction::Grouping => Ok(true),
AggregateFunction::NthValue => Ok(true),
}
}
}

impl AggregateFunction {
Expand Down
36 changes: 21 additions & 15 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -708,10 +708,14 @@ pub enum WindowFunctionDefinition {

impl WindowFunctionDefinition {
/// Returns the datatype of the window function
pub fn return_type(&self, input_expr_types: &[DataType]) -> Result<DataType> {
pub fn return_type(
&self,
input_expr_types: &[DataType],
input_expr_nullable: &[bool],
) -> Result<DataType> {
match self {
WindowFunctionDefinition::AggregateFunction(fun) => {
fun.return_type(input_expr_types)
fun.return_type(input_expr_types, input_expr_nullable)
}
WindowFunctionDefinition::BuiltInWindowFunction(fun) => {
fun.return_type(input_expr_types)
Expand Down Expand Up @@ -2180,10 +2184,10 @@ mod test {
#[test]
fn test_first_value_return_type() -> Result<()> {
let fun = find_df_window_func("first_value").unwrap();
let observed = fun.return_type(&[DataType::Utf8])?;
let observed = fun.return_type(&[DataType::Utf8], &[true])?;
assert_eq!(DataType::Utf8, observed);

let observed = fun.return_type(&[DataType::UInt64])?;
let observed = fun.return_type(&[DataType::UInt64], &[true])?;
assert_eq!(DataType::UInt64, observed);

Ok(())
Expand All @@ -2192,10 +2196,10 @@ mod test {
#[test]
fn test_last_value_return_type() -> Result<()> {
let fun = find_df_window_func("last_value").unwrap();
let observed = fun.return_type(&[DataType::Utf8])?;
let observed = fun.return_type(&[DataType::Utf8], &[true])?;
assert_eq!(DataType::Utf8, observed);

let observed = fun.return_type(&[DataType::Float64])?;
let observed = fun.return_type(&[DataType::Float64], &[true])?;
assert_eq!(DataType::Float64, observed);

Ok(())
Expand All @@ -2204,10 +2208,10 @@ mod test {
#[test]
fn test_lead_return_type() -> Result<()> {
let fun = find_df_window_func("lead").unwrap();
let observed = fun.return_type(&[DataType::Utf8])?;
let observed = fun.return_type(&[DataType::Utf8], &[true])?;
assert_eq!(DataType::Utf8, observed);

let observed = fun.return_type(&[DataType::Float64])?;
let observed = fun.return_type(&[DataType::Float64], &[true])?;
assert_eq!(DataType::Float64, observed);

Ok(())
Expand All @@ -2216,10 +2220,10 @@ mod test {
#[test]
fn test_lag_return_type() -> Result<()> {
let fun = find_df_window_func("lag").unwrap();
let observed = fun.return_type(&[DataType::Utf8])?;
let observed = fun.return_type(&[DataType::Utf8], &[true])?;
assert_eq!(DataType::Utf8, observed);

let observed = fun.return_type(&[DataType::Float64])?;
let observed = fun.return_type(&[DataType::Float64], &[true])?;
assert_eq!(DataType::Float64, observed);

Ok(())
Expand All @@ -2228,10 +2232,12 @@ mod test {
#[test]
fn test_nth_value_return_type() -> Result<()> {
let fun = find_df_window_func("nth_value").unwrap();
let observed = fun.return_type(&[DataType::Utf8, DataType::UInt64])?;
let observed =
fun.return_type(&[DataType::Utf8, DataType::UInt64], &[true, true])?;
assert_eq!(DataType::Utf8, observed);

let observed = fun.return_type(&[DataType::Float64, DataType::UInt64])?;
let observed =
fun.return_type(&[DataType::Float64, DataType::UInt64], &[true, true])?;
assert_eq!(DataType::Float64, observed);

Ok(())
Expand All @@ -2240,7 +2246,7 @@ mod test {
#[test]
fn test_percent_rank_return_type() -> Result<()> {
let fun = find_df_window_func("percent_rank").unwrap();
let observed = fun.return_type(&[])?;
let observed = fun.return_type(&[], &[])?;
assert_eq!(DataType::Float64, observed);

Ok(())
Expand All @@ -2249,7 +2255,7 @@ mod test {
#[test]
fn test_cume_dist_return_type() -> Result<()> {
let fun = find_df_window_func("cume_dist").unwrap();
let observed = fun.return_type(&[])?;
let observed = fun.return_type(&[], &[])?;
assert_eq!(DataType::Float64, observed);

Ok(())
Expand All @@ -2258,7 +2264,7 @@ mod test {
#[test]
fn test_ntile_return_type() -> Result<()> {
let fun = find_df_window_func("ntile").unwrap();
let observed = fun.return_type(&[DataType::Int16])?;
let observed = fun.return_type(&[DataType::Int16], &[true])?;
assert_eq!(DataType::UInt64, observed);

Ok(())
Expand Down
Loading