diff --git a/Cargo.toml b/Cargo.toml index 742e2b8a19f9..1cd76982dd4a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -99,7 +99,6 @@ arrow-flight = { version = "55.2.0", features = [ ] } arrow-ipc = { version = "55.2.0", default-features = false, features = [ "lz4", - "zstd", ] } arrow-ord = { version = "55.2.0", default-features = false } arrow-schema = { version = "55.2.0", default-features = false } diff --git a/ci.test b/ci.test new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/datafusion-examples/examples/advanced_udwf.rs b/datafusion-examples/examples/advanced_udwf.rs index f7316ddc1bec..e0fab7ee9f31 100644 --- a/datafusion-examples/examples/advanced_udwf.rs +++ b/datafusion-examples/examples/advanced_udwf.rs @@ -199,6 +199,7 @@ impl WindowUDFImpl for SimplifySmoothItUdf { order_by: window_function.params.order_by, window_frame: window_function.params.window_frame, null_treatment: window_function.params.null_treatment, + distinct: window_function.params.distinct, }, })) }; diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 31159d4a8588..6758ed4799d5 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -389,7 +389,7 @@ config_namespace! { /// /// This is used to workaround bugs in the planner that are now caught by /// the new schema verification step. - pub skip_physical_aggregate_schema_check: bool, default = false + pub skip_physical_aggregate_schema_check: bool, default = true /// Sets the compression codec used when spilling data to disk. /// diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index c4455e271c84..1a6a66923e55 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -47,6 +47,7 @@ compression = [ "bzip2", "flate2", "zstd", + "arrow-ipc/zstd", "datafusion-datasource/compression", ] crypto_expressions = ["datafusion-functions/crypto_expressions"] diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index ab123dcceada..6d2393c99d2b 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -712,6 +712,9 @@ impl DefaultPhysicalPlanner { differences.push(format!("field nullability at index {} [{}]: (physical) {} vs (logical) {}", i, physical_field.name(), physical_field.is_nullable(), logical_field.is_nullable())); } } + + log::warn!("Physical input schema should be the same as the one converted from logical input schema, but did not match for logical plan:\n{}", input.display_indent()); + return internal_err!("Physical input schema should be the same as the one converted from logical input schema. Differences: {}", differences .iter() .map(|s| format!("\n\t- {s}")) @@ -1646,6 +1649,7 @@ pub fn create_window_expr_with_name( order_by, window_frame, null_treatment, + distinct, }, } = window_fun.as_ref(); let physical_args = @@ -1674,6 +1678,7 @@ pub fn create_window_expr_with_name( window_frame, physical_schema, ignore_nulls, + *distinct, ) } other => plan_err!("Invalid window expression '{other:?}'"), diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index 316d3ba5a926..23e3281cf386 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -288,6 +288,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> { Arc::new(window_frame), &extended_schema, false, + false, )?; let running_window_exec = Arc::new(BoundedWindowAggExec::try_new( vec![window_expr], @@ -660,6 +661,7 @@ async fn run_window_test( Arc::new(window_frame.clone()), &extended_schema, false, + false, )?], exec1, false, @@ -678,6 +680,7 @@ async fn run_window_test( Arc::new(window_frame.clone()), &extended_schema, false, + false, )?], exec2, search_mode.clone(), diff --git a/datafusion/core/tests/physical_optimizer/enforce_distribution.rs b/datafusion/core/tests/physical_optimizer/enforce_distribution.rs index fd847763124a..2dce87de00ed 100644 --- a/datafusion/core/tests/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/tests/physical_optimizer/enforce_distribution.rs @@ -23,7 +23,7 @@ use crate::physical_optimizer::test_utils::{ check_integrity, coalesce_partitions_exec, parquet_exec_with_sort, parquet_exec_with_stats, repartition_exec, schema, sort_exec, sort_exec_with_preserve_partitioning, sort_merge_join_exec, - sort_preserving_merge_exec, union_exec, + sort_preserving_merge_exec, trim_plan_display, union_exec, }; use arrow::array::{RecordBatch, UInt64Array, UInt8Array}; @@ -39,10 +39,12 @@ use datafusion::datasource::MemTable; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::error::Result; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::ScalarValue; +use datafusion_common::{assert_contains, ScalarValue}; use datafusion_datasource::file_groups::FileGroup; use datafusion_datasource::file_scan_config::FileScanConfigBuilder; -use datafusion_expr::{JoinType, Operator}; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_expr::{AggregateUDF, JoinType, Operator}; +use datafusion_physical_expr::aggregate::AggregateExprBuilder; use datafusion_physical_expr::expressions::{binary, lit, BinaryExpr, Column, Literal}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::{ @@ -51,6 +53,7 @@ use datafusion_physical_expr_common::sort_expr::{ use datafusion_physical_optimizer::enforce_distribution::*; use datafusion_physical_optimizer::enforce_sorting::EnforceSorting; use datafusion_physical_optimizer::output_requirements::OutputRequirements; +use datafusion_physical_optimizer::sanity_checker::check_plan_sanity; use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, @@ -66,7 +69,7 @@ use datafusion_physical_plan::projection::ProjectionExec; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion_physical_plan::union::UnionExec; use datafusion_physical_plan::{ - get_plan_string, DisplayAs, DisplayFormatType, ExecutionPlanProperties, + displayable, get_plan_string, DisplayAs, DisplayFormatType, ExecutionPlanProperties, PlanProperties, Statistics, }; @@ -162,8 +165,8 @@ impl ExecutionPlan for SortRequiredExec { fn execute( &self, _partition: usize, - _context: Arc, - ) -> Result { + _context: Arc, + ) -> Result { unreachable!(); } @@ -237,7 +240,7 @@ fn csv_exec_multiple_sorted(output_ordering: Vec) -> Arc, alias_pairs: Vec<(String, String)>, ) -> Arc { @@ -251,6 +254,15 @@ fn projection_exec_with_alias( fn aggregate_exec_with_alias( input: Arc, alias_pairs: Vec<(String, String)>, +) -> Arc { + aggregate_exec_with_aggr_expr_and_alias(input, vec![], alias_pairs) +} + +#[expect(clippy::type_complexity)] +fn aggregate_exec_with_aggr_expr_and_alias( + input: Arc, + aggr_expr: Vec<(Arc, Vec>)>, + alias_pairs: Vec<(String, String)>, ) -> Arc { let schema = schema(); let mut group_by_expr: Vec<(Arc, String)> = vec![]; @@ -271,18 +283,31 @@ fn aggregate_exec_with_alias( .collect::>(); let final_grouping = PhysicalGroupBy::new_single(final_group_by_expr); + let aggr_expr = aggr_expr + .into_iter() + .map(|(udaf, exprs)| { + AggregateExprBuilder::new(udaf.clone(), exprs) + .alias(udaf.name()) + .schema(Arc::clone(&schema)) + .build() + .map(Arc::new) + .unwrap() + }) + .collect::>(); + let filter_exprs = std::iter::repeat_n(None, aggr_expr.len()).collect::>(); + Arc::new( AggregateExec::try_new( AggregateMode::FinalPartitioned, final_grouping, - vec![], - vec![], + aggr_expr.clone(), + filter_exprs.clone(), Arc::new( AggregateExec::try_new( AggregateMode::Partial, group_by, - vec![], - vec![], + aggr_expr, + filter_exprs, input, schema.clone(), ) @@ -439,6 +464,12 @@ impl TestConfig { self } + /// Set batch size. + fn with_batch_size(mut self, batch_size: usize) -> Self { + self.config.execution.batch_size = batch_size; + self + } + /// Perform a series of runs using the current [`TestConfig`], /// assert the expected plan result, /// and return the result plan (for potentional subsequent runs). @@ -2027,6 +2058,285 @@ fn repartition_ignores_union() -> Result<()> { Ok(()) } +fn aggregate_over_union(input: Vec>) -> Arc { + let union = union_exec(input); + let plan = + aggregate_exec_with_alias(union, vec![("a".to_string(), "a1".to_string())]); + + // Demonstrate starting plan. + let before = displayable(plan.as_ref()).indent(true).to_string(); + let before = trim_plan_display(&before); + assert_eq!( + before, + vec![ + "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[]", + "AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[]", + "UnionExec", + "DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + "DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + ], + ); + + plan +} + +// Aggregate over a union, +// with current testing setup. +// +// It will repartiton twice for an aggregate over a union. +// * repartitions before the partial aggregate. +// * repartitions before the final aggregation. +#[test] +fn repartitions_twice_for_aggregate_after_union() -> Result<()> { + let plan = aggregate_over_union(vec![parquet_exec(); 2]); + + // We get a distribution error without repartitioning. + let err = check_plan_sanity(plan.clone(), &Default::default()).unwrap_err(); + assert_contains!( + err.message(), + "DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet\"] does not satisfy distribution requirements: HashPartitioned[[a1@0]]). Child-0 output partitioning: UnknownPartitioning(2)" + ); + + // Updated plan (post optimization) will have added RepartitionExecs (btwn union and aggregation). + let expected = &[ + "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[]", + " RepartitionExec: partitioning=Hash([a1@0], 10), input_partitions=10", + " AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[]", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", + " UnionExec", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + ]; + let test_config = TestConfig::default(); + test_config.run(expected, plan.clone(), &DISTRIB_DISTRIB_SORT)?; + test_config.run(expected, plan, &SORT_DISTRIB_DISTRIB)?; + + Ok(()) +} + +// Aggregate over a union, +// but make the test setup more realistic. +// +// It will repartiton once for an aggregate over a union. +// * repartitions btwn partial & final aggregations. +#[test] +fn repartitions_once_for_aggregate_after_union() -> Result<()> { + // use parquet exec with stats + let plan: Arc = + aggregate_over_union(vec![parquet_exec_with_stats(10000); 2]); + + // We get a distribution error without repartitioning. + let err = check_plan_sanity(plan.clone(), &Default::default()).unwrap_err(); + assert_contains!( + err.message(), + "DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet\"] does not satisfy distribution requirements: HashPartitioned[[a1@0]]). Child-0 output partitioning: UnknownPartitioning(2)" + ); + + // This removes the forced round-robin repartitioning, + // by no longer hard-coding batch_size=1. + // + // Updated plan (post optimization) will have added only 1 RepartitionExec. + let expected = &[ + "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[]", + " RepartitionExec: partitioning=Hash([a1@0], 10), input_partitions=10", + " AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[]", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", + " UnionExec", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + ]; + let test_config = TestConfig::default().with_batch_size(100); + test_config.run(expected, plan.clone(), &DISTRIB_DISTRIB_SORT)?; + test_config.run(expected, plan, &SORT_DISTRIB_DISTRIB)?; + + Ok(()) +} + +/// Same as [`aggregate_over_union`], but with a sort btwn the union and aggregation. +fn aggregate_over_sorted_union( + input: Vec>, +) -> Arc { + let union = union_exec(input); + let schema = schema(); + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { + expr: col("a", &schema).unwrap(), + options: SortOptions::default(), + }]) + .unwrap(); + let sort = sort_exec(sort_key, union); + let plan = aggregate_exec_with_alias(sort, vec![("a".to_string(), "a1".to_string())]); + + // Demonstrate starting plan. + // Notice the `ordering_mode=Sorted` on the aggregations. + let before = displayable(plan.as_ref()).indent(true).to_string(); + let before = trim_plan_display(&before); + assert_eq!( + before, + vec![ + "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[], ordering_mode=Sorted", + "AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[], ordering_mode=Sorted", + "SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + "UnionExec", + "DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + "DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + ], + ); + + plan +} + +/// Same as [`repartitions_once_for_aggregate_after_union`], but adds a sort btwn +/// the union and the aggregate. This changes the outcome: +/// +/// * we no longer get a distribution error. +/// * but we still get repartitioning? +#[test] +fn repartitions_for_aggregate_after_sorted_union() -> Result<()> { + let plan = aggregate_over_sorted_union(vec![parquet_exec_with_stats(10000); 2]); + + // With the sort, there is no distribution error. + let checker = check_plan_sanity(plan.clone(), &Default::default()); + assert!(checker.is_ok()); + + // It does not repartition on the first run + let expected_after_first_run = &[ + "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[], ordering_mode=Sorted", + " SortExec: expr=[a1@0 ASC NULLS LAST], preserve_partitioning=[true]", + " RepartitionExec: partitioning=Hash([a1@0], 10), input_partitions=10", + " AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[], ordering_mode=Sorted", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " SortPreservingMergeExec: [a@0 ASC]", + " UnionExec", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + ]; + let test_config = TestConfig::default().with_batch_size(100); + test_config.run( + expected_after_first_run, + plan.clone(), + &DISTRIB_DISTRIB_SORT, + )?; + + // But does repartition on the second run. + let expected_after_second_run = &[ + "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[], ordering_mode=Sorted", + " SortExec: expr=[a1@0 ASC NULLS LAST], preserve_partitioning=[true]", + " RepartitionExec: partitioning=Hash([a1@0], 10), input_partitions=10", + " SortExec: expr=[a1@0 ASC NULLS LAST], preserve_partitioning=[true]", + " AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[], ordering_mode=Sorted", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + " CoalescePartitionsExec", + " UnionExec", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + ]; + test_config.run(expected_after_second_run, plan, &SORT_DISTRIB_DISTRIB)?; + + Ok(()) +} + +/// Same as [`aggregate_over_sorted_union`], but with a sort btwn the union and aggregation. +fn aggregate_over_sorted_union_projection( + input: Vec>, +) -> Arc { + let union = union_exec(input); + let union_projection = projection_exec_with_alias( + union, + vec![ + ("a".to_string(), "a".to_string()), + ("b".to_string(), "value".to_string()), + ], + ); + let schema = schema(); + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { + expr: col("a", &schema).unwrap(), + options: SortOptions::default(), + }]) + .unwrap(); + let sort = sort_exec(sort_key, union_projection); + let plan = aggregate_exec_with_alias(sort, vec![("a".to_string(), "a1".to_string())]); + + // Demonstrate starting plan. + // Notice the `ordering_mode=Sorted` on the aggregations. + let before = displayable(plan.as_ref()).indent(true).to_string(); + let before = trim_plan_display(&before); + assert_eq!( + before, + vec![ + "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[], ordering_mode=Sorted", + "AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[], ordering_mode=Sorted", + "SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + "ProjectionExec: expr=[a@0 as a, b@1 as value]", + "UnionExec", + "DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + "DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + ], + ); + + plan +} + +/// Same as [`repartitions_for_aggregate_after_sorted_union`], but adds a projection +/// as well between the union and aggregate. This change the outcome: +/// +/// * we no longer get repartitioning, and instead get coalescing. +#[test] +fn coalesces_for_aggregate_after_sorted_union_projection() -> Result<()> { + let plan = + aggregate_over_sorted_union_projection(vec![parquet_exec_with_stats(10000); 2]); + + // Same as `repartitions_for_aggregate_after_sorted_union`. No error. + let checker = check_plan_sanity(plan.clone(), &Default::default()); + assert!(checker.is_ok()); + + // It no longer does a repartition on the first run. + // Instead adds a SPM. + let expected_after_first_run = &[ + "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[], ordering_mode=Sorted", + " SortExec: expr=[a1@0 ASC NULLS LAST], preserve_partitioning=[true]", + " RepartitionExec: partitioning=Hash([a1@0], 10), input_partitions=10", + " AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[], ordering_mode=Sorted", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " SortPreservingMergeExec: [a@0 ASC]", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", + " ProjectionExec: expr=[a@0 as a, b@1 as value]", + " UnionExec", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + ]; + let test_config = TestConfig::default().with_batch_size(100); + test_config.run( + expected_after_first_run, + plan.clone(), + &DISTRIB_DISTRIB_SORT, + )?; + + // Then it removes the SPM, and inserts a coalesace on the second run. + let expected_after_second_run = &[ + "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[], ordering_mode=Sorted", + " SortExec: expr=[a1@0 ASC NULLS LAST], preserve_partitioning=[true]", + " RepartitionExec: partitioning=Hash([a1@0], 10), input_partitions=10", + " SortExec: expr=[a1@0 ASC NULLS LAST], preserve_partitioning=[true]", + " AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[], ordering_mode=Sorted", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + " CoalescePartitionsExec", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", + " ProjectionExec: expr=[a@0 as a, b@1 as value]", + " UnionExec", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + ]; + test_config.run(expected_after_second_run, plan, &SORT_DISTRIB_DISTRIB)?; + + Ok(()) +} + #[test] fn repartition_through_sort_preserving_merge() -> Result<()> { // sort preserving merge with non-sorted input diff --git a/datafusion/core/tests/physical_optimizer/enforce_sorting.rs b/datafusion/core/tests/physical_optimizer/enforce_sorting.rs index e31a30cc0883..ef29d51e5d37 100644 --- a/datafusion/core/tests/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/tests/physical_optimizer/enforce_sorting.rs @@ -17,15 +17,16 @@ use std::sync::Arc; +use crate::physical_optimizer::enforce_distribution::projection_exec_with_alias; use crate::physical_optimizer::test_utils::{ aggregate_exec, bounded_window_exec, bounded_window_exec_with_partition, check_integrity, coalesce_batches_exec, coalesce_partitions_exec, create_test_schema, create_test_schema2, create_test_schema3, filter_exec, global_limit_exec, hash_join_exec, local_limit_exec, memory_exec, parquet_exec, parquet_exec_with_sort, - projection_exec, repartition_exec, sort_exec, sort_exec_with_fetch, sort_expr, - sort_expr_options, sort_merge_join_exec, sort_preserving_merge_exec, - sort_preserving_merge_exec_with_fetch, spr_repartition_exec, stream_exec_ordered, - union_exec, RequirementsTestExec, + parquet_exec_with_stats, projection_exec, repartition_exec, schema, sort_exec, + sort_exec_with_fetch, sort_expr, sort_expr_options, sort_merge_join_exec, + sort_preserving_merge_exec, sort_preserving_merge_exec_with_fetch, + spr_repartition_exec, stream_exec_ordered, union_exec, RequirementsTestExec, }; use arrow::compute::SortOptions; @@ -47,6 +48,9 @@ use datafusion_physical_expr_common::sort_expr::{ }; use datafusion_physical_expr::{Distribution, Partitioning}; use datafusion_physical_expr::expressions::{col, BinaryExpr, Column, NotExpr}; +use datafusion_physical_optimizer::sanity_checker::SanityCheckPlan; +use datafusion_physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; +use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion_physical_plan::repartition::RepartitionExec; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; @@ -2292,6 +2296,93 @@ async fn test_commutativity() -> Result<()> { Ok(()) } +fn single_partition_aggregate( + input: Arc, + alias_pairs: Vec<(String, String)>, +) -> Arc { + let schema = schema(); + let group_by = alias_pairs + .iter() + .map(|(column, alias)| (col(column, &input.schema()).unwrap(), alias.to_string())) + .collect::>(); + let group_by = PhysicalGroupBy::new_single(group_by); + + Arc::new( + AggregateExec::try_new( + AggregateMode::SinglePartitioned, + group_by, + vec![], + vec![], + input, + schema, + ) + .unwrap(), + ) +} + +#[tokio::test] +async fn test_preserve_needed_coalesce() -> Result<()> { + // Input to EnforceSorting, from our test case. + let plan = projection_exec_with_alias( + union_exec(vec![parquet_exec_with_stats(10000); 2]), + vec![ + ("a".to_string(), "a".to_string()), + ("b".to_string(), "value".to_string()), + ], + ); + let plan = Arc::new(CoalescePartitionsExec::new(plan)); + let schema = schema(); + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { + expr: col("a", &schema).unwrap(), + options: SortOptions::default(), + }]) + .unwrap(); + let plan: Arc = + single_partition_aggregate(plan, vec![("a".to_string(), "a1".to_string())]); + let plan = sort_exec(sort_key, plan); + + // Starting plan: as in our test case. + assert_eq!( + get_plan_string(&plan), + vec![ + "SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + " AggregateExec: mode=SinglePartitioned, gby=[a@0 as a1], aggr=[]", + " CoalescePartitionsExec", + " ProjectionExec: expr=[a@0 as a, b@1 as value]", + " UnionExec", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + ], + ); + + let checker = SanityCheckPlan::new().optimize(plan.clone(), &Default::default()); + assert!(checker.is_ok()); + + // EnforceSorting will remove the coalesce, and add an SPM further up (above the aggregate). + let optimizer = EnforceSorting::new(); + let optimized = optimizer.optimize(plan, &Default::default())?; + assert_eq!( + get_plan_string(&optimized), + vec![ + "SortPreservingMergeExec: [a@0 ASC]", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + " AggregateExec: mode=SinglePartitioned, gby=[a@0 as a1], aggr=[]", + " CoalescePartitionsExec", + " ProjectionExec: expr=[a@0 as a, b@1 as value]", + " UnionExec", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + " DataSourceExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], file_type=parquet", + ], + ); + + // Plan is valid. + let checker = SanityCheckPlan::new(); + let checker = checker.optimize(optimized, &Default::default()); + assert!(checker.is_ok()); + + Ok(()) +} + #[tokio::test] async fn test_coalesce_propagate() -> Result<()> { let schema = create_test_schema()?; @@ -3675,6 +3766,7 @@ async fn test_window_partial_constant_and_set_monotonicity() -> Result<()> { case.window_frame, input_schema.as_ref(), false, + false, )?; let window_exec = if window_expr.uses_bounded_memory() { Arc::new(BoundedWindowAggExec::try_new( diff --git a/datafusion/core/tests/physical_optimizer/test_utils.rs b/datafusion/core/tests/physical_optimizer/test_utils.rs index 7fb0f795f294..5e2d61e68f8d 100644 --- a/datafusion/core/tests/physical_optimizer/test_utils.rs +++ b/datafusion/core/tests/physical_optimizer/test_utils.rs @@ -265,6 +265,7 @@ pub fn bounded_window_exec_with_partition( Arc::new(WindowFrame::new(Some(false))), schema.as_ref(), false, + false, ) .unwrap(); @@ -509,6 +510,13 @@ pub fn check_integrity(context: PlanContext) -> Result Vec<&str> { + plan.split('\n') + .map(|s| s.trim()) + .filter(|s| !s.is_empty()) + .collect() +} + // construct a stream partition for test purposes #[derive(Debug)] pub struct TestStreamPartition { diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 0749ff0e98b7..efe8a639087a 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1131,6 +1131,8 @@ pub struct WindowFunctionParams { pub window_frame: WindowFrame, /// Specifies how NULL value is treated: ignore or respect pub null_treatment: Option, + /// Distinct flag + pub distinct: bool, } impl WindowFunction { @@ -1145,6 +1147,7 @@ impl WindowFunction { order_by: Vec::default(), window_frame: WindowFrame::new(None), null_treatment: None, + distinct: false, }, } } @@ -2291,6 +2294,7 @@ impl NormalizeEq for Expr { partition_by: self_partition_by, order_by: self_order_by, null_treatment: self_null_treatment, + distinct: self_distinct, }, } = left.as_ref(); let WindowFunction { @@ -2302,6 +2306,7 @@ impl NormalizeEq for Expr { partition_by: other_partition_by, order_by: other_order_by, null_treatment: other_null_treatment, + distinct: other_distinct, }, } = other.as_ref(); @@ -2325,6 +2330,7 @@ impl NormalizeEq for Expr { && a.nulls_first == b.nulls_first && a.expr.normalize_eq(&b.expr) }) + && self_distinct == other_distinct } ( Expr::Exists(Exists { @@ -2558,11 +2564,13 @@ impl HashNode for Expr { order_by: _, window_frame, null_treatment, + distinct, }, } = window_fun.as_ref(); fun.hash(state); window_frame.hash(state); null_treatment.hash(state); + distinct.hash(state); } Expr::InList(InList { expr: _expr, @@ -2865,15 +2873,27 @@ impl Display for SchemaDisplay<'_> { order_by, window_frame, null_treatment, + distinct, } = params; + // Write function name and open parenthesis + write!(f, "{fun}(")?; + + // If DISTINCT, emit the keyword + if *distinct { + write!(f, "DISTINCT ")?; + } + + // Write the comma‑separated argument list write!( f, - "{}({})", - fun, + "{}", schema_name_from_exprs_comma_separated_without_space(args)? )?; + // **Close the argument parenthesis** + write!(f, ")")?; + if let Some(null_treatment) = null_treatment { write!(f, " {null_treatment}")?; } @@ -3260,9 +3280,10 @@ impl Display for Expr { order_by, window_frame, null_treatment, + distinct, } = params; - fmt_function(f, &fun.to_string(), false, args, true)?; + fmt_function(f, &fun.to_string(), *distinct, args, true)?; if let Some(nt) = null_treatment { write!(f, "{nt}")?; diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index c0351a9dcaca..fab86fe7663d 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -945,6 +945,7 @@ impl ExprFuncBuilder { window_frame: window_frame .unwrap_or_else(|| WindowFrame::new(has_order_by)), null_treatment, + distinct, }, }) } diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs index 067c7a94279f..b04fe32d376e 100644 --- a/datafusion/expr/src/planner.rs +++ b/datafusion/expr/src/planner.rs @@ -308,6 +308,7 @@ pub struct RawWindowExpr { pub order_by: Vec, pub window_frame: WindowFrame, pub null_treatment: Option, + pub distinct: bool, } /// Result of planning a raw expr with [`ExprPlanner`] diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index f953aec5a1e3..b6f583ca4c74 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -242,10 +242,22 @@ impl TreeNode for Expr { order_by, window_frame, null_treatment, + distinct, }, } = *window_fun; (args, partition_by, order_by).map_elements(f)?.update_data( |(new_args, new_partition_by, new_order_by)| { + if distinct { + return Expr::from(WindowFunction::new(fun, new_args)) + .partition_by(new_partition_by) + .order_by(new_order_by) + .window_frame(window_frame) + .null_treatment(null_treatment) + .distinct() + .build() + .unwrap(); + } + Expr::from(WindowFunction::new(fun, new_args)) .partition_by(new_partition_by) .order_by(new_order_by) diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index b6c8eb627c77..15c0dd57ad2c 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -554,14 +554,25 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { order_by, window_frame, null_treatment, + distinct, } = params; let mut schema_name = String::new(); - schema_name.write_fmt(format_args!( - "{}({})", - self.name(), - schema_name_from_exprs(args)? - ))?; + + // Inject DISTINCT into the schema name when requested + if *distinct { + schema_name.write_fmt(format_args!( + "{}(DISTINCT {})", + self.name(), + schema_name_from_exprs(args)? + ))?; + } else { + schema_name.write_fmt(format_args!( + "{}({})", + self.name(), + schema_name_from_exprs(args)? + ))?; + } if let Some(null_treatment) = null_treatment { schema_name.write_fmt(format_args!(" {null_treatment}"))?; @@ -579,7 +590,7 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { " ORDER BY [{}]", schema_name_from_sorts(order_by)? ))?; - }; + } schema_name.write_fmt(format_args!(" {window_frame}"))?; @@ -648,15 +659,24 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { order_by, window_frame, null_treatment, + distinct, } = params; let mut display_name = String::new(); - display_name.write_fmt(format_args!( - "{}({})", - self.name(), - expr_vec_fmt!(args) - ))?; + if *distinct { + display_name.write_fmt(format_args!( + "{}(DISTINCT {})", + self.name(), + expr_vec_fmt!(args) + ))?; + } else { + display_name.write_fmt(format_args!( + "{}({})", + self.name(), + expr_vec_fmt!(args) + ))?; + } if let Some(null_treatment) = null_treatment { display_name.write_fmt(format_args!(" {null_treatment}"))?; diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index 55c8c847ad0a..fce300e79bea 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -77,15 +77,38 @@ pub fn approx_percentile_cont( #[user_doc( doc_section(label = "Approximate Functions"), description = "Returns the approximate percentile of input values using the t-digest algorithm.", - syntax_example = "approx_percentile_cont(percentile, centroids) WITHIN GROUP (ORDER BY expression)", + syntax_example = "approx_percentile_cont(percentile [, centroids]) WITHIN GROUP (ORDER BY expression)", sql_example = r#"```sql +> SELECT approx_percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) FROM table_name; ++------------------------------------------------------------------+ +| approx_percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) | ++------------------------------------------------------------------+ +| 65.0 | ++------------------------------------------------------------------+ > SELECT approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) FROM table_name; +-----------------------------------------------------------------------+ | approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) | +-----------------------------------------------------------------------+ | 65.0 | +-----------------------------------------------------------------------+ -```"#, +``` +An alternate syntax is also supported: +```sql +> SELECT approx_percentile_cont(column_name, 0.75) FROM table_name; ++-----------------------------------------------+ +| approx_percentile_cont(column_name, 0.75) | ++-----------------------------------------------+ +| 65.0 | ++-----------------------------------------------+ + +> SELECT approx_percentile_cont(column_name, 0.75, 100) FROM table_name; ++----------------------------------------------------------+ +| approx_percentile_cont(column_name, 0.75, 100) | ++----------------------------------------------------------+ +| 65.0 | ++----------------------------------------------------------+ +``` +"#, standard_argument(name = "expression",), argument( name = "percentile", @@ -313,7 +336,7 @@ impl AggregateUDFImpl for ApproxPercentileCont { } if arg_types.len() == 3 && !arg_types[2].is_integer() { return plan_err!( - "approx_percentile_cont requires integer max_size input types" + "approx_percentile_cont requires integer centroids input types" ); } Ok(arg_types[0].clone()) @@ -360,6 +383,11 @@ impl ApproxPercentileAccumulator { } } + // public for approx_percentile_cont_with_weight + pub(crate) fn max_size(&self) -> usize { + self.digest.max_size() + } + // public for approx_percentile_cont_with_weight pub fn merge_digests(&mut self, digests: &[TDigest]) { let digests = digests.iter().chain(std::iter::once(&self.digest)); diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs index ab847e838869..f70d751a8cb9 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs @@ -25,32 +25,53 @@ use arrow::datatypes::FieldRef; use arrow::{array::ArrayRef, datatypes::DataType}; use datafusion_common::ScalarValue; use datafusion_common::{not_impl_err, plan_err, Result}; +use datafusion_expr::expr::{AggregateFunction, Sort}; use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; -use datafusion_expr::type_coercion::aggregates::NUMERICS; +use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS}; use datafusion_expr::Volatility::Immutable; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, Documentation, Signature, TypeSignature, -}; -use datafusion_functions_aggregate_common::tdigest::{ - Centroid, TDigest, DEFAULT_MAX_SIZE, + Accumulator, AggregateUDFImpl, Documentation, Expr, Signature, TypeSignature, }; +use datafusion_functions_aggregate_common::tdigest::{Centroid, TDigest}; use datafusion_macros::user_doc; use crate::approx_percentile_cont::{ApproxPercentileAccumulator, ApproxPercentileCont}; -make_udaf_expr_and_func!( +create_func!( ApproxPercentileContWithWeight, - approx_percentile_cont_with_weight, - expression weight percentile, - "Computes the approximate percentile continuous with weight of a set of numbers", approx_percentile_cont_with_weight_udaf ); +/// Computes the approximate percentile continuous with weight of a set of numbers +pub fn approx_percentile_cont_with_weight( + order_by: Sort, + weight: Expr, + percentile: Expr, + centroids: Option, +) -> Expr { + let expr = order_by.expr.clone(); + + let args = if let Some(centroids) = centroids { + vec![expr, weight, percentile, centroids] + } else { + vec![expr, weight, percentile] + }; + + Expr::AggregateFunction(AggregateFunction::new_udf( + approx_percentile_cont_with_weight_udaf(), + args, + false, + None, + vec![order_by], + None, + )) +} + /// APPROX_PERCENTILE_CONT_WITH_WEIGHT aggregate expression #[user_doc( doc_section(label = "Approximate Functions"), description = "Returns the weighted approximate percentile of input values using the t-digest algorithm.", - syntax_example = "approx_percentile_cont_with_weight(weight, percentile) WITHIN GROUP (ORDER BY expression)", + syntax_example = "approx_percentile_cont_with_weight(weight, percentile [, centroids]) WITHIN GROUP (ORDER BY expression)", sql_example = r#"```sql > SELECT approx_percentile_cont_with_weight(weight_column, 0.90) WITHIN GROUP (ORDER BY column_name) FROM table_name; +---------------------------------------------------------------------------------------------+ @@ -58,6 +79,22 @@ make_udaf_expr_and_func!( +---------------------------------------------------------------------------------------------+ | 78.5 | +---------------------------------------------------------------------------------------------+ +> SELECT approx_percentile_cont_with_weight(weight_column, 0.90, 100) WITHIN GROUP (ORDER BY column_name) FROM table_name; ++--------------------------------------------------------------------------------------------------+ +| approx_percentile_cont_with_weight(weight_column, 0.90, 100) WITHIN GROUP (ORDER BY column_name) | ++--------------------------------------------------------------------------------------------------+ +| 78.5 | ++--------------------------------------------------------------------------------------------------+ +``` +An alternative syntax is also supported: + +```sql +> SELECT approx_percentile_cont_with_weight(column_name, weight_column, 0.90) FROM table_name; ++--------------------------------------------------+ +| approx_percentile_cont_with_weight(column_name, weight_column, 0.90) | ++--------------------------------------------------+ +| 78.5 | ++--------------------------------------------------+ ```"#, standard_argument(name = "expression", prefix = "The"), argument( @@ -67,6 +104,10 @@ make_udaf_expr_and_func!( argument( name = "percentile", description = "Percentile to compute. Must be a float value between 0 and 1 (inclusive)." + ), + argument( + name = "centroids", + description = "Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory." ) )] pub struct ApproxPercentileContWithWeight { @@ -91,21 +132,26 @@ impl Default for ApproxPercentileContWithWeight { impl ApproxPercentileContWithWeight { /// Create a new [`ApproxPercentileContWithWeight`] aggregate function. pub fn new() -> Self { + let mut variants = Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1)); + // Accept any numeric value paired with weight and float64 percentile + for num in NUMERICS { + variants.push(TypeSignature::Exact(vec![ + num.clone(), + num.clone(), + DataType::Float64, + ])); + // Additionally accept an integer number of centroids for T-Digest + for int in INTEGERS { + variants.push(TypeSignature::Exact(vec![ + num.clone(), + num.clone(), + DataType::Float64, + int.clone(), + ])); + } + } Self { - signature: Signature::one_of( - // Accept any numeric value paired with a float64 percentile - NUMERICS - .iter() - .map(|t| { - TypeSignature::Exact(vec![ - t.clone(), - t.clone(), - DataType::Float64, - ]) - }) - .collect(), - Immutable, - ), + signature: Signature::one_of(variants, Immutable), approx_percentile_cont: ApproxPercentileCont::new(), } } @@ -138,6 +184,11 @@ impl AggregateUDFImpl for ApproxPercentileContWithWeight { if arg_types[2] != DataType::Float64 { return plan_err!("approx_percentile_cont_with_weight requires float64 percentile input types"); } + if arg_types.len() == 4 && !arg_types[3].is_integer() { + return plan_err!( + "approx_percentile_cont_with_weight requires integer centroids input types" + ); + } Ok(arg_types[0].clone()) } @@ -148,17 +199,25 @@ impl AggregateUDFImpl for ApproxPercentileContWithWeight { ); } - if acc_args.exprs.len() != 3 { + if acc_args.exprs.len() != 3 && acc_args.exprs.len() != 4 { return plan_err!( - "approx_percentile_cont_with_weight requires three arguments: value, weight, percentile" + "approx_percentile_cont_with_weight requires three or four arguments: value, weight, percentile[, centroids]" ); } let sub_args = AccumulatorArgs { - exprs: &[ - Arc::clone(&acc_args.exprs[0]), - Arc::clone(&acc_args.exprs[2]), - ], + exprs: if acc_args.exprs.len() == 4 { + &[ + Arc::clone(&acc_args.exprs[0]), // value + Arc::clone(&acc_args.exprs[2]), // percentile + Arc::clone(&acc_args.exprs[3]), // centroids + ] + } else { + &[ + Arc::clone(&acc_args.exprs[0]), // value + Arc::clone(&acc_args.exprs[2]), // percentile + ] + }, ..acc_args }; let approx_percentile_cont_accumulator = @@ -244,7 +303,7 @@ impl Accumulator for ApproxPercentileWithWeightAccumulator { let mut digests: Vec = vec![]; for (mean, weight) in means_f64.iter().zip(weights_f64.iter()) { digests.push(TDigest::new_with_centroid( - DEFAULT_MAX_SIZE, + self.approx_percentile_cont_accumulator.max_size(), Centroid::new(*mean, *weight), )) } diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 09904bbad6ec..7a7c2879aa79 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -31,7 +31,7 @@ use arrow::{ }; use datafusion_common::{ downcast_value, internal_err, not_impl_err, stats::Precision, - utils::expr::COUNT_STAR_EXPANSION, Result, ScalarValue, + utils::expr::COUNT_STAR_EXPANSION, HashMap, Result, ScalarValue, }; use datafusion_expr::{ expr::WindowFunction, @@ -59,6 +59,7 @@ use std::{ ops::BitAnd, sync::Arc, }; + make_udaf_expr_and_func!( Count, count, @@ -406,6 +407,98 @@ impl AggregateUDFImpl for Count { // the same as new values are seen. SetMonotonicity::Increasing } + + fn create_sliding_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + if args.is_distinct { + let acc = + SlidingDistinctCountAccumulator::try_new(args.return_field.data_type())?; + Ok(Box::new(acc)) + } else { + let acc = CountAccumulator::new(); + Ok(Box::new(acc)) + } + } +} + +// DistinctCountAccumulator does not support retract_batch and sliding window +// this is a specialized accumulator for distinct count that supports retract_batch +// and sliding window. +#[derive(Debug)] +pub struct SlidingDistinctCountAccumulator { + counts: HashMap, + data_type: DataType, +} + +impl SlidingDistinctCountAccumulator { + pub fn try_new(data_type: &DataType) -> Result { + Ok(Self { + counts: HashMap::default(), + data_type: data_type.clone(), + }) + } +} + +impl Accumulator for SlidingDistinctCountAccumulator { + fn state(&mut self) -> Result> { + let keys = self.counts.keys().cloned().collect::>(); + Ok(vec![ScalarValue::List(ScalarValue::new_list_nullable( + keys.as_slice(), + &self.data_type, + ))]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let arr = &values[0]; + for i in 0..arr.len() { + let v = ScalarValue::try_from_array(arr, i)?; + if !v.is_null() { + *self.counts.entry(v).or_default() += 1; + } + } + Ok(()) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let arr = &values[0]; + for i in 0..arr.len() { + let v = ScalarValue::try_from_array(arr, i)?; + if !v.is_null() { + if let Some(cnt) = self.counts.get_mut(&v) { + *cnt -= 1; + if *cnt == 0 { + self.counts.remove(&v); + } + } + } + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let list_arr = states[0].as_list::(); + for inner in list_arr.iter().flatten() { + for j in 0..inner.len() { + let v = ScalarValue::try_from_array(&*inner, j)?; + *self.counts.entry(v).or_default() += 1; + } + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + Ok(ScalarValue::Int64(Some(self.counts.len() as i64))) + } + + fn supports_retract_batch(&self) -> bool { + true + } + + fn size(&self) -> usize { + size_of_val(self) + } } #[derive(Debug)] @@ -878,4 +971,72 @@ mod tests { assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(0))); Ok(()) } + + #[test] + fn sliding_distinct_count_accumulator_basic() -> Result<()> { + // Basic update_batch + evaluate functionality + let mut acc = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?; + // Create an Int32Array: [1, 2, 2, 3, null] + let values: ArrayRef = Arc::new(Int32Array::from(vec![ + Some(1), + Some(2), + Some(2), + Some(3), + None, + ])); + acc.update_batch(&[values])?; + // Expect distinct values {1,2,3} → count = 3 + assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(3))); + Ok(()) + } + + #[test] + fn sliding_distinct_count_accumulator_retract() -> Result<()> { + // Test that retract_batch properly decrements counts + let mut acc = SlidingDistinctCountAccumulator::try_new(&DataType::Utf8)?; + // Initial batch: ["a", "b", "a"] + let arr1 = Arc::new(StringArray::from(vec![Some("a"), Some("b"), Some("a")])) + as ArrayRef; + acc.update_batch(&[arr1])?; + assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(2))); // {"a","b"} + + // Retract batch: ["a", null, "b"] + let arr2 = + Arc::new(StringArray::from(vec![Some("a"), None, Some("b")])) as ArrayRef; + acc.retract_batch(&[arr2])?; + // Before: a→2, b→1; after retract a→1, b→0 → b removed; remaining {"a"} + assert_eq!(acc.evaluate()?, ScalarValue::Int64(Some(1))); + Ok(()) + } + + #[test] + fn sliding_distinct_count_accumulator_merge_states() -> Result<()> { + // Test merging multiple accumulator states with merge_batch + let mut acc1 = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?; + let mut acc2 = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?; + // acc1 sees [1, 2] + acc1.update_batch(&[Arc::new(Int32Array::from(vec![Some(1), Some(2)]))])?; + // acc2 sees [2, 3] + acc2.update_batch(&[Arc::new(Int32Array::from(vec![Some(2), Some(3)]))])?; + // Extract their states as Vec + let state_sv1 = acc1.state()?; + let state_sv2 = acc2.state()?; + // Convert ScalarValue states into Vec, propagating errors + // NOTE we pass `1` because each ScalarValue.to_array produces a 1‑row ListArray + let state_arr1: Vec = state_sv1 + .into_iter() + .map(|sv| sv.to_array()) + .collect::>()?; + let state_arr2: Vec = state_sv2 + .into_iter() + .map(|sv| sv.to_array()) + .collect::>()?; + // Merge both states into a fresh accumulator + let mut merged = SlidingDistinctCountAccumulator::try_new(&DataType::Int32)?; + merged.merge_batch(&state_arr1)?; + merged.merge_batch(&state_arr2)?; + // Expect distinct {1,2,3} → count = 3 + assert_eq!(merged.evaluate()?, ScalarValue::Int64(Some(3))); + Ok(()) + } } diff --git a/datafusion/functions-window/src/planner.rs b/datafusion/functions-window/src/planner.rs index 091737bb9c15..5e3a6bc6336c 100644 --- a/datafusion/functions-window/src/planner.rs +++ b/datafusion/functions-window/src/planner.rs @@ -41,6 +41,7 @@ impl ExprPlanner for WindowFunctionPlanner { order_by, window_frame, null_treatment, + distinct, } = raw_expr; let origin_expr = Expr::from(WindowFunction { @@ -51,6 +52,7 @@ impl ExprPlanner for WindowFunctionPlanner { order_by, window_frame, null_treatment, + distinct, }, }); @@ -68,6 +70,7 @@ impl ExprPlanner for WindowFunctionPlanner { order_by, window_frame, null_treatment, + distinct, }, } = *window_fun; let raw_expr = RawWindowExpr { @@ -77,6 +80,7 @@ impl ExprPlanner for WindowFunctionPlanner { order_by, window_frame, null_treatment, + distinct, }; // TODO: remove the next line after `Expr::Wildcard` is removed @@ -93,18 +97,23 @@ impl ExprPlanner for WindowFunctionPlanner { order_by, window_frame, null_treatment, + distinct, } = raw_expr; - let new_expr = Expr::from(WindowFunction::new( + let mut new_expr_before_build = Expr::from(WindowFunction::new( func_def, vec![Expr::Literal(COUNT_STAR_EXPANSION, None)], )) .partition_by(partition_by) .order_by(order_by) .window_frame(window_frame) - .null_treatment(null_treatment) - .build()?; + .null_treatment(null_treatment); + if distinct { + new_expr_before_build = new_expr_before_build.distinct(); + } + + let new_expr = new_expr_before_build.build()?; let new_expr = saved_name.restore(new_expr); return Ok(PlannerResult::Planned(new_expr)); diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index a98b0fdcc3d3..e6fc006cb2ff 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -549,6 +549,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { order_by, window_frame, null_treatment, + distinct, }, } = *window_fun; let window_frame = @@ -565,14 +566,26 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { _ => args, }; - Ok(Transformed::yes( - Expr::from(WindowFunction::new(fun, args)) - .partition_by(partition_by) - .order_by(order_by) - .window_frame(window_frame) - .null_treatment(null_treatment) - .build()?, - )) + if distinct { + Ok(Transformed::yes( + Expr::from(WindowFunction::new(fun, args)) + .partition_by(partition_by) + .order_by(order_by) + .window_frame(window_frame) + .null_treatment(null_treatment) + .distinct() + .build()?, + )) + } else { + Ok(Transformed::yes( + Expr::from(WindowFunction::new(fun, args)) + .partition_by(partition_by) + .order_by(order_by) + .window_frame(window_frame) + .null_treatment(null_treatment) + .build()?, + )) + } } // TODO: remove the next line after `Expr::Wildcard` is removed #[expect(deprecated)] diff --git a/datafusion/physical-optimizer/src/enforce_sorting/mod.rs b/datafusion/physical-optimizer/src/enforce_sorting/mod.rs index 8a71b28486a2..dae0edcfb171 100644 --- a/datafusion/physical-optimizer/src/enforce_sorting/mod.rs +++ b/datafusion/physical-optimizer/src/enforce_sorting/mod.rs @@ -48,8 +48,8 @@ use crate::enforce_sorting::sort_pushdown::{ }; use crate::output_requirements::OutputRequirementExec; use crate::utils::{ - add_sort_above, add_sort_above_with_check, is_coalesce_partitions, is_limit, - is_repartition, is_sort, is_sort_preserving_merge, is_union, is_window, + add_sort_above, add_sort_above_with_check, is_aggregation, is_coalesce_partitions, + is_limit, is_repartition, is_sort, is_sort_preserving_merge, is_union, is_window, }; use crate::PhysicalOptimizerRule; @@ -678,7 +678,7 @@ fn remove_bottleneck_in_subplan( ) -> Result { let plan = &requirements.plan; let children = &mut requirements.children; - if is_coalesce_partitions(&children[0].plan) { + if is_coalesce_partitions(&children[0].plan) && !is_aggregation(plan) { // We can safely use the 0th index since we have a `CoalescePartitionsExec`. let mut new_child_node = children[0].children.swap_remove(0); while new_child_node.plan.output_partitioning() == plan.output_partitioning() diff --git a/datafusion/physical-optimizer/src/utils.rs b/datafusion/physical-optimizer/src/utils.rs index 3655e555a744..d3207d4880a7 100644 --- a/datafusion/physical-optimizer/src/utils.rs +++ b/datafusion/physical-optimizer/src/utils.rs @@ -19,6 +19,7 @@ use std::sync::Arc; use datafusion_common::Result; use datafusion_physical_expr::{LexOrdering, LexRequirement}; +use datafusion_physical_plan::aggregates::AggregateExec; use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion_physical_plan::repartition::RepartitionExec; @@ -113,3 +114,8 @@ pub fn is_repartition(plan: &Arc) -> bool { pub fn is_limit(plan: &Arc) -> bool { plan.as_any().is::() || plan.as_any().is::() } + +/// Checks whether the given operator is a [`AggregateExec`]. +pub fn is_aggregation(plan: &Arc) -> bool { + plan.as_any().is::() +} diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs index d3335c0e7fe1..4c991544f877 100644 --- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs @@ -1377,6 +1377,7 @@ mod tests { Arc::new(window_frame), &input.schema(), false, + false, )?], input, input_order_mode, diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 5583abfd72a2..085b17cab9bc 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -103,21 +103,38 @@ pub fn create_window_expr( window_frame: Arc, input_schema: &Schema, ignore_nulls: bool, + distinct: bool, ) -> Result> { Ok(match fun { WindowFunctionDefinition::AggregateUDF(fun) => { - let aggregate = AggregateExprBuilder::new(Arc::clone(fun), args.to_vec()) - .schema(Arc::new(input_schema.clone())) - .alias(name) - .with_ignore_nulls(ignore_nulls) - .build() - .map(Arc::new)?; - window_expr_from_aggregate_expr( - partition_by, - order_by, - window_frame, - aggregate, - ) + if distinct { + let aggregate = AggregateExprBuilder::new(Arc::clone(fun), args.to_vec()) + .schema(Arc::new(input_schema.clone())) + .alias(name) + .with_ignore_nulls(ignore_nulls) + .distinct() + .build() + .map(Arc::new)?; + window_expr_from_aggregate_expr( + partition_by, + order_by, + window_frame, + aggregate, + ) + } else { + let aggregate = AggregateExprBuilder::new(Arc::clone(fun), args.to_vec()) + .schema(Arc::new(input_schema.clone())) + .alias(name) + .with_ignore_nulls(ignore_nulls) + .build() + .map(Arc::new)?; + window_expr_from_aggregate_expr( + partition_by, + order_by, + window_frame, + aggregate, + ) + } } WindowFunctionDefinition::WindowUDF(fun) => Arc::new(StandardWindowExpr::new( create_udwf_window_expr(fun, args, input_schema, name, ignore_nulls)?, @@ -805,6 +822,7 @@ mod tests { Arc::new(WindowFrame::new(None)), schema.as_ref(), false, + false, )?], blocking_exec, false, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 43afaa0fbe65..f59e97df0d46 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -316,6 +316,7 @@ pub fn serialize_expr( ref window_frame, // TODO: support null treatment in proto null_treatment: _, + distinct: _, }, } = window_fun.as_ref(); let mut buf = Vec::new(); diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 1c60470b2218..2ed6ec037fc8 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -179,6 +179,7 @@ pub fn parse_physical_window_expr( Arc::new(window_frame), &extended_schema, false, + false, ) } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 6c51d553fe16..b56fdc0fede6 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -981,7 +981,18 @@ async fn roundtrip_expr_api() -> Result<()> { approx_median(lit(2)), approx_percentile_cont(lit(2).sort(true, false), lit(0.5), None), approx_percentile_cont(lit(2).sort(true, false), lit(0.5), Some(lit(50))), - approx_percentile_cont_with_weight(lit(2), lit(1), lit(0.5)), + approx_percentile_cont_with_weight( + lit(2).sort(true, false), + lit(1), + lit(0.5), + None, + ), + approx_percentile_cont_with_weight( + lit(2).sort(true, false), + lit(1), + lit(0.5), + Some(lit(50)), + ), grouping(lit(1)), bit_and(lit(2)), bit_or(lit(2)), diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index e63ca75d019d..fd0e7dc6e3b9 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -352,6 +352,7 @@ impl SqlToRel<'_, S> { order_by, window_frame, null_treatment, + distinct: function_args.distinct, }; for planner in self.context_provider.get_expr_planners().iter() { @@ -368,8 +369,19 @@ impl SqlToRel<'_, S> { order_by, window_frame, null_treatment, + distinct, } = window_expr; + if distinct { + return Expr::from(expr::WindowFunction::new(func_def, args)) + .partition_by(partition_by) + .order_by(order_by) + .window_frame(window_frame) + .null_treatment(null_treatment) + .distinct() + .build(); + } + return Expr::from(expr::WindowFunction::new(func_def, args)) .partition_by(partition_by) .order_by(order_by) @@ -380,10 +392,6 @@ impl SqlToRel<'_, S> { } else { // User defined aggregate functions (UDAF) have precedence in case it has the same name as a scalar built-in function if let Some(fm) = self.context_provider.get_aggregate_meta(&name) { - if fm.is_ordered_set_aggregate() && within_group.is_empty() { - return plan_err!("WITHIN GROUP clause is required when calling ordered set aggregate function({})", fm.name()); - } - if null_treatment.is_some() && !fm.supports_null_handling_clause() { return plan_err!( "[IGNORE | RESPECT] NULLS are not permitted for {}", @@ -403,7 +411,8 @@ impl SqlToRel<'_, S> { None, )?; - // add target column expression in within group clause to function arguments + // Add the WITHIN GROUP ordering expressions to the front of the argument list + // So function(arg) WITHIN GROUP (ORDER BY x) becomes function(x, arg) if !within_group.is_empty() { args = within_group .iter() diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 4ddd5ccccbbd..4c0dc316615c 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -18,8 +18,9 @@ use datafusion_expr::expr::{AggregateFunctionParams, Unnest, WindowFunctionParams}; use sqlparser::ast::Value::SingleQuotedString; use sqlparser::ast::{ - self, Array, BinaryOperator, CaseWhen, Expr as AstExpr, Function, Ident, Interval, - ObjectName, OrderByOptions, Subscript, TimezoneInfo, UnaryOperator, ValueWithSpan, + self, Array, BinaryOperator, CaseWhen, DuplicateTreatment, Expr as AstExpr, Function, + Ident, Interval, ObjectName, OrderByOptions, Subscript, TimezoneInfo, UnaryOperator, + ValueWithSpan, }; use std::sync::Arc; use std::vec; @@ -198,6 +199,7 @@ impl Unparser<'_> { partition_by, order_by, window_frame, + distinct, .. }, } = window_fun.as_ref(); @@ -256,7 +258,8 @@ impl Unparser<'_> { span: Span::empty(), }]), args: ast::FunctionArguments::List(ast::FunctionArgumentList { - duplicate_treatment: None, + duplicate_treatment: distinct + .then_some(DuplicateTreatment::Distinct), args, clauses: vec![], }), @@ -339,7 +342,7 @@ impl Unparser<'_> { }]), args: ast::FunctionArguments::List(ast::FunctionArgumentList { duplicate_treatment: distinct - .then_some(ast::DuplicateTreatment::Distinct), + .then_some(DuplicateTreatment::Distinct), args, clauses: vec![], }), @@ -2051,6 +2054,7 @@ mod tests { order_by: vec![], window_frame: WindowFrame::new(None), null_treatment: None, + distinct: false, }, }), r#"row_number(col) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)"#, @@ -2076,6 +2080,7 @@ mod tests { ), ), null_treatment: None, + distinct: false, }, }), r#"count(*) OVER (ORDER BY a DESC NULLS FIRST RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING)"#, diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 753820b6b619..4671408349e2 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -1287,7 +1287,7 @@ SELECT approx_distinct(c9) AS a, approx_distinct(c9) AS b FROM aggregate_test_10 ## Column `c12` is omitted due to a large relative error (~10%) due to the small ## float values. -#csv_query_approx_percentile_cont (c2) +# csv_query_approx_percentile_cont (c2) query B SELECT (ABS(1 - CAST(approx_percentile_cont(0.1) WITHIN GROUP (ORDER BY c2) AS DOUBLE) / 1.0) < 0.05) AS q FROM aggregate_test_100 ---- @@ -1303,6 +1303,23 @@ SELECT (ABS(1 - CAST(approx_percentile_cont(0.9) WITHIN GROUP (ORDER BY c2) AS D ---- true + +# csv_query_approx_percentile_cont (c2, alternate syntax, should be the same as above) +query B +SELECT (ABS(1 - CAST(approx_percentile_cont(c2, 0.1) AS DOUBLE) / 1.0) < 0.05) AS q FROM aggregate_test_100 +---- +true + +query B +SELECT (ABS(1 - CAST(approx_percentile_cont(c2, 0.5) AS DOUBLE) / 3.0) < 0.05) AS q FROM aggregate_test_100 +---- +true + +query B +SELECT (ABS(1 - CAST(approx_percentile_cont(c2, 0.9) AS DOUBLE) / 5.0) < 0.05) AS q FROM aggregate_test_100 +---- +true + # csv_query_approx_percentile_cont (c3) query B SELECT (ABS(1 - CAST(approx_percentile_cont(0.1) WITHIN GROUP (ORDER BY c3) AS DOUBLE) / -95.3) < 0.05) AS q FROM aggregate_test_100 @@ -1743,6 +1760,17 @@ c 122 d 124 e 115 + +# csv_query_approx_percentile_cont_with_weight (should be the same as above) +query TI +SELECT c1, approx_percentile_cont(c3, 0.95) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +---- +a 73 +b 68 +c 122 +d 124 +e 115 + query TI SELECT c1, approx_percentile_cont(0.95) WITHIN GROUP (ORDER BY c3 DESC) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 ---- @@ -1762,6 +1790,17 @@ c 122 d 124 e 115 +# csv_query_approx_percentile_cont_with_weight alternate syntax +query TI +SELECT c1, approx_percentile_cont_with_weight(c3, 1, 0.95) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +---- +a 73 +b 68 +c 122 +d 124 +e 115 + + query TI SELECT c1, approx_percentile_cont_with_weight(1, 0.95) WITHIN GROUP (ORDER BY c3 DESC) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 ---- @@ -1790,6 +1829,16 @@ c 123 d 124 e 115 +# approx_percentile_cont_with_weight with centroids +query TI +SELECT c1, approx_percentile_cont_with_weight(c2, 0.95, 200) WITHIN GROUP (ORDER BY c3) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +---- +a 74 +b 68 +c 123 +d 124 +e 115 + # csv_query_sum_crossjoin query TTI SELECT a.c1, b.c1, SUM(a.c2) FROM aggregate_test_100 as a CROSS JOIN aggregate_test_100 as b GROUP BY a.c1, b.c1 ORDER BY a.c1, b.c1 diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index f76e436e0ad3..0f1fb892c746 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -259,7 +259,7 @@ datafusion.execution.parquet.writer_version 1.0 datafusion.execution.planning_concurrency 13 datafusion.execution.skip_partial_aggregation_probe_ratio_threshold 0.8 datafusion.execution.skip_partial_aggregation_probe_rows_threshold 100000 -datafusion.execution.skip_physical_aggregate_schema_check false +datafusion.execution.skip_physical_aggregate_schema_check true datafusion.execution.soft_max_rows_per_output_file 50000000 datafusion.execution.sort_in_place_threshold_bytes 1048576 datafusion.execution.sort_spill_reservation_bytes 10485760 @@ -371,7 +371,7 @@ datafusion.execution.parquet.writer_version 1.0 (writing) Sets parquet writer ve datafusion.execution.planning_concurrency 13 Fan-out during initial physical planning. This is mostly use to plan `UNION` children in parallel. Defaults to the number of CPU cores on the system datafusion.execution.skip_partial_aggregation_probe_ratio_threshold 0.8 Aggregation ratio (number of distinct groups / number of input rows) threshold for skipping partial aggregation. If the value is greater then partial aggregation will skip aggregation for further input datafusion.execution.skip_partial_aggregation_probe_rows_threshold 100000 Number of input rows partial aggregation partition should process, before aggregation ratio check and trying to switch to skipping aggregation mode -datafusion.execution.skip_physical_aggregate_schema_check false When set to true, skips verifying that the schema produced by planning the input of `LogicalPlan::Aggregate` exactly matches the schema of the input plan. When set to false, if the schema does not match exactly (including nullability and metadata), a planning error will be raised. This is used to workaround bugs in the planner that are now caught by the new schema verification step. +datafusion.execution.skip_physical_aggregate_schema_check true When set to true, skips verifying that the schema produced by planning the input of `LogicalPlan::Aggregate` exactly matches the schema of the input plan. When set to false, if the schema does not match exactly (including nullability and metadata), a planning error will be raised. This is used to workaround bugs in the planner that are now caught by the new schema verification step. datafusion.execution.soft_max_rows_per_output_file 50000000 Target number of rows in output files when writing multiple. This is a soft max, so it can be exceeded slightly. There also will be one file smaller than the limit if the total number of rows written is not roughly divisible by the soft max datafusion.execution.sort_in_place_threshold_bytes 1048576 When sorting, below what size should data be concatenated and sorted in a single RecordBatch rather than sorted in batches and merged. datafusion.execution.sort_spill_reservation_bytes 10485760 Specifies the reserved memory for each spillable sort operation to facilitate an in-memory merge. When a sort operation spills to disk, the in-memory data must be sorted and merged before being written to a file. This setting reserves a specific amount of memory for that in-memory sort/merge process. Note: This setting is irrelevant if the sort operation cannot spill (i.e., if there's no `DiskManager` configured). diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 82de11302857..bed9121eec3f 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -5650,3 +5650,82 @@ WINDOW 3 7 4 11 5 16 + + +# window with distinct operation +statement ok +CREATE TABLE table_test_distinct_count ( + k VARCHAR, + v Int, + time TIMESTAMP WITH TIME ZONE +); + +statement ok +INSERT INTO table_test_distinct_count (k, v, time) VALUES + ('a', 1, '1970-01-01T00:01:00.00Z'), + ('a', 1, '1970-01-01T00:02:00.00Z'), + ('a', 1, '1970-01-01T00:03:00.00Z'), + ('a', 2, '1970-01-01T00:03:00.00Z'), + ('a', 1, '1970-01-01T00:04:00.00Z'), + ('b', 3, '1970-01-01T00:01:00.00Z'), + ('b', 3, '1970-01-01T00:02:00.00Z'), + ('b', 4, '1970-01-01T00:03:00.00Z'), + ('b', 4, '1970-01-01T00:03:00.00Z'); + +query TPII +SELECT + k, + time, + COUNT(v) OVER ( + PARTITION BY k + ORDER BY time + RANGE BETWEEN INTERVAL '2 minutes' PRECEDING AND CURRENT ROW + ) AS normal_count, + COUNT(DISTINCT v) OVER ( + PARTITION BY k + ORDER BY time + RANGE BETWEEN INTERVAL '2 minutes' PRECEDING AND CURRENT ROW + ) AS distinct_count +FROM table_test_distinct_count +ORDER BY k, time; +---- +a 1970-01-01T00:01:00Z 1 1 +a 1970-01-01T00:02:00Z 2 1 +a 1970-01-01T00:03:00Z 4 2 +a 1970-01-01T00:03:00Z 4 2 +a 1970-01-01T00:04:00Z 4 2 +b 1970-01-01T00:01:00Z 1 1 +b 1970-01-01T00:02:00Z 2 1 +b 1970-01-01T00:03:00Z 4 2 +b 1970-01-01T00:03:00Z 4 2 + + +query TT +EXPLAIN SELECT + k, + time, + COUNT(v) OVER ( + PARTITION BY k + ORDER BY time + RANGE BETWEEN INTERVAL '2 minutes' PRECEDING AND CURRENT ROW + ) AS normal_count, + COUNT(DISTINCT v) OVER ( + PARTITION BY k + ORDER BY time + RANGE BETWEEN INTERVAL '2 minutes' PRECEDING AND CURRENT ROW + ) AS distinct_count +FROM table_test_distinct_count +ODER BY k, time; +---- +logical_plan +01)Projection: oder.k, oder.time, count(oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW AS normal_count, count(DISTINCT oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW AS distinct_count +02)--WindowAggr: windowExpr=[[count(oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 } PRECEDING AND CURRENT ROW AS count(oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW, count(DISTINCT oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 } PRECEDING AND CURRENT ROW AS count(DISTINCT oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW]] +03)----SubqueryAlias: oder +04)------TableScan: table_test_distinct_count projection=[k, v, time] +physical_plan +01)ProjectionExec: expr=[k@0 as k, time@2 as time, count(oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW@3 as normal_count, count(DISTINCT oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW@4 as distinct_count] +02)--BoundedWindowAggExec: wdw=[count(oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW: Field { name: "count(oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 } PRECEDING AND CURRENT ROW, count(DISTINCT oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW: Field { name: "count(DISTINCT oder.v) PARTITION BY [oder.k] ORDER BY [oder.time ASC NULLS LAST] RANGE BETWEEN 2 minutes PRECEDING AND CURRENT ROW", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, frame: RANGE BETWEEN IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 } PRECEDING AND CURRENT ROW], mode=[Sorted] +03)----SortExec: expr=[k@0 ASC NULLS LAST, time@2 ASC NULLS LAST], preserve_partitioning=[true] +04)------CoalesceBatchesExec: target_batch_size=1 +05)--------RepartitionExec: partitioning=Hash([k@0], 2), input_partitions=2 +06)----------DataSourceExec: partitions=2, partition_sizes=[5, 4] diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/window_function.rs b/datafusion/substrait/src/logical_plan/consumer/expr/window_function.rs index 80b643a547ee..27f0de84b7a0 100644 --- a/datafusion/substrait/src/logical_plan/consumer/expr/window_function.rs +++ b/datafusion/substrait/src/logical_plan/consumer/expr/window_function.rs @@ -112,6 +112,7 @@ pub async fn from_window_function( order_by, window_frame, null_treatment: None, + distinct: false, }, })) } diff --git a/datafusion/substrait/src/logical_plan/producer/expr/window_function.rs b/datafusion/substrait/src/logical_plan/producer/expr/window_function.rs index 17e71f2d7c14..94a39e930f1c 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/window_function.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/window_function.rs @@ -42,6 +42,7 @@ pub fn from_window_function( order_by, window_frame, null_treatment: _, + distinct: _, }, } = window_fn; // function reference diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index dc950de01f1b..7ea22bc0ba30 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -82,7 +82,7 @@ Environment variables are read during `SessionConfig` initialisation so they mus | datafusion.execution.parquet.maximum_parallel_row_group_writers | 1 | (writing) By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. | | datafusion.execution.parquet.maximum_buffered_record_batches_per_stream | 2 | (writing) By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. | | datafusion.execution.planning_concurrency | 0 | Fan-out during initial physical planning. This is mostly use to plan `UNION` children in parallel. Defaults to the number of CPU cores on the system | -| datafusion.execution.skip_physical_aggregate_schema_check | false | When set to true, skips verifying that the schema produced by planning the input of `LogicalPlan::Aggregate` exactly matches the schema of the input plan. When set to false, if the schema does not match exactly (including nullability and metadata), a planning error will be raised. This is used to workaround bugs in the planner that are now caught by the new schema verification step. | +| datafusion.execution.skip_physical_aggregate_schema_check | true | When set to true, skips verifying that the schema produced by planning the input of `LogicalPlan::Aggregate` exactly matches the schema of the input plan. When set to false, if the schema does not match exactly (including nullability and metadata), a planning error will be raised. This is used to workaround bugs in the planner that are now caught by the new schema verification step. | | datafusion.execution.spill_compression | uncompressed | Sets the compression codec used when spilling data to disk. Since datafusion writes spill files using the Arrow IPC Stream format, only codecs supported by the Arrow IPC Stream Writer are allowed. Valid values are: uncompressed, lz4_frame, zstd. Note: lz4_frame offers faster (de)compression, but typically results in larger spill files. In contrast, zstd achieves higher compression ratios at the cost of slower (de)compression speed. | | datafusion.execution.sort_spill_reservation_bytes | 10485760 | Specifies the reserved memory for each spillable sort operation to facilitate an in-memory merge. When a sort operation spills to disk, the in-memory data must be sorted and merged before being written to a file. This setting reserves a specific amount of memory for that in-memory sort/merge process. Note: This setting is irrelevant if the sort operation cannot spill (i.e., if there's no `DiskManager` configured). | | datafusion.execution.sort_in_place_threshold_bytes | 1048576 | When sorting, below what size should data be concatenated and sorted in a single RecordBatch rather than sorted in batches and merged. | diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index 03ab86eeb813..abf0286fa85b 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -285,27 +285,27 @@ select log(-1), log(0), sqrt(-1); ## Aggregate Functions -| Syntax | Description | -| ----------------------------------------------------------------- | --------------------------------------------------------------------------------------- | -| avg(expr) | Сalculates the average value for `expr`. | -| approx_distinct(expr) | Calculates an approximate count of the number of distinct values for `expr`. | -| approx_median(expr) | Calculates an approximation of the median for `expr`. | -| approx_percentile_cont(expr, percentile) | Calculates an approximation of the specified `percentile` for `expr`. | -| approx_percentile_cont_with_weight(expr, weight_expr, percentile) | Calculates an approximation of the specified `percentile` for `expr` and `weight_expr`. | -| bit_and(expr) | Computes the bitwise AND of all non-null input values for `expr`. | -| bit_or(expr) | Computes the bitwise OR of all non-null input values for `expr`. | -| bit_xor(expr) | Computes the bitwise exclusive OR of all non-null input values for `expr`. | -| bool_and(expr) | Returns true if all non-null input values (`expr`) are true, otherwise false. | -| bool_or(expr) | Returns true if any non-null input value (`expr`) is true, otherwise false. | -| count(expr) | Returns the number of rows for `expr`. | -| count_distinct | Creates an expression to represent the count(distinct) aggregate function | -| cube(exprs) | Creates a grouping set for all combination of `exprs` | -| grouping_set(exprs) | Create a grouping set. | -| max(expr) | Finds the maximum value of `expr`. | -| median(expr) | Сalculates the median of `expr`. | -| min(expr) | Finds the minimum value of `expr`. | -| rollup(exprs) | Creates a grouping set for rollup sets. | -| sum(expr) | Сalculates the sum of `expr`. | +| Syntax | Description | +| ------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------- | +| avg(expr) | Сalculates the average value for `expr`. | +| approx_distinct(expr) | Calculates an approximate count of the number of distinct values for `expr`. | +| approx_median(expr) | Calculates an approximation of the median for `expr`. | +| approx_percentile_cont(expr, percentile [, centroids]) | Calculates an approximation of the specified `percentile` for `expr`. Optional `centroids` parameter controls accuracy (default: 100). | +| approx_percentile_cont_with_weight(expr, weight_expr, percentile [, centroids]) | Calculates an approximation of the specified `percentile` for `expr` and `weight_expr`. Optional `centroids` parameter controls accuracy (default: 100). | +| bit_and(expr) | Computes the bitwise AND of all non-null input values for `expr`. | +| bit_or(expr) | Computes the bitwise OR of all non-null input values for `expr`. | +| bit_xor(expr) | Computes the bitwise exclusive OR of all non-null input values for `expr`. | +| bool_and(expr) | Returns true if all non-null input values (`expr`) are true, otherwise false. | +| bool_or(expr) | Returns true if any non-null input value (`expr`) is true, otherwise false. | +| count(expr) | Returns the number of rows for `expr`. | +| count_distinct | Creates an expression to represent the count(distinct) aggregate function | +| cube(exprs) | Creates a grouping set for all combination of `exprs` | +| grouping_set(exprs) | Create a grouping set. | +| max(expr) | Finds the maximum value of `expr`. | +| median(expr) | Сalculates the median of `expr`. | +| min(expr) | Finds the minimum value of `expr`. | +| rollup(exprs) | Creates a grouping set for rollup sets. | +| sum(expr) | Сalculates the sum of `expr`. | ## Aggregate Function Builder diff --git a/docs/source/user-guide/sql/aggregate_functions.md b/docs/source/user-guide/sql/aggregate_functions.md index 774a4fae6bf3..4f2f0abe55c9 100644 --- a/docs/source/user-guide/sql/aggregate_functions.md +++ b/docs/source/user-guide/sql/aggregate_functions.md @@ -834,7 +834,7 @@ approx_median(expression) Returns the approximate percentile of input values using the t-digest algorithm. ```sql -approx_percentile_cont(percentile, centroids) WITHIN GROUP (ORDER BY expression) +approx_percentile_cont(percentile [, centroids]) WITHIN GROUP (ORDER BY expression) ``` #### Arguments @@ -846,6 +846,12 @@ approx_percentile_cont(percentile, centroids) WITHIN GROUP (ORDER BY expression) #### Example ```sql +> SELECT approx_percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) FROM table_name; ++------------------------------------------------------------------+ +| approx_percentile_cont(0.75) WITHIN GROUP (ORDER BY column_name) | ++------------------------------------------------------------------+ +| 65.0 | ++------------------------------------------------------------------+ > SELECT approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) FROM table_name; +-----------------------------------------------------------------------+ | approx_percentile_cont(0.75, 100) WITHIN GROUP (ORDER BY column_name) | @@ -854,12 +860,30 @@ approx_percentile_cont(percentile, centroids) WITHIN GROUP (ORDER BY expression) +-----------------------------------------------------------------------+ ``` +An alternate syntax is also supported: + +```sql +> SELECT approx_percentile_cont(column_name, 0.75) FROM table_name; ++-----------------------------------------------+ +| approx_percentile_cont(column_name, 0.75) | ++-----------------------------------------------+ +| 65.0 | ++-----------------------------------------------+ + +> SELECT approx_percentile_cont(column_name, 0.75, 100) FROM table_name; ++----------------------------------------------------------+ +| approx_percentile_cont(column_name, 0.75, 100) | ++----------------------------------------------------------+ +| 65.0 | ++----------------------------------------------------------+ +``` + ### `approx_percentile_cont_with_weight` Returns the weighted approximate percentile of input values using the t-digest algorithm. ```sql -approx_percentile_cont_with_weight(weight, percentile) WITHIN GROUP (ORDER BY expression) +approx_percentile_cont_with_weight(weight, percentile [, centroids]) WITHIN GROUP (ORDER BY expression) ``` #### Arguments @@ -867,6 +891,7 @@ approx_percentile_cont_with_weight(weight, percentile) WITHIN GROUP (ORDER BY ex - **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. - **weight**: Expression to use as weight. Can be a constant, column, or function, and any combination of arithmetic operators. - **percentile**: Percentile to compute. Must be a float value between 0 and 1 (inclusive). +- **centroids**: Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory. #### Example @@ -877,4 +902,21 @@ approx_percentile_cont_with_weight(weight, percentile) WITHIN GROUP (ORDER BY ex +---------------------------------------------------------------------------------------------+ | 78.5 | +---------------------------------------------------------------------------------------------+ +> SELECT approx_percentile_cont_with_weight(weight_column, 0.90, 100) WITHIN GROUP (ORDER BY column_name) FROM table_name; ++--------------------------------------------------------------------------------------------------+ +| approx_percentile_cont_with_weight(weight_column, 0.90, 100) WITHIN GROUP (ORDER BY column_name) | ++--------------------------------------------------------------------------------------------------+ +| 78.5 | ++--------------------------------------------------------------------------------------------------+ +``` + +An alternative syntax is also supported: + +```sql +> SELECT approx_percentile_cont_with_weight(column_name, weight_column, 0.90) FROM table_name; ++--------------------------------------------------+ +| approx_percentile_cont_with_weight(column_name, weight_column, 0.90) | ++--------------------------------------------------+ +| 78.5 | ++--------------------------------------------------+ ```