Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion python/pyspark/sql/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ def distributed_id() -> Column:

@staticmethod
def distributed_sequence_id() -> Column:
return InternalFunction._invoke_internal_function_over_columns("distributed_sequence_id")
return InternalFunction._invoke_internal_function_over_columns(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is only used in PS

"distributed_sequence_id", F.lit(True)
)

@staticmethod
def collect_top_k(col: Column, num: int, reverse: bool) -> Column:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
newVersion.copyTagsFrom(oldVersion)
Seq((oldVersion, newVersion))

case oldVersion @ AttachDistributedSequence(sequenceAttr, _)
case oldVersion @ AttachDistributedSequence(sequenceAttr, _, _)
if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
val newVersion = oldVersion.copy(sequenceAttr = sequenceAttr.newInstance())
newVersion.copyTagsFrom(oldVersion)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,12 @@ object ExtractDistributedSequenceID extends Rule[LogicalPlan] {
plan.resolveOperatorsUpWithPruning(_.containsPattern(DISTRIBUTED_SEQUENCE_ID)) {
case plan: LogicalPlan if plan.resolved &&
plan.expressions.exists(_.exists(_.isInstanceOf[DistributedSequenceID])) =>
val cache = plan.expressions.exists(_.exists(e =>
e.isInstanceOf[DistributedSequenceID] &&
e.asInstanceOf[DistributedSequenceID].cache.eval().asInstanceOf[Boolean]))
val attr = AttributeReference("distributed_sequence_id", LongType, nullable = false)()
val newPlan = plan.withNewChildren(plan.children.map(AttachDistributedSequence(attr, _)))
val newPlan = plan.withNewChildren(
plan.children.map(AttachDistributedSequence(attr, _, cache)))
.transformExpressions { case _: DistributedSequenceID => attr }
Project(plan.output, newPlan)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,15 @@ import org.apache.spark.sql.types.{DataType, LongType}
*
* @note this expression is dedicated for Pandas API on Spark to use.
*/
case class DistributedSequenceID() extends LeafExpression with Unevaluable with NonSQLExpression {
case class DistributedSequenceID(cache: Expression)
extends LeafExpression with Unevaluable with NonSQLExpression {

// This argument indicate whether the underlying RDD should be cached
// according to PS config "pandas_on_Spark.compute.default_index_cache".
def this() = this(Literal(false))

override def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = {
DistributedSequenceID()
DistributedSequenceID(cache)
}

override def nullable: Boolean = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1068,7 +1068,7 @@ object ColumnPruning extends Rule[LogicalPlan] {
a.copy(child = Expand(newProjects, newOutput, grandChild))

// Prune and drop AttachDistributedSequence if the produced attribute is not referred.
case p @ Project(_, a @ AttachDistributedSequence(_, grandChild))
case p @ Project(_, a @ AttachDistributedSequence(_, grandChild, _))
if !p.references.contains(a.sequenceAttr) =>
p.copy(child = prunedChild(grandChild, p.references))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -367,12 +367,14 @@ case class ArrowEvalPythonUDTF(

/**
* A logical plan that adds a new long column with the name `name` that
* increases one by one. This is for 'distributed-sequence' default index
* in pandas API on Spark.
* increases one by one.
* This is used in both 'distributed-sequence' index in pandas API on Spark
* and 'DataFrame.zipWithIndex'.
*/
case class AttachDistributedSequence(
sequenceAttr: Attribute,
child: LogicalPlan) extends UnaryNode {
child: LogicalPlan,
cache: Boolean = false) extends UnaryNode {

override val producedAttributes: AttributeSet = AttributeSet(sequenceAttr)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2062,7 +2062,7 @@ class Dataset[T] private[sql](
* This is for 'distributed-sequence' default index in pandas API on Spark.
*/
private[sql] def withSequenceColumn(name: String) = {
select(Column(DistributedSequenceID()).alias(name), col("*"))
select(Column(DistributedSequenceID(Literal(true))).alias(name), col("*"))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be also a place for PS on pyspark classic

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -969,8 +969,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.python.MapInPandasExec(func, output, planLater(child), isBarrier, profile) :: Nil
case logical.MapInArrow(func, output, child, isBarrier, profile) =>
execution.python.MapInArrowExec(func, output, planLater(child), isBarrier, profile) :: Nil
case logical.AttachDistributedSequence(attr, child) =>
execution.python.AttachDistributedSequenceExec(attr, planLater(child)) :: Nil
case logical.AttachDistributedSequence(attr, child, cache) =>
execution.python.AttachDistributedSequenceExec(attr, planLater(child), cache) :: Nil
case logical.PythonWorkerLogs(jsonAttr) =>
execution.python.PythonWorkerLogsExec(jsonAttr) :: Nil
case logical.MapElements(f, _, _, objAttr, child) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,16 @@ import org.apache.spark.storage.{StorageLevel, StorageLevelMapper}

/**
* A physical plan that adds a new long column with `sequenceAttr` that
* increases one by one. This is for 'distributed-sequence' default index
* in pandas API on Spark.
* increases one by one.
* This is for 'distributed-sequence' default index in pandas API on Spark,
* and 'DataFrame.zipWithIndex'
* When cache is true, the underlying RDD will be cached according to
* PS config "pandas_on_Spark.compute.default_index_cache".
*/
case class AttachDistributedSequenceExec(
sequenceAttr: Attribute,
child: SparkPlan)
child: SparkPlan,
cache: Boolean)
extends UnaryExecNode {

override def producedAttributes: AttributeSet = AttributeSet(sequenceAttr)
Expand All @@ -45,8 +49,9 @@ case class AttachDistributedSequenceExec(

@transient private var cached: RDD[InternalRow] = _

override protected def doExecute(): RDD[InternalRow] = {
val childRDD = child.execute()
// cache the underlying RDD according to
// PS config "pandas_on_Spark.compute.default_index_cache"
private def cacheRDD(rdd: RDD[InternalRow]): RDD[InternalRow] = {
// before `compute.default_index_cache` is explicitly set via
// `ps.set_option`, `SQLConf.get` can not get its value (as well as its default value);
// after `ps.set_option`, `SQLConf.get` can get its value:
Expand Down Expand Up @@ -74,22 +79,30 @@ case class AttachDistributedSequenceExec(
StorageLevelMapper.MEMORY_AND_DISK_SER.name()
).stripPrefix("\"").stripSuffix("\"")

val cachedRDD = storageLevel match {
storageLevel match {
// zipWithIndex launches a Spark job only if #partition > 1
case _ if childRDD.getNumPartitions <= 1 => childRDD
case _ if rdd.getNumPartitions <= 1 => rdd

case "NONE" => childRDD
case "NONE" => rdd

case "LOCAL_CHECKPOINT" =>
// localcheckpointing is unreliable so should not eagerly release it in 'cleanupResources'
childRDD.map(_.copy()).localCheckpoint()
rdd.map(_.copy()).localCheckpoint()
.setName(s"Temporary RDD locally checkpointed in AttachDistributedSequenceExec($id)")

case _ =>
cached = childRDD.map(_.copy()).persist(StorageLevel.fromString(storageLevel))
cached = rdd.map(_.copy()).persist(StorageLevel.fromString(storageLevel))
.setName(s"Temporary RDD cached in AttachDistributedSequenceExec($id)")
cached
}
}

override protected def doExecute(): RDD[InternalRow] = {
val childRDD: RDD[InternalRow] = child.execute()

// if cache is true, the underlying rdd is cached according to
// PS config "pandas_on_Spark.compute.default_index_cache"
val cachedRDD = if (cache) this.cacheRDD(childRDD) else childRDD

cachedRDD.zipWithIndex().mapPartitions { iter =>
val unsafeProj = UnsafeProjection.create(output, output)
Expand Down