diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index ff75de763b..1eaf0b2a97 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -18,6 +18,7 @@ use crate::hash_funcs::*; use crate::math_funcs::abs::abs; use crate::math_funcs::checked_arithmetic::{checked_add, checked_div, checked_mul, checked_sub}; +use crate::math_funcs::log::spark_log; use crate::math_funcs::modulo_expr::spark_modulo; use crate::{ spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_isnan, @@ -177,6 +178,10 @@ pub fn create_comet_physical_fun_with_eval_mode( let func = Arc::new(abs); make_comet_scalar_udf!("abs", func, without data_type) } + "spark_log" => { + let func = Arc::new(spark_log); + make_comet_scalar_udf!("spark_log", func, without data_type) + } "split" => { let func = Arc::new(crate::string_funcs::spark_split); make_comet_scalar_udf!("split", func, without data_type) diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index a7711d642d..342ef73619 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -79,8 +79,8 @@ pub use hash_funcs::*; pub use json_funcs::{FromJson, ToJson}; pub use math_funcs::{ create_modulo_expr, create_negate_expr, spark_ceil, spark_decimal_div, - spark_decimal_integral_div, spark_floor, spark_make_decimal, spark_round, spark_unhex, - spark_unscaled_value, CheckOverflow, DecimalRescaleCheckOverflow, NegativeExpr, + spark_decimal_integral_div, spark_floor, spark_log, spark_make_decimal, spark_round, + spark_unhex, spark_unscaled_value, CheckOverflow, DecimalRescaleCheckOverflow, NegativeExpr, NormalizeNaNAndZero, WideDecimalBinaryExpr, WideDecimalOp, }; pub use query_context::{create_query_context_map, QueryContext, QueryContextMap}; diff --git a/native/spark-expr/src/math_funcs/log.rs b/native/spark-expr/src/math_funcs/log.rs new file mode 100644 index 0000000000..499d4f33ed --- /dev/null +++ b/native/spark-expr/src/math_funcs/log.rs @@ -0,0 +1,227 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, Float64Array}; +use datafusion::common::{DataFusionError, ScalarValue}; +use datafusion::physical_plan::ColumnarValue; +use std::sync::Arc; + +/// Spark-compatible two-argument logarithm: `log(base, value)`. +/// +/// Returns `log(value) / log(base)`, matching Spark's `Logarithm` expression. +/// Returns null when `base <= 0` or `value <= 0`, matching Spark's `nullSafeEval`. +pub fn spark_log(args: &[ColumnarValue]) -> Result { + if args.len() != 2 { + return Err(DataFusionError::Internal(format!( + "spark_log requires 2 arguments, got {}", + args.len() + ))); + } + + // Spark's Logarithm: log(base, value) = ln(value) / ln(base) + // Returns null when base <= 0 or value <= 0 + fn compute(base: f64, value: f64) -> Option { + if base <= 0.0 || value <= 0.0 { + None + } else { + Some(value.ln() / base.ln()) + } + } + + match (&args[0], &args[1]) { + (ColumnarValue::Array(base_arr), ColumnarValue::Array(val_arr)) => { + let bases = base_arr + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal(format!( + "spark_log expected Float64 for base, got {:?}", + base_arr.data_type() + )) + })?; + let values = val_arr + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal(format!( + "spark_log expected Float64 for value, got {:?}", + val_arr.data_type() + )) + })?; + let result: Float64Array = bases + .iter() + .zip(values.iter()) + .map(|(b, v)| match (b, v) { + (Some(base), Some(value)) => compute(base, value), + _ => None, + }) + .collect(); + Ok(ColumnarValue::Array(Arc::new(result))) + } + (ColumnarValue::Scalar(base_scalar), ColumnarValue::Array(val_arr)) => { + let base = match base_scalar { + ScalarValue::Float64(Some(b)) => *b, + ScalarValue::Float64(None) => { + let result = Float64Array::new_null(val_arr.len()); + return Ok(ColumnarValue::Array(Arc::new(result))); + } + _ => { + return Err(DataFusionError::Internal(format!( + "spark_log expected Float64 scalar for base, got {base_scalar:?}", + ))); + } + }; + let values = val_arr + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal(format!( + "spark_log expected Float64 for value, got {:?}", + val_arr.data_type() + )) + })?; + let result: Float64Array = values + .iter() + .map(|v| v.and_then(|value| compute(base, value))) + .collect(); + Ok(ColumnarValue::Array(Arc::new(result))) + } + (ColumnarValue::Array(base_arr), ColumnarValue::Scalar(val_scalar)) => { + let value = match val_scalar { + ScalarValue::Float64(Some(v)) => *v, + ScalarValue::Float64(None) => { + let result = Float64Array::new_null(base_arr.len()); + return Ok(ColumnarValue::Array(Arc::new(result))); + } + _ => { + return Err(DataFusionError::Internal(format!( + "spark_log expected Float64 scalar for value, got {val_scalar:?}", + ))); + } + }; + let bases = base_arr + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal(format!( + "spark_log expected Float64 for base, got {:?}", + base_arr.data_type() + )) + })?; + let result: Float64Array = bases + .iter() + .map(|b| b.and_then(|base| compute(base, value))) + .collect(); + Ok(ColumnarValue::Array(Arc::new(result))) + } + (ColumnarValue::Scalar(base_scalar), ColumnarValue::Scalar(val_scalar)) => { + let result = match (base_scalar, val_scalar) { + (ScalarValue::Float64(Some(base)), ScalarValue::Float64(Some(value))) => { + ScalarValue::Float64(compute(*base, *value)) + } + (ScalarValue::Float64(_), ScalarValue::Float64(_)) => ScalarValue::Float64(None), + _ => { + return Err(DataFusionError::Internal(format!( + "spark_log expected Float64 scalars, got {base_scalar:?} and {val_scalar:?}", + ))); + } + }; + Ok(ColumnarValue::Scalar(result)) + } + } +} + +#[cfg(test)] +mod test { + use super::*; + use arrow::array::Array; + + #[test] + fn test_spark_log_basic() { + let bases = Float64Array::from(vec![10.0, 2.0, 10.0]); + let values = Float64Array::from(vec![100.0, 8.0, 1.0]); + let result = spark_log(&[ + ColumnarValue::Array(Arc::new(bases)), + ColumnarValue::Array(Arc::new(values)), + ]) + .unwrap(); + if let ColumnarValue::Array(arr) = result { + let arr = arr.as_any().downcast_ref::().unwrap(); + assert!((arr.value(0) - 2.0).abs() < 1e-10); + assert!((arr.value(1) - 3.0).abs() < 1e-10); + assert!((arr.value(2) - 0.0).abs() < 1e-10); + } else { + panic!("expected array result"); + } + } + + #[test] + fn test_spark_log_non_positive_returns_null() { + let bases = Float64Array::from(vec![Some(0.0), Some(-1.0), Some(10.0), Some(10.0)]); + let values = Float64Array::from(vec![Some(10.0), Some(10.0), Some(0.0), Some(-1.0)]); + let result = spark_log(&[ + ColumnarValue::Array(Arc::new(bases)), + ColumnarValue::Array(Arc::new(values)), + ]) + .unwrap(); + if let ColumnarValue::Array(arr) = result { + let arr = arr.as_any().downcast_ref::().unwrap(); + assert!(arr.is_null(0)); + assert!(arr.is_null(1)); + assert!(arr.is_null(2)); + assert!(arr.is_null(3)); + } else { + panic!("expected array result"); + } + } + + #[test] + fn test_spark_log_null_propagation() { + let bases = Float64Array::from(vec![Some(10.0), None]); + let values = Float64Array::from(vec![None, Some(10.0)]); + let result = spark_log(&[ + ColumnarValue::Array(Arc::new(bases)), + ColumnarValue::Array(Arc::new(values)), + ]) + .unwrap(); + if let ColumnarValue::Array(arr) = result { + let arr = arr.as_any().downcast_ref::().unwrap(); + assert!(arr.is_null(0)); + assert!(arr.is_null(1)); + } else { + panic!("expected array result"); + } + } + + #[test] + fn test_spark_log_base_one_returns_nan() { + // log(1, 1) = ln(1) / ln(1) = 0/0 = NaN + let bases = Float64Array::from(vec![1.0]); + let values = Float64Array::from(vec![1.0]); + let result = spark_log(&[ + ColumnarValue::Array(Arc::new(bases)), + ColumnarValue::Array(Arc::new(values)), + ]) + .unwrap(); + if let ColumnarValue::Array(arr) = result { + let arr = arr.as_any().downcast_ref::().unwrap(); + assert!(arr.value(0).is_nan()); + } else { + panic!("expected array result"); + } + } +} diff --git a/native/spark-expr/src/math_funcs/mod.rs b/native/spark-expr/src/math_funcs/mod.rs index 1219bc7208..f66c584e2f 100644 --- a/native/spark-expr/src/math_funcs/mod.rs +++ b/native/spark-expr/src/math_funcs/mod.rs @@ -21,6 +21,7 @@ pub(crate) mod checked_arithmetic; mod div; mod floor; pub mod internal; +pub(crate) mod log; pub mod modulo_expr; mod negative; mod round; @@ -33,6 +34,7 @@ pub use div::spark_decimal_div; pub use div::spark_decimal_integral_div; pub use floor::spark_floor; pub use internal::*; +pub use log::spark_log; pub use modulo_expr::create_modulo_expr; pub use negative::{create_negate_expr, NegativeExpr}; pub use round::spark_round; diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 02a76f69f0..9322149e46 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -105,6 +105,7 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[Log] -> CometLog, classOf[Log2] -> CometLog2, classOf[Log10] -> CometLog10, + classOf[Logarithm] -> CometLogarithm, classOf[Multiply] -> CometMultiply, classOf[Pow] -> CometScalarFunction("pow"), classOf[Rand] -> CometRand, diff --git a/spark/src/main/scala/org/apache/comet/serde/math.scala b/spark/src/main/scala/org/apache/comet/serde/math.scala index 5a0393142a..45c60b8226 100644 --- a/spark/src/main/scala/org/apache/comet/serde/math.scala +++ b/spark/src/main/scala/org/apache/comet/serde/math.scala @@ -19,8 +19,8 @@ package org.apache.comet.serde -import org.apache.spark.sql.catalyst.expressions.{Abs, Atan2, Attribute, Ceil, CheckOverflow, Expression, Floor, Hex, If, LessThanOrEqual, Literal, Log, Log10, Log2, Tan, Unhex} -import org.apache.spark.sql.types.{DecimalType, NumericType} +import org.apache.spark.sql.catalyst.expressions.{Abs, Atan2, Attribute, Ceil, CheckOverflow, Expression, Floor, Hex, If, LessThanOrEqual, Literal, Log, Log10, Log2, Logarithm, Tan, Unhex} +import org.apache.spark.sql.types.{DecimalType, DoubleType, NumericType} import org.apache.comet.CometSparkSessionExtensions.withInfo import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType, serializeDataType} @@ -138,6 +138,21 @@ object CometLog2 extends CometExpressionSerde[Log2] with MathExprBase { } } +object CometLogarithm extends CometExpressionSerde[Logarithm] { + override def convert( + expr: Logarithm, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + // Uses custom spark_log UDF that returns null when base <= 0 or value <= 0, + // matching Spark's Logarithm.nullSafeEval behavior. + val leftExpr = exprToProtoInternal(expr.left, inputs, binding) + val rightExpr = exprToProtoInternal(expr.right, inputs, binding) + val optExpr = + scalarFunctionExprToProtoWithReturnType("spark_log", DoubleType, false, leftExpr, rightExpr) + optExprWithInfo(optExpr, expr, expr.left, expr.right) + } +} + object CometHex extends CometExpressionSerde[Hex] with MathExprBase { override def convert( expr: Hex, diff --git a/spark/src/test/resources/sql-tests/expressions/math/log.sql b/spark/src/test/resources/sql-tests/expressions/math/log.sql index e7420954cf..8ee5282cd6 100644 --- a/spark/src/test/resources/sql-tests/expressions/math/log.sql +++ b/spark/src/test/resources/sql-tests/expressions/math/log.sql @@ -21,7 +21,7 @@ statement CREATE TABLE test_log(d double) USING parquet statement -INSERT INTO test_log VALUES (1.0), (2.718281828459045), (10.0), (0.5), (NULL), (cast('NaN' as double)), (cast('Infinity' as double)) +INSERT INTO test_log VALUES (1.0), (2.718281828459045), (10.0), (0.5), (NULL), (cast('NaN' as double)), (cast('Infinity' as double)), (0.0), (-1.0) query tolerance=1e-6 SELECT ln(d) FROM test_log @@ -40,3 +40,11 @@ SELECT ln(1.0), ln(2.718281828459045), ln(10.0), ln(NULL) -- literal + literal (2-arg form) query tolerance=1e-6 SELECT log(10.0, 100.0), log(2.0, 8.0), log(10.0, 1.0), log(NULL, 10.0) + +-- edge cases: base or value <= 0 should return null +query tolerance=1e-6 +SELECT log(0.0, 10.0), log(-1.0, 10.0), log(10.0, 0.0), log(10.0, -1.0), log(0.0, 0.0), log(-1.0, -1.0) + +-- edge case: log(1, 1) produces NaN (0/0) which Spark preserves as NaN +query tolerance=1e-6 +SELECT log(1.0, 1.0) diff --git a/spark/src/test/resources/sql-tests/expressions/math/log10.sql b/spark/src/test/resources/sql-tests/expressions/math/log10.sql index 1b3c9417f1..77019c9b66 100644 --- a/spark/src/test/resources/sql-tests/expressions/math/log10.sql +++ b/spark/src/test/resources/sql-tests/expressions/math/log10.sql @@ -21,7 +21,7 @@ statement CREATE TABLE test_log10(d double) USING parquet statement -INSERT INTO test_log10 VALUES (1.0), (10.0), (100.0), (0.1), (NULL), (cast('NaN' as double)), (cast('Infinity' as double)) +INSERT INTO test_log10 VALUES (1.0), (10.0), (100.0), (0.1), (NULL), (cast('NaN' as double)), (cast('Infinity' as double)), (0.0), (-1.0) query tolerance=1e-6 SELECT log10(d) FROM test_log10 diff --git a/spark/src/test/resources/sql-tests/expressions/math/log2.sql b/spark/src/test/resources/sql-tests/expressions/math/log2.sql index 5db0ca484b..01ff6f75b5 100644 --- a/spark/src/test/resources/sql-tests/expressions/math/log2.sql +++ b/spark/src/test/resources/sql-tests/expressions/math/log2.sql @@ -21,7 +21,7 @@ statement CREATE TABLE test_log2(d double) USING parquet statement -INSERT INTO test_log2 VALUES (1.0), (2.0), (4.0), (8.0), (0.5), (NULL), (cast('NaN' as double)), (cast('Infinity' as double)) +INSERT INTO test_log2 VALUES (1.0), (2.0), (4.0), (8.0), (0.5), (NULL), (cast('NaN' as double)), (cast('Infinity' as double)), (0.0), (-1.0) query tolerance=1e-6 SELECT log2(d) FROM test_log2 diff --git a/spark/src/test/resources/sql-tests/expressions/math/tan.sql b/spark/src/test/resources/sql-tests/expressions/math/tan.sql index 21bd44f907..9496844804 100644 --- a/spark/src/test/resources/sql-tests/expressions/math/tan.sql +++ b/spark/src/test/resources/sql-tests/expressions/math/tan.sql @@ -16,6 +16,7 @@ -- under the License. -- ConfigMatrix: parquet.enable.dictionary=false,true +-- Config: spark.comet.expression.Tan.allowIncompatible=true statement CREATE TABLE test_tan(d double) USING parquet diff --git a/spark/src/test/scala/org/apache/comet/CometSqlFileTestSuite.scala b/spark/src/test/scala/org/apache/comet/CometSqlFileTestSuite.scala index 020759a7a6..5a0b34e056 100644 --- a/spark/src/test/scala/org/apache/comet/CometSqlFileTestSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometSqlFileTestSuite.scala @@ -102,7 +102,7 @@ class CometSqlFileTestSuite extends CometTestBase with AdaptiveSparkPlanHelper { case SparkAnswerOnly => checkSparkAnswer(sql) case WithTolerance(tol) => - checkSparkAnswerWithTolerance(sql, tol) + checkSparkAnswerAndOperatorWithTolerance(sql, tol) case ExpectFallback(reason) => checkSparkAnswerAndFallbackReason(sql, reason) case Ignore(reason) => diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index 33c1d444b9..a540c61d3e 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -190,6 +190,17 @@ abstract class CometTestBase internalCheckSparkAnswer(df, assertCometNative = false, withTol = Some(absTol)) } + /** + * Check that the query returns the correct results when Comet is enabled and that Comet + * replaced all possible operators. Use the provided `absTol` when comparing floating-point + * results. + */ + protected def checkSparkAnswerAndOperatorWithTolerance( + query: String, + absTol: Double = 1e-6): (SparkPlan, SparkPlan) = { + checkSparkAnswerAndOperatorWithTol(sql(query), absTol) + } + /** * Check that the query returns the correct results when Comet is enabled and that Comet * replaced all possible operators except for those specified in the excluded list.