Skip to content

Commit 5c7c816

Browse files
committed
Naive deduplicating recursive CTE implementation
- introduce a new "result table" intermediate table storing all already emitted results - use existing physical operators to deduplicate the output of both static and recursive terms and remove from the recursive term output the already emitted results - add a simple test of a transitive closure on a cyclic graph
1 parent a4acec3 commit 5c7c816

File tree

7 files changed

+411
-41
lines changed

7 files changed

+411
-41
lines changed

datafusion/core/src/physical_planner.rs

Lines changed: 65 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,9 @@ use datafusion_catalog::ScanArgs;
6464
use datafusion_common::display::ToStringifiedPlan;
6565
use datafusion_common::format::ExplainAnalyzeLevel;
6666
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor};
67-
use datafusion_common::TableReference;
6867
use datafusion_common::{
6968
exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema,
70-
ScalarValue,
69+
NullEquality, ScalarValue, TableReference,
7170
};
7271
use datafusion_datasource::file_groups::FileGroup;
7372
use datafusion_datasource::memory::MemorySourceConfig;
@@ -84,7 +83,7 @@ use datafusion_expr::{
8483
WindowFrameBound, WriteOp,
8584
};
8685
use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr};
87-
use datafusion_physical_expr::expressions::Literal;
86+
use datafusion_physical_expr::expressions::{Column, Literal};
8887
use datafusion_physical_expr::{
8988
create_physical_sort_exprs, LexOrdering, PhysicalSortExpr,
9089
};
@@ -98,6 +97,7 @@ use datafusion_physical_plan::unnest::ListUnnest;
9897

9998
use async_trait::async_trait;
10099
use datafusion_physical_plan::async_func::{AsyncFuncExec, AsyncMapper};
100+
use datafusion_physical_plan::result_table::ResultTableExec;
101101
use futures::{StreamExt, TryStreamExt};
102102
use itertools::{multiunzip, Itertools};
103103
use log::debug;
@@ -1280,12 +1280,68 @@ impl DefaultPhysicalPlanner {
12801280
name, is_distinct, ..
12811281
}) => {
12821282
let [static_term, recursive_term] = children.two()?;
1283-
Arc::new(RecursiveQueryExec::try_new(
1284-
name.clone(),
1285-
static_term,
1286-
recursive_term,
1287-
*is_distinct,
1288-
)?)
1283+
let inner_schema = static_term.schema();
1284+
let group_by = PhysicalGroupBy::new_single(
1285+
inner_schema
1286+
.fields()
1287+
.iter()
1288+
.enumerate()
1289+
.map(|(i, f)| {
1290+
(Arc::new(Column::new(&f.name(), i)) as _, f.name().clone())
1291+
})
1292+
.collect(),
1293+
);
1294+
if *is_distinct {
1295+
// We deduplicate each input to avoid duplicated values
1296+
// And we remove from the recursive term the only emitted values i.e. the results table.
1297+
Arc::new(RecursiveQueryExec::try_new(
1298+
name.clone(),
1299+
Arc::new(AggregateExec::try_new(
1300+
AggregateMode::Final,
1301+
group_by.clone(),
1302+
Vec::new(),
1303+
Vec::new(),
1304+
static_term,
1305+
Arc::clone(&inner_schema),
1306+
)?),
1307+
Arc::new(HashJoinExec::try_new(
1308+
Arc::new(AggregateExec::try_new(
1309+
AggregateMode::Final,
1310+
group_by,
1311+
Vec::new(),
1312+
Vec::new(),
1313+
recursive_term,
1314+
Arc::clone(&inner_schema),
1315+
)?),
1316+
Arc::new(ResultTableExec::new(
1317+
"union".into(),
1318+
Arc::clone(&inner_schema),
1319+
)),
1320+
inner_schema
1321+
.fields()
1322+
.iter()
1323+
.enumerate()
1324+
.map(|(i, f)| {
1325+
let col = Arc::new(Column::new(&f.name(), i)) as _;
1326+
(Arc::clone(&col), col)
1327+
})
1328+
.collect(),
1329+
None,
1330+
&JoinType::LeftAnti,
1331+
None,
1332+
PartitionMode::CollectLeft,
1333+
NullEquality::NullEqualsNull,
1334+
)?),
1335+
true,
1336+
)?)
1337+
} else {
1338+
Arc::new(RecursiveQueryExec::try_new(
1339+
name.clone(),
1340+
static_term,
1341+
recursive_term,
1342+
false,
1343+
)?)
1344+
}
12891345
}
12901346

12911347
// N Children
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
start,end
2+
1,2
3+
2,3
4+
2,4
5+
4,1

datafusion/expr/src/logical_plan/builder.rs

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef};
5454
use datafusion_common::display::ToStringifiedPlan;
5555
use datafusion_common::file_options::file_type::FileType;
5656
use datafusion_common::{
57-
exec_err, get_target_functional_dependencies, internal_datafusion_err, not_impl_err,
57+
exec_err, get_target_functional_dependencies, internal_datafusion_err,
5858
plan_datafusion_err, plan_err, Column, Constraints, DFSchema, DFSchemaRef,
5959
NullEquality, Result, ScalarValue, TableReference, ToDFSchema, UnnestOptions,
6060
};
@@ -178,12 +178,6 @@ impl LogicalPlanBuilder {
178178
recursive_term: LogicalPlan,
179179
is_distinct: bool,
180180
) -> Result<Self> {
181-
// TODO: we need to do a bunch of validation here. Maybe more.
182-
if is_distinct {
183-
return not_impl_err!(
184-
"Recursive queries with a distinct 'UNION' (in which the previous iteration's results will be de-duplicated) is not supported"
185-
);
186-
}
187181
// Ensure that the static term and the recursive term have the same number of fields
188182
let static_fields_len = self.plan.schema().fields().len();
189183
let recursive_fields_len = recursive_term.schema().fields().len();

datafusion/physical-plan/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ pub mod placeholder_row;
8080
pub mod projection;
8181
pub mod recursive_query;
8282
pub mod repartition;
83+
pub mod result_table;
8384
pub mod sorts;
8485
pub mod spill;
8586
pub mod stream;

datafusion/physical-plan/src/recursive_query.rs

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,13 @@
1818
//! Defines the recursive query plan
1919
2020
use std::any::Any;
21+
use std::mem::take;
2122
use std::sync::Arc;
2223
use std::task::{Context, Poll};
2324

2425
use super::work_table::{ReservedBatches, WorkTable, WorkTableExec};
2526
use crate::execution_plan::{Boundedness, EmissionType};
27+
use crate::result_table::ResultTable;
2628
use crate::{
2729
metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet},
2830
PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics,
@@ -64,8 +66,9 @@ pub struct RecursiveQueryExec {
6466
static_term: Arc<dyn ExecutionPlan>,
6567
/// The dynamic part (recursive term)
6668
recursive_term: Arc<dyn ExecutionPlan>,
67-
/// Distinction
6869
is_distinct: bool,
70+
/// If is_distinct is true, holds the result table that saves all previous results
71+
result_table: Option<Arc<ResultTable>>,
6972
/// Execution metrics
7073
metrics: ExecutionPlanMetricsSet,
7174
/// Cache holding plan properties like equivalences, output partitioning etc.
@@ -77,13 +80,22 @@ impl RecursiveQueryExec {
7780
pub fn try_new(
7881
name: String,
7982
static_term: Arc<dyn ExecutionPlan>,
80-
recursive_term: Arc<dyn ExecutionPlan>,
83+
mut recursive_term: Arc<dyn ExecutionPlan>,
8184
is_distinct: bool,
8285
) -> Result<Self> {
8386
// Each recursive query needs its own work table
8487
let work_table = Arc::new(WorkTable::new());
8588
// Use the same work table for both the WorkTableExec and the recursive term
86-
let recursive_term = assign_work_table(recursive_term, Arc::clone(&work_table))?;
89+
recursive_term = assign_work_table(recursive_term, Arc::clone(&work_table))?;
90+
let result_table = if is_distinct {
91+
let result_table = Arc::new(ResultTable::new());
92+
// Use the same result table for both the ResultTableExec and the result term
93+
recursive_term =
94+
assign_work_table(recursive_term, Arc::clone(&result_table))?;
95+
Some(result_table)
96+
} else {
97+
None
98+
};
8799
let cache = Self::compute_properties(static_term.schema());
88100
Ok(RecursiveQueryExec {
89101
name,
@@ -93,6 +105,7 @@ impl RecursiveQueryExec {
93105
work_table,
94106
metrics: ExecutionPlanMetricsSet::new(),
95107
cache,
108+
result_table,
96109
})
97110
}
98111

@@ -193,6 +206,7 @@ impl ExecutionPlan for RecursiveQueryExec {
193206
Ok(Box::pin(RecursiveQueryStream::new(
194207
context,
195208
Arc::clone(&self.work_table),
209+
self.result_table.as_ref().map(Arc::clone),
196210
Arc::clone(&self.recursive_term),
197211
static_stream,
198212
baseline_metrics,
@@ -237,16 +251,16 @@ impl DisplayAs for RecursiveQueryExec {
237251
///
238252
/// while batch := static_stream.next():
239253
/// buffer.push(batch)
240-
/// yield buffer
254+
/// yield batch
241255
///
242256
/// while buffer.len() > 0:
243257
/// sender, receiver = Channel()
244-
/// register_continuation(handle_name, receiver)
258+
/// register_work_table(handle_name, receiver)
245259
/// sender.send(buffer.drain())
246260
/// recursive_stream = recursive_term.execute()
247261
/// while batch := recursive_stream.next():
248262
/// buffer.append(batch)
249-
/// yield buffer
263+
/// yield batch
250264
///
251265
struct RecursiveQueryStream {
252266
/// The context to be used for managing handlers & executing new tasks
@@ -268,6 +282,8 @@ struct RecursiveQueryStream {
268282
buffer: Vec<RecordBatch>,
269283
/// Tracks the memory used by the buffer
270284
reservation: MemoryReservation,
285+
/// The result table state, representing the table used for deduplication in case it is enabled
286+
results_table: Option<Arc<ResultTable>>,
271287
// /// Metrics.
272288
_baseline_metrics: BaselineMetrics,
273289
}
@@ -277,6 +293,7 @@ impl RecursiveQueryStream {
277293
fn new(
278294
task_context: Arc<TaskContext>,
279295
work_table: Arc<WorkTable>,
296+
results_table: Option<Arc<ResultTable>>,
280297
recursive_term: Arc<dyn ExecutionPlan>,
281298
static_stream: SendableRecordBatchStream,
282299
baseline_metrics: BaselineMetrics,
@@ -294,6 +311,7 @@ impl RecursiveQueryStream {
294311
buffer: vec![],
295312
reservation,
296313
_baseline_metrics: baseline_metrics,
314+
results_table,
297315
}
298316
}
299317

@@ -327,11 +345,21 @@ impl RecursiveQueryStream {
327345
return Poll::Ready(None);
328346
}
329347

348+
// Update the union table with the current buffer
349+
if self.results_table.is_some() {
350+
// Note it's fine to take the memory reservation here, we are not cloning the underlying data,
351+
// and the result table is going to outlive the work table.
352+
let buffer = self.buffer.clone();
353+
let reservation = self.reservation.take();
354+
self.results_table
355+
.as_mut()
356+
.unwrap()
357+
.append(buffer, reservation);
358+
}
359+
330360
// Update the work table with the current buffer
331-
let reserved_batches = ReservedBatches::new(
332-
std::mem::take(&mut self.buffer),
333-
self.reservation.take(),
334-
);
361+
let reserved_batches =
362+
ReservedBatches::new(take(&mut self.buffer), self.reservation.take());
335363
self.work_table.update(reserved_batches);
336364

337365
// We always execute (and re-execute iteratively) the first partition.
@@ -345,9 +373,9 @@ impl RecursiveQueryStream {
345373
}
346374
}
347375

348-
fn assign_work_table(
376+
fn assign_work_table<T: Any + Send + Sync>(
349377
plan: Arc<dyn ExecutionPlan>,
350-
work_table: Arc<WorkTable>,
378+
work_table: Arc<T>,
351379
) -> Result<Arc<dyn ExecutionPlan>> {
352380
let mut work_table_refs = 0;
353381
plan.transform_down(|plan| {
@@ -380,7 +408,7 @@ fn assign_work_table(
380408
fn reset_plan_states(plan: Arc<dyn ExecutionPlan>) -> Result<Arc<dyn ExecutionPlan>> {
381409
plan.transform_up(|plan| {
382410
// WorkTableExec's states have already been updated correctly.
383-
if plan.as_any().is::<WorkTableExec>() {
411+
if plan.as_any().is::<WorkTableExec>() || plan.as_any().is::<ResultTable>() {
384412
Ok(Transformed::no(plan))
385413
} else {
386414
let new_plan = Arc::clone(&plan).reset_state()?;

0 commit comments

Comments
 (0)