Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,42 +21,30 @@ import org.apache.gluten.execution._

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.EXCHANGE
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike
import org.apache.spark.sql.types.{DataType, DoubleType, FloatType}

import scala.collection.mutable

/**
* To transform regular aggregation to intermediate aggregation that internally enables
* optimizations such as flushing and abandoning.
*/
case class FlushableHashAggregateRule(session: SparkSession) extends Rule[SparkPlan] {
import FlushableHashAggregateRule._
override def apply(plan: SparkPlan): SparkPlan = {
if (!VeloxConfig.get.enableVeloxFlushablePartialAggregation) {
return plan
}
val protectedAggs = collectProtectedOneDistinctPartialMergeAggs(plan)
plan.transformUpWithPruning(_.containsPattern(EXCHANGE)) {
case s: ShuffleExchangeLike =>
// If an exchange follows a hash aggregate in which all functions are in partial mode,
// then it's safe to convert the hash aggregate to flushable hash aggregate.
val out = s.withNewChildren(
List(
replaceEligibleAggregates(s.child) {
agg =>
FlushableHashAggregateExecTransformer(
agg.requiredChildDistributionExpressions,
agg.groupingExpressions,
agg.aggregateExpressions,
agg.aggregateAttributes,
agg.initialInputBufferOffset,
agg.resultExpressions,
agg.child
)
}
)
List(replaceEligibleAggregates(s.child, protectedAggs))
)
out
}
Expand Down Expand Up @@ -85,79 +73,119 @@ case class FlushableHashAggregateRule(session: SparkSession) extends Rule[SparkP
/**
* Walks the plan downward, applying func to each RegularHashAggregateExecTransformer or
* SortHashAggregateExecTransformer that is eligible for flushable conversion. An aggregate is
* eligible when all expressions are Partial/PartialMerge, input is not already partitioned by the
* grouping keys, and no aggregate function disallows flushing.
* eligible when all expressions are Partial/PartialMerge, it is not the protected PartialMerge
* aggregate directly below a distinct-partial aggregate, and no aggregate function disallows
* flushing.
*/
private def replaceEligibleAggregates(plan: SparkPlan)(
func: HashAggregateExecTransformer => SparkPlan): SparkPlan = {
private def replaceEligibleAggregates(
plan: SparkPlan,
protectedAggs: mutable.Map[Int, HashAggregateExecTransformer]): SparkPlan = {
def toFlushableAgg(agg: HashAggregateExecTransformer): FlushableHashAggregateExecTransformer = {
FlushableHashAggregateExecTransformer(
agg.requiredChildDistributionExpressions,
agg.groupingExpressions,
agg.aggregateExpressions,
agg.aggregateAttributes,
agg.initialInputBufferOffset,
agg.resultExpressions,
agg.child
)
}

def transformDown: SparkPlan => SparkPlan = {
case agg: RegularHashAggregateExecTransformer
if !agg.aggregateExpressions.forall(p => p.mode == Partial || p.mode == PartialMerge) =>
// Not an intermediate agg. Skip.
agg
case agg: RegularHashAggregateExecTransformer
if isAggInputAlreadyDistributedWithAggKeys(agg) =>
// Data already grouped by aggregate keys. Skip.
if protectedAggs.contains(agg.id) =>
// This is the PartialMerge aggregate directly below a distinct-partial aggregate in
// Spark's one-distinct pipeline. Keep it non-flushable so the distinct step continues to
// see globally de-duplicated (grouping + distinct) keys.
agg
case agg: RegularHashAggregateExecTransformer
if aggregatesNotSupportFlush(agg.aggregateExpressions) =>
// Aggregate uses a function that is unsafe to flush. Skip.
agg
case agg: RegularHashAggregateExecTransformer =>
// All guards passed; replace with the flushable variant.
func(agg)
toFlushableAgg(agg)
case agg: SortHashAggregateExecTransformer
if !agg.aggregateExpressions.forall(p => p.mode == Partial || p.mode == PartialMerge) =>
// Not an intermediate agg. Skip.
agg
case agg: SortHashAggregateExecTransformer if isAggInputAlreadyDistributedWithAggKeys(agg) =>
// Data already grouped by aggregate keys. Skip.
case agg: SortHashAggregateExecTransformer if protectedAggs.contains(agg.id) =>
// See the RegularHashAggregateExecTransformer branch above.
agg
case agg: SortHashAggregateExecTransformer
if aggregatesNotSupportFlush(agg.aggregateExpressions) =>
// Aggregate uses a function that is unsafe to flush. Skip.
agg
case agg: SortHashAggregateExecTransformer =>
// All guards passed; replace with the flushable variant.
func(agg)
case p if !canPropagate(p) => p
toFlushableAgg(agg)
case exchange: ShuffleExchangeLike =>
// Stop at the next exchange. This rule is applied from an exchange boundary and should not
// continue rewriting into a different shuffle region.
exchange
case other => other.withNewChildren(other.children.map(transformDown))
}

val out = transformDown(plan)
out
}

private def canPropagate(plan: SparkPlan): Boolean = plan match {
case _: ProjectExecTransformer => true
case _: VeloxResizeBatchesExec => true
case _ => false
}
}

object FlushableHashAggregateRule {

/**
* If child output already partitioned by aggregation keys (this function returns true), we
* usually avoid the optimization converting to flushable aggregation.
* Collect the PartialMerge aggregates that must stay regular in Spark's one-distinct aggregation
* pipeline.
*
* Example plan shape:
*
* RegularHashAggregateExecTransformer [k] [count(distinct v)] // finalAggregate +-
* RegularHashAggregateExecTransformer [k] [count(distinct v)] // partialDistinctAggregate +-
* RegularHashAggregateExecTransformer [k, v] [count(...)] // partialMergeAggregate +-
* ColumnarExchange hashpartitioning(k, v, 200) +- RegularHashAggregateExecTransformer [k, v]
* [count(...)] // partialAggregate +- ...
*
* For example, if input is hash-partitioned by keys (a, b) and aggregate node requests "group by
* a, b, c", then the aggregate should NOT flush as the grouping set (a, b, c) will be created
* only on a single partition among the whole cluster. Spark's planner may use this information to
* perform optimizations like doing "partial_count(a, b, c)" directly on the output data.
* We walk every aggregate node and, when we encounter the `partialDistinctAggregate`, we record
* its child `partialMergeAggregate` as protected.
*
* That `partialMergeAggregate` must stay regular. It is the step that materializes the
* de-duplicated `(k, v)` stream consumed by the distinct-partial aggregate above it. If it
* flushes, duplicate `(k, v)` keys may be reintroduced within one partition and the distinct
* aggregation pipeline would no longer see the shape Spark planned for.
*/
private def isAggInputAlreadyDistributedWithAggKeys(
agg: HashAggregateExecTransformer): Boolean = {
if (agg.groupingExpressions.isEmpty) {
// Empty grouping set () should not be satisfied by any partitioning patterns.
// E.g.,
// (a, b) satisfies (a, b, c)
// (a, b) satisfies (a, b)
// (a, b) doesn't satisfy (a)
// (a, b) doesn't satisfy ()
return false
private def collectProtectedOneDistinctPartialMergeAggs(
plan: SparkPlan): mutable.Map[Int, HashAggregateExecTransformer] = {
val protectedAggs = mutable.HashMap.empty[Int, HashAggregateExecTransformer]
plan.foreach {
case agg: HashAggregateExecTransformer =>
findProtectedPartialMergeAgg(agg).foreach {
protectedAgg => protectedAggs.put(protectedAgg.id, protectedAgg)
}
case _ =>
}
protectedAggs
}

/** If this aggregate is the distinct-partial stage, return its child PartialMerge aggregate. */
private def findProtectedPartialMergeAgg(
distinctPartialAgg: HashAggregateExecTransformer): Option[HashAggregateExecTransformer] = {
if (
!distinctPartialAgg.aggregateExpressions.exists(
expr => expr.isDistinct && expr.mode == Partial)
) {
return None
}
val distribution = ClusteredDistribution(agg.groupingExpressions)
agg.child.outputPartitioning.satisfies(distribution)

for {
partialMergeAgg <- asAggregate(distinctPartialAgg.child)
if partialMergeAgg.aggregateExpressions.forall(_.mode == PartialMerge)
} yield partialMergeAgg
}

private def asAggregate(plan: SparkPlan): Option[HashAggregateExecTransformer] = plan match {
case agg: HashAggregateExecTransformer => Some(agg)
case _ => None
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1258,20 +1258,26 @@ class VeloxAggregateFunctionsFlushSuite extends VeloxAggregateFunctionsSuite {
}
}

test("flushable aggregate rule - agg input already distributed by keys") {
test("flushable aggregate rule - count distinct keeps partial merge regular") {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
SQLConf.FILES_MAX_PARTITION_BYTES.key -> "1k") {
runQueryAndCompare(
"select * from (select distinct l_orderkey,l_partkey from lineitem) a" +
" inner join (select l_orderkey from lineitem limit 10) b" +
" on a.l_orderkey = b.l_orderkey limit 10") {
runQueryAndCompare("select count(distinct l_partkey) from lineitem group by l_orderkey") {
df =>
val executedPlan = getExecutedPlan(df)
val regularAggCount = executedPlan.count {
plan => plan.isInstanceOf[RegularHashAggregateExecTransformer]
}
val flushableAggCount = executedPlan.count {
plan => plan.isInstanceOf[FlushableHashAggregateExecTransformer]
}
assert(
executedPlan.exists(plan => plan.isInstanceOf[RegularHashAggregateExecTransformer]))
regularAggCount == 2,
s"expected 2 regular hash aggregates in one-distinct pipeline, got $regularAggCount")
assert(
executedPlan.exists(plan => plan.isInstanceOf[FlushableHashAggregateExecTransformer]))
flushableAggCount == 2,
s"expected 2 flushable hash aggregates in one-distinct pipeline, got" +
s" $flushableAggCount")
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ VeloxColumnarToRow (187)
: : : : +- ColumnarExchange (58)
: : : : +- VeloxResizeBatches (57)
: : : : +- ^ ProjectExecTransformer (55)
: : : : +- ^ RegularHashAggregateExecTransformer (54)
: : : : +- ^ FlushableHashAggregateExecTransformer (54)
: : : : +- ^ InputIteratorTransformer (53)
: : : : +- ColumnarExchange (51)
: : : : +- VeloxResizeBatches (50)
Expand Down Expand Up @@ -373,7 +373,7 @@ Input [3]: [brand_id#27, class_id#28, category_id#29]
(53) InputIteratorTransformer
Input [3]: [brand_id#27, class_id#28, category_id#29]

(54) RegularHashAggregateExecTransformer
(54) FlushableHashAggregateExecTransformer
Input [3]: [brand_id#27, class_id#28, category_id#29]
Keys [3]: [brand_id#27, class_id#28, category_id#29]
Functions: []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ VeloxColumnarToRow
VeloxResizeBatches
WholeStageCodegenTransformer (22)
ProjectExecTransformer [brand_id,class_id,category_id]
RegularHashAggregateExecTransformer [brand_id,class_id,category_id]
FlushableHashAggregateExecTransformer [brand_id,class_id,category_id]
InputIteratorTransformer
InputAdapter
ColumnarExchange [brand_id,class_id,category_id] #7
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ VeloxColumnarToRow (150)
: : : : +- ColumnarExchange (58)
: : : : +- VeloxResizeBatches (57)
: : : : +- ^ ProjectExecTransformer (55)
: : : : +- ^ RegularHashAggregateExecTransformer (54)
: : : : +- ^ FlushableHashAggregateExecTransformer (54)
: : : : +- ^ InputIteratorTransformer (53)
: : : : +- ColumnarExchange (51)
: : : : +- VeloxResizeBatches (50)
Expand Down Expand Up @@ -345,7 +345,7 @@ Input [3]: [brand_id#27, class_id#28, category_id#29]
(53) InputIteratorTransformer
Input [3]: [brand_id#27, class_id#28, category_id#29]

(54) RegularHashAggregateExecTransformer
(54) FlushableHashAggregateExecTransformer
Input [3]: [brand_id#27, class_id#28, category_id#29]
Keys [3]: [brand_id#27, class_id#28, category_id#29]
Functions: []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ VeloxColumnarToRow
VeloxResizeBatches
WholeStageCodegenTransformer (24)
ProjectExecTransformer [brand_id,class_id,category_id]
RegularHashAggregateExecTransformer [brand_id,class_id,category_id]
FlushableHashAggregateExecTransformer [brand_id,class_id,category_id]
InputIteratorTransformer
InputAdapter
ColumnarExchange [brand_id,class_id,category_id] #6
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ VeloxColumnarToRow (104)
: : +- ColumnarExchange (34)
: : +- VeloxResizeBatches (33)
: : +- ^ ProjectExecTransformer (31)
: : +- ^ RegularHashAggregateExecTransformer (30)
: : +- ^ FlushableHashAggregateExecTransformer (30)
: : +- ^ InputIteratorTransformer (29)
: : +- ColumnarExchange (27)
: : +- VeloxResizeBatches (26)
Expand All @@ -35,7 +35,7 @@ VeloxColumnarToRow (104)
: +- ColumnarExchange (65)
: +- VeloxResizeBatches (64)
: +- ^ ProjectExecTransformer (62)
: +- ^ RegularHashAggregateExecTransformer (61)
: +- ^ FlushableHashAggregateExecTransformer (61)
: +- ^ InputIteratorTransformer (60)
: +- ColumnarExchange (58)
: +- VeloxResizeBatches (57)
Expand All @@ -58,7 +58,7 @@ VeloxColumnarToRow (104)
+- ColumnarExchange (97)
+- VeloxResizeBatches (96)
+- ^ ProjectExecTransformer (94)
+- ^ RegularHashAggregateExecTransformer (93)
+- ^ FlushableHashAggregateExecTransformer (93)
+- ^ InputIteratorTransformer (92)
+- ColumnarExchange (90)
+- VeloxResizeBatches (89)
Expand Down Expand Up @@ -200,7 +200,7 @@ Input [3]: [c_last_name#9, c_first_name#8, d_date#5]
(29) InputIteratorTransformer
Input [3]: [c_last_name#9, c_first_name#8, d_date#5]

(30) RegularHashAggregateExecTransformer
(30) FlushableHashAggregateExecTransformer
Input [3]: [c_last_name#9, c_first_name#8, d_date#5]
Keys [3]: [c_last_name#9, c_first_name#8, d_date#5]
Functions: []
Expand Down Expand Up @@ -326,7 +326,7 @@ Input [3]: [c_last_name#20, c_first_name#19, d_date#16]
(60) InputIteratorTransformer
Input [3]: [c_last_name#20, c_first_name#19, d_date#16]

(61) RegularHashAggregateExecTransformer
(61) FlushableHashAggregateExecTransformer
Input [3]: [c_last_name#20, c_first_name#19, d_date#16]
Keys [3]: [c_last_name#20, c_first_name#19, d_date#16]
Functions: []
Expand Down Expand Up @@ -458,7 +458,7 @@ Input [3]: [c_last_name#30, c_first_name#29, d_date#26]
(92) InputIteratorTransformer
Input [3]: [c_last_name#30, c_first_name#29, d_date#26]

(93) RegularHashAggregateExecTransformer
(93) FlushableHashAggregateExecTransformer
Input [3]: [c_last_name#30, c_first_name#29, d_date#26]
Keys [3]: [c_last_name#30, c_first_name#29, d_date#26]
Functions: []
Expand Down
Loading
Loading