Skip to content

Commit 4dfbca6

Browse files
authored
sort_primitive result is capped to the min of limit or values.len (#236)
* sort_primitive result is capped to the min of limit or values.len fixes #235 * Fixed length calculation of nulls to include * Add more sort_primitive tests for sorts /w limit
1 parent 4865247 commit 4dfbca6

File tree

1 file changed

+52
-7
lines changed

1 file changed

+52
-7
lines changed

arrow/src/compute/kernels/sort.rs

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -487,24 +487,27 @@ where
487487
len = limit.min(len);
488488
}
489489
if !descending {
490-
sort_by(&mut valids, len - nulls_len, |a, b| cmp(a.1, b.1));
490+
sort_by(&mut valids, len.saturating_sub(nulls_len), |a, b| {
491+
cmp(a.1, b.1)
492+
});
491493
} else {
492-
sort_by(&mut valids, len - nulls_len, |a, b| cmp(a.1, b.1).reverse());
494+
sort_by(&mut valids, len.saturating_sub(nulls_len), |a, b| {
495+
cmp(a.1, b.1).reverse()
496+
});
493497
// reverse to keep a stable ordering
494498
nulls.reverse();
495499
}
496500

497501
// collect results directly into a buffer instead of a vec to avoid another aligned allocation
498-
let mut result = MutableBuffer::new(values.len() * std::mem::size_of::<u32>());
502+
let result_capacity = len * std::mem::size_of::<u32>();
503+
let mut result = MutableBuffer::new(result_capacity);
499504
// sets len to capacity so we can access the whole buffer as a typed slice
500-
result.resize(values.len() * std::mem::size_of::<u32>(), 0);
505+
result.resize(result_capacity, 0);
501506
let result_slice: &mut [u32] = result.typed_data_mut();
502507

503-
debug_assert_eq!(result_slice.len(), nulls_len + valids_len);
504-
505508
if options.nulls_first {
506509
let size = nulls_len.min(len);
507-
result_slice[0..nulls_len.min(len)].copy_from_slice(&nulls);
510+
result_slice[0..size].copy_from_slice(&nulls[0..size]);
508511
if nulls_len < len {
509512
insert_valid_values(result_slice, nulls_len, &valids[0..len - size]);
510513
}
@@ -1556,6 +1559,48 @@ mod tests {
15561559
Some(3),
15571560
vec![Some(1.0), Some(2.0), Some(3.0)],
15581561
);
1562+
1563+
// valid values less than limit with extra nulls
1564+
test_sort_primitive_arrays::<Float64Type>(
1565+
vec![Some(2.0), None, None, Some(1.0)],
1566+
Some(SortOptions {
1567+
descending: false,
1568+
nulls_first: false,
1569+
}),
1570+
Some(3),
1571+
vec![Some(1.0), Some(2.0), None],
1572+
);
1573+
1574+
test_sort_primitive_arrays::<Float64Type>(
1575+
vec![Some(2.0), None, None, Some(1.0)],
1576+
Some(SortOptions {
1577+
descending: false,
1578+
nulls_first: true,
1579+
}),
1580+
Some(3),
1581+
vec![None, None, Some(1.0)],
1582+
);
1583+
1584+
// more nulls than limit
1585+
test_sort_primitive_arrays::<Float64Type>(
1586+
vec![Some(2.0), None, None, None],
1587+
Some(SortOptions {
1588+
descending: false,
1589+
nulls_first: true,
1590+
}),
1591+
Some(2),
1592+
vec![None, None],
1593+
);
1594+
1595+
test_sort_primitive_arrays::<Float64Type>(
1596+
vec![Some(2.0), None, None, None],
1597+
Some(SortOptions {
1598+
descending: false,
1599+
nulls_first: false,
1600+
}),
1601+
Some(2),
1602+
vec![Some(2.0), None],
1603+
);
15591604
}
15601605

15611606
#[test]

0 commit comments

Comments
 (0)