diff --git a/datafusion/physical-plan/src/coalesce_partitions.rs b/datafusion/physical-plan/src/coalesce_partitions.rs index 9290d725165e..87b16b90f30d 100644 --- a/datafusion/physical-plan/src/coalesce_partitions.rs +++ b/datafusion/physical-plan/src/coalesce_partitions.rs @@ -18,14 +18,17 @@ //! Defines the merge plan for executing partitions in parallel and then merging the results //! into a single partition +use std::pin::Pin; use std::sync::Arc; +use std::task::{Context, Poll}; use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use super::stream::{ObservedStream, RecordBatchReceiverStream}; use super::{ - DisplayAs, ExecutionPlanProperties, PlanProperties, SendableRecordBatchStream, - Statistics, + DisplayAs, ExecutionPlanProperties, PlanProperties, RecordBatchStream, + SendableRecordBatchStream, Statistics, }; +use crate::coalesce::{LimitedBatchCoalescer, PushBatchStatus}; use crate::execution_plan::{CardinalityEffect, EvaluationType, SchedulingType}; use crate::filter_pushdown::{FilterDescription, FilterPushdownPhase}; use crate::projection::{ProjectionExec, make_with_child}; @@ -33,11 +36,15 @@ use crate::sort_pushdown::SortOrderPushdownResult; use crate::{DisplayFormatType, ExecutionPlan, Partitioning, check_if_same_properties}; use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; +use arrow::datatypes::SchemaRef; +use arrow::record_batch::RecordBatch; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::TreeNodeRecursion; use datafusion_common::{Result, assert_eq_or_internal_err, internal_err}; use datafusion_execution::TaskContext; use datafusion_physical_expr::PhysicalExpr; +use futures::ready; +use futures::stream::{Stream, StreamExt}; /// Merge execution plan executes partitions in parallel and combines them into a single /// partition. No guarantees are made about the order of the resulting partition. @@ -209,6 +216,8 @@ impl ExecutionPlan for CoalescePartitionsExec { let elapsed_compute = baseline_metrics.elapsed_compute().clone(); let _timer = elapsed_compute.timer(); + let batch_size = context.session_config().batch_size(); + // use a stream that allows each sender to put in at // least one result in an attempt to maximize // parallelism. @@ -226,11 +235,23 @@ impl ExecutionPlan for CoalescePartitionsExec { } let stream = builder.build(); - Ok(Box::pin(ObservedStream::new( - stream, - baseline_metrics, - self.fetch, - ))) + // Coalesce small batches from multiple partitions into + // larger batches of target_batch_size. This improves + // downstream performance (e.g. hash join build side + // benefits from fewer, larger batches). + Ok(Box::pin(CoalescedStream { + input: Box::pin(ObservedStream::new( + stream, + baseline_metrics, + self.fetch, + )), + coalescer: LimitedBatchCoalescer::new( + self.schema(), + batch_size, + None, // fetch is already handled by ObservedStream + ), + completed: false, + })) } } } @@ -347,6 +368,53 @@ impl ExecutionPlan for CoalescePartitionsExec { } } +/// Stream that coalesces small batches into larger ones using +/// [`LimitedBatchCoalescer`]. +struct CoalescedStream { + input: SendableRecordBatchStream, + coalescer: LimitedBatchCoalescer, + completed: bool, +} + +impl Stream for CoalescedStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + loop { + if let Some(batch) = self.coalescer.next_completed_batch() { + return Poll::Ready(Some(Ok(batch))); + } + if self.completed { + return Poll::Ready(None); + } + let input_batch = ready!(self.input.poll_next_unpin(cx)); + match input_batch { + None => { + self.completed = true; + self.coalescer.finish()?; + } + Some(Ok(batch)) => match self.coalescer.push_batch(batch)? { + PushBatchStatus::Continue => {} + PushBatchStatus::LimitReached => { + self.completed = true; + self.coalescer.finish()?; + } + }, + other => return Poll::Ready(other), + } + } + } +} + +impl RecordBatchStream for CoalescedStream { + fn schema(&self) -> SchemaRef { + self.coalescer.schema() + } +} + #[cfg(test)] mod tests { use super::*; @@ -378,10 +446,9 @@ mod tests { 1 ); - // the result should contain 4 batches (one per input partition) + // the result should contain all rows (coalesced into fewer batches) let iter = merge.execute(0, task_ctx)?; let batches = common::collect(iter).await?; - assert_eq!(batches.len(), num_partitions); // there should be a total of 400 rows (100 per each partition) let row_count: usize = batches.iter().map(|batch| batch.num_rows()).sum();