From 5ca6ff8bfa94229ac22546fbc41c5ce6d493dabb Mon Sep 17 00:00:00 2001 From: Limame Malainine Date: Tue, 24 Feb 2026 15:03:53 +0100 Subject: [PATCH 01/26] feat: add lambda expression support --- .../expression/AbstractExpressionVisitor.java | 26 + .../io/substrait/expression/Expression.java | 47 ++ .../expression/ExpressionVisitor.java | 20 + .../substrait/expression/FieldReference.java | 18 +- .../proto/ExpressionProtoConverter.java | 28 ++ .../proto/ProtoExpressionConverter.java | 59 +++ .../extension/DefaultExtensionCatalog.java | 3 + .../function/ExtendedTypeCreator.java | 2 + .../substrait/function/ParameterizedType.java | 17 + .../function/ParameterizedTypeCreator.java | 10 + .../function/ParameterizedTypeVisitor.java | 7 + .../io/substrait/function/TypeExpression.java | 16 + .../function/TypeExpressionCreator.java | 10 + .../function/TypeExpressionVisitor.java | 7 + .../ExpressionCopyOnWriteVisitor.java | 29 ++ .../substrait/relation/VirtualTableScan.java | 6 + .../io/substrait/type/StringTypeVisitor.java | 9 + .../src/main/java/io/substrait/type/Type.java | 16 + .../java/io/substrait/type/TypeCreator.java | 13 + .../java/io/substrait/type/TypeVisitor.java | 7 + .../type/proto/BaseProtoConverter.java | 10 + .../substrait/type/proto/BaseProtoTypes.java | 2 + .../proto/ParameterizedProtoConverter.java | 7 + .../type/proto/ProtoTypeConverter.java | 7 + .../proto/TypeExpressionProtoVisitor.java | 7 + .../type/proto/TypeProtoConverter.java | 12 + .../proto/LambdaExpressionRoundtripTest.java | 471 ++++++++++++++++++ .../examples/util/ExpressionStringify.java | 12 + .../examples/util/TypeStringify.java | 5 + .../expression/ExpressionRexConverter.java | 33 ++ .../IgnoreNullableAndParameters.java | 10 + .../expression/RexExpressionConverter.java | 23 +- .../isthmus/LambdaExpressionTest.java | 191 +++++++ .../IgnoreNullableAndParameters.scala | 8 + 34 files changed, 1145 insertions(+), 3 deletions(-) create mode 100644 core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java create mode 100644 isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java diff --git a/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java b/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java index e6ebb5782..a61f5d824 100644 --- a/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java +++ b/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java @@ -448,6 +448,32 @@ public O visit(Expression.IfThen expr, C context) throws E { return visitFallback(expr, context); } + /** + * Visits a Lambda expression. + * + * @param expr the Lambda expression + * @param context the visitation context + * @return the visit result + * @throws E if visitation fails + */ + @Override + public O visit(Expression.Lambda expr, C context) throws E { + return visitFallback(expr, context); + } + + /** + * Visits a Lambda expression invocation. + * + * @param expr the Lambda expression invocation + * @param context the visitation context + * @return the visit result + * @throws E if visitation fails + */ + @Override + public O visit(Expression.LambdaInvocation expr, C context) throws E { + return visitFallback(expr, context); + } + /** * Visits a scalar function invocation. * diff --git a/core/src/main/java/io/substrait/expression/Expression.java b/core/src/main/java/io/substrait/expression/Expression.java index 6361af3f5..1301c5c5c 100644 --- a/core/src/main/java/io/substrait/expression/Expression.java +++ b/core/src/main/java/io/substrait/expression/Expression.java @@ -707,6 +707,31 @@ public R accept( } } + @Value.Immutable + abstract class Lambda implements Expression { + public abstract Type.Struct parameters(); + + public abstract Expression body(); + + @Override + public Type getType() { + List paramTypes = parameters().fields(); + Type returnType = body().getType(); + + return Type.withNullability(false).func(paramTypes, returnType); + } + + public static ImmutableExpression.Lambda.Builder builder() { + return ImmutableExpression.Lambda.builder(); + } + + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); + } + } + /** * Base interface for user-defined literals. * @@ -902,6 +927,28 @@ public R accept( } } + @Value.Immutable + abstract class LambdaInvocation implements Expression { + public abstract Lambda lambda(); + + public abstract Expression.NestedStruct arguments(); + + @Override + public Type getType() { + return ((Type.Func) lambda().getType()).returnType(); + } + + public static ImmutableExpression.LambdaInvocation.Builder builder() { + return ImmutableExpression.LambdaInvocation.builder(); + } + + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); + } + } + @Value.Immutable abstract class ScalarFunctionInvocation implements Expression { public abstract SimpleExtension.ScalarFunctionVariant declaration(); diff --git a/core/src/main/java/io/substrait/expression/ExpressionVisitor.java b/core/src/main/java/io/substrait/expression/ExpressionVisitor.java index 7f094b688..e36d92e4f 100644 --- a/core/src/main/java/io/substrait/expression/ExpressionVisitor.java +++ b/core/src/main/java/io/substrait/expression/ExpressionVisitor.java @@ -311,6 +311,16 @@ public interface ExpressionVisitor outerReferenceStepsOut(); + public abstract Optional lambdaParameterReferenceStepsOut(); + @Override public Type getType() { return type(); @@ -38,13 +40,18 @@ public R accept( public boolean isSimpleRootReference() { return segments().size() == 1 && !inputExpression().isPresent() - && !outerReferenceStepsOut().isPresent(); + && !outerReferenceStepsOut().isPresent() + && !lambdaParameterReferenceStepsOut().isPresent(); } public boolean isOuterReference() { return outerReferenceStepsOut().orElse(0) > 0; } + public boolean isLambdaParameterReference() { + return lambdaParameterReferenceStepsOut().isPresent(); + } + public FieldReference dereferenceStruct(int index) { Type newType = StructFieldFinder.getReferencedType(type(), index); return dereference(newType, StructField.of(index)); @@ -134,6 +141,15 @@ public static FieldReference newInputRelReference(int index, List rels) { index, currentOffset)); } + public static FieldReference newLambdaParameterReference( + int paramIndex, Type.Struct lambdaParamsType, int stepsOut) { + return ImmutableFieldReference.builder() + .addSegments(StructField.of(paramIndex)) + .type(lambdaParamsType.fields().get(paramIndex)) + .lambdaParameterReferenceStepsOut(stepsOut) + .build(); + } + public interface ReferenceSegment { FieldReference apply(FieldReference reference); diff --git a/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java b/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java index 498e6eada..45d587d73 100644 --- a/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java @@ -373,6 +373,18 @@ public Expression visit( }); } + @Override + public Expression visit( + io.substrait.expression.Expression.Lambda expr, EmptyVisitationContext context) + throws RuntimeException { + return io.substrait.proto.Expression.newBuilder() + .setLambda( + io.substrait.proto.Expression.Lambda.newBuilder() + .setParameters(typeProtoConverter.toProto(expr.parameters()).getStruct()) + .setBody(expr.body().accept(this, context))) + .build(); + } + @Override public Expression visit( io.substrait.expression.Expression.UserDefinedAnyLiteral expr, @@ -468,6 +480,18 @@ public Expression visit( .build(); } + @Override + public Expression visit( + io.substrait.expression.Expression.LambdaInvocation expr, EmptyVisitationContext context) + throws RuntimeException { + return io.substrait.proto.Expression.newBuilder() + .setLambdaInvocation( + io.substrait.proto.Expression.LambdaInvocation.newBuilder() + .setLambda(expr.lambda().accept(this, context).getLambda()) + .setArguments(expr.arguments().accept(this, context).getNested().getStruct())) + .build(); + } + @Override public Expression visit( io.substrait.expression.Expression.ScalarFunctionInvocation expr, @@ -603,6 +627,10 @@ public Expression visit(FieldReference expr, EmptyVisitationContext context) { out.setOuterReference( io.substrait.proto.Expression.FieldReference.OuterReference.newBuilder() .setStepsOut(expr.outerReferenceStepsOut().get())); + } else if (expr.lambdaParameterReferenceStepsOut().isPresent()) { + out.setLambdaParameterReference( + io.substrait.proto.Expression.FieldReference.LambdaParameterReference.newBuilder() + .setStepsOut(expr.lambdaParameterReferenceStepsOut().get())); } else { out.setRootReference(Expression.FieldReference.RootReference.getDefaultInstance()); } diff --git a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java index 01e25c907..fd00e55b2 100644 --- a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java @@ -37,6 +37,7 @@ public class ProtoExpressionConverter { private final Type.Struct rootType; private final ProtoTypeConverter protoTypeConverter; private final ProtoRelConverter protoRelConverter; + private final List lambdaParameterStack = new ArrayList<>(); public ProtoExpressionConverter( ExtensionLookup lookup, @@ -75,6 +76,26 @@ public FieldReference from(io.substrait.proto.Expression.FieldReference referenc reference.getDirectReference().getStructField().getField(), rootType, reference.getOuterReference().getStepsOut()); + case LAMBDA_PARAMETER_REFERENCE: + { + io.substrait.proto.Expression.FieldReference.LambdaParameterReference lambdaParamRef = + reference.getLambdaParameterReference(); + + int stepsOut = lambdaParamRef.getStepsOut(); + int lambdaIndex = lambdaParameterStack.size() - 1 - stepsOut; + if (lambdaIndex < 0 || lambdaIndex >= lambdaParameterStack.size()) { + throw new IllegalArgumentException( + String.format( + "Lambda parameter reference with stepsOut=%d is invalid (current depth: %d)", + stepsOut, lambdaParameterStack.size())); + } + + Type.Struct lambdaParameters = lambdaParameterStack.get(lambdaIndex); + return FieldReference.newLambdaParameterReference( + reference.getDirectReference().getStructField().getField(), + lambdaParameters, + stepsOut); + } case ROOTTYPE_NOT_SET: default: throw new IllegalArgumentException("Unhandled type: " + reference.getRootTypeCase()); @@ -260,6 +281,44 @@ public Type visit(Type.Struct type) throws RuntimeException { } } + case LAMBDA: + { + io.substrait.proto.Expression.Lambda protoLambda = expr.getLambda(); + Type.Struct parameters = + (Type.Struct) + protoTypeConverter.from( + io.substrait.proto.Type.newBuilder() + .setStruct(protoLambda.getParameters()) + .build()); + + lambdaParameterStack.add(parameters); + + Expression body; + try { + body = from(protoLambda.getBody()); + } finally { + lambdaParameterStack.remove(lambdaParameterStack.size() - 1); + } + + return Expression.Lambda.builder().parameters(parameters).body(body).build(); + } + case LAMBDA_INVOCATION: + { + io.substrait.proto.Expression.LambdaInvocation protoInvocation = + expr.getLambdaInvocation(); + + Expression.Lambda lambda = + (Expression.Lambda) + from( + io.substrait.proto.Expression.newBuilder() + .setLambda(protoInvocation.getLambda()) + .build()); + + Expression.NestedStruct arguments = from(protoInvocation.getArguments()); + + return Expression.LambdaInvocation.builder().lambda(lambda).arguments(arguments).build(); + } + // TODO enum. case ENUM: throw new UnsupportedOperationException("Unsupported type: " + expr.getRexTypeCase()); diff --git a/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java b/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java index 7b316d4be..478ed63e2 100644 --- a/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java +++ b/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java @@ -50,6 +50,9 @@ public class DefaultExtensionCatalog { /** Extension identifier for set functions. */ public static final String FUNCTIONS_SET = "extension:io.substrait:functions_set"; + /** Extension identifier for list functions. */ + public static final String FUNCTIONS_LIST = "extension:io.substrait:functions_list"; + /** Extension identifier for string functions. */ public static final String FUNCTIONS_STRING = "extension:io.substrait:functions_string"; diff --git a/core/src/main/java/io/substrait/function/ExtendedTypeCreator.java b/core/src/main/java/io/substrait/function/ExtendedTypeCreator.java index c3360093d..b802d93e0 100644 --- a/core/src/main/java/io/substrait/function/ExtendedTypeCreator.java +++ b/core/src/main/java/io/substrait/function/ExtendedTypeCreator.java @@ -16,4 +16,6 @@ public interface ExtendedTypeCreator { T listE(T type); T mapE(T key, T value); + + T funcE(Iterable parameterTypes, T returnType); } diff --git a/core/src/main/java/io/substrait/function/ParameterizedType.java b/core/src/main/java/io/substrait/function/ParameterizedType.java index e514fb975..f35c21f7a 100644 --- a/core/src/main/java/io/substrait/function/ParameterizedType.java +++ b/core/src/main/java/io/substrait/function/ParameterizedType.java @@ -200,6 +200,23 @@ R accept(final ParameterizedTypeVisitor parameter } } + @Value.Immutable + abstract class Func extends BaseParameterizedType implements NullableType { + public abstract java.util.List parameterTypes(); + + public abstract ParameterizedType returnType(); + + public static ImmutableParameterizedType.Func.Builder builder() { + return ImmutableParameterizedType.Func.builder(); + } + + @Override + R accept(final ParameterizedTypeVisitor parameterizedTypeVisitor) + throws E { + return parameterizedTypeVisitor.visit(this); + } + } + @Value.Immutable abstract class ListType extends BaseParameterizedType implements NullableType { public abstract ParameterizedType name(); diff --git a/core/src/main/java/io/substrait/function/ParameterizedTypeCreator.java b/core/src/main/java/io/substrait/function/ParameterizedTypeCreator.java index 6b89840f6..af7bda7e1 100644 --- a/core/src/main/java/io/substrait/function/ParameterizedTypeCreator.java +++ b/core/src/main/java/io/substrait/function/ParameterizedTypeCreator.java @@ -96,6 +96,16 @@ public ParameterizedType listE(ParameterizedType type) { return ParameterizedType.ListType.builder().nullable(nullable).name(type).build(); } + @Override + public ParameterizedType funcE( + Iterable parameterTypes, ParameterizedType returnType) { + return ParameterizedType.Func.builder() + .nullable(nullable) + .addAllParameterTypes(parameterTypes) + .returnType(returnType) + .build(); + } + @Override public ParameterizedType mapE(ParameterizedType key, ParameterizedType value) { return ParameterizedType.Map.builder().nullable(nullable).key(key).value(value).build(); diff --git a/core/src/main/java/io/substrait/function/ParameterizedTypeVisitor.java b/core/src/main/java/io/substrait/function/ParameterizedTypeVisitor.java index 9ff42f549..755c99777 100644 --- a/core/src/main/java/io/substrait/function/ParameterizedTypeVisitor.java +++ b/core/src/main/java/io/substrait/function/ParameterizedTypeVisitor.java @@ -29,6 +29,8 @@ public interface ParameterizedTypeVisitor extends TypeVi R visit(ParameterizedType.StringLiteral stringLiteral) throws E; + R visit(ParameterizedType.Func expr) throws E; + abstract class ParameterizedTypeThrowsVisitor extends TypeVisitor.TypeThrowsVisitor implements ParameterizedTypeVisitor { @@ -100,5 +102,10 @@ public R visit(ParameterizedType.Map expr) throws E { public R visit(ParameterizedType.StringLiteral stringLiteral) throws E { throw t(); } + + @Override + public R visit(ParameterizedType.Func expr) throws E { + throw t(); + } } } diff --git a/core/src/main/java/io/substrait/function/TypeExpression.java b/core/src/main/java/io/substrait/function/TypeExpression.java index a183c1959..345fc0398 100644 --- a/core/src/main/java/io/substrait/function/TypeExpression.java +++ b/core/src/main/java/io/substrait/function/TypeExpression.java @@ -191,6 +191,22 @@ R acceptE(final TypeExpressionVisitor visitor) th } } + @Value.Immutable + abstract class Func extends BaseTypeExpression implements NullableType { + public abstract java.util.List parameterTypes(); + + public abstract TypeExpression returnType(); + + public static ImmutableTypeExpression.Func.Builder builder() { + return ImmutableTypeExpression.Func.builder(); + } + + @Override + R acceptE(final TypeExpressionVisitor visitor) throws E { + return visitor.visit(this); + } + } + @Value.Immutable abstract class BinaryOperation extends BaseTypeExpression { public enum OpType { diff --git a/core/src/main/java/io/substrait/function/TypeExpressionCreator.java b/core/src/main/java/io/substrait/function/TypeExpressionCreator.java index b7524911b..9d822ffed 100644 --- a/core/src/main/java/io/substrait/function/TypeExpressionCreator.java +++ b/core/src/main/java/io/substrait/function/TypeExpressionCreator.java @@ -82,6 +82,16 @@ public TypeExpression mapE(TypeExpression key, TypeExpression value) { return TypeExpression.Map.builder().nullable(nullable).key(key).value(value).build(); } + @Override + public TypeExpression funcE( + Iterable parameterTypes, TypeExpression returnType) { + return TypeExpression.Func.builder() + .nullable(nullable) + .addAllParameterTypes(parameterTypes) + .returnType(returnType) + .build(); + } + public static class Assign { String name; TypeExpression expr; diff --git a/core/src/main/java/io/substrait/function/TypeExpressionVisitor.java b/core/src/main/java/io/substrait/function/TypeExpressionVisitor.java index 31d632c71..44d871337 100644 --- a/core/src/main/java/io/substrait/function/TypeExpressionVisitor.java +++ b/core/src/main/java/io/substrait/function/TypeExpressionVisitor.java @@ -24,6 +24,8 @@ public interface TypeExpressionVisitor R visit(TypeExpression.Map expr) throws E; + R visit(TypeExpression.Func expr) throws E; + R visit(TypeExpression.BinaryOperation expr) throws E; R visit(TypeExpression.NotOperation expr) throws E; @@ -97,6 +99,11 @@ public R visit(TypeExpression.Map expr) throws E { throw t(); } + @Override + public R visit(TypeExpression.Func expr) throws E { + throw t(); + } + @Override public R visit(TypeExpression.BinaryOperation expr) throws E { throw t(); diff --git a/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java b/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java index cdb72aea8..59ebe1257 100644 --- a/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java +++ b/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java @@ -432,6 +432,35 @@ public Optional visit( .build()); } + @Override + public Optional visit(Expression.Lambda lambda, EmptyVisitationContext context) + throws E { + Optional newBody = lambda.body().accept(this, context); + + if (allEmpty(newBody)) { + return Optional.empty(); + } + return Optional.of( + Expression.Lambda.builder().from(lambda).body(newBody.orElse(lambda.body())).build()); + } + + @Override + public Optional visit( + Expression.LambdaInvocation expr, EmptyVisitationContext context) throws E { + Optional lambda = expr.lambda().accept(this, context); + Optional arguments = expr.arguments().accept(this, context); + + if (allEmpty(lambda, arguments)) { + return Optional.empty(); + } + return Optional.of( + Expression.LambdaInvocation.builder() + .from(expr) + .lambda((Expression.Lambda) lambda.orElse(expr.lambda())) + .arguments((Expression.NestedStruct) arguments.orElse(expr.arguments())) + .build()); + } + // utilities protected Optional> visitExprList( diff --git a/core/src/main/java/io/substrait/relation/VirtualTableScan.java b/core/src/main/java/io/substrait/relation/VirtualTableScan.java index e71b0b00c..de7c29dbb 100644 --- a/core/src/main/java/io/substrait/relation/VirtualTableScan.java +++ b/core/src/main/java/io/substrait/relation/VirtualTableScan.java @@ -242,5 +242,11 @@ public Integer visit(Type.Map type) throws RuntimeException { public Integer visit(Type.UserDefined type) throws RuntimeException { return 0; } + + @Override + public Integer visit(Type.Func type) throws RuntimeException { + return type.parameterTypes().stream().mapToInt(p -> p.accept(this)).sum() + + type.returnType().accept(this); + } } } diff --git a/core/src/main/java/io/substrait/type/StringTypeVisitor.java b/core/src/main/java/io/substrait/type/StringTypeVisitor.java index d7c196148..e9d711d96 100644 --- a/core/src/main/java/io/substrait/type/StringTypeVisitor.java +++ b/core/src/main/java/io/substrait/type/StringTypeVisitor.java @@ -150,4 +150,13 @@ public String visit(Type.Map type) throws RuntimeException { public String visit(Type.UserDefined type) throws RuntimeException { return String.format("u!%s%s", type.name(), n(type)); } + + @Override + public String visit(Type.Func type) throws RuntimeException { + return String.format( + "func%s<%s -> %s>", + n(type), + type.parameterTypes().stream().map(t -> t.accept(this)).collect(Collectors.joining(", ")), + type.returnType().accept(this)); + } } diff --git a/core/src/main/java/io/substrait/type/Type.java b/core/src/main/java/io/substrait/type/Type.java index 5a1594e59..685cbe25a 100644 --- a/core/src/main/java/io/substrait/type/Type.java +++ b/core/src/main/java/io/substrait/type/Type.java @@ -352,6 +352,22 @@ public R accept(final TypeVisitor typeVisitor) th } } + @Value.Immutable + abstract class Func implements Type { + public abstract java.util.List parameterTypes(); + + public abstract Type returnType(); + + public static ImmutableType.Func.Builder builder() { + return ImmutableType.Func.builder(); + } + + @Override + public R accept(TypeVisitor typeVisitor) throws E { + return typeVisitor.visit(this); + } + } + @Value.Immutable abstract class Struct implements Type { public abstract java.util.List fields(); diff --git a/core/src/main/java/io/substrait/type/TypeCreator.java b/core/src/main/java/io/substrait/type/TypeCreator.java index 43358e505..7e4b1eec4 100644 --- a/core/src/main/java/io/substrait/type/TypeCreator.java +++ b/core/src/main/java/io/substrait/type/TypeCreator.java @@ -89,6 +89,14 @@ public final Type intervalCompound(int precision) { return Type.IntervalCompound.builder().nullable(nullable).precision(precision).build(); } + public Type.Func func(java.util.List parameterTypes, Type returnType) { + return Type.Func.builder() + .nullable(nullable) + .parameterTypes(parameterTypes) + .returnType(returnType) + .build(); + } + public Type.Struct struct(Iterable types) { return Type.Struct.builder().nullable(nullable).addAllFields(types).build(); } @@ -258,6 +266,11 @@ public Type visit(Type.Struct type) throws RuntimeException { return Type.Struct.builder().from(type).nullable(nullability).build(); } + @Override + public Type visit(Type.Func type) throws RuntimeException { + return Type.Func.builder().from(type).nullable(nullability).build(); + } + @Override public Type visit(Type.ListType type) throws RuntimeException { return Type.ListType.builder().from(type).nullable(nullability).build(); diff --git a/core/src/main/java/io/substrait/type/TypeVisitor.java b/core/src/main/java/io/substrait/type/TypeVisitor.java index ce6a08910..62e760175 100644 --- a/core/src/main/java/io/substrait/type/TypeVisitor.java +++ b/core/src/main/java/io/substrait/type/TypeVisitor.java @@ -51,6 +51,8 @@ public interface TypeVisitor { R visit(Type.Decimal type) throws E; + R visit(Type.Func type) throws E; + R visit(Type.Struct type) throws E; R visit(Type.ListType type) throws E; @@ -191,6 +193,11 @@ public R visit(Type.PrecisionTimestampTZ type) throws E { throw t(); } + @Override + public R visit(Type.Func type) throws E { + throw t(); + } + @Override public R visit(Type.Struct type) throws E { throw t(); diff --git a/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java b/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java index 67d7bc9b5..3909d4033 100644 --- a/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java +++ b/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java @@ -142,6 +142,16 @@ public final T visit(final Type.PrecisionTimestampTZ expr) { return typeContainer(expr).precisionTimestampTZ(expr.precision()); } + @Override + public final T visit(final Type.Func expr) { + return typeContainer(expr) + .func( + expr.parameterTypes().stream() + .map(t -> t.accept(this)) + .collect(java.util.stream.Collectors.toList()), + expr.returnType().accept(this)); + } + @Override public final T visit(final Type.Struct expr) { return typeContainer(expr) diff --git a/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java b/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java index 57b1f26b5..47842382f 100644 --- a/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java +++ b/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java @@ -119,6 +119,8 @@ public final T precisionTimestampTZ(int precision) { public abstract T intervalCompound(I precision); + public abstract T func(Iterable parameterTypes, T returnType); + public final T struct(T... types) { return struct(Arrays.asList(types)); } diff --git a/core/src/main/java/io/substrait/type/proto/ParameterizedProtoConverter.java b/core/src/main/java/io/substrait/type/proto/ParameterizedProtoConverter.java index 817f7b0b3..f4428ec47 100644 --- a/core/src/main/java/io/substrait/type/proto/ParameterizedProtoConverter.java +++ b/core/src/main/java/io/substrait/type/proto/ParameterizedProtoConverter.java @@ -256,6 +256,13 @@ public ParameterizedType map(ParameterizedType key, ParameterizedType value) { .build()); } + @Override + public ParameterizedType func( + Iterable parameterTypes, ParameterizedType returnType) { + throw new UnsupportedOperationException( + "Function types are not supported in Parameterized Types - ParameterizedFunc does not exist in proto"); + } + @Override public ParameterizedType userDefined(int ref) { throw new UnsupportedOperationException( diff --git a/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java b/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java index bdb600c1c..24231aebc 100644 --- a/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java +++ b/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java @@ -77,6 +77,13 @@ public Type from(io.substrait.proto.Type type) { case PRECISION_TIMESTAMP_TZ: return n(type.getPrecisionTimestampTz().getNullability()) .precisionTimestampTZ(type.getPrecisionTimestampTz().getPrecision()); + case FUNC: + return n(type.getFunc().getNullability()) + .func( + type.getFunc().getParameterTypesList().stream() + .map(this::from) + .collect(java.util.stream.Collectors.toList()), + from(type.getFunc().getReturnType())); case STRUCT: return n(type.getStruct().getNullability()) .struct( diff --git a/core/src/main/java/io/substrait/type/proto/TypeExpressionProtoVisitor.java b/core/src/main/java/io/substrait/type/proto/TypeExpressionProtoVisitor.java index f9d2129d2..8c92789cb 100644 --- a/core/src/main/java/io/substrait/type/proto/TypeExpressionProtoVisitor.java +++ b/core/src/main/java/io/substrait/type/proto/TypeExpressionProtoVisitor.java @@ -317,6 +317,13 @@ public DerivationExpression intervalCompound(DerivationExpression precision) { .build()); } + @Override + public DerivationExpression func( + Iterable parameterTypes, DerivationExpression returnType) { + throw new UnsupportedOperationException( + "User defined types are not supported in Derivation Expressions for now"); + } + @Override public DerivationExpression struct(Iterable types) { return wrap( diff --git a/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java b/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java index 6422904c4..c0e785db4 100644 --- a/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java +++ b/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java @@ -154,6 +154,16 @@ public Type precisionTimestampTZ(Integer precision) { .build()); } + @Override + public Type func(Iterable parameterTypes, Type returnType) { + return wrap( + Type.Func.newBuilder() + .addAllParameterTypes(parameterTypes) + .setReturnType(returnType) + .setNullability(nullability) + .build()); + } + @Override public Type struct(Iterable types) { return wrap(Type.Struct.newBuilder().addAllTypes(types).setNullability(nullability).build()); @@ -237,6 +247,8 @@ protected Type wrap(final Object o) { return bldr.setPrecisionTimestamp((Type.PrecisionTimestamp) o).build(); } else if (o instanceof Type.PrecisionTimestampTZ) { return bldr.setPrecisionTimestampTz((Type.PrecisionTimestampTZ) o).build(); + } else if (o instanceof Type.Func) { + return bldr.setFunc((Type.Func) o).build(); } else if (o instanceof Type.Struct) { return bldr.setStruct((Type.Struct) o).build(); } else if (o instanceof Type.List) { diff --git a/core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java new file mode 100644 index 000000000..537b7aaf0 --- /dev/null +++ b/core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java @@ -0,0 +1,471 @@ +package io.substrait.type.proto; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.substrait.TestBase; +import io.substrait.expression.Expression; +import io.substrait.expression.ExpressionCreator; +import io.substrait.expression.FieldReference; +import io.substrait.type.Type; +import io.substrait.type.TypeCreator; +import java.util.List; +import org.junit.jupiter.api.Test; + +/** + * Tests for Lambda expression round-trip conversion through protobuf. + * Based on equivalent tests from substrait-go. + */ +class LambdaExpressionRoundtripTest extends TestBase { + + /** + * Test that lambdas with no parameters are valid. + * Building: () -> i32(42) : func<() -> i32> + */ + @Test + void zeroParameterLambda() { + Type.Struct emptyParams = Type.Struct.builder() + .nullable(false) + .build(); + + Expression body = ExpressionCreator.i32(false, 42); + + Expression.Lambda lambda = Expression.Lambda.builder() + .parameters(emptyParams) + .body(body) + .build(); + + verifyRoundTrip(lambda); + + // Verify the lambda type + Type lambdaType = lambda.getType(); + assertInstanceOf(Type.Func.class, lambdaType); + Type.Func funcType = (Type.Func) lambdaType; + assertEquals(0, funcType.parameterTypes().size()); + assertEquals(R.I32, funcType.returnType()); + } + + /** + * Test valid stepsOut=0 references. + * Building: ($0: i32) -> $0 : func i32> + */ + @Test + void validStepsOut0() { + Type.Struct params = Type.Struct.builder() + .nullable(false) + .addFields(R.I32) + .build(); + + // Lambda body references parameter 0 with stepsOut=0 + FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); + + Expression.Lambda lambda = Expression.Lambda.builder() + .parameters(params) + .body(paramRef) + .build(); + + verifyRoundTrip(lambda); + + // Verify types + Type lambdaType = lambda.getType(); + assertInstanceOf(Type.Func.class, lambdaType); + Type.Func funcType = (Type.Func) lambdaType; + assertEquals(1, funcType.parameterTypes().size()); + assertEquals(R.I32, funcType.parameterTypes().get(0)); + assertEquals(R.I32, funcType.returnType()); + } + + /** + * Test valid field index with multiple parameters. + * Building: ($0: i32, $1: i64, $2: string) -> $2 : func<(i32, i64, string) -> string> + */ + @Test + void validFieldIndex() { + Type.Struct params = Type.Struct.builder() + .nullable(false) + .addFields(R.I32, R.I64, R.STRING) + .build(); + + // Reference the 3rd parameter (string) + FieldReference paramRef = FieldReference.newLambdaParameterReference(2, params, 0); + + Expression.Lambda lambda = Expression.Lambda.builder() + .parameters(params) + .body(paramRef) + .build(); + + verifyRoundTrip(lambda); + + // Verify return type is string + Type.Func funcType = (Type.Func) lambda.getType(); + assertEquals(R.STRING, funcType.returnType()); + } + + /** + * Test type resolution for different parameter types. + */ + @Test + void typeResolution() { + // Test cases: (paramTypes, fieldIndex, expectedReturnType) + record TestCase(List paramTypes, int fieldIndex, Type expectedType) {} + + List testCases = List.of( + new TestCase(List.of(R.I32), 0, R.I32), + new TestCase(List.of(R.I32, R.I64), 1, R.I64), + new TestCase(List.of(R.I32, R.I64, R.STRING), 2, R.STRING), + new TestCase(List.of(R.FP64), 0, R.FP64), + new TestCase(List.of(R.BOOLEAN, R.DATE, R.TIMESTAMP), 1, R.DATE) + ); + + for (TestCase tc : testCases) { + Type.Struct params = Type.Struct.builder() + .nullable(false) + .addAllFields(tc.paramTypes) + .build(); + + FieldReference paramRef = FieldReference.newLambdaParameterReference(tc.fieldIndex, params, 0); + + Expression.Lambda lambda = Expression.Lambda.builder() + .parameters(params) + .body(paramRef) + .build(); + + verifyRoundTrip(lambda); + + // Verify the body type matches expected + assertEquals(tc.expectedType, lambda.body().getType(), + "Body type should match referenced parameter type"); + + // Verify lambda return type + Type.Func funcType = (Type.Func) lambda.getType(); + assertEquals(tc.expectedType, funcType.returnType(), + "Lambda return type should match body type"); + } + } + + /** + * Test nested lambda with outer reference. + * Building: ($0: i64, $1: i64) -> (($0: i32) -> outer[$0] : i64) : func<(i64, i64) -> func i64>> + */ + @Test + void nestedLambdaWithOuterRef() { + Type.Struct outerParams = Type.Struct.builder() + .nullable(false) + .addFields(R.I64, R.I64) + .build(); + + Type.Struct innerParams = Type.Struct.builder() + .nullable(false) + .addFields(R.I32) + .build(); + + // Inner lambda references outer's parameter 0 with stepsOut=1 + FieldReference outerRef = FieldReference.newLambdaParameterReference(0, outerParams, 1); + + Expression.Lambda innerLambda = Expression.Lambda.builder() + .parameters(innerParams) + .body(outerRef) + .build(); + + Expression.Lambda outerLambda = Expression.Lambda.builder() + .parameters(outerParams) + .body(innerLambda) + .build(); + + verifyRoundTrip(outerLambda); + + // Verify structure + assertInstanceOf(Expression.Lambda.class, outerLambda.body()); + Expression.Lambda resultInner = (Expression.Lambda) outerLambda.body(); + assertEquals(1, resultInner.parameters().fields().size()); + } + + /** + * Test outer reference type resolution in nested lambdas. + * Building: ($0: i32, $1: i64, $2: string) -> (($0: fp64) -> outer[$2] : string) : func<...> + */ + @Test + void outerRefTypeResolution() { + Type.Struct outerParams = Type.Struct.builder() + .nullable(false) + .addFields(R.I32, R.I64, R.STRING) + .build(); + + Type.Struct innerParams = Type.Struct.builder() + .nullable(false) + .addFields(R.FP64) + .build(); + + // Inner references outer's field 2 (string) with stepsOut=1 + FieldReference outerRef = FieldReference.newLambdaParameterReference(2, outerParams, 1); + + Expression.Lambda innerLambda = Expression.Lambda.builder() + .parameters(innerParams) + .body(outerRef) + .build(); + + Expression.Lambda outerLambda = Expression.Lambda.builder() + .parameters(outerParams) + .body(innerLambda) + .build(); + + verifyRoundTrip(outerLambda); + + // Verify inner lambda's return type is string (from outer param 2) + Expression.Lambda resultInner = (Expression.Lambda) outerLambda.body(); + Type.Func innerFuncType = (Type.Func) resultInner.getType(); + assertEquals(R.STRING, innerFuncType.returnType(), + "Inner lambda return type should be string from outer.$2"); + + // Verify body's type is also string + assertEquals(R.STRING, resultInner.body().getType(), + "Body type should be string"); + } + + /** + * Test deeply nested field ref inside Cast. + * Building: ($0: i32) -> cast($0 as i64) : func i64> + */ + @Test + void deeplyNestedFieldRef() { + Type.Struct params = Type.Struct.builder() + .nullable(false) + .addFields(R.I32) + .build(); + + FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); + + Expression.Cast castExpr = (Expression.Cast) + ExpressionCreator.cast(R.I64, paramRef, Expression.FailureBehavior.THROW_EXCEPTION); + + Expression.Lambda lambda = Expression.Lambda.builder() + .parameters(params) + .body(castExpr) + .build(); + + verifyRoundTrip(lambda); + + // Verify the nested FieldRef has its type resolved + Expression.Cast resultCast = (Expression.Cast) lambda.body(); + assertInstanceOf(FieldReference.class, resultCast.input()); + FieldReference resultFieldRef = (FieldReference) resultCast.input(); + + assertNotNull(resultFieldRef.getType(), "Nested FieldRef should have type resolved"); + assertEquals(R.I32, resultFieldRef.getType(), "Should resolve to i32"); + + // Verify lambda return type is i64 (cast output) + Type.Func funcType = (Type.Func) lambda.getType(); + assertEquals(R.I64, funcType.returnType()); + } + + /** + * Test doubly nested field ref (Cast(Cast(LambdaParamRef))). + * Building: ($0: i32) -> cast(cast($0 as i64) as string) : func string> + */ + @Test + void doublyNestedFieldRef() { + Type.Struct params = Type.Struct.builder() + .nullable(false) + .addFields(R.I32) + .build(); + + FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); + Expression.Cast innerCast = (Expression.Cast) + ExpressionCreator.cast(R.I64, paramRef, Expression.FailureBehavior.THROW_EXCEPTION); + Expression.Cast outerCast = (Expression.Cast) + ExpressionCreator.cast(R.STRING, innerCast, Expression.FailureBehavior.THROW_EXCEPTION); + + Expression.Lambda lambda = Expression.Lambda.builder() + .parameters(params) + .body(outerCast) + .build(); + + verifyRoundTrip(lambda); + + // Navigate to the deeply nested FieldRef (2 levels deep) + Expression.Cast resultOuter = (Expression.Cast) lambda.body(); + Expression.Cast resultInner = (Expression.Cast) resultOuter.input(); + FieldReference resultFieldRef = (FieldReference) resultInner.input(); + + // Verify type is resolved even at depth 2 + assertNotNull(resultFieldRef.getType(), "FieldRef at depth 2 should have type resolved"); + assertEquals(R.I32, resultFieldRef.getType()); + } + + /** + * Test lambda with literal body (no parameter references). + * Building: ($0: i32) -> 42 : func i32> + */ + @Test + void lambdaWithLiteralBody() { + Type.Struct params = Type.Struct.builder() + .nullable(false) + .addFields(R.I32) + .build(); + + Expression body = ExpressionCreator.i32(false, 42); + + Expression.Lambda lambda = Expression.Lambda.builder() + .parameters(params) + .body(body) + .build(); + + verifyRoundTrip(lambda); + } + + /** + * Test lambda getType returns correct Func type. + */ + @Test + void lambdaGetTypeReturnsFunc() { + Type.Struct params = Type.Struct.builder() + .nullable(false) + .addFields(R.I32, R.STRING) + .build(); + + FieldReference paramRef = FieldReference.newLambdaParameterReference(1, params, 0); + + Expression.Lambda lambda = Expression.Lambda.builder() + .parameters(params) + .body(paramRef) + .build(); + + Type lambdaType = lambda.getType(); + + assertInstanceOf(Type.Func.class, lambdaType); + Type.Func funcType = (Type.Func) lambdaType; + + assertEquals(2, funcType.parameterTypes().size()); + assertEquals(R.I32, funcType.parameterTypes().get(0)); + assertEquals(R.STRING, funcType.parameterTypes().get(1)); + assertEquals(R.STRING, funcType.returnType()); // body references param 1 which is STRING + } + + // ==================== Validation Error Tests ==================== + + /** + * Test that invalid outer reference (stepsOut too high) fails during proto conversion. + * Building: ($0: i32) -> outer[$0] : INVALID (no outer lambda, stepsOut=1) + */ + @Test + void invalidOuterRef_stepsOutTooHigh() { + Type.Struct params = Type.Struct.builder() + .nullable(false) + .addFields(R.I32) + .build(); + + // Create a parameter reference with stepsOut=1 but no outer lambda exists + FieldReference invalidRef = FieldReference.newLambdaParameterReference(0, params, 1); + + Expression.Lambda lambda = Expression.Lambda.builder() + .parameters(params) + .body(invalidRef) + .build(); + + // Convert to proto - this should work + io.substrait.proto.Expression protoExpression = expressionProtoConverter.toProto(lambda); + + // Converting back should fail because stepsOut=1 references non-existent outer lambda + assertThrows(IllegalArgumentException.class, () -> { + protoExpressionConverter.from(protoExpression); + }, "Should fail when stepsOut references non-existent outer lambda"); + } + + /** + * Test that invalid field index (out of bounds) fails during proto conversion. + * Building: ($0: i32) -> $5 : INVALID (only has 1 param) + */ + @Test + void invalidFieldIndex_outOfBounds() { + Type.Struct params = Type.Struct.builder() + .nullable(false) + .addFields(R.I32) + .build(); + + // Create a reference to field 5, but lambda only has 1 parameter (index 0) + // This will fail at build time since newLambdaParameterReference accesses fields.get(5) + assertThrows(IndexOutOfBoundsException.class, () -> { + FieldReference.newLambdaParameterReference(5, params, 0); + }, "Should fail when field index is out of bounds"); + } + + /** + * Test nested invalid outer ref (stepsOut=2 but only 1 outer lambda). + * Building: ($0: i64) -> (($0: i32) -> outer.outer[$0]) : INVALID (no grandparent lambda) + */ + @Test + void nestedInvalidOuterRef() { + Type.Struct outerParams = Type.Struct.builder() + .nullable(false) + .addFields(R.I64) + .build(); + + Type.Struct innerParams = Type.Struct.builder() + .nullable(false) + .addFields(R.I32) + .build(); + + // Inner lambda references stepsOut=2, but only 1 outer lambda exists + FieldReference invalidRef = FieldReference.newLambdaParameterReference(0, outerParams, 2); + + Expression.Lambda innerLambda = Expression.Lambda.builder() + .parameters(innerParams) + .body(invalidRef) + .build(); + + Expression.Lambda outerLambda = Expression.Lambda.builder() + .parameters(outerParams) + .body(innerLambda) + .build(); + + // Convert to proto + io.substrait.proto.Expression protoExpression = expressionProtoConverter.toProto(outerLambda); + + // Converting back should fail because stepsOut=2 references non-existent grandparent + assertThrows(IllegalArgumentException.class, () -> { + protoExpressionConverter.from(protoExpression); + }, "Should fail when stepsOut references non-existent grandparent lambda"); + } + + /** + * Test deeply nested invalid field ref inside Cast. + * Building: ($0: i32) -> cast($5 as i64) : INVALID (only has 1 param) + */ + @Test + void deeplyNestedInvalidFieldRef() { + Type.Struct params = Type.Struct.builder() + .nullable(false) + .addFields(R.I32) + .build(); + + // This should fail at build time since field 5 doesn't exist + assertThrows(IndexOutOfBoundsException.class, () -> { + FieldReference.newLambdaParameterReference(5, params, 0); + }, "Should fail when nested field index is out of bounds"); + } + + /** + * Test that outer field index out of bounds fails. + * Building: ($0: i64) -> (($0: i32) -> outer[$5]) : INVALID (outer only has 1 param) + */ + @Test + void nestedInvalidOuterFieldIndex() { + Type.Struct outerParams = Type.Struct.builder() + .nullable(false) + .addFields(R.I64) + .build(); + + Type.Struct innerParams = Type.Struct.builder() + .nullable(false) + .addFields(R.I32) + .build(); + + // This should fail at build time since outer only has 1 parameter (index 0) + assertThrows(IndexOutOfBoundsException.class, () -> { + FieldReference.newLambdaParameterReference(5, outerParams, 1); + }, "Should fail when outer field index is out of bounds"); + } +} diff --git a/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java b/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java index a26ec963e..f9125587d 100644 --- a/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java +++ b/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java @@ -195,6 +195,12 @@ public String visit(Expression.NestedStruct expr, EmptyVisitationContext context return ""; } + @Override + public String visit(Expression.Lambda expr, EmptyVisitationContext context) + throws RuntimeException { + return ""; + } + @Override public String visit(UserDefinedAnyLiteral expr, EmptyVisitationContext context) throws RuntimeException { @@ -217,6 +223,12 @@ public String visit(IfThen expr, EmptyVisitationContext context) throws RuntimeE return ""; } + @Override + public String visit(Expression.LambdaInvocation expr, EmptyVisitationContext context) + throws RuntimeException { + return ""; + } + @Override public String visit(ScalarFunctionInvocation expr, EmptyVisitationContext context) throws RuntimeException { diff --git a/examples/substrait-spark/src/main/java/io/substrait/examples/util/TypeStringify.java b/examples/substrait-spark/src/main/java/io/substrait/examples/util/TypeStringify.java index 0e13c3e2e..f5f8d93e9 100644 --- a/examples/substrait-spark/src/main/java/io/substrait/examples/util/TypeStringify.java +++ b/examples/substrait-spark/src/main/java/io/substrait/examples/util/TypeStringify.java @@ -149,6 +149,11 @@ public String visit(Decimal type) throws RuntimeException { return type.getClass().getSimpleName(); } + @Override + public String visit(Type.Func type) throws RuntimeException { + return type.getClass().getSimpleName(); + } + @Override public String visit(Struct type) throws RuntimeException { StringBuffer sb = new StringBuffer(type.getClass().getSimpleName()); diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java index e06f66e8b..0f31d3d4b 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java @@ -41,6 +41,7 @@ import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexFieldCollation; import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexLambdaRef; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexSubQuery; @@ -381,6 +382,23 @@ public RexNode visit(Expression.IfThen expr, Context context) throws RuntimeExce return rexBuilder.makeCall(SqlStdOperatorTable.CASE, args); } + @Override + public RexNode visit(Expression.Lambda expr, Context context) throws RuntimeException { + List parameters = + IntStream.range(0, expr.parameters().fields().size()) + .mapToObj( + i -> + new RexLambdaRef( + i, + "p" + i, + typeConverter.toCalcite(typeFactory, expr.parameters().fields().get(i)))) + .collect(Collectors.toList()); + + RexNode body = expr.body().accept(this, context); + + return rexBuilder.makeLambdaCall(body, parameters); + } + @Override public RexNode visit(Switch expr, Context context) throws RuntimeException { RexNode match = expr.match().accept(this, context); @@ -655,6 +673,21 @@ public RexNode visit(FieldReference expr, Context context) throws RuntimeExcepti } return rexInputRef; + } else if (expr.isLambdaParameterReference()) { + int stepsOut = expr.lambdaParameterReferenceStepsOut().get(); + if (stepsOut != 0) { + throw new UnsupportedOperationException( + "Calcite does not support nested lambdas (stepsOut=" + stepsOut + ")"); + } + + final ReferenceSegment segment = expr.segments().get(0); + if (segment instanceof FieldReference.StructField) { + final FieldReference.StructField field = (FieldReference.StructField) segment; + RelDataType calciteType = typeConverter.toCalcite(typeFactory, expr.getType()); + return new RexLambdaRef(field.offset(), "p" + field.offset(), calciteType); + } else { + throw new IllegalArgumentException("Unhandled type: " + segment); + } } return visitFallback(expr, context); diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/IgnoreNullableAndParameters.java b/isthmus/src/main/java/io/substrait/isthmus/expression/IgnoreNullableAndParameters.java index f8b4be1dd..b56fa4fd3 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/IgnoreNullableAndParameters.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/IgnoreNullableAndParameters.java @@ -128,6 +128,11 @@ public Boolean visit(Type.Decimal type) { return typeToMatch instanceof Type.Decimal || typeToMatch instanceof ParameterizedType.Decimal; } + @Override + public Boolean visit(Type.Func type) throws RuntimeException { + return typeToMatch instanceof Type.Func || typeToMatch instanceof ParameterizedType.Func; + } + @Override public Boolean visit(Type.PrecisionTime type) { return typeToMatch instanceof Type.PrecisionTime @@ -234,4 +239,9 @@ public Boolean visit(ParameterizedType.Map expr) throws RuntimeException { public Boolean visit(ParameterizedType.StringLiteral stringLiteral) throws RuntimeException { return false; } + + @Override + public Boolean visit(ParameterizedType.Func expr) throws RuntimeException { + return typeToMatch instanceof Type.Func || typeToMatch instanceof ParameterizedType.Func; + } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java index 6993c8451..19b3d2224 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java @@ -7,6 +7,7 @@ import io.substrait.isthmus.TypeConverter; import io.substrait.relation.Rel; import io.substrait.type.StringTypeVisitor; +import io.substrait.type.Type; import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -202,12 +203,30 @@ public Expression visitPatternFieldRef(RexPatternFieldRef fieldRef) { @Override public Expression visitLambda(RexLambda rexLambda) { - throw new UnsupportedOperationException("RexLambda not supported"); + List paramTypes = + rexLambda.getParameters().stream() + .map(param -> typeConverter.toSubstrait(param.getType())) + .collect(Collectors.toList()); + + Type.Struct parameters = Type.Struct.builder().nullable(false).addAllFields(paramTypes).build(); + + + Expression body = rexLambda.getExpression().accept(this); + + return Expression.Lambda.builder().parameters(parameters).body(body).build(); } @Override public Expression visitLambdaRef(RexLambdaRef rexLambdaRef) { - throw new UnsupportedOperationException("RexLambdaRef not supported"); + int fieldIndex = rexLambdaRef.getIndex(); + Type paramType = typeConverter.toSubstrait(rexLambdaRef.getType()); + + return FieldReference.builder() + .addSegments(FieldReference.StructField.of(fieldIndex)) + .type(paramType) + .lambdaParameterReferenceStepsOut( + 0) // Always 0 since Calcite doesn't support nested Lambda expressions + .build(); } @Override diff --git a/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java b/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java new file mode 100644 index 000000000..8b391a01d --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java @@ -0,0 +1,191 @@ +package io.substrait.isthmus; + + +import io.substrait.expression.Expression; +import io.substrait.expression.ExpressionCreator; +import io.substrait.expression.FieldReference; +import io.substrait.relation.Project; +import io.substrait.relation.Rel; +import io.substrait.type.Type; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +/** + * Tests for Lambda expression conversion between Substrait and Calcite. + * Note: Calcite does not support nested lambda expressions for the moment, so all tests use stepsOut=0. + */ + +public class LambdaExpressionTest extends PlanTestBase{ + + final Rel emptyTable = sb.emptyVirtualTableScan(); + + /** + * Test that lambdas with no parameters are valid. + * Building: () -> i32(42) : func<() -> i32> + */ + @Test + void lambdaExpressionZeroParameters(){ + Type.Struct params = Type.Struct.builder() + .nullable(false) + .build(); + + Expression body = ExpressionCreator.i32(false, 42); + List expressionList = new ArrayList<>(); + Expression.Lambda lambda = Expression.Lambda.builder() + .parameters(params) + .body(body) + .build(); + expressionList.add(lambda); + + Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + assertFullRoundTrip(project); + + } + + /** + * Test valid field index with multiple parameters. + * Building: ($0: i32, $1: i64, $2: string) -> $2 : func<(i32, i64, string) -> string> + */ + @Test + void validFieldIndex() { + Type.Struct params = Type.Struct.builder() + .nullable(false) + .addFields(R.I32, R.I64, R.STRING) + .build(); + + FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); + + List expressionList = new ArrayList<>(); + Expression.Lambda lambda = Expression.Lambda.builder() + .parameters(params) + .body(paramRef) + .build(); + + expressionList.add(lambda); + + + Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + assertFullRoundTrip(project); + } + + /** + * Test deeply nested field ref inside Cast. + * Building: ($0: i32) -> cast($0 as i64) : func i64> + */ + @Test + void deeplyNestedFieldRef() { + Type.Struct params = Type.Struct.builder() + .nullable(false) + .addFields(R.I32) + .build(); + FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); + + Expression.Cast castExpr = (Expression.Cast) + ExpressionCreator.cast(R.I64, paramRef, Expression.FailureBehavior.THROW_EXCEPTION); + List expressionList = new ArrayList<>(); + + + Expression.Lambda lambda = Expression.Lambda.builder() + .parameters(params) + .body(castExpr) + .build(); + + expressionList.add(lambda); + Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + assertFullRoundTrip(project); + } + + /** + * Test doubly nested field ref (Cast(Cast(LambdaParamRef))). + * Building: ($0: i32) -> cast(cast($0 as i64) as string) : func string> + */ + @Test + void doublyNestedFieldRef() { + Type.Struct params = Type.Struct.builder() + .nullable(false) + .addFields(R.I32) + .build(); + + FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); + Expression.Cast innerCast = (Expression.Cast) + ExpressionCreator.cast(R.I64, paramRef, Expression.FailureBehavior.THROW_EXCEPTION); + Expression.Cast outerCast = (Expression.Cast) + ExpressionCreator.cast(R.STRING, innerCast, Expression.FailureBehavior.THROW_EXCEPTION); + + List expressionList = new ArrayList<>(); + + Expression.Lambda lambda = Expression.Lambda.builder() + .parameters(params) + .body(outerCast) + .build(); + + expressionList.add(lambda); + Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + assertFullRoundTrip(project); + } + + /** + * Test lambda with literal body (no parameter references). + * Building: ($0: i32) -> 42 : func i32> + */ + @Test + void lambdaWithLiteralBody() { + Type.Struct params = Type.Struct.builder() + .nullable(false) + .addFields(R.I32) + .build(); + + Expression body = ExpressionCreator.i32(false, 42); + List expressionList = new ArrayList<>(); + + Expression.Lambda lambda = Expression.Lambda.builder() + .parameters(params) + .body(body) + .build(); + + expressionList.add(lambda); + Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + assertFullRoundTrip(project); + } + + /** + * Test that nested lambda (stepsOut > 0) throws UnsupportedOperationException. + * Calcite does not support nested lambda expressions. + */ + @Test + void nestedLambdaThrowsUnsupportedOperation() { + Type.Struct outerParams = Type.Struct.builder() + .nullable(false) + .addFields(R.I64) + .build(); + + Type.Struct innerParams = Type.Struct.builder() + .nullable(false) + .addFields(R.I32) + .build(); + + // Inner lambda references outer's parameter with stepsOut=1 + FieldReference outerRef = FieldReference.newLambdaParameterReference(0, outerParams, 1); + + Expression.Lambda innerLambda = Expression.Lambda.builder() + .parameters(innerParams) + .body(outerRef) + .build(); + + List expressionList = new ArrayList<>(); + + Expression.Lambda outerLambda = Expression.Lambda.builder() + .parameters(outerParams) + .body(innerLambda) + .build(); + + expressionList.add(outerLambda); + Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + assertThrows(UnsupportedOperationException.class, () -> assertFullRoundTrip(project)); + + } + } diff --git a/spark/src/main/scala/io/substrait/spark/expression/IgnoreNullableAndParameters.scala b/spark/src/main/scala/io/substrait/spark/expression/IgnoreNullableAndParameters.scala index 9cb38b8d0..acf006215 100644 --- a/spark/src/main/scala/io/substrait/spark/expression/IgnoreNullableAndParameters.scala +++ b/spark/src/main/scala/io/substrait/spark/expression/IgnoreNullableAndParameters.scala @@ -157,4 +157,12 @@ class IgnoreNullableAndParameters(val typeToMatch: ParameterizedType) @throws[RuntimeException] override def visit(precisionTimestampTZ: Type.PrecisionTimestampTZ): Boolean = typeToMatch.isInstanceOf[Type.PrecisionTimestampTZ] + + @throws[RuntimeException] + override def visit(`type`: Type.Func): Boolean = + typeToMatch.isInstanceOf[Type.Func] || typeToMatch.isInstanceOf[ParameterizedType.Func] + + @throws[RuntimeException] + override def visit(expr: ParameterizedType.Func): Boolean = + typeToMatch.isInstanceOf[Type.Func] || typeToMatch.isInstanceOf[ParameterizedType.Func] } From a90c6099aad6b50b7178ec42e19f7f5a8421a518 Mon Sep 17 00:00:00 2001 From: Limame Malainine Date: Wed, 25 Feb 2026 12:06:54 +0100 Subject: [PATCH 02/26] adresse some of @benbellick's comments --- .../expression/AbstractExpressionVisitor.java | 13 - .../io/substrait/expression/Expression.java | 22 -- .../expression/ExpressionVisitor.java | 10 - .../proto/ExpressionProtoConverter.java | 12 - .../proto/ProtoExpressionConverter.java | 17 - .../ExpressionCopyOnWriteVisitor.java | 17 - .../proto/LambdaExpressionRoundtripTest.java | 344 ++++++------------ .../examples/util/ExpressionStringify.java | 6 - .../expression/ExpressionRexConverter.java | 2 + .../expression/RexExpressionConverter.java | 8 +- .../isthmus/LambdaExpressionTest.java | 312 +++++++--------- 11 files changed, 258 insertions(+), 505 deletions(-) diff --git a/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java b/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java index a61f5d824..9b1fdc496 100644 --- a/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java +++ b/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java @@ -461,19 +461,6 @@ public O visit(Expression.Lambda expr, C context) throws E { return visitFallback(expr, context); } - /** - * Visits a Lambda expression invocation. - * - * @param expr the Lambda expression invocation - * @param context the visitation context - * @return the visit result - * @throws E if visitation fails - */ - @Override - public O visit(Expression.LambdaInvocation expr, C context) throws E { - return visitFallback(expr, context); - } - /** * Visits a scalar function invocation. * diff --git a/core/src/main/java/io/substrait/expression/Expression.java b/core/src/main/java/io/substrait/expression/Expression.java index 1301c5c5c..320a8a614 100644 --- a/core/src/main/java/io/substrait/expression/Expression.java +++ b/core/src/main/java/io/substrait/expression/Expression.java @@ -927,28 +927,6 @@ public R accept( } } - @Value.Immutable - abstract class LambdaInvocation implements Expression { - public abstract Lambda lambda(); - - public abstract Expression.NestedStruct arguments(); - - @Override - public Type getType() { - return ((Type.Func) lambda().getType()).returnType(); - } - - public static ImmutableExpression.LambdaInvocation.Builder builder() { - return ImmutableExpression.LambdaInvocation.builder(); - } - - @Override - public R accept( - ExpressionVisitor visitor, C context) throws E { - return visitor.visit(this, context); - } - } - @Value.Immutable abstract class ScalarFunctionInvocation implements Expression { public abstract SimpleExtension.ScalarFunctionVariant declaration(); diff --git a/core/src/main/java/io/substrait/expression/ExpressionVisitor.java b/core/src/main/java/io/substrait/expression/ExpressionVisitor.java index e36d92e4f..a505af778 100644 --- a/core/src/main/java/io/substrait/expression/ExpressionVisitor.java +++ b/core/src/main/java/io/substrait/expression/ExpressionVisitor.java @@ -361,16 +361,6 @@ public interface ExpressionVisitor visit(Expression.Lambda lambda, EmptyVisitationConte Expression.Lambda.builder().from(lambda).body(newBody.orElse(lambda.body())).build()); } - @Override - public Optional visit( - Expression.LambdaInvocation expr, EmptyVisitationContext context) throws E { - Optional lambda = expr.lambda().accept(this, context); - Optional arguments = expr.arguments().accept(this, context); - - if (allEmpty(lambda, arguments)) { - return Optional.empty(); - } - return Optional.of( - Expression.LambdaInvocation.builder() - .from(expr) - .lambda((Expression.Lambda) lambda.orElse(expr.lambda())) - .arguments((Expression.NestedStruct) arguments.orElse(expr.arguments())) - .build()); - } - // utilities protected Optional> visitExprList( diff --git a/core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java index 537b7aaf0..c0a979a94 100644 --- a/core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java @@ -4,39 +4,30 @@ import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; import io.substrait.TestBase; import io.substrait.expression.Expression; import io.substrait.expression.ExpressionCreator; import io.substrait.expression.FieldReference; import io.substrait.type.Type; -import io.substrait.type.TypeCreator; import java.util.List; import org.junit.jupiter.api.Test; /** - * Tests for Lambda expression round-trip conversion through protobuf. - * Based on equivalent tests from substrait-go. + * Tests for Lambda expression round-trip conversion through protobuf. Based on equivalent tests + * from substrait-go. */ class LambdaExpressionRoundtripTest extends TestBase { - /** - * Test that lambdas with no parameters are valid. - * Building: () -> i32(42) : func<() -> i32> - */ + /** Test that lambdas with no parameters are valid. Building: () -> i32(42) : func<() -> i32> */ @Test void zeroParameterLambda() { - Type.Struct emptyParams = Type.Struct.builder() - .nullable(false) - .build(); + Type.Struct emptyParams = Type.Struct.builder().nullable(false).build(); Expression body = ExpressionCreator.i32(false, 42); - Expression.Lambda lambda = Expression.Lambda.builder() - .parameters(emptyParams) - .body(body) - .build(); + Expression.Lambda lambda = + Expression.Lambda.builder().parameters(emptyParams).body(body).build(); verifyRoundTrip(lambda); @@ -48,24 +39,16 @@ void zeroParameterLambda() { assertEquals(R.I32, funcType.returnType()); } - /** - * Test valid stepsOut=0 references. - * Building: ($0: i32) -> $0 : func i32> - */ + /** Test valid stepsOut=0 references. Building: ($0: i32) -> $0 : func i32> */ @Test void validStepsOut0() { - Type.Struct params = Type.Struct.builder() - .nullable(false) - .addFields(R.I32) - .build(); + Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); // Lambda body references parameter 0 with stepsOut=0 FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); - Expression.Lambda lambda = Expression.Lambda.builder() - .parameters(params) - .body(paramRef) - .build(); + Expression.Lambda lambda = + Expression.Lambda.builder().parameters(params).body(paramRef).build(); verifyRoundTrip(lambda); @@ -79,23 +62,19 @@ void validStepsOut0() { } /** - * Test valid field index with multiple parameters. - * Building: ($0: i32, $1: i64, $2: string) -> $2 : func<(i32, i64, string) -> string> + * Test valid field index with multiple parameters. Building: ($0: i32, $1: i64, $2: string) -> $2 + * : func<(i32, i64, string) -> string> */ @Test void validFieldIndex() { - Type.Struct params = Type.Struct.builder() - .nullable(false) - .addFields(R.I32, R.I64, R.STRING) - .build(); + Type.Struct params = + Type.Struct.builder().nullable(false).addFields(R.I32, R.I64, R.STRING).build(); // Reference the 3rd parameter (string) FieldReference paramRef = FieldReference.newLambdaParameterReference(2, params, 0); - Expression.Lambda lambda = Expression.Lambda.builder() - .parameters(params) - .body(paramRef) - .build(); + Expression.Lambda lambda = + Expression.Lambda.builder().parameters(params).body(paramRef).build(); verifyRoundTrip(lambda); @@ -104,76 +83,63 @@ void validFieldIndex() { assertEquals(R.STRING, funcType.returnType()); } - /** - * Test type resolution for different parameter types. - */ + /** Test type resolution for different parameter types. */ @Test void typeResolution() { // Test cases: (paramTypes, fieldIndex, expectedReturnType) record TestCase(List paramTypes, int fieldIndex, Type expectedType) {} - List testCases = List.of( - new TestCase(List.of(R.I32), 0, R.I32), - new TestCase(List.of(R.I32, R.I64), 1, R.I64), - new TestCase(List.of(R.I32, R.I64, R.STRING), 2, R.STRING), - new TestCase(List.of(R.FP64), 0, R.FP64), - new TestCase(List.of(R.BOOLEAN, R.DATE, R.TIMESTAMP), 1, R.DATE) - ); + List testCases = + List.of( + new TestCase(List.of(R.I32), 0, R.I32), + new TestCase(List.of(R.I32, R.I64), 1, R.I64), + new TestCase(List.of(R.I32, R.I64, R.STRING), 2, R.STRING), + new TestCase(List.of(R.FP64), 0, R.FP64), + new TestCase(List.of(R.BOOLEAN, R.DATE, R.TIMESTAMP), 1, R.DATE)); for (TestCase tc : testCases) { - Type.Struct params = Type.Struct.builder() - .nullable(false) - .addAllFields(tc.paramTypes) - .build(); + Type.Struct params = + Type.Struct.builder().nullable(false).addAllFields(tc.paramTypes).build(); - FieldReference paramRef = FieldReference.newLambdaParameterReference(tc.fieldIndex, params, 0); + FieldReference paramRef = + FieldReference.newLambdaParameterReference(tc.fieldIndex, params, 0); - Expression.Lambda lambda = Expression.Lambda.builder() - .parameters(params) - .body(paramRef) - .build(); + Expression.Lambda lambda = + Expression.Lambda.builder().parameters(params).body(paramRef).build(); verifyRoundTrip(lambda); // Verify the body type matches expected - assertEquals(tc.expectedType, lambda.body().getType(), + assertEquals( + tc.expectedType, + lambda.body().getType(), "Body type should match referenced parameter type"); // Verify lambda return type Type.Func funcType = (Type.Func) lambda.getType(); - assertEquals(tc.expectedType, funcType.returnType(), - "Lambda return type should match body type"); + assertEquals( + tc.expectedType, funcType.returnType(), "Lambda return type should match body type"); } } /** - * Test nested lambda with outer reference. - * Building: ($0: i64, $1: i64) -> (($0: i32) -> outer[$0] : i64) : func<(i64, i64) -> func i64>> + * Test nested lambda with outer reference. Building: ($0: i64, $1: i64) -> (($0: i32) -> + * outer[$0] : i64) : func<(i64, i64) -> func i64>> */ @Test void nestedLambdaWithOuterRef() { - Type.Struct outerParams = Type.Struct.builder() - .nullable(false) - .addFields(R.I64, R.I64) - .build(); + Type.Struct outerParams = Type.Struct.builder().nullable(false).addFields(R.I64, R.I64).build(); - Type.Struct innerParams = Type.Struct.builder() - .nullable(false) - .addFields(R.I32) - .build(); + Type.Struct innerParams = Type.Struct.builder().nullable(false).addFields(R.I32).build(); // Inner lambda references outer's parameter 0 with stepsOut=1 FieldReference outerRef = FieldReference.newLambdaParameterReference(0, outerParams, 1); - Expression.Lambda innerLambda = Expression.Lambda.builder() - .parameters(innerParams) - .body(outerRef) - .build(); + Expression.Lambda innerLambda = + Expression.Lambda.builder().parameters(innerParams).body(outerRef).build(); - Expression.Lambda outerLambda = Expression.Lambda.builder() - .parameters(outerParams) - .body(innerLambda) - .build(); + Expression.Lambda outerLambda = + Expression.Lambda.builder().parameters(outerParams).body(innerLambda).build(); verifyRoundTrip(outerLambda); @@ -184,67 +150,55 @@ void nestedLambdaWithOuterRef() { } /** - * Test outer reference type resolution in nested lambdas. - * Building: ($0: i32, $1: i64, $2: string) -> (($0: fp64) -> outer[$2] : string) : func<...> + * Test outer reference type resolution in nested lambdas. Building: ($0: i32, $1: i64, $2: + * string) -> (($0: fp64) -> outer[$2] : string) : func<...> */ @Test void outerRefTypeResolution() { - Type.Struct outerParams = Type.Struct.builder() - .nullable(false) - .addFields(R.I32, R.I64, R.STRING) - .build(); + Type.Struct outerParams = + Type.Struct.builder().nullable(false).addFields(R.I32, R.I64, R.STRING).build(); - Type.Struct innerParams = Type.Struct.builder() - .nullable(false) - .addFields(R.FP64) - .build(); + Type.Struct innerParams = Type.Struct.builder().nullable(false).addFields(R.FP64).build(); // Inner references outer's field 2 (string) with stepsOut=1 FieldReference outerRef = FieldReference.newLambdaParameterReference(2, outerParams, 1); - Expression.Lambda innerLambda = Expression.Lambda.builder() - .parameters(innerParams) - .body(outerRef) - .build(); + Expression.Lambda innerLambda = + Expression.Lambda.builder().parameters(innerParams).body(outerRef).build(); - Expression.Lambda outerLambda = Expression.Lambda.builder() - .parameters(outerParams) - .body(innerLambda) - .build(); + Expression.Lambda outerLambda = + Expression.Lambda.builder().parameters(outerParams).body(innerLambda).build(); verifyRoundTrip(outerLambda); // Verify inner lambda's return type is string (from outer param 2) Expression.Lambda resultInner = (Expression.Lambda) outerLambda.body(); Type.Func innerFuncType = (Type.Func) resultInner.getType(); - assertEquals(R.STRING, innerFuncType.returnType(), + assertEquals( + R.STRING, + innerFuncType.returnType(), "Inner lambda return type should be string from outer.$2"); // Verify body's type is also string - assertEquals(R.STRING, resultInner.body().getType(), - "Body type should be string"); + assertEquals(R.STRING, resultInner.body().getType(), "Body type should be string"); } /** - * Test deeply nested field ref inside Cast. - * Building: ($0: i32) -> cast($0 as i64) : func i64> + * Test deeply nested field ref inside Cast. Building: ($0: i32) -> cast($0 as i64) : func + * i64> */ @Test void deeplyNestedFieldRef() { - Type.Struct params = Type.Struct.builder() - .nullable(false) - .addFields(R.I32) - .build(); + Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); - Expression.Cast castExpr = (Expression.Cast) - ExpressionCreator.cast(R.I64, paramRef, Expression.FailureBehavior.THROW_EXCEPTION); + Expression.Cast castExpr = + (Expression.Cast) + ExpressionCreator.cast(R.I64, paramRef, Expression.FailureBehavior.THROW_EXCEPTION); - Expression.Lambda lambda = Expression.Lambda.builder() - .parameters(params) - .body(castExpr) - .build(); + Expression.Lambda lambda = + Expression.Lambda.builder().parameters(params).body(castExpr).build(); verifyRoundTrip(lambda); @@ -262,26 +216,23 @@ void deeplyNestedFieldRef() { } /** - * Test doubly nested field ref (Cast(Cast(LambdaParamRef))). - * Building: ($0: i32) -> cast(cast($0 as i64) as string) : func string> + * Test doubly nested field ref (Cast(Cast(LambdaParamRef))). Building: ($0: i32) -> cast(cast($0 + * as i64) as string) : func string> */ @Test void doublyNestedFieldRef() { - Type.Struct params = Type.Struct.builder() - .nullable(false) - .addFields(R.I32) - .build(); + Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); - Expression.Cast innerCast = (Expression.Cast) - ExpressionCreator.cast(R.I64, paramRef, Expression.FailureBehavior.THROW_EXCEPTION); - Expression.Cast outerCast = (Expression.Cast) - ExpressionCreator.cast(R.STRING, innerCast, Expression.FailureBehavior.THROW_EXCEPTION); + Expression.Cast innerCast = + (Expression.Cast) + ExpressionCreator.cast(R.I64, paramRef, Expression.FailureBehavior.THROW_EXCEPTION); + Expression.Cast outerCast = + (Expression.Cast) + ExpressionCreator.cast(R.STRING, innerCast, Expression.FailureBehavior.THROW_EXCEPTION); - Expression.Lambda lambda = Expression.Lambda.builder() - .parameters(params) - .body(outerCast) - .build(); + Expression.Lambda lambda = + Expression.Lambda.builder().parameters(params).body(outerCast).build(); verifyRoundTrip(lambda); @@ -296,42 +247,29 @@ void doublyNestedFieldRef() { } /** - * Test lambda with literal body (no parameter references). - * Building: ($0: i32) -> 42 : func i32> + * Test lambda with literal body (no parameter references). Building: ($0: i32) -> 42 : func i32> */ @Test void lambdaWithLiteralBody() { - Type.Struct params = Type.Struct.builder() - .nullable(false) - .addFields(R.I32) - .build(); + Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); Expression body = ExpressionCreator.i32(false, 42); - Expression.Lambda lambda = Expression.Lambda.builder() - .parameters(params) - .body(body) - .build(); + Expression.Lambda lambda = Expression.Lambda.builder().parameters(params).body(body).build(); verifyRoundTrip(lambda); } - /** - * Test lambda getType returns correct Func type. - */ + /** Test lambda getType returns correct Func type. */ @Test void lambdaGetTypeReturnsFunc() { - Type.Struct params = Type.Struct.builder() - .nullable(false) - .addFields(R.I32, R.STRING) - .build(); + Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32, R.STRING).build(); FieldReference paramRef = FieldReference.newLambdaParameterReference(1, params, 0); - Expression.Lambda lambda = Expression.Lambda.builder() - .parameters(params) - .body(paramRef) - .build(); + Expression.Lambda lambda = + Expression.Lambda.builder().parameters(params).body(paramRef).build(); Type lambdaType = lambda.getType(); @@ -347,125 +285,77 @@ void lambdaGetTypeReturnsFunc() { // ==================== Validation Error Tests ==================== /** - * Test that invalid outer reference (stepsOut too high) fails during proto conversion. - * Building: ($0: i32) -> outer[$0] : INVALID (no outer lambda, stepsOut=1) + * Test that invalid outer reference (stepsOut too high) fails during proto conversion. Building: + * ($0: i32) -> outer[$0] : INVALID (no outer lambda, stepsOut=1) */ @Test void invalidOuterRef_stepsOutTooHigh() { - Type.Struct params = Type.Struct.builder() - .nullable(false) - .addFields(R.I32) - .build(); + Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); // Create a parameter reference with stepsOut=1 but no outer lambda exists FieldReference invalidRef = FieldReference.newLambdaParameterReference(0, params, 1); - Expression.Lambda lambda = Expression.Lambda.builder() - .parameters(params) - .body(invalidRef) - .build(); + Expression.Lambda lambda = + Expression.Lambda.builder().parameters(params).body(invalidRef).build(); // Convert to proto - this should work io.substrait.proto.Expression protoExpression = expressionProtoConverter.toProto(lambda); // Converting back should fail because stepsOut=1 references non-existent outer lambda - assertThrows(IllegalArgumentException.class, () -> { - protoExpressionConverter.from(protoExpression); - }, "Should fail when stepsOut references non-existent outer lambda"); + assertThrows( + IllegalArgumentException.class, + () -> { + protoExpressionConverter.from(protoExpression); + }, + "Should fail when stepsOut references non-existent outer lambda"); } /** - * Test that invalid field index (out of bounds) fails during proto conversion. - * Building: ($0: i32) -> $5 : INVALID (only has 1 param) + * Test that invalid field index (out of bounds) fails during proto conversion. Building: ($0: + * i32) -> $5 : INVALID (only has 1 param) */ @Test void invalidFieldIndex_outOfBounds() { - Type.Struct params = Type.Struct.builder() - .nullable(false) - .addFields(R.I32) - .build(); + Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); // Create a reference to field 5, but lambda only has 1 parameter (index 0) // This will fail at build time since newLambdaParameterReference accesses fields.get(5) - assertThrows(IndexOutOfBoundsException.class, () -> { - FieldReference.newLambdaParameterReference(5, params, 0); - }, "Should fail when field index is out of bounds"); + assertThrows( + IndexOutOfBoundsException.class, + () -> { + FieldReference.newLambdaParameterReference(5, params, 0); + }, + "Should fail when field index is out of bounds"); } /** - * Test nested invalid outer ref (stepsOut=2 but only 1 outer lambda). - * Building: ($0: i64) -> (($0: i32) -> outer.outer[$0]) : INVALID (no grandparent lambda) + * Test nested invalid outer ref (stepsOut=2 but only 1 outer lambda). Building: ($0: i64) -> + * (($0: i32) -> outer.outer[$0]) : INVALID (no grandparent lambda) */ @Test void nestedInvalidOuterRef() { - Type.Struct outerParams = Type.Struct.builder() - .nullable(false) - .addFields(R.I64) - .build(); + Type.Struct outerParams = Type.Struct.builder().nullable(false).addFields(R.I64).build(); - Type.Struct innerParams = Type.Struct.builder() - .nullable(false) - .addFields(R.I32) - .build(); + Type.Struct innerParams = Type.Struct.builder().nullable(false).addFields(R.I32).build(); // Inner lambda references stepsOut=2, but only 1 outer lambda exists FieldReference invalidRef = FieldReference.newLambdaParameterReference(0, outerParams, 2); - Expression.Lambda innerLambda = Expression.Lambda.builder() - .parameters(innerParams) - .body(invalidRef) - .build(); + Expression.Lambda innerLambda = + Expression.Lambda.builder().parameters(innerParams).body(invalidRef).build(); - Expression.Lambda outerLambda = Expression.Lambda.builder() - .parameters(outerParams) - .body(innerLambda) - .build(); + Expression.Lambda outerLambda = + Expression.Lambda.builder().parameters(outerParams).body(innerLambda).build(); // Convert to proto io.substrait.proto.Expression protoExpression = expressionProtoConverter.toProto(outerLambda); // Converting back should fail because stepsOut=2 references non-existent grandparent - assertThrows(IllegalArgumentException.class, () -> { - protoExpressionConverter.from(protoExpression); - }, "Should fail when stepsOut references non-existent grandparent lambda"); - } - - /** - * Test deeply nested invalid field ref inside Cast. - * Building: ($0: i32) -> cast($5 as i64) : INVALID (only has 1 param) - */ - @Test - void deeplyNestedInvalidFieldRef() { - Type.Struct params = Type.Struct.builder() - .nullable(false) - .addFields(R.I32) - .build(); - - // This should fail at build time since field 5 doesn't exist - assertThrows(IndexOutOfBoundsException.class, () -> { - FieldReference.newLambdaParameterReference(5, params, 0); - }, "Should fail when nested field index is out of bounds"); - } - - /** - * Test that outer field index out of bounds fails. - * Building: ($0: i64) -> (($0: i32) -> outer[$5]) : INVALID (outer only has 1 param) - */ - @Test - void nestedInvalidOuterFieldIndex() { - Type.Struct outerParams = Type.Struct.builder() - .nullable(false) - .addFields(R.I64) - .build(); - - Type.Struct innerParams = Type.Struct.builder() - .nullable(false) - .addFields(R.I32) - .build(); - - // This should fail at build time since outer only has 1 parameter (index 0) - assertThrows(IndexOutOfBoundsException.class, () -> { - FieldReference.newLambdaParameterReference(5, outerParams, 1); - }, "Should fail when outer field index is out of bounds"); + assertThrows( + IllegalArgumentException.class, + () -> { + protoExpressionConverter.from(protoExpression); + }, + "Should fail when stepsOut references non-existent grandparent lambda"); } } diff --git a/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java b/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java index f9125587d..93f8a9834 100644 --- a/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java +++ b/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java @@ -223,12 +223,6 @@ public String visit(IfThen expr, EmptyVisitationContext context) throws RuntimeE return ""; } - @Override - public String visit(Expression.LambdaInvocation expr, EmptyVisitationContext context) - throws RuntimeException { - return ""; - } - @Override public String visit(ScalarFunctionInvocation expr, EmptyVisitationContext context) throws RuntimeException { diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java index 0f31d3d4b..a118beec0 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java @@ -674,6 +674,8 @@ public RexNode visit(FieldReference expr, Context context) throws RuntimeExcepti return rexInputRef; } else if (expr.isLambdaParameterReference()) { + // as of now calcite doesn't support nested lambda functions + // https://github.com/substrait-io/substrait-java/issues/711 int stepsOut = expr.lambdaParameterReferenceStepsOut().get(); if (stepsOut != 0) { throw new UnsupportedOperationException( diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java index 19b3d2224..176b246e7 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java @@ -208,10 +208,9 @@ public Expression visitLambda(RexLambda rexLambda) { .map(param -> typeConverter.toSubstrait(param.getType())) .collect(Collectors.toList()); - Type.Struct parameters = Type.Struct.builder().nullable(false).addAllFields(paramTypes).build(); + Type.Struct parameters = Type.Struct.builder().nullable(false).addAllFields(paramTypes).build(); - - Expression body = rexLambda.getExpression().accept(this); + Expression body = rexLambda.getExpression().accept(this); return Expression.Lambda.builder().parameters(parameters).body(body).build(); } @@ -225,7 +224,8 @@ public Expression visitLambdaRef(RexLambdaRef rexLambdaRef) { .addSegments(FieldReference.StructField.of(fieldIndex)) .type(paramType) .lambdaParameterReferenceStepsOut( - 0) // Always 0 since Calcite doesn't support nested Lambda expressions + 0) // Always 0 since Calcite doesn't support nested Lambda expressions for now + // https://github.com/substrait-io/substrait-java/issues/711 .build(); } diff --git a/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java b/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java index 8b391a01d..add746ee1 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java @@ -1,5 +1,6 @@ package io.substrait.isthmus; +import static org.junit.jupiter.api.Assertions.assertThrows; import io.substrait.expression.Expression; import io.substrait.expression.ExpressionCreator; @@ -7,185 +8,142 @@ import io.substrait.relation.Project; import io.substrait.relation.Rel; import io.substrait.type.Type; -import org.junit.jupiter.api.Test; - import java.util.ArrayList; import java.util.List; - -import static org.junit.jupiter.api.Assertions.assertThrows; +import org.junit.jupiter.api.Test; /** - * Tests for Lambda expression conversion between Substrait and Calcite. - * Note: Calcite does not support nested lambda expressions for the moment, so all tests use stepsOut=0. + * Tests for Lambda expression conversion between Substrait and Calcite. Note: Calcite does not + * support nested lambda expressions for the moment, so all tests use stepsOut=0. */ - -public class LambdaExpressionTest extends PlanTestBase{ - - final Rel emptyTable = sb.emptyVirtualTableScan(); - - /** - * Test that lambdas with no parameters are valid. - * Building: () -> i32(42) : func<() -> i32> - */ - @Test - void lambdaExpressionZeroParameters(){ - Type.Struct params = Type.Struct.builder() - .nullable(false) - .build(); - - Expression body = ExpressionCreator.i32(false, 42); - List expressionList = new ArrayList<>(); - Expression.Lambda lambda = Expression.Lambda.builder() - .parameters(params) - .body(body) - .build(); - expressionList.add(lambda); - - Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); - assertFullRoundTrip(project); - - } - - /** - * Test valid field index with multiple parameters. - * Building: ($0: i32, $1: i64, $2: string) -> $2 : func<(i32, i64, string) -> string> - */ - @Test - void validFieldIndex() { - Type.Struct params = Type.Struct.builder() - .nullable(false) - .addFields(R.I32, R.I64, R.STRING) - .build(); - - FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); - - List expressionList = new ArrayList<>(); - Expression.Lambda lambda = Expression.Lambda.builder() - .parameters(params) - .body(paramRef) - .build(); - - expressionList.add(lambda); - - - Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); - assertFullRoundTrip(project); - } - - /** - * Test deeply nested field ref inside Cast. - * Building: ($0: i32) -> cast($0 as i64) : func i64> - */ - @Test - void deeplyNestedFieldRef() { - Type.Struct params = Type.Struct.builder() - .nullable(false) - .addFields(R.I32) - .build(); - FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); - - Expression.Cast castExpr = (Expression.Cast) - ExpressionCreator.cast(R.I64, paramRef, Expression.FailureBehavior.THROW_EXCEPTION); - List expressionList = new ArrayList<>(); - - - Expression.Lambda lambda = Expression.Lambda.builder() - .parameters(params) - .body(castExpr) - .build(); - - expressionList.add(lambda); - Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); - assertFullRoundTrip(project); - } - - /** - * Test doubly nested field ref (Cast(Cast(LambdaParamRef))). - * Building: ($0: i32) -> cast(cast($0 as i64) as string) : func string> - */ - @Test - void doublyNestedFieldRef() { - Type.Struct params = Type.Struct.builder() - .nullable(false) - .addFields(R.I32) - .build(); - - FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); - Expression.Cast innerCast = (Expression.Cast) - ExpressionCreator.cast(R.I64, paramRef, Expression.FailureBehavior.THROW_EXCEPTION); - Expression.Cast outerCast = (Expression.Cast) - ExpressionCreator.cast(R.STRING, innerCast, Expression.FailureBehavior.THROW_EXCEPTION); - - List expressionList = new ArrayList<>(); - - Expression.Lambda lambda = Expression.Lambda.builder() - .parameters(params) - .body(outerCast) - .build(); - - expressionList.add(lambda); - Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); - assertFullRoundTrip(project); - } - - /** - * Test lambda with literal body (no parameter references). - * Building: ($0: i32) -> 42 : func i32> - */ - @Test - void lambdaWithLiteralBody() { - Type.Struct params = Type.Struct.builder() - .nullable(false) - .addFields(R.I32) - .build(); - - Expression body = ExpressionCreator.i32(false, 42); - List expressionList = new ArrayList<>(); - - Expression.Lambda lambda = Expression.Lambda.builder() - .parameters(params) - .body(body) - .build(); - - expressionList.add(lambda); - Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); - assertFullRoundTrip(project); - } - - /** - * Test that nested lambda (stepsOut > 0) throws UnsupportedOperationException. - * Calcite does not support nested lambda expressions. - */ - @Test - void nestedLambdaThrowsUnsupportedOperation() { - Type.Struct outerParams = Type.Struct.builder() - .nullable(false) - .addFields(R.I64) - .build(); - - Type.Struct innerParams = Type.Struct.builder() - .nullable(false) - .addFields(R.I32) - .build(); - - // Inner lambda references outer's parameter with stepsOut=1 - FieldReference outerRef = FieldReference.newLambdaParameterReference(0, outerParams, 1); - - Expression.Lambda innerLambda = Expression.Lambda.builder() - .parameters(innerParams) - .body(outerRef) - .build(); - - List expressionList = new ArrayList<>(); - - Expression.Lambda outerLambda = Expression.Lambda.builder() - .parameters(outerParams) - .body(innerLambda) - .build(); - - expressionList.add(outerLambda); - Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); - assertThrows(UnsupportedOperationException.class, () -> assertFullRoundTrip(project)); - - } - } +class LambdaExpressionTest extends PlanTestBase { + + final Rel emptyTable = sb.emptyVirtualTableScan(); + + /** Test that lambdas with no parameters are valid. Building: () -> i32(42) : func<() -> i32> */ + @Test + void lambdaExpressionZeroParameters() { + Type.Struct params = Type.Struct.builder().nullable(false).build(); + + Expression body = ExpressionCreator.i32(false, 42); + List expressionList = new ArrayList<>(); + Expression.Lambda lambda = Expression.Lambda.builder().parameters(params).body(body).build(); + expressionList.add(lambda); + + Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + assertFullRoundTrip(project); + } + + /** + * Test valid field index with multiple parameters. Building: ($0: i32, $1: i64, $2: string) -> $2 + * : func<(i32, i64, string) -> string> + */ + @Test + void validFieldIndex() { + Type.Struct params = + Type.Struct.builder().nullable(false).addFields(R.I32, R.I64, R.STRING).build(); + + FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); + + List expressionList = new ArrayList<>(); + Expression.Lambda lambda = + Expression.Lambda.builder().parameters(params).body(paramRef).build(); + + expressionList.add(lambda); + + Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + assertFullRoundTrip(project); + } + + /** + * Test deeply nested field ref inside Cast. Building: ($0: i32) -> cast($0 as i64) : func + * i64> + */ + @Test + void deeplyNestedFieldRef() { + Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); + FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); + + Expression.Cast castExpr = + (Expression.Cast) + ExpressionCreator.cast(R.I64, paramRef, Expression.FailureBehavior.THROW_EXCEPTION); + List expressionList = new ArrayList<>(); + + Expression.Lambda lambda = + Expression.Lambda.builder().parameters(params).body(castExpr).build(); + + expressionList.add(lambda); + Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + assertFullRoundTrip(project); + } + + /** + * Test doubly nested field ref (Cast(Cast(LambdaParamRef))). Building: ($0: i32) -> cast(cast($0 + * as i64) as string) : func string> + */ + @Test + void doublyNestedFieldRef() { + Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); + + FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); + Expression.Cast innerCast = + (Expression.Cast) + ExpressionCreator.cast(R.I64, paramRef, Expression.FailureBehavior.THROW_EXCEPTION); + Expression.Cast outerCast = + (Expression.Cast) + ExpressionCreator.cast(R.STRING, innerCast, Expression.FailureBehavior.THROW_EXCEPTION); + + List expressionList = new ArrayList<>(); + + Expression.Lambda lambda = + Expression.Lambda.builder().parameters(params).body(outerCast).build(); + + expressionList.add(lambda); + Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + assertFullRoundTrip(project); + } + + /** + * Test lambda with literal body (no parameter references). Building: ($0: i32) -> 42 : func i32> + */ + @Test + void lambdaWithLiteralBody() { + Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); + + Expression body = ExpressionCreator.i32(false, 42); + List expressionList = new ArrayList<>(); + + Expression.Lambda lambda = Expression.Lambda.builder().parameters(params).body(body).build(); + + expressionList.add(lambda); + Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + assertFullRoundTrip(project); + } + + /** + * Test that nested lambda (stepsOut > 0) throws UnsupportedOperationException. Calcite does not + * support nested lambda expressions. + */ + @Test + void nestedLambdaThrowsUnsupportedOperation() { + Type.Struct outerParams = Type.Struct.builder().nullable(false).addFields(R.I64).build(); + + Type.Struct innerParams = Type.Struct.builder().nullable(false).addFields(R.I32).build(); + + // Inner lambda references outer's parameter with stepsOut=1 + FieldReference outerRef = FieldReference.newLambdaParameterReference(0, outerParams, 1); + + Expression.Lambda innerLambda = + Expression.Lambda.builder().parameters(innerParams).body(outerRef).build(); + + List expressionList = new ArrayList<>(); + + Expression.Lambda outerLambda = + Expression.Lambda.builder().parameters(outerParams).body(innerLambda).build(); + + expressionList.add(outerLambda); + Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + assertThrows(UnsupportedOperationException.class, () -> assertFullRoundTrip(project)); + } +} From 40533570d4ef9b3f1765a315f72adaee5ed98e4a Mon Sep 17 00:00:00 2001 From: Limame Malainine Date: Mon, 2 Mar 2026 14:38:33 +0100 Subject: [PATCH 03/26] tweak: encapsulate LambdaParameterStack logic in a class --- .../io/substrait/expression/Expression.java | 2 + .../proto/ProtoExpressionConverter.java | 41 ++++++++++++++----- 2 files changed, 32 insertions(+), 11 deletions(-) diff --git a/core/src/main/java/io/substrait/expression/Expression.java b/core/src/main/java/io/substrait/expression/Expression.java index 320a8a614..00a92d797 100644 --- a/core/src/main/java/io/substrait/expression/Expression.java +++ b/core/src/main/java/io/substrait/expression/Expression.java @@ -718,6 +718,8 @@ public Type getType() { List paramTypes = parameters().fields(); Type returnType = body().getType(); + // TO DO: fix Lambda return type once this issue + // https://github.com/substrait-io/substrait/issues/976 is resolved return Type.withNullability(false).func(paramTypes, returnType); } diff --git a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java index 3780a8e1b..09ee2da59 100644 --- a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java @@ -37,7 +37,7 @@ public class ProtoExpressionConverter { private final Type.Struct rootType; private final ProtoTypeConverter protoTypeConverter; private final ProtoRelConverter protoRelConverter; - private final List lambdaParameterStack = new ArrayList<>(); + private final LambdaParameterStack lambdaParameterStack = new LambdaParameterStack(); public ProtoExpressionConverter( ExtensionLookup lookup, @@ -82,15 +82,8 @@ public FieldReference from(io.substrait.proto.Expression.FieldReference referenc reference.getLambdaParameterReference(); int stepsOut = lambdaParamRef.getStepsOut(); - int lambdaIndex = lambdaParameterStack.size() - 1 - stepsOut; - if (lambdaIndex < 0 || lambdaIndex >= lambdaParameterStack.size()) { - throw new IllegalArgumentException( - String.format( - "Lambda parameter reference with stepsOut=%d is invalid (current depth: %d)", - stepsOut, lambdaParameterStack.size())); - } + Type.Struct lambdaParameters = lambdaParameterStack.get(stepsOut); - Type.Struct lambdaParameters = lambdaParameterStack.get(lambdaIndex); return FieldReference.newLambdaParameterReference( reference.getDirectReference().getStructField().getField(), lambdaParameters, @@ -291,13 +284,13 @@ public Type visit(Type.Struct type) throws RuntimeException { .setStruct(protoLambda.getParameters()) .build()); - lambdaParameterStack.add(parameters); + lambdaParameterStack.push(parameters); Expression body; try { body = from(protoLambda.getBody()); } finally { - lambdaParameterStack.remove(lambdaParameterStack.size() - 1); + lambdaParameterStack.pop(); } return Expression.Lambda.builder().parameters(parameters).body(body).build(); @@ -616,4 +609,30 @@ public Expression.SortField fromSortField(SortField s) { public static FunctionOption fromFunctionOption(io.substrait.proto.FunctionOption o) { return FunctionOption.builder().name(o.getName()).addAllValues(o.getPreferenceList()).build(); } + + private static class LambdaParameterStack { + private final List stack = new ArrayList<>(); + + void push(Type.Struct parameters) { + stack.add(parameters); + } + + void pop() { + if (stack.isEmpty()) { + throw new IllegalArgumentException("Lambda parameter stack is empty"); + } + stack.remove(stack.size() - 1); + } + + Type.Struct get(int stepsOut) { + int index = stack.size() - 1 - stepsOut; + if (index < 0 || index >= stack.size()) { + throw new IllegalArgumentException( + String.format( + "Lambda parameter reference with stepsOut=%d is invalid (current depth: %d)", + stepsOut, stack.size())); + } + return stack.get(index); + } + } } From eb7b2b23c8fb6e2f4eadde602c4db8128b0c99ea Mon Sep 17 00:00:00 2001 From: Limame Malainine Date: Mon, 2 Mar 2026 16:31:49 +0100 Subject: [PATCH 04/26] tweak: adding comments expailining LambdaParameterStack --- .../expression/proto/ProtoExpressionConverter.java | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java index 09ee2da59..ff864c259 100644 --- a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java @@ -610,6 +610,18 @@ public static FunctionOption fromFunctionOption(io.substrait.proto.FunctionOptio return FunctionOption.builder().name(o.getName()).addAllValues(o.getPreferenceList()).build(); } + /** + * A stack for tracking lambda parameter types during expression parsing. + * + *

When parsing nested lambda expressions, each lambda's parameters are pushed onto this stack. + * Lambda parameter references use "stepsOut" to indicate which enclosing lambda they reference: + * + *

    + *
  • stepsOut=0 refers to the innermost (current) lambda + *
  • stepsOut=1 refers to the next enclosing lambda + *
  • stepsOut=N refers to N levels up + *
+ */ private static class LambdaParameterStack { private final List stack = new ArrayList<>(); From 1d1b26ee9def88df3c6f5f2f4c1636e7ec94377f Mon Sep 17 00:00:00 2001 From: Victor Barua Date: Fri, 6 Mar 2026 16:04:44 -0800 Subject: [PATCH 05/26] build: ignore build related generated files --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index 5eab9ab03..4093b00d1 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,6 @@ out/** .metals .bloop .project +.classpath +.settings +bin/ From 2224c4b5708969500ca6e9a6148cfbc5b11e233f Mon Sep 17 00:00:00 2001 From: Victor Barua Date: Fri, 6 Mar 2026 13:41:28 -0800 Subject: [PATCH 06/26] feat: enable parsing of func types in extensions --- .../extension/DefaultExtensionCatalog.java | 1 + .../io/substrait/function/ToTypeString.java | 10 ++++ .../io/substrait/type/parser/ParseToPojo.java | 49 +++++++++++++++++-- .../DefaultExtensionCatalogTest.java | 13 +++++ .../substrait/type/parser/TestTypeParser.java | 16 ++++++ 5 files changed, 86 insertions(+), 3 deletions(-) create mode 100644 core/src/test/java/io/substrait/extension/DefaultExtensionCatalogTest.java diff --git a/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java b/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java index 478ed63e2..d8b6b1fa6 100644 --- a/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java +++ b/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java @@ -78,6 +78,7 @@ private static SimpleExtension.ExtensionCollection loadDefaultCollection() { "arithmetic", "comparison", "datetime", + "list", "logarithmic", "rounding", "rounding_decimal", diff --git a/core/src/main/java/io/substrait/function/ToTypeString.java b/core/src/main/java/io/substrait/function/ToTypeString.java index d6fc1bdb8..5c942f46a 100644 --- a/core/src/main/java/io/substrait/function/ToTypeString.java +++ b/core/src/main/java/io/substrait/function/ToTypeString.java @@ -150,6 +150,11 @@ public String visit(final Type.Map expr) { return "map"; } + @Override + public String visit(Type.Func type) throws RuntimeException { + return "func"; + } + @Override public String visit(final Type.UserDefined expr) { return String.format("u!%s", expr.name()); @@ -210,6 +215,11 @@ public String visit(ParameterizedType.Map expr) throws RuntimeException { return "map"; } + @Override + public String visit(ParameterizedType.Func expr) throws RuntimeException { + return "func"; + } + @Override public String visit(ParameterizedType.StringLiteral expr) throws RuntimeException { if (expr.value().toLowerCase().startsWith("any")) { diff --git a/core/src/main/java/io/substrait/type/parser/ParseToPojo.java b/core/src/main/java/io/substrait/type/parser/ParseToPojo.java index 8555ab219..ddad886b1 100644 --- a/core/src/main/java/io/substrait/type/parser/ParseToPojo.java +++ b/core/src/main/java/io/substrait/type/parser/ParseToPojo.java @@ -1,5 +1,6 @@ package io.substrait.type.parser; +import io.substrait.function.ImmutableParameterizedType; import io.substrait.function.ImmutableTypeExpression; import io.substrait.function.ParameterizedType; import io.substrait.function.ParameterizedTypeCreator; @@ -411,19 +412,61 @@ public TypeExpression visitNStruct(final SubstraitTypeParser.NStructContext ctx) @Override public TypeExpression visitFunc(final SubstraitTypeParser.FuncContext ctx) { - throw new UnsupportedOperationException(); + boolean nullable = ctx.isnull != null; + + // Process function parameters + List paramExprs; + if (ctx.params instanceof SubstraitTypeParser.SingleFuncParamContext) { + paramExprs = + java.util.Collections.singletonList( + ((SubstraitTypeParser.SingleFuncParamContext) ctx.params).expr().accept(this)); + } else if (ctx.params instanceof SubstraitTypeParser.FuncParamsWithParensContext) { + paramExprs = + ((SubstraitTypeParser.FuncParamsWithParensContext) ctx.params) + .expr().stream() + .map(e -> e.accept(this)) + .collect(java.util.stream.Collectors.toList()); + } else { + throw new UnsupportedOperationException( + "Unknown funcParams type: " + ctx.params.getClass()); + } + + // Process return type + TypeExpression returnExpr = ctx.returnType.accept(this); + + // If all types are instances of Type, we return a Type + if (paramExprs.stream().allMatch(p -> p instanceof Type) && returnExpr instanceof Type) { + ImmutableType.Func.Builder builder = ImmutableType.Func.builder().nullable(nullable); + paramExprs.forEach(p -> builder.addParameterTypes((Type) p)); + return builder.returnType((Type) returnExpr).build(); + } + + // If all types are instances of ParameterizedType, we return a ParameterizedType + if (paramExprs.stream().allMatch(p -> p instanceof ParameterizedType) + && returnExpr instanceof ParameterizedType) { + checkParameterizedOrExpression(); + ImmutableParameterizedType.Func.Builder builder = + ParameterizedType.Func.builder().nullable(nullable); + paramExprs.forEach(p -> builder.addParameterTypes((ParameterizedType) p)); + return builder.returnType((ParameterizedType) returnExpr).build(); + } + + throw new UnsupportedOperationException( + "func type with TypeExpression-level parameter or return types are not yet supported"); } @Override public TypeExpression visitSingleFuncParam( final SubstraitTypeParser.SingleFuncParamContext ctx) { - throw new UnsupportedOperationException(); + throw new UnsupportedOperationException( + "visitSingleFuncParam is handled in visitFunc directly"); } @Override public TypeExpression visitFuncParamsWithParens( final SubstraitTypeParser.FuncParamsWithParensContext ctx) { - throw new UnsupportedOperationException(); + throw new UnsupportedOperationException( + "visitFuncParamsWithParens is handled in visitFunc directly"); } @Override diff --git a/core/src/test/java/io/substrait/extension/DefaultExtensionCatalogTest.java b/core/src/test/java/io/substrait/extension/DefaultExtensionCatalogTest.java new file mode 100644 index 000000000..7aed4e961 --- /dev/null +++ b/core/src/test/java/io/substrait/extension/DefaultExtensionCatalogTest.java @@ -0,0 +1,13 @@ +package io.substrait.extension; + +import static org.junit.jupiter.api.Assertions.assertNotNull; + +import org.junit.jupiter.api.Test; + +class DefaultExtensionCatalogTest { + + @Test + void defaultCollectionLoads() { + assertNotNull(DefaultExtensionCatalog.DEFAULT_COLLECTION); + } +} diff --git a/core/src/test/java/io/substrait/type/parser/TestTypeParser.java b/core/src/test/java/io/substrait/type/parser/TestTypeParser.java index 010ad123a..92989795b 100644 --- a/core/src/test/java/io/substrait/type/parser/TestTypeParser.java +++ b/core/src/test/java/io/substrait/type/parser/TestTypeParser.java @@ -8,6 +8,7 @@ import io.substrait.type.ImmutableType; import io.substrait.type.Type; import io.substrait.type.TypeCreator; +import java.util.List; import org.junit.jupiter.api.Test; class TestTypeParser { @@ -120,6 +121,11 @@ private void compoundTests(ParseToPojo.Visitor v) { test(v, n.precisionTimestamp(9), "PRECISION_TIMESTAMP?<9>"); test(v, r.precisionTimestampTZ(6), "PRECISION_TIMESTAMP_TZ<6>"); test(v, n.precisionTimestampTZ(9), "PRECISION_TIMESTAMP_TZ?<9>"); + + test(v, r.func(List.of(r.I8), r.I32), "func i32>"); + test(v, r.func(List.of(r.I8, r.I8), r.I32), "func<(i8, i8) -> i32>"); + test(v, n.func(List.of(r.I8), r.I32), "func? i32>"); + test(v, r.func(List.of(n.I8), n.I32), "func i32?>"); } private void parameterizedTests(ParseToPojo.Visitor v) { @@ -142,6 +148,16 @@ private void parameterizedTests(ParseToPojo.Visitor v) { test(v, pr.precisionTimeE("P"), "PRECISION_TIME

"); test(v, pr.precisionTimestampE("P"), "PRECISION_TIMESTAMP

"); test(v, pr.precisionTimestampTZE("P"), "PRECISION_TIMESTAMP_TZ

"); + + test(v, pr.funcE(List.of(pr.parameter("any")), r.I64), "func i64>"); + test(v, pr.funcE(List.of(pr.parameter("any"), r.I64), r.I64), "func<(any, i64) -> i64>"); + test(v, pr.funcE(List.of(pr.parameter("any1")), pr.parameter("any1")), "func any1>"); + test(v, pn.funcE(List.of(pr.parameter("any")), n.I64), "func? i64?>"); + test(v, pn.funcE(List.of(pr.parameter("any1")), pr.parameter("any1")), "func? any1>"); + test( + v, + pr.funcE(List.of(pr.parameter("any1"), r.I8), pr.parameter("any1")), + "func<(any1, i8) -> any1>"); } @Test From 83b9e2480cfe0065ded455ee417abf6a79258eac Mon Sep 17 00:00:00 2001 From: Victor Barua Date: Tue, 10 Mar 2026 14:47:27 -0700 Subject: [PATCH 07/26] test: copy substrait-go lambda plans Note that these had to be tweaked slightly because they were not entirely valid. This will be fixed in substrait-go. --- .../test/resources/lambdas/basic-lambda.json | 132 ++++++++++++++ .../resources/lambdas/lambda-field-ref.json | 135 ++++++++++++++ .../lambdas/lambda-with-function.json | 172 ++++++++++++++++++ 3 files changed, 439 insertions(+) create mode 100644 isthmus/src/test/resources/lambdas/basic-lambda.json create mode 100644 isthmus/src/test/resources/lambdas/lambda-field-ref.json create mode 100644 isthmus/src/test/resources/lambdas/lambda-with-function.json diff --git a/isthmus/src/test/resources/lambdas/basic-lambda.json b/isthmus/src/test/resources/lambdas/basic-lambda.json new file mode 100644 index 000000000..114e3ad6d --- /dev/null +++ b/isthmus/src/test/resources/lambdas/basic-lambda.json @@ -0,0 +1,132 @@ +{ + "version": { + "majorNumber": 0, + "minorNumber": 79 + }, + "extensionUrns": [ + { + "extensionUrnAnchor": 1, + "urn": "extension:io.substrait:functions_list" + } + ], + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_list.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUrnReference": 1, + "extensionUriReference": 1, + "functionAnchor": 1, + "name": "transform:list_func" + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "values" + ], + "struct": { + "nullability": "NULLABILITY_REQUIRED", + "types": [ + { + "list": { + "type": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "nullability": "NULLABILITY_REQUIRED" + } + } + ] + } + }, + "namedTable": { + "names": [ + "test_table" + ] + } + } + }, + "expressions": [ + { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "list": { + "type": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "lambda": { + "parameters": { + "nullability": "NULLABILITY_REQUIRED", + "types": [ + { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + } + ] + }, + "body": { + "literal": { + "i32": 42 + } + } + } + } + } + ] + } + } + ] + } + }, + "names": [ + "result" + ] + } + } + ] +} diff --git a/isthmus/src/test/resources/lambdas/lambda-field-ref.json b/isthmus/src/test/resources/lambdas/lambda-field-ref.json new file mode 100644 index 000000000..58c041582 --- /dev/null +++ b/isthmus/src/test/resources/lambdas/lambda-field-ref.json @@ -0,0 +1,135 @@ +{ + "extensionUrns": [ + { + "extensionUrnAnchor": 1, + "urn": "extension:io.substrait:functions_list" + } + ], + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_list.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUrnReference": 1, + "extensionUriReference": 1, + "functionAnchor": 1, + "name": "transform:list_func" + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "values" + ], + "struct": { + "nullability": "NULLABILITY_REQUIRED", + "types": [ + { + "list": { + "type": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "nullability": "NULLABILITY_REQUIRED" + } + } + ] + } + }, + "namedTable": { + "names": [ + "test_table" + ] + } + } + }, + "expressions": [ + { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "list": { + "type": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "lambda": { + "parameters": { + "nullability": "NULLABILITY_REQUIRED", + "types": [ + { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + } + ] + }, + "body": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "lambdaParameterReference": { + "stepsOut": 0 + } + } + } + } + } + } + ] + } + } + ] + } + }, + "names": [ + "result" + ] + } + } + ] +} diff --git a/isthmus/src/test/resources/lambdas/lambda-with-function.json b/isthmus/src/test/resources/lambdas/lambda-with-function.json new file mode 100644 index 000000000..9c0a1a55b --- /dev/null +++ b/isthmus/src/test/resources/lambdas/lambda-with-function.json @@ -0,0 +1,172 @@ +{ + "extensionUrns": [ + { + "extensionUrnAnchor": 1, + "urn": "extension:io.substrait:functions_arithmetic" + }, + { + "extensionUrnAnchor": 2, + "urn": "extension:io.substrait:functions_list" + } + ], + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml" + }, + { + "extensionUriAnchor": 2, + "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_list.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUrnReference": 1, + "extensionUriReference": 1, + "functionAnchor": 1, + "name": "multiply:i32_i32" + } + }, + { + "extensionFunction": { + "extensionUrnReference": 2, + "extensionUriReference": 2, + "functionAnchor": 2, + "name": "transform:list_func" + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "values" + ], + "struct": { + "nullability": "NULLABILITY_REQUIRED", + "types": [ + { + "list": { + "type": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "nullability": "NULLABILITY_REQUIRED" + } + } + ] + } + }, + "namedTable": { + "names": [ + "test_table" + ] + } + } + }, + "expressions": [ + { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "list": { + "type": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "lambda": { + "parameters": { + "nullability": "NULLABILITY_REQUIRED", + "types": [ + { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + } + ] + }, + "body": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "lambdaParameterReference": { + "stepsOut": 0 + } + } + } + }, + { + "value": { + "literal": { + "i32": 2 + } + } + } + ] + } + } + } + } + } + ] + } + } + ] + } + }, + "names": [ + "result" + ] + } + } + ] +} From d802c19fe9c76fe98c949da77aa498244521d158 Mon Sep 17 00:00:00 2001 From: Victor Barua Date: Tue, 10 Mar 2026 14:48:09 -0700 Subject: [PATCH 08/26] test: add LambdaRoundtripTests --- .../isthmus/LambdaRoundtripTest.java | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 isthmus/src/test/java/io/substrait/isthmus/LambdaRoundtripTest.java diff --git a/isthmus/src/test/java/io/substrait/isthmus/LambdaRoundtripTest.java b/isthmus/src/test/java/io/substrait/isthmus/LambdaRoundtripTest.java new file mode 100644 index 000000000..d34fa7744 --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/LambdaRoundtripTest.java @@ -0,0 +1,38 @@ +package io.substrait.isthmus; + +import com.google.protobuf.util.JsonFormat; +import io.substrait.plan.Plan; +import io.substrait.plan.ProtoPlanConverter; +import java.io.IOException; +import org.junit.jupiter.api.Test; + +public class LambdaRoundtripTest extends PlanTestBase { + + public static io.substrait.proto.Plan readJsonPlan(String resourcePath) throws IOException { + String json = asString(resourcePath); + io.substrait.proto.Plan.Builder builder = io.substrait.proto.Plan.newBuilder(); + JsonFormat.parser().merge(json, builder); + return builder.build(); + } + + @Test + void testBasicLambdaRoundtrip() throws IOException { + io.substrait.proto.Plan proto = readJsonPlan("lambdas/basic-lambda.json"); + Plan plan = new ProtoPlanConverter(extensions).from(proto); + assertFullRoundTrip(plan.getRoots().get(0)); + } + + @Test + void testLambdaWithFieldRefRoundtrip() throws IOException { + io.substrait.proto.Plan proto = readJsonPlan("lambdas/lambda-field-ref.json"); + Plan plan = new ProtoPlanConverter(extensions).from(proto); + assertFullRoundTrip(plan.getRoots().get(0)); + } + + @Test + void testLambdaWithFunctionRoundtrip() throws IOException { + io.substrait.proto.Plan proto = readJsonPlan("lambdas/lambda-with-function.json"); + Plan plan = new ProtoPlanConverter(extensions).from(proto); + assertFullRoundTrip(plan.getRoots().get(0)); + } +} From 14738e31fe451e243025c77c465482077b58dc22 Mon Sep 17 00:00:00 2001 From: Victor Barua Date: Mon, 9 Mar 2026 07:56:42 -0700 Subject: [PATCH 09/26] feat(isthmus): add TRANSFORM SqlFunction to handle transform:list_func --- .../isthmus/expression/FunctionMappings.java | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java index 90f04a326..87b688b58 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java @@ -5,14 +5,25 @@ import java.util.Locale; import java.util.Map; import java.util.Set; +import org.apache.calcite.sql.SqlBasicFunction; +import org.apache.calcite.sql.SqlFunction; import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.fun.SqlLibraryOperators; import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.type.OperandTypes; +import org.apache.calcite.sql.type.SqlTypeFamily; public class FunctionMappings { // Static list of signature mapping between Calcite SQL operators and Substrait base function // names. + /** The transform:list_func function; applies a lambda to each element of an array. */ + public static final SqlFunction TRANSFORM = + SqlBasicFunction.create( + "transform", + opBinding -> opBinding.getTypeFactory().createArrayType(opBinding.getOperandType(1), -1), + OperandTypes.family(SqlTypeFamily.ARRAY, SqlTypeFamily.ANY)); + public static final ImmutableList SCALAR_SIGS = ImmutableList.builder() .add( @@ -100,7 +111,8 @@ public class FunctionMappings { s(SqlLibraryOperators.RPAD, "rpad"), s(SqlLibraryOperators.PARSE_TIME, "strptime_time"), s(SqlLibraryOperators.PARSE_TIMESTAMP, "strptime_timestamp"), - s(SqlLibraryOperators.PARSE_DATE, "strptime_date")) + s(SqlLibraryOperators.PARSE_DATE, "strptime_date"), + s(TRANSFORM, "transform")) .build(); public static final ImmutableList AGGREGATE_SIGS = From 6fae010c22e5ff204841ddf2a0da4d0472630ff8 Mon Sep 17 00:00:00 2001 From: Limame Malainine Date: Wed, 11 Mar 2026 16:44:03 +0100 Subject: [PATCH 10/26] tweak: add filter in the function mapping --- .../substrait/isthmus/expression/FunctionMappings.java | 10 +++++++++- .../java/io/substrait/isthmus/LambdaRoundtripTest.java | 2 +- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java index 87b688b58..cc2b098e2 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java @@ -24,6 +24,13 @@ public class FunctionMappings { opBinding -> opBinding.getTypeFactory().createArrayType(opBinding.getOperandType(1), -1), OperandTypes.family(SqlTypeFamily.ARRAY, SqlTypeFamily.ANY)); + /** The filter:list_func function; filters elements of an array using a predicate lambda. */ + public static final SqlFunction FILTER = + SqlBasicFunction.create( + "filter", + opBinding -> opBinding.getOperandType(0), + OperandTypes.family(SqlTypeFamily.ARRAY, SqlTypeFamily.ANY)); + public static final ImmutableList SCALAR_SIGS = ImmutableList.builder() .add( @@ -112,7 +119,8 @@ public class FunctionMappings { s(SqlLibraryOperators.PARSE_TIME, "strptime_time"), s(SqlLibraryOperators.PARSE_TIMESTAMP, "strptime_timestamp"), s(SqlLibraryOperators.PARSE_DATE, "strptime_date"), - s(TRANSFORM, "transform")) + s(TRANSFORM, "transform"), + s(FILTER, "filter")) .build(); public static final ImmutableList AGGREGATE_SIGS = diff --git a/isthmus/src/test/java/io/substrait/isthmus/LambdaRoundtripTest.java b/isthmus/src/test/java/io/substrait/isthmus/LambdaRoundtripTest.java index d34fa7744..427dc67ba 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/LambdaRoundtripTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/LambdaRoundtripTest.java @@ -6,7 +6,7 @@ import java.io.IOException; import org.junit.jupiter.api.Test; -public class LambdaRoundtripTest extends PlanTestBase { +class LambdaRoundtripTest extends PlanTestBase { public static io.substrait.proto.Plan readJsonPlan(String resourcePath) throws IOException { String json = asString(resourcePath); From 1c1f8d2d6977a6038c4725c184320694a7bd1b0c Mon Sep 17 00:00:00 2001 From: Limame Malainine Date: Fri, 13 Mar 2026 14:16:53 +0100 Subject: [PATCH 11/26] adress @vbarua's comments --- core/src/main/java/io/substrait/expression/Expression.java | 5 +++-- .../expression/proto/ProtoExpressionConverter.java | 6 ++++++ .../main/java/io/substrait/relation/VirtualTableScan.java | 3 +-- .../io/substrait/examples/util/ExpressionStringify.java | 2 +- 4 files changed, 11 insertions(+), 5 deletions(-) diff --git a/core/src/main/java/io/substrait/expression/Expression.java b/core/src/main/java/io/substrait/expression/Expression.java index a519a2166..72510ad2c 100644 --- a/core/src/main/java/io/substrait/expression/Expression.java +++ b/core/src/main/java/io/substrait/expression/Expression.java @@ -769,8 +769,9 @@ public Type getType() { List paramTypes = parameters().fields(); Type returnType = body().getType(); - // TO DO: fix Lambda return type once this issue - // https://github.com/substrait-io/substrait/issues/976 is resolved + // TODO: Type.Func nullability is hardcoded to false here because the spec does not allow for + // declaring otherwise. + // See: https://github.com/substrait-io/substrait/issues/976 return Type.withNullability(false).func(paramTypes, returnType); } diff --git a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java index 603196ebd..290e88c49 100644 --- a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java @@ -84,6 +84,12 @@ public FieldReference from(io.substrait.proto.Expression.FieldReference referenc int stepsOut = lambdaParamRef.getStepsOut(); Type.Struct lambdaParameters = lambdaParameterStack.get(stepsOut); + // Check for unsupported nested field access + if (reference.getDirectReference().getStructField().hasChild()) { + throw new UnsupportedOperationException( + "Nested field access in lambda parameters is not yet supported"); + } + return FieldReference.newLambdaParameterReference( reference.getDirectReference().getStructField().getField(), lambdaParameters, diff --git a/core/src/main/java/io/substrait/relation/VirtualTableScan.java b/core/src/main/java/io/substrait/relation/VirtualTableScan.java index de7c29dbb..36e5b6dcc 100644 --- a/core/src/main/java/io/substrait/relation/VirtualTableScan.java +++ b/core/src/main/java/io/substrait/relation/VirtualTableScan.java @@ -245,8 +245,7 @@ public Integer visit(Type.UserDefined type) throws RuntimeException { @Override public Integer visit(Type.Func type) throws RuntimeException { - return type.parameterTypes().stream().mapToInt(p -> p.accept(this)).sum() - + type.returnType().accept(this); + return 0; } } } diff --git a/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java b/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java index 72f5ac556..2d345642c 100644 --- a/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java +++ b/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java @@ -205,7 +205,7 @@ public String visit(Expression.NestedStruct expr, EmptyVisitationContext context @Override public String visit(Expression.Lambda expr, EmptyVisitationContext context) throws RuntimeException { - return ""; + return ""; } @Override From bd30a21e0816ff941cdfb5cbd8fa820101d2d316 Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Mon, 16 Mar 2026 13:52:10 -0400 Subject: [PATCH 12/26] feat(core): add LambdaBuilder for build-time validation of lambda parameter references Introduces LambdaBuilder, a context-aware builder that maintains a lambda parameter stack (lambdaContext) to validate parameter references at build time. Nested lambdas use the same builder, ensuring stepsOut is computed automatically. Mirrors the lambdaContext pattern from substrait-go. --- .../io/substrait/expression/Expression.java | 4 - .../substrait/expression/LambdaBuilder.java | 99 ++++ .../proto/ProtoExpressionConverter.java | 4 +- .../ExpressionCopyOnWriteVisitor.java | 6 +- .../expression/LambdaBuilderTest.java | 44 ++ .../proto/LambdaExpressionRoundtripTest.java | 473 ++++++++---------- .../expression/RexExpressionConverter.java | 3 +- .../isthmus/LambdaExpressionTest.java | 127 +---- 8 files changed, 377 insertions(+), 383 deletions(-) create mode 100644 core/src/main/java/io/substrait/expression/LambdaBuilder.java create mode 100644 core/src/test/java/io/substrait/expression/LambdaBuilderTest.java diff --git a/core/src/main/java/io/substrait/expression/Expression.java b/core/src/main/java/io/substrait/expression/Expression.java index 72510ad2c..b116cbdc2 100644 --- a/core/src/main/java/io/substrait/expression/Expression.java +++ b/core/src/main/java/io/substrait/expression/Expression.java @@ -775,10 +775,6 @@ public Type getType() { return Type.withNullability(false).func(paramTypes, returnType); } - public static ImmutableExpression.Lambda.Builder builder() { - return ImmutableExpression.Lambda.builder(); - } - @Override public R accept( ExpressionVisitor visitor, C context) throws E { diff --git a/core/src/main/java/io/substrait/expression/LambdaBuilder.java b/core/src/main/java/io/substrait/expression/LambdaBuilder.java new file mode 100644 index 000000000..b2b89c02d --- /dev/null +++ b/core/src/main/java/io/substrait/expression/LambdaBuilder.java @@ -0,0 +1,99 @@ +package io.substrait.expression; + +import io.substrait.type.Type; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Function; + +/** + * Builds lambda expressions with build-time validation of parameter references. + * + *

Maintains a stack of lambda parameter scopes. Each call to {@link #lambda} pushes parameters + * onto the stack, builds the body via a callback, and pops. Nested lambdas simply call {@code + * lambda()} again on the same builder. + * + *

The callback receives a {@link Scope} handle for creating validated parameter references. The + * correct {@code stepsOut} value is computed automatically from the stack. + * + *

{@code
+ * LambdaBuilder lb = new LambdaBuilder();
+ *
+ * // Simple: (x: i32) -> x
+ * Expression.Lambda simple = lb.lambda(List.of(R.I32), x -> x.ref(0));
+ *
+ * // Nested: (x: i32) -> (y: i64) -> add(x, y)
+ * Expression.Lambda nested = lb.lambda(List.of(R.I32), x ->
+ *     lb.lambda(List.of(R.I64), y ->
+ *         add(x.ref(0), y.ref(0))
+ *     )
+ * );
+ * }
+ */ +public class LambdaBuilder { + private final List lambdaContext = new ArrayList<>(); + + /** + * Builds a lambda expression. The body function receives a {@link Scope} for creating validated + * parameter references. Nested lambdas are built by calling this method again inside the + * callback. + * + * @param paramTypes the lambda's parameter types + * @param bodyFn function that builds the lambda body given a scope handle + * @return the constructed lambda expression + */ + public Expression.Lambda lambda(List paramTypes, Function bodyFn) { + Type.Struct params = Type.Struct.builder().nullable(false).addAllFields(paramTypes).build(); + pushLambdaContext(params); + try { + int index = lambdaContext.size() - 1; + Scope scope = new Scope(index); + Expression body = bodyFn.apply(scope); + return ImmutableExpression.Lambda.builder().parameters(params).body(body).build(); + } finally { + popLambdaContext(); + } + } + + /** + * Pushes a lambda's parameters onto the context stack. This makes the parameters available for + * validation when building the lambda's body, and allows nested lambda parameter references to + * correctly compute their stepsOut values. + */ + private void pushLambdaContext(Type.Struct params) { + lambdaContext.add(params); + } + + /** + * Pops the most recently pushed lambda parameters from the context stack. Called after a lambda's + * body has been built, restoring the context to the enclosing lambda's scope. + */ + private void popLambdaContext() { + lambdaContext.remove(lambdaContext.size() - 1); + } + + /** + * A handle to a particular lambda's parameter scope. Use {@link #ref} to create validated + * parameter references. + */ + public class Scope { + private final int index; + + private Scope(int index) { + this.index = index; + } + + /** + * Creates a validated reference to a parameter of this lambda. The correct {@code stepsOut} + * value is computed automatically. + * + * @param paramIndex index of the parameter within this lambda's parameter struct + * @return a {@link FieldReference} pointing to the specified parameter + * @throws IndexOutOfBoundsException if paramIndex is out of bounds + */ + public FieldReference ref(int paramIndex) { + int stepsOut = lambdaContext.size() - 1 - index; + return FieldReference.newLambdaParameterReference( + paramIndex, lambdaContext.get(index), stepsOut); + } + } +} diff --git a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java index 290e88c49..426b0f56a 100644 --- a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java @@ -6,6 +6,7 @@ import io.substrait.expression.FieldReference.ReferenceSegment; import io.substrait.expression.FunctionArg; import io.substrait.expression.FunctionOption; +import io.substrait.expression.ImmutableExpression; import io.substrait.expression.WindowBound; import io.substrait.extension.ExtensionLookup; import io.substrait.extension.SimpleExtension; @@ -282,6 +283,7 @@ public Type visit(Type.Struct type) throws RuntimeException { case LAMBDA: { + // TODO: Add build-time validation of lambda parameter references during deserialization. io.substrait.proto.Expression.Lambda protoLambda = expr.getLambda(); Type.Struct parameters = (Type.Struct) @@ -299,7 +301,7 @@ public Type visit(Type.Struct type) throws RuntimeException { lambdaParameterStack.pop(); } - return Expression.Lambda.builder().parameters(parameters).body(body).build(); + return ImmutableExpression.Lambda.builder().parameters(parameters).body(body).build(); } // TODO enum. case ENUM: diff --git a/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java b/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java index 11fc14f27..b097614a0 100644 --- a/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java +++ b/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java @@ -8,6 +8,7 @@ import io.substrait.expression.ExpressionVisitor; import io.substrait.expression.FieldReference; import io.substrait.expression.FunctionArg; +import io.substrait.expression.ImmutableExpression; import io.substrait.util.EmptyVisitationContext; import java.util.List; import java.util.Optional; @@ -448,7 +449,10 @@ public Optional visit(Expression.Lambda lambda, EmptyVisitationConte return Optional.empty(); } return Optional.of( - Expression.Lambda.builder().from(lambda).body(newBody.orElse(lambda.body())).build()); + ImmutableExpression.Lambda.builder() + .from(lambda) + .body(newBody.orElse(lambda.body())) + .build()); } // utilities diff --git a/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java b/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java new file mode 100644 index 000000000..7d8f8976b --- /dev/null +++ b/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java @@ -0,0 +1,44 @@ +package io.substrait.expression; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +import io.substrait.type.TypeCreator; +import java.util.List; +import org.junit.jupiter.api.Test; + +/** Tests for {@link LambdaBuilder} build-time validation. */ +class LambdaBuilderTest { + + static final TypeCreator R = TypeCreator.REQUIRED; + + final LambdaBuilder lb = new LambdaBuilder(); + + // (x: i32) -> x[5] — field index 5 is out of bounds (only 1 param) + @Test + void invalidFieldIndex_outOfBounds() { + assertThrows( + IndexOutOfBoundsException.class, () -> lb.lambda(List.of(R.I32), params -> params.ref(5))); + } + + // (x: i32) -> x[-1] — negative field index + @Test + void negativeFieldIndex() { + assertThrows(Exception.class, () -> lb.lambda(List.of(R.I32), params -> params.ref(-1))); + } + + // (x: i32) -> (y: i64) -> x[5] — outer field index 5 is out of bounds + @Test + void nestedOuterFieldIndexOutOfBounds() { + assertThrows( + IndexOutOfBoundsException.class, + () -> lb.lambda(List.of(R.I32), outer -> lb.lambda(List.of(R.I64), inner -> outer.ref(5)))); + } + + // (x: i32) -> (y: i64) -> y[3] — inner field index 3 is out of bounds (only 1 param) + @Test + void nestedInnerFieldIndexOutOfBounds() { + assertThrows( + IndexOutOfBoundsException.class, + () -> lb.lambda(List.of(R.I32), outer -> lb.lambda(List.of(R.I64), inner -> inner.ref(3)))); + } +} diff --git a/core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java index c0a979a94..3b9cfedad 100644 --- a/core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java @@ -1,361 +1,298 @@ package io.substrait.type.proto; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertInstanceOf; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; import io.substrait.TestBase; import io.substrait.expression.Expression; import io.substrait.expression.ExpressionCreator; import io.substrait.expression.FieldReference; +import io.substrait.expression.LambdaBuilder; +import io.substrait.extension.DefaultExtensionCatalog; import io.substrait.type.Type; import java.util.List; import org.junit.jupiter.api.Test; -/** - * Tests for Lambda expression round-trip conversion through protobuf. Based on equivalent tests - * from substrait-go. - */ class LambdaExpressionRoundtripTest extends TestBase { - /** Test that lambdas with no parameters are valid. Building: () -> i32(42) : func<() -> i32> */ - @Test - void zeroParameterLambda() { - Type.Struct emptyParams = Type.Struct.builder().nullable(false).build(); + final LambdaBuilder lb = new LambdaBuilder(); - Expression body = ExpressionCreator.i32(false, 42); + // ==================== Single Lambda Tests ==================== - Expression.Lambda lambda = - Expression.Lambda.builder().parameters(emptyParams).body(body).build(); + // () -> 42 + @Test + void zeroParameterLambda() { + Expression.Lambda lambda = lb.lambda(List.of(), params -> ExpressionCreator.i32(false, 42)); verifyRoundTrip(lambda); - // Verify the lambda type - Type lambdaType = lambda.getType(); - assertInstanceOf(Type.Func.class, lambdaType); - Type.Func funcType = (Type.Func) lambdaType; + Type.Func funcType = (Type.Func) lambda.getType(); assertEquals(0, funcType.parameterTypes().size()); assertEquals(R.I32, funcType.returnType()); } - /** Test valid stepsOut=0 references. Building: ($0: i32) -> $0 : func i32> */ + // (x: i32) -> x @Test - void validStepsOut0() { - Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); - - // Lambda body references parameter 0 with stepsOut=0 - FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); - - Expression.Lambda lambda = - Expression.Lambda.builder().parameters(params).body(paramRef).build(); + void identityLambda() { + Expression.Lambda lambda = lb.lambda(List.of(R.I32), params -> params.ref(0)); verifyRoundTrip(lambda); - // Verify types - Type lambdaType = lambda.getType(); - assertInstanceOf(Type.Func.class, lambdaType); - Type.Func funcType = (Type.Func) lambdaType; + Type.Func funcType = (Type.Func) lambda.getType(); assertEquals(1, funcType.parameterTypes().size()); assertEquals(R.I32, funcType.parameterTypes().get(0)); assertEquals(R.I32, funcType.returnType()); + + assertInstanceOf(FieldReference.class, lambda.body()); + FieldReference ref = (FieldReference) lambda.body(); + assertTrue(ref.isLambdaParameterReference()); + assertEquals(0, ref.lambdaParameterReferenceStepsOut().orElse(-1)); } - /** - * Test valid field index with multiple parameters. Building: ($0: i32, $1: i64, $2: string) -> $2 - * : func<(i32, i64, string) -> string> - */ + // (x: i32, y: i64, z: string) -> z @Test void validFieldIndex() { - Type.Struct params = - Type.Struct.builder().nullable(false).addFields(R.I32, R.I64, R.STRING).build(); + Expression.Lambda lambda = lb.lambda(List.of(R.I32, R.I64, R.STRING), params -> params.ref(2)); - // Reference the 3rd parameter (string) - FieldReference paramRef = FieldReference.newLambdaParameterReference(2, params, 0); + verifyRoundTrip(lambda); + assertEquals(R.STRING, ((Type.Func) lambda.getType()).returnType()); + } + // (x: i32) -> 42 + @Test + void lambdaWithLiteralBody() { Expression.Lambda lambda = - Expression.Lambda.builder().parameters(params).body(paramRef).build(); + lb.lambda(List.of(R.I32), params -> ExpressionCreator.i32(false, 42)); verifyRoundTrip(lambda); - - // Verify return type is string - Type.Func funcType = (Type.Func) lambda.getType(); - assertEquals(R.STRING, funcType.returnType()); + assertInstanceOf(Expression.I32Literal.class, lambda.body()); } - /** Test type resolution for different parameter types. */ + // Parameterized: (params...) -> params[fieldIndex], verifying type resolution @Test void typeResolution() { - // Test cases: (paramTypes, fieldIndex, expectedReturnType) - record TestCase(List paramTypes, int fieldIndex, Type expectedType) {} + record TestCase(String name, List paramTypes, int fieldIndex, Type expectedType) {} List testCases = List.of( - new TestCase(List.of(R.I32), 0, R.I32), - new TestCase(List.of(R.I32, R.I64), 1, R.I64), - new TestCase(List.of(R.I32, R.I64, R.STRING), 2, R.STRING), - new TestCase(List.of(R.FP64), 0, R.FP64), - new TestCase(List.of(R.BOOLEAN, R.DATE, R.TIMESTAMP), 1, R.DATE)); + new TestCase("first param (i32)", List.of(R.I32), 0, R.I32), + new TestCase("second param (i64)", List.of(R.I32, R.I64), 1, R.I64), + new TestCase("third param (string)", List.of(R.I32, R.I64, R.STRING), 2, R.STRING), + new TestCase("float64 param", List.of(R.FP64), 0, R.FP64), + new TestCase("date param", List.of(R.BOOLEAN, R.DATE, R.TIMESTAMP), 1, R.DATE)); for (TestCase tc : testCases) { - Type.Struct params = - Type.Struct.builder().nullable(false).addAllFields(tc.paramTypes).build(); - - FieldReference paramRef = - FieldReference.newLambdaParameterReference(tc.fieldIndex, params, 0); - - Expression.Lambda lambda = - Expression.Lambda.builder().parameters(params).body(paramRef).build(); + Expression.Lambda lambda = lb.lambda(tc.paramTypes, params -> params.ref(tc.fieldIndex)); verifyRoundTrip(lambda); - // Verify the body type matches expected - assertEquals( - tc.expectedType, - lambda.body().getType(), - "Body type should match referenced parameter type"); - - // Verify lambda return type + assertEquals(tc.expectedType, lambda.body().getType(), tc.name + ": body type mismatch"); Type.Func funcType = (Type.Func) lambda.getType(); assertEquals( - tc.expectedType, funcType.returnType(), "Lambda return type should match body type"); + tc.expectedType, funcType.returnType(), tc.name + ": lambda return type mismatch"); } } - /** - * Test nested lambda with outer reference. Building: ($0: i64, $1: i64) -> (($0: i32) -> - * outer[$0] : i64) : func<(i64, i64) -> func i64>> - */ + // (x: i32, y: string) -> y — verify full Func type structure @Test - void nestedLambdaWithOuterRef() { - Type.Struct outerParams = Type.Struct.builder().nullable(false).addFields(R.I64, R.I64).build(); - - Type.Struct innerParams = Type.Struct.builder().nullable(false).addFields(R.I32).build(); - - // Inner lambda references outer's parameter 0 with stepsOut=1 - FieldReference outerRef = FieldReference.newLambdaParameterReference(0, outerParams, 1); - - Expression.Lambda innerLambda = - Expression.Lambda.builder().parameters(innerParams).body(outerRef).build(); - - Expression.Lambda outerLambda = - Expression.Lambda.builder().parameters(outerParams).body(innerLambda).build(); - - verifyRoundTrip(outerLambda); + void lambdaGetTypeReturnsFunc() { + Expression.Lambda lambda = lb.lambda(List.of(R.I32, R.STRING), params -> params.ref(1)); - // Verify structure - assertInstanceOf(Expression.Lambda.class, outerLambda.body()); - Expression.Lambda resultInner = (Expression.Lambda) outerLambda.body(); - assertEquals(1, resultInner.parameters().fields().size()); + Type.Func funcType = (Type.Func) lambda.getType(); + assertEquals(2, funcType.parameterTypes().size()); + assertEquals(R.I32, funcType.parameterTypes().get(0)); + assertEquals(R.STRING, funcType.parameterTypes().get(1)); + assertEquals(R.STRING, funcType.returnType()); } - /** - * Test outer reference type resolution in nested lambdas. Building: ($0: i32, $1: i64, $2: - * string) -> (($0: fp64) -> outer[$2] : string) : func<...> - */ + // (x: i32, y: i64, z: string) -> ... — verify FieldReference metadata for each param @Test - void outerRefTypeResolution() { - Type.Struct outerParams = - Type.Struct.builder().nullable(false).addFields(R.I32, R.I64, R.STRING).build(); - - Type.Struct innerParams = Type.Struct.builder().nullable(false).addFields(R.FP64).build(); - - // Inner references outer's field 2 (string) with stepsOut=1 - FieldReference outerRef = FieldReference.newLambdaParameterReference(2, outerParams, 1); - - Expression.Lambda innerLambda = - Expression.Lambda.builder().parameters(innerParams).body(outerRef).build(); - - Expression.Lambda outerLambda = - Expression.Lambda.builder().parameters(outerParams).body(innerLambda).build(); - - verifyRoundTrip(outerLambda); - - // Verify inner lambda's return type is string (from outer param 2) - Expression.Lambda resultInner = (Expression.Lambda) outerLambda.body(); - Type.Func innerFuncType = (Type.Func) resultInner.getType(); - assertEquals( - R.STRING, - innerFuncType.returnType(), - "Inner lambda return type should be string from outer.$2"); - - // Verify body's type is also string - assertEquals(R.STRING, resultInner.body().getType(), "Body type should be string"); + void parameterReferenceMetadata() { + List paramTypes = List.of(R.I32, R.I64, R.STRING); + + lb.lambda( + paramTypes, + params -> { + for (int i = 0; i < 3; i++) { + FieldReference ref = params.ref(i); + assertTrue(ref.isLambdaParameterReference()); + assertFalse(ref.isOuterReference()); + assertFalse(ref.isSimpleRootReference()); + assertEquals(0, ref.lambdaParameterReferenceStepsOut().orElse(-1)); + assertEquals(paramTypes.get(i), ref.getType()); + } + return params.ref(0); + }); } - /** - * Test deeply nested field ref inside Cast. Building: ($0: i32) -> cast($0 as i64) : func - * i64> - */ - @Test - void deeplyNestedFieldRef() { - Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); - - FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); + // ==================== Expression Body Tests ==================== - Expression.Cast castExpr = - (Expression.Cast) - ExpressionCreator.cast(R.I64, paramRef, Expression.FailureBehavior.THROW_EXCEPTION); - - Expression.Lambda lambda = - Expression.Lambda.builder().parameters(params).body(castExpr).build(); + // (x: i64) -> (y1: i64, y2: i64) -> y1 * x + y2 + @Test + void nestedLambdaWithArithmeticBody() { + String ARITH = DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC; + + Expression.Lambda result = + lb.lambda( + List.of(R.I64), + outer -> + lb.lambda( + List.of(R.I64, R.I64), + inner -> { + // y1 * x + Expression multiply = + sb.scalarFn(ARITH, "multiply:i64_i64", R.I64, inner.ref(0), outer.ref(0)); + // (y1 * x) + y2 + return sb.scalarFn(ARITH, "add:i64_i64", R.I64, multiply, inner.ref(1)); + })); + + verifyRoundTrip(result); + + // Outer lambda returns a func type + Type.Func outerFuncType = (Type.Func) result.getType(); + assertInstanceOf(Type.Func.class, outerFuncType.returnType()); + + // Inner lambda returns i64 + Expression.Lambda innerLambda = (Expression.Lambda) result.body(); + Type.Func innerFuncType = (Type.Func) innerLambda.getType(); + assertEquals(R.I64, innerFuncType.returnType()); + + // Inner body is a scalar function (add) + assertInstanceOf(Expression.ScalarFunctionInvocation.class, innerLambda.body()); + } - verifyRoundTrip(lambda); + // ==================== Nested Lambda Tests ==================== - // Verify the nested FieldRef has its type resolved - Expression.Cast resultCast = (Expression.Cast) lambda.body(); - assertInstanceOf(FieldReference.class, resultCast.input()); - FieldReference resultFieldRef = (FieldReference) resultCast.input(); + // (x: i64, y: i64) -> (z: i32) -> x + @Test + void nestedLambdaWithOuterRef() { + Expression.Lambda result = + lb.lambda(List.of(R.I64, R.I64), outer -> lb.lambda(List.of(R.I32), inner -> outer.ref(0))); - assertNotNull(resultFieldRef.getType(), "Nested FieldRef should have type resolved"); - assertEquals(R.I32, resultFieldRef.getType(), "Should resolve to i32"); + verifyRoundTrip(result); - // Verify lambda return type is i64 (cast output) - Type.Func funcType = (Type.Func) lambda.getType(); - assertEquals(R.I64, funcType.returnType()); + Expression.Lambda resultInner = (Expression.Lambda) result.body(); + assertEquals(1, resultInner.parameters().fields().size()); + assertEquals(R.I64, resultInner.body().getType()); } - /** - * Test doubly nested field ref (Cast(Cast(LambdaParamRef))). Building: ($0: i32) -> cast(cast($0 - * as i64) as string) : func string> - */ + // (x: i32, y: i64, z: string) -> (w: fp64) -> z @Test - void doublyNestedFieldRef() { - Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); - - FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); - Expression.Cast innerCast = - (Expression.Cast) - ExpressionCreator.cast(R.I64, paramRef, Expression.FailureBehavior.THROW_EXCEPTION); - Expression.Cast outerCast = - (Expression.Cast) - ExpressionCreator.cast(R.STRING, innerCast, Expression.FailureBehavior.THROW_EXCEPTION); - - Expression.Lambda lambda = - Expression.Lambda.builder().parameters(params).body(outerCast).build(); + void nestedLambdaOuterRefTypeResolution() { + Expression.Lambda result = + lb.lambda( + List.of(R.I32, R.I64, R.STRING), + outer -> lb.lambda(List.of(R.FP64), inner -> outer.ref(2))); - verifyRoundTrip(lambda); - - // Navigate to the deeply nested FieldRef (2 levels deep) - Expression.Cast resultOuter = (Expression.Cast) lambda.body(); - Expression.Cast resultInner = (Expression.Cast) resultOuter.input(); - FieldReference resultFieldRef = (FieldReference) resultInner.input(); + verifyRoundTrip(result); - // Verify type is resolved even at depth 2 - assertNotNull(resultFieldRef.getType(), "FieldRef at depth 2 should have type resolved"); - assertEquals(R.I32, resultFieldRef.getType()); + Expression.Lambda resultInner = (Expression.Lambda) result.body(); + assertEquals(R.STRING, ((Type.Func) resultInner.getType()).returnType()); } - /** - * Test lambda with literal body (no parameter references). Building: ($0: i32) -> 42 : func i32> - */ + // (x: i32) -> (y: i64) -> y @Test - void lambdaWithLiteralBody() { - Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); + void nestedLambdaInnerRefOnly() { + Expression.Lambda result = + lb.lambda(List.of(R.I32), outer -> lb.lambda(List.of(R.I64), inner -> inner.ref(0))); - Expression body = ExpressionCreator.i32(false, 42); + verifyRoundTrip(result); - Expression.Lambda lambda = Expression.Lambda.builder().parameters(params).body(body).build(); - - verifyRoundTrip(lambda); + Expression.Lambda innerLambda = (Expression.Lambda) result.body(); + assertEquals(R.I64, innerLambda.body().getType()); + assertInstanceOf(Type.Func.class, ((Type.Func) result.getType()).returnType()); } - /** Test lambda getType returns correct Func type. */ + // (x: i32) -> (y: i64) -> (x, y) — body references both outer and inner params @Test - void lambdaGetTypeReturnsFunc() { - Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32, R.STRING).build(); - - FieldReference paramRef = FieldReference.newLambdaParameterReference(1, params, 0); - - Expression.Lambda lambda = - Expression.Lambda.builder().parameters(params).body(paramRef).build(); - - Type lambdaType = lambda.getType(); - - assertInstanceOf(Type.Func.class, lambdaType); - Type.Func funcType = (Type.Func) lambdaType; - - assertEquals(2, funcType.parameterTypes().size()); - assertEquals(R.I32, funcType.parameterTypes().get(0)); - assertEquals(R.STRING, funcType.parameterTypes().get(1)); - assertEquals(R.STRING, funcType.returnType()); // body references param 1 which is STRING + void nestedLambdaBothInnerAndOuterRefs() { + Expression.Lambda result = + lb.lambda( + List.of(R.I32), + outer -> + lb.lambda( + List.of(R.I64), + inner -> { + FieldReference innerRef = inner.ref(0); + assertEquals(R.I64, innerRef.getType()); + assertEquals(0, innerRef.lambdaParameterReferenceStepsOut().orElse(-1)); + + FieldReference outerRef = outer.ref(0); + assertEquals(R.I32, outerRef.getType()); + assertEquals(1, outerRef.lambdaParameterReferenceStepsOut().orElse(-1)); + + return innerRef; + })); + + verifyRoundTrip(result); } - // ==================== Validation Error Tests ==================== - - /** - * Test that invalid outer reference (stepsOut too high) fails during proto conversion. Building: - * ($0: i32) -> outer[$0] : INVALID (no outer lambda, stepsOut=1) - */ + // (a: i32, b: string) -> (c: i64, d: fp64) -> b — verify all 4 params resolve correctly @Test - void invalidOuterRef_stepsOutTooHigh() { - Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); - - // Create a parameter reference with stepsOut=1 but no outer lambda exists - FieldReference invalidRef = FieldReference.newLambdaParameterReference(0, params, 1); - - Expression.Lambda lambda = - Expression.Lambda.builder().parameters(params).body(invalidRef).build(); - - // Convert to proto - this should work - io.substrait.proto.Expression protoExpression = expressionProtoConverter.toProto(lambda); - - // Converting back should fail because stepsOut=1 references non-existent outer lambda - assertThrows( - IllegalArgumentException.class, - () -> { - protoExpressionConverter.from(protoExpression); - }, - "Should fail when stepsOut references non-existent outer lambda"); + void nestedLambdaMultiParamCorrectResolution() { + Expression.Lambda result = + lb.lambda( + List.of(R.I32, R.STRING), + outer -> + lb.lambda( + List.of(R.I64, R.FP64), + inner -> { + assertEquals(R.I64, inner.ref(0).getType()); + assertEquals(R.FP64, inner.ref(1).getType()); + assertEquals(R.I32, outer.ref(0).getType()); + assertEquals(R.STRING, outer.ref(1).getType()); + + return outer.ref(1); + })); + + verifyRoundTrip(result); + + Expression.Lambda innerLambda = (Expression.Lambda) result.body(); + assertEquals(R.STRING, ((Type.Func) innerLambda.getType()).returnType()); } - /** - * Test that invalid field index (out of bounds) fails during proto conversion. Building: ($0: - * i32) -> $5 : INVALID (only has 1 param) - */ + // (x: i32) -> (y: i64) -> (z: string) -> x @Test - void invalidFieldIndex_outOfBounds() { - Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); - - // Create a reference to field 5, but lambda only has 1 parameter (index 0) - // This will fail at build time since newLambdaParameterReference accesses fields.get(5) - assertThrows( - IndexOutOfBoundsException.class, - () -> { - FieldReference.newLambdaParameterReference(5, params, 0); - }, - "Should fail when field index is out of bounds"); + void tripleNestedLambdaRoundtrip() { + Expression.Lambda result = + lb.lambda( + List.of(R.I32), + outer -> + lb.lambda( + List.of(R.I64), mid -> lb.lambda(List.of(R.STRING), inner -> outer.ref(0)))); + + verifyRoundTrip(result); + + Expression.Lambda l1 = (Expression.Lambda) result.body(); + Expression.Lambda l2 = (Expression.Lambda) l1.body(); + assertEquals(R.I32, l2.body().getType()); } - /** - * Test nested invalid outer ref (stepsOut=2 but only 1 outer lambda). Building: ($0: i64) -> - * (($0: i32) -> outer.outer[$0]) : INVALID (no grandparent lambda) - */ + // (x: i32) -> (y: i64) -> (z: string) -> ... — verify stepsOut is auto-computed at each level @Test - void nestedInvalidOuterRef() { - Type.Struct outerParams = Type.Struct.builder().nullable(false).addFields(R.I64).build(); - - Type.Struct innerParams = Type.Struct.builder().nullable(false).addFields(R.I32).build(); - - // Inner lambda references stepsOut=2, but only 1 outer lambda exists - FieldReference invalidRef = FieldReference.newLambdaParameterReference(0, outerParams, 2); - - Expression.Lambda innerLambda = - Expression.Lambda.builder().parameters(innerParams).body(invalidRef).build(); - - Expression.Lambda outerLambda = - Expression.Lambda.builder().parameters(outerParams).body(innerLambda).build(); - - // Convert to proto - io.substrait.proto.Expression protoExpression = expressionProtoConverter.toProto(outerLambda); - - // Converting back should fail because stepsOut=2 references non-existent grandparent - assertThrows( - IllegalArgumentException.class, - () -> { - protoExpressionConverter.from(protoExpression); - }, - "Should fail when stepsOut references non-existent grandparent lambda"); + void tripleNestedLambdaScopeTracking() { + lb.lambda( + List.of(R.I32), + outer -> + lb.lambda( + List.of(R.I64), + mid -> + lb.lambda( + List.of(R.STRING), + inner -> { + assertEquals(R.STRING, inner.ref(0).getType()); + assertEquals(R.I64, mid.ref(0).getType()); + assertEquals(R.I32, outer.ref(0).getType()); + + assertEquals( + 0, inner.ref(0).lambdaParameterReferenceStepsOut().orElse(-1)); + assertEquals(1, mid.ref(0).lambdaParameterReferenceStepsOut().orElse(-1)); + assertEquals( + 2, outer.ref(0).lambdaParameterReferenceStepsOut().orElse(-1)); + + return inner.ref(0); + }))); } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java index 176b246e7..d5330681d 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java @@ -2,6 +2,7 @@ import io.substrait.expression.Expression; import io.substrait.expression.FieldReference; +import io.substrait.expression.ImmutableExpression; import io.substrait.isthmus.CallConverter; import io.substrait.isthmus.SubstraitRelVisitor; import io.substrait.isthmus.TypeConverter; @@ -212,7 +213,7 @@ public Expression visitLambda(RexLambda rexLambda) { Expression body = rexLambda.getExpression().accept(this); - return Expression.Lambda.builder().parameters(parameters).body(body).build(); + return ImmutableExpression.Lambda.builder().parameters(parameters).body(body).build(); } @Override diff --git a/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java b/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java index add746ee1..55924081d 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java @@ -4,146 +4,57 @@ import io.substrait.expression.Expression; import io.substrait.expression.ExpressionCreator; -import io.substrait.expression.FieldReference; +import io.substrait.expression.LambdaBuilder; import io.substrait.relation.Project; import io.substrait.relation.Rel; -import io.substrait.type.Type; import java.util.ArrayList; import java.util.List; import org.junit.jupiter.api.Test; -/** - * Tests for Lambda expression conversion between Substrait and Calcite. Note: Calcite does not - * support nested lambda expressions for the moment, so all tests use stepsOut=0. - */ class LambdaExpressionTest extends PlanTestBase { final Rel emptyTable = sb.emptyVirtualTableScan(); + final LambdaBuilder lb = new LambdaBuilder(); - /** Test that lambdas with no parameters are valid. Building: () -> i32(42) : func<() -> i32> */ + // () -> 42 @Test void lambdaExpressionZeroParameters() { - Type.Struct params = Type.Struct.builder().nullable(false).build(); + List exprs = new ArrayList<>(); + exprs.add(lb.lambda(List.of(), params -> ExpressionCreator.i32(false, 42))); - Expression body = ExpressionCreator.i32(false, 42); - List expressionList = new ArrayList<>(); - Expression.Lambda lambda = Expression.Lambda.builder().parameters(params).body(body).build(); - expressionList.add(lambda); - - Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + Project project = Project.builder().expressions(exprs).input(emptyTable).build(); assertFullRoundTrip(project); } - /** - * Test valid field index with multiple parameters. Building: ($0: i32, $1: i64, $2: string) -> $2 - * : func<(i32, i64, string) -> string> - */ + // (x: i32, y: i64, z: string) -> x @Test void validFieldIndex() { - Type.Struct params = - Type.Struct.builder().nullable(false).addFields(R.I32, R.I64, R.STRING).build(); - - FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); - - List expressionList = new ArrayList<>(); - Expression.Lambda lambda = - Expression.Lambda.builder().parameters(params).body(paramRef).build(); - - expressionList.add(lambda); - - Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); - assertFullRoundTrip(project); - } - - /** - * Test deeply nested field ref inside Cast. Building: ($0: i32) -> cast($0 as i64) : func - * i64> - */ - @Test - void deeplyNestedFieldRef() { - Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); - FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); - - Expression.Cast castExpr = - (Expression.Cast) - ExpressionCreator.cast(R.I64, paramRef, Expression.FailureBehavior.THROW_EXCEPTION); - List expressionList = new ArrayList<>(); - - Expression.Lambda lambda = - Expression.Lambda.builder().parameters(params).body(castExpr).build(); - - expressionList.add(lambda); - Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); - assertFullRoundTrip(project); - } - - /** - * Test doubly nested field ref (Cast(Cast(LambdaParamRef))). Building: ($0: i32) -> cast(cast($0 - * as i64) as string) : func string> - */ - @Test - void doublyNestedFieldRef() { - Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); - - FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); - Expression.Cast innerCast = - (Expression.Cast) - ExpressionCreator.cast(R.I64, paramRef, Expression.FailureBehavior.THROW_EXCEPTION); - Expression.Cast outerCast = - (Expression.Cast) - ExpressionCreator.cast(R.STRING, innerCast, Expression.FailureBehavior.THROW_EXCEPTION); + List exprs = new ArrayList<>(); + exprs.add(lb.lambda(List.of(R.I32, R.I64, R.STRING), params -> params.ref(0))); - List expressionList = new ArrayList<>(); - - Expression.Lambda lambda = - Expression.Lambda.builder().parameters(params).body(outerCast).build(); - - expressionList.add(lambda); - Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + Project project = Project.builder().expressions(exprs).input(emptyTable).build(); assertFullRoundTrip(project); } - /** - * Test lambda with literal body (no parameter references). Building: ($0: i32) -> 42 : func i32> - */ + // (x: i32) -> 42 @Test void lambdaWithLiteralBody() { - Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); - - Expression body = ExpressionCreator.i32(false, 42); - List expressionList = new ArrayList<>(); + List exprs = new ArrayList<>(); + exprs.add(lb.lambda(List.of(R.I32), params -> ExpressionCreator.i32(false, 42))); - Expression.Lambda lambda = Expression.Lambda.builder().parameters(params).body(body).build(); - - expressionList.add(lambda); - Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + Project project = Project.builder().expressions(exprs).input(emptyTable).build(); assertFullRoundTrip(project); } - /** - * Test that nested lambda (stepsOut > 0) throws UnsupportedOperationException. Calcite does not - * support nested lambda expressions. - */ + // (x: i64) -> (y: i32) -> x — Calcite doesn't support nested lambdas @Test void nestedLambdaThrowsUnsupportedOperation() { - Type.Struct outerParams = Type.Struct.builder().nullable(false).addFields(R.I64).build(); - - Type.Struct innerParams = Type.Struct.builder().nullable(false).addFields(R.I32).build(); - - // Inner lambda references outer's parameter with stepsOut=1 - FieldReference outerRef = FieldReference.newLambdaParameterReference(0, outerParams, 1); - - Expression.Lambda innerLambda = - Expression.Lambda.builder().parameters(innerParams).body(outerRef).build(); - - List expressionList = new ArrayList<>(); - Expression.Lambda outerLambda = - Expression.Lambda.builder().parameters(outerParams).body(innerLambda).build(); + lb.lambda(List.of(R.I64), outer -> lb.lambda(List.of(R.I32), inner -> outer.ref(0))); - expressionList.add(outerLambda); - Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + List exprs = new ArrayList<>(); + exprs.add(outerLambda); + Project project = Project.builder().expressions(exprs).input(emptyTable).build(); assertThrows(UnsupportedOperationException.class, () -> assertFullRoundTrip(project)); } } From 50db6998d84c3ee9af532e5ca2073e14e727287d Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Mon, 16 Mar 2026 17:28:21 -0400 Subject: [PATCH 13/26] refactor: unify lambda validation, add JSON-based roundtrip tests Moves ProtoExpressionConverter to use LambdaBuilder for lambda parameter validation, removing the private LambdaParameterStack. Replaces builder-based roundtrip tests with parameterized JSON fixtures in expressions/lambda/. Adds arithmetic body test to isthmus LambdaExpressionTest. --- .../substrait/expression/LambdaBuilder.java | 39 +++ .../proto/ProtoExpressionConverter.java | 56 +--- .../expression/LambdaBuilderTest.java | 51 ++- .../proto/LambdaExpressionRoundtripTest.java | 313 ++---------------- .../lambda/invalid/nested_steps_out.json | 30 ++ .../expressions/lambda/invalid/steps_out.json | 21 ++ .../expressions/lambda/valid/identity.json | 21 ++ .../lambda/valid/literal_body.json | 14 + .../expressions/lambda/valid/multi_param.json | 23 ++ .../expressions/lambda/valid/nested.json | 30 ++ .../lambda/valid/triple_nested.json | 39 +++ .../expressions/lambda/valid/zero_params.json | 12 + .../isthmus/LambdaExpressionTest.java | 33 ++ 13 files changed, 348 insertions(+), 334 deletions(-) create mode 100644 core/src/test/resources/expressions/lambda/invalid/nested_steps_out.json create mode 100644 core/src/test/resources/expressions/lambda/invalid/steps_out.json create mode 100644 core/src/test/resources/expressions/lambda/valid/identity.json create mode 100644 core/src/test/resources/expressions/lambda/valid/literal_body.json create mode 100644 core/src/test/resources/expressions/lambda/valid/multi_param.json create mode 100644 core/src/test/resources/expressions/lambda/valid/nested.json create mode 100644 core/src/test/resources/expressions/lambda/valid/triple_nested.json create mode 100644 core/src/test/resources/expressions/lambda/valid/zero_params.json diff --git a/core/src/main/java/io/substrait/expression/LambdaBuilder.java b/core/src/main/java/io/substrait/expression/LambdaBuilder.java index b2b89c02d..1197cd1f1 100644 --- a/core/src/main/java/io/substrait/expression/LambdaBuilder.java +++ b/core/src/main/java/io/substrait/expression/LambdaBuilder.java @@ -54,6 +54,45 @@ public Expression.Lambda lambda(List paramTypes, Function bodyFn) { + pushLambdaContext(params); + try { + Expression body = bodyFn.get(); + return ImmutableExpression.Lambda.builder().parameters(params).body(body).build(); + } finally { + popLambdaContext(); + } + } + + /** + * Resolves the parameter struct for a lambda at the given stepsOut from the current innermost + * scope. Used by internal converters to validate lambda parameter references during + * deserialization. + * + * @param stepsOut number of lambda scopes to traverse outward (0 = current/innermost) + * @return the parameter struct at the target scope level + * @throws IllegalArgumentException if stepsOut exceeds the current nesting depth + */ + public Type.Struct resolveParams(int stepsOut) { + int index = lambdaContext.size() - 1 - stepsOut; + if (index < 0 || index >= lambdaContext.size()) { + throw new IllegalArgumentException( + String.format( + "Lambda parameter reference with stepsOut=%d is invalid (current depth: %d)", + stepsOut, lambdaContext.size())); + } + return lambdaContext.get(index); + } + /** * Pushes a lambda's parameters onto the context stack. This makes the parameters available for * validation when building the lambda's body, and allows nested lambda parameter references to diff --git a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java index 426b0f56a..470c13c6a 100644 --- a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java @@ -6,7 +6,7 @@ import io.substrait.expression.FieldReference.ReferenceSegment; import io.substrait.expression.FunctionArg; import io.substrait.expression.FunctionOption; -import io.substrait.expression.ImmutableExpression; +import io.substrait.expression.LambdaBuilder; import io.substrait.expression.WindowBound; import io.substrait.extension.ExtensionLookup; import io.substrait.extension.SimpleExtension; @@ -38,7 +38,7 @@ public class ProtoExpressionConverter { private final Type.Struct rootType; private final ProtoTypeConverter protoTypeConverter; private final ProtoRelConverter protoRelConverter; - private final LambdaParameterStack lambdaParameterStack = new LambdaParameterStack(); + private final LambdaBuilder lambdaBuilder = new LambdaBuilder(); public ProtoExpressionConverter( ExtensionLookup lookup, @@ -83,7 +83,7 @@ public FieldReference from(io.substrait.proto.Expression.FieldReference referenc reference.getLambdaParameterReference(); int stepsOut = lambdaParamRef.getStepsOut(); - Type.Struct lambdaParameters = lambdaParameterStack.get(stepsOut); + Type.Struct lambdaParameters = lambdaBuilder.resolveParams(stepsOut); // Check for unsupported nested field access if (reference.getDirectReference().getStructField().hasChild()) { @@ -283,7 +283,6 @@ public Type visit(Type.Struct type) throws RuntimeException { case LAMBDA: { - // TODO: Add build-time validation of lambda parameter references during deserialization. io.substrait.proto.Expression.Lambda protoLambda = expr.getLambda(); Type.Struct parameters = (Type.Struct) @@ -292,16 +291,7 @@ public Type visit(Type.Struct type) throws RuntimeException { .setStruct(protoLambda.getParameters()) .build()); - lambdaParameterStack.push(parameters); - - Expression body; - try { - body = from(protoLambda.getBody()); - } finally { - lambdaParameterStack.pop(); - } - - return ImmutableExpression.Lambda.builder().parameters(parameters).body(body).build(); + return lambdaBuilder.lambdaFromStruct(parameters, () -> from(protoLambda.getBody())); } // TODO enum. case ENUM: @@ -622,42 +612,4 @@ public Expression.SortField fromSortField(SortField s) { public static FunctionOption fromFunctionOption(io.substrait.proto.FunctionOption o) { return FunctionOption.builder().name(o.getName()).addAllValues(o.getPreferenceList()).build(); } - - /** - * A stack for tracking lambda parameter types during expression parsing. - * - *

When parsing nested lambda expressions, each lambda's parameters are pushed onto this stack. - * Lambda parameter references use "stepsOut" to indicate which enclosing lambda they reference: - * - *

    - *
  • stepsOut=0 refers to the innermost (current) lambda - *
  • stepsOut=1 refers to the next enclosing lambda - *
  • stepsOut=N refers to N levels up - *
- */ - private static class LambdaParameterStack { - private final List stack = new ArrayList<>(); - - void push(Type.Struct parameters) { - stack.add(parameters); - } - - void pop() { - if (stack.isEmpty()) { - throw new IllegalArgumentException("Lambda parameter stack is empty"); - } - stack.remove(stack.size() - 1); - } - - Type.Struct get(int stepsOut) { - int index = stack.size() - 1 - stepsOut; - if (index < 0 || index >= stack.size()) { - throw new IllegalArgumentException( - String.format( - "Lambda parameter reference with stepsOut=%d is invalid (current depth: %d)", - stepsOut, stack.size())); - } - return stack.get(index); - } - } } diff --git a/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java b/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java index 7d8f8976b..492c56f30 100644 --- a/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java +++ b/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java @@ -1,32 +1,73 @@ package io.substrait.expression; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; +import io.substrait.type.Type; import io.substrait.type.TypeCreator; import java.util.List; import org.junit.jupiter.api.Test; -/** Tests for {@link LambdaBuilder} build-time validation. */ +/** Tests for {@link LambdaBuilder}. */ class LambdaBuilderTest { static final TypeCreator R = TypeCreator.REQUIRED; final LambdaBuilder lb = new LambdaBuilder(); - // (x: i32) -> x[5] — field index 5 is out of bounds (only 1 param) + // (x: i32)@p -> p[0] + @Test + void simpleLambda() { + Expression.Lambda lambda = lb.lambda(List.of(R.I32), params -> params.ref(0)); + + Expression.Lambda expected = + ImmutableExpression.Lambda.builder() + .parameters(Type.Struct.builder().nullable(false).addFields(R.I32).build()) + .body( + FieldReference.newLambdaParameterReference( + 0, Type.Struct.builder().nullable(false).addFields(R.I32).build(), 0)) + .build(); + + assertEquals(expected, lambda); + } + + // (x: i32)@outer -> (y: i64)@inner -> outer[0] + @Test + void nestedLambda() { + Expression.Lambda lambda = + lb.lambda(List.of(R.I32), outer -> lb.lambda(List.of(R.I64), inner -> outer.ref(0))); + + Expression.Lambda expectedInner = + ImmutableExpression.Lambda.builder() + .parameters(Type.Struct.builder().nullable(false).addFields(R.I64).build()) + .body( + FieldReference.newLambdaParameterReference( + 0, Type.Struct.builder().nullable(false).addFields(R.I32).build(), 1)) + .build(); + + Expression.Lambda expected = + ImmutableExpression.Lambda.builder() + .parameters(Type.Struct.builder().nullable(false).addFields(R.I32).build()) + .body(expectedInner) + .build(); + + assertEquals(expected, lambda); + } + + // (x: i32)@p -> p[5] — only 1 param, index 5 is out of bounds @Test void invalidFieldIndex_outOfBounds() { assertThrows( IndexOutOfBoundsException.class, () -> lb.lambda(List.of(R.I32), params -> params.ref(5))); } - // (x: i32) -> x[-1] — negative field index + // (x: i32)@p -> p[-1] — negative index @Test void negativeFieldIndex() { assertThrows(Exception.class, () -> lb.lambda(List.of(R.I32), params -> params.ref(-1))); } - // (x: i32) -> (y: i64) -> x[5] — outer field index 5 is out of bounds + // (x: i32)@outer -> (y: i64)@inner -> outer[5] — outer only has 1 param @Test void nestedOuterFieldIndexOutOfBounds() { assertThrows( @@ -34,7 +75,7 @@ void nestedOuterFieldIndexOutOfBounds() { () -> lb.lambda(List.of(R.I32), outer -> lb.lambda(List.of(R.I64), inner -> outer.ref(5)))); } - // (x: i32) -> (y: i64) -> y[3] — inner field index 3 is out of bounds (only 1 param) + // (x: i32)@outer -> (y: i64)@inner -> inner[3] — inner only has 1 param @Test void nestedInnerFieldIndexOutOfBounds() { assertThrows( diff --git a/core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java index 3b9cfedad..1589c35ba 100644 --- a/core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java @@ -1,298 +1,57 @@ package io.substrait.type.proto; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertInstanceOf; -import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertThrows; +import com.google.protobuf.util.JsonFormat; import io.substrait.TestBase; import io.substrait.expression.Expression; -import io.substrait.expression.ExpressionCreator; -import io.substrait.expression.FieldReference; -import io.substrait.expression.LambdaBuilder; -import io.substrait.extension.DefaultExtensionCatalog; -import io.substrait.type.Type; -import java.util.List; -import org.junit.jupiter.api.Test; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.stream.Stream; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; class LambdaExpressionRoundtripTest extends TestBase { - final LambdaBuilder lb = new LambdaBuilder(); - - // ==================== Single Lambda Tests ==================== - - // () -> 42 - @Test - void zeroParameterLambda() { - Expression.Lambda lambda = lb.lambda(List.of(), params -> ExpressionCreator.i32(false, 42)); - - verifyRoundTrip(lambda); - - Type.Func funcType = (Type.Func) lambda.getType(); - assertEquals(0, funcType.parameterTypes().size()); - assertEquals(R.I32, funcType.returnType()); - } - - // (x: i32) -> x - @Test - void identityLambda() { - Expression.Lambda lambda = lb.lambda(List.of(R.I32), params -> params.ref(0)); - - verifyRoundTrip(lambda); - - Type.Func funcType = (Type.Func) lambda.getType(); - assertEquals(1, funcType.parameterTypes().size()); - assertEquals(R.I32, funcType.parameterTypes().get(0)); - assertEquals(R.I32, funcType.returnType()); - - assertInstanceOf(FieldReference.class, lambda.body()); - FieldReference ref = (FieldReference) lambda.body(); - assertTrue(ref.isLambdaParameterReference()); - assertEquals(0, ref.lambdaParameterReferenceStepsOut().orElse(-1)); - } - - // (x: i32, y: i64, z: string) -> z - @Test - void validFieldIndex() { - Expression.Lambda lambda = lb.lambda(List.of(R.I32, R.I64, R.STRING), params -> params.ref(2)); - - verifyRoundTrip(lambda); - assertEquals(R.STRING, ((Type.Func) lambda.getType()).returnType()); - } - - // (x: i32) -> 42 - @Test - void lambdaWithLiteralBody() { - Expression.Lambda lambda = - lb.lambda(List.of(R.I32), params -> ExpressionCreator.i32(false, 42)); - - verifyRoundTrip(lambda); - assertInstanceOf(Expression.I32Literal.class, lambda.body()); - } - - // Parameterized: (params...) -> params[fieldIndex], verifying type resolution - @Test - void typeResolution() { - record TestCase(String name, List paramTypes, int fieldIndex, Type expectedType) {} - - List testCases = - List.of( - new TestCase("first param (i32)", List.of(R.I32), 0, R.I32), - new TestCase("second param (i64)", List.of(R.I32, R.I64), 1, R.I64), - new TestCase("third param (string)", List.of(R.I32, R.I64, R.STRING), 2, R.STRING), - new TestCase("float64 param", List.of(R.FP64), 0, R.FP64), - new TestCase("date param", List.of(R.BOOLEAN, R.DATE, R.TIMESTAMP), 1, R.DATE)); - - for (TestCase tc : testCases) { - Expression.Lambda lambda = lb.lambda(tc.paramTypes, params -> params.ref(tc.fieldIndex)); - - verifyRoundTrip(lambda); - - assertEquals(tc.expectedType, lambda.body().getType(), tc.name + ": body type mismatch"); - Type.Func funcType = (Type.Func) lambda.getType(); - assertEquals( - tc.expectedType, funcType.returnType(), tc.name + ": lambda return type mismatch"); - } - } - - // (x: i32, y: string) -> y — verify full Func type structure - @Test - void lambdaGetTypeReturnsFunc() { - Expression.Lambda lambda = lb.lambda(List.of(R.I32, R.STRING), params -> params.ref(1)); - - Type.Func funcType = (Type.Func) lambda.getType(); - assertEquals(2, funcType.parameterTypes().size()); - assertEquals(R.I32, funcType.parameterTypes().get(0)); - assertEquals(R.STRING, funcType.parameterTypes().get(1)); - assertEquals(R.STRING, funcType.returnType()); - } - - // (x: i32, y: i64, z: string) -> ... — verify FieldReference metadata for each param - @Test - void parameterReferenceMetadata() { - List paramTypes = List.of(R.I32, R.I64, R.STRING); - - lb.lambda( - paramTypes, - params -> { - for (int i = 0; i < 3; i++) { - FieldReference ref = params.ref(i); - assertTrue(ref.isLambdaParameterReference()); - assertFalse(ref.isOuterReference()); - assertFalse(ref.isSimpleRootReference()); - assertEquals(0, ref.lambdaParameterReferenceStepsOut().orElse(-1)); - assertEquals(paramTypes.get(i), ref.getType()); - } - return params.ref(0); - }); - } - - // ==================== Expression Body Tests ==================== - - // (x: i64) -> (y1: i64, y2: i64) -> y1 * x + y2 - @Test - void nestedLambdaWithArithmeticBody() { - String ARITH = DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC; - - Expression.Lambda result = - lb.lambda( - List.of(R.I64), - outer -> - lb.lambda( - List.of(R.I64, R.I64), - inner -> { - // y1 * x - Expression multiply = - sb.scalarFn(ARITH, "multiply:i64_i64", R.I64, inner.ref(0), outer.ref(0)); - // (y1 * x) + y2 - return sb.scalarFn(ARITH, "add:i64_i64", R.I64, multiply, inner.ref(1)); - })); - - verifyRoundTrip(result); - - // Outer lambda returns a func type - Type.Func outerFuncType = (Type.Func) result.getType(); - assertInstanceOf(Type.Func.class, outerFuncType.returnType()); - - // Inner lambda returns i64 - Expression.Lambda innerLambda = (Expression.Lambda) result.body(); - Type.Func innerFuncType = (Type.Func) innerLambda.getType(); - assertEquals(R.I64, innerFuncType.returnType()); - - // Inner body is a scalar function (add) - assertInstanceOf(Expression.ScalarFunctionInvocation.class, innerLambda.body()); - } - - // ==================== Nested Lambda Tests ==================== - - // (x: i64, y: i64) -> (z: i32) -> x - @Test - void nestedLambdaWithOuterRef() { - Expression.Lambda result = - lb.lambda(List.of(R.I64, R.I64), outer -> lb.lambda(List.of(R.I32), inner -> outer.ref(0))); - - verifyRoundTrip(result); - - Expression.Lambda resultInner = (Expression.Lambda) result.body(); - assertEquals(1, resultInner.parameters().fields().size()); - assertEquals(R.I64, resultInner.body().getType()); - } - - // (x: i32, y: i64, z: string) -> (w: fp64) -> z - @Test - void nestedLambdaOuterRefTypeResolution() { - Expression.Lambda result = - lb.lambda( - List.of(R.I32, R.I64, R.STRING), - outer -> lb.lambda(List.of(R.FP64), inner -> outer.ref(2))); - - verifyRoundTrip(result); - - Expression.Lambda resultInner = (Expression.Lambda) result.body(); - assertEquals(R.STRING, ((Type.Func) resultInner.getType()).returnType()); + static Stream validLambdaExpressions() throws IOException { + return listJsonResources("expressions/lambda/valid"); } - // (x: i32) -> (y: i64) -> y - @Test - void nestedLambdaInnerRefOnly() { - Expression.Lambda result = - lb.lambda(List.of(R.I32), outer -> lb.lambda(List.of(R.I64), inner -> inner.ref(0))); - - verifyRoundTrip(result); - - Expression.Lambda innerLambda = (Expression.Lambda) result.body(); - assertEquals(R.I64, innerLambda.body().getType()); - assertInstanceOf(Type.Func.class, ((Type.Func) result.getType()).returnType()); + static Stream invalidLambdaExpressions() throws IOException { + return listJsonResources("expressions/lambda/invalid"); } - // (x: i32) -> (y: i64) -> (x, y) — body references both outer and inner params - @Test - void nestedLambdaBothInnerAndOuterRefs() { - Expression.Lambda result = - lb.lambda( - List.of(R.I32), - outer -> - lb.lambda( - List.of(R.I64), - inner -> { - FieldReference innerRef = inner.ref(0); - assertEquals(R.I64, innerRef.getType()); - assertEquals(0, innerRef.lambdaParameterReferenceStepsOut().orElse(-1)); - - FieldReference outerRef = outer.ref(0); - assertEquals(R.I32, outerRef.getType()); - assertEquals(1, outerRef.lambdaParameterReferenceStepsOut().orElse(-1)); - - return innerRef; - })); - - verifyRoundTrip(result); + @ParameterizedTest + @MethodSource("validLambdaExpressions") + void validLambdaExpressionRoundtrip(String resourcePath) throws IOException { + Expression deserialized = deserializeExpression(resourcePath); + assertInstanceOf(Expression.Lambda.class, deserialized); + verifyRoundTrip(deserialized); } - // (a: i32, b: string) -> (c: i64, d: fp64) -> b — verify all 4 params resolve correctly - @Test - void nestedLambdaMultiParamCorrectResolution() { - Expression.Lambda result = - lb.lambda( - List.of(R.I32, R.STRING), - outer -> - lb.lambda( - List.of(R.I64, R.FP64), - inner -> { - assertEquals(R.I64, inner.ref(0).getType()); - assertEquals(R.FP64, inner.ref(1).getType()); - assertEquals(R.I32, outer.ref(0).getType()); - assertEquals(R.STRING, outer.ref(1).getType()); - - return outer.ref(1); - })); - - verifyRoundTrip(result); - - Expression.Lambda innerLambda = (Expression.Lambda) result.body(); - assertEquals(R.STRING, ((Type.Func) innerLambda.getType()).returnType()); + @ParameterizedTest + @MethodSource("invalidLambdaExpressions") + void invalidLambdaExpressionRejected(String resourcePath) { + assertThrows(Exception.class, () -> deserializeExpression(resourcePath)); } - // (x: i32) -> (y: i64) -> (z: string) -> x - @Test - void tripleNestedLambdaRoundtrip() { - Expression.Lambda result = - lb.lambda( - List.of(R.I32), - outer -> - lb.lambda( - List.of(R.I64), mid -> lb.lambda(List.of(R.STRING), inner -> outer.ref(0)))); - - verifyRoundTrip(result); - - Expression.Lambda l1 = (Expression.Lambda) result.body(); - Expression.Lambda l2 = (Expression.Lambda) l1.body(); - assertEquals(R.I32, l2.body().getType()); + private static Stream listJsonResources(String dirPath) throws IOException { + Path dir = + Paths.get( + LambdaExpressionRoundtripTest.class.getClassLoader().getResource(dirPath).getPath()); + return Files.list(dir) + .filter(p -> p.toString().endsWith(".json")) + .map(p -> dirPath + "/" + p.getFileName().toString()) + .sorted(); } - // (x: i32) -> (y: i64) -> (z: string) -> ... — verify stepsOut is auto-computed at each level - @Test - void tripleNestedLambdaScopeTracking() { - lb.lambda( - List.of(R.I32), - outer -> - lb.lambda( - List.of(R.I64), - mid -> - lb.lambda( - List.of(R.STRING), - inner -> { - assertEquals(R.STRING, inner.ref(0).getType()); - assertEquals(R.I64, mid.ref(0).getType()); - assertEquals(R.I32, outer.ref(0).getType()); - - assertEquals( - 0, inner.ref(0).lambdaParameterReferenceStepsOut().orElse(-1)); - assertEquals(1, mid.ref(0).lambdaParameterReferenceStepsOut().orElse(-1)); - assertEquals( - 2, outer.ref(0).lambdaParameterReferenceStepsOut().orElse(-1)); - - return inner.ref(0); - }))); + private Expression deserializeExpression(String resourcePath) throws IOException { + String json = asString(resourcePath); + io.substrait.proto.Expression.Builder builder = io.substrait.proto.Expression.newBuilder(); + JsonFormat.parser().merge(json, builder); + return protoExpressionConverter.from(builder.build()); } } diff --git a/core/src/test/resources/expressions/lambda/invalid/nested_steps_out.json b/core/src/test/resources/expressions/lambda/invalid/nested_steps_out.json new file mode 100644 index 000000000..54ec7321b --- /dev/null +++ b/core/src/test/resources/expressions/lambda/invalid/nested_steps_out.json @@ -0,0 +1,30 @@ +{ + "lambda": { + "parameters": { + "types": [ + { "i32": { "nullability": "NULLABILITY_REQUIRED" } } + ] + }, + "body": { + "lambda": { + "parameters": { + "types": [ + { "i64": { "nullability": "NULLABILITY_REQUIRED" } } + ] + }, + "body": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "lambdaParameterReference": { + "stepsOut": 2 + } + } + } + } + } + } +} diff --git a/core/src/test/resources/expressions/lambda/invalid/steps_out.json b/core/src/test/resources/expressions/lambda/invalid/steps_out.json new file mode 100644 index 000000000..148173e8b --- /dev/null +++ b/core/src/test/resources/expressions/lambda/invalid/steps_out.json @@ -0,0 +1,21 @@ +{ + "lambda": { + "parameters": { + "types": [ + { "i32": { "nullability": "NULLABILITY_REQUIRED" } } + ] + }, + "body": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "lambdaParameterReference": { + "stepsOut": 1 + } + } + } + } +} diff --git a/core/src/test/resources/expressions/lambda/valid/identity.json b/core/src/test/resources/expressions/lambda/valid/identity.json new file mode 100644 index 000000000..a984334ef --- /dev/null +++ b/core/src/test/resources/expressions/lambda/valid/identity.json @@ -0,0 +1,21 @@ +{ + "lambda": { + "parameters": { + "types": [ + { "i32": { "nullability": "NULLABILITY_REQUIRED" } } + ] + }, + "body": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "lambdaParameterReference": { + "stepsOut": 0 + } + } + } + } +} diff --git a/core/src/test/resources/expressions/lambda/valid/literal_body.json b/core/src/test/resources/expressions/lambda/valid/literal_body.json new file mode 100644 index 000000000..c36e2df5e --- /dev/null +++ b/core/src/test/resources/expressions/lambda/valid/literal_body.json @@ -0,0 +1,14 @@ +{ + "lambda": { + "parameters": { + "types": [ + { "i32": { "nullability": "NULLABILITY_REQUIRED" } } + ] + }, + "body": { + "literal": { + "i32": 42 + } + } + } +} diff --git a/core/src/test/resources/expressions/lambda/valid/multi_param.json b/core/src/test/resources/expressions/lambda/valid/multi_param.json new file mode 100644 index 000000000..1c31bc1e7 --- /dev/null +++ b/core/src/test/resources/expressions/lambda/valid/multi_param.json @@ -0,0 +1,23 @@ +{ + "lambda": { + "parameters": { + "types": [ + { "i32": { "nullability": "NULLABILITY_REQUIRED" } }, + { "i64": { "nullability": "NULLABILITY_REQUIRED" } }, + { "string": { "nullability": "NULLABILITY_REQUIRED" } } + ] + }, + "body": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "lambdaParameterReference": { + "stepsOut": 0 + } + } + } + } +} diff --git a/core/src/test/resources/expressions/lambda/valid/nested.json b/core/src/test/resources/expressions/lambda/valid/nested.json new file mode 100644 index 000000000..abf716b7a --- /dev/null +++ b/core/src/test/resources/expressions/lambda/valid/nested.json @@ -0,0 +1,30 @@ +{ + "lambda": { + "parameters": { + "types": [ + { "i32": { "nullability": "NULLABILITY_REQUIRED" } } + ] + }, + "body": { + "lambda": { + "parameters": { + "types": [ + { "i64": { "nullability": "NULLABILITY_REQUIRED" } } + ] + }, + "body": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "lambdaParameterReference": { + "stepsOut": 1 + } + } + } + } + } + } +} diff --git a/core/src/test/resources/expressions/lambda/valid/triple_nested.json b/core/src/test/resources/expressions/lambda/valid/triple_nested.json new file mode 100644 index 000000000..e1e692521 --- /dev/null +++ b/core/src/test/resources/expressions/lambda/valid/triple_nested.json @@ -0,0 +1,39 @@ +{ + "lambda": { + "parameters": { + "types": [ + { "i32": { "nullability": "NULLABILITY_REQUIRED" } } + ] + }, + "body": { + "lambda": { + "parameters": { + "types": [ + { "i64": { "nullability": "NULLABILITY_REQUIRED" } } + ] + }, + "body": { + "lambda": { + "parameters": { + "types": [ + { "string": { "nullability": "NULLABILITY_REQUIRED" } } + ] + }, + "body": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "lambdaParameterReference": { + "stepsOut": 2 + } + } + } + } + } + } + } + } +} diff --git a/core/src/test/resources/expressions/lambda/valid/zero_params.json b/core/src/test/resources/expressions/lambda/valid/zero_params.json new file mode 100644 index 000000000..df801c9cd --- /dev/null +++ b/core/src/test/resources/expressions/lambda/valid/zero_params.json @@ -0,0 +1,12 @@ +{ + "lambda": { + "parameters": { + "types": [] + }, + "body": { + "literal": { + "i32": 42 + } + } + } +} diff --git a/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java b/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java index 55924081d..0c294ff9f 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java @@ -1,10 +1,12 @@ package io.substrait.isthmus; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import io.substrait.expression.Expression; import io.substrait.expression.ExpressionCreator; import io.substrait.expression.LambdaBuilder; +import io.substrait.extension.DefaultExtensionCatalog; import io.substrait.relation.Project; import io.substrait.relation.Rel; import java.util.ArrayList; @@ -57,4 +59,35 @@ void nestedLambdaThrowsUnsupportedOperation() { Project project = Project.builder().expressions(exprs).input(emptyTable).build(); assertThrows(UnsupportedOperationException.class, () -> assertFullRoundTrip(project)); } + + // (x: i64) -> (y1: i64, y2: i64) -> y1 * x + y2 + @Test + void nestedLambdaWithArithmeticBody() { + String ARITH = DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC; + + Expression.Lambda lambda = + lb.lambda( + List.of(R.I64), + outer -> + lb.lambda( + List.of(R.I64, R.I64), + inner -> { + Expression multiply = + sb.scalarFn(ARITH, "multiply:i64_i64", R.I64, inner.ref(0), outer.ref(0)); + return sb.scalarFn(ARITH, "add:i64_i64", R.I64, multiply, inner.ref(1)); + })); + + // Proto-only roundtrip since Calcite doesn't support nested lambdas + List exprs = new ArrayList<>(); + exprs.add(lambda); + Project project = Project.builder().expressions(exprs).input(emptyTable).build(); + + io.substrait.extension.ExtensionCollector collector = + new io.substrait.extension.ExtensionCollector(); + io.substrait.proto.Rel proto = + new io.substrait.relation.RelProtoConverter(collector).toProto(project); + io.substrait.relation.Rel roundTripped = + new io.substrait.relation.ProtoRelConverter(collector, extensions).from(proto); + assertEquals(project, roundTripped); + } } From f172704dce0b71de4f971c55d3f7a653a6d66b5d Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Mon, 16 Mar 2026 17:35:59 -0400 Subject: [PATCH 14/26] docs: fix LambdaBuilder javadoc to use params/outer/inner naming --- .../main/java/io/substrait/expression/LambdaBuilder.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/core/src/main/java/io/substrait/expression/LambdaBuilder.java b/core/src/main/java/io/substrait/expression/LambdaBuilder.java index 1197cd1f1..621a5ef90 100644 --- a/core/src/main/java/io/substrait/expression/LambdaBuilder.java +++ b/core/src/main/java/io/substrait/expression/LambdaBuilder.java @@ -19,12 +19,12 @@ * LambdaBuilder lb = new LambdaBuilder(); * * // Simple: (x: i32) -> x - * Expression.Lambda simple = lb.lambda(List.of(R.I32), x -> x.ref(0)); + * Expression.Lambda simple = lb.lambda(List.of(R.I32), params -> params.ref(0)); * * // Nested: (x: i32) -> (y: i64) -> add(x, y) - * Expression.Lambda nested = lb.lambda(List.of(R.I32), x -> - * lb.lambda(List.of(R.I64), y -> - * add(x.ref(0), y.ref(0)) + * Expression.Lambda nested = lb.lambda(List.of(R.I32), outer -> + * lb.lambda(List.of(R.I64), inner -> + * add(outer.ref(0), inner.ref(0)) * ) * ); * } From 67c5a8a763af262653cfd72a662b3e9841ab3339 Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Mon, 16 Mar 2026 17:45:12 -0400 Subject: [PATCH 15/26] refactor: clarify Scope internals, extract stepsOut() method and document depth-capture mechanism --- .../substrait/expression/LambdaBuilder.java | 38 ++++++++++++------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/core/src/main/java/io/substrait/expression/LambdaBuilder.java b/core/src/main/java/io/substrait/expression/LambdaBuilder.java index 621a5ef90..8709b6aa6 100644 --- a/core/src/main/java/io/substrait/expression/LambdaBuilder.java +++ b/core/src/main/java/io/substrait/expression/LambdaBuilder.java @@ -45,8 +45,7 @@ public Expression.Lambda lambda(List paramTypes, Function= lambdaContext.size()) { + int targetDepth = lambdaContext.size() - stepsOut; + if (targetDepth <= 0 || targetDepth > lambdaContext.size()) { throw new IllegalArgumentException( String.format( "Lambda parameter reference with stepsOut=%d is invalid (current depth: %d)", stepsOut, lambdaContext.size())); } - return lambdaContext.get(index); + return lambdaContext.get(targetDepth - 1); } /** @@ -113,26 +112,39 @@ private void popLambdaContext() { /** * A handle to a particular lambda's parameter scope. Use {@link #ref} to create validated * parameter references. + * + *

Each Scope captures the depth of the lambdaContext stack at the time it was created. When + * {@link #ref} is called, the Substrait {@code stepsOut} value is computed as the difference + * between the current stack depth and the captured depth. This means the same Scope produces + * different stepsOut values depending on the nesting level at the time of the call, which is what + * allows outer.ref(0) to produce stepsOut=1 when called inside a nested lambda. */ public class Scope { - private final int index; + private final Type.Struct params; + private final int depth; + + private Scope(Type.Struct params) { + this.params = params; + this.depth = lambdaContext.size(); + } - private Scope(int index) { - this.index = index; + /** + * Computes the number of lambda boundaries between this scope and the current innermost scope. + * This value changes dynamically as nested lambdas are built. + */ + private int stepsOut() { + return lambdaContext.size() - depth; } /** - * Creates a validated reference to a parameter of this lambda. The correct {@code stepsOut} - * value is computed automatically. + * Creates a validated reference to a parameter of this lambda. * * @param paramIndex index of the parameter within this lambda's parameter struct * @return a {@link FieldReference} pointing to the specified parameter * @throws IndexOutOfBoundsException if paramIndex is out of bounds */ public FieldReference ref(int paramIndex) { - int stepsOut = lambdaContext.size() - 1 - index; - return FieldReference.newLambdaParameterReference( - paramIndex, lambdaContext.get(index), stepsOut); + return FieldReference.newLambdaParameterReference(paramIndex, params, stepsOut()); } } } From eed9ea92cf79243a54f14486ea398f4e0a9daf6c Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Mon, 16 Mar 2026 17:46:49 -0400 Subject: [PATCH 16/26] test: add test verifying stepsOut changes dynamically with nesting depth --- .../expression/LambdaBuilderTest.java | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java b/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java index 492c56f30..a4c569219 100644 --- a/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java +++ b/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java @@ -54,6 +54,31 @@ void nestedLambda() { assertEquals(expected, lambda); } + // Verify that the same scope handle produces different stepsOut values depending on nesting. + // outer.ref(0) should produce stepsOut=0 at the top level and stepsOut=1 inside a nested lambda. + @Test + void scopeStepsOutChangesDynamically() { + Type.Struct outerParams = Type.Struct.builder().nullable(false).addFields(R.I32).build(); + Type.Struct innerParams = Type.Struct.builder().nullable(false).addFields(R.I64).build(); + + lb.lambda( + List.of(R.I32), + outer -> { + FieldReference atTopLevel = outer.ref(0); + assertEquals(0, atTopLevel.lambdaParameterReferenceStepsOut().orElse(-1)); + + lb.lambda( + List.of(R.I64), + inner -> { + FieldReference atNestedLevel = outer.ref(0); + assertEquals(1, atNestedLevel.lambdaParameterReferenceStepsOut().orElse(-1)); + return inner.ref(0); + }); + + return atTopLevel; + }); + } + // (x: i32)@p -> p[5] — only 1 param, index 5 is out of bounds @Test void invalidFieldIndex_outOfBounds() { From c37d527a2fea7eedbdd00af258fa78f7ef64f9dc Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Mon, 16 Mar 2026 17:54:32 -0400 Subject: [PATCH 17/26] test: simplify arithmetic body test to single lambda (x -> x + x) --- .../isthmus/LambdaExpressionTest.java | 24 ++++--------------- 1 file changed, 4 insertions(+), 20 deletions(-) diff --git a/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java b/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java index 0c294ff9f..fb33407b8 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java @@ -1,6 +1,5 @@ package io.substrait.isthmus; -import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import io.substrait.expression.Expression; @@ -60,34 +59,19 @@ void nestedLambdaThrowsUnsupportedOperation() { assertThrows(UnsupportedOperationException.class, () -> assertFullRoundTrip(project)); } - // (x: i64) -> (y1: i64, y2: i64) -> y1 * x + y2 + // (x: i64)@p -> add(p[0], p[0]) @Test - void nestedLambdaWithArithmeticBody() { + void lambdaWithArithmeticBody() { String ARITH = DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC; Expression.Lambda lambda = lb.lambda( List.of(R.I64), - outer -> - lb.lambda( - List.of(R.I64, R.I64), - inner -> { - Expression multiply = - sb.scalarFn(ARITH, "multiply:i64_i64", R.I64, inner.ref(0), outer.ref(0)); - return sb.scalarFn(ARITH, "add:i64_i64", R.I64, multiply, inner.ref(1)); - })); + params -> sb.scalarFn(ARITH, "add:i64_i64", R.I64, params.ref(0), params.ref(0))); - // Proto-only roundtrip since Calcite doesn't support nested lambdas List exprs = new ArrayList<>(); exprs.add(lambda); Project project = Project.builder().expressions(exprs).input(emptyTable).build(); - - io.substrait.extension.ExtensionCollector collector = - new io.substrait.extension.ExtensionCollector(); - io.substrait.proto.Rel proto = - new io.substrait.relation.RelProtoConverter(collector).toProto(project); - io.substrait.relation.Rel roundTripped = - new io.substrait.relation.ProtoRelConverter(collector, extensions).from(proto); - assertEquals(project, roundTripped); + assertFullRoundTrip(project); } } From d55762905a0a53252b98f5ae506ab974796cdd7c Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Mon, 16 Mar 2026 18:05:00 -0400 Subject: [PATCH 18/26] fix: remove unused local variables flagged by PMD --- .../test/java/io/substrait/expression/LambdaBuilderTest.java | 3 --- 1 file changed, 3 deletions(-) diff --git a/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java b/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java index a4c569219..01303d932 100644 --- a/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java +++ b/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java @@ -58,9 +58,6 @@ void nestedLambda() { // outer.ref(0) should produce stepsOut=0 at the top level and stepsOut=1 inside a nested lambda. @Test void scopeStepsOutChangesDynamically() { - Type.Struct outerParams = Type.Struct.builder().nullable(false).addFields(R.I32).build(); - Type.Struct innerParams = Type.Struct.builder().nullable(false).addFields(R.I64).build(); - lb.lambda( List.of(R.I32), outer -> { From 5c2c99cb8ef057f2854b8e78cd1c7a04fffa1b41 Mon Sep 17 00:00:00 2001 From: Limame Malainine Date: Tue, 17 Mar 2026 16:15:50 +0100 Subject: [PATCH 19/26] fix: remove uri mentions left in substrait plans and add all_match and any_match in function mappings --- .../isthmus/expression/FunctionMappings.java | 19 ++++++++++++++++++- .../test/resources/lambdas/basic-lambda.json | 7 ------- .../resources/lambdas/lambda-field-ref.json | 7 ------- .../lambdas/lambda-with-function.json | 12 ------------ 4 files changed, 18 insertions(+), 27 deletions(-) diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java index b185abf15..43c6921e9 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java @@ -12,6 +12,7 @@ import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.type.OperandTypes; import org.apache.calcite.sql.type.SqlTypeFamily; +import org.apache.calcite.sql.type.SqlTypeName; public class FunctionMappings { // Static list of signature mapping between Calcite SQL operators and Substrait base function @@ -31,6 +32,20 @@ public class FunctionMappings { opBinding -> opBinding.getOperandType(0), OperandTypes.family(SqlTypeFamily.ARRAY, SqlTypeFamily.ANY)); + /** The any_match:list_func function; returns true if any element matches the predicate. */ + public static final SqlFunction ANY_MATCH = + SqlBasicFunction.create( + "any_match", + opBinding -> opBinding.getTypeFactory().createSqlType(SqlTypeName.BOOLEAN), + OperandTypes.family(SqlTypeFamily.ARRAY, SqlTypeFamily.ANY)); + + /** The all_match:list_func function; returns true if all elements match the predicate. */ + public static final SqlFunction ALL_MATCH = + SqlBasicFunction.create( + "all_match", + opBinding -> opBinding.getTypeFactory().createSqlType(SqlTypeName.BOOLEAN), + OperandTypes.family(SqlTypeFamily.ARRAY, SqlTypeFamily.ANY)); + public static final ImmutableList SCALAR_SIGS = ImmutableList.builder() .add( @@ -120,7 +135,9 @@ public class FunctionMappings { s(SqlLibraryOperators.PARSE_TIMESTAMP, "strptime_timestamp"), s(SqlLibraryOperators.PARSE_DATE, "strptime_date"), s(TRANSFORM, "transform"), - s(FILTER, "filter")) + s(FILTER, "filter"), + s(ANY_MATCH, "any_match"), + s(ALL_MATCH, "all_match")) .build(); public static final ImmutableList AGGREGATE_SIGS = diff --git a/isthmus/src/test/resources/lambdas/basic-lambda.json b/isthmus/src/test/resources/lambdas/basic-lambda.json index 114e3ad6d..0a542c74e 100644 --- a/isthmus/src/test/resources/lambdas/basic-lambda.json +++ b/isthmus/src/test/resources/lambdas/basic-lambda.json @@ -9,17 +9,10 @@ "urn": "extension:io.substrait:functions_list" } ], - "extensionUris": [ - { - "extensionUriAnchor": 1, - "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_list.yaml" - } - ], "extensions": [ { "extensionFunction": { "extensionUrnReference": 1, - "extensionUriReference": 1, "functionAnchor": 1, "name": "transform:list_func" } diff --git a/isthmus/src/test/resources/lambdas/lambda-field-ref.json b/isthmus/src/test/resources/lambdas/lambda-field-ref.json index 58c041582..a16657663 100644 --- a/isthmus/src/test/resources/lambdas/lambda-field-ref.json +++ b/isthmus/src/test/resources/lambdas/lambda-field-ref.json @@ -5,17 +5,10 @@ "urn": "extension:io.substrait:functions_list" } ], - "extensionUris": [ - { - "extensionUriAnchor": 1, - "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_list.yaml" - } - ], "extensions": [ { "extensionFunction": { "extensionUrnReference": 1, - "extensionUriReference": 1, "functionAnchor": 1, "name": "transform:list_func" } diff --git a/isthmus/src/test/resources/lambdas/lambda-with-function.json b/isthmus/src/test/resources/lambdas/lambda-with-function.json index 9c0a1a55b..f11451160 100644 --- a/isthmus/src/test/resources/lambdas/lambda-with-function.json +++ b/isthmus/src/test/resources/lambdas/lambda-with-function.json @@ -9,21 +9,10 @@ "urn": "extension:io.substrait:functions_list" } ], - "extensionUris": [ - { - "extensionUriAnchor": 1, - "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml" - }, - { - "extensionUriAnchor": 2, - "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_list.yaml" - } - ], "extensions": [ { "extensionFunction": { "extensionUrnReference": 1, - "extensionUriReference": 1, "functionAnchor": 1, "name": "multiply:i32_i32" } @@ -31,7 +20,6 @@ { "extensionFunction": { "extensionUrnReference": 2, - "extensionUriReference": 2, "functionAnchor": 2, "name": "transform:list_func" } From 8b81dbbd40bc2139f135766f73f3e4933584b525 Mon Sep 17 00:00:00 2001 From: Limame Malainine Date: Wed, 18 Mar 2026 14:33:45 +0100 Subject: [PATCH 20/26] adressing some of @benbellick's comments --- .../java/io/substrait/expression/FieldReference.java | 9 +++++++++ .../main/java/io/substrait/expression/LambdaBuilder.java | 2 +- .../java/io/substrait/expression/LambdaBuilderTest.java | 3 ++- 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/io/substrait/expression/FieldReference.java b/core/src/main/java/io/substrait/expression/FieldReference.java index af99c74d1..a60f7218e 100644 --- a/core/src/main/java/io/substrait/expression/FieldReference.java +++ b/core/src/main/java/io/substrait/expression/FieldReference.java @@ -37,6 +37,14 @@ public R accept( return visitor.visit(this, context); } + @Value.Check + protected void check() { + if (outerReferenceStepsOut().isPresent() && lambdaParameterReferenceStepsOut().isPresent()) { + throw new IllegalArgumentException( + "FieldReference cannot have both outerReferenceStepsOut and lambdaParameterReferenceStepsOut set"); + } + } + public boolean isSimpleRootReference() { return segments().size() == 1 && !inputExpression().isPresent() @@ -48,6 +56,7 @@ public boolean isOuterReference() { return outerReferenceStepsOut().orElse(0) > 0; } + /** Returns true if this field reference refers to a lambda parameter. */ public boolean isLambdaParameterReference() { return lambdaParameterReferenceStepsOut().isPresent(); } diff --git a/core/src/main/java/io/substrait/expression/LambdaBuilder.java b/core/src/main/java/io/substrait/expression/LambdaBuilder.java index 8709b6aa6..baeef83a8 100644 --- a/core/src/main/java/io/substrait/expression/LambdaBuilder.java +++ b/core/src/main/java/io/substrait/expression/LambdaBuilder.java @@ -143,7 +143,7 @@ private int stepsOut() { * @return a {@link FieldReference} pointing to the specified parameter * @throws IndexOutOfBoundsException if paramIndex is out of bounds */ - public FieldReference ref(int paramIndex) { + public FieldReference ref(int paramIndex) throws IndexOutOfBoundsException { return FieldReference.newLambdaParameterReference(paramIndex, params, stepsOut()); } } diff --git a/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java b/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java index 01303d932..86ff41ee3 100644 --- a/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java +++ b/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java @@ -86,7 +86,8 @@ void invalidFieldIndex_outOfBounds() { // (x: i32)@p -> p[-1] — negative index @Test void negativeFieldIndex() { - assertThrows(Exception.class, () -> lb.lambda(List.of(R.I32), params -> params.ref(-1))); + assertThrows( + IndexOutOfBoundsException.class, () -> lb.lambda(List.of(R.I32), params -> params.ref(-1))); } // (x: i32)@outer -> (y: i64)@inner -> outer[5] — outer only has 1 param From aa659c5dbfd85f18699a47c2f8ee2ad2c9c95cc4 Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Wed, 18 Mar 2026 13:54:54 -0400 Subject: [PATCH 21/26] refactor: reorder newLambdaParameterReference parameters for readability Change parameter order from (paramIndex, lambdaParamsType, stepsOut) to (stepsOut, paramIndex, lambdaParamsType) so the long type parameter is at the end, improving readability at call sites. --- .../src/main/java/io/substrait/expression/FieldReference.java | 2 +- core/src/main/java/io/substrait/expression/LambdaBuilder.java | 2 +- .../substrait/expression/proto/ProtoExpressionConverter.java | 4 ++-- .../test/java/io/substrait/expression/LambdaBuilderTest.java | 4 ++-- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/core/src/main/java/io/substrait/expression/FieldReference.java b/core/src/main/java/io/substrait/expression/FieldReference.java index a60f7218e..4a7ddb1a7 100644 --- a/core/src/main/java/io/substrait/expression/FieldReference.java +++ b/core/src/main/java/io/substrait/expression/FieldReference.java @@ -151,7 +151,7 @@ public static FieldReference newInputRelReference(int index, List rels) { } public static FieldReference newLambdaParameterReference( - int paramIndex, Type.Struct lambdaParamsType, int stepsOut) { + int stepsOut, int paramIndex, Type.Struct lambdaParamsType) { return ImmutableFieldReference.builder() .addSegments(StructField.of(paramIndex)) .type(lambdaParamsType.fields().get(paramIndex)) diff --git a/core/src/main/java/io/substrait/expression/LambdaBuilder.java b/core/src/main/java/io/substrait/expression/LambdaBuilder.java index baeef83a8..1f2d54b7d 100644 --- a/core/src/main/java/io/substrait/expression/LambdaBuilder.java +++ b/core/src/main/java/io/substrait/expression/LambdaBuilder.java @@ -144,7 +144,7 @@ private int stepsOut() { * @throws IndexOutOfBoundsException if paramIndex is out of bounds */ public FieldReference ref(int paramIndex) throws IndexOutOfBoundsException { - return FieldReference.newLambdaParameterReference(paramIndex, params, stepsOut()); + return FieldReference.newLambdaParameterReference(stepsOut(), paramIndex, params); } } } diff --git a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java index 470c13c6a..e1093e4ab 100644 --- a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java @@ -92,9 +92,9 @@ public FieldReference from(io.substrait.proto.Expression.FieldReference referenc } return FieldReference.newLambdaParameterReference( + stepsOut, reference.getDirectReference().getStructField().getField(), - lambdaParameters, - stepsOut); + lambdaParameters); } case ROOTTYPE_NOT_SET: default: diff --git a/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java b/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java index 86ff41ee3..e87dc144b 100644 --- a/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java +++ b/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java @@ -25,7 +25,7 @@ void simpleLambda() { .parameters(Type.Struct.builder().nullable(false).addFields(R.I32).build()) .body( FieldReference.newLambdaParameterReference( - 0, Type.Struct.builder().nullable(false).addFields(R.I32).build(), 0)) + 0, 0, Type.Struct.builder().nullable(false).addFields(R.I32).build())) .build(); assertEquals(expected, lambda); @@ -42,7 +42,7 @@ void nestedLambda() { .parameters(Type.Struct.builder().nullable(false).addFields(R.I64).build()) .body( FieldReference.newLambdaParameterReference( - 0, Type.Struct.builder().nullable(false).addFields(R.I32).build(), 1)) + 1, 0, Type.Struct.builder().nullable(false).addFields(R.I32).build())) .build(); Expression.Lambda expected = From e3e9f48b2a4a68f78c8b655abeaa8d1dba042222 Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Wed, 18 Mar 2026 14:16:35 -0400 Subject: [PATCH 22/26] docs: add javadoc to newLambdaParameterReference explaining validation Clarify that this method does not validate stepsOut and that callers should use LambdaBuilder.Scope.ref() for validated lambda construction. --- .../java/io/substrait/expression/FieldReference.java | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/core/src/main/java/io/substrait/expression/FieldReference.java b/core/src/main/java/io/substrait/expression/FieldReference.java index 4a7ddb1a7..00ced72fb 100644 --- a/core/src/main/java/io/substrait/expression/FieldReference.java +++ b/core/src/main/java/io/substrait/expression/FieldReference.java @@ -150,6 +150,16 @@ public static FieldReference newInputRelReference(int index, List rels) { index, currentOffset)); } + /** + * Creates a field reference to a lambda parameter. This method does not validate that stepsOut is + * correct for any particular lambda nesting context. For validated lambda construction, use + * {@link LambdaBuilder} and {@link LambdaBuilder.Scope#ref(int)}. + * + * @param stepsOut number of lambda scopes to traverse outward (0 = innermost/current lambda) + * @param paramIndex index of the parameter within the lambda's parameter struct + * @param lambdaParamsType the lambda's parameter struct type + * @return a field reference to the specified lambda parameter + */ public static FieldReference newLambdaParameterReference( int stepsOut, int paramIndex, Type.Struct lambdaParamsType) { return ImmutableFieldReference.builder() From b35ad6343344131422de01ae7af9fb20bfb585e0 Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Wed, 18 Mar 2026 14:26:23 -0400 Subject: [PATCH 23/26] refactor: make newLambdaParameterReference package-private Add LambdaBuilder.newParameterReference() as the public API for creating validated lambda parameter references. This ensures stepsOut is always validated against the current lambda nesting context. ProtoExpressionConverter now uses this method instead of calling FieldReference.newLambdaParameterReference() directly. --- .../io/substrait/expression/FieldReference.java | 12 +----------- .../io/substrait/expression/LambdaBuilder.java | 15 +++++++++++++++ .../proto/ProtoExpressionConverter.java | 10 +++------- 3 files changed, 19 insertions(+), 18 deletions(-) diff --git a/core/src/main/java/io/substrait/expression/FieldReference.java b/core/src/main/java/io/substrait/expression/FieldReference.java index 00ced72fb..55f4c0050 100644 --- a/core/src/main/java/io/substrait/expression/FieldReference.java +++ b/core/src/main/java/io/substrait/expression/FieldReference.java @@ -150,17 +150,7 @@ public static FieldReference newInputRelReference(int index, List rels) { index, currentOffset)); } - /** - * Creates a field reference to a lambda parameter. This method does not validate that stepsOut is - * correct for any particular lambda nesting context. For validated lambda construction, use - * {@link LambdaBuilder} and {@link LambdaBuilder.Scope#ref(int)}. - * - * @param stepsOut number of lambda scopes to traverse outward (0 = innermost/current lambda) - * @param paramIndex index of the parameter within the lambda's parameter struct - * @param lambdaParamsType the lambda's parameter struct type - * @return a field reference to the specified lambda parameter - */ - public static FieldReference newLambdaParameterReference( + static FieldReference newLambdaParameterReference( int stepsOut, int paramIndex, Type.Struct lambdaParamsType) { return ImmutableFieldReference.builder() .addSegments(StructField.of(paramIndex)) diff --git a/core/src/main/java/io/substrait/expression/LambdaBuilder.java b/core/src/main/java/io/substrait/expression/LambdaBuilder.java index 1f2d54b7d..d6ce6ab30 100644 --- a/core/src/main/java/io/substrait/expression/LambdaBuilder.java +++ b/core/src/main/java/io/substrait/expression/LambdaBuilder.java @@ -92,6 +92,21 @@ public Type.Struct resolveParams(int stepsOut) { return lambdaContext.get(targetDepth - 1); } + /** + * Creates a validated field reference to a lambda parameter. Validates that stepsOut is valid for + * the current lambda nesting context. + * + * @param stepsOut number of lambda scopes to traverse outward (0 = current/innermost) + * @param paramIndex index of the parameter within the target lambda's parameter struct + * @return a field reference to the specified lambda parameter + * @throws IllegalArgumentException if stepsOut exceeds the current nesting depth + * @throws IndexOutOfBoundsException if paramIndex is out of bounds for the target lambda + */ + public FieldReference newParameterReference(int stepsOut, int paramIndex) { + Type.Struct params = resolveParams(stepsOut); + return FieldReference.newLambdaParameterReference(stepsOut, paramIndex, params); + } + /** * Pushes a lambda's parameters onto the context stack. This makes the parameters available for * validation when building the lambda's body, and allows nested lambda parameter references to diff --git a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java index e1093e4ab..54fa1a93c 100644 --- a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java @@ -82,19 +82,15 @@ public FieldReference from(io.substrait.proto.Expression.FieldReference referenc io.substrait.proto.Expression.FieldReference.LambdaParameterReference lambdaParamRef = reference.getLambdaParameterReference(); - int stepsOut = lambdaParamRef.getStepsOut(); - Type.Struct lambdaParameters = lambdaBuilder.resolveParams(stepsOut); - // Check for unsupported nested field access if (reference.getDirectReference().getStructField().hasChild()) { throw new UnsupportedOperationException( "Nested field access in lambda parameters is not yet supported"); } - return FieldReference.newLambdaParameterReference( - stepsOut, - reference.getDirectReference().getStructField().getField(), - lambdaParameters); + return lambdaBuilder.newParameterReference( + lambdaParamRef.getStepsOut(), + reference.getDirectReference().getStructField().getField()); } case ROOTTYPE_NOT_SET: default: From 7dc36b9add55cb7beef72eb66fe513e4878468f5 Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Wed, 18 Mar 2026 14:35:12 -0400 Subject: [PATCH 24/26] refactor: simplify newLambdaParameterReference to take Type directly Instead of passing the full Type.Struct and doing the field lookup internally, pass the already-extracted Type. This simplifies the method and moves the bounds checking to the caller. Also rename parameter to 'knownType' for consistency with other FieldReference factory methods. --- .../main/java/io/substrait/expression/FieldReference.java | 5 ++--- .../main/java/io/substrait/expression/LambdaBuilder.java | 6 ++++-- .../java/io/substrait/expression/LambdaBuilderTest.java | 8 ++------ 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/core/src/main/java/io/substrait/expression/FieldReference.java b/core/src/main/java/io/substrait/expression/FieldReference.java index 55f4c0050..322097ab2 100644 --- a/core/src/main/java/io/substrait/expression/FieldReference.java +++ b/core/src/main/java/io/substrait/expression/FieldReference.java @@ -150,11 +150,10 @@ public static FieldReference newInputRelReference(int index, List rels) { index, currentOffset)); } - static FieldReference newLambdaParameterReference( - int stepsOut, int paramIndex, Type.Struct lambdaParamsType) { + static FieldReference newLambdaParameterReference(int stepsOut, int paramIndex, Type knownType) { return ImmutableFieldReference.builder() .addSegments(StructField.of(paramIndex)) - .type(lambdaParamsType.fields().get(paramIndex)) + .type(knownType) .lambdaParameterReferenceStepsOut(stepsOut) .build(); } diff --git a/core/src/main/java/io/substrait/expression/LambdaBuilder.java b/core/src/main/java/io/substrait/expression/LambdaBuilder.java index d6ce6ab30..e8857dc11 100644 --- a/core/src/main/java/io/substrait/expression/LambdaBuilder.java +++ b/core/src/main/java/io/substrait/expression/LambdaBuilder.java @@ -104,7 +104,8 @@ public Type.Struct resolveParams(int stepsOut) { */ public FieldReference newParameterReference(int stepsOut, int paramIndex) { Type.Struct params = resolveParams(stepsOut); - return FieldReference.newLambdaParameterReference(stepsOut, paramIndex, params); + Type type = params.fields().get(paramIndex); + return FieldReference.newLambdaParameterReference(stepsOut, paramIndex, type); } /** @@ -159,7 +160,8 @@ private int stepsOut() { * @throws IndexOutOfBoundsException if paramIndex is out of bounds */ public FieldReference ref(int paramIndex) throws IndexOutOfBoundsException { - return FieldReference.newLambdaParameterReference(stepsOut(), paramIndex, params); + Type type = params.fields().get(paramIndex); + return FieldReference.newLambdaParameterReference(stepsOut(), paramIndex, type); } } } diff --git a/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java b/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java index e87dc144b..87e3c51e2 100644 --- a/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java +++ b/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java @@ -23,9 +23,7 @@ void simpleLambda() { Expression.Lambda expected = ImmutableExpression.Lambda.builder() .parameters(Type.Struct.builder().nullable(false).addFields(R.I32).build()) - .body( - FieldReference.newLambdaParameterReference( - 0, 0, Type.Struct.builder().nullable(false).addFields(R.I32).build())) + .body(FieldReference.newLambdaParameterReference(0, 0, R.I32)) .build(); assertEquals(expected, lambda); @@ -40,9 +38,7 @@ void nestedLambda() { Expression.Lambda expectedInner = ImmutableExpression.Lambda.builder() .parameters(Type.Struct.builder().nullable(false).addFields(R.I64).build()) - .body( - FieldReference.newLambdaParameterReference( - 1, 0, Type.Struct.builder().nullable(false).addFields(R.I32).build())) + .body(FieldReference.newLambdaParameterReference(1, 0, R.I32)) .build(); Expression.Lambda expected = From c88459d573de4a5578f1f809a3dd0bd317c01357 Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Wed, 18 Mar 2026 17:24:13 -0400 Subject: [PATCH 25/26] test: add lambdaWithFunctionCall test to LambdaBuilderTest --- .../expression/LambdaBuilderTest.java | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java b/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java index 87e3c51e2..4e41aa856 100644 --- a/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java +++ b/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java @@ -3,6 +3,9 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; +import io.substrait.dsl.SubstraitBuilder; +import io.substrait.extension.DefaultExtensionCatalog; +import io.substrait.extension.SimpleExtension; import io.substrait.type.Type; import io.substrait.type.TypeCreator; import java.util.List; @@ -12,8 +15,11 @@ class LambdaBuilderTest { static final TypeCreator R = TypeCreator.REQUIRED; + static final SimpleExtension.ExtensionCollection EXTENSIONS = + DefaultExtensionCatalog.DEFAULT_COLLECTION; final LambdaBuilder lb = new LambdaBuilder(); + final SubstraitBuilder sb = new SubstraitBuilder(EXTENSIONS); // (x: i32)@p -> p[0] @Test @@ -50,6 +56,23 @@ void nestedLambda() { assertEquals(expected, lambda); } + // (x: i64) -> add(x, x) + // Example of a lambda with a function call in the body + @Test + void lambdaWithFunctionCall() { + String ARITH = DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC; + + Expression.Lambda lambda = + lb.lambda( + List.of(R.I64), + params -> sb.scalarFn(ARITH, "add:i64_i64", R.I64, params.ref(0), params.ref(0))); + + Type.Struct expectedParams = Type.Struct.builder().nullable(false).addFields(R.I64).build(); + assertEquals(expectedParams, lambda.parameters()); + + assertEquals(R.I64, lambda.body().getType()); + } + // Verify that the same scope handle produces different stepsOut values depending on nesting. // outer.ref(0) should produce stepsOut=0 at the top level and stepsOut=1 inside a nested lambda. @Test From d34bb58c786e12865648ad55cf04e5420eafba98 Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Wed, 18 Mar 2026 17:28:09 -0400 Subject: [PATCH 26/26] test: add invalid proto test for out-of-bounds param index --- .../lambda/invalid/param_index.json | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 core/src/test/resources/expressions/lambda/invalid/param_index.json diff --git a/core/src/test/resources/expressions/lambda/invalid/param_index.json b/core/src/test/resources/expressions/lambda/invalid/param_index.json new file mode 100644 index 000000000..7a5545e9f --- /dev/null +++ b/core/src/test/resources/expressions/lambda/invalid/param_index.json @@ -0,0 +1,21 @@ +{ + "lambda": { + "parameters": { + "types": [ + { "i32": { "nullability": "NULLABILITY_REQUIRED" } } + ] + }, + "body": { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "lambdaParameterReference": { + "stepsOut": 0 + } + } + } + } +}