diff --git a/python/pyspark/sql/internal.py b/python/pyspark/sql/internal.py index 3007b28b00441..dd9ebbcdc1822 100644 --- a/python/pyspark/sql/internal.py +++ b/python/pyspark/sql/internal.py @@ -104,7 +104,9 @@ def distributed_id() -> Column: @staticmethod def distributed_sequence_id() -> Column: - return InternalFunction._invoke_internal_function_over_columns("distributed_sequence_id") + return InternalFunction._invoke_internal_function_over_columns( + "distributed_sequence_id", F.lit(True) + ) @staticmethod def collect_top_k(col: Column, num: int, reverse: bool) -> Column: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala index b8da376bead6f..2a2440117e401 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala @@ -441,7 +441,7 @@ object DeduplicateRelations extends Rule[LogicalPlan] { newVersion.copyTagsFrom(oldVersion) Seq((oldVersion, newVersion)) - case oldVersion @ AttachDistributedSequence(sequenceAttr, _) + case oldVersion @ AttachDistributedSequence(sequenceAttr, _, _) if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty => val newVersion = oldVersion.copy(sequenceAttr = sequenceAttr.newInstance()) newVersion.copyTagsFrom(oldVersion) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ExtractDistributedSequenceID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ExtractDistributedSequenceID.scala index bf6ab8e50616c..fe26122f3ac13 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ExtractDistributedSequenceID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ExtractDistributedSequenceID.scala @@ -34,8 +34,12 @@ object ExtractDistributedSequenceID extends Rule[LogicalPlan] { plan.resolveOperatorsUpWithPruning(_.containsPattern(DISTRIBUTED_SEQUENCE_ID)) { case plan: LogicalPlan if plan.resolved && plan.expressions.exists(_.exists(_.isInstanceOf[DistributedSequenceID])) => + val cache = plan.expressions.exists(_.exists(e => + e.isInstanceOf[DistributedSequenceID] && + e.asInstanceOf[DistributedSequenceID].cache.eval().asInstanceOf[Boolean])) val attr = AttributeReference("distributed_sequence_id", LongType, nullable = false)() - val newPlan = plan.withNewChildren(plan.children.map(AttachDistributedSequence(attr, _))) + val newPlan = plan.withNewChildren( + plan.children.map(AttachDistributedSequence(attr, _, cache))) .transformExpressions { case _: DistributedSequenceID => attr } Project(plan.output, newPlan) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DistributedSequenceID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DistributedSequenceID.scala index 5a0bff990e68a..cd71ee8580525 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DistributedSequenceID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/DistributedSequenceID.scala @@ -26,10 +26,15 @@ import org.apache.spark.sql.types.{DataType, LongType} * * @note this expression is dedicated for Pandas API on Spark to use. */ -case class DistributedSequenceID() extends LeafExpression with Unevaluable with NonSQLExpression { +case class DistributedSequenceID(cache: Expression) + extends LeafExpression with Unevaluable with NonSQLExpression { + + // This argument indicate whether the underlying RDD should be cached + // according to PS config "pandas_on_Spark.compute.default_index_cache". + def this() = this(Literal(false)) override def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = { - DistributedSequenceID() + DistributedSequenceID(cache) } override def nullable: Boolean = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index fe15819bd44a7..125db2752b209 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1068,7 +1068,7 @@ object ColumnPruning extends Rule[LogicalPlan] { a.copy(child = Expand(newProjects, newOutput, grandChild)) // Prune and drop AttachDistributedSequence if the produced attribute is not referred. - case p @ Project(_, a @ AttachDistributedSequence(_, grandChild)) + case p @ Project(_, a @ AttachDistributedSequence(_, grandChild, _)) if !p.references.contains(a.sequenceAttr) => p.copy(child = prunedChild(grandChild, p.references)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala index bcfcae2ee16c9..db22a0781c0e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala @@ -367,12 +367,14 @@ case class ArrowEvalPythonUDTF( /** * A logical plan that adds a new long column with the name `name` that - * increases one by one. This is for 'distributed-sequence' default index - * in pandas API on Spark. + * increases one by one. + * This is used in both 'distributed-sequence' index in pandas API on Spark + * and 'DataFrame.zipWithIndex'. */ case class AttachDistributedSequence( sequenceAttr: Attribute, - child: LogicalPlan) extends UnaryNode { + child: LogicalPlan, + cache: Boolean = false) extends UnaryNode { override val producedAttributes: AttributeSet = AttributeSet(sequenceAttr) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala index 088df782a541c..17d4640f22fad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala @@ -2062,7 +2062,7 @@ class Dataset[T] private[sql]( * This is for 'distributed-sequence' default index in pandas API on Spark. */ private[sql] def withSequenceColumn(name: String) = { - select(Column(DistributedSequenceID()).alias(name), col("*")) + select(Column(DistributedSequenceID(Literal(true))).alias(name), col("*")) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 5efad83bcba78..5c393b1db227e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -969,8 +969,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.python.MapInPandasExec(func, output, planLater(child), isBarrier, profile) :: Nil case logical.MapInArrow(func, output, child, isBarrier, profile) => execution.python.MapInArrowExec(func, output, planLater(child), isBarrier, profile) :: Nil - case logical.AttachDistributedSequence(attr, child) => - execution.python.AttachDistributedSequenceExec(attr, planLater(child)) :: Nil + case logical.AttachDistributedSequence(attr, child, cache) => + execution.python.AttachDistributedSequenceExec(attr, planLater(child), cache) :: Nil case logical.PythonWorkerLogs(jsonAttr) => execution.python.PythonWorkerLogsExec(jsonAttr) :: Nil case logical.MapElements(f, _, _, objAttr, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AttachDistributedSequenceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AttachDistributedSequenceExec.scala index e27bde38a6f5f..507b632f55653 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AttachDistributedSequenceExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AttachDistributedSequenceExec.scala @@ -29,12 +29,16 @@ import org.apache.spark.storage.{StorageLevel, StorageLevelMapper} /** * A physical plan that adds a new long column with `sequenceAttr` that - * increases one by one. This is for 'distributed-sequence' default index - * in pandas API on Spark. + * increases one by one. + * This is for 'distributed-sequence' default index in pandas API on Spark, + * and 'DataFrame.zipWithIndex' + * When cache is true, the underlying RDD will be cached according to + * PS config "pandas_on_Spark.compute.default_index_cache". */ case class AttachDistributedSequenceExec( sequenceAttr: Attribute, - child: SparkPlan) + child: SparkPlan, + cache: Boolean) extends UnaryExecNode { override def producedAttributes: AttributeSet = AttributeSet(sequenceAttr) @@ -45,8 +49,9 @@ case class AttachDistributedSequenceExec( @transient private var cached: RDD[InternalRow] = _ - override protected def doExecute(): RDD[InternalRow] = { - val childRDD = child.execute() + // cache the underlying RDD according to + // PS config "pandas_on_Spark.compute.default_index_cache" + private def cacheRDD(rdd: RDD[InternalRow]): RDD[InternalRow] = { // before `compute.default_index_cache` is explicitly set via // `ps.set_option`, `SQLConf.get` can not get its value (as well as its default value); // after `ps.set_option`, `SQLConf.get` can get its value: @@ -74,22 +79,30 @@ case class AttachDistributedSequenceExec( StorageLevelMapper.MEMORY_AND_DISK_SER.name() ).stripPrefix("\"").stripSuffix("\"") - val cachedRDD = storageLevel match { + storageLevel match { // zipWithIndex launches a Spark job only if #partition > 1 - case _ if childRDD.getNumPartitions <= 1 => childRDD + case _ if rdd.getNumPartitions <= 1 => rdd - case "NONE" => childRDD + case "NONE" => rdd case "LOCAL_CHECKPOINT" => // localcheckpointing is unreliable so should not eagerly release it in 'cleanupResources' - childRDD.map(_.copy()).localCheckpoint() + rdd.map(_.copy()).localCheckpoint() .setName(s"Temporary RDD locally checkpointed in AttachDistributedSequenceExec($id)") case _ => - cached = childRDD.map(_.copy()).persist(StorageLevel.fromString(storageLevel)) + cached = rdd.map(_.copy()).persist(StorageLevel.fromString(storageLevel)) .setName(s"Temporary RDD cached in AttachDistributedSequenceExec($id)") cached } + } + + override protected def doExecute(): RDD[InternalRow] = { + val childRDD: RDD[InternalRow] = child.execute() + + // if cache is true, the underlying rdd is cached according to + // PS config "pandas_on_Spark.compute.default_index_cache" + val cachedRDD = if (cache) this.cacheRDD(childRDD) else childRDD cachedRDD.zipWithIndex().mapPartitions { iter => val unsafeProj = UnsafeProjection.create(output, output)