Skip to content

Commit 667cbdf

Browse files
committed
Support FixedSizeList RowConverter
Add `DataType::FixedSizeList` support to `RowConverter`. This is necessary to support DISTINCT and GROUP BY over fixed-sized arrays in DataFusion.
1 parent 0fddbd4 commit 667cbdf

File tree

3 files changed

+261
-10
lines changed

3 files changed

+261
-10
lines changed

arrow-array/src/array/fixed_size_list_array.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,8 +343,8 @@ impl From<ArrayData> for FixedSizeListArray {
343343
fn from(data: ArrayData) -> Self {
344344
let value_length = match data.data_type() {
345345
DataType::FixedSizeList(_, len) => *len,
346-
_ => {
347-
panic!("FixedSizeListArray data should contain a FixedSizeList data type")
346+
data_type => {
347+
panic!("FixedSizeListArray data should contain a FixedSizeList data type, got {data_type:?}")
348348
}
349349
};
350350

arrow-row/src/lib.rs

Lines changed: 134 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ use arrow_schema::*;
144144
use variable::{decode_binary_view, decode_string_view};
145145

146146
use crate::fixed::{decode_bool, decode_fixed_size_binary, decode_primitive};
147+
use crate::list::{compute_lengths_fixed_size_list, encode_fixed_size_list};
147148
use crate::variable::{decode_binary, decode_string};
148149
use arrow_array::types::{Int16Type, Int32Type, Int64Type};
149150

@@ -433,6 +434,11 @@ impl Codec {
433434
let converter = RowConverter::new(vec![field])?;
434435
Ok(Self::List(converter))
435436
}
437+
DataType::FixedSizeList(f, _) => {
438+
let field = SortField::new_with_options(f.data_type().clone(), sort_field.options);
439+
let converter = RowConverter::new(vec![field])?;
440+
Ok(Self::List(converter))
441+
}
436442
DataType::Struct(f) => {
437443
let sort_fields = f
438444
.iter()
@@ -474,6 +480,7 @@ impl Codec {
474480
let values = match array.data_type() {
475481
DataType::List(_) => as_list_array(array).values(),
476482
DataType::LargeList(_) => as_large_list_array(array).values(),
483+
DataType::FixedSizeList(_, _) => as_fixed_size_list_array(array).values(),
477484
_ => unreachable!(),
478485
};
479486
let rows = converter.convert_columns(&[values.clone()])?;
@@ -576,9 +583,10 @@ impl RowConverter {
576583
fn supports_datatype(d: &DataType) -> bool {
577584
match d {
578585
_ if !d.is_nested() => true,
579-
DataType::List(f) | DataType::LargeList(f) | DataType::Map(f, _) => {
580-
Self::supports_datatype(f.data_type())
581-
}
586+
DataType::List(f)
587+
| DataType::LargeList(f)
588+
| DataType::FixedSizeList(f, _)
589+
| DataType::Map(f, _) => Self::supports_datatype(f.data_type()),
582590
DataType::Struct(f) => f.iter().all(|x| Self::supports_datatype(x.data_type())),
583591
DataType::RunEndEncoded(_, values) => Self::supports_datatype(values.data_type()),
584592
_ => false,
@@ -1365,6 +1373,11 @@ fn row_lengths(cols: &[ArrayRef], encoders: &[Encoder]) -> LengthTracker {
13651373
DataType::LargeList(_) => {
13661374
list::compute_lengths(tracker.materialized(), rows, as_large_list_array(array))
13671375
}
1376+
DataType::FixedSizeList(_, _) => compute_lengths_fixed_size_list(
1377+
&mut tracker,
1378+
rows,
1379+
as_fixed_size_list_array(array),
1380+
),
13681381
_ => unreachable!(),
13691382
},
13701383
Encoder::RunEndEncoded(rows) => match array.data_type() {
@@ -1482,6 +1495,9 @@ fn encode_column(
14821495
DataType::LargeList(_) => {
14831496
list::encode(data, offsets, rows, opts, as_large_list_array(column))
14841497
}
1498+
DataType::FixedSizeList(_, _) => {
1499+
encode_fixed_size_list(data, offsets, rows, opts, as_fixed_size_list_array(column))
1500+
}
14851501
_ => unreachable!(),
14861502
},
14871503
Encoder::RunEndEncoded(rows) => match column.data_type() {
@@ -1582,6 +1598,13 @@ unsafe fn decode_column(
15821598
DataType::LargeList(_) => {
15831599
Arc::new(list::decode::<i64>(converter, rows, field, validate_utf8)?)
15841600
}
1601+
DataType::FixedSizeList(_, value_length) => Arc::new(list::decode_fixed_size_list(
1602+
converter,
1603+
rows,
1604+
field,
1605+
validate_utf8,
1606+
value_length.as_usize(),
1607+
)?),
15851608
_ => unreachable!(),
15861609
},
15871610
Codec::RunEndEncoded(converter) => match &field.data_type {
@@ -2378,6 +2401,114 @@ mod tests {
23782401
test_nested_list::<i64>();
23792402
}
23802403

2404+
#[test]
2405+
fn test_fixed_size_list() {
2406+
let mut builder = FixedSizeListBuilder::new(Int32Builder::new(), 3);
2407+
builder.values().append_value(32);
2408+
builder.values().append_value(52);
2409+
builder.values().append_value(32);
2410+
builder.append(true);
2411+
builder.values().append_value(32);
2412+
builder.values().append_value(52);
2413+
builder.values().append_value(12);
2414+
builder.append(true);
2415+
builder.values().append_value(32);
2416+
builder.values().append_value(52);
2417+
builder.values().append_null();
2418+
builder.append(true);
2419+
builder.values().append_value(32); // MASKED
2420+
builder.values().append_value(52); // MASKED
2421+
builder.values().append_value(13); // MASKED
2422+
builder.append(false);
2423+
builder.values().append_value(32);
2424+
builder.values().append_null();
2425+
builder.values().append_null();
2426+
builder.append(true);
2427+
builder.values().append_null();
2428+
builder.values().append_null();
2429+
builder.values().append_null();
2430+
builder.append(true);
2431+
builder.values().append_value(17); // MASKED
2432+
builder.values().append_null(); // MASKED
2433+
builder.values().append_value(77); // MASKED
2434+
builder.append(false);
2435+
2436+
let list = Arc::new(builder.finish()) as ArrayRef;
2437+
let d = list.data_type().clone();
2438+
2439+
// Default sorting (ascending, nulls first)
2440+
let converter = RowConverter::new(vec![SortField::new(d.clone())]).unwrap();
2441+
2442+
let rows = converter.convert_columns(&[Arc::clone(&list)]).unwrap();
2443+
assert!(rows.row(0) > rows.row(1)); // [32, 52, 32] > [32, 52, 12]
2444+
assert!(rows.row(2) < rows.row(1)); // [32, 52, null] < [32, 52, 12]
2445+
assert!(rows.row(3) < rows.row(2)); // null < [32, 52, null]
2446+
assert!(rows.row(4) < rows.row(2)); // [32, null, null] < [32, 52, null]
2447+
assert!(rows.row(5) < rows.row(2)); // [null, null, null] < [32, 52, null]
2448+
assert!(rows.row(3) < rows.row(5)); // null < [null, null, null]
2449+
assert_eq!(rows.row(3), rows.row(6)); // null = null (different masked values)
2450+
2451+
let back = converter.convert_rows(&rows).unwrap();
2452+
assert_eq!(back.len(), 1);
2453+
back[0].to_data().validate_full().unwrap();
2454+
assert_eq!(&back[0], &list);
2455+
2456+
// Ascending, null last
2457+
let options = SortOptions::default().asc().with_nulls_first(false);
2458+
let field = SortField::new_with_options(d.clone(), options);
2459+
let converter = RowConverter::new(vec![field]).unwrap();
2460+
let rows = converter.convert_columns(&[Arc::clone(&list)]).unwrap();
2461+
assert!(rows.row(0) > rows.row(1)); // [32, 52, 32] > [32, 52, 12]
2462+
assert!(rows.row(2) > rows.row(1)); // [32, 52, null] > [32, 52, 12]
2463+
assert!(rows.row(3) > rows.row(2)); // null > [32, 52, null]
2464+
assert!(rows.row(4) > rows.row(2)); // [32, null, null] > [32, 52, null]
2465+
assert!(rows.row(5) > rows.row(2)); // [null, null, null] > [32, 52, null]
2466+
assert!(rows.row(3) > rows.row(5)); // null > [null, null, null]
2467+
assert_eq!(rows.row(3), rows.row(6)); // null = null (different masked values)
2468+
2469+
let back = converter.convert_rows(&rows).unwrap();
2470+
assert_eq!(back.len(), 1);
2471+
back[0].to_data().validate_full().unwrap();
2472+
assert_eq!(&back[0], &list);
2473+
2474+
// Descending, nulls last
2475+
let options = SortOptions::default().desc().with_nulls_first(false);
2476+
let field = SortField::new_with_options(d.clone(), options);
2477+
let converter = RowConverter::new(vec![field]).unwrap();
2478+
let rows = converter.convert_columns(&[Arc::clone(&list)]).unwrap();
2479+
assert!(rows.row(0) < rows.row(1)); // [32, 52, 32] < [32, 52, 12]
2480+
assert!(rows.row(2) > rows.row(1)); // [32, 52, null] > [32, 52, 12]
2481+
assert!(rows.row(3) > rows.row(2)); // null > [32, 52, null]
2482+
assert!(rows.row(4) > rows.row(2)); // [32, null, null] > [32, 52, null]
2483+
assert!(rows.row(5) > rows.row(2)); // [null, null, null] > [32, 52, null]
2484+
assert!(rows.row(3) > rows.row(5)); // null > [null, null, null]
2485+
assert_eq!(rows.row(3), rows.row(6)); // null = null (different masked values)
2486+
2487+
let back = converter.convert_rows(&rows).unwrap();
2488+
assert_eq!(back.len(), 1);
2489+
back[0].to_data().validate_full().unwrap();
2490+
assert_eq!(&back[0], &list);
2491+
2492+
// Descending, nulls first
2493+
let options = SortOptions::default().desc().with_nulls_first(true);
2494+
let field = SortField::new_with_options(d, options);
2495+
let converter = RowConverter::new(vec![field]).unwrap();
2496+
let rows = converter.convert_columns(&[Arc::clone(&list)]).unwrap();
2497+
2498+
assert!(rows.row(0) < rows.row(1)); // [32, 52, 32] < [32, 52, 12]
2499+
assert!(rows.row(2) < rows.row(1)); // [32, 52, null] > [32, 52, 12]
2500+
assert!(rows.row(3) < rows.row(2)); // null < [32, 52, null]
2501+
assert!(rows.row(4) < rows.row(2)); // [32, null, null] < [32, 52, null]
2502+
assert!(rows.row(5) < rows.row(2)); // [null, null, null] > [32, 52, null]
2503+
assert!(rows.row(3) < rows.row(5)); // null < [null, null, null]
2504+
assert_eq!(rows.row(3), rows.row(6)); // null = null (different masked values)
2505+
2506+
let back = converter.convert_rows(&rows).unwrap();
2507+
assert_eq!(back.len(), 1);
2508+
back[0].to_data().validate_full().unwrap();
2509+
assert_eq!(&back[0], &list);
2510+
}
2511+
23812512
fn generate_primitive_array<K>(len: usize, valid_percent: f64) -> PrimitiveArray<K>
23822513
where
23832514
K: ArrowPrimitiveType,

arrow-row/src/list.rs

Lines changed: 125 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use crate::{null_sentinel, RowConverter, Rows, SortField};
19-
use arrow_array::{Array, GenericListArray, OffsetSizeTrait};
20-
use arrow_buffer::{Buffer, MutableBuffer};
18+
use crate::{fixed, null_sentinel, LengthTracker, RowConverter, Rows, SortField};
19+
use arrow_array::{new_null_array, Array, FixedSizeListArray, GenericListArray, OffsetSizeTrait};
20+
use arrow_buffer::{ArrowNativeType, Buffer, MutableBuffer};
2121
use arrow_data::ArrayDataBuilder;
22-
use arrow_schema::{ArrowError, SortOptions};
22+
use arrow_schema::{ArrowError, DataType, SortOptions};
2323
use std::ops::Range;
2424

2525
pub fn compute_lengths<O: OffsetSizeTrait>(
@@ -97,7 +97,7 @@ fn encode_one(
9797
}
9898
}
9999

100-
/// Decodes a string array from `rows` with the provided `options`
100+
/// Decodes an array from `rows` with the provided `options`
101101
///
102102
/// # Safety
103103
///
@@ -184,3 +184,123 @@ pub unsafe fn decode<O: OffsetSizeTrait>(
184184

185185
Ok(GenericListArray::from(unsafe { builder.build_unchecked() }))
186186
}
187+
188+
pub fn compute_lengths_fixed_size_list(
189+
tracker: &mut LengthTracker,
190+
rows: &Rows,
191+
array: &FixedSizeListArray,
192+
) {
193+
let value_length = array.value_length().as_usize();
194+
tracker.push_variable((0..array.len()).map(|idx| {
195+
match array.is_valid(idx) {
196+
true => {
197+
1 + ((idx * value_length)..(idx + 1) * value_length)
198+
.map(|child_idx| rows.row(child_idx).as_ref().len())
199+
.sum::<usize>()
200+
}
201+
false => 1,
202+
}
203+
}))
204+
}
205+
206+
/// Encodes the provided `FixedSizeListArray` to `out` with the provided `SortOptions`
207+
///
208+
/// `rows` should contain the encoded child elements
209+
pub fn encode_fixed_size_list(
210+
data: &mut [u8],
211+
offsets: &mut [usize],
212+
rows: &Rows,
213+
opts: SortOptions,
214+
array: &FixedSizeListArray,
215+
) {
216+
let null_sentinel = null_sentinel(opts);
217+
offsets
218+
.iter_mut()
219+
.skip(1)
220+
.enumerate()
221+
.for_each(|(idx, offset)| {
222+
let value_length = array.value_length().as_usize();
223+
match array.is_valid(idx) {
224+
true => {
225+
data[*offset] = 0x01;
226+
*offset += 1;
227+
for child_idx in (idx * value_length)..(idx + 1) * value_length {
228+
//dbg!(child_idx);
229+
let row = rows.row(child_idx);
230+
let end_offset = *offset + row.as_ref().len();
231+
data[*offset..end_offset].copy_from_slice(row.as_ref());
232+
*offset = end_offset;
233+
}
234+
}
235+
false => {
236+
let null_sentinels = 1;
237+
//+ value_length; // 1 for self + for values too
238+
for i in 0..null_sentinels {
239+
data[*offset + i] = null_sentinel;
240+
}
241+
*offset += null_sentinels;
242+
}
243+
};
244+
})
245+
}
246+
247+
/// Decodes a fixed size list array from `rows` with the provided `options`
248+
///
249+
/// # Safety
250+
///
251+
/// `rows` must contain valid data for the provided `converter`
252+
pub unsafe fn decode_fixed_size_list(
253+
converter: &RowConverter,
254+
rows: &mut [&[u8]],
255+
field: &SortField,
256+
validate_utf8: bool,
257+
value_length: usize,
258+
) -> Result<FixedSizeListArray, ArrowError> {
259+
let list_type = &field.data_type;
260+
let element_type = match list_type {
261+
DataType::FixedSizeList(element_field, _) => element_field.data_type(),
262+
_ => {
263+
return Err(ArrowError::InvalidArgumentError(format!(
264+
"Expected FixedSizeListArray, found: {:?}",
265+
list_type
266+
)))
267+
}
268+
};
269+
270+
let len = rows.len();
271+
let (null_count, nulls) = fixed::decode_nulls(rows);
272+
273+
let null_element_encoded = converter.convert_columns(&[new_null_array(element_type, 1)])?;
274+
let null_element_encoded = null_element_encoded.row(0);
275+
let null_element_slice = null_element_encoded.as_ref();
276+
277+
let mut child_rows = Vec::new();
278+
for row in rows {
279+
let valid = row[0] == 1;
280+
let mut row_offset = 1;
281+
if !valid {
282+
for _ in 0..value_length {
283+
child_rows.push(null_element_slice);
284+
}
285+
} else {
286+
for _ in 0..value_length {
287+
let mut temp_child_rows = vec![&row[row_offset..]];
288+
converter.convert_raw(&mut temp_child_rows, validate_utf8)?;
289+
let decoded_bytes = row.len() - row_offset - temp_child_rows[0].len();
290+
let next_offset = row_offset + decoded_bytes;
291+
child_rows.push(&row[row_offset..next_offset]);
292+
row_offset = next_offset;
293+
}
294+
}
295+
}
296+
297+
let children = converter.convert_raw(&mut child_rows, validate_utf8)?;
298+
let child_data = children.iter().map(|c| c.to_data()).collect();
299+
let builder = ArrayDataBuilder::new(list_type.clone())
300+
.len(len)
301+
.null_count(null_count)
302+
.null_bit_buffer(Some(nulls))
303+
.child_data(child_data);
304+
305+
Ok(FixedSizeListArray::from(builder.build_unchecked()))
306+
}

0 commit comments

Comments
 (0)