diff --git a/datafusion/physical-plan/src/coalesce_partitions.rs b/datafusion/physical-plan/src/coalesce_partitions.rs index 9290d725165e..339955461243 100644 --- a/datafusion/physical-plan/src/coalesce_partitions.rs +++ b/datafusion/physical-plan/src/coalesce_partitions.rs @@ -351,11 +351,14 @@ impl ExecutionPlan for CoalescePartitionsExec { mod tests { use super::*; use crate::test::exec::{ - BlockingExec, PanicExec, assert_strong_count_converges_to_zero, + BarrierExec, BlockingExec, PanicExec, assert_strong_count_converges_to_zero, }; use crate::test::{self, assert_is_pending}; use crate::{collect, common}; + use std::time::Duration; + + use arrow::array::RecordBatch; use arrow::datatypes::{DataType, Field, Schema}; use futures::FutureExt; @@ -390,6 +393,45 @@ mod tests { Ok(()) } + #[tokio::test] + async fn drops_input_plan_after_input_streams_start() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); + let input_partitions = 2; + let batch = RecordBatch::new_empty(Arc::clone(&schema)); + let input = Arc::new( + BarrierExec::new(vec![vec![batch]; input_partitions], schema) + .without_start_barrier() + .with_finish_barrier() + .with_log(false), + ); + let refs = Arc::downgrade(&input); + + let input_plan: Arc = Arc::clone(&input); + let coalesce = CoalescePartitionsExec::new(input_plan); + let stream = coalesce.execute(0, task_ctx)?; + drop(coalesce); + + tokio::time::timeout(Duration::from_secs(5), async { + // Why not `wait_finish` here: that releases the barrier which lets the input tasks + // finish, which drops the input Arcs and hides the bug. + while !input.is_finish_barrier_reached() { + tokio::task::yield_now().await; + } + }) + .await + .expect("input streams should reach pending"); + + drop(input); + + assert_strong_count_converges_to_zero(refs).await; + + drop(stream); + + Ok(()) + } + #[tokio::test] async fn test_drop_cancel() -> Result<()> { let task_ctx = Arc::new(TaskContext::default()); diff --git a/datafusion/physical-plan/src/stream.rs b/datafusion/physical-plan/src/stream.rs index 4b7e707fcced..9139a6dd0479 100644 --- a/datafusion/physical-plan/src/stream.rs +++ b/datafusion/physical-plan/src/stream.rs @@ -330,6 +330,11 @@ impl RecordBatchReceiverStreamBuilder { context: Arc, ) { let output = self.tx(); + let input_display = if log::log_enabled!(log::Level::Debug) { + displayable(input.as_ref()).one_line().to_string() + } else { + String::new() + }; self.inner.spawn(async move { let mut stream = match input.execute(partition, context) { @@ -338,14 +343,18 @@ impl RecordBatchReceiverStreamBuilder { // is no place to send the error and no reason to continue. output.send(Err(e)).await.ok(); debug!( - "Stopping execution: error executing input: {}", - displayable(input.as_ref()).one_line() + "Stopping execution: error executing input: {input_display}", ); return Ok(()); } Ok(stream) => stream, }; + // Drop the input early, as soon as we're done with it. + // Holding on to it can cause delays in cancelling the child plan when the query is + // cancelled. + drop(input); + // Transfer batches from inner stream to the output tx // immediately. while let Some(item) = stream.next().await { @@ -355,8 +364,7 @@ impl RecordBatchReceiverStreamBuilder { // place to send the error and no reason to continue. if output.send(item).await.is_err() { debug!( - "Stopping execution: output is gone, plan cancelling: {}", - displayable(input.as_ref()).one_line() + "Stopping execution: output is gone, plan cancelling: {input_display}", ); return Ok(()); } @@ -364,10 +372,7 @@ impl RecordBatchReceiverStreamBuilder { // Stop after the first error is encountered (Don't // drive all streams to completion) if is_err { - debug!( - "Stopping execution: plan returned error: {}", - displayable(input.as_ref()).one_line() - ); + debug!("Stopping execution: plan returned error: {input_display}"); return Ok(()); } }