Skip to content

Commit a040394

Browse files
kosiewavantgardnerio
authored andcommitted
Use partial aggregation schema for spilling to avoid column mismatch in GroupedHashAggregateStream (apache#13995)
* Refactor spill handling in GroupedHashAggregateStream to use partial aggregate schema * Implement aggregate functions with spill handling in tests * Add tests for aggregate functions with and without spill handling * Move test related imports into mod test * Rename spill pool test functions for clarity and consistency * Refactor aggregate function imports to use fully qualified paths * Remove outdated comments regarding input batch schema for spilling in GroupedHashAggregateStream * Update aggregate test to use AVG instead of MAX * assert spill count * Refactor partial aggregate schema creation to use create_schema function * Refactor partial aggregation schema creation and remove redundant function * Remove unused import of Schema from arrow::datatypes in row_hash.rs * move spill pool testing for aggregate functions to physical-plan/src/aggregates * Use Arc::clone for schema references in aggregate functions (cherry picked from commit 81b50c4)
1 parent 6c90e2c commit a040394

File tree

3 files changed

+166
-11
lines changed

3 files changed

+166
-11
lines changed

datafusion/core/src/dataframe/mod.rs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,6 @@
2020
#[cfg(feature = "parquet")]
2121
mod parquet;
2222

23-
use std::any::Any;
24-
use std::borrow::Cow;
25-
use std::collections::HashMap;
26-
use std::sync::Arc;
27-
2823
use crate::arrow::record_batch::RecordBatch;
2924
use crate::arrow::util::pretty;
3025
use crate::datasource::file_format::csv::CsvFormatFactory;
@@ -43,6 +38,10 @@ use crate::physical_plan::{
4338
ExecutionPlan, SendableRecordBatchStream,
4439
};
4540
use crate::prelude::SessionContext;
41+
use std::any::Any;
42+
use std::borrow::Cow;
43+
use std::collections::HashMap;
44+
use std::sync::Arc;
4645

4746
use arrow::array::{Array, ArrayRef, Int64Array, StringArray};
4847
use arrow::compute::{cast, concat};

datafusion/physical-plan/src/aggregates/mod.rs

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1300,6 +1300,7 @@ mod tests {
13001300
use crate::common;
13011301
use crate::expressions::col;
13021302
use crate::memory::MemoryExec;
1303+
use crate::metrics::MetricValue;
13031304
use crate::test::assert_is_pending;
13041305
use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec};
13051306
use crate::RecordBatchStream;
@@ -2769,4 +2770,137 @@ mod tests {
27692770
assert_eq!(aggr_schema, expected_schema);
27702771
Ok(())
27712772
}
2773+
2774+
// test for https://github.com/apache/datafusion/issues/13949
2775+
async fn run_test_with_spill_pool_if_necessary(
2776+
pool_size: usize,
2777+
expect_spill: bool,
2778+
) -> Result<()> {
2779+
fn create_record_batch(
2780+
schema: &Arc<Schema>,
2781+
data: (Vec<u32>, Vec<f64>),
2782+
) -> Result<RecordBatch> {
2783+
Ok(RecordBatch::try_new(
2784+
Arc::clone(schema),
2785+
vec![
2786+
Arc::new(UInt32Array::from(data.0)),
2787+
Arc::new(Float64Array::from(data.1)),
2788+
],
2789+
)?)
2790+
}
2791+
2792+
let schema = Arc::new(Schema::new(vec![
2793+
Field::new("a", DataType::UInt32, false),
2794+
Field::new("b", DataType::Float64, false),
2795+
]));
2796+
2797+
let batches = vec![
2798+
create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?,
2799+
create_record_batch(&schema, (vec![2, 3, 4, 4], vec![1.0, 2.0, 3.0, 4.0]))?,
2800+
];
2801+
let plan: Arc<dyn ExecutionPlan> =
2802+
Arc::new(MemoryExec::try_new(&[batches], Arc::clone(&schema), None)?);
2803+
2804+
let grouping_set = PhysicalGroupBy::new(
2805+
vec![(col("a", &schema)?, "a".to_string())],
2806+
vec![],
2807+
vec![vec![false]],
2808+
);
2809+
2810+
// Test with MIN for simple intermediate state (min) and AVG for multiple intermediate states (partial sum, partial count).
2811+
let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
2812+
Arc::new(
2813+
AggregateExprBuilder::new(
2814+
datafusion_functions_aggregate::min_max::min_udaf(),
2815+
vec![col("b", &schema)?],
2816+
)
2817+
.schema(Arc::clone(&schema))
2818+
.alias("MIN(b)")
2819+
.build()?,
2820+
),
2821+
Arc::new(
2822+
AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
2823+
.schema(Arc::clone(&schema))
2824+
.alias("AVG(b)")
2825+
.build()?,
2826+
),
2827+
];
2828+
2829+
let single_aggregate = Arc::new(AggregateExec::try_new(
2830+
AggregateMode::Single,
2831+
grouping_set,
2832+
aggregates,
2833+
vec![None, None],
2834+
plan,
2835+
Arc::clone(&schema),
2836+
)?);
2837+
2838+
let batch_size = 2;
2839+
let memory_pool = Arc::new(FairSpillPool::new(pool_size));
2840+
let task_ctx = Arc::new(
2841+
TaskContext::default()
2842+
.with_session_config(SessionConfig::new().with_batch_size(batch_size))
2843+
.with_runtime(Arc::new(
2844+
RuntimeEnvBuilder::new()
2845+
.with_memory_pool(memory_pool)
2846+
.build()?,
2847+
)),
2848+
);
2849+
2850+
let result = collect(single_aggregate.execute(0, Arc::clone(&task_ctx))?).await?;
2851+
2852+
assert_spill_count_metric(expect_spill, single_aggregate);
2853+
2854+
#[rustfmt::skip]
2855+
assert_batches_sorted_eq!(
2856+
[
2857+
"+---+--------+--------+",
2858+
"| a | MIN(b) | AVG(b) |",
2859+
"+---+--------+--------+",
2860+
"| 2 | 1.0 | 1.0 |",
2861+
"| 3 | 2.0 | 2.0 |",
2862+
"| 4 | 3.0 | 3.5 |",
2863+
"+---+--------+--------+",
2864+
],
2865+
&result
2866+
);
2867+
2868+
Ok(())
2869+
}
2870+
2871+
fn assert_spill_count_metric(
2872+
expect_spill: bool,
2873+
single_aggregate: Arc<AggregateExec>,
2874+
) {
2875+
if let Some(metrics_set) = single_aggregate.metrics() {
2876+
let mut spill_count = 0;
2877+
2878+
// Inspect metrics for SpillCount
2879+
for metric in metrics_set.iter() {
2880+
if let MetricValue::SpillCount(count) = metric.value() {
2881+
spill_count = count.value();
2882+
break;
2883+
}
2884+
}
2885+
2886+
if expect_spill && spill_count == 0 {
2887+
panic!(
2888+
"Expected spill but SpillCount metric not found or SpillCount was 0."
2889+
);
2890+
} else if !expect_spill && spill_count > 0 {
2891+
panic!("Expected no spill but found SpillCount metric with value greater than 0.");
2892+
}
2893+
} else {
2894+
panic!("No metrics returned from the operator; cannot verify spilling.");
2895+
}
2896+
}
2897+
2898+
#[tokio::test]
2899+
async fn test_aggregate_with_spill_if_necessary() -> Result<()> {
2900+
// test with spill
2901+
run_test_with_spill_pool_if_necessary(2_000, true).await?;
2902+
// test without spill
2903+
run_test_with_spill_pool_if_necessary(20_000, false).await?;
2904+
Ok(())
2905+
}
27722906
}

datafusion/physical-plan/src/aggregates/row_hash.rs

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ use std::vec;
2424
use crate::aggregates::group_values::{new_group_values, GroupValues};
2525
use crate::aggregates::order::GroupOrderingFull;
2626
use crate::aggregates::{
27-
evaluate_group_by, evaluate_many, evaluate_optional, group_schema, AggregateMode,
28-
PhysicalGroupBy,
27+
create_schema, evaluate_group_by, evaluate_many, evaluate_optional, group_schema,
28+
AggregateMode, PhysicalGroupBy,
2929
};
3030
use crate::metrics::{BaselineMetrics, MetricBuilder, RecordOutput};
3131
use crate::sorts::sort::sort_batch;
@@ -490,6 +490,31 @@ impl GroupedHashAggregateStream {
490490
.collect::<Result<_>>()?;
491491

492492
let group_schema = group_schema(&agg.input().schema(), &agg_group_by)?;
493+
494+
// fix https://github.com/apache/datafusion/issues/13949
495+
// Builds a **partial aggregation** schema by combining the group columns and
496+
// the accumulator state columns produced by each aggregate expression.
497+
//
498+
// # Why Partial Aggregation Schema Is Needed
499+
//
500+
// In a multi-stage (partial/final) aggregation strategy, each partial-aggregate
501+
// operator produces *intermediate* states (e.g., partial sums, counts) rather
502+
// than final scalar values. These extra columns do **not** exist in the original
503+
// input schema (which may be something like `[colA, colB, ...]`). Instead,
504+
// each aggregator adds its own internal state columns (e.g., `[acc_state_1, acc_state_2, ...]`).
505+
//
506+
// Therefore, when we spill these intermediate states or pass them to another
507+
// aggregation operator, we must use a schema that includes both the group
508+
// columns **and** the partial-state columns.
509+
let partial_agg_schema = create_schema(
510+
&agg.input().schema(),
511+
&agg_group_by,
512+
&aggregate_exprs,
513+
AggregateMode::Partial,
514+
)?;
515+
516+
let partial_agg_schema = Arc::new(partial_agg_schema);
517+
493518
let spill_expr = group_schema
494519
.fields
495520
.into_iter()
@@ -522,7 +547,7 @@ impl GroupedHashAggregateStream {
522547
let spill_state = SpillState {
523548
spills: vec![],
524549
spill_expr,
525-
spill_schema: Arc::clone(&agg_schema),
550+
spill_schema: partial_agg_schema,
526551
is_stream_merging: false,
527552
merging_aggregate_arguments,
528553
merging_group_by: PhysicalGroupBy::new_single(agg_group_by.expr.clone()),
@@ -953,9 +978,6 @@ impl GroupedHashAggregateStream {
953978
&& self.update_memory_reservation().is_err()
954979
{
955980
assert_ne!(self.mode, AggregateMode::Partial);
956-
// Use input batch (Partial mode) schema for spilling because
957-
// the spilled data will be merged and re-evaluated later.
958-
self.spill_state.spill_schema = batch.schema();
959981
self.spill()?;
960982
self.clear_shrink(batch);
961983
}

0 commit comments

Comments
 (0)