From 273769b9563d068ddb4b6a149fe9678ba94cd992 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 25 Mar 2026 14:02:24 -0700 Subject: [PATCH 1/3] feat: fix array_compact for Spark 4.0 and correct return type metadata - Handle KnownNotContainsNull wrapper in Spark 4.0 shim so that array_compact runs natively on Spark 4.0 (previously always fell back) - Fix return type passed to DataFusion: use expr.dataType instead of hardcoded ArrayType(elementType) so that containsNull=false is correctly propagated on Spark 4.0 - Mark CometArrayCompact as Compatible() instead of Incompatible(None) - Expand SQL test coverage: add string, double, and nested array types; change spark_answer_only to query to verify native execution - Remove assume(\!isSpark40Plus) skip from Scala test - Update expressions.md: ArrayCompact is now supported --- docs/source/user-guide/latest/expressions.md | 2 +- .../scala/org/apache/comet/serde/arrays.scala | 4 +- .../apache/comet/shims/CometExprShim.scala | 28 ++++++++++++- .../expressions/array/array_compact.sql | 41 ++++++++++++++++--- .../comet/CometArrayExpressionSuite.scala | 2 - 5 files changed, 65 insertions(+), 12 deletions(-) diff --git a/docs/source/user-guide/latest/expressions.md b/docs/source/user-guide/latest/expressions.md index c8e5475d0d..1bd36a3c5a 100644 --- a/docs/source/user-guide/latest/expressions.md +++ b/docs/source/user-guide/latest/expressions.md @@ -233,7 +233,7 @@ Comet supports using the following aggregate functions within window contexts wi | Expression | Spark-Compatible? | Compatibility Notes | | -------------- | ----------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | ArrayAppend | No | | -| ArrayCompact | No | | +| ArrayCompact | Yes | | | ArrayContains | No | Returns null instead of false for empty arrays with literal values ([#3346](https://github.com/apache/datafusion-comet/issues/3346)) | | ArrayDistinct | No | Behaves differently than spark. Comet first sorts then removes duplicates while Spark preserves the original order. | | ArrayExcept | No | | diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index 298c473087..674f94e988 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -305,7 +305,7 @@ object CometArrayRepeat extends CometExpressionSerde[ArrayRepeat] { object CometArrayCompact extends CometExpressionSerde[Expression] { - override def getSupportLevel(expr: Expression): SupportLevel = Incompatible(None) + override def getSupportLevel(expr: Expression): SupportLevel = Compatible() override def convert( expr: Expression, @@ -319,7 +319,7 @@ object CometArrayCompact extends CometExpressionSerde[Expression] { val arrayCompactScalarExpr = scalarFunctionExprToProtoWithReturnType( "array_remove_all", - ArrayType(elementType = elementType), + expr.dataType, false, arrayExprProto, nullLiteralProto) diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala index 2c5cebd166..8b2be5da4c 100644 --- a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala @@ -23,13 +23,13 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.types.StringTypeWithCollation -import org.apache.spark.sql.types.{BinaryType, BooleanType, DataTypes, StringType} +import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, DataTypes, StringType} import org.apache.comet.CometSparkSessionExtensions.withInfo import org.apache.comet.expressions.{CometCast, CometEvalMode} import org.apache.comet.serde.{CommonStringExprs, Compatible, ExprOuterClass, Incompatible} import org.apache.comet.serde.ExprOuterClass.{BinaryOutputStyle, Expr} -import org.apache.comet.serde.QueryPlanSerde.exprToProtoInternal +import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProtoWithReturnType} /** * `CometExprShim` acts as a shim for parsing expressions from different Spark versions. @@ -55,6 +55,30 @@ trait CometExprShim extends CommonStringExprs { inputs: Seq[Attribute], binding: Boolean): Option[Expr] = { expr match { + case knc: KnownNotContainsNull => + // On Spark 4.0, array_compact rewrites to KnownNotContainsNull(ArrayFilter(IsNotNull)). + // Strip the wrapper and serialize the inner ArrayFilter using the outer node's dataType + // so that containsNull=false is propagated correctly to the output schema. + knc.child match { + case filter: ArrayFilter => + filter.function.children.headOption match { + case Some(_: IsNotNull) => + val arrayChild = filter.left + val elementType = arrayChild.dataType.asInstanceOf[ArrayType].elementType + val arrayExprProto = exprToProtoInternal(arrayChild, inputs, binding) + val nullLiteralProto = exprToProtoInternal(Literal(null, elementType), Seq.empty) + val scalarExpr = scalarFunctionExprToProtoWithReturnType( + "array_remove_all", + knc.dataType, + false, + arrayExprProto, + nullLiteralProto) + optExprWithInfo(scalarExpr, knc, arrayChild) + case _ => None + } + case _ => None + } + case s: StaticInvoke if s.staticObject == classOf[StringDecode] && s.dataType.isInstanceOf[StringType] && diff --git a/spark/src/test/resources/sql-tests/expressions/array/array_compact.sql b/spark/src/test/resources/sql-tests/expressions/array/array_compact.sql index 9b834a4dbd..f8130f4b56 100644 --- a/spark/src/test/resources/sql-tests/expressions/array/array_compact.sql +++ b/spark/src/test/resources/sql-tests/expressions/array/array_compact.sql @@ -18,14 +18,45 @@ -- ConfigMatrix: parquet.enable.dictionary=false,true statement -CREATE TABLE test_array_compact(arr array) USING parquet +CREATE TABLE test_array_compact( + ints array, + strs array, + dbls array, + nested array> +) USING parquet statement -INSERT INTO test_array_compact VALUES (array(1, NULL, 2, NULL, 3)), (array()), (NULL), (array(NULL, NULL)), (array(1, 2, 3)) +INSERT INTO test_array_compact VALUES + (array(1, NULL, 2, NULL, 3), array('a', NULL, 'b', NULL, 'c'), array(1.0, NULL, 2.0), array(array(1, NULL, 3), NULL, array(4, NULL, 6))), + (array(), array(), array(), array()), + (NULL, NULL, NULL, NULL), + (array(NULL, NULL), array(NULL, NULL), array(NULL, NULL), array(NULL, NULL)), + (array(1, 2, 3), array('x', 'y', 'z'), array(1.5, 2.5), array(array(1, 2), array(3, 4))) -query spark_answer_only -SELECT array_compact(arr) FROM test_array_compact +-- integer column +query +SELECT array_compact(ints) FROM test_array_compact + +-- string column +query +SELECT array_compact(strs) FROM test_array_compact + +-- double column +query +SELECT array_compact(dbls) FROM test_array_compact + +-- nested array column: outer nulls removed, inner nulls preserved +query +SELECT array_compact(nested) FROM test_array_compact -- literal arguments -query spark_answer_only +query SELECT array_compact(array(1, NULL, 2, NULL, 3)) + +-- literal string array +query +SELECT array_compact(array('a', NULL, 'b')) + +-- all-null literal array +query +SELECT array_compact(array(NULL, NULL, NULL)) diff --git a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala index fb5531a573..d716039a2f 100644 --- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala @@ -567,8 +567,6 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp } test("array_compact") { - // TODO fix for Spark 4.0.0 - assume(!isSpark40Plus) Seq(true, false).foreach { dictionaryEnabled => withTempDir { dir => withTempView("t1") { From 4884d08b4ae80d23f7e29c511fa254a004853640 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 25 Mar 2026 14:15:46 -0700 Subject: [PATCH 2/3] fix: add missing binding argument to exprToProtoInternal in Spark 4.0 shim --- .../main/spark-4.0/org/apache/comet/shims/CometExprShim.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala index 8b2be5da4c..217e03e940 100644 --- a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala @@ -66,7 +66,8 @@ trait CometExprShim extends CommonStringExprs { val arrayChild = filter.left val elementType = arrayChild.dataType.asInstanceOf[ArrayType].elementType val arrayExprProto = exprToProtoInternal(arrayChild, inputs, binding) - val nullLiteralProto = exprToProtoInternal(Literal(null, elementType), Seq.empty) + val nullLiteralProto = + exprToProtoInternal(Literal(null, elementType), Seq.empty, false) val scalarExpr = scalarFunctionExprToProtoWithReturnType( "array_remove_all", knc.dataType, From d18c7843bd7401696db3369ca249c95deb594547 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 25 Mar 2026 15:17:40 -0700 Subject: [PATCH 3/3] fix: pass containsNull=true as array_remove_all return type to match DataFusion DataFusion's array_remove_all function always returns a list with nullable elements (nullable=true), but array_compact's Spark dataType has containsNull=false. Passing the Spark type as the promised return type caused a runtime type-mismatch assertion in DataFusion's ScalarFunctionExpr. Fix by passing ArrayType(elementType, containsNull=true) in both the Spark 3.x CometArrayCompact serde and the Spark 4.0 KnownNotContainsNull shim. --- spark/src/main/scala/org/apache/comet/serde/arrays.scala | 6 +++++- .../spark-4.0/org/apache/comet/shims/CometExprShim.scala | 9 ++++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index 674f94e988..f140e415fa 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -317,9 +317,13 @@ object CometArrayCompact extends CometExpressionSerde[Expression] { val arrayExprProto = exprToProto(child, inputs, binding) val nullLiteralProto = exprToProto(Literal(null, elementType), Seq.empty) + // Pass containsNull=true because DataFusion's array_remove_all always returns + // a list type with nullable elements; the containsNull=false from Spark's expr.dataType + // would cause a runtime type-mismatch assertion in DataFusion's ScalarFunctionExpr. + val returnType = ArrayType(elementType, containsNull = true) val arrayCompactScalarExpr = scalarFunctionExprToProtoWithReturnType( "array_remove_all", - expr.dataType, + returnType, false, arrayExprProto, nullLiteralProto) diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala index 217e03e940..e97f371cb9 100644 --- a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala @@ -57,8 +57,7 @@ trait CometExprShim extends CommonStringExprs { expr match { case knc: KnownNotContainsNull => // On Spark 4.0, array_compact rewrites to KnownNotContainsNull(ArrayFilter(IsNotNull)). - // Strip the wrapper and serialize the inner ArrayFilter using the outer node's dataType - // so that containsNull=false is propagated correctly to the output schema. + // Strip the wrapper and serialize the inner ArrayFilter as array_remove_all. knc.child match { case filter: ArrayFilter => filter.function.children.headOption match { @@ -68,9 +67,13 @@ trait CometExprShim extends CommonStringExprs { val arrayExprProto = exprToProtoInternal(arrayChild, inputs, binding) val nullLiteralProto = exprToProtoInternal(Literal(null, elementType), Seq.empty, false) + // Pass containsNull=true because DataFusion's array_remove_all always returns + // a list type with nullable elements; knc.dataType has containsNull=false + // which would cause a runtime type-mismatch assertion in DataFusion. + val returnType = ArrayType(elementType, containsNull = true) val scalarExpr = scalarFunctionExprToProtoWithReturnType( "array_remove_all", - knc.dataType, + returnType, false, arrayExprProto, nullLiteralProto)