From b6f57cfc640dbe7451699a144787bad36b197886 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Thu, 26 Mar 2026 16:53:31 -0400 Subject: [PATCH 1/8] stash with dyncomparator changes --- datafusion/core/tests/fuzz_cases/join_fuzz.rs | 80 +++---- .../piecewise_merge_join/classic_join.rs | 22 +- .../joins/semi_anti_sort_merge_join/stream.rs | 172 +++++++------- .../src/joins/sort_merge_join/stream.rs | 156 ++++++------- datafusion/physical-plan/src/joins/utils.rs | 215 ++++++++++++++++++ 5 files changed, 426 insertions(+), 219 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index 669b98e39fec..3b61418aab73 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -98,7 +98,7 @@ async fn test_inner_join_1k_filtered() { JoinType::Inner, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[HjSmj, NljHj], false) + .run_test(&[HjSmj], false) .await } } @@ -112,7 +112,7 @@ async fn test_inner_join_1k() { JoinType::Inner, None, ) - .run_test(&[HjSmj, NljHj], false) + .run_test(&[HjSmj], false) .await } } @@ -126,7 +126,7 @@ async fn test_left_join_1k() { JoinType::Left, None, ) - .run_test(&[HjSmj, NljHj], false) + .run_test(&[HjSmj], false) .await } } @@ -140,7 +140,7 @@ async fn test_left_join_1k_filtered() { JoinType::Left, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[HjSmj, NljHj], false) + .run_test(&[HjSmj], false) .await } } @@ -154,7 +154,7 @@ async fn test_right_join_1k() { JoinType::Right, None, ) - .run_test(&[HjSmj, NljHj], false) + .run_test(&[HjSmj], false) .await } } @@ -168,7 +168,7 @@ async fn test_right_join_1k_filtered() { JoinType::Right, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[HjSmj, NljHj], false) + .run_test(&[HjSmj], false) .await } } @@ -182,7 +182,7 @@ async fn test_full_join_1k() { JoinType::Full, None, ) - .run_test(&[HjSmj, NljHj], false) + .run_test(&[HjSmj], false) .await } } @@ -196,7 +196,7 @@ async fn test_full_join_1k_filtered() { JoinType::Full, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[NljHj, HjSmj], false) + .run_test(&[HjSmj], false) .await } } @@ -210,7 +210,7 @@ async fn test_left_semi_join_1k() { JoinType::LeftSemi, None, ) - .run_test(&[HjSmj, NljHj], false) + .run_test(&[HjSmj], false) .await } } @@ -224,7 +224,7 @@ async fn test_left_semi_join_1k_filtered() { JoinType::LeftSemi, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[HjSmj, NljHj], false) + .run_test(&[HjSmj], false) .await } } @@ -238,7 +238,7 @@ async fn test_right_semi_join_1k() { JoinType::RightSemi, None, ) - .run_test(&[HjSmj, NljHj], false) + .run_test(&[HjSmj], false) .await } } @@ -252,7 +252,7 @@ async fn test_right_semi_join_1k_filtered() { JoinType::RightSemi, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[HjSmj, NljHj], false) + .run_test(&[HjSmj], false) .await } } @@ -266,7 +266,7 @@ async fn test_left_anti_join_1k() { JoinType::LeftAnti, None, ) - .run_test(&[HjSmj, NljHj], false) + .run_test(&[HjSmj], false) .await } } @@ -280,7 +280,7 @@ async fn test_left_anti_join_1k_filtered() { JoinType::LeftAnti, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[HjSmj, NljHj], false) + .run_test(&[HjSmj], false) .await } } @@ -294,7 +294,7 @@ async fn test_right_anti_join_1k() { JoinType::RightAnti, None, ) - .run_test(&[HjSmj, NljHj], false) + .run_test(&[HjSmj], false) .await } } @@ -308,7 +308,7 @@ async fn test_right_anti_join_1k_filtered() { JoinType::RightAnti, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[HjSmj, NljHj], false) + .run_test(&[HjSmj], false) .await } } @@ -322,7 +322,7 @@ async fn test_left_mark_join_1k() { JoinType::LeftMark, None, ) - .run_test(&[HjSmj, NljHj], false) + .run_test(&[HjSmj], false) .await } } @@ -336,7 +336,7 @@ async fn test_left_mark_join_1k_filtered() { JoinType::LeftMark, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[HjSmj, NljHj], false) + .run_test(&[HjSmj], false) .await } } @@ -351,7 +351,7 @@ async fn test_right_mark_join_1k() { JoinType::RightMark, None, ) - .run_test(&[HjSmj, NljHj], false) + .run_test(&[HjSmj], false) .await } } @@ -365,7 +365,7 @@ async fn test_right_mark_join_1k_filtered() { JoinType::RightMark, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[HjSmj, NljHj], false) + .run_test(&[HjSmj], false) .await } } @@ -379,7 +379,7 @@ async fn test_inner_join_1k_binary_filtered() { JoinType::Inner, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[HjSmj, NljHj], false) + .run_test(&[HjSmj], false) .await } } @@ -393,7 +393,7 @@ async fn test_inner_join_1k_binary() { JoinType::Inner, None, ) - .run_test(&[HjSmj, NljHj], false) + .run_test(&[HjSmj], false) .await } } @@ -407,7 +407,7 @@ async fn test_left_join_1k_binary() { JoinType::Left, None, ) - .run_test(&[HjSmj, NljHj], false) + .run_test(&[HjSmj], false) .await } } @@ -421,7 +421,7 @@ async fn test_left_join_1k_binary_filtered() { JoinType::Left, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[HjSmj, NljHj], false) + .run_test(&[HjSmj], false) .await } } @@ -435,7 +435,7 @@ async fn test_right_join_1k_binary() { JoinType::Right, None, ) - .run_test(&[HjSmj, NljHj], false) + .run_test(&[HjSmj], false) .await } } @@ -449,7 +449,7 @@ async fn test_right_join_1k_binary_filtered() { JoinType::Right, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[HjSmj, NljHj], false) + .run_test(&[HjSmj], false) .await } } @@ -463,7 +463,7 @@ async fn test_full_join_1k_binary() { JoinType::Full, None, ) - .run_test(&[HjSmj, NljHj], false) + .run_test(&[HjSmj], false) .await } } @@ -477,7 +477,7 @@ async fn test_full_join_1k_binary_filtered() { JoinType::Full, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[NljHj, HjSmj], false) + .run_test(&[HjSmj], false) .await } } @@ -491,7 +491,7 @@ async fn test_left_semi_join_1k_binary() { JoinType::LeftSemi, None, ) - .run_test(&[HjSmj, NljHj], false) + .run_test(&[HjSmj], false) .await } } @@ -505,7 +505,7 @@ async fn test_left_semi_join_1k_binary_filtered() { JoinType::LeftSemi, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[HjSmj, NljHj], false) + .run_test(&[HjSmj], false) .await } } @@ -519,7 +519,7 @@ async fn test_right_semi_join_1k_binary() { JoinType::RightSemi, None, ) - .run_test(&[HjSmj, NljHj], false) + .run_test(&[HjSmj], false) .await } } @@ -533,7 +533,7 @@ async fn test_right_semi_join_1k_binary_filtered() { JoinType::RightSemi, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[HjSmj, NljHj], false) + .run_test(&[HjSmj], false) .await } } @@ -547,7 +547,7 @@ async fn test_left_anti_join_1k_binary() { JoinType::LeftAnti, None, ) - .run_test(&[HjSmj, NljHj], false) + .run_test(&[HjSmj], false) .await } } @@ -561,7 +561,7 @@ async fn test_left_anti_join_1k_binary_filtered() { JoinType::LeftAnti, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[HjSmj, NljHj], false) + .run_test(&[HjSmj], false) .await } } @@ -575,7 +575,7 @@ async fn test_right_anti_join_1k_binary() { JoinType::RightAnti, None, ) - .run_test(&[HjSmj, NljHj], false) + .run_test(&[HjSmj], false) .await } } @@ -589,7 +589,7 @@ async fn test_right_anti_join_1k_binary_filtered() { JoinType::RightAnti, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[HjSmj, NljHj], false) + .run_test(&[HjSmj], false) .await } } @@ -603,7 +603,7 @@ async fn test_left_mark_join_1k_binary() { JoinType::LeftMark, None, ) - .run_test(&[HjSmj, NljHj], false) + .run_test(&[HjSmj], false) .await } } @@ -617,7 +617,7 @@ async fn test_left_mark_join_1k_binary_filtered() { JoinType::LeftMark, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[HjSmj, NljHj], false) + .run_test(&[HjSmj], false) .await } } @@ -632,7 +632,7 @@ async fn test_right_mark_join_1k_binary() { JoinType::RightMark, None, ) - .run_test(&[HjSmj, NljHj], false) + .run_test(&[HjSmj], false) .await } } @@ -646,7 +646,7 @@ async fn test_right_mark_join_1k_binary_filtered() { JoinType::RightMark, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[HjSmj, NljHj], false) + .run_test(&[HjSmj], false) .await } } diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs index ee1ae3708961..da0d21f046da 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join/classic_join.rs @@ -38,7 +38,7 @@ use crate::handle_state; use crate::joins::piecewise_merge_join::exec::{BufferedSide, BufferedSideReadyState}; use crate::joins::piecewise_merge_join::utils::need_produce_result_in_final; use crate::joins::utils::{BuildProbeJoinMetrics, StatefulStreamResult}; -use crate::joins::utils::{compare_join_arrays, get_final_indices_from_shared_bitmap}; +use crate::joins::utils::{JoinKeyComparator, get_final_indices_from_shared_bitmap}; pub(super) enum PiecewiseMergeJoinStreamState { WaitBufferedSide, @@ -460,6 +460,14 @@ fn resolve_classic_join( let buffered_len = buffered_side.buffered_data.values().len(); let stream_values = stream_batch.compare_key_values(); + // Build comparator once for the batch pair + let cmp = JoinKeyComparator::new( + &[Arc::clone(&stream_values[0])], + &[Arc::clone(buffered_side.buffered_data.values())], + &[sort_options], + NullEquality::NullEqualsNothing, + )?; + let mut buffer_idx = batch_process_state.start_buffer_idx; let mut stream_idx = batch_process_state.start_stream_idx; @@ -475,17 +483,7 @@ fn resolve_classic_join( // in the previous stream row. for row_idx in stream_idx..stream_batch.batch.num_rows() { while buffer_idx < buffered_len { - let compare = { - let buffered_values = buffered_side.buffered_data.values(); - compare_join_arrays( - &[Arc::clone(&stream_values[0])], - row_idx, - &[Arc::clone(buffered_values)], - buffer_idx, - &[sort_options], - NullEquality::NullEqualsNothing, - )? - }; + let compare = cmp.compare(row_idx, buffer_idx); // If we find a match we append all indices and move to the next stream row index match operator { diff --git a/datafusion/physical-plan/src/joins/semi_anti_sort_merge_join/stream.rs b/datafusion/physical-plan/src/joins/semi_anti_sort_merge_join/stream.rs index 40e2022d4154..dc75651ee5b6 100644 --- a/datafusion/physical-plan/src/joins/semi_anti_sort_merge_join/stream.rs +++ b/datafusion/physical-plan/src/joins/semi_anti_sort_merge_join/stream.rs @@ -123,7 +123,7 @@ use std::sync::Arc; use std::task::{Context, Poll}; use crate::RecordBatchStream; -use crate::joins::utils::{JoinFilter, compare_join_arrays}; +use crate::joins::utils::{JoinFilter, JoinKeyComparator, compare_join_arrays}; use crate::metrics::{ BaselineMetrics, Count, ExecutionPlanMetricsSet, Gauge, MetricBuilder, }; @@ -159,48 +159,26 @@ fn evaluate_join_keys( } /// Find the first index in `key_arrays` starting from `from` where the key -/// differs from the key at `from`. Uses `compare_join_arrays` for zero-alloc -/// ordinal comparison. +/// differs from the key at `from`. Uses a pre-built `JoinKeyComparator` for +/// zero-alloc ordinal comparison without per-row type dispatch. /// /// Optimized for join workloads: checks adjacent and boundary keys before /// falling back to binary search, since most key groups are small (often 1). -fn find_key_group_end( - key_arrays: &[ArrayRef], - from: usize, - len: usize, - sort_options: &[SortOptions], - null_equality: NullEquality, -) -> Result { +fn find_key_group_end(cmp: &JoinKeyComparator, from: usize, len: usize) -> usize { let next = from + 1; if next >= len { - return Ok(len); + return len; } // Fast path: single-row group (common with unique keys). - if compare_join_arrays( - key_arrays, - from, - key_arrays, - next, - sort_options, - null_equality, - )? != Ordering::Equal - { - return Ok(next); + if cmp.compare(from, next) != Ordering::Equal { + return next; } // Check if the entire remaining batch shares this key. let last = len - 1; - if compare_join_arrays( - key_arrays, - from, - key_arrays, - last, - sort_options, - null_equality, - )? == Ordering::Equal - { - return Ok(len); + if cmp.compare(from, last) == Ordering::Equal { + return len; } // Binary search the interior: key at `next` matches, key at `last` doesn't. @@ -208,21 +186,13 @@ fn find_key_group_end( let mut hi = last; while lo < hi { let mid = lo + (hi - lo) / 2; - if compare_join_arrays( - key_arrays, - from, - key_arrays, - mid, - sort_options, - null_equality, - )? == Ordering::Equal - { + if cmp.compare(from, mid) == Ordering::Equal { lo = mid + 1; } else { hi = mid; } } - Ok(lo) + lo } /// When an outer key group spans a batch boundary, the boundary loop emits @@ -321,6 +291,14 @@ pub(crate) struct SemiAntiSortMergeJoinStream { runtime_env: Arc, inner_buffer_size: usize, + // Cached comparators — pre-built to avoid per-row type dispatch. + /// Comparator for outer vs inner key comparison + outer_inner_cmp: Option, + /// Comparator for outer self-comparison (find_key_group_end on outer) + outer_self_cmp: Option, + /// Comparator for inner self-comparison (find_key_group_end on inner) + inner_self_cmp: Option, + // True once the current outer batch has been emitted. The Equal // branch's inner loops call emit then `ready!(poll_next_outer_batch)`. // If that poll returns Pending, poll_join re-enters from the top @@ -392,6 +370,9 @@ impl SemiAntiSortMergeJoinStream { spill_manager, runtime_env, inner_buffer_size: 0, + outer_inner_cmp: None, + outer_self_cmp: None, + inner_self_cmp: None, batch_emitted: false, }) } @@ -404,6 +385,45 @@ impl SemiAntiSortMergeJoinStream { Ok(()) } + /// Get or build the outer vs inner key comparator. + fn get_outer_inner_cmp(&mut self) -> Result<&JoinKeyComparator> { + if self.outer_inner_cmp.is_none() { + self.outer_inner_cmp = Some(JoinKeyComparator::new( + &self.outer_key_arrays, + &self.inner_key_arrays, + &self.sort_options, + self.null_equality, + )?); + } + Ok(self.outer_inner_cmp.as_ref().unwrap()) + } + + /// Get or build the outer self-comparison comparator. + fn get_outer_self_cmp(&mut self) -> Result<&JoinKeyComparator> { + if self.outer_self_cmp.is_none() { + self.outer_self_cmp = Some(JoinKeyComparator::new( + &self.outer_key_arrays, + &self.outer_key_arrays, + &self.sort_options, + self.null_equality, + )?); + } + Ok(self.outer_self_cmp.as_ref().unwrap()) + } + + /// Get or build the inner self-comparison comparator. + fn get_inner_self_cmp(&mut self) -> Result<&JoinKeyComparator> { + if self.inner_self_cmp.is_none() { + self.inner_self_cmp = Some(JoinKeyComparator::new( + &self.inner_key_arrays, + &self.inner_key_arrays, + &self.sort_options, + self.null_equality, + )?); + } + Ok(self.inner_self_cmp.as_ref().unwrap()) + } + /// Spill the in-memory inner key buffer to disk and clear it. fn spill_inner_key_buffer(&mut self) -> Result<()> { let spill_file = self @@ -447,6 +467,8 @@ impl SemiAntiSortMergeJoinStream { self.outer_batch = Some(batch); self.outer_offset = 0; self.outer_key_arrays = keys; + self.outer_inner_cmp = None; + self.outer_self_cmp = None; self.batch_emitted = false; self.matched = BooleanBufferBuilder::new(batch_num_rows); self.matched.append_n(batch_num_rows, false); @@ -473,6 +495,8 @@ impl SemiAntiSortMergeJoinStream { self.inner_batch = Some(batch); self.inner_offset = 0; self.inner_key_arrays = keys; + self.outer_inner_cmp = None; + self.inner_self_cmp = None; return Poll::Ready(Ok(true)); } } @@ -513,13 +537,12 @@ impl SemiAntiSortMergeJoinStream { let outer_batch = self.outer_batch.as_ref().unwrap(); let num_outer = outer_batch.num_rows(); + self.get_outer_self_cmp()?; let outer_group_end = find_key_group_end( - &self.outer_key_arrays, + self.outer_self_cmp.as_ref().unwrap(), self.outer_offset, num_outer, - &self.sort_options, - self.null_equality, - )?; + ); for i in self.outer_offset..outer_group_end { self.matched.set_bit(i, true); @@ -542,13 +565,12 @@ impl SemiAntiSortMergeJoinStream { }; let num_inner = inner_batch.num_rows(); + self.get_inner_self_cmp()?; let group_end = find_key_group_end( - &self.inner_key_arrays, + self.inner_self_cmp.as_ref().unwrap(), self.inner_offset, num_inner, - &self.sort_options, - self.null_equality, - )?; + ); if group_end < num_inner { self.inner_offset = group_end; @@ -600,20 +622,19 @@ impl SemiAntiSortMergeJoinStream { } loop { - let inner_batch = match &self.inner_batch { - Some(b) => b, - None => return Poll::Ready(Ok(true)), - }; - let num_inner = inner_batch.num_rows(); + if self.inner_batch.is_none() { + return Poll::Ready(Ok(true)); + } + let num_inner = self.inner_batch.as_ref().unwrap().num_rows(); + self.get_inner_self_cmp()?; let group_end = find_key_group_end( - &self.inner_key_arrays, + self.inner_self_cmp.as_ref().unwrap(), self.inner_offset, num_inner, - &self.sort_options, - self.null_equality, - )?; + ); if !resume_from_poll { + let inner_batch = self.inner_batch.as_ref().unwrap(); let slice = inner_batch.slice(self.inner_offset, group_end - self.inner_offset); self.inner_buffer_size += slice.get_array_memory_size(); @@ -677,6 +698,7 @@ impl SemiAntiSortMergeJoinStream { /// key group, evaluates the filter against the outer key group and ORs /// the results into the matched bitset using u64-chunked bitwise ops. fn process_key_match_with_filter(&mut self) -> Result<()> { + self.get_outer_self_cmp()?; let filter = self.filter.as_ref().unwrap(); let outer_batch = self.outer_batch.as_ref().unwrap(); let num_outer = outer_batch.num_rows(); @@ -696,12 +718,10 @@ impl SemiAntiSortMergeJoinStream { ); let outer_group_end = find_key_group_end( - &self.outer_key_arrays, + self.outer_self_cmp.as_ref().unwrap(), self.outer_offset, num_outer, - &self.sort_options, - self.null_equality, - )?; + ); let outer_group_len = outer_group_end - self.outer_offset; let outer_slice = outer_batch.slice(self.outer_offset, outer_group_len); @@ -917,34 +937,30 @@ impl SemiAntiSortMergeJoinStream { } // 4. Compare keys at current positions - let cmp = compare_join_arrays( - &self.outer_key_arrays, - self.outer_offset, - &self.inner_key_arrays, - self.inner_offset, - &self.sort_options, - self.null_equality, - )?; + self.get_outer_inner_cmp()?; + let cmp = self + .outer_inner_cmp + .as_ref() + .unwrap() + .compare(self.outer_offset, self.inner_offset); match cmp { Ordering::Less => { + self.get_outer_self_cmp()?; let group_end = find_key_group_end( - &self.outer_key_arrays, + self.outer_self_cmp.as_ref().unwrap(), self.outer_offset, num_outer, - &self.sort_options, - self.null_equality, - )?; + ); self.outer_offset = group_end; } Ordering::Greater => { + self.get_inner_self_cmp()?; let group_end = find_key_group_end( - &self.inner_key_arrays, + self.inner_self_cmp.as_ref().unwrap(), self.inner_offset, num_inner, - &self.sort_options, - self.null_equality, - )?; + ); if group_end >= num_inner { let saved_keys = slice_keys(&self.inner_key_arrays, num_inner - 1); diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs b/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs index 4dcbe1f64799..bd5ef6bdb3ed 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs @@ -38,7 +38,7 @@ use crate::joins::sort_merge_join::filter::{ get_filter_columns, needs_deferred_filtering, }; use crate::joins::sort_merge_join::metrics::SortMergeJoinMetrics; -use crate::joins::utils::{JoinFilter, compare_join_arrays}; +use crate::joins::utils::{JoinFilter, JoinKeyComparator}; use crate::metrics::RecordOutput; use crate::spill::spill_manager::SpillManager; use crate::{PhysicalExpr, RecordBatchStream, SendableRecordBatchStream}; @@ -48,11 +48,11 @@ use arrow::compute::{ self, BatchCoalescer, SortOptions, concat_batches, filter_record_batch, is_not_null, take, take_arrays, }; -use arrow::datatypes::{DataType, SchemaRef, TimeUnit}; +use arrow::datatypes::SchemaRef; use arrow::ipc::reader::StreamReader; use datafusion_common::config::SpillCompression; use datafusion_common::{ - HashSet, JoinType, NullEquality, Result, exec_err, internal_err, not_impl_err, + HashSet, JoinType, NullEquality, Result, exec_err, internal_err, }; use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::MemoryReservation; @@ -354,6 +354,15 @@ pub(super) struct SortMergeJoinStream { /// Manages the process of spilling and reading back intermediate data pub spill_manager: SpillManager, + // ======================================================================== + // CACHED COMPARATORS: + // Pre-built comparators to avoid per-row type dispatch in hot loops. + // ======================================================================== + /// Comparator for streamed vs buffered head batch key comparison + pub streamed_buffered_cmp: Option, + /// Comparator for buffered head vs tail batch equality check + pub buffered_equality_cmp: Option, + // ======================================================================== // EXECUTION RESOURCES: // Fields related to managing execution resources and monitoring performance. @@ -793,10 +802,45 @@ impl SortMergeJoinStream { reservation, runtime_env, spill_manager, + streamed_buffered_cmp: None, + buffered_equality_cmp: None, streamed_batch_counter: AtomicUsize::new(0), }) } + /// Build a comparator for streamed vs buffered head batch keys. + fn rebuild_streamed_buffered_cmp(&mut self) -> Result<()> { + if self.streamed_batch.join_arrays.is_empty() + || !self.buffered_data.has_buffered_rows() + { + self.streamed_buffered_cmp = None; + return Ok(()); + } + self.streamed_buffered_cmp = Some(JoinKeyComparator::new( + &self.streamed_batch.join_arrays, + &self.buffered_data.head_batch().join_arrays, + &self.sort_options, + self.null_equality, + )?); + Ok(()) + } + + /// Build a comparator for buffered head vs tail batch equality. + fn rebuild_buffered_equality_cmp(&mut self) -> Result<()> { + if self.buffered_data.batches.is_empty() { + self.buffered_equality_cmp = None; + return Ok(()); + } + self.buffered_equality_cmp = Some(JoinKeyComparator::new( + &self.buffered_data.head_batch().join_arrays, + &self.buffered_data.tail_batch().join_arrays, + &self.sort_options, + // is_join_arrays_equal treats both-null as equal + NullEquality::NullEqualsNull, + )?); + Ok(()) + } + /// Number of unfrozen output pairs (used to decide when to freeze + output) fn num_unfrozen_pairs(&self) -> usize { self.streamed_batch.num_output_rows() @@ -860,6 +904,7 @@ impl SortMergeJoinStream { self.join_metrics.input_rows().add(batch.num_rows()); self.streamed_batch = StreamedBatch::new(batch, &self.on_streamed); + self.rebuild_streamed_buffered_cmp()?; // Every incoming streaming batch should have its unique id // Check `JoinedRecordBatches.self.streamed_batch_counter` documentation self.streamed_batch_counter @@ -927,6 +972,7 @@ impl SortMergeJoinStream { match &self.buffered_state { BufferedState::Init => { // pop previous buffered batches + let mut head_changed = false; while !self.buffered_data.batches.is_empty() { let head_batch = self.buffered_data.head_batch(); // If the head batch is fully processed, dequeue it and produce output of it. @@ -937,6 +983,7 @@ impl SortMergeJoinStream { { self.produce_buffered_not_matched(&mut buffered_batch)?; self.free_reservation(&buffered_batch)?; + head_changed = true; } } else { // If the head batch is not fully processed, break the loop. @@ -944,6 +991,10 @@ impl SortMergeJoinStream { break; } } + if head_changed { + self.streamed_buffered_cmp = None; + self.buffered_equality_cmp = None; + } if self.buffered_data.batches.is_empty() { self.buffered_state = BufferedState::PollingFirst; } else { @@ -970,6 +1021,7 @@ impl SortMergeJoinStream { BufferedBatch::new(batch, 0..1, &self.on_buffered); self.allocate_reservation(buffered_batch)?; + self.streamed_buffered_cmp = None; self.buffered_state = BufferedState::PollingRest; } } @@ -978,15 +1030,16 @@ impl SortMergeJoinStream { if self.buffered_data.tail_batch().range.end < self.buffered_data.tail_batch().num_rows { + if self.buffered_equality_cmp.is_none() { + self.rebuild_buffered_equality_cmp()?; + } while self.buffered_data.tail_batch().range.end < self.buffered_data.tail_batch().num_rows { - if is_join_arrays_equal( - &self.buffered_data.head_batch().join_arrays, + if self.buffered_equality_cmp.as_ref().unwrap().is_equal( self.buffered_data.head_batch().range.start, - &self.buffered_data.tail_batch().join_arrays, self.buffered_data.tail_batch().range.end, - )? { + ) { self.buffered_data.tail_batch_mut().range.end += 1; } else { self.buffered_state = BufferedState::Ready; @@ -1012,6 +1065,7 @@ impl SortMergeJoinStream { &self.on_buffered, ); self.allocate_reservation(buffered_batch)?; + self.buffered_equality_cmp = None; } } } @@ -1028,7 +1082,7 @@ impl SortMergeJoinStream { } /// Get comparison result of streamed row and buffered batches - fn compare_streamed_buffered(&self) -> Result { + fn compare_streamed_buffered(&mut self) -> Result { if self.streamed_state == StreamedState::Exhausted { return Ok(Ordering::Greater); } @@ -1036,14 +1090,13 @@ impl SortMergeJoinStream { return Ok(Ordering::Less); } - compare_join_arrays( - &self.streamed_batch.join_arrays, + if self.streamed_buffered_cmp.is_none() { + self.rebuild_streamed_buffered_cmp()?; + } + Ok(self.streamed_buffered_cmp.as_ref().unwrap().compare( self.streamed_batch.idx, - &self.buffered_data.head_batch().join_arrays, self.buffered_data.head_batch().range.start, - &self.sort_options, - self.null_equality, - ) + )) } /// Produce join and fill output buffer until reaching target batch size @@ -1763,78 +1816,3 @@ fn join_arrays(batch: &RecordBatch, on_column: &[PhysicalExprRef]) -> Vec Result { - let mut is_equal = true; - for (left_array, right_array) in left_arrays.iter().zip(right_arrays) { - macro_rules! compare_value { - ($T:ty) => {{ - match (left_array.is_null(left), right_array.is_null(right)) { - (false, false) => { - let left_array = - left_array.as_any().downcast_ref::<$T>().unwrap(); - let right_array = - right_array.as_any().downcast_ref::<$T>().unwrap(); - if left_array.value(left) != right_array.value(right) { - is_equal = false; - } - } - (true, false) => is_equal = false, - (false, true) => is_equal = false, - _ => {} - } - }}; - } - - match left_array.data_type() { - DataType::Null => {} - DataType::Boolean => compare_value!(BooleanArray), - DataType::Int8 => compare_value!(Int8Array), - DataType::Int16 => compare_value!(Int16Array), - DataType::Int32 => compare_value!(Int32Array), - DataType::Int64 => compare_value!(Int64Array), - DataType::UInt8 => compare_value!(UInt8Array), - DataType::UInt16 => compare_value!(UInt16Array), - DataType::UInt32 => compare_value!(UInt32Array), - DataType::UInt64 => compare_value!(UInt64Array), - DataType::Float32 => compare_value!(Float32Array), - DataType::Float64 => compare_value!(Float64Array), - DataType::Utf8 => compare_value!(StringArray), - DataType::Utf8View => compare_value!(StringViewArray), - DataType::LargeUtf8 => compare_value!(LargeStringArray), - DataType::Binary => compare_value!(BinaryArray), - DataType::BinaryView => compare_value!(BinaryViewArray), - DataType::FixedSizeBinary(_) => compare_value!(FixedSizeBinaryArray), - DataType::LargeBinary => compare_value!(LargeBinaryArray), - DataType::Decimal32(..) => compare_value!(Decimal32Array), - DataType::Decimal64(..) => compare_value!(Decimal64Array), - DataType::Decimal128(..) => compare_value!(Decimal128Array), - DataType::Decimal256(..) => compare_value!(Decimal256Array), - DataType::Timestamp(time_unit, None) => match time_unit { - TimeUnit::Second => compare_value!(TimestampSecondArray), - TimeUnit::Millisecond => compare_value!(TimestampMillisecondArray), - TimeUnit::Microsecond => compare_value!(TimestampMicrosecondArray), - TimeUnit::Nanosecond => compare_value!(TimestampNanosecondArray), - }, - DataType::Date32 => compare_value!(Date32Array), - DataType::Date64 => compare_value!(Date64Array), - dt => { - return not_impl_err!( - "Unsupported data type in sort merge join comparator: {}", - dt - ); - } - } - if !is_equal { - return Ok(false); - } - } - Ok(true) -} diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 47cb118aee2b..055177d33cf3 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -58,6 +58,7 @@ use arrow::datatypes::{ ArrowNativeType, Field, Schema, SchemaBuilder, UInt32Type, UInt64Type, }; use arrow_ord::cmp::not_distinct; +use arrow_ord::ord::{DynComparator, make_comparator}; use arrow_schema::{ArrowError, DataType, SortOptions, TimeUnit}; use datafusion_common::cast::as_boolean_array; use datafusion_common::hash_utils::RandomState; @@ -1827,6 +1828,101 @@ fn eq_dyn_null( } } +/// Pre-built comparator for join key columns that eliminates per-row type +/// dispatch. Wraps `arrow_ord::ord::DynComparator` closures built once per +/// batch pair, used for all row comparisons within those batches. +pub struct JoinKeyComparator { + comparators: Vec, + /// Only populated when null_equality == NullEqualsNothing + left_nulls: Vec>, + right_nulls: Vec>, + null_equality: NullEquality, +} + +impl JoinKeyComparator { + /// Build comparators for each join key column pair. The `sort_options` + /// slice must have the same length as the array slices. + pub fn new( + left_arrays: &[ArrayRef], + right_arrays: &[ArrayRef], + sort_options: &[SortOptions], + null_equality: NullEquality, + ) -> Result { + let comparators = left_arrays + .iter() + .zip(right_arrays.iter()) + .zip(sort_options.iter()) + .map(|((l, r), opts)| Ok(make_comparator(l.as_ref(), r.as_ref(), *opts)?)) + .collect::>>()?; + + let (left_nulls, right_nulls) = + if null_equality == NullEquality::NullEqualsNothing { + let ln = left_arrays + .iter() + .map(|a| a.logical_nulls().filter(|n| n.null_count() > 0)) + .collect(); + let rn = right_arrays + .iter() + .map(|a| a.logical_nulls().filter(|n| n.null_count() > 0)) + .collect(); + (ln, rn) + } else { + (vec![], vec![]) + }; + + Ok(Self { + comparators, + left_nulls, + right_nulls, + null_equality, + }) + } + + /// Compare row `left` (in the left arrays) with row `right` (in the right + /// arrays). Returns the lexicographic ordering across all key columns. + #[inline] + pub fn compare(&self, left: usize, right: usize) -> Ordering { + if self.null_equality == NullEquality::NullEqualsNothing { + for (idx, cmp_fn) in self.comparators.iter().enumerate() { + // Override both-null: make_comparator returns Equal but + // NullEqualsNothing semantics require Less. + if let (Some(ln), Some(rn)) = + (&self.left_nulls[idx], &self.right_nulls[idx]) + && ln.is_null(left) + && rn.is_null(right) + { + return Ordering::Less; + } + let ord = cmp_fn(left, right); + if ord != Ordering::Equal { + return ord; + } + } + } else { + for cmp_fn in &self.comparators { + let ord = cmp_fn(left, right); + if ord != Ordering::Equal { + return ord; + } + } + } + Ordering::Equal + } + + /// Check equality of row `left` (in the left arrays) with row `right` + /// (in the right arrays). Both-null is treated as equal regardless of + /// `null_equality` — this matches `is_join_arrays_equal` semantics. + #[inline] + pub fn is_equal(&self, left: usize, right: usize) -> bool { + for cmp_fn in &self.comparators { + if cmp_fn(left, right) != Ordering::Equal { + return false; + } + } + true + } +} + /// Get comparison result of two rows of join arrays pub fn compare_join_arrays( left_arrays: &[ArrayRef], @@ -2954,4 +3050,123 @@ mod tests { let result = max_distinct_count(&num_rows, &stats); assert_eq!(result, Exact(0)); } + + #[test] + fn test_join_key_comparator_multi_column() { + let left_a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 2, 3])); + let left_b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d"])); + let right_a: ArrayRef = Arc::new(Int32Array::from(vec![2, 2, 3, 4])); + let right_b: ArrayRef = Arc::new(StringArray::from(vec!["b", "d", "a", "a"])); + + let opts = vec![SortOptions::default(), SortOptions::default()]; + let cmp = JoinKeyComparator::new( + &[left_a, left_b], + &[right_a, right_b], + &opts, + NullEquality::NullEqualsNull, + ) + .unwrap(); + + // left[0]=(1,"a") vs right[0]=(2,"b") → Less (first column) + assert_eq!(cmp.compare(0, 0), Ordering::Less); + // left[1]=(2,"b") vs right[0]=(2,"b") → Equal + assert_eq!(cmp.compare(1, 0), Ordering::Equal); + assert!(cmp.is_equal(1, 0)); + // left[2]=(2,"c") vs right[1]=(2,"d") → Less (second column) + assert_eq!(cmp.compare(2, 1), Ordering::Less); + // left[3]=(3,"d") vs right[0]=(2,"b") → Greater + assert_eq!(cmp.compare(3, 0), Ordering::Greater); + } + + #[test] + fn test_join_key_comparator_null_equals_null() { + let left: ArrayRef = + Arc::new(Int32Array::from(vec![Some(1), None, None, Some(2)])); + let right: ArrayRef = + Arc::new(Int32Array::from(vec![None, None, Some(1), Some(2)])); + + let opts = vec![SortOptions { + descending: false, + nulls_first: true, + }]; + let cmp = JoinKeyComparator::new( + &[left], + &[right], + &opts, + NullEquality::NullEqualsNull, + ) + .unwrap(); + + // left[1]=NULL vs right[1]=NULL → Equal (NullEqualsNull) + assert_eq!(cmp.compare(1, 1), Ordering::Equal); + assert!(cmp.is_equal(1, 1)); + // left[0]=1 vs right[0]=NULL → Greater (nulls_first, non-null > null) + assert_eq!(cmp.compare(0, 0), Ordering::Greater); + // left[3]=2 vs right[3]=2 → Equal + assert_eq!(cmp.compare(3, 3), Ordering::Equal); + } + + #[test] + fn test_join_key_comparator_null_equals_nothing() { + let left: ArrayRef = + Arc::new(Int32Array::from(vec![Some(1), None, None, Some(2)])); + let right: ArrayRef = + Arc::new(Int32Array::from(vec![None, None, Some(1), Some(2)])); + + let opts = vec![SortOptions { + descending: false, + nulls_first: true, + }]; + let cmp = JoinKeyComparator::new( + &[left], + &[right], + &opts, + NullEquality::NullEqualsNothing, + ) + .unwrap(); + + // left[1]=NULL vs right[1]=NULL → Less (NullEqualsNothing) + assert_eq!(cmp.compare(1, 1), Ordering::Less); + // left[0]=1 vs right[0]=NULL → Greater (nulls_first) + assert_eq!(cmp.compare(0, 0), Ordering::Greater); + // left[3]=2 vs right[3]=2 → Equal + assert_eq!(cmp.compare(3, 3), Ordering::Equal); + + // is_equal always treats both-null as equal + assert!(cmp.is_equal(1, 1)); + } + + #[test] + fn test_join_key_comparator_nulls_first_ordering() { + let left: ArrayRef = Arc::new(Int32Array::from(vec![None, Some(1)])); + let right: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None])); + + // nulls_first = true: null < non-null + let cmp_nf = JoinKeyComparator::new( + &[Arc::clone(&left)], + &[Arc::clone(&right)], + &[SortOptions { + descending: false, + nulls_first: true, + }], + NullEquality::NullEqualsNull, + ) + .unwrap(); + assert_eq!(cmp_nf.compare(0, 0), Ordering::Less); + assert_eq!(cmp_nf.compare(1, 1), Ordering::Greater); + + // nulls_first = false: null > non-null + let cmp_nl = JoinKeyComparator::new( + &[left], + &[right], + &[SortOptions { + descending: false, + nulls_first: false, + }], + NullEquality::NullEqualsNull, + ) + .unwrap(); + assert_eq!(cmp_nl.compare(0, 0), Ordering::Greater); + assert_eq!(cmp_nl.compare(1, 1), Ordering::Less); + } } From b6ac26e79979429bd96c25b3ad20c93e6d95bbd9 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Thu, 26 Mar 2026 21:34:06 -0400 Subject: [PATCH 2/8] fix inadvertent commit --- datafusion/core/tests/fuzz_cases/join_fuzz.rs | 80 +++++++++---------- 1 file changed, 40 insertions(+), 40 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index 3b61418aab73..669b98e39fec 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -98,7 +98,7 @@ async fn test_inner_join_1k_filtered() { JoinType::Inner, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[HjSmj], false) + .run_test(&[HjSmj, NljHj], false) .await } } @@ -112,7 +112,7 @@ async fn test_inner_join_1k() { JoinType::Inner, None, ) - .run_test(&[HjSmj], false) + .run_test(&[HjSmj, NljHj], false) .await } } @@ -126,7 +126,7 @@ async fn test_left_join_1k() { JoinType::Left, None, ) - .run_test(&[HjSmj], false) + .run_test(&[HjSmj, NljHj], false) .await } } @@ -140,7 +140,7 @@ async fn test_left_join_1k_filtered() { JoinType::Left, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[HjSmj], false) + .run_test(&[HjSmj, NljHj], false) .await } } @@ -154,7 +154,7 @@ async fn test_right_join_1k() { JoinType::Right, None, ) - .run_test(&[HjSmj], false) + .run_test(&[HjSmj, NljHj], false) .await } } @@ -168,7 +168,7 @@ async fn test_right_join_1k_filtered() { JoinType::Right, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[HjSmj], false) + .run_test(&[HjSmj, NljHj], false) .await } } @@ -182,7 +182,7 @@ async fn test_full_join_1k() { JoinType::Full, None, ) - .run_test(&[HjSmj], false) + .run_test(&[HjSmj, NljHj], false) .await } } @@ -196,7 +196,7 @@ async fn test_full_join_1k_filtered() { JoinType::Full, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[HjSmj], false) + .run_test(&[NljHj, HjSmj], false) .await } } @@ -210,7 +210,7 @@ async fn test_left_semi_join_1k() { JoinType::LeftSemi, None, ) - .run_test(&[HjSmj], false) + .run_test(&[HjSmj, NljHj], false) .await } } @@ -224,7 +224,7 @@ async fn test_left_semi_join_1k_filtered() { JoinType::LeftSemi, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[HjSmj], false) + .run_test(&[HjSmj, NljHj], false) .await } } @@ -238,7 +238,7 @@ async fn test_right_semi_join_1k() { JoinType::RightSemi, None, ) - .run_test(&[HjSmj], false) + .run_test(&[HjSmj, NljHj], false) .await } } @@ -252,7 +252,7 @@ async fn test_right_semi_join_1k_filtered() { JoinType::RightSemi, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[HjSmj], false) + .run_test(&[HjSmj, NljHj], false) .await } } @@ -266,7 +266,7 @@ async fn test_left_anti_join_1k() { JoinType::LeftAnti, None, ) - .run_test(&[HjSmj], false) + .run_test(&[HjSmj, NljHj], false) .await } } @@ -280,7 +280,7 @@ async fn test_left_anti_join_1k_filtered() { JoinType::LeftAnti, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[HjSmj], false) + .run_test(&[HjSmj, NljHj], false) .await } } @@ -294,7 +294,7 @@ async fn test_right_anti_join_1k() { JoinType::RightAnti, None, ) - .run_test(&[HjSmj], false) + .run_test(&[HjSmj, NljHj], false) .await } } @@ -308,7 +308,7 @@ async fn test_right_anti_join_1k_filtered() { JoinType::RightAnti, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[HjSmj], false) + .run_test(&[HjSmj, NljHj], false) .await } } @@ -322,7 +322,7 @@ async fn test_left_mark_join_1k() { JoinType::LeftMark, None, ) - .run_test(&[HjSmj], false) + .run_test(&[HjSmj, NljHj], false) .await } } @@ -336,7 +336,7 @@ async fn test_left_mark_join_1k_filtered() { JoinType::LeftMark, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[HjSmj], false) + .run_test(&[HjSmj, NljHj], false) .await } } @@ -351,7 +351,7 @@ async fn test_right_mark_join_1k() { JoinType::RightMark, None, ) - .run_test(&[HjSmj], false) + .run_test(&[HjSmj, NljHj], false) .await } } @@ -365,7 +365,7 @@ async fn test_right_mark_join_1k_filtered() { JoinType::RightMark, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[HjSmj], false) + .run_test(&[HjSmj, NljHj], false) .await } } @@ -379,7 +379,7 @@ async fn test_inner_join_1k_binary_filtered() { JoinType::Inner, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[HjSmj], false) + .run_test(&[HjSmj, NljHj], false) .await } } @@ -393,7 +393,7 @@ async fn test_inner_join_1k_binary() { JoinType::Inner, None, ) - .run_test(&[HjSmj], false) + .run_test(&[HjSmj, NljHj], false) .await } } @@ -407,7 +407,7 @@ async fn test_left_join_1k_binary() { JoinType::Left, None, ) - .run_test(&[HjSmj], false) + .run_test(&[HjSmj, NljHj], false) .await } } @@ -421,7 +421,7 @@ async fn test_left_join_1k_binary_filtered() { JoinType::Left, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[HjSmj], false) + .run_test(&[HjSmj, NljHj], false) .await } } @@ -435,7 +435,7 @@ async fn test_right_join_1k_binary() { JoinType::Right, None, ) - .run_test(&[HjSmj], false) + .run_test(&[HjSmj, NljHj], false) .await } } @@ -449,7 +449,7 @@ async fn test_right_join_1k_binary_filtered() { JoinType::Right, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[HjSmj], false) + .run_test(&[HjSmj, NljHj], false) .await } } @@ -463,7 +463,7 @@ async fn test_full_join_1k_binary() { JoinType::Full, None, ) - .run_test(&[HjSmj], false) + .run_test(&[HjSmj, NljHj], false) .await } } @@ -477,7 +477,7 @@ async fn test_full_join_1k_binary_filtered() { JoinType::Full, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[HjSmj], false) + .run_test(&[NljHj, HjSmj], false) .await } } @@ -491,7 +491,7 @@ async fn test_left_semi_join_1k_binary() { JoinType::LeftSemi, None, ) - .run_test(&[HjSmj], false) + .run_test(&[HjSmj, NljHj], false) .await } } @@ -505,7 +505,7 @@ async fn test_left_semi_join_1k_binary_filtered() { JoinType::LeftSemi, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[HjSmj], false) + .run_test(&[HjSmj, NljHj], false) .await } } @@ -519,7 +519,7 @@ async fn test_right_semi_join_1k_binary() { JoinType::RightSemi, None, ) - .run_test(&[HjSmj], false) + .run_test(&[HjSmj, NljHj], false) .await } } @@ -533,7 +533,7 @@ async fn test_right_semi_join_1k_binary_filtered() { JoinType::RightSemi, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[HjSmj], false) + .run_test(&[HjSmj, NljHj], false) .await } } @@ -547,7 +547,7 @@ async fn test_left_anti_join_1k_binary() { JoinType::LeftAnti, None, ) - .run_test(&[HjSmj], false) + .run_test(&[HjSmj, NljHj], false) .await } } @@ -561,7 +561,7 @@ async fn test_left_anti_join_1k_binary_filtered() { JoinType::LeftAnti, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[HjSmj], false) + .run_test(&[HjSmj, NljHj], false) .await } } @@ -575,7 +575,7 @@ async fn test_right_anti_join_1k_binary() { JoinType::RightAnti, None, ) - .run_test(&[HjSmj], false) + .run_test(&[HjSmj, NljHj], false) .await } } @@ -589,7 +589,7 @@ async fn test_right_anti_join_1k_binary_filtered() { JoinType::RightAnti, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[HjSmj], false) + .run_test(&[HjSmj, NljHj], false) .await } } @@ -603,7 +603,7 @@ async fn test_left_mark_join_1k_binary() { JoinType::LeftMark, None, ) - .run_test(&[HjSmj], false) + .run_test(&[HjSmj, NljHj], false) .await } } @@ -617,7 +617,7 @@ async fn test_left_mark_join_1k_binary_filtered() { JoinType::LeftMark, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[HjSmj], false) + .run_test(&[HjSmj, NljHj], false) .await } } @@ -632,7 +632,7 @@ async fn test_right_mark_join_1k_binary() { JoinType::RightMark, None, ) - .run_test(&[HjSmj], false) + .run_test(&[HjSmj, NljHj], false) .await } } @@ -646,7 +646,7 @@ async fn test_right_mark_join_1k_binary_filtered() { JoinType::RightMark, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[HjSmj], false) + .run_test(&[HjSmj, NljHj], false) .await } } From a3a832c96e752ec84f20380bf6062e62ac562abf Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Fri, 27 Mar 2026 08:39:43 -0400 Subject: [PATCH 3/8] scale benchmarks --- benchmarks/src/smj.rs | 56 +++++++++++++++++++++---------------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/benchmarks/src/smj.rs b/benchmarks/src/smj.rs index d782762a1be4..c6f56ffcc997 100644 --- a/benchmarks/src/smj.rs +++ b/benchmarks/src/smj.rs @@ -60,13 +60,13 @@ pub struct RunOpt { /// - Key cardinality (rows per key) /// - Filter selectivity (if applicable) const SMJ_QUERIES: &[&str] = &[ - // Q1: INNER 100K x 100K | 1:1 + // Q1: INNER 1M x 1M | 1:1 r#" WITH t1_sorted AS ( - SELECT value as key FROM range(100000) ORDER BY value + SELECT value as key FROM range(1000000) ORDER BY value ), t2_sorted AS ( - SELECT value as key FROM range(100000) ORDER BY value + SELECT value as key FROM range(1000000) ORDER BY value ) SELECT t1_sorted.key as k1, t2_sorted.key as k2 FROM t1_sorted JOIN t2_sorted ON t1_sorted.key = t2_sorted.key @@ -164,16 +164,16 @@ const SMJ_QUERIES: &[&str] = &[ FROM t1_sorted LEFT JOIN t2_sorted ON t1_sorted.key = t2_sorted.key WHERE t2_sorted.data IS NULL OR t2_sorted.data % 2 = 0 "#, - // Q8: FULL 100K x 100K | 1:10 + // Q8: FULL 1M x 1M | 1:10 r#" WITH t1_sorted AS ( - SELECT value % 10000 as key, value as data - FROM range(100000) + SELECT value % 100000 as key, value as data + FROM range(1000000) ORDER BY key, data ), t2_sorted AS ( - SELECT value % 12500 as key, value as data - FROM range(100000) + SELECT value % 125000 as key, value as data + FROM range(1000000) ORDER BY key, data ) SELECT t1_sorted.key as k1, t1_sorted.data as d1, @@ -199,16 +199,16 @@ const SMJ_QUERIES: &[&str] = &[ OR t1_sorted.data <> t2_sorted.data) AND (t1_sorted.data IS NULL OR t1_sorted.data % 10 = 0) "#, - // Q10: LEFT SEMI 100K x 1M | 1:10 + // Q10: LEFT SEMI 1M x 10M | 1:10 r#" WITH t1_sorted AS ( - SELECT value % 10000 as key, value as data - FROM range(100000) + SELECT value % 100000 as key, value as data + FROM range(1000000) ORDER BY key, data ), t2_sorted AS ( - SELECT value % 10000 as key - FROM range(1000000) + SELECT value % 100000 as key + FROM range(10000000) ORDER BY key ) SELECT t1_sorted.key, t1_sorted.data @@ -281,16 +281,16 @@ const SMJ_QUERIES: &[&str] = &[ AND t2_sorted.data % 10 <> 0 ) "#, - // Q14: LEFT ANTI 100K x 1M | 1:10 + // Q14: LEFT ANTI 1M x 10M | 1:10 r#" WITH t1_sorted AS ( - SELECT value % 10500 as key, value as data - FROM range(100000) + SELECT value % 105000 as key, value as data + FROM range(1000000) ORDER BY key, data ), t2_sorted AS ( - SELECT value % 10000 as key - FROM range(1000000) + SELECT value % 100000 as key + FROM range(10000000) ORDER BY key ) SELECT t1_sorted.key, t1_sorted.data @@ -300,16 +300,16 @@ const SMJ_QUERIES: &[&str] = &[ WHERE t2_sorted.key = t1_sorted.key ) "#, - // Q15: LEFT ANTI 100K x 1M | 1:10 | partial match + // Q15: LEFT ANTI 1M x 10M | 1:10 | partial match r#" WITH t1_sorted AS ( - SELECT value % 12000 as key, value as data - FROM range(100000) + SELECT value % 120000 as key, value as data + FROM range(1000000) ORDER BY key, data ), t2_sorted AS ( - SELECT value % 10000 as key - FROM range(1000000) + SELECT value % 100000 as key + FROM range(10000000) ORDER BY key ) SELECT t1_sorted.key, t1_sorted.data @@ -319,16 +319,16 @@ const SMJ_QUERIES: &[&str] = &[ WHERE t2_sorted.key = t1_sorted.key ) "#, - // Q16: LEFT ANTI 100K x 100K | 1:1 | stress + // Q16: LEFT ANTI 1M x 1M | 1:1 | stress r#" WITH t1_sorted AS ( - SELECT value % 11000 as key, value as data - FROM range(100000) + SELECT value % 110000 as key, value as data + FROM range(1000000) ORDER BY key, data ), t2_sorted AS ( - SELECT value % 10000 as key - FROM range(100000) + SELECT value % 100000 as key + FROM range(1000000) ORDER BY key ) SELECT t1_sorted.key, t1_sorted.data From 6a0351dd7edcfc6447c7f7df2615f1827546e59f Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Wed, 8 Apr 2026 15:05:51 -0400 Subject: [PATCH 4/8] cargo fmt --- .../src/joins/sort_merge_join/materializing_stream.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs b/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs index 4a47ee31f7b3..bf0117d2e364 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs @@ -51,9 +51,7 @@ use arrow::compute::{ use arrow::datatypes::SchemaRef; use arrow::ipc::reader::StreamReader; use datafusion_common::cast::as_uint64_array; -use datafusion_common::{ - JoinType, NullEquality, Result, exec_err, internal_err, -}; +use datafusion_common::{JoinType, NullEquality, Result, exec_err, internal_err}; use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::MemoryReservation; use datafusion_execution::runtime_env::RuntimeEnv; From d32d5ae888e9244a263bded3fbb9ed58bd0e304e Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Wed, 8 Apr 2026 15:37:51 -0400 Subject: [PATCH 5/8] Remove branch in compare() by instantiating closures based on NullEquality at construction. --- datafusion/physical-plan/src/joins/utils.rs | 96 +++++++++------------ 1 file changed, 43 insertions(+), 53 deletions(-) diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index a7023fcd7418..df03f8bd1655 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -1826,12 +1826,21 @@ fn eq_dyn_null( /// Pre-built comparator for join key columns that eliminates per-row type /// dispatch. Wraps `arrow_ord::ord::DynComparator` closures built once per /// batch pair, used for all row comparisons within those batches. +/// +/// Null handling is baked into the closures at construction time: +/// - `NullEqualsNull`: `make_comparator` returns `Equal` for both-null, which +/// is the desired behavior. Closures are used as-is. +/// - `NullEqualsNothing`: columns where both sides contain nulls get a wrapper +/// that returns `Less` for both-null. Columns where one side has no nulls +/// skip the wrapper since both-null is impossible. +/// +/// Because `NullEqualsNothing` wraps comparators to return `Less` for +/// both-null, `is_equal` will return `false` for both-null rows when that +/// mode is active. Callers needing both-null == equal semantics (e.g., +/// buffered head/tail equality in SMJ) should construct with +/// `NullEqualsNull`. pub struct JoinKeyComparator { comparators: Vec, - /// Only populated when null_equality == NullEqualsNothing - left_nulls: Vec>, - right_nulls: Vec>, - null_equality: NullEquality, } impl JoinKeyComparator { @@ -1847,66 +1856,50 @@ impl JoinKeyComparator { .iter() .zip(right_arrays.iter()) .zip(sort_options.iter()) - .map(|((l, r), opts)| Ok(make_comparator(l.as_ref(), r.as_ref(), *opts)?)) + .map(|((l, r), opts)| { + let inner = make_comparator(l.as_ref(), r.as_ref(), *opts)?; + if null_equality == NullEquality::NullEqualsNothing { + let ln = l.logical_nulls().filter(|n| n.null_count() > 0); + let rn = r.logical_nulls().filter(|n| n.null_count() > 0); + match (ln, rn) { + // Both sides have nulls — wrap to override both-null. + (Some(ln), Some(rn)) => Ok(Box::new(move |i, j| { + if ln.is_null(i) && rn.is_null(j) { + Ordering::Less + } else { + inner(i, j) + } + }) + as DynComparator), + // One side has no nulls — both-null impossible, no wrap. + _ => Ok(inner), + } + } else { + Ok(inner) + } + }) .collect::>>()?; - let (left_nulls, right_nulls) = - if null_equality == NullEquality::NullEqualsNothing { - let ln = left_arrays - .iter() - .map(|a| a.logical_nulls().filter(|n| n.null_count() > 0)) - .collect(); - let rn = right_arrays - .iter() - .map(|a| a.logical_nulls().filter(|n| n.null_count() > 0)) - .collect(); - (ln, rn) - } else { - (vec![], vec![]) - }; - - Ok(Self { - comparators, - left_nulls, - right_nulls, - null_equality, - }) + Ok(Self { comparators }) } /// Compare row `left` (in the left arrays) with row `right` (in the right /// arrays). Returns the lexicographic ordering across all key columns. #[inline] pub fn compare(&self, left: usize, right: usize) -> Ordering { - if self.null_equality == NullEquality::NullEqualsNothing { - for (idx, cmp_fn) in self.comparators.iter().enumerate() { - // Override both-null: make_comparator returns Equal but - // NullEqualsNothing semantics require Less. - if let (Some(ln), Some(rn)) = - (&self.left_nulls[idx], &self.right_nulls[idx]) - && ln.is_null(left) - && rn.is_null(right) - { - return Ordering::Less; - } - let ord = cmp_fn(left, right); - if ord != Ordering::Equal { - return ord; - } - } - } else { - for cmp_fn in &self.comparators { - let ord = cmp_fn(left, right); - if ord != Ordering::Equal { - return ord; - } + for cmp_fn in &self.comparators { + let ord = cmp_fn(left, right); + if ord != Ordering::Equal { + return ord; } } Ordering::Equal } /// Check equality of row `left` (in the left arrays) with row `right` - /// (in the right arrays). Both-null is treated as equal regardless of - /// `null_equality` — this matches `is_join_arrays_equal` semantics. + /// (in the right arrays). Both-null is treated as equal when constructed + /// with `NullEqualsNull`. With `NullEqualsNothing`, both-null returns + /// `false` because the override is baked into the comparators. #[inline] pub fn is_equal(&self, left: usize, right: usize) -> bool { for cmp_fn in &self.comparators { @@ -3126,9 +3119,6 @@ mod tests { assert_eq!(cmp.compare(0, 0), Ordering::Greater); // left[3]=2 vs right[3]=2 → Equal assert_eq!(cmp.compare(3, 3), Ordering::Equal); - - // is_equal always treats both-null as equal - assert!(cmp.is_equal(1, 1)); } #[test] From 3947c7ec220b0d75b93967e8ed769e7a172f038a Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Wed, 8 Apr 2026 15:40:22 -0400 Subject: [PATCH 6/8] consistent comments --- datafusion/physical-plan/src/joins/utils.rs | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index df03f8bd1655..28358305abbc 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -3055,14 +3055,14 @@ mod tests { ) .unwrap(); - // left[0]=(1,"a") vs right[0]=(2,"b") → Less (first column) + // left[0]=(1,"a") vs right[0]=(2,"b") -> Less (first column) assert_eq!(cmp.compare(0, 0), Ordering::Less); - // left[1]=(2,"b") vs right[0]=(2,"b") → Equal + // left[1]=(2,"b") vs right[0]=(2,"b") -> Equal assert_eq!(cmp.compare(1, 0), Ordering::Equal); assert!(cmp.is_equal(1, 0)); - // left[2]=(2,"c") vs right[1]=(2,"d") → Less (second column) + // left[2]=(2,"c") vs right[1]=(2,"d") -> Less (second column) assert_eq!(cmp.compare(2, 1), Ordering::Less); - // left[3]=(3,"d") vs right[0]=(2,"b") → Greater + // left[3]=(3,"d") vs right[0]=(2,"b") -> Greater assert_eq!(cmp.compare(3, 0), Ordering::Greater); } @@ -3085,12 +3085,12 @@ mod tests { ) .unwrap(); - // left[1]=NULL vs right[1]=NULL → Equal (NullEqualsNull) + // left[1]=NULL vs right[1]=NULL -> Equal (NullEqualsNull) assert_eq!(cmp.compare(1, 1), Ordering::Equal); assert!(cmp.is_equal(1, 1)); - // left[0]=1 vs right[0]=NULL → Greater (nulls_first, non-null > null) + // left[0]=1 vs right[0]=NULL -> Greater (nulls_first, non-null > null) assert_eq!(cmp.compare(0, 0), Ordering::Greater); - // left[3]=2 vs right[3]=2 → Equal + // left[3]=2 vs right[3]=2 -> Equal assert_eq!(cmp.compare(3, 3), Ordering::Equal); } @@ -3113,11 +3113,11 @@ mod tests { ) .unwrap(); - // left[1]=NULL vs right[1]=NULL → Less (NullEqualsNothing) + // left[1]=NULL vs right[1]=NULL -> Less (NullEqualsNothing) assert_eq!(cmp.compare(1, 1), Ordering::Less); - // left[0]=1 vs right[0]=NULL → Greater (nulls_first) + // left[0]=1 vs right[0]=NULL -> Greater (nulls_first) assert_eq!(cmp.compare(0, 0), Ordering::Greater); - // left[3]=2 vs right[3]=2 → Equal + // left[3]=2 vs right[3]=2 -> Equal assert_eq!(cmp.compare(3, 3), Ordering::Equal); } From e0d06eea91e4649d853305c0224350dd2503d5f1 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Wed, 8 Apr 2026 16:52:55 -0400 Subject: [PATCH 7/8] fix extended test --- datafusion/core/tests/memory_limit/mod.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/datafusion/core/tests/memory_limit/mod.rs b/datafusion/core/tests/memory_limit/mod.rs index 7075fbc2443d..da13389901ee 100644 --- a/datafusion/core/tests/memory_limit/mod.rs +++ b/datafusion/core/tests/memory_limit/mod.rs @@ -213,6 +213,7 @@ async fn sort_merge_join_spill() { .with_config(config) .with_disk_manager_builder(DiskManagerBuilder::default()) .with_scenario(Scenario::AccessLogStreaming) + .with_expected_success() .run() .await } From d8658065055c864dbf9c0c3c6aa26755f4ee50c2 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Thu, 9 Apr 2026 11:04:12 -0400 Subject: [PATCH 8/8] Optimize for single-column key joins. --- datafusion/physical-plan/src/joins/utils.rs | 34 +++++++++++++++------ 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 28358305abbc..39a380eed9d4 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -1827,6 +1827,11 @@ fn eq_dyn_null( /// dispatch. Wraps `arrow_ord::ord::DynComparator` closures built once per /// batch pair, used for all row comparisons within those batches. /// +/// The first key column is stored separately so that single-column joins +/// (the common case) avoid Vec iteration entirely, and multi-column joins +/// short-circuit without entering the loop when the first column is +/// selective. +/// /// Null handling is baked into the closures at construction time: /// - `NullEqualsNull`: `make_comparator` returns `Equal` for both-null, which /// is the desired behavior. Closures are used as-is. @@ -1840,19 +1845,22 @@ fn eq_dyn_null( /// buffered head/tail equality in SMJ) should construct with /// `NullEqualsNull`. pub struct JoinKeyComparator { - comparators: Vec, + first: DynComparator, + rest: Vec, } impl JoinKeyComparator { - /// Build comparators for each join key column pair. The `sort_options` - /// slice must have the same length as the array slices. + /// Build comparators for each join key column pair. pub fn new( left_arrays: &[ArrayRef], right_arrays: &[ArrayRef], sort_options: &[SortOptions], null_equality: NullEquality, ) -> Result { - let comparators = left_arrays + debug_assert_eq!(left_arrays.len(), right_arrays.len()); + debug_assert_eq!(left_arrays.len(), sort_options.len()); + + let mut iter = left_arrays .iter() .zip(right_arrays.iter()) .zip(sort_options.iter()) @@ -1877,17 +1885,22 @@ impl JoinKeyComparator { } else { Ok(inner) } - }) - .collect::>>()?; + }); - Ok(Self { comparators }) + let first = iter.next().expect("join must have at least one key")?; + let rest = iter.collect::>>()?; + Ok(Self { first, rest }) } /// Compare row `left` (in the left arrays) with row `right` (in the right /// arrays). Returns the lexicographic ordering across all key columns. #[inline] pub fn compare(&self, left: usize, right: usize) -> Ordering { - for cmp_fn in &self.comparators { + let ord = (self.first)(left, right); + if ord != Ordering::Equal || self.rest.is_empty() { + return ord; + } + for cmp_fn in &self.rest { let ord = cmp_fn(left, right); if ord != Ordering::Equal { return ord; @@ -1902,7 +1915,10 @@ impl JoinKeyComparator { /// `false` because the override is baked into the comparators. #[inline] pub fn is_equal(&self, left: usize, right: usize) -> bool { - for cmp_fn in &self.comparators { + if (self.first)(left, right) != Ordering::Equal { + return false; + } + for cmp_fn in &self.rest { if cmp_fn(left, right) != Ordering::Equal { return false; }