17
17
18
18
//! This module provides a builder for creating LogicalPlans
19
19
20
+ use crate :: expr:: AggregateFunction ;
20
21
use crate :: expr_rewriter:: {
21
22
coerce_plan_expr_for_schema, normalize_col,
22
23
normalize_col_with_schemas_and_ambiguity_check, normalize_cols,
23
24
rewrite_sort_cols_by_aggs,
24
25
} ;
25
26
use crate :: type_coercion:: binary:: comparison_coercion;
26
- use crate :: utils:: { columnize_expr, compare_sort_expr, exprlist_to_fields, from_plan} ;
27
- use crate :: { and, binary_expr, Operator } ;
27
+ use crate :: utils:: {
28
+ columnize_expr, compare_sort_expr, exprlist_to_fields, from_plan,
29
+ COUNT_STAR_EXPANSION ,
30
+ } ;
31
+ use crate :: { aggregate_function, and, binary_expr, lit, Operator } ;
28
32
use crate :: {
29
33
logical_plan:: {
30
34
Aggregate , Analyze , CrossJoin , Distinct , EmptyRelation , Explain , Filter , Join ,
@@ -47,6 +51,7 @@ use std::any::Any;
47
51
use std:: cmp:: Ordering ;
48
52
use std:: collections:: { HashMap , HashSet } ;
49
53
use std:: convert:: TryFrom ;
54
+ use std:: env:: args;
50
55
use std:: sync:: Arc ;
51
56
52
57
/// Default table name for unnamed table
@@ -778,6 +783,9 @@ impl LogicalPlanBuilder {
778
783
window_expr : impl IntoIterator < Item = impl Into < Expr > > ,
779
784
) -> Result < Self > {
780
785
let window_expr = normalize_cols ( window_expr, & self . plan ) ?;
786
+ //handle Count(Expr:Wildcard) with DataFrame API
787
+ let window_expr = handle_wildcard ( window_expr) ?;
788
+
781
789
let all_expr = window_expr. iter ( ) ;
782
790
validate_unique_names ( "Windows" , all_expr. clone ( ) ) ?;
783
791
let mut window_fields: Vec < DFField > = self . plan . schema ( ) . fields ( ) . clone ( ) ;
@@ -801,6 +809,10 @@ impl LogicalPlanBuilder {
801
809
) -> Result < Self > {
802
810
let group_expr = normalize_cols ( group_expr, & self . plan ) ?;
803
811
let aggr_expr = normalize_cols ( aggr_expr, & self . plan ) ?;
812
+
813
+ //handle Count(Expr:Wildcard) with DataFrame API
814
+ let aggr_expr = handle_wildcard ( aggr_expr) ?;
815
+
804
816
Ok ( Self :: from ( LogicalPlan :: Aggregate ( Aggregate :: try_new (
805
817
Arc :: new ( self . plan ) ,
806
818
group_expr,
@@ -986,6 +998,30 @@ impl LogicalPlanBuilder {
986
998
}
987
999
}
988
1000
1001
+ //handle Count(Expr:Wildcard) with DataFrame API
1002
+ pub fn handle_wildcard ( exprs : Vec < Expr > ) -> Result < Vec < Expr > > {
1003
+ let exprs: Vec < Expr > = exprs
1004
+ . iter ( )
1005
+ . map ( |expr| match expr {
1006
+ Expr :: AggregateFunction ( AggregateFunction {
1007
+ fun : aggregate_function:: AggregateFunction :: Count ,
1008
+ args,
1009
+ distinct,
1010
+ filter,
1011
+ } ) if args. len ( ) == 1 => match args[ 0 ] {
1012
+ Expr :: Wildcard => Expr :: AggregateFunction ( AggregateFunction {
1013
+ fun : aggregate_function:: AggregateFunction :: Count ,
1014
+ args : vec ! [ lit( COUNT_STAR_EXPANSION ) ] ,
1015
+ distinct : * distinct,
1016
+ filter : filter. clone ( ) ,
1017
+ } ) ,
1018
+ _ => expr. clone ( ) ,
1019
+ } ,
1020
+ _ => expr. clone ( ) ,
1021
+ } )
1022
+ . collect ( ) ;
1023
+ Ok ( exprs)
1024
+ }
989
1025
/// Creates a schema for a join operation.
990
1026
/// The fields from the left side are first
991
1027
pub fn build_join_schema (
@@ -1315,7 +1351,7 @@ pub fn unnest(input: LogicalPlan, column: Column) -> Result<LogicalPlan> {
1315
1351
1316
1352
#[ cfg( test) ]
1317
1353
mod tests {
1318
- use crate :: { expr, expr_fn:: exists} ;
1354
+ use crate :: { count , expr, expr_fn:: exists} ;
1319
1355
use arrow:: datatypes:: { DataType , Field } ;
1320
1356
use datafusion_common:: SchemaError ;
1321
1357
@@ -1324,6 +1360,35 @@ mod tests {
1324
1360
use super :: * ;
1325
1361
use crate :: { col, in_subquery, lit, scalar_subquery, sum} ;
1326
1362
1363
+ #[ test]
1364
+ fn window_wildcard ( ) -> Result < ( ) > {
1365
+ let plan = table_scan ( Some ( "employee_csv" ) , & employee_schema ( ) , Some ( vec ! [ ] ) ) ?
1366
+ . window ( vec ! [ count( Expr :: Wildcard ) ] ) ?
1367
+ . build ( ) ?;
1368
+
1369
+ let expected = "WindowAggr: windowExpr=[[COUNT(UInt8(1))]]\
1370
+ \n TableScan: employee_csv projection=[]";
1371
+
1372
+ assert_eq ! ( expected, format!( "{plan:?}" ) ) ;
1373
+
1374
+ Ok ( ( ) )
1375
+ }
1376
+
1377
+ #[ test]
1378
+ fn count_wildcard ( ) -> Result < ( ) > {
1379
+ let group_expr: Vec < Expr > = Vec :: new ( ) ;
1380
+ let plan = table_scan ( Some ( "employee_csv" ) , & employee_schema ( ) , Some ( vec ! [ ] ) ) ?
1381
+ . aggregate ( group_expr, vec ! [ count( Expr :: Wildcard ) ] ) ?
1382
+ . build ( ) ?;
1383
+
1384
+ let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\
1385
+ \n TableScan: employee_csv projection=[]";
1386
+
1387
+ assert_eq ! ( expected, format!( "{plan:?}" ) ) ;
1388
+
1389
+ Ok ( ( ) )
1390
+ }
1391
+
1327
1392
#[ test]
1328
1393
fn plan_builder_simple ( ) -> Result < ( ) > {
1329
1394
let plan =
0 commit comments