Skip to content
47 changes: 7 additions & 40 deletions datafusion/core/tests/fuzz_cases/sort_query_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ use datafusion_execution::memory_pool::{
};
use datafusion_expr::display_schema;
use datafusion_physical_plan::spill::get_record_batch_memory_size;
use itertools::Itertools;
use std::time::Duration;

use datafusion_execution::{memory_pool::FairSpillPool, runtime_env::RuntimeEnvBuilder};
Expand Down Expand Up @@ -73,43 +72,6 @@ async fn sort_query_fuzzer_runner() {
fuzzer.run().await.unwrap();
}

/// Reproduce the bug with specific seeds from the
/// [failing test case](https://github.com/apache/datafusion/issues/16452).
#[tokio::test(flavor = "multi_thread")]
async fn test_reproduce_sort_query_issue_16452() {
// Seeds from the failing test case
let init_seed = 10313160656544581998u64;
let query_seed = 15004039071976572201u64;
let config_seed_1 = 11807432710583113300u64;
let config_seed_2 = 759937414670321802u64;

let random_seed = 1u64; // Use a fixed seed to ensure consistent behavior

let mut test_generator = SortFuzzerTestGenerator::new(
2000,
3,
"sort_fuzz_table".to_string(),
get_supported_types_columns(random_seed),
false,
random_seed,
);

let mut results = vec![];

for config_seed in [config_seed_1, config_seed_2] {
let r = test_generator
.fuzzer_run(init_seed, query_seed, config_seed)
.await
.unwrap();

results.push(r);
}

for (lhs, rhs) in results.iter().tuple_windows() {
check_equality_of_batches(lhs, rhs).unwrap();
}
}

/// SortQueryFuzzer holds the runner configuration for executing sort query fuzz tests. The fuzzing details are managed inside `SortFuzzerTestGenerator`.
///
/// It defines:
Expand Down Expand Up @@ -466,7 +428,7 @@ impl SortFuzzerTestGenerator {
.collect();

let mut order_by_clauses = Vec::new();
for col in selected_columns {
for col in &selected_columns {
let mut clause = col.name.clone();
if rng.random_bool(0.5) {
let order = if rng.random_bool(0.5) { "ASC" } else { "DESC" };
Expand Down Expand Up @@ -501,7 +463,12 @@ impl SortFuzzerTestGenerator {
let limit_clause = limit.map_or(String::new(), |l| format!(" LIMIT {l}"));

let query = format!(
"SELECT * FROM {} ORDER BY {}{}",
"SELECT {} FROM {} ORDER BY {}{}",
selected_columns
.iter()
.map(|col| col.name.clone())
.collect::<Vec<_>>()
.join(", "),
self.table_name,
order_by_clauses.join(", "),
limit_clause
Expand Down
51 changes: 46 additions & 5 deletions datafusion/physical-plan/src/topk/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
//! TopK: Combination of Sort / LIMIT

use arrow::{
array::Array,
compute::interleave_record_batch,
array::{Array, AsArray},
compute::{interleave_record_batch, prep_null_mask_filter, FilterBuilder},
row::{RowConverter, Rows, SortField},
};
use datafusion_expr::{ColumnarValue, Operator};
Expand Down Expand Up @@ -203,7 +203,7 @@ impl TopK {
let baseline = self.metrics.baseline.clone();
let _timer = baseline.elapsed_compute().timer();

let sort_keys: Vec<ArrayRef> = self
let mut sort_keys: Vec<ArrayRef> = self
.expr
.iter()
.map(|expr| {
Expand All @@ -212,15 +212,56 @@ impl TopK {
})
.collect::<Result<Vec<_>>>()?;

let mut selected_rows = None;

if let Some(filter) = self.filter.as_ref() {
// If a filter is provided, update it with the new rows
let filter = filter.current()?;
let filtered = filter.evaluate(&batch)?;
let num_rows = batch.num_rows();
let array = filtered.into_array(num_rows)?;
let mut filter = array.as_boolean().clone();
let true_count = filter.true_count();
if true_count == 0 {
// nothing to filter, so no need to update
return Ok(());
}
// only update the keys / rows if the filter does not match all rows
if true_count < num_rows {
// Indices in `set_indices` should be correct if filter contains nulls
// So we prepare the filter here. Note this is also done in the `FilterBuilder`
// so there is no overhead to do this here.
if filter.nulls().is_some() {
filter = prep_null_mask_filter(&filter);
}

let filter_predicate = FilterBuilder::new(&filter);
let filter_predicate = if sort_keys.len() > 1 {
// Optimize filter when it has multiple sort keys
filter_predicate.optimize().build()
} else {
filter_predicate.build()
};
selected_rows = Some(filter);
sort_keys = sort_keys
.iter()
.map(|key| filter_predicate.filter(key).map_err(|x| x.into()))
.collect::<Result<Vec<_>>>()?;
}
};
// reuse existing `Rows` to avoid reallocations
let rows = &mut self.scratch_rows;
rows.clear();
self.row_converter.append(rows, &sort_keys)?;

let mut batch_entry = self.heap.register_batch(batch.clone());

let replacements =
self.find_new_topk_items(0..sort_keys[0].len(), &mut batch_entry);
let replacements = match selected_rows {
Some(filter) => {
self.find_new_topk_items(filter.values().set_indices(), &mut batch_entry)
}
None => self.find_new_topk_items(0..sort_keys[0].len(), &mut batch_entry),
};

if replacements > 0 {
self.metrics.row_replacements.add(replacements);
Expand Down