Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/user-guide/latest/expressions.md
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ Comet supports using the following aggregate functions within window contexts wi
| Expression | Spark-Compatible? | Compatibility Notes |
| -------------- | ----------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| ArrayAppend | No | |
| ArrayCompact | No | |
| ArrayCompact | Yes | |
| ArrayContains | No | Returns null instead of false for empty arrays with literal values ([#3346](https://github.com/apache/datafusion-comet/issues/3346)) |
| ArrayDistinct | No | Behaves differently than spark. Comet first sorts then removes duplicates while Spark preserves the original order. |
| ArrayExcept | No | |
Expand Down
8 changes: 6 additions & 2 deletions spark/src/main/scala/org/apache/comet/serde/arrays.scala
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ object CometArrayRepeat extends CometExpressionSerde[ArrayRepeat] {

object CometArrayCompact extends CometExpressionSerde[Expression] {

override def getSupportLevel(expr: Expression): SupportLevel = Incompatible(None)
override def getSupportLevel(expr: Expression): SupportLevel = Compatible()

override def convert(
expr: Expression,
Expand All @@ -317,9 +317,13 @@ object CometArrayCompact extends CometExpressionSerde[Expression] {
val arrayExprProto = exprToProto(child, inputs, binding)
val nullLiteralProto = exprToProto(Literal(null, elementType), Seq.empty)

// Pass containsNull=true because DataFusion's array_remove_all always returns
// a list type with nullable elements; the containsNull=false from Spark's expr.dataType
// would cause a runtime type-mismatch assertion in DataFusion's ScalarFunctionExpr.
val returnType = ArrayType(elementType, containsNull = true)
val arrayCompactScalarExpr = scalarFunctionExprToProtoWithReturnType(
"array_remove_all",
ArrayType(elementType = elementType),
returnType,
false,
arrayExprProto,
nullLiteralProto)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.types.StringTypeWithCollation
import org.apache.spark.sql.types.{BinaryType, BooleanType, DataTypes, StringType}
import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, DataTypes, StringType}

import org.apache.comet.CometSparkSessionExtensions.withInfo
import org.apache.comet.expressions.{CometCast, CometEvalMode}
import org.apache.comet.serde.{CommonStringExprs, Compatible, ExprOuterClass, Incompatible}
import org.apache.comet.serde.ExprOuterClass.{BinaryOutputStyle, Expr}
import org.apache.comet.serde.QueryPlanSerde.exprToProtoInternal
import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProtoWithReturnType}

/**
* `CometExprShim` acts as a shim for parsing expressions from different Spark versions.
Expand All @@ -55,6 +55,34 @@ trait CometExprShim extends CommonStringExprs {
inputs: Seq[Attribute],
binding: Boolean): Option[Expr] = {
expr match {
case knc: KnownNotContainsNull =>
// On Spark 4.0, array_compact rewrites to KnownNotContainsNull(ArrayFilter(IsNotNull)).
// Strip the wrapper and serialize the inner ArrayFilter as array_remove_all.
knc.child match {
case filter: ArrayFilter =>
filter.function.children.headOption match {
case Some(_: IsNotNull) =>
val arrayChild = filter.left
val elementType = arrayChild.dataType.asInstanceOf[ArrayType].elementType
val arrayExprProto = exprToProtoInternal(arrayChild, inputs, binding)
val nullLiteralProto =
exprToProtoInternal(Literal(null, elementType), Seq.empty, false)
// Pass containsNull=true because DataFusion's array_remove_all always returns
// a list type with nullable elements; knc.dataType has containsNull=false
// which would cause a runtime type-mismatch assertion in DataFusion.
val returnType = ArrayType(elementType, containsNull = true)
val scalarExpr = scalarFunctionExprToProtoWithReturnType(
"array_remove_all",
returnType,
false,
arrayExprProto,
nullLiteralProto)
optExprWithInfo(scalarExpr, knc, arrayChild)
case _ => None
}
case _ => None
}

case s: StaticInvoke
if s.staticObject == classOf[StringDecode] &&
s.dataType.isInstanceOf[StringType] &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,45 @@
-- ConfigMatrix: parquet.enable.dictionary=false,true

statement
CREATE TABLE test_array_compact(arr array<int>) USING parquet
CREATE TABLE test_array_compact(
ints array<int>,
strs array<string>,
dbls array<double>,
nested array<array<int>>
) USING parquet

statement
INSERT INTO test_array_compact VALUES (array(1, NULL, 2, NULL, 3)), (array()), (NULL), (array(NULL, NULL)), (array(1, 2, 3))
INSERT INTO test_array_compact VALUES
(array(1, NULL, 2, NULL, 3), array('a', NULL, 'b', NULL, 'c'), array(1.0, NULL, 2.0), array(array(1, NULL, 3), NULL, array(4, NULL, 6))),
(array(), array(), array(), array()),
(NULL, NULL, NULL, NULL),
(array(NULL, NULL), array(NULL, NULL), array(NULL, NULL), array(NULL, NULL)),
(array(1, 2, 3), array('x', 'y', 'z'), array(1.5, 2.5), array(array(1, 2), array(3, 4)))

query spark_answer_only
SELECT array_compact(arr) FROM test_array_compact
-- integer column
query
SELECT array_compact(ints) FROM test_array_compact

-- string column
query
SELECT array_compact(strs) FROM test_array_compact

-- double column
query
SELECT array_compact(dbls) FROM test_array_compact

-- nested array column: outer nulls removed, inner nulls preserved
query
SELECT array_compact(nested) FROM test_array_compact

-- literal arguments
query spark_answer_only
query
SELECT array_compact(array(1, NULL, 2, NULL, 3))

-- literal string array
query
SELECT array_compact(array('a', NULL, 'b'))

-- all-null literal array
query
SELECT array_compact(array(NULL, NULL, NULL))
Original file line number Diff line number Diff line change
Expand Up @@ -567,8 +567,6 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp
}

test("array_compact") {
// TODO fix for Spark 4.0.0
assume(!isSpark40Plus)
Seq(true, false).foreach { dictionaryEnabled =>
withTempDir { dir =>
withTempView("t1") {
Expand Down
Loading