From 591d3fb223f5ba3627a6ca973fe323d8567a9cb8 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 25 Mar 2026 14:34:25 -0700 Subject: [PATCH 1/5] fix: query tolerance= in SQL file tests now also asserts Comet native execution The WithTolerance mode in CometSqlFileTestSuite called checkSparkAnswerWithTolerance, which hardcodes assertCometNative=false. This meant that all 50+ SQL test queries using 'query tolerance=...' (sin, cos, tan, stddev, avg, etc.) silently skipped the check that Comet actually executed the expression natively. Add checkSparkAnswerAndOperatorWithTolerance that combines tolerance- based result comparison with the native operator assertion, and use it in the WithTolerance case. --- .../org/apache/comet/CometSqlFileTestSuite.scala | 2 +- .../scala/org/apache/spark/sql/CometTestBase.scala | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/comet/CometSqlFileTestSuite.scala b/spark/src/test/scala/org/apache/comet/CometSqlFileTestSuite.scala index 020759a7a6..5a0b34e056 100644 --- a/spark/src/test/scala/org/apache/comet/CometSqlFileTestSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometSqlFileTestSuite.scala @@ -102,7 +102,7 @@ class CometSqlFileTestSuite extends CometTestBase with AdaptiveSparkPlanHelper { case SparkAnswerOnly => checkSparkAnswer(sql) case WithTolerance(tol) => - checkSparkAnswerWithTolerance(sql, tol) + checkSparkAnswerAndOperatorWithTolerance(sql, tol) case ExpectFallback(reason) => checkSparkAnswerAndFallbackReason(sql, reason) case Ignore(reason) => diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index 33c1d444b9..138b073b87 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -190,6 +190,17 @@ abstract class CometTestBase internalCheckSparkAnswer(df, assertCometNative = false, withTol = Some(absTol)) } + /** + * Check that the query returns the correct results when Comet is enabled and that Comet + * replaced all possible operators. Use the provided `tol` when comparing floating-point + * results. + */ + protected def checkSparkAnswerAndOperatorWithTolerance( + query: String, + absTol: Double = 1e-6): (SparkPlan, SparkPlan) = { + internalCheckSparkAnswer(sql(query), assertCometNative = true, withTol = Some(absTol)) + } + /** * Check that the query returns the correct results when Comet is enabled and that Comet * replaced all possible operators except for those specified in the excluded list. From e4ba9065cccf8f46be396eb74422fe3db80bbb28 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 25 Mar 2026 15:08:18 -0700 Subject: [PATCH 2/5] fix: add allowIncompatible config to tan.sql so tolerance test runs natively --- spark/src/test/resources/sql-tests/expressions/math/tan.sql | 1 + 1 file changed, 1 insertion(+) diff --git a/spark/src/test/resources/sql-tests/expressions/math/tan.sql b/spark/src/test/resources/sql-tests/expressions/math/tan.sql index 21bd44f907..9496844804 100644 --- a/spark/src/test/resources/sql-tests/expressions/math/tan.sql +++ b/spark/src/test/resources/sql-tests/expressions/math/tan.sql @@ -16,6 +16,7 @@ -- under the License. -- ConfigMatrix: parquet.enable.dictionary=false,true +-- Config: spark.comet.expression.Tan.allowIncompatible=true statement CREATE TABLE test_tan(d double) USING parquet From e07dfd872111ca7f6ad99a639d72d0f114dbb021 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 25 Mar 2026 21:28:59 -0700 Subject: [PATCH 3/5] feat: add Comet native support for Logarithm (two-arg log(base, value)) Maps Spark's Logarithm expression to DataFusion's log(base, value) function. Applies nullIfNegative to both base and value to match Spark's behavior of returning null when the result would be NaN (inputs <= 0). --- .../org/apache/comet/serde/QueryPlanSerde.scala | 1 + .../main/scala/org/apache/comet/serde/math.scala | 16 +++++++++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 02a76f69f0..9322149e46 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -105,6 +105,7 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[Log] -> CometLog, classOf[Log2] -> CometLog2, classOf[Log10] -> CometLog10, + classOf[Logarithm] -> CometLogarithm, classOf[Multiply] -> CometMultiply, classOf[Pow] -> CometScalarFunction("pow"), classOf[Rand] -> CometRand, diff --git a/spark/src/main/scala/org/apache/comet/serde/math.scala b/spark/src/main/scala/org/apache/comet/serde/math.scala index 5a0393142a..b258ca6fae 100644 --- a/spark/src/main/scala/org/apache/comet/serde/math.scala +++ b/spark/src/main/scala/org/apache/comet/serde/math.scala @@ -19,7 +19,7 @@ package org.apache.comet.serde -import org.apache.spark.sql.catalyst.expressions.{Abs, Atan2, Attribute, Ceil, CheckOverflow, Expression, Floor, Hex, If, LessThanOrEqual, Literal, Log, Log10, Log2, Tan, Unhex} +import org.apache.spark.sql.catalyst.expressions.{Abs, Atan2, Attribute, Ceil, CheckOverflow, Expression, Floor, Hex, If, LessThanOrEqual, Literal, Log, Log10, Log2, Logarithm, Tan, Unhex} import org.apache.spark.sql.types.{DecimalType, NumericType} import org.apache.comet.CometSparkSessionExtensions.withInfo @@ -138,6 +138,20 @@ object CometLog2 extends CometExpressionSerde[Log2] with MathExprBase { } } +object CometLogarithm extends CometExpressionSerde[Logarithm] with MathExprBase { + override def convert( + expr: Logarithm, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + // Spark's Logarithm(left=base, right=value) returns null when result is NaN, + // which happens when base <= 0 or value <= 0. Apply nullIfNegative to both. + val leftExpr = exprToProtoInternal(nullIfNegative(expr.left), inputs, binding) + val rightExpr = exprToProtoInternal(nullIfNegative(expr.right), inputs, binding) + val optExpr = scalarFunctionExprToProto("log", leftExpr, rightExpr) + optExprWithInfo(optExpr, expr, expr.left, expr.right) + } +} + object CometHex extends CometExpressionSerde[Hex] with MathExprBase { override def convert( expr: Hex, From f081e186663fe5d455b9a5db64a567d78139b17d Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 29 Mar 2026 12:21:18 -0600 Subject: [PATCH 4/5] fix: address review feedback for SQL file tolerance tests - Fix inaccurate comment on CometLogarithm (base/value <= 0 returns null, not "returns null when result is NaN") - Fix doc reference tol -> absTol in checkSparkAnswerAndOperatorWithTolerance - Delegate to existing checkSparkAnswerAndOperatorWithTol method - Add <= 0 edge case test data to log.sql, log10.sql, log2.sql - Add custom spark_log Rust UDF to handle <= 0 -> null internally, fixing DataFusion's broken null propagation for two-arg log --- native/spark-expr/src/comet_scalar_funcs.rs | 5 + native/spark-expr/src/lib.rs | 3 +- native/spark-expr/src/math_funcs/log.rs | 229 ++++++++++++++++++ native/spark-expr/src/math_funcs/mod.rs | 2 + .../scala/org/apache/comet/serde/math.scala | 15 +- .../sql-tests/expressions/math/log.sql | 10 +- .../sql-tests/expressions/math/log10.sql | 2 +- .../sql-tests/expressions/math/log2.sql | 2 +- .../org/apache/spark/sql/CometTestBase.scala | 4 +- 9 files changed, 259 insertions(+), 13 deletions(-) create mode 100644 native/spark-expr/src/math_funcs/log.rs diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index ff75de763b..1eaf0b2a97 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -18,6 +18,7 @@ use crate::hash_funcs::*; use crate::math_funcs::abs::abs; use crate::math_funcs::checked_arithmetic::{checked_add, checked_div, checked_mul, checked_sub}; +use crate::math_funcs::log::spark_log; use crate::math_funcs::modulo_expr::spark_modulo; use crate::{ spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_isnan, @@ -177,6 +178,10 @@ pub fn create_comet_physical_fun_with_eval_mode( let func = Arc::new(abs); make_comet_scalar_udf!("abs", func, without data_type) } + "spark_log" => { + let func = Arc::new(spark_log); + make_comet_scalar_udf!("spark_log", func, without data_type) + } "split" => { let func = Arc::new(crate::string_funcs::spark_split); make_comet_scalar_udf!("split", func, without data_type) diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index a7711d642d..159c252bc2 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -79,7 +79,8 @@ pub use hash_funcs::*; pub use json_funcs::{FromJson, ToJson}; pub use math_funcs::{ create_modulo_expr, create_negate_expr, spark_ceil, spark_decimal_div, - spark_decimal_integral_div, spark_floor, spark_make_decimal, spark_round, spark_unhex, + spark_decimal_integral_div, spark_floor, spark_log, spark_make_decimal, spark_round, + spark_unhex, spark_unscaled_value, CheckOverflow, DecimalRescaleCheckOverflow, NegativeExpr, NormalizeNaNAndZero, WideDecimalBinaryExpr, WideDecimalOp, }; diff --git a/native/spark-expr/src/math_funcs/log.rs b/native/spark-expr/src/math_funcs/log.rs new file mode 100644 index 0000000000..5e37a1c843 --- /dev/null +++ b/native/spark-expr/src/math_funcs/log.rs @@ -0,0 +1,229 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, Float64Array}; +use datafusion::common::{DataFusionError, ScalarValue}; +use datafusion::physical_plan::ColumnarValue; +use std::sync::Arc; + +/// Spark-compatible two-argument logarithm: `log(base, value)`. +/// +/// Returns `log(value) / log(base)`, matching Spark's `Logarithm` expression. +/// Returns null when `base <= 0` or `value <= 0`, matching Spark's `nullSafeEval`. +pub fn spark_log(args: &[ColumnarValue]) -> Result { + if args.len() != 2 { + return Err(DataFusionError::Internal(format!( + "spark_log requires 2 arguments, got {}", + args.len() + ))); + } + + // Spark's Logarithm: log(base, value) = ln(value) / ln(base) + // Returns null when base <= 0 or value <= 0 + fn compute(base: f64, value: f64) -> Option { + if base <= 0.0 || value <= 0.0 { + None + } else { + Some(value.ln() / base.ln()) + } + } + + match (&args[0], &args[1]) { + (ColumnarValue::Array(base_arr), ColumnarValue::Array(val_arr)) => { + let bases = base_arr + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal(format!( + "spark_log expected Float64 for base, got {:?}", + base_arr.data_type() + )) + })?; + let values = val_arr + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal(format!( + "spark_log expected Float64 for value, got {:?}", + val_arr.data_type() + )) + })?; + let result: Float64Array = bases + .iter() + .zip(values.iter()) + .map(|(b, v)| match (b, v) { + (Some(base), Some(value)) => compute(base, value), + _ => None, + }) + .collect(); + Ok(ColumnarValue::Array(Arc::new(result))) + } + (ColumnarValue::Scalar(base_scalar), ColumnarValue::Array(val_arr)) => { + let base = match base_scalar { + ScalarValue::Float64(Some(b)) => *b, + ScalarValue::Float64(None) => { + let result = Float64Array::new_null(val_arr.len()); + return Ok(ColumnarValue::Array(Arc::new(result))); + } + _ => { + return Err(DataFusionError::Internal(format!( + "spark_log expected Float64 scalar for base, got {base_scalar:?}", + ))); + } + }; + let values = val_arr + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal(format!( + "spark_log expected Float64 for value, got {:?}", + val_arr.data_type() + )) + })?; + let result: Float64Array = values + .iter() + .map(|v| v.and_then(|value| compute(base, value))) + .collect(); + Ok(ColumnarValue::Array(Arc::new(result))) + } + (ColumnarValue::Array(base_arr), ColumnarValue::Scalar(val_scalar)) => { + let value = match val_scalar { + ScalarValue::Float64(Some(v)) => *v, + ScalarValue::Float64(None) => { + let result = Float64Array::new_null(base_arr.len()); + return Ok(ColumnarValue::Array(Arc::new(result))); + } + _ => { + return Err(DataFusionError::Internal(format!( + "spark_log expected Float64 scalar for value, got {val_scalar:?}", + ))); + } + }; + let bases = base_arr + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal(format!( + "spark_log expected Float64 for base, got {:?}", + base_arr.data_type() + )) + })?; + let result: Float64Array = bases + .iter() + .map(|b| b.and_then(|base| compute(base, value))) + .collect(); + Ok(ColumnarValue::Array(Arc::new(result))) + } + (ColumnarValue::Scalar(base_scalar), ColumnarValue::Scalar(val_scalar)) => { + let result = match (base_scalar, val_scalar) { + (ScalarValue::Float64(Some(base)), ScalarValue::Float64(Some(value))) => { + ScalarValue::Float64(compute(*base, *value)) + } + (ScalarValue::Float64(_), ScalarValue::Float64(_)) => { + ScalarValue::Float64(None) + } + _ => { + return Err(DataFusionError::Internal(format!( + "spark_log expected Float64 scalars, got {base_scalar:?} and {val_scalar:?}", + ))); + } + }; + Ok(ColumnarValue::Scalar(result)) + } + } +} + +#[cfg(test)] +mod test { + use super::*; + use arrow::array::Array; + + #[test] + fn test_spark_log_basic() { + let bases = Float64Array::from(vec![10.0, 2.0, 10.0]); + let values = Float64Array::from(vec![100.0, 8.0, 1.0]); + let result = spark_log(&[ + ColumnarValue::Array(Arc::new(bases)), + ColumnarValue::Array(Arc::new(values)), + ]) + .unwrap(); + if let ColumnarValue::Array(arr) = result { + let arr = arr.as_any().downcast_ref::().unwrap(); + assert!((arr.value(0) - 2.0).abs() < 1e-10); + assert!((arr.value(1) - 3.0).abs() < 1e-10); + assert!((arr.value(2) - 0.0).abs() < 1e-10); + } else { + panic!("expected array result"); + } + } + + #[test] + fn test_spark_log_non_positive_returns_null() { + let bases = Float64Array::from(vec![Some(0.0), Some(-1.0), Some(10.0), Some(10.0)]); + let values = Float64Array::from(vec![Some(10.0), Some(10.0), Some(0.0), Some(-1.0)]); + let result = spark_log(&[ + ColumnarValue::Array(Arc::new(bases)), + ColumnarValue::Array(Arc::new(values)), + ]) + .unwrap(); + if let ColumnarValue::Array(arr) = result { + let arr = arr.as_any().downcast_ref::().unwrap(); + assert!(arr.is_null(0)); + assert!(arr.is_null(1)); + assert!(arr.is_null(2)); + assert!(arr.is_null(3)); + } else { + panic!("expected array result"); + } + } + + #[test] + fn test_spark_log_null_propagation() { + let bases = Float64Array::from(vec![Some(10.0), None]); + let values = Float64Array::from(vec![None, Some(10.0)]); + let result = spark_log(&[ + ColumnarValue::Array(Arc::new(bases)), + ColumnarValue::Array(Arc::new(values)), + ]) + .unwrap(); + if let ColumnarValue::Array(arr) = result { + let arr = arr.as_any().downcast_ref::().unwrap(); + assert!(arr.is_null(0)); + assert!(arr.is_null(1)); + } else { + panic!("expected array result"); + } + } + + #[test] + fn test_spark_log_base_one_returns_nan() { + // log(1, 1) = ln(1) / ln(1) = 0/0 = NaN + let bases = Float64Array::from(vec![1.0]); + let values = Float64Array::from(vec![1.0]); + let result = spark_log(&[ + ColumnarValue::Array(Arc::new(bases)), + ColumnarValue::Array(Arc::new(values)), + ]) + .unwrap(); + if let ColumnarValue::Array(arr) = result { + let arr = arr.as_any().downcast_ref::().unwrap(); + assert!(arr.value(0).is_nan()); + } else { + panic!("expected array result"); + } + } +} diff --git a/native/spark-expr/src/math_funcs/mod.rs b/native/spark-expr/src/math_funcs/mod.rs index 1219bc7208..f66c584e2f 100644 --- a/native/spark-expr/src/math_funcs/mod.rs +++ b/native/spark-expr/src/math_funcs/mod.rs @@ -21,6 +21,7 @@ pub(crate) mod checked_arithmetic; mod div; mod floor; pub mod internal; +pub(crate) mod log; pub mod modulo_expr; mod negative; mod round; @@ -33,6 +34,7 @@ pub use div::spark_decimal_div; pub use div::spark_decimal_integral_div; pub use floor::spark_floor; pub use internal::*; +pub use log::spark_log; pub use modulo_expr::create_modulo_expr; pub use negative::{create_negate_expr, NegativeExpr}; pub use round::spark_round; diff --git a/spark/src/main/scala/org/apache/comet/serde/math.scala b/spark/src/main/scala/org/apache/comet/serde/math.scala index b258ca6fae..45c60b8226 100644 --- a/spark/src/main/scala/org/apache/comet/serde/math.scala +++ b/spark/src/main/scala/org/apache/comet/serde/math.scala @@ -20,7 +20,7 @@ package org.apache.comet.serde import org.apache.spark.sql.catalyst.expressions.{Abs, Atan2, Attribute, Ceil, CheckOverflow, Expression, Floor, Hex, If, LessThanOrEqual, Literal, Log, Log10, Log2, Logarithm, Tan, Unhex} -import org.apache.spark.sql.types.{DecimalType, NumericType} +import org.apache.spark.sql.types.{DecimalType, DoubleType, NumericType} import org.apache.comet.CometSparkSessionExtensions.withInfo import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType, serializeDataType} @@ -138,16 +138,17 @@ object CometLog2 extends CometExpressionSerde[Log2] with MathExprBase { } } -object CometLogarithm extends CometExpressionSerde[Logarithm] with MathExprBase { +object CometLogarithm extends CometExpressionSerde[Logarithm] { override def convert( expr: Logarithm, inputs: Seq[Attribute], binding: Boolean): Option[ExprOuterClass.Expr] = { - // Spark's Logarithm(left=base, right=value) returns null when result is NaN, - // which happens when base <= 0 or value <= 0. Apply nullIfNegative to both. - val leftExpr = exprToProtoInternal(nullIfNegative(expr.left), inputs, binding) - val rightExpr = exprToProtoInternal(nullIfNegative(expr.right), inputs, binding) - val optExpr = scalarFunctionExprToProto("log", leftExpr, rightExpr) + // Uses custom spark_log UDF that returns null when base <= 0 or value <= 0, + // matching Spark's Logarithm.nullSafeEval behavior. + val leftExpr = exprToProtoInternal(expr.left, inputs, binding) + val rightExpr = exprToProtoInternal(expr.right, inputs, binding) + val optExpr = + scalarFunctionExprToProtoWithReturnType("spark_log", DoubleType, false, leftExpr, rightExpr) optExprWithInfo(optExpr, expr, expr.left, expr.right) } } diff --git a/spark/src/test/resources/sql-tests/expressions/math/log.sql b/spark/src/test/resources/sql-tests/expressions/math/log.sql index e7420954cf..8ee5282cd6 100644 --- a/spark/src/test/resources/sql-tests/expressions/math/log.sql +++ b/spark/src/test/resources/sql-tests/expressions/math/log.sql @@ -21,7 +21,7 @@ statement CREATE TABLE test_log(d double) USING parquet statement -INSERT INTO test_log VALUES (1.0), (2.718281828459045), (10.0), (0.5), (NULL), (cast('NaN' as double)), (cast('Infinity' as double)) +INSERT INTO test_log VALUES (1.0), (2.718281828459045), (10.0), (0.5), (NULL), (cast('NaN' as double)), (cast('Infinity' as double)), (0.0), (-1.0) query tolerance=1e-6 SELECT ln(d) FROM test_log @@ -40,3 +40,11 @@ SELECT ln(1.0), ln(2.718281828459045), ln(10.0), ln(NULL) -- literal + literal (2-arg form) query tolerance=1e-6 SELECT log(10.0, 100.0), log(2.0, 8.0), log(10.0, 1.0), log(NULL, 10.0) + +-- edge cases: base or value <= 0 should return null +query tolerance=1e-6 +SELECT log(0.0, 10.0), log(-1.0, 10.0), log(10.0, 0.0), log(10.0, -1.0), log(0.0, 0.0), log(-1.0, -1.0) + +-- edge case: log(1, 1) produces NaN (0/0) which Spark preserves as NaN +query tolerance=1e-6 +SELECT log(1.0, 1.0) diff --git a/spark/src/test/resources/sql-tests/expressions/math/log10.sql b/spark/src/test/resources/sql-tests/expressions/math/log10.sql index 1b3c9417f1..77019c9b66 100644 --- a/spark/src/test/resources/sql-tests/expressions/math/log10.sql +++ b/spark/src/test/resources/sql-tests/expressions/math/log10.sql @@ -21,7 +21,7 @@ statement CREATE TABLE test_log10(d double) USING parquet statement -INSERT INTO test_log10 VALUES (1.0), (10.0), (100.0), (0.1), (NULL), (cast('NaN' as double)), (cast('Infinity' as double)) +INSERT INTO test_log10 VALUES (1.0), (10.0), (100.0), (0.1), (NULL), (cast('NaN' as double)), (cast('Infinity' as double)), (0.0), (-1.0) query tolerance=1e-6 SELECT log10(d) FROM test_log10 diff --git a/spark/src/test/resources/sql-tests/expressions/math/log2.sql b/spark/src/test/resources/sql-tests/expressions/math/log2.sql index 5db0ca484b..01ff6f75b5 100644 --- a/spark/src/test/resources/sql-tests/expressions/math/log2.sql +++ b/spark/src/test/resources/sql-tests/expressions/math/log2.sql @@ -21,7 +21,7 @@ statement CREATE TABLE test_log2(d double) USING parquet statement -INSERT INTO test_log2 VALUES (1.0), (2.0), (4.0), (8.0), (0.5), (NULL), (cast('NaN' as double)), (cast('Infinity' as double)) +INSERT INTO test_log2 VALUES (1.0), (2.0), (4.0), (8.0), (0.5), (NULL), (cast('NaN' as double)), (cast('Infinity' as double)), (0.0), (-1.0) query tolerance=1e-6 SELECT log2(d) FROM test_log2 diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index 138b073b87..a540c61d3e 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -192,13 +192,13 @@ abstract class CometTestBase /** * Check that the query returns the correct results when Comet is enabled and that Comet - * replaced all possible operators. Use the provided `tol` when comparing floating-point + * replaced all possible operators. Use the provided `absTol` when comparing floating-point * results. */ protected def checkSparkAnswerAndOperatorWithTolerance( query: String, absTol: Double = 1e-6): (SparkPlan, SparkPlan) = { - internalCheckSparkAnswer(sql(query), assertCometNative = true, withTol = Some(absTol)) + checkSparkAnswerAndOperatorWithTol(sql(query), absTol) } /** From 2fa36873df130ce9d7fe2b6bfa7634ab1d6d3d48 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 30 Mar 2026 08:02:53 -0600 Subject: [PATCH 5/5] cargo fmt --- native/spark-expr/src/lib.rs | 3 +-- native/spark-expr/src/math_funcs/log.rs | 4 +--- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index 159c252bc2..342ef73619 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -80,8 +80,7 @@ pub use json_funcs::{FromJson, ToJson}; pub use math_funcs::{ create_modulo_expr, create_negate_expr, spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_log, spark_make_decimal, spark_round, - spark_unhex, - spark_unscaled_value, CheckOverflow, DecimalRescaleCheckOverflow, NegativeExpr, + spark_unhex, spark_unscaled_value, CheckOverflow, DecimalRescaleCheckOverflow, NegativeExpr, NormalizeNaNAndZero, WideDecimalBinaryExpr, WideDecimalOp, }; pub use query_context::{create_query_context_map, QueryContext, QueryContextMap}; diff --git a/native/spark-expr/src/math_funcs/log.rs b/native/spark-expr/src/math_funcs/log.rs index 5e37a1c843..499d4f33ed 100644 --- a/native/spark-expr/src/math_funcs/log.rs +++ b/native/spark-expr/src/math_funcs/log.rs @@ -133,9 +133,7 @@ pub fn spark_log(args: &[ColumnarValue]) -> Result { ScalarValue::Float64(compute(*base, *value)) } - (ScalarValue::Float64(_), ScalarValue::Float64(_)) => { - ScalarValue::Float64(None) - } + (ScalarValue::Float64(_), ScalarValue::Float64(_)) => ScalarValue::Float64(None), _ => { return Err(DataFusionError::Internal(format!( "spark_log expected Float64 scalars, got {base_scalar:?} and {val_scalar:?}",