diff --git a/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs b/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs index e84257ea67..85a863a095 100644 --- a/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs +++ b/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs @@ -159,12 +159,162 @@ impl SparkBloomFilter { self.bits.to_bytes() } + /// Merges another bloom filter's state into this one. Accepts both Comet's + /// raw bits format (native endian, no header) and Spark's full serialization + /// format (12-byte big-endian header + big-endian bit data). + #[allow(dead_code)] // used in tests + pub fn num_hash_functions(&self) -> u32 { + self.num_hash_functions + } + + /// Merges another bloom filter's state into this one. Accepts both Comet's + /// raw bits format (native endian, no header) and Spark's full serialization + /// format (12-byte big-endian header + big-endian bit data). pub fn merge_filter(&mut self, other: &[u8]) { + let expected_bytes = self.bits.byte_size(); + let header_size = 12; // version (4) + num_hash_functions (4) + num_words (4) + + if other.len() == expected_bytes { + // Comet state format: raw bits in native endianness + self.bits.merge_bits(other); + } else if other.len() == expected_bytes + header_size { + // Spark serialization format: 12-byte big-endian header + big-endian bit data. + // Skip the header and convert from big-endian to native endianness. + let bits_data = &other[header_size..]; + let native_bytes: Vec = bits_data + .chunks(8) + .flat_map(|chunk| { + let be_val = u64::from_be_bytes(chunk.try_into().unwrap()); + be_val.to_ne_bytes() + }) + .collect(); + self.bits.merge_bits(&native_bytes); + } else { + panic!( + "Cannot merge SparkBloomFilter: unexpected buffer length {}. \ + Expected {} (raw bits) or {} (Spark serialization format).", + other.len(), + expected_bytes, + expected_bytes + header_size + ); + } + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::bloom_filter::spark_bloom_filter::optimal_num_hash_functions; + use arrow::array::{ArrayRef, BinaryArray}; + use datafusion::physical_plan::Accumulator; + use std::sync::Arc; + + const NUM_ITEMS: i32 = 1000; + const NUM_BITS: i32 = 8192; + + fn new_bloom_filter() -> SparkBloomFilter { + let num_hash = optimal_num_hash_functions(NUM_ITEMS, NUM_BITS); + SparkBloomFilter::from((num_hash, NUM_BITS)) + } + + #[test] + fn test_merge_comet_state_format() { + // Simulates Comet partial -> Comet final: state() produces raw bits + let mut partial = new_bloom_filter(); + partial.put_long(42); + partial.put_long(100); + + let raw_bits = partial.state_as_bytes(); + + let mut final_agg = new_bloom_filter(); + final_agg.merge_filter(&raw_bits); + + assert!(final_agg.might_contain_long(42)); + assert!(final_agg.might_contain_long(100)); + assert!(!final_agg.might_contain_long(999)); + } + + #[test] + fn test_merge_spark_serialization_format() { + // Simulates Spark partial -> Comet final: evaluate() produces Spark format + // with 12-byte header. This is the scenario from issue #2889. + let mut partial = new_bloom_filter(); + partial.put_long(42); + partial.put_long(100); + + let spark_format = partial.spark_serialization(); + // Verify the Spark format has the 12-byte header + assert_eq!(spark_format.len(), partial.state_as_bytes().len() + 12); + + let mut final_agg = new_bloom_filter(); + final_agg.merge_filter(&spark_format); + + assert!(final_agg.might_contain_long(42)); + assert!(final_agg.might_contain_long(100)); + assert!(!final_agg.might_contain_long(999)); + } + + #[test] + fn test_merge_batch_with_spark_format() { + // End-to-end test using the Accumulator trait, matching what happens + // when Spark partial sends its state to Comet final via merge_batch. + let mut partial = new_bloom_filter(); + partial.put_long(42); + partial.put_long(100); + + let spark_bytes = partial.spark_serialization(); + let binary_array: ArrayRef = Arc::new(BinaryArray::from_vec(vec![&spark_bytes])); + + let mut final_agg = new_bloom_filter(); + final_agg.merge_batch(&[binary_array]).unwrap(); + + assert!(final_agg.might_contain_long(42)); + assert!(final_agg.might_contain_long(100)); + assert!(!final_agg.might_contain_long(999)); + } + + #[test] + fn test_merge_batch_with_comet_state() { + // Comet partial -> Comet final via merge_batch using state() output + let mut partial = new_bloom_filter(); + partial.put_long(42); + + let state = partial.state().unwrap(); + let raw_bits = match &state[0] { + datafusion::common::ScalarValue::Binary(Some(b)) => b.clone(), + _ => panic!("expected Binary"), + }; + let binary_array: ArrayRef = Arc::new(BinaryArray::from_vec(vec![&raw_bits])); + + let mut final_agg = new_bloom_filter(); + final_agg.merge_batch(&[binary_array]).unwrap(); + + assert!(final_agg.might_contain_long(42)); + } + + #[test] + fn test_roundtrip_spark_serialization() { + let mut original = new_bloom_filter(); + for i in 0..50 { + original.put_long(i); + } + + let spark_bytes = original.spark_serialization(); + let deserialized = SparkBloomFilter::from(spark_bytes.as_slice()); + assert_eq!( - other.len(), - self.bits.byte_size(), - "Cannot merge SparkBloomFilters with different lengths." + deserialized.num_hash_functions(), + original.num_hash_functions() ); - self.bits.merge_bits(other); + for i in 0..50 { + assert!(deserialized.might_contain_long(i)); + } + } + + #[test] + #[should_panic(expected = "unexpected buffer length")] + fn test_merge_invalid_length_panics() { + let mut filter = new_bloom_filter(); + filter.merge_filter(&[0u8; 37]); } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 2965e46988..8cdc734fd4 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -29,7 +29,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeSet, Expression, ExpressionSet, Generator, NamedExpression, SortOrder} -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateMode, Final, Partial, PartialMerge} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateMode, BloomFilterAggregate, Final, Partial, PartialMerge} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ @@ -1344,8 +1344,14 @@ trait CometBaseAggregate { // In distinct aggregates there can be a combination of modes val multiMode = modes.size > 1 // For a final mode HashAggregate, we only need to transform the HashAggregate - // if there is Comet partial aggregation. - val sparkFinalMode = modes.contains(Final) && findCometPartialAgg(aggregate.child).isEmpty + // if there is Comet partial aggregation. Some aggregate functions (e.g. BloomFilterAggregate) + // have compatible intermediate buffer formats between Spark and Comet, so they can safely + // run as Comet final even when Spark did the partial. + val allowMixed = + aggregate.aggregateExpressions.forall( + _.aggregateFunction.isInstanceOf[BloomFilterAggregate]) + val sparkFinalMode = + !allowMixed && modes.contains(Final) && findCometPartialAgg(aggregate.child).isEmpty if (multiMode || sparkFinalMode) { return None @@ -1552,6 +1558,19 @@ object CometObjectHashAggregateExec override def enabledConfig: Option[ConfigEntry[Boolean]] = Some( CometConf.COMET_EXEC_AGGREGATE_ENABLED) + override def getSupportLevel(op: ObjectHashAggregateExec): SupportLevel = { + if (!CometConf.COMET_ENABLE_PARTIAL_HASH_AGGREGATE.get(op.conf) && + op.aggregateExpressions + .exists(expr => expr.mode == Partial || expr.mode == PartialMerge)) { + return Unsupported(Some("Partial aggregates disabled via test config")) + } + if (!CometConf.COMET_ENABLE_FINAL_HASH_AGGREGATE.get(op.conf) && + op.aggregateExpressions.exists(_.mode == Final)) { + return Unsupported(Some("Final aggregates disabled via test config")) + } + Compatible() + } + override def convert( aggregate: ObjectHashAggregateExec, builder: Operator.Builder, diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index be60f4aaee..196488421d 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -1982,4 +1982,41 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { sparkPlan.collect { case s: CometHashAggregateExec => s }.size } + test("bloom_filter_agg with Spark partial and Comet final aggregate") { + import org.apache.spark.sql.catalyst.FunctionIdentifier + import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo} + import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate + + val funcId = new FunctionIdentifier("bloom_filter_agg") + spark.sessionState.functionRegistry.registerFunction( + funcId, + new ExpressionInfo(classOf[BloomFilterAggregate].getName, "bloom_filter_agg"), + (children: Seq[Expression]) => + children.size match { + case 1 => new BloomFilterAggregate(children.head) + case 2 => new BloomFilterAggregate(children.head, children(1)) + case 3 => new BloomFilterAggregate(children.head, children(1), children(2)) + }) + + try { + withParquetTable((0 until 100).map(i => (i.toLong, i.toLong % 5)), "bloom_tbl") { + // Disable Comet partial aggregate so Spark does partial. BloomFilterAggregate + // is allowed to run as Comet final even with Spark partial, since the Rust + // merge_filter now handles both Spark's serialization format (12-byte header + + // big-endian bits) and Comet's raw native-endian bits format. + withSQLConf(CometConf.COMET_ENABLE_PARTIAL_HASH_AGGREGATE.key -> "false") { + val df = sql("SELECT bloom_filter_agg(cast(_1 as long)) FROM bloom_tbl") + // Verify we have the mixed execution: Spark partial + Comet final + val plan = stripAQEPlan(df.queryExecution.executedPlan) + assert( + plan.collect { case a: CometHashAggregateExec => a }.nonEmpty, + "Expected at least one CometHashAggregateExec in the plan") + checkSparkAnswer(df) + } + } + } finally { + spark.sessionState.functionRegistry.dropFunction(funcId) + } + } + }