Skip to content
1 change: 1 addition & 0 deletions datafusion/core/tests/memory_limit/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ async fn sort_merge_join_spill() {
.with_config(config)
.with_disk_manager_builder(DiskManagerBuilder::default())
.with_scenario(Scenario::AccessLogStreaming)
Copy link
Copy Markdown
Contributor Author

@mbutrovich mbutrovich Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_join_arrays_equal didn't support Dictionary keys — it hit the not_impl_err! fallthrough. This test was passing because that error counted as the expected "failure." JoinKeyComparator uses make_comparator which handles all Arrow types, so the query now correctly spills and succeeds.

If you add this on main:

.with_expected_errors(vec!["Unsupported data type in sort merge join comparator"])

the test passes, confirming it was failing for the wrong reason on main all along.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude's assessment:

The test's original intent was to verify SMJ fails with OOM when memory is too tight even for spilling. But at 1000 bytes with this data, spilling works fine — there's no "too tight to spill" threshold we can easily hit because the spill path writes to disk before the reservation grows.

The reservation system works like: try_grow fails → spill to disk → no reservation needed. As long as the disk manager works, there's no minimum memory floor that would cause OOM. Even at 1 byte, the first try_grow would fail and spill immediately.

So the test as-is can't be made to fail with OOM under SMJ spilling — it was only "working" because of the Dictionary bug.

I'm not sure if the test is worth keeping, but maybe that's a different PR.

.with_expected_success()
.run()
.await
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ use crate::handle_state;
use crate::joins::piecewise_merge_join::exec::{BufferedSide, BufferedSideReadyState};
use crate::joins::piecewise_merge_join::utils::need_produce_result_in_final;
use crate::joins::utils::{BuildProbeJoinMetrics, StatefulStreamResult};
use crate::joins::utils::{compare_join_arrays, get_final_indices_from_shared_bitmap};
use crate::joins::utils::{JoinKeyComparator, get_final_indices_from_shared_bitmap};

pub(super) enum PiecewiseMergeJoinStreamState {
WaitBufferedSide,
Expand Down Expand Up @@ -460,6 +460,14 @@ fn resolve_classic_join(
let buffered_len = buffered_side.buffered_data.values().len();
let stream_values = stream_batch.compare_key_values();

// Build comparator once for the batch pair
let cmp = JoinKeyComparator::new(
Copy link
Copy Markdown
Contributor Author

@mbutrovich mbutrovich Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: piecewise merge join rebuilds the JoinKeyComparator on each resolve_classic_join re-entry (when output exceeds batch size mid-scan), even though the batch pair hasn't changed. Could lift it to ClassicPWMJStream and rebuild only in fetch_stream_batch, but the cost is one make_comparator call per ~8192 rows of output, so probably not worth the added lifecycle coupling. Left as-is for now.

&[Arc::clone(&stream_values[0])],
&[Arc::clone(buffered_side.buffered_data.values())],
&[sort_options],
NullEquality::NullEqualsNothing,
)?;

let mut buffer_idx = batch_process_state.start_buffer_idx;
let mut stream_idx = batch_process_state.start_stream_idx;

Expand All @@ -475,17 +483,7 @@ fn resolve_classic_join(
// in the previous stream row.
for row_idx in stream_idx..stream_batch.batch.num_rows() {
while buffer_idx < buffered_len {
let compare = {
let buffered_values = buffered_side.buffered_data.values();
compare_join_arrays(
&[Arc::clone(&stream_values[0])],
row_idx,
&[Arc::clone(buffered_values)],
buffer_idx,
&[sort_options],
NullEquality::NullEqualsNothing,
)?
};
let compare = cmp.compare(row_idx, buffer_idx);

// If we find a match we append all indices and move to the next stream row index
match operator {
Expand Down
172 changes: 94 additions & 78 deletions datafusion/physical-plan/src/joins/sort_merge_join/bitwise_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ use std::sync::Arc;
use std::task::{Context, Poll};

use crate::RecordBatchStream;
use crate::joins::utils::{JoinFilter, compare_join_arrays};
use crate::joins::utils::{JoinFilter, JoinKeyComparator, compare_join_arrays};
use crate::metrics::{
BaselineMetrics, Count, ExecutionPlanMetricsSet, Gauge, MetricBuilder,
};
Expand Down Expand Up @@ -162,70 +162,40 @@ fn evaluate_join_keys(
}

/// Find the first index in `key_arrays` starting from `from` where the key
/// differs from the key at `from`. Uses `compare_join_arrays` for zero-alloc
/// ordinal comparison.
/// differs from the key at `from`. Uses a pre-built `JoinKeyComparator` for
/// zero-alloc ordinal comparison without per-row type dispatch.
///
/// Optimized for join workloads: checks adjacent and boundary keys before
/// falling back to binary search, since most key groups are small (often 1).
fn find_key_group_end(
key_arrays: &[ArrayRef],
from: usize,
len: usize,
sort_options: &[SortOptions],
null_equality: NullEquality,
) -> Result<usize> {
fn find_key_group_end(cmp: &JoinKeyComparator, from: usize, len: usize) -> usize {
let next = from + 1;
if next >= len {
return Ok(len);
return len;
}

// Fast path: single-row group (common with unique keys).
if compare_join_arrays(
key_arrays,
from,
key_arrays,
next,
sort_options,
null_equality,
)? != Ordering::Equal
{
return Ok(next);
if cmp.compare(from, next) != Ordering::Equal {
return next;
}

// Check if the entire remaining batch shares this key.
let last = len - 1;
if compare_join_arrays(
key_arrays,
from,
key_arrays,
last,
sort_options,
null_equality,
)? == Ordering::Equal
{
return Ok(len);
if cmp.compare(from, last) == Ordering::Equal {
return len;
}

// Binary search the interior: key at `next` matches, key at `last` doesn't.
let mut lo = next + 1;
let mut hi = last;
while lo < hi {
let mid = lo + (hi - lo) / 2;
if compare_join_arrays(
key_arrays,
from,
key_arrays,
mid,
sort_options,
null_equality,
)? == Ordering::Equal
{
if cmp.compare(from, mid) == Ordering::Equal {
lo = mid + 1;
} else {
hi = mid;
}
}
Ok(lo)
lo
}

/// When an outer key group spans a batch boundary, the boundary loop emits
Expand Down Expand Up @@ -328,6 +298,14 @@ pub(crate) struct BitwiseSortMergeJoinStream {
runtime_env: Arc<datafusion_execution::runtime_env::RuntimeEnv>,
inner_buffer_size: usize,

// Cached comparators — pre-built to avoid per-row type dispatch.
/// Comparator for outer vs inner key comparison
outer_inner_cmp: Option<JoinKeyComparator>,
/// Comparator for outer self-comparison (find_key_group_end on outer)
outer_self_cmp: Option<JoinKeyComparator>,
/// Comparator for inner self-comparison (find_key_group_end on inner)
inner_self_cmp: Option<JoinKeyComparator>,

// True once the current outer batch has been emitted. The Equal
// branch's inner loops call emit then `ready!(poll_next_outer_batch)`.
// If that poll returns Pending, poll_join re-enters from the top
Expand Down Expand Up @@ -413,6 +391,9 @@ impl BitwiseSortMergeJoinStream {
spill_manager,
runtime_env,
inner_buffer_size: 0,
outer_inner_cmp: None,
outer_self_cmp: None,
inner_self_cmp: None,
batch_emitted: false,
})
}
Expand All @@ -425,6 +406,45 @@ impl BitwiseSortMergeJoinStream {
Ok(())
}

/// Get or build the outer vs inner key comparator.
fn get_outer_inner_cmp(&mut self) -> Result<&JoinKeyComparator> {
if self.outer_inner_cmp.is_none() {
self.outer_inner_cmp = Some(JoinKeyComparator::new(
&self.outer_key_arrays,
&self.inner_key_arrays,
&self.sort_options,
self.null_equality,
)?);
}
Ok(self.outer_inner_cmp.as_ref().unwrap())
}

/// Get or build the outer self-comparison comparator.
fn get_outer_self_cmp(&mut self) -> Result<&JoinKeyComparator> {
if self.outer_self_cmp.is_none() {
self.outer_self_cmp = Some(JoinKeyComparator::new(
&self.outer_key_arrays,
&self.outer_key_arrays,
&self.sort_options,
self.null_equality,
)?);
}
Ok(self.outer_self_cmp.as_ref().unwrap())
}

/// Get or build the inner self-comparison comparator.
fn get_inner_self_cmp(&mut self) -> Result<&JoinKeyComparator> {
if self.inner_self_cmp.is_none() {
self.inner_self_cmp = Some(JoinKeyComparator::new(
&self.inner_key_arrays,
&self.inner_key_arrays,
&self.sort_options,
self.null_equality,
)?);
}
Ok(self.inner_self_cmp.as_ref().unwrap())
}

/// Spill the in-memory inner key buffer to disk and clear it.
fn spill_inner_key_buffer(&mut self) -> Result<()> {
let spill_file = self
Expand Down Expand Up @@ -468,6 +488,8 @@ impl BitwiseSortMergeJoinStream {
self.outer_batch = Some(batch);
self.outer_offset = 0;
self.outer_key_arrays = keys;
self.outer_inner_cmp = None;
self.outer_self_cmp = None;
self.batch_emitted = false;
self.matched = BooleanBufferBuilder::new(batch_num_rows);
self.matched.append_n(batch_num_rows, false);
Expand All @@ -494,6 +516,8 @@ impl BitwiseSortMergeJoinStream {
self.inner_batch = Some(batch);
self.inner_offset = 0;
self.inner_key_arrays = keys;
self.outer_inner_cmp = None;
self.inner_self_cmp = None;
return Poll::Ready(Ok(true));
}
}
Expand Down Expand Up @@ -555,13 +579,12 @@ impl BitwiseSortMergeJoinStream {
let outer_batch = self.outer_batch.as_ref().unwrap();
let num_outer = outer_batch.num_rows();

self.get_outer_self_cmp()?;
let outer_group_end = find_key_group_end(
&self.outer_key_arrays,
self.outer_self_cmp.as_ref().unwrap(),
self.outer_offset,
num_outer,
&self.sort_options,
self.null_equality,
)?;
);

for i in self.outer_offset..outer_group_end {
self.matched.set_bit(i, true);
Expand All @@ -584,13 +607,12 @@ impl BitwiseSortMergeJoinStream {
};
let num_inner = inner_batch.num_rows();

self.get_inner_self_cmp()?;
let group_end = find_key_group_end(
&self.inner_key_arrays,
self.inner_self_cmp.as_ref().unwrap(),
self.inner_offset,
num_inner,
&self.sort_options,
self.null_equality,
)?;
);

if group_end < num_inner {
self.inner_offset = group_end;
Expand Down Expand Up @@ -642,20 +664,19 @@ impl BitwiseSortMergeJoinStream {
}

loop {
let inner_batch = match &self.inner_batch {
Some(b) => b,
None => return Poll::Ready(Ok(true)),
};
let num_inner = inner_batch.num_rows();
if self.inner_batch.is_none() {
return Poll::Ready(Ok(true));
}
let num_inner = self.inner_batch.as_ref().unwrap().num_rows();
self.get_inner_self_cmp()?;
let group_end = find_key_group_end(
&self.inner_key_arrays,
self.inner_self_cmp.as_ref().unwrap(),
self.inner_offset,
num_inner,
&self.sort_options,
self.null_equality,
)?;
);

if !resume_from_poll {
let inner_batch = self.inner_batch.as_ref().unwrap();
let slice =
inner_batch.slice(self.inner_offset, group_end - self.inner_offset);
self.inner_buffer_size += slice.get_array_memory_size();
Expand Down Expand Up @@ -719,6 +740,7 @@ impl BitwiseSortMergeJoinStream {
/// key group, evaluates the filter against the outer key group and ORs
/// the results into the matched bitset using u64-chunked bitwise ops.
fn process_key_match_with_filter(&mut self) -> Result<()> {
self.get_outer_self_cmp()?;
let filter = self.filter.as_ref().unwrap();
let outer_batch = self.outer_batch.as_ref().unwrap();
let num_outer = outer_batch.num_rows();
Expand All @@ -738,12 +760,10 @@ impl BitwiseSortMergeJoinStream {
);

let outer_group_end = find_key_group_end(
&self.outer_key_arrays,
self.outer_self_cmp.as_ref().unwrap(),
self.outer_offset,
num_outer,
&self.sort_options,
self.null_equality,
)?;
);
let outer_group_len = outer_group_end - self.outer_offset;
let outer_slice = outer_batch.slice(self.outer_offset, outer_group_len);

Expand Down Expand Up @@ -959,34 +979,30 @@ impl BitwiseSortMergeJoinStream {
}

// 4. Compare keys at current positions
let cmp = compare_join_arrays(
&self.outer_key_arrays,
self.outer_offset,
&self.inner_key_arrays,
self.inner_offset,
&self.sort_options,
self.null_equality,
)?;
self.get_outer_inner_cmp()?;
let cmp = self
.outer_inner_cmp
.as_ref()
.unwrap()
.compare(self.outer_offset, self.inner_offset);

match cmp {
Ordering::Less => {
self.get_outer_self_cmp()?;
let group_end = find_key_group_end(
&self.outer_key_arrays,
self.outer_self_cmp.as_ref().unwrap(),
self.outer_offset,
num_outer,
&self.sort_options,
self.null_equality,
)?;
);
self.outer_offset = group_end;
}
Ordering::Greater => {
self.get_inner_self_cmp()?;
let group_end = find_key_group_end(
&self.inner_key_arrays,
self.inner_self_cmp.as_ref().unwrap(),
self.inner_offset,
num_inner,
&self.sort_options,
self.null_equality,
)?;
);
if group_end >= num_inner {
let saved_keys =
slice_keys(&self.inner_key_arrays, num_inner - 1);
Expand Down
Loading
Loading