Skip to content

Commit 2f91005

Browse files
committed
fix erro on Count(Expr:Wildcard) with DataFrame API
1 parent 146a949 commit 2f91005

File tree

2 files changed

+122
-3
lines changed

2 files changed

+122
-3
lines changed

datafusion/core/tests/dataframe.rs

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
// under the License.
1717

1818
use arrow::datatypes::{DataType, Field, Schema};
19+
use arrow::util::pretty::pretty_format_batches;
1920
use arrow::{
2021
array::{
2122
ArrayRef, Int32Array, Int32Builder, ListBuilder, StringArray, StringBuilder,
@@ -35,6 +36,59 @@ use datafusion::{assert_batches_eq, assert_batches_sorted_eq};
3536
use datafusion_expr::expr::{GroupingSet, Sort};
3637
use datafusion_expr::{avg, col, count, lit, max, sum, Expr, ExprSchemable};
3738

39+
#[tokio::test]
40+
async fn count_wildcard() -> Result<()> {
41+
let ctx = SessionContext::new();
42+
let testdata = datafusion::test_util::parquet_test_data();
43+
44+
ctx.register_parquet(
45+
"alltypes_tiny_pages",
46+
&format!("{testdata}/alltypes_tiny_pages.parquet"),
47+
ParquetReadOptions::default(),
48+
)
49+
.await?;
50+
51+
let sql_results = ctx
52+
.sql("select count(*) from alltypes_tiny_pages")
53+
.await?
54+
.explain(false, false)?
55+
.collect()
56+
.await?;
57+
58+
let df_results = ctx
59+
.table("alltypes_tiny_pages")
60+
.await?
61+
.aggregate(vec![], vec![count(Expr::Wildcard)])?
62+
.explain(false, false)
63+
.unwrap()
64+
.collect()
65+
.await?;
66+
67+
//make sure sql plan same with df plan
68+
assert_eq!(
69+
pretty_format_batches(&sql_results)?.to_string(),
70+
pretty_format_batches(&df_results)?.to_string()
71+
);
72+
73+
let results = ctx
74+
.table("alltypes_tiny_pages")
75+
.await?
76+
.aggregate(vec![], vec![count(Expr::Wildcard)])?
77+
.collect()
78+
.await?;
79+
80+
let expected = vec![
81+
"+-----------------+",
82+
"| COUNT(UInt8(1)) |",
83+
"+-----------------+",
84+
"| 7300 |",
85+
"+-----------------+",
86+
];
87+
assert_batches_sorted_eq!(expected, &results);
88+
89+
Ok(())
90+
}
91+
3892
#[tokio::test]
3993
async fn describe() -> Result<()> {
4094
let ctx = SessionContext::new();

datafusion/expr/src/logical_plan/builder.rs

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,18 @@
1717

1818
//! This module provides a builder for creating LogicalPlans
1919
20+
use crate::expr::AggregateFunction;
2021
use crate::expr_rewriter::{
2122
coerce_plan_expr_for_schema, normalize_col,
2223
normalize_col_with_schemas_and_ambiguity_check, normalize_cols,
2324
rewrite_sort_cols_by_aggs,
2425
};
2526
use crate::type_coercion::binary::comparison_coercion;
26-
use crate::utils::{columnize_expr, compare_sort_expr, exprlist_to_fields, from_plan};
27-
use crate::{and, binary_expr, Operator};
27+
use crate::utils::{
28+
columnize_expr, compare_sort_expr, exprlist_to_fields, from_plan,
29+
COUNT_STAR_EXPANSION,
30+
};
31+
use crate::{aggregate_function, and, binary_expr, lit, Operator};
2832
use crate::{
2933
logical_plan::{
3034
Aggregate, Analyze, CrossJoin, Distinct, EmptyRelation, Explain, Filter, Join,
@@ -47,6 +51,7 @@ use std::any::Any;
4751
use std::cmp::Ordering;
4852
use std::collections::{HashMap, HashSet};
4953
use std::convert::TryFrom;
54+
use std::env::args;
5055
use std::sync::Arc;
5156

5257
/// Default table name for unnamed table
@@ -778,6 +783,9 @@ impl LogicalPlanBuilder {
778783
window_expr: impl IntoIterator<Item = impl Into<Expr>>,
779784
) -> Result<Self> {
780785
let window_expr = normalize_cols(window_expr, &self.plan)?;
786+
//handle Count(Expr:Wildcard) with DataFrame API
787+
let window_expr = handle_wildcard(window_expr)?;
788+
781789
let all_expr = window_expr.iter();
782790
validate_unique_names("Windows", all_expr.clone())?;
783791
let mut window_fields: Vec<DFField> = self.plan.schema().fields().clone();
@@ -801,6 +809,10 @@ impl LogicalPlanBuilder {
801809
) -> Result<Self> {
802810
let group_expr = normalize_cols(group_expr, &self.plan)?;
803811
let aggr_expr = normalize_cols(aggr_expr, &self.plan)?;
812+
813+
//handle Count(Expr:Wildcard) with DataFrame API
814+
let aggr_expr = handle_wildcard(aggr_expr)?;
815+
804816
Ok(Self::from(LogicalPlan::Aggregate(Aggregate::try_new(
805817
Arc::new(self.plan),
806818
group_expr,
@@ -986,6 +998,30 @@ impl LogicalPlanBuilder {
986998
}
987999
}
9881000

1001+
//handle Count(Expr:Wildcard) with DataFrame API
1002+
pub fn handle_wildcard(exprs: Vec<Expr>) -> Result<Vec<Expr>> {
1003+
let exprs: Vec<Expr> = exprs
1004+
.iter()
1005+
.map(|expr| match expr {
1006+
Expr::AggregateFunction(AggregateFunction {
1007+
fun: aggregate_function::AggregateFunction::Count,
1008+
args,
1009+
distinct,
1010+
filter,
1011+
}) if args.len() == 1 => match args[0] {
1012+
Expr::Wildcard => Expr::AggregateFunction(AggregateFunction {
1013+
fun: aggregate_function::AggregateFunction::Count,
1014+
args: vec![lit(COUNT_STAR_EXPANSION)],
1015+
distinct: *distinct,
1016+
filter: filter.clone(),
1017+
}),
1018+
_ => expr.clone(),
1019+
},
1020+
_ => expr.clone(),
1021+
})
1022+
.collect();
1023+
Ok(exprs)
1024+
}
9891025
/// Creates a schema for a join operation.
9901026
/// The fields from the left side are first
9911027
pub fn build_join_schema(
@@ -1315,7 +1351,7 @@ pub fn unnest(input: LogicalPlan, column: Column) -> Result<LogicalPlan> {
13151351

13161352
#[cfg(test)]
13171353
mod tests {
1318-
use crate::{expr, expr_fn::exists};
1354+
use crate::{count, expr, expr_fn::exists};
13191355
use arrow::datatypes::{DataType, Field};
13201356
use datafusion_common::SchemaError;
13211357

@@ -1324,6 +1360,35 @@ mod tests {
13241360
use super::*;
13251361
use crate::{col, in_subquery, lit, scalar_subquery, sum};
13261362

1363+
#[test]
1364+
fn window_wildcard() -> Result<()> {
1365+
let plan = table_scan(Some("employee_csv"), &employee_schema(), Some(vec![]))?
1366+
.window(vec![count(Expr::Wildcard)])?
1367+
.build()?;
1368+
1369+
let expected = "WindowAggr: windowExpr=[[COUNT(UInt8(1))]]\
1370+
\n TableScan: employee_csv projection=[]";
1371+
1372+
assert_eq!(expected, format!("{plan:?}"));
1373+
1374+
Ok(())
1375+
}
1376+
1377+
#[test]
1378+
fn count_wildcard() -> Result<()> {
1379+
let group_expr: Vec<Expr> = Vec::new();
1380+
let plan = table_scan(Some("employee_csv"), &employee_schema(), Some(vec![]))?
1381+
.aggregate(group_expr, vec![count(Expr::Wildcard)])?
1382+
.build()?;
1383+
1384+
let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\
1385+
\n TableScan: employee_csv projection=[]";
1386+
1387+
assert_eq!(expected, format!("{plan:?}"));
1388+
1389+
Ok(())
1390+
}
1391+
13271392
#[test]
13281393
fn plan_builder_simple() -> Result<()> {
13291394
let plan =

0 commit comments

Comments
 (0)