diff --git a/backends-velox/src/main/scala/org/apache/gluten/expression/aggregate/VeloxCollect.scala b/backends-velox/src/main/scala/org/apache/gluten/expression/aggregate/VeloxCollect.scala index dc41bbc4fc01..0945343ce971 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/expression/aggregate/VeloxCollect.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/expression/aggregate/VeloxCollect.scala @@ -21,13 +21,13 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.types.{ArrayType, DataType} -abstract class VeloxCollect(child: Expression) +abstract class VeloxCollect(child: Expression, val ignoreNulls: Boolean) extends DeclarativeAggregate with UnaryLike[Expression] { protected lazy val buffer: AttributeReference = AttributeReference("buffer", dataType)() - override def dataType: DataType = ArrayType(child.dataType, false) + override def dataType: DataType = ArrayType(child.dataType, !ignoreNulls) override def nullable: Boolean = false @@ -35,12 +35,17 @@ abstract class VeloxCollect(child: Expression) override lazy val initialValues: Seq[Expression] = Seq(Literal.create(Array(), dataType)) - override lazy val updateExpressions: Seq[Expression] = Seq( - If( - IsNull(child), - buffer, - Concat(Seq(buffer, CreateArray(Seq(child), useStringTypeWhenEmpty = false)))) - ) + override lazy val updateExpressions: Seq[Expression] = { + val append = if (ignoreNulls) { + If( + IsNull(child), + buffer, + Concat(Seq(buffer, CreateArray(Seq(child), useStringTypeWhenEmpty = false)))) + } else { + Concat(Seq(buffer, CreateArray(Seq(child), useStringTypeWhenEmpty = false))) + } + Seq(append) + } override lazy val mergeExpressions: Seq[Expression] = Seq( Concat(Seq(buffer.left, buffer.right)) @@ -49,7 +54,8 @@ abstract class VeloxCollect(child: Expression) override def defaultResult: Option[Literal] = Option(Literal.create(Array(), dataType)) } -case class VeloxCollectSet(child: Expression) extends VeloxCollect(child) { +case class VeloxCollectSet(child: Expression, override val ignoreNulls: Boolean = true) + extends VeloxCollect(child, ignoreNulls) { override lazy val evaluateExpression: Expression = ArrayDistinct(buffer) @@ -60,7 +66,8 @@ case class VeloxCollectSet(child: Expression) extends VeloxCollect(child) { copy(child = newChild) } -case class VeloxCollectList(child: Expression) extends VeloxCollect(child) { +case class VeloxCollectList(child: Expression, override val ignoreNulls: Boolean = true) + extends VeloxCollect(child, ignoreNulls) { override val evaluateExpression: Expression = buffer diff --git a/backends-velox/src/main/scala/org/apache/gluten/extension/CollectRewriteRule.scala b/backends-velox/src/main/scala/org/apache/gluten/extension/CollectRewriteRule.scala index e76de56374f5..72e52cf3dd62 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/extension/CollectRewriteRule.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/extension/CollectRewriteRule.scala @@ -67,15 +67,24 @@ object CollectRewriteRule { def unapply(expr: Expression): Option[Expression] = expr match { case aggExpr @ AggregateExpression(s: CollectSet, _, _, _, _) if has[VeloxCollectSet] => val newAggExpr = - aggExpr.copy(aggregateFunction = VeloxCollectSet(s.child)) + aggExpr.copy(aggregateFunction = VeloxCollectSet(s.child, getIgnoreNulls(s))) Some(newAggExpr) case aggExpr @ AggregateExpression(l: CollectList, _, _, _, _) if has[VeloxCollectList] => - val newAggExpr = aggExpr.copy(VeloxCollectList(l.child)) + val newAggExpr = aggExpr.copy(VeloxCollectList(l.child, getIgnoreNulls(l))) Some(newAggExpr) case _ => None } } + private def getIgnoreNulls(expr: Expression): Boolean = { + try { + val method = expr.getClass.getMethod("ignoreNulls") + method.invoke(expr).asInstanceOf[Boolean] + } catch { + case _: NoSuchMethodException => true // Default: ignore nulls + } + } + private def has[T <: Expression: ClassTag]: Boolean = ExpressionMappings.expressionsMap.contains(classTag[T].runtimeClass) } diff --git a/cpp/velox/substrait/SubstraitToVeloxPlan.cc b/cpp/velox/substrait/SubstraitToVeloxPlan.cc index adb7fc5f45b6..e3e3ec3a4d3c 100644 --- a/cpp/velox/substrait/SubstraitToVeloxPlan.cc +++ b/cpp/velox/substrait/SubstraitToVeloxPlan.cc @@ -284,14 +284,30 @@ std::string SubstraitToVeloxPlanConverter::toAggregationFunctionName( // The merge_extract function is registered without suffix. return functionName; } - // The merge_extract function must be registered with suffix based on result type. - functionName += ("_" + companionFunctionSuffix(resultType)); - signatures = exec::getAggregateFunctionSignatures(functionName); - VELOX_CHECK( - signatures.has_value() && signatures.value().size() > 0, + // The merge_extract function must be registered with suffix based on + // result type. First try exact concrete type suffix. + auto suffixedName = + functionName + "_" + companionFunctionSuffix(resultType); + signatures = exec::getAggregateFunctionSignatures(suffixedName); + if (signatures.has_value() && signatures.value().size() > 0) { + return suffixedName; + } + // When companion functions are registered with generic type variables + // (e.g., "collect_set_merge_extract_array_T"), look up companion + // function names from the aggregate function registry. + auto companionSigs = exec::getCompanionFunctionSignatures(baseName); + if (companionSigs.has_value()) { + for (const auto& entry : companionSigs->mergeExtract) { + auto entrySigs = + exec::getAggregateFunctionSignatures(entry.functionName); + if (entrySigs.has_value() && entrySigs.value().size() > 0) { + return entry.functionName; + } + } + } + VELOX_FAIL( "Cannot find function signature for {} in final aggregation step.", - functionName); - return functionName; + suffixedName); } case core::AggregationNode::Step::kIntermediate: suffix = "_merge";