Skip to content

Commit b51c66f

Browse files
committed
feat: convert_array_to_scalar_vec returns optional arrays
1 parent 3b37ae0 commit b51c66f

File tree

6 files changed

+199
-61
lines changed

6 files changed

+199
-61
lines changed

datafusion/common/src/scalar/mod.rs

Lines changed: 130 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3246,6 +3246,8 @@ impl ScalarValue {
32463246

32473247
/// Retrieve ScalarValue for each row in `array`
32483248
///
3249+
/// Elements in `array` may be NULL, in which case the corresponding element in the returned vector is None.
3250+
///
32493251
/// Example 1: Array (ScalarValue::Int32)
32503252
/// ```
32513253
/// use datafusion_common::ScalarValue;
@@ -3262,15 +3264,15 @@ impl ScalarValue {
32623264
/// let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&list_arr).unwrap();
32633265
///
32643266
/// let expected = vec![
3265-
/// vec![
3266-
/// ScalarValue::Int32(Some(1)),
3267-
/// ScalarValue::Int32(Some(2)),
3268-
/// ScalarValue::Int32(Some(3)),
3269-
/// ],
3270-
/// vec![
3271-
/// ScalarValue::Int32(Some(4)),
3272-
/// ScalarValue::Int32(Some(5)),
3273-
/// ],
3267+
/// Some(vec![
3268+
/// ScalarValue::Int32(Some(1)),
3269+
/// ScalarValue::Int32(Some(2)),
3270+
/// ScalarValue::Int32(Some(3)),
3271+
/// ]),
3272+
/// Some(vec![
3273+
/// ScalarValue::Int32(Some(4)),
3274+
/// ScalarValue::Int32(Some(5)),
3275+
/// ]),
32743276
/// ];
32753277
///
32763278
/// assert_eq!(scalar_vec, expected);
@@ -3303,28 +3305,62 @@ impl ScalarValue {
33033305
/// ]);
33043306
///
33053307
/// let expected = vec![
3306-
/// vec![
3308+
/// Some(vec![
33073309
/// ScalarValue::List(Arc::new(l1)),
33083310
/// ScalarValue::List(Arc::new(l2)),
3309-
/// ],
3311+
/// ]),
3312+
/// ];
3313+
///
3314+
/// assert_eq!(scalar_vec, expected);
3315+
/// ```
3316+
///
3317+
/// Example 3: Nullable array
3318+
/// ```
3319+
/// use datafusion_common::ScalarValue;
3320+
/// use arrow::array::ListArray;
3321+
/// use arrow::datatypes::{DataType, Int32Type};
3322+
///
3323+
/// let list_arr = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
3324+
/// Some(vec![Some(1), Some(2), Some(3)]),
3325+
/// None,
3326+
/// Some(vec![Some(4), Some(5)])
3327+
/// ]);
3328+
///
3329+
/// // Convert the array into Scalar Values for each row
3330+
/// let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&list_arr).unwrap();
3331+
///
3332+
/// let expected = vec![
3333+
/// Some(vec![
3334+
/// ScalarValue::Int32(Some(1)),
3335+
/// ScalarValue::Int32(Some(2)),
3336+
/// ScalarValue::Int32(Some(3)),
3337+
/// ]),
3338+
/// None,
3339+
/// Some(vec![
3340+
/// ScalarValue::Int32(Some(4)),
3341+
/// ScalarValue::Int32(Some(5)),
3342+
/// ]),
33103343
/// ];
33113344
///
33123345
/// assert_eq!(scalar_vec, expected);
33133346
/// ```
3314-
pub fn convert_array_to_scalar_vec(array: &dyn Array) -> Result<Vec<Vec<Self>>> {
3347+
pub fn convert_array_to_scalar_vec(
3348+
array: &dyn Array,
3349+
) -> Result<Vec<Option<Vec<Self>>>> {
33153350
fn generic_collect<OffsetSize: OffsetSizeTrait>(
33163351
array: &dyn Array,
3317-
) -> Result<Vec<Vec<ScalarValue>>> {
3352+
) -> Result<Vec<Option<Vec<ScalarValue>>>> {
33183353
array
33193354
.as_list::<OffsetSize>()
33203355
.iter()
3321-
.map(|nested_array| match nested_array {
3322-
Some(nested_array) => (0..nested_array.len())
3323-
.map(|i| ScalarValue::try_from_array(&nested_array, i))
3324-
.collect::<Result<Vec<_>>>(),
3325-
// TODO: what can we put for null?
3326-
// https://github.com/apache/datafusion/issues/17749
3327-
None => Ok(vec![]),
3356+
.map(|nested_array| {
3357+
nested_array
3358+
.map(|array| {
3359+
(0..array.len())
3360+
.map(|i| ScalarValue::try_from_array(&array, i))
3361+
.collect::<Result<Vec<_>>>()
3362+
})
3363+
.transpose()
33283364
})
33293365
.collect()
33303366
}
@@ -9021,7 +9057,7 @@ mod tests {
90219057

90229058
#[test]
90239059
fn test_convert_array_to_scalar_vec() {
9024-
// Regular ListArray
9060+
// 1: Regular ListArray
90259061
let list = ListArray::from_iter_primitive::<Int64Type, _, _>(vec![
90269062
Some(vec![Some(1), Some(2)]),
90279063
None,
@@ -9031,17 +9067,20 @@ mod tests {
90319067
assert_eq!(
90329068
converted,
90339069
vec![
9034-
vec![ScalarValue::Int64(Some(1)), ScalarValue::Int64(Some(2))],
9035-
vec![],
9036-
vec![
9070+
Some(vec![
9071+
ScalarValue::Int64(Some(1)),
9072+
ScalarValue::Int64(Some(2))
9073+
]),
9074+
None,
9075+
Some(vec![
90379076
ScalarValue::Int64(Some(3)),
90389077
ScalarValue::Int64(None),
90399078
ScalarValue::Int64(Some(4))
9040-
],
9079+
]),
90419080
]
90429081
);
90439082

9044-
// Regular LargeListArray
9083+
// 2: Regular LargeListArray
90459084
let large_list = LargeListArray::from_iter_primitive::<Int64Type, _, _>(vec![
90469085
Some(vec![Some(1), Some(2)]),
90479086
None,
@@ -9051,17 +9090,20 @@ mod tests {
90519090
assert_eq!(
90529091
converted,
90539092
vec![
9054-
vec![ScalarValue::Int64(Some(1)), ScalarValue::Int64(Some(2))],
9055-
vec![],
9056-
vec![
9093+
Some(vec![
9094+
ScalarValue::Int64(Some(1)),
9095+
ScalarValue::Int64(Some(2))
9096+
]),
9097+
None,
9098+
Some(vec![
90579099
ScalarValue::Int64(Some(3)),
90589100
ScalarValue::Int64(None),
90599101
ScalarValue::Int64(Some(4))
9060-
],
9102+
]),
90619103
]
90629104
);
90639105

9064-
// Funky (null slot has non-zero list offsets)
9106+
// 3: Funky (null slot has non-zero list offsets)
90659107
// Offsets + Values looks like this: [[1, 2], [3, 4], [5]]
90669108
// But with NullBuffer it's like this: [[1, 2], NULL, [5]]
90679109
let funky = ListArray::new(
@@ -9074,9 +9116,63 @@ mod tests {
90749116
assert_eq!(
90759117
converted,
90769118
vec![
9077-
vec![ScalarValue::Int64(Some(1)), ScalarValue::Int64(Some(2))],
9078-
vec![],
9079-
vec![ScalarValue::Int64(Some(5))],
9119+
Some(vec![
9120+
ScalarValue::Int64(Some(1)),
9121+
ScalarValue::Int64(Some(2))
9122+
]),
9123+
None,
9124+
Some(vec![ScalarValue::Int64(Some(5))]),
9125+
]
9126+
);
9127+
9128+
// 4: Offsets + Values looks like this: [[1, 2], [], [5]]
9129+
// But with NullBuffer it's like this: [[1, 2], NULL, [5]]
9130+
// The converted result is: [[1, 2], None, [5]]
9131+
let array4 = ListArray::new(
9132+
Field::new_list_field(DataType::Int64, true).into(),
9133+
OffsetBuffer::new(vec![0, 2, 2, 5].into()),
9134+
Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5, 6])),
9135+
Some(NullBuffer::from(vec![true, false, true])),
9136+
);
9137+
let converted = ScalarValue::convert_array_to_scalar_vec(&array4).unwrap();
9138+
assert_eq!(
9139+
converted,
9140+
vec![
9141+
Some(vec![
9142+
ScalarValue::Int64(Some(1)),
9143+
ScalarValue::Int64(Some(2))
9144+
]),
9145+
None,
9146+
Some(vec![
9147+
ScalarValue::Int64(Some(3)),
9148+
ScalarValue::Int64(Some(4)),
9149+
ScalarValue::Int64(Some(5)),
9150+
]),
9151+
]
9152+
);
9153+
9154+
// 5: Offsets + Values looks like this: [[1, 2], [], [5]]
9155+
// Same as 4, but the middle array is not null, so after conversion it's empty.
9156+
let array5 = ListArray::new(
9157+
Field::new_list_field(DataType::Int64, true).into(),
9158+
OffsetBuffer::new(vec![0, 2, 2, 5].into()),
9159+
Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5, 6])),
9160+
Some(NullBuffer::from(vec![true, true, true])),
9161+
);
9162+
let converted = ScalarValue::convert_array_to_scalar_vec(&array5).unwrap();
9163+
assert_eq!(
9164+
converted,
9165+
vec![
9166+
Some(vec![
9167+
ScalarValue::Int64(Some(1)),
9168+
ScalarValue::Int64(Some(2))
9169+
]),
9170+
Some(vec![]),
9171+
Some(vec![
9172+
ScalarValue::Int64(Some(3)),
9173+
ScalarValue::Int64(Some(4)),
9174+
ScalarValue::Int64(Some(5)),
9175+
]),
90809176
]
90819177
);
90829178
}

datafusion/core/tests/sql/aggregates/basic.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ async fn csv_query_array_agg_distinct() -> Result<()> {
4848
let column = actual[0].column(0);
4949
assert_eq!(column.len(), 1);
5050
let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&column)?;
51-
let mut scalars = scalar_vec[0].clone();
51+
let mut scalars = scalar_vec[0].as_ref().unwrap().clone();
5252

5353
// workaround lack of Ord of ScalarValue
5454
let cmp = |a: &ScalarValue, b: &ScalarValue| {

datafusion/functions-aggregate-common/src/merge_arrays.rs

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ impl PartialOrd for CustomElement<'_> {
8787

8888
/// This functions merges `values` array (`&[Vec<ScalarValue>]`) into single array `Vec<ScalarValue>`
8989
/// Merging done according to ordering values stored inside `ordering_values` (`&[Vec<Vec<ScalarValue>>]`)
90-
/// Inner `Vec<ScalarValue>` in the `ordering_values` can be thought as ordering information for the
90+
/// Inner `Vec<ScalarValue>` in the `ordering_values` can be thought as ordering information for
9191
/// each `ScalarValue` in the `values` array.
9292
/// Desired ordering specified by `sort_options` argument (Should have same size with inner `Vec<ScalarValue>`
9393
/// of the `ordering_values` array).
@@ -119,17 +119,27 @@ pub fn merge_ordered_arrays(
119119
// Defines according to which ordering comparisons should be done.
120120
sort_options: &[SortOptions],
121121
) -> datafusion_common::Result<(Vec<ScalarValue>, Vec<Vec<ScalarValue>>)> {
122-
// Keep track the most recent data of each branch, in binary heap data structure.
122+
// Keep track of the most recent data of each branch, in a binary heap data structure.
123123
let mut heap = BinaryHeap::<CustomElement>::new();
124124

125-
if values.len() != ordering_values.len()
126-
|| values
127-
.iter()
128-
.zip(ordering_values.iter())
129-
.any(|(vals, ordering_vals)| vals.len() != ordering_vals.len())
125+
if values.len() != ordering_values.len() {
126+
return exec_err!(
127+
"Expects values and ordering_values to have same size but got {} and {}",
128+
values.len(),
129+
ordering_values.len()
130+
);
131+
}
132+
if let Some((idx, (values, ordering_values))) = values
133+
.iter()
134+
.zip(ordering_values.iter())
135+
.enumerate()
136+
.find(|(_, (vals, ordering_vals))| vals.len() != ordering_vals.len())
130137
{
131138
return exec_err!(
132-
"Expects values arguments and/or ordering_values arguments to have same size"
139+
"Expects values elements and ordering_values elements to have same size but got {} and {} at index {}",
140+
values.len(),
141+
ordering_values.len(),
142+
idx
133143
);
134144
}
135145
let n_branch = values.len();

datafusion/functions-aggregate/src/array_agg.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -687,13 +687,16 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator {
687687

688688
// Convert array to Scalars to sort them easily. Convert back to array at evaluation.
689689
let array_agg_res = ScalarValue::convert_array_to_scalar_vec(array_agg_values)?;
690-
for v in array_agg_res.into_iter() {
691-
partition_values.push(v.into());
690+
for maybe_v in array_agg_res.into_iter() {
691+
if let Some(v) = maybe_v {
692+
partition_values.push(v.into());
693+
} else {
694+
partition_values.push(vec![].into());
695+
}
692696
}
693697

694698
let orderings = ScalarValue::convert_array_to_scalar_vec(agg_orderings)?;
695-
696-
for partition_ordering_rows in orderings.into_iter() {
699+
for partition_ordering_rows in orderings.into_iter().flatten() {
697700
// Extract value from struct to ordering_rows for each group/partition
698701
let ordering_value = partition_ordering_rows.into_iter().map(|ordering_row| {
699702
if let ScalarValue::Struct(s) = ordering_row {

datafusion/functions-aggregate/src/nth_value.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ impl Accumulator for TrivialNthValueAccumulator {
267267
// First entry in the state is the aggregation result.
268268
let n_required = self.n.unsigned_abs() as usize;
269269
let array_agg_res = ScalarValue::convert_array_to_scalar_vec(&states[0])?;
270-
for v in array_agg_res.into_iter() {
270+
for v in array_agg_res.into_iter().flatten() {
271271
self.values.extend(v);
272272
if self.values.len() > n_required {
273273
// There is enough data collected, can stop merging:
@@ -457,14 +457,14 @@ impl Accumulator for NthValueAccumulator {
457457
let mut partition_values = vec![self.values.clone()];
458458
// First entry in the state is the aggregation result.
459459
let array_agg_res = ScalarValue::convert_array_to_scalar_vec(&states[0])?;
460-
for v in array_agg_res.into_iter() {
460+
for v in array_agg_res.into_iter().flatten() {
461461
partition_values.push(v.into());
462462
}
463463
// Stores ordering requirement expression results coming from each partition:
464464
let mut partition_ordering_values = vec![self.ordering_values.clone()];
465465
let orderings = ScalarValue::convert_array_to_scalar_vec(agg_orderings)?;
466466
// Extract value from struct to ordering_rows for each group/partition:
467-
for partition_ordering_rows in orderings.into_iter() {
467+
for partition_ordering_rows in orderings.into_iter().flatten() {
468468
let ordering_values = partition_ordering_rows.into_iter().map(|ordering_row| {
469469
let ScalarValue::Struct(s_array) = ordering_row else {
470470
return exec_err!(

datafusion/functions-nested/src/array_has.rs

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -142,17 +142,25 @@ impl ScalarUDFImpl for ArrayHas {
142142
ScalarValue::convert_array_to_scalar_vec(&array)
143143
{
144144
assert_eq!(scalar_values.len(), 1);
145-
let list = scalar_values
146-
.into_iter()
147-
.flatten()
148-
.map(|v| Expr::Literal(v, None))
149-
.collect();
150-
151-
return Ok(ExprSimplifyResult::Simplified(in_list(
152-
std::mem::take(needle),
153-
list,
154-
false,
155-
)));
145+
match &scalar_values[0] {
146+
// Haystack was a single list element as expected
147+
Some(list) => {
148+
let list = list
149+
.iter()
150+
.map(|v| Expr::Literal(v.clone(), None))
151+
.collect();
152+
153+
return Ok(ExprSimplifyResult::Simplified(in_list(
154+
std::mem::take(needle),
155+
list,
156+
false,
157+
)));
158+
}
159+
// Haystack was a singular null, should be handled elsewhere
160+
None => {
161+
return Ok(ExprSimplifyResult::Original(args));
162+
}
163+
};
156164
}
157165
}
158166
Expr::ScalarFunction(ScalarFunction { func, args })
@@ -786,4 +794,25 @@ mod tests {
786794

787795
Ok(())
788796
}
797+
798+
#[test]
799+
fn test_simplify_array_has_with_null_haystack() {
800+
let haystack = ListArray::new_null(
801+
Arc::new(Field::new_list_field(DataType::Int32, true)),
802+
1,
803+
);
804+
let haystack = lit(ScalarValue::List(Arc::new(haystack)));
805+
let needle = col("c");
806+
807+
let props = ExecutionProps::new();
808+
let context = datafusion_expr::simplify::SimplifyContext::new(&props);
809+
810+
let Ok(ExprSimplifyResult::Original(args)) =
811+
ArrayHas::new().simplify(vec![haystack.clone(), needle.clone()], &context)
812+
else {
813+
panic!("Expected non-simplified expression");
814+
};
815+
816+
assert_eq!(args, vec![haystack, col("c")]);
817+
}
789818
}

0 commit comments

Comments
 (0)