diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala index b7a1e172b2c..5f1003689dd 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala @@ -162,8 +162,8 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi with Logging { } } - override def getDecimalArithmeticExprName(exprName: String): String = - if (!SQLConf.get.decimalOperationsAllowPrecisionLoss) { exprName + "_deny_precision_loss" } + override def getDecimalArithmeticExprName(exprName: String, allowPrecisionLoss: Boolean): String = + if (!allowPrecisionLoss) { exprName + "_deny_precision_loss" } else { exprName } /** Transform map_entries to Substrait. */ diff --git a/backends-velox/src/test/scala/org/apache/gluten/functions/MathFunctionsValidateSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/functions/MathFunctionsValidateSuite.scala index 2b2922629c7..9fa0d336a45 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/functions/MathFunctionsValidateSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/functions/MathFunctionsValidateSuite.scala @@ -413,4 +413,30 @@ abstract class MathFunctionsValidateSuite extends FunctionsValidateSuite { } } } + + test("decimal arithmetic respects allowPrecisionLoss captured at view analysis time") { + // Regression test for GLUTEN-11917: in Spark 4.1, arithmetic expressions embed + // allowPrecisionLoss in their evalContext at analysis time. Gluten must read from + // the expression rather than SQLConf.get, which can differ when querying a view + // analyzed under a different session config. + withTempView("t", "v") { + sql(""" + |SELECT + |CAST('1234567890123456789012345.12345678901' AS DECIMAL(38,11)) AS a, + |CAST('1234567890123456789012345.02345678901' AS DECIMAL(38,11)) AS b""".stripMargin) + .createOrReplaceTempView("t") + + // Analyze arithmetic with allowPrecisionLoss=false and cache it in the view's plan. + withSQLConf("spark.sql.decimalOperations.allowPrecisionLoss" -> "false") { + sql("CREATE OR REPLACE TEMP VIEW v AS SELECT a - b, a + b, a * b, a / b FROM t") + } + + // Query under the opposite setting -- Gluten must use the captured context, not SQLConf. + withSQLConf("spark.sql.decimalOperations.allowPrecisionLoss" -> "true") { + runQueryAndCompare("SELECT * FROM v") { + checkGlutenPlan[ProjectExecTransformer] + } + } + } + } } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala index ce0a79f0bc2..84e2d865541 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala @@ -264,7 +264,10 @@ trait SparkPlanExecApi { GenericExpressionTransformer(substraitExprName, Seq(left, right), original) } - def getDecimalArithmeticExprName(exprName: String): String = exprName + // Default: ignore allowPrecisionLoss and return exprName unchanged. Non-Velox backends + // (e.g. ClickHouse) do not use the _deny_precision_loss naming convention; they handle + // decimal precision through their own mechanisms. VeloxSparkPlanExecApi overrides this. + def getDecimalArithmeticExprName(exprName: String, allowPrecisionLoss: Boolean): String = exprName /** Transform map_entries to Substrait. */ def genMapEntriesTransformer( diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala index c50bae6e77b..7969d305025 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala @@ -647,7 +647,8 @@ object ExpressionConverter extends SQLConfHelper with Logging { DecimalArithmeticUtil.isDecimalArithmetic(b) => val arithmeticExprName = BackendsApiManager.getSparkPlanExecApiInstance.getDecimalArithmeticExprName( - getAndCheckSubstraitName(b, expressionsMap)) + getAndCheckSubstraitName(b, expressionsMap), + SparkShimLoader.getSparkShims.decimalAllowPrecisionLoss(b)) val left = replaceWithExpressionTransformer0(b.left, attributeSeq, expressionsMap) val right = @@ -664,7 +665,8 @@ object ExpressionConverter extends SQLConfHelper with Logging { ) case b: BinaryArithmetic if DecimalArithmeticUtil.isDecimalArithmetic(b) => val exprName = BackendsApiManager.getSparkPlanExecApiInstance.getDecimalArithmeticExprName( - substraitExprName) + substraitExprName, + SparkShimLoader.getSparkShims.decimalAllowPrecisionLoss(b)) if (!BackendsApiManager.getSettings.transformCheckOverflow) { GenericExpressionTransformer( exprName, diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/utils/DecimalArithmeticUtil.scala b/gluten-substrait/src/main/scala/org/apache/gluten/utils/DecimalArithmeticUtil.scala index df5ed47e838..893cbdd0f5f 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/utils/DecimalArithmeticUtil.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/utils/DecimalArithmeticUtil.scala @@ -20,7 +20,6 @@ import org.apache.gluten.exception.GlutenNotSupportException import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.spark.sql.catalyst.expressions.{Add, BinaryArithmetic, Cast, Divide, Expression, Literal, Multiply, Pmod, PromotePrecision, Remainder, Subtract} -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ByteType, Decimal, DecimalType, IntegerType, LongType, ShortType} import org.apache.spark.sql.utils.DecimalTypeUtil @@ -33,7 +32,7 @@ object DecimalArithmeticUtil { // Returns the result decimal type of a decimal arithmetic computing. def getResultType(expr: BinaryArithmetic, type1: DecimalType, type2: DecimalType): DecimalType = { - val allowPrecisionLoss = SQLConf.get.decimalOperationsAllowPrecisionLoss + val allowPrecisionLoss = SparkShimLoader.getSparkShims.decimalAllowPrecisionLoss(expr) var resultScale = 0 var resultPrecision = 0 expr match { diff --git a/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala b/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala index 4d1fd804a9c..b230c9e927b 100644 --- a/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala +++ b/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala @@ -24,7 +24,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, RaiseError, UnBase64} +import org.apache.spark.sql.catalyst.expressions.{Attribute, BinaryArithmetic, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, RaiseError, UnBase64} import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -250,6 +250,12 @@ trait SparkShims { def widerDecimalType(d1: DecimalType, d2: DecimalType): DecimalType + // Spark 4.1+ (SPARK-53968) embeds allowDecimalPrecisionLoss in each arithmetic expression's + // evalContext at analysis time. Spark41Shims overrides this to read from the expression. + // All earlier versions have no evalContext field, so reading SQLConf.get here is correct. + def decimalAllowPrecisionLoss(expr: BinaryArithmetic): Boolean = + SQLConf.get.decimalOperationsAllowPrecisionLoss + def getRewriteCreateTableAsSelect(session: SparkSession): SparkStrategy = _ => Seq.empty /** Shim method for get the "errorMessage" value for Spark 4.0 and above */ diff --git a/shims/spark41/src/main/scala/org/apache/gluten/sql/shims/spark41/Spark41Shims.scala b/shims/spark41/src/main/scala/org/apache/gluten/sql/shims/spark41/Spark41Shims.scala index 3eabf6b595a..44665a9db06 100644 --- a/shims/spark41/src/main/scala/org/apache/gluten/sql/shims/spark41/Spark41Shims.scala +++ b/shims/spark41/src/main/scala/org/apache/gluten/sql/shims/spark41/Spark41Shims.scala @@ -610,6 +610,17 @@ class Spark41Shims extends SparkShims { DecimalPrecisionTypeCoercion.widerDecimalType(d1, d2) } + override def decimalAllowPrecisionLoss(expr: BinaryArithmetic): Boolean = expr match { + case a: Add => a.evalContext.allowDecimalPrecisionLoss + case s: Subtract => s.evalContext.allowDecimalPrecisionLoss + case m: Multiply => m.evalContext.allowDecimalPrecisionLoss + case d: Divide => d.evalContext.allowDecimalPrecisionLoss + // Remainder and Pmod do not carry evalContext in Spark 4.1. They also throw + // GlutenNotSupportException in DecimalArithmeticUtil.getResultType, so they never + // reach Velox execution; SQLConf.get is a safe fallback for the name-lookup path. + case _ => SQLConf.get.decimalOperationsAllowPrecisionLoss + } + override def getErrorMessage(raiseError: RaiseError): Option[Expression] = { raiseError.errorParms match { case CreateMap(children, _)