diff --git a/.gitignore b/.gitignore index 5eab9ab03..4093b00d1 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,6 @@ out/** .metals .bloop .project +.classpath +.settings +bin/ diff --git a/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java b/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java index 340aa5b04..6343bc904 100644 --- a/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java +++ b/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java @@ -461,6 +461,19 @@ public O visit(Expression.IfThen expr, C context) throws E { return visitFallback(expr, context); } + /** + * Visits a Lambda expression. + * + * @param expr the Lambda expression + * @param context the visitation context + * @return the visit result + * @throws E if visitation fails + */ + @Override + public O visit(Expression.Lambda expr, C context) throws E { + return visitFallback(expr, context); + } + /** * Visits a scalar function invocation. * diff --git a/core/src/main/java/io/substrait/expression/Expression.java b/core/src/main/java/io/substrait/expression/Expression.java index 2197b0bd0..b116cbdc2 100644 --- a/core/src/main/java/io/substrait/expression/Expression.java +++ b/core/src/main/java/io/substrait/expression/Expression.java @@ -758,6 +758,30 @@ public R accept( } } + @Value.Immutable + abstract class Lambda implements Expression { + public abstract Type.Struct parameters(); + + public abstract Expression body(); + + @Override + public Type getType() { + List paramTypes = parameters().fields(); + Type returnType = body().getType(); + + // TODO: Type.Func nullability is hardcoded to false here because the spec does not allow for + // declaring otherwise. + // See: https://github.com/substrait-io/substrait/issues/976 + return Type.withNullability(false).func(paramTypes, returnType); + } + + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); + } + } + /** * Base interface for user-defined literals. * diff --git a/core/src/main/java/io/substrait/expression/ExpressionVisitor.java b/core/src/main/java/io/substrait/expression/ExpressionVisitor.java index 05540a924..147c7f8b7 100644 --- a/core/src/main/java/io/substrait/expression/ExpressionVisitor.java +++ b/core/src/main/java/io/substrait/expression/ExpressionVisitor.java @@ -321,6 +321,16 @@ public interface ExpressionVisitor outerReferenceStepsOut(); + public abstract Optional lambdaParameterReferenceStepsOut(); + @Override public Type getType() { return type(); @@ -38,13 +40,18 @@ public R accept( public boolean isSimpleRootReference() { return segments().size() == 1 && !inputExpression().isPresent() - && !outerReferenceStepsOut().isPresent(); + && !outerReferenceStepsOut().isPresent() + && !lambdaParameterReferenceStepsOut().isPresent(); } public boolean isOuterReference() { return outerReferenceStepsOut().orElse(0) > 0; } + public boolean isLambdaParameterReference() { + return lambdaParameterReferenceStepsOut().isPresent(); + } + public FieldReference dereferenceStruct(int index) { Type newType = StructFieldFinder.getReferencedType(type(), index); return dereference(newType, StructField.of(index)); @@ -134,6 +141,15 @@ public static FieldReference newInputRelReference(int index, List rels) { index, currentOffset)); } + public static FieldReference newLambdaParameterReference( + int paramIndex, Type.Struct lambdaParamsType, int stepsOut) { + return ImmutableFieldReference.builder() + .addSegments(StructField.of(paramIndex)) + .type(lambdaParamsType.fields().get(paramIndex)) + .lambdaParameterReferenceStepsOut(stepsOut) + .build(); + } + public interface ReferenceSegment { FieldReference apply(FieldReference reference); diff --git a/core/src/main/java/io/substrait/expression/LambdaBuilder.java b/core/src/main/java/io/substrait/expression/LambdaBuilder.java new file mode 100644 index 000000000..8709b6aa6 --- /dev/null +++ b/core/src/main/java/io/substrait/expression/LambdaBuilder.java @@ -0,0 +1,150 @@ +package io.substrait.expression; + +import io.substrait.type.Type; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Function; + +/** + * Builds lambda expressions with build-time validation of parameter references. + * + *

Maintains a stack of lambda parameter scopes. Each call to {@link #lambda} pushes parameters + * onto the stack, builds the body via a callback, and pops. Nested lambdas simply call {@code + * lambda()} again on the same builder. + * + *

The callback receives a {@link Scope} handle for creating validated parameter references. The + * correct {@code stepsOut} value is computed automatically from the stack. + * + *

{@code
+ * LambdaBuilder lb = new LambdaBuilder();
+ *
+ * // Simple: (x: i32) -> x
+ * Expression.Lambda simple = lb.lambda(List.of(R.I32), params -> params.ref(0));
+ *
+ * // Nested: (x: i32) -> (y: i64) -> add(x, y)
+ * Expression.Lambda nested = lb.lambda(List.of(R.I32), outer ->
+ *     lb.lambda(List.of(R.I64), inner ->
+ *         add(outer.ref(0), inner.ref(0))
+ *     )
+ * );
+ * }
+ */ +public class LambdaBuilder { + private final List lambdaContext = new ArrayList<>(); + + /** + * Builds a lambda expression. The body function receives a {@link Scope} for creating validated + * parameter references. Nested lambdas are built by calling this method again inside the + * callback. + * + * @param paramTypes the lambda's parameter types + * @param bodyFn function that builds the lambda body given a scope handle + * @return the constructed lambda expression + */ + public Expression.Lambda lambda(List paramTypes, Function bodyFn) { + Type.Struct params = Type.Struct.builder().nullable(false).addAllFields(paramTypes).build(); + pushLambdaContext(params); + try { + Scope scope = new Scope(params); + Expression body = bodyFn.apply(scope); + return ImmutableExpression.Lambda.builder().parameters(params).body(body).build(); + } finally { + popLambdaContext(); + } + } + + /** + * Builds a lambda expression from a pre-built parameter struct. Used by internal converters that + * already have a Type.Struct (e.g., during protobuf deserialization). + * + * @param params the lambda's parameter struct + * @param bodyFn function that builds the lambda body + * @return the constructed lambda expression + */ + public Expression.Lambda lambdaFromStruct( + Type.Struct params, java.util.function.Supplier bodyFn) { + pushLambdaContext(params); + try { + Expression body = bodyFn.get(); + return ImmutableExpression.Lambda.builder().parameters(params).body(body).build(); + } finally { + popLambdaContext(); + } + } + + /** + * Resolves the parameter struct for a lambda at the given stepsOut from the current innermost + * scope. Used by internal converters to validate lambda parameter references during + * deserialization. + * + * @param stepsOut number of lambda scopes to traverse outward (0 = current/innermost) + * @return the parameter struct at the target scope level + * @throws IllegalArgumentException if stepsOut exceeds the current nesting depth + */ + public Type.Struct resolveParams(int stepsOut) { + int targetDepth = lambdaContext.size() - stepsOut; + if (targetDepth <= 0 || targetDepth > lambdaContext.size()) { + throw new IllegalArgumentException( + String.format( + "Lambda parameter reference with stepsOut=%d is invalid (current depth: %d)", + stepsOut, lambdaContext.size())); + } + return lambdaContext.get(targetDepth - 1); + } + + /** + * Pushes a lambda's parameters onto the context stack. This makes the parameters available for + * validation when building the lambda's body, and allows nested lambda parameter references to + * correctly compute their stepsOut values. + */ + private void pushLambdaContext(Type.Struct params) { + lambdaContext.add(params); + } + + /** + * Pops the most recently pushed lambda parameters from the context stack. Called after a lambda's + * body has been built, restoring the context to the enclosing lambda's scope. + */ + private void popLambdaContext() { + lambdaContext.remove(lambdaContext.size() - 1); + } + + /** + * A handle to a particular lambda's parameter scope. Use {@link #ref} to create validated + * parameter references. + * + *

Each Scope captures the depth of the lambdaContext stack at the time it was created. When + * {@link #ref} is called, the Substrait {@code stepsOut} value is computed as the difference + * between the current stack depth and the captured depth. This means the same Scope produces + * different stepsOut values depending on the nesting level at the time of the call, which is what + * allows outer.ref(0) to produce stepsOut=1 when called inside a nested lambda. + */ + public class Scope { + private final Type.Struct params; + private final int depth; + + private Scope(Type.Struct params) { + this.params = params; + this.depth = lambdaContext.size(); + } + + /** + * Computes the number of lambda boundaries between this scope and the current innermost scope. + * This value changes dynamically as nested lambdas are built. + */ + private int stepsOut() { + return lambdaContext.size() - depth; + } + + /** + * Creates a validated reference to a parameter of this lambda. + * + * @param paramIndex index of the parameter within this lambda's parameter struct + * @return a {@link FieldReference} pointing to the specified parameter + * @throws IndexOutOfBoundsException if paramIndex is out of bounds + */ + public FieldReference ref(int paramIndex) { + return FieldReference.newLambdaParameterReference(paramIndex, params, stepsOut()); + } + } +} diff --git a/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java b/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java index eb2e45784..869aa6cc4 100644 --- a/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java @@ -387,6 +387,18 @@ public Expression visit( }); } + @Override + public Expression visit( + io.substrait.expression.Expression.Lambda expr, EmptyVisitationContext context) + throws RuntimeException { + return io.substrait.proto.Expression.newBuilder() + .setLambda( + io.substrait.proto.Expression.Lambda.newBuilder() + .setParameters(typeProtoConverter.toProto(expr.parameters()).getStruct()) + .setBody(expr.body().accept(this, context))) + .build(); + } + @Override public Expression visit( io.substrait.expression.Expression.UserDefinedAnyLiteral expr, @@ -617,6 +629,10 @@ public Expression visit(FieldReference expr, EmptyVisitationContext context) { out.setOuterReference( io.substrait.proto.Expression.FieldReference.OuterReference.newBuilder() .setStepsOut(expr.outerReferenceStepsOut().get())); + } else if (expr.lambdaParameterReferenceStepsOut().isPresent()) { + out.setLambdaParameterReference( + io.substrait.proto.Expression.FieldReference.LambdaParameterReference.newBuilder() + .setStepsOut(expr.lambdaParameterReferenceStepsOut().get())); } else { out.setRootReference(Expression.FieldReference.RootReference.getDefaultInstance()); } diff --git a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java index e4a9fffea..470c13c6a 100644 --- a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java @@ -6,6 +6,7 @@ import io.substrait.expression.FieldReference.ReferenceSegment; import io.substrait.expression.FunctionArg; import io.substrait.expression.FunctionOption; +import io.substrait.expression.LambdaBuilder; import io.substrait.expression.WindowBound; import io.substrait.extension.ExtensionLookup; import io.substrait.extension.SimpleExtension; @@ -37,6 +38,7 @@ public class ProtoExpressionConverter { private final Type.Struct rootType; private final ProtoTypeConverter protoTypeConverter; private final ProtoRelConverter protoRelConverter; + private final LambdaBuilder lambdaBuilder = new LambdaBuilder(); public ProtoExpressionConverter( ExtensionLookup lookup, @@ -75,6 +77,25 @@ public FieldReference from(io.substrait.proto.Expression.FieldReference referenc reference.getDirectReference().getStructField().getField(), rootType, reference.getOuterReference().getStepsOut()); + case LAMBDA_PARAMETER_REFERENCE: + { + io.substrait.proto.Expression.FieldReference.LambdaParameterReference lambdaParamRef = + reference.getLambdaParameterReference(); + + int stepsOut = lambdaParamRef.getStepsOut(); + Type.Struct lambdaParameters = lambdaBuilder.resolveParams(stepsOut); + + // Check for unsupported nested field access + if (reference.getDirectReference().getStructField().hasChild()) { + throw new UnsupportedOperationException( + "Nested field access in lambda parameters is not yet supported"); + } + + return FieldReference.newLambdaParameterReference( + reference.getDirectReference().getStructField().getField(), + lambdaParameters, + stepsOut); + } case ROOTTYPE_NOT_SET: default: throw new IllegalArgumentException("Unhandled type: " + reference.getRootTypeCase()); @@ -260,6 +281,18 @@ public Type visit(Type.Struct type) throws RuntimeException { } } + case LAMBDA: + { + io.substrait.proto.Expression.Lambda protoLambda = expr.getLambda(); + Type.Struct parameters = + (Type.Struct) + protoTypeConverter.from( + io.substrait.proto.Type.newBuilder() + .setStruct(protoLambda.getParameters()) + .build()); + + return lambdaBuilder.lambdaFromStruct(parameters, () -> from(protoLambda.getBody())); + } // TODO enum. case ENUM: throw new UnsupportedOperationException("Unsupported type: " + expr.getRexTypeCase()); diff --git a/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java b/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java index 7b316d4be..d8b6b1fa6 100644 --- a/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java +++ b/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java @@ -50,6 +50,9 @@ public class DefaultExtensionCatalog { /** Extension identifier for set functions. */ public static final String FUNCTIONS_SET = "extension:io.substrait:functions_set"; + /** Extension identifier for list functions. */ + public static final String FUNCTIONS_LIST = "extension:io.substrait:functions_list"; + /** Extension identifier for string functions. */ public static final String FUNCTIONS_STRING = "extension:io.substrait:functions_string"; @@ -75,6 +78,7 @@ private static SimpleExtension.ExtensionCollection loadDefaultCollection() { "arithmetic", "comparison", "datetime", + "list", "logarithmic", "rounding", "rounding_decimal", diff --git a/core/src/main/java/io/substrait/function/ExtendedTypeCreator.java b/core/src/main/java/io/substrait/function/ExtendedTypeCreator.java index c3360093d..b802d93e0 100644 --- a/core/src/main/java/io/substrait/function/ExtendedTypeCreator.java +++ b/core/src/main/java/io/substrait/function/ExtendedTypeCreator.java @@ -16,4 +16,6 @@ public interface ExtendedTypeCreator { T listE(T type); T mapE(T key, T value); + + T funcE(Iterable parameterTypes, T returnType); } diff --git a/core/src/main/java/io/substrait/function/ParameterizedType.java b/core/src/main/java/io/substrait/function/ParameterizedType.java index e514fb975..f35c21f7a 100644 --- a/core/src/main/java/io/substrait/function/ParameterizedType.java +++ b/core/src/main/java/io/substrait/function/ParameterizedType.java @@ -200,6 +200,23 @@ R accept(final ParameterizedTypeVisitor parameter } } + @Value.Immutable + abstract class Func extends BaseParameterizedType implements NullableType { + public abstract java.util.List parameterTypes(); + + public abstract ParameterizedType returnType(); + + public static ImmutableParameterizedType.Func.Builder builder() { + return ImmutableParameterizedType.Func.builder(); + } + + @Override + R accept(final ParameterizedTypeVisitor parameterizedTypeVisitor) + throws E { + return parameterizedTypeVisitor.visit(this); + } + } + @Value.Immutable abstract class ListType extends BaseParameterizedType implements NullableType { public abstract ParameterizedType name(); diff --git a/core/src/main/java/io/substrait/function/ParameterizedTypeCreator.java b/core/src/main/java/io/substrait/function/ParameterizedTypeCreator.java index 4c3f314e5..b35d8a781 100644 --- a/core/src/main/java/io/substrait/function/ParameterizedTypeCreator.java +++ b/core/src/main/java/io/substrait/function/ParameterizedTypeCreator.java @@ -103,6 +103,16 @@ public ParameterizedType listE(ParameterizedType type) { return ParameterizedType.ListType.builder().nullable(nullable).name(type).build(); } + @Override + public ParameterizedType funcE( + Iterable parameterTypes, ParameterizedType returnType) { + return ParameterizedType.Func.builder() + .nullable(nullable) + .addAllParameterTypes(parameterTypes) + .returnType(returnType) + .build(); + } + @Override public ParameterizedType mapE(ParameterizedType key, ParameterizedType value) { return ParameterizedType.Map.builder().nullable(nullable).key(key).value(value).build(); diff --git a/core/src/main/java/io/substrait/function/ParameterizedTypeVisitor.java b/core/src/main/java/io/substrait/function/ParameterizedTypeVisitor.java index 9ff42f549..755c99777 100644 --- a/core/src/main/java/io/substrait/function/ParameterizedTypeVisitor.java +++ b/core/src/main/java/io/substrait/function/ParameterizedTypeVisitor.java @@ -29,6 +29,8 @@ public interface ParameterizedTypeVisitor extends TypeVi R visit(ParameterizedType.StringLiteral stringLiteral) throws E; + R visit(ParameterizedType.Func expr) throws E; + abstract class ParameterizedTypeThrowsVisitor extends TypeVisitor.TypeThrowsVisitor implements ParameterizedTypeVisitor { @@ -100,5 +102,10 @@ public R visit(ParameterizedType.Map expr) throws E { public R visit(ParameterizedType.StringLiteral stringLiteral) throws E { throw t(); } + + @Override + public R visit(ParameterizedType.Func expr) throws E { + throw t(); + } } } diff --git a/core/src/main/java/io/substrait/function/ToTypeString.java b/core/src/main/java/io/substrait/function/ToTypeString.java index d6fc1bdb8..5c942f46a 100644 --- a/core/src/main/java/io/substrait/function/ToTypeString.java +++ b/core/src/main/java/io/substrait/function/ToTypeString.java @@ -150,6 +150,11 @@ public String visit(final Type.Map expr) { return "map"; } + @Override + public String visit(Type.Func type) throws RuntimeException { + return "func"; + } + @Override public String visit(final Type.UserDefined expr) { return String.format("u!%s", expr.name()); @@ -210,6 +215,11 @@ public String visit(ParameterizedType.Map expr) throws RuntimeException { return "map"; } + @Override + public String visit(ParameterizedType.Func expr) throws RuntimeException { + return "func"; + } + @Override public String visit(ParameterizedType.StringLiteral expr) throws RuntimeException { if (expr.value().toLowerCase().startsWith("any")) { diff --git a/core/src/main/java/io/substrait/function/TypeExpression.java b/core/src/main/java/io/substrait/function/TypeExpression.java index cc9bc068c..1ee1cae56 100644 --- a/core/src/main/java/io/substrait/function/TypeExpression.java +++ b/core/src/main/java/io/substrait/function/TypeExpression.java @@ -206,6 +206,22 @@ R acceptE(final TypeExpressionVisitor visitor) th } } + @Value.Immutable + abstract class Func extends BaseTypeExpression implements NullableType { + public abstract java.util.List parameterTypes(); + + public abstract TypeExpression returnType(); + + public static ImmutableTypeExpression.Func.Builder builder() { + return ImmutableTypeExpression.Func.builder(); + } + + @Override + R acceptE(final TypeExpressionVisitor visitor) throws E { + return visitor.visit(this); + } + } + @Value.Immutable abstract class BinaryOperation extends BaseTypeExpression { public enum OpType { diff --git a/core/src/main/java/io/substrait/function/TypeExpressionCreator.java b/core/src/main/java/io/substrait/function/TypeExpressionCreator.java index 808a0dd5a..62e1daee2 100644 --- a/core/src/main/java/io/substrait/function/TypeExpressionCreator.java +++ b/core/src/main/java/io/substrait/function/TypeExpressionCreator.java @@ -86,6 +86,16 @@ public TypeExpression mapE(TypeExpression key, TypeExpression value) { return TypeExpression.Map.builder().nullable(nullable).key(key).value(value).build(); } + @Override + public TypeExpression funcE( + Iterable parameterTypes, TypeExpression returnType) { + return TypeExpression.Func.builder() + .nullable(nullable) + .addAllParameterTypes(parameterTypes) + .returnType(returnType) + .build(); + } + public static class Assign { String name; TypeExpression expr; diff --git a/core/src/main/java/io/substrait/function/TypeExpressionVisitor.java b/core/src/main/java/io/substrait/function/TypeExpressionVisitor.java index 2ef76b50f..e1bef4398 100644 --- a/core/src/main/java/io/substrait/function/TypeExpressionVisitor.java +++ b/core/src/main/java/io/substrait/function/TypeExpressionVisitor.java @@ -26,6 +26,8 @@ public interface TypeExpressionVisitor R visit(TypeExpression.Map expr) throws E; + R visit(TypeExpression.Func expr) throws E; + R visit(TypeExpression.BinaryOperation expr) throws E; R visit(TypeExpression.NotOperation expr) throws E; @@ -104,6 +106,11 @@ public R visit(TypeExpression.Map expr) throws E { throw t(); } + @Override + public R visit(TypeExpression.Func expr) throws E { + throw t(); + } + @Override public R visit(TypeExpression.BinaryOperation expr) throws E { throw t(); diff --git a/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java b/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java index 1e9254716..b097614a0 100644 --- a/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java +++ b/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java @@ -8,6 +8,7 @@ import io.substrait.expression.ExpressionVisitor; import io.substrait.expression.FieldReference; import io.substrait.expression.FunctionArg; +import io.substrait.expression.ImmutableExpression; import io.substrait.util.EmptyVisitationContext; import java.util.List; import java.util.Optional; @@ -439,6 +440,21 @@ public Optional visit( .build()); } + @Override + public Optional visit(Expression.Lambda lambda, EmptyVisitationContext context) + throws E { + Optional newBody = lambda.body().accept(this, context); + + if (allEmpty(newBody)) { + return Optional.empty(); + } + return Optional.of( + ImmutableExpression.Lambda.builder() + .from(lambda) + .body(newBody.orElse(lambda.body())) + .build()); + } + // utilities protected Optional> visitExprList( diff --git a/core/src/main/java/io/substrait/relation/VirtualTableScan.java b/core/src/main/java/io/substrait/relation/VirtualTableScan.java index e71b0b00c..36e5b6dcc 100644 --- a/core/src/main/java/io/substrait/relation/VirtualTableScan.java +++ b/core/src/main/java/io/substrait/relation/VirtualTableScan.java @@ -242,5 +242,10 @@ public Integer visit(Type.Map type) throws RuntimeException { public Integer visit(Type.UserDefined type) throws RuntimeException { return 0; } + + @Override + public Integer visit(Type.Func type) throws RuntimeException { + return 0; + } } } diff --git a/core/src/main/java/io/substrait/type/StringTypeVisitor.java b/core/src/main/java/io/substrait/type/StringTypeVisitor.java index d7c196148..e9d711d96 100644 --- a/core/src/main/java/io/substrait/type/StringTypeVisitor.java +++ b/core/src/main/java/io/substrait/type/StringTypeVisitor.java @@ -150,4 +150,13 @@ public String visit(Type.Map type) throws RuntimeException { public String visit(Type.UserDefined type) throws RuntimeException { return String.format("u!%s%s", type.name(), n(type)); } + + @Override + public String visit(Type.Func type) throws RuntimeException { + return String.format( + "func%s<%s -> %s>", + n(type), + type.parameterTypes().stream().map(t -> t.accept(this)).collect(Collectors.joining(", ")), + type.returnType().accept(this)); + } } diff --git a/core/src/main/java/io/substrait/type/Type.java b/core/src/main/java/io/substrait/type/Type.java index 86eaa733c..e82f8c1cf 100644 --- a/core/src/main/java/io/substrait/type/Type.java +++ b/core/src/main/java/io/substrait/type/Type.java @@ -361,6 +361,22 @@ public R accept(final TypeVisitor typeVisitor) th } } + @Value.Immutable + abstract class Func implements Type { + public abstract java.util.List parameterTypes(); + + public abstract Type returnType(); + + public static ImmutableType.Func.Builder builder() { + return ImmutableType.Func.builder(); + } + + @Override + public R accept(TypeVisitor typeVisitor) throws E { + return typeVisitor.visit(this); + } + } + @Value.Immutable abstract class Struct implements Type { public abstract java.util.List fields(); diff --git a/core/src/main/java/io/substrait/type/TypeCreator.java b/core/src/main/java/io/substrait/type/TypeCreator.java index 999769cd9..6a897417e 100644 --- a/core/src/main/java/io/substrait/type/TypeCreator.java +++ b/core/src/main/java/io/substrait/type/TypeCreator.java @@ -84,6 +84,14 @@ public final Type intervalCompound(int precision) { return Type.IntervalCompound.builder().nullable(nullable).precision(precision).build(); } + public Type.Func func(java.util.List parameterTypes, Type returnType) { + return Type.Func.builder() + .nullable(nullable) + .parameterTypes(parameterTypes) + .returnType(returnType) + .build(); + } + public Type.Struct struct(Iterable types) { return Type.Struct.builder().nullable(nullable).addAllFields(types).build(); } diff --git a/core/src/main/java/io/substrait/type/TypeVisitor.java b/core/src/main/java/io/substrait/type/TypeVisitor.java index 9cf772232..d76b7a1f9 100644 --- a/core/src/main/java/io/substrait/type/TypeVisitor.java +++ b/core/src/main/java/io/substrait/type/TypeVisitor.java @@ -52,6 +52,8 @@ public interface TypeVisitor { R visit(Type.Decimal type) throws E; + R visit(Type.Func type) throws E; + R visit(Type.Struct type) throws E; R visit(Type.ListType type) throws E; @@ -192,6 +194,11 @@ public R visit(Type.PrecisionTimestampTZ type) throws E { throw t(); } + @Override + public R visit(Type.Func type) throws E { + throw t(); + } + @Override public R visit(Type.Struct type) throws E { throw t(); diff --git a/core/src/main/java/io/substrait/type/parser/ParseToPojo.java b/core/src/main/java/io/substrait/type/parser/ParseToPojo.java index 8555ab219..ddad886b1 100644 --- a/core/src/main/java/io/substrait/type/parser/ParseToPojo.java +++ b/core/src/main/java/io/substrait/type/parser/ParseToPojo.java @@ -1,5 +1,6 @@ package io.substrait.type.parser; +import io.substrait.function.ImmutableParameterizedType; import io.substrait.function.ImmutableTypeExpression; import io.substrait.function.ParameterizedType; import io.substrait.function.ParameterizedTypeCreator; @@ -411,19 +412,61 @@ public TypeExpression visitNStruct(final SubstraitTypeParser.NStructContext ctx) @Override public TypeExpression visitFunc(final SubstraitTypeParser.FuncContext ctx) { - throw new UnsupportedOperationException(); + boolean nullable = ctx.isnull != null; + + // Process function parameters + List paramExprs; + if (ctx.params instanceof SubstraitTypeParser.SingleFuncParamContext) { + paramExprs = + java.util.Collections.singletonList( + ((SubstraitTypeParser.SingleFuncParamContext) ctx.params).expr().accept(this)); + } else if (ctx.params instanceof SubstraitTypeParser.FuncParamsWithParensContext) { + paramExprs = + ((SubstraitTypeParser.FuncParamsWithParensContext) ctx.params) + .expr().stream() + .map(e -> e.accept(this)) + .collect(java.util.stream.Collectors.toList()); + } else { + throw new UnsupportedOperationException( + "Unknown funcParams type: " + ctx.params.getClass()); + } + + // Process return type + TypeExpression returnExpr = ctx.returnType.accept(this); + + // If all types are instances of Type, we return a Type + if (paramExprs.stream().allMatch(p -> p instanceof Type) && returnExpr instanceof Type) { + ImmutableType.Func.Builder builder = ImmutableType.Func.builder().nullable(nullable); + paramExprs.forEach(p -> builder.addParameterTypes((Type) p)); + return builder.returnType((Type) returnExpr).build(); + } + + // If all types are instances of ParameterizedType, we return a ParameterizedType + if (paramExprs.stream().allMatch(p -> p instanceof ParameterizedType) + && returnExpr instanceof ParameterizedType) { + checkParameterizedOrExpression(); + ImmutableParameterizedType.Func.Builder builder = + ParameterizedType.Func.builder().nullable(nullable); + paramExprs.forEach(p -> builder.addParameterTypes((ParameterizedType) p)); + return builder.returnType((ParameterizedType) returnExpr).build(); + } + + throw new UnsupportedOperationException( + "func type with TypeExpression-level parameter or return types are not yet supported"); } @Override public TypeExpression visitSingleFuncParam( final SubstraitTypeParser.SingleFuncParamContext ctx) { - throw new UnsupportedOperationException(); + throw new UnsupportedOperationException( + "visitSingleFuncParam is handled in visitFunc directly"); } @Override public TypeExpression visitFuncParamsWithParens( final SubstraitTypeParser.FuncParamsWithParensContext ctx) { - throw new UnsupportedOperationException(); + throw new UnsupportedOperationException( + "visitFuncParamsWithParens is handled in visitFunc directly"); } @Override diff --git a/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java b/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java index 67d7bc9b5..3909d4033 100644 --- a/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java +++ b/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java @@ -142,6 +142,16 @@ public final T visit(final Type.PrecisionTimestampTZ expr) { return typeContainer(expr).precisionTimestampTZ(expr.precision()); } + @Override + public final T visit(final Type.Func expr) { + return typeContainer(expr) + .func( + expr.parameterTypes().stream() + .map(t -> t.accept(this)) + .collect(java.util.stream.Collectors.toList()), + expr.returnType().accept(this)); + } + @Override public final T visit(final Type.Struct expr) { return typeContainer(expr) diff --git a/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java b/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java index 57b1f26b5..47842382f 100644 --- a/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java +++ b/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java @@ -119,6 +119,8 @@ public final T precisionTimestampTZ(int precision) { public abstract T intervalCompound(I precision); + public abstract T func(Iterable parameterTypes, T returnType); + public final T struct(T... types) { return struct(Arrays.asList(types)); } diff --git a/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java b/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java index bdb600c1c..24231aebc 100644 --- a/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java +++ b/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java @@ -77,6 +77,13 @@ public Type from(io.substrait.proto.Type type) { case PRECISION_TIMESTAMP_TZ: return n(type.getPrecisionTimestampTz().getNullability()) .precisionTimestampTZ(type.getPrecisionTimestampTz().getPrecision()); + case FUNC: + return n(type.getFunc().getNullability()) + .func( + type.getFunc().getParameterTypesList().stream() + .map(this::from) + .collect(java.util.stream.Collectors.toList()), + from(type.getFunc().getReturnType())); case STRUCT: return n(type.getStruct().getNullability()) .struct( diff --git a/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java b/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java index 6422904c4..c0e785db4 100644 --- a/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java +++ b/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java @@ -154,6 +154,16 @@ public Type precisionTimestampTZ(Integer precision) { .build()); } + @Override + public Type func(Iterable parameterTypes, Type returnType) { + return wrap( + Type.Func.newBuilder() + .addAllParameterTypes(parameterTypes) + .setReturnType(returnType) + .setNullability(nullability) + .build()); + } + @Override public Type struct(Iterable types) { return wrap(Type.Struct.newBuilder().addAllTypes(types).setNullability(nullability).build()); @@ -237,6 +247,8 @@ protected Type wrap(final Object o) { return bldr.setPrecisionTimestamp((Type.PrecisionTimestamp) o).build(); } else if (o instanceof Type.PrecisionTimestampTZ) { return bldr.setPrecisionTimestampTz((Type.PrecisionTimestampTZ) o).build(); + } else if (o instanceof Type.Func) { + return bldr.setFunc((Type.Func) o).build(); } else if (o instanceof Type.Struct) { return bldr.setStruct((Type.Struct) o).build(); } else if (o instanceof Type.List) { diff --git a/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java b/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java new file mode 100644 index 000000000..01303d932 --- /dev/null +++ b/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java @@ -0,0 +1,107 @@ +package io.substrait.expression; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import io.substrait.type.Type; +import io.substrait.type.TypeCreator; +import java.util.List; +import org.junit.jupiter.api.Test; + +/** Tests for {@link LambdaBuilder}. */ +class LambdaBuilderTest { + + static final TypeCreator R = TypeCreator.REQUIRED; + + final LambdaBuilder lb = new LambdaBuilder(); + + // (x: i32)@p -> p[0] + @Test + void simpleLambda() { + Expression.Lambda lambda = lb.lambda(List.of(R.I32), params -> params.ref(0)); + + Expression.Lambda expected = + ImmutableExpression.Lambda.builder() + .parameters(Type.Struct.builder().nullable(false).addFields(R.I32).build()) + .body( + FieldReference.newLambdaParameterReference( + 0, Type.Struct.builder().nullable(false).addFields(R.I32).build(), 0)) + .build(); + + assertEquals(expected, lambda); + } + + // (x: i32)@outer -> (y: i64)@inner -> outer[0] + @Test + void nestedLambda() { + Expression.Lambda lambda = + lb.lambda(List.of(R.I32), outer -> lb.lambda(List.of(R.I64), inner -> outer.ref(0))); + + Expression.Lambda expectedInner = + ImmutableExpression.Lambda.builder() + .parameters(Type.Struct.builder().nullable(false).addFields(R.I64).build()) + .body( + FieldReference.newLambdaParameterReference( + 0, Type.Struct.builder().nullable(false).addFields(R.I32).build(), 1)) + .build(); + + Expression.Lambda expected = + ImmutableExpression.Lambda.builder() + .parameters(Type.Struct.builder().nullable(false).addFields(R.I32).build()) + .body(expectedInner) + .build(); + + assertEquals(expected, lambda); + } + + // Verify that the same scope handle produces different stepsOut values depending on nesting. + // outer.ref(0) should produce stepsOut=0 at the top level and stepsOut=1 inside a nested lambda. + @Test + void scopeStepsOutChangesDynamically() { + lb.lambda( + List.of(R.I32), + outer -> { + FieldReference atTopLevel = outer.ref(0); + assertEquals(0, atTopLevel.lambdaParameterReferenceStepsOut().orElse(-1)); + + lb.lambda( + List.of(R.I64), + inner -> { + FieldReference atNestedLevel = outer.ref(0); + assertEquals(1, atNestedLevel.lambdaParameterReferenceStepsOut().orElse(-1)); + return inner.ref(0); + }); + + return atTopLevel; + }); + } + + // (x: i32)@p -> p[5] — only 1 param, index 5 is out of bounds + @Test + void invalidFieldIndex_outOfBounds() { + assertThrows( + IndexOutOfBoundsException.class, () -> lb.lambda(List.of(R.I32), params -> params.ref(5))); + } + + // (x: i32)@p -> p[-1] — negative index + @Test + void negativeFieldIndex() { + assertThrows(Exception.class, () -> lb.lambda(List.of(R.I32), params -> params.ref(-1))); + } + + // (x: i32)@outer -> (y: i64)@inner -> outer[5] — outer only has 1 param + @Test + void nestedOuterFieldIndexOutOfBounds() { + assertThrows( + IndexOutOfBoundsException.class, + () -> lb.lambda(List.of(R.I32), outer -> lb.lambda(List.of(R.I64), inner -> outer.ref(5)))); + } + + // (x: i32)@outer -> (y: i64)@inner -> inner[3] — inner only has 1 param + @Test + void nestedInnerFieldIndexOutOfBounds() { + assertThrows( + IndexOutOfBoundsException.class, + () -> lb.lambda(List.of(R.I32), outer -> lb.lambda(List.of(R.I64), inner -> inner.ref(3)))); + } +} diff --git a/core/src/test/java/io/substrait/extension/DefaultExtensionCatalogTest.java b/core/src/test/java/io/substrait/extension/DefaultExtensionCatalogTest.java new file mode 100644 index 000000000..7aed4e961 --- /dev/null +++ b/core/src/test/java/io/substrait/extension/DefaultExtensionCatalogTest.java @@ -0,0 +1,13 @@ +package io.substrait.extension; + +import static org.junit.jupiter.api.Assertions.assertNotNull; + +import org.junit.jupiter.api.Test; + +class DefaultExtensionCatalogTest { + + @Test + void defaultCollectionLoads() { + assertNotNull(DefaultExtensionCatalog.DEFAULT_COLLECTION); + } +} diff --git a/core/src/test/java/io/substrait/type/parser/TestTypeParser.java b/core/src/test/java/io/substrait/type/parser/TestTypeParser.java index 010ad123a..92989795b 100644 --- a/core/src/test/java/io/substrait/type/parser/TestTypeParser.java +++ b/core/src/test/java/io/substrait/type/parser/TestTypeParser.java @@ -8,6 +8,7 @@ import io.substrait.type.ImmutableType; import io.substrait.type.Type; import io.substrait.type.TypeCreator; +import java.util.List; import org.junit.jupiter.api.Test; class TestTypeParser { @@ -120,6 +121,11 @@ private void compoundTests(ParseToPojo.Visitor v) { test(v, n.precisionTimestamp(9), "PRECISION_TIMESTAMP?<9>"); test(v, r.precisionTimestampTZ(6), "PRECISION_TIMESTAMP_TZ<6>"); test(v, n.precisionTimestampTZ(9), "PRECISION_TIMESTAMP_TZ?<9>"); + + test(v, r.func(List.of(r.I8), r.I32), "func i32>"); + test(v, r.func(List.of(r.I8, r.I8), r.I32), "func<(i8, i8) -> i32>"); + test(v, n.func(List.of(r.I8), r.I32), "func? i32>"); + test(v, r.func(List.of(n.I8), n.I32), "func i32?>"); } private void parameterizedTests(ParseToPojo.Visitor v) { @@ -142,6 +148,16 @@ private void parameterizedTests(ParseToPojo.Visitor v) { test(v, pr.precisionTimeE("P"), "PRECISION_TIME

"); test(v, pr.precisionTimestampE("P"), "PRECISION_TIMESTAMP

"); test(v, pr.precisionTimestampTZE("P"), "PRECISION_TIMESTAMP_TZ

"); + + test(v, pr.funcE(List.of(pr.parameter("any")), r.I64), "func i64>"); + test(v, pr.funcE(List.of(pr.parameter("any"), r.I64), r.I64), "func<(any, i64) -> i64>"); + test(v, pr.funcE(List.of(pr.parameter("any1")), pr.parameter("any1")), "func any1>"); + test(v, pn.funcE(List.of(pr.parameter("any")), n.I64), "func? i64?>"); + test(v, pn.funcE(List.of(pr.parameter("any1")), pr.parameter("any1")), "func? any1>"); + test( + v, + pr.funcE(List.of(pr.parameter("any1"), r.I8), pr.parameter("any1")), + "func<(any1, i8) -> any1>"); } @Test diff --git a/core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java new file mode 100644 index 000000000..1589c35ba --- /dev/null +++ b/core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java @@ -0,0 +1,57 @@ +package io.substrait.type.proto; + +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import com.google.protobuf.util.JsonFormat; +import io.substrait.TestBase; +import io.substrait.expression.Expression; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.stream.Stream; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +class LambdaExpressionRoundtripTest extends TestBase { + + static Stream validLambdaExpressions() throws IOException { + return listJsonResources("expressions/lambda/valid"); + } + + static Stream invalidLambdaExpressions() throws IOException { + return listJsonResources("expressions/lambda/invalid"); + } + + @ParameterizedTest + @MethodSource("validLambdaExpressions") + void validLambdaExpressionRoundtrip(String resourcePath) throws IOException { + Expression deserialized = deserializeExpression(resourcePath); + assertInstanceOf(Expression.Lambda.class, deserialized); + verifyRoundTrip(deserialized); + } + + @ParameterizedTest + @MethodSource("invalidLambdaExpressions") + void invalidLambdaExpressionRejected(String resourcePath) { + assertThrows(Exception.class, () -> deserializeExpression(resourcePath)); + } + + private static Stream listJsonResources(String dirPath) throws IOException { + Path dir = + Paths.get( + LambdaExpressionRoundtripTest.class.getClassLoader().getResource(dirPath).getPath()); + return Files.list(dir) + .filter(p -> p.toString().endsWith(".json")) + .map(p -> dirPath + "/" + p.getFileName().toString()) + .sorted(); + } + + private Expression deserializeExpression(String resourcePath) throws IOException { + String json = asString(resourcePath); + io.substrait.proto.Expression.Builder builder = io.substrait.proto.Expression.newBuilder(); + JsonFormat.parser().merge(json, builder); + return protoExpressionConverter.from(builder.build()); + } +} diff --git a/core/src/test/resources/expressions/lambda/invalid/nested_steps_out.json b/core/src/test/resources/expressions/lambda/invalid/nested_steps_out.json new file mode 100644 index 000000000..54ec7321b --- /dev/null +++ b/core/src/test/resources/expressions/lambda/invalid/nested_steps_out.json @@ -0,0 +1,30 @@ +{ + "lambda": { + "parameters": { + "types": [ + { "i32": { "nullability": "NULLABILITY_REQUIRED" } } + ] + }, + "body": { + "lambda": { + "parameters": { + "types": [ + { "i64": { "nullability": "NULLABILITY_REQUIRED" } } + ] + }, + "body": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "lambdaParameterReference": { + "stepsOut": 2 + } + } + } + } + } + } +} diff --git a/core/src/test/resources/expressions/lambda/invalid/steps_out.json b/core/src/test/resources/expressions/lambda/invalid/steps_out.json new file mode 100644 index 000000000..148173e8b --- /dev/null +++ b/core/src/test/resources/expressions/lambda/invalid/steps_out.json @@ -0,0 +1,21 @@ +{ + "lambda": { + "parameters": { + "types": [ + { "i32": { "nullability": "NULLABILITY_REQUIRED" } } + ] + }, + "body": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "lambdaParameterReference": { + "stepsOut": 1 + } + } + } + } +} diff --git a/core/src/test/resources/expressions/lambda/valid/identity.json b/core/src/test/resources/expressions/lambda/valid/identity.json new file mode 100644 index 000000000..a984334ef --- /dev/null +++ b/core/src/test/resources/expressions/lambda/valid/identity.json @@ -0,0 +1,21 @@ +{ + "lambda": { + "parameters": { + "types": [ + { "i32": { "nullability": "NULLABILITY_REQUIRED" } } + ] + }, + "body": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "lambdaParameterReference": { + "stepsOut": 0 + } + } + } + } +} diff --git a/core/src/test/resources/expressions/lambda/valid/literal_body.json b/core/src/test/resources/expressions/lambda/valid/literal_body.json new file mode 100644 index 000000000..c36e2df5e --- /dev/null +++ b/core/src/test/resources/expressions/lambda/valid/literal_body.json @@ -0,0 +1,14 @@ +{ + "lambda": { + "parameters": { + "types": [ + { "i32": { "nullability": "NULLABILITY_REQUIRED" } } + ] + }, + "body": { + "literal": { + "i32": 42 + } + } + } +} diff --git a/core/src/test/resources/expressions/lambda/valid/multi_param.json b/core/src/test/resources/expressions/lambda/valid/multi_param.json new file mode 100644 index 000000000..1c31bc1e7 --- /dev/null +++ b/core/src/test/resources/expressions/lambda/valid/multi_param.json @@ -0,0 +1,23 @@ +{ + "lambda": { + "parameters": { + "types": [ + { "i32": { "nullability": "NULLABILITY_REQUIRED" } }, + { "i64": { "nullability": "NULLABILITY_REQUIRED" } }, + { "string": { "nullability": "NULLABILITY_REQUIRED" } } + ] + }, + "body": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "lambdaParameterReference": { + "stepsOut": 0 + } + } + } + } +} diff --git a/core/src/test/resources/expressions/lambda/valid/nested.json b/core/src/test/resources/expressions/lambda/valid/nested.json new file mode 100644 index 000000000..abf716b7a --- /dev/null +++ b/core/src/test/resources/expressions/lambda/valid/nested.json @@ -0,0 +1,30 @@ +{ + "lambda": { + "parameters": { + "types": [ + { "i32": { "nullability": "NULLABILITY_REQUIRED" } } + ] + }, + "body": { + "lambda": { + "parameters": { + "types": [ + { "i64": { "nullability": "NULLABILITY_REQUIRED" } } + ] + }, + "body": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "lambdaParameterReference": { + "stepsOut": 1 + } + } + } + } + } + } +} diff --git a/core/src/test/resources/expressions/lambda/valid/triple_nested.json b/core/src/test/resources/expressions/lambda/valid/triple_nested.json new file mode 100644 index 000000000..e1e692521 --- /dev/null +++ b/core/src/test/resources/expressions/lambda/valid/triple_nested.json @@ -0,0 +1,39 @@ +{ + "lambda": { + "parameters": { + "types": [ + { "i32": { "nullability": "NULLABILITY_REQUIRED" } } + ] + }, + "body": { + "lambda": { + "parameters": { + "types": [ + { "i64": { "nullability": "NULLABILITY_REQUIRED" } } + ] + }, + "body": { + "lambda": { + "parameters": { + "types": [ + { "string": { "nullability": "NULLABILITY_REQUIRED" } } + ] + }, + "body": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "lambdaParameterReference": { + "stepsOut": 2 + } + } + } + } + } + } + } + } +} diff --git a/core/src/test/resources/expressions/lambda/valid/zero_params.json b/core/src/test/resources/expressions/lambda/valid/zero_params.json new file mode 100644 index 000000000..df801c9cd --- /dev/null +++ b/core/src/test/resources/expressions/lambda/valid/zero_params.json @@ -0,0 +1,12 @@ +{ + "lambda": { + "parameters": { + "types": [] + }, + "body": { + "literal": { + "i32": 42 + } + } + } +} diff --git a/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java b/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java index a22756cc9..2d345642c 100644 --- a/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java +++ b/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java @@ -202,6 +202,12 @@ public String visit(Expression.NestedStruct expr, EmptyVisitationContext context return ""; } + @Override + public String visit(Expression.Lambda expr, EmptyVisitationContext context) + throws RuntimeException { + return ""; + } + @Override public String visit(UserDefinedAnyLiteral expr, EmptyVisitationContext context) throws RuntimeException { diff --git a/examples/substrait-spark/src/main/java/io/substrait/examples/util/TypeStringify.java b/examples/substrait-spark/src/main/java/io/substrait/examples/util/TypeStringify.java index 0e13c3e2e..f5f8d93e9 100644 --- a/examples/substrait-spark/src/main/java/io/substrait/examples/util/TypeStringify.java +++ b/examples/substrait-spark/src/main/java/io/substrait/examples/util/TypeStringify.java @@ -149,6 +149,11 @@ public String visit(Decimal type) throws RuntimeException { return type.getClass().getSimpleName(); } + @Override + public String visit(Type.Func type) throws RuntimeException { + return type.getClass().getSimpleName(); + } + @Override public String visit(Struct type) throws RuntimeException { StringBuffer sb = new StringBuffer(type.getClass().getSimpleName()); diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java index 5fd9efdcb..d7147a594 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java @@ -42,6 +42,7 @@ import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexFieldCollation; import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexLambdaRef; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexSubQuery; @@ -423,6 +424,23 @@ public RexNode visit(Expression.IfThen expr, Context context) throws RuntimeExce return rexBuilder.makeCall(SqlStdOperatorTable.CASE, args); } + @Override + public RexNode visit(Expression.Lambda expr, Context context) throws RuntimeException { + List parameters = + IntStream.range(0, expr.parameters().fields().size()) + .mapToObj( + i -> + new RexLambdaRef( + i, + "p" + i, + typeConverter.toCalcite(typeFactory, expr.parameters().fields().get(i)))) + .collect(Collectors.toList()); + + RexNode body = expr.body().accept(this, context); + + return rexBuilder.makeLambdaCall(body, parameters); + } + @Override public RexNode visit(Switch expr, Context context) throws RuntimeException { RexNode match = expr.match().accept(this, context); @@ -697,6 +715,23 @@ public RexNode visit(FieldReference expr, Context context) throws RuntimeExcepti } return rexInputRef; + } else if (expr.isLambdaParameterReference()) { + // as of now calcite doesn't support nested lambda functions + // https://github.com/substrait-io/substrait-java/issues/711 + int stepsOut = expr.lambdaParameterReferenceStepsOut().get(); + if (stepsOut != 0) { + throw new UnsupportedOperationException( + "Calcite does not support nested lambdas (stepsOut=" + stepsOut + ")"); + } + + final ReferenceSegment segment = expr.segments().get(0); + if (segment instanceof FieldReference.StructField) { + final FieldReference.StructField field = (FieldReference.StructField) segment; + RelDataType calciteType = typeConverter.toCalcite(typeFactory, expr.getType()); + return new RexLambdaRef(field.offset(), "p" + field.offset(), calciteType); + } else { + throw new IllegalArgumentException("Unhandled type: " + segment); + } } return visitFallback(expr, context); diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java index 90f04a326..cc2b098e2 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java @@ -5,14 +5,32 @@ import java.util.Locale; import java.util.Map; import java.util.Set; +import org.apache.calcite.sql.SqlBasicFunction; +import org.apache.calcite.sql.SqlFunction; import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.fun.SqlLibraryOperators; import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.type.OperandTypes; +import org.apache.calcite.sql.type.SqlTypeFamily; public class FunctionMappings { // Static list of signature mapping between Calcite SQL operators and Substrait base function // names. + /** The transform:list_func function; applies a lambda to each element of an array. */ + public static final SqlFunction TRANSFORM = + SqlBasicFunction.create( + "transform", + opBinding -> opBinding.getTypeFactory().createArrayType(opBinding.getOperandType(1), -1), + OperandTypes.family(SqlTypeFamily.ARRAY, SqlTypeFamily.ANY)); + + /** The filter:list_func function; filters elements of an array using a predicate lambda. */ + public static final SqlFunction FILTER = + SqlBasicFunction.create( + "filter", + opBinding -> opBinding.getOperandType(0), + OperandTypes.family(SqlTypeFamily.ARRAY, SqlTypeFamily.ANY)); + public static final ImmutableList SCALAR_SIGS = ImmutableList.builder() .add( @@ -100,7 +118,9 @@ public class FunctionMappings { s(SqlLibraryOperators.RPAD, "rpad"), s(SqlLibraryOperators.PARSE_TIME, "strptime_time"), s(SqlLibraryOperators.PARSE_TIMESTAMP, "strptime_timestamp"), - s(SqlLibraryOperators.PARSE_DATE, "strptime_date")) + s(SqlLibraryOperators.PARSE_DATE, "strptime_date"), + s(TRANSFORM, "transform"), + s(FILTER, "filter")) .build(); public static final ImmutableList AGGREGATE_SIGS = diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/IgnoreNullableAndParameters.java b/isthmus/src/main/java/io/substrait/isthmus/expression/IgnoreNullableAndParameters.java index f8b4be1dd..b56fa4fd3 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/IgnoreNullableAndParameters.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/IgnoreNullableAndParameters.java @@ -128,6 +128,11 @@ public Boolean visit(Type.Decimal type) { return typeToMatch instanceof Type.Decimal || typeToMatch instanceof ParameterizedType.Decimal; } + @Override + public Boolean visit(Type.Func type) throws RuntimeException { + return typeToMatch instanceof Type.Func || typeToMatch instanceof ParameterizedType.Func; + } + @Override public Boolean visit(Type.PrecisionTime type) { return typeToMatch instanceof Type.PrecisionTime @@ -234,4 +239,9 @@ public Boolean visit(ParameterizedType.Map expr) throws RuntimeException { public Boolean visit(ParameterizedType.StringLiteral stringLiteral) throws RuntimeException { return false; } + + @Override + public Boolean visit(ParameterizedType.Func expr) throws RuntimeException { + return typeToMatch instanceof Type.Func || typeToMatch instanceof ParameterizedType.Func; + } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java index 6993c8451..d5330681d 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java @@ -2,11 +2,13 @@ import io.substrait.expression.Expression; import io.substrait.expression.FieldReference; +import io.substrait.expression.ImmutableExpression; import io.substrait.isthmus.CallConverter; import io.substrait.isthmus.SubstraitRelVisitor; import io.substrait.isthmus.TypeConverter; import io.substrait.relation.Rel; import io.substrait.type.StringTypeVisitor; +import io.substrait.type.Type; import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -202,12 +204,30 @@ public Expression visitPatternFieldRef(RexPatternFieldRef fieldRef) { @Override public Expression visitLambda(RexLambda rexLambda) { - throw new UnsupportedOperationException("RexLambda not supported"); + List paramTypes = + rexLambda.getParameters().stream() + .map(param -> typeConverter.toSubstrait(param.getType())) + .collect(Collectors.toList()); + + Type.Struct parameters = Type.Struct.builder().nullable(false).addAllFields(paramTypes).build(); + + Expression body = rexLambda.getExpression().accept(this); + + return ImmutableExpression.Lambda.builder().parameters(parameters).body(body).build(); } @Override public Expression visitLambdaRef(RexLambdaRef rexLambdaRef) { - throw new UnsupportedOperationException("RexLambdaRef not supported"); + int fieldIndex = rexLambdaRef.getIndex(); + Type paramType = typeConverter.toSubstrait(rexLambdaRef.getType()); + + return FieldReference.builder() + .addSegments(FieldReference.StructField.of(fieldIndex)) + .type(paramType) + .lambdaParameterReferenceStepsOut( + 0) // Always 0 since Calcite doesn't support nested Lambda expressions for now + // https://github.com/substrait-io/substrait-java/issues/711 + .build(); } @Override diff --git a/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java b/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java new file mode 100644 index 000000000..fb33407b8 --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java @@ -0,0 +1,77 @@ +package io.substrait.isthmus; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +import io.substrait.expression.Expression; +import io.substrait.expression.ExpressionCreator; +import io.substrait.expression.LambdaBuilder; +import io.substrait.extension.DefaultExtensionCatalog; +import io.substrait.relation.Project; +import io.substrait.relation.Rel; +import java.util.ArrayList; +import java.util.List; +import org.junit.jupiter.api.Test; + +class LambdaExpressionTest extends PlanTestBase { + + final Rel emptyTable = sb.emptyVirtualTableScan(); + final LambdaBuilder lb = new LambdaBuilder(); + + // () -> 42 + @Test + void lambdaExpressionZeroParameters() { + List exprs = new ArrayList<>(); + exprs.add(lb.lambda(List.of(), params -> ExpressionCreator.i32(false, 42))); + + Project project = Project.builder().expressions(exprs).input(emptyTable).build(); + assertFullRoundTrip(project); + } + + // (x: i32, y: i64, z: string) -> x + @Test + void validFieldIndex() { + List exprs = new ArrayList<>(); + exprs.add(lb.lambda(List.of(R.I32, R.I64, R.STRING), params -> params.ref(0))); + + Project project = Project.builder().expressions(exprs).input(emptyTable).build(); + assertFullRoundTrip(project); + } + + // (x: i32) -> 42 + @Test + void lambdaWithLiteralBody() { + List exprs = new ArrayList<>(); + exprs.add(lb.lambda(List.of(R.I32), params -> ExpressionCreator.i32(false, 42))); + + Project project = Project.builder().expressions(exprs).input(emptyTable).build(); + assertFullRoundTrip(project); + } + + // (x: i64) -> (y: i32) -> x — Calcite doesn't support nested lambdas + @Test + void nestedLambdaThrowsUnsupportedOperation() { + Expression.Lambda outerLambda = + lb.lambda(List.of(R.I64), outer -> lb.lambda(List.of(R.I32), inner -> outer.ref(0))); + + List exprs = new ArrayList<>(); + exprs.add(outerLambda); + Project project = Project.builder().expressions(exprs).input(emptyTable).build(); + assertThrows(UnsupportedOperationException.class, () -> assertFullRoundTrip(project)); + } + + // (x: i64)@p -> add(p[0], p[0]) + @Test + void lambdaWithArithmeticBody() { + String ARITH = DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC; + + Expression.Lambda lambda = + lb.lambda( + List.of(R.I64), + params -> sb.scalarFn(ARITH, "add:i64_i64", R.I64, params.ref(0), params.ref(0))); + + List exprs = new ArrayList<>(); + exprs.add(lambda); + Project project = Project.builder().expressions(exprs).input(emptyTable).build(); + assertFullRoundTrip(project); + } +} diff --git a/isthmus/src/test/java/io/substrait/isthmus/LambdaRoundtripTest.java b/isthmus/src/test/java/io/substrait/isthmus/LambdaRoundtripTest.java new file mode 100644 index 000000000..427dc67ba --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/LambdaRoundtripTest.java @@ -0,0 +1,38 @@ +package io.substrait.isthmus; + +import com.google.protobuf.util.JsonFormat; +import io.substrait.plan.Plan; +import io.substrait.plan.ProtoPlanConverter; +import java.io.IOException; +import org.junit.jupiter.api.Test; + +class LambdaRoundtripTest extends PlanTestBase { + + public static io.substrait.proto.Plan readJsonPlan(String resourcePath) throws IOException { + String json = asString(resourcePath); + io.substrait.proto.Plan.Builder builder = io.substrait.proto.Plan.newBuilder(); + JsonFormat.parser().merge(json, builder); + return builder.build(); + } + + @Test + void testBasicLambdaRoundtrip() throws IOException { + io.substrait.proto.Plan proto = readJsonPlan("lambdas/basic-lambda.json"); + Plan plan = new ProtoPlanConverter(extensions).from(proto); + assertFullRoundTrip(plan.getRoots().get(0)); + } + + @Test + void testLambdaWithFieldRefRoundtrip() throws IOException { + io.substrait.proto.Plan proto = readJsonPlan("lambdas/lambda-field-ref.json"); + Plan plan = new ProtoPlanConverter(extensions).from(proto); + assertFullRoundTrip(plan.getRoots().get(0)); + } + + @Test + void testLambdaWithFunctionRoundtrip() throws IOException { + io.substrait.proto.Plan proto = readJsonPlan("lambdas/lambda-with-function.json"); + Plan plan = new ProtoPlanConverter(extensions).from(proto); + assertFullRoundTrip(plan.getRoots().get(0)); + } +} diff --git a/isthmus/src/test/resources/lambdas/basic-lambda.json b/isthmus/src/test/resources/lambdas/basic-lambda.json new file mode 100644 index 000000000..114e3ad6d --- /dev/null +++ b/isthmus/src/test/resources/lambdas/basic-lambda.json @@ -0,0 +1,132 @@ +{ + "version": { + "majorNumber": 0, + "minorNumber": 79 + }, + "extensionUrns": [ + { + "extensionUrnAnchor": 1, + "urn": "extension:io.substrait:functions_list" + } + ], + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_list.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUrnReference": 1, + "extensionUriReference": 1, + "functionAnchor": 1, + "name": "transform:list_func" + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "values" + ], + "struct": { + "nullability": "NULLABILITY_REQUIRED", + "types": [ + { + "list": { + "type": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "nullability": "NULLABILITY_REQUIRED" + } + } + ] + } + }, + "namedTable": { + "names": [ + "test_table" + ] + } + } + }, + "expressions": [ + { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "list": { + "type": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "lambda": { + "parameters": { + "nullability": "NULLABILITY_REQUIRED", + "types": [ + { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + } + ] + }, + "body": { + "literal": { + "i32": 42 + } + } + } + } + } + ] + } + } + ] + } + }, + "names": [ + "result" + ] + } + } + ] +} diff --git a/isthmus/src/test/resources/lambdas/lambda-field-ref.json b/isthmus/src/test/resources/lambdas/lambda-field-ref.json new file mode 100644 index 000000000..58c041582 --- /dev/null +++ b/isthmus/src/test/resources/lambdas/lambda-field-ref.json @@ -0,0 +1,135 @@ +{ + "extensionUrns": [ + { + "extensionUrnAnchor": 1, + "urn": "extension:io.substrait:functions_list" + } + ], + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_list.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUrnReference": 1, + "extensionUriReference": 1, + "functionAnchor": 1, + "name": "transform:list_func" + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "values" + ], + "struct": { + "nullability": "NULLABILITY_REQUIRED", + "types": [ + { + "list": { + "type": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "nullability": "NULLABILITY_REQUIRED" + } + } + ] + } + }, + "namedTable": { + "names": [ + "test_table" + ] + } + } + }, + "expressions": [ + { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "list": { + "type": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "lambda": { + "parameters": { + "nullability": "NULLABILITY_REQUIRED", + "types": [ + { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + } + ] + }, + "body": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "lambdaParameterReference": { + "stepsOut": 0 + } + } + } + } + } + } + ] + } + } + ] + } + }, + "names": [ + "result" + ] + } + } + ] +} diff --git a/isthmus/src/test/resources/lambdas/lambda-with-function.json b/isthmus/src/test/resources/lambdas/lambda-with-function.json new file mode 100644 index 000000000..9c0a1a55b --- /dev/null +++ b/isthmus/src/test/resources/lambdas/lambda-with-function.json @@ -0,0 +1,172 @@ +{ + "extensionUrns": [ + { + "extensionUrnAnchor": 1, + "urn": "extension:io.substrait:functions_arithmetic" + }, + { + "extensionUrnAnchor": 2, + "urn": "extension:io.substrait:functions_list" + } + ], + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml" + }, + { + "extensionUriAnchor": 2, + "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_list.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUrnReference": 1, + "extensionUriReference": 1, + "functionAnchor": 1, + "name": "multiply:i32_i32" + } + }, + { + "extensionFunction": { + "extensionUrnReference": 2, + "extensionUriReference": 2, + "functionAnchor": 2, + "name": "transform:list_func" + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "values" + ], + "struct": { + "nullability": "NULLABILITY_REQUIRED", + "types": [ + { + "list": { + "type": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "nullability": "NULLABILITY_REQUIRED" + } + } + ] + } + }, + "namedTable": { + "names": [ + "test_table" + ] + } + } + }, + "expressions": [ + { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "list": { + "type": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "lambda": { + "parameters": { + "nullability": "NULLABILITY_REQUIRED", + "types": [ + { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + } + ] + }, + "body": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "lambdaParameterReference": { + "stepsOut": 0 + } + } + } + }, + { + "value": { + "literal": { + "i32": 2 + } + } + } + ] + } + } + } + } + } + ] + } + } + ] + } + }, + "names": [ + "result" + ] + } + } + ] +} diff --git a/spark/src/main/scala/io/substrait/spark/expression/IgnoreNullableAndParameters.scala b/spark/src/main/scala/io/substrait/spark/expression/IgnoreNullableAndParameters.scala index c280e1fb1..0db03388e 100644 --- a/spark/src/main/scala/io/substrait/spark/expression/IgnoreNullableAndParameters.scala +++ b/spark/src/main/scala/io/substrait/spark/expression/IgnoreNullableAndParameters.scala @@ -158,4 +158,12 @@ class IgnoreNullableAndParameters(val typeToMatch: ParameterizedType) @throws[RuntimeException] override def visit(precisionTimestampTZ: Type.PrecisionTimestampTZ): Boolean = typeToMatch.isInstanceOf[Type.PrecisionTimestampTZ] + + @throws[RuntimeException] + override def visit(`type`: Type.Func): Boolean = + typeToMatch.isInstanceOf[Type.Func] || typeToMatch.isInstanceOf[ParameterizedType.Func] + + @throws[RuntimeException] + override def visit(expr: ParameterizedType.Func): Boolean = + typeToMatch.isInstanceOf[Type.Func] || typeToMatch.isInstanceOf[ParameterizedType.Func] }