Skip to content

Commit a7af57e

Browse files
committed
Support FixedSizeList RowConverter (apache#7705)
# Which issue does this PR close? none # Rationale for this change This is necessary to support DISTINCT and GROUP BY over fixed-sized arrays in DataFusion. # What changes are included in this PR? Add `DataType::FixedSizeList` support to `RowConverter`. # Are there any user-facing changes? No (cherry picked from commit d7fc416)
1 parent bb844d2 commit a7af57e

File tree

3 files changed

+324
-26
lines changed

3 files changed

+324
-26
lines changed

arrow-array/src/array/fixed_size_list_array.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,8 +343,8 @@ impl From<ArrayData> for FixedSizeListArray {
343343
fn from(data: ArrayData) -> Self {
344344
let value_length = match data.data_type() {
345345
DataType::FixedSizeList(_, len) => *len,
346-
_ => {
347-
panic!("FixedSizeListArray data should contain a FixedSizeList data type")
346+
data_type => {
347+
panic!("FixedSizeListArray data should contain a FixedSizeList data type, got {data_type:?}")
348348
}
349349
};
350350

arrow-row/src/lib.rs

Lines changed: 197 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ use arrow_schema::*;
139139
use variable::{decode_binary_view, decode_string_view};
140140

141141
use crate::fixed::{decode_bool, decode_fixed_size_binary, decode_primitive};
142+
use crate::list::{compute_lengths_fixed_size_list, encode_fixed_size_list};
142143
use crate::variable::{decode_binary, decode_string};
143144

144145
mod fixed;
@@ -335,6 +336,46 @@ mod variable;
335336
///
336337
/// With `[]` represented by an empty byte array, and `null` a null byte array.
337338
///
339+
/// ## Fixed Size List Encoding
340+
///
341+
/// Fixed Size Lists are encoded by first encoding all child elements to the row format.
342+
///
343+
/// A non-null list value is then encoded as 0x01 followed by the concatenation of each
344+
/// of the child elements. A null list value is encoded as a null marker.
345+
///
346+
/// For example given:
347+
///
348+
/// ```text
349+
/// [1_u8, 2_u8]
350+
/// [3_u8, null]
351+
/// null
352+
/// ```
353+
///
354+
/// The elements would be converted to:
355+
///
356+
/// ```text
357+
/// ┌──┬──┐ ┌──┬──┐ ┌──┬──┐ ┌──┬──┐
358+
/// 1 │01│01│ 2 │01│02│ 3 │01│03│ null │00│00│
359+
/// └──┴──┘ └──┴──┘ └──┴──┘ └──┴──┘
360+
///```
361+
///
362+
/// Which would be encoded as
363+
///
364+
/// ```text
365+
/// ┌──┬──┬──┬──┬──┐
366+
/// [1_u8, 2_u8] │01│01│01│01│02│
367+
/// └──┴──┴──┴──┴──┘
368+
/// └ 1 ┘ └ 2 ┘
369+
/// ┌──┬──┬──┬──┬──┐
370+
/// [3_u8, null] │01│01│03│00│00│
371+
/// └──┴──┴──┴──┴──┘
372+
/// └ 1 ┘ └null┘
373+
/// ┌──┐
374+
/// null │00│
375+
/// └──┘
376+
///
377+
///```
378+
///
338379
/// # Ordering
339380
///
340381
/// ## Float Ordering
@@ -409,6 +450,11 @@ impl Codec {
409450
let converter = RowConverter::new(vec![field])?;
410451
Ok(Self::List(converter))
411452
}
453+
DataType::FixedSizeList(f, _) => {
454+
let field = SortField::new_with_options(f.data_type().clone(), sort_field.options);
455+
let converter = RowConverter::new(vec![field])?;
456+
Ok(Self::List(converter))
457+
}
412458
DataType::Struct(f) => {
413459
let sort_fields = f
414460
.iter()
@@ -450,6 +496,7 @@ impl Codec {
450496
let values = match array.data_type() {
451497
DataType::List(_) => as_list_array(array).values(),
452498
DataType::LargeList(_) => as_large_list_array(array).values(),
499+
DataType::FixedSizeList(_, _) => as_fixed_size_list_array(array).values(),
453500
_ => unreachable!(),
454501
};
455502
let rows = converter.convert_columns(&[values.clone()])?;
@@ -536,9 +583,10 @@ impl RowConverter {
536583
fn supports_datatype(d: &DataType) -> bool {
537584
match d {
538585
_ if !d.is_nested() => true,
539-
DataType::List(f) | DataType::LargeList(f) | DataType::Map(f, _) => {
540-
Self::supports_datatype(f.data_type())
541-
}
586+
DataType::List(f)
587+
| DataType::LargeList(f)
588+
| DataType::FixedSizeList(f, _)
589+
| DataType::Map(f, _) => Self::supports_datatype(f.data_type()),
542590
DataType::Struct(f) => f.iter().all(|x| Self::supports_datatype(x.data_type())),
543591
_ => false,
544592
}
@@ -1244,6 +1292,11 @@ fn row_lengths(cols: &[ArrayRef], encoders: &[Encoder]) -> Vec<usize> {
12441292
DataType::LargeList(_) => {
12451293
list::compute_lengths(&mut lengths, rows, as_large_list_array(array))
12461294
}
1295+
DataType::FixedSizeList(_, _) => compute_lengths_fixed_size_list(
1296+
&mut tracker,
1297+
rows,
1298+
as_fixed_size_list_array(array),
1299+
),
12471300
_ => unreachable!(),
12481301
},
12491302
}
@@ -1340,6 +1393,9 @@ fn encode_column(
13401393
DataType::LargeList(_) => {
13411394
list::encode(data, offsets, rows, opts, as_large_list_array(column))
13421395
}
1396+
DataType::FixedSizeList(_, _) => {
1397+
encode_fixed_size_list(data, offsets, rows, opts, as_fixed_size_list_array(column))
1398+
}
13431399
_ => unreachable!(),
13441400
},
13451401
}
@@ -1425,6 +1481,13 @@ unsafe fn decode_column(
14251481
DataType::LargeList(_) => {
14261482
Arc::new(list::decode::<i64>(converter, rows, field, validate_utf8)?)
14271483
}
1484+
DataType::FixedSizeList(_, value_length) => Arc::new(list::decode_fixed_size_list(
1485+
converter,
1486+
rows,
1487+
field,
1488+
validate_utf8,
1489+
value_length.as_usize(),
1490+
)?),
14281491
_ => unreachable!(),
14291492
},
14301493
};
@@ -2016,6 +2079,9 @@ mod tests {
20162079
builder.values().append_null();
20172080
builder.append(true);
20182081
builder.append(true);
2082+
builder.values().append_value(17); // MASKED
2083+
builder.values().append_null(); // MASKED
2084+
builder.append(false);
20192085

20202086
let list = Arc::new(builder.finish()) as ArrayRef;
20212087
let d = list.data_type().clone();
@@ -2024,11 +2090,12 @@ mod tests {
20242090

20252091
let rows = converter.convert_columns(&[Arc::clone(&list)]).unwrap();
20262092
assert!(rows.row(0) > rows.row(1)); // [32, 52, 32] > [32, 52, 12]
2027-
assert!(rows.row(2) < rows.row(1)); // [32, 42] < [32, 52, 12]
2028-
assert!(rows.row(3) < rows.row(2)); // null < [32, 42]
2029-
assert!(rows.row(4) < rows.row(2)); // [32, null] < [32, 42]
2030-
assert!(rows.row(5) < rows.row(2)); // [] < [32, 42]
2093+
assert!(rows.row(2) < rows.row(1)); // [32, 52] < [32, 52, 12]
2094+
assert!(rows.row(3) < rows.row(2)); // null < [32, 52]
2095+
assert!(rows.row(4) < rows.row(2)); // [32, null] < [32, 52]
2096+
assert!(rows.row(5) < rows.row(2)); // [] < [32, 52]
20312097
assert!(rows.row(3) < rows.row(5)); // null < []
2098+
assert_eq!(rows.row(3), rows.row(6)); // null = null (different masked values)
20322099

20332100
let back = converter.convert_rows(&rows).unwrap();
20342101
assert_eq!(back.len(), 1);
@@ -2041,11 +2108,12 @@ mod tests {
20412108
let rows = converter.convert_columns(&[Arc::clone(&list)]).unwrap();
20422109

20432110
assert!(rows.row(0) > rows.row(1)); // [32, 52, 32] > [32, 52, 12]
2044-
assert!(rows.row(2) < rows.row(1)); // [32, 42] < [32, 52, 12]
2045-
assert!(rows.row(3) > rows.row(2)); // null > [32, 42]
2046-
assert!(rows.row(4) > rows.row(2)); // [32, null] > [32, 42]
2047-
assert!(rows.row(5) < rows.row(2)); // [] < [32, 42]
2111+
assert!(rows.row(2) < rows.row(1)); // [32, 52] < [32, 52, 12]
2112+
assert!(rows.row(3) > rows.row(2)); // null > [32, 52]
2113+
assert!(rows.row(4) > rows.row(2)); // [32, null] > [32, 52]
2114+
assert!(rows.row(5) < rows.row(2)); // [] < [32, 52]
20482115
assert!(rows.row(3) > rows.row(5)); // null > []
2116+
assert_eq!(rows.row(3), rows.row(6)); // null = null (different masked values)
20492117

20502118
let back = converter.convert_rows(&rows).unwrap();
20512119
assert_eq!(back.len(), 1);
@@ -2058,11 +2126,12 @@ mod tests {
20582126
let rows = converter.convert_columns(&[Arc::clone(&list)]).unwrap();
20592127

20602128
assert!(rows.row(0) < rows.row(1)); // [32, 52, 32] < [32, 52, 12]
2061-
assert!(rows.row(2) > rows.row(1)); // [32, 42] > [32, 52, 12]
2062-
assert!(rows.row(3) > rows.row(2)); // null > [32, 42]
2063-
assert!(rows.row(4) > rows.row(2)); // [32, null] > [32, 42]
2064-
assert!(rows.row(5) > rows.row(2)); // [] > [32, 42]
2129+
assert!(rows.row(2) > rows.row(1)); // [32, 52] > [32, 52, 12]
2130+
assert!(rows.row(3) > rows.row(2)); // null > [32, 52]
2131+
assert!(rows.row(4) > rows.row(2)); // [32, null] > [32, 52]
2132+
assert!(rows.row(5) > rows.row(2)); // [] > [32, 52]
20652133
assert!(rows.row(3) > rows.row(5)); // null > []
2134+
assert_eq!(rows.row(3), rows.row(6)); // null = null (different masked values)
20662135

20672136
let back = converter.convert_rows(&rows).unwrap();
20682137
assert_eq!(back.len(), 1);
@@ -2075,11 +2144,12 @@ mod tests {
20752144
let rows = converter.convert_columns(&[Arc::clone(&list)]).unwrap();
20762145

20772146
assert!(rows.row(0) < rows.row(1)); // [32, 52, 32] < [32, 52, 12]
2078-
assert!(rows.row(2) > rows.row(1)); // [32, 42] > [32, 52, 12]
2079-
assert!(rows.row(3) < rows.row(2)); // null < [32, 42]
2080-
assert!(rows.row(4) < rows.row(2)); // [32, null] < [32, 42]
2081-
assert!(rows.row(5) > rows.row(2)); // [] > [32, 42]
2147+
assert!(rows.row(2) > rows.row(1)); // [32, 52] > [32, 52, 12]
2148+
assert!(rows.row(3) < rows.row(2)); // null < [32, 52]
2149+
assert!(rows.row(4) < rows.row(2)); // [32, null] < [32, 52]
2150+
assert!(rows.row(5) > rows.row(2)); // [] > [32, 52]
20822151
assert!(rows.row(3) < rows.row(5)); // null < []
2152+
assert_eq!(rows.row(3), rows.row(6)); // null = null (different masked values)
20832153

20842154
let back = converter.convert_rows(&rows).unwrap();
20852155
assert_eq!(back.len(), 1);
@@ -2190,6 +2260,114 @@ mod tests {
21902260
test_nested_list::<i64>();
21912261
}
21922262

2263+
#[test]
2264+
fn test_fixed_size_list() {
2265+
let mut builder = FixedSizeListBuilder::new(Int32Builder::new(), 3);
2266+
builder.values().append_value(32);
2267+
builder.values().append_value(52);
2268+
builder.values().append_value(32);
2269+
builder.append(true);
2270+
builder.values().append_value(32);
2271+
builder.values().append_value(52);
2272+
builder.values().append_value(12);
2273+
builder.append(true);
2274+
builder.values().append_value(32);
2275+
builder.values().append_value(52);
2276+
builder.values().append_null();
2277+
builder.append(true);
2278+
builder.values().append_value(32); // MASKED
2279+
builder.values().append_value(52); // MASKED
2280+
builder.values().append_value(13); // MASKED
2281+
builder.append(false);
2282+
builder.values().append_value(32);
2283+
builder.values().append_null();
2284+
builder.values().append_null();
2285+
builder.append(true);
2286+
builder.values().append_null();
2287+
builder.values().append_null();
2288+
builder.values().append_null();
2289+
builder.append(true);
2290+
builder.values().append_value(17); // MASKED
2291+
builder.values().append_null(); // MASKED
2292+
builder.values().append_value(77); // MASKED
2293+
builder.append(false);
2294+
2295+
let list = Arc::new(builder.finish()) as ArrayRef;
2296+
let d = list.data_type().clone();
2297+
2298+
// Default sorting (ascending, nulls first)
2299+
let converter = RowConverter::new(vec![SortField::new(d.clone())]).unwrap();
2300+
2301+
let rows = converter.convert_columns(&[Arc::clone(&list)]).unwrap();
2302+
assert!(rows.row(0) > rows.row(1)); // [32, 52, 32] > [32, 52, 12]
2303+
assert!(rows.row(2) < rows.row(1)); // [32, 52, null] < [32, 52, 12]
2304+
assert!(rows.row(3) < rows.row(2)); // null < [32, 52, null]
2305+
assert!(rows.row(4) < rows.row(2)); // [32, null, null] < [32, 52, null]
2306+
assert!(rows.row(5) < rows.row(2)); // [null, null, null] < [32, 52, null]
2307+
assert!(rows.row(3) < rows.row(5)); // null < [null, null, null]
2308+
assert_eq!(rows.row(3), rows.row(6)); // null = null (different masked values)
2309+
2310+
let back = converter.convert_rows(&rows).unwrap();
2311+
assert_eq!(back.len(), 1);
2312+
back[0].to_data().validate_full().unwrap();
2313+
assert_eq!(&back[0], &list);
2314+
2315+
// Ascending, null last
2316+
let options = SortOptions::default().asc().with_nulls_first(false);
2317+
let field = SortField::new_with_options(d.clone(), options);
2318+
let converter = RowConverter::new(vec![field]).unwrap();
2319+
let rows = converter.convert_columns(&[Arc::clone(&list)]).unwrap();
2320+
assert!(rows.row(0) > rows.row(1)); // [32, 52, 32] > [32, 52, 12]
2321+
assert!(rows.row(2) > rows.row(1)); // [32, 52, null] > [32, 52, 12]
2322+
assert!(rows.row(3) > rows.row(2)); // null > [32, 52, null]
2323+
assert!(rows.row(4) > rows.row(2)); // [32, null, null] > [32, 52, null]
2324+
assert!(rows.row(5) > rows.row(2)); // [null, null, null] > [32, 52, null]
2325+
assert!(rows.row(3) > rows.row(5)); // null > [null, null, null]
2326+
assert_eq!(rows.row(3), rows.row(6)); // null = null (different masked values)
2327+
2328+
let back = converter.convert_rows(&rows).unwrap();
2329+
assert_eq!(back.len(), 1);
2330+
back[0].to_data().validate_full().unwrap();
2331+
assert_eq!(&back[0], &list);
2332+
2333+
// Descending, nulls last
2334+
let options = SortOptions::default().desc().with_nulls_first(false);
2335+
let field = SortField::new_with_options(d.clone(), options);
2336+
let converter = RowConverter::new(vec![field]).unwrap();
2337+
let rows = converter.convert_columns(&[Arc::clone(&list)]).unwrap();
2338+
assert!(rows.row(0) < rows.row(1)); // [32, 52, 32] < [32, 52, 12]
2339+
assert!(rows.row(2) > rows.row(1)); // [32, 52, null] > [32, 52, 12]
2340+
assert!(rows.row(3) > rows.row(2)); // null > [32, 52, null]
2341+
assert!(rows.row(4) > rows.row(2)); // [32, null, null] > [32, 52, null]
2342+
assert!(rows.row(5) > rows.row(2)); // [null, null, null] > [32, 52, null]
2343+
assert!(rows.row(3) > rows.row(5)); // null > [null, null, null]
2344+
assert_eq!(rows.row(3), rows.row(6)); // null = null (different masked values)
2345+
2346+
let back = converter.convert_rows(&rows).unwrap();
2347+
assert_eq!(back.len(), 1);
2348+
back[0].to_data().validate_full().unwrap();
2349+
assert_eq!(&back[0], &list);
2350+
2351+
// Descending, nulls first
2352+
let options = SortOptions::default().desc().with_nulls_first(true);
2353+
let field = SortField::new_with_options(d, options);
2354+
let converter = RowConverter::new(vec![field]).unwrap();
2355+
let rows = converter.convert_columns(&[Arc::clone(&list)]).unwrap();
2356+
2357+
assert!(rows.row(0) < rows.row(1)); // [32, 52, 32] < [32, 52, 12]
2358+
assert!(rows.row(2) < rows.row(1)); // [32, 52, null] > [32, 52, 12]
2359+
assert!(rows.row(3) < rows.row(2)); // null < [32, 52, null]
2360+
assert!(rows.row(4) < rows.row(2)); // [32, null, null] < [32, 52, null]
2361+
assert!(rows.row(5) < rows.row(2)); // [null, null, null] > [32, 52, null]
2362+
assert!(rows.row(3) < rows.row(5)); // null < [null, null, null]
2363+
assert_eq!(rows.row(3), rows.row(6)); // null = null (different masked values)
2364+
2365+
let back = converter.convert_rows(&rows).unwrap();
2366+
assert_eq!(back.len(), 1);
2367+
back[0].to_data().validate_full().unwrap();
2368+
assert_eq!(&back[0], &list);
2369+
}
2370+
21932371
fn generate_primitive_array<K>(len: usize, valid_percent: f64) -> PrimitiveArray<K>
21942372
where
21952373
K: ArrowPrimitiveType,

0 commit comments

Comments
 (0)