diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index 8174f71c31af..7194fb0c3267 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -97,6 +97,8 @@ struct PartitionChannels { reservation: SharedMemoryReservation, /// Spill manager for handling disk spills for this output partition spill_manager: Arc, + /// Baseline metrics shared by this output partition + baseline_metrics: BaselineMetrics, } #[derive(Debug)] @@ -236,6 +238,7 @@ impl RepartitionExecState { spill_metrics, input.schema(), )); + let baseline_metrics = BaselineMetrics::new(&metrics, partition); channels.insert( partition, PartitionChannels { @@ -243,6 +246,7 @@ impl RepartitionExecState { rx, reservation, spill_manager, + baseline_metrics, }, ); } @@ -266,11 +270,22 @@ impl RepartitionExecState { }) .collect(); + let baseline_metrics: Vec<_> = (0..partitioning.partition_count()) + .map(|partition| { + channels + .get(&partition) + .expect("baseline for partition should exist") + .baseline_metrics + .intermediate() + }) + .collect(); + let input_task = SpawnedTask::spawn(RepartitionExec::pull_from_input( stream, txs.clone(), partitioning.clone(), metrics, + baseline_metrics, )); // In a separate task, wait for each input to be done @@ -349,14 +364,20 @@ impl BatchPartitioner { /// /// The time spent repartitioning, not including time spent in `f` will be recorded /// to the [`metrics::Time`] provided on construction - pub fn partition(&mut self, batch: RecordBatch, mut f: F) -> Result<()> + pub fn partition( + &mut self, + batch: RecordBatch, + baselines: Option<&[BaselineMetrics]>, + mut f: F, + ) -> Result<()> where F: FnMut(usize, RecordBatch) -> Result<()>, { - self.partition_iter(batch)?.try_for_each(|res| match res { - Ok((partition, batch)) => f(partition, batch), - Err(e) => Err(e), - }) + self.partition_iter(batch, baselines)? + .try_for_each(|res| match res { + Ok((partition, batch)) => f(partition, batch), + Err(e) => Err(e), + }) } /// Actual implementation of [`partition`](Self::partition). @@ -364,10 +385,11 @@ impl BatchPartitioner { /// The reason this was pulled out is that we need to have a variant of `partition` that works w/ sync functions, /// and one that works w/ async. Using an iterator as an intermediate representation was the best way to achieve /// this (so we don't need to clone the entire implementation). - fn partition_iter( - &mut self, + fn partition_iter<'a>( + &'a mut self, batch: RecordBatch, - ) -> Result> + Send + '_> { + baselines: Option<&'a [BaselineMetrics]>, + ) -> Result> + Send + 'a> { let it: Box> + Send> = match &mut self.state { BatchPartitionerState::RoundRobin { @@ -410,6 +432,7 @@ impl BatchPartitioner { // Borrowing partitioner timer to prevent moving `self` to closure let partitioner_timer = &self.timer; + let baselines_copy = baselines; let it = indices .into_iter() .enumerate() @@ -418,6 +441,9 @@ impl BatchPartitioner { (!indices.is_empty()).then_some((partition, indices)) }) .map(move |(partition, indices)| { + let baseline_timer = baselines_copy.map(|baselines| { + baselines[partition].elapsed_compute().timer() + }); // Tracking time required for repartitioned batches construction let _timer = partitioner_timer.timer(); @@ -433,7 +459,9 @@ impl BatchPartitioner { ) .unwrap(); - Ok((partition, batch)) + let result = Ok((partition, batch)); + drop(baseline_timer); + result }); Box::new(it) @@ -721,7 +749,7 @@ impl ExecutionPlan for RepartitionExec { let num_input_partitions = input.output_partitioning().partition_count(); // lock scope - let (mut rx, reservation, spill_manager, abort_helper) = { + let (mut rx, reservation, spill_manager, baseline_metrics, abort_helper) = { // lock mutexes let mut state = state.lock(); let state = state.consume_input_streams( @@ -739,6 +767,7 @@ impl ExecutionPlan for RepartitionExec { rx, reservation, spill_manager, + baseline_metrics, .. } = state .channels @@ -749,6 +778,7 @@ impl ExecutionPlan for RepartitionExec { rx, reservation, spill_manager, + baseline_metrics, Arc::clone(&state.abort_helper), ) }; @@ -784,7 +814,7 @@ impl ExecutionPlan for RepartitionExec { .with_streams(input_streams) .with_schema(schema_captured) .with_expressions(&sort_exprs.unwrap()) - .with_metrics(BaselineMetrics::new(&metrics, partition)) + .with_metrics(baseline_metrics.clone()) .with_batch_size(context.session_config().batch_size()) .with_fetch(fetch) .with_reservation(merge_reservation) @@ -798,6 +828,7 @@ impl ExecutionPlan for RepartitionExec { _drop_helper: abort_helper, reservation, spill_manager, + baseline_metrics, state: RepartitionStreamState::ReceivingFromChannel, }) as SendableRecordBatchStream) } @@ -1047,6 +1078,7 @@ impl RepartitionExec { >, partitioning: Partitioning, metrics: RepartitionMetrics, + baseline_metrics: Vec, ) -> Result<()> { let mut partitioner = BatchPartitioner::try_new(partitioning, metrics.repartition_time.clone())?; @@ -1070,7 +1102,7 @@ impl RepartitionExec { continue; } - for res in partitioner.partition_iter(batch)? { + for res in partitioner.partition_iter(batch, Some(&baseline_metrics))? { let (partition, batch) = res?; let size = batch.get_array_memory_size(); @@ -1079,6 +1111,8 @@ impl RepartitionExec { if let Some((tx, reservation, spill_manager)) = output_channels.get_mut(&partition) { + let _baseline_timer = + baseline_metrics[partition].elapsed_compute().timer(); let (batch_to_send, is_memory_batch) = match reservation.lock().try_grow(size) { Ok(_) => { @@ -1217,6 +1251,9 @@ struct RepartitionStream { /// Spill manager for reading spilled batches spill_manager: Arc, + /// Baseline metrics tracker + baseline_metrics: BaselineMetrics, + /// Current state of the stream state: RepartitionStreamState, } @@ -1228,70 +1265,96 @@ impl Stream for RepartitionStream { mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { - loop { - match &mut self.state { + let this = self.as_mut().get_mut(); + let poll = loop { + match &mut this.state { RepartitionStreamState::ReceivingFromChannel => { - let value = futures::ready!(self.input.recv().poll_unpin(cx)); - match value { - Some(Some(v)) => match v { - Ok(RepartitionBatch::Memory(batch)) => { - // Release memory and return - self.reservation - .lock() - .shrink(batch.get_array_memory_size()); - return Poll::Ready(Some(Ok(batch))); - } - Ok(RepartitionBatch::Spilled { spill_file, size }) => { - // Read from disk - SpillReaderStream uses tokio::fs internally - // Pass the original size for validation - let stream = self - .spill_manager - .read_spill_as_stream(spill_file, Some(size))?; - self.state = - RepartitionStreamState::ReadingSpilledBatch(stream); - // Continue loop to poll the stream immediately - } - Err(e) => { - return Poll::Ready(Some(Err(e))); + let recv = match this.input.recv().poll_unpin(cx) { + Poll::Pending => break Poll::Pending, + Poll::Ready(value) => value, + }; + + match recv { + Some(Some(result)) => { + match result { + Ok(RepartitionBatch::Memory(batch)) => { + let bytes = batch.get_array_memory_size(); + break this.ready_with_batch(batch, Some(bytes)); + } + Ok(RepartitionBatch::Spilled { spill_file, size }) => { + match this + .spill_manager + .read_spill_as_stream(spill_file, Some(size)) + { + Ok(stream) => { + // Read from disk - SpillReaderStream uses tokio::fs internally + this.state = RepartitionStreamState::ReadingSpilledBatch(stream); + continue; + } + Err(err) => { + let _timer = this + .baseline_metrics + .elapsed_compute() + .timer(); + break Poll::Ready(Some(Err(err))); + } + } + } + Err(err) => { + let _timer = + this.baseline_metrics.elapsed_compute().timer(); + break Poll::Ready(Some(Err(err))); + } } - }, + } Some(None) => { - self.num_input_partitions_processed += 1; - - if self.num_input_partitions - == self.num_input_partitions_processed + this.num_input_partitions_processed += 1; + if this.num_input_partitions + == this.num_input_partitions_processed { - // all input partitions have finished sending batches - return Poll::Ready(None); - } else { - // other partitions still have data to send - continue; + break Poll::Ready(None); } + continue; } - None => { - return Poll::Ready(None); - } + None => break Poll::Ready(None), } } RepartitionStreamState::ReadingSpilledBatch(stream) => { - match futures::ready!(stream.poll_next_unpin(cx)) { - Some(Ok(batch)) => { - // Return batch and stay in ReadingSpilledBatch state to read more batches - return Poll::Ready(Some(Ok(batch))); + match stream.poll_next_unpin(cx) { + Poll::Pending => break Poll::Pending, + Poll::Ready(Some(Ok(batch))) => { + break this.ready_with_batch(batch, None) } - Some(Err(e)) => { - self.state = RepartitionStreamState::ReceivingFromChannel; - return Poll::Ready(Some(Err(e))); + Poll::Ready(Some(Err(err))) => { + this.state = RepartitionStreamState::ReceivingFromChannel; + let _timer = this.baseline_metrics.elapsed_compute().timer(); + break Poll::Ready(Some(Err(err))); } - None => { - // Spill stream ended - go back to receiving from channel - self.state = RepartitionStreamState::ReceivingFromChannel; + Poll::Ready(None) => { + this.state = RepartitionStreamState::ReceivingFromChannel; continue; } } } } + }; + + this.baseline_metrics.record_poll(poll) + } +} + +impl RepartitionStream { + fn ready_with_batch( + &mut self, + batch: RecordBatch, + released_bytes: Option, + ) -> Poll>> { + if let Some(bytes) = released_bytes { + // Release reserved memory before returning the batch + self.reservation.lock().shrink(bytes); } + let _timer = self.baseline_metrics.elapsed_compute().timer(); + Poll::Ready(Some(Ok(batch))) } } @@ -1411,14 +1474,15 @@ mod tests { }, {collect, expressions::col}, }; - use arrow::array::{ArrayRef, StringArray, UInt32Array}; + use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::cast::as_string_array; use datafusion_common::exec_err; use datafusion_common::test_util::batches_to_sort_string; use datafusion_common_runtime::JoinSet; use datafusion_execution::runtime_env::RuntimeEnvBuilder; + use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use insta::assert_snapshot; use itertools::Itertools; @@ -1897,6 +1961,76 @@ mod tests { Ok(()) } + #[tokio::test] + async fn repartition_reports_baseline_metrics() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let schema = test_schema(); + let partition = create_vec_batches(1); + let input_partitions = vec![partition]; + + let exec = + TestMemoryExec::try_new_exec(&input_partitions, Arc::clone(&schema), None)?; + let exec = RepartitionExec::try_new(exec, Partitioning::RoundRobinBatch(2))?; + + let mut total_rows = 0; + for partition in 0..2 { + let mut stream = exec.execute(partition, Arc::clone(&task_ctx))?; + while let Some(batch) = stream.next().await { + total_rows += batch?.num_rows(); + } + } + + let metrics = exec.metrics().expect("repartition metrics"); + assert_eq!(metrics.output_rows(), Some(total_rows)); + assert!( + metrics.elapsed_compute().is_some(), + "expected elapsed_compute metric" + ); + + Ok(()) + } + + #[tokio::test] + async fn repartition_preserve_order_reports_baseline_metrics() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let schema = test_schema(); + let partition = create_vec_batches(1); + let input_partitions = vec![partition.clone(), partition]; + + let sort_exprs: LexOrdering = [PhysicalSortExpr { + expr: col("c0", schema.as_ref()).unwrap(), + options: SortOptions::default(), + }] + .into(); + + let memory = + TestMemoryExec::try_new(&input_partitions, Arc::clone(&schema), None)?; + let memory = memory + .try_with_sort_information(vec![sort_exprs.clone(), sort_exprs.clone()])?; + let memory = TestMemoryExec::update_cache(Arc::new(memory)); + let input: Arc = Arc::new(memory); + + let exec = RepartitionExec::try_new(input, Partitioning::RoundRobinBatch(2))? + .with_preserve_order(); + + let mut total_rows = 0; + for partition in 0..2 { + let mut stream = exec.execute(partition, Arc::clone(&task_ctx))?; + while let Some(batch) = stream.next().await { + total_rows += batch?.num_rows(); + } + } + + let metrics = exec.metrics().expect("repartition metrics"); + assert_eq!(metrics.output_rows(), Some(total_rows)); + assert!( + metrics.elapsed_compute().is_some(), + "expected elapsed_compute metric" + ); + + Ok(()) + } + #[tokio::test] async fn repartition_with_spilling() -> Result<()> { // Test that repartition successfully spills to disk when memory is constrained @@ -1933,6 +2067,14 @@ mod tests { // Verify spilling metrics to confirm spilling actually happened let metrics = exec.metrics().unwrap(); + assert_eq!(metrics.output_rows(), Some(total_rows)); + let elapsed = metrics + .elapsed_compute() + .expect("expected elapsed_compute metric for spilling case"); + assert!( + elapsed > 0, + "expected elapsed_compute > 0 but was {elapsed}" + ); assert!( metrics.spill_count().unwrap() > 0, "Expected spill_count > 0, but got {:?}", @@ -1996,6 +2138,14 @@ mod tests { // Verify partial spilling metrics let metrics = exec.metrics().unwrap(); + assert_eq!(metrics.output_rows(), Some(total_rows)); + let elapsed = metrics + .elapsed_compute() + .expect("expected elapsed_compute metric for partial spilling case"); + assert!( + elapsed > 0, + "expected elapsed_compute > 0 but was {elapsed}" + ); let spill_count = metrics.spill_count().unwrap(); let spilled_rows = metrics.spilled_rows().unwrap(); let spilled_bytes = metrics.spilled_bytes().unwrap();