From a89b23536b0d36a1fc1084f0a3b829e9e7c660f6 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 10 Jun 2026 08:22:06 -0600 Subject: [PATCH] feat: route higher-order functions through codegen dispatcher Register the array and map higher-order (lambda) functions that previously fell back to Spark so they stay native via the codegen dispatcher: - array: transform, exists, forall, aggregate, array_sort (comparator), zip_with - map: map_filter, transform_keys, transform_values, map_zip_with These have no native (rust) implementation and extend Spark's CodegenFallback, which the dispatcher's canHandle already admits, so the projection stays native and matches Spark exactly. When the dispatcher is disabled they fall back to Spark. Supporting fixes so higher-order functions over nested-complex element types stay correct natively: - Emit copy() on the codegen kernel's nested input views (InputArray, InputStruct, InputMap). Spark's interpreted lambda evaluation calls InternalRow.copyValue on complex elements; the views read straight off the per-batch Arrow buffers, so copy() deep-materializes into on-heap GenericArrayData / GenericInternalRow / ArrayBasedMapData, cloning strings and recursing into nested elements. - Reconcile nested operand nullability for native comparisons. DataFusion's nested comparison kernel requires identical types including nested nullability, whereas Spark comparisons ignore it. When a comparison's operands are nested types that differ only in nullability (e.g. a transform result vs a nullable-element column), cast both to their nullability-union type. Update the columnar-shuffle map-array-element test: on Spark 4.0+ the shuffle key normalizes to transform(arr, x -> mapsort(x)), which now stays native, so the shuffle runs as Comet shuffle on all versions. Add SQL file test coverage under expressions/array and expressions/map for each higher-order function (basic, column capture, nested element types, null and empty collections, and the disabled-dispatcher fallback path). --- docs/source/user-guide/latest/expressions.md | 26 ++--- native/core/src/execution/planner.rs | 57 ++++++++++ native/core/src/execution/planner/macros.rs | 14 ++- .../CometBatchKernelCodegenInput.scala | 105 ++++++++++++++++++ .../apache/comet/serde/QueryPlanSerde.scala | 14 ++- .../scala/org/apache/comet/serde/arrays.scala | 14 ++- .../scala/org/apache/comet/serde/maps.scala | 8 ++ .../sql-tests/expressions/array/aggregate.sql | 45 ++++++++ .../array/array_sort_comparator.sql | 44 ++++++++ .../sql-tests/expressions/array/exists.sql | 45 ++++++++ .../sql-tests/expressions/array/forall.sql | 45 ++++++++ .../array/higher_order_function_fallback.sql | 39 +++++++ .../sql-tests/expressions/array/transform.sql | 82 ++++++++++++++ .../sql-tests/expressions/array/zip_with.sql | 45 ++++++++ .../sql-tests/expressions/map/map_filter.sql | 44 ++++++++ .../expressions/map/map_zip_with.sql | 41 +++++++ .../expressions/map/transform_keys.sql | 40 +++++++ .../expressions/map/transform_values.sql | 44 ++++++++ .../comet/CometCodegenSourceSuite.scala | 44 ++++++++ .../exec/CometColumnarShuffleSuite.scala | 8 +- 20 files changed, 782 insertions(+), 22 deletions(-) create mode 100644 spark/src/test/resources/sql-tests/expressions/array/aggregate.sql create mode 100644 spark/src/test/resources/sql-tests/expressions/array/array_sort_comparator.sql create mode 100644 spark/src/test/resources/sql-tests/expressions/array/exists.sql create mode 100644 spark/src/test/resources/sql-tests/expressions/array/forall.sql create mode 100644 spark/src/test/resources/sql-tests/expressions/array/higher_order_function_fallback.sql create mode 100644 spark/src/test/resources/sql-tests/expressions/array/transform.sql create mode 100644 spark/src/test/resources/sql-tests/expressions/array/zip_with.sql create mode 100644 spark/src/test/resources/sql-tests/expressions/map/map_filter.sql create mode 100644 spark/src/test/resources/sql-tests/expressions/map/map_zip_with.sql create mode 100644 spark/src/test/resources/sql-tests/expressions/map/transform_keys.sql create mode 100644 spark/src/test/resources/sql-tests/expressions/map/transform_values.sql diff --git a/docs/source/user-guide/latest/expressions.md b/docs/source/user-guide/latest/expressions.md index 86e70cb330..6e8e9407d6 100644 --- a/docs/source/user-guide/latest/expressions.md +++ b/docs/source/user-guide/latest/expressions.md @@ -348,22 +348,20 @@ expression-level). The `outer` variants are wired but marked `Incompatible`; the ## lambda_funcs -All higher-order functions are planned via [#4224](https://github.com/apache/datafusion-comet/issues/4224). - | Function | Status | Notes | | --- | --- | --- | -| `aggregate` | 🔜 | [#4224](https://github.com/apache/datafusion-comet/issues/4224) | -| `array_sort` | 🔜 | [#4224](https://github.com/apache/datafusion-comet/issues/4224) | -| `exists` | 🔜 | [#4224](https://github.com/apache/datafusion-comet/issues/4224) | -| `filter` | 🔜 | [#4224](https://github.com/apache/datafusion-comet/issues/4224) | -| `forall` | 🔜 | [#4224](https://github.com/apache/datafusion-comet/issues/4224) | -| `map_filter` | 🔜 | [#4224](https://github.com/apache/datafusion-comet/issues/4224) | -| `map_zip_with` | 🔜 | [#4224](https://github.com/apache/datafusion-comet/issues/4224) | -| `reduce` | 🔜 | [#4224](https://github.com/apache/datafusion-comet/issues/4224) | -| `transform` | 🔜 | [#4224](https://github.com/apache/datafusion-comet/issues/4224) | -| `transform_keys` | 🔜 | [#4224](https://github.com/apache/datafusion-comet/issues/4224) | -| `transform_values` | 🔜 | [#4224](https://github.com/apache/datafusion-comet/issues/4224) | -| `zip_with` | 🔜 | [#4224](https://github.com/apache/datafusion-comet/issues/4224) | +| `aggregate` | ✅ | | +| `array_sort` | ✅ | | +| `exists` | ✅ | | +| `filter` | 🔜 | General lambda not yet wired; the `array_compact` form is supported ([#4224](https://github.com/apache/datafusion-comet/issues/4224)) | +| `forall` | ✅ | | +| `map_filter` | ✅ | | +| `map_zip_with` | ✅ | | +| `reduce` | ✅ | | +| `transform` | ✅ | | +| `transform_keys` | ✅ | | +| `transform_values` | ✅ | | +| `zip_with` | ✅ | | --- diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index c6160bddd4..a69027139e 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -974,6 +974,63 @@ impl PhysicalPlanner { } } + /// DataFusion's nested comparison kernel (`apply_cmp_for_nested`) requires both operands to + /// have identical data types, including nested field nullability, whereas Spark comparisons + /// ignore nullability. When a comparison's operands are nested types that differ only in + /// nullability (e.g. a higher-order `transform` produces `List(non-null Struct)` while the + /// other side is `List(nullable Struct)`), cast both to their nullability-union type so the + /// kernel accepts them. Non-comparison ops and non-nested or already-matching types are left + /// untouched. + pub fn reconcile_nested_comparison_types( + left: Arc, + right: Arc, + op: &DataFusionOperator, + input_schema: &SchemaRef, + ) -> (Arc, Arc) { + use DataFusionOperator::*; + let is_cmp = matches!( + op, + Eq | NotEq | Lt | LtEq | Gt | GtEq | IsDistinctFrom | IsNotDistinctFrom + ); + if !is_cmp { + return (left, right); + } + let (lt, rt) = match (left.data_type(input_schema), right.data_type(input_schema)) { + (Ok(lt), Ok(rt)) => (lt, rt), + _ => return (left, right), + }; + // Only nested types route through `apply_cmp_for_nested`; primitives coerce fine. + let nested = matches!( + lt, + DataType::List(_) + | DataType::LargeList(_) + | DataType::FixedSizeList(_, _) + | DataType::Struct(_) + | DataType::Map(_, _) + ); + if !nested || lt.equals_datatype(&rt) { + return (left, right); + } + // `Field::try_merge` unions nullability recursively while preserving structure (and the + // Map/list invariants). Bail out unchanged if the structures are genuinely incompatible. + let mut merged = Field::new("c", lt.clone(), true); + if merged + .try_merge(&Field::new("c", rt.clone(), true)) + .is_err() + { + return (left, right); + } + let target = merged.data_type().clone(); + let cast_to_target = |e: Arc, dt: &DataType| -> Arc { + if dt.equals_datatype(&target) { + e + } else { + Arc::new(CastExpr::new(e, target.clone(), None)) + } + }; + (cast_to_target(left, <), cast_to_target(right, &rt)) + } + /// Create a DataFusion physical plan from Spark physical plan. There is a level of /// abstraction where a tree of SparkPlan nodes is returned. There is a 1:1 mapping from a /// protobuf Operator (that represents a Spark operator) to a native SparkPlan struct. We diff --git a/native/core/src/execution/planner/macros.rs b/native/core/src/execution/planner/macros.rs index 9d9ccf35da..0ec60c0f7f 100644 --- a/native/core/src/execution/planner/macros.rs +++ b/native/core/src/execution/planner/macros.rs @@ -80,7 +80,19 @@ macro_rules! binary_expr_builder { expr.left.as_ref().unwrap(), std::sync::Arc::clone(&input_schema), )?; - let right = planner.create_expr(expr.right.as_ref().unwrap(), input_schema)?; + let right = planner.create_expr( + expr.right.as_ref().unwrap(), + std::sync::Arc::clone(&input_schema), + )?; + // Reconcile nested operand nullability for comparisons (Spark ignores nullability, + // DataFusion's nested comparison kernel does not). + let (left, right) = + $crate::execution::planner::PhysicalPlanner::reconcile_nested_comparison_types( + left, + right, + &$operator, + &input_schema, + ); Ok(std::sync::Arc::new( datafusion::physical_expr::expressions::BinaryExpr::new(left, $operator, right), )) diff --git a/spark/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegenInput.scala b/spark/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegenInput.scala index 74e4881de0..09bfc52bd4 100644 --- a/spark/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegenInput.scala +++ b/spark/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegenInput.scala @@ -526,6 +526,7 @@ private[codegen] object CometBatchKernelCodegenInput { | return $elemPath.${nullCheckMethod(spec.element)}(startIndex + i); | }""".stripMargin val elementGetter = emitArrayElementGetter(path, spec) + val copy = emitArrayCopyMethod(spec) s""" private final class InputArray_$path extends $baseClassName { | private final int startIndex; | private final int length; @@ -543,10 +544,77 @@ private[codegen] object CometBatchKernelCodegenInput { |$isNullAt | |$elementGetter + | + |$copy | } |""".stripMargin } + /** + * Emit `copy()` for an `InputArray_${path}`. These views read straight off the off-heap Arrow + * buffers, which are only valid for the current batch, so Spark's `InternalRow.copyValue` + * (invoked by e.g. `ArrayTransform.nullSafeEval` when a lambda passes a complex element + * through) must deep-materialize into on-heap `GenericArrayData`. Scalars autobox, strings + * clone off the Arrow buffer, and nested array/struct/map elements recurse through their own + * `copy()`. + */ + private def emitArrayCopyMethod(spec: ArrayColumnSpec): String = { + val getter = elementGetterCall(spec.elementSparkType, "__i") + val copyExpr = copyValueExpr(getter, spec.elementSparkType) + val assign = + if (spec.element.nullable) { + s""" if (isNullAt(__i)) { + | __vals[__i] = null; + | } else { + | __vals[__i] = $copyExpr; + | }""".stripMargin + } else { + s" __vals[__i] = $copyExpr;" + } + s""" @Override + | public org.apache.spark.sql.catalyst.util.ArrayData copy() { + | int __n = numElements(); + | Object[] __vals = new Object[__n]; + | for (int __i = 0; __i < __n; __i++) { + |$assign + | } + | return new org.apache.spark.sql.catalyst.util.GenericArrayData(__vals); + | }""".stripMargin + } + + /** + * The typed getter call (`getX(idx)`) used to read a value of `dt` out of a nested array + * element or struct field. `idx` is the index/ordinal token (e.g. `"__i"` or `"3"`). + */ + private def elementGetterCall(dt: DataType, idx: String): String = dt match { + case BooleanType => s"getBoolean($idx)" + case ByteType => s"getByte($idx)" + case ShortType => s"getShort($idx)" + case IntegerType | DateType => s"getInt($idx)" + case LongType | TimestampType | TimestampNTZType => s"getLong($idx)" + case FloatType => s"getFloat($idx)" + case DoubleType => s"getDouble($idx)" + case d: DecimalType => s"getDecimal($idx, ${d.precision}, ${d.scale})" + case _: StringType => s"getUTF8String($idx)" + case BinaryType => s"getBinary($idx)" + case _: ArrayType => s"getArray($idx)" + case _: StructType => s"getStruct($idx, ${dt.asInstanceOf[StructType].fields.length})" + case _: MapType => s"getMap($idx)" + case other => + throw new UnsupportedOperationException(s"nested copy: unsupported type $other") + } + + /** + * Wrap a non-null getter expression so the produced value is detached from the Arrow buffers: + * primitives/decimals/binary are already by-value, strings clone, and nested complex values + * recurse through their own `copy()`. + */ + private def copyValueExpr(getter: String, dt: DataType): String = dt match { + case _: StringType => s"$getter.clone()" + case _: ArrayType | _: StructType | _: MapType => s"$getter.copy()" + case _ => getter + } + /** * Element-getter body for a nested array. Scalar -> direct typed read. Complex -> allocate a * fresh inner view. @@ -690,6 +758,7 @@ private[codegen] object CometBatchKernelCodegenInput { } val scalarGetters = emitStructScalarGetters(path, spec) val complexGetters = emitStructComplexGetters(path, spec) + val copy = emitStructCopyMethod(spec) s""" private final class InputStruct_$path extends $baseClassName { | private final int rowIdx; | @@ -713,10 +782,40 @@ private[codegen] object CometBatchKernelCodegenInput { | |$scalarGetters |$complexGetters + | + |$copy | } |""".stripMargin } + /** + * Emit `copy()` for an `InputStruct_${path}`. Deep-materializes into an on-heap + * `GenericInternalRow` so Spark's `InternalRow.copyValue` (e.g. a lambda passing a struct + * element through) detaches it from the per-batch Arrow buffers. Mirrors + * [[emitArrayCopyMethod]]. + */ + private def emitStructCopyMethod(spec: StructColumnSpec): String = { + val assigns = spec.fields.zipWithIndex.map { case (f, fi) => + val getter = elementGetterCall(f.sparkType, fi.toString) + val copyExpr = copyValueExpr(getter, f.sparkType) + if (f.nullable) { + s""" if (isNullAt($fi)) { + | __vals[$fi] = null; + | } else { + | __vals[$fi] = $copyExpr; + | }""".stripMargin + } else { + s" __vals[$fi] = $copyExpr;" + } + } + s""" @Override + | public org.apache.spark.sql.catalyst.InternalRow copy() { + | Object[] __vals = new Object[${spec.fields.length}]; + |${assigns.mkString("\n")} + | return new org.apache.spark.sql.catalyst.expressions.GenericInternalRow(__vals); + | }""".stripMargin + } + // Scalar-read body templates parameterized on row-index expression (`idx`), cached buffer // addresses (`valueAddr`, `offsetAddr`) for unsafe reads, or the Arrow field for the decimal // slow path. `ind` is the per-line indent. @@ -941,6 +1040,12 @@ private[codegen] object CometBatchKernelCodegenInput { | public org.apache.spark.sql.catalyst.util.ArrayData valueArray() { | return new InputArray_$valPath(this.startIndex, this.length); | } + | + | @Override + | public org.apache.spark.sql.catalyst.util.MapData copy() { + | return new org.apache.spark.sql.catalyst.util.ArrayBasedMapData( + | keyArray().copy(), valueArray().copy()); + | } | } |""".stripMargin } diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 2c49114dd8..d322af0ae1 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -73,7 +73,13 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { classOf[Flatten] -> CometFlatten, classOf[GetArrayItem] -> CometGetArrayItem, classOf[Size] -> CometSize, - classOf[ArraysZip] -> CometArraysZip) + classOf[ArraysZip] -> CometArraysZip, + classOf[ArrayTransform] -> CometArrayTransform, + classOf[ArrayExists] -> CometArrayExists, + classOf[ArrayForAll] -> CometArrayForAll, + classOf[ArrayAggregate] -> CometArrayAggregate, + classOf[ArraySort] -> CometArraySort, + classOf[ZipWith] -> CometZipWith) private val conditionalExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map(classOf[CaseWhen] -> CometCaseWhen, classOf[If] -> CometIf) @@ -153,7 +159,11 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { classOf[MapFromArrays] -> CometMapFromArrays, classOf[MapContainsKey] -> CometMapContainsKey, classOf[MapFromEntries] -> CometMapFromEntries, - classOf[StringToMap] -> CometStrToMap) + classOf[StringToMap] -> CometStrToMap, + classOf[MapFilter] -> CometMapFilter, + classOf[TransformKeys] -> CometTransformKeys, + classOf[TransformValues] -> CometTransformValues, + classOf[MapZipWith] -> CometMapZipWith) private[comet] val structExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index 2f89b0e2e3..a25447dd92 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -22,7 +22,7 @@ package org.apache.comet.serde import scala.annotation.tailrec import scala.jdk.CollectionConverters._ -import org.apache.spark.sql.catalyst.expressions.{And, ArrayAppend, ArrayContains, ArrayExcept, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayPosition, ArrayRemove, ArrayRepeat, ArraysOverlap, ArraysZip, ArrayUnion, Attribute, Cast, CreateArray, ElementAt, EmptyRow, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse, Size, Slice, SortArray} +import org.apache.spark.sql.catalyst.expressions.{And, ArrayAggregate, ArrayAppend, ArrayContains, ArrayExcept, ArrayExists, ArrayFilter, ArrayForAll, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayPosition, ArrayRemove, ArrayRepeat, ArraySort, ArraysOverlap, ArraysZip, ArrayTransform, ArrayUnion, Attribute, Cast, CreateArray, ElementAt, EmptyRow, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse, Size, Slice, SortArray, ZipWith} import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -843,3 +843,15 @@ trait ArraysBase { } } } + +object CometArrayTransform extends CometCodegenDispatch[ArrayTransform] + +object CometArrayExists extends CometCodegenDispatch[ArrayExists] + +object CometArrayForAll extends CometCodegenDispatch[ArrayForAll] + +object CometArrayAggregate extends CometCodegenDispatch[ArrayAggregate] + +object CometArraySort extends CometCodegenDispatch[ArraySort] + +object CometZipWith extends CometCodegenDispatch[ZipWith] diff --git a/spark/src/main/scala/org/apache/comet/serde/maps.scala b/spark/src/main/scala/org/apache/comet/serde/maps.scala index abecbaa16d..419f21d915 100644 --- a/spark/src/main/scala/org/apache/comet/serde/maps.scala +++ b/spark/src/main/scala/org/apache/comet/serde/maps.scala @@ -163,3 +163,11 @@ object CometMapFromEntries extends CometScalarFunction[MapFromEntries]("map_from } object CometStrToMap extends CometScalarFunction[StringToMap]("str_to_map") + +object CometMapFilter extends CometCodegenDispatch[MapFilter] + +object CometTransformKeys extends CometCodegenDispatch[TransformKeys] + +object CometTransformValues extends CometCodegenDispatch[TransformValues] + +object CometMapZipWith extends CometCodegenDispatch[MapZipWith] diff --git a/spark/src/test/resources/sql-tests/expressions/array/aggregate.sql b/spark/src/test/resources/sql-tests/expressions/array/aggregate.sql new file mode 100644 index 0000000000..16fde7dac2 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/array/aggregate.sql @@ -0,0 +1,45 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Higher-order function aggregate (a.k.a. reduce). Runs through Comet's codegen dispatcher. + +statement +CREATE TABLE test_aggregate(a array, start int) USING parquet + +statement +INSERT INTO test_aggregate VALUES + (array(1, 2, 3), 0), + (array(-5, 5), 100), + (array(10), 0), + (array(), 0), + (NULL, NULL) + +-- basic sum +query +SELECT aggregate(a, 0, (acc, x) -> acc + x) FROM test_aggregate + +-- column capture in the initial value +query +SELECT aggregate(a, start, (acc, x) -> acc + x) FROM test_aggregate + +-- with a finish function +query +SELECT aggregate(a, 0, (acc, x) -> acc + x, acc -> acc * 10) FROM test_aggregate + +-- all literals +query +SELECT aggregate(array(1, 2, 3, 4), 0, (acc, x) -> acc + x) diff --git a/spark/src/test/resources/sql-tests/expressions/array/array_sort_comparator.sql b/spark/src/test/resources/sql-tests/expressions/array/array_sort_comparator.sql new file mode 100644 index 0000000000..67905bb74e --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/array/array_sort_comparator.sql @@ -0,0 +1,44 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- array_sort with a comparator lambda is a higher-order function and runs through Comet's codegen +-- dispatcher. (array_sort without a comparator has a separate native path.) + +statement +CREATE TABLE test_array_sort(a array) USING parquet + +statement +INSERT INTO test_array_sort VALUES + (array(3, 1, 2)), + (array(-5, 5, 0)), + (array(10)), + (array()), + (NULL) + +-- descending comparator +query +SELECT array_sort(a, (l, r) -> CASE WHEN l < r THEN 1 WHEN l > r THEN -1 ELSE 0 END) +FROM test_array_sort + +-- ascending comparator +query +SELECT array_sort(a, (l, r) -> CASE WHEN l < r THEN -1 WHEN l > r THEN 1 ELSE 0 END) +FROM test_array_sort + +-- all literals +query +SELECT array_sort(array(3, 1, 2), (l, r) -> CASE WHEN l < r THEN 1 WHEN l > r THEN -1 ELSE 0 END) diff --git a/spark/src/test/resources/sql-tests/expressions/array/exists.sql b/spark/src/test/resources/sql-tests/expressions/array/exists.sql new file mode 100644 index 0000000000..fb9782a092 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/array/exists.sql @@ -0,0 +1,45 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Higher-order function exists. Runs through Comet's codegen dispatcher. + +statement +CREATE TABLE test_exists(a array, threshold int) USING parquet + +statement +INSERT INTO test_exists VALUES + (array(1, 2, 3), 2), + (array(-5, 5), 0), + (array(0), 0), + (array(), 0), + (NULL, NULL) + +-- basic +query +SELECT exists(a, x -> x > 2) FROM test_exists + +-- column capture +query +SELECT exists(a, x -> x > threshold) FROM test_exists + +-- predicate that can be null (array with nulls) +query +SELECT exists(array(1, NULL, 3), x -> x > 2) + +-- all literals +query +SELECT exists(array(1, 2, 3), x -> x < 0) diff --git a/spark/src/test/resources/sql-tests/expressions/array/forall.sql b/spark/src/test/resources/sql-tests/expressions/array/forall.sql new file mode 100644 index 0000000000..902116e9db --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/array/forall.sql @@ -0,0 +1,45 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Higher-order function forall. Runs through Comet's codegen dispatcher. + +statement +CREATE TABLE test_forall(a array, threshold int) USING parquet + +statement +INSERT INTO test_forall VALUES + (array(1, 2, 3), 0), + (array(-5, 5), 0), + (array(0), 0), + (array(), 0), + (NULL, NULL) + +-- basic +query +SELECT forall(a, x -> x > 0) FROM test_forall + +-- column capture +query +SELECT forall(a, x -> x >= threshold) FROM test_forall + +-- predicate that can be null (array with nulls) +query +SELECT forall(array(2, NULL, 4), x -> x > 0) + +-- all literals +query +SELECT forall(array(1, 2, 3), x -> x > 0) diff --git a/spark/src/test/resources/sql-tests/expressions/array/higher_order_function_fallback.sql b/spark/src/test/resources/sql-tests/expressions/array/higher_order_function_fallback.sql new file mode 100644 index 0000000000..3bb5581b99 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/array/higher_order_function_fallback.sql @@ -0,0 +1,39 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Higher-order functions have no native rust path; they ride the codegen dispatcher. With the +-- dispatcher disabled they have no native path and the projection falls back to Spark while still +-- producing correct results. + +-- Config: spark.comet.exec.scalaUDF.codegen.enabled=false + +statement +CREATE TABLE test_hof_fallback(a array, m map) USING parquet + +statement +INSERT INTO test_hof_fallback VALUES + (array(1, 2, 3), map('a', 1, 'b', 2)), + (array(), map()), + (NULL, NULL) + +-- array higher-order function falls back to Spark +query expect_fallback(spark.comet.exec.scalaUDF.codegen.enabled) +SELECT transform(a, x -> x + 1) FROM test_hof_fallback + +-- map higher-order function falls back to Spark +query expect_fallback(spark.comet.exec.scalaUDF.codegen.enabled) +SELECT transform_values(m, (k, v) -> v + 1) FROM test_hof_fallback diff --git a/spark/src/test/resources/sql-tests/expressions/array/transform.sql b/spark/src/test/resources/sql-tests/expressions/array/transform.sql new file mode 100644 index 0000000000..4649091761 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/array/transform.sql @@ -0,0 +1,82 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Higher-order function transform. No native (rust) implementation; runs through Comet's codegen +-- dispatcher, so the projection stays native and matches Spark. + +statement +CREATE TABLE test_transform( + a array, + s array, + factor int, + nested array>, + structs array>, + maps array>) USING parquet + +statement +INSERT INTO test_transform VALUES + (array(1, 2, 3), array('a', 'b'), 10, + array(array(1, 2), array(3)), array(struct(1, 'a'), struct(2, 'b')), array(map('k', 1))), + (array(-5, 5), array('x'), 100, + array(array()), array(struct(3, 'c')), array(map('p', 9), map('q', 8))), + (array(), array(), 0, array(), array(), array()), + (NULL, NULL, NULL, NULL, NULL, NULL) + +-- basic +query +SELECT transform(a, x -> x + 1) FROM test_transform + +-- lambda with element index +query +SELECT transform(a, (x, i) -> x + i) FROM test_transform + +-- column capture: lambda references another column from the row +query +SELECT transform(a, x -> x + factor) FROM test_transform + +-- string elements +query +SELECT transform(s, x -> concat(x, '!')) FROM test_transform + +-- nested array> +query +SELECT transform(nested, x -> x) FROM test_transform + +-- nested array>, inner transform +query +SELECT transform(nested, x -> transform(x, y -> y * 2)) FROM test_transform + +-- nested array, identity passes a complex element through (exercises element copy) +query +SELECT transform(structs, e -> e) FROM test_transform + +-- nested array, project a field into a new struct +query +SELECT transform(structs, e -> struct(e.x + 1 AS x)) FROM test_transform + +-- nested array, identity passes a map element through +query +SELECT transform(maps, e -> e) FROM test_transform + +-- compare a transform result (containsNull=false) against a nullable-element column: exercises +-- nested-comparison nullability reconciliation +query +SELECT a FROM test_transform WHERE transform(a, x -> array(x)) = nested + +-- all literals (constant folding is disabled by the test harness) +query +SELECT transform(array(1, 2, 3), x -> x * x) diff --git a/spark/src/test/resources/sql-tests/expressions/array/zip_with.sql b/spark/src/test/resources/sql-tests/expressions/array/zip_with.sql new file mode 100644 index 0000000000..9345dbdfcd --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/array/zip_with.sql @@ -0,0 +1,45 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Higher-order function zip_with. Runs through Comet's codegen dispatcher. + +statement +CREATE TABLE test_zip_with(a array, b array) USING parquet + +statement +INSERT INTO test_zip_with VALUES + (array(1, 2, 3), array(10, 20, 30)), + (array(1, 2), array(5)), + (array(), array()), + (array(1), NULL), + (NULL, NULL) + +-- basic, equal lengths +query +SELECT zip_with(a, b, (x, y) -> x + y) FROM test_zip_with + +-- unequal lengths: shorter side padded with NULL +query +SELECT zip_with(a, b, (x, y) -> coalesce(x, 0) + coalesce(y, 0)) FROM test_zip_with + +-- build a struct from both elements +query +SELECT zip_with(a, b, (x, y) -> struct(x AS l, y AS r)) FROM test_zip_with + +-- all literals +query +SELECT zip_with(array(1, 2), array(3, 4), (x, y) -> x * y) diff --git a/spark/src/test/resources/sql-tests/expressions/map/map_filter.sql b/spark/src/test/resources/sql-tests/expressions/map/map_filter.sql new file mode 100644 index 0000000000..edf0484710 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/map/map_filter.sql @@ -0,0 +1,44 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Higher-order function map_filter. Runs through Comet's codegen dispatcher. + +statement +CREATE TABLE test_map_filter(m map, threshold int) USING parquet + +statement +INSERT INTO test_map_filter VALUES + (map('a', 1, 'b', 2, 'c', 3), 1), + (map('x', -1), 0), + (map(), 0), + (NULL, NULL) + +-- filter on value +query +SELECT map_filter(m, (k, v) -> v > 1) FROM test_map_filter + +-- filter on key +query +SELECT map_filter(m, (k, v) -> k = 'a') FROM test_map_filter + +-- column capture +query +SELECT map_filter(m, (k, v) -> v > threshold) FROM test_map_filter + +-- all literals +query +SELECT map_filter(map('a', 1, 'b', 2), (k, v) -> v > 1) diff --git a/spark/src/test/resources/sql-tests/expressions/map/map_zip_with.sql b/spark/src/test/resources/sql-tests/expressions/map/map_zip_with.sql new file mode 100644 index 0000000000..181e6a7c15 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/map/map_zip_with.sql @@ -0,0 +1,41 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Higher-order function map_zip_with. Runs through Comet's codegen dispatcher. + +statement +CREATE TABLE test_map_zip_with(m map, n map) USING parquet + +statement +INSERT INTO test_map_zip_with VALUES + (map('a', 1, 'b', 2), map('a', 10, 'c', 30)), + (map('x', -1), map('x', 5)), + (map(), map()), + (map('k', 1), NULL), + (NULL, NULL) + +-- combine values present in either map (missing side is NULL) +query +SELECT map_zip_with(m, n, (k, v1, v2) -> coalesce(v1, 0) + coalesce(v2, 0)) FROM test_map_zip_with + +-- build a struct of both values +query +SELECT map_zip_with(m, n, (k, v1, v2) -> struct(v1 AS left, v2 AS right)) FROM test_map_zip_with + +-- all literals +query +SELECT map_zip_with(map('a', 1), map('a', 2, 'b', 3), (k, v1, v2) -> coalesce(v1, 0) + coalesce(v2, 0)) diff --git a/spark/src/test/resources/sql-tests/expressions/map/transform_keys.sql b/spark/src/test/resources/sql-tests/expressions/map/transform_keys.sql new file mode 100644 index 0000000000..cc91075a72 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/map/transform_keys.sql @@ -0,0 +1,40 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Higher-order function transform_keys. Runs through Comet's codegen dispatcher. + +statement +CREATE TABLE test_transform_keys(m map, suffix string) USING parquet + +statement +INSERT INTO test_transform_keys VALUES + (map('a', 1, 'b', 2), 'X'), + (map('x', -1), 'Y'), + (map(), 'Z'), + (NULL, NULL) + +-- rewrite keys using key and value +query +SELECT transform_keys(m, (k, v) -> concat(k, cast(v AS string))) FROM test_transform_keys + +-- column capture +query +SELECT transform_keys(m, (k, v) -> concat(k, suffix)) FROM test_transform_keys + +-- all literals +query +SELECT transform_keys(map('a', 1, 'b', 2), (k, v) -> upper(k)) diff --git a/spark/src/test/resources/sql-tests/expressions/map/transform_values.sql b/spark/src/test/resources/sql-tests/expressions/map/transform_values.sql new file mode 100644 index 0000000000..50e49d8959 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/map/transform_values.sql @@ -0,0 +1,44 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Higher-order function transform_values. Runs through Comet's codegen dispatcher. + +statement +CREATE TABLE test_transform_values(m map, delta int) USING parquet + +statement +INSERT INTO test_transform_values VALUES + (map('a', 1, 'b', 2), 10), + (map('x', -1), 5), + (map(), 0), + (NULL, NULL) + +-- rewrite values using value +query +SELECT transform_values(m, (k, v) -> v + 1) FROM test_transform_values + +-- rewrite values using key and value +query +SELECT transform_values(m, (k, v) -> concat(k, '=', cast(v AS string))) FROM test_transform_values + +-- column capture +query +SELECT transform_values(m, (k, v) -> v + delta) FROM test_transform_values + +-- all literals +query +SELECT transform_values(map('a', 1, 'b', 2), (k, v) -> v * 100) diff --git a/spark/src/test/scala/org/apache/comet/CometCodegenSourceSuite.scala b/spark/src/test/scala/org/apache/comet/CometCodegenSourceSuite.scala index 9416861da4..87c52f153c 100644 --- a/spark/src/test/scala/org/apache/comet/CometCodegenSourceSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCodegenSourceSuite.scala @@ -700,6 +700,50 @@ class CometCodegenSourceSuite extends AnyFunSuite { s"expected innermost scalar getter for IntegerType element; got:\n$src") } + test("nested input classes emit copy() that deep-materializes off the Arrow buffers") { + // Higher-order functions (e.g. ArrayTransform) evaluate Spark's interpreted lambda, which + // calls InternalRow.copyValue on complex elements. The nested input views read straight off + // the per-batch Arrow buffers, so copy() must materialize into on-heap Spark structures: + // GenericArrayData for arrays, GenericInternalRow for structs, ArrayBasedMapData for maps. + // Strings clone (they alias off-heap memory) and nested complex elements recurse via copy(). + val stringChild = ScalarColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("VarCharVector"), + nullable = true) + val innerStruct = StructColumnSpec( + nullable = true, + fields = Seq(StructFieldSpec("s", StringType, nullable = true, stringChild))) + // Array>>: exercises array, map, and struct copy() together. + val mapSpec = MapColumnSpec( + nullable = true, + keySparkType = StringType, + valueSparkType = StructType(Seq(StructField("s", StringType))), + key = stringChild, + value = innerStruct) + val outerArray = ArrayColumnSpec( + nullable = true, + elementSparkType = MapType(StringType, StructType(Seq(StructField("s", StringType)))), + element = mapSpec) + val mapType = MapType(StringType, StructType(Seq(StructField("s", StringType)))) + val expr = Size(BoundReference(0, ArrayType(mapType), nullable = true)) + val src = generate(expr, IndexedSeq(outerArray)) + + assert( + src.contains("public org.apache.spark.sql.catalyst.util.ArrayData copy()"), + s"expected array copy() override; got:\n$src") + assert( + src.contains("new org.apache.spark.sql.catalyst.util.GenericArrayData("), + s"expected array copy() to materialize a GenericArrayData; got:\n$src") + assert( + src.contains("new org.apache.spark.sql.catalyst.util.ArrayBasedMapData("), + s"expected map copy() to materialize an ArrayBasedMapData; got:\n$src") + assert( + src.contains("new org.apache.spark.sql.catalyst.expressions.GenericInternalRow("), + s"expected struct copy() to materialize a GenericInternalRow; got:\n$src") + assert( + src.contains(".clone()"), + s"expected string elements to clone off the Arrow buffer in copy(); got:\n$src") + } + test("Array> emits array class allocating fresh InputStruct_col0_e") { val innerStruct = StructColumnSpec( nullable = true, diff --git a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala index b0be2b90ac..17feb3ce84 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala @@ -212,10 +212,10 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar .sortWithinPartitions($"_2") // Spark 4.0 normalizes shuffle keys containing array via - // transform(arr, x -> mapsort(x)), which Comet doesn't yet - // support, so the shuffle falls back to Spark. - val expectedShuffles = if (isSpark40Plus) 0 else 1 - checkShuffleAnswer(df, expectedShuffles) + // transform(arr, x -> mapsort(x)). Comet routes the higher-order + // transform through the codegen dispatcher, so the partitioning + // expression stays native and the shuffle runs as Comet shuffle. + checkShuffleAnswer(df, 1) } } }