diff --git a/docs/source/user-guide/latest/expressions.md b/docs/source/user-guide/latest/expressions.md index 86e70cb330..6e8e9407d6 100644 --- a/docs/source/user-guide/latest/expressions.md +++ b/docs/source/user-guide/latest/expressions.md @@ -348,22 +348,20 @@ expression-level). The `outer` variants are wired but marked `Incompatible`; the ## lambda_funcs -All higher-order functions are planned via [#4224](https://github.com/apache/datafusion-comet/issues/4224). - | Function | Status | Notes | | --- | --- | --- | -| `aggregate` | 🔜 | [#4224](https://github.com/apache/datafusion-comet/issues/4224) | -| `array_sort` | 🔜 | [#4224](https://github.com/apache/datafusion-comet/issues/4224) | -| `exists` | 🔜 | [#4224](https://github.com/apache/datafusion-comet/issues/4224) | -| `filter` | 🔜 | [#4224](https://github.com/apache/datafusion-comet/issues/4224) | -| `forall` | 🔜 | [#4224](https://github.com/apache/datafusion-comet/issues/4224) | -| `map_filter` | 🔜 | [#4224](https://github.com/apache/datafusion-comet/issues/4224) | -| `map_zip_with` | 🔜 | [#4224](https://github.com/apache/datafusion-comet/issues/4224) | -| `reduce` | 🔜 | [#4224](https://github.com/apache/datafusion-comet/issues/4224) | -| `transform` | 🔜 | [#4224](https://github.com/apache/datafusion-comet/issues/4224) | -| `transform_keys` | 🔜 | [#4224](https://github.com/apache/datafusion-comet/issues/4224) | -| `transform_values` | 🔜 | [#4224](https://github.com/apache/datafusion-comet/issues/4224) | -| `zip_with` | 🔜 | [#4224](https://github.com/apache/datafusion-comet/issues/4224) | +| `aggregate` | ✅ | | +| `array_sort` | ✅ | | +| `exists` | ✅ | | +| `filter` | 🔜 | General lambda not yet wired; the `array_compact` form is supported ([#4224](https://github.com/apache/datafusion-comet/issues/4224)) | +| `forall` | ✅ | | +| `map_filter` | ✅ | | +| `map_zip_with` | ✅ | | +| `reduce` | ✅ | | +| `transform` | ✅ | | +| `transform_keys` | ✅ | | +| `transform_values` | ✅ | | +| `zip_with` | ✅ | | --- diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index c6160bddd4..a69027139e 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -974,6 +974,63 @@ impl PhysicalPlanner { } } + /// DataFusion's nested comparison kernel (`apply_cmp_for_nested`) requires both operands to + /// have identical data types, including nested field nullability, whereas Spark comparisons + /// ignore nullability. When a comparison's operands are nested types that differ only in + /// nullability (e.g. a higher-order `transform` produces `List(non-null Struct)` while the + /// other side is `List(nullable Struct)`), cast both to their nullability-union type so the + /// kernel accepts them. Non-comparison ops and non-nested or already-matching types are left + /// untouched. + pub fn reconcile_nested_comparison_types( + left: Arc, + right: Arc, + op: &DataFusionOperator, + input_schema: &SchemaRef, + ) -> (Arc, Arc) { + use DataFusionOperator::*; + let is_cmp = matches!( + op, + Eq | NotEq | Lt | LtEq | Gt | GtEq | IsDistinctFrom | IsNotDistinctFrom + ); + if !is_cmp { + return (left, right); + } + let (lt, rt) = match (left.data_type(input_schema), right.data_type(input_schema)) { + (Ok(lt), Ok(rt)) => (lt, rt), + _ => return (left, right), + }; + // Only nested types route through `apply_cmp_for_nested`; primitives coerce fine. + let nested = matches!( + lt, + DataType::List(_) + | DataType::LargeList(_) + | DataType::FixedSizeList(_, _) + | DataType::Struct(_) + | DataType::Map(_, _) + ); + if !nested || lt.equals_datatype(&rt) { + return (left, right); + } + // `Field::try_merge` unions nullability recursively while preserving structure (and the + // Map/list invariants). Bail out unchanged if the structures are genuinely incompatible. + let mut merged = Field::new("c", lt.clone(), true); + if merged + .try_merge(&Field::new("c", rt.clone(), true)) + .is_err() + { + return (left, right); + } + let target = merged.data_type().clone(); + let cast_to_target = |e: Arc, dt: &DataType| -> Arc { + if dt.equals_datatype(&target) { + e + } else { + Arc::new(CastExpr::new(e, target.clone(), None)) + } + }; + (cast_to_target(left, <), cast_to_target(right, &rt)) + } + /// Create a DataFusion physical plan from Spark physical plan. There is a level of /// abstraction where a tree of SparkPlan nodes is returned. There is a 1:1 mapping from a /// protobuf Operator (that represents a Spark operator) to a native SparkPlan struct. We diff --git a/native/core/src/execution/planner/macros.rs b/native/core/src/execution/planner/macros.rs index 9d9ccf35da..0ec60c0f7f 100644 --- a/native/core/src/execution/planner/macros.rs +++ b/native/core/src/execution/planner/macros.rs @@ -80,7 +80,19 @@ macro_rules! binary_expr_builder { expr.left.as_ref().unwrap(), std::sync::Arc::clone(&input_schema), )?; - let right = planner.create_expr(expr.right.as_ref().unwrap(), input_schema)?; + let right = planner.create_expr( + expr.right.as_ref().unwrap(), + std::sync::Arc::clone(&input_schema), + )?; + // Reconcile nested operand nullability for comparisons (Spark ignores nullability, + // DataFusion's nested comparison kernel does not). + let (left, right) = + $crate::execution::planner::PhysicalPlanner::reconcile_nested_comparison_types( + left, + right, + &$operator, + &input_schema, + ); Ok(std::sync::Arc::new( datafusion::physical_expr::expressions::BinaryExpr::new(left, $operator, right), )) diff --git a/spark/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegenInput.scala b/spark/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegenInput.scala index 74e4881de0..09bfc52bd4 100644 --- a/spark/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegenInput.scala +++ b/spark/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegenInput.scala @@ -526,6 +526,7 @@ private[codegen] object CometBatchKernelCodegenInput { | return $elemPath.${nullCheckMethod(spec.element)}(startIndex + i); | }""".stripMargin val elementGetter = emitArrayElementGetter(path, spec) + val copy = emitArrayCopyMethod(spec) s""" private final class InputArray_$path extends $baseClassName { | private final int startIndex; | private final int length; @@ -543,10 +544,77 @@ private[codegen] object CometBatchKernelCodegenInput { |$isNullAt | |$elementGetter + | + |$copy | } |""".stripMargin } + /** + * Emit `copy()` for an `InputArray_${path}`. These views read straight off the off-heap Arrow + * buffers, which are only valid for the current batch, so Spark's `InternalRow.copyValue` + * (invoked by e.g. `ArrayTransform.nullSafeEval` when a lambda passes a complex element + * through) must deep-materialize into on-heap `GenericArrayData`. Scalars autobox, strings + * clone off the Arrow buffer, and nested array/struct/map elements recurse through their own + * `copy()`. + */ + private def emitArrayCopyMethod(spec: ArrayColumnSpec): String = { + val getter = elementGetterCall(spec.elementSparkType, "__i") + val copyExpr = copyValueExpr(getter, spec.elementSparkType) + val assign = + if (spec.element.nullable) { + s""" if (isNullAt(__i)) { + | __vals[__i] = null; + | } else { + | __vals[__i] = $copyExpr; + | }""".stripMargin + } else { + s" __vals[__i] = $copyExpr;" + } + s""" @Override + | public org.apache.spark.sql.catalyst.util.ArrayData copy() { + | int __n = numElements(); + | Object[] __vals = new Object[__n]; + | for (int __i = 0; __i < __n; __i++) { + |$assign + | } + | return new org.apache.spark.sql.catalyst.util.GenericArrayData(__vals); + | }""".stripMargin + } + + /** + * The typed getter call (`getX(idx)`) used to read a value of `dt` out of a nested array + * element or struct field. `idx` is the index/ordinal token (e.g. `"__i"` or `"3"`). + */ + private def elementGetterCall(dt: DataType, idx: String): String = dt match { + case BooleanType => s"getBoolean($idx)" + case ByteType => s"getByte($idx)" + case ShortType => s"getShort($idx)" + case IntegerType | DateType => s"getInt($idx)" + case LongType | TimestampType | TimestampNTZType => s"getLong($idx)" + case FloatType => s"getFloat($idx)" + case DoubleType => s"getDouble($idx)" + case d: DecimalType => s"getDecimal($idx, ${d.precision}, ${d.scale})" + case _: StringType => s"getUTF8String($idx)" + case BinaryType => s"getBinary($idx)" + case _: ArrayType => s"getArray($idx)" + case _: StructType => s"getStruct($idx, ${dt.asInstanceOf[StructType].fields.length})" + case _: MapType => s"getMap($idx)" + case other => + throw new UnsupportedOperationException(s"nested copy: unsupported type $other") + } + + /** + * Wrap a non-null getter expression so the produced value is detached from the Arrow buffers: + * primitives/decimals/binary are already by-value, strings clone, and nested complex values + * recurse through their own `copy()`. + */ + private def copyValueExpr(getter: String, dt: DataType): String = dt match { + case _: StringType => s"$getter.clone()" + case _: ArrayType | _: StructType | _: MapType => s"$getter.copy()" + case _ => getter + } + /** * Element-getter body for a nested array. Scalar -> direct typed read. Complex -> allocate a * fresh inner view. @@ -690,6 +758,7 @@ private[codegen] object CometBatchKernelCodegenInput { } val scalarGetters = emitStructScalarGetters(path, spec) val complexGetters = emitStructComplexGetters(path, spec) + val copy = emitStructCopyMethod(spec) s""" private final class InputStruct_$path extends $baseClassName { | private final int rowIdx; | @@ -713,10 +782,40 @@ private[codegen] object CometBatchKernelCodegenInput { | |$scalarGetters |$complexGetters + | + |$copy | } |""".stripMargin } + /** + * Emit `copy()` for an `InputStruct_${path}`. Deep-materializes into an on-heap + * `GenericInternalRow` so Spark's `InternalRow.copyValue` (e.g. a lambda passing a struct + * element through) detaches it from the per-batch Arrow buffers. Mirrors + * [[emitArrayCopyMethod]]. + */ + private def emitStructCopyMethod(spec: StructColumnSpec): String = { + val assigns = spec.fields.zipWithIndex.map { case (f, fi) => + val getter = elementGetterCall(f.sparkType, fi.toString) + val copyExpr = copyValueExpr(getter, f.sparkType) + if (f.nullable) { + s""" if (isNullAt($fi)) { + | __vals[$fi] = null; + | } else { + | __vals[$fi] = $copyExpr; + | }""".stripMargin + } else { + s" __vals[$fi] = $copyExpr;" + } + } + s""" @Override + | public org.apache.spark.sql.catalyst.InternalRow copy() { + | Object[] __vals = new Object[${spec.fields.length}]; + |${assigns.mkString("\n")} + | return new org.apache.spark.sql.catalyst.expressions.GenericInternalRow(__vals); + | }""".stripMargin + } + // Scalar-read body templates parameterized on row-index expression (`idx`), cached buffer // addresses (`valueAddr`, `offsetAddr`) for unsafe reads, or the Arrow field for the decimal // slow path. `ind` is the per-line indent. @@ -941,6 +1040,12 @@ private[codegen] object CometBatchKernelCodegenInput { | public org.apache.spark.sql.catalyst.util.ArrayData valueArray() { | return new InputArray_$valPath(this.startIndex, this.length); | } + | + | @Override + | public org.apache.spark.sql.catalyst.util.MapData copy() { + | return new org.apache.spark.sql.catalyst.util.ArrayBasedMapData( + | keyArray().copy(), valueArray().copy()); + | } | } |""".stripMargin } diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 2c49114dd8..d322af0ae1 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -73,7 +73,13 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { classOf[Flatten] -> CometFlatten, classOf[GetArrayItem] -> CometGetArrayItem, classOf[Size] -> CometSize, - classOf[ArraysZip] -> CometArraysZip) + classOf[ArraysZip] -> CometArraysZip, + classOf[ArrayTransform] -> CometArrayTransform, + classOf[ArrayExists] -> CometArrayExists, + classOf[ArrayForAll] -> CometArrayForAll, + classOf[ArrayAggregate] -> CometArrayAggregate, + classOf[ArraySort] -> CometArraySort, + classOf[ZipWith] -> CometZipWith) private val conditionalExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map(classOf[CaseWhen] -> CometCaseWhen, classOf[If] -> CometIf) @@ -153,7 +159,11 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { classOf[MapFromArrays] -> CometMapFromArrays, classOf[MapContainsKey] -> CometMapContainsKey, classOf[MapFromEntries] -> CometMapFromEntries, - classOf[StringToMap] -> CometStrToMap) + classOf[StringToMap] -> CometStrToMap, + classOf[MapFilter] -> CometMapFilter, + classOf[TransformKeys] -> CometTransformKeys, + classOf[TransformValues] -> CometTransformValues, + classOf[MapZipWith] -> CometMapZipWith) private[comet] val structExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index 2f89b0e2e3..a25447dd92 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -22,7 +22,7 @@ package org.apache.comet.serde import scala.annotation.tailrec import scala.jdk.CollectionConverters._ -import org.apache.spark.sql.catalyst.expressions.{And, ArrayAppend, ArrayContains, ArrayExcept, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayPosition, ArrayRemove, ArrayRepeat, ArraysOverlap, ArraysZip, ArrayUnion, Attribute, Cast, CreateArray, ElementAt, EmptyRow, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse, Size, Slice, SortArray} +import org.apache.spark.sql.catalyst.expressions.{And, ArrayAggregate, ArrayAppend, ArrayContains, ArrayExcept, ArrayExists, ArrayFilter, ArrayForAll, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayPosition, ArrayRemove, ArrayRepeat, ArraySort, ArraysOverlap, ArraysZip, ArrayTransform, ArrayUnion, Attribute, Cast, CreateArray, ElementAt, EmptyRow, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse, Size, Slice, SortArray, ZipWith} import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -843,3 +843,15 @@ trait ArraysBase { } } } + +object CometArrayTransform extends CometCodegenDispatch[ArrayTransform] + +object CometArrayExists extends CometCodegenDispatch[ArrayExists] + +object CometArrayForAll extends CometCodegenDispatch[ArrayForAll] + +object CometArrayAggregate extends CometCodegenDispatch[ArrayAggregate] + +object CometArraySort extends CometCodegenDispatch[ArraySort] + +object CometZipWith extends CometCodegenDispatch[ZipWith] diff --git a/spark/src/main/scala/org/apache/comet/serde/maps.scala b/spark/src/main/scala/org/apache/comet/serde/maps.scala index abecbaa16d..419f21d915 100644 --- a/spark/src/main/scala/org/apache/comet/serde/maps.scala +++ b/spark/src/main/scala/org/apache/comet/serde/maps.scala @@ -163,3 +163,11 @@ object CometMapFromEntries extends CometScalarFunction[MapFromEntries]("map_from } object CometStrToMap extends CometScalarFunction[StringToMap]("str_to_map") + +object CometMapFilter extends CometCodegenDispatch[MapFilter] + +object CometTransformKeys extends CometCodegenDispatch[TransformKeys] + +object CometTransformValues extends CometCodegenDispatch[TransformValues] + +object CometMapZipWith extends CometCodegenDispatch[MapZipWith] diff --git a/spark/src/test/resources/sql-tests/expressions/array/aggregate.sql b/spark/src/test/resources/sql-tests/expressions/array/aggregate.sql new file mode 100644 index 0000000000..16fde7dac2 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/array/aggregate.sql @@ -0,0 +1,45 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Higher-order function aggregate (a.k.a. reduce). Runs through Comet's codegen dispatcher. + +statement +CREATE TABLE test_aggregate(a array, start int) USING parquet + +statement +INSERT INTO test_aggregate VALUES + (array(1, 2, 3), 0), + (array(-5, 5), 100), + (array(10), 0), + (array(), 0), + (NULL, NULL) + +-- basic sum +query +SELECT aggregate(a, 0, (acc, x) -> acc + x) FROM test_aggregate + +-- column capture in the initial value +query +SELECT aggregate(a, start, (acc, x) -> acc + x) FROM test_aggregate + +-- with a finish function +query +SELECT aggregate(a, 0, (acc, x) -> acc + x, acc -> acc * 10) FROM test_aggregate + +-- all literals +query +SELECT aggregate(array(1, 2, 3, 4), 0, (acc, x) -> acc + x) diff --git a/spark/src/test/resources/sql-tests/expressions/array/array_sort_comparator.sql b/spark/src/test/resources/sql-tests/expressions/array/array_sort_comparator.sql new file mode 100644 index 0000000000..67905bb74e --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/array/array_sort_comparator.sql @@ -0,0 +1,44 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- array_sort with a comparator lambda is a higher-order function and runs through Comet's codegen +-- dispatcher. (array_sort without a comparator has a separate native path.) + +statement +CREATE TABLE test_array_sort(a array) USING parquet + +statement +INSERT INTO test_array_sort VALUES + (array(3, 1, 2)), + (array(-5, 5, 0)), + (array(10)), + (array()), + (NULL) + +-- descending comparator +query +SELECT array_sort(a, (l, r) -> CASE WHEN l < r THEN 1 WHEN l > r THEN -1 ELSE 0 END) +FROM test_array_sort + +-- ascending comparator +query +SELECT array_sort(a, (l, r) -> CASE WHEN l < r THEN -1 WHEN l > r THEN 1 ELSE 0 END) +FROM test_array_sort + +-- all literals +query +SELECT array_sort(array(3, 1, 2), (l, r) -> CASE WHEN l < r THEN 1 WHEN l > r THEN -1 ELSE 0 END) diff --git a/spark/src/test/resources/sql-tests/expressions/array/exists.sql b/spark/src/test/resources/sql-tests/expressions/array/exists.sql new file mode 100644 index 0000000000..fb9782a092 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/array/exists.sql @@ -0,0 +1,45 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Higher-order function exists. Runs through Comet's codegen dispatcher. + +statement +CREATE TABLE test_exists(a array, threshold int) USING parquet + +statement +INSERT INTO test_exists VALUES + (array(1, 2, 3), 2), + (array(-5, 5), 0), + (array(0), 0), + (array(), 0), + (NULL, NULL) + +-- basic +query +SELECT exists(a, x -> x > 2) FROM test_exists + +-- column capture +query +SELECT exists(a, x -> x > threshold) FROM test_exists + +-- predicate that can be null (array with nulls) +query +SELECT exists(array(1, NULL, 3), x -> x > 2) + +-- all literals +query +SELECT exists(array(1, 2, 3), x -> x < 0) diff --git a/spark/src/test/resources/sql-tests/expressions/array/forall.sql b/spark/src/test/resources/sql-tests/expressions/array/forall.sql new file mode 100644 index 0000000000..902116e9db --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/array/forall.sql @@ -0,0 +1,45 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Higher-order function forall. Runs through Comet's codegen dispatcher. + +statement +CREATE TABLE test_forall(a array, threshold int) USING parquet + +statement +INSERT INTO test_forall VALUES + (array(1, 2, 3), 0), + (array(-5, 5), 0), + (array(0), 0), + (array(), 0), + (NULL, NULL) + +-- basic +query +SELECT forall(a, x -> x > 0) FROM test_forall + +-- column capture +query +SELECT forall(a, x -> x >= threshold) FROM test_forall + +-- predicate that can be null (array with nulls) +query +SELECT forall(array(2, NULL, 4), x -> x > 0) + +-- all literals +query +SELECT forall(array(1, 2, 3), x -> x > 0) diff --git a/spark/src/test/resources/sql-tests/expressions/array/higher_order_function_fallback.sql b/spark/src/test/resources/sql-tests/expressions/array/higher_order_function_fallback.sql new file mode 100644 index 0000000000..3bb5581b99 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/array/higher_order_function_fallback.sql @@ -0,0 +1,39 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Higher-order functions have no native rust path; they ride the codegen dispatcher. With the +-- dispatcher disabled they have no native path and the projection falls back to Spark while still +-- producing correct results. + +-- Config: spark.comet.exec.scalaUDF.codegen.enabled=false + +statement +CREATE TABLE test_hof_fallback(a array, m map) USING parquet + +statement +INSERT INTO test_hof_fallback VALUES + (array(1, 2, 3), map('a', 1, 'b', 2)), + (array(), map()), + (NULL, NULL) + +-- array higher-order function falls back to Spark +query expect_fallback(spark.comet.exec.scalaUDF.codegen.enabled) +SELECT transform(a, x -> x + 1) FROM test_hof_fallback + +-- map higher-order function falls back to Spark +query expect_fallback(spark.comet.exec.scalaUDF.codegen.enabled) +SELECT transform_values(m, (k, v) -> v + 1) FROM test_hof_fallback diff --git a/spark/src/test/resources/sql-tests/expressions/array/transform.sql b/spark/src/test/resources/sql-tests/expressions/array/transform.sql new file mode 100644 index 0000000000..4649091761 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/array/transform.sql @@ -0,0 +1,82 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Higher-order function transform. No native (rust) implementation; runs through Comet's codegen +-- dispatcher, so the projection stays native and matches Spark. + +statement +CREATE TABLE test_transform( + a array, + s array, + factor int, + nested array>, + structs array>, + maps array>) USING parquet + +statement +INSERT INTO test_transform VALUES + (array(1, 2, 3), array('a', 'b'), 10, + array(array(1, 2), array(3)), array(struct(1, 'a'), struct(2, 'b')), array(map('k', 1))), + (array(-5, 5), array('x'), 100, + array(array()), array(struct(3, 'c')), array(map('p', 9), map('q', 8))), + (array(), array(), 0, array(), array(), array()), + (NULL, NULL, NULL, NULL, NULL, NULL) + +-- basic +query +SELECT transform(a, x -> x + 1) FROM test_transform + +-- lambda with element index +query +SELECT transform(a, (x, i) -> x + i) FROM test_transform + +-- column capture: lambda references another column from the row +query +SELECT transform(a, x -> x + factor) FROM test_transform + +-- string elements +query +SELECT transform(s, x -> concat(x, '!')) FROM test_transform + +-- nested array> +query +SELECT transform(nested, x -> x) FROM test_transform + +-- nested array>, inner transform +query +SELECT transform(nested, x -> transform(x, y -> y * 2)) FROM test_transform + +-- nested array, identity passes a complex element through (exercises element copy) +query +SELECT transform(structs, e -> e) FROM test_transform + +-- nested array, project a field into a new struct +query +SELECT transform(structs, e -> struct(e.x + 1 AS x)) FROM test_transform + +-- nested array, identity passes a map element through +query +SELECT transform(maps, e -> e) FROM test_transform + +-- compare a transform result (containsNull=false) against a nullable-element column: exercises +-- nested-comparison nullability reconciliation +query +SELECT a FROM test_transform WHERE transform(a, x -> array(x)) = nested + +-- all literals (constant folding is disabled by the test harness) +query +SELECT transform(array(1, 2, 3), x -> x * x) diff --git a/spark/src/test/resources/sql-tests/expressions/array/zip_with.sql b/spark/src/test/resources/sql-tests/expressions/array/zip_with.sql new file mode 100644 index 0000000000..9345dbdfcd --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/array/zip_with.sql @@ -0,0 +1,45 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Higher-order function zip_with. Runs through Comet's codegen dispatcher. + +statement +CREATE TABLE test_zip_with(a array, b array) USING parquet + +statement +INSERT INTO test_zip_with VALUES + (array(1, 2, 3), array(10, 20, 30)), + (array(1, 2), array(5)), + (array(), array()), + (array(1), NULL), + (NULL, NULL) + +-- basic, equal lengths +query +SELECT zip_with(a, b, (x, y) -> x + y) FROM test_zip_with + +-- unequal lengths: shorter side padded with NULL +query +SELECT zip_with(a, b, (x, y) -> coalesce(x, 0) + coalesce(y, 0)) FROM test_zip_with + +-- build a struct from both elements +query +SELECT zip_with(a, b, (x, y) -> struct(x AS l, y AS r)) FROM test_zip_with + +-- all literals +query +SELECT zip_with(array(1, 2), array(3, 4), (x, y) -> x * y) diff --git a/spark/src/test/resources/sql-tests/expressions/map/map_filter.sql b/spark/src/test/resources/sql-tests/expressions/map/map_filter.sql new file mode 100644 index 0000000000..edf0484710 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/map/map_filter.sql @@ -0,0 +1,44 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Higher-order function map_filter. Runs through Comet's codegen dispatcher. + +statement +CREATE TABLE test_map_filter(m map, threshold int) USING parquet + +statement +INSERT INTO test_map_filter VALUES + (map('a', 1, 'b', 2, 'c', 3), 1), + (map('x', -1), 0), + (map(), 0), + (NULL, NULL) + +-- filter on value +query +SELECT map_filter(m, (k, v) -> v > 1) FROM test_map_filter + +-- filter on key +query +SELECT map_filter(m, (k, v) -> k = 'a') FROM test_map_filter + +-- column capture +query +SELECT map_filter(m, (k, v) -> v > threshold) FROM test_map_filter + +-- all literals +query +SELECT map_filter(map('a', 1, 'b', 2), (k, v) -> v > 1) diff --git a/spark/src/test/resources/sql-tests/expressions/map/map_zip_with.sql b/spark/src/test/resources/sql-tests/expressions/map/map_zip_with.sql new file mode 100644 index 0000000000..181e6a7c15 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/map/map_zip_with.sql @@ -0,0 +1,41 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Higher-order function map_zip_with. Runs through Comet's codegen dispatcher. + +statement +CREATE TABLE test_map_zip_with(m map, n map) USING parquet + +statement +INSERT INTO test_map_zip_with VALUES + (map('a', 1, 'b', 2), map('a', 10, 'c', 30)), + (map('x', -1), map('x', 5)), + (map(), map()), + (map('k', 1), NULL), + (NULL, NULL) + +-- combine values present in either map (missing side is NULL) +query +SELECT map_zip_with(m, n, (k, v1, v2) -> coalesce(v1, 0) + coalesce(v2, 0)) FROM test_map_zip_with + +-- build a struct of both values +query +SELECT map_zip_with(m, n, (k, v1, v2) -> struct(v1 AS left, v2 AS right)) FROM test_map_zip_with + +-- all literals +query +SELECT map_zip_with(map('a', 1), map('a', 2, 'b', 3), (k, v1, v2) -> coalesce(v1, 0) + coalesce(v2, 0)) diff --git a/spark/src/test/resources/sql-tests/expressions/map/transform_keys.sql b/spark/src/test/resources/sql-tests/expressions/map/transform_keys.sql new file mode 100644 index 0000000000..cc91075a72 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/map/transform_keys.sql @@ -0,0 +1,40 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Higher-order function transform_keys. Runs through Comet's codegen dispatcher. + +statement +CREATE TABLE test_transform_keys(m map, suffix string) USING parquet + +statement +INSERT INTO test_transform_keys VALUES + (map('a', 1, 'b', 2), 'X'), + (map('x', -1), 'Y'), + (map(), 'Z'), + (NULL, NULL) + +-- rewrite keys using key and value +query +SELECT transform_keys(m, (k, v) -> concat(k, cast(v AS string))) FROM test_transform_keys + +-- column capture +query +SELECT transform_keys(m, (k, v) -> concat(k, suffix)) FROM test_transform_keys + +-- all literals +query +SELECT transform_keys(map('a', 1, 'b', 2), (k, v) -> upper(k)) diff --git a/spark/src/test/resources/sql-tests/expressions/map/transform_values.sql b/spark/src/test/resources/sql-tests/expressions/map/transform_values.sql new file mode 100644 index 0000000000..50e49d8959 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/map/transform_values.sql @@ -0,0 +1,44 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Higher-order function transform_values. Runs through Comet's codegen dispatcher. + +statement +CREATE TABLE test_transform_values(m map, delta int) USING parquet + +statement +INSERT INTO test_transform_values VALUES + (map('a', 1, 'b', 2), 10), + (map('x', -1), 5), + (map(), 0), + (NULL, NULL) + +-- rewrite values using value +query +SELECT transform_values(m, (k, v) -> v + 1) FROM test_transform_values + +-- rewrite values using key and value +query +SELECT transform_values(m, (k, v) -> concat(k, '=', cast(v AS string))) FROM test_transform_values + +-- column capture +query +SELECT transform_values(m, (k, v) -> v + delta) FROM test_transform_values + +-- all literals +query +SELECT transform_values(map('a', 1, 'b', 2), (k, v) -> v * 100) diff --git a/spark/src/test/scala/org/apache/comet/CometCodegenSourceSuite.scala b/spark/src/test/scala/org/apache/comet/CometCodegenSourceSuite.scala index 9416861da4..87c52f153c 100644 --- a/spark/src/test/scala/org/apache/comet/CometCodegenSourceSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCodegenSourceSuite.scala @@ -700,6 +700,50 @@ class CometCodegenSourceSuite extends AnyFunSuite { s"expected innermost scalar getter for IntegerType element; got:\n$src") } + test("nested input classes emit copy() that deep-materializes off the Arrow buffers") { + // Higher-order functions (e.g. ArrayTransform) evaluate Spark's interpreted lambda, which + // calls InternalRow.copyValue on complex elements. The nested input views read straight off + // the per-batch Arrow buffers, so copy() must materialize into on-heap Spark structures: + // GenericArrayData for arrays, GenericInternalRow for structs, ArrayBasedMapData for maps. + // Strings clone (they alias off-heap memory) and nested complex elements recurse via copy(). + val stringChild = ScalarColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("VarCharVector"), + nullable = true) + val innerStruct = StructColumnSpec( + nullable = true, + fields = Seq(StructFieldSpec("s", StringType, nullable = true, stringChild))) + // Array>>: exercises array, map, and struct copy() together. + val mapSpec = MapColumnSpec( + nullable = true, + keySparkType = StringType, + valueSparkType = StructType(Seq(StructField("s", StringType))), + key = stringChild, + value = innerStruct) + val outerArray = ArrayColumnSpec( + nullable = true, + elementSparkType = MapType(StringType, StructType(Seq(StructField("s", StringType)))), + element = mapSpec) + val mapType = MapType(StringType, StructType(Seq(StructField("s", StringType)))) + val expr = Size(BoundReference(0, ArrayType(mapType), nullable = true)) + val src = generate(expr, IndexedSeq(outerArray)) + + assert( + src.contains("public org.apache.spark.sql.catalyst.util.ArrayData copy()"), + s"expected array copy() override; got:\n$src") + assert( + src.contains("new org.apache.spark.sql.catalyst.util.GenericArrayData("), + s"expected array copy() to materialize a GenericArrayData; got:\n$src") + assert( + src.contains("new org.apache.spark.sql.catalyst.util.ArrayBasedMapData("), + s"expected map copy() to materialize an ArrayBasedMapData; got:\n$src") + assert( + src.contains("new org.apache.spark.sql.catalyst.expressions.GenericInternalRow("), + s"expected struct copy() to materialize a GenericInternalRow; got:\n$src") + assert( + src.contains(".clone()"), + s"expected string elements to clone off the Arrow buffer in copy(); got:\n$src") + } + test("Array> emits array class allocating fresh InputStruct_col0_e") { val innerStruct = StructColumnSpec( nullable = true, diff --git a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala index b0be2b90ac..17feb3ce84 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala @@ -212,10 +212,10 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar .sortWithinPartitions($"_2") // Spark 4.0 normalizes shuffle keys containing array via - // transform(arr, x -> mapsort(x)), which Comet doesn't yet - // support, so the shuffle falls back to Spark. - val expectedShuffles = if (isSpark40Plus) 0 else 1 - checkShuffleAnswer(df, expectedShuffles) + // transform(arr, x -> mapsort(x)). Comet routes the higher-order + // transform through the codegen dispatcher, so the partitioning + // expression stays native and the shuffle runs as Comet shuffle. + checkShuffleAnswer(df, 1) } } }