Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi with Logging {
}
}

override def getDecimalArithmeticExprName(exprName: String): String =
if (!SQLConf.get.decimalOperationsAllowPrecisionLoss) { exprName + "_deny_precision_loss" }
override def getDecimalArithmeticExprName(exprName: String, allowPrecisionLoss: Boolean): String =
if (!allowPrecisionLoss) { exprName + "_deny_precision_loss" }
else { exprName }

/** Transform map_entries to Substrait. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -413,4 +413,30 @@ abstract class MathFunctionsValidateSuite extends FunctionsValidateSuite {
}
}
}

test("decimal arithmetic respects allowPrecisionLoss captured at view analysis time") {
// Regression test for GLUTEN-11917: in Spark 4.1, arithmetic expressions embed
// allowPrecisionLoss in their evalContext at analysis time. Gluten must read from
// the expression rather than SQLConf.get, which can differ when querying a view
// analyzed under a different session config.
withTempView("t", "v") {
sql("""
|SELECT
|CAST('1234567890123456789012345.12345678901' AS DECIMAL(38,11)) AS a,
|CAST('1234567890123456789012345.02345678901' AS DECIMAL(38,11)) AS b""".stripMargin)
.createOrReplaceTempView("t")

// Analyze arithmetic with allowPrecisionLoss=false and cache it in the view's plan.
withSQLConf("spark.sql.decimalOperations.allowPrecisionLoss" -> "false") {
sql("CREATE OR REPLACE TEMP VIEW v AS SELECT a - b, a + b, a * b, a / b FROM t")
}

// Query under the opposite setting -- Gluten must use the captured context, not SQLConf.
withSQLConf("spark.sql.decimalOperations.allowPrecisionLoss" -> "true") {
runQueryAndCompare("SELECT * FROM v") {
checkGlutenPlan[ProjectExecTransformer]
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,10 @@ trait SparkPlanExecApi {
GenericExpressionTransformer(substraitExprName, Seq(left, right), original)
}

def getDecimalArithmeticExprName(exprName: String): String = exprName
// Default: ignore allowPrecisionLoss and return exprName unchanged. Non-Velox backends
// (e.g. ClickHouse) do not use the _deny_precision_loss naming convention; they handle
// decimal precision through their own mechanisms. VeloxSparkPlanExecApi overrides this.
def getDecimalArithmeticExprName(exprName: String, allowPrecisionLoss: Boolean): String = exprName

/** Transform map_entries to Substrait. */
def genMapEntriesTransformer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,8 @@ object ExpressionConverter extends SQLConfHelper with Logging {
DecimalArithmeticUtil.isDecimalArithmetic(b) =>
val arithmeticExprName =
BackendsApiManager.getSparkPlanExecApiInstance.getDecimalArithmeticExprName(
getAndCheckSubstraitName(b, expressionsMap))
getAndCheckSubstraitName(b, expressionsMap),
SparkShimLoader.getSparkShims.decimalAllowPrecisionLoss(b))
val left =
replaceWithExpressionTransformer0(b.left, attributeSeq, expressionsMap)
val right =
Expand All @@ -664,7 +665,8 @@ object ExpressionConverter extends SQLConfHelper with Logging {
)
case b: BinaryArithmetic if DecimalArithmeticUtil.isDecimalArithmetic(b) =>
val exprName = BackendsApiManager.getSparkPlanExecApiInstance.getDecimalArithmeticExprName(
substraitExprName)
substraitExprName,
SparkShimLoader.getSparkShims.decimalAllowPrecisionLoss(b))
if (!BackendsApiManager.getSettings.transformCheckOverflow) {
GenericExpressionTransformer(
exprName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.sql.shims.SparkShimLoader

import org.apache.spark.sql.catalyst.expressions.{Add, BinaryArithmetic, Cast, Divide, Expression, Literal, Multiply, Pmod, PromotePrecision, Remainder, Subtract}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{ByteType, Decimal, DecimalType, IntegerType, LongType, ShortType}
import org.apache.spark.sql.utils.DecimalTypeUtil

Expand All @@ -33,7 +32,7 @@ object DecimalArithmeticUtil {
// Returns the result decimal type of a decimal arithmetic computing.
def getResultType(expr: BinaryArithmetic, type1: DecimalType, type2: DecimalType): DecimalType = {

val allowPrecisionLoss = SQLConf.get.decimalOperationsAllowPrecisionLoss
val allowPrecisionLoss = SparkShimLoader.getSparkShims.decimalAllowPrecisionLoss(expr)
var resultScale = 0
var resultPrecision = 0
expr match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.io.FileCommitProtocol
import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, RaiseError, UnBase64}
import org.apache.spark.sql.catalyst.expressions.{Attribute, BinaryArithmetic, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, RaiseError, UnBase64}
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
Expand Down Expand Up @@ -250,6 +250,12 @@ trait SparkShims {

def widerDecimalType(d1: DecimalType, d2: DecimalType): DecimalType

// Spark 4.1+ (SPARK-53968) embeds allowDecimalPrecisionLoss in each arithmetic expression's
// evalContext at analysis time. Spark41Shims overrides this to read from the expression.
// All earlier versions have no evalContext field, so reading SQLConf.get here is correct.
def decimalAllowPrecisionLoss(expr: BinaryArithmetic): Boolean =
SQLConf.get.decimalOperationsAllowPrecisionLoss

def getRewriteCreateTableAsSelect(session: SparkSession): SparkStrategy = _ => Seq.empty

/** Shim method for get the "errorMessage" value for Spark 4.0 and above */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,17 @@ class Spark41Shims extends SparkShims {
DecimalPrecisionTypeCoercion.widerDecimalType(d1, d2)
}

override def decimalAllowPrecisionLoss(expr: BinaryArithmetic): Boolean = expr match {
case a: Add => a.evalContext.allowDecimalPrecisionLoss
case s: Subtract => s.evalContext.allowDecimalPrecisionLoss
case m: Multiply => m.evalContext.allowDecimalPrecisionLoss
case d: Divide => d.evalContext.allowDecimalPrecisionLoss
// Remainder and Pmod do not carry evalContext in Spark 4.1. They also throw
// GlutenNotSupportException in DecimalArithmeticUtil.getResultType, so they never
// reach Velox execution; SQLConf.get is a safe fallback for the name-lookup path.
case _ => SQLConf.get.decimalOperationsAllowPrecisionLoss
}

override def getErrorMessage(raiseError: RaiseError): Option[Expression] = {
raiseError.errorParms match {
case CreateMap(children, _)
Expand Down
Loading