Skip to content

Commit 58a72b8

Browse files
committed
merge main@public
2 parents e969500 + c82b49b commit 58a72b8

File tree

3 files changed

+157
-20
lines changed

3 files changed

+157
-20
lines changed

datafusion/core/tests/dataframe/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5274,7 +5274,7 @@ async fn test_dataframe_placeholder_column_parameter() -> Result<()> {
52745274
assert_snapshot!(
52755275
actual,
52765276
@r"
5277-
Projection: Int32(3) AS $1 [$1:Null;N]
5277+
Projection: Int32(3) AS $1 [$1:Int32]
52785278
EmptyRelation: rows=1 []
52795279
"
52805280
);

datafusion/core/tests/sql/select.rs

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
use std::collections::HashMap;
1919

2020
use super::*;
21-
use datafusion::assert_batches_eq;
2221
use datafusion_common::{metadata::ScalarAndMetadata, ParamValues, ScalarValue};
2322
use insta::assert_snapshot;
2423

@@ -343,26 +342,53 @@ async fn test_query_parameters_with_metadata() -> Result<()> {
343342
]))
344343
.unwrap();
345344

346-
// df_with_params_replaced.schema() is not correct here
347-
// https://github.com/apache/datafusion/issues/18102
348-
let batches = df_with_params_replaced.clone().collect().await.unwrap();
349-
let schema = batches[0].schema();
350-
345+
let schema = df_with_params_replaced.schema();
351346
assert_eq!(schema.field(0).data_type(), &DataType::UInt32);
352347
assert_eq!(schema.field(0).metadata(), &metadata1);
353348
assert_eq!(schema.field(1).data_type(), &DataType::Utf8);
354349
assert_eq!(schema.field(1).metadata(), &metadata2);
355350

356-
assert_batches_eq!(
357-
[
358-
"+----+-----+",
359-
"| $1 | $2 |",
360-
"+----+-----+",
361-
"| 1 | two |",
362-
"+----+-----+",
363-
],
364-
&batches
365-
);
351+
let batches = df_with_params_replaced.collect().await.unwrap();
352+
assert_snapshot!(batches_to_sort_string(&batches), @r"
353+
+----+-----+
354+
| $1 | $2 |
355+
+----+-----+
356+
| 1 | two |
357+
+----+-----+
358+
");
359+
360+
Ok(())
361+
}
362+
363+
/// Test for https://github.com/apache/datafusion/issues/18102
364+
#[tokio::test]
365+
async fn test_query_parameters_in_values_list_relation() -> Result<()> {
366+
let ctx = SessionContext::new();
367+
368+
let df = ctx
369+
.sql("SELECT a, b FROM (VALUES ($1, $2)) AS t(a, b)")
370+
.await
371+
.unwrap();
372+
373+
let df_with_params_replaced = df
374+
.with_param_values(ParamValues::List(vec![
375+
ScalarAndMetadata::new(ScalarValue::UInt32(Some(1)), None),
376+
ScalarAndMetadata::new(ScalarValue::Utf8(Some("two".to_string())), None),
377+
]))
378+
.unwrap();
379+
380+
let schema = df_with_params_replaced.schema();
381+
assert_eq!(schema.field(0).data_type(), &DataType::UInt32);
382+
assert_eq!(schema.field(1).data_type(), &DataType::Utf8);
383+
384+
let batches = df_with_params_replaced.collect().await.unwrap();
385+
assert_snapshot!(batches_to_sort_string(&batches), @r"
386+
+---+-----+
387+
| a | b |
388+
+---+-----+
389+
| 1 | two |
390+
+---+-----+
391+
");
366392

367393
Ok(())
368394
}

datafusion/expr/src/logical_plan/plan.rs

Lines changed: 114 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -633,8 +633,36 @@ impl LogicalPlan {
633633
LogicalPlan::Dml(_) => Ok(self),
634634
LogicalPlan::Copy(_) => Ok(self),
635635
LogicalPlan::Values(Values { schema, values }) => {
636-
// todo it isn't clear why the schema is not recomputed here
637-
Ok(LogicalPlan::Values(Values { schema, values }))
636+
// Using `values` alone cannot compute correct schema for the plan. For example:
637+
// Projection: col_1, col_2
638+
// Values: (Float32(1), Float32(10)), (Float32(100), Float32(10))
639+
//
640+
// Thus, we need to recompute a new schema from `values` and retain some
641+
// information from the original schema.
642+
let new_plan = LogicalPlanBuilder::values(values.clone())?.build()?;
643+
644+
let qualified_fields = schema
645+
.iter()
646+
.zip(new_plan.schema().fields())
647+
.map(|((table_ref, old_field), new_field)| {
648+
let field = old_field
649+
.as_ref()
650+
.clone()
651+
.with_data_type(new_field.data_type().clone())
652+
.with_nullable(new_field.is_nullable());
653+
(table_ref.cloned(), Arc::new(field))
654+
})
655+
.collect::<Vec<_>>();
656+
657+
let schema = DFSchema::new_with_metadata(
658+
qualified_fields,
659+
schema.metadata().clone(),
660+
)?
661+
.with_functional_dependencies(schema.functional_dependencies().clone())?;
662+
Ok(LogicalPlan::Values(Values {
663+
schema: Arc::new(schema),
664+
values,
665+
}))
638666
}
639667
LogicalPlan::Filter(Filter { predicate, input }) => {
640668
Filter::try_new(predicate, input).map(LogicalPlan::Filter)
@@ -1471,7 +1499,10 @@ impl LogicalPlan {
14711499
// Preserve name to avoid breaking column references to this expression
14721500
Ok(transformed_expr.update_data(|expr| original_name.restore(expr)))
14731501
}
1474-
})
1502+
})?
1503+
// always recompute the schema to ensure the changed in the schema's field should be
1504+
// poplulated to the plan's parent
1505+
.map_data(|plan| plan.recompute_schema())
14751506
})
14761507
.map(|res| res.data)
14771508
}
@@ -4247,6 +4278,7 @@ mod tests {
42474278
use super::*;
42484279
use crate::builder::LogicalTableSource;
42494280
use crate::logical_plan::table_scan;
4281+
use crate::select_expr::SelectExpr;
42504282
use crate::test::function_stub::{count, count_udaf};
42514283
use crate::{
42524284
binary_expr, col, exists, in_subquery, lit, placeholder, scalar_subquery,
@@ -4825,6 +4857,85 @@ mod tests {
48254857
.expect_err("prepared field metadata mismatch unexpectedly succeeded");
48264858
}
48274859

4860+
#[test]
4861+
fn test_replace_placeholder_values_relation_valid_schema() {
4862+
// SELECT a, b, c, d FROM (VALUES (1), ($1), ($2), ($3 + $4)) AS t(a, b, c, d);
4863+
let plan = LogicalPlanBuilder::values(vec![vec![
4864+
lit(1),
4865+
placeholder("$1"),
4866+
placeholder("$2"),
4867+
binary_expr(placeholder("$3"), Operator::Plus, placeholder("$4")),
4868+
]])
4869+
.unwrap()
4870+
.project(vec![
4871+
col("column1").alias("a"),
4872+
col("column2").alias("b"),
4873+
col("column3").alias("c"),
4874+
col("column4").alias("d"),
4875+
])
4876+
.unwrap()
4877+
.alias("t")
4878+
.unwrap()
4879+
.project(vec![col("a"), col("b"), col("c"), col("d")])
4880+
.unwrap()
4881+
.build()
4882+
.unwrap();
4883+
4884+
// original
4885+
assert_snapshot!(plan.display_indent_schema(), @r"
4886+
Projection: t.a, t.b, t.c, t.d [a:Int32;N, b:Null;N, c:Null;N, d:Int64;N]
4887+
SubqueryAlias: t [a:Int32;N, b:Null;N, c:Null;N, d:Int64;N]
4888+
Projection: column1 AS a, column2 AS b, column3 AS c, column4 AS d [a:Int32;N, b:Null;N, c:Null;N, d:Int64;N]
4889+
Values: (Int32(1), $1, $2, $3 + $4) [column1:Int32;N, column2:Null;N, column3:Null;N, column4:Int64;N]
4890+
");
4891+
4892+
let plan = plan
4893+
.with_param_values(vec![
4894+
ScalarValue::from(1i32),
4895+
ScalarValue::from("s"),
4896+
ScalarValue::from(3),
4897+
ScalarValue::from(4),
4898+
])
4899+
.unwrap();
4900+
4901+
// replaced
4902+
assert_snapshot!(plan.display_indent_schema(), @r#"
4903+
Projection: t.a, t.b, t.c, t.d [a:Int32;N, b:Int32;N, c:Utf8;N, d:Int32;N]
4904+
SubqueryAlias: t [a:Int32;N, b:Int32;N, c:Utf8;N, d:Int32;N]
4905+
Projection: column1 AS a, column2 AS b, column3 AS c, column4 AS d [a:Int32;N, b:Int32;N, c:Utf8;N, d:Int32;N]
4906+
Values: (Int32(1), Int32(1) AS $1, Utf8("s") AS $2, Int32(3) + Int32(4) AS $3 + $4) [column1:Int32;N, column2:Int32;N, column3:Utf8;N, column4:Int32;N]
4907+
"#);
4908+
}
4909+
4910+
#[test]
4911+
fn test_replace_placeholder_empty_relation_valid_schema() {
4912+
// SELECT $1, $2;
4913+
let plan = LogicalPlanBuilder::empty(false)
4914+
.project(vec![
4915+
SelectExpr::from(placeholder("$1")),
4916+
SelectExpr::from(placeholder("$2")),
4917+
])
4918+
.unwrap()
4919+
.build()
4920+
.unwrap();
4921+
4922+
// original
4923+
assert_snapshot!(plan.display_indent_schema(), @r"
4924+
Projection: $1, $2 [$1:Null;N, $2:Null;N]
4925+
EmptyRelation: rows=0 []
4926+
");
4927+
4928+
let plan = plan
4929+
.with_param_values(vec![ScalarValue::from(1i32), ScalarValue::from("s")])
4930+
.unwrap();
4931+
4932+
// replaced
4933+
assert_snapshot!(plan.display_indent_schema(), @r#"
4934+
Projection: Int32(1) AS $1, Utf8("s") AS $2 [$1:Int32, $2:Utf8]
4935+
EmptyRelation: rows=0 []
4936+
"#);
4937+
}
4938+
48284939
#[test]
48294940
fn test_nullable_schema_after_grouping_set() {
48304941
let schema = Schema::new(vec![

0 commit comments

Comments
 (0)