Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.internal.LogKeys._
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.catalog.HiveTableRelation
import org.apache.spark.sql.catalyst.expressions.{Attribute, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, SubqueryExpression}
import org.apache.spark.sql.catalyst.optimizer.EliminateResolvedHint
import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan, ResolvedHint, View}
import org.apache.spark.sql.catalyst.trees.TreePattern.PLAN_EXPRESSION
Expand All @@ -36,11 +36,12 @@ import org.apache.spark.sql.connector.catalog.CatalogPlugin
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.{IdentifierHelper, MultipartIdentifierHelper}
import org.apache.spark.sql.connector.catalog.Identifier
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.columnar.InMemoryRelation
import org.apache.spark.sql.execution.columnar.{InMemoryCacheTable, InMemoryRelation}
import org.apache.spark.sql.execution.command.CommandUtils
import org.apache.spark.sql.execution.datasources.{FileIndex, HadoopFsRelation, LogicalRelation, LogicalRelationWithTable}
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, ExtractV2CatalogAndIdentifier, ExtractV2Table, FileTable, V2TableRefreshUtil}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.storage.StorageLevel
import org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK

Expand Down Expand Up @@ -332,8 +333,17 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper {
cachedData: CachedData,
column: Seq[Attribute]): Unit = {
val relation = cachedData.cachedRepresentation
// Wrap in DataSourceV2Relation so the DSv2 planning path is used consistently
// (DataSourceV2Strategy handles InMemoryTableScanExec via InMemoryCacheScan).
val dsv2Relation = DataSourceV2Relation(
table = new InMemoryCacheTable(relation),
output = relation.output.map(_.asInstanceOf[AttributeReference]),
catalog = None,
identifier = None,
options = CaseInsensitiveStringMap.empty()
)
val (rowCount, newColStats) =
CommandUtils.computeColumnStats(sparkSession, relation, column)
CommandUtils.computeColumnStats(sparkSession, dsv2Relation, column)
relation.updateStats(rowCount, newColStats)
}

Expand Down Expand Up @@ -502,9 +512,19 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper {
// After cache lookup, we should still keep the hints from the input plan.
val hints = EliminateResolvedHint.extractHintsFromPlan(currentFragment)._2
val cachedPlan = cached.cachedRepresentation.withOutput(currentFragment.output)
// Wrap the InMemoryRelation in a DataSourceV2Relation so that V2ScanRelationPushDown
// optimizer rules can apply column pruning, filter pushdown, and ordering/statistics
// reporting. Physical execution is still routed to InMemoryTableScanExec.
val dsv2Relation = DataSourceV2Relation(
table = new InMemoryCacheTable(cachedPlan),
output = cachedPlan.output.map(_.asInstanceOf[AttributeReference]),
catalog = None,
identifier = None,
options = CaseInsensitiveStringMap.empty()
)
// The returned hint list is in top-down order, we should create the hint nodes from
// right to left.
hints.foldRight[LogicalPlan](cachedPlan) { case (hint, p) =>
hints.foldRight[LogicalPlan](dsv2Relation) { case (hint, p) =>
ResolvedHint(p, hint)
}
}.getOrElse(currentFragment)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ class SparkPlanner(val session: SparkSession, val experimentalMethods: Experimen
Window ::
WindowGroupLimit ::
JoinSelection ::
InMemoryScans ::
SparkScripts ::
Pipelines ::
BasicOperators ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.execution.{SparkStrategy => Strategy}
import org.apache.spark.sql.execution.aggregate.AggUtils
import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
import org.apache.spark.sql.execution.command._
import org.apache.spark.sql.execution.datasources.{LogicalRelation, WriteFiles, WriteFilesExec}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
Expand Down Expand Up @@ -703,18 +702,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
}

object InMemoryScans extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case PhysicalOperation(projectList, filters, mem: InMemoryRelation) =>
pruneFilterProject(
projectList,
filters,
identity[Seq[Expression]], // All filters still need to be evaluated.
InMemoryTableScanExec(_, filters, mem)) :: Nil
case _ => Nil
}
}

/**
* This strategy is just for explaining `Dataset/DataFrame` created by `spark.readStream`.
* It won't affect the execution, because `StreamingRelation` will be replaced with
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.columnar

import java.util
import java.util.OptionalLong

import org.apache.spark.sql.catalyst.expressions.{
Ascending, Attribute, AttributeReference, Descending, NullsFirst, NullsLast,
SortOrder => CatalystSortOrder
}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability}
import org.apache.spark.sql.connector.expressions.{
FieldReference, NamedReference, NullOrdering => V2NullOrdering,
SortDirection => V2SortDirection, SortOrder => V2SortOrder, SortValue
}
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.read.{
Scan, ScanBuilder, Statistics => V2Statistics, SupportsPushDownRequiredColumns,
SupportsPushDownV2Filters, SupportsReportOrdering, SupportsReportStatistics
}
import org.apache.spark.sql.connector.read.colstats.ColumnStatistics
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanRelation}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap

/**
* A DSv2 [[Table]] wrapper around [[InMemoryRelation]], enabling [[V2ScanRelationPushDown]]
* optimizer rules to apply column pruning, filter pushdown, and ordering/statistics reporting
* to cached DataFrames.
*/
private[sql] class InMemoryCacheTable(val relation: InMemoryRelation)
extends Table with SupportsRead {

// Two InMemoryCacheTable instances wrapping the same CachedRDDBuilder are equal.
// All InMemoryRelation copies from the same CachedData share the same cacheBuilder by reference.
override def equals(other: Any): Boolean = other match {
case t: InMemoryCacheTable => relation.cacheBuilder eq t.relation.cacheBuilder
case _ => false
}
override def hashCode(): Int = System.identityHashCode(relation.cacheBuilder)

override def name(): String = relation.cacheBuilder.cachedName

override def schema(): StructType = DataTypeUtils.fromAttributes(relation.output)

override def capabilities(): util.Set[TableCapability] =
util.EnumSet.of(TableCapability.BATCH_READ)

override def newScanBuilder(options: CaseInsensitiveStringMap): InMemoryScanBuilder =
new InMemoryScanBuilder(relation)
}

/**
* DSv2 [[ScanBuilder]] for [[InMemoryRelation]].
*
* - Column pruning via [[SupportsPushDownRequiredColumns]]: only requested columns are
* passed to [[InMemoryTableScanExec]], reducing deserialization work.
* - Filter pushdown via [[SupportsPushDownV2Filters]]: predicates are recorded for
* batch-level pruning using per-batch min/max statistics, but all predicates are
* returned (category-2: still need post-scan row-level re-evaluation).
*/
private[sql] class InMemoryScanBuilder(relation: InMemoryRelation)
extends ScanBuilder
with SupportsPushDownRequiredColumns
with SupportsPushDownV2Filters {

private var requiredSchema: StructType = DataTypeUtils.fromAttributes(relation.output)
private var _pushedPredicates: Array[Predicate] = Array.empty

override def pruneColumns(required: StructType): Unit = {
requiredSchema = required
}

/**
* Accepts all predicates for batch-level min/max pruning via
* [[CachedBatchSerializer.buildFilter]], but returns them unchanged so Spark
* adds a post-scan [[FilterExec]] for row-level evaluation.
*/
override def pushPredicates(predicates: Array[Predicate]): Array[Predicate] = {
_pushedPredicates = predicates
predicates
}

override def pushedPredicates(): Array[Predicate] = _pushedPredicates

override def build(): InMemoryCacheScan = {
val requiredFieldNames = requiredSchema.fieldNames.toSet
val prunedAttrs =
if (requiredFieldNames == relation.output.map(_.name).toSet) relation.output
else relation.output.filter(a => requiredFieldNames.contains(a.name))
new InMemoryCacheScan(relation, prunedAttrs, _pushedPredicates)
}
}

/**
* DSv2 [[Scan]] for [[InMemoryRelation]].
*
* Physical execution is handled by [[InMemoryTableScanExec]] via [[DataSourceV2Strategy]]
* rather than [[Batch]]/[[InputPartition]] to preserve the existing efficient columnar path.
*
* Reports:
* - Ordering ([[SupportsReportOrdering]]): propagates the ordering of the original cached plan
* so the optimizer can eliminate redundant sorts on top of the cache.
* - Statistics ([[SupportsReportStatistics]]): exposes accurate row count and size from
* accumulated scan metrics once the cache is materialized, feeding AQE decisions.
*/
private[sql] class InMemoryCacheScan(
val relation: InMemoryRelation,
val prunedAttrs: Seq[Attribute],
val pushedPredicates: Array[Predicate])
extends Scan
with SupportsReportOrdering
with SupportsReportStatistics {

override def readSchema(): StructType = DataTypeUtils.fromAttributes(prunedAttrs)

/**
* Converts the Catalyst sort ordering of the cached plan to V2 [[SortOrder]]s.
* Only attribute-reference based orderings are converted; complex expressions are skipped.
*/
override def outputOrdering(): Array[V2SortOrder] =
relation.outputOrdering.flatMap {
case CatalystSortOrder(attr: AttributeReference, direction, nullOrdering, _) =>
val v2Dir = direction match {
case Ascending => V2SortDirection.ASCENDING
case Descending => V2SortDirection.DESCENDING
}
val v2Nulls = nullOrdering match {
case NullsFirst => V2NullOrdering.NULLS_FIRST
case NullsLast => V2NullOrdering.NULLS_LAST
}
Some(SortValue(FieldReference.column(attr.name), v2Dir, v2Nulls))
case _ => None
}.toArray

override def estimateStatistics(): V2Statistics = {
val stats = relation.computeStats()
val v2ColStats = new util.HashMap[NamedReference, ColumnStatistics]()
stats.attributeStats.foreach { case (attr, colStat) =>
val cs = new ColumnStatistics {
override def distinctCount(): OptionalLong =
colStat.distinctCount.map(v => OptionalLong.of(v.toLong)).getOrElse(OptionalLong.empty())
override def min(): util.Optional[Object] =
colStat.min.map(v => util.Optional.of(v.asInstanceOf[Object]))
.getOrElse(util.Optional.empty[Object]())
override def max(): util.Optional[Object] =
colStat.max.map(v => util.Optional.of(v.asInstanceOf[Object]))
.getOrElse(util.Optional.empty[Object]())
override def nullCount(): OptionalLong =
colStat.nullCount.map(v => OptionalLong.of(v.toLong)).getOrElse(OptionalLong.empty())
override def avgLen(): OptionalLong =
colStat.avgLen.map(OptionalLong.of).getOrElse(OptionalLong.empty())
override def maxLen(): OptionalLong =
colStat.maxLen.map(OptionalLong.of).getOrElse(OptionalLong.empty())
}
v2ColStats.put(FieldReference.column(attr.name), cs)
}
new V2Statistics {
override def sizeInBytes(): OptionalLong = OptionalLong.of(stats.sizeInBytes.toLong)
override def numRows(): OptionalLong =
stats.rowCount.map(c => OptionalLong.of(c.toLong)).getOrElse(OptionalLong.empty())
override def columnStats(): util.Map[NamedReference, ColumnStatistics] = v2ColStats
}
}
}

/**
* Extractor that matches any in-plan representation of a cached DataFrame and returns its
* underlying [[InMemoryRelation]].
*
* Three forms appear depending on the query stage:
* - [[InMemoryRelation]] - the direct node (e.g. as stored in [[CachedData]]).
* - [[DataSourceV2Relation]] backed by [[InMemoryCacheTable]] - produced by [[CacheManager]]
* in `useCachedData`, visible in `QueryExecution.withCachedData`.
* - [[DataSourceV2ScanRelation]] backed by [[InMemoryCacheScan]] - after
* [[V2ScanRelationPushDown]] optimizes the above, visible in `QueryExecution.optimizedPlan`.
*/
object CachedRelation {
def unapply(plan: LogicalPlan): Option[InMemoryRelation] = plan match {
case mem: InMemoryRelation => Some(mem)
case DataSourceV2Relation(table: InMemoryCacheTable, _, _, _, _, _) => Some(table.relation)
case DataSourceV2ScanRelation(_, scan: InMemoryCacheScan, _, _, _) => Some(scan.relation)
case _ => None
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBat
import org.apache.spark.sql.connector.write.V1Write
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.execution.{FilterExec, InSubqueryExec, LeafExecNode, LocalTableScanExec, ProjectExec, RowDataSourceScanExec, SparkPlan, SparkStrategy => Strategy}
import org.apache.spark.sql.execution.columnar.{InMemoryCacheScan, InMemoryTableScanExec}
import org.apache.spark.sql.execution.command.CommandUtils
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, LogicalRelationWithTable, PushableColumnAndNestedColumn}
import org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDataSource, WriteToContinuousDataSourceExec}
Expand Down Expand Up @@ -151,6 +152,16 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
DataSourceV2Strategy.withProjectAndFilter(
project, filters, localScanExec, needsUnsafeConversion = false) :: Nil

case PhysicalOperation(project, filters,
DataSourceV2ScanRelation(_, scan: InMemoryCacheScan, output, _, _)) =>
// Route cached DataFrames back to InMemoryTableScanExec, preserving the optimized
// columnar path. Filters are passed for batch-level min/max pruning and a post-scan
// FilterExec is added by withProjectAndFilter for row-level re-evaluation.
DataSourceV2Strategy.withProjectAndFilter(
project, filters,
InMemoryTableScanExec(output, filters, scan.relation),
needsUnsafeConversion = false) :: Nil

case PhysicalOperation(project, filters, relation: DataSourceV2ScanRelation) =>
// projection and filters were already pushed down in the optimizer.
// this uses PhysicalOperation to get the projection and ensure that if the batch scan does
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.read.SupportsRuntimeV2Filtering
import org.apache.spark.sql.execution.columnar.InMemoryRelation
import org.apache.spark.sql.execution.columnar.CachedRelation
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
import org.apache.spark.util.ArrayImplicits._
Expand Down Expand Up @@ -183,17 +183,17 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper with Join
*/
private def calculatePlanOverhead(plan: LogicalPlan): Float = {
val (cached, notCached) = plan.collectLeaves().partition(p => p match {
case _: InMemoryRelation => true
case CachedRelation(_) => true
case _ => false
})
val scanOverhead = notCached.map(_.stats.sizeInBytes).sum.toFloat
val cachedOverhead = cached.map {
case m: InMemoryRelation if m.cacheBuilder.storageLevel.useDisk &&
case CachedRelation(m) if m.cacheBuilder.storageLevel.useDisk &&
!m.cacheBuilder.storageLevel.useMemory =>
m.stats.sizeInBytes.toFloat
case m: InMemoryRelation if m.cacheBuilder.storageLevel.useDisk =>
case CachedRelation(m) if m.cacheBuilder.storageLevel.useDisk =>
m.stats.sizeInBytes.toFloat * 0.2
case m: InMemoryRelation if m.cacheBuilder.storageLevel.useMemory =>
case CachedRelation(_) =>
0.0
}.sum.toFloat
scanOverhead + cachedOverhead
Expand Down
Loading