Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 154 additions & 4 deletions native/spark-expr/src/bloom_filter/spark_bloom_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,162 @@ impl SparkBloomFilter {
self.bits.to_bytes()
}

/// Merges another bloom filter's state into this one. Accepts both Comet's
/// raw bits format (native endian, no header) and Spark's full serialization
/// format (12-byte big-endian header + big-endian bit data).
#[allow(dead_code)] // used in tests
pub fn num_hash_functions(&self) -> u32 {
self.num_hash_functions
}

/// Merges another bloom filter's state into this one. Accepts both Comet's
/// raw bits format (native endian, no header) and Spark's full serialization
/// format (12-byte big-endian header + big-endian bit data).
pub fn merge_filter(&mut self, other: &[u8]) {
let expected_bytes = self.bits.byte_size();
let header_size = 12; // version (4) + num_hash_functions (4) + num_words (4)

if other.len() == expected_bytes {
// Comet state format: raw bits in native endianness
self.bits.merge_bits(other);
} else if other.len() == expected_bytes + header_size {
// Spark serialization format: 12-byte big-endian header + big-endian bit data.
// Skip the header and convert from big-endian to native endianness.
let bits_data = &other[header_size..];
let native_bytes: Vec<u8> = bits_data
.chunks(8)
.flat_map(|chunk| {
let be_val = u64::from_be_bytes(chunk.try_into().unwrap());
be_val.to_ne_bytes()
})
.collect();
self.bits.merge_bits(&native_bytes);
} else {
panic!(
"Cannot merge SparkBloomFilter: unexpected buffer length {}. \
Expected {} (raw bits) or {} (Spark serialization format).",
other.len(),
expected_bytes,
expected_bytes + header_size
);
}
}
}

#[cfg(test)]
mod test {
use super::*;
use crate::bloom_filter::spark_bloom_filter::optimal_num_hash_functions;
use arrow::array::{ArrayRef, BinaryArray};
use datafusion::physical_plan::Accumulator;
use std::sync::Arc;

const NUM_ITEMS: i32 = 1000;
const NUM_BITS: i32 = 8192;

fn new_bloom_filter() -> SparkBloomFilter {
let num_hash = optimal_num_hash_functions(NUM_ITEMS, NUM_BITS);
SparkBloomFilter::from((num_hash, NUM_BITS))
}

#[test]
fn test_merge_comet_state_format() {
// Simulates Comet partial -> Comet final: state() produces raw bits
let mut partial = new_bloom_filter();
partial.put_long(42);
partial.put_long(100);

let raw_bits = partial.state_as_bytes();

let mut final_agg = new_bloom_filter();
final_agg.merge_filter(&raw_bits);

assert!(final_agg.might_contain_long(42));
assert!(final_agg.might_contain_long(100));
assert!(!final_agg.might_contain_long(999));
}

#[test]
fn test_merge_spark_serialization_format() {
// Simulates Spark partial -> Comet final: evaluate() produces Spark format
// with 12-byte header. This is the scenario from issue #2889.
let mut partial = new_bloom_filter();
partial.put_long(42);
partial.put_long(100);

let spark_format = partial.spark_serialization();
// Verify the Spark format has the 12-byte header
assert_eq!(spark_format.len(), partial.state_as_bytes().len() + 12);

let mut final_agg = new_bloom_filter();
final_agg.merge_filter(&spark_format);

assert!(final_agg.might_contain_long(42));
assert!(final_agg.might_contain_long(100));
assert!(!final_agg.might_contain_long(999));
}

#[test]
fn test_merge_batch_with_spark_format() {
// End-to-end test using the Accumulator trait, matching what happens
// when Spark partial sends its state to Comet final via merge_batch.
let mut partial = new_bloom_filter();
partial.put_long(42);
partial.put_long(100);

let spark_bytes = partial.spark_serialization();
let binary_array: ArrayRef = Arc::new(BinaryArray::from_vec(vec![&spark_bytes]));

let mut final_agg = new_bloom_filter();
final_agg.merge_batch(&[binary_array]).unwrap();

assert!(final_agg.might_contain_long(42));
assert!(final_agg.might_contain_long(100));
assert!(!final_agg.might_contain_long(999));
}

#[test]
fn test_merge_batch_with_comet_state() {
// Comet partial -> Comet final via merge_batch using state() output
let mut partial = new_bloom_filter();
partial.put_long(42);

let state = partial.state().unwrap();
let raw_bits = match &state[0] {
datafusion::common::ScalarValue::Binary(Some(b)) => b.clone(),
_ => panic!("expected Binary"),
};
let binary_array: ArrayRef = Arc::new(BinaryArray::from_vec(vec![&raw_bits]));

let mut final_agg = new_bloom_filter();
final_agg.merge_batch(&[binary_array]).unwrap();

assert!(final_agg.might_contain_long(42));
}

#[test]
fn test_roundtrip_spark_serialization() {
let mut original = new_bloom_filter();
for i in 0..50 {
original.put_long(i);
}

let spark_bytes = original.spark_serialization();
let deserialized = SparkBloomFilter::from(spark_bytes.as_slice());

assert_eq!(
other.len(),
self.bits.byte_size(),
"Cannot merge SparkBloomFilters with different lengths."
deserialized.num_hash_functions(),
original.num_hash_functions()
);
self.bits.merge_bits(other);
for i in 0..50 {
assert!(deserialized.might_contain_long(i));
}
}

#[test]
#[should_panic(expected = "unexpected buffer length")]
fn test_merge_invalid_length_panics() {
let mut filter = new_bloom_filter();
filter.merge_filter(&[0u8; 37]);
}
}
25 changes: 22 additions & 3 deletions spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeSet, Expression, ExpressionSet, Generator, NamedExpression, SortOrder}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateMode, Final, Partial, PartialMerge}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateMode, BloomFilterAggregate, Final, Partial, PartialMerge}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
Expand Down Expand Up @@ -1344,8 +1344,14 @@ trait CometBaseAggregate {
// In distinct aggregates there can be a combination of modes
val multiMode = modes.size > 1
// For a final mode HashAggregate, we only need to transform the HashAggregate
// if there is Comet partial aggregation.
val sparkFinalMode = modes.contains(Final) && findCometPartialAgg(aggregate.child).isEmpty
// if there is Comet partial aggregation. Some aggregate functions (e.g. BloomFilterAggregate)
// have compatible intermediate buffer formats between Spark and Comet, so they can safely
// run as Comet final even when Spark did the partial.
val allowMixed =
aggregate.aggregateExpressions.forall(
_.aggregateFunction.isInstanceOf[BloomFilterAggregate])
val sparkFinalMode =
!allowMixed && modes.contains(Final) && findCometPartialAgg(aggregate.child).isEmpty

if (multiMode || sparkFinalMode) {
return None
Expand Down Expand Up @@ -1552,6 +1558,19 @@ object CometObjectHashAggregateExec
override def enabledConfig: Option[ConfigEntry[Boolean]] = Some(
CometConf.COMET_EXEC_AGGREGATE_ENABLED)

override def getSupportLevel(op: ObjectHashAggregateExec): SupportLevel = {
if (!CometConf.COMET_ENABLE_PARTIAL_HASH_AGGREGATE.get(op.conf) &&
op.aggregateExpressions
.exists(expr => expr.mode == Partial || expr.mode == PartialMerge)) {
return Unsupported(Some("Partial aggregates disabled via test config"))
}
if (!CometConf.COMET_ENABLE_FINAL_HASH_AGGREGATE.get(op.conf) &&
op.aggregateExpressions.exists(_.mode == Final)) {
return Unsupported(Some("Final aggregates disabled via test config"))
}
Compatible()
}

override def convert(
aggregate: ObjectHashAggregateExec,
builder: Operator.Builder,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1982,4 +1982,41 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
sparkPlan.collect { case s: CometHashAggregateExec => s }.size
}

test("bloom_filter_agg with Spark partial and Comet final aggregate") {
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo}
import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate

val funcId = new FunctionIdentifier("bloom_filter_agg")
spark.sessionState.functionRegistry.registerFunction(
funcId,
new ExpressionInfo(classOf[BloomFilterAggregate].getName, "bloom_filter_agg"),
(children: Seq[Expression]) =>
children.size match {
case 1 => new BloomFilterAggregate(children.head)
case 2 => new BloomFilterAggregate(children.head, children(1))
case 3 => new BloomFilterAggregate(children.head, children(1), children(2))
})

try {
withParquetTable((0 until 100).map(i => (i.toLong, i.toLong % 5)), "bloom_tbl") {
// Disable Comet partial aggregate so Spark does partial. BloomFilterAggregate
// is allowed to run as Comet final even with Spark partial, since the Rust
// merge_filter now handles both Spark's serialization format (12-byte header +
// big-endian bits) and Comet's raw native-endian bits format.
withSQLConf(CometConf.COMET_ENABLE_PARTIAL_HASH_AGGREGATE.key -> "false") {
val df = sql("SELECT bloom_filter_agg(cast(_1 as long)) FROM bloom_tbl")
// Verify we have the mixed execution: Spark partial + Comet final
val plan = stripAQEPlan(df.queryExecution.executedPlan)
assert(
plan.collect { case a: CometHashAggregateExec => a }.nonEmpty,
"Expected at least one CometHashAggregateExec in the plan")
checkSparkAnswer(df)
}
}
} finally {
spark.sessionState.functionRegistry.dropFunction(funcId)
}
}

}
Loading