From 7afeb8b8078018a167e7bce6421024b95517b234 Mon Sep 17 00:00:00 2001 From: Asura7969 <1402357969@qq.com> Date: Wed, 8 Nov 2023 17:04:53 +0800 Subject: [PATCH 01/16] Minor: Improve the document format of JoinHashMap --- .../src/joins/hash_join_utils.rs | 115 ++++++++++-------- 1 file changed, 62 insertions(+), 53 deletions(-) diff --git a/datafusion/physical-plan/src/joins/hash_join_utils.rs b/datafusion/physical-plan/src/joins/hash_join_utils.rs index 3a2a85c72722..3ea0331ab4fe 100644 --- a/datafusion/physical-plan/src/joins/hash_join_utils.rs +++ b/datafusion/physical-plan/src/joins/hash_join_utils.rs @@ -40,59 +40,68 @@ use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; use hashbrown::raw::RawTable; use hashbrown::HashSet; -// Maps a `u64` hash value based on the build side ["on" values] to a list of indices with this key's value. -// By allocating a `HashMap` with capacity for *at least* the number of rows for entries at the build side, -// we make sure that we don't have to re-hash the hashmap, which needs access to the key (the hash in this case) value. -// E.g. 1 -> [3, 6, 8] indicates that the column values map to rows 3, 6 and 8 for hash value 1 -// As the key is a hash value, we need to check possible hash collisions in the probe stage -// During this stage it might be the case that a row is contained the same hashmap value, -// but the values don't match. Those are checked in the [equal_rows] macro -// The indices (values) are stored in a separate chained list stored in the `Vec`. -// The first value (+1) is stored in the hashmap, whereas the next value is stored in array at the position value. -// The chain can be followed until the value "0" has been reached, meaning the end of the list. -// Also see chapter 5.3 of [Balancing vectorized query execution with bandwidth-optimized storage](https://dare.uva.nl/search?identifier=5ccbb60a-38b8-4eeb-858a-e7735dd37487) -// See the example below: -// Insert (1,1) -// map: -// --------- -// | 1 | 2 | -// --------- -// next: -// --------------------- -// | 0 | 0 | 0 | 0 | 0 | -// --------------------- -// Insert (2,2) -// map: -// --------- -// | 1 | 2 | -// | 2 | 3 | -// --------- -// next: -// --------------------- -// | 0 | 0 | 0 | 0 | 0 | -// --------------------- -// Insert (1,3) -// map: -// --------- -// | 1 | 4 | -// | 2 | 3 | -// --------- -// next: -// --------------------- -// | 0 | 0 | 0 | 2 | 0 | <--- hash value 1 maps to 4,2 (which means indices values 3,1) -// --------------------- -// Insert (1,4) -// map: -// --------- -// | 1 | 5 | -// | 2 | 3 | -// --------- -// next: -// --------------------- -// | 0 | 0 | 0 | 2 | 4 | <--- hash value 1 maps to 5,4,2 (which means indices values 4,3,1) -// --------------------- -// TODO: speed up collision checks -// https://github.com/apache/arrow-datafusion/issues/50 +/// Maps a `u64` hash value based on the build side ["on" values] to a list of indices with this key's value. +/// +/// By allocating a `HashMap` with capacity for *at least* the number of rows for entries at the build side, +/// we make sure that we don't have to re-hash the hashmap, which needs access to the key (the hash in this case) value. +/// +/// E.g. 1 -> [3, 6, 8] indicates that the column values map to rows 3, 6 and 8 for hash value 1 +/// As the key is a hash value, we need to check possible hash collisions in the probe stage +/// During this stage it might be the case that a row is contained the same hashmap value, +/// but the values don't match. Those are checked in the [equal_rows] macro +/// The indices (values) are stored in a separate chained list stored in the `Vec`. +/// +/// The first value (+1) is stored in the hashmap, whereas the next value is stored in array at the position value. +/// +/// The chain can be followed until the value "0" has been reached, meaning the end of the list. +/// Also see chapter 5.3 of [Balancing vectorized query execution with bandwidth-optimized storage](https://dare.uva.nl/search?identifier=5ccbb60a-38b8-4eeb-858a-e7735dd37487) +/// +/// # Example +/// +/// ``` text +/// See the example below: +/// Insert (1,1) +/// map: +/// --------- +/// | 1 | 2 | +/// --------- +/// next: +/// --------------------- +/// | 0 | 0 | 0 | 0 | 0 | +/// --------------------- +/// Insert (2,2) +/// map: +/// --------- +/// | 1 | 2 | +/// | 2 | 3 | +/// --------- +/// next: +/// --------------------- +/// | 0 | 0 | 0 | 0 | 0 | +/// --------------------- +/// Insert (1,3) +/// map: +/// --------- +/// | 1 | 4 | +/// | 2 | 3 | +/// --------- +/// next: +/// --------------------- +/// | 0 | 0 | 0 | 2 | 0 | <--- hash value 1 maps to 4,2 (which means indices values 3,1) +/// --------------------- +/// Insert (1,4) +/// map: +/// --------- +/// | 1 | 5 | +/// | 2 | 3 | +/// --------- +/// next: +/// --------------------- +/// | 0 | 0 | 0 | 2 | 4 | <--- hash value 1 maps to 5,4,2 (which means indices values 4,3,1) +/// --------------------- +/// ``` +/// +///TODO: [speed up collision checks](https://github.com/apache/arrow-datafusion/issues/50) pub struct JoinHashMap { // Stores hash value to last row index pub map: RawTable<(u64, u64)>, From 9becd0492b20d88d57182474ae498f1a7dbf90e8 Mon Sep 17 00:00:00 2001 From: Asura7969 <1402357969@qq.com> Date: Mon, 20 Nov 2023 17:37:01 +0800 Subject: [PATCH 02/16] list sort --- datafusion/expr/src/built_in_function.rs | 8 ++ datafusion/expr/src/expr_fn.rs | 3 + .../physical-expr/src/array_expressions.rs | 77 ++++++++++++++- datafusion/physical-expr/src/functions.rs | 3 + datafusion/proto/src/generated/pbjson.rs | 99 ++++++++++--------- datafusion/proto/src/generated/prost.rs | 2 + .../proto/src/logical_plan/from_proto.rs | 15 ++- datafusion/proto/src/logical_plan/to_proto.rs | 1 + datafusion/sqllogictest/test_files/array.slt | 26 +++-- .../source/user-guide/sql/scalar_functions.md | 37 +++++++ 10 files changed, 211 insertions(+), 60 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index fc6f9c28e105..69c8c0b8baf8 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -130,6 +130,8 @@ pub enum BuiltinScalarFunction { // array functions /// array_append ArrayAppend, + /// array_sort + ArraySort, /// array_concat ArrayConcat, /// array_has @@ -387,6 +389,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Tanh => Volatility::Immutable, BuiltinScalarFunction::Trunc => Volatility::Immutable, BuiltinScalarFunction::ArrayAppend => Volatility::Immutable, + BuiltinScalarFunction::ArraySort => Volatility::Immutable, BuiltinScalarFunction::ArrayConcat => Volatility::Immutable, BuiltinScalarFunction::ArrayEmpty => Volatility::Immutable, BuiltinScalarFunction::ArrayHasAll => Volatility::Immutable, @@ -531,6 +534,7 @@ impl BuiltinScalarFunction { Ok(data_type) } BuiltinScalarFunction::ArrayAppend => Ok(input_expr_types[0].clone()), + BuiltinScalarFunction::ArraySort => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayConcat => { let mut expr_type = Null; let mut max_dims = 0; @@ -879,6 +883,9 @@ impl BuiltinScalarFunction { // for now, the list is small, as we do not have many built-in functions. match self { BuiltinScalarFunction::ArrayAppend => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArraySort => { + Signature::variadic_any(self.volatility()) + } BuiltinScalarFunction::ArrayPopFront => Signature::any(1, self.volatility()), BuiltinScalarFunction::ArrayPopBack => Signature::any(1, self.volatility()), BuiltinScalarFunction::ArrayConcat => { @@ -1510,6 +1517,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] { "array_push_back", "list_push_back", ], + BuiltinScalarFunction::ArraySort => &["array_sort", "list_sort"], BuiltinScalarFunction::ArrayConcat => { &["array_concat", "array_cat", "list_concat", "list_cat"] } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 75b762804427..71773c1b3854 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -583,6 +583,8 @@ scalar_expr!( "appends an element to the end of an array." ); +nary_scalar_expr!(ArraySort, array_sort, "returns sorted array."); + scalar_expr!( ArrayPopBack, array_pop_back, @@ -1174,6 +1176,7 @@ mod test { test_scalar_expr!(FromUnixtime, from_unixtime, unixtime); test_scalar_expr!(ArrayAppend, array_append, array, element); + test_scalar_expr!(ArraySort, array_sort, array, element); test_scalar_expr!(ArrayPopFront, array_pop_front, array); test_scalar_expr!(ArrayPopBack, array_pop_back, array); test_unary_scalar_expr!(ArrayDims, array_dims); diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index ded606c3b705..11d05be49bc2 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -27,7 +27,7 @@ use arrow::datatypes::{DataType, Field, UInt64Type}; use arrow::row::{RowConverter, SortField}; use arrow_buffer::NullBuffer; -use arrow_schema::FieldRef; +use arrow_schema::{FieldRef, SortOptions}; use datafusion_common::cast::{ as_generic_string_array, as_int64_array, as_list_array, as_string_array, }; @@ -743,7 +743,7 @@ fn general_append_and_prepend( /// # Arguments /// /// * `args` - An array of 1 to 3 ArrayRefs representing start, stop, and step(step value can not be zero.) values. -/// +/// /// # Examples /// /// gen_range(3) => [0, 1, 2] @@ -821,6 +821,37 @@ pub fn array_append(args: &[ArrayRef]) -> Result { Ok(res) } +/// Array_sort SQL function +pub fn array_sort(args: &[ArrayRef]) -> Result { + + let sort_option = match args.len() { + 1 => None, + 2 => { + let sort = as_boolean_array(&args[1]); + Some(SortOptions { + descending: sort.value(0), + nulls_first: true, + }) + } + 3 => { + let sort = as_boolean_array(&args[1]); + let nulls_first = as_boolean_array(&args[2]); + Some(SortOptions { + descending: sort.value(0), + nulls_first: nulls_first.value(0), + }) + } + _ => return internal_err!("array_sort expects 1 to 3 arguments"), + }; + + let list_array = as_list_array(&args[0])?; + + let sorted_array = + arrow_ord::sort::sort(list_array.values(), sort_option.clone()).unwrap(); + + Ok(Arc::new(array_into_list_array(sorted_array))) +} + /// Array_prepend SQL function pub fn array_prepend(args: &[ArrayRef]) -> Result { let list_array = as_list_array(&args[1])?; @@ -2697,6 +2728,48 @@ mod tests { ); } + #[test] + fn test_array_sort() { + // array_sort([2, 3, NULL, 4]) = [NULL, 2, 3, 4] + let data = vec![Some(vec![Some(2), Some(3), None, Some(4)])]; + let array = + Arc::new(ListArray::from_iter_primitive::(data)) as ArrayRef; + + let result = array_sort(&[array.clone()]) + .expect("failed to initialize function array_sort"); + + let result = + as_list_array(&result).expect("failed to initialize function array_sort"); + + assert_eq!( + &[0, 2, 3, 4], + result + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .values() + ); + + // array_sort([2, 3, NULL, 4], true, false) = [4, 3, 2, NULL] + let desc = Arc::new(BooleanArray::from(vec![true])) as ArrayRef; + let null_first = Arc::new(BooleanArray::from(vec![false])) as ArrayRef; + let result = array_sort(&[array.clone(), desc, null_first]) + .expect("failed to initialize function array_sort"); + + let result = as_list_array(&result) + .expect("failed to initialize function array_to_string"); + assert_eq!( + &[4, 3, 2, 0], + result + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .values() + ); + } + #[test] fn test_array_prepend() { // array_prepend(1, [2, 3, 4]) = [1, 2, 3, 4] diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index b46249d26dde..4f71936b1619 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -329,6 +329,9 @@ pub fn create_physical_fun( BuiltinScalarFunction::ArrayAppend => { Arc::new(|args| make_scalar_function(array_expressions::array_append)(args)) } + BuiltinScalarFunction::ArraySort => { + Arc::new(|args| make_scalar_function(array_expressions::array_sort)(args)) + } BuiltinScalarFunction::ArrayConcat => { Arc::new(|args| make_scalar_function(array_expressions::array_concat)(args)) } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 3faacca18c60..23b9a3d94025 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -2875,7 +2875,7 @@ impl<'de> serde::Deserialize<'de> for CoalesceBatchesExecNode { if target_batch_size__.is_some() { return Err(serde::de::Error::duplicate_field("targetBatchSize")); } - target_batch_size__ = + target_batch_size__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -3180,7 +3180,7 @@ impl<'de> serde::Deserialize<'de> for ColumnIndex { if index__.is_some() { return Err(serde::de::Error::duplicate_field("index")); } - index__ = + index__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -5201,7 +5201,7 @@ impl<'de> serde::Deserialize<'de> for CustomTableScanNode { if custom_table_data__.is_some() { return Err(serde::de::Error::duplicate_field("customTableData")); } - custom_table_data__ = + custom_table_data__ = Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) ; } @@ -5379,7 +5379,7 @@ impl<'de> serde::Deserialize<'de> for Decimal { if precision__.is_some() { return Err(serde::de::Error::duplicate_field("precision")); } - precision__ = + precision__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -5387,7 +5387,7 @@ impl<'de> serde::Deserialize<'de> for Decimal { if scale__.is_some() { return Err(serde::de::Error::duplicate_field("scale")); } - scale__ = + scale__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -5504,7 +5504,7 @@ impl<'de> serde::Deserialize<'de> for Decimal128 { if value__.is_some() { return Err(serde::de::Error::duplicate_field("value")); } - value__ = + value__ = Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) ; } @@ -5512,7 +5512,7 @@ impl<'de> serde::Deserialize<'de> for Decimal128 { if p__.is_some() { return Err(serde::de::Error::duplicate_field("p")); } - p__ = + p__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -5520,7 +5520,7 @@ impl<'de> serde::Deserialize<'de> for Decimal128 { if s__.is_some() { return Err(serde::de::Error::duplicate_field("s")); } - s__ = + s__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -5638,7 +5638,7 @@ impl<'de> serde::Deserialize<'de> for Decimal256 { if value__.is_some() { return Err(serde::de::Error::duplicate_field("value")); } - value__ = + value__ = Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) ; } @@ -5646,7 +5646,7 @@ impl<'de> serde::Deserialize<'de> for Decimal256 { if p__.is_some() { return Err(serde::de::Error::duplicate_field("p")); } - p__ = + p__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -5654,7 +5654,7 @@ impl<'de> serde::Deserialize<'de> for Decimal256 { if s__.is_some() { return Err(serde::de::Error::duplicate_field("s")); } - s__ = + s__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -7208,7 +7208,7 @@ impl<'de> serde::Deserialize<'de> for FileRange { if start__.is_some() { return Err(serde::de::Error::duplicate_field("start")); } - start__ = + start__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -7216,7 +7216,7 @@ impl<'de> serde::Deserialize<'de> for FileRange { if end__.is_some() { return Err(serde::de::Error::duplicate_field("end")); } - end__ = + end__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -7396,7 +7396,7 @@ impl<'de> serde::Deserialize<'de> for FileScanExecConf { if projection__.is_some() { return Err(serde::de::Error::duplicate_field("projection")); } - projection__ = + projection__ = Some(map_.next_value::>>()? .into_iter().map(|x| x.0).collect()) ; @@ -8061,7 +8061,7 @@ impl<'de> serde::Deserialize<'de> for FixedSizeBinary { if length__.is_some() { return Err(serde::de::Error::duplicate_field("length")); } - length__ = + length__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -8172,7 +8172,7 @@ impl<'de> serde::Deserialize<'de> for FixedSizeList { if list_size__.is_some() { return Err(serde::de::Error::duplicate_field("listSize")); } - list_size__ = + list_size__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -8560,7 +8560,7 @@ impl<'de> serde::Deserialize<'de> for GlobalLimitExecNode { if skip__.is_some() { return Err(serde::de::Error::duplicate_field("skip")); } - skip__ = + skip__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -8568,7 +8568,7 @@ impl<'de> serde::Deserialize<'de> for GlobalLimitExecNode { if fetch__.is_some() { return Err(serde::de::Error::duplicate_field("fetch")); } - fetch__ = + fetch__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -8973,7 +8973,7 @@ impl<'de> serde::Deserialize<'de> for HashRepartition { if partition_count__.is_some() { return Err(serde::de::Error::duplicate_field("partitionCount")); } - partition_count__ = + partition_count__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -9356,7 +9356,7 @@ impl<'de> serde::Deserialize<'de> for IntervalMonthDayNanoValue { if months__.is_some() { return Err(serde::de::Error::duplicate_field("months")); } - months__ = + months__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -9364,7 +9364,7 @@ impl<'de> serde::Deserialize<'de> for IntervalMonthDayNanoValue { if days__.is_some() { return Err(serde::de::Error::duplicate_field("days")); } - days__ = + days__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -9372,7 +9372,7 @@ impl<'de> serde::Deserialize<'de> for IntervalMonthDayNanoValue { if nanos__.is_some() { return Err(serde::de::Error::duplicate_field("nanos")); } - nanos__ = + nanos__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -11452,7 +11452,7 @@ impl<'de> serde::Deserialize<'de> for LimitNode { if skip__.is_some() { return Err(serde::de::Error::duplicate_field("skip")); } - skip__ = + skip__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -11460,7 +11460,7 @@ impl<'de> serde::Deserialize<'de> for LimitNode { if fetch__.is_some() { return Err(serde::de::Error::duplicate_field("fetch")); } - fetch__ = + fetch__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -12215,7 +12215,7 @@ impl<'de> serde::Deserialize<'de> for ListingTableScanNode { if target_partitions__.is_some() { return Err(serde::de::Error::duplicate_field("targetPartitions")); } - target_partitions__ = + target_partitions__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -12361,7 +12361,7 @@ impl<'de> serde::Deserialize<'de> for LocalLimitExecNode { if fetch__.is_some() { return Err(serde::de::Error::duplicate_field("fetch")); } - fetch__ = + fetch__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -13195,7 +13195,7 @@ impl<'de> serde::Deserialize<'de> for LogicalExtensionNode { if node__.is_some() { return Err(serde::de::Error::duplicate_field("node")); } - node__ = + node__ = Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) ; } @@ -15137,7 +15137,7 @@ impl<'de> serde::Deserialize<'de> for PartiallySortedPartitionSearchMode { if columns__.is_some() { return Err(serde::de::Error::duplicate_field("columns")); } - columns__ = + columns__ = Some(map_.next_value::>>()? .into_iter().map(|x| x.0).collect()) ; @@ -15451,7 +15451,7 @@ impl<'de> serde::Deserialize<'de> for PartitionStats { if num_rows__.is_some() { return Err(serde::de::Error::duplicate_field("numRows")); } - num_rows__ = + num_rows__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -15459,7 +15459,7 @@ impl<'de> serde::Deserialize<'de> for PartitionStats { if num_batches__.is_some() { return Err(serde::de::Error::duplicate_field("numBatches")); } - num_batches__ = + num_batches__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -15467,7 +15467,7 @@ impl<'de> serde::Deserialize<'de> for PartitionStats { if num_bytes__.is_some() { return Err(serde::de::Error::duplicate_field("numBytes")); } - num_bytes__ = + num_bytes__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -15619,7 +15619,7 @@ impl<'de> serde::Deserialize<'de> for PartitionedFile { if size__.is_some() { return Err(serde::de::Error::duplicate_field("size")); } - size__ = + size__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -15627,7 +15627,7 @@ impl<'de> serde::Deserialize<'de> for PartitionedFile { if last_modified_ns__.is_some() { return Err(serde::de::Error::duplicate_field("lastModifiedNs")); } - last_modified_ns__ = + last_modified_ns__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -16384,7 +16384,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalColumn { if index__.is_some() { return Err(serde::de::Error::duplicate_field("index")); } - index__ = + index__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -16944,7 +16944,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExtensionNode { if node__.is_some() { return Err(serde::de::Error::duplicate_field("node")); } - node__ = + node__ = Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) ; } @@ -17205,7 +17205,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalHashRepartition { if partition_count__.is_some() { return Err(serde::de::Error::duplicate_field("partitionCount")); } - partition_count__ = + partition_count__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -19866,7 +19866,7 @@ impl<'de> serde::Deserialize<'de> for PrimaryKeyConstraint { if indices__.is_some() { return Err(serde::de::Error::duplicate_field("indices")); } - indices__ = + indices__ = Some(map_.next_value::>>()? .into_iter().map(|x| x.0).collect()) ; @@ -20784,7 +20784,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFixedSizeBinary { if values__.is_some() { return Err(serde::de::Error::duplicate_field("values")); } - values__ = + values__ = Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) ; } @@ -20792,7 +20792,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFixedSizeBinary { if length__.is_some() { return Err(serde::de::Error::duplicate_field("length")); } - length__ = + length__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -20901,6 +20901,7 @@ impl serde::Serialize for ScalarFunction { Self::Lcm => "Lcm", Self::Gcd => "Gcd", Self::ArrayAppend => "ArrayAppend", + Self::ArraySort => "ArraySort", Self::ArrayConcat => "ArrayConcat", Self::ArrayDims => "ArrayDims", Self::ArrayRepeat => "ArrayRepeat", @@ -21037,6 +21038,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Lcm", "Gcd", "ArrayAppend", + "ArraySort", "ArrayConcat", "ArrayDims", "ArrayRepeat", @@ -21202,6 +21204,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Lcm" => Ok(ScalarFunction::Lcm), "Gcd" => Ok(ScalarFunction::Gcd), "ArrayAppend" => Ok(ScalarFunction::ArrayAppend), + "ArraySort" => Ok(ScalarFunction::ArraySort), "ArrayConcat" => Ok(ScalarFunction::ArrayConcat), "ArrayDims" => Ok(ScalarFunction::ArrayDims), "ArrayRepeat" => Ok(ScalarFunction::ArrayRepeat), @@ -21460,7 +21463,7 @@ impl<'de> serde::Deserialize<'de> for ScalarListValue { if ipc_message__.is_some() { return Err(serde::de::Error::duplicate_field("ipcMessage")); } - ipc_message__ = + ipc_message__ = Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) ; } @@ -21468,7 +21471,7 @@ impl<'de> serde::Deserialize<'de> for ScalarListValue { if arrow_data__.is_some() { return Err(serde::de::Error::duplicate_field("arrowData")); } - arrow_data__ = + arrow_data__ = Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) ; } @@ -22600,7 +22603,7 @@ impl<'de> serde::Deserialize<'de> for ScanLimit { if limit__.is_some() { return Err(serde::de::Error::duplicate_field("limit")); } - limit__ = + limit__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -23189,7 +23192,7 @@ impl<'de> serde::Deserialize<'de> for SortExecNode { if fetch__.is_some() { return Err(serde::de::Error::duplicate_field("fetch")); } - fetch__ = + fetch__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -23450,7 +23453,7 @@ impl<'de> serde::Deserialize<'de> for SortNode { if fetch__.is_some() { return Err(serde::de::Error::duplicate_field("fetch")); } - fetch__ = + fetch__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -23578,7 +23581,7 @@ impl<'de> serde::Deserialize<'de> for SortPreservingMergeExecNode { if fetch__.is_some() { return Err(serde::de::Error::duplicate_field("fetch")); } - fetch__ = + fetch__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -24553,7 +24556,7 @@ impl<'de> serde::Deserialize<'de> for Union { if type_ids__.is_some() { return Err(serde::de::Error::duplicate_field("typeIds")); } - type_ids__ = + type_ids__ = Some(map_.next_value::>>()? .into_iter().map(|x| x.0).collect()) ; @@ -24902,7 +24905,7 @@ impl<'de> serde::Deserialize<'de> for UniqueConstraint { if indices__.is_some() { return Err(serde::de::Error::duplicate_field("indices")); } - indices__ = + indices__ = Some(map_.next_value::>>()? .into_iter().map(|x| x.0).collect()) ; @@ -25009,7 +25012,7 @@ impl<'de> serde::Deserialize<'de> for ValuesNode { if n_cols__.is_some() { return Err(serde::de::Error::duplicate_field("nCols")); } - n_cols__ = + n_cols__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 2555a31f6fe2..05a02bac770d 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2571,6 +2571,7 @@ pub enum ScalarFunction { Range = 122, ArrayPopFront = 123, Levenshtein = 124, + ArraySort = 125, } impl ScalarFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2796,6 +2797,7 @@ impl ScalarFunction { "Lcm" => Some(Self::Lcm), "Gcd" => Some(Self::Gcd), "ArrayAppend" => Some(Self::ArrayAppend), + "ArraySort" => Some(Self::ArraySort), "ArrayConcat" => Some(Self::ArrayConcat), "ArrayDims" => Some(Self::ArrayDims), "ArrayRepeat" => Some(Self::ArrayRepeat), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index f14da70485ab..c7e9af9f7356 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -44,10 +44,10 @@ use datafusion_expr::{ array_has, array_has_all, array_has_any, array_intersect, array_length, array_ndims, array_position, array_positions, array_prepend, array_remove, array_remove_all, array_remove_n, array_repeat, array_replace, array_replace_all, array_replace_n, - array_slice, array_to_string, arrow_typeof, ascii, asin, asinh, atan, atan2, atanh, - bit_length, btrim, cardinality, cbrt, ceil, character_length, chr, coalesce, - concat_expr, concat_ws_expr, cos, cosh, cot, current_date, current_time, date_bin, - date_part, date_trunc, decode, degrees, digest, encode, exp, + array_slice, array_sort, array_to_string, arrow_typeof, ascii, asin, asinh, atan, + atan2, atanh, bit_length, btrim, cardinality, cbrt, ceil, character_length, chr, + coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, current_date, current_time, + date_bin, date_part, date_trunc, decode, degrees, digest, encode, exp, expr::{self, InList, Sort, WindowFunction}, factorial, flatten, floor, from_unixtime, gcd, gen_range, isnan, iszero, lcm, left, levenshtein, ln, log, log10, log2, @@ -463,6 +463,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Rtrim => Self::Rtrim, ScalarFunction::ToTimestamp => Self::ToTimestamp, ScalarFunction::ArrayAppend => Self::ArrayAppend, + ScalarFunction::ArraySort => Self::ArraySort, ScalarFunction::ArrayConcat => Self::ArrayConcat, ScalarFunction::ArrayEmpty => Self::ArrayEmpty, ScalarFunction::ArrayHasAll => Self::ArrayHasAll, @@ -1332,6 +1333,12 @@ pub fn parse_expr( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, )), + ScalarFunction::ArraySort => Ok(array_sort( + args.to_owned() + .iter() + .map(|expr| parse_expr(expr, registry)) + .collect::, _>>()?, + )), ScalarFunction::ArrayPopFront => { Ok(array_pop_front(parse_expr(&args[0], registry)?)) } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index de81a1f4caef..15e174487f79 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1470,6 +1470,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Rtrim => Self::Rtrim, BuiltinScalarFunction::ToTimestamp => Self::ToTimestamp, BuiltinScalarFunction::ArrayAppend => Self::ArrayAppend, + BuiltinScalarFunction::ArraySort => Self::ArraySort, BuiltinScalarFunction::ArrayConcat => Self::ArrayConcat, BuiltinScalarFunction::ArrayEmpty => Self::ArrayEmpty, BuiltinScalarFunction::ArrayHasAll => Self::ArrayHasAll, diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 99ed94883629..a6c39e38a6d1 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -1044,6 +1044,20 @@ select make_array(['a','b'], null); ---- [[a, b], ] +## array_sort (aliases: `list_sort`) +query ??? +select array_sort(make_array(4, 2, 3, 1)), array_sort(make_array(1.0, 4.0, null, 3.0), true, true), array_sort(make_array('a', 'd', 'c', null, 'b'), false, false); +---- +[1, 2, 3, 4] [, 4.0, 3.0, 1.0] [a, b, c, d, ] + + +## list_sort (aliases: `array_sort`) +query ??? +select list_sort(make_array(4, 2, 3, 1)), list_sort(make_array(1.0, 4.0, null, 3.0), true, true), list_sort(make_array('a', 'd', 'c', null, 'b'), false, false); +---- +[1, 2, 3, 4] [, 4.0, 3.0, 1.0] [a, b, c, d, ] + + ## array_append (aliases: `list_append`, `array_push_back`, `list_push_back`) # TODO: array_append with NULLs @@ -1216,7 +1230,7 @@ select array_prepend(make_array(1, 11, 111), column1), array_prepend(column2, ma # array_repeat scalar function #1 query ???????? -select +select array_repeat(1, 5), array_repeat(3.14, 3), array_repeat('l', 4), @@ -1249,7 +1263,7 @@ AS VALUES (0, 3, 3.3, 'datafusion', make_array(8, 9)); query ?????? -select +select array_repeat(column2, column1), array_repeat(column3, column1), array_repeat(column4, column1), @@ -1264,7 +1278,7 @@ from array_repeat_table; [] [] [] [] [3, 3, 3] [] statement ok -drop table array_repeat_table; +drop table array_repeat_table; ## array_concat (aliases: `array_cat`, `list_concat`, `list_cat`) @@ -2180,7 +2194,7 @@ select array_remove(make_array(1, 2, 2, 1, 1), 2), array_remove(make_array(1.0, [1, 2, 1, 1] [2.0, 2.0, 1.0, 1.0] [h, e, l, o] query ??? -select +select array_remove(make_array(1, null, 2, 3), 2), array_remove(make_array(1.1, null, 2.2, 3.3), 1.1), array_remove(make_array('a', null, 'bc'), 'a'); @@ -2679,7 +2693,7 @@ from array_intersect_table_3D; query ?????? SELECT array_intersect(make_array(1,2,3), make_array(2,3,4)), array_intersect(make_array(1,3,5), make_array(2,4,6)), - array_intersect(make_array('aa','bb','cc'), make_array('cc','aa','dd')), + array_intersect(make_array('aa','bb','cc'), make_array('cc','aa','dd')), array_intersect(make_array(true, false), make_array(true)), array_intersect(make_array(1.1, 2.2, 3.3), make_array(2.2, 3.3, 4.4)), array_intersect(make_array([1, 1], [2, 2], [3, 3]), make_array([2, 2], [3, 3], [4, 4])) @@ -2690,7 +2704,7 @@ SELECT array_intersect(make_array(1,2,3), make_array(2,3,4)), query ?????? SELECT list_intersect(make_array(1,2,3), make_array(2,3,4)), list_intersect(make_array(1,3,5), make_array(2,4,6)), - list_intersect(make_array('aa','bb','cc'), make_array('cc','aa','dd')), + list_intersect(make_array('aa','bb','cc'), make_array('cc','aa','dd')), list_intersect(make_array(true, false), make_array(true)), list_intersect(make_array(1.1, 2.2, 3.3), make_array(2.2, 3.3, 4.4)), list_intersect(make_array([1, 1], [2, 2], [3, 3]), make_array([2, 2], [3, 3], [4, 4])) diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index f9f45a1b0a97..5b215105b0e6 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1519,6 +1519,7 @@ from_unixtime(expression) ## Array Functions - [array_append](#array_append) +- [array_sort](#array_sort) - [array_cat](#array_cat) - [array_concat](#array_concat) - [array_contains](#array_contains) @@ -1548,6 +1549,7 @@ from_unixtime(expression) - [cardinality](#cardinality) - [empty](#empty) - [list_append](#list_append) +- [list_sort](#list_sort) - [list_cat](#list_cat) - [list_concat](#list_concat) - [list_dims](#list_dims) @@ -1609,6 +1611,37 @@ array_append(array, element) - list_append - list_push_back +### `array_sort` +Sort array. + +``` +array_sort(array) +array_sort(array, false) +array_sort(array, true, false) +``` + +#### Arguments + +- **array**: Array expression. + Can be a constant, column, or function, and any combination of array operators. +- **desc**: Whether to sort in descending order. +- **nulls_first**: Whether to sort nulls first. + +#### Example + +``` +❯ select array_sort([3, 1, null, 2], true, false); ++--------------------------------------+ +| array_sort(List([1,2,3]),Int64(4)) | ++--------------------------------------+ +| [3, 2, 1, ] | ++--------------------------------------+ +``` + +#### Aliases + +- list_sort + ### `array_cat` _Alias of [array_concat](#array_concat)._ @@ -2359,6 +2392,10 @@ empty(array) _Alias of [array_append](#array_append)._ +### `list_sort` + +_Alias of [array_sort](#array_sort)._ + ### `list_cat` _Alias of [array_concat](#array_concat)._ From 32b406e6c81b17cf075a3a8a434f915d51ca5c2e Mon Sep 17 00:00:00 2001 From: asura7969 <1402357969@qq.com> Date: Mon, 20 Nov 2023 21:13:09 +0800 Subject: [PATCH 03/16] fix: example doc --- datafusion/proto/src/generated/prost.rs | 1 + docs/source/user-guide/sql/scalar_functions.md | 12 ++++++------ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 3560e25b4d0f..d8bc486d0fa7 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2668,6 +2668,7 @@ impl ScalarFunction { ScalarFunction::Lcm => "Lcm", ScalarFunction::Gcd => "Gcd", ScalarFunction::ArrayAppend => "ArrayAppend", + ScalarFunction::ArraySort => "ArraySort", ScalarFunction::ArrayConcat => "ArrayConcat", ScalarFunction::ArrayDims => "ArrayDims", ScalarFunction::ArrayRepeat => "ArrayRepeat", diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 67b127e9e82e..513d337ae8df 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1630,12 +1630,12 @@ array_sort(array, true, false) #### Example ``` -❯ select array_sort([3, 1, null, 2], true, false); -+--------------------------------------+ -| array_sort(List([1,2,3]),Int64(4)) | -+--------------------------------------+ -| [3, 2, 1, ] | -+--------------------------------------+ +❯ select array_sort([3, 1, 2]); ++-----------------------------+ +| array_sort(List([3,1,2])) | ++-----------------------------+ +| [3, 2, 1] | ++-----------------------------+ ``` #### Aliases From e87930271751975c3bc24e301568e99f6d457e6c Mon Sep 17 00:00:00 2001 From: asura7969 <1402357969@qq.com> Date: Mon, 20 Nov 2023 22:49:58 +0800 Subject: [PATCH 04/16] fix: ci --- datafusion/expr/src/expr_fn.rs | 4 ++-- .../physical-expr/src/array_expressions.rs | 1 - datafusion/proto/src/logical_plan/from_proto.rs | 16 ++++++++-------- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 5e685f003294..89dd038b24ee 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -583,7 +583,7 @@ scalar_expr!( "appends an element to the end of an array." ); -nary_scalar_expr!(ArraySort, array_sort, "returns sorted array."); +scalar_expr!(ArraySort, array_sort, array desc null_first, "returns sorted array."); scalar_expr!( ArrayPopBack, @@ -1182,7 +1182,7 @@ mod test { test_scalar_expr!(FromUnixtime, from_unixtime, unixtime); test_scalar_expr!(ArrayAppend, array_append, array, element); - test_scalar_expr!(ArraySort, array_sort, array, element); + test_scalar_expr!(ArraySort, array_sort, array, desc, null_first); test_scalar_expr!(ArrayPopFront, array_pop_front, array); test_scalar_expr!(ArrayPopBack, array_pop_back, array); test_unary_scalar_expr!(ArrayDims, array_dims); diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index f88cbd9037ce..ba78a81ce92b 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -795,7 +795,6 @@ pub fn array_append(args: &[ArrayRef]) -> Result { /// Array_sort SQL function pub fn array_sort(args: &[ArrayRef]) -> Result { - let sort_option = match args.len() { 1 => None, 2 => { diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 2509ca6446d1..9c017b1687aa 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -44,10 +44,11 @@ use datafusion_expr::{ array_except, array_has, array_has_all, array_has_any, array_intersect, array_length, array_ndims, array_position, array_positions, array_prepend, array_remove, array_remove_all, array_remove_n, array_repeat, array_replace, array_replace_all, - array_replace_n, array_slice, array_sort, array_to_string, arrow_typeof, ascii, asin, asinh, - atan, atan2, atanh, bit_length, btrim, cardinality, cbrt, ceil, character_length, - chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, current_date, - current_time, date_bin, date_part, date_trunc, decode, degrees, digest, encode, exp, + array_replace_n, array_slice, array_sort, array_to_string, arrow_typeof, ascii, asin, + asinh, atan, atan2, atanh, bit_length, btrim, cardinality, cbrt, ceil, + character_length, chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, + current_date, current_time, date_bin, date_part, date_trunc, decode, degrees, digest, + encode, exp, expr::{self, InList, Sort, WindowFunction}, factorial, flatten, floor, from_unixtime, gcd, gen_range, isnan, iszero, lcm, left, levenshtein, ln, log, log10, log2, @@ -1340,10 +1341,9 @@ pub fn parse_expr( parse_expr(&args[1], registry)?, )), ScalarFunction::ArraySort => Ok(array_sort( - args.to_owned() - .iter() - .map(|expr| parse_expr(expr, registry)) - .collect::, _>>()?, + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + parse_expr(&args[2], registry)?, )), ScalarFunction::ArrayPopFront => { Ok(array_pop_front(parse_expr(&args[0], registry)?)) From 116134a779ffcfe556302944c009731f92d0a6d0 Mon Sep 17 00:00:00 2001 From: asura7969 <1402357969@qq.com> Date: Mon, 20 Nov 2023 23:24:02 +0800 Subject: [PATCH 05/16] fix: doc error --- datafusion/physical-expr/src/array_expressions.rs | 3 +-- docs/source/user-guide/sql/scalar_functions.md | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index ba78a81ce92b..f9feb1b2a903 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -817,8 +817,7 @@ pub fn array_sort(args: &[ArrayRef]) -> Result { let list_array = as_list_array(&args[0])?; - let sorted_array = - arrow_ord::sort::sort(list_array.values(), sort_option.clone()).unwrap(); + let sorted_array = arrow_ord::sort::sort(list_array.values(), sort_option).unwrap(); Ok(Arc::new(array_into_list_array(sorted_array))) } diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 513d337ae8df..696a30ff2fb6 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1612,6 +1612,7 @@ array_append(array, element) - list_push_back ### `array_sort` + Sort array. ``` From 52f0afeab96f5f425930e06040a3c774c366e2e6 Mon Sep 17 00:00:00 2001 From: Asura7969 <1402357969@qq.com> Date: Tue, 21 Nov 2023 08:45:48 +0800 Subject: [PATCH 06/16] fix pb --- datafusion/proto/src/generated/pbjson.rs | 3 --- datafusion/proto/src/generated/prost.rs | 3 --- 2 files changed, 6 deletions(-) diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 1ad2de2b9e7a..dc7d1fb16147 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -20824,7 +20824,6 @@ impl serde::Serialize for ScalarFunction { Self::Lcm => "Lcm", Self::Gcd => "Gcd", Self::ArrayAppend => "ArrayAppend", - Self::ArraySort => "ArraySort", Self::ArrayConcat => "ArrayConcat", Self::ArrayDims => "ArrayDims", Self::ArrayRepeat => "ArrayRepeat", @@ -20962,7 +20961,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Lcm", "Gcd", "ArrayAppend", - "ArraySort", "ArrayConcat", "ArrayDims", "ArrayRepeat", @@ -21129,7 +21127,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Lcm" => Ok(ScalarFunction::Lcm), "Gcd" => Ok(ScalarFunction::Gcd), "ArrayAppend" => Ok(ScalarFunction::ArrayAppend), - "ArraySort" => Ok(ScalarFunction::ArraySort), "ArrayConcat" => Ok(ScalarFunction::ArrayConcat), "ArrayDims" => Ok(ScalarFunction::ArrayDims), "ArrayRepeat" => Ok(ScalarFunction::ArrayRepeat), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index e476ff7dedf2..4fb8e1599e4b 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2592,7 +2592,6 @@ pub enum ScalarFunction { ArrayExcept = 123, ArrayPopFront = 124, Levenshtein = 125, - ArraySort = 126, } impl ScalarFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2688,7 +2687,6 @@ impl ScalarFunction { ScalarFunction::Lcm => "Lcm", ScalarFunction::Gcd => "Gcd", ScalarFunction::ArrayAppend => "ArrayAppend", - ScalarFunction::ArraySort => "ArraySort", ScalarFunction::ArrayConcat => "ArrayConcat", ScalarFunction::ArrayDims => "ArrayDims", ScalarFunction::ArrayRepeat => "ArrayRepeat", @@ -2820,7 +2818,6 @@ impl ScalarFunction { "Lcm" => Some(Self::Lcm), "Gcd" => Some(Self::Gcd), "ArrayAppend" => Some(Self::ArrayAppend), - "ArraySort" => Some(Self::ArraySort), "ArrayConcat" => Some(Self::ArrayConcat), "ArrayDims" => Some(Self::ArrayDims), "ArrayRepeat" => Some(Self::ArrayRepeat), From 15b41b13ef8a55bc3223749ec6a338aba778673f Mon Sep 17 00:00:00 2001 From: Asura7969 <1402357969@qq.com> Date: Tue, 21 Nov 2023 10:21:22 +0800 Subject: [PATCH 07/16] like DuckDB function semantics --- .../physical-expr/src/array_expressions.rs | 72 +++++++------------ datafusion/proto/src/generated/prost.rs | 2 + datafusion/sqllogictest/test_files/array.slt | 13 ++-- 3 files changed, 35 insertions(+), 52 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index e1539f150f90..921b46a7bb49 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -798,18 +798,18 @@ pub fn array_sort(args: &[ArrayRef]) -> Result { let sort_option = match args.len() { 1 => None, 2 => { - let sort = as_boolean_array(&args[1]); + let sort = as_string_array(&args[1])?.value(0); Some(SortOptions { - descending: sort.value(0), + descending: order_desc(sort)?, nulls_first: true, }) } 3 => { - let sort = as_boolean_array(&args[1]); - let nulls_first = as_boolean_array(&args[2]); + let sort = as_string_array(&args[1])?.value(0); + let nulls_first = as_string_array(&args[2])?.value(0); Some(SortOptions { - descending: sort.value(0), - nulls_first: nulls_first.value(0), + descending: order_desc(sort)?, + nulls_first: order_nulls_first(nulls_first)?, }) } _ => return internal_err!("array_sort expects 1 to 3 arguments"), @@ -822,6 +822,24 @@ pub fn array_sort(args: &[ArrayRef]) -> Result { Ok(Arc::new(array_into_list_array(sorted_array))) } +fn order_desc(modifier: &str) -> Result { + match modifier.to_uppercase().as_str() { + "DESC" => Ok(true), + "ASC" => Ok(false), + _ => internal_err!("the second parameter of array_sort expects DESC or ASC"), + } +} + +fn order_nulls_first(modifier: &str) -> Result { + match modifier.to_uppercase().as_str() { + "NULLS FIRST" => Ok(true), + "NULLS LAST" => Ok(false), + _ => internal_err!( + "the third parameter of array_sort expects NULLS FIRST or NULLS LAST" + ), + } +} + /// Array_prepend SQL function pub fn array_prepend(args: &[ArrayRef]) -> Result { let list_array = as_list_array(&args[1])?; @@ -2619,48 +2637,6 @@ mod tests { ); } - #[test] - fn test_array_sort() { - // array_sort([2, 3, NULL, 4]) = [NULL, 2, 3, 4] - let data = vec![Some(vec![Some(2), Some(3), None, Some(4)])]; - let array = - Arc::new(ListArray::from_iter_primitive::(data)) as ArrayRef; - - let result = array_sort(&[array.clone()]) - .expect("failed to initialize function array_sort"); - - let result = - as_list_array(&result).expect("failed to initialize function array_sort"); - - assert_eq!( - &[0, 2, 3, 4], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - - // array_sort([2, 3, NULL, 4], true, false) = [4, 3, 2, NULL] - let desc = Arc::new(BooleanArray::from(vec![true])) as ArrayRef; - let null_first = Arc::new(BooleanArray::from(vec![false])) as ArrayRef; - let result = array_sort(&[array.clone(), desc, null_first]) - .expect("failed to initialize function array_sort"); - - let result = as_list_array(&result) - .expect("failed to initialize function array_to_string"); - assert_eq!( - &[4, 3, 2, 0], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - #[test] fn test_array_prepend() { // array_prepend(1, [2, 3, 4]) = [1, 2, 3, 4] diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 4fb8e1599e4b..32a36bd2cdb5 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2592,6 +2592,7 @@ pub enum ScalarFunction { ArrayExcept = 123, ArrayPopFront = 124, Levenshtein = 125, + ArraySort = 126, } impl ScalarFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2687,6 +2688,7 @@ impl ScalarFunction { ScalarFunction::Lcm => "Lcm", ScalarFunction::Gcd => "Gcd", ScalarFunction::ArrayAppend => "ArrayAppend", + ScalarFunction::ArraySort => "ArraySort", ScalarFunction::ArrayConcat => "ArrayConcat", ScalarFunction::ArrayDims => "ArrayDims", ScalarFunction::ArrayRepeat => "ArrayRepeat", diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 0d6c0aa2d8f4..37a41cdb56fb 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -1046,16 +1046,21 @@ select make_array(['a','b'], null); ## array_sort (aliases: `list_sort`) query ??? -select array_sort(make_array(4, 2, 3, 1)), array_sort(make_array(1.0, 4.0, null, 3.0), true, true), array_sort(make_array('a', 'd', 'c', null, 'b'), false, false); +select array_sort(make_array(1, 3, null, 5, NULL, -5)), array_sort(make_array(1, 3, null, 2), 'ASC'), array_sort(make_array(1, 3, null, 2), 'desc', 'NULLS FIRST'); ---- -[1, 2, 3, 4] [, 4.0, 3.0, 1.0] [a, b, c, d, ] +[, , -5, 1, 3, 5] [, 1, 2, 3] [, 3, 2, 1] + +query ? +select array_sort(make_array(a,d,e), 'DESC', 'NULLS LAST') from values; +---- +[sit, ipsum, dolor, consectetur, amet, adipiscing, Lorem, 8.8, 8, 7.7, 7, 6.6, 6, 5.5, 5, 4.4, 4, 3.3, 3, 2.2, 2, 1.1, 1, ,, , , ] ## list_sort (aliases: `array_sort`) query ??? -select list_sort(make_array(4, 2, 3, 1)), list_sort(make_array(1.0, 4.0, null, 3.0), true, true), list_sort(make_array('a', 'd', 'c', null, 'b'), false, false); +select list_sort(make_array(1, 3, null, 5, NULL, -5)), list_sort(make_array(1, 3, null, 2), 'ASC'), list_sort(make_array(1, 3, null, 2), 'desc', 'NULLS FIRST'); ---- -[1, 2, 3, 4] [, 4.0, 3.0, 1.0] [a, b, c, d, ] +[, , -5, 1, 3, 5] [, 1, 2, 3] [, 3, 2, 1] ## array_append (aliases: `list_append`, `array_push_back`, `list_push_back`) From b7ad7f0349c54f24b611de2ebca697325222d376 Mon Sep 17 00:00:00 2001 From: Asura7969 <1402357969@qq.com> Date: Tue, 21 Nov 2023 10:48:46 +0800 Subject: [PATCH 08/16] fix ci --- datafusion/proto/src/generated/pbjson.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index dc7d1fb16147..a539b8c34328 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -20863,6 +20863,7 @@ impl serde::Serialize for ScalarFunction { Self::ArrayExcept => "ArrayExcept", Self::ArrayPopFront => "ArrayPopFront", Self::Levenshtein => "Levenshtein", + Self::ArraySort => "ArraySort", }; serializer.serialize_str(variant) } From 504655da956e78358b11d15dab77f6454ed1f905 Mon Sep 17 00:00:00 2001 From: Asura7969 <1402357969@qq.com> Date: Tue, 21 Nov 2023 10:55:23 +0800 Subject: [PATCH 09/16] fix pb --- datafusion/proto/proto/datafusion.proto | 1 + datafusion/proto/src/generated/pbjson.rs | 98 ++++++++++++------------ datafusion/proto/src/generated/prost.rs | 3 +- 3 files changed, 53 insertions(+), 49 deletions(-) diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 9197343d749e..3e867c4a87f0 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -641,6 +641,7 @@ enum ScalarFunction { ArrayExcept = 123; ArrayPopFront = 124; Levenshtein = 125; + ArraySort = 126; } message ScalarFunctionNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index a539b8c34328..5d5ec0fa5556 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -2878,7 +2878,7 @@ impl<'de> serde::Deserialize<'de> for CoalesceBatchesExecNode { if target_batch_size__.is_some() { return Err(serde::de::Error::duplicate_field("targetBatchSize")); } - target_batch_size__ = + target_batch_size__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -3183,7 +3183,7 @@ impl<'de> serde::Deserialize<'de> for ColumnIndex { if index__.is_some() { return Err(serde::de::Error::duplicate_field("index")); } - index__ = + index__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -5204,7 +5204,7 @@ impl<'de> serde::Deserialize<'de> for CustomTableScanNode { if custom_table_data__.is_some() { return Err(serde::de::Error::duplicate_field("customTableData")); } - custom_table_data__ = + custom_table_data__ = Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) ; } @@ -5382,7 +5382,7 @@ impl<'de> serde::Deserialize<'de> for Decimal { if precision__.is_some() { return Err(serde::de::Error::duplicate_field("precision")); } - precision__ = + precision__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -5390,7 +5390,7 @@ impl<'de> serde::Deserialize<'de> for Decimal { if scale__.is_some() { return Err(serde::de::Error::duplicate_field("scale")); } - scale__ = + scale__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -5507,7 +5507,7 @@ impl<'de> serde::Deserialize<'de> for Decimal128 { if value__.is_some() { return Err(serde::de::Error::duplicate_field("value")); } - value__ = + value__ = Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) ; } @@ -5515,7 +5515,7 @@ impl<'de> serde::Deserialize<'de> for Decimal128 { if p__.is_some() { return Err(serde::de::Error::duplicate_field("p")); } - p__ = + p__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -5523,7 +5523,7 @@ impl<'de> serde::Deserialize<'de> for Decimal128 { if s__.is_some() { return Err(serde::de::Error::duplicate_field("s")); } - s__ = + s__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -5641,7 +5641,7 @@ impl<'de> serde::Deserialize<'de> for Decimal256 { if value__.is_some() { return Err(serde::de::Error::duplicate_field("value")); } - value__ = + value__ = Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) ; } @@ -5649,7 +5649,7 @@ impl<'de> serde::Deserialize<'de> for Decimal256 { if p__.is_some() { return Err(serde::de::Error::duplicate_field("p")); } - p__ = + p__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -5657,7 +5657,7 @@ impl<'de> serde::Deserialize<'de> for Decimal256 { if s__.is_some() { return Err(serde::de::Error::duplicate_field("s")); } - s__ = + s__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -7211,7 +7211,7 @@ impl<'de> serde::Deserialize<'de> for FileRange { if start__.is_some() { return Err(serde::de::Error::duplicate_field("start")); } - start__ = + start__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -7219,7 +7219,7 @@ impl<'de> serde::Deserialize<'de> for FileRange { if end__.is_some() { return Err(serde::de::Error::duplicate_field("end")); } - end__ = + end__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -7399,7 +7399,7 @@ impl<'de> serde::Deserialize<'de> for FileScanExecConf { if projection__.is_some() { return Err(serde::de::Error::duplicate_field("projection")); } - projection__ = + projection__ = Some(map_.next_value::>>()? .into_iter().map(|x| x.0).collect()) ; @@ -7970,7 +7970,7 @@ impl<'de> serde::Deserialize<'de> for FixedSizeBinary { if length__.is_some() { return Err(serde::de::Error::duplicate_field("length")); } - length__ = + length__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -8081,7 +8081,7 @@ impl<'de> serde::Deserialize<'de> for FixedSizeList { if list_size__.is_some() { return Err(serde::de::Error::duplicate_field("listSize")); } - list_size__ = + list_size__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -8469,7 +8469,7 @@ impl<'de> serde::Deserialize<'de> for GlobalLimitExecNode { if skip__.is_some() { return Err(serde::de::Error::duplicate_field("skip")); } - skip__ = + skip__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -8477,7 +8477,7 @@ impl<'de> serde::Deserialize<'de> for GlobalLimitExecNode { if fetch__.is_some() { return Err(serde::de::Error::duplicate_field("fetch")); } - fetch__ = + fetch__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -8882,7 +8882,7 @@ impl<'de> serde::Deserialize<'de> for HashRepartition { if partition_count__.is_some() { return Err(serde::de::Error::duplicate_field("partitionCount")); } - partition_count__ = + partition_count__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -9265,7 +9265,7 @@ impl<'de> serde::Deserialize<'de> for IntervalMonthDayNanoValue { if months__.is_some() { return Err(serde::de::Error::duplicate_field("months")); } - months__ = + months__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -9273,7 +9273,7 @@ impl<'de> serde::Deserialize<'de> for IntervalMonthDayNanoValue { if days__.is_some() { return Err(serde::de::Error::duplicate_field("days")); } - days__ = + days__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -9281,7 +9281,7 @@ impl<'de> serde::Deserialize<'de> for IntervalMonthDayNanoValue { if nanos__.is_some() { return Err(serde::de::Error::duplicate_field("nanos")); } - nanos__ = + nanos__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -11361,7 +11361,7 @@ impl<'de> serde::Deserialize<'de> for LimitNode { if skip__.is_some() { return Err(serde::de::Error::duplicate_field("skip")); } - skip__ = + skip__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -11369,7 +11369,7 @@ impl<'de> serde::Deserialize<'de> for LimitNode { if fetch__.is_some() { return Err(serde::de::Error::duplicate_field("fetch")); } - fetch__ = + fetch__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -12124,7 +12124,7 @@ impl<'de> serde::Deserialize<'de> for ListingTableScanNode { if target_partitions__.is_some() { return Err(serde::de::Error::duplicate_field("targetPartitions")); } - target_partitions__ = + target_partitions__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -12270,7 +12270,7 @@ impl<'de> serde::Deserialize<'de> for LocalLimitExecNode { if fetch__.is_some() { return Err(serde::de::Error::duplicate_field("fetch")); } - fetch__ = + fetch__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -13104,7 +13104,7 @@ impl<'de> serde::Deserialize<'de> for LogicalExtensionNode { if node__.is_some() { return Err(serde::de::Error::duplicate_field("node")); } - node__ = + node__ = Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) ; } @@ -15046,7 +15046,7 @@ impl<'de> serde::Deserialize<'de> for PartiallySortedPartitionSearchMode { if columns__.is_some() { return Err(serde::de::Error::duplicate_field("columns")); } - columns__ = + columns__ = Some(map_.next_value::>>()? .into_iter().map(|x| x.0).collect()) ; @@ -15360,7 +15360,7 @@ impl<'de> serde::Deserialize<'de> for PartitionStats { if num_rows__.is_some() { return Err(serde::de::Error::duplicate_field("numRows")); } - num_rows__ = + num_rows__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -15368,7 +15368,7 @@ impl<'de> serde::Deserialize<'de> for PartitionStats { if num_batches__.is_some() { return Err(serde::de::Error::duplicate_field("numBatches")); } - num_batches__ = + num_batches__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -15376,7 +15376,7 @@ impl<'de> serde::Deserialize<'de> for PartitionStats { if num_bytes__.is_some() { return Err(serde::de::Error::duplicate_field("numBytes")); } - num_bytes__ = + num_bytes__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -15528,7 +15528,7 @@ impl<'de> serde::Deserialize<'de> for PartitionedFile { if size__.is_some() { return Err(serde::de::Error::duplicate_field("size")); } - size__ = + size__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -15536,7 +15536,7 @@ impl<'de> serde::Deserialize<'de> for PartitionedFile { if last_modified_ns__.is_some() { return Err(serde::de::Error::duplicate_field("lastModifiedNs")); } - last_modified_ns__ = + last_modified_ns__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -16293,7 +16293,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalColumn { if index__.is_some() { return Err(serde::de::Error::duplicate_field("index")); } - index__ = + index__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -16853,7 +16853,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExtensionNode { if node__.is_some() { return Err(serde::de::Error::duplicate_field("node")); } - node__ = + node__ = Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) ; } @@ -17114,7 +17114,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalHashRepartition { if partition_count__.is_some() { return Err(serde::de::Error::duplicate_field("partitionCount")); } - partition_count__ = + partition_count__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -19789,7 +19789,7 @@ impl<'de> serde::Deserialize<'de> for PrimaryKeyConstraint { if indices__.is_some() { return Err(serde::de::Error::duplicate_field("indices")); } - indices__ = + indices__ = Some(map_.next_value::>>()? .into_iter().map(|x| x.0).collect()) ; @@ -20707,7 +20707,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFixedSizeBinary { if values__.is_some() { return Err(serde::de::Error::duplicate_field("values")); } - values__ = + values__ = Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) ; } @@ -20715,7 +20715,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFixedSizeBinary { if length__.is_some() { return Err(serde::de::Error::duplicate_field("length")); } - length__ = + length__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -21001,6 +21001,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayExcept", "ArrayPopFront", "Levenshtein", + "ArraySort", ]; struct GeneratedVisitor; @@ -21167,6 +21168,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayExcept" => Ok(ScalarFunction::ArrayExcept), "ArrayPopFront" => Ok(ScalarFunction::ArrayPopFront), "Levenshtein" => Ok(ScalarFunction::Levenshtein), + "ArraySort" => Ok(ScalarFunction::ArraySort), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } @@ -21387,7 +21389,7 @@ impl<'de> serde::Deserialize<'de> for ScalarListValue { if ipc_message__.is_some() { return Err(serde::de::Error::duplicate_field("ipcMessage")); } - ipc_message__ = + ipc_message__ = Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) ; } @@ -21395,7 +21397,7 @@ impl<'de> serde::Deserialize<'de> for ScalarListValue { if arrow_data__.is_some() { return Err(serde::de::Error::duplicate_field("arrowData")); } - arrow_data__ = + arrow_data__ = Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) ; } @@ -22541,7 +22543,7 @@ impl<'de> serde::Deserialize<'de> for ScanLimit { if limit__.is_some() { return Err(serde::de::Error::duplicate_field("limit")); } - limit__ = + limit__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -23130,7 +23132,7 @@ impl<'de> serde::Deserialize<'de> for SortExecNode { if fetch__.is_some() { return Err(serde::de::Error::duplicate_field("fetch")); } - fetch__ = + fetch__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -23391,7 +23393,7 @@ impl<'de> serde::Deserialize<'de> for SortNode { if fetch__.is_some() { return Err(serde::de::Error::duplicate_field("fetch")); } - fetch__ = + fetch__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -23519,7 +23521,7 @@ impl<'de> serde::Deserialize<'de> for SortPreservingMergeExecNode { if fetch__.is_some() { return Err(serde::de::Error::duplicate_field("fetch")); } - fetch__ = + fetch__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } @@ -24765,7 +24767,7 @@ impl<'de> serde::Deserialize<'de> for Union { if type_ids__.is_some() { return Err(serde::de::Error::duplicate_field("typeIds")); } - type_ids__ = + type_ids__ = Some(map_.next_value::>>()? .into_iter().map(|x| x.0).collect()) ; @@ -25114,7 +25116,7 @@ impl<'de> serde::Deserialize<'de> for UniqueConstraint { if indices__.is_some() { return Err(serde::de::Error::duplicate_field("indices")); } - indices__ = + indices__ = Some(map_.next_value::>>()? .into_iter().map(|x| x.0).collect()) ; @@ -25221,7 +25223,7 @@ impl<'de> serde::Deserialize<'de> for ValuesNode { if n_cols__.is_some() { return Err(serde::de::Error::duplicate_field("nCols")); } - n_cols__ = + n_cols__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 32a36bd2cdb5..390d7b726e8c 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2688,7 +2688,6 @@ impl ScalarFunction { ScalarFunction::Lcm => "Lcm", ScalarFunction::Gcd => "Gcd", ScalarFunction::ArrayAppend => "ArrayAppend", - ScalarFunction::ArraySort => "ArraySort", ScalarFunction::ArrayConcat => "ArrayConcat", ScalarFunction::ArrayDims => "ArrayDims", ScalarFunction::ArrayRepeat => "ArrayRepeat", @@ -2728,6 +2727,7 @@ impl ScalarFunction { ScalarFunction::ArrayExcept => "ArrayExcept", ScalarFunction::ArrayPopFront => "ArrayPopFront", ScalarFunction::Levenshtein => "Levenshtein", + ScalarFunction::ArraySort => "ArraySort", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2859,6 +2859,7 @@ impl ScalarFunction { "ArrayExcept" => Some(Self::ArrayExcept), "ArrayPopFront" => Some(Self::ArrayPopFront), "Levenshtein" => Some(Self::Levenshtein), + "ArraySort" => Some(Self::ArraySort), _ => None, } } From 85c25d93ced7b5cb27ebbf5221b027a33b30a44a Mon Sep 17 00:00:00 2001 From: asura7969 <1402357969@qq.com> Date: Tue, 21 Nov 2023 21:01:17 +0800 Subject: [PATCH 10/16] fix: doc --- docs/source/user-guide/sql/scalar_functions.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 696a30ff2fb6..ac3c9d2b50a9 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1617,16 +1617,16 @@ Sort array. ``` array_sort(array) -array_sort(array, false) -array_sort(array, true, false) +array_sort(array, 'ASC') +array_sort(array, 'DESC', 'NULLS FIRST') ``` #### Arguments - **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. -- **desc**: Whether to sort in descending order. -- **nulls_first**: Whether to sort nulls first. +- **desc**: Whether to sort in descending order(`ASC` or `DESC`). +- **nulls_first**: Whether to sort nulls first(`NULLS FIRST` or `NULLS LAST`). #### Example @@ -1635,7 +1635,7 @@ array_sort(array, true, false) +-----------------------------+ | array_sort(List([3,1,2])) | +-----------------------------+ -| [3, 2, 1] | +| [1, 2, 3] | +-----------------------------+ ``` From a7e51e1da479e41a743504f9444df9d889fc1cff Mon Sep 17 00:00:00 2001 From: Asura7969 <1402357969@qq.com> Date: Wed, 22 Nov 2023 09:00:15 +0800 Subject: [PATCH 11/16] add table test --- datafusion/sqllogictest/test_files/array.slt | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 37a41cdb56fb..860b7526e3f2 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -1051,9 +1051,14 @@ select array_sort(make_array(1, 3, null, 5, NULL, -5)), array_sort(make_array(1, [, , -5, 1, 3, 5] [, 1, 2, 3] [, 3, 2, 1] query ? -select array_sort(make_array(a,d,e), 'DESC', 'NULLS LAST') from values; +select array_sort(column1, 'DESC', 'NULLS LAST') from arrays_values; ---- -[sit, ipsum, dolor, consectetur, amet, adipiscing, Lorem, 8.8, 8, 7.7, 7, 6.6, 6, 5.5, 5, 4.4, 4, 3.3, 3, 2.2, 2, 1.1, 1, ,, , , ] +[70, 69, 68, 67, 66, 65, 64, 63, 62, 61, 60, 59, 58, 57, 56, 55, 54, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 23, 22, 21, 20, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, , , , , ] + +query ? +select array_sort(column1, 'ASC', 'NULLS FIRST') from arrays_values; +---- +[, , , , , 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 20, 21, 22, 23, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70] ## list_sort (aliases: `array_sort`) From 29f2d6256703e92b3cf67746a20f692845378732 Mon Sep 17 00:00:00 2001 From: Asura7969 <1402357969@qq.com> Date: Thu, 23 Nov 2023 11:09:09 +0800 Subject: [PATCH 12/16] fix: not as expected --- .../physical-expr/src/array_expressions.rs | 15 +++++++++++---- datafusion/sqllogictest/test_files/array.slt | 18 ++++++++++++++++-- 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 921b46a7bb49..c18eb026d6d0 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -32,7 +32,7 @@ use arrow_schema::{FieldRef, SortOptions}; use datafusion_common::cast::{ as_generic_string_array, as_int64_array, as_list_array, as_string_array, }; -use datafusion_common::utils::array_into_list_array; +use datafusion_common::utils::{array_into_list_array, arrays_into_list_array}; use datafusion_common::{ exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, DataFusionError, Result, @@ -816,10 +816,17 @@ pub fn array_sort(args: &[ArrayRef]) -> Result { }; let list_array = as_list_array(&args[0])?; + let default_empty = ArrayData::new_empty(&list_array.value_type()); + let sorted = list_array + .iter() + .map(|array| { + array.map_or(arrow::array::make_array(default_empty.clone()), |arr_ref| { + arrow_ord::sort::sort(&arr_ref, sort_option).unwrap() + }) + }) + .collect::>(); - let sorted_array = arrow_ord::sort::sort(list_array.values(), sort_option).unwrap(); - - Ok(Arc::new(array_into_list_array(sorted_array))) + Ok(Arc::new(arrays_into_list_array(sorted)?)) } fn order_desc(modifier: &str) -> Result { diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 860b7526e3f2..32134e51857e 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -1053,12 +1053,26 @@ select array_sort(make_array(1, 3, null, 5, NULL, -5)), array_sort(make_array(1, query ? select array_sort(column1, 'DESC', 'NULLS LAST') from arrays_values; ---- -[70, 69, 68, 67, 66, 65, 64, 63, 62, 61, 60, 59, 58, 57, 56, 55, 54, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 23, 22, 21, 20, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, , , , , ] +[10, 9, 8, 7, 6, 5, 4, 3, 2, ] +[20, 18, 17, 16, 15, 14, 13, 12, 11, ] +[30, 29, 28, 27, 26, 25, 23, 22, 21, ] +[40, 39, 38, 37, 35, 34, 33, 32, 31, ] +[] +[50, 49, 48, 47, 46, 45, 44, 43, 42, 41] +[60, 59, 58, 57, 56, 55, 54, 52, 51, ] +[70, 69, 68, 67, 66, 65, 64, 63, 62, 61] query ? select array_sort(column1, 'ASC', 'NULLS FIRST') from arrays_values; ---- -[, , , , , 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 20, 21, 22, 23, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70] +[, 2, 3, 4, 5, 6, 7, 8, 9, 10] +[, 11, 12, 13, 14, 15, 16, 17, 18, 20] +[, 21, 22, 23, 25, 26, 27, 28, 29, 30] +[, 31, 32, 33, 34, 35, 37, 38, 39, 40] +[] +[41, 42, 43, 44, 45, 46, 47, 48, 49, 50] +[, 51, 52, 54, 55, 56, 57, 58, 59, 60] +[61, 62, 63, 64, 65, 66, 67, 68, 69, 70] ## list_sort (aliases: `array_sort`) From d369caa708fd59fef1f11699f73ed11a124b5f81 Mon Sep 17 00:00:00 2001 From: asura7969 <1402357969@qq.com> Date: Thu, 23 Nov 2023 23:42:28 +0800 Subject: [PATCH 13/16] fix: return null --- .../physical-expr/src/array_expressions.rs | 45 ++++++++++++++----- datafusion/sqllogictest/test_files/array.slt | 4 +- 2 files changed, 37 insertions(+), 12 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index c18eb026d6d0..7d142e2262bb 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -32,7 +32,7 @@ use arrow_schema::{FieldRef, SortOptions}; use datafusion_common::cast::{ as_generic_string_array, as_int64_array, as_list_array, as_string_array, }; -use datafusion_common::utils::{array_into_list_array, arrays_into_list_array}; +use datafusion_common::utils::array_into_list_array; use datafusion_common::{ exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, DataFusionError, Result, @@ -816,17 +816,42 @@ pub fn array_sort(args: &[ArrayRef]) -> Result { }; let list_array = as_list_array(&args[0])?; - let default_empty = ArrayData::new_empty(&list_array.value_type()); - let sorted = list_array + let row_count = list_array.len(); + + let mut array_lengths = vec![]; + let mut arrays = vec![]; + let mut valid = BooleanBufferBuilder::new(row_count); + for i in 0..row_count { + if list_array.is_null(i) { + array_lengths.push(0); + valid.append(false); + } else { + let arr_ref = list_array.value(i); + let arr_ref = arr_ref.as_ref(); + + let sorted_array = compute::sort(arr_ref, sort_option)?; + array_lengths.push(sorted_array.len()); + arrays.push(sorted_array); + valid.append(true); + } + } + + // Assume all arrays have the same data type + let data_type = list_array.value_type(); + let buffer = valid.finish(); + + let elements = arrays .iter() - .map(|array| { - array.map_or(arrow::array::make_array(default_empty.clone()), |arr_ref| { - arrow_ord::sort::sort(&arr_ref, sort_option).unwrap() - }) - }) - .collect::>(); + .map(|a| a.as_ref()) + .collect::>(); - Ok(Arc::new(arrays_into_list_array(sorted)?)) + let list_arr = ListArray::new( + Arc::new(Field::new("item", data_type, true)), + OffsetBuffer::from_lengths(array_lengths), + Arc::new(compute::concat(elements.as_slice())?), + Some(NullBuffer::new(buffer)), + ); + Ok(Arc::new(list_arr)) } fn order_desc(modifier: &str) -> Result { diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 32134e51857e..6ebaa35f7e97 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -1057,7 +1057,7 @@ select array_sort(column1, 'DESC', 'NULLS LAST') from arrays_values; [20, 18, 17, 16, 15, 14, 13, 12, 11, ] [30, 29, 28, 27, 26, 25, 23, 22, 21, ] [40, 39, 38, 37, 35, 34, 33, 32, 31, ] -[] +NULL [50, 49, 48, 47, 46, 45, 44, 43, 42, 41] [60, 59, 58, 57, 56, 55, 54, 52, 51, ] [70, 69, 68, 67, 66, 65, 64, 63, 62, 61] @@ -1069,7 +1069,7 @@ select array_sort(column1, 'ASC', 'NULLS FIRST') from arrays_values; [, 11, 12, 13, 14, 15, 16, 17, 18, 20] [, 21, 22, 23, 25, 26, 27, 28, 29, 30] [, 31, 32, 33, 34, 35, 37, 38, 39, 40] -[] +NULL [41, 42, 43, 44, 45, 46, 47, 48, 49, 50] [, 51, 52, 54, 55, 56, 57, 58, 59, 60] [61, 62, 63, 64, 65, 66, 67, 68, 69, 70] From dea30c3a5184d8e41fce5a26e9bb828fdf49d07e Mon Sep 17 00:00:00 2001 From: Asura7969 <1402357969@qq.com> Date: Tue, 28 Nov 2023 15:32:54 +0800 Subject: [PATCH 14/16] resolve conflicts --- datafusion/proto/src/generated/pbjson.rs | 3 +++ datafusion/proto/src/generated/prost.rs | 3 +++ 2 files changed, 6 insertions(+) diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 598719dc8ac6..d614337c9fe3 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -20864,6 +20864,7 @@ impl serde::Serialize for ScalarFunction { Self::ArrayPopFront => "ArrayPopFront", Self::Levenshtein => "Levenshtein", Self::SubstrIndex => "SubstrIndex", + Self::ArraySort => "ArraySort", }; serializer.serialize_str(variant) } @@ -21002,6 +21003,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayPopFront", "Levenshtein", "SubstrIndex", + "ArraySort", ]; struct GeneratedVisitor; @@ -21169,6 +21171,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayPopFront" => Ok(ScalarFunction::ArrayPopFront), "Levenshtein" => Ok(ScalarFunction::Levenshtein), "SubstrIndex" => Ok(ScalarFunction::SubstrIndex), + "ArraySort" => Ok(ScalarFunction::ArraySort), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index e79a17fc5c9c..01790ba60f10 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2595,6 +2595,7 @@ pub enum ScalarFunction { ArrayPopFront = 124, Levenshtein = 125, SubstrIndex = 126, + ArraySort = 127, } impl ScalarFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2730,6 +2731,7 @@ impl ScalarFunction { ScalarFunction::ArrayPopFront => "ArrayPopFront", ScalarFunction::Levenshtein => "Levenshtein", ScalarFunction::SubstrIndex => "SubstrIndex", + ScalarFunction::ArraySort => "ArraySort", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2862,6 +2864,7 @@ impl ScalarFunction { "ArrayPopFront" => Some(Self::ArrayPopFront), "Levenshtein" => Some(Self::Levenshtein), "SubstrIndex" => Some(Self::SubstrIndex), + "ArraySort" => Some(Self::ArraySort), _ => None, } } From 052f184b99208e23a184829be92f469394a0fe4c Mon Sep 17 00:00:00 2001 From: Asura7969 <1402357969@qq.com> Date: Tue, 28 Nov 2023 16:58:53 +0800 Subject: [PATCH 15/16] doc --- docs/source/user-guide/sql/scalar_functions.md | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index dc666d18a009..896f54e4555e 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1634,9 +1634,7 @@ array_append(array, element) Sort array. ``` -array_sort(array) -array_sort(array, 'ASC') -array_sort(array, 'DESC', 'NULLS FIRST') +array_sort(array, desc, nulls_first) ``` #### Arguments From 1e1cc77d075b3e37e6bee82ad1b480f5d8603382 Mon Sep 17 00:00:00 2001 From: Asura7969 <1402357969@qq.com> Date: Wed, 6 Dec 2023 11:02:27 +0800 Subject: [PATCH 16/16] merge --- datafusion/proto/src/generated/pbjson.rs | 3 +++ datafusion/proto/src/generated/prost.rs | 3 +++ 2 files changed, 6 insertions(+) diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index b8c5f6a4aae8..4d2ba26020e7 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -20865,6 +20865,7 @@ impl serde::Serialize for ScalarFunction { Self::Levenshtein => "Levenshtein", Self::SubstrIndex => "SubstrIndex", Self::FindInSet => "FindInSet", + Self::ArraySort => "ArraySort", }; serializer.serialize_str(variant) } @@ -21004,6 +21005,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Levenshtein", "SubstrIndex", "FindInSet", + "ArraySort", ]; struct GeneratedVisitor; @@ -21172,6 +21174,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Levenshtein" => Ok(ScalarFunction::Levenshtein), "SubstrIndex" => Ok(ScalarFunction::SubstrIndex), "FindInSet" => Ok(ScalarFunction::FindInSet), + "ArraySort" => Ok(ScalarFunction::ArraySort), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index c31bc4ab5948..a987d24fca8d 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2596,6 +2596,7 @@ pub enum ScalarFunction { Levenshtein = 125, SubstrIndex = 126, FindInSet = 127, + ArraySort = 128, } impl ScalarFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2732,6 +2733,7 @@ impl ScalarFunction { ScalarFunction::Levenshtein => "Levenshtein", ScalarFunction::SubstrIndex => "SubstrIndex", ScalarFunction::FindInSet => "FindInSet", + ScalarFunction::ArraySort => "ArraySort", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2865,6 +2867,7 @@ impl ScalarFunction { "Levenshtein" => Some(Self::Levenshtein), "SubstrIndex" => Some(Self::SubstrIndex), "FindInSet" => Some(Self::FindInSet), + "ArraySort" => Some(Self::ArraySort), _ => None, } }