From 5ca6ff8bfa94229ac22546fbc41c5ce6d493dabb Mon Sep 17 00:00:00 2001 From: Limame Malainine Date: Tue, 24 Feb 2026 15:03:53 +0100 Subject: [PATCH 01/18] feat: add lambda expression support --- .../expression/AbstractExpressionVisitor.java | 26 + .../io/substrait/expression/Expression.java | 47 ++ .../expression/ExpressionVisitor.java | 20 + .../substrait/expression/FieldReference.java | 18 +- .../proto/ExpressionProtoConverter.java | 28 ++ .../proto/ProtoExpressionConverter.java | 59 +++ .../extension/DefaultExtensionCatalog.java | 3 + .../function/ExtendedTypeCreator.java | 2 + .../substrait/function/ParameterizedType.java | 17 + .../function/ParameterizedTypeCreator.java | 10 + .../function/ParameterizedTypeVisitor.java | 7 + .../io/substrait/function/TypeExpression.java | 16 + .../function/TypeExpressionCreator.java | 10 + .../function/TypeExpressionVisitor.java | 7 + .../ExpressionCopyOnWriteVisitor.java | 29 ++ .../substrait/relation/VirtualTableScan.java | 6 + .../io/substrait/type/StringTypeVisitor.java | 9 + .../src/main/java/io/substrait/type/Type.java | 16 + .../java/io/substrait/type/TypeCreator.java | 13 + .../java/io/substrait/type/TypeVisitor.java | 7 + .../type/proto/BaseProtoConverter.java | 10 + .../substrait/type/proto/BaseProtoTypes.java | 2 + .../proto/ParameterizedProtoConverter.java | 7 + .../type/proto/ProtoTypeConverter.java | 7 + .../proto/TypeExpressionProtoVisitor.java | 7 + .../type/proto/TypeProtoConverter.java | 12 + .../proto/LambdaExpressionRoundtripTest.java | 471 ++++++++++++++++++ .../examples/util/ExpressionStringify.java | 12 + .../examples/util/TypeStringify.java | 5 + .../expression/ExpressionRexConverter.java | 33 ++ .../IgnoreNullableAndParameters.java | 10 + .../expression/RexExpressionConverter.java | 23 +- .../isthmus/LambdaExpressionTest.java | 191 +++++++ .../IgnoreNullableAndParameters.scala | 8 + 34 files changed, 1145 insertions(+), 3 deletions(-) create mode 100644 core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java create mode 100644 isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java diff --git a/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java b/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java index e6ebb5782..a61f5d824 100644 --- a/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java +++ b/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java @@ -448,6 +448,32 @@ public O visit(Expression.IfThen expr, C context) throws E { return visitFallback(expr, context); } + /** + * Visits a Lambda expression. + * + * @param expr the Lambda expression + * @param context the visitation context + * @return the visit result + * @throws E if visitation fails + */ + @Override + public O visit(Expression.Lambda expr, C context) throws E { + return visitFallback(expr, context); + } + + /** + * Visits a Lambda expression invocation. + * + * @param expr the Lambda expression invocation + * @param context the visitation context + * @return the visit result + * @throws E if visitation fails + */ + @Override + public O visit(Expression.LambdaInvocation expr, C context) throws E { + return visitFallback(expr, context); + } + /** * Visits a scalar function invocation. * diff --git a/core/src/main/java/io/substrait/expression/Expression.java b/core/src/main/java/io/substrait/expression/Expression.java index 6361af3f5..1301c5c5c 100644 --- a/core/src/main/java/io/substrait/expression/Expression.java +++ b/core/src/main/java/io/substrait/expression/Expression.java @@ -707,6 +707,31 @@ public R accept( } } + @Value.Immutable + abstract class Lambda implements Expression { + public abstract Type.Struct parameters(); + + public abstract Expression body(); + + @Override + public Type getType() { + List paramTypes = parameters().fields(); + Type returnType = body().getType(); + + return Type.withNullability(false).func(paramTypes, returnType); + } + + public static ImmutableExpression.Lambda.Builder builder() { + return ImmutableExpression.Lambda.builder(); + } + + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); + } + } + /** * Base interface for user-defined literals. * @@ -902,6 +927,28 @@ public R accept( } } + @Value.Immutable + abstract class LambdaInvocation implements Expression { + public abstract Lambda lambda(); + + public abstract Expression.NestedStruct arguments(); + + @Override + public Type getType() { + return ((Type.Func) lambda().getType()).returnType(); + } + + public static ImmutableExpression.LambdaInvocation.Builder builder() { + return ImmutableExpression.LambdaInvocation.builder(); + } + + @Override + public R accept( + ExpressionVisitor visitor, C context) throws E { + return visitor.visit(this, context); + } + } + @Value.Immutable abstract class ScalarFunctionInvocation implements Expression { public abstract SimpleExtension.ScalarFunctionVariant declaration(); diff --git a/core/src/main/java/io/substrait/expression/ExpressionVisitor.java b/core/src/main/java/io/substrait/expression/ExpressionVisitor.java index 7f094b688..e36d92e4f 100644 --- a/core/src/main/java/io/substrait/expression/ExpressionVisitor.java +++ b/core/src/main/java/io/substrait/expression/ExpressionVisitor.java @@ -311,6 +311,16 @@ public interface ExpressionVisitor outerReferenceStepsOut(); + public abstract Optional lambdaParameterReferenceStepsOut(); + @Override public Type getType() { return type(); @@ -38,13 +40,18 @@ public R accept( public boolean isSimpleRootReference() { return segments().size() == 1 && !inputExpression().isPresent() - && !outerReferenceStepsOut().isPresent(); + && !outerReferenceStepsOut().isPresent() + && !lambdaParameterReferenceStepsOut().isPresent(); } public boolean isOuterReference() { return outerReferenceStepsOut().orElse(0) > 0; } + public boolean isLambdaParameterReference() { + return lambdaParameterReferenceStepsOut().isPresent(); + } + public FieldReference dereferenceStruct(int index) { Type newType = StructFieldFinder.getReferencedType(type(), index); return dereference(newType, StructField.of(index)); @@ -134,6 +141,15 @@ public static FieldReference newInputRelReference(int index, List rels) { index, currentOffset)); } + public static FieldReference newLambdaParameterReference( + int paramIndex, Type.Struct lambdaParamsType, int stepsOut) { + return ImmutableFieldReference.builder() + .addSegments(StructField.of(paramIndex)) + .type(lambdaParamsType.fields().get(paramIndex)) + .lambdaParameterReferenceStepsOut(stepsOut) + .build(); + } + public interface ReferenceSegment { FieldReference apply(FieldReference reference); diff --git a/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java b/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java index 498e6eada..45d587d73 100644 --- a/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java @@ -373,6 +373,18 @@ public Expression visit( }); } + @Override + public Expression visit( + io.substrait.expression.Expression.Lambda expr, EmptyVisitationContext context) + throws RuntimeException { + return io.substrait.proto.Expression.newBuilder() + .setLambda( + io.substrait.proto.Expression.Lambda.newBuilder() + .setParameters(typeProtoConverter.toProto(expr.parameters()).getStruct()) + .setBody(expr.body().accept(this, context))) + .build(); + } + @Override public Expression visit( io.substrait.expression.Expression.UserDefinedAnyLiteral expr, @@ -468,6 +480,18 @@ public Expression visit( .build(); } + @Override + public Expression visit( + io.substrait.expression.Expression.LambdaInvocation expr, EmptyVisitationContext context) + throws RuntimeException { + return io.substrait.proto.Expression.newBuilder() + .setLambdaInvocation( + io.substrait.proto.Expression.LambdaInvocation.newBuilder() + .setLambda(expr.lambda().accept(this, context).getLambda()) + .setArguments(expr.arguments().accept(this, context).getNested().getStruct())) + .build(); + } + @Override public Expression visit( io.substrait.expression.Expression.ScalarFunctionInvocation expr, @@ -603,6 +627,10 @@ public Expression visit(FieldReference expr, EmptyVisitationContext context) { out.setOuterReference( io.substrait.proto.Expression.FieldReference.OuterReference.newBuilder() .setStepsOut(expr.outerReferenceStepsOut().get())); + } else if (expr.lambdaParameterReferenceStepsOut().isPresent()) { + out.setLambdaParameterReference( + io.substrait.proto.Expression.FieldReference.LambdaParameterReference.newBuilder() + .setStepsOut(expr.lambdaParameterReferenceStepsOut().get())); } else { out.setRootReference(Expression.FieldReference.RootReference.getDefaultInstance()); } diff --git a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java index 01e25c907..fd00e55b2 100644 --- a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java @@ -37,6 +37,7 @@ public class ProtoExpressionConverter { private final Type.Struct rootType; private final ProtoTypeConverter protoTypeConverter; private final ProtoRelConverter protoRelConverter; + private final List lambdaParameterStack = new ArrayList<>(); public ProtoExpressionConverter( ExtensionLookup lookup, @@ -75,6 +76,26 @@ public FieldReference from(io.substrait.proto.Expression.FieldReference referenc reference.getDirectReference().getStructField().getField(), rootType, reference.getOuterReference().getStepsOut()); + case LAMBDA_PARAMETER_REFERENCE: + { + io.substrait.proto.Expression.FieldReference.LambdaParameterReference lambdaParamRef = + reference.getLambdaParameterReference(); + + int stepsOut = lambdaParamRef.getStepsOut(); + int lambdaIndex = lambdaParameterStack.size() - 1 - stepsOut; + if (lambdaIndex < 0 || lambdaIndex >= lambdaParameterStack.size()) { + throw new IllegalArgumentException( + String.format( + "Lambda parameter reference with stepsOut=%d is invalid (current depth: %d)", + stepsOut, lambdaParameterStack.size())); + } + + Type.Struct lambdaParameters = lambdaParameterStack.get(lambdaIndex); + return FieldReference.newLambdaParameterReference( + reference.getDirectReference().getStructField().getField(), + lambdaParameters, + stepsOut); + } case ROOTTYPE_NOT_SET: default: throw new IllegalArgumentException("Unhandled type: " + reference.getRootTypeCase()); @@ -260,6 +281,44 @@ public Type visit(Type.Struct type) throws RuntimeException { } } + case LAMBDA: + { + io.substrait.proto.Expression.Lambda protoLambda = expr.getLambda(); + Type.Struct parameters = + (Type.Struct) + protoTypeConverter.from( + io.substrait.proto.Type.newBuilder() + .setStruct(protoLambda.getParameters()) + .build()); + + lambdaParameterStack.add(parameters); + + Expression body; + try { + body = from(protoLambda.getBody()); + } finally { + lambdaParameterStack.remove(lambdaParameterStack.size() - 1); + } + + return Expression.Lambda.builder().parameters(parameters).body(body).build(); + } + case LAMBDA_INVOCATION: + { + io.substrait.proto.Expression.LambdaInvocation protoInvocation = + expr.getLambdaInvocation(); + + Expression.Lambda lambda = + (Expression.Lambda) + from( + io.substrait.proto.Expression.newBuilder() + .setLambda(protoInvocation.getLambda()) + .build()); + + Expression.NestedStruct arguments = from(protoInvocation.getArguments()); + + return Expression.LambdaInvocation.builder().lambda(lambda).arguments(arguments).build(); + } + // TODO enum. case ENUM: throw new UnsupportedOperationException("Unsupported type: " + expr.getRexTypeCase()); diff --git a/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java b/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java index 7b316d4be..478ed63e2 100644 --- a/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java +++ b/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java @@ -50,6 +50,9 @@ public class DefaultExtensionCatalog { /** Extension identifier for set functions. */ public static final String FUNCTIONS_SET = "extension:io.substrait:functions_set"; + /** Extension identifier for list functions. */ + public static final String FUNCTIONS_LIST = "extension:io.substrait:functions_list"; + /** Extension identifier for string functions. */ public static final String FUNCTIONS_STRING = "extension:io.substrait:functions_string"; diff --git a/core/src/main/java/io/substrait/function/ExtendedTypeCreator.java b/core/src/main/java/io/substrait/function/ExtendedTypeCreator.java index c3360093d..b802d93e0 100644 --- a/core/src/main/java/io/substrait/function/ExtendedTypeCreator.java +++ b/core/src/main/java/io/substrait/function/ExtendedTypeCreator.java @@ -16,4 +16,6 @@ public interface ExtendedTypeCreator { T listE(T type); T mapE(T key, T value); + + T funcE(Iterable parameterTypes, T returnType); } diff --git a/core/src/main/java/io/substrait/function/ParameterizedType.java b/core/src/main/java/io/substrait/function/ParameterizedType.java index e514fb975..f35c21f7a 100644 --- a/core/src/main/java/io/substrait/function/ParameterizedType.java +++ b/core/src/main/java/io/substrait/function/ParameterizedType.java @@ -200,6 +200,23 @@ R accept(final ParameterizedTypeVisitor parameter } } + @Value.Immutable + abstract class Func extends BaseParameterizedType implements NullableType { + public abstract java.util.List parameterTypes(); + + public abstract ParameterizedType returnType(); + + public static ImmutableParameterizedType.Func.Builder builder() { + return ImmutableParameterizedType.Func.builder(); + } + + @Override + R accept(final ParameterizedTypeVisitor parameterizedTypeVisitor) + throws E { + return parameterizedTypeVisitor.visit(this); + } + } + @Value.Immutable abstract class ListType extends BaseParameterizedType implements NullableType { public abstract ParameterizedType name(); diff --git a/core/src/main/java/io/substrait/function/ParameterizedTypeCreator.java b/core/src/main/java/io/substrait/function/ParameterizedTypeCreator.java index 6b89840f6..af7bda7e1 100644 --- a/core/src/main/java/io/substrait/function/ParameterizedTypeCreator.java +++ b/core/src/main/java/io/substrait/function/ParameterizedTypeCreator.java @@ -96,6 +96,16 @@ public ParameterizedType listE(ParameterizedType type) { return ParameterizedType.ListType.builder().nullable(nullable).name(type).build(); } + @Override + public ParameterizedType funcE( + Iterable parameterTypes, ParameterizedType returnType) { + return ParameterizedType.Func.builder() + .nullable(nullable) + .addAllParameterTypes(parameterTypes) + .returnType(returnType) + .build(); + } + @Override public ParameterizedType mapE(ParameterizedType key, ParameterizedType value) { return ParameterizedType.Map.builder().nullable(nullable).key(key).value(value).build(); diff --git a/core/src/main/java/io/substrait/function/ParameterizedTypeVisitor.java b/core/src/main/java/io/substrait/function/ParameterizedTypeVisitor.java index 9ff42f549..755c99777 100644 --- a/core/src/main/java/io/substrait/function/ParameterizedTypeVisitor.java +++ b/core/src/main/java/io/substrait/function/ParameterizedTypeVisitor.java @@ -29,6 +29,8 @@ public interface ParameterizedTypeVisitor extends TypeVi R visit(ParameterizedType.StringLiteral stringLiteral) throws E; + R visit(ParameterizedType.Func expr) throws E; + abstract class ParameterizedTypeThrowsVisitor extends TypeVisitor.TypeThrowsVisitor implements ParameterizedTypeVisitor { @@ -100,5 +102,10 @@ public R visit(ParameterizedType.Map expr) throws E { public R visit(ParameterizedType.StringLiteral stringLiteral) throws E { throw t(); } + + @Override + public R visit(ParameterizedType.Func expr) throws E { + throw t(); + } } } diff --git a/core/src/main/java/io/substrait/function/TypeExpression.java b/core/src/main/java/io/substrait/function/TypeExpression.java index a183c1959..345fc0398 100644 --- a/core/src/main/java/io/substrait/function/TypeExpression.java +++ b/core/src/main/java/io/substrait/function/TypeExpression.java @@ -191,6 +191,22 @@ R acceptE(final TypeExpressionVisitor visitor) th } } + @Value.Immutable + abstract class Func extends BaseTypeExpression implements NullableType { + public abstract java.util.List parameterTypes(); + + public abstract TypeExpression returnType(); + + public static ImmutableTypeExpression.Func.Builder builder() { + return ImmutableTypeExpression.Func.builder(); + } + + @Override + R acceptE(final TypeExpressionVisitor visitor) throws E { + return visitor.visit(this); + } + } + @Value.Immutable abstract class BinaryOperation extends BaseTypeExpression { public enum OpType { diff --git a/core/src/main/java/io/substrait/function/TypeExpressionCreator.java b/core/src/main/java/io/substrait/function/TypeExpressionCreator.java index b7524911b..9d822ffed 100644 --- a/core/src/main/java/io/substrait/function/TypeExpressionCreator.java +++ b/core/src/main/java/io/substrait/function/TypeExpressionCreator.java @@ -82,6 +82,16 @@ public TypeExpression mapE(TypeExpression key, TypeExpression value) { return TypeExpression.Map.builder().nullable(nullable).key(key).value(value).build(); } + @Override + public TypeExpression funcE( + Iterable parameterTypes, TypeExpression returnType) { + return TypeExpression.Func.builder() + .nullable(nullable) + .addAllParameterTypes(parameterTypes) + .returnType(returnType) + .build(); + } + public static class Assign { String name; TypeExpression expr; diff --git a/core/src/main/java/io/substrait/function/TypeExpressionVisitor.java b/core/src/main/java/io/substrait/function/TypeExpressionVisitor.java index 31d632c71..44d871337 100644 --- a/core/src/main/java/io/substrait/function/TypeExpressionVisitor.java +++ b/core/src/main/java/io/substrait/function/TypeExpressionVisitor.java @@ -24,6 +24,8 @@ public interface TypeExpressionVisitor R visit(TypeExpression.Map expr) throws E; + R visit(TypeExpression.Func expr) throws E; + R visit(TypeExpression.BinaryOperation expr) throws E; R visit(TypeExpression.NotOperation expr) throws E; @@ -97,6 +99,11 @@ public R visit(TypeExpression.Map expr) throws E { throw t(); } + @Override + public R visit(TypeExpression.Func expr) throws E { + throw t(); + } + @Override public R visit(TypeExpression.BinaryOperation expr) throws E { throw t(); diff --git a/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java b/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java index cdb72aea8..59ebe1257 100644 --- a/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java +++ b/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java @@ -432,6 +432,35 @@ public Optional visit( .build()); } + @Override + public Optional visit(Expression.Lambda lambda, EmptyVisitationContext context) + throws E { + Optional newBody = lambda.body().accept(this, context); + + if (allEmpty(newBody)) { + return Optional.empty(); + } + return Optional.of( + Expression.Lambda.builder().from(lambda).body(newBody.orElse(lambda.body())).build()); + } + + @Override + public Optional visit( + Expression.LambdaInvocation expr, EmptyVisitationContext context) throws E { + Optional lambda = expr.lambda().accept(this, context); + Optional arguments = expr.arguments().accept(this, context); + + if (allEmpty(lambda, arguments)) { + return Optional.empty(); + } + return Optional.of( + Expression.LambdaInvocation.builder() + .from(expr) + .lambda((Expression.Lambda) lambda.orElse(expr.lambda())) + .arguments((Expression.NestedStruct) arguments.orElse(expr.arguments())) + .build()); + } + // utilities protected Optional> visitExprList( diff --git a/core/src/main/java/io/substrait/relation/VirtualTableScan.java b/core/src/main/java/io/substrait/relation/VirtualTableScan.java index e71b0b00c..de7c29dbb 100644 --- a/core/src/main/java/io/substrait/relation/VirtualTableScan.java +++ b/core/src/main/java/io/substrait/relation/VirtualTableScan.java @@ -242,5 +242,11 @@ public Integer visit(Type.Map type) throws RuntimeException { public Integer visit(Type.UserDefined type) throws RuntimeException { return 0; } + + @Override + public Integer visit(Type.Func type) throws RuntimeException { + return type.parameterTypes().stream().mapToInt(p -> p.accept(this)).sum() + + type.returnType().accept(this); + } } } diff --git a/core/src/main/java/io/substrait/type/StringTypeVisitor.java b/core/src/main/java/io/substrait/type/StringTypeVisitor.java index d7c196148..e9d711d96 100644 --- a/core/src/main/java/io/substrait/type/StringTypeVisitor.java +++ b/core/src/main/java/io/substrait/type/StringTypeVisitor.java @@ -150,4 +150,13 @@ public String visit(Type.Map type) throws RuntimeException { public String visit(Type.UserDefined type) throws RuntimeException { return String.format("u!%s%s", type.name(), n(type)); } + + @Override + public String visit(Type.Func type) throws RuntimeException { + return String.format( + "func%s<%s -> %s>", + n(type), + type.parameterTypes().stream().map(t -> t.accept(this)).collect(Collectors.joining(", ")), + type.returnType().accept(this)); + } } diff --git a/core/src/main/java/io/substrait/type/Type.java b/core/src/main/java/io/substrait/type/Type.java index 5a1594e59..685cbe25a 100644 --- a/core/src/main/java/io/substrait/type/Type.java +++ b/core/src/main/java/io/substrait/type/Type.java @@ -352,6 +352,22 @@ public R accept(final TypeVisitor typeVisitor) th } } + @Value.Immutable + abstract class Func implements Type { + public abstract java.util.List parameterTypes(); + + public abstract Type returnType(); + + public static ImmutableType.Func.Builder builder() { + return ImmutableType.Func.builder(); + } + + @Override + public R accept(TypeVisitor typeVisitor) throws E { + return typeVisitor.visit(this); + } + } + @Value.Immutable abstract class Struct implements Type { public abstract java.util.List fields(); diff --git a/core/src/main/java/io/substrait/type/TypeCreator.java b/core/src/main/java/io/substrait/type/TypeCreator.java index 43358e505..7e4b1eec4 100644 --- a/core/src/main/java/io/substrait/type/TypeCreator.java +++ b/core/src/main/java/io/substrait/type/TypeCreator.java @@ -89,6 +89,14 @@ public final Type intervalCompound(int precision) { return Type.IntervalCompound.builder().nullable(nullable).precision(precision).build(); } + public Type.Func func(java.util.List parameterTypes, Type returnType) { + return Type.Func.builder() + .nullable(nullable) + .parameterTypes(parameterTypes) + .returnType(returnType) + .build(); + } + public Type.Struct struct(Iterable types) { return Type.Struct.builder().nullable(nullable).addAllFields(types).build(); } @@ -258,6 +266,11 @@ public Type visit(Type.Struct type) throws RuntimeException { return Type.Struct.builder().from(type).nullable(nullability).build(); } + @Override + public Type visit(Type.Func type) throws RuntimeException { + return Type.Func.builder().from(type).nullable(nullability).build(); + } + @Override public Type visit(Type.ListType type) throws RuntimeException { return Type.ListType.builder().from(type).nullable(nullability).build(); diff --git a/core/src/main/java/io/substrait/type/TypeVisitor.java b/core/src/main/java/io/substrait/type/TypeVisitor.java index ce6a08910..62e760175 100644 --- a/core/src/main/java/io/substrait/type/TypeVisitor.java +++ b/core/src/main/java/io/substrait/type/TypeVisitor.java @@ -51,6 +51,8 @@ public interface TypeVisitor { R visit(Type.Decimal type) throws E; + R visit(Type.Func type) throws E; + R visit(Type.Struct type) throws E; R visit(Type.ListType type) throws E; @@ -191,6 +193,11 @@ public R visit(Type.PrecisionTimestampTZ type) throws E { throw t(); } + @Override + public R visit(Type.Func type) throws E { + throw t(); + } + @Override public R visit(Type.Struct type) throws E { throw t(); diff --git a/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java b/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java index 67d7bc9b5..3909d4033 100644 --- a/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java +++ b/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java @@ -142,6 +142,16 @@ public final T visit(final Type.PrecisionTimestampTZ expr) { return typeContainer(expr).precisionTimestampTZ(expr.precision()); } + @Override + public final T visit(final Type.Func expr) { + return typeContainer(expr) + .func( + expr.parameterTypes().stream() + .map(t -> t.accept(this)) + .collect(java.util.stream.Collectors.toList()), + expr.returnType().accept(this)); + } + @Override public final T visit(final Type.Struct expr) { return typeContainer(expr) diff --git a/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java b/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java index 57b1f26b5..47842382f 100644 --- a/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java +++ b/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java @@ -119,6 +119,8 @@ public final T precisionTimestampTZ(int precision) { public abstract T intervalCompound(I precision); + public abstract T func(Iterable parameterTypes, T returnType); + public final T struct(T... types) { return struct(Arrays.asList(types)); } diff --git a/core/src/main/java/io/substrait/type/proto/ParameterizedProtoConverter.java b/core/src/main/java/io/substrait/type/proto/ParameterizedProtoConverter.java index 817f7b0b3..f4428ec47 100644 --- a/core/src/main/java/io/substrait/type/proto/ParameterizedProtoConverter.java +++ b/core/src/main/java/io/substrait/type/proto/ParameterizedProtoConverter.java @@ -256,6 +256,13 @@ public ParameterizedType map(ParameterizedType key, ParameterizedType value) { .build()); } + @Override + public ParameterizedType func( + Iterable parameterTypes, ParameterizedType returnType) { + throw new UnsupportedOperationException( + "Function types are not supported in Parameterized Types - ParameterizedFunc does not exist in proto"); + } + @Override public ParameterizedType userDefined(int ref) { throw new UnsupportedOperationException( diff --git a/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java b/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java index bdb600c1c..24231aebc 100644 --- a/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java +++ b/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java @@ -77,6 +77,13 @@ public Type from(io.substrait.proto.Type type) { case PRECISION_TIMESTAMP_TZ: return n(type.getPrecisionTimestampTz().getNullability()) .precisionTimestampTZ(type.getPrecisionTimestampTz().getPrecision()); + case FUNC: + return n(type.getFunc().getNullability()) + .func( + type.getFunc().getParameterTypesList().stream() + .map(this::from) + .collect(java.util.stream.Collectors.toList()), + from(type.getFunc().getReturnType())); case STRUCT: return n(type.getStruct().getNullability()) .struct( diff --git a/core/src/main/java/io/substrait/type/proto/TypeExpressionProtoVisitor.java b/core/src/main/java/io/substrait/type/proto/TypeExpressionProtoVisitor.java index f9d2129d2..8c92789cb 100644 --- a/core/src/main/java/io/substrait/type/proto/TypeExpressionProtoVisitor.java +++ b/core/src/main/java/io/substrait/type/proto/TypeExpressionProtoVisitor.java @@ -317,6 +317,13 @@ public DerivationExpression intervalCompound(DerivationExpression precision) { .build()); } + @Override + public DerivationExpression func( + Iterable parameterTypes, DerivationExpression returnType) { + throw new UnsupportedOperationException( + "User defined types are not supported in Derivation Expressions for now"); + } + @Override public DerivationExpression struct(Iterable types) { return wrap( diff --git a/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java b/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java index 6422904c4..c0e785db4 100644 --- a/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java +++ b/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java @@ -154,6 +154,16 @@ public Type precisionTimestampTZ(Integer precision) { .build()); } + @Override + public Type func(Iterable parameterTypes, Type returnType) { + return wrap( + Type.Func.newBuilder() + .addAllParameterTypes(parameterTypes) + .setReturnType(returnType) + .setNullability(nullability) + .build()); + } + @Override public Type struct(Iterable types) { return wrap(Type.Struct.newBuilder().addAllTypes(types).setNullability(nullability).build()); @@ -237,6 +247,8 @@ protected Type wrap(final Object o) { return bldr.setPrecisionTimestamp((Type.PrecisionTimestamp) o).build(); } else if (o instanceof Type.PrecisionTimestampTZ) { return bldr.setPrecisionTimestampTz((Type.PrecisionTimestampTZ) o).build(); + } else if (o instanceof Type.Func) { + return bldr.setFunc((Type.Func) o).build(); } else if (o instanceof Type.Struct) { return bldr.setStruct((Type.Struct) o).build(); } else if (o instanceof Type.List) { diff --git a/core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java new file mode 100644 index 000000000..537b7aaf0 --- /dev/null +++ b/core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java @@ -0,0 +1,471 @@ +package io.substrait.type.proto; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.substrait.TestBase; +import io.substrait.expression.Expression; +import io.substrait.expression.ExpressionCreator; +import io.substrait.expression.FieldReference; +import io.substrait.type.Type; +import io.substrait.type.TypeCreator; +import java.util.List; +import org.junit.jupiter.api.Test; + +/** + * Tests for Lambda expression round-trip conversion through protobuf. + * Based on equivalent tests from substrait-go. + */ +class LambdaExpressionRoundtripTest extends TestBase { + + /** + * Test that lambdas with no parameters are valid. + * Building: () -> i32(42) : func<() -> i32> + */ + @Test + void zeroParameterLambda() { + Type.Struct emptyParams = Type.Struct.builder() + .nullable(false) + .build(); + + Expression body = ExpressionCreator.i32(false, 42); + + Expression.Lambda lambda = Expression.Lambda.builder() + .parameters(emptyParams) + .body(body) + .build(); + + verifyRoundTrip(lambda); + + // Verify the lambda type + Type lambdaType = lambda.getType(); + assertInstanceOf(Type.Func.class, lambdaType); + Type.Func funcType = (Type.Func) lambdaType; + assertEquals(0, funcType.parameterTypes().size()); + assertEquals(R.I32, funcType.returnType()); + } + + /** + * Test valid stepsOut=0 references. + * Building: ($0: i32) -> $0 : func i32> + */ + @Test + void validStepsOut0() { + Type.Struct params = Type.Struct.builder() + .nullable(false) + .addFields(R.I32) + .build(); + + // Lambda body references parameter 0 with stepsOut=0 + FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); + + Expression.Lambda lambda = Expression.Lambda.builder() + .parameters(params) + .body(paramRef) + .build(); + + verifyRoundTrip(lambda); + + // Verify types + Type lambdaType = lambda.getType(); + assertInstanceOf(Type.Func.class, lambdaType); + Type.Func funcType = (Type.Func) lambdaType; + assertEquals(1, funcType.parameterTypes().size()); + assertEquals(R.I32, funcType.parameterTypes().get(0)); + assertEquals(R.I32, funcType.returnType()); + } + + /** + * Test valid field index with multiple parameters. + * Building: ($0: i32, $1: i64, $2: string) -> $2 : func<(i32, i64, string) -> string> + */ + @Test + void validFieldIndex() { + Type.Struct params = Type.Struct.builder() + .nullable(false) + .addFields(R.I32, R.I64, R.STRING) + .build(); + + // Reference the 3rd parameter (string) + FieldReference paramRef = FieldReference.newLambdaParameterReference(2, params, 0); + + Expression.Lambda lambda = Expression.Lambda.builder() + .parameters(params) + .body(paramRef) + .build(); + + verifyRoundTrip(lambda); + + // Verify return type is string + Type.Func funcType = (Type.Func) lambda.getType(); + assertEquals(R.STRING, funcType.returnType()); + } + + /** + * Test type resolution for different parameter types. + */ + @Test + void typeResolution() { + // Test cases: (paramTypes, fieldIndex, expectedReturnType) + record TestCase(List paramTypes, int fieldIndex, Type expectedType) {} + + List testCases = List.of( + new TestCase(List.of(R.I32), 0, R.I32), + new TestCase(List.of(R.I32, R.I64), 1, R.I64), + new TestCase(List.of(R.I32, R.I64, R.STRING), 2, R.STRING), + new TestCase(List.of(R.FP64), 0, R.FP64), + new TestCase(List.of(R.BOOLEAN, R.DATE, R.TIMESTAMP), 1, R.DATE) + ); + + for (TestCase tc : testCases) { + Type.Struct params = Type.Struct.builder() + .nullable(false) + .addAllFields(tc.paramTypes) + .build(); + + FieldReference paramRef = FieldReference.newLambdaParameterReference(tc.fieldIndex, params, 0); + + Expression.Lambda lambda = Expression.Lambda.builder() + .parameters(params) + .body(paramRef) + .build(); + + verifyRoundTrip(lambda); + + // Verify the body type matches expected + assertEquals(tc.expectedType, lambda.body().getType(), + "Body type should match referenced parameter type"); + + // Verify lambda return type + Type.Func funcType = (Type.Func) lambda.getType(); + assertEquals(tc.expectedType, funcType.returnType(), + "Lambda return type should match body type"); + } + } + + /** + * Test nested lambda with outer reference. + * Building: ($0: i64, $1: i64) -> (($0: i32) -> outer[$0] : i64) : func<(i64, i64) -> func i64>> + */ + @Test + void nestedLambdaWithOuterRef() { + Type.Struct outerParams = Type.Struct.builder() + .nullable(false) + .addFields(R.I64, R.I64) + .build(); + + Type.Struct innerParams = Type.Struct.builder() + .nullable(false) + .addFields(R.I32) + .build(); + + // Inner lambda references outer's parameter 0 with stepsOut=1 + FieldReference outerRef = FieldReference.newLambdaParameterReference(0, outerParams, 1); + + Expression.Lambda innerLambda = Expression.Lambda.builder() + .parameters(innerParams) + .body(outerRef) + .build(); + + Expression.Lambda outerLambda = Expression.Lambda.builder() + .parameters(outerParams) + .body(innerLambda) + .build(); + + verifyRoundTrip(outerLambda); + + // Verify structure + assertInstanceOf(Expression.Lambda.class, outerLambda.body()); + Expression.Lambda resultInner = (Expression.Lambda) outerLambda.body(); + assertEquals(1, resultInner.parameters().fields().size()); + } + + /** + * Test outer reference type resolution in nested lambdas. + * Building: ($0: i32, $1: i64, $2: string) -> (($0: fp64) -> outer[$2] : string) : func<...> + */ + @Test + void outerRefTypeResolution() { + Type.Struct outerParams = Type.Struct.builder() + .nullable(false) + .addFields(R.I32, R.I64, R.STRING) + .build(); + + Type.Struct innerParams = Type.Struct.builder() + .nullable(false) + .addFields(R.FP64) + .build(); + + // Inner references outer's field 2 (string) with stepsOut=1 + FieldReference outerRef = FieldReference.newLambdaParameterReference(2, outerParams, 1); + + Expression.Lambda innerLambda = Expression.Lambda.builder() + .parameters(innerParams) + .body(outerRef) + .build(); + + Expression.Lambda outerLambda = Expression.Lambda.builder() + .parameters(outerParams) + .body(innerLambda) + .build(); + + verifyRoundTrip(outerLambda); + + // Verify inner lambda's return type is string (from outer param 2) + Expression.Lambda resultInner = (Expression.Lambda) outerLambda.body(); + Type.Func innerFuncType = (Type.Func) resultInner.getType(); + assertEquals(R.STRING, innerFuncType.returnType(), + "Inner lambda return type should be string from outer.$2"); + + // Verify body's type is also string + assertEquals(R.STRING, resultInner.body().getType(), + "Body type should be string"); + } + + /** + * Test deeply nested field ref inside Cast. + * Building: ($0: i32) -> cast($0 as i64) : func i64> + */ + @Test + void deeplyNestedFieldRef() { + Type.Struct params = Type.Struct.builder() + .nullable(false) + .addFields(R.I32) + .build(); + + FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); + + Expression.Cast castExpr = (Expression.Cast) + ExpressionCreator.cast(R.I64, paramRef, Expression.FailureBehavior.THROW_EXCEPTION); + + Expression.Lambda lambda = Expression.Lambda.builder() + .parameters(params) + .body(castExpr) + .build(); + + verifyRoundTrip(lambda); + + // Verify the nested FieldRef has its type resolved + Expression.Cast resultCast = (Expression.Cast) lambda.body(); + assertInstanceOf(FieldReference.class, resultCast.input()); + FieldReference resultFieldRef = (FieldReference) resultCast.input(); + + assertNotNull(resultFieldRef.getType(), "Nested FieldRef should have type resolved"); + assertEquals(R.I32, resultFieldRef.getType(), "Should resolve to i32"); + + // Verify lambda return type is i64 (cast output) + Type.Func funcType = (Type.Func) lambda.getType(); + assertEquals(R.I64, funcType.returnType()); + } + + /** + * Test doubly nested field ref (Cast(Cast(LambdaParamRef))). + * Building: ($0: i32) -> cast(cast($0 as i64) as string) : func string> + */ + @Test + void doublyNestedFieldRef() { + Type.Struct params = Type.Struct.builder() + .nullable(false) + .addFields(R.I32) + .build(); + + FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); + Expression.Cast innerCast = (Expression.Cast) + ExpressionCreator.cast(R.I64, paramRef, Expression.FailureBehavior.THROW_EXCEPTION); + Expression.Cast outerCast = (Expression.Cast) + ExpressionCreator.cast(R.STRING, innerCast, Expression.FailureBehavior.THROW_EXCEPTION); + + Expression.Lambda lambda = Expression.Lambda.builder() + .parameters(params) + .body(outerCast) + .build(); + + verifyRoundTrip(lambda); + + // Navigate to the deeply nested FieldRef (2 levels deep) + Expression.Cast resultOuter = (Expression.Cast) lambda.body(); + Expression.Cast resultInner = (Expression.Cast) resultOuter.input(); + FieldReference resultFieldRef = (FieldReference) resultInner.input(); + + // Verify type is resolved even at depth 2 + assertNotNull(resultFieldRef.getType(), "FieldRef at depth 2 should have type resolved"); + assertEquals(R.I32, resultFieldRef.getType()); + } + + /** + * Test lambda with literal body (no parameter references). + * Building: ($0: i32) -> 42 : func i32> + */ + @Test + void lambdaWithLiteralBody() { + Type.Struct params = Type.Struct.builder() + .nullable(false) + .addFields(R.I32) + .build(); + + Expression body = ExpressionCreator.i32(false, 42); + + Expression.Lambda lambda = Expression.Lambda.builder() + .parameters(params) + .body(body) + .build(); + + verifyRoundTrip(lambda); + } + + /** + * Test lambda getType returns correct Func type. + */ + @Test + void lambdaGetTypeReturnsFunc() { + Type.Struct params = Type.Struct.builder() + .nullable(false) + .addFields(R.I32, R.STRING) + .build(); + + FieldReference paramRef = FieldReference.newLambdaParameterReference(1, params, 0); + + Expression.Lambda lambda = Expression.Lambda.builder() + .parameters(params) + .body(paramRef) + .build(); + + Type lambdaType = lambda.getType(); + + assertInstanceOf(Type.Func.class, lambdaType); + Type.Func funcType = (Type.Func) lambdaType; + + assertEquals(2, funcType.parameterTypes().size()); + assertEquals(R.I32, funcType.parameterTypes().get(0)); + assertEquals(R.STRING, funcType.parameterTypes().get(1)); + assertEquals(R.STRING, funcType.returnType()); // body references param 1 which is STRING + } + + // ==================== Validation Error Tests ==================== + + /** + * Test that invalid outer reference (stepsOut too high) fails during proto conversion. + * Building: ($0: i32) -> outer[$0] : INVALID (no outer lambda, stepsOut=1) + */ + @Test + void invalidOuterRef_stepsOutTooHigh() { + Type.Struct params = Type.Struct.builder() + .nullable(false) + .addFields(R.I32) + .build(); + + // Create a parameter reference with stepsOut=1 but no outer lambda exists + FieldReference invalidRef = FieldReference.newLambdaParameterReference(0, params, 1); + + Expression.Lambda lambda = Expression.Lambda.builder() + .parameters(params) + .body(invalidRef) + .build(); + + // Convert to proto - this should work + io.substrait.proto.Expression protoExpression = expressionProtoConverter.toProto(lambda); + + // Converting back should fail because stepsOut=1 references non-existent outer lambda + assertThrows(IllegalArgumentException.class, () -> { + protoExpressionConverter.from(protoExpression); + }, "Should fail when stepsOut references non-existent outer lambda"); + } + + /** + * Test that invalid field index (out of bounds) fails during proto conversion. + * Building: ($0: i32) -> $5 : INVALID (only has 1 param) + */ + @Test + void invalidFieldIndex_outOfBounds() { + Type.Struct params = Type.Struct.builder() + .nullable(false) + .addFields(R.I32) + .build(); + + // Create a reference to field 5, but lambda only has 1 parameter (index 0) + // This will fail at build time since newLambdaParameterReference accesses fields.get(5) + assertThrows(IndexOutOfBoundsException.class, () -> { + FieldReference.newLambdaParameterReference(5, params, 0); + }, "Should fail when field index is out of bounds"); + } + + /** + * Test nested invalid outer ref (stepsOut=2 but only 1 outer lambda). + * Building: ($0: i64) -> (($0: i32) -> outer.outer[$0]) : INVALID (no grandparent lambda) + */ + @Test + void nestedInvalidOuterRef() { + Type.Struct outerParams = Type.Struct.builder() + .nullable(false) + .addFields(R.I64) + .build(); + + Type.Struct innerParams = Type.Struct.builder() + .nullable(false) + .addFields(R.I32) + .build(); + + // Inner lambda references stepsOut=2, but only 1 outer lambda exists + FieldReference invalidRef = FieldReference.newLambdaParameterReference(0, outerParams, 2); + + Expression.Lambda innerLambda = Expression.Lambda.builder() + .parameters(innerParams) + .body(invalidRef) + .build(); + + Expression.Lambda outerLambda = Expression.Lambda.builder() + .parameters(outerParams) + .body(innerLambda) + .build(); + + // Convert to proto + io.substrait.proto.Expression protoExpression = expressionProtoConverter.toProto(outerLambda); + + // Converting back should fail because stepsOut=2 references non-existent grandparent + assertThrows(IllegalArgumentException.class, () -> { + protoExpressionConverter.from(protoExpression); + }, "Should fail when stepsOut references non-existent grandparent lambda"); + } + + /** + * Test deeply nested invalid field ref inside Cast. + * Building: ($0: i32) -> cast($5 as i64) : INVALID (only has 1 param) + */ + @Test + void deeplyNestedInvalidFieldRef() { + Type.Struct params = Type.Struct.builder() + .nullable(false) + .addFields(R.I32) + .build(); + + // This should fail at build time since field 5 doesn't exist + assertThrows(IndexOutOfBoundsException.class, () -> { + FieldReference.newLambdaParameterReference(5, params, 0); + }, "Should fail when nested field index is out of bounds"); + } + + /** + * Test that outer field index out of bounds fails. + * Building: ($0: i64) -> (($0: i32) -> outer[$5]) : INVALID (outer only has 1 param) + */ + @Test + void nestedInvalidOuterFieldIndex() { + Type.Struct outerParams = Type.Struct.builder() + .nullable(false) + .addFields(R.I64) + .build(); + + Type.Struct innerParams = Type.Struct.builder() + .nullable(false) + .addFields(R.I32) + .build(); + + // This should fail at build time since outer only has 1 parameter (index 0) + assertThrows(IndexOutOfBoundsException.class, () -> { + FieldReference.newLambdaParameterReference(5, outerParams, 1); + }, "Should fail when outer field index is out of bounds"); + } +} diff --git a/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java b/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java index a26ec963e..f9125587d 100644 --- a/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java +++ b/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java @@ -195,6 +195,12 @@ public String visit(Expression.NestedStruct expr, EmptyVisitationContext context return ""; } + @Override + public String visit(Expression.Lambda expr, EmptyVisitationContext context) + throws RuntimeException { + return ""; + } + @Override public String visit(UserDefinedAnyLiteral expr, EmptyVisitationContext context) throws RuntimeException { @@ -217,6 +223,12 @@ public String visit(IfThen expr, EmptyVisitationContext context) throws RuntimeE return ""; } + @Override + public String visit(Expression.LambdaInvocation expr, EmptyVisitationContext context) + throws RuntimeException { + return ""; + } + @Override public String visit(ScalarFunctionInvocation expr, EmptyVisitationContext context) throws RuntimeException { diff --git a/examples/substrait-spark/src/main/java/io/substrait/examples/util/TypeStringify.java b/examples/substrait-spark/src/main/java/io/substrait/examples/util/TypeStringify.java index 0e13c3e2e..f5f8d93e9 100644 --- a/examples/substrait-spark/src/main/java/io/substrait/examples/util/TypeStringify.java +++ b/examples/substrait-spark/src/main/java/io/substrait/examples/util/TypeStringify.java @@ -149,6 +149,11 @@ public String visit(Decimal type) throws RuntimeException { return type.getClass().getSimpleName(); } + @Override + public String visit(Type.Func type) throws RuntimeException { + return type.getClass().getSimpleName(); + } + @Override public String visit(Struct type) throws RuntimeException { StringBuffer sb = new StringBuffer(type.getClass().getSimpleName()); diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java index e06f66e8b..0f31d3d4b 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java @@ -41,6 +41,7 @@ import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexFieldCollation; import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexLambdaRef; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexSubQuery; @@ -381,6 +382,23 @@ public RexNode visit(Expression.IfThen expr, Context context) throws RuntimeExce return rexBuilder.makeCall(SqlStdOperatorTable.CASE, args); } + @Override + public RexNode visit(Expression.Lambda expr, Context context) throws RuntimeException { + List parameters = + IntStream.range(0, expr.parameters().fields().size()) + .mapToObj( + i -> + new RexLambdaRef( + i, + "p" + i, + typeConverter.toCalcite(typeFactory, expr.parameters().fields().get(i)))) + .collect(Collectors.toList()); + + RexNode body = expr.body().accept(this, context); + + return rexBuilder.makeLambdaCall(body, parameters); + } + @Override public RexNode visit(Switch expr, Context context) throws RuntimeException { RexNode match = expr.match().accept(this, context); @@ -655,6 +673,21 @@ public RexNode visit(FieldReference expr, Context context) throws RuntimeExcepti } return rexInputRef; + } else if (expr.isLambdaParameterReference()) { + int stepsOut = expr.lambdaParameterReferenceStepsOut().get(); + if (stepsOut != 0) { + throw new UnsupportedOperationException( + "Calcite does not support nested lambdas (stepsOut=" + stepsOut + ")"); + } + + final ReferenceSegment segment = expr.segments().get(0); + if (segment instanceof FieldReference.StructField) { + final FieldReference.StructField field = (FieldReference.StructField) segment; + RelDataType calciteType = typeConverter.toCalcite(typeFactory, expr.getType()); + return new RexLambdaRef(field.offset(), "p" + field.offset(), calciteType); + } else { + throw new IllegalArgumentException("Unhandled type: " + segment); + } } return visitFallback(expr, context); diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/IgnoreNullableAndParameters.java b/isthmus/src/main/java/io/substrait/isthmus/expression/IgnoreNullableAndParameters.java index f8b4be1dd..b56fa4fd3 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/IgnoreNullableAndParameters.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/IgnoreNullableAndParameters.java @@ -128,6 +128,11 @@ public Boolean visit(Type.Decimal type) { return typeToMatch instanceof Type.Decimal || typeToMatch instanceof ParameterizedType.Decimal; } + @Override + public Boolean visit(Type.Func type) throws RuntimeException { + return typeToMatch instanceof Type.Func || typeToMatch instanceof ParameterizedType.Func; + } + @Override public Boolean visit(Type.PrecisionTime type) { return typeToMatch instanceof Type.PrecisionTime @@ -234,4 +239,9 @@ public Boolean visit(ParameterizedType.Map expr) throws RuntimeException { public Boolean visit(ParameterizedType.StringLiteral stringLiteral) throws RuntimeException { return false; } + + @Override + public Boolean visit(ParameterizedType.Func expr) throws RuntimeException { + return typeToMatch instanceof Type.Func || typeToMatch instanceof ParameterizedType.Func; + } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java index 6993c8451..19b3d2224 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java @@ -7,6 +7,7 @@ import io.substrait.isthmus.TypeConverter; import io.substrait.relation.Rel; import io.substrait.type.StringTypeVisitor; +import io.substrait.type.Type; import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -202,12 +203,30 @@ public Expression visitPatternFieldRef(RexPatternFieldRef fieldRef) { @Override public Expression visitLambda(RexLambda rexLambda) { - throw new UnsupportedOperationException("RexLambda not supported"); + List paramTypes = + rexLambda.getParameters().stream() + .map(param -> typeConverter.toSubstrait(param.getType())) + .collect(Collectors.toList()); + + Type.Struct parameters = Type.Struct.builder().nullable(false).addAllFields(paramTypes).build(); + + + Expression body = rexLambda.getExpression().accept(this); + + return Expression.Lambda.builder().parameters(parameters).body(body).build(); } @Override public Expression visitLambdaRef(RexLambdaRef rexLambdaRef) { - throw new UnsupportedOperationException("RexLambdaRef not supported"); + int fieldIndex = rexLambdaRef.getIndex(); + Type paramType = typeConverter.toSubstrait(rexLambdaRef.getType()); + + return FieldReference.builder() + .addSegments(FieldReference.StructField.of(fieldIndex)) + .type(paramType) + .lambdaParameterReferenceStepsOut( + 0) // Always 0 since Calcite doesn't support nested Lambda expressions + .build(); } @Override diff --git a/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java b/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java new file mode 100644 index 000000000..8b391a01d --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java @@ -0,0 +1,191 @@ +package io.substrait.isthmus; + + +import io.substrait.expression.Expression; +import io.substrait.expression.ExpressionCreator; +import io.substrait.expression.FieldReference; +import io.substrait.relation.Project; +import io.substrait.relation.Rel; +import io.substrait.type.Type; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +/** + * Tests for Lambda expression conversion between Substrait and Calcite. + * Note: Calcite does not support nested lambda expressions for the moment, so all tests use stepsOut=0. + */ + +public class LambdaExpressionTest extends PlanTestBase{ + + final Rel emptyTable = sb.emptyVirtualTableScan(); + + /** + * Test that lambdas with no parameters are valid. + * Building: () -> i32(42) : func<() -> i32> + */ + @Test + void lambdaExpressionZeroParameters(){ + Type.Struct params = Type.Struct.builder() + .nullable(false) + .build(); + + Expression body = ExpressionCreator.i32(false, 42); + List expressionList = new ArrayList<>(); + Expression.Lambda lambda = Expression.Lambda.builder() + .parameters(params) + .body(body) + .build(); + expressionList.add(lambda); + + Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + assertFullRoundTrip(project); + + } + + /** + * Test valid field index with multiple parameters. + * Building: ($0: i32, $1: i64, $2: string) -> $2 : func<(i32, i64, string) -> string> + */ + @Test + void validFieldIndex() { + Type.Struct params = Type.Struct.builder() + .nullable(false) + .addFields(R.I32, R.I64, R.STRING) + .build(); + + FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); + + List expressionList = new ArrayList<>(); + Expression.Lambda lambda = Expression.Lambda.builder() + .parameters(params) + .body(paramRef) + .build(); + + expressionList.add(lambda); + + + Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + assertFullRoundTrip(project); + } + + /** + * Test deeply nested field ref inside Cast. + * Building: ($0: i32) -> cast($0 as i64) : func i64> + */ + @Test + void deeplyNestedFieldRef() { + Type.Struct params = Type.Struct.builder() + .nullable(false) + .addFields(R.I32) + .build(); + FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); + + Expression.Cast castExpr = (Expression.Cast) + ExpressionCreator.cast(R.I64, paramRef, Expression.FailureBehavior.THROW_EXCEPTION); + List expressionList = new ArrayList<>(); + + + Expression.Lambda lambda = Expression.Lambda.builder() + .parameters(params) + .body(castExpr) + .build(); + + expressionList.add(lambda); + Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + assertFullRoundTrip(project); + } + + /** + * Test doubly nested field ref (Cast(Cast(LambdaParamRef))). + * Building: ($0: i32) -> cast(cast($0 as i64) as string) : func string> + */ + @Test + void doublyNestedFieldRef() { + Type.Struct params = Type.Struct.builder() + .nullable(false) + .addFields(R.I32) + .build(); + + FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); + Expression.Cast innerCast = (Expression.Cast) + ExpressionCreator.cast(R.I64, paramRef, Expression.FailureBehavior.THROW_EXCEPTION); + Expression.Cast outerCast = (Expression.Cast) + ExpressionCreator.cast(R.STRING, innerCast, Expression.FailureBehavior.THROW_EXCEPTION); + + List expressionList = new ArrayList<>(); + + Expression.Lambda lambda = Expression.Lambda.builder() + .parameters(params) + .body(outerCast) + .build(); + + expressionList.add(lambda); + Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + assertFullRoundTrip(project); + } + + /** + * Test lambda with literal body (no parameter references). + * Building: ($0: i32) -> 42 : func i32> + */ + @Test + void lambdaWithLiteralBody() { + Type.Struct params = Type.Struct.builder() + .nullable(false) + .addFields(R.I32) + .build(); + + Expression body = ExpressionCreator.i32(false, 42); + List expressionList = new ArrayList<>(); + + Expression.Lambda lambda = Expression.Lambda.builder() + .parameters(params) + .body(body) + .build(); + + expressionList.add(lambda); + Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + assertFullRoundTrip(project); + } + + /** + * Test that nested lambda (stepsOut > 0) throws UnsupportedOperationException. + * Calcite does not support nested lambda expressions. + */ + @Test + void nestedLambdaThrowsUnsupportedOperation() { + Type.Struct outerParams = Type.Struct.builder() + .nullable(false) + .addFields(R.I64) + .build(); + + Type.Struct innerParams = Type.Struct.builder() + .nullable(false) + .addFields(R.I32) + .build(); + + // Inner lambda references outer's parameter with stepsOut=1 + FieldReference outerRef = FieldReference.newLambdaParameterReference(0, outerParams, 1); + + Expression.Lambda innerLambda = Expression.Lambda.builder() + .parameters(innerParams) + .body(outerRef) + .build(); + + List expressionList = new ArrayList<>(); + + Expression.Lambda outerLambda = Expression.Lambda.builder() + .parameters(outerParams) + .body(innerLambda) + .build(); + + expressionList.add(outerLambda); + Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + assertThrows(UnsupportedOperationException.class, () -> assertFullRoundTrip(project)); + + } + } diff --git a/spark/src/main/scala/io/substrait/spark/expression/IgnoreNullableAndParameters.scala b/spark/src/main/scala/io/substrait/spark/expression/IgnoreNullableAndParameters.scala index 9cb38b8d0..acf006215 100644 --- a/spark/src/main/scala/io/substrait/spark/expression/IgnoreNullableAndParameters.scala +++ b/spark/src/main/scala/io/substrait/spark/expression/IgnoreNullableAndParameters.scala @@ -157,4 +157,12 @@ class IgnoreNullableAndParameters(val typeToMatch: ParameterizedType) @throws[RuntimeException] override def visit(precisionTimestampTZ: Type.PrecisionTimestampTZ): Boolean = typeToMatch.isInstanceOf[Type.PrecisionTimestampTZ] + + @throws[RuntimeException] + override def visit(`type`: Type.Func): Boolean = + typeToMatch.isInstanceOf[Type.Func] || typeToMatch.isInstanceOf[ParameterizedType.Func] + + @throws[RuntimeException] + override def visit(expr: ParameterizedType.Func): Boolean = + typeToMatch.isInstanceOf[Type.Func] || typeToMatch.isInstanceOf[ParameterizedType.Func] } From a90c6099aad6b50b7178ec42e19f7f5a8421a518 Mon Sep 17 00:00:00 2001 From: Limame Malainine Date: Wed, 25 Feb 2026 12:06:54 +0100 Subject: [PATCH 02/18] adresse some of @benbellick's comments --- .../expression/AbstractExpressionVisitor.java | 13 - .../io/substrait/expression/Expression.java | 22 -- .../expression/ExpressionVisitor.java | 10 - .../proto/ExpressionProtoConverter.java | 12 - .../proto/ProtoExpressionConverter.java | 17 - .../ExpressionCopyOnWriteVisitor.java | 17 - .../proto/LambdaExpressionRoundtripTest.java | 344 ++++++------------ .../examples/util/ExpressionStringify.java | 6 - .../expression/ExpressionRexConverter.java | 2 + .../expression/RexExpressionConverter.java | 8 +- .../isthmus/LambdaExpressionTest.java | 312 +++++++--------- 11 files changed, 258 insertions(+), 505 deletions(-) diff --git a/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java b/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java index a61f5d824..9b1fdc496 100644 --- a/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java +++ b/core/src/main/java/io/substrait/expression/AbstractExpressionVisitor.java @@ -461,19 +461,6 @@ public O visit(Expression.Lambda expr, C context) throws E { return visitFallback(expr, context); } - /** - * Visits a Lambda expression invocation. - * - * @param expr the Lambda expression invocation - * @param context the visitation context - * @return the visit result - * @throws E if visitation fails - */ - @Override - public O visit(Expression.LambdaInvocation expr, C context) throws E { - return visitFallback(expr, context); - } - /** * Visits a scalar function invocation. * diff --git a/core/src/main/java/io/substrait/expression/Expression.java b/core/src/main/java/io/substrait/expression/Expression.java index 1301c5c5c..320a8a614 100644 --- a/core/src/main/java/io/substrait/expression/Expression.java +++ b/core/src/main/java/io/substrait/expression/Expression.java @@ -927,28 +927,6 @@ public R accept( } } - @Value.Immutable - abstract class LambdaInvocation implements Expression { - public abstract Lambda lambda(); - - public abstract Expression.NestedStruct arguments(); - - @Override - public Type getType() { - return ((Type.Func) lambda().getType()).returnType(); - } - - public static ImmutableExpression.LambdaInvocation.Builder builder() { - return ImmutableExpression.LambdaInvocation.builder(); - } - - @Override - public R accept( - ExpressionVisitor visitor, C context) throws E { - return visitor.visit(this, context); - } - } - @Value.Immutable abstract class ScalarFunctionInvocation implements Expression { public abstract SimpleExtension.ScalarFunctionVariant declaration(); diff --git a/core/src/main/java/io/substrait/expression/ExpressionVisitor.java b/core/src/main/java/io/substrait/expression/ExpressionVisitor.java index e36d92e4f..a505af778 100644 --- a/core/src/main/java/io/substrait/expression/ExpressionVisitor.java +++ b/core/src/main/java/io/substrait/expression/ExpressionVisitor.java @@ -361,16 +361,6 @@ public interface ExpressionVisitor visit(Expression.Lambda lambda, EmptyVisitationConte Expression.Lambda.builder().from(lambda).body(newBody.orElse(lambda.body())).build()); } - @Override - public Optional visit( - Expression.LambdaInvocation expr, EmptyVisitationContext context) throws E { - Optional lambda = expr.lambda().accept(this, context); - Optional arguments = expr.arguments().accept(this, context); - - if (allEmpty(lambda, arguments)) { - return Optional.empty(); - } - return Optional.of( - Expression.LambdaInvocation.builder() - .from(expr) - .lambda((Expression.Lambda) lambda.orElse(expr.lambda())) - .arguments((Expression.NestedStruct) arguments.orElse(expr.arguments())) - .build()); - } - // utilities protected Optional> visitExprList( diff --git a/core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java index 537b7aaf0..c0a979a94 100644 --- a/core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java @@ -4,39 +4,30 @@ import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; import io.substrait.TestBase; import io.substrait.expression.Expression; import io.substrait.expression.ExpressionCreator; import io.substrait.expression.FieldReference; import io.substrait.type.Type; -import io.substrait.type.TypeCreator; import java.util.List; import org.junit.jupiter.api.Test; /** - * Tests for Lambda expression round-trip conversion through protobuf. - * Based on equivalent tests from substrait-go. + * Tests for Lambda expression round-trip conversion through protobuf. Based on equivalent tests + * from substrait-go. */ class LambdaExpressionRoundtripTest extends TestBase { - /** - * Test that lambdas with no parameters are valid. - * Building: () -> i32(42) : func<() -> i32> - */ + /** Test that lambdas with no parameters are valid. Building: () -> i32(42) : func<() -> i32> */ @Test void zeroParameterLambda() { - Type.Struct emptyParams = Type.Struct.builder() - .nullable(false) - .build(); + Type.Struct emptyParams = Type.Struct.builder().nullable(false).build(); Expression body = ExpressionCreator.i32(false, 42); - Expression.Lambda lambda = Expression.Lambda.builder() - .parameters(emptyParams) - .body(body) - .build(); + Expression.Lambda lambda = + Expression.Lambda.builder().parameters(emptyParams).body(body).build(); verifyRoundTrip(lambda); @@ -48,24 +39,16 @@ void zeroParameterLambda() { assertEquals(R.I32, funcType.returnType()); } - /** - * Test valid stepsOut=0 references. - * Building: ($0: i32) -> $0 : func i32> - */ + /** Test valid stepsOut=0 references. Building: ($0: i32) -> $0 : func i32> */ @Test void validStepsOut0() { - Type.Struct params = Type.Struct.builder() - .nullable(false) - .addFields(R.I32) - .build(); + Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); // Lambda body references parameter 0 with stepsOut=0 FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); - Expression.Lambda lambda = Expression.Lambda.builder() - .parameters(params) - .body(paramRef) - .build(); + Expression.Lambda lambda = + Expression.Lambda.builder().parameters(params).body(paramRef).build(); verifyRoundTrip(lambda); @@ -79,23 +62,19 @@ void validStepsOut0() { } /** - * Test valid field index with multiple parameters. - * Building: ($0: i32, $1: i64, $2: string) -> $2 : func<(i32, i64, string) -> string> + * Test valid field index with multiple parameters. Building: ($0: i32, $1: i64, $2: string) -> $2 + * : func<(i32, i64, string) -> string> */ @Test void validFieldIndex() { - Type.Struct params = Type.Struct.builder() - .nullable(false) - .addFields(R.I32, R.I64, R.STRING) - .build(); + Type.Struct params = + Type.Struct.builder().nullable(false).addFields(R.I32, R.I64, R.STRING).build(); // Reference the 3rd parameter (string) FieldReference paramRef = FieldReference.newLambdaParameterReference(2, params, 0); - Expression.Lambda lambda = Expression.Lambda.builder() - .parameters(params) - .body(paramRef) - .build(); + Expression.Lambda lambda = + Expression.Lambda.builder().parameters(params).body(paramRef).build(); verifyRoundTrip(lambda); @@ -104,76 +83,63 @@ void validFieldIndex() { assertEquals(R.STRING, funcType.returnType()); } - /** - * Test type resolution for different parameter types. - */ + /** Test type resolution for different parameter types. */ @Test void typeResolution() { // Test cases: (paramTypes, fieldIndex, expectedReturnType) record TestCase(List paramTypes, int fieldIndex, Type expectedType) {} - List testCases = List.of( - new TestCase(List.of(R.I32), 0, R.I32), - new TestCase(List.of(R.I32, R.I64), 1, R.I64), - new TestCase(List.of(R.I32, R.I64, R.STRING), 2, R.STRING), - new TestCase(List.of(R.FP64), 0, R.FP64), - new TestCase(List.of(R.BOOLEAN, R.DATE, R.TIMESTAMP), 1, R.DATE) - ); + List testCases = + List.of( + new TestCase(List.of(R.I32), 0, R.I32), + new TestCase(List.of(R.I32, R.I64), 1, R.I64), + new TestCase(List.of(R.I32, R.I64, R.STRING), 2, R.STRING), + new TestCase(List.of(R.FP64), 0, R.FP64), + new TestCase(List.of(R.BOOLEAN, R.DATE, R.TIMESTAMP), 1, R.DATE)); for (TestCase tc : testCases) { - Type.Struct params = Type.Struct.builder() - .nullable(false) - .addAllFields(tc.paramTypes) - .build(); + Type.Struct params = + Type.Struct.builder().nullable(false).addAllFields(tc.paramTypes).build(); - FieldReference paramRef = FieldReference.newLambdaParameterReference(tc.fieldIndex, params, 0); + FieldReference paramRef = + FieldReference.newLambdaParameterReference(tc.fieldIndex, params, 0); - Expression.Lambda lambda = Expression.Lambda.builder() - .parameters(params) - .body(paramRef) - .build(); + Expression.Lambda lambda = + Expression.Lambda.builder().parameters(params).body(paramRef).build(); verifyRoundTrip(lambda); // Verify the body type matches expected - assertEquals(tc.expectedType, lambda.body().getType(), + assertEquals( + tc.expectedType, + lambda.body().getType(), "Body type should match referenced parameter type"); // Verify lambda return type Type.Func funcType = (Type.Func) lambda.getType(); - assertEquals(tc.expectedType, funcType.returnType(), - "Lambda return type should match body type"); + assertEquals( + tc.expectedType, funcType.returnType(), "Lambda return type should match body type"); } } /** - * Test nested lambda with outer reference. - * Building: ($0: i64, $1: i64) -> (($0: i32) -> outer[$0] : i64) : func<(i64, i64) -> func i64>> + * Test nested lambda with outer reference. Building: ($0: i64, $1: i64) -> (($0: i32) -> + * outer[$0] : i64) : func<(i64, i64) -> func i64>> */ @Test void nestedLambdaWithOuterRef() { - Type.Struct outerParams = Type.Struct.builder() - .nullable(false) - .addFields(R.I64, R.I64) - .build(); + Type.Struct outerParams = Type.Struct.builder().nullable(false).addFields(R.I64, R.I64).build(); - Type.Struct innerParams = Type.Struct.builder() - .nullable(false) - .addFields(R.I32) - .build(); + Type.Struct innerParams = Type.Struct.builder().nullable(false).addFields(R.I32).build(); // Inner lambda references outer's parameter 0 with stepsOut=1 FieldReference outerRef = FieldReference.newLambdaParameterReference(0, outerParams, 1); - Expression.Lambda innerLambda = Expression.Lambda.builder() - .parameters(innerParams) - .body(outerRef) - .build(); + Expression.Lambda innerLambda = + Expression.Lambda.builder().parameters(innerParams).body(outerRef).build(); - Expression.Lambda outerLambda = Expression.Lambda.builder() - .parameters(outerParams) - .body(innerLambda) - .build(); + Expression.Lambda outerLambda = + Expression.Lambda.builder().parameters(outerParams).body(innerLambda).build(); verifyRoundTrip(outerLambda); @@ -184,67 +150,55 @@ void nestedLambdaWithOuterRef() { } /** - * Test outer reference type resolution in nested lambdas. - * Building: ($0: i32, $1: i64, $2: string) -> (($0: fp64) -> outer[$2] : string) : func<...> + * Test outer reference type resolution in nested lambdas. Building: ($0: i32, $1: i64, $2: + * string) -> (($0: fp64) -> outer[$2] : string) : func<...> */ @Test void outerRefTypeResolution() { - Type.Struct outerParams = Type.Struct.builder() - .nullable(false) - .addFields(R.I32, R.I64, R.STRING) - .build(); + Type.Struct outerParams = + Type.Struct.builder().nullable(false).addFields(R.I32, R.I64, R.STRING).build(); - Type.Struct innerParams = Type.Struct.builder() - .nullable(false) - .addFields(R.FP64) - .build(); + Type.Struct innerParams = Type.Struct.builder().nullable(false).addFields(R.FP64).build(); // Inner references outer's field 2 (string) with stepsOut=1 FieldReference outerRef = FieldReference.newLambdaParameterReference(2, outerParams, 1); - Expression.Lambda innerLambda = Expression.Lambda.builder() - .parameters(innerParams) - .body(outerRef) - .build(); + Expression.Lambda innerLambda = + Expression.Lambda.builder().parameters(innerParams).body(outerRef).build(); - Expression.Lambda outerLambda = Expression.Lambda.builder() - .parameters(outerParams) - .body(innerLambda) - .build(); + Expression.Lambda outerLambda = + Expression.Lambda.builder().parameters(outerParams).body(innerLambda).build(); verifyRoundTrip(outerLambda); // Verify inner lambda's return type is string (from outer param 2) Expression.Lambda resultInner = (Expression.Lambda) outerLambda.body(); Type.Func innerFuncType = (Type.Func) resultInner.getType(); - assertEquals(R.STRING, innerFuncType.returnType(), + assertEquals( + R.STRING, + innerFuncType.returnType(), "Inner lambda return type should be string from outer.$2"); // Verify body's type is also string - assertEquals(R.STRING, resultInner.body().getType(), - "Body type should be string"); + assertEquals(R.STRING, resultInner.body().getType(), "Body type should be string"); } /** - * Test deeply nested field ref inside Cast. - * Building: ($0: i32) -> cast($0 as i64) : func i64> + * Test deeply nested field ref inside Cast. Building: ($0: i32) -> cast($0 as i64) : func + * i64> */ @Test void deeplyNestedFieldRef() { - Type.Struct params = Type.Struct.builder() - .nullable(false) - .addFields(R.I32) - .build(); + Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); - Expression.Cast castExpr = (Expression.Cast) - ExpressionCreator.cast(R.I64, paramRef, Expression.FailureBehavior.THROW_EXCEPTION); + Expression.Cast castExpr = + (Expression.Cast) + ExpressionCreator.cast(R.I64, paramRef, Expression.FailureBehavior.THROW_EXCEPTION); - Expression.Lambda lambda = Expression.Lambda.builder() - .parameters(params) - .body(castExpr) - .build(); + Expression.Lambda lambda = + Expression.Lambda.builder().parameters(params).body(castExpr).build(); verifyRoundTrip(lambda); @@ -262,26 +216,23 @@ void deeplyNestedFieldRef() { } /** - * Test doubly nested field ref (Cast(Cast(LambdaParamRef))). - * Building: ($0: i32) -> cast(cast($0 as i64) as string) : func string> + * Test doubly nested field ref (Cast(Cast(LambdaParamRef))). Building: ($0: i32) -> cast(cast($0 + * as i64) as string) : func string> */ @Test void doublyNestedFieldRef() { - Type.Struct params = Type.Struct.builder() - .nullable(false) - .addFields(R.I32) - .build(); + Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); - Expression.Cast innerCast = (Expression.Cast) - ExpressionCreator.cast(R.I64, paramRef, Expression.FailureBehavior.THROW_EXCEPTION); - Expression.Cast outerCast = (Expression.Cast) - ExpressionCreator.cast(R.STRING, innerCast, Expression.FailureBehavior.THROW_EXCEPTION); + Expression.Cast innerCast = + (Expression.Cast) + ExpressionCreator.cast(R.I64, paramRef, Expression.FailureBehavior.THROW_EXCEPTION); + Expression.Cast outerCast = + (Expression.Cast) + ExpressionCreator.cast(R.STRING, innerCast, Expression.FailureBehavior.THROW_EXCEPTION); - Expression.Lambda lambda = Expression.Lambda.builder() - .parameters(params) - .body(outerCast) - .build(); + Expression.Lambda lambda = + Expression.Lambda.builder().parameters(params).body(outerCast).build(); verifyRoundTrip(lambda); @@ -296,42 +247,29 @@ void doublyNestedFieldRef() { } /** - * Test lambda with literal body (no parameter references). - * Building: ($0: i32) -> 42 : func i32> + * Test lambda with literal body (no parameter references). Building: ($0: i32) -> 42 : func i32> */ @Test void lambdaWithLiteralBody() { - Type.Struct params = Type.Struct.builder() - .nullable(false) - .addFields(R.I32) - .build(); + Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); Expression body = ExpressionCreator.i32(false, 42); - Expression.Lambda lambda = Expression.Lambda.builder() - .parameters(params) - .body(body) - .build(); + Expression.Lambda lambda = Expression.Lambda.builder().parameters(params).body(body).build(); verifyRoundTrip(lambda); } - /** - * Test lambda getType returns correct Func type. - */ + /** Test lambda getType returns correct Func type. */ @Test void lambdaGetTypeReturnsFunc() { - Type.Struct params = Type.Struct.builder() - .nullable(false) - .addFields(R.I32, R.STRING) - .build(); + Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32, R.STRING).build(); FieldReference paramRef = FieldReference.newLambdaParameterReference(1, params, 0); - Expression.Lambda lambda = Expression.Lambda.builder() - .parameters(params) - .body(paramRef) - .build(); + Expression.Lambda lambda = + Expression.Lambda.builder().parameters(params).body(paramRef).build(); Type lambdaType = lambda.getType(); @@ -347,125 +285,77 @@ void lambdaGetTypeReturnsFunc() { // ==================== Validation Error Tests ==================== /** - * Test that invalid outer reference (stepsOut too high) fails during proto conversion. - * Building: ($0: i32) -> outer[$0] : INVALID (no outer lambda, stepsOut=1) + * Test that invalid outer reference (stepsOut too high) fails during proto conversion. Building: + * ($0: i32) -> outer[$0] : INVALID (no outer lambda, stepsOut=1) */ @Test void invalidOuterRef_stepsOutTooHigh() { - Type.Struct params = Type.Struct.builder() - .nullable(false) - .addFields(R.I32) - .build(); + Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); // Create a parameter reference with stepsOut=1 but no outer lambda exists FieldReference invalidRef = FieldReference.newLambdaParameterReference(0, params, 1); - Expression.Lambda lambda = Expression.Lambda.builder() - .parameters(params) - .body(invalidRef) - .build(); + Expression.Lambda lambda = + Expression.Lambda.builder().parameters(params).body(invalidRef).build(); // Convert to proto - this should work io.substrait.proto.Expression protoExpression = expressionProtoConverter.toProto(lambda); // Converting back should fail because stepsOut=1 references non-existent outer lambda - assertThrows(IllegalArgumentException.class, () -> { - protoExpressionConverter.from(protoExpression); - }, "Should fail when stepsOut references non-existent outer lambda"); + assertThrows( + IllegalArgumentException.class, + () -> { + protoExpressionConverter.from(protoExpression); + }, + "Should fail when stepsOut references non-existent outer lambda"); } /** - * Test that invalid field index (out of bounds) fails during proto conversion. - * Building: ($0: i32) -> $5 : INVALID (only has 1 param) + * Test that invalid field index (out of bounds) fails during proto conversion. Building: ($0: + * i32) -> $5 : INVALID (only has 1 param) */ @Test void invalidFieldIndex_outOfBounds() { - Type.Struct params = Type.Struct.builder() - .nullable(false) - .addFields(R.I32) - .build(); + Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); // Create a reference to field 5, but lambda only has 1 parameter (index 0) // This will fail at build time since newLambdaParameterReference accesses fields.get(5) - assertThrows(IndexOutOfBoundsException.class, () -> { - FieldReference.newLambdaParameterReference(5, params, 0); - }, "Should fail when field index is out of bounds"); + assertThrows( + IndexOutOfBoundsException.class, + () -> { + FieldReference.newLambdaParameterReference(5, params, 0); + }, + "Should fail when field index is out of bounds"); } /** - * Test nested invalid outer ref (stepsOut=2 but only 1 outer lambda). - * Building: ($0: i64) -> (($0: i32) -> outer.outer[$0]) : INVALID (no grandparent lambda) + * Test nested invalid outer ref (stepsOut=2 but only 1 outer lambda). Building: ($0: i64) -> + * (($0: i32) -> outer.outer[$0]) : INVALID (no grandparent lambda) */ @Test void nestedInvalidOuterRef() { - Type.Struct outerParams = Type.Struct.builder() - .nullable(false) - .addFields(R.I64) - .build(); + Type.Struct outerParams = Type.Struct.builder().nullable(false).addFields(R.I64).build(); - Type.Struct innerParams = Type.Struct.builder() - .nullable(false) - .addFields(R.I32) - .build(); + Type.Struct innerParams = Type.Struct.builder().nullable(false).addFields(R.I32).build(); // Inner lambda references stepsOut=2, but only 1 outer lambda exists FieldReference invalidRef = FieldReference.newLambdaParameterReference(0, outerParams, 2); - Expression.Lambda innerLambda = Expression.Lambda.builder() - .parameters(innerParams) - .body(invalidRef) - .build(); + Expression.Lambda innerLambda = + Expression.Lambda.builder().parameters(innerParams).body(invalidRef).build(); - Expression.Lambda outerLambda = Expression.Lambda.builder() - .parameters(outerParams) - .body(innerLambda) - .build(); + Expression.Lambda outerLambda = + Expression.Lambda.builder().parameters(outerParams).body(innerLambda).build(); // Convert to proto io.substrait.proto.Expression protoExpression = expressionProtoConverter.toProto(outerLambda); // Converting back should fail because stepsOut=2 references non-existent grandparent - assertThrows(IllegalArgumentException.class, () -> { - protoExpressionConverter.from(protoExpression); - }, "Should fail when stepsOut references non-existent grandparent lambda"); - } - - /** - * Test deeply nested invalid field ref inside Cast. - * Building: ($0: i32) -> cast($5 as i64) : INVALID (only has 1 param) - */ - @Test - void deeplyNestedInvalidFieldRef() { - Type.Struct params = Type.Struct.builder() - .nullable(false) - .addFields(R.I32) - .build(); - - // This should fail at build time since field 5 doesn't exist - assertThrows(IndexOutOfBoundsException.class, () -> { - FieldReference.newLambdaParameterReference(5, params, 0); - }, "Should fail when nested field index is out of bounds"); - } - - /** - * Test that outer field index out of bounds fails. - * Building: ($0: i64) -> (($0: i32) -> outer[$5]) : INVALID (outer only has 1 param) - */ - @Test - void nestedInvalidOuterFieldIndex() { - Type.Struct outerParams = Type.Struct.builder() - .nullable(false) - .addFields(R.I64) - .build(); - - Type.Struct innerParams = Type.Struct.builder() - .nullable(false) - .addFields(R.I32) - .build(); - - // This should fail at build time since outer only has 1 parameter (index 0) - assertThrows(IndexOutOfBoundsException.class, () -> { - FieldReference.newLambdaParameterReference(5, outerParams, 1); - }, "Should fail when outer field index is out of bounds"); + assertThrows( + IllegalArgumentException.class, + () -> { + protoExpressionConverter.from(protoExpression); + }, + "Should fail when stepsOut references non-existent grandparent lambda"); } } diff --git a/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java b/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java index f9125587d..93f8a9834 100644 --- a/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java +++ b/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java @@ -223,12 +223,6 @@ public String visit(IfThen expr, EmptyVisitationContext context) throws RuntimeE return ""; } - @Override - public String visit(Expression.LambdaInvocation expr, EmptyVisitationContext context) - throws RuntimeException { - return ""; - } - @Override public String visit(ScalarFunctionInvocation expr, EmptyVisitationContext context) throws RuntimeException { diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java index 0f31d3d4b..a118beec0 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/ExpressionRexConverter.java @@ -674,6 +674,8 @@ public RexNode visit(FieldReference expr, Context context) throws RuntimeExcepti return rexInputRef; } else if (expr.isLambdaParameterReference()) { + // as of now calcite doesn't support nested lambda functions + // https://github.com/substrait-io/substrait-java/issues/711 int stepsOut = expr.lambdaParameterReferenceStepsOut().get(); if (stepsOut != 0) { throw new UnsupportedOperationException( diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java index 19b3d2224..176b246e7 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java @@ -208,10 +208,9 @@ public Expression visitLambda(RexLambda rexLambda) { .map(param -> typeConverter.toSubstrait(param.getType())) .collect(Collectors.toList()); - Type.Struct parameters = Type.Struct.builder().nullable(false).addAllFields(paramTypes).build(); + Type.Struct parameters = Type.Struct.builder().nullable(false).addAllFields(paramTypes).build(); - - Expression body = rexLambda.getExpression().accept(this); + Expression body = rexLambda.getExpression().accept(this); return Expression.Lambda.builder().parameters(parameters).body(body).build(); } @@ -225,7 +224,8 @@ public Expression visitLambdaRef(RexLambdaRef rexLambdaRef) { .addSegments(FieldReference.StructField.of(fieldIndex)) .type(paramType) .lambdaParameterReferenceStepsOut( - 0) // Always 0 since Calcite doesn't support nested Lambda expressions + 0) // Always 0 since Calcite doesn't support nested Lambda expressions for now + // https://github.com/substrait-io/substrait-java/issues/711 .build(); } diff --git a/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java b/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java index 8b391a01d..add746ee1 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java @@ -1,5 +1,6 @@ package io.substrait.isthmus; +import static org.junit.jupiter.api.Assertions.assertThrows; import io.substrait.expression.Expression; import io.substrait.expression.ExpressionCreator; @@ -7,185 +8,142 @@ import io.substrait.relation.Project; import io.substrait.relation.Rel; import io.substrait.type.Type; -import org.junit.jupiter.api.Test; - import java.util.ArrayList; import java.util.List; - -import static org.junit.jupiter.api.Assertions.assertThrows; +import org.junit.jupiter.api.Test; /** - * Tests for Lambda expression conversion between Substrait and Calcite. - * Note: Calcite does not support nested lambda expressions for the moment, so all tests use stepsOut=0. + * Tests for Lambda expression conversion between Substrait and Calcite. Note: Calcite does not + * support nested lambda expressions for the moment, so all tests use stepsOut=0. */ - -public class LambdaExpressionTest extends PlanTestBase{ - - final Rel emptyTable = sb.emptyVirtualTableScan(); - - /** - * Test that lambdas with no parameters are valid. - * Building: () -> i32(42) : func<() -> i32> - */ - @Test - void lambdaExpressionZeroParameters(){ - Type.Struct params = Type.Struct.builder() - .nullable(false) - .build(); - - Expression body = ExpressionCreator.i32(false, 42); - List expressionList = new ArrayList<>(); - Expression.Lambda lambda = Expression.Lambda.builder() - .parameters(params) - .body(body) - .build(); - expressionList.add(lambda); - - Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); - assertFullRoundTrip(project); - - } - - /** - * Test valid field index with multiple parameters. - * Building: ($0: i32, $1: i64, $2: string) -> $2 : func<(i32, i64, string) -> string> - */ - @Test - void validFieldIndex() { - Type.Struct params = Type.Struct.builder() - .nullable(false) - .addFields(R.I32, R.I64, R.STRING) - .build(); - - FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); - - List expressionList = new ArrayList<>(); - Expression.Lambda lambda = Expression.Lambda.builder() - .parameters(params) - .body(paramRef) - .build(); - - expressionList.add(lambda); - - - Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); - assertFullRoundTrip(project); - } - - /** - * Test deeply nested field ref inside Cast. - * Building: ($0: i32) -> cast($0 as i64) : func i64> - */ - @Test - void deeplyNestedFieldRef() { - Type.Struct params = Type.Struct.builder() - .nullable(false) - .addFields(R.I32) - .build(); - FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); - - Expression.Cast castExpr = (Expression.Cast) - ExpressionCreator.cast(R.I64, paramRef, Expression.FailureBehavior.THROW_EXCEPTION); - List expressionList = new ArrayList<>(); - - - Expression.Lambda lambda = Expression.Lambda.builder() - .parameters(params) - .body(castExpr) - .build(); - - expressionList.add(lambda); - Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); - assertFullRoundTrip(project); - } - - /** - * Test doubly nested field ref (Cast(Cast(LambdaParamRef))). - * Building: ($0: i32) -> cast(cast($0 as i64) as string) : func string> - */ - @Test - void doublyNestedFieldRef() { - Type.Struct params = Type.Struct.builder() - .nullable(false) - .addFields(R.I32) - .build(); - - FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); - Expression.Cast innerCast = (Expression.Cast) - ExpressionCreator.cast(R.I64, paramRef, Expression.FailureBehavior.THROW_EXCEPTION); - Expression.Cast outerCast = (Expression.Cast) - ExpressionCreator.cast(R.STRING, innerCast, Expression.FailureBehavior.THROW_EXCEPTION); - - List expressionList = new ArrayList<>(); - - Expression.Lambda lambda = Expression.Lambda.builder() - .parameters(params) - .body(outerCast) - .build(); - - expressionList.add(lambda); - Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); - assertFullRoundTrip(project); - } - - /** - * Test lambda with literal body (no parameter references). - * Building: ($0: i32) -> 42 : func i32> - */ - @Test - void lambdaWithLiteralBody() { - Type.Struct params = Type.Struct.builder() - .nullable(false) - .addFields(R.I32) - .build(); - - Expression body = ExpressionCreator.i32(false, 42); - List expressionList = new ArrayList<>(); - - Expression.Lambda lambda = Expression.Lambda.builder() - .parameters(params) - .body(body) - .build(); - - expressionList.add(lambda); - Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); - assertFullRoundTrip(project); - } - - /** - * Test that nested lambda (stepsOut > 0) throws UnsupportedOperationException. - * Calcite does not support nested lambda expressions. - */ - @Test - void nestedLambdaThrowsUnsupportedOperation() { - Type.Struct outerParams = Type.Struct.builder() - .nullable(false) - .addFields(R.I64) - .build(); - - Type.Struct innerParams = Type.Struct.builder() - .nullable(false) - .addFields(R.I32) - .build(); - - // Inner lambda references outer's parameter with stepsOut=1 - FieldReference outerRef = FieldReference.newLambdaParameterReference(0, outerParams, 1); - - Expression.Lambda innerLambda = Expression.Lambda.builder() - .parameters(innerParams) - .body(outerRef) - .build(); - - List expressionList = new ArrayList<>(); - - Expression.Lambda outerLambda = Expression.Lambda.builder() - .parameters(outerParams) - .body(innerLambda) - .build(); - - expressionList.add(outerLambda); - Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); - assertThrows(UnsupportedOperationException.class, () -> assertFullRoundTrip(project)); - - } - } +class LambdaExpressionTest extends PlanTestBase { + + final Rel emptyTable = sb.emptyVirtualTableScan(); + + /** Test that lambdas with no parameters are valid. Building: () -> i32(42) : func<() -> i32> */ + @Test + void lambdaExpressionZeroParameters() { + Type.Struct params = Type.Struct.builder().nullable(false).build(); + + Expression body = ExpressionCreator.i32(false, 42); + List expressionList = new ArrayList<>(); + Expression.Lambda lambda = Expression.Lambda.builder().parameters(params).body(body).build(); + expressionList.add(lambda); + + Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + assertFullRoundTrip(project); + } + + /** + * Test valid field index with multiple parameters. Building: ($0: i32, $1: i64, $2: string) -> $2 + * : func<(i32, i64, string) -> string> + */ + @Test + void validFieldIndex() { + Type.Struct params = + Type.Struct.builder().nullable(false).addFields(R.I32, R.I64, R.STRING).build(); + + FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); + + List expressionList = new ArrayList<>(); + Expression.Lambda lambda = + Expression.Lambda.builder().parameters(params).body(paramRef).build(); + + expressionList.add(lambda); + + Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + assertFullRoundTrip(project); + } + + /** + * Test deeply nested field ref inside Cast. Building: ($0: i32) -> cast($0 as i64) : func + * i64> + */ + @Test + void deeplyNestedFieldRef() { + Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); + FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); + + Expression.Cast castExpr = + (Expression.Cast) + ExpressionCreator.cast(R.I64, paramRef, Expression.FailureBehavior.THROW_EXCEPTION); + List expressionList = new ArrayList<>(); + + Expression.Lambda lambda = + Expression.Lambda.builder().parameters(params).body(castExpr).build(); + + expressionList.add(lambda); + Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + assertFullRoundTrip(project); + } + + /** + * Test doubly nested field ref (Cast(Cast(LambdaParamRef))). Building: ($0: i32) -> cast(cast($0 + * as i64) as string) : func string> + */ + @Test + void doublyNestedFieldRef() { + Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); + + FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); + Expression.Cast innerCast = + (Expression.Cast) + ExpressionCreator.cast(R.I64, paramRef, Expression.FailureBehavior.THROW_EXCEPTION); + Expression.Cast outerCast = + (Expression.Cast) + ExpressionCreator.cast(R.STRING, innerCast, Expression.FailureBehavior.THROW_EXCEPTION); + + List expressionList = new ArrayList<>(); + + Expression.Lambda lambda = + Expression.Lambda.builder().parameters(params).body(outerCast).build(); + + expressionList.add(lambda); + Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + assertFullRoundTrip(project); + } + + /** + * Test lambda with literal body (no parameter references). Building: ($0: i32) -> 42 : func i32> + */ + @Test + void lambdaWithLiteralBody() { + Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); + + Expression body = ExpressionCreator.i32(false, 42); + List expressionList = new ArrayList<>(); + + Expression.Lambda lambda = Expression.Lambda.builder().parameters(params).body(body).build(); + + expressionList.add(lambda); + Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + assertFullRoundTrip(project); + } + + /** + * Test that nested lambda (stepsOut > 0) throws UnsupportedOperationException. Calcite does not + * support nested lambda expressions. + */ + @Test + void nestedLambdaThrowsUnsupportedOperation() { + Type.Struct outerParams = Type.Struct.builder().nullable(false).addFields(R.I64).build(); + + Type.Struct innerParams = Type.Struct.builder().nullable(false).addFields(R.I32).build(); + + // Inner lambda references outer's parameter with stepsOut=1 + FieldReference outerRef = FieldReference.newLambdaParameterReference(0, outerParams, 1); + + Expression.Lambda innerLambda = + Expression.Lambda.builder().parameters(innerParams).body(outerRef).build(); + + List expressionList = new ArrayList<>(); + + Expression.Lambda outerLambda = + Expression.Lambda.builder().parameters(outerParams).body(innerLambda).build(); + + expressionList.add(outerLambda); + Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + assertThrows(UnsupportedOperationException.class, () -> assertFullRoundTrip(project)); + } +} From 40533570d4ef9b3f1765a315f72adaee5ed98e4a Mon Sep 17 00:00:00 2001 From: Limame Malainine Date: Mon, 2 Mar 2026 14:38:33 +0100 Subject: [PATCH 03/18] tweak: encapsulate LambdaParameterStack logic in a class --- .../io/substrait/expression/Expression.java | 2 + .../proto/ProtoExpressionConverter.java | 41 ++++++++++++++----- 2 files changed, 32 insertions(+), 11 deletions(-) diff --git a/core/src/main/java/io/substrait/expression/Expression.java b/core/src/main/java/io/substrait/expression/Expression.java index 320a8a614..00a92d797 100644 --- a/core/src/main/java/io/substrait/expression/Expression.java +++ b/core/src/main/java/io/substrait/expression/Expression.java @@ -718,6 +718,8 @@ public Type getType() { List paramTypes = parameters().fields(); Type returnType = body().getType(); + // TO DO: fix Lambda return type once this issue + // https://github.com/substrait-io/substrait/issues/976 is resolved return Type.withNullability(false).func(paramTypes, returnType); } diff --git a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java index 3780a8e1b..09ee2da59 100644 --- a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java @@ -37,7 +37,7 @@ public class ProtoExpressionConverter { private final Type.Struct rootType; private final ProtoTypeConverter protoTypeConverter; private final ProtoRelConverter protoRelConverter; - private final List lambdaParameterStack = new ArrayList<>(); + private final LambdaParameterStack lambdaParameterStack = new LambdaParameterStack(); public ProtoExpressionConverter( ExtensionLookup lookup, @@ -82,15 +82,8 @@ public FieldReference from(io.substrait.proto.Expression.FieldReference referenc reference.getLambdaParameterReference(); int stepsOut = lambdaParamRef.getStepsOut(); - int lambdaIndex = lambdaParameterStack.size() - 1 - stepsOut; - if (lambdaIndex < 0 || lambdaIndex >= lambdaParameterStack.size()) { - throw new IllegalArgumentException( - String.format( - "Lambda parameter reference with stepsOut=%d is invalid (current depth: %d)", - stepsOut, lambdaParameterStack.size())); - } + Type.Struct lambdaParameters = lambdaParameterStack.get(stepsOut); - Type.Struct lambdaParameters = lambdaParameterStack.get(lambdaIndex); return FieldReference.newLambdaParameterReference( reference.getDirectReference().getStructField().getField(), lambdaParameters, @@ -291,13 +284,13 @@ public Type visit(Type.Struct type) throws RuntimeException { .setStruct(protoLambda.getParameters()) .build()); - lambdaParameterStack.add(parameters); + lambdaParameterStack.push(parameters); Expression body; try { body = from(protoLambda.getBody()); } finally { - lambdaParameterStack.remove(lambdaParameterStack.size() - 1); + lambdaParameterStack.pop(); } return Expression.Lambda.builder().parameters(parameters).body(body).build(); @@ -616,4 +609,30 @@ public Expression.SortField fromSortField(SortField s) { public static FunctionOption fromFunctionOption(io.substrait.proto.FunctionOption o) { return FunctionOption.builder().name(o.getName()).addAllValues(o.getPreferenceList()).build(); } + + private static class LambdaParameterStack { + private final List stack = new ArrayList<>(); + + void push(Type.Struct parameters) { + stack.add(parameters); + } + + void pop() { + if (stack.isEmpty()) { + throw new IllegalArgumentException("Lambda parameter stack is empty"); + } + stack.remove(stack.size() - 1); + } + + Type.Struct get(int stepsOut) { + int index = stack.size() - 1 - stepsOut; + if (index < 0 || index >= stack.size()) { + throw new IllegalArgumentException( + String.format( + "Lambda parameter reference with stepsOut=%d is invalid (current depth: %d)", + stepsOut, stack.size())); + } + return stack.get(index); + } + } } From eb7b2b23c8fb6e2f4eadde602c4db8128b0c99ea Mon Sep 17 00:00:00 2001 From: Limame Malainine Date: Mon, 2 Mar 2026 16:31:49 +0100 Subject: [PATCH 04/18] tweak: adding comments expailining LambdaParameterStack --- .../expression/proto/ProtoExpressionConverter.java | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java index 09ee2da59..ff864c259 100644 --- a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java @@ -610,6 +610,18 @@ public static FunctionOption fromFunctionOption(io.substrait.proto.FunctionOptio return FunctionOption.builder().name(o.getName()).addAllValues(o.getPreferenceList()).build(); } + /** + * A stack for tracking lambda parameter types during expression parsing. + * + *

When parsing nested lambda expressions, each lambda's parameters are pushed onto this stack. + * Lambda parameter references use "stepsOut" to indicate which enclosing lambda they reference: + * + *

    + *
  • stepsOut=0 refers to the innermost (current) lambda + *
  • stepsOut=1 refers to the next enclosing lambda + *
  • stepsOut=N refers to N levels up + *
+ */ private static class LambdaParameterStack { private final List stack = new ArrayList<>(); From 1d1b26ee9def88df3c6f5f2f4c1636e7ec94377f Mon Sep 17 00:00:00 2001 From: Victor Barua Date: Fri, 6 Mar 2026 16:04:44 -0800 Subject: [PATCH 05/18] build: ignore build related generated files --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index 5eab9ab03..4093b00d1 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,6 @@ out/** .metals .bloop .project +.classpath +.settings +bin/ From 2224c4b5708969500ca6e9a6148cfbc5b11e233f Mon Sep 17 00:00:00 2001 From: Victor Barua Date: Fri, 6 Mar 2026 13:41:28 -0800 Subject: [PATCH 06/18] feat: enable parsing of func types in extensions --- .../extension/DefaultExtensionCatalog.java | 1 + .../io/substrait/function/ToTypeString.java | 10 ++++ .../io/substrait/type/parser/ParseToPojo.java | 49 +++++++++++++++++-- .../DefaultExtensionCatalogTest.java | 13 +++++ .../substrait/type/parser/TestTypeParser.java | 16 ++++++ 5 files changed, 86 insertions(+), 3 deletions(-) create mode 100644 core/src/test/java/io/substrait/extension/DefaultExtensionCatalogTest.java diff --git a/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java b/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java index 478ed63e2..d8b6b1fa6 100644 --- a/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java +++ b/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java @@ -78,6 +78,7 @@ private static SimpleExtension.ExtensionCollection loadDefaultCollection() { "arithmetic", "comparison", "datetime", + "list", "logarithmic", "rounding", "rounding_decimal", diff --git a/core/src/main/java/io/substrait/function/ToTypeString.java b/core/src/main/java/io/substrait/function/ToTypeString.java index d6fc1bdb8..5c942f46a 100644 --- a/core/src/main/java/io/substrait/function/ToTypeString.java +++ b/core/src/main/java/io/substrait/function/ToTypeString.java @@ -150,6 +150,11 @@ public String visit(final Type.Map expr) { return "map"; } + @Override + public String visit(Type.Func type) throws RuntimeException { + return "func"; + } + @Override public String visit(final Type.UserDefined expr) { return String.format("u!%s", expr.name()); @@ -210,6 +215,11 @@ public String visit(ParameterizedType.Map expr) throws RuntimeException { return "map"; } + @Override + public String visit(ParameterizedType.Func expr) throws RuntimeException { + return "func"; + } + @Override public String visit(ParameterizedType.StringLiteral expr) throws RuntimeException { if (expr.value().toLowerCase().startsWith("any")) { diff --git a/core/src/main/java/io/substrait/type/parser/ParseToPojo.java b/core/src/main/java/io/substrait/type/parser/ParseToPojo.java index 8555ab219..ddad886b1 100644 --- a/core/src/main/java/io/substrait/type/parser/ParseToPojo.java +++ b/core/src/main/java/io/substrait/type/parser/ParseToPojo.java @@ -1,5 +1,6 @@ package io.substrait.type.parser; +import io.substrait.function.ImmutableParameterizedType; import io.substrait.function.ImmutableTypeExpression; import io.substrait.function.ParameterizedType; import io.substrait.function.ParameterizedTypeCreator; @@ -411,19 +412,61 @@ public TypeExpression visitNStruct(final SubstraitTypeParser.NStructContext ctx) @Override public TypeExpression visitFunc(final SubstraitTypeParser.FuncContext ctx) { - throw new UnsupportedOperationException(); + boolean nullable = ctx.isnull != null; + + // Process function parameters + List paramExprs; + if (ctx.params instanceof SubstraitTypeParser.SingleFuncParamContext) { + paramExprs = + java.util.Collections.singletonList( + ((SubstraitTypeParser.SingleFuncParamContext) ctx.params).expr().accept(this)); + } else if (ctx.params instanceof SubstraitTypeParser.FuncParamsWithParensContext) { + paramExprs = + ((SubstraitTypeParser.FuncParamsWithParensContext) ctx.params) + .expr().stream() + .map(e -> e.accept(this)) + .collect(java.util.stream.Collectors.toList()); + } else { + throw new UnsupportedOperationException( + "Unknown funcParams type: " + ctx.params.getClass()); + } + + // Process return type + TypeExpression returnExpr = ctx.returnType.accept(this); + + // If all types are instances of Type, we return a Type + if (paramExprs.stream().allMatch(p -> p instanceof Type) && returnExpr instanceof Type) { + ImmutableType.Func.Builder builder = ImmutableType.Func.builder().nullable(nullable); + paramExprs.forEach(p -> builder.addParameterTypes((Type) p)); + return builder.returnType((Type) returnExpr).build(); + } + + // If all types are instances of ParameterizedType, we return a ParameterizedType + if (paramExprs.stream().allMatch(p -> p instanceof ParameterizedType) + && returnExpr instanceof ParameterizedType) { + checkParameterizedOrExpression(); + ImmutableParameterizedType.Func.Builder builder = + ParameterizedType.Func.builder().nullable(nullable); + paramExprs.forEach(p -> builder.addParameterTypes((ParameterizedType) p)); + return builder.returnType((ParameterizedType) returnExpr).build(); + } + + throw new UnsupportedOperationException( + "func type with TypeExpression-level parameter or return types are not yet supported"); } @Override public TypeExpression visitSingleFuncParam( final SubstraitTypeParser.SingleFuncParamContext ctx) { - throw new UnsupportedOperationException(); + throw new UnsupportedOperationException( + "visitSingleFuncParam is handled in visitFunc directly"); } @Override public TypeExpression visitFuncParamsWithParens( final SubstraitTypeParser.FuncParamsWithParensContext ctx) { - throw new UnsupportedOperationException(); + throw new UnsupportedOperationException( + "visitFuncParamsWithParens is handled in visitFunc directly"); } @Override diff --git a/core/src/test/java/io/substrait/extension/DefaultExtensionCatalogTest.java b/core/src/test/java/io/substrait/extension/DefaultExtensionCatalogTest.java new file mode 100644 index 000000000..7aed4e961 --- /dev/null +++ b/core/src/test/java/io/substrait/extension/DefaultExtensionCatalogTest.java @@ -0,0 +1,13 @@ +package io.substrait.extension; + +import static org.junit.jupiter.api.Assertions.assertNotNull; + +import org.junit.jupiter.api.Test; + +class DefaultExtensionCatalogTest { + + @Test + void defaultCollectionLoads() { + assertNotNull(DefaultExtensionCatalog.DEFAULT_COLLECTION); + } +} diff --git a/core/src/test/java/io/substrait/type/parser/TestTypeParser.java b/core/src/test/java/io/substrait/type/parser/TestTypeParser.java index 010ad123a..92989795b 100644 --- a/core/src/test/java/io/substrait/type/parser/TestTypeParser.java +++ b/core/src/test/java/io/substrait/type/parser/TestTypeParser.java @@ -8,6 +8,7 @@ import io.substrait.type.ImmutableType; import io.substrait.type.Type; import io.substrait.type.TypeCreator; +import java.util.List; import org.junit.jupiter.api.Test; class TestTypeParser { @@ -120,6 +121,11 @@ private void compoundTests(ParseToPojo.Visitor v) { test(v, n.precisionTimestamp(9), "PRECISION_TIMESTAMP?<9>"); test(v, r.precisionTimestampTZ(6), "PRECISION_TIMESTAMP_TZ<6>"); test(v, n.precisionTimestampTZ(9), "PRECISION_TIMESTAMP_TZ?<9>"); + + test(v, r.func(List.of(r.I8), r.I32), "func i32>"); + test(v, r.func(List.of(r.I8, r.I8), r.I32), "func<(i8, i8) -> i32>"); + test(v, n.func(List.of(r.I8), r.I32), "func? i32>"); + test(v, r.func(List.of(n.I8), n.I32), "func i32?>"); } private void parameterizedTests(ParseToPojo.Visitor v) { @@ -142,6 +148,16 @@ private void parameterizedTests(ParseToPojo.Visitor v) { test(v, pr.precisionTimeE("P"), "PRECISION_TIME

"); test(v, pr.precisionTimestampE("P"), "PRECISION_TIMESTAMP

"); test(v, pr.precisionTimestampTZE("P"), "PRECISION_TIMESTAMP_TZ

"); + + test(v, pr.funcE(List.of(pr.parameter("any")), r.I64), "func i64>"); + test(v, pr.funcE(List.of(pr.parameter("any"), r.I64), r.I64), "func<(any, i64) -> i64>"); + test(v, pr.funcE(List.of(pr.parameter("any1")), pr.parameter("any1")), "func any1>"); + test(v, pn.funcE(List.of(pr.parameter("any")), n.I64), "func? i64?>"); + test(v, pn.funcE(List.of(pr.parameter("any1")), pr.parameter("any1")), "func? any1>"); + test( + v, + pr.funcE(List.of(pr.parameter("any1"), r.I8), pr.parameter("any1")), + "func<(any1, i8) -> any1>"); } @Test From 83b9e2480cfe0065ded455ee417abf6a79258eac Mon Sep 17 00:00:00 2001 From: Victor Barua Date: Tue, 10 Mar 2026 14:47:27 -0700 Subject: [PATCH 07/18] test: copy substrait-go lambda plans Note that these had to be tweaked slightly because they were not entirely valid. This will be fixed in substrait-go. --- .../test/resources/lambdas/basic-lambda.json | 132 ++++++++++++++ .../resources/lambdas/lambda-field-ref.json | 135 ++++++++++++++ .../lambdas/lambda-with-function.json | 172 ++++++++++++++++++ 3 files changed, 439 insertions(+) create mode 100644 isthmus/src/test/resources/lambdas/basic-lambda.json create mode 100644 isthmus/src/test/resources/lambdas/lambda-field-ref.json create mode 100644 isthmus/src/test/resources/lambdas/lambda-with-function.json diff --git a/isthmus/src/test/resources/lambdas/basic-lambda.json b/isthmus/src/test/resources/lambdas/basic-lambda.json new file mode 100644 index 000000000..114e3ad6d --- /dev/null +++ b/isthmus/src/test/resources/lambdas/basic-lambda.json @@ -0,0 +1,132 @@ +{ + "version": { + "majorNumber": 0, + "minorNumber": 79 + }, + "extensionUrns": [ + { + "extensionUrnAnchor": 1, + "urn": "extension:io.substrait:functions_list" + } + ], + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_list.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUrnReference": 1, + "extensionUriReference": 1, + "functionAnchor": 1, + "name": "transform:list_func" + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "values" + ], + "struct": { + "nullability": "NULLABILITY_REQUIRED", + "types": [ + { + "list": { + "type": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "nullability": "NULLABILITY_REQUIRED" + } + } + ] + } + }, + "namedTable": { + "names": [ + "test_table" + ] + } + } + }, + "expressions": [ + { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "list": { + "type": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "lambda": { + "parameters": { + "nullability": "NULLABILITY_REQUIRED", + "types": [ + { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + } + ] + }, + "body": { + "literal": { + "i32": 42 + } + } + } + } + } + ] + } + } + ] + } + }, + "names": [ + "result" + ] + } + } + ] +} diff --git a/isthmus/src/test/resources/lambdas/lambda-field-ref.json b/isthmus/src/test/resources/lambdas/lambda-field-ref.json new file mode 100644 index 000000000..58c041582 --- /dev/null +++ b/isthmus/src/test/resources/lambdas/lambda-field-ref.json @@ -0,0 +1,135 @@ +{ + "extensionUrns": [ + { + "extensionUrnAnchor": 1, + "urn": "extension:io.substrait:functions_list" + } + ], + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_list.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUrnReference": 1, + "extensionUriReference": 1, + "functionAnchor": 1, + "name": "transform:list_func" + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "values" + ], + "struct": { + "nullability": "NULLABILITY_REQUIRED", + "types": [ + { + "list": { + "type": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "nullability": "NULLABILITY_REQUIRED" + } + } + ] + } + }, + "namedTable": { + "names": [ + "test_table" + ] + } + } + }, + "expressions": [ + { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "list": { + "type": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "lambda": { + "parameters": { + "nullability": "NULLABILITY_REQUIRED", + "types": [ + { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + } + ] + }, + "body": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "lambdaParameterReference": { + "stepsOut": 0 + } + } + } + } + } + } + ] + } + } + ] + } + }, + "names": [ + "result" + ] + } + } + ] +} diff --git a/isthmus/src/test/resources/lambdas/lambda-with-function.json b/isthmus/src/test/resources/lambdas/lambda-with-function.json new file mode 100644 index 000000000..9c0a1a55b --- /dev/null +++ b/isthmus/src/test/resources/lambdas/lambda-with-function.json @@ -0,0 +1,172 @@ +{ + "extensionUrns": [ + { + "extensionUrnAnchor": 1, + "urn": "extension:io.substrait:functions_arithmetic" + }, + { + "extensionUrnAnchor": 2, + "urn": "extension:io.substrait:functions_list" + } + ], + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml" + }, + { + "extensionUriAnchor": 2, + "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_list.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUrnReference": 1, + "extensionUriReference": 1, + "functionAnchor": 1, + "name": "multiply:i32_i32" + } + }, + { + "extensionFunction": { + "extensionUrnReference": 2, + "extensionUriReference": 2, + "functionAnchor": 2, + "name": "transform:list_func" + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "values" + ], + "struct": { + "nullability": "NULLABILITY_REQUIRED", + "types": [ + { + "list": { + "type": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "nullability": "NULLABILITY_REQUIRED" + } + } + ] + } + }, + "namedTable": { + "names": [ + "test_table" + ] + } + } + }, + "expressions": [ + { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "list": { + "type": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "lambda": { + "parameters": { + "nullability": "NULLABILITY_REQUIRED", + "types": [ + { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + } + ] + }, + "body": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "lambdaParameterReference": { + "stepsOut": 0 + } + } + } + }, + { + "value": { + "literal": { + "i32": 2 + } + } + } + ] + } + } + } + } + } + ] + } + } + ] + } + }, + "names": [ + "result" + ] + } + } + ] +} From d802c19fe9c76fe98c949da77aa498244521d158 Mon Sep 17 00:00:00 2001 From: Victor Barua Date: Tue, 10 Mar 2026 14:48:09 -0700 Subject: [PATCH 08/18] test: add LambdaRoundtripTests --- .../isthmus/LambdaRoundtripTest.java | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 isthmus/src/test/java/io/substrait/isthmus/LambdaRoundtripTest.java diff --git a/isthmus/src/test/java/io/substrait/isthmus/LambdaRoundtripTest.java b/isthmus/src/test/java/io/substrait/isthmus/LambdaRoundtripTest.java new file mode 100644 index 000000000..d34fa7744 --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/LambdaRoundtripTest.java @@ -0,0 +1,38 @@ +package io.substrait.isthmus; + +import com.google.protobuf.util.JsonFormat; +import io.substrait.plan.Plan; +import io.substrait.plan.ProtoPlanConverter; +import java.io.IOException; +import org.junit.jupiter.api.Test; + +public class LambdaRoundtripTest extends PlanTestBase { + + public static io.substrait.proto.Plan readJsonPlan(String resourcePath) throws IOException { + String json = asString(resourcePath); + io.substrait.proto.Plan.Builder builder = io.substrait.proto.Plan.newBuilder(); + JsonFormat.parser().merge(json, builder); + return builder.build(); + } + + @Test + void testBasicLambdaRoundtrip() throws IOException { + io.substrait.proto.Plan proto = readJsonPlan("lambdas/basic-lambda.json"); + Plan plan = new ProtoPlanConverter(extensions).from(proto); + assertFullRoundTrip(plan.getRoots().get(0)); + } + + @Test + void testLambdaWithFieldRefRoundtrip() throws IOException { + io.substrait.proto.Plan proto = readJsonPlan("lambdas/lambda-field-ref.json"); + Plan plan = new ProtoPlanConverter(extensions).from(proto); + assertFullRoundTrip(plan.getRoots().get(0)); + } + + @Test + void testLambdaWithFunctionRoundtrip() throws IOException { + io.substrait.proto.Plan proto = readJsonPlan("lambdas/lambda-with-function.json"); + Plan plan = new ProtoPlanConverter(extensions).from(proto); + assertFullRoundTrip(plan.getRoots().get(0)); + } +} From 14738e31fe451e243025c77c465482077b58dc22 Mon Sep 17 00:00:00 2001 From: Victor Barua Date: Mon, 9 Mar 2026 07:56:42 -0700 Subject: [PATCH 09/18] feat(isthmus): add TRANSFORM SqlFunction to handle transform:list_func --- .../isthmus/expression/FunctionMappings.java | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java index 90f04a326..87b688b58 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java @@ -5,14 +5,25 @@ import java.util.Locale; import java.util.Map; import java.util.Set; +import org.apache.calcite.sql.SqlBasicFunction; +import org.apache.calcite.sql.SqlFunction; import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.fun.SqlLibraryOperators; import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.type.OperandTypes; +import org.apache.calcite.sql.type.SqlTypeFamily; public class FunctionMappings { // Static list of signature mapping between Calcite SQL operators and Substrait base function // names. + /** The transform:list_func function; applies a lambda to each element of an array. */ + public static final SqlFunction TRANSFORM = + SqlBasicFunction.create( + "transform", + opBinding -> opBinding.getTypeFactory().createArrayType(opBinding.getOperandType(1), -1), + OperandTypes.family(SqlTypeFamily.ARRAY, SqlTypeFamily.ANY)); + public static final ImmutableList SCALAR_SIGS = ImmutableList.builder() .add( @@ -100,7 +111,8 @@ public class FunctionMappings { s(SqlLibraryOperators.RPAD, "rpad"), s(SqlLibraryOperators.PARSE_TIME, "strptime_time"), s(SqlLibraryOperators.PARSE_TIMESTAMP, "strptime_timestamp"), - s(SqlLibraryOperators.PARSE_DATE, "strptime_date")) + s(SqlLibraryOperators.PARSE_DATE, "strptime_date"), + s(TRANSFORM, "transform")) .build(); public static final ImmutableList AGGREGATE_SIGS = From 6fae010c22e5ff204841ddf2a0da4d0472630ff8 Mon Sep 17 00:00:00 2001 From: Limame Malainine Date: Wed, 11 Mar 2026 16:44:03 +0100 Subject: [PATCH 10/18] tweak: add filter in the function mapping --- .../substrait/isthmus/expression/FunctionMappings.java | 10 +++++++++- .../java/io/substrait/isthmus/LambdaRoundtripTest.java | 2 +- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java index 87b688b58..cc2b098e2 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java @@ -24,6 +24,13 @@ public class FunctionMappings { opBinding -> opBinding.getTypeFactory().createArrayType(opBinding.getOperandType(1), -1), OperandTypes.family(SqlTypeFamily.ARRAY, SqlTypeFamily.ANY)); + /** The filter:list_func function; filters elements of an array using a predicate lambda. */ + public static final SqlFunction FILTER = + SqlBasicFunction.create( + "filter", + opBinding -> opBinding.getOperandType(0), + OperandTypes.family(SqlTypeFamily.ARRAY, SqlTypeFamily.ANY)); + public static final ImmutableList SCALAR_SIGS = ImmutableList.builder() .add( @@ -112,7 +119,8 @@ public class FunctionMappings { s(SqlLibraryOperators.PARSE_TIME, "strptime_time"), s(SqlLibraryOperators.PARSE_TIMESTAMP, "strptime_timestamp"), s(SqlLibraryOperators.PARSE_DATE, "strptime_date"), - s(TRANSFORM, "transform")) + s(TRANSFORM, "transform"), + s(FILTER, "filter")) .build(); public static final ImmutableList AGGREGATE_SIGS = diff --git a/isthmus/src/test/java/io/substrait/isthmus/LambdaRoundtripTest.java b/isthmus/src/test/java/io/substrait/isthmus/LambdaRoundtripTest.java index d34fa7744..427dc67ba 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/LambdaRoundtripTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/LambdaRoundtripTest.java @@ -6,7 +6,7 @@ import java.io.IOException; import org.junit.jupiter.api.Test; -public class LambdaRoundtripTest extends PlanTestBase { +class LambdaRoundtripTest extends PlanTestBase { public static io.substrait.proto.Plan readJsonPlan(String resourcePath) throws IOException { String json = asString(resourcePath); From 1c1f8d2d6977a6038c4725c184320694a7bd1b0c Mon Sep 17 00:00:00 2001 From: Limame Malainine Date: Fri, 13 Mar 2026 14:16:53 +0100 Subject: [PATCH 11/18] adress @vbarua's comments --- core/src/main/java/io/substrait/expression/Expression.java | 5 +++-- .../expression/proto/ProtoExpressionConverter.java | 6 ++++++ .../main/java/io/substrait/relation/VirtualTableScan.java | 3 +-- .../io/substrait/examples/util/ExpressionStringify.java | 2 +- 4 files changed, 11 insertions(+), 5 deletions(-) diff --git a/core/src/main/java/io/substrait/expression/Expression.java b/core/src/main/java/io/substrait/expression/Expression.java index a519a2166..72510ad2c 100644 --- a/core/src/main/java/io/substrait/expression/Expression.java +++ b/core/src/main/java/io/substrait/expression/Expression.java @@ -769,8 +769,9 @@ public Type getType() { List paramTypes = parameters().fields(); Type returnType = body().getType(); - // TO DO: fix Lambda return type once this issue - // https://github.com/substrait-io/substrait/issues/976 is resolved + // TODO: Type.Func nullability is hardcoded to false here because the spec does not allow for + // declaring otherwise. + // See: https://github.com/substrait-io/substrait/issues/976 return Type.withNullability(false).func(paramTypes, returnType); } diff --git a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java index 603196ebd..290e88c49 100644 --- a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java @@ -84,6 +84,12 @@ public FieldReference from(io.substrait.proto.Expression.FieldReference referenc int stepsOut = lambdaParamRef.getStepsOut(); Type.Struct lambdaParameters = lambdaParameterStack.get(stepsOut); + // Check for unsupported nested field access + if (reference.getDirectReference().getStructField().hasChild()) { + throw new UnsupportedOperationException( + "Nested field access in lambda parameters is not yet supported"); + } + return FieldReference.newLambdaParameterReference( reference.getDirectReference().getStructField().getField(), lambdaParameters, diff --git a/core/src/main/java/io/substrait/relation/VirtualTableScan.java b/core/src/main/java/io/substrait/relation/VirtualTableScan.java index de7c29dbb..36e5b6dcc 100644 --- a/core/src/main/java/io/substrait/relation/VirtualTableScan.java +++ b/core/src/main/java/io/substrait/relation/VirtualTableScan.java @@ -245,8 +245,7 @@ public Integer visit(Type.UserDefined type) throws RuntimeException { @Override public Integer visit(Type.Func type) throws RuntimeException { - return type.parameterTypes().stream().mapToInt(p -> p.accept(this)).sum() - + type.returnType().accept(this); + return 0; } } } diff --git a/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java b/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java index 72f5ac556..2d345642c 100644 --- a/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java +++ b/examples/substrait-spark/src/main/java/io/substrait/examples/util/ExpressionStringify.java @@ -205,7 +205,7 @@ public String visit(Expression.NestedStruct expr, EmptyVisitationContext context @Override public String visit(Expression.Lambda expr, EmptyVisitationContext context) throws RuntimeException { - return ""; + return ""; } @Override From bd30a21e0816ff941cdfb5cbd8fa820101d2d316 Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Mon, 16 Mar 2026 13:52:10 -0400 Subject: [PATCH 12/18] feat(core): add LambdaBuilder for build-time validation of lambda parameter references Introduces LambdaBuilder, a context-aware builder that maintains a lambda parameter stack (lambdaContext) to validate parameter references at build time. Nested lambdas use the same builder, ensuring stepsOut is computed automatically. Mirrors the lambdaContext pattern from substrait-go. --- .../io/substrait/expression/Expression.java | 4 - .../substrait/expression/LambdaBuilder.java | 99 ++++ .../proto/ProtoExpressionConverter.java | 4 +- .../ExpressionCopyOnWriteVisitor.java | 6 +- .../expression/LambdaBuilderTest.java | 44 ++ .../proto/LambdaExpressionRoundtripTest.java | 473 ++++++++---------- .../expression/RexExpressionConverter.java | 3 +- .../isthmus/LambdaExpressionTest.java | 127 +---- 8 files changed, 377 insertions(+), 383 deletions(-) create mode 100644 core/src/main/java/io/substrait/expression/LambdaBuilder.java create mode 100644 core/src/test/java/io/substrait/expression/LambdaBuilderTest.java diff --git a/core/src/main/java/io/substrait/expression/Expression.java b/core/src/main/java/io/substrait/expression/Expression.java index 72510ad2c..b116cbdc2 100644 --- a/core/src/main/java/io/substrait/expression/Expression.java +++ b/core/src/main/java/io/substrait/expression/Expression.java @@ -775,10 +775,6 @@ public Type getType() { return Type.withNullability(false).func(paramTypes, returnType); } - public static ImmutableExpression.Lambda.Builder builder() { - return ImmutableExpression.Lambda.builder(); - } - @Override public R accept( ExpressionVisitor visitor, C context) throws E { diff --git a/core/src/main/java/io/substrait/expression/LambdaBuilder.java b/core/src/main/java/io/substrait/expression/LambdaBuilder.java new file mode 100644 index 000000000..b2b89c02d --- /dev/null +++ b/core/src/main/java/io/substrait/expression/LambdaBuilder.java @@ -0,0 +1,99 @@ +package io.substrait.expression; + +import io.substrait.type.Type; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Function; + +/** + * Builds lambda expressions with build-time validation of parameter references. + * + *

Maintains a stack of lambda parameter scopes. Each call to {@link #lambda} pushes parameters + * onto the stack, builds the body via a callback, and pops. Nested lambdas simply call {@code + * lambda()} again on the same builder. + * + *

The callback receives a {@link Scope} handle for creating validated parameter references. The + * correct {@code stepsOut} value is computed automatically from the stack. + * + *

{@code
+ * LambdaBuilder lb = new LambdaBuilder();
+ *
+ * // Simple: (x: i32) -> x
+ * Expression.Lambda simple = lb.lambda(List.of(R.I32), x -> x.ref(0));
+ *
+ * // Nested: (x: i32) -> (y: i64) -> add(x, y)
+ * Expression.Lambda nested = lb.lambda(List.of(R.I32), x ->
+ *     lb.lambda(List.of(R.I64), y ->
+ *         add(x.ref(0), y.ref(0))
+ *     )
+ * );
+ * }
+ */ +public class LambdaBuilder { + private final List lambdaContext = new ArrayList<>(); + + /** + * Builds a lambda expression. The body function receives a {@link Scope} for creating validated + * parameter references. Nested lambdas are built by calling this method again inside the + * callback. + * + * @param paramTypes the lambda's parameter types + * @param bodyFn function that builds the lambda body given a scope handle + * @return the constructed lambda expression + */ + public Expression.Lambda lambda(List paramTypes, Function bodyFn) { + Type.Struct params = Type.Struct.builder().nullable(false).addAllFields(paramTypes).build(); + pushLambdaContext(params); + try { + int index = lambdaContext.size() - 1; + Scope scope = new Scope(index); + Expression body = bodyFn.apply(scope); + return ImmutableExpression.Lambda.builder().parameters(params).body(body).build(); + } finally { + popLambdaContext(); + } + } + + /** + * Pushes a lambda's parameters onto the context stack. This makes the parameters available for + * validation when building the lambda's body, and allows nested lambda parameter references to + * correctly compute their stepsOut values. + */ + private void pushLambdaContext(Type.Struct params) { + lambdaContext.add(params); + } + + /** + * Pops the most recently pushed lambda parameters from the context stack. Called after a lambda's + * body has been built, restoring the context to the enclosing lambda's scope. + */ + private void popLambdaContext() { + lambdaContext.remove(lambdaContext.size() - 1); + } + + /** + * A handle to a particular lambda's parameter scope. Use {@link #ref} to create validated + * parameter references. + */ + public class Scope { + private final int index; + + private Scope(int index) { + this.index = index; + } + + /** + * Creates a validated reference to a parameter of this lambda. The correct {@code stepsOut} + * value is computed automatically. + * + * @param paramIndex index of the parameter within this lambda's parameter struct + * @return a {@link FieldReference} pointing to the specified parameter + * @throws IndexOutOfBoundsException if paramIndex is out of bounds + */ + public FieldReference ref(int paramIndex) { + int stepsOut = lambdaContext.size() - 1 - index; + return FieldReference.newLambdaParameterReference( + paramIndex, lambdaContext.get(index), stepsOut); + } + } +} diff --git a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java index 290e88c49..426b0f56a 100644 --- a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java @@ -6,6 +6,7 @@ import io.substrait.expression.FieldReference.ReferenceSegment; import io.substrait.expression.FunctionArg; import io.substrait.expression.FunctionOption; +import io.substrait.expression.ImmutableExpression; import io.substrait.expression.WindowBound; import io.substrait.extension.ExtensionLookup; import io.substrait.extension.SimpleExtension; @@ -282,6 +283,7 @@ public Type visit(Type.Struct type) throws RuntimeException { case LAMBDA: { + // TODO: Add build-time validation of lambda parameter references during deserialization. io.substrait.proto.Expression.Lambda protoLambda = expr.getLambda(); Type.Struct parameters = (Type.Struct) @@ -299,7 +301,7 @@ public Type visit(Type.Struct type) throws RuntimeException { lambdaParameterStack.pop(); } - return Expression.Lambda.builder().parameters(parameters).body(body).build(); + return ImmutableExpression.Lambda.builder().parameters(parameters).body(body).build(); } // TODO enum. case ENUM: diff --git a/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java b/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java index 11fc14f27..b097614a0 100644 --- a/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java +++ b/core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java @@ -8,6 +8,7 @@ import io.substrait.expression.ExpressionVisitor; import io.substrait.expression.FieldReference; import io.substrait.expression.FunctionArg; +import io.substrait.expression.ImmutableExpression; import io.substrait.util.EmptyVisitationContext; import java.util.List; import java.util.Optional; @@ -448,7 +449,10 @@ public Optional visit(Expression.Lambda lambda, EmptyVisitationConte return Optional.empty(); } return Optional.of( - Expression.Lambda.builder().from(lambda).body(newBody.orElse(lambda.body())).build()); + ImmutableExpression.Lambda.builder() + .from(lambda) + .body(newBody.orElse(lambda.body())) + .build()); } // utilities diff --git a/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java b/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java new file mode 100644 index 000000000..7d8f8976b --- /dev/null +++ b/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java @@ -0,0 +1,44 @@ +package io.substrait.expression; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +import io.substrait.type.TypeCreator; +import java.util.List; +import org.junit.jupiter.api.Test; + +/** Tests for {@link LambdaBuilder} build-time validation. */ +class LambdaBuilderTest { + + static final TypeCreator R = TypeCreator.REQUIRED; + + final LambdaBuilder lb = new LambdaBuilder(); + + // (x: i32) -> x[5] — field index 5 is out of bounds (only 1 param) + @Test + void invalidFieldIndex_outOfBounds() { + assertThrows( + IndexOutOfBoundsException.class, () -> lb.lambda(List.of(R.I32), params -> params.ref(5))); + } + + // (x: i32) -> x[-1] — negative field index + @Test + void negativeFieldIndex() { + assertThrows(Exception.class, () -> lb.lambda(List.of(R.I32), params -> params.ref(-1))); + } + + // (x: i32) -> (y: i64) -> x[5] — outer field index 5 is out of bounds + @Test + void nestedOuterFieldIndexOutOfBounds() { + assertThrows( + IndexOutOfBoundsException.class, + () -> lb.lambda(List.of(R.I32), outer -> lb.lambda(List.of(R.I64), inner -> outer.ref(5)))); + } + + // (x: i32) -> (y: i64) -> y[3] — inner field index 3 is out of bounds (only 1 param) + @Test + void nestedInnerFieldIndexOutOfBounds() { + assertThrows( + IndexOutOfBoundsException.class, + () -> lb.lambda(List.of(R.I32), outer -> lb.lambda(List.of(R.I64), inner -> inner.ref(3)))); + } +} diff --git a/core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java index c0a979a94..3b9cfedad 100644 --- a/core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java @@ -1,361 +1,298 @@ package io.substrait.type.proto; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertInstanceOf; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; import io.substrait.TestBase; import io.substrait.expression.Expression; import io.substrait.expression.ExpressionCreator; import io.substrait.expression.FieldReference; +import io.substrait.expression.LambdaBuilder; +import io.substrait.extension.DefaultExtensionCatalog; import io.substrait.type.Type; import java.util.List; import org.junit.jupiter.api.Test; -/** - * Tests for Lambda expression round-trip conversion through protobuf. Based on equivalent tests - * from substrait-go. - */ class LambdaExpressionRoundtripTest extends TestBase { - /** Test that lambdas with no parameters are valid. Building: () -> i32(42) : func<() -> i32> */ - @Test - void zeroParameterLambda() { - Type.Struct emptyParams = Type.Struct.builder().nullable(false).build(); + final LambdaBuilder lb = new LambdaBuilder(); - Expression body = ExpressionCreator.i32(false, 42); + // ==================== Single Lambda Tests ==================== - Expression.Lambda lambda = - Expression.Lambda.builder().parameters(emptyParams).body(body).build(); + // () -> 42 + @Test + void zeroParameterLambda() { + Expression.Lambda lambda = lb.lambda(List.of(), params -> ExpressionCreator.i32(false, 42)); verifyRoundTrip(lambda); - // Verify the lambda type - Type lambdaType = lambda.getType(); - assertInstanceOf(Type.Func.class, lambdaType); - Type.Func funcType = (Type.Func) lambdaType; + Type.Func funcType = (Type.Func) lambda.getType(); assertEquals(0, funcType.parameterTypes().size()); assertEquals(R.I32, funcType.returnType()); } - /** Test valid stepsOut=0 references. Building: ($0: i32) -> $0 : func i32> */ + // (x: i32) -> x @Test - void validStepsOut0() { - Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); - - // Lambda body references parameter 0 with stepsOut=0 - FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); - - Expression.Lambda lambda = - Expression.Lambda.builder().parameters(params).body(paramRef).build(); + void identityLambda() { + Expression.Lambda lambda = lb.lambda(List.of(R.I32), params -> params.ref(0)); verifyRoundTrip(lambda); - // Verify types - Type lambdaType = lambda.getType(); - assertInstanceOf(Type.Func.class, lambdaType); - Type.Func funcType = (Type.Func) lambdaType; + Type.Func funcType = (Type.Func) lambda.getType(); assertEquals(1, funcType.parameterTypes().size()); assertEquals(R.I32, funcType.parameterTypes().get(0)); assertEquals(R.I32, funcType.returnType()); + + assertInstanceOf(FieldReference.class, lambda.body()); + FieldReference ref = (FieldReference) lambda.body(); + assertTrue(ref.isLambdaParameterReference()); + assertEquals(0, ref.lambdaParameterReferenceStepsOut().orElse(-1)); } - /** - * Test valid field index with multiple parameters. Building: ($0: i32, $1: i64, $2: string) -> $2 - * : func<(i32, i64, string) -> string> - */ + // (x: i32, y: i64, z: string) -> z @Test void validFieldIndex() { - Type.Struct params = - Type.Struct.builder().nullable(false).addFields(R.I32, R.I64, R.STRING).build(); + Expression.Lambda lambda = lb.lambda(List.of(R.I32, R.I64, R.STRING), params -> params.ref(2)); - // Reference the 3rd parameter (string) - FieldReference paramRef = FieldReference.newLambdaParameterReference(2, params, 0); + verifyRoundTrip(lambda); + assertEquals(R.STRING, ((Type.Func) lambda.getType()).returnType()); + } + // (x: i32) -> 42 + @Test + void lambdaWithLiteralBody() { Expression.Lambda lambda = - Expression.Lambda.builder().parameters(params).body(paramRef).build(); + lb.lambda(List.of(R.I32), params -> ExpressionCreator.i32(false, 42)); verifyRoundTrip(lambda); - - // Verify return type is string - Type.Func funcType = (Type.Func) lambda.getType(); - assertEquals(R.STRING, funcType.returnType()); + assertInstanceOf(Expression.I32Literal.class, lambda.body()); } - /** Test type resolution for different parameter types. */ + // Parameterized: (params...) -> params[fieldIndex], verifying type resolution @Test void typeResolution() { - // Test cases: (paramTypes, fieldIndex, expectedReturnType) - record TestCase(List paramTypes, int fieldIndex, Type expectedType) {} + record TestCase(String name, List paramTypes, int fieldIndex, Type expectedType) {} List testCases = List.of( - new TestCase(List.of(R.I32), 0, R.I32), - new TestCase(List.of(R.I32, R.I64), 1, R.I64), - new TestCase(List.of(R.I32, R.I64, R.STRING), 2, R.STRING), - new TestCase(List.of(R.FP64), 0, R.FP64), - new TestCase(List.of(R.BOOLEAN, R.DATE, R.TIMESTAMP), 1, R.DATE)); + new TestCase("first param (i32)", List.of(R.I32), 0, R.I32), + new TestCase("second param (i64)", List.of(R.I32, R.I64), 1, R.I64), + new TestCase("third param (string)", List.of(R.I32, R.I64, R.STRING), 2, R.STRING), + new TestCase("float64 param", List.of(R.FP64), 0, R.FP64), + new TestCase("date param", List.of(R.BOOLEAN, R.DATE, R.TIMESTAMP), 1, R.DATE)); for (TestCase tc : testCases) { - Type.Struct params = - Type.Struct.builder().nullable(false).addAllFields(tc.paramTypes).build(); - - FieldReference paramRef = - FieldReference.newLambdaParameterReference(tc.fieldIndex, params, 0); - - Expression.Lambda lambda = - Expression.Lambda.builder().parameters(params).body(paramRef).build(); + Expression.Lambda lambda = lb.lambda(tc.paramTypes, params -> params.ref(tc.fieldIndex)); verifyRoundTrip(lambda); - // Verify the body type matches expected - assertEquals( - tc.expectedType, - lambda.body().getType(), - "Body type should match referenced parameter type"); - - // Verify lambda return type + assertEquals(tc.expectedType, lambda.body().getType(), tc.name + ": body type mismatch"); Type.Func funcType = (Type.Func) lambda.getType(); assertEquals( - tc.expectedType, funcType.returnType(), "Lambda return type should match body type"); + tc.expectedType, funcType.returnType(), tc.name + ": lambda return type mismatch"); } } - /** - * Test nested lambda with outer reference. Building: ($0: i64, $1: i64) -> (($0: i32) -> - * outer[$0] : i64) : func<(i64, i64) -> func i64>> - */ + // (x: i32, y: string) -> y — verify full Func type structure @Test - void nestedLambdaWithOuterRef() { - Type.Struct outerParams = Type.Struct.builder().nullable(false).addFields(R.I64, R.I64).build(); - - Type.Struct innerParams = Type.Struct.builder().nullable(false).addFields(R.I32).build(); - - // Inner lambda references outer's parameter 0 with stepsOut=1 - FieldReference outerRef = FieldReference.newLambdaParameterReference(0, outerParams, 1); - - Expression.Lambda innerLambda = - Expression.Lambda.builder().parameters(innerParams).body(outerRef).build(); - - Expression.Lambda outerLambda = - Expression.Lambda.builder().parameters(outerParams).body(innerLambda).build(); - - verifyRoundTrip(outerLambda); + void lambdaGetTypeReturnsFunc() { + Expression.Lambda lambda = lb.lambda(List.of(R.I32, R.STRING), params -> params.ref(1)); - // Verify structure - assertInstanceOf(Expression.Lambda.class, outerLambda.body()); - Expression.Lambda resultInner = (Expression.Lambda) outerLambda.body(); - assertEquals(1, resultInner.parameters().fields().size()); + Type.Func funcType = (Type.Func) lambda.getType(); + assertEquals(2, funcType.parameterTypes().size()); + assertEquals(R.I32, funcType.parameterTypes().get(0)); + assertEquals(R.STRING, funcType.parameterTypes().get(1)); + assertEquals(R.STRING, funcType.returnType()); } - /** - * Test outer reference type resolution in nested lambdas. Building: ($0: i32, $1: i64, $2: - * string) -> (($0: fp64) -> outer[$2] : string) : func<...> - */ + // (x: i32, y: i64, z: string) -> ... — verify FieldReference metadata for each param @Test - void outerRefTypeResolution() { - Type.Struct outerParams = - Type.Struct.builder().nullable(false).addFields(R.I32, R.I64, R.STRING).build(); - - Type.Struct innerParams = Type.Struct.builder().nullable(false).addFields(R.FP64).build(); - - // Inner references outer's field 2 (string) with stepsOut=1 - FieldReference outerRef = FieldReference.newLambdaParameterReference(2, outerParams, 1); - - Expression.Lambda innerLambda = - Expression.Lambda.builder().parameters(innerParams).body(outerRef).build(); - - Expression.Lambda outerLambda = - Expression.Lambda.builder().parameters(outerParams).body(innerLambda).build(); - - verifyRoundTrip(outerLambda); - - // Verify inner lambda's return type is string (from outer param 2) - Expression.Lambda resultInner = (Expression.Lambda) outerLambda.body(); - Type.Func innerFuncType = (Type.Func) resultInner.getType(); - assertEquals( - R.STRING, - innerFuncType.returnType(), - "Inner lambda return type should be string from outer.$2"); - - // Verify body's type is also string - assertEquals(R.STRING, resultInner.body().getType(), "Body type should be string"); + void parameterReferenceMetadata() { + List paramTypes = List.of(R.I32, R.I64, R.STRING); + + lb.lambda( + paramTypes, + params -> { + for (int i = 0; i < 3; i++) { + FieldReference ref = params.ref(i); + assertTrue(ref.isLambdaParameterReference()); + assertFalse(ref.isOuterReference()); + assertFalse(ref.isSimpleRootReference()); + assertEquals(0, ref.lambdaParameterReferenceStepsOut().orElse(-1)); + assertEquals(paramTypes.get(i), ref.getType()); + } + return params.ref(0); + }); } - /** - * Test deeply nested field ref inside Cast. Building: ($0: i32) -> cast($0 as i64) : func - * i64> - */ - @Test - void deeplyNestedFieldRef() { - Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); - - FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); + // ==================== Expression Body Tests ==================== - Expression.Cast castExpr = - (Expression.Cast) - ExpressionCreator.cast(R.I64, paramRef, Expression.FailureBehavior.THROW_EXCEPTION); - - Expression.Lambda lambda = - Expression.Lambda.builder().parameters(params).body(castExpr).build(); + // (x: i64) -> (y1: i64, y2: i64) -> y1 * x + y2 + @Test + void nestedLambdaWithArithmeticBody() { + String ARITH = DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC; + + Expression.Lambda result = + lb.lambda( + List.of(R.I64), + outer -> + lb.lambda( + List.of(R.I64, R.I64), + inner -> { + // y1 * x + Expression multiply = + sb.scalarFn(ARITH, "multiply:i64_i64", R.I64, inner.ref(0), outer.ref(0)); + // (y1 * x) + y2 + return sb.scalarFn(ARITH, "add:i64_i64", R.I64, multiply, inner.ref(1)); + })); + + verifyRoundTrip(result); + + // Outer lambda returns a func type + Type.Func outerFuncType = (Type.Func) result.getType(); + assertInstanceOf(Type.Func.class, outerFuncType.returnType()); + + // Inner lambda returns i64 + Expression.Lambda innerLambda = (Expression.Lambda) result.body(); + Type.Func innerFuncType = (Type.Func) innerLambda.getType(); + assertEquals(R.I64, innerFuncType.returnType()); + + // Inner body is a scalar function (add) + assertInstanceOf(Expression.ScalarFunctionInvocation.class, innerLambda.body()); + } - verifyRoundTrip(lambda); + // ==================== Nested Lambda Tests ==================== - // Verify the nested FieldRef has its type resolved - Expression.Cast resultCast = (Expression.Cast) lambda.body(); - assertInstanceOf(FieldReference.class, resultCast.input()); - FieldReference resultFieldRef = (FieldReference) resultCast.input(); + // (x: i64, y: i64) -> (z: i32) -> x + @Test + void nestedLambdaWithOuterRef() { + Expression.Lambda result = + lb.lambda(List.of(R.I64, R.I64), outer -> lb.lambda(List.of(R.I32), inner -> outer.ref(0))); - assertNotNull(resultFieldRef.getType(), "Nested FieldRef should have type resolved"); - assertEquals(R.I32, resultFieldRef.getType(), "Should resolve to i32"); + verifyRoundTrip(result); - // Verify lambda return type is i64 (cast output) - Type.Func funcType = (Type.Func) lambda.getType(); - assertEquals(R.I64, funcType.returnType()); + Expression.Lambda resultInner = (Expression.Lambda) result.body(); + assertEquals(1, resultInner.parameters().fields().size()); + assertEquals(R.I64, resultInner.body().getType()); } - /** - * Test doubly nested field ref (Cast(Cast(LambdaParamRef))). Building: ($0: i32) -> cast(cast($0 - * as i64) as string) : func string> - */ + // (x: i32, y: i64, z: string) -> (w: fp64) -> z @Test - void doublyNestedFieldRef() { - Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); - - FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); - Expression.Cast innerCast = - (Expression.Cast) - ExpressionCreator.cast(R.I64, paramRef, Expression.FailureBehavior.THROW_EXCEPTION); - Expression.Cast outerCast = - (Expression.Cast) - ExpressionCreator.cast(R.STRING, innerCast, Expression.FailureBehavior.THROW_EXCEPTION); - - Expression.Lambda lambda = - Expression.Lambda.builder().parameters(params).body(outerCast).build(); + void nestedLambdaOuterRefTypeResolution() { + Expression.Lambda result = + lb.lambda( + List.of(R.I32, R.I64, R.STRING), + outer -> lb.lambda(List.of(R.FP64), inner -> outer.ref(2))); - verifyRoundTrip(lambda); - - // Navigate to the deeply nested FieldRef (2 levels deep) - Expression.Cast resultOuter = (Expression.Cast) lambda.body(); - Expression.Cast resultInner = (Expression.Cast) resultOuter.input(); - FieldReference resultFieldRef = (FieldReference) resultInner.input(); + verifyRoundTrip(result); - // Verify type is resolved even at depth 2 - assertNotNull(resultFieldRef.getType(), "FieldRef at depth 2 should have type resolved"); - assertEquals(R.I32, resultFieldRef.getType()); + Expression.Lambda resultInner = (Expression.Lambda) result.body(); + assertEquals(R.STRING, ((Type.Func) resultInner.getType()).returnType()); } - /** - * Test lambda with literal body (no parameter references). Building: ($0: i32) -> 42 : func i32> - */ + // (x: i32) -> (y: i64) -> y @Test - void lambdaWithLiteralBody() { - Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); + void nestedLambdaInnerRefOnly() { + Expression.Lambda result = + lb.lambda(List.of(R.I32), outer -> lb.lambda(List.of(R.I64), inner -> inner.ref(0))); - Expression body = ExpressionCreator.i32(false, 42); + verifyRoundTrip(result); - Expression.Lambda lambda = Expression.Lambda.builder().parameters(params).body(body).build(); - - verifyRoundTrip(lambda); + Expression.Lambda innerLambda = (Expression.Lambda) result.body(); + assertEquals(R.I64, innerLambda.body().getType()); + assertInstanceOf(Type.Func.class, ((Type.Func) result.getType()).returnType()); } - /** Test lambda getType returns correct Func type. */ + // (x: i32) -> (y: i64) -> (x, y) — body references both outer and inner params @Test - void lambdaGetTypeReturnsFunc() { - Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32, R.STRING).build(); - - FieldReference paramRef = FieldReference.newLambdaParameterReference(1, params, 0); - - Expression.Lambda lambda = - Expression.Lambda.builder().parameters(params).body(paramRef).build(); - - Type lambdaType = lambda.getType(); - - assertInstanceOf(Type.Func.class, lambdaType); - Type.Func funcType = (Type.Func) lambdaType; - - assertEquals(2, funcType.parameterTypes().size()); - assertEquals(R.I32, funcType.parameterTypes().get(0)); - assertEquals(R.STRING, funcType.parameterTypes().get(1)); - assertEquals(R.STRING, funcType.returnType()); // body references param 1 which is STRING + void nestedLambdaBothInnerAndOuterRefs() { + Expression.Lambda result = + lb.lambda( + List.of(R.I32), + outer -> + lb.lambda( + List.of(R.I64), + inner -> { + FieldReference innerRef = inner.ref(0); + assertEquals(R.I64, innerRef.getType()); + assertEquals(0, innerRef.lambdaParameterReferenceStepsOut().orElse(-1)); + + FieldReference outerRef = outer.ref(0); + assertEquals(R.I32, outerRef.getType()); + assertEquals(1, outerRef.lambdaParameterReferenceStepsOut().orElse(-1)); + + return innerRef; + })); + + verifyRoundTrip(result); } - // ==================== Validation Error Tests ==================== - - /** - * Test that invalid outer reference (stepsOut too high) fails during proto conversion. Building: - * ($0: i32) -> outer[$0] : INVALID (no outer lambda, stepsOut=1) - */ + // (a: i32, b: string) -> (c: i64, d: fp64) -> b — verify all 4 params resolve correctly @Test - void invalidOuterRef_stepsOutTooHigh() { - Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); - - // Create a parameter reference with stepsOut=1 but no outer lambda exists - FieldReference invalidRef = FieldReference.newLambdaParameterReference(0, params, 1); - - Expression.Lambda lambda = - Expression.Lambda.builder().parameters(params).body(invalidRef).build(); - - // Convert to proto - this should work - io.substrait.proto.Expression protoExpression = expressionProtoConverter.toProto(lambda); - - // Converting back should fail because stepsOut=1 references non-existent outer lambda - assertThrows( - IllegalArgumentException.class, - () -> { - protoExpressionConverter.from(protoExpression); - }, - "Should fail when stepsOut references non-existent outer lambda"); + void nestedLambdaMultiParamCorrectResolution() { + Expression.Lambda result = + lb.lambda( + List.of(R.I32, R.STRING), + outer -> + lb.lambda( + List.of(R.I64, R.FP64), + inner -> { + assertEquals(R.I64, inner.ref(0).getType()); + assertEquals(R.FP64, inner.ref(1).getType()); + assertEquals(R.I32, outer.ref(0).getType()); + assertEquals(R.STRING, outer.ref(1).getType()); + + return outer.ref(1); + })); + + verifyRoundTrip(result); + + Expression.Lambda innerLambda = (Expression.Lambda) result.body(); + assertEquals(R.STRING, ((Type.Func) innerLambda.getType()).returnType()); } - /** - * Test that invalid field index (out of bounds) fails during proto conversion. Building: ($0: - * i32) -> $5 : INVALID (only has 1 param) - */ + // (x: i32) -> (y: i64) -> (z: string) -> x @Test - void invalidFieldIndex_outOfBounds() { - Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); - - // Create a reference to field 5, but lambda only has 1 parameter (index 0) - // This will fail at build time since newLambdaParameterReference accesses fields.get(5) - assertThrows( - IndexOutOfBoundsException.class, - () -> { - FieldReference.newLambdaParameterReference(5, params, 0); - }, - "Should fail when field index is out of bounds"); + void tripleNestedLambdaRoundtrip() { + Expression.Lambda result = + lb.lambda( + List.of(R.I32), + outer -> + lb.lambda( + List.of(R.I64), mid -> lb.lambda(List.of(R.STRING), inner -> outer.ref(0)))); + + verifyRoundTrip(result); + + Expression.Lambda l1 = (Expression.Lambda) result.body(); + Expression.Lambda l2 = (Expression.Lambda) l1.body(); + assertEquals(R.I32, l2.body().getType()); } - /** - * Test nested invalid outer ref (stepsOut=2 but only 1 outer lambda). Building: ($0: i64) -> - * (($0: i32) -> outer.outer[$0]) : INVALID (no grandparent lambda) - */ + // (x: i32) -> (y: i64) -> (z: string) -> ... — verify stepsOut is auto-computed at each level @Test - void nestedInvalidOuterRef() { - Type.Struct outerParams = Type.Struct.builder().nullable(false).addFields(R.I64).build(); - - Type.Struct innerParams = Type.Struct.builder().nullable(false).addFields(R.I32).build(); - - // Inner lambda references stepsOut=2, but only 1 outer lambda exists - FieldReference invalidRef = FieldReference.newLambdaParameterReference(0, outerParams, 2); - - Expression.Lambda innerLambda = - Expression.Lambda.builder().parameters(innerParams).body(invalidRef).build(); - - Expression.Lambda outerLambda = - Expression.Lambda.builder().parameters(outerParams).body(innerLambda).build(); - - // Convert to proto - io.substrait.proto.Expression protoExpression = expressionProtoConverter.toProto(outerLambda); - - // Converting back should fail because stepsOut=2 references non-existent grandparent - assertThrows( - IllegalArgumentException.class, - () -> { - protoExpressionConverter.from(protoExpression); - }, - "Should fail when stepsOut references non-existent grandparent lambda"); + void tripleNestedLambdaScopeTracking() { + lb.lambda( + List.of(R.I32), + outer -> + lb.lambda( + List.of(R.I64), + mid -> + lb.lambda( + List.of(R.STRING), + inner -> { + assertEquals(R.STRING, inner.ref(0).getType()); + assertEquals(R.I64, mid.ref(0).getType()); + assertEquals(R.I32, outer.ref(0).getType()); + + assertEquals( + 0, inner.ref(0).lambdaParameterReferenceStepsOut().orElse(-1)); + assertEquals(1, mid.ref(0).lambdaParameterReferenceStepsOut().orElse(-1)); + assertEquals( + 2, outer.ref(0).lambdaParameterReferenceStepsOut().orElse(-1)); + + return inner.ref(0); + }))); } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java index 176b246e7..d5330681d 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/RexExpressionConverter.java @@ -2,6 +2,7 @@ import io.substrait.expression.Expression; import io.substrait.expression.FieldReference; +import io.substrait.expression.ImmutableExpression; import io.substrait.isthmus.CallConverter; import io.substrait.isthmus.SubstraitRelVisitor; import io.substrait.isthmus.TypeConverter; @@ -212,7 +213,7 @@ public Expression visitLambda(RexLambda rexLambda) { Expression body = rexLambda.getExpression().accept(this); - return Expression.Lambda.builder().parameters(parameters).body(body).build(); + return ImmutableExpression.Lambda.builder().parameters(parameters).body(body).build(); } @Override diff --git a/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java b/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java index add746ee1..55924081d 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java @@ -4,146 +4,57 @@ import io.substrait.expression.Expression; import io.substrait.expression.ExpressionCreator; -import io.substrait.expression.FieldReference; +import io.substrait.expression.LambdaBuilder; import io.substrait.relation.Project; import io.substrait.relation.Rel; -import io.substrait.type.Type; import java.util.ArrayList; import java.util.List; import org.junit.jupiter.api.Test; -/** - * Tests for Lambda expression conversion between Substrait and Calcite. Note: Calcite does not - * support nested lambda expressions for the moment, so all tests use stepsOut=0. - */ class LambdaExpressionTest extends PlanTestBase { final Rel emptyTable = sb.emptyVirtualTableScan(); + final LambdaBuilder lb = new LambdaBuilder(); - /** Test that lambdas with no parameters are valid. Building: () -> i32(42) : func<() -> i32> */ + // () -> 42 @Test void lambdaExpressionZeroParameters() { - Type.Struct params = Type.Struct.builder().nullable(false).build(); + List exprs = new ArrayList<>(); + exprs.add(lb.lambda(List.of(), params -> ExpressionCreator.i32(false, 42))); - Expression body = ExpressionCreator.i32(false, 42); - List expressionList = new ArrayList<>(); - Expression.Lambda lambda = Expression.Lambda.builder().parameters(params).body(body).build(); - expressionList.add(lambda); - - Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + Project project = Project.builder().expressions(exprs).input(emptyTable).build(); assertFullRoundTrip(project); } - /** - * Test valid field index with multiple parameters. Building: ($0: i32, $1: i64, $2: string) -> $2 - * : func<(i32, i64, string) -> string> - */ + // (x: i32, y: i64, z: string) -> x @Test void validFieldIndex() { - Type.Struct params = - Type.Struct.builder().nullable(false).addFields(R.I32, R.I64, R.STRING).build(); - - FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); - - List expressionList = new ArrayList<>(); - Expression.Lambda lambda = - Expression.Lambda.builder().parameters(params).body(paramRef).build(); - - expressionList.add(lambda); - - Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); - assertFullRoundTrip(project); - } - - /** - * Test deeply nested field ref inside Cast. Building: ($0: i32) -> cast($0 as i64) : func - * i64> - */ - @Test - void deeplyNestedFieldRef() { - Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); - FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); - - Expression.Cast castExpr = - (Expression.Cast) - ExpressionCreator.cast(R.I64, paramRef, Expression.FailureBehavior.THROW_EXCEPTION); - List expressionList = new ArrayList<>(); - - Expression.Lambda lambda = - Expression.Lambda.builder().parameters(params).body(castExpr).build(); - - expressionList.add(lambda); - Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); - assertFullRoundTrip(project); - } - - /** - * Test doubly nested field ref (Cast(Cast(LambdaParamRef))). Building: ($0: i32) -> cast(cast($0 - * as i64) as string) : func string> - */ - @Test - void doublyNestedFieldRef() { - Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); - - FieldReference paramRef = FieldReference.newLambdaParameterReference(0, params, 0); - Expression.Cast innerCast = - (Expression.Cast) - ExpressionCreator.cast(R.I64, paramRef, Expression.FailureBehavior.THROW_EXCEPTION); - Expression.Cast outerCast = - (Expression.Cast) - ExpressionCreator.cast(R.STRING, innerCast, Expression.FailureBehavior.THROW_EXCEPTION); + List exprs = new ArrayList<>(); + exprs.add(lb.lambda(List.of(R.I32, R.I64, R.STRING), params -> params.ref(0))); - List expressionList = new ArrayList<>(); - - Expression.Lambda lambda = - Expression.Lambda.builder().parameters(params).body(outerCast).build(); - - expressionList.add(lambda); - Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + Project project = Project.builder().expressions(exprs).input(emptyTable).build(); assertFullRoundTrip(project); } - /** - * Test lambda with literal body (no parameter references). Building: ($0: i32) -> 42 : func i32> - */ + // (x: i32) -> 42 @Test void lambdaWithLiteralBody() { - Type.Struct params = Type.Struct.builder().nullable(false).addFields(R.I32).build(); - - Expression body = ExpressionCreator.i32(false, 42); - List expressionList = new ArrayList<>(); + List exprs = new ArrayList<>(); + exprs.add(lb.lambda(List.of(R.I32), params -> ExpressionCreator.i32(false, 42))); - Expression.Lambda lambda = Expression.Lambda.builder().parameters(params).body(body).build(); - - expressionList.add(lambda); - Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + Project project = Project.builder().expressions(exprs).input(emptyTable).build(); assertFullRoundTrip(project); } - /** - * Test that nested lambda (stepsOut > 0) throws UnsupportedOperationException. Calcite does not - * support nested lambda expressions. - */ + // (x: i64) -> (y: i32) -> x — Calcite doesn't support nested lambdas @Test void nestedLambdaThrowsUnsupportedOperation() { - Type.Struct outerParams = Type.Struct.builder().nullable(false).addFields(R.I64).build(); - - Type.Struct innerParams = Type.Struct.builder().nullable(false).addFields(R.I32).build(); - - // Inner lambda references outer's parameter with stepsOut=1 - FieldReference outerRef = FieldReference.newLambdaParameterReference(0, outerParams, 1); - - Expression.Lambda innerLambda = - Expression.Lambda.builder().parameters(innerParams).body(outerRef).build(); - - List expressionList = new ArrayList<>(); - Expression.Lambda outerLambda = - Expression.Lambda.builder().parameters(outerParams).body(innerLambda).build(); + lb.lambda(List.of(R.I64), outer -> lb.lambda(List.of(R.I32), inner -> outer.ref(0))); - expressionList.add(outerLambda); - Project project = Project.builder().expressions(expressionList).input(emptyTable).build(); + List exprs = new ArrayList<>(); + exprs.add(outerLambda); + Project project = Project.builder().expressions(exprs).input(emptyTable).build(); assertThrows(UnsupportedOperationException.class, () -> assertFullRoundTrip(project)); } } From 50db6998d84c3ee9af532e5ca2073e14e727287d Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Mon, 16 Mar 2026 17:28:21 -0400 Subject: [PATCH 13/18] refactor: unify lambda validation, add JSON-based roundtrip tests Moves ProtoExpressionConverter to use LambdaBuilder for lambda parameter validation, removing the private LambdaParameterStack. Replaces builder-based roundtrip tests with parameterized JSON fixtures in expressions/lambda/. Adds arithmetic body test to isthmus LambdaExpressionTest. --- .../substrait/expression/LambdaBuilder.java | 39 +++ .../proto/ProtoExpressionConverter.java | 56 +--- .../expression/LambdaBuilderTest.java | 51 ++- .../proto/LambdaExpressionRoundtripTest.java | 313 ++---------------- .../lambda/invalid/nested_steps_out.json | 30 ++ .../expressions/lambda/invalid/steps_out.json | 21 ++ .../expressions/lambda/valid/identity.json | 21 ++ .../lambda/valid/literal_body.json | 14 + .../expressions/lambda/valid/multi_param.json | 23 ++ .../expressions/lambda/valid/nested.json | 30 ++ .../lambda/valid/triple_nested.json | 39 +++ .../expressions/lambda/valid/zero_params.json | 12 + .../isthmus/LambdaExpressionTest.java | 33 ++ 13 files changed, 348 insertions(+), 334 deletions(-) create mode 100644 core/src/test/resources/expressions/lambda/invalid/nested_steps_out.json create mode 100644 core/src/test/resources/expressions/lambda/invalid/steps_out.json create mode 100644 core/src/test/resources/expressions/lambda/valid/identity.json create mode 100644 core/src/test/resources/expressions/lambda/valid/literal_body.json create mode 100644 core/src/test/resources/expressions/lambda/valid/multi_param.json create mode 100644 core/src/test/resources/expressions/lambda/valid/nested.json create mode 100644 core/src/test/resources/expressions/lambda/valid/triple_nested.json create mode 100644 core/src/test/resources/expressions/lambda/valid/zero_params.json diff --git a/core/src/main/java/io/substrait/expression/LambdaBuilder.java b/core/src/main/java/io/substrait/expression/LambdaBuilder.java index b2b89c02d..1197cd1f1 100644 --- a/core/src/main/java/io/substrait/expression/LambdaBuilder.java +++ b/core/src/main/java/io/substrait/expression/LambdaBuilder.java @@ -54,6 +54,45 @@ public Expression.Lambda lambda(List paramTypes, Function bodyFn) { + pushLambdaContext(params); + try { + Expression body = bodyFn.get(); + return ImmutableExpression.Lambda.builder().parameters(params).body(body).build(); + } finally { + popLambdaContext(); + } + } + + /** + * Resolves the parameter struct for a lambda at the given stepsOut from the current innermost + * scope. Used by internal converters to validate lambda parameter references during + * deserialization. + * + * @param stepsOut number of lambda scopes to traverse outward (0 = current/innermost) + * @return the parameter struct at the target scope level + * @throws IllegalArgumentException if stepsOut exceeds the current nesting depth + */ + public Type.Struct resolveParams(int stepsOut) { + int index = lambdaContext.size() - 1 - stepsOut; + if (index < 0 || index >= lambdaContext.size()) { + throw new IllegalArgumentException( + String.format( + "Lambda parameter reference with stepsOut=%d is invalid (current depth: %d)", + stepsOut, lambdaContext.size())); + } + return lambdaContext.get(index); + } + /** * Pushes a lambda's parameters onto the context stack. This makes the parameters available for * validation when building the lambda's body, and allows nested lambda parameter references to diff --git a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java index 426b0f56a..470c13c6a 100644 --- a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java @@ -6,7 +6,7 @@ import io.substrait.expression.FieldReference.ReferenceSegment; import io.substrait.expression.FunctionArg; import io.substrait.expression.FunctionOption; -import io.substrait.expression.ImmutableExpression; +import io.substrait.expression.LambdaBuilder; import io.substrait.expression.WindowBound; import io.substrait.extension.ExtensionLookup; import io.substrait.extension.SimpleExtension; @@ -38,7 +38,7 @@ public class ProtoExpressionConverter { private final Type.Struct rootType; private final ProtoTypeConverter protoTypeConverter; private final ProtoRelConverter protoRelConverter; - private final LambdaParameterStack lambdaParameterStack = new LambdaParameterStack(); + private final LambdaBuilder lambdaBuilder = new LambdaBuilder(); public ProtoExpressionConverter( ExtensionLookup lookup, @@ -83,7 +83,7 @@ public FieldReference from(io.substrait.proto.Expression.FieldReference referenc reference.getLambdaParameterReference(); int stepsOut = lambdaParamRef.getStepsOut(); - Type.Struct lambdaParameters = lambdaParameterStack.get(stepsOut); + Type.Struct lambdaParameters = lambdaBuilder.resolveParams(stepsOut); // Check for unsupported nested field access if (reference.getDirectReference().getStructField().hasChild()) { @@ -283,7 +283,6 @@ public Type visit(Type.Struct type) throws RuntimeException { case LAMBDA: { - // TODO: Add build-time validation of lambda parameter references during deserialization. io.substrait.proto.Expression.Lambda protoLambda = expr.getLambda(); Type.Struct parameters = (Type.Struct) @@ -292,16 +291,7 @@ public Type visit(Type.Struct type) throws RuntimeException { .setStruct(protoLambda.getParameters()) .build()); - lambdaParameterStack.push(parameters); - - Expression body; - try { - body = from(protoLambda.getBody()); - } finally { - lambdaParameterStack.pop(); - } - - return ImmutableExpression.Lambda.builder().parameters(parameters).body(body).build(); + return lambdaBuilder.lambdaFromStruct(parameters, () -> from(protoLambda.getBody())); } // TODO enum. case ENUM: @@ -622,42 +612,4 @@ public Expression.SortField fromSortField(SortField s) { public static FunctionOption fromFunctionOption(io.substrait.proto.FunctionOption o) { return FunctionOption.builder().name(o.getName()).addAllValues(o.getPreferenceList()).build(); } - - /** - * A stack for tracking lambda parameter types during expression parsing. - * - *

When parsing nested lambda expressions, each lambda's parameters are pushed onto this stack. - * Lambda parameter references use "stepsOut" to indicate which enclosing lambda they reference: - * - *

    - *
  • stepsOut=0 refers to the innermost (current) lambda - *
  • stepsOut=1 refers to the next enclosing lambda - *
  • stepsOut=N refers to N levels up - *
- */ - private static class LambdaParameterStack { - private final List stack = new ArrayList<>(); - - void push(Type.Struct parameters) { - stack.add(parameters); - } - - void pop() { - if (stack.isEmpty()) { - throw new IllegalArgumentException("Lambda parameter stack is empty"); - } - stack.remove(stack.size() - 1); - } - - Type.Struct get(int stepsOut) { - int index = stack.size() - 1 - stepsOut; - if (index < 0 || index >= stack.size()) { - throw new IllegalArgumentException( - String.format( - "Lambda parameter reference with stepsOut=%d is invalid (current depth: %d)", - stepsOut, stack.size())); - } - return stack.get(index); - } - } } diff --git a/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java b/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java index 7d8f8976b..492c56f30 100644 --- a/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java +++ b/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java @@ -1,32 +1,73 @@ package io.substrait.expression; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; +import io.substrait.type.Type; import io.substrait.type.TypeCreator; import java.util.List; import org.junit.jupiter.api.Test; -/** Tests for {@link LambdaBuilder} build-time validation. */ +/** Tests for {@link LambdaBuilder}. */ class LambdaBuilderTest { static final TypeCreator R = TypeCreator.REQUIRED; final LambdaBuilder lb = new LambdaBuilder(); - // (x: i32) -> x[5] — field index 5 is out of bounds (only 1 param) + // (x: i32)@p -> p[0] + @Test + void simpleLambda() { + Expression.Lambda lambda = lb.lambda(List.of(R.I32), params -> params.ref(0)); + + Expression.Lambda expected = + ImmutableExpression.Lambda.builder() + .parameters(Type.Struct.builder().nullable(false).addFields(R.I32).build()) + .body( + FieldReference.newLambdaParameterReference( + 0, Type.Struct.builder().nullable(false).addFields(R.I32).build(), 0)) + .build(); + + assertEquals(expected, lambda); + } + + // (x: i32)@outer -> (y: i64)@inner -> outer[0] + @Test + void nestedLambda() { + Expression.Lambda lambda = + lb.lambda(List.of(R.I32), outer -> lb.lambda(List.of(R.I64), inner -> outer.ref(0))); + + Expression.Lambda expectedInner = + ImmutableExpression.Lambda.builder() + .parameters(Type.Struct.builder().nullable(false).addFields(R.I64).build()) + .body( + FieldReference.newLambdaParameterReference( + 0, Type.Struct.builder().nullable(false).addFields(R.I32).build(), 1)) + .build(); + + Expression.Lambda expected = + ImmutableExpression.Lambda.builder() + .parameters(Type.Struct.builder().nullable(false).addFields(R.I32).build()) + .body(expectedInner) + .build(); + + assertEquals(expected, lambda); + } + + // (x: i32)@p -> p[5] — only 1 param, index 5 is out of bounds @Test void invalidFieldIndex_outOfBounds() { assertThrows( IndexOutOfBoundsException.class, () -> lb.lambda(List.of(R.I32), params -> params.ref(5))); } - // (x: i32) -> x[-1] — negative field index + // (x: i32)@p -> p[-1] — negative index @Test void negativeFieldIndex() { assertThrows(Exception.class, () -> lb.lambda(List.of(R.I32), params -> params.ref(-1))); } - // (x: i32) -> (y: i64) -> x[5] — outer field index 5 is out of bounds + // (x: i32)@outer -> (y: i64)@inner -> outer[5] — outer only has 1 param @Test void nestedOuterFieldIndexOutOfBounds() { assertThrows( @@ -34,7 +75,7 @@ void nestedOuterFieldIndexOutOfBounds() { () -> lb.lambda(List.of(R.I32), outer -> lb.lambda(List.of(R.I64), inner -> outer.ref(5)))); } - // (x: i32) -> (y: i64) -> y[3] — inner field index 3 is out of bounds (only 1 param) + // (x: i32)@outer -> (y: i64)@inner -> inner[3] — inner only has 1 param @Test void nestedInnerFieldIndexOutOfBounds() { assertThrows( diff --git a/core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java index 3b9cfedad..1589c35ba 100644 --- a/core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/LambdaExpressionRoundtripTest.java @@ -1,298 +1,57 @@ package io.substrait.type.proto; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertInstanceOf; -import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertThrows; +import com.google.protobuf.util.JsonFormat; import io.substrait.TestBase; import io.substrait.expression.Expression; -import io.substrait.expression.ExpressionCreator; -import io.substrait.expression.FieldReference; -import io.substrait.expression.LambdaBuilder; -import io.substrait.extension.DefaultExtensionCatalog; -import io.substrait.type.Type; -import java.util.List; -import org.junit.jupiter.api.Test; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.stream.Stream; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; class LambdaExpressionRoundtripTest extends TestBase { - final LambdaBuilder lb = new LambdaBuilder(); - - // ==================== Single Lambda Tests ==================== - - // () -> 42 - @Test - void zeroParameterLambda() { - Expression.Lambda lambda = lb.lambda(List.of(), params -> ExpressionCreator.i32(false, 42)); - - verifyRoundTrip(lambda); - - Type.Func funcType = (Type.Func) lambda.getType(); - assertEquals(0, funcType.parameterTypes().size()); - assertEquals(R.I32, funcType.returnType()); - } - - // (x: i32) -> x - @Test - void identityLambda() { - Expression.Lambda lambda = lb.lambda(List.of(R.I32), params -> params.ref(0)); - - verifyRoundTrip(lambda); - - Type.Func funcType = (Type.Func) lambda.getType(); - assertEquals(1, funcType.parameterTypes().size()); - assertEquals(R.I32, funcType.parameterTypes().get(0)); - assertEquals(R.I32, funcType.returnType()); - - assertInstanceOf(FieldReference.class, lambda.body()); - FieldReference ref = (FieldReference) lambda.body(); - assertTrue(ref.isLambdaParameterReference()); - assertEquals(0, ref.lambdaParameterReferenceStepsOut().orElse(-1)); - } - - // (x: i32, y: i64, z: string) -> z - @Test - void validFieldIndex() { - Expression.Lambda lambda = lb.lambda(List.of(R.I32, R.I64, R.STRING), params -> params.ref(2)); - - verifyRoundTrip(lambda); - assertEquals(R.STRING, ((Type.Func) lambda.getType()).returnType()); - } - - // (x: i32) -> 42 - @Test - void lambdaWithLiteralBody() { - Expression.Lambda lambda = - lb.lambda(List.of(R.I32), params -> ExpressionCreator.i32(false, 42)); - - verifyRoundTrip(lambda); - assertInstanceOf(Expression.I32Literal.class, lambda.body()); - } - - // Parameterized: (params...) -> params[fieldIndex], verifying type resolution - @Test - void typeResolution() { - record TestCase(String name, List paramTypes, int fieldIndex, Type expectedType) {} - - List testCases = - List.of( - new TestCase("first param (i32)", List.of(R.I32), 0, R.I32), - new TestCase("second param (i64)", List.of(R.I32, R.I64), 1, R.I64), - new TestCase("third param (string)", List.of(R.I32, R.I64, R.STRING), 2, R.STRING), - new TestCase("float64 param", List.of(R.FP64), 0, R.FP64), - new TestCase("date param", List.of(R.BOOLEAN, R.DATE, R.TIMESTAMP), 1, R.DATE)); - - for (TestCase tc : testCases) { - Expression.Lambda lambda = lb.lambda(tc.paramTypes, params -> params.ref(tc.fieldIndex)); - - verifyRoundTrip(lambda); - - assertEquals(tc.expectedType, lambda.body().getType(), tc.name + ": body type mismatch"); - Type.Func funcType = (Type.Func) lambda.getType(); - assertEquals( - tc.expectedType, funcType.returnType(), tc.name + ": lambda return type mismatch"); - } - } - - // (x: i32, y: string) -> y — verify full Func type structure - @Test - void lambdaGetTypeReturnsFunc() { - Expression.Lambda lambda = lb.lambda(List.of(R.I32, R.STRING), params -> params.ref(1)); - - Type.Func funcType = (Type.Func) lambda.getType(); - assertEquals(2, funcType.parameterTypes().size()); - assertEquals(R.I32, funcType.parameterTypes().get(0)); - assertEquals(R.STRING, funcType.parameterTypes().get(1)); - assertEquals(R.STRING, funcType.returnType()); - } - - // (x: i32, y: i64, z: string) -> ... — verify FieldReference metadata for each param - @Test - void parameterReferenceMetadata() { - List paramTypes = List.of(R.I32, R.I64, R.STRING); - - lb.lambda( - paramTypes, - params -> { - for (int i = 0; i < 3; i++) { - FieldReference ref = params.ref(i); - assertTrue(ref.isLambdaParameterReference()); - assertFalse(ref.isOuterReference()); - assertFalse(ref.isSimpleRootReference()); - assertEquals(0, ref.lambdaParameterReferenceStepsOut().orElse(-1)); - assertEquals(paramTypes.get(i), ref.getType()); - } - return params.ref(0); - }); - } - - // ==================== Expression Body Tests ==================== - - // (x: i64) -> (y1: i64, y2: i64) -> y1 * x + y2 - @Test - void nestedLambdaWithArithmeticBody() { - String ARITH = DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC; - - Expression.Lambda result = - lb.lambda( - List.of(R.I64), - outer -> - lb.lambda( - List.of(R.I64, R.I64), - inner -> { - // y1 * x - Expression multiply = - sb.scalarFn(ARITH, "multiply:i64_i64", R.I64, inner.ref(0), outer.ref(0)); - // (y1 * x) + y2 - return sb.scalarFn(ARITH, "add:i64_i64", R.I64, multiply, inner.ref(1)); - })); - - verifyRoundTrip(result); - - // Outer lambda returns a func type - Type.Func outerFuncType = (Type.Func) result.getType(); - assertInstanceOf(Type.Func.class, outerFuncType.returnType()); - - // Inner lambda returns i64 - Expression.Lambda innerLambda = (Expression.Lambda) result.body(); - Type.Func innerFuncType = (Type.Func) innerLambda.getType(); - assertEquals(R.I64, innerFuncType.returnType()); - - // Inner body is a scalar function (add) - assertInstanceOf(Expression.ScalarFunctionInvocation.class, innerLambda.body()); - } - - // ==================== Nested Lambda Tests ==================== - - // (x: i64, y: i64) -> (z: i32) -> x - @Test - void nestedLambdaWithOuterRef() { - Expression.Lambda result = - lb.lambda(List.of(R.I64, R.I64), outer -> lb.lambda(List.of(R.I32), inner -> outer.ref(0))); - - verifyRoundTrip(result); - - Expression.Lambda resultInner = (Expression.Lambda) result.body(); - assertEquals(1, resultInner.parameters().fields().size()); - assertEquals(R.I64, resultInner.body().getType()); - } - - // (x: i32, y: i64, z: string) -> (w: fp64) -> z - @Test - void nestedLambdaOuterRefTypeResolution() { - Expression.Lambda result = - lb.lambda( - List.of(R.I32, R.I64, R.STRING), - outer -> lb.lambda(List.of(R.FP64), inner -> outer.ref(2))); - - verifyRoundTrip(result); - - Expression.Lambda resultInner = (Expression.Lambda) result.body(); - assertEquals(R.STRING, ((Type.Func) resultInner.getType()).returnType()); + static Stream validLambdaExpressions() throws IOException { + return listJsonResources("expressions/lambda/valid"); } - // (x: i32) -> (y: i64) -> y - @Test - void nestedLambdaInnerRefOnly() { - Expression.Lambda result = - lb.lambda(List.of(R.I32), outer -> lb.lambda(List.of(R.I64), inner -> inner.ref(0))); - - verifyRoundTrip(result); - - Expression.Lambda innerLambda = (Expression.Lambda) result.body(); - assertEquals(R.I64, innerLambda.body().getType()); - assertInstanceOf(Type.Func.class, ((Type.Func) result.getType()).returnType()); + static Stream invalidLambdaExpressions() throws IOException { + return listJsonResources("expressions/lambda/invalid"); } - // (x: i32) -> (y: i64) -> (x, y) — body references both outer and inner params - @Test - void nestedLambdaBothInnerAndOuterRefs() { - Expression.Lambda result = - lb.lambda( - List.of(R.I32), - outer -> - lb.lambda( - List.of(R.I64), - inner -> { - FieldReference innerRef = inner.ref(0); - assertEquals(R.I64, innerRef.getType()); - assertEquals(0, innerRef.lambdaParameterReferenceStepsOut().orElse(-1)); - - FieldReference outerRef = outer.ref(0); - assertEquals(R.I32, outerRef.getType()); - assertEquals(1, outerRef.lambdaParameterReferenceStepsOut().orElse(-1)); - - return innerRef; - })); - - verifyRoundTrip(result); + @ParameterizedTest + @MethodSource("validLambdaExpressions") + void validLambdaExpressionRoundtrip(String resourcePath) throws IOException { + Expression deserialized = deserializeExpression(resourcePath); + assertInstanceOf(Expression.Lambda.class, deserialized); + verifyRoundTrip(deserialized); } - // (a: i32, b: string) -> (c: i64, d: fp64) -> b — verify all 4 params resolve correctly - @Test - void nestedLambdaMultiParamCorrectResolution() { - Expression.Lambda result = - lb.lambda( - List.of(R.I32, R.STRING), - outer -> - lb.lambda( - List.of(R.I64, R.FP64), - inner -> { - assertEquals(R.I64, inner.ref(0).getType()); - assertEquals(R.FP64, inner.ref(1).getType()); - assertEquals(R.I32, outer.ref(0).getType()); - assertEquals(R.STRING, outer.ref(1).getType()); - - return outer.ref(1); - })); - - verifyRoundTrip(result); - - Expression.Lambda innerLambda = (Expression.Lambda) result.body(); - assertEquals(R.STRING, ((Type.Func) innerLambda.getType()).returnType()); + @ParameterizedTest + @MethodSource("invalidLambdaExpressions") + void invalidLambdaExpressionRejected(String resourcePath) { + assertThrows(Exception.class, () -> deserializeExpression(resourcePath)); } - // (x: i32) -> (y: i64) -> (z: string) -> x - @Test - void tripleNestedLambdaRoundtrip() { - Expression.Lambda result = - lb.lambda( - List.of(R.I32), - outer -> - lb.lambda( - List.of(R.I64), mid -> lb.lambda(List.of(R.STRING), inner -> outer.ref(0)))); - - verifyRoundTrip(result); - - Expression.Lambda l1 = (Expression.Lambda) result.body(); - Expression.Lambda l2 = (Expression.Lambda) l1.body(); - assertEquals(R.I32, l2.body().getType()); + private static Stream listJsonResources(String dirPath) throws IOException { + Path dir = + Paths.get( + LambdaExpressionRoundtripTest.class.getClassLoader().getResource(dirPath).getPath()); + return Files.list(dir) + .filter(p -> p.toString().endsWith(".json")) + .map(p -> dirPath + "/" + p.getFileName().toString()) + .sorted(); } - // (x: i32) -> (y: i64) -> (z: string) -> ... — verify stepsOut is auto-computed at each level - @Test - void tripleNestedLambdaScopeTracking() { - lb.lambda( - List.of(R.I32), - outer -> - lb.lambda( - List.of(R.I64), - mid -> - lb.lambda( - List.of(R.STRING), - inner -> { - assertEquals(R.STRING, inner.ref(0).getType()); - assertEquals(R.I64, mid.ref(0).getType()); - assertEquals(R.I32, outer.ref(0).getType()); - - assertEquals( - 0, inner.ref(0).lambdaParameterReferenceStepsOut().orElse(-1)); - assertEquals(1, mid.ref(0).lambdaParameterReferenceStepsOut().orElse(-1)); - assertEquals( - 2, outer.ref(0).lambdaParameterReferenceStepsOut().orElse(-1)); - - return inner.ref(0); - }))); + private Expression deserializeExpression(String resourcePath) throws IOException { + String json = asString(resourcePath); + io.substrait.proto.Expression.Builder builder = io.substrait.proto.Expression.newBuilder(); + JsonFormat.parser().merge(json, builder); + return protoExpressionConverter.from(builder.build()); } } diff --git a/core/src/test/resources/expressions/lambda/invalid/nested_steps_out.json b/core/src/test/resources/expressions/lambda/invalid/nested_steps_out.json new file mode 100644 index 000000000..54ec7321b --- /dev/null +++ b/core/src/test/resources/expressions/lambda/invalid/nested_steps_out.json @@ -0,0 +1,30 @@ +{ + "lambda": { + "parameters": { + "types": [ + { "i32": { "nullability": "NULLABILITY_REQUIRED" } } + ] + }, + "body": { + "lambda": { + "parameters": { + "types": [ + { "i64": { "nullability": "NULLABILITY_REQUIRED" } } + ] + }, + "body": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "lambdaParameterReference": { + "stepsOut": 2 + } + } + } + } + } + } +} diff --git a/core/src/test/resources/expressions/lambda/invalid/steps_out.json b/core/src/test/resources/expressions/lambda/invalid/steps_out.json new file mode 100644 index 000000000..148173e8b --- /dev/null +++ b/core/src/test/resources/expressions/lambda/invalid/steps_out.json @@ -0,0 +1,21 @@ +{ + "lambda": { + "parameters": { + "types": [ + { "i32": { "nullability": "NULLABILITY_REQUIRED" } } + ] + }, + "body": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "lambdaParameterReference": { + "stepsOut": 1 + } + } + } + } +} diff --git a/core/src/test/resources/expressions/lambda/valid/identity.json b/core/src/test/resources/expressions/lambda/valid/identity.json new file mode 100644 index 000000000..a984334ef --- /dev/null +++ b/core/src/test/resources/expressions/lambda/valid/identity.json @@ -0,0 +1,21 @@ +{ + "lambda": { + "parameters": { + "types": [ + { "i32": { "nullability": "NULLABILITY_REQUIRED" } } + ] + }, + "body": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "lambdaParameterReference": { + "stepsOut": 0 + } + } + } + } +} diff --git a/core/src/test/resources/expressions/lambda/valid/literal_body.json b/core/src/test/resources/expressions/lambda/valid/literal_body.json new file mode 100644 index 000000000..c36e2df5e --- /dev/null +++ b/core/src/test/resources/expressions/lambda/valid/literal_body.json @@ -0,0 +1,14 @@ +{ + "lambda": { + "parameters": { + "types": [ + { "i32": { "nullability": "NULLABILITY_REQUIRED" } } + ] + }, + "body": { + "literal": { + "i32": 42 + } + } + } +} diff --git a/core/src/test/resources/expressions/lambda/valid/multi_param.json b/core/src/test/resources/expressions/lambda/valid/multi_param.json new file mode 100644 index 000000000..1c31bc1e7 --- /dev/null +++ b/core/src/test/resources/expressions/lambda/valid/multi_param.json @@ -0,0 +1,23 @@ +{ + "lambda": { + "parameters": { + "types": [ + { "i32": { "nullability": "NULLABILITY_REQUIRED" } }, + { "i64": { "nullability": "NULLABILITY_REQUIRED" } }, + { "string": { "nullability": "NULLABILITY_REQUIRED" } } + ] + }, + "body": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "lambdaParameterReference": { + "stepsOut": 0 + } + } + } + } +} diff --git a/core/src/test/resources/expressions/lambda/valid/nested.json b/core/src/test/resources/expressions/lambda/valid/nested.json new file mode 100644 index 000000000..abf716b7a --- /dev/null +++ b/core/src/test/resources/expressions/lambda/valid/nested.json @@ -0,0 +1,30 @@ +{ + "lambda": { + "parameters": { + "types": [ + { "i32": { "nullability": "NULLABILITY_REQUIRED" } } + ] + }, + "body": { + "lambda": { + "parameters": { + "types": [ + { "i64": { "nullability": "NULLABILITY_REQUIRED" } } + ] + }, + "body": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "lambdaParameterReference": { + "stepsOut": 1 + } + } + } + } + } + } +} diff --git a/core/src/test/resources/expressions/lambda/valid/triple_nested.json b/core/src/test/resources/expressions/lambda/valid/triple_nested.json new file mode 100644 index 000000000..e1e692521 --- /dev/null +++ b/core/src/test/resources/expressions/lambda/valid/triple_nested.json @@ -0,0 +1,39 @@ +{ + "lambda": { + "parameters": { + "types": [ + { "i32": { "nullability": "NULLABILITY_REQUIRED" } } + ] + }, + "body": { + "lambda": { + "parameters": { + "types": [ + { "i64": { "nullability": "NULLABILITY_REQUIRED" } } + ] + }, + "body": { + "lambda": { + "parameters": { + "types": [ + { "string": { "nullability": "NULLABILITY_REQUIRED" } } + ] + }, + "body": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "lambdaParameterReference": { + "stepsOut": 2 + } + } + } + } + } + } + } + } +} diff --git a/core/src/test/resources/expressions/lambda/valid/zero_params.json b/core/src/test/resources/expressions/lambda/valid/zero_params.json new file mode 100644 index 000000000..df801c9cd --- /dev/null +++ b/core/src/test/resources/expressions/lambda/valid/zero_params.json @@ -0,0 +1,12 @@ +{ + "lambda": { + "parameters": { + "types": [] + }, + "body": { + "literal": { + "i32": 42 + } + } + } +} diff --git a/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java b/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java index 55924081d..0c294ff9f 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java @@ -1,10 +1,12 @@ package io.substrait.isthmus; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import io.substrait.expression.Expression; import io.substrait.expression.ExpressionCreator; import io.substrait.expression.LambdaBuilder; +import io.substrait.extension.DefaultExtensionCatalog; import io.substrait.relation.Project; import io.substrait.relation.Rel; import java.util.ArrayList; @@ -57,4 +59,35 @@ void nestedLambdaThrowsUnsupportedOperation() { Project project = Project.builder().expressions(exprs).input(emptyTable).build(); assertThrows(UnsupportedOperationException.class, () -> assertFullRoundTrip(project)); } + + // (x: i64) -> (y1: i64, y2: i64) -> y1 * x + y2 + @Test + void nestedLambdaWithArithmeticBody() { + String ARITH = DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC; + + Expression.Lambda lambda = + lb.lambda( + List.of(R.I64), + outer -> + lb.lambda( + List.of(R.I64, R.I64), + inner -> { + Expression multiply = + sb.scalarFn(ARITH, "multiply:i64_i64", R.I64, inner.ref(0), outer.ref(0)); + return sb.scalarFn(ARITH, "add:i64_i64", R.I64, multiply, inner.ref(1)); + })); + + // Proto-only roundtrip since Calcite doesn't support nested lambdas + List exprs = new ArrayList<>(); + exprs.add(lambda); + Project project = Project.builder().expressions(exprs).input(emptyTable).build(); + + io.substrait.extension.ExtensionCollector collector = + new io.substrait.extension.ExtensionCollector(); + io.substrait.proto.Rel proto = + new io.substrait.relation.RelProtoConverter(collector).toProto(project); + io.substrait.relation.Rel roundTripped = + new io.substrait.relation.ProtoRelConverter(collector, extensions).from(proto); + assertEquals(project, roundTripped); + } } From f172704dce0b71de4f971c55d3f7a653a6d66b5d Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Mon, 16 Mar 2026 17:35:59 -0400 Subject: [PATCH 14/18] docs: fix LambdaBuilder javadoc to use params/outer/inner naming --- .../main/java/io/substrait/expression/LambdaBuilder.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/core/src/main/java/io/substrait/expression/LambdaBuilder.java b/core/src/main/java/io/substrait/expression/LambdaBuilder.java index 1197cd1f1..621a5ef90 100644 --- a/core/src/main/java/io/substrait/expression/LambdaBuilder.java +++ b/core/src/main/java/io/substrait/expression/LambdaBuilder.java @@ -19,12 +19,12 @@ * LambdaBuilder lb = new LambdaBuilder(); * * // Simple: (x: i32) -> x - * Expression.Lambda simple = lb.lambda(List.of(R.I32), x -> x.ref(0)); + * Expression.Lambda simple = lb.lambda(List.of(R.I32), params -> params.ref(0)); * * // Nested: (x: i32) -> (y: i64) -> add(x, y) - * Expression.Lambda nested = lb.lambda(List.of(R.I32), x -> - * lb.lambda(List.of(R.I64), y -> - * add(x.ref(0), y.ref(0)) + * Expression.Lambda nested = lb.lambda(List.of(R.I32), outer -> + * lb.lambda(List.of(R.I64), inner -> + * add(outer.ref(0), inner.ref(0)) * ) * ); * } From 67c5a8a763af262653cfd72a662b3e9841ab3339 Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Mon, 16 Mar 2026 17:45:12 -0400 Subject: [PATCH 15/18] refactor: clarify Scope internals, extract stepsOut() method and document depth-capture mechanism --- .../substrait/expression/LambdaBuilder.java | 38 ++++++++++++------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/core/src/main/java/io/substrait/expression/LambdaBuilder.java b/core/src/main/java/io/substrait/expression/LambdaBuilder.java index 621a5ef90..8709b6aa6 100644 --- a/core/src/main/java/io/substrait/expression/LambdaBuilder.java +++ b/core/src/main/java/io/substrait/expression/LambdaBuilder.java @@ -45,8 +45,7 @@ public Expression.Lambda lambda(List paramTypes, Function= lambdaContext.size()) { + int targetDepth = lambdaContext.size() - stepsOut; + if (targetDepth <= 0 || targetDepth > lambdaContext.size()) { throw new IllegalArgumentException( String.format( "Lambda parameter reference with stepsOut=%d is invalid (current depth: %d)", stepsOut, lambdaContext.size())); } - return lambdaContext.get(index); + return lambdaContext.get(targetDepth - 1); } /** @@ -113,26 +112,39 @@ private void popLambdaContext() { /** * A handle to a particular lambda's parameter scope. Use {@link #ref} to create validated * parameter references. + * + *

Each Scope captures the depth of the lambdaContext stack at the time it was created. When + * {@link #ref} is called, the Substrait {@code stepsOut} value is computed as the difference + * between the current stack depth and the captured depth. This means the same Scope produces + * different stepsOut values depending on the nesting level at the time of the call, which is what + * allows outer.ref(0) to produce stepsOut=1 when called inside a nested lambda. */ public class Scope { - private final int index; + private final Type.Struct params; + private final int depth; + + private Scope(Type.Struct params) { + this.params = params; + this.depth = lambdaContext.size(); + } - private Scope(int index) { - this.index = index; + /** + * Computes the number of lambda boundaries between this scope and the current innermost scope. + * This value changes dynamically as nested lambdas are built. + */ + private int stepsOut() { + return lambdaContext.size() - depth; } /** - * Creates a validated reference to a parameter of this lambda. The correct {@code stepsOut} - * value is computed automatically. + * Creates a validated reference to a parameter of this lambda. * * @param paramIndex index of the parameter within this lambda's parameter struct * @return a {@link FieldReference} pointing to the specified parameter * @throws IndexOutOfBoundsException if paramIndex is out of bounds */ public FieldReference ref(int paramIndex) { - int stepsOut = lambdaContext.size() - 1 - index; - return FieldReference.newLambdaParameterReference( - paramIndex, lambdaContext.get(index), stepsOut); + return FieldReference.newLambdaParameterReference(paramIndex, params, stepsOut()); } } } From eed9ea92cf79243a54f14486ea398f4e0a9daf6c Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Mon, 16 Mar 2026 17:46:49 -0400 Subject: [PATCH 16/18] test: add test verifying stepsOut changes dynamically with nesting depth --- .../expression/LambdaBuilderTest.java | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java b/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java index 492c56f30..a4c569219 100644 --- a/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java +++ b/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java @@ -54,6 +54,31 @@ void nestedLambda() { assertEquals(expected, lambda); } + // Verify that the same scope handle produces different stepsOut values depending on nesting. + // outer.ref(0) should produce stepsOut=0 at the top level and stepsOut=1 inside a nested lambda. + @Test + void scopeStepsOutChangesDynamically() { + Type.Struct outerParams = Type.Struct.builder().nullable(false).addFields(R.I32).build(); + Type.Struct innerParams = Type.Struct.builder().nullable(false).addFields(R.I64).build(); + + lb.lambda( + List.of(R.I32), + outer -> { + FieldReference atTopLevel = outer.ref(0); + assertEquals(0, atTopLevel.lambdaParameterReferenceStepsOut().orElse(-1)); + + lb.lambda( + List.of(R.I64), + inner -> { + FieldReference atNestedLevel = outer.ref(0); + assertEquals(1, atNestedLevel.lambdaParameterReferenceStepsOut().orElse(-1)); + return inner.ref(0); + }); + + return atTopLevel; + }); + } + // (x: i32)@p -> p[5] — only 1 param, index 5 is out of bounds @Test void invalidFieldIndex_outOfBounds() { From c37d527a2fea7eedbdd00af258fa78f7ef64f9dc Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Mon, 16 Mar 2026 17:54:32 -0400 Subject: [PATCH 17/18] test: simplify arithmetic body test to single lambda (x -> x + x) --- .../isthmus/LambdaExpressionTest.java | 24 ++++--------------- 1 file changed, 4 insertions(+), 20 deletions(-) diff --git a/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java b/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java index 0c294ff9f..fb33407b8 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/LambdaExpressionTest.java @@ -1,6 +1,5 @@ package io.substrait.isthmus; -import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import io.substrait.expression.Expression; @@ -60,34 +59,19 @@ void nestedLambdaThrowsUnsupportedOperation() { assertThrows(UnsupportedOperationException.class, () -> assertFullRoundTrip(project)); } - // (x: i64) -> (y1: i64, y2: i64) -> y1 * x + y2 + // (x: i64)@p -> add(p[0], p[0]) @Test - void nestedLambdaWithArithmeticBody() { + void lambdaWithArithmeticBody() { String ARITH = DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC; Expression.Lambda lambda = lb.lambda( List.of(R.I64), - outer -> - lb.lambda( - List.of(R.I64, R.I64), - inner -> { - Expression multiply = - sb.scalarFn(ARITH, "multiply:i64_i64", R.I64, inner.ref(0), outer.ref(0)); - return sb.scalarFn(ARITH, "add:i64_i64", R.I64, multiply, inner.ref(1)); - })); + params -> sb.scalarFn(ARITH, "add:i64_i64", R.I64, params.ref(0), params.ref(0))); - // Proto-only roundtrip since Calcite doesn't support nested lambdas List exprs = new ArrayList<>(); exprs.add(lambda); Project project = Project.builder().expressions(exprs).input(emptyTable).build(); - - io.substrait.extension.ExtensionCollector collector = - new io.substrait.extension.ExtensionCollector(); - io.substrait.proto.Rel proto = - new io.substrait.relation.RelProtoConverter(collector).toProto(project); - io.substrait.relation.Rel roundTripped = - new io.substrait.relation.ProtoRelConverter(collector, extensions).from(proto); - assertEquals(project, roundTripped); + assertFullRoundTrip(project); } } From d55762905a0a53252b98f5ae506ab974796cdd7c Mon Sep 17 00:00:00 2001 From: Ben Bellick Date: Mon, 16 Mar 2026 18:05:00 -0400 Subject: [PATCH 18/18] fix: remove unused local variables flagged by PMD --- .../test/java/io/substrait/expression/LambdaBuilderTest.java | 3 --- 1 file changed, 3 deletions(-) diff --git a/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java b/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java index a4c569219..01303d932 100644 --- a/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java +++ b/core/src/test/java/io/substrait/expression/LambdaBuilderTest.java @@ -58,9 +58,6 @@ void nestedLambda() { // outer.ref(0) should produce stepsOut=0 at the top level and stepsOut=1 inside a nested lambda. @Test void scopeStepsOutChangesDynamically() { - Type.Struct outerParams = Type.Struct.builder().nullable(false).addFields(R.I32).build(); - Type.Struct innerParams = Type.Struct.builder().nullable(false).addFields(R.I64).build(); - lb.lambda( List.of(R.I32), outer -> {