|
16 | 16 | */ |
17 | 17 | package org.apache.spark.sql.execution |
18 | 18 |
|
| 19 | +import org.apache.gluten.execution._ |
| 20 | + |
19 | 21 | import org.apache.spark.sql.GlutenSQLTestsTrait |
| 22 | +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Final} |
| 23 | +import org.apache.spark.sql.catalyst.plans.logical._ |
| 24 | +import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec |
| 25 | +import org.apache.spark.sql.execution.datasources.LogicalRelation |
| 26 | +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation |
| 27 | +import org.apache.spark.sql.execution.exchange.ReusedExchangeExec |
| 28 | + |
| 29 | +import scala.reflect.ClassTag |
20 | 30 |
|
21 | 31 | class GlutenLogicalPlanTagInSparkPlanSuite |
22 | 32 | extends LogicalPlanTagInSparkPlanSuite |
23 | | - with GlutenSQLTestsTrait {} |
| 33 | + with GlutenSQLTestsTrait { |
| 34 | + |
| 35 | + // Override to use Gluten-aware logical plan tag checking. |
| 36 | + // Gluten replaces Spark physical operators with Transformer nodes that don't match |
| 37 | + // the original Spark pattern matching in LogicalPlanTagInSparkPlanSuite. |
| 38 | + override protected def checkGeneratedCode( |
| 39 | + plan: SparkPlan, |
| 40 | + checkMethodCodeSize: Boolean = true): Unit = { |
| 41 | + // Skip parent's codegen check (Gluten doesn't use WholeStageCodegen). |
| 42 | + // Only run the Gluten-aware logical plan tag check. |
| 43 | + checkGlutenLogicalPlanTag(plan) |
| 44 | + } |
| 45 | + |
| 46 | + private def isFinalAgg(aggExprs: Seq[AggregateExpression]): Boolean = { |
| 47 | + aggExprs.nonEmpty && aggExprs.forall(ae => ae.mode == Complete || ae.mode == Final) |
| 48 | + } |
| 49 | + |
| 50 | + private def checkGlutenLogicalPlanTag(plan: SparkPlan): Unit = { |
| 51 | + plan match { |
| 52 | + // Joins (Gluten + Spark) |
| 53 | + case _: BroadcastHashJoinExecTransformerBase | _: ShuffledHashJoinExecTransformerBase | |
| 54 | + _: SortMergeJoinExecTransformerBase | _: CartesianProductExecTransformer | |
| 55 | + _: BroadcastNestedLoopJoinExecTransformer | _: joins.BroadcastHashJoinExec | |
| 56 | + _: joins.ShuffledHashJoinExec | _: joins.SortMergeJoinExec | |
| 57 | + _: joins.BroadcastNestedLoopJoinExec | _: joins.CartesianProductExec => |
| 58 | + assertLogicalPlanType[Join](plan) |
| 59 | + |
| 60 | + // Aggregates - only final (Gluten + Spark) |
| 61 | + case agg: HashAggregateExecBaseTransformer if isFinalAgg(agg.aggregateExpressions) => |
| 62 | + assertLogicalPlanType[Aggregate](plan) |
| 63 | + case agg: aggregate.HashAggregateExec if isFinalAgg(agg.aggregateExpressions) => |
| 64 | + assertLogicalPlanType[Aggregate](plan) |
| 65 | + case agg: aggregate.ObjectHashAggregateExec if isFinalAgg(agg.aggregateExpressions) => |
| 66 | + assertLogicalPlanType[Aggregate](plan) |
| 67 | + case agg: aggregate.SortAggregateExec if isFinalAgg(agg.aggregateExpressions) => |
| 68 | + assertLogicalPlanType[Aggregate](plan) |
| 69 | + |
| 70 | + // Window |
| 71 | + case _: WindowExecTransformer | _: window.WindowExec => |
| 72 | + assertLogicalPlanType[Window](plan) |
| 73 | + |
| 74 | + // Union |
| 75 | + case _: ColumnarUnionExec | _: UnionExec => |
| 76 | + assertLogicalPlanType[Union](plan) |
| 77 | + |
| 78 | + // Sample |
| 79 | + case _: SampleExec => |
| 80 | + assertLogicalPlanType[Sample](plan) |
| 81 | + |
| 82 | + // Generate |
| 83 | + case _: GenerateExecTransformerBase | _: GenerateExec => |
| 84 | + assertLogicalPlanType[Generate](plan) |
| 85 | + |
| 86 | + // Exchange nodes should NOT have logical plan tags |
| 87 | + case _: ColumnarShuffleExchangeExec | _: ColumnarBroadcastExchangeExec | |
| 88 | + _: exchange.ShuffleExchangeExec | _: exchange.BroadcastExchangeExec | |
| 89 | + _: ReusedExchangeExec => |
| 90 | + assert( |
| 91 | + plan.getTagValue(SparkPlan.LOGICAL_PLAN_TAG).isEmpty, |
| 92 | + s"${plan.getClass.getSimpleName} should not have a logical plan tag") |
| 93 | + |
| 94 | + // Subquery exec nodes don't have logical plan tags |
| 95 | + case _: SubqueryExec | _: ReusedSubqueryExec => |
| 96 | + assert(plan.getTagValue(SparkPlan.LOGICAL_PLAN_TAG).isEmpty) |
| 97 | + |
| 98 | + // Gluten infrastructure nodes (no corresponding logical plan) |
| 99 | + case _: WholeStageTransformer | _: InputIteratorTransformer | _: ColumnarInputAdapter | |
| 100 | + _: VeloxResizeBatchesExec => |
| 101 | + // These are Gluten-specific wrapper nodes without logical plan links. |
| 102 | + |
| 103 | + // Scan trees |
| 104 | + case _ if isGlutenScanPlanTree(plan) => |
| 105 | + // For scan plan trees (leaf under Project/Filter), we check that the leaf node |
| 106 | + // has a correct logical plan link. The intermediate Project/Filter nodes may not |
| 107 | + // have tags if they were created by Gluten's rewrite rules. |
| 108 | + val physicalLeaves = plan.collectLeaves() |
| 109 | + assert( |
| 110 | + physicalLeaves.length == 1, |
| 111 | + s"Expected 1 physical leaf, got ${physicalLeaves.length}") |
| 112 | + |
| 113 | + val leafNode = physicalLeaves.head |
| 114 | + // Find the logical plan from the leaf or any ancestor with a tag |
| 115 | + val logicalPlanOpt = leafNode |
| 116 | + .getTagValue(SparkPlan.LOGICAL_PLAN_TAG) |
| 117 | + .orElse(leafNode.getTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG)) |
| 118 | + .orElse(findLogicalPlanInTree(plan)) |
| 119 | + |
| 120 | + logicalPlanOpt.foreach { |
| 121 | + lp => |
| 122 | + val logicalPlan = lp match { |
| 123 | + case w: WithCTE => w.plan |
| 124 | + case o => o |
| 125 | + } |
| 126 | + val logicalLeaves = logicalPlan.collectLeaves() |
| 127 | + assert( |
| 128 | + logicalLeaves.length == 1, |
| 129 | + s"Expected 1 logical leaf, got ${logicalLeaves.length}") |
| 130 | + physicalLeaves.head match { |
| 131 | + case _: RangeExec => assert(logicalLeaves.head.isInstanceOf[Range]) |
| 132 | + case _: DataSourceScanExec | _: BasicScanExecTransformer => |
| 133 | + assert(logicalLeaves.head.isInstanceOf[LogicalRelation]) |
| 134 | + case _: InMemoryTableScanExec => |
| 135 | + assert(logicalLeaves.head.isInstanceOf[columnar.InMemoryRelation]) |
| 136 | + case _: LocalTableScanExec => assert(logicalLeaves.head.isInstanceOf[LocalRelation]) |
| 137 | + case _: ExternalRDDScanExec[_] => |
| 138 | + assert(logicalLeaves.head.isInstanceOf[ExternalRDD[_]]) |
| 139 | + case _: datasources.v2.BatchScanExec => |
| 140 | + assert(logicalLeaves.head.isInstanceOf[DataSourceV2Relation]) |
| 141 | + case _ => |
| 142 | + } |
| 143 | + } |
| 144 | + return |
| 145 | + |
| 146 | + case _ => |
| 147 | + } |
| 148 | + |
| 149 | + plan.children.foreach(checkGlutenLogicalPlanTag) |
| 150 | + plan.subqueries.foreach(checkGlutenLogicalPlanTag) |
| 151 | + } |
| 152 | + |
| 153 | + @scala.annotation.tailrec |
| 154 | + private def isGlutenScanPlanTree(plan: SparkPlan): Boolean = plan match { |
| 155 | + case ColumnarToRowExec(i: InputAdapter) => isGlutenScanPlanTree(i.child) |
| 156 | + case p: ProjectExec => isGlutenScanPlanTree(p.child) |
| 157 | + case p: ProjectExecTransformer => isGlutenScanPlanTree(p.child) |
| 158 | + case f: FilterExec => isGlutenScanPlanTree(f.child) |
| 159 | + case f: FilterExecTransformerBase => isGlutenScanPlanTree(f.child) |
| 160 | + case _: LeafExecNode => true |
| 161 | + case _ => false |
| 162 | + } |
| 163 | + |
| 164 | + /** Find any node in the tree that has a LOGICAL_PLAN_TAG. */ |
| 165 | + private def findLogicalPlanInTree(plan: SparkPlan): Option[LogicalPlan] = { |
| 166 | + plan |
| 167 | + .getTagValue(SparkPlan.LOGICAL_PLAN_TAG) |
| 168 | + .orElse(plan.getTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG)) |
| 169 | + .orElse(plan.children.collectFirst { |
| 170 | + case child if findLogicalPlanInTree(child).isDefined => findLogicalPlanInTree(child).get |
| 171 | + }) |
| 172 | + } |
| 173 | + |
| 174 | + private def getGlutenLogicalPlan(node: SparkPlan): LogicalPlan = { |
| 175 | + node.getTagValue(SparkPlan.LOGICAL_PLAN_TAG).getOrElse { |
| 176 | + node.getTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG).getOrElse { |
| 177 | + fail(node.getClass.getSimpleName + " does not have a logical plan link") |
| 178 | + } |
| 179 | + } |
| 180 | + } |
| 181 | + |
| 182 | + private def assertLogicalPlanType[T <: LogicalPlan: ClassTag](node: SparkPlan): Unit = { |
| 183 | + val logicalPlan = getGlutenLogicalPlan(node) |
| 184 | + val expectedCls = implicitly[ClassTag[T]].runtimeClass |
| 185 | + assert( |
| 186 | + expectedCls == logicalPlan.getClass, |
| 187 | + s"Expected ${expectedCls.getSimpleName} but got ${logicalPlan.getClass.getSimpleName}" + |
| 188 | + s" for ${node.getClass.getSimpleName}" |
| 189 | + ) |
| 190 | + } |
| 191 | +} |
0 commit comments