Skip to content
Merged
183 changes: 183 additions & 0 deletions datafusion/core/benches/topk_aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,63 @@

mod data_utils;

use arrow::array::Int64Builder;
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use arrow::record_batch::RecordBatch;
use arrow::util::pretty::pretty_format_batches;
use criterion::{Criterion, criterion_group, criterion_main};
use data_utils::make_data;
use datafusion::physical_plan::{collect, displayable};
use datafusion::prelude::SessionContext;
use datafusion::{datasource::MemTable, error::Result};
use datafusion_execution::config::SessionConfig;
use rand::SeedableRng;
use rand::seq::SliceRandom;
use std::hint::black_box;
use std::sync::Arc;
use tokio::runtime::Runtime;

const LIMIT: usize = 10;

/// Create deterministic data for DISTINCT benchmarks with predictable trace_ids
/// This ensures consistent results across benchmark runs
fn make_distinct_data(
partition_cnt: i32,
sample_cnt: i32,
) -> Result<(Arc<Schema>, Vec<Vec<RecordBatch>>)> {
let mut rng = rand::rngs::SmallRng::from_seed([42; 32]);
let total_samples = partition_cnt as usize * sample_cnt as usize;
let mut ids = Vec::new();
for i in 0..total_samples {
ids.push(i as i64);
}
ids.shuffle(&mut rng);

let mut global_idx = 0;
let schema = test_distinct_schema();
let mut partitions = vec![];
for _ in 0..partition_cnt {
let mut id_builder = Int64Builder::new();

for _ in 0..sample_cnt {
let id = ids[global_idx];
id_builder.append_value(id);
global_idx += 1;
}

let id_col = Arc::new(id_builder.finish());
let batch = RecordBatch::try_new(schema.clone(), vec![id_col])?;
partitions.push(vec![batch]);
}

Ok((schema, partitions))
}

/// Returns a Schema for distinct benchmarks with i64 trace_id
fn test_distinct_schema() -> SchemaRef {
Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]))
}

async fn create_context(
partition_cnt: i32,
sample_cnt: i32,
Expand All @@ -50,6 +94,25 @@ async fn create_context(
Ok(ctx)
}

async fn create_context_distinct(
partition_cnt: i32,
sample_cnt: i32,
use_topk: bool,
) -> Result<SessionContext> {
// Use deterministic data generation for DISTINCT queries to ensure consistent results
let (schema, parts) = make_distinct_data(partition_cnt, sample_cnt).unwrap();
let mem_table = Arc::new(MemTable::try_new(schema, parts).unwrap());

// Create the DataFrame
let mut cfg = SessionConfig::new();
let opts = cfg.options_mut();
opts.optimizer.enable_topk_aggregation = use_topk;
let ctx = SessionContext::new_with_config(cfg);
let _ = ctx.register_table("traces", mem_table)?;

Ok(ctx)
}

fn run(rt: &Runtime, ctx: SessionContext, limit: usize, use_topk: bool, asc: bool) {
black_box(rt.block_on(async { aggregate(ctx, limit, use_topk, asc).await })).unwrap();
}
Expand All @@ -59,6 +122,17 @@ fn run_string(rt: &Runtime, ctx: SessionContext, limit: usize, use_topk: bool) {
.unwrap();
}

fn run_distinct(
rt: &Runtime,
ctx: SessionContext,
limit: usize,
use_topk: bool,
asc: bool,
) {
black_box(rt.block_on(async { aggregate_distinct(ctx, limit, use_topk, asc).await }))
.unwrap();
}

async fn aggregate(
ctx: SessionContext,
limit: usize,
Expand Down Expand Up @@ -133,6 +207,84 @@ async fn aggregate_string(
Ok(())
}

async fn aggregate_distinct(
ctx: SessionContext,
limit: usize,
use_topk: bool,
asc: bool,
) -> Result<()> {
let order_direction = if asc { "asc" } else { "desc" };
let sql = format!(
"select id from traces group by id order by id {order_direction} limit {limit};"
);
let df = ctx.sql(sql.as_str()).await?;
let plan = df.create_physical_plan().await?;
let actual_phys_plan = displayable(plan.as_ref()).indent(true).to_string();
assert_eq!(
actual_phys_plan.contains(&format!("lim=[{limit}]")),
use_topk
);
let batches = collect(plan, ctx.task_ctx()).await?;
assert_eq!(batches.len(), 1);
let batch = batches.first().unwrap();
assert_eq!(batch.num_rows(), LIMIT);

let actual = format!("{}", pretty_format_batches(&batches)?).to_lowercase();

let expected_asc = r#"
+----+
| id |
+----+
| 0 |
| 1 |
| 2 |
| 3 |
| 4 |
| 5 |
| 6 |
| 7 |
| 8 |
| 9 |
+----+
"#
.trim();

let expected_desc = r#"
+---------+
| id |
+---------+
| 9999999 |
| 9999998 |
| 9999997 |
| 9999996 |
| 9999995 |
| 9999994 |
| 9999993 |
| 9999992 |
| 9999991 |
| 9999990 |
+---------+
"#
.trim();

// Verify exact results match expected values
if asc {
assert_eq!(
actual.trim(),
expected_asc,
"Ascending DISTINCT results do not match expected values"
);
} else {
assert_eq!(
actual.trim(),
expected_desc,
"Descending DISTINCT results do not match expected values"
);
}

Ok(())
}

fn criterion_benchmark(c: &mut Criterion) {
let rt = Runtime::new().unwrap();
let limit = LIMIT;
Expand Down Expand Up @@ -253,6 +405,37 @@ fn criterion_benchmark(c: &mut Criterion) {
.as_str(),
|b| b.iter(|| run_string(&rt, ctx.clone(), limit, true)),
);

// DISTINCT benchmarks
let ctx = rt.block_on(async {
create_context_distinct(partitions, samples, false)
.await
.unwrap()
});
c.bench_function(
format!("distinct {} rows desc [no TopK]", partitions * samples).as_str(),
|b| b.iter(|| run_distinct(&rt, ctx.clone(), limit, false, false)),
);

c.bench_function(
format!("distinct {} rows asc [no TopK]", partitions * samples).as_str(),
|b| b.iter(|| run_distinct(&rt, ctx.clone(), limit, false, true)),
);

let ctx_topk = rt.block_on(async {
create_context_distinct(partitions, samples, true)
.await
.unwrap()
});
c.bench_function(
format!("distinct {} rows desc [TopK]", partitions * samples).as_str(),
|b| b.iter(|| run_distinct(&rt, ctx_topk.clone(), limit, true, false)),
);

c.bench_function(
format!("distinct {} rows asc [TopK]", partitions * samples).as_str(),
|b| b.iter(|| run_distinct(&rt, ctx_topk.clone(), limit, true, true)),
);
}

criterion_group!(benches, criterion_benchmark);
Expand Down
5 changes: 3 additions & 2 deletions datafusion/core/tests/execution/coop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use datafusion::physical_expr::aggregate::AggregateExprBuilder;
use datafusion::physical_plan;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::physical_plan::aggregates::{
AggregateExec, AggregateMode, PhysicalGroupBy,
AggregateExec, AggregateMode, LimitOptions, PhysicalGroupBy,
};
use datafusion::physical_plan::execution_plan::Boundedness;
use datafusion::prelude::SessionContext;
Expand Down Expand Up @@ -233,6 +233,7 @@ async fn agg_grouped_topk_yields(
#[values(false, true)] pretend_infinite: bool,
) -> Result<(), Box<dyn Error>> {
// build session

let session_ctx = SessionContext::new();

// set up a top-k aggregation
Expand Down Expand Up @@ -260,7 +261,7 @@ async fn agg_grouped_topk_yields(
inf.clone(),
inf.schema(),
)?
.with_limit(Some(100)),
.with_limit_options(Some(LimitOptions::new(100))),
);

query_yields(aggr, session_ctx.task_ctx()).await
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ use datafusion_physical_optimizer::PhysicalOptimizerRule;
use datafusion_physical_optimizer::combine_partial_final_agg::CombinePartialFinalAggregate;
use datafusion_physical_plan::ExecutionPlan;
use datafusion_physical_plan::aggregates::{
AggregateExec, AggregateMode, PhysicalGroupBy,
AggregateExec, AggregateMode, LimitOptions, PhysicalGroupBy,
};
use datafusion_physical_plan::displayable;
use datafusion_physical_plan::repartition::RepartitionExec;
Expand Down Expand Up @@ -260,7 +260,7 @@ fn aggregations_with_limit_combined() -> datafusion_common::Result<()> {
schema,
)
.unwrap()
.with_limit(Some(5)),
.with_limit_options(Some(LimitOptions::new(5))),
);
let plan: Arc<dyn ExecutionPlan> = final_agg;
// should combine the Partial/Final AggregateExecs to a Single AggregateExec
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,9 @@ impl PhysicalOptimizerRule for CombinePartialFinalAggregate {
Arc::clone(input_agg_exec.input()),
input_agg_exec.input_schema(),
)
.map(|combined_agg| combined_agg.with_limit(agg_exec.limit()))
.map(|combined_agg| {
combined_agg.with_limit_options(agg_exec.limit_options())
})
.ok()
.map(Arc::new)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

use std::sync::Arc;

use datafusion_physical_plan::aggregates::AggregateExec;
use datafusion_physical_plan::aggregates::{AggregateExec, LimitOptions};
use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec};
use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties};

Expand Down Expand Up @@ -63,7 +63,7 @@ impl LimitedDistinctAggregation {
aggr.input_schema(),
)
.expect("Unable to copy Aggregate!")
.with_limit(Some(limit));
.with_limit_options(Some(LimitOptions::new(limit)));
Some(Arc::new(new_aggr))
}

Expand Down
44 changes: 32 additions & 12 deletions datafusion/physical-optimizer/src/topk_aggregation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_plan::ExecutionPlan;
use datafusion_physical_plan::aggregates::LimitOptions;
use datafusion_physical_plan::aggregates::{AggregateExec, topk_types_supported};
use datafusion_physical_plan::execution_plan::CardinalityEffect;
use datafusion_physical_plan::projection::ProjectionExec;
Expand All @@ -47,28 +48,47 @@ impl TopKAggregation {
order_desc: bool,
limit: usize,
) -> Option<Arc<dyn ExecutionPlan>> {
// ensure the sort direction matches aggregate function
let (field, desc) = aggr.get_minmax_desc()?;
if desc != order_desc {
return None;
}
let group_key = aggr.group_expr().expr().iter().exactly_one().ok()?;
let kt = group_key.0.data_type(&aggr.input().schema()).ok()?;
let vt = field.data_type();
if !topk_types_supported(&kt, vt) {
// Current only support single group key
let (group_key, group_key_alias) =
aggr.group_expr().expr().iter().exactly_one().ok()?;
let kt = group_key.data_type(&aggr.input().schema()).ok()?;
let vt = if let Some((field, _)) = aggr.get_minmax_desc() {
field.data_type().clone()
} else {
kt.clone()
};
if !topk_types_supported(&kt, &vt) {
return None;
}
if aggr.filter_expr().iter().any(|e| e.is_some()) {
return None;
}

// ensure the sort is on the same field as the aggregate output
if order_by != field.name() {
// Check if this is ordering by an aggregate function (MIN/MAX)
if let Some((field, desc)) = aggr.get_minmax_desc() {
// ensure the sort direction matches aggregate function
if desc != order_desc {
return None;
}
// ensure the sort is on the same field as the aggregate output
if order_by != field.name() {
return None;
}
} else if aggr.aggr_expr().is_empty() {
// This is a GROUP BY without aggregates, check if ordering is on the group key itself
if order_by != group_key_alias {
return None;
}
} else {
// Has aggregates but not MIN/MAX, or doesn't DISTINCT
return None;
}

// We found what we want: clone, copy the limit down, and return modified node
let new_aggr = aggr.with_new_limit(Some(limit));
let new_aggr = AggregateExec::with_new_limit_options(
aggr,
Some(LimitOptions::new_with_order(limit, order_desc)),
);
Some(Arc::new(new_aggr))
}

Expand Down
Loading