Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,31 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.types.{ArrayType, DataType}

abstract class VeloxCollect(child: Expression)
abstract class VeloxCollect(child: Expression, val ignoreNulls: Boolean)
extends DeclarativeAggregate
with UnaryLike[Expression] {

protected lazy val buffer: AttributeReference = AttributeReference("buffer", dataType)()

override def dataType: DataType = ArrayType(child.dataType, false)
override def dataType: DataType = ArrayType(child.dataType, !ignoreNulls)

override def nullable: Boolean = false

override def aggBufferAttributes: Seq[AttributeReference] = Seq(buffer)

override lazy val initialValues: Seq[Expression] = Seq(Literal.create(Array(), dataType))

override lazy val updateExpressions: Seq[Expression] = Seq(
If(
IsNull(child),
buffer,
Concat(Seq(buffer, CreateArray(Seq(child), useStringTypeWhenEmpty = false))))
)
override lazy val updateExpressions: Seq[Expression] = {
val append = if (ignoreNulls) {
If(
IsNull(child),
buffer,
Concat(Seq(buffer, CreateArray(Seq(child), useStringTypeWhenEmpty = false))))
} else {
Concat(Seq(buffer, CreateArray(Seq(child), useStringTypeWhenEmpty = false)))
}
Seq(append)
}

override lazy val mergeExpressions: Seq[Expression] = Seq(
Concat(Seq(buffer.left, buffer.right))
Expand All @@ -49,7 +54,8 @@ abstract class VeloxCollect(child: Expression)
override def defaultResult: Option[Literal] = Option(Literal.create(Array(), dataType))
}

case class VeloxCollectSet(child: Expression) extends VeloxCollect(child) {
case class VeloxCollectSet(child: Expression, override val ignoreNulls: Boolean = true)
extends VeloxCollect(child, ignoreNulls) {

override lazy val evaluateExpression: Expression =
ArrayDistinct(buffer)
Expand All @@ -60,7 +66,8 @@ case class VeloxCollectSet(child: Expression) extends VeloxCollect(child) {
copy(child = newChild)
}

case class VeloxCollectList(child: Expression) extends VeloxCollect(child) {
case class VeloxCollectList(child: Expression, override val ignoreNulls: Boolean = true)
extends VeloxCollect(child, ignoreNulls) {

override val evaluateExpression: Expression = buffer

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,24 @@ object CollectRewriteRule {
def unapply(expr: Expression): Option[Expression] = expr match {
case aggExpr @ AggregateExpression(s: CollectSet, _, _, _, _) if has[VeloxCollectSet] =>
val newAggExpr =
aggExpr.copy(aggregateFunction = VeloxCollectSet(s.child))
aggExpr.copy(aggregateFunction = VeloxCollectSet(s.child, getIgnoreNulls(s)))
Some(newAggExpr)
case aggExpr @ AggregateExpression(l: CollectList, _, _, _, _) if has[VeloxCollectList] =>
val newAggExpr = aggExpr.copy(VeloxCollectList(l.child))
val newAggExpr = aggExpr.copy(VeloxCollectList(l.child, getIgnoreNulls(l)))
Some(newAggExpr)
case _ => None
}
}

private def getIgnoreNulls(expr: Expression): Boolean = {
try {
val method = expr.getClass.getMethod("ignoreNulls")
method.invoke(expr).asInstanceOf[Boolean]
} catch {
case _: NoSuchMethodException => true // Default: ignore nulls
}
}

private def has[T <: Expression: ClassTag]: Boolean =
ExpressionMappings.expressionsMap.contains(classTag[T].runtimeClass)
}
30 changes: 23 additions & 7 deletions cpp/velox/substrait/SubstraitToVeloxPlan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -284,14 +284,30 @@ std::string SubstraitToVeloxPlanConverter::toAggregationFunctionName(
// The merge_extract function is registered without suffix.
return functionName;
}
// The merge_extract function must be registered with suffix based on result type.
functionName += ("_" + companionFunctionSuffix(resultType));
signatures = exec::getAggregateFunctionSignatures(functionName);
VELOX_CHECK(
signatures.has_value() && signatures.value().size() > 0,
// The merge_extract function must be registered with suffix based on
// result type. First try exact concrete type suffix.
auto suffixedName =
functionName + "_" + companionFunctionSuffix(resultType);
signatures = exec::getAggregateFunctionSignatures(suffixedName);
if (signatures.has_value() && signatures.value().size() > 0) {
return suffixedName;
}
// When companion functions are registered with generic type variables
// (e.g., "collect_set_merge_extract_array_T"), look up companion
// function names from the aggregate function registry.
auto companionSigs = exec::getCompanionFunctionSignatures(baseName);
if (companionSigs.has_value()) {
for (const auto& entry : companionSigs->mergeExtract) {
auto entrySigs =
exec::getAggregateFunctionSignatures(entry.functionName);
if (entrySigs.has_value() && entrySigs.value().size() > 0) {
return entry.functionName;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When two merge-extract functions share the same intermediate type, this approach cannot guarantee that the correct function will be returned. We need to fix the companionFunctionSuffix, which needs to be fully compatible with Velox’s std::string toSuffixString(const TypeSignature& type) function. Currently, we are encountering fallback issue because Gluten appears to generate a function name that differs from the one Velox uses. It looks like Gluten is generating collect_set_merge_extract_array_row_VARCHAR_BIGINT_BIGINT_endrow while Velox uses collect_set_merge_extract_array_T.

}
}
}
VELOX_FAIL(
"Cannot find function signature for {} in final aggregation step.",
functionName);
return functionName;
suffixedName);
}
case core::AggregationNode::Step::kIntermediate:
suffix = "_merge";
Expand Down
Loading