diff --git a/datafusion/core/tests/fuzz_cases/sort_query_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_query_fuzz.rs index 1f47412caf2a..c95d00340e71 100644 --- a/datafusion/core/tests/fuzz_cases/sort_query_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_query_fuzz.rs @@ -20,8 +20,8 @@ use std::cmp::min; use std::sync::Arc; -use arrow::array::RecordBatch; -use arrow_schema::SchemaRef; +use arrow::array::{record_batch, RecordBatch}; +use arrow_schema::{DataType, SchemaRef}; use datafusion::datasource::MemTable; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::{instant::Instant, Result}; @@ -31,7 +31,6 @@ use datafusion_execution::memory_pool::{ }; use datafusion_expr::display_schema; use datafusion_physical_plan::spill::get_record_batch_memory_size; -use itertools::Itertools; use std::time::Duration; use datafusion_execution::{memory_pool::FairSpillPool, runtime_env::RuntimeEnvBuilder}; @@ -75,38 +74,43 @@ async fn sort_query_fuzzer_runner() { /// Reproduce the bug with specific seeds from the /// [failing test case](https://github.com/apache/datafusion/issues/16452). -#[tokio::test(flavor = "multi_thread")] +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn test_reproduce_sort_query_issue_16452() { - // Seeds from the failing test case - let init_seed = 10313160656544581998u64; - let query_seed = 15004039071976572201u64; - let config_seed_1 = 11807432710583113300u64; - let config_seed_2 = 759937414670321802u64; - - let random_seed = 1u64; // Use a fixed seed to ensure consistent behavior - - let mut test_generator = SortFuzzerTestGenerator::new( - 2000, - 3, - "sort_fuzz_table".to_string(), - get_supported_types_columns(random_seed), - false, - random_seed, - ); - - let mut results = vec![]; - - for config_seed in [config_seed_1, config_seed_2] { - let r = test_generator - .fuzzer_run(init_seed, query_seed, config_seed) - .await - .unwrap(); - - results.push(r); - } - - for (lhs, rhs) in results.iter().tuple_windows() { - check_equality_of_batches(lhs, rhs).unwrap(); + let schema = Arc::new(arrow_schema::Schema::new(vec![ + arrow_schema::Field::new("u64", DataType::UInt64, true), + arrow_schema::Field::new("u32", DataType::UInt32, true), + ])); + + // build the data manually to reproduce the bug + let data = vec![ + vec![record_batch!(("u64", UInt64, [1]), ("u32", UInt32, [2])).unwrap()], + vec![record_batch!(("u64", UInt64, [2]), ("u32", UInt32, [2])).unwrap()], + ]; + + let query = "SELECT * FROM sort_fuzz_table ORDER BY u32 LIMIT 1"; + let config = SessionConfig::new() + .with_target_partitions(2) + .with_batch_size(1); + let ctx = SessionContext::new_with_config(config); + let provider = Arc::new(MemTable::try_new(schema.clone(), data.clone()).unwrap()); + ctx.register_table("sort_fuzz_table", provider).unwrap(); + + // Failure usually happens afer ~500 iterations, add a generous number of runs to make sure it reproduces + let mut previous_results = None; + for iteration in 0..4096 { + println!("Iteration {iteration}"); + let r = ctx.sql(query).await.unwrap().collect().await.unwrap(); + match &mut previous_results { + None => { + // Store the first run as the expected result + previous_results = Some(r.clone()); + } + Some(prev) => { + // Check that the results are consistent with the previous run + check_equality_of_batches(prev, &r).unwrap(); + *prev = r; // Update the previous results + } + } } } diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs index 71029662f5f5..8d06fa73ce8e 100644 --- a/datafusion/physical-plan/src/topk/mod.rs +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -18,8 +18,8 @@ //! TopK: Combination of Sort / LIMIT use arrow::{ - array::Array, - compute::interleave_record_batch, + array::{Array, AsArray}, + compute::{interleave_record_batch, prep_null_mask_filter, FilterBuilder}, row::{RowConverter, Rows, SortField}, }; use datafusion_expr::{ColumnarValue, Operator}; @@ -203,7 +203,7 @@ impl TopK { let baseline = self.metrics.baseline.clone(); let _timer = baseline.elapsed_compute().timer(); - let sort_keys: Vec = self + let mut sort_keys: Vec = self .expr .iter() .map(|expr| { @@ -212,6 +212,43 @@ impl TopK { }) .collect::>>()?; + let mut selected_rows = None; + + if let Some(filter) = self.filter.as_ref() { + // If a filter is provided, update it with the new rows + let filter = filter.current()?; + let filtered = filter.evaluate(&batch)?; + let num_rows = batch.num_rows(); + let array = filtered.into_array(num_rows)?; + let mut filter = array.as_boolean().clone(); + let true_count = filter.true_count(); + if true_count == 0 { + // nothing to filter, so no need to update + return Ok(()); + } + // only update the keys / rows if the filter does not match all rows + if true_count < num_rows { + // Indices in `set_indices` should be correct if filter contains nulls + // So we prepare the filter here. Note this is also done in the `FilterBuilder` + // so there is no overhead to do this here. + if filter.nulls().is_some() { + filter = prep_null_mask_filter(&filter); + } + + let filter_predicate = FilterBuilder::new(&filter); + let filter_predicate = if sort_keys.len() > 1 { + // Optimize filter when it has multiple sort keys + filter_predicate.optimize().build() + } else { + filter_predicate.build() + }; + selected_rows = Some(filter); + sort_keys = sort_keys + .iter() + .map(|key| filter_predicate.filter(key).map_err(|x| x.into())) + .collect::>>()?; + } + }; // reuse existing `Rows` to avoid reallocations let rows = &mut self.scratch_rows; rows.clear(); @@ -219,8 +256,12 @@ impl TopK { let mut batch_entry = self.heap.register_batch(batch.clone()); - let replacements = - self.find_new_topk_items(0..sort_keys[0].len(), &mut batch_entry); + let replacements = match selected_rows { + Some(filter) => { + self.find_new_topk_items(filter.values().set_indices(), &mut batch_entry) + } + None => self.find_new_topk_items(0..sort_keys[0].len(), &mut batch_entry), + }; if replacements > 0 { self.metrics.row_replacements.add(replacements);