Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 12 additions & 14 deletions docs/source/user-guide/latest/expressions.md
Original file line number Diff line number Diff line change
Expand Up @@ -348,22 +348,20 @@ expression-level). The `outer` variants are wired but marked `Incompatible`; the

## lambda_funcs

All higher-order functions are planned via [#4224](https://github.com/apache/datafusion-comet/issues/4224).

| Function | Status | Notes |
| --- | --- | --- |
| `aggregate` | 🔜 | [#4224](https://github.com/apache/datafusion-comet/issues/4224) |
| `array_sort` | 🔜 | [#4224](https://github.com/apache/datafusion-comet/issues/4224) |
| `exists` | 🔜 | [#4224](https://github.com/apache/datafusion-comet/issues/4224) |
| `filter` | 🔜 | [#4224](https://github.com/apache/datafusion-comet/issues/4224) |
| `forall` | 🔜 | [#4224](https://github.com/apache/datafusion-comet/issues/4224) |
| `map_filter` | 🔜 | [#4224](https://github.com/apache/datafusion-comet/issues/4224) |
| `map_zip_with` | 🔜 | [#4224](https://github.com/apache/datafusion-comet/issues/4224) |
| `reduce` | 🔜 | [#4224](https://github.com/apache/datafusion-comet/issues/4224) |
| `transform` | 🔜 | [#4224](https://github.com/apache/datafusion-comet/issues/4224) |
| `transform_keys` | 🔜 | [#4224](https://github.com/apache/datafusion-comet/issues/4224) |
| `transform_values` | 🔜 | [#4224](https://github.com/apache/datafusion-comet/issues/4224) |
| `zip_with` | 🔜 | [#4224](https://github.com/apache/datafusion-comet/issues/4224) |
| `aggregate` | | |
| `array_sort` | | |
| `exists` | | |
| `filter` | 🔜 | General lambda not yet wired; the `array_compact` form is supported ([#4224](https://github.com/apache/datafusion-comet/issues/4224)) |
| `forall` | | |
| `map_filter` | | |
| `map_zip_with` | | |
| `reduce` | | |
| `transform` | | |
| `transform_keys` | | |
| `transform_values` | | |
| `zip_with` | | |

---

Expand Down
57 changes: 57 additions & 0 deletions native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -974,6 +974,63 @@ impl PhysicalPlanner {
}
}

/// DataFusion's nested comparison kernel (`apply_cmp_for_nested`) requires both operands to
/// have identical data types, including nested field nullability, whereas Spark comparisons
/// ignore nullability. When a comparison's operands are nested types that differ only in
/// nullability (e.g. a higher-order `transform` produces `List(non-null Struct)` while the
/// other side is `List(nullable Struct)`), cast both to their nullability-union type so the
/// kernel accepts them. Non-comparison ops and non-nested or already-matching types are left
/// untouched.
pub fn reconcile_nested_comparison_types(
left: Arc<dyn PhysicalExpr>,
right: Arc<dyn PhysicalExpr>,
op: &DataFusionOperator,
input_schema: &SchemaRef,
) -> (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>) {
use DataFusionOperator::*;
let is_cmp = matches!(
op,
Eq | NotEq | Lt | LtEq | Gt | GtEq | IsDistinctFrom | IsNotDistinctFrom
);
if !is_cmp {
return (left, right);
}
let (lt, rt) = match (left.data_type(input_schema), right.data_type(input_schema)) {
(Ok(lt), Ok(rt)) => (lt, rt),
_ => return (left, right),
};
// Only nested types route through `apply_cmp_for_nested`; primitives coerce fine.
let nested = matches!(
lt,
DataType::List(_)
| DataType::LargeList(_)
| DataType::FixedSizeList(_, _)
| DataType::Struct(_)
| DataType::Map(_, _)
);
if !nested || lt.equals_datatype(&rt) {
return (left, right);
}
// `Field::try_merge` unions nullability recursively while preserving structure (and the
// Map/list invariants). Bail out unchanged if the structures are genuinely incompatible.
let mut merged = Field::new("c", lt.clone(), true);
if merged
.try_merge(&Field::new("c", rt.clone(), true))
.is_err()
{
return (left, right);
}
let target = merged.data_type().clone();
let cast_to_target = |e: Arc<dyn PhysicalExpr>, dt: &DataType| -> Arc<dyn PhysicalExpr> {
if dt.equals_datatype(&target) {
e
} else {
Arc::new(CastExpr::new(e, target.clone(), None))
}
};
(cast_to_target(left, &lt), cast_to_target(right, &rt))
}

/// Create a DataFusion physical plan from Spark physical plan. There is a level of
/// abstraction where a tree of SparkPlan nodes is returned. There is a 1:1 mapping from a
/// protobuf Operator (that represents a Spark operator) to a native SparkPlan struct. We
Expand Down
14 changes: 13 additions & 1 deletion native/core/src/execution/planner/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,19 @@ macro_rules! binary_expr_builder {
expr.left.as_ref().unwrap(),
std::sync::Arc::clone(&input_schema),
)?;
let right = planner.create_expr(expr.right.as_ref().unwrap(), input_schema)?;
let right = planner.create_expr(
expr.right.as_ref().unwrap(),
std::sync::Arc::clone(&input_schema),
)?;
// Reconcile nested operand nullability for comparisons (Spark ignores nullability,
// DataFusion's nested comparison kernel does not).
let (left, right) =
$crate::execution::planner::PhysicalPlanner::reconcile_nested_comparison_types(
left,
right,
&$operator,
&input_schema,
);
Ok(std::sync::Arc::new(
datafusion::physical_expr::expressions::BinaryExpr::new(left, $operator, right),
))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,7 @@ private[codegen] object CometBatchKernelCodegenInput {
| return $elemPath.${nullCheckMethod(spec.element)}(startIndex + i);
| }""".stripMargin
val elementGetter = emitArrayElementGetter(path, spec)
val copy = emitArrayCopyMethod(spec)
s""" private final class InputArray_$path extends $baseClassName {
| private final int startIndex;
| private final int length;
Expand All @@ -543,10 +544,77 @@ private[codegen] object CometBatchKernelCodegenInput {
|$isNullAt
|
|$elementGetter
|
|$copy
| }
|""".stripMargin
}

/**
* Emit `copy()` for an `InputArray_${path}`. These views read straight off the off-heap Arrow
* buffers, which are only valid for the current batch, so Spark's `InternalRow.copyValue`
* (invoked by e.g. `ArrayTransform.nullSafeEval` when a lambda passes a complex element
* through) must deep-materialize into on-heap `GenericArrayData`. Scalars autobox, strings
* clone off the Arrow buffer, and nested array/struct/map elements recurse through their own
* `copy()`.
*/
private def emitArrayCopyMethod(spec: ArrayColumnSpec): String = {
val getter = elementGetterCall(spec.elementSparkType, "__i")
val copyExpr = copyValueExpr(getter, spec.elementSparkType)
val assign =
if (spec.element.nullable) {
s""" if (isNullAt(__i)) {
| __vals[__i] = null;
| } else {
| __vals[__i] = $copyExpr;
| }""".stripMargin
} else {
s" __vals[__i] = $copyExpr;"
}
s""" @Override
| public org.apache.spark.sql.catalyst.util.ArrayData copy() {
| int __n = numElements();
| Object[] __vals = new Object[__n];
| for (int __i = 0; __i < __n; __i++) {
|$assign
| }
| return new org.apache.spark.sql.catalyst.util.GenericArrayData(__vals);
| }""".stripMargin
}

/**
* The typed getter call (`getX(idx)`) used to read a value of `dt` out of a nested array
* element or struct field. `idx` is the index/ordinal token (e.g. `"__i"` or `"3"`).
*/
private def elementGetterCall(dt: DataType, idx: String): String = dt match {
case BooleanType => s"getBoolean($idx)"
case ByteType => s"getByte($idx)"
case ShortType => s"getShort($idx)"
case IntegerType | DateType => s"getInt($idx)"
case LongType | TimestampType | TimestampNTZType => s"getLong($idx)"
case FloatType => s"getFloat($idx)"
case DoubleType => s"getDouble($idx)"
case d: DecimalType => s"getDecimal($idx, ${d.precision}, ${d.scale})"
case _: StringType => s"getUTF8String($idx)"
case BinaryType => s"getBinary($idx)"
case _: ArrayType => s"getArray($idx)"
case _: StructType => s"getStruct($idx, ${dt.asInstanceOf[StructType].fields.length})"
case _: MapType => s"getMap($idx)"
case other =>
throw new UnsupportedOperationException(s"nested copy: unsupported type $other")
}

/**
* Wrap a non-null getter expression so the produced value is detached from the Arrow buffers:
* primitives/decimals/binary are already by-value, strings clone, and nested complex values
* recurse through their own `copy()`.
*/
private def copyValueExpr(getter: String, dt: DataType): String = dt match {
case _: StringType => s"$getter.clone()"
case _: ArrayType | _: StructType | _: MapType => s"$getter.copy()"
case _ => getter
}

/**
* Element-getter body for a nested array. Scalar -> direct typed read. Complex -> allocate a
* fresh inner view.
Expand Down Expand Up @@ -690,6 +758,7 @@ private[codegen] object CometBatchKernelCodegenInput {
}
val scalarGetters = emitStructScalarGetters(path, spec)
val complexGetters = emitStructComplexGetters(path, spec)
val copy = emitStructCopyMethod(spec)
s""" private final class InputStruct_$path extends $baseClassName {
| private final int rowIdx;
|
Expand All @@ -713,10 +782,40 @@ private[codegen] object CometBatchKernelCodegenInput {
|
|$scalarGetters
|$complexGetters
|
|$copy
| }
|""".stripMargin
}

/**
* Emit `copy()` for an `InputStruct_${path}`. Deep-materializes into an on-heap
* `GenericInternalRow` so Spark's `InternalRow.copyValue` (e.g. a lambda passing a struct
* element through) detaches it from the per-batch Arrow buffers. Mirrors
* [[emitArrayCopyMethod]].
*/
private def emitStructCopyMethod(spec: StructColumnSpec): String = {
val assigns = spec.fields.zipWithIndex.map { case (f, fi) =>
val getter = elementGetterCall(f.sparkType, fi.toString)
val copyExpr = copyValueExpr(getter, f.sparkType)
if (f.nullable) {
s""" if (isNullAt($fi)) {
| __vals[$fi] = null;
| } else {
| __vals[$fi] = $copyExpr;
| }""".stripMargin
} else {
s" __vals[$fi] = $copyExpr;"
}
}
s""" @Override
| public org.apache.spark.sql.catalyst.InternalRow copy() {
| Object[] __vals = new Object[${spec.fields.length}];
|${assigns.mkString("\n")}
| return new org.apache.spark.sql.catalyst.expressions.GenericInternalRow(__vals);
| }""".stripMargin
}

// Scalar-read body templates parameterized on row-index expression (`idx`), cached buffer
// addresses (`valueAddr`, `offsetAddr`) for unsafe reads, or the Arrow field for the decimal
// slow path. `ind` is the per-line indent.
Expand Down Expand Up @@ -941,6 +1040,12 @@ private[codegen] object CometBatchKernelCodegenInput {
| public org.apache.spark.sql.catalyst.util.ArrayData valueArray() {
| return new InputArray_$valPath(this.startIndex, this.length);
| }
|
| @Override
| public org.apache.spark.sql.catalyst.util.MapData copy() {
| return new org.apache.spark.sql.catalyst.util.ArrayBasedMapData(
| keyArray().copy(), valueArray().copy());
| }
| }
|""".stripMargin
}
Expand Down
14 changes: 12 additions & 2 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,13 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim {
classOf[Flatten] -> CometFlatten,
classOf[GetArrayItem] -> CometGetArrayItem,
classOf[Size] -> CometSize,
classOf[ArraysZip] -> CometArraysZip)
classOf[ArraysZip] -> CometArraysZip,
classOf[ArrayTransform] -> CometArrayTransform,
classOf[ArrayExists] -> CometArrayExists,
classOf[ArrayForAll] -> CometArrayForAll,
classOf[ArrayAggregate] -> CometArrayAggregate,
classOf[ArraySort] -> CometArraySort,
classOf[ZipWith] -> CometZipWith)

private val conditionalExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] =
Map(classOf[CaseWhen] -> CometCaseWhen, classOf[If] -> CometIf)
Expand Down Expand Up @@ -153,7 +159,11 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim {
classOf[MapFromArrays] -> CometMapFromArrays,
classOf[MapContainsKey] -> CometMapContainsKey,
classOf[MapFromEntries] -> CometMapFromEntries,
classOf[StringToMap] -> CometStrToMap)
classOf[StringToMap] -> CometStrToMap,
classOf[MapFilter] -> CometMapFilter,
classOf[TransformKeys] -> CometTransformKeys,
classOf[TransformValues] -> CometTransformValues,
classOf[MapZipWith] -> CometMapZipWith)

private[comet] val structExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] =
Map(
Expand Down
14 changes: 13 additions & 1 deletion spark/src/main/scala/org/apache/comet/serde/arrays.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ package org.apache.comet.serde
import scala.annotation.tailrec
import scala.jdk.CollectionConverters._

import org.apache.spark.sql.catalyst.expressions.{And, ArrayAppend, ArrayContains, ArrayExcept, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayPosition, ArrayRemove, ArrayRepeat, ArraysOverlap, ArraysZip, ArrayUnion, Attribute, Cast, CreateArray, ElementAt, EmptyRow, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse, Size, Slice, SortArray}
import org.apache.spark.sql.catalyst.expressions.{And, ArrayAggregate, ArrayAppend, ArrayContains, ArrayExcept, ArrayExists, ArrayFilter, ArrayForAll, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayPosition, ArrayRemove, ArrayRepeat, ArraySort, ArraysOverlap, ArraysZip, ArrayTransform, ArrayUnion, Attribute, Cast, CreateArray, ElementAt, EmptyRow, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse, Size, Slice, SortArray, ZipWith}
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -843,3 +843,15 @@ trait ArraysBase {
}
}
}

object CometArrayTransform extends CometCodegenDispatch[ArrayTransform]

object CometArrayExists extends CometCodegenDispatch[ArrayExists]

object CometArrayForAll extends CometCodegenDispatch[ArrayForAll]

object CometArrayAggregate extends CometCodegenDispatch[ArrayAggregate]

object CometArraySort extends CometCodegenDispatch[ArraySort]

object CometZipWith extends CometCodegenDispatch[ZipWith]
8 changes: 8 additions & 0 deletions spark/src/main/scala/org/apache/comet/serde/maps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,11 @@ object CometMapFromEntries extends CometScalarFunction[MapFromEntries]("map_from
}

object CometStrToMap extends CometScalarFunction[StringToMap]("str_to_map")

object CometMapFilter extends CometCodegenDispatch[MapFilter]

object CometTransformKeys extends CometCodegenDispatch[TransformKeys]

object CometTransformValues extends CometCodegenDispatch[TransformValues]

object CometMapZipWith extends CometCodegenDispatch[MapZipWith]
45 changes: 45 additions & 0 deletions spark/src/test/resources/sql-tests/expressions/array/aggregate.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
-- Licensed to the Apache Software Foundation (ASF) under one
-- or more contributor license agreements. See the NOTICE file
-- distributed with this work for additional information
-- regarding copyright ownership. The ASF licenses this file
-- to you under the Apache License, Version 2.0 (the
-- "License"); you may not use this file except in compliance
-- with the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing,
-- software distributed under the License is distributed on an
-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-- KIND, either express or implied. See the License for the
-- specific language governing permissions and limitations
-- under the License.

-- Higher-order function aggregate (a.k.a. reduce). Runs through Comet's codegen dispatcher.

statement
CREATE TABLE test_aggregate(a array<int>, start int) USING parquet

statement
INSERT INTO test_aggregate VALUES
(array(1, 2, 3), 0),
(array(-5, 5), 100),
(array(10), 0),
(array(), 0),
(NULL, NULL)

-- basic sum
query
SELECT aggregate(a, 0, (acc, x) -> acc + x) FROM test_aggregate

-- column capture in the initial value
query
SELECT aggregate(a, start, (acc, x) -> acc + x) FROM test_aggregate

-- with a finish function
query
SELECT aggregate(a, 0, (acc, x) -> acc + x, acc -> acc * 10) FROM test_aggregate

-- all literals
query
SELECT aggregate(array(1, 2, 3, 4), 0, (acc, x) -> acc + x)
Loading
Loading