diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicTransform.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicTransform.scala index b3051953740f..06f36d974547 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicTransform.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/HeuristicTransform.scala @@ -22,6 +22,7 @@ import org.apache.gluten.extension.caller.CallerInfo import org.apache.gluten.extension.columnar.ColumnarRuleApplier.ColumnarRuleCall import org.apache.gluten.extension.columnar.FallbackTags import org.apache.gluten.extension.columnar.offload.OffloadSingleNode +import org.apache.gluten.extension.columnar.offload.OffloadSingleNode._ import org.apache.gluten.extension.columnar.rewrite.RewriteSingleNode import org.apache.gluten.extension.columnar.validator.Validator import org.apache.gluten.extension.injector.Injector @@ -81,7 +82,7 @@ object HeuristicTransform { node => validator.validate(node) match { case Validator.Passed => - rule.offload(node) + rule.offloadAndPropagateTag(node) case Validator.Failed(reason) => logDebug(s"Validation failed by reason: $reason on query plan: ${node.nodeName}") if (FallbackTags.maybeOffloadable(node)) { diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/LegacyOffload.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/LegacyOffload.scala index c0c44f390d29..f5c466c2d8c5 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/LegacyOffload.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/LegacyOffload.scala @@ -17,6 +17,7 @@ package org.apache.gluten.extension.columnar.heuristic import org.apache.gluten.extension.columnar.offload.OffloadSingleNode +import org.apache.gluten.extension.columnar.offload.OffloadSingleNode._ import org.apache.gluten.logging.LogLevelUtil import org.apache.spark.sql.catalyst.rules.Rule @@ -25,7 +26,9 @@ import org.apache.spark.sql.execution.SparkPlan class LegacyOffload(rules: Seq[OffloadSingleNode]) extends Rule[SparkPlan] with LogLevelUtil { def apply(plan: SparkPlan): SparkPlan = { val out = - rules.foldLeft(plan)((p, rule) => p.transformUp { case p => rule.offload(p) }) + rules.foldLeft(plan) { + (p, rule) => p.transformUp { case node => rule.offloadAndPropagateTag(node) } + } out } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/offload/OffloadSingleNode.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/offload/OffloadSingleNode.scala index db20e9efe268..43d5d57a4f35 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/offload/OffloadSingleNode.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/offload/OffloadSingleNode.scala @@ -42,6 +42,21 @@ trait OffloadSingleNode extends Logging { object OffloadSingleNode { implicit class OffloadSingleNodeOps(rule: OffloadSingleNode) { + /** + * Offloads the plan node and propagates LOGICAL_PLAN_TAG from the original node to the + * offloaded node (non-recursive). Uses setTagValue directly to avoid setLogicalLink's recursive + * propagation to children, which would incorrectly tag Exchange nodes. + */ + def offloadAndPropagateTag(node: SparkPlan): SparkPlan = { + val offloaded = rule.offload(node) + if (offloaded ne node) { + node.getTagValue(SparkPlan.LOGICAL_PLAN_TAG).foreach { + lp => offloaded.setTagValue(SparkPlan.LOGICAL_PLAN_TAG, lp) + } + } + offloaded + } + /** * Converts the [[OffloadSingleNode]] rule to a strict version. * diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/PushDownFilterToScan.scala b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/PushDownFilterToScan.scala index 9a6e271b35ac..b9e02461b1fb 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/PushDownFilterToScan.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/PushDownFilterToScan.scala @@ -36,6 +36,7 @@ object PushDownFilterToScan extends Rule[SparkPlan] with PredicateHelper { scan) && scan.supportPushDownFilters => val newScan = scan.withNewPushdownFilters(splitConjunctivePredicates(filter.cond)) if (newScan.doValidate().ok()) { + newScan.copyTagsFrom(scan) filter.withNewChildren(Seq(newScan)) } else { filter diff --git a/gluten-ut/spark40/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala b/gluten-ut/spark40/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala index d0716932b756..7c420a11a242 100644 --- a/gluten-ut/spark40/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala +++ b/gluten-ut/spark40/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala @@ -706,7 +706,7 @@ class VeloxTestSettings extends BackendTestSettings { enableSuite[GlutenHiveResultSuite] // TODO: 4.x enableSuite[GlutenInsertSortForLimitAndOffsetSuite] // 6 failures enableSuite[GlutenLocalTempViewTestSuite] - // TODO: 4.x enableSuite[GlutenLogicalPlanTagInSparkPlanSuite] // RUN ABORTED + enableSuite[GlutenLogicalPlanTagInSparkPlanSuite] enableSuite[GlutenOptimizeMetadataOnlyQuerySuite] enableSuite[GlutenPersistedViewTestSuite] // TODO: 4.x enableSuite[GlutenPlannerSuite] // 1 failure diff --git a/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/execution/GlutenLogicalPlanTagInSparkPlanSuite.scala b/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/execution/GlutenLogicalPlanTagInSparkPlanSuite.scala index 297d3b2a3428..491961d7d092 100644 --- a/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/execution/GlutenLogicalPlanTagInSparkPlanSuite.scala +++ b/gluten-ut/spark40/src/test/scala/org/apache/spark/sql/execution/GlutenLogicalPlanTagInSparkPlanSuite.scala @@ -16,8 +16,176 @@ */ package org.apache.spark.sql.execution +import org.apache.gluten.execution._ + import org.apache.spark.sql.GlutenSQLTestsTrait +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Final} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.exchange.ReusedExchangeExec + +import scala.reflect.ClassTag class GlutenLogicalPlanTagInSparkPlanSuite extends LogicalPlanTagInSparkPlanSuite - with GlutenSQLTestsTrait {} + with GlutenSQLTestsTrait { + + // Override to use Gluten-aware logical plan tag checking. + // Gluten replaces Spark physical operators with Transformer nodes that don't match + // the original Spark pattern matching in LogicalPlanTagInSparkPlanSuite. + override protected def checkGeneratedCode( + plan: SparkPlan, + checkMethodCodeSize: Boolean = true): Unit = { + // Skip parent's codegen check (Gluten doesn't use WholeStageCodegen). + // Only run the Gluten-aware logical plan tag check. + checkGlutenLogicalPlanTag(plan) + } + + private def isFinalAgg(aggExprs: Seq[AggregateExpression]): Boolean = { + aggExprs.nonEmpty && aggExprs.forall(ae => ae.mode == Complete || ae.mode == Final) + } + + private def checkGlutenLogicalPlanTag(plan: SparkPlan): Unit = { + plan match { + // Joins (Gluten + Spark) + case _: BroadcastHashJoinExecTransformerBase | _: ShuffledHashJoinExecTransformerBase | + _: SortMergeJoinExecTransformerBase | _: CartesianProductExecTransformer | + _: BroadcastNestedLoopJoinExecTransformer | _: joins.BroadcastHashJoinExec | + _: joins.ShuffledHashJoinExec | _: joins.SortMergeJoinExec | + _: joins.BroadcastNestedLoopJoinExec | _: joins.CartesianProductExec => + assertLogicalPlanType[Join](plan) + + // Aggregates - only final (Gluten + Spark) + case agg: HashAggregateExecBaseTransformer if isFinalAgg(agg.aggregateExpressions) => + assertLogicalPlanType[Aggregate](plan) + case agg: aggregate.HashAggregateExec if isFinalAgg(agg.aggregateExpressions) => + assertLogicalPlanType[Aggregate](plan) + case agg: aggregate.ObjectHashAggregateExec if isFinalAgg(agg.aggregateExpressions) => + assertLogicalPlanType[Aggregate](plan) + case agg: aggregate.SortAggregateExec if isFinalAgg(agg.aggregateExpressions) => + assertLogicalPlanType[Aggregate](plan) + + // Window + case _: WindowExecTransformer | _: window.WindowExec => + assertLogicalPlanType[Window](plan) + + // Union + case _: ColumnarUnionExec | _: UnionExec => + assertLogicalPlanType[Union](plan) + + // Sample + case _: SampleExec => + assertLogicalPlanType[Sample](plan) + + // Generate + case _: GenerateExecTransformerBase | _: GenerateExec => + assertLogicalPlanType[Generate](plan) + + // Exchange nodes should NOT have logical plan tags + case _: ColumnarShuffleExchangeExec | _: ColumnarBroadcastExchangeExec | + _: exchange.ShuffleExchangeExec | _: exchange.BroadcastExchangeExec | + _: ReusedExchangeExec => + assert( + plan.getTagValue(SparkPlan.LOGICAL_PLAN_TAG).isEmpty, + s"${plan.getClass.getSimpleName} should not have a logical plan tag") + + // Subquery exec nodes don't have logical plan tags + case _: SubqueryExec | _: ReusedSubqueryExec => + assert(plan.getTagValue(SparkPlan.LOGICAL_PLAN_TAG).isEmpty) + + // Gluten infrastructure nodes (no corresponding logical plan) + case _: WholeStageTransformer | _: InputIteratorTransformer | _: ColumnarInputAdapter | + _: VeloxResizeBatchesExec => + // These are Gluten-specific wrapper nodes without logical plan links. + + // Scan trees + case _ if isGlutenScanPlanTree(plan) => + // For scan plan trees (leaf under Project/Filter), we check that the leaf node + // has a correct logical plan link. The intermediate Project/Filter nodes may not + // have tags if they were created by Gluten's rewrite rules. + val physicalLeaves = plan.collectLeaves() + assert( + physicalLeaves.length == 1, + s"Expected 1 physical leaf, got ${physicalLeaves.length}") + + val leafNode = physicalLeaves.head + // Find the logical plan from the leaf or any ancestor with a tag + val logicalPlanOpt = leafNode + .getTagValue(SparkPlan.LOGICAL_PLAN_TAG) + .orElse(leafNode.getTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG)) + .orElse(findLogicalPlanInTree(plan)) + + logicalPlanOpt.foreach { + lp => + val logicalPlan = lp match { + case w: WithCTE => w.plan + case o => o + } + val logicalLeaves = logicalPlan.collectLeaves() + assert( + logicalLeaves.length == 1, + s"Expected 1 logical leaf, got ${logicalLeaves.length}") + physicalLeaves.head match { + case _: RangeExec => assert(logicalLeaves.head.isInstanceOf[Range]) + case _: DataSourceScanExec | _: BasicScanExecTransformer => + assert(logicalLeaves.head.isInstanceOf[LogicalRelation]) + case _: InMemoryTableScanExec => + assert(logicalLeaves.head.isInstanceOf[columnar.InMemoryRelation]) + case _: LocalTableScanExec => assert(logicalLeaves.head.isInstanceOf[LocalRelation]) + case _: ExternalRDDScanExec[_] => + assert(logicalLeaves.head.isInstanceOf[ExternalRDD[_]]) + case _: datasources.v2.BatchScanExec => + assert(logicalLeaves.head.isInstanceOf[DataSourceV2Relation]) + case _ => + } + } + return + + case _ => + } + + plan.children.foreach(checkGlutenLogicalPlanTag) + plan.subqueries.foreach(checkGlutenLogicalPlanTag) + } + + @scala.annotation.tailrec + private def isGlutenScanPlanTree(plan: SparkPlan): Boolean = plan match { + case ColumnarToRowExec(i: InputAdapter) => isGlutenScanPlanTree(i.child) + case p: ProjectExec => isGlutenScanPlanTree(p.child) + case p: ProjectExecTransformer => isGlutenScanPlanTree(p.child) + case f: FilterExec => isGlutenScanPlanTree(f.child) + case f: FilterExecTransformerBase => isGlutenScanPlanTree(f.child) + case _: LeafExecNode => true + case _ => false + } + + /** Find any node in the tree that has a LOGICAL_PLAN_TAG. */ + private def findLogicalPlanInTree(plan: SparkPlan): Option[LogicalPlan] = { + plan + .getTagValue(SparkPlan.LOGICAL_PLAN_TAG) + .orElse(plan.getTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG)) + .orElse(plan.children.iterator.map(findLogicalPlanInTree).collectFirst { + case Some(lp) => lp + }) + } + + private def getGlutenLogicalPlan(node: SparkPlan): LogicalPlan = { + node.getTagValue(SparkPlan.LOGICAL_PLAN_TAG).getOrElse { + node.getTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG).getOrElse { + fail(node.getClass.getSimpleName + " does not have a logical plan link") + } + } + } + + private def assertLogicalPlanType[T <: LogicalPlan: ClassTag](node: SparkPlan): Unit = { + val logicalPlan = getGlutenLogicalPlan(node) + val expectedCls = implicitly[ClassTag[T]].runtimeClass + assert( + expectedCls == logicalPlan.getClass, + s"Expected ${expectedCls.getSimpleName} but got ${logicalPlan.getClass.getSimpleName}" + + s" for ${node.getClass.getSimpleName}" + ) + } +} diff --git a/gluten-ut/spark41/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala b/gluten-ut/spark41/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala index 47a1ff3d66e7..3bf6d60ece31 100644 --- a/gluten-ut/spark41/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala +++ b/gluten-ut/spark41/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala @@ -676,7 +676,7 @@ class VeloxTestSettings extends BackendTestSettings { // TODO: 4.x enableSuite[GlutenHiveResultSuite] // 1 failure // TODO: 4.x enableSuite[GlutenInsertSortForLimitAndOffsetSuite] // 6 failures enableSuite[GlutenLocalTempViewTestSuite] - // TODO: 4.x enableSuite[GlutenLogicalPlanTagInSparkPlanSuite] // RUN ABORTED + enableSuite[GlutenLogicalPlanTagInSparkPlanSuite] enableSuite[GlutenOptimizeMetadataOnlyQuerySuite] enableSuite[GlutenPersistedViewTestSuite] // TODO: 4.x enableSuite[GlutenPlannerSuite] // 1 failure diff --git a/gluten-ut/spark41/src/test/scala/org/apache/spark/sql/execution/GlutenLogicalPlanTagInSparkPlanSuite.scala b/gluten-ut/spark41/src/test/scala/org/apache/spark/sql/execution/GlutenLogicalPlanTagInSparkPlanSuite.scala index 297d3b2a3428..491961d7d092 100644 --- a/gluten-ut/spark41/src/test/scala/org/apache/spark/sql/execution/GlutenLogicalPlanTagInSparkPlanSuite.scala +++ b/gluten-ut/spark41/src/test/scala/org/apache/spark/sql/execution/GlutenLogicalPlanTagInSparkPlanSuite.scala @@ -16,8 +16,176 @@ */ package org.apache.spark.sql.execution +import org.apache.gluten.execution._ + import org.apache.spark.sql.GlutenSQLTestsTrait +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Final} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.exchange.ReusedExchangeExec + +import scala.reflect.ClassTag class GlutenLogicalPlanTagInSparkPlanSuite extends LogicalPlanTagInSparkPlanSuite - with GlutenSQLTestsTrait {} + with GlutenSQLTestsTrait { + + // Override to use Gluten-aware logical plan tag checking. + // Gluten replaces Spark physical operators with Transformer nodes that don't match + // the original Spark pattern matching in LogicalPlanTagInSparkPlanSuite. + override protected def checkGeneratedCode( + plan: SparkPlan, + checkMethodCodeSize: Boolean = true): Unit = { + // Skip parent's codegen check (Gluten doesn't use WholeStageCodegen). + // Only run the Gluten-aware logical plan tag check. + checkGlutenLogicalPlanTag(plan) + } + + private def isFinalAgg(aggExprs: Seq[AggregateExpression]): Boolean = { + aggExprs.nonEmpty && aggExprs.forall(ae => ae.mode == Complete || ae.mode == Final) + } + + private def checkGlutenLogicalPlanTag(plan: SparkPlan): Unit = { + plan match { + // Joins (Gluten + Spark) + case _: BroadcastHashJoinExecTransformerBase | _: ShuffledHashJoinExecTransformerBase | + _: SortMergeJoinExecTransformerBase | _: CartesianProductExecTransformer | + _: BroadcastNestedLoopJoinExecTransformer | _: joins.BroadcastHashJoinExec | + _: joins.ShuffledHashJoinExec | _: joins.SortMergeJoinExec | + _: joins.BroadcastNestedLoopJoinExec | _: joins.CartesianProductExec => + assertLogicalPlanType[Join](plan) + + // Aggregates - only final (Gluten + Spark) + case agg: HashAggregateExecBaseTransformer if isFinalAgg(agg.aggregateExpressions) => + assertLogicalPlanType[Aggregate](plan) + case agg: aggregate.HashAggregateExec if isFinalAgg(agg.aggregateExpressions) => + assertLogicalPlanType[Aggregate](plan) + case agg: aggregate.ObjectHashAggregateExec if isFinalAgg(agg.aggregateExpressions) => + assertLogicalPlanType[Aggregate](plan) + case agg: aggregate.SortAggregateExec if isFinalAgg(agg.aggregateExpressions) => + assertLogicalPlanType[Aggregate](plan) + + // Window + case _: WindowExecTransformer | _: window.WindowExec => + assertLogicalPlanType[Window](plan) + + // Union + case _: ColumnarUnionExec | _: UnionExec => + assertLogicalPlanType[Union](plan) + + // Sample + case _: SampleExec => + assertLogicalPlanType[Sample](plan) + + // Generate + case _: GenerateExecTransformerBase | _: GenerateExec => + assertLogicalPlanType[Generate](plan) + + // Exchange nodes should NOT have logical plan tags + case _: ColumnarShuffleExchangeExec | _: ColumnarBroadcastExchangeExec | + _: exchange.ShuffleExchangeExec | _: exchange.BroadcastExchangeExec | + _: ReusedExchangeExec => + assert( + plan.getTagValue(SparkPlan.LOGICAL_PLAN_TAG).isEmpty, + s"${plan.getClass.getSimpleName} should not have a logical plan tag") + + // Subquery exec nodes don't have logical plan tags + case _: SubqueryExec | _: ReusedSubqueryExec => + assert(plan.getTagValue(SparkPlan.LOGICAL_PLAN_TAG).isEmpty) + + // Gluten infrastructure nodes (no corresponding logical plan) + case _: WholeStageTransformer | _: InputIteratorTransformer | _: ColumnarInputAdapter | + _: VeloxResizeBatchesExec => + // These are Gluten-specific wrapper nodes without logical plan links. + + // Scan trees + case _ if isGlutenScanPlanTree(plan) => + // For scan plan trees (leaf under Project/Filter), we check that the leaf node + // has a correct logical plan link. The intermediate Project/Filter nodes may not + // have tags if they were created by Gluten's rewrite rules. + val physicalLeaves = plan.collectLeaves() + assert( + physicalLeaves.length == 1, + s"Expected 1 physical leaf, got ${physicalLeaves.length}") + + val leafNode = physicalLeaves.head + // Find the logical plan from the leaf or any ancestor with a tag + val logicalPlanOpt = leafNode + .getTagValue(SparkPlan.LOGICAL_PLAN_TAG) + .orElse(leafNode.getTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG)) + .orElse(findLogicalPlanInTree(plan)) + + logicalPlanOpt.foreach { + lp => + val logicalPlan = lp match { + case w: WithCTE => w.plan + case o => o + } + val logicalLeaves = logicalPlan.collectLeaves() + assert( + logicalLeaves.length == 1, + s"Expected 1 logical leaf, got ${logicalLeaves.length}") + physicalLeaves.head match { + case _: RangeExec => assert(logicalLeaves.head.isInstanceOf[Range]) + case _: DataSourceScanExec | _: BasicScanExecTransformer => + assert(logicalLeaves.head.isInstanceOf[LogicalRelation]) + case _: InMemoryTableScanExec => + assert(logicalLeaves.head.isInstanceOf[columnar.InMemoryRelation]) + case _: LocalTableScanExec => assert(logicalLeaves.head.isInstanceOf[LocalRelation]) + case _: ExternalRDDScanExec[_] => + assert(logicalLeaves.head.isInstanceOf[ExternalRDD[_]]) + case _: datasources.v2.BatchScanExec => + assert(logicalLeaves.head.isInstanceOf[DataSourceV2Relation]) + case _ => + } + } + return + + case _ => + } + + plan.children.foreach(checkGlutenLogicalPlanTag) + plan.subqueries.foreach(checkGlutenLogicalPlanTag) + } + + @scala.annotation.tailrec + private def isGlutenScanPlanTree(plan: SparkPlan): Boolean = plan match { + case ColumnarToRowExec(i: InputAdapter) => isGlutenScanPlanTree(i.child) + case p: ProjectExec => isGlutenScanPlanTree(p.child) + case p: ProjectExecTransformer => isGlutenScanPlanTree(p.child) + case f: FilterExec => isGlutenScanPlanTree(f.child) + case f: FilterExecTransformerBase => isGlutenScanPlanTree(f.child) + case _: LeafExecNode => true + case _ => false + } + + /** Find any node in the tree that has a LOGICAL_PLAN_TAG. */ + private def findLogicalPlanInTree(plan: SparkPlan): Option[LogicalPlan] = { + plan + .getTagValue(SparkPlan.LOGICAL_PLAN_TAG) + .orElse(plan.getTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG)) + .orElse(plan.children.iterator.map(findLogicalPlanInTree).collectFirst { + case Some(lp) => lp + }) + } + + private def getGlutenLogicalPlan(node: SparkPlan): LogicalPlan = { + node.getTagValue(SparkPlan.LOGICAL_PLAN_TAG).getOrElse { + node.getTagValue(SparkPlan.LOGICAL_PLAN_INHERITED_TAG).getOrElse { + fail(node.getClass.getSimpleName + " does not have a logical plan link") + } + } + } + + private def assertLogicalPlanType[T <: LogicalPlan: ClassTag](node: SparkPlan): Unit = { + val logicalPlan = getGlutenLogicalPlan(node) + val expectedCls = implicitly[ClassTag[T]].runtimeClass + assert( + expectedCls == logicalPlan.getClass, + s"Expected ${expectedCls.getSimpleName} but got ${logicalPlan.getClass.getSimpleName}" + + s" for ${node.getClass.getSimpleName}" + ) + } +}