diff --git a/datafusion/core/tests/memory_limit/mod.rs b/datafusion/core/tests/memory_limit/mod.rs index 7075fbc2443d..da13389901ee 100644 --- a/datafusion/core/tests/memory_limit/mod.rs +++ b/datafusion/core/tests/memory_limit/mod.rs @@ -213,6 +213,7 @@ async fn sort_merge_join_spill() { .with_config(config) .with_disk_manager_builder(DiskManagerBuilder::default()) .with_scenario(Scenario::AccessLogStreaming) + .with_expected_success() .run() .await } diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs index ee1ae3708961..da0d21f046da 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs @@ -38,7 +38,7 @@ use crate::handle_state; use crate::joins::piecewise_merge_join::exec::{BufferedSide, BufferedSideReadyState}; use crate::joins::piecewise_merge_join::utils::need_produce_result_in_final; use crate::joins::utils::{BuildProbeJoinMetrics, StatefulStreamResult}; -use crate::joins::utils::{compare_join_arrays, get_final_indices_from_shared_bitmap}; +use crate::joins::utils::{JoinKeyComparator, get_final_indices_from_shared_bitmap}; pub(super) enum PiecewiseMergeJoinStreamState { WaitBufferedSide, @@ -460,6 +460,14 @@ fn resolve_classic_join( let buffered_len = buffered_side.buffered_data.values().len(); let stream_values = stream_batch.compare_key_values(); + // Build comparator once for the batch pair + let cmp = JoinKeyComparator::new( + &[Arc::clone(&stream_values[0])], + &[Arc::clone(buffered_side.buffered_data.values())], + &[sort_options], + NullEquality::NullEqualsNothing, + )?; + let mut buffer_idx = batch_process_state.start_buffer_idx; let mut stream_idx = batch_process_state.start_stream_idx; @@ -475,17 +483,7 @@ fn resolve_classic_join( // in the previous stream row. for row_idx in stream_idx..stream_batch.batch.num_rows() { while buffer_idx < buffered_len { - let compare = { - let buffered_values = buffered_side.buffered_data.values(); - compare_join_arrays( - &[Arc::clone(&stream_values[0])], - row_idx, - &[Arc::clone(buffered_values)], - buffer_idx, - &[sort_options], - NullEquality::NullEqualsNothing, - )? - }; + let compare = cmp.compare(row_idx, buffer_idx); // If we find a match we append all indices and move to the next stream row index match operator { diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/bitwise_stream.rs b/datafusion/physical-plan/src/joins/sort_merge_join/bitwise_stream.rs index 2f7c9acb9d1b..3b409c98b2cf 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/bitwise_stream.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/bitwise_stream.rs @@ -126,7 +126,7 @@ use std::sync::Arc; use std::task::{Context, Poll}; use crate::RecordBatchStream; -use crate::joins::utils::{JoinFilter, compare_join_arrays}; +use crate::joins::utils::{JoinFilter, JoinKeyComparator, compare_join_arrays}; use crate::metrics::{ BaselineMetrics, Count, ExecutionPlanMetricsSet, Gauge, MetricBuilder, }; @@ -162,48 +162,26 @@ fn evaluate_join_keys( } /// Find the first index in `key_arrays` starting from `from` where the key -/// differs from the key at `from`. Uses `compare_join_arrays` for zero-alloc -/// ordinal comparison. +/// differs from the key at `from`. Uses a pre-built `JoinKeyComparator` for +/// zero-alloc ordinal comparison without per-row type dispatch. /// /// Optimized for join workloads: checks adjacent and boundary keys before /// falling back to binary search, since most key groups are small (often 1). -fn find_key_group_end( - key_arrays: &[ArrayRef], - from: usize, - len: usize, - sort_options: &[SortOptions], - null_equality: NullEquality, -) -> Result { +fn find_key_group_end(cmp: &JoinKeyComparator, from: usize, len: usize) -> usize { let next = from + 1; if next >= len { - return Ok(len); + return len; } // Fast path: single-row group (common with unique keys). - if compare_join_arrays( - key_arrays, - from, - key_arrays, - next, - sort_options, - null_equality, - )? != Ordering::Equal - { - return Ok(next); + if cmp.compare(from, next) != Ordering::Equal { + return next; } // Check if the entire remaining batch shares this key. let last = len - 1; - if compare_join_arrays( - key_arrays, - from, - key_arrays, - last, - sort_options, - null_equality, - )? == Ordering::Equal - { - return Ok(len); + if cmp.compare(from, last) == Ordering::Equal { + return len; } // Binary search the interior: key at `next` matches, key at `last` doesn't. @@ -211,21 +189,13 @@ fn find_key_group_end( let mut hi = last; while lo < hi { let mid = lo + (hi - lo) / 2; - if compare_join_arrays( - key_arrays, - from, - key_arrays, - mid, - sort_options, - null_equality, - )? == Ordering::Equal - { + if cmp.compare(from, mid) == Ordering::Equal { lo = mid + 1; } else { hi = mid; } } - Ok(lo) + lo } /// When an outer key group spans a batch boundary, the boundary loop emits @@ -328,6 +298,14 @@ pub(crate) struct BitwiseSortMergeJoinStream { runtime_env: Arc, inner_buffer_size: usize, + // Cached comparators — pre-built to avoid per-row type dispatch. + /// Comparator for outer vs inner key comparison + outer_inner_cmp: Option, + /// Comparator for outer self-comparison (find_key_group_end on outer) + outer_self_cmp: Option, + /// Comparator for inner self-comparison (find_key_group_end on inner) + inner_self_cmp: Option, + // True once the current outer batch has been emitted. The Equal // branch's inner loops call emit then `ready!(poll_next_outer_batch)`. // If that poll returns Pending, poll_join re-enters from the top @@ -413,6 +391,9 @@ impl BitwiseSortMergeJoinStream { spill_manager, runtime_env, inner_buffer_size: 0, + outer_inner_cmp: None, + outer_self_cmp: None, + inner_self_cmp: None, batch_emitted: false, }) } @@ -425,6 +406,45 @@ impl BitwiseSortMergeJoinStream { Ok(()) } + /// Get or build the outer vs inner key comparator. + fn get_outer_inner_cmp(&mut self) -> Result<&JoinKeyComparator> { + if self.outer_inner_cmp.is_none() { + self.outer_inner_cmp = Some(JoinKeyComparator::new( + &self.outer_key_arrays, + &self.inner_key_arrays, + &self.sort_options, + self.null_equality, + )?); + } + Ok(self.outer_inner_cmp.as_ref().unwrap()) + } + + /// Get or build the outer self-comparison comparator. + fn get_outer_self_cmp(&mut self) -> Result<&JoinKeyComparator> { + if self.outer_self_cmp.is_none() { + self.outer_self_cmp = Some(JoinKeyComparator::new( + &self.outer_key_arrays, + &self.outer_key_arrays, + &self.sort_options, + self.null_equality, + )?); + } + Ok(self.outer_self_cmp.as_ref().unwrap()) + } + + /// Get or build the inner self-comparison comparator. + fn get_inner_self_cmp(&mut self) -> Result<&JoinKeyComparator> { + if self.inner_self_cmp.is_none() { + self.inner_self_cmp = Some(JoinKeyComparator::new( + &self.inner_key_arrays, + &self.inner_key_arrays, + &self.sort_options, + self.null_equality, + )?); + } + Ok(self.inner_self_cmp.as_ref().unwrap()) + } + /// Spill the in-memory inner key buffer to disk and clear it. fn spill_inner_key_buffer(&mut self) -> Result<()> { let spill_file = self @@ -468,6 +488,8 @@ impl BitwiseSortMergeJoinStream { self.outer_batch = Some(batch); self.outer_offset = 0; self.outer_key_arrays = keys; + self.outer_inner_cmp = None; + self.outer_self_cmp = None; self.batch_emitted = false; self.matched = BooleanBufferBuilder::new(batch_num_rows); self.matched.append_n(batch_num_rows, false); @@ -494,6 +516,8 @@ impl BitwiseSortMergeJoinStream { self.inner_batch = Some(batch); self.inner_offset = 0; self.inner_key_arrays = keys; + self.outer_inner_cmp = None; + self.inner_self_cmp = None; return Poll::Ready(Ok(true)); } } @@ -555,13 +579,12 @@ impl BitwiseSortMergeJoinStream { let outer_batch = self.outer_batch.as_ref().unwrap(); let num_outer = outer_batch.num_rows(); + self.get_outer_self_cmp()?; let outer_group_end = find_key_group_end( - &self.outer_key_arrays, + self.outer_self_cmp.as_ref().unwrap(), self.outer_offset, num_outer, - &self.sort_options, - self.null_equality, - )?; + ); for i in self.outer_offset..outer_group_end { self.matched.set_bit(i, true); @@ -584,13 +607,12 @@ impl BitwiseSortMergeJoinStream { }; let num_inner = inner_batch.num_rows(); + self.get_inner_self_cmp()?; let group_end = find_key_group_end( - &self.inner_key_arrays, + self.inner_self_cmp.as_ref().unwrap(), self.inner_offset, num_inner, - &self.sort_options, - self.null_equality, - )?; + ); if group_end < num_inner { self.inner_offset = group_end; @@ -642,20 +664,19 @@ impl BitwiseSortMergeJoinStream { } loop { - let inner_batch = match &self.inner_batch { - Some(b) => b, - None => return Poll::Ready(Ok(true)), - }; - let num_inner = inner_batch.num_rows(); + if self.inner_batch.is_none() { + return Poll::Ready(Ok(true)); + } + let num_inner = self.inner_batch.as_ref().unwrap().num_rows(); + self.get_inner_self_cmp()?; let group_end = find_key_group_end( - &self.inner_key_arrays, + self.inner_self_cmp.as_ref().unwrap(), self.inner_offset, num_inner, - &self.sort_options, - self.null_equality, - )?; + ); if !resume_from_poll { + let inner_batch = self.inner_batch.as_ref().unwrap(); let slice = inner_batch.slice(self.inner_offset, group_end - self.inner_offset); self.inner_buffer_size += slice.get_array_memory_size(); @@ -719,6 +740,7 @@ impl BitwiseSortMergeJoinStream { /// key group, evaluates the filter against the outer key group and ORs /// the results into the matched bitset using u64-chunked bitwise ops. fn process_key_match_with_filter(&mut self) -> Result<()> { + self.get_outer_self_cmp()?; let filter = self.filter.as_ref().unwrap(); let outer_batch = self.outer_batch.as_ref().unwrap(); let num_outer = outer_batch.num_rows(); @@ -738,12 +760,10 @@ impl BitwiseSortMergeJoinStream { ); let outer_group_end = find_key_group_end( - &self.outer_key_arrays, + self.outer_self_cmp.as_ref().unwrap(), self.outer_offset, num_outer, - &self.sort_options, - self.null_equality, - )?; + ); let outer_group_len = outer_group_end - self.outer_offset; let outer_slice = outer_batch.slice(self.outer_offset, outer_group_len); @@ -959,34 +979,30 @@ impl BitwiseSortMergeJoinStream { } // 4. Compare keys at current positions - let cmp = compare_join_arrays( - &self.outer_key_arrays, - self.outer_offset, - &self.inner_key_arrays, - self.inner_offset, - &self.sort_options, - self.null_equality, - )?; + self.get_outer_inner_cmp()?; + let cmp = self + .outer_inner_cmp + .as_ref() + .unwrap() + .compare(self.outer_offset, self.inner_offset); match cmp { Ordering::Less => { + self.get_outer_self_cmp()?; let group_end = find_key_group_end( - &self.outer_key_arrays, + self.outer_self_cmp.as_ref().unwrap(), self.outer_offset, num_outer, - &self.sort_options, - self.null_equality, - )?; + ); self.outer_offset = group_end; } Ordering::Greater => { + self.get_inner_self_cmp()?; let group_end = find_key_group_end( - &self.inner_key_arrays, + self.inner_self_cmp.as_ref().unwrap(), self.inner_offset, num_inner, - &self.sort_options, - self.null_equality, - )?; + ); if group_end >= num_inner { let saved_keys = slice_keys(&self.inner_key_arrays, num_inner - 1); diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs b/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs index c387e05390fc..bf0117d2e364 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs @@ -38,7 +38,7 @@ use crate::joins::sort_merge_join::filter::{ get_filter_columns, needs_deferred_filtering, }; use crate::joins::sort_merge_join::metrics::SortMergeJoinMetrics; -use crate::joins::utils::{JoinFilter, compare_join_arrays}; +use crate::joins::utils::{JoinFilter, JoinKeyComparator}; use crate::metrics::RecordOutput; use crate::spill::spill_manager::SpillManager; use crate::{PhysicalExpr, RecordBatchStream, SendableRecordBatchStream}; @@ -48,12 +48,10 @@ use arrow::compute::{ self, BatchCoalescer, SortOptions, concat_batches, filter_record_batch, interleave, take, take_arrays, }; -use arrow::datatypes::{DataType, SchemaRef, TimeUnit}; +use arrow::datatypes::SchemaRef; use arrow::ipc::reader::StreamReader; use datafusion_common::cast::as_uint64_array; -use datafusion_common::{ - JoinType, NullEquality, Result, exec_err, internal_err, not_impl_err, -}; +use datafusion_common::{JoinType, NullEquality, Result, exec_err, internal_err}; use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::MemoryReservation; use datafusion_execution::runtime_env::RuntimeEnv; @@ -351,6 +349,15 @@ pub(super) struct MaterializingSortMergeJoinStream { /// Manages the process of spilling and reading back intermediate data pub spill_manager: SpillManager, + // ======================================================================== + // CACHED COMPARATORS: + // Pre-built comparators to avoid per-row type dispatch in hot loops. + // ======================================================================== + /// Comparator for streamed vs buffered head batch key comparison + pub streamed_buffered_cmp: Option, + /// Comparator for buffered head vs tail batch equality check + pub buffered_equality_cmp: Option, + // ======================================================================== // EXECUTION RESOURCES: // Fields related to managing execution resources and monitoring performance. @@ -803,10 +810,45 @@ impl MaterializingSortMergeJoinStream { reservation, runtime_env, spill_manager, + streamed_buffered_cmp: None, + buffered_equality_cmp: None, streamed_batch_counter: AtomicUsize::new(0), }) } + /// Build a comparator for streamed vs buffered head batch keys. + fn rebuild_streamed_buffered_cmp(&mut self) -> Result<()> { + if self.streamed_batch.join_arrays.is_empty() + || !self.buffered_data.has_buffered_rows() + { + self.streamed_buffered_cmp = None; + return Ok(()); + } + self.streamed_buffered_cmp = Some(JoinKeyComparator::new( + &self.streamed_batch.join_arrays, + &self.buffered_data.head_batch().join_arrays, + &self.sort_options, + self.null_equality, + )?); + Ok(()) + } + + /// Build a comparator for buffered head vs tail batch equality. + fn rebuild_buffered_equality_cmp(&mut self) -> Result<()> { + if self.buffered_data.batches.is_empty() { + self.buffered_equality_cmp = None; + return Ok(()); + } + self.buffered_equality_cmp = Some(JoinKeyComparator::new( + &self.buffered_data.head_batch().join_arrays, + &self.buffered_data.tail_batch().join_arrays, + &self.sort_options, + // is_join_arrays_equal treats both-null as equal + NullEquality::NullEqualsNull, + )?); + Ok(()) + } + /// Number of unfrozen output pairs (used to decide when to freeze + output) fn num_unfrozen_pairs(&self) -> usize { self.streamed_batch.num_output_rows() @@ -870,6 +912,7 @@ impl MaterializingSortMergeJoinStream { self.join_metrics.input_rows().add(batch.num_rows()); self.streamed_batch = StreamedBatch::new(batch, &self.on_streamed); + self.rebuild_streamed_buffered_cmp()?; // Every incoming streaming batch should have its unique id // Check `JoinedRecordBatches.self.streamed_batch_counter` documentation self.streamed_batch_counter @@ -937,6 +980,7 @@ impl MaterializingSortMergeJoinStream { match &self.buffered_state { BufferedState::Init => { // pop previous buffered batches + let mut head_changed = false; while !self.buffered_data.batches.is_empty() { let head_batch = self.buffered_data.head_batch(); // If the head batch is fully processed, dequeue it and produce output of it. @@ -947,6 +991,7 @@ impl MaterializingSortMergeJoinStream { { self.produce_buffered_not_matched(&mut buffered_batch)?; self.free_reservation(&buffered_batch)?; + head_changed = true; } } else { // If the head batch is not fully processed, break the loop. @@ -954,6 +999,10 @@ impl MaterializingSortMergeJoinStream { break; } } + if head_changed { + self.streamed_buffered_cmp = None; + self.buffered_equality_cmp = None; + } if self.buffered_data.batches.is_empty() { self.buffered_state = BufferedState::PollingFirst; } else { @@ -980,6 +1029,7 @@ impl MaterializingSortMergeJoinStream { BufferedBatch::new(batch, 0..1, &self.on_buffered); self.allocate_reservation(buffered_batch)?; + self.streamed_buffered_cmp = None; self.buffered_state = BufferedState::PollingRest; } } @@ -988,15 +1038,16 @@ impl MaterializingSortMergeJoinStream { if self.buffered_data.tail_batch().range.end < self.buffered_data.tail_batch().num_rows { + if self.buffered_equality_cmp.is_none() { + self.rebuild_buffered_equality_cmp()?; + } while self.buffered_data.tail_batch().range.end < self.buffered_data.tail_batch().num_rows { - if is_join_arrays_equal( - &self.buffered_data.head_batch().join_arrays, + if self.buffered_equality_cmp.as_ref().unwrap().is_equal( self.buffered_data.head_batch().range.start, - &self.buffered_data.tail_batch().join_arrays, self.buffered_data.tail_batch().range.end, - )? { + ) { self.buffered_data.tail_batch_mut().range.end += 1; } else { self.buffered_state = BufferedState::Ready; @@ -1022,6 +1073,7 @@ impl MaterializingSortMergeJoinStream { &self.on_buffered, ); self.allocate_reservation(buffered_batch)?; + self.buffered_equality_cmp = None; } } } @@ -1038,7 +1090,7 @@ impl MaterializingSortMergeJoinStream { } /// Get comparison result of streamed row and buffered batches - fn compare_streamed_buffered(&self) -> Result { + fn compare_streamed_buffered(&mut self) -> Result { if self.streamed_state == StreamedState::Exhausted { return Ok(Ordering::Greater); } @@ -1046,14 +1098,13 @@ impl MaterializingSortMergeJoinStream { return Ok(Ordering::Less); } - compare_join_arrays( - &self.streamed_batch.join_arrays, + if self.streamed_buffered_cmp.is_none() { + self.rebuild_streamed_buffered_cmp()?; + } + Ok(self.streamed_buffered_cmp.as_ref().unwrap().compare( self.streamed_batch.idx, - &self.buffered_data.head_batch().join_arrays, self.buffered_data.head_batch().range.start, - &self.sort_options, - self.null_equality, - ) + )) } /// Produce join and fill output buffer until reaching target batch size @@ -1810,78 +1861,3 @@ fn join_arrays(batch: &RecordBatch, on_column: &[PhysicalExprRef]) -> Vec Result { - let mut is_equal = true; - for (left_array, right_array) in left_arrays.iter().zip(right_arrays) { - macro_rules! compare_value { - ($T:ty) => {{ - match (left_array.is_null(left), right_array.is_null(right)) { - (false, false) => { - let left_array = - left_array.as_any().downcast_ref::<$T>().unwrap(); - let right_array = - right_array.as_any().downcast_ref::<$T>().unwrap(); - if left_array.value(left) != right_array.value(right) { - is_equal = false; - } - } - (true, false) => is_equal = false, - (false, true) => is_equal = false, - _ => {} - } - }}; - } - - match left_array.data_type() { - DataType::Null => {} - DataType::Boolean => compare_value!(BooleanArray), - DataType::Int8 => compare_value!(Int8Array), - DataType::Int16 => compare_value!(Int16Array), - DataType::Int32 => compare_value!(Int32Array), - DataType::Int64 => compare_value!(Int64Array), - DataType::UInt8 => compare_value!(UInt8Array), - DataType::UInt16 => compare_value!(UInt16Array), - DataType::UInt32 => compare_value!(UInt32Array), - DataType::UInt64 => compare_value!(UInt64Array), - DataType::Float32 => compare_value!(Float32Array), - DataType::Float64 => compare_value!(Float64Array), - DataType::Utf8 => compare_value!(StringArray), - DataType::Utf8View => compare_value!(StringViewArray), - DataType::LargeUtf8 => compare_value!(LargeStringArray), - DataType::Binary => compare_value!(BinaryArray), - DataType::BinaryView => compare_value!(BinaryViewArray), - DataType::FixedSizeBinary(_) => compare_value!(FixedSizeBinaryArray), - DataType::LargeBinary => compare_value!(LargeBinaryArray), - DataType::Decimal32(..) => compare_value!(Decimal32Array), - DataType::Decimal64(..) => compare_value!(Decimal64Array), - DataType::Decimal128(..) => compare_value!(Decimal128Array), - DataType::Decimal256(..) => compare_value!(Decimal256Array), - DataType::Timestamp(time_unit, None) => match time_unit { - TimeUnit::Second => compare_value!(TimestampSecondArray), - TimeUnit::Millisecond => compare_value!(TimestampMillisecondArray), - TimeUnit::Microsecond => compare_value!(TimestampMicrosecondArray), - TimeUnit::Nanosecond => compare_value!(TimestampNanosecondArray), - }, - DataType::Date32 => compare_value!(Date32Array), - DataType::Date64 => compare_value!(Date64Array), - dt => { - return not_impl_err!( - "Unsupported data type in sort merge join comparator: {}", - dt - ); - } - } - if !is_equal { - return Ok(false); - } - } - Ok(true) -} diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index d3c8ccc11bcb..39a380eed9d4 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -59,6 +59,7 @@ use arrow::datatypes::{ ArrowNativeType, Field, Schema, SchemaBuilder, UInt32Type, UInt64Type, }; use arrow_ord::cmp::not_distinct; +use arrow_ord::ord::{DynComparator, make_comparator}; use arrow_schema::{ArrowError, DataType, SortOptions, TimeUnit}; use datafusion_common::cast::as_boolean_array; use datafusion_common::hash_utils::RandomState; @@ -1822,6 +1823,110 @@ fn eq_dyn_null( } } +/// Pre-built comparator for join key columns that eliminates per-row type +/// dispatch. Wraps `arrow_ord::ord::DynComparator` closures built once per +/// batch pair, used for all row comparisons within those batches. +/// +/// The first key column is stored separately so that single-column joins +/// (the common case) avoid Vec iteration entirely, and multi-column joins +/// short-circuit without entering the loop when the first column is +/// selective. +/// +/// Null handling is baked into the closures at construction time: +/// - `NullEqualsNull`: `make_comparator` returns `Equal` for both-null, which +/// is the desired behavior. Closures are used as-is. +/// - `NullEqualsNothing`: columns where both sides contain nulls get a wrapper +/// that returns `Less` for both-null. Columns where one side has no nulls +/// skip the wrapper since both-null is impossible. +/// +/// Because `NullEqualsNothing` wraps comparators to return `Less` for +/// both-null, `is_equal` will return `false` for both-null rows when that +/// mode is active. Callers needing both-null == equal semantics (e.g., +/// buffered head/tail equality in SMJ) should construct with +/// `NullEqualsNull`. +pub struct JoinKeyComparator { + first: DynComparator, + rest: Vec, +} + +impl JoinKeyComparator { + /// Build comparators for each join key column pair. + pub fn new( + left_arrays: &[ArrayRef], + right_arrays: &[ArrayRef], + sort_options: &[SortOptions], + null_equality: NullEquality, + ) -> Result { + debug_assert_eq!(left_arrays.len(), right_arrays.len()); + debug_assert_eq!(left_arrays.len(), sort_options.len()); + + let mut iter = left_arrays + .iter() + .zip(right_arrays.iter()) + .zip(sort_options.iter()) + .map(|((l, r), opts)| { + let inner = make_comparator(l.as_ref(), r.as_ref(), *opts)?; + if null_equality == NullEquality::NullEqualsNothing { + let ln = l.logical_nulls().filter(|n| n.null_count() > 0); + let rn = r.logical_nulls().filter(|n| n.null_count() > 0); + match (ln, rn) { + // Both sides have nulls — wrap to override both-null. + (Some(ln), Some(rn)) => Ok(Box::new(move |i, j| { + if ln.is_null(i) && rn.is_null(j) { + Ordering::Less + } else { + inner(i, j) + } + }) + as DynComparator), + // One side has no nulls — both-null impossible, no wrap. + _ => Ok(inner), + } + } else { + Ok(inner) + } + }); + + let first = iter.next().expect("join must have at least one key")?; + let rest = iter.collect::>>()?; + Ok(Self { first, rest }) + } + + /// Compare row `left` (in the left arrays) with row `right` (in the right + /// arrays). Returns the lexicographic ordering across all key columns. + #[inline] + pub fn compare(&self, left: usize, right: usize) -> Ordering { + let ord = (self.first)(left, right); + if ord != Ordering::Equal || self.rest.is_empty() { + return ord; + } + for cmp_fn in &self.rest { + let ord = cmp_fn(left, right); + if ord != Ordering::Equal { + return ord; + } + } + Ordering::Equal + } + + /// Check equality of row `left` (in the left arrays) with row `right` + /// (in the right arrays). Both-null is treated as equal when constructed + /// with `NullEqualsNull`. With `NullEqualsNothing`, both-null returns + /// `false` because the override is baked into the comparators. + #[inline] + pub fn is_equal(&self, left: usize, right: usize) -> bool { + if (self.first)(left, right) != Ordering::Equal { + return false; + } + for cmp_fn in &self.rest { + if cmp_fn(left, right) != Ordering::Equal { + return false; + } + } + true + } +} + /// Get comparison result of two rows of join arrays pub fn compare_join_arrays( left_arrays: &[ArrayRef], @@ -2949,4 +3054,120 @@ mod tests { let result = max_distinct_count(&num_rows, &stats); assert_eq!(result, Exact(0)); } + + #[test] + fn test_join_key_comparator_multi_column() { + let left_a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 2, 3])); + let left_b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d"])); + let right_a: ArrayRef = Arc::new(Int32Array::from(vec![2, 2, 3, 4])); + let right_b: ArrayRef = Arc::new(StringArray::from(vec!["b", "d", "a", "a"])); + + let opts = vec![SortOptions::default(), SortOptions::default()]; + let cmp = JoinKeyComparator::new( + &[left_a, left_b], + &[right_a, right_b], + &opts, + NullEquality::NullEqualsNull, + ) + .unwrap(); + + // left[0]=(1,"a") vs right[0]=(2,"b") -> Less (first column) + assert_eq!(cmp.compare(0, 0), Ordering::Less); + // left[1]=(2,"b") vs right[0]=(2,"b") -> Equal + assert_eq!(cmp.compare(1, 0), Ordering::Equal); + assert!(cmp.is_equal(1, 0)); + // left[2]=(2,"c") vs right[1]=(2,"d") -> Less (second column) + assert_eq!(cmp.compare(2, 1), Ordering::Less); + // left[3]=(3,"d") vs right[0]=(2,"b") -> Greater + assert_eq!(cmp.compare(3, 0), Ordering::Greater); + } + + #[test] + fn test_join_key_comparator_null_equals_null() { + let left: ArrayRef = + Arc::new(Int32Array::from(vec![Some(1), None, None, Some(2)])); + let right: ArrayRef = + Arc::new(Int32Array::from(vec![None, None, Some(1), Some(2)])); + + let opts = vec![SortOptions { + descending: false, + nulls_first: true, + }]; + let cmp = JoinKeyComparator::new( + &[left], + &[right], + &opts, + NullEquality::NullEqualsNull, + ) + .unwrap(); + + // left[1]=NULL vs right[1]=NULL -> Equal (NullEqualsNull) + assert_eq!(cmp.compare(1, 1), Ordering::Equal); + assert!(cmp.is_equal(1, 1)); + // left[0]=1 vs right[0]=NULL -> Greater (nulls_first, non-null > null) + assert_eq!(cmp.compare(0, 0), Ordering::Greater); + // left[3]=2 vs right[3]=2 -> Equal + assert_eq!(cmp.compare(3, 3), Ordering::Equal); + } + + #[test] + fn test_join_key_comparator_null_equals_nothing() { + let left: ArrayRef = + Arc::new(Int32Array::from(vec![Some(1), None, None, Some(2)])); + let right: ArrayRef = + Arc::new(Int32Array::from(vec![None, None, Some(1), Some(2)])); + + let opts = vec![SortOptions { + descending: false, + nulls_first: true, + }]; + let cmp = JoinKeyComparator::new( + &[left], + &[right], + &opts, + NullEquality::NullEqualsNothing, + ) + .unwrap(); + + // left[1]=NULL vs right[1]=NULL -> Less (NullEqualsNothing) + assert_eq!(cmp.compare(1, 1), Ordering::Less); + // left[0]=1 vs right[0]=NULL -> Greater (nulls_first) + assert_eq!(cmp.compare(0, 0), Ordering::Greater); + // left[3]=2 vs right[3]=2 -> Equal + assert_eq!(cmp.compare(3, 3), Ordering::Equal); + } + + #[test] + fn test_join_key_comparator_nulls_first_ordering() { + let left: ArrayRef = Arc::new(Int32Array::from(vec![None, Some(1)])); + let right: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None])); + + // nulls_first = true: null < non-null + let cmp_nf = JoinKeyComparator::new( + &[Arc::clone(&left)], + &[Arc::clone(&right)], + &[SortOptions { + descending: false, + nulls_first: true, + }], + NullEquality::NullEqualsNull, + ) + .unwrap(); + assert_eq!(cmp_nf.compare(0, 0), Ordering::Less); + assert_eq!(cmp_nf.compare(1, 1), Ordering::Greater); + + // nulls_first = false: null > non-null + let cmp_nl = JoinKeyComparator::new( + &[left], + &[right], + &[SortOptions { + descending: false, + nulls_first: false, + }], + NullEquality::NullEqualsNull, + ) + .unwrap(); + assert_eq!(cmp_nl.compare(0, 0), Ordering::Greater); + assert_eq!(cmp_nl.compare(1, 1), Ordering::Less); + } }