Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.gluten.extension.caller.CallerInfo
import org.apache.gluten.extension.columnar.ColumnarRuleApplier.ColumnarRuleCall
import org.apache.gluten.extension.columnar.FallbackTags
import org.apache.gluten.extension.columnar.offload.OffloadSingleNode
import org.apache.gluten.extension.columnar.offload.OffloadSingleNode._
import org.apache.gluten.extension.columnar.rewrite.RewriteSingleNode
import org.apache.gluten.extension.columnar.validator.Validator
import org.apache.gluten.extension.injector.Injector
Expand Down Expand Up @@ -81,7 +82,7 @@ object HeuristicTransform {
node =>
validator.validate(node) match {
case Validator.Passed =>
rule.offload(node)
rule.offloadAndPropagateTag(node)
case Validator.Failed(reason) =>
logDebug(s"Validation failed by reason: $reason on query plan: ${node.nodeName}")
if (FallbackTags.maybeOffloadable(node)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.apache.gluten.extension.columnar.heuristic

import org.apache.gluten.extension.columnar.offload.OffloadSingleNode
import org.apache.gluten.extension.columnar.offload.OffloadSingleNode._
import org.apache.gluten.logging.LogLevelUtil

import org.apache.spark.sql.catalyst.rules.Rule
Expand All @@ -25,7 +26,9 @@ import org.apache.spark.sql.execution.SparkPlan
class LegacyOffload(rules: Seq[OffloadSingleNode]) extends Rule[SparkPlan] with LogLevelUtil {
def apply(plan: SparkPlan): SparkPlan = {
val out =
rules.foldLeft(plan)((p, rule) => p.transformUp { case p => rule.offload(p) })
rules.foldLeft(plan) {
(p, rule) => p.transformUp { case node => rule.offloadAndPropagateTag(node) }
}
out
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,21 @@ trait OffloadSingleNode extends Logging {
object OffloadSingleNode {
implicit class OffloadSingleNodeOps(rule: OffloadSingleNode) {

/**
* Offloads the plan node and propagates LOGICAL_PLAN_TAG from the original node to the
* offloaded node (non-recursive). Uses setTagValue directly to avoid setLogicalLink's recursive
* propagation to children, which would incorrectly tag Exchange nodes.
*/
def offloadAndPropagateTag(node: SparkPlan): SparkPlan = {
val offloaded = rule.offload(node)
if (offloaded ne node) {
node.getTagValue(SparkPlan.LOGICAL_PLAN_TAG).foreach {
lp => offloaded.setTagValue(SparkPlan.LOGICAL_PLAN_TAG, lp)
}
}
offloaded
}

/**
* Converts the [[OffloadSingleNode]] rule to a strict version.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ object PushDownFilterToScan extends Rule[SparkPlan] with PredicateHelper {
scan) && scan.supportPushDownFilters =>
val newScan = scan.withNewPushdownFilters(splitConjunctivePredicates(filter.cond))
if (newScan.doValidate().ok()) {
newScan.copyTagsFrom(scan)
filter.withNewChildren(Seq(newScan))
} else {
filter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,7 @@ class VeloxTestSettings extends BackendTestSettings {
enableSuite[GlutenHiveResultSuite]
// TODO: 4.x enableSuite[GlutenInsertSortForLimitAndOffsetSuite] // 6 failures
enableSuite[GlutenLocalTempViewTestSuite]
// TODO: 4.x enableSuite[GlutenLogicalPlanTagInSparkPlanSuite] // RUN ABORTED
enableSuite[GlutenLogicalPlanTagInSparkPlanSuite]
enableSuite[GlutenOptimizeMetadataOnlyQuerySuite]
enableSuite[GlutenPersistedViewTestSuite]
// TODO: 4.x enableSuite[GlutenPlannerSuite] // 1 failure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,176 @@
*/
package org.apache.spark.sql.execution

import org.apache.gluten.execution._

import org.apache.spark.sql.GlutenSQLTestsTrait
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Final}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec

import scala.reflect.ClassTag

class GlutenLogicalPlanTagInSparkPlanSuite
extends LogicalPlanTagInSparkPlanSuite
with GlutenSQLTestsTrait {}
with GlutenSQLTestsTrait {

// Override to use Gluten-aware logical plan tag checking.
// Gluten replaces Spark physical operators with Transformer nodes that don't match
// the original Spark pattern matching in LogicalPlanTagInSparkPlanSuite.
override protected def checkGeneratedCode(
plan: SparkPlan,
checkMethodCodeSize: Boolean = true): Unit = {
// Skip parent's codegen check (Gluten doesn't use WholeStageCodegen).
// Only run the Gluten-aware logical plan tag check.
checkGlutenLogicalPlanTag(plan)
}

private def isFinalAgg(aggExprs: Seq[AggregateExpression]): Boolean = {
aggExprs.nonEmpty && aggExprs.forall(ae => ae.mode == Complete || ae.mode == Final)
}

private def checkGlutenLogicalPlanTag(plan: SparkPlan): Unit = {
plan match {
// Joins (Gluten + Spark)
case _: BroadcastHashJoinExecTransformerBase | _: ShuffledHashJoinExecTransformerBase |
_: SortMergeJoinExecTransformerBase | _: CartesianProductExecTransformer |
_: BroadcastNestedLoopJoinExecTransformer | _: joins.BroadcastHashJoinExec |
_: joins.ShuffledHashJoinExec | _: joins.SortMergeJoinExec |
_: joins.BroadcastNestedLoopJoinExec | _: joins.CartesianProductExec =>
assertLogicalPlanType[Join](plan)

// Aggregates - only final (Gluten + Spark)
case agg: HashAggregateExecBaseTransformer if isFinalAgg(agg.aggregateExpressions) =>
assertLogicalPlanType[Aggregate](plan)
case agg: aggregate.HashAggregateExec if isFinalAgg(agg.aggregateExpressions) =>
assertLogicalPlanType[Aggregate](plan)
case agg: aggregate.ObjectHashAggregateExec if isFinalAgg(agg.aggregateExpressions) =>
assertLogicalPlanType[Aggregate](plan)
case agg: aggregate.SortAggregateExec if isFinalAgg(agg.aggregateExpressions) =>
assertLogicalPlanType[Aggregate](plan)

// Window
case _: WindowExecTransformer | _: window.WindowExec =>
assertLogicalPlanType[Window](plan)

// Union
case _: ColumnarUnionExec | _: UnionExec =>
assertLogicalPlanType[Union](plan)

// Sample
case _: SampleExec =>
assertLogicalPlanType[Sample](plan)

// Generate
case _: GenerateExecTransformerBase | _: GenerateExec =>
assertLogicalPlanType[Generate](plan)

// Exchange nodes should NOT have logical plan tags
case _: ColumnarShuffleExchangeExec | _: ColumnarBroadcastExchangeExec |
_: exchange.ShuffleExchangeExec | _: exchange.BroadcastExchangeExec |
_: ReusedExchangeExec =>
assert(
plan.getTagValue(SparkPlan.LOGICAL_PLAN_TAG).isEmpty,
s"${plan.getClass.getSimpleName} should not have a logical plan tag")

// Subquery exec nodes don't have logical plan tags
case _: SubqueryExec | _: ReusedSubqueryExec =>
assert(plan.getTagValue(SparkPlan.LOGICAL_PLAN_TAG).isEmpty)

// Gluten infrastructure nodes (no corresponding logical plan)
case _: WholeStageTransformer | _: InputIteratorTransformer | _: ColumnarInputAdapter |
_: VeloxResizeBatchesExec =>
// These are Gluten-specific wrapper nodes without logical plan links.

// Scan trees
case _ if isGlutenScanPlanTree(plan) =>
// For scan plan trees (leaf under Project/Filter), we check that the leaf node
// has a correct logical plan link. The intermediate Project/Filter nodes may not
// have tags if they were created by Gluten's rewrite rules.
val physicalLeaves = plan.collectLeaves()
assert(
physicalLeaves.length == 1,
s"Expected 1 physical leaf, got ${physicalLeaves.length}")

val leafNode = physicalLeaves.head
// Find the logical plan from the leaf or any ancestor with a tag
val logicalPlanOpt = leafNode
.getTagValue(SparkPlan.LOGICAL_PLAN_TAG)
.orElse(leafNode.getTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG))
.orElse(findLogicalPlanInTree(plan))

logicalPlanOpt.foreach {
lp =>
val logicalPlan = lp match {
case w: WithCTE => w.plan
case o => o
}
val logicalLeaves = logicalPlan.collectLeaves()
assert(
logicalLeaves.length == 1,
s"Expected 1 logical leaf, got ${logicalLeaves.length}")
physicalLeaves.head match {
case _: RangeExec => assert(logicalLeaves.head.isInstanceOf[Range])
case _: DataSourceScanExec | _: BasicScanExecTransformer =>
assert(logicalLeaves.head.isInstanceOf[LogicalRelation])
case _: InMemoryTableScanExec =>
assert(logicalLeaves.head.isInstanceOf[columnar.InMemoryRelation])
case _: LocalTableScanExec => assert(logicalLeaves.head.isInstanceOf[LocalRelation])
case _: ExternalRDDScanExec[_] =>
assert(logicalLeaves.head.isInstanceOf[ExternalRDD[_]])
case _: datasources.v2.BatchScanExec =>
assert(logicalLeaves.head.isInstanceOf[DataSourceV2Relation])
case _ =>
}
}
return

case _ =>
}

plan.children.foreach(checkGlutenLogicalPlanTag)
plan.subqueries.foreach(checkGlutenLogicalPlanTag)
}

@scala.annotation.tailrec
private def isGlutenScanPlanTree(plan: SparkPlan): Boolean = plan match {
case ColumnarToRowExec(i: InputAdapter) => isGlutenScanPlanTree(i.child)
case p: ProjectExec => isGlutenScanPlanTree(p.child)
case p: ProjectExecTransformer => isGlutenScanPlanTree(p.child)
case f: FilterExec => isGlutenScanPlanTree(f.child)
case f: FilterExecTransformerBase => isGlutenScanPlanTree(f.child)
case _: LeafExecNode => true
case _ => false
}

/** Find any node in the tree that has a LOGICAL_PLAN_TAG. */
private def findLogicalPlanInTree(plan: SparkPlan): Option[LogicalPlan] = {
plan
.getTagValue(SparkPlan.LOGICAL_PLAN_TAG)
.orElse(plan.getTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG))
.orElse(plan.children.iterator.map(findLogicalPlanInTree).collectFirst {
case Some(lp) => lp
})
}

private def getGlutenLogicalPlan(node: SparkPlan): LogicalPlan = {
node.getTagValue(SparkPlan.LOGICAL_PLAN_TAG).getOrElse {
node.getTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG).getOrElse {
fail(node.getClass.getSimpleName + " does not have a logical plan link")
}
}
}

private def assertLogicalPlanType[T <: LogicalPlan: ClassTag](node: SparkPlan): Unit = {
val logicalPlan = getGlutenLogicalPlan(node)
val expectedCls = implicitly[ClassTag[T]].runtimeClass
assert(
expectedCls == logicalPlan.getClass,
s"Expected ${expectedCls.getSimpleName} but got ${logicalPlan.getClass.getSimpleName}" +
s" for ${node.getClass.getSimpleName}"
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,7 @@ class VeloxTestSettings extends BackendTestSettings {
// TODO: 4.x enableSuite[GlutenHiveResultSuite] // 1 failure
// TODO: 4.x enableSuite[GlutenInsertSortForLimitAndOffsetSuite] // 6 failures
enableSuite[GlutenLocalTempViewTestSuite]
// TODO: 4.x enableSuite[GlutenLogicalPlanTagInSparkPlanSuite] // RUN ABORTED
enableSuite[GlutenLogicalPlanTagInSparkPlanSuite]
enableSuite[GlutenOptimizeMetadataOnlyQuerySuite]
enableSuite[GlutenPersistedViewTestSuite]
// TODO: 4.x enableSuite[GlutenPlannerSuite] // 1 failure
Expand Down
Loading
Loading