Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 152 additions & 1 deletion datafusion/core/tests/physical_optimizer/aggregate_statistics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,23 @@ use arrow::array::Int32Array;
use arrow::array::{Int64Array, StringArray};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use datafusion::datasource::listing::PartitionedFile;
use datafusion::datasource::memory::MemTable;
use datafusion::datasource::memory::MemorySourceConfig;
use datafusion::datasource::physical_plan::ParquetSource;
use datafusion::datasource::source::DataSourceExec;
use datafusion::prelude::{SessionConfig, SessionContext};
use datafusion_common::Result;
use datafusion_common::assert_batches_eq;
use datafusion_common::cast::as_int64_array;
use datafusion_common::config::ConfigOptions;
use datafusion_common::stats::Precision;
use datafusion_common::{ColumnStatistics, Result, Statistics};
use datafusion_datasource::file_scan_config::FileScanConfigBuilder;
use datafusion_execution::TaskContext;
use datafusion_execution::object_store::ObjectStoreUrl;
use datafusion_expr::Operator;
use datafusion_functions_aggregate::count::count_udaf;
use datafusion_physical_expr::aggregate::AggregateExprBuilder;
use datafusion_physical_expr::expressions::{self, cast};
use datafusion_physical_optimizer::PhysicalOptimizerRule;
use datafusion_physical_optimizer::aggregate_statistics::AggregateStatistics;
Expand Down Expand Up @@ -402,3 +409,147 @@ async fn utf8_grouping_min_max_limit_fallbacks() -> Result<()> {

Ok(())
}

#[tokio::test]
async fn test_count_distinct_optimization() -> Result<()> {
struct TestCase {
name: &'static str,
distinct_count: Precision<usize>,
use_column_expr: bool,
expect_optimized: bool,
expected_value: Option<i64>,
}

let cases = vec![
TestCase {
name: "exact statistics",
distinct_count: Precision::Exact(42),
use_column_expr: true,
expect_optimized: true,
expected_value: Some(42),
},
TestCase {
name: "absent statistics",
distinct_count: Precision::Absent,
use_column_expr: true,
expect_optimized: false,
expected_value: None,
},
TestCase {
name: "inexact statistics",
distinct_count: Precision::Inexact(42),
use_column_expr: true,
expect_optimized: false,
expected_value: None,
},
TestCase {
name: "non-column expression with exact statistics",
distinct_count: Precision::Exact(42),
use_column_expr: false,
expect_optimized: false,
expected_value: None,
},
];

for case in cases {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Int32, true),
]));

let statistics = Statistics {
num_rows: Precision::Exact(100),
total_byte_size: Precision::Absent,
column_statistics: vec![
ColumnStatistics {
distinct_count: case.distinct_count,
null_count: Precision::Exact(10),
..Default::default()
},
ColumnStatistics::default(),
],
};

let config = FileScanConfigBuilder::new(
ObjectStoreUrl::parse("test:///").unwrap(),
Arc::new(ParquetSource::new(Arc::clone(&schema))),
)
.with_file(PartitionedFile::new("x".to_string(), 100))
.with_statistics(statistics)
.build();

let source: Arc<dyn ExecutionPlan> = DataSourceExec::from_data_source(config);
let schema = source.schema();

let (agg_args, alias): (Vec<Arc<dyn datafusion_physical_expr::PhysicalExpr>>, _) =
if case.use_column_expr {
(vec![expressions::col("a", &schema)?], "COUNT(DISTINCT a)")
} else {
(
vec![expressions::binary(
expressions::col("a", &schema)?,
Operator::Plus,
expressions::col("b", &schema)?,
&schema,
)?],
"COUNT(DISTINCT a + b)",
)
};

let count_distinct_expr = AggregateExprBuilder::new(count_udaf(), agg_args)
.schema(Arc::clone(&schema))
.alias(alias)
.distinct()
.build()?;

let partial_agg = AggregateExec::try_new(
AggregateMode::Partial,
PhysicalGroupBy::default(),
vec![Arc::new(count_distinct_expr.clone())],
vec![None],
source,
Arc::clone(&schema),
)?;

let final_agg = AggregateExec::try_new(
AggregateMode::Final,
PhysicalGroupBy::default(),
vec![Arc::new(count_distinct_expr)],
vec![None],
Arc::new(partial_agg),
Arc::clone(&schema),
)?;

let conf = ConfigOptions::new();
let optimized =
AggregateStatistics::new().optimize(Arc::new(final_agg), &conf)?;

if case.expect_optimized {
assert!(
optimized.as_any().is::<ProjectionExec>(),
"'{}': expected ProjectionExec",
case.name
);

if let Some(expected_val) = case.expected_value {
let task_ctx = Arc::new(TaskContext::default());
let result = common::collect(optimized.execute(0, task_ctx)?).await?;
assert_eq!(result.len(), 1, "'{}': expected 1 batch", case.name);
assert_eq!(
as_int64_array(result[0].column(0)).unwrap().values(),
&[expected_val],
"'{}': unexpected value",
case.name
);
}
} else {
assert!(
optimized.as_any().is::<AggregateExec>(),
"'{}': expected AggregateExec (not optimized)",
case.name
);
}
}

Ok(())
}
49 changes: 29 additions & 20 deletions datafusion/functions-aggregate/src/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -365,31 +365,40 @@ impl AggregateUDFImpl for Count {
}

fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
let [expr] = statistics_args.exprs else {
return None;
};
let col_stats = &statistics_args.statistics.column_statistics;

if statistics_args.is_distinct {
// Only column references can be resolved from statistics;
// expressions like casts or literals are not supported.
let col_expr = expr.as_any().downcast_ref::<expressions::Column>()?;
if let Precision::Exact(dc) = col_stats[col_expr.index()].distinct_count {
let dc = i64::try_from(dc).ok()?;
return Some(ScalarValue::Int64(Some(dc)));
}
return None;
}
if let Precision::Exact(num_rows) = statistics_args.statistics.num_rows
&& statistics_args.exprs.len() == 1
{
// TODO optimize with exprs other than Column
if let Some(col_expr) = statistics_args.exprs[0]
.as_any()
.downcast_ref::<expressions::Column>()
{
let current_val = &statistics_args.statistics.column_statistics
[col_expr.index()]
.null_count;
if let &Precision::Exact(val) = current_val {
return Some(ScalarValue::Int64(Some((num_rows - val) as i64)));
}
} else if let Some(lit_expr) = statistics_args.exprs[0]
.as_any()
.downcast_ref::<expressions::Literal>()
&& lit_expr.value() == &COUNT_STAR_EXPANSION
{
return Some(ScalarValue::Int64(Some(num_rows as i64)));

let Precision::Exact(num_rows) = statistics_args.statistics.num_rows else {
return None;
};

// TODO optimize with exprs other than Column
if let Some(col_expr) = expr.as_any().downcast_ref::<expressions::Column>() {
if let Precision::Exact(val) = col_stats[col_expr.index()].null_count {
let count = i64::try_from(num_rows - val).ok()?;
return Some(ScalarValue::Int64(Some(count)));
}
} else if let Some(lit_expr) =
expr.as_any().downcast_ref::<expressions::Literal>()
&& lit_expr.value() == &COUNT_STAR_EXPANSION
{
let num_rows = i64::try_from(num_rows).ok()?;
return Some(ScalarValue::Int64(Some(num_rows)));
}

None
}

Expand Down