diff --git a/arrow-row/src/lib.rs b/arrow-row/src/lib.rs index 089bf43ebeed..44f52fd69f43 100644 --- a/arrow-row/src/lib.rs +++ b/arrow-row/src/lib.rs @@ -145,9 +145,11 @@ use variable::{decode_binary_view, decode_string_view}; use crate::fixed::{decode_bool, decode_fixed_size_binary, decode_primitive}; use crate::variable::{decode_binary, decode_string}; +use arrow_array::types::{Int16Type, Int32Type, Int64Type}; mod fixed; mod list; +mod run; mod variable; /// Converts [`ArrayRef`] columns into a [row-oriented](self) format. @@ -381,6 +383,8 @@ enum Codec { Struct(RowConverter, OwnedRow), /// A row converter for the child field List(RowConverter), + /// A row converter for the values array of a run-end encoded array + RunEndEncoded(RowConverter), } impl Codec { @@ -400,6 +404,17 @@ impl Codec { }; Ok(Self::Dictionary(converter, owned)) } + DataType::RunEndEncoded(_, values) => { + // Similar to List implementation + let options = SortOptions { + descending: false, + nulls_first: sort_field.options.nulls_first != sort_field.options.descending, + }; + + let field = SortField::new_with_options(values.data_type().clone(), options); + let converter = RowConverter::new(vec![field])?; + Ok(Self::RunEndEncoded(converter)) + } d if !d.is_nested() => Ok(Self::Stateless), DataType::List(f) | DataType::LargeList(f) => { // The encoded contents will be inverted if descending is set to true @@ -460,6 +475,19 @@ impl Codec { let rows = converter.convert_columns(&[values.clone()])?; Ok(Encoder::List(rows)) } + Codec::RunEndEncoded(converter) => { + let values = match array.data_type() { + DataType::RunEndEncoded(r, _) => match r.data_type() { + DataType::Int16 => array.as_run::().values(), + DataType::Int32 => array.as_run::().values(), + DataType::Int64 => array.as_run::().values(), + _ => unreachable!("Unsupported run end index type: {r:?}"), + }, + _ => unreachable!(), + }; + let rows = converter.convert_columns(&[values.clone()])?; + Ok(Encoder::RunEndEncoded(rows)) + } } } @@ -469,6 +497,7 @@ impl Codec { Codec::Dictionary(converter, nulls) => converter.size() + nulls.data.len(), Codec::Struct(converter, nulls) => converter.size() + nulls.data.len(), Codec::List(converter) => converter.size(), + Codec::RunEndEncoded(converter) => converter.size(), } } } @@ -487,6 +516,8 @@ enum Encoder<'a> { Struct(Rows, Row<'a>), /// The row encoding of the child array List(Rows), + /// The row encoding of the values array + RunEndEncoded(Rows), } /// Configure the data type and sort order for a given column @@ -545,6 +576,7 @@ impl RowConverter { Self::supports_datatype(f.data_type()) } DataType::Struct(f) => f.iter().all(|x| Self::supports_datatype(x.data_type())), + DataType::RunEndEncoded(_, values) => Self::supports_datatype(values.data_type()), _ => false, } } @@ -1331,6 +1363,27 @@ fn row_lengths(cols: &[ArrayRef], encoders: &[Encoder]) -> LengthTracker { } _ => unreachable!(), }, + Encoder::RunEndEncoded(rows) => match array.data_type() { + DataType::RunEndEncoded(r, _) => match r.data_type() { + DataType::Int16 => run::compute_lengths( + tracker.materialized(), + rows, + array.as_run::(), + ), + DataType::Int32 => run::compute_lengths( + tracker.materialized(), + rows, + array.as_run::(), + ), + DataType::Int64 => run::compute_lengths( + tracker.materialized(), + rows, + array.as_run::(), + ), + _ => unreachable!("Unsupported run end index type: {r:?}"), + }, + _ => unreachable!(), + }, } } @@ -1427,6 +1480,21 @@ fn encode_column( } _ => unreachable!(), }, + Encoder::RunEndEncoded(rows) => match column.data_type() { + DataType::RunEndEncoded(r, _) => match r.data_type() { + DataType::Int16 => { + run::encode(data, offsets, rows, opts, column.as_run::()) + } + DataType::Int32 => { + run::encode(data, offsets, rows, opts, column.as_run::()) + } + DataType::Int64 => { + run::encode(data, offsets, rows, opts, column.as_run::()) + } + _ => unreachable!("Unsupported run end index type: {r:?}"), + }, + _ => unreachable!(), + }, } } @@ -1512,6 +1580,30 @@ unsafe fn decode_column( } _ => unreachable!(), }, + Codec::RunEndEncoded(converter) => match &field.data_type { + DataType::RunEndEncoded(run_ends, _) => match run_ends.data_type() { + DataType::Int16 => Arc::new(run::decode::( + converter, + rows, + field, + validate_utf8, + )?), + DataType::Int32 => Arc::new(run::decode::( + converter, + rows, + field, + validate_utf8, + )?), + DataType::Int64 => Arc::new(run::decode::( + converter, + rows, + field, + validate_utf8, + )?), + _ => unreachable!(), + }, + _ => unreachable!(), + }, }; Ok(array) } diff --git a/arrow-row/src/run.rs b/arrow-row/src/run.rs new file mode 100644 index 000000000000..15ef59df3904 --- /dev/null +++ b/arrow-row/src/run.rs @@ -0,0 +1,695 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{variable, RowConverter, Rows, SortField}; +use arrow_array::types::RunEndIndexType; +use arrow_array::{PrimitiveArray, RunArray}; +use arrow_buffer::{ArrowNativeType, ScalarBuffer}; +use arrow_schema::{ArrowError, SortOptions}; + +/// Computes the lengths of each row for a RunEndEncodedArray +pub fn compute_lengths( + lengths: &mut [usize], + rows: &Rows, + array: &RunArray, +) { + let run_ends = array.run_ends().values(); + let mut logical_start = 0; + + // Iterate over each run and apply the same length to all logical positions in the run + for (physical_idx, &run_end) in run_ends.iter().enumerate() { + let logical_end = run_end.as_usize(); + let row = rows.row(physical_idx); + let encoded_len = variable::encoded_len(Some(row.data)); + + // Add the same length for all logical positions in this run + for length in &mut lengths[logical_start..logical_end] { + *length += encoded_len; + } + + logical_start = logical_end; + } +} + +/// Encodes the provided `RunEndEncodedArray` to `out` with the provided `SortOptions` +/// +/// `rows` should contain the encoded values +pub fn encode( + data: &mut [u8], + offsets: &mut [usize], + rows: &Rows, + opts: SortOptions, + array: &RunArray, +) { + let run_ends = array.run_ends(); + + let mut logical_idx = 0; + let mut offset_idx = 1; // Skip first offset + + // Iterate over each run + for physical_idx in 0..run_ends.values().len() { + let run_end = run_ends.values()[physical_idx].as_usize(); + + // Process all elements in this run + while logical_idx < run_end && offset_idx < offsets.len() { + let offset = &mut offsets[offset_idx]; + let out = &mut data[*offset..]; + + // Use variable-length encoding to make the data self-describing + let row = rows.row(physical_idx); + let bytes_written = variable::encode_one(out, Some(row.data), opts); + *offset += bytes_written; + + logical_idx += 1; + offset_idx += 1; + } + + // Break if we've processed all offsets + if offset_idx >= offsets.len() { + break; + } + } +} + +/// Decodes a RunEndEncodedArray from `rows` with the provided `options` +/// +/// # Safety +/// +/// `rows` must contain valid data for the provided `converter` +pub unsafe fn decode( + converter: &RowConverter, + rows: &mut [&[u8]], + field: &SortField, + validate_utf8: bool, +) -> Result, ArrowError> { + if rows.is_empty() { + let values = converter.convert_raw(&mut [], validate_utf8)?; + let run_ends_array = PrimitiveArray::::new(ScalarBuffer::from(vec![]), None); + return RunArray::::try_new(&run_ends_array, &values[0]); + } + + // Decode each row's REE data and collect the decoded values + let mut decoded_values = Vec::new(); + let mut run_ends = Vec::new(); + let mut unique_row_indices = Vec::new(); + + // Process each row to extract its REE data (following decode_binary pattern) + let mut decoded_data = Vec::new(); + for (idx, row) in rows.iter_mut().enumerate() { + decoded_data.clear(); + // Extract the decoded value data from this row + let consumed = variable::decode_blocks(row, field.options, |block| { + decoded_data.extend_from_slice(block); + }); + + // Handle bit inversion for descending sort (following decode_binary pattern) + if field.options.descending { + decoded_data.iter_mut().for_each(|b| *b = !*b); + } + + // Update the row to point past the consumed REE data + *row = &row[consumed..]; + + // Check if this decoded value is the same as the previous one to identify runs + let is_new_run = + idx == 0 || decoded_data != decoded_values[*unique_row_indices.last().unwrap()]; + + if is_new_run { + // This is a new unique value - end the previous run if any + if idx > 0 { + run_ends.push(R::Native::usize_as(idx)); + } + unique_row_indices.push(decoded_values.len()); + decoded_values.push(decoded_data.clone()); + } + } + // Add the final run end + run_ends.push(R::Native::usize_as(rows.len())); + + // Convert the unique decoded values using the row converter + let mut unique_rows: Vec<&[u8]> = decoded_values.iter().map(|v| v.as_slice()).collect(); + let values = if unique_rows.is_empty() { + converter.convert_raw(&mut [], validate_utf8)? + } else { + converter.convert_raw(&mut unique_rows, validate_utf8)? + }; + + // Create run ends array + let run_ends_array = PrimitiveArray::::new(ScalarBuffer::from(run_ends), None); + + // Create the RunEndEncodedArray + RunArray::::try_new(&run_ends_array, &values[0]) +} + +#[cfg(test)] +mod tests { + use crate::{RowConverter, SortField}; + use arrow_array::types::Int32Type; + use arrow_array::{Array, Int64Array, RunArray, StringArray}; + use arrow_schema::{DataType, SortOptions}; + use std::sync::Arc; + + #[test] + fn test_run_end_encoded_supports_datatype() { + // Test that the RowConverter correctly supports run-end encoded arrays + assert!(RowConverter::supports_datatype(&DataType::RunEndEncoded( + Arc::new(arrow_schema::Field::new("run_ends", DataType::Int32, false)), + Arc::new(arrow_schema::Field::new("values", DataType::Utf8, true)), + ))); + } + + #[test] + fn test_run_end_encoded_round_trip_int64s() { + // Test round-trip correctness for RunEndEncodedArray with Int64 values making sure it + // doesn't just work with eg. strings (which are all the other tests). + + let values = Int64Array::from(vec![100, 200, 100, 300]); + let run_ends = vec![2, 3, 5, 6]; + let array: RunArray = + RunArray::try_new(&arrow_array::PrimitiveArray::from(run_ends), &values).unwrap(); + + let converter = RowConverter::new(vec![SortField::new(DataType::RunEndEncoded( + Arc::new(arrow_schema::Field::new("run_ends", DataType::Int32, false)), + Arc::new(arrow_schema::Field::new("values", DataType::Int64, true)), + ))]) + .unwrap(); + + let rows = converter + .convert_columns(&[Arc::new(array.clone())]) + .unwrap(); + + let arrays = converter.convert_rows(&rows).unwrap(); + let result = arrays[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + assert_eq!(array.run_ends().values(), result.run_ends().values()); + assert_eq!(array.values().as_ref(), result.values().as_ref()); + } + + #[test] + fn test_run_end_encoded_round_trip_strings() { + // Test round-trip correctness for RunEndEncodedArray with strings + + let array: RunArray = vec!["b", "b", "a"].into_iter().collect(); + + let converter = RowConverter::new(vec![SortField::new(DataType::RunEndEncoded( + Arc::new(arrow_schema::Field::new("run_ends", DataType::Int32, false)), + Arc::new(arrow_schema::Field::new("values", DataType::Utf8, true)), + ))]) + .unwrap(); + + let rows = converter + .convert_columns(&[Arc::new(array.clone())]) + .unwrap(); + + let arrays = converter.convert_rows(&rows).unwrap(); + let result = arrays[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + assert_eq!(array.run_ends().values(), result.run_ends().values()); + assert_eq!(array.values().as_ref(), result.values().as_ref()); + } + + #[test] + fn test_run_end_encoded_round_trip_strings_with_nulls() { + // Test round-trip correctness for RunEndEncodedArray with nulls + + let array: RunArray = vec![Some("b"), Some("b"), None, Some("a")] + .into_iter() + .collect(); + + let converter = RowConverter::new(vec![SortField::new(DataType::RunEndEncoded( + Arc::new(arrow_schema::Field::new("run_ends", DataType::Int32, false)), + Arc::new(arrow_schema::Field::new("values", DataType::Utf8, true)), + ))]) + .unwrap(); + + let rows = converter + .convert_columns(&[Arc::new(array.clone())]) + .unwrap(); + + let arrays = converter.convert_rows(&rows).unwrap(); + let result = arrays[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + assert_eq!(array.run_ends().values(), result.run_ends().values()); + assert_eq!(array.values().as_ref(), result.values().as_ref()); + } + + #[test] + fn test_run_end_encoded_ascending_descending_round_trip() { + // Test round-trip correctness for ascending vs descending sort options + + let values_asc = + arrow_array::StringArray::from(vec![Some("apple"), Some("banana"), Some("cherry")]); + let run_ends_asc = vec![2, 4, 6]; + let run_array_asc: RunArray = RunArray::try_new( + &arrow_array::PrimitiveArray::from(run_ends_asc), + &values_asc, + ) + .unwrap(); + + // Test ascending order + let converter_asc = RowConverter::new(vec![SortField::new_with_options( + DataType::RunEndEncoded( + Arc::new(arrow_schema::Field::new("run_ends", DataType::Int32, false)), + Arc::new(arrow_schema::Field::new("values", DataType::Utf8, true)), + ), + SortOptions { + descending: false, + nulls_first: true, + }, + )]) + .unwrap(); + + let rows_asc = converter_asc + .convert_columns(&[Arc::new(run_array_asc.clone())]) + .unwrap(); + let arrays_asc = converter_asc.convert_rows(&rows_asc).unwrap(); + let result_asc = arrays_asc[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + // Verify round-trip correctness for ascending + assert_eq!(run_array_asc.len(), result_asc.len()); + for i in 0..run_array_asc.len() { + let orig_physical = run_array_asc.get_physical_index(i); + let result_physical = result_asc.get_physical_index(i); + + let orig_values = run_array_asc + .values() + .as_any() + .downcast_ref::() + .unwrap(); + let result_values = result_asc + .values() + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!( + orig_values.value(orig_physical), + result_values.value(result_physical), + "Ascending sort value mismatch at index {}", + i + ); + } + + // Test descending order + let converter_desc = RowConverter::new(vec![SortField::new_with_options( + DataType::RunEndEncoded( + Arc::new(arrow_schema::Field::new("run_ends", DataType::Int32, false)), + Arc::new(arrow_schema::Field::new("values", DataType::Utf8, true)), + ), + SortOptions { + descending: true, + nulls_first: true, + }, + )]) + .unwrap(); + + let rows_desc = converter_desc + .convert_columns(&[Arc::new(run_array_asc.clone())]) + .unwrap(); + let arrays_desc = converter_desc.convert_rows(&rows_desc).unwrap(); + let result_desc = arrays_desc[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + // Verify round-trip correctness for descending + assert_eq!(run_array_asc.len(), result_desc.len()); + for i in 0..run_array_asc.len() { + let orig_physical = run_array_asc.get_physical_index(i); + let result_physical = result_desc.get_physical_index(i); + + let orig_values = run_array_asc + .values() + .as_any() + .downcast_ref::() + .unwrap(); + let result_values = result_desc + .values() + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!( + orig_values.value(orig_physical), + result_values.value(result_physical), + "Descending sort value mismatch at index {}", + i + ); + } + } + + #[test] + fn test_run_end_encoded_sort_configurations_basic() { + // Test that different sort configurations work and can round-trip successfully + + let test_array: RunArray = vec!["test"].into_iter().collect(); + + let converter_asc = RowConverter::new(vec![SortField::new_with_options( + DataType::RunEndEncoded( + Arc::new(arrow_schema::Field::new("run_ends", DataType::Int32, false)), + Arc::new(arrow_schema::Field::new("values", DataType::Utf8, true)), + ), + SortOptions { + descending: false, + nulls_first: true, + }, + )]) + .unwrap(); + + let converter_desc = RowConverter::new(vec![SortField::new_with_options( + DataType::RunEndEncoded( + Arc::new(arrow_schema::Field::new("run_ends", DataType::Int32, false)), + Arc::new(arrow_schema::Field::new("values", DataType::Utf8, true)), + ), + SortOptions { + descending: true, + nulls_first: true, + }, + )]) + .unwrap(); + + let rows_test_asc = converter_asc + .convert_columns(&[Arc::new(test_array.clone())]) + .unwrap(); + let rows_test_desc = converter_desc + .convert_columns(&[Arc::new(test_array.clone())]) + .unwrap(); + + // Convert back to verify both configurations work + let result_test_asc = converter_asc.convert_rows(&rows_test_asc).unwrap(); + let result_test_desc = converter_desc.convert_rows(&rows_test_desc).unwrap(); + + // Both should successfully reconstruct the original + assert_eq!(result_test_asc.len(), 1); + assert_eq!(result_test_desc.len(), 1); + } + + #[test] + fn test_run_end_encoded_nulls_first_last_configurations() { + // Test that nulls_first vs nulls_last configurations work + + let simple_array: RunArray = vec!["simple"].into_iter().collect(); + + let converter_nulls_first = RowConverter::new(vec![SortField::new_with_options( + DataType::RunEndEncoded( + Arc::new(arrow_schema::Field::new("run_ends", DataType::Int32, false)), + Arc::new(arrow_schema::Field::new("values", DataType::Utf8, true)), + ), + SortOptions { + descending: false, + nulls_first: true, + }, + )]) + .unwrap(); + + let converter_nulls_last = RowConverter::new(vec![SortField::new_with_options( + DataType::RunEndEncoded( + Arc::new(arrow_schema::Field::new("run_ends", DataType::Int32, false)), + Arc::new(arrow_schema::Field::new("values", DataType::Utf8, true)), + ), + SortOptions { + descending: false, + nulls_first: false, + }, + )]) + .unwrap(); + + // Test that both configurations can handle simple arrays + let rows_nulls_first = converter_nulls_first + .convert_columns(&[Arc::new(simple_array.clone())]) + .unwrap(); + let arrays_nulls_first = converter_nulls_first + .convert_rows(&rows_nulls_first) + .unwrap(); + let result_nulls_first = arrays_nulls_first[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + let rows_nulls_last = converter_nulls_last + .convert_columns(&[Arc::new(simple_array.clone())]) + .unwrap(); + let arrays_nulls_last = converter_nulls_last.convert_rows(&rows_nulls_last).unwrap(); + let result_nulls_last = arrays_nulls_last[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + // Both should successfully convert the simple array + assert_eq!(simple_array.len(), result_nulls_first.len()); + assert_eq!(simple_array.len(), result_nulls_last.len()); + } + + #[test] + fn test_run_end_encoded_row_consumption() { + // This test verifies that ALL rows are properly consumed during decoding, + // not just the unique values. We test this by ensuring multi-column conversion + // works correctly - if rows aren't consumed properly, the second column would fail. + + // Create a REE array with multiple runs + let array: RunArray = vec!["a", "a", "b", "b", "b", "c"].into_iter().collect(); + let string_array = StringArray::from(vec!["x", "y", "z", "w", "u", "v"]); + + let multi_converter = RowConverter::new(vec![ + SortField::new(DataType::RunEndEncoded( + Arc::new(arrow_schema::Field::new("run_ends", DataType::Int32, false)), + Arc::new(arrow_schema::Field::new("values", DataType::Utf8, true)), + )), + SortField::new(DataType::Utf8), + ]) + .unwrap(); + + let multi_rows = multi_converter + .convert_columns(&[Arc::new(array.clone()), Arc::new(string_array.clone())]) + .unwrap(); + + // Convert back - this will test that all rows are consumed properly + let arrays = multi_converter.convert_rows(&multi_rows).unwrap(); + + // Verify both columns round-trip correctly + let result_ree = arrays[0] + .as_any() + .downcast_ref::>() + .unwrap(); + + let result_string = arrays[1].as_any().downcast_ref::().unwrap(); + + // This should pass - both arrays should be identical to originals + assert_eq!(result_ree.values().as_ref(), array.values().as_ref()); + assert_eq!(result_ree.run_ends().values(), array.run_ends().values()); + assert_eq!(*result_string, string_array); + } + + #[test] + fn test_run_end_encoded_sorting_behavior() { + // Test that the binary row encoding actually produces the correct sort order + + // Create REE arrays with different values to test sorting + let array1: RunArray = vec!["apple", "apple"].into_iter().collect(); + let array2: RunArray = vec!["banana", "banana"].into_iter().collect(); + let array3: RunArray = vec!["cherry", "cherry"].into_iter().collect(); + + // Test ascending sort + let converter_asc = RowConverter::new(vec![SortField::new(DataType::RunEndEncoded( + Arc::new(arrow_schema::Field::new("run_ends", DataType::Int32, false)), + Arc::new(arrow_schema::Field::new("values", DataType::Utf8, true)), + ))]) + .unwrap(); + + let rows1_asc = converter_asc + .convert_columns(&[Arc::new(array1.clone())]) + .unwrap(); + let rows2_asc = converter_asc + .convert_columns(&[Arc::new(array2.clone())]) + .unwrap(); + let rows3_asc = converter_asc + .convert_columns(&[Arc::new(array3.clone())]) + .unwrap(); + + // For ascending: apple < banana < cherry + // So row bytes should sort: rows1 < rows2 < rows3 + assert!( + rows1_asc.row(0) < rows2_asc.row(0), + "apple should come before banana in ascending order" + ); + assert!( + rows2_asc.row(0) < rows3_asc.row(0), + "banana should come before cherry in ascending order" + ); + assert!( + rows1_asc.row(0) < rows3_asc.row(0), + "apple should come before cherry in ascending order" + ); + + // Test descending sort + let converter_desc = RowConverter::new(vec![SortField::new_with_options( + DataType::RunEndEncoded( + Arc::new(arrow_schema::Field::new("run_ends", DataType::Int32, false)), + Arc::new(arrow_schema::Field::new("values", DataType::Utf8, true)), + ), + arrow_schema::SortOptions { + descending: true, + nulls_first: true, + }, + )]) + .unwrap(); + + let rows1_desc = converter_desc + .convert_columns(&[Arc::new(array1.clone())]) + .unwrap(); + let rows2_desc = converter_desc + .convert_columns(&[Arc::new(array2.clone())]) + .unwrap(); + let rows3_desc = converter_desc + .convert_columns(&[Arc::new(array3.clone())]) + .unwrap(); + + // For descending: cherry > banana > apple + // So row bytes should sort: rows3 < rows2 < rows1 (because byte comparison is ascending) + assert!( + rows3_desc.row(0) < rows2_desc.row(0), + "cherry should come before banana in descending order (byte-wise)" + ); + assert!( + rows2_desc.row(0) < rows1_desc.row(0), + "banana should come before apple in descending order (byte-wise)" + ); + assert!( + rows3_desc.row(0) < rows1_desc.row(0), + "cherry should come before apple in descending order (byte-wise)" + ); + } + + #[test] + fn test_run_end_encoded_null_sorting() { + // Test null handling in sort order + + let array_with_nulls: RunArray = vec![None, None].into_iter().collect(); + let array_with_values: RunArray = vec!["apple", "apple"].into_iter().collect(); + + // Test nulls_first = true + let converter_nulls_first = RowConverter::new(vec![SortField::new_with_options( + DataType::RunEndEncoded( + Arc::new(arrow_schema::Field::new("run_ends", DataType::Int32, false)), + Arc::new(arrow_schema::Field::new("values", DataType::Utf8, true)), + ), + arrow_schema::SortOptions { + descending: false, + nulls_first: true, + }, + )]) + .unwrap(); + + let rows_nulls = converter_nulls_first + .convert_columns(&[Arc::new(array_with_nulls.clone())]) + .unwrap(); + let rows_values = converter_nulls_first + .convert_columns(&[Arc::new(array_with_values.clone())]) + .unwrap(); + + // nulls should come before values when nulls_first = true + assert!( + rows_nulls.row(0) < rows_values.row(0), + "nulls should come before values when nulls_first=true" + ); + + // Test nulls_first = false + let converter_nulls_last = RowConverter::new(vec![SortField::new_with_options( + DataType::RunEndEncoded( + Arc::new(arrow_schema::Field::new("run_ends", DataType::Int32, false)), + Arc::new(arrow_schema::Field::new("values", DataType::Utf8, true)), + ), + arrow_schema::SortOptions { + descending: false, + nulls_first: false, + }, + )]) + .unwrap(); + + let rows_nulls_last = converter_nulls_last + .convert_columns(&[Arc::new(array_with_nulls.clone())]) + .unwrap(); + let rows_values_last = converter_nulls_last + .convert_columns(&[Arc::new(array_with_values.clone())]) + .unwrap(); + + // values should come before nulls when nulls_first = false + assert!( + rows_values_last.row(0) < rows_nulls_last.row(0), + "values should come before nulls when nulls_first=false" + ); + } + + #[test] + fn test_run_end_encoded_mixed_sorting() { + // Test sorting with mixed values and nulls to ensure complex scenarios work + + let array1: RunArray = vec![Some("apple"), None].into_iter().collect(); + let array2: RunArray = vec![None, Some("banana")].into_iter().collect(); + let array3: RunArray = + vec![Some("cherry"), Some("cherry")].into_iter().collect(); + + let converter = RowConverter::new(vec![SortField::new_with_options( + DataType::RunEndEncoded( + Arc::new(arrow_schema::Field::new("run_ends", DataType::Int32, false)), + Arc::new(arrow_schema::Field::new("values", DataType::Utf8, true)), + ), + arrow_schema::SortOptions { + descending: false, + nulls_first: true, + }, + )]) + .unwrap(); + + let rows1 = converter.convert_columns(&[Arc::new(array1)]).unwrap(); + let rows2 = converter.convert_columns(&[Arc::new(array2)]).unwrap(); + let rows3 = converter.convert_columns(&[Arc::new(array3)]).unwrap(); + + // With nulls_first=true, ascending: + // Row 0: array1[0]="apple", array2[0]=null, array3[0]="cherry" -> null < apple < cherry + // Row 1: array1[1]=null, array2[1]="banana", array3[1]="cherry" -> null < banana < cherry + + // Compare first rows: null < apple < cherry + assert!(rows2.row(0) < rows1.row(0), "null should come before apple"); + assert!( + rows1.row(0) < rows3.row(0), + "apple should come before cherry" + ); + + // Compare second rows: null < banana < cherry + assert!( + rows1.row(1) < rows2.row(1), + "null should come before banana" + ); + assert!( + rows2.row(1) < rows3.row(1), + "banana should come before cherry" + ); + } +} diff --git a/arrow-schema/src/datatype.rs b/arrow-schema/src/datatype.rs index 98b692a04357..f22b6c52ba34 100644 --- a/arrow-schema/src/datatype.rs +++ b/arrow-schema/src/datatype.rs @@ -592,6 +592,7 @@ impl DataType { use DataType::*; match self { Dictionary(_, v) => DataType::is_nested(v.as_ref()), + RunEndEncoded(_, v) => DataType::is_nested(v.data_type()), List(_) | FixedSizeList(_, _) | LargeList(_)