diff --git a/docs/source/user-guide/latest/expressions.md b/docs/source/user-guide/latest/expressions.md index c8e5475d0d..1bd36a3c5a 100644 --- a/docs/source/user-guide/latest/expressions.md +++ b/docs/source/user-guide/latest/expressions.md @@ -233,7 +233,7 @@ Comet supports using the following aggregate functions within window contexts wi | Expression | Spark-Compatible? | Compatibility Notes | | -------------- | ----------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | ArrayAppend | No | | -| ArrayCompact | No | | +| ArrayCompact | Yes | | | ArrayContains | No | Returns null instead of false for empty arrays with literal values ([#3346](https://github.com/apache/datafusion-comet/issues/3346)) | | ArrayDistinct | No | Behaves differently than spark. Comet first sorts then removes duplicates while Spark preserves the original order. | | ArrayExcept | No | | diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index 298c473087..f140e415fa 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -305,7 +305,7 @@ object CometArrayRepeat extends CometExpressionSerde[ArrayRepeat] { object CometArrayCompact extends CometExpressionSerde[Expression] { - override def getSupportLevel(expr: Expression): SupportLevel = Incompatible(None) + override def getSupportLevel(expr: Expression): SupportLevel = Compatible() override def convert( expr: Expression, @@ -317,9 +317,13 @@ object CometArrayCompact extends CometExpressionSerde[Expression] { val arrayExprProto = exprToProto(child, inputs, binding) val nullLiteralProto = exprToProto(Literal(null, elementType), Seq.empty) + // Pass containsNull=true because DataFusion's array_remove_all always returns + // a list type with nullable elements; the containsNull=false from Spark's expr.dataType + // would cause a runtime type-mismatch assertion in DataFusion's ScalarFunctionExpr. + val returnType = ArrayType(elementType, containsNull = true) val arrayCompactScalarExpr = scalarFunctionExprToProtoWithReturnType( "array_remove_all", - ArrayType(elementType = elementType), + returnType, false, arrayExprProto, nullLiteralProto) diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala index 2c5cebd166..e97f371cb9 100644 --- a/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala +++ b/spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala @@ -23,13 +23,13 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.types.StringTypeWithCollation -import org.apache.spark.sql.types.{BinaryType, BooleanType, DataTypes, StringType} +import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, DataTypes, StringType} import org.apache.comet.CometSparkSessionExtensions.withInfo import org.apache.comet.expressions.{CometCast, CometEvalMode} import org.apache.comet.serde.{CommonStringExprs, Compatible, ExprOuterClass, Incompatible} import org.apache.comet.serde.ExprOuterClass.{BinaryOutputStyle, Expr} -import org.apache.comet.serde.QueryPlanSerde.exprToProtoInternal +import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProtoWithReturnType} /** * `CometExprShim` acts as a shim for parsing expressions from different Spark versions. @@ -55,6 +55,34 @@ trait CometExprShim extends CommonStringExprs { inputs: Seq[Attribute], binding: Boolean): Option[Expr] = { expr match { + case knc: KnownNotContainsNull => + // On Spark 4.0, array_compact rewrites to KnownNotContainsNull(ArrayFilter(IsNotNull)). + // Strip the wrapper and serialize the inner ArrayFilter as array_remove_all. + knc.child match { + case filter: ArrayFilter => + filter.function.children.headOption match { + case Some(_: IsNotNull) => + val arrayChild = filter.left + val elementType = arrayChild.dataType.asInstanceOf[ArrayType].elementType + val arrayExprProto = exprToProtoInternal(arrayChild, inputs, binding) + val nullLiteralProto = + exprToProtoInternal(Literal(null, elementType), Seq.empty, false) + // Pass containsNull=true because DataFusion's array_remove_all always returns + // a list type with nullable elements; knc.dataType has containsNull=false + // which would cause a runtime type-mismatch assertion in DataFusion. + val returnType = ArrayType(elementType, containsNull = true) + val scalarExpr = scalarFunctionExprToProtoWithReturnType( + "array_remove_all", + returnType, + false, + arrayExprProto, + nullLiteralProto) + optExprWithInfo(scalarExpr, knc, arrayChild) + case _ => None + } + case _ => None + } + case s: StaticInvoke if s.staticObject == classOf[StringDecode] && s.dataType.isInstanceOf[StringType] && diff --git a/spark/src/test/resources/sql-tests/expressions/array/array_compact.sql b/spark/src/test/resources/sql-tests/expressions/array/array_compact.sql index 9b834a4dbd..f8130f4b56 100644 --- a/spark/src/test/resources/sql-tests/expressions/array/array_compact.sql +++ b/spark/src/test/resources/sql-tests/expressions/array/array_compact.sql @@ -18,14 +18,45 @@ -- ConfigMatrix: parquet.enable.dictionary=false,true statement -CREATE TABLE test_array_compact(arr array) USING parquet +CREATE TABLE test_array_compact( + ints array, + strs array, + dbls array, + nested array> +) USING parquet statement -INSERT INTO test_array_compact VALUES (array(1, NULL, 2, NULL, 3)), (array()), (NULL), (array(NULL, NULL)), (array(1, 2, 3)) +INSERT INTO test_array_compact VALUES + (array(1, NULL, 2, NULL, 3), array('a', NULL, 'b', NULL, 'c'), array(1.0, NULL, 2.0), array(array(1, NULL, 3), NULL, array(4, NULL, 6))), + (array(), array(), array(), array()), + (NULL, NULL, NULL, NULL), + (array(NULL, NULL), array(NULL, NULL), array(NULL, NULL), array(NULL, NULL)), + (array(1, 2, 3), array('x', 'y', 'z'), array(1.5, 2.5), array(array(1, 2), array(3, 4))) -query spark_answer_only -SELECT array_compact(arr) FROM test_array_compact +-- integer column +query +SELECT array_compact(ints) FROM test_array_compact + +-- string column +query +SELECT array_compact(strs) FROM test_array_compact + +-- double column +query +SELECT array_compact(dbls) FROM test_array_compact + +-- nested array column: outer nulls removed, inner nulls preserved +query +SELECT array_compact(nested) FROM test_array_compact -- literal arguments -query spark_answer_only +query SELECT array_compact(array(1, NULL, 2, NULL, 3)) + +-- literal string array +query +SELECT array_compact(array('a', NULL, 'b')) + +-- all-null literal array +query +SELECT array_compact(array(NULL, NULL, NULL)) diff --git a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala index fb5531a573..d716039a2f 100644 --- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala @@ -567,8 +567,6 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp } test("array_compact") { - // TODO fix for Spark 4.0.0 - assume(!isSpark40Plus) Seq(true, false).foreach { dictionaryEnabled => withTempDir { dir => withTempView("t1") {