Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,10 @@ class QueryExecution(
def assertExecutedPlanPrepared(): Unit = executedPlan

val lazyToRdd = LazyTry {
new SQLExecutionRDD(executedPlan.execute(), sparkSession.sessionState.conf)
new SQLExecutionRDD(
executedPlan.execute(),
sparkSession.sessionState.conf,
SparkPlanInfo.fromSparkPlan(executedPlan))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,18 @@ import org.apache.spark.sql.internal.SQLConf
*
* @param sqlRDD the `RDD` generated by the SQL plan
* @param conf the `SQLConf` to apply to the execution of the SQL plan
* @param sparkPlanInfo the physical plan information for `sqlRDD`
*/
class SQLExecutionRDD(
var sqlRDD: RDD[InternalRow], @transient conf: SQLConf) extends RDD[InternalRow](sqlRDD) {
var sqlRDD: RDD[InternalRow],
@transient conf: SQLConf,
@transient val sparkPlanInfo: SparkPlanInfo = SparkPlanInfo.EMPTY)
extends RDD[InternalRow](sqlRDD) {

def this(sqlRDD: RDD[InternalRow], conf: SQLConf) = {
this(sqlRDD, conf, SparkPlanInfo.EMPTY)
}

private val sqlConfigs = conf.getAllConfs
private lazy val sqlConfExecutorSide = {
val newConf = new SQLConf()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@

package org.apache.spark.sql.execution

import scala.util.control.NonFatal

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.plans.logical.{EmptyRelation, LogicalPlan}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec}
import org.apache.spark.sql.execution.adaptive.LogicalQueryStage
Expand Down Expand Up @@ -74,6 +77,7 @@ private[execution] object SparkPlanInfo {
case a: AdaptiveSparkPlanExec => a.executedPlan :: Nil
case stage: QueryStageExec => stage.plan :: Nil
case inMemTab: InMemoryTableScanExec => inMemTab.relation.cachedPlan :: Nil
case rddScan: RDDScanExec => sparkPlanInfosFromRDD(rddScan.rdd)
case EmptyRelationExec(logical) => (logical :: Nil)
case _ => plan.children ++ plan.subqueries
}
Expand All @@ -91,6 +95,8 @@ private[execution] object SparkPlanInfo {
Some(fromSparkPlan(child))
case child: LogicalPlan =>
Some(fromLogicalPlan(child))
case child: SparkPlanInfo =>
Some(child)
case _ => None
}
new SparkPlanInfo(
Expand All @@ -102,4 +108,31 @@ private[execution] object SparkPlanInfo {
}

final lazy val EMPTY: SparkPlanInfo = new SparkPlanInfo("", "", Nil, Map.empty, Nil)

private def sparkPlanInfosFromRDD(rdd: RDD[_]): Seq[SparkPlanInfo] = {
// Walk only driver-side RDD dependency metadata. Dedupe by RDD id so shared lineage does not
// duplicate the same internal SQL plan under a single RDDScanExec.
val visitedRDDs = scala.collection.mutable.HashSet.empty[Int]
val rddsToVisit = scala.collection.mutable.Queue.empty[RDD[_]]
val planInfos = scala.collection.mutable.ArrayBuffer.empty[SparkPlanInfo]

rddsToVisit.enqueue(rdd)
while (rddsToVisit.nonEmpty) {
val current = rddsToVisit.dequeue()
if (visitedRDDs.add(current.id)) {
current match {
case sqlRDD: SQLExecutionRDD if sqlRDD.sparkPlanInfo != EMPTY =>
planInfos += sqlRDD.sparkPlanInfo
case _ =>
try {
current.dependencies.foreach(dep => rddsToVisit.enqueue(dep.rdd))
} catch {
case NonFatal(_) =>
}
}
}
}

planInfos.toSeq
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,49 @@

package org.apache.spark.sql.execution.ui

import scala.concurrent.duration._

import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.execution.SparkPlanInfo
import org.apache.spark.sql.test.SharedSparkSession

class SparkPlanInfoSuite extends SharedSparkSession {

import testImplicits._

private def collectSparkPlanInfo(sparkPlanInfo: SparkPlanInfo): Seq[SparkPlanInfo] = {
sparkPlanInfo +: sparkPlanInfo.children.flatMap(collectSparkPlanInfo)
}

private def findSparkPlanInfo(sparkPlanInfo: SparkPlanInfo, nodeName: String): SparkPlanInfo = {
collectSparkPlanInfo(sparkPlanInfo)
.find(_.nodeName == nodeName)
.getOrElse(fail(s"Could not find $nodeName in ${sparkPlanInfo.simpleString}"))
}

private def collectSparkPlanGraphMetrics(
df: DataFrame): (SparkPlanGraph, Map[Long, String]) = {
val statusStore = spark.sharedState.statusStore
spark.sparkContext.listenerBus.waitUntilEmpty(10000)
val previousExecutionIds = statusStore.executionsList().map(_.executionId).toSet

df.collect()
spark.sparkContext.listenerBus.waitUntilEmpty(10000)

eventually(timeout(10.seconds), interval(10.milliseconds)) {
assert(statusStore.executionsList().map(_.executionId).toSet
.diff(previousExecutionIds).size === 1)
}
val executionIds = statusStore.executionsList().map(_.executionId).toSet
.diff(previousExecutionIds)
val executionId = executionIds.head
eventually(timeout(10.seconds), interval(10.milliseconds)) {
assert(statusStore.execution(executionId).exists(_.metricValues != null))
}

(statusStore.planGraph(executionId), statusStore.executionMetrics(executionId))
}

def validateSparkPlanInfo(sparkPlanInfo: SparkPlanInfo): Unit = {
sparkPlanInfo.nodeName match {
case "InMemoryTableScan" => assert(sparkPlanInfo.children.length == 1)
Expand All @@ -41,4 +77,48 @@ class SparkPlanInfoSuite extends SharedSparkSession {

validateSparkPlanInfo(planInfoResult)
}

test("SPARK-47017: SparkPlanInfo and SQL UI include SQL plan inside RDDScanExec") {
val source = spark.range(10).where($"id" > 3).select($"id".as("age"))
val recreated = spark.createDataFrame(source.rdd, source.schema)

val planInfo = SparkPlanInfo.fromSparkPlan(recreated.queryExecution.executedPlan)
val rddScanInfo = findSparkPlanInfo(planInfo, "Scan ExistingRDD")
val internalRDDPlanInfos = rddScanInfo.children.flatMap(collectSparkPlanInfo)
val filterInfo = internalRDDPlanInfos
.find(_.nodeName == "Filter")
.getOrElse(fail(s"Could not find Filter under Scan ExistingRDD in ${planInfo.simpleString}"))

assert(rddScanInfo.children.nonEmpty)
assert(filterInfo.metrics.exists(_.name == "number of output rows"))

val unionRDD = spark.sparkContext.union(source.rdd, source.rdd)
val unionPlanInfo = SparkPlanInfo.fromSparkPlan(
spark.createDataFrame(unionRDD, source.schema).queryExecution.executedPlan)
val unionRDDScanInfo = findSparkPlanInfo(unionPlanInfo, "Scan ExistingRDD")

assert(unionRDDScanInfo.children.size === 1)

val nested = spark.createDataFrame(recreated.rdd, recreated.schema)
val nestedPlanInfo = SparkPlanInfo.fromSparkPlan(nested.queryExecution.executedPlan)
val nestedRDDScanInfo = findSparkPlanInfo(nestedPlanInfo, "Scan ExistingRDD")

assert(nestedRDDScanInfo.children.size === 1)
assert(collectSparkPlanInfo(nestedRDDScanInfo.children.head).exists(_.nodeName == "Filter"))

val (planGraph, metricValues) = collectSparkPlanGraphMetrics(recreated)
val filterNode = planGraph.allNodes
.find(_.name == "Filter")
.getOrElse(fail(s"Could not find Filter in ${recreated.queryExecution.executedPlan}"))
val filterMetric = filterNode.metrics
.find(_.name == "number of output rows")
.getOrElse(fail("Could not find number of output rows metric for Filter"))
val filterMetricValue = metricValues
.getOrElse(filterMetric.accumulatorId, fail("Could not find Filter metric value"))
val outputRows = "\\d+".r.findFirstIn(filterMetricValue.replace(",", ""))
.map(_.toLong)
.getOrElse(fail(s"Could not parse Filter metric value $filterMetricValue"))

assert(outputRows === 6L)
}
}