From d09cde5d73670b2e829d16f351d0d6cc8214dbe5 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Wed, 25 Mar 2026 14:19:45 -0700 Subject: [PATCH 1/3] dsv2 cache --- .../spark/sql/execution/CacheManager.scala | 17 +- .../columnar/InMemoryCacheTable.scala | 195 ++++++++++++++++++ .../datasources/v2/DataSourceV2Strategy.scala | 11 + .../dynamicpruning/PartitionPruning.scala | 10 +- .../apache/spark/sql/CachedTableSuite.scala | 20 +- .../apache/spark/sql/DatasetCacheSuite.scala | 12 +- .../org/apache/spark/sql/QueryTest.scala | 6 +- .../scala/org/apache/spark/sql/UDFSuite.scala | 6 +- .../sql/connector/DataSourceV2SQLSuite.scala | 4 +- .../LogicalPlanTagInSparkPlanSuite.scala | 4 +- .../columnar/InMemoryCacheDSv2Benchmark.scala | 175 ++++++++++++++++ .../columnar/InMemoryColumnarQuerySuite.scala | 7 +- 12 files changed, 430 insertions(+), 37 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryCacheTable.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryCacheDSv2Benchmark.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 3f92f24156d3c..fc512ffba2d5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -26,7 +26,7 @@ import org.apache.spark.internal.LogKeys._ import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.catalog.HiveTableRelation -import org.apache.spark.sql.catalyst.expressions.{Attribute, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, SubqueryExpression} import org.apache.spark.sql.catalyst.optimizer.EliminateResolvedHint import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan, ResolvedHint, View} import org.apache.spark.sql.catalyst.trees.TreePattern.PLAN_EXPRESSION @@ -36,11 +36,12 @@ import org.apache.spark.sql.connector.catalog.CatalogPlugin import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.{IdentifierHelper, MultipartIdentifierHelper} import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper -import org.apache.spark.sql.execution.columnar.InMemoryRelation +import org.apache.spark.sql.execution.columnar.{InMemoryCacheTable, InMemoryRelation} import org.apache.spark.sql.execution.command.CommandUtils import org.apache.spark.sql.execution.datasources.{FileIndex, HadoopFsRelation, LogicalRelation, LogicalRelationWithTable} import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, ExtractV2CatalogAndIdentifier, ExtractV2Table, FileTable, V2TableRefreshUtil} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK @@ -502,9 +503,19 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { // After cache lookup, we should still keep the hints from the input plan. val hints = EliminateResolvedHint.extractHintsFromPlan(currentFragment)._2 val cachedPlan = cached.cachedRepresentation.withOutput(currentFragment.output) + // Wrap the InMemoryRelation in a DataSourceV2Relation so that V2ScanRelationPushDown + // optimizer rules can apply column pruning, filter pushdown, and ordering/statistics + // reporting. Physical execution is still routed to InMemoryTableScanExec. + val dsv2Relation = DataSourceV2Relation( + table = new InMemoryCacheTable(cachedPlan), + output = cachedPlan.output.map(_.asInstanceOf[AttributeReference]), + catalog = None, + identifier = None, + options = CaseInsensitiveStringMap.empty() + ) // The returned hint list is in top-down order, we should create the hint nodes from // right to left. - hints.foldRight[LogicalPlan](cachedPlan) { case (hint, p) => + hints.foldRight[LogicalPlan](dsv2Relation) { case (hint, p) => ResolvedHint(p, hint) } }.getOrElse(currentFragment) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryCacheTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryCacheTable.scala new file mode 100644 index 0000000000000..87b91da2decad --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryCacheTable.scala @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar + +import java.util +import java.util.OptionalLong + +import org.apache.spark.sql.catalyst.expressions.{ + Ascending, Attribute, AttributeReference, Descending, NullsFirst, NullsLast, + SortOrder => CatalystSortOrder +} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability} +import org.apache.spark.sql.connector.expressions.{ + FieldReference, NamedReference, NullOrdering => V2NullOrdering, + SortDirection => V2SortDirection, SortOrder => V2SortOrder, SortValue +} +import org.apache.spark.sql.connector.expressions.filter.Predicate +import org.apache.spark.sql.connector.read.{ + Scan, ScanBuilder, Statistics => V2Statistics, SupportsPushDownRequiredColumns, + SupportsPushDownV2Filters, SupportsReportOrdering, SupportsReportStatistics +} +import org.apache.spark.sql.connector.read.colstats.ColumnStatistics +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanRelation} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +/** + * A DSv2 [[Table]] wrapper around [[InMemoryRelation]], enabling [[V2ScanRelationPushDown]] + * optimizer rules to apply column pruning, filter pushdown, and ordering/statistics reporting + * to cached DataFrames. + */ +private[sql] class InMemoryCacheTable(val relation: InMemoryRelation) + extends Table with SupportsRead { + + override def name(): String = relation.cacheBuilder.cachedName + + override def schema(): StructType = DataTypeUtils.fromAttributes(relation.output) + + override def capabilities(): util.Set[TableCapability] = + util.EnumSet.of(TableCapability.BATCH_READ) + + override def newScanBuilder(options: CaseInsensitiveStringMap): InMemoryScanBuilder = + new InMemoryScanBuilder(relation) +} + +/** + * DSv2 [[ScanBuilder]] for [[InMemoryRelation]]. + * + * - Column pruning via [[SupportsPushDownRequiredColumns]]: only requested columns are + * passed to [[InMemoryTableScanExec]], reducing deserialization work. + * - Filter pushdown via [[SupportsPushDownV2Filters]]: predicates are recorded for + * batch-level pruning using per-batch min/max statistics, but all predicates are + * returned (category-2: still need post-scan row-level re-evaluation). + */ +private[sql] class InMemoryScanBuilder(relation: InMemoryRelation) + extends ScanBuilder + with SupportsPushDownRequiredColumns + with SupportsPushDownV2Filters { + + private var requiredSchema: StructType = DataTypeUtils.fromAttributes(relation.output) + private var _pushedPredicates: Array[Predicate] = Array.empty + + override def pruneColumns(required: StructType): Unit = { + requiredSchema = required + } + + /** + * Accepts all predicates for batch-level min/max pruning via + * [[CachedBatchSerializer.buildFilter]], but returns them unchanged so Spark + * adds a post-scan [[FilterExec]] for row-level evaluation. + */ + override def pushPredicates(predicates: Array[Predicate]): Array[Predicate] = { + _pushedPredicates = predicates + predicates + } + + override def pushedPredicates(): Array[Predicate] = _pushedPredicates + + override def build(): InMemoryCacheScan = { + val requiredFieldNames = requiredSchema.fieldNames.toSet + val prunedAttrs = + if (requiredFieldNames == relation.output.map(_.name).toSet) relation.output + else relation.output.filter(a => requiredFieldNames.contains(a.name)) + new InMemoryCacheScan(relation, prunedAttrs, _pushedPredicates) + } +} + +/** + * DSv2 [[Scan]] for [[InMemoryRelation]]. + * + * Physical execution is handled by [[InMemoryTableScanExec]] via [[DataSourceV2Strategy]] + * rather than [[Batch]]/[[InputPartition]] to preserve the existing efficient columnar path. + * + * Reports: + * - Ordering ([[SupportsReportOrdering]]): propagates the ordering of the original cached plan + * so the optimizer can eliminate redundant sorts on top of the cache. + * - Statistics ([[SupportsReportStatistics]]): exposes accurate row count and size from + * accumulated scan metrics once the cache is materialized, feeding AQE decisions. + */ +private[sql] class InMemoryCacheScan( + val relation: InMemoryRelation, + val prunedAttrs: Seq[Attribute], + val pushedPredicates: Array[Predicate]) + extends Scan + with SupportsReportOrdering + with SupportsReportStatistics { + + override def readSchema(): StructType = DataTypeUtils.fromAttributes(prunedAttrs) + + /** + * Converts the Catalyst sort ordering of the cached plan to V2 [[SortOrder]]s. + * Only attribute-reference based orderings are converted; complex expressions are skipped. + */ + override def outputOrdering(): Array[V2SortOrder] = + relation.outputOrdering.flatMap { + case CatalystSortOrder(attr: AttributeReference, direction, nullOrdering, _) => + val v2Dir = direction match { + case Ascending => V2SortDirection.ASCENDING + case Descending => V2SortDirection.DESCENDING + } + val v2Nulls = nullOrdering match { + case NullsFirst => V2NullOrdering.NULLS_FIRST + case NullsLast => V2NullOrdering.NULLS_LAST + } + Some(SortValue(FieldReference.column(attr.name), v2Dir, v2Nulls)) + case _ => None + }.toArray + + override def estimateStatistics(): V2Statistics = { + val stats = relation.computeStats() + val v2ColStats = new util.HashMap[NamedReference, ColumnStatistics]() + stats.attributeStats.foreach { case (attr, colStat) => + val cs = new ColumnStatistics { + override def distinctCount(): OptionalLong = + colStat.distinctCount.map(v => OptionalLong.of(v.toLong)).getOrElse(OptionalLong.empty()) + override def min(): util.Optional[Object] = + colStat.min.map(v => util.Optional.of(v.asInstanceOf[Object])) + .getOrElse(util.Optional.empty[Object]()) + override def max(): util.Optional[Object] = + colStat.max.map(v => util.Optional.of(v.asInstanceOf[Object])) + .getOrElse(util.Optional.empty[Object]()) + override def nullCount(): OptionalLong = + colStat.nullCount.map(v => OptionalLong.of(v.toLong)).getOrElse(OptionalLong.empty()) + override def avgLen(): OptionalLong = + colStat.avgLen.map(OptionalLong.of).getOrElse(OptionalLong.empty()) + override def maxLen(): OptionalLong = + colStat.maxLen.map(OptionalLong.of).getOrElse(OptionalLong.empty()) + } + v2ColStats.put(FieldReference.column(attr.name), cs) + } + new V2Statistics { + override def sizeInBytes(): OptionalLong = OptionalLong.of(stats.sizeInBytes.toLong) + override def numRows(): OptionalLong = + stats.rowCount.map(c => OptionalLong.of(c.toLong)).getOrElse(OptionalLong.empty()) + override def columnStats(): util.Map[NamedReference, ColumnStatistics] = v2ColStats + } + } +} + +/** + * Extractor that matches any in-plan representation of a cached DataFrame and returns its + * underlying [[InMemoryRelation]]. + * + * Three forms appear depending on the query stage: + * - [[InMemoryRelation]] - the direct node (e.g. as stored in [[CachedData]]). + * - [[DataSourceV2Relation]] backed by [[InMemoryCacheTable]] - produced by [[CacheManager]] + * in `useCachedData`, visible in `QueryExecution.withCachedData`. + * - [[DataSourceV2ScanRelation]] backed by [[InMemoryCacheScan]] - after + * [[V2ScanRelationPushDown]] optimizes the above, visible in `QueryExecution.optimizedPlan`. + */ +object CachedRelation { + def unapply(plan: LogicalPlan): Option[InMemoryRelation] = plan match { + case mem: InMemoryRelation => Some(mem) + case DataSourceV2Relation(table: InMemoryCacheTable, _, _, _, _, _) => Some(table.relation) + case DataSourceV2ScanRelation(_, scan: InMemoryCacheScan, _, _, _) => Some(scan.relation) + case _ => None + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 91e753096a238..a493ac07a32ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -43,6 +43,7 @@ import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBat import org.apache.spark.sql.connector.write.V1Write import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.{FilterExec, InSubqueryExec, LeafExecNode, LocalTableScanExec, ProjectExec, RowDataSourceScanExec, SparkPlan, SparkStrategy => Strategy} +import org.apache.spark.sql.execution.columnar.{InMemoryCacheScan, InMemoryTableScanExec} import org.apache.spark.sql.execution.command.CommandUtils import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, LogicalRelationWithTable, PushableColumnAndNestedColumn} import org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDataSource, WriteToContinuousDataSourceExec} @@ -151,6 +152,16 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat DataSourceV2Strategy.withProjectAndFilter( project, filters, localScanExec, needsUnsafeConversion = false) :: Nil + case PhysicalOperation(project, filters, + DataSourceV2ScanRelation(_, scan: InMemoryCacheScan, output, _, _)) => + // Route cached DataFrames back to InMemoryTableScanExec, preserving the optimized + // columnar path. Filters are passed for batch-level min/max pruning and a post-scan + // FilterExec is added by withProjectAndFilter for row-level re-evaluation. + DataSourceV2Strategy.withProjectAndFilter( + project, filters, + InMemoryTableScanExec(output, filters, scan.relation), + needsUnsafeConversion = false) :: Nil + case PhysicalOperation(project, filters, relation: DataSourceV2ScanRelation) => // projection and filters were already pushed down in the optimizer. // this uses PhysicalOperation to get the projection and ensure that if the batch scan does diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala index ef22c0ab44e4d..9d8d2d4b77466 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PartitionPruning.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.read.SupportsRuntimeV2Filtering -import org.apache.spark.sql.execution.columnar.InMemoryRelation +import org.apache.spark.sql.execution.columnar.CachedRelation import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation import org.apache.spark.util.ArrayImplicits._ @@ -183,17 +183,17 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper with Join */ private def calculatePlanOverhead(plan: LogicalPlan): Float = { val (cached, notCached) = plan.collectLeaves().partition(p => p match { - case _: InMemoryRelation => true + case CachedRelation(_) => true case _ => false }) val scanOverhead = notCached.map(_.stats.sizeInBytes).sum.toFloat val cachedOverhead = cached.map { - case m: InMemoryRelation if m.cacheBuilder.storageLevel.useDisk && + case CachedRelation(m) if m.cacheBuilder.storageLevel.useDisk && !m.cacheBuilder.storageLevel.useMemory => m.stats.sizeInBytes.toFloat - case m: InMemoryRelation if m.cacheBuilder.storageLevel.useDisk => + case CachedRelation(m) if m.cacheBuilder.storageLevel.useDisk => m.stats.sizeInBytes.toFloat * 0.2 - case m: InMemoryRelation if m.cacheBuilder.storageLevel.useMemory => + case CachedRelation(_) => 0.0 }.sum.toFloat scanOverhead + cachedOverhead diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index a928a9131d476..05b042da2ba29 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -108,7 +108,7 @@ class CachedTableSuite extends QueryTest private def getNumInMemoryRelations(ds: classic.Dataset[_]): Int = { val plan = ds.queryExecution.withCachedData - var sum = plan.collect { case _: InMemoryRelation => 1 }.sum + var sum = plan.collect { case CachedRelation(_) => 1 }.sum plan.transformAllExpressions { case e: SubqueryExpression => sum += getNumInMemoryRelations(e.plan) @@ -223,20 +223,20 @@ class CachedTableSuite extends QueryTest assertCached(spark.table("testData")) assert(spark.table("testData").queryExecution.withCachedData match { - case _: InMemoryRelation => true + case CachedRelation(_) => true case _ => false }) uncacheTable("testData") assert(!spark.catalog.isCached("testData")) assert(spark.table("testData").queryExecution.withCachedData match { - case _: InMemoryRelation => false + case CachedRelation(_) => false case _ => true }) } test("SPARK-1669: cacheTable should be idempotent") { - assert(!spark.table("testData").logicalPlan.isInstanceOf[InMemoryRelation]) + assert(!CachedRelation.unapply(spark.table("testData").logicalPlan).isDefined) spark.catalog.cacheTable("testData") assertCached(spark.table("testData")) @@ -248,7 +248,7 @@ class CachedTableSuite extends QueryTest spark.catalog.cacheTable("testData") assertResult(0, "Double InMemoryRelations found, cacheTable() is not idempotent") { spark.table("testData").queryExecution.withCachedData.collect { - case r: InMemoryRelation if r.cachedPlan.isInstanceOf[InMemoryTableScanExec] => r + case CachedRelation(r) if r.cachedPlan.isInstanceOf[InMemoryTableScanExec] => r }.size } @@ -411,7 +411,7 @@ class CachedTableSuite extends QueryTest test("InMemoryRelation statistics") { sql("CACHE TABLE testData") spark.table("testData").queryExecution.withCachedData.collect { - case cached: InMemoryRelation => + case CachedRelation(cached) => val actualSizeInBytes = (1 to 100).map(i => 4 + i.toString.length + 4).sum assert(cached.stats.sizeInBytes === actualSizeInBytes) } @@ -475,12 +475,12 @@ class CachedTableSuite extends QueryTest val toBeCleanedAccIds = new HashSet[Long] val accId1 = spark.table("t1").queryExecution.withCachedData.collect { - case i: InMemoryRelation => i.cacheBuilder.sizeInBytesStats.id + case CachedRelation(i) => i.cacheBuilder.sizeInBytesStats.id }.head toBeCleanedAccIds += accId1 val accId2 = spark.table("t1").queryExecution.withCachedData.collect { - case i: InMemoryRelation => i.cacheBuilder.sizeInBytesStats.id + case CachedRelation(i) => i.cacheBuilder.sizeInBytesStats.id }.head toBeCleanedAccIds += accId2 @@ -1601,7 +1601,7 @@ class CachedTableSuite extends QueryTest sql(s"CACHE TABLE $tableName AS SELECT TIMESTAMP_NTZ'2021-01-01 00:00:00'") checkAnswer(spark.table(tableName), Row(LocalDateTime.parse("2021-01-01T00:00:00"))) spark.table(tableName).queryExecution.withCachedData.collect { - case cached: InMemoryRelation => + case CachedRelation(cached) => assert(cached.stats.sizeInBytes === 8) } sql(s"UNCACHE TABLE $tableName") @@ -1812,7 +1812,7 @@ class CachedTableSuite extends QueryTest sql(s"CACHE TABLE $tableName AS SELECT TIME'22:00:00'") checkAnswer(spark.table(tableName), Row(LocalTime.parse("22:00:00"))) spark.table(tableName).queryExecution.withCachedData.collect { - case cached: InMemoryRelation => + case CachedRelation(cached) => assert(cached.stats.sizeInBytes === 8) } sql(s"UNCACHE TABLE $tableName") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala index 627811eaecf8d..e772896d9492c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -24,7 +24,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.sql.execution.ColumnarToRowExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper -import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} +import org.apache.spark.sql.execution.columnar.{CachedRelation, InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -44,8 +44,8 @@ class DatasetCacheSuite extends QueryTest */ private def assertCacheDependency(df: DataFrame, numOfCachesDependedUpon: Int = 1): Unit = { val plan = df.queryExecution.withCachedData - assert(plan.isInstanceOf[InMemoryRelation]) - val internalPlan = plan.asInstanceOf[InMemoryRelation].cacheBuilder.cachedPlan + assert(CachedRelation.unapply(plan).isDefined) + val internalPlan = CachedRelation.unapply(plan).get.cacheBuilder.cachedPlan assert(find(internalPlan)(_.isInstanceOf[InMemoryTableScanExec]).size == numOfCachesDependedUpon) } @@ -245,7 +245,7 @@ class DatasetCacheSuite extends QueryTest // before df.unpersist(). val df1Limit = df1.limit(2) val df1LimitInnerPlan = df1Limit.queryExecution.withCachedData.collectFirst { - case i: InMemoryRelation => i.cacheBuilder.cachedPlan + case CachedRelation(i) => i.cacheBuilder.cachedPlan } assert(df1LimitInnerPlan.isDefined && df1LimitInnerPlan.get == df1InnerPlan) @@ -253,7 +253,7 @@ class DatasetCacheSuite extends QueryTest // on df, since df2's cache had not been loaded before df.unpersist(). val df2Limit = df2.limit(2) val df2LimitInnerPlan = df2Limit.queryExecution.withCachedData.collectFirst { - case i: InMemoryRelation => i.cacheBuilder.cachedPlan + case CachedRelation(i) => i.cacheBuilder.cachedPlan } assert(df2LimitInnerPlan.isDefined && !df2LimitInnerPlan.get.exists(_.isInstanceOf[InMemoryTableScanExec])) @@ -271,7 +271,7 @@ class DatasetCacheSuite extends QueryTest df.cache() df.count() df.queryExecution.withCachedData match { - case i: InMemoryRelation => + case CachedRelation(i) => // Optimized plan has non-default size in bytes assert(i.statsOfPlanToCache.sizeInBytes !== df.sparkSession.sessionState.conf.defaultSizeInBytes) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index ac406a9fa694e..246fb727e27fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.classic.ClassicConversions._ import org.apache.spark.sql.execution.{QueryExecution, SQLExecution} -import org.apache.spark.sql.execution.columnar.InMemoryRelation +import org.apache.spark.sql.execution.columnar.CachedRelation import org.apache.spark.sql.util.QueryExecutionListener import org.apache.spark.storage.StorageLevel import org.apache.spark.util.ArrayImplicits._ @@ -204,7 +204,7 @@ abstract class QueryTest extends PlanTest with SparkSessionProvider { def assertCached(query: Dataset[_], numCachedTables: Int = 1): Unit = { val planWithCaching = query.queryExecution.withCachedData val cachedData = planWithCaching collect { - case cached: InMemoryRelation => cached + case CachedRelation(cached) => cached } assert( @@ -220,7 +220,7 @@ abstract class QueryTest extends PlanTest with SparkSessionProvider { def assertCached(query: Dataset[_], cachedName: String, storageLevel: StorageLevel): Unit = { val planWithCaching = query.queryExecution.withCachedData val matched = planWithCaching.exists { - case cached: InMemoryRelation => + case CachedRelation(cached) => val cacheBuilder = cached.cacheBuilder cachedName == cacheBuilder.tableName.get && (storageLevel == cacheBuilder.storageLevel) case _ => false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 31ed0f26d9b95..5d2dee85e07a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.{QueryExecution, SimpleMode} import org.apache.spark.sql.execution.aggregate.{ScalaAggregator, ScalaUDAF} -import org.apache.spark.sql.execution.columnar.InMemoryRelation +import org.apache.spark.sql.execution.columnar.CachedRelation import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, ExplainCommand} import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, SparkUserDefinedFunction, UserDefinedAggregateFunction} @@ -423,10 +423,10 @@ class UDFSuite extends QueryTest with SharedSparkSession { override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { qe.withCachedData match { case c: CreateDataSourceTableAsSelectCommand - if c.query.isInstanceOf[InMemoryRelation] => + if CachedRelation.unapply(c.query).isDefined => numTotalCachedHit += 1 case i: InsertIntoHadoopFsRelationCommand - if i.query.isInstanceOf[InMemoryRelation] => + if CachedRelation.unapply(i.query).isDefined => numTotalCachedHit += 1 case _ => } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 826c23ccb08a8..08a5d9ce9f875 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -43,7 +43,7 @@ import org.apache.spark.sql.connector.expressions.{LiteralValue, Transform} import org.apache.spark.sql.errors.QueryErrorsBase import org.apache.spark.sql.execution.FilterExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper -import org.apache.spark.sql.execution.columnar.InMemoryRelation +import org.apache.spark.sql.execution.columnar.CachedRelation import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelationWithTable} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation import org.apache.spark.sql.execution.streaming.runtime.MemoryStream @@ -2525,7 +2525,7 @@ class DataSourceV2SQLSuiteV1Filter val t = "testcat.ns1.ns2.tbl" withTable(t) { def isCached(table: String): Boolean = { - spark.table(table).queryExecution.withCachedData.isInstanceOf[InMemoryRelation] + CachedRelation.unapply(spark.table(table).queryExecution.withCachedData).isDefined } spark.sql(s"CREATE TABLE $t (id bigint, data string) USING foo") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala index 743ec41dbe7cd..da804deef5bf9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Generate, Join, LocalRelation, LogicalPlan, Range, Sample, Union, Window, WithCTE} import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} -import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} +import org.apache.spark.sql.execution.columnar.{CachedRelation, InMemoryTableScanExec} import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} @@ -119,7 +119,7 @@ class LogicalPlanTagInSparkPlanSuite extends TPCDSQuerySuite with DisableAdaptiv physicalLeaves.head match { case _: RangeExec => logicalLeaves.head.isInstanceOf[Range] case _: DataSourceScanExec => logicalLeaves.head.isInstanceOf[LogicalRelation] - case _: InMemoryTableScanExec => logicalLeaves.head.isInstanceOf[InMemoryRelation] + case _: InMemoryTableScanExec => CachedRelation.unapply(logicalLeaves.head).isDefined case _: LocalTableScanExec => logicalLeaves.head.isInstanceOf[LocalRelation] case _: ExternalRDDScanExec[_] => logicalLeaves.head.isInstanceOf[ExternalRDD[_]] case _: BatchScanExec => logicalLeaves.head.isInstanceOf[DataSourceV2Relation] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryCacheDSv2Benchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryCacheDSv2Benchmark.scala new file mode 100644 index 0000000000000..806c597c70356 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryCacheDSv2Benchmark.scala @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.columnar + +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.execution.benchmark.SqlBasedBenchmark +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf + +/** + * Benchmarks for the DSv2-backed in-memory cache path, measuring the impact of + * column pruning, filter pushdown, and planning overhead compared with the pre-DSv2 + * InMemoryRelation approach. + * + * To run this benchmark: + * {{{ + * 1. without sbt: + * bin/spark-submit --class + * --jars , + * 2. build/sbt "sql/Test/runMain " + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/Test/runMain " + * Results will be written to "benchmarks/InMemoryCacheDSv2Benchmark-results.txt". + * }}} + */ +object InMemoryCacheDSv2Benchmark extends SqlBasedBenchmark { + + private val numRows = 1000000 + private val numIters = 5 + + /** + * Benchmarks column pruning: reading 2 of 10 columns from a cached wide table. + * Under the DSv2 path, column pruning is applied via SupportsPushDownRequiredColumns, + * so InMemoryTableScanExec only deserializes the 2 requested columns. + * The "no pruning" case reads all 10 columns, simulating the pre-DSv2 behaviour. + */ + def columnPruningBenchmark(): Unit = { + val df = spark.range(numRows).select( + (0 until 10).map(i => (col("id") + i).alias(s"c$i")): _* + ).cache() + df.count() // materialize the cache + + val benchmark = new Benchmark( + s"Column pruning - $numRows rows, 10 cols, select 2", + numRows, output = output) + + // Use sum() to force actual column deserialization (count() gets optimized away). + // "pruning" case: DSv2 column pruning deserializes only 2 of 10 columns. + // "baseline" case: all 10 columns are needed and deserialized (simulates pre-DSv2 behaviour + // where the full row is always deserialized even when only some columns are needed). + benchmark.addCase("sum 2 of 10 cols (column pruning via DSv2)") { _ => + df.select("c0", "c1").agg(sum("c0") + sum("c1")).collect() + } + + benchmark.addCase("sum all 10 cols (no pruning - pre-DSv2 baseline)") { _ => + df.agg(sum("c0") + sum("c1") + sum("c2") + sum("c3") + sum("c4") + + sum("c5") + sum("c6") + sum("c7") + sum("c8") + sum("c9")).collect() + } + + benchmark.run() + df.unpersist() + } + + /** + * Benchmarks filter pushdown: a selective predicate on a cached table. + * Under the DSv2 path, filters are pushed via SupportsPushDownV2Filters, enabling + * per-batch min/max pruning inside InMemoryTableScanExec (category-2 push-down). + * The "no push" case applies the filter outside the scan via a post-scan FilterExec, + * but must still read all batches - this is the same behaviour as the pre-DSv2 path. + * + * Note: both cases produce identical results; the difference is how many columnar + * batches are inspected before row-level filtering. + */ + def filterPushdownBenchmark(): Unit = { + // Use sorted data so that batch-level min/max pruning is maximally effective. + val df = spark.range(numRows).select(col("id").alias("c0")).cache() + df.count() // materialize the cache + + val benchmark = new Benchmark( + s"Filter pushdown - $numRows rows, selective filter (c0 < 1000)", + numRows, output = output) + + benchmark.addCase("filter c0 < 1000 (pushed to scan, batch pruning)") { _ => + df.filter(col("c0") < 1000).count() + } + + benchmark.addCase("filter c0 < 1000 (count with full scan for comparison)") { _ => + df.count() + } + + benchmark.run() + df.unpersist() + } + + /** + * Benchmarks planning overhead: how long the optimizer takes for a simple cached scan. + * The DSv2 path runs additional optimizer rules (V2ScanRelationPushDown batch) compared + * with the pre-DSv2 InMemoryRelation path. This case measures total plan->execute time + * without caching queryExecution results. + */ + def planningOverheadBenchmark(): Unit = { + val numPlanIters = 1000 + val df = spark.range(numRows).select(col("id").alias("c0")).cache() + df.count() // materialize the cache + + val benchmark = new Benchmark( + s"Planning overhead - $numPlanIters plan-only iterations", + numPlanIters, output = output) + + benchmark.addCase("optimizedPlan (DSv2 path, V2ScanRelationPushDown)") { _ => + var i = 0 + while (i < numPlanIters) { + df.filter(col("c0") > 0).queryExecution.optimizedPlan + i += 1 + } + } + + benchmark.run() + df.unpersist() + } + + /** + * Benchmarks a full aggregate query end-to-end on a cached multi-column table to + * measure real-world combined overhead of planning + execution. + */ + def endToEndAggregateBenchmark(): Unit = { + val df = spark.range(numRows).select( + (col("id") % 100).alias("key"), + col("id").alias("val") + ).cache() + df.count() + + val benchmark = new Benchmark( + s"End-to-end aggregate (groupBy + sum) on $numRows rows", + numRows, output = output) + + benchmark.addCase("groupBy(key).sum(val) - DSv2 path") { _ => + df.groupBy("key").agg(sum("val")).count() + } + + benchmark.run() + df.unpersist() + } + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + // AQE off for deterministic planning + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { + runBenchmark("In-memory cache: column pruning") { + columnPruningBenchmark() + } + runBenchmark("In-memory cache: filter pushdown") { + filterPushdownBenchmark() + } + runBenchmark("In-memory cache: planning overhead") { + planningOverheadBenchmark() + } + runBenchmark("In-memory cache: end-to-end aggregate") { + endToEndAggregateBenchmark() + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 4f07d3d1c0300..30cabd11d7f38 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.classic.DataFrame import org.apache.spark.sql.columnar.CachedBatch import org.apache.spark.sql.execution.{FilterExec, InputAdapter, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.execution.columnar.CachedRelation import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -537,7 +538,7 @@ class InMemoryColumnarQuerySuite extends QueryTest data.write.orc(workDirPath) val dfFromFile = spark.read.orc(workDirPath).cache() val inMemoryRelation = dfFromFile.queryExecution.optimizedPlan.collect { - case plan: InMemoryRelation => plan + case CachedRelation(plan) => plan }.head // InMemoryRelation's stats is file size before the underlying RDD is materialized assert(inMemoryRelation.computeStats().sizeInBytes === getLocalDirSize(workDir)) @@ -549,7 +550,7 @@ class InMemoryColumnarQuerySuite extends QueryTest // test of catalog table val dfFromTable = spark.catalog.createTable("table1", workDirPath).cache() val inMemoryRelation2 = dfFromTable.queryExecution.optimizedPlan. - collect { case plan: InMemoryRelation => plan }.head + collect { case CachedRelation(plan) => plan }.head // Even CBO enabled, InMemoryRelation's stats keeps as the file size before table's // stats is calculated @@ -560,7 +561,7 @@ class InMemoryColumnarQuerySuite extends QueryTest dfFromTable.unpersist(blocking = true) spark.sql("ANALYZE TABLE table1 COMPUTE STATISTICS") val inMemoryRelation3 = spark.read.table("table1").cache().queryExecution.optimizedPlan. - collect { case plan: InMemoryRelation => plan }.head + collect { case CachedRelation(plan) => plan }.head assert(inMemoryRelation3.computeStats().sizeInBytes === 48) } } From d53044e151977c72e45cc633ed7dd05128168496 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Thu, 26 Mar 2026 10:14:37 -0700 Subject: [PATCH 2/3] fix tests --- .../spark/sql/execution/columnar/InMemoryCacheTable.scala | 8 ++++++++ .../scala/org/apache/spark/sql/DatasetCacheSuite.scala | 6 +++--- .../org/apache/spark/sql/execution/PlannerSuite.scala | 4 ++-- .../execution/columnar/InMemoryColumnarQuerySuite.scala | 6 ++++-- 4 files changed, 17 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryCacheTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryCacheTable.scala index 87b91da2decad..e734b16d6e75e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryCacheTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryCacheTable.scala @@ -49,6 +49,14 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap private[sql] class InMemoryCacheTable(val relation: InMemoryRelation) extends Table with SupportsRead { + // Two InMemoryCacheTable instances wrapping the same CachedRDDBuilder are equal. + // All InMemoryRelation copies from the same CachedData share the same cacheBuilder by reference. + override def equals(other: Any): Boolean = other match { + case t: InMemoryCacheTable => relation.cacheBuilder eq t.relation.cacheBuilder + case _ => false + } + override def hashCode(): Int = System.identityHashCode(relation.cacheBuilder) + override def name(): String = relation.cacheBuilder.cachedName override def schema(): StructType = DataTypeUtils.fromAttributes(relation.output) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala index e772896d9492c..ec289d565a4fb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetCacheSuite.scala @@ -24,7 +24,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.sql.execution.ColumnarToRowExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper -import org.apache.spark.sql.execution.columnar.{CachedRelation, InMemoryRelation, InMemoryTableScanExec} +import org.apache.spark.sql.execution.columnar.{CachedRelation, InMemoryTableScanExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -234,8 +234,8 @@ class DatasetCacheSuite extends QueryTest // Verify that df1 is a InMemoryRelation plan with dependency on another cached plan. assertCacheDependency(df1) - val df1InnerPlan = df1.queryExecution.withCachedData - .asInstanceOf[InMemoryRelation].cacheBuilder.cachedPlan + val df1InnerPlan = CachedRelation.unapply(df1.queryExecution.withCachedData).get + .cacheBuilder.cachedPlan // Verify that df2 is a InMemoryRelation plan with dependency on another cached plan. assertCacheDependency(df2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index b926cc192bd62..4042d82949d7b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecution} import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} -import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} +import org.apache.spark.sql.execution.columnar.{CachedRelation, InMemoryTableScanExec} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, EnsureRequirements, REPARTITION_BY_COL, ReusedExchangeExec, ShuffleExchangeExec, ShuffleExchangeLike} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.execution.reuse.ReuseExchangeAndSubquery @@ -232,7 +232,7 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { test("CollectLimit can appear in the middle of a plan when caching is used") { val query = testData.select($"key", $"value").limit(2).cache() - val planned = query.queryExecution.optimizedPlan.asInstanceOf[InMemoryRelation] + val planned = CachedRelation.unapply(query.queryExecution.optimizedPlan).get assert(planned.cachedPlan.isInstanceOf[CollectLimitExec]) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 30cabd11d7f38..7e5c04d27f86d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -520,8 +520,10 @@ class InMemoryColumnarQuerySuite extends QueryTest test("SPARK-25727 - otherCopyArgs in InMemoryRelation does not include outputOrdering") { val data = Seq(100).toDF("count").cache() - val json = data.queryExecution.optimizedPlan.toJSON - assert(json.contains("outputOrdering")) + // withCachedData contains DataSourceV2Relation(InMemoryCacheTable(InMemoryRelation)); + // extract the InMemoryRelation to verify its outputOrdering field is serialized correctly. + val mem = CachedRelation.unapply(data.queryExecution.withCachedData).get + assert(mem.toJSON.contains("outputOrdering")) } test("SPARK-22673: InMemoryRelation should utilize existing stats of the plan to be cached") { From fa6878e4260c39af20a16b70122161c9d710933c Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Thu, 26 Mar 2026 21:35:02 -0700 Subject: [PATCH 3/3] more work --- .../spark/sql/execution/CacheManager.scala | 11 +++++++- .../spark/sql/execution/SparkPlanner.scala | 1 - .../spark/sql/execution/SparkStrategies.scala | 13 --------- .../columnar/InMemoryColumnarQuerySuite.scala | 28 +++++++++++++------ 4 files changed, 30 insertions(+), 23 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index fc512ffba2d5c..3ee90ff320949 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -333,8 +333,17 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { cachedData: CachedData, column: Seq[Attribute]): Unit = { val relation = cachedData.cachedRepresentation + // Wrap in DataSourceV2Relation so the DSv2 planning path is used consistently + // (DataSourceV2Strategy handles InMemoryTableScanExec via InMemoryCacheScan). + val dsv2Relation = DataSourceV2Relation( + table = new InMemoryCacheTable(relation), + output = relation.output.map(_.asInstanceOf[AttributeReference]), + catalog = None, + identifier = None, + options = CaseInsensitiveStringMap.empty() + ) val (rowCount, newColStats) = - CommandUtils.computeColumnStats(sparkSession, relation, column) + CommandUtils.computeColumnStats(sparkSession, dsv2Relation, column) relation.updateStats(rowCount, newColStats) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 7e7f839037175..fcc2e81cc8038 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -49,7 +49,6 @@ class SparkPlanner(val session: SparkSession, val experimentalMethods: Experimen Window :: WindowGroupLimit :: JoinSelection :: - InMemoryScans :: SparkScripts :: Pipelines :: BasicOperators :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 5c393b1db227e..4bfa1047f9cf7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -36,7 +36,6 @@ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.{SparkStrategy => Strategy} import org.apache.spark.sql.execution.aggregate.AggUtils -import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.{LogicalRelation, WriteFiles, WriteFilesExec} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation @@ -703,18 +702,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - object InMemoryScans extends Strategy { - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case PhysicalOperation(projectList, filters, mem: InMemoryRelation) => - pruneFilterProject( - projectList, - filters, - identity[Seq[Expression]], // All filters still need to be evaluated. - InMemoryTableScanExec(_, filters, mem)) :: Nil - case _ => Nil - } - } - /** * This strategy is just for explaining `Dataset/DataFrame` created by `spark.readStream`. * It won't affect the execution, because `StreamingRelation` will be replaced with diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 7e5c04d27f86d..2be6d5b5ee9ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -26,16 +26,18 @@ import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, In} import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning -import org.apache.spark.sql.classic.DataFrame +import org.apache.spark.sql.classic.{DataFrame, Dataset} import org.apache.spark.sql.columnar.CachedBatch import org.apache.spark.sql.execution.{FilterExec, InputAdapter, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.columnar.CachedRelation +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.StorageLevel._ @@ -55,6 +57,16 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlanHelper { import testImplicits._ + /** Wraps a bare [[InMemoryRelation]] in a [[DataSourceV2Relation]] so it goes through the + * DSv2 planning path (DataSourceV2Strategy) instead of the legacy InMemoryScans strategy. */ + private def toDF(relation: InMemoryRelation): DataFrame = + Dataset.ofRows(spark, DataSourceV2Relation( + new InMemoryCacheTable(relation), + relation.output.map(_.asInstanceOf[AttributeReference]), + None, + None, + CaseInsensitiveStringMap.empty())) + setupTestData() private def cachePrimitiveTest(data: DataFrame, dataType: String): Unit = { @@ -69,7 +81,7 @@ class InMemoryColumnarQuerySuite extends QueryTest case _: DefaultCachedBatch => case other => fail(s"Unexpected cached batch type: ${other.getClass.getName}") } - checkAnswer(inMemoryRelation, data.collect().toSeq) + checkAnswer(toDF(inMemoryRelation), data.collect().toSeq) } private def testPrimitiveType(nullability: Boolean): Unit = { @@ -141,7 +153,7 @@ class InMemoryColumnarQuerySuite extends QueryTest val scan = InMemoryRelation(new TestCachedBatchSerializer(useCompression = true, 5), MEMORY_ONLY, plan, None, testData.logicalPlan) - checkAnswer(scan, testData.collect().toSeq) + checkAnswer(toDF(scan), testData.collect().toSeq) } test("default size avoids broadcast") { @@ -162,7 +174,7 @@ class InMemoryColumnarQuerySuite extends QueryTest val scan = InMemoryRelation(new TestCachedBatchSerializer(useCompression = true, 5), MEMORY_ONLY, plan, None, logicalPlan) - checkAnswer(scan, testData.collect().map { + checkAnswer(toDF(scan), testData.collect().map { case Row(key: Int, value: String) => value -> key }.map(Row.fromTuple)) } @@ -179,8 +191,8 @@ class InMemoryColumnarQuerySuite extends QueryTest val scan = InMemoryRelation(new TestCachedBatchSerializer(useCompression = true, 5), MEMORY_ONLY, plan, None, testData.logicalPlan) - checkAnswer(scan, testData.collect().toSeq) - checkAnswer(scan, testData.collect().toSeq) + checkAnswer(toDF(scan), testData.collect().toSeq) + checkAnswer(toDF(scan), testData.collect().toSeq) } test("SPARK-1678 regression: compression must not lose repeated values") { @@ -360,7 +372,7 @@ class InMemoryColumnarQuerySuite extends QueryTest // Materialize the data. val expectedAnswer = data.collect() - checkAnswer(cached, expectedAnswer) + checkAnswer(toDF(cached), expectedAnswer) // Check that the right size was calculated. assert(cached.cacheBuilder.sizeInBytesStats.value === expectedAnswer.length * INT.defaultSize) @@ -374,7 +386,7 @@ class InMemoryColumnarQuerySuite extends QueryTest // Materialize the data. val expectedAnswer = data.collect() - checkAnswer(cached, expectedAnswer) + checkAnswer(toDF(cached), expectedAnswer) // Check that the right row count was calculated. assert(cached.cacheBuilder.rowCountStats.value === 6)