Skip to content

Commit 26e1b20

Browse files
mingmwangalamb
andauthored
Add OuterReferenceColumn to Expr to represent correlated expression (#5593)
* Add OuterReferenceColumn * fix intg test * resolve review comments, avoid unncessary clone to the outer_query_schema * add more UT to cover more complex subqueryies, fix order by correlated columns in subquery * fix cargo fmt * fix: logical merge conflict --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent f4c0edb commit 26e1b20

30 files changed

+678
-179
lines changed

datafusion/core/src/datasource/listing/helpers.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ impl ExpressionVisitor for ApplicabilityVisitor<'_> {
8181
}
8282
Expr::Literal(_)
8383
| Expr::Alias(_, _)
84+
| Expr::OuterReferenceColumn(_, _)
8485
| Expr::ScalarVariable(_, _)
8586
| Expr::Not(_)
8687
| Expr::IsNotNull(_)

datafusion/core/src/physical_plan/planner.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,9 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result<String> {
346346
Expr::Placeholder { .. } => Err(DataFusionError::Internal(
347347
"Create physical name does not support placeholder".to_string(),
348348
)),
349+
Expr::OuterReferenceColumn(_, _) => Err(DataFusionError::Internal(
350+
"Create physical name does not support OuterReferenceColumn".to_string(),
351+
)),
349352
}
350353
}
351354

datafusion/core/tests/sql/mod.rs

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,82 @@ fn create_join_context(
231231
Ok(ctx)
232232
}
233233

234+
fn create_sub_query_join_context(
235+
column_outer: &str,
236+
column_inner_left: &str,
237+
column_inner_right: &str,
238+
repartition_joins: bool,
239+
) -> Result<SessionContext> {
240+
let ctx = SessionContext::with_config(
241+
SessionConfig::new()
242+
.with_repartition_joins(repartition_joins)
243+
.with_target_partitions(2)
244+
.with_batch_size(4096),
245+
);
246+
247+
let t0_schema = Arc::new(Schema::new(vec![
248+
Field::new(column_outer, DataType::UInt32, true),
249+
Field::new("t0_name", DataType::Utf8, true),
250+
Field::new("t0_int", DataType::UInt32, true),
251+
]));
252+
let t0_data = RecordBatch::try_new(
253+
t0_schema,
254+
vec![
255+
Arc::new(UInt32Array::from_slice([11, 22, 33, 44])),
256+
Arc::new(StringArray::from(vec![
257+
Some("a"),
258+
Some("b"),
259+
Some("c"),
260+
Some("d"),
261+
])),
262+
Arc::new(UInt32Array::from_slice([1, 2, 3, 4])),
263+
],
264+
)?;
265+
ctx.register_batch("t0", t0_data)?;
266+
267+
let t1_schema = Arc::new(Schema::new(vec![
268+
Field::new(column_inner_left, DataType::UInt32, true),
269+
Field::new("t1_name", DataType::Utf8, true),
270+
Field::new("t1_int", DataType::UInt32, true),
271+
]));
272+
let t1_data = RecordBatch::try_new(
273+
t1_schema,
274+
vec![
275+
Arc::new(UInt32Array::from_slice([11, 22, 33, 44])),
276+
Arc::new(StringArray::from(vec![
277+
Some("a"),
278+
Some("b"),
279+
Some("c"),
280+
Some("d"),
281+
])),
282+
Arc::new(UInt32Array::from_slice([1, 2, 3, 4])),
283+
],
284+
)?;
285+
ctx.register_batch("t1", t1_data)?;
286+
287+
let t2_schema = Arc::new(Schema::new(vec![
288+
Field::new(column_inner_right, DataType::UInt32, true),
289+
Field::new("t2_name", DataType::Utf8, true),
290+
Field::new("t2_int", DataType::UInt32, true),
291+
]));
292+
let t2_data = RecordBatch::try_new(
293+
t2_schema,
294+
vec![
295+
Arc::new(UInt32Array::from_slice([11, 22, 44, 55])),
296+
Arc::new(StringArray::from(vec![
297+
Some("z"),
298+
Some("y"),
299+
Some("x"),
300+
Some("w"),
301+
])),
302+
Arc::new(UInt32Array::from_slice([3, 1, 3, 3])),
303+
],
304+
)?;
305+
ctx.register_batch("t2", t2_data)?;
306+
307+
Ok(ctx)
308+
}
309+
234310
fn create_left_semi_anti_join_context_with_null_ids(
235311
column_left: &str,
236312
column_right: &str,

datafusion/core/tests/sql/subqueries.rs

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,3 +213,228 @@ async fn subquery_not_allowed() -> Result<()> {
213213

214214
Ok(())
215215
}
216+
217+
#[tokio::test]
218+
async fn support_agg_correlated_columns() -> Result<()> {
219+
let ctx = create_join_context("t1_id", "t2_id", true)?;
220+
221+
let sql = "SELECT t1_id, t1_name FROM t1 WHERE EXISTS (SELECT sum(t1.t1_int + t2.t2_id) FROM t2 WHERE t1.t1_name = t2.t2_name)";
222+
let msg = format!("Creating logical plan for '{sql}'");
223+
let dataframe = ctx.sql(sql).await.expect(&msg);
224+
let plan = dataframe.into_optimized_plan()?;
225+
226+
let expected = vec![
227+
"Filter: EXISTS (<subquery>) [t1_id:UInt32;N, t1_name:Utf8;N]",
228+
" Subquery: [SUM(outer_ref(t1.t1_int) + t2.t2_id):UInt64;N]",
229+
" Projection: SUM(outer_ref(t1.t1_int) + t2.t2_id) [SUM(outer_ref(t1.t1_int) + t2.t2_id):UInt64;N]",
230+
" Aggregate: groupBy=[[]], aggr=[[SUM(outer_ref(t1.t1_int) + t2.t2_id)]] [SUM(outer_ref(t1.t1_int) + t2.t2_id):UInt64;N]",
231+
" Filter: outer_ref(t1.t1_name) = t2.t2_name [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
232+
" TableScan: t2 [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
233+
" TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, t1_name:Utf8;N]",
234+
];
235+
let formatted = plan.display_indent_schema().to_string();
236+
let actual: Vec<&str> = formatted.trim().lines().collect();
237+
assert_eq!(
238+
expected, actual,
239+
"\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
240+
);
241+
242+
Ok(())
243+
}
244+
245+
#[tokio::test]
246+
async fn support_agg_correlated_columns2() -> Result<()> {
247+
let ctx = create_join_context("t1_id", "t2_id", true)?;
248+
249+
let sql = "SELECT t1_id, t1_name FROM t1 WHERE EXISTS (SELECT count(*) FROM t2 WHERE t1.t1_name = t2.t2_name having sum(t1_int + t2_id) >0)";
250+
let msg = format!("Creating logical plan for '{sql}'");
251+
let dataframe = ctx.sql(sql).await.expect(&msg);
252+
let plan = dataframe.into_optimized_plan()?;
253+
254+
let expected = vec![
255+
"Filter: EXISTS (<subquery>) [t1_id:UInt32;N, t1_name:Utf8;N]",
256+
" Subquery: [COUNT(UInt8(1)):Int64;N]",
257+
" Projection: COUNT(UInt8(1)) [COUNT(UInt8(1)):Int64;N]",
258+
" Filter: CAST(SUM(outer_ref(t1.t1_int) + t2.t2_id) AS Int64) > Int64(0) [COUNT(UInt8(1)):Int64;N, SUM(outer_ref(t1.t1_int) + t2.t2_id):UInt64;N]",
259+
" Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)), SUM(outer_ref(t1.t1_int) + t2.t2_id)]] [COUNT(UInt8(1)):Int64;N, SUM(outer_ref(t1.t1_int) + t2.t2_id):UInt64;N]",
260+
" Filter: outer_ref(t1.t1_name) = t2.t2_name [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
261+
" TableScan: t2 [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
262+
" TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, t1_name:Utf8;N]",
263+
];
264+
let formatted = plan.display_indent_schema().to_string();
265+
let actual: Vec<&str> = formatted.trim().lines().collect();
266+
assert_eq!(
267+
expected, actual,
268+
"\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
269+
);
270+
271+
Ok(())
272+
}
273+
274+
#[tokio::test]
275+
async fn support_join_correlated_columns() -> Result<()> {
276+
let ctx = create_sub_query_join_context("t0_id", "t1_id", "t2_id", true)?;
277+
let sql = "SELECT t0_id, t0_name FROM t0 WHERE EXISTS (SELECT 1 FROM t1 INNER JOIN t2 ON(t1.t1_id = t2.t2_id and t1.t1_name = t0.t0_name))";
278+
let msg = format!("Creating logical plan for '{sql}'");
279+
let dataframe = ctx.sql(sql).await.expect(&msg);
280+
let plan = dataframe.into_optimized_plan()?;
281+
282+
let expected = vec![
283+
"Filter: EXISTS (<subquery>) [t0_id:UInt32;N, t0_name:Utf8;N]",
284+
" Subquery: [Int64(1):Int64]",
285+
" Projection: Int64(1) [Int64(1):Int64]",
286+
" Inner Join: Filter: t1.t1_id = t2.t2_id AND t1.t1_name = outer_ref(t0.t0_name) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
287+
" TableScan: t1 [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
288+
" TableScan: t2 [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
289+
" TableScan: t0 projection=[t0_id, t0_name] [t0_id:UInt32;N, t0_name:Utf8;N]",
290+
];
291+
let formatted = plan.display_indent_schema().to_string();
292+
let actual: Vec<&str> = formatted.trim().lines().collect();
293+
assert_eq!(
294+
expected, actual,
295+
"\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
296+
);
297+
298+
Ok(())
299+
}
300+
301+
#[tokio::test]
302+
async fn support_join_correlated_columns2() -> Result<()> {
303+
let ctx = create_sub_query_join_context("t0_id", "t1_id", "t2_id", true)?;
304+
let sql = "SELECT t0_id, t0_name FROM t0 WHERE EXISTS (SELECT 1 FROM t1 INNER JOIN (select * from t2 where t2.t2_name = t0.t0_name) as t2 ON(t1.t1_id = t2.t2_id ))";
305+
let msg = format!("Creating logical plan for '{sql}'");
306+
let dataframe = ctx.sql(sql).await.expect(&msg);
307+
let plan = dataframe.into_optimized_plan()?;
308+
309+
let expected = vec![
310+
"Filter: EXISTS (<subquery>) [t0_id:UInt32;N, t0_name:Utf8;N]",
311+
" Subquery: [Int64(1):Int64]",
312+
" Projection: Int64(1) [Int64(1):Int64]",
313+
" Inner Join: Filter: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
314+
" TableScan: t1 [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
315+
" SubqueryAlias: t2 [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
316+
" Projection: t2.t2_id, t2.t2_name, t2.t2_int [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
317+
" Filter: t2.t2_name = outer_ref(t0.t0_name) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
318+
" TableScan: t2 [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
319+
" TableScan: t0 projection=[t0_id, t0_name] [t0_id:UInt32;N, t0_name:Utf8;N]",
320+
];
321+
let formatted = plan.display_indent_schema().to_string();
322+
let actual: Vec<&str> = formatted.trim().lines().collect();
323+
assert_eq!(
324+
expected, actual,
325+
"\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
326+
);
327+
328+
Ok(())
329+
}
330+
331+
#[tokio::test]
332+
async fn support_order_by_correlated_columns() -> Result<()> {
333+
let ctx = create_join_context("t1_id", "t2_id", true)?;
334+
335+
let sql = "SELECT t1_id, t1_name FROM t1 WHERE EXISTS (SELECT * FROM t2 WHERE t2_id >= t1_id order by t1_id)";
336+
let msg = format!("Creating logical plan for '{sql}'");
337+
let dataframe = ctx.sql(sql).await.expect(&msg);
338+
let plan = dataframe.into_optimized_plan()?;
339+
340+
let expected = vec![
341+
"Filter: EXISTS (<subquery>) [t1_id:UInt32;N, t1_name:Utf8;N]",
342+
" Subquery: [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
343+
" Sort: outer_ref(t1.t1_id) ASC NULLS LAST [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
344+
" Projection: t2.t2_id, t2.t2_name, t2.t2_int [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
345+
" Filter: t2.t2_id >= outer_ref(t1.t1_id) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
346+
" TableScan: t2 [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
347+
" TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, t1_name:Utf8;N]",
348+
];
349+
let formatted = plan.display_indent_schema().to_string();
350+
let actual: Vec<&str> = formatted.trim().lines().collect();
351+
assert_eq!(
352+
expected, actual,
353+
"\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
354+
);
355+
356+
Ok(())
357+
}
358+
359+
#[tokio::test]
360+
async fn support_limit_subquery() -> Result<()> {
361+
let ctx = create_join_context("t1_id", "t2_id", true)?;
362+
363+
let sql = "SELECT t1_id, t1_name FROM t1 WHERE EXISTS (SELECT * FROM t2 WHERE t2_id = t1_id limit 1)";
364+
let msg = format!("Creating logical plan for '{sql}'");
365+
let dataframe = ctx.sql(sql).await.expect(&msg);
366+
let plan = dataframe.into_optimized_plan()?;
367+
368+
let expected = vec![
369+
"Filter: EXISTS (<subquery>) [t1_id:UInt32;N, t1_name:Utf8;N]",
370+
" Subquery: [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
371+
" Limit: skip=0, fetch=1 [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
372+
" Projection: t2.t2_id, t2.t2_name, t2.t2_int [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
373+
" Filter: t2.t2_id = outer_ref(t1.t1_id) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
374+
" TableScan: t2 [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
375+
" TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, t1_name:Utf8;N]",
376+
];
377+
let formatted = plan.display_indent_schema().to_string();
378+
let actual: Vec<&str> = formatted.trim().lines().collect();
379+
assert_eq!(
380+
expected, actual,
381+
"\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
382+
);
383+
384+
let sql = "SELECT t1_id, t1_name FROM t1 WHERE t1_id in (SELECT t2_id FROM t2 where t1_name = t2_name limit 10)";
385+
let msg = format!("Creating logical plan for '{sql}'");
386+
let dataframe = ctx.sql(sql).await.expect(&msg);
387+
let plan = dataframe.into_optimized_plan()?;
388+
389+
let expected = vec![
390+
"Filter: t1.t1_id IN (<subquery>) [t1_id:UInt32;N, t1_name:Utf8;N]",
391+
" Subquery: [t2_id:UInt32;N]",
392+
" Limit: skip=0, fetch=10 [t2_id:UInt32;N]",
393+
" Projection: t2.t2_id [t2_id:UInt32;N]",
394+
" Filter: outer_ref(t1.t1_name) = t2.t2_name [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
395+
" TableScan: t2 [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
396+
" TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, t1_name:Utf8;N]",
397+
];
398+
let formatted = plan.display_indent_schema().to_string();
399+
let actual: Vec<&str> = formatted.trim().lines().collect();
400+
assert_eq!(
401+
expected, actual,
402+
"\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
403+
);
404+
405+
Ok(())
406+
}
407+
408+
#[tokio::test]
409+
async fn support_union_subquery() -> Result<()> {
410+
let ctx = create_join_context("t1_id", "t2_id", true)?;
411+
412+
let sql = "SELECT t1_id, t1_name FROM t1 WHERE EXISTS \
413+
(SELECT * FROM t2 WHERE t2_id = t1_id UNION ALL \
414+
SELECT * FROM t2 WHERE upper(t2_name) = upper(t1.t1_name))";
415+
416+
let msg = format!("Creating logical plan for '{sql}'");
417+
let dataframe = ctx.sql(sql).await.expect(&msg);
418+
let plan = dataframe.into_optimized_plan()?;
419+
420+
let expected = vec![
421+
"Filter: EXISTS (<subquery>) [t1_id:UInt32;N, t1_name:Utf8;N]",
422+
" Subquery: [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
423+
" Union [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
424+
" Projection: t2.t2_id, t2.t2_name, t2.t2_int [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
425+
" Filter: t2.t2_id = outer_ref(t1.t1_id) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
426+
" TableScan: t2 [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
427+
" Projection: t2.t2_id, t2.t2_name, t2.t2_int [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
428+
" Filter: upper(t2.t2_name) = upper(outer_ref(t1.t1_name)) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
429+
" TableScan: t2 [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
430+
" TableScan: t1 projection=[t1_id, t1_name] [t1_id:UInt32;N, t1_name:Utf8;N]",
431+
];
432+
let formatted = plan.display_indent_schema().to_string();
433+
let actual: Vec<&str> = formatted.trim().lines().collect();
434+
assert_eq!(
435+
expected, actual,
436+
"\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
437+
);
438+
439+
Ok(())
440+
}

datafusion/expr/src/expr.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use crate::aggregate_function;
2121
use crate::built_in_function;
2222
use crate::expr_fn::binary_expr;
2323
use crate::logical_plan::Subquery;
24-
use crate::utils::expr_to_columns;
24+
use crate::utils::{expr_to_columns, find_out_reference_exprs};
2525
use crate::window_frame;
2626
use crate::window_function;
2727
use crate::AggregateUDF;
@@ -220,6 +220,9 @@ pub enum Expr {
220220
/// The type the parameter will be filled in with
221221
data_type: Option<DataType>,
222222
},
223+
/// A place holder which hold a reference to a qualified field
224+
/// in the outer query, used for correlated sub queries.
225+
OuterReferenceColumn(DataType, Column),
223226
}
224227

225228
/// Binary expression
@@ -567,6 +570,7 @@ impl Expr {
567570
Expr::Case { .. } => "Case",
568571
Expr::Cast { .. } => "Cast",
569572
Expr::Column(..) => "Column",
573+
Expr::OuterReferenceColumn(_, _) => "Outer",
570574
Expr::Exists { .. } => "Exists",
571575
Expr::GetIndexedField { .. } => "GetIndexedField",
572576
Expr::GroupingSet(..) => "GroupingSet",
@@ -785,6 +789,11 @@ impl Expr {
785789

786790
Ok(using_columns)
787791
}
792+
793+
/// Return true when the expression contains out reference(correlated) expressions.
794+
pub fn contains_outer(&self) -> bool {
795+
!find_out_reference_exprs(self).is_empty()
796+
}
788797
}
789798

790799
impl Not for Expr {
@@ -830,6 +839,7 @@ impl fmt::Debug for Expr {
830839
match self {
831840
Expr::Alias(expr, alias) => write!(f, "{expr:?} AS {alias}"),
832841
Expr::Column(c) => write!(f, "{c}"),
842+
Expr::OuterReferenceColumn(_, c) => write!(f, "outer_ref({})", c),
833843
Expr::ScalarVariable(_, var_names) => write!(f, "{}", var_names.join(".")),
834844
Expr::Literal(v) => write!(f, "{v:?}"),
835845
Expr::Case(case) => {
@@ -1110,6 +1120,7 @@ fn create_name(e: &Expr) -> Result<String> {
11101120
match e {
11111121
Expr::Alias(_, name) => Ok(name.clone()),
11121122
Expr::Column(c) => Ok(c.flat_name()),
1123+
Expr::OuterReferenceColumn(_, c) => Ok(format!("outer_ref({})", c.flat_name())),
11131124
Expr::ScalarVariable(_, variable_names) => Ok(variable_names.join(".")),
11141125
Expr::Literal(value) => Ok(format!("{value:?}")),
11151126
Expr::BinaryExpr(binary_expr) => {

0 commit comments

Comments
 (0)