Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
5ca6ff8
feat: add lambda expression support
limameml Feb 24, 2026
ed61127
Merge branch 'main' into limame.malainine/add-lambda-support
limameml Feb 25, 2026
a90c609
adresse some of @benbellick's comments
limameml Feb 25, 2026
4053357
tweak: encapsulate LambdaParameterStack logic in a class
limameml Mar 2, 2026
ed5de56
Merge branch 'main' into limame.malainine/add-lambda-support
limameml Mar 2, 2026
eb7b2b2
tweak: adding comments expailining LambdaParameterStack
limameml Mar 2, 2026
60b66be
Merge branch 'main' into limame.malainine/add-lambda-support
vbarua Mar 6, 2026
1d1b26e
build: ignore build related generated files
vbarua Mar 7, 2026
2224c4b
feat: enable parsing of func types in extensions
vbarua Mar 6, 2026
83b9e24
test: copy substrait-go lambda plans
vbarua Mar 10, 2026
d802c19
test: add LambdaRoundtripTests
vbarua Mar 10, 2026
14738e3
feat(isthmus): add TRANSFORM SqlFunction to handle transform:list_func
vbarua Mar 9, 2026
2b5436c
Merge pull request #1 from substrait-io/vbarua/lambda-testing
limameml Mar 11, 2026
6fae010
tweak: add filter in the function mapping
limameml Mar 11, 2026
1c1f8d2
adress @vbarua's comments
limameml Mar 13, 2026
bd30a21
feat(core): add LambdaBuilder for build-time validation of lambda par…
benbellick Mar 16, 2026
50db699
refactor: unify lambda validation, add JSON-based roundtrip tests
benbellick Mar 16, 2026
f172704
docs: fix LambdaBuilder javadoc to use params/outer/inner naming
benbellick Mar 16, 2026
67c5a8a
refactor: clarify Scope internals, extract stepsOut() method and docu…
benbellick Mar 16, 2026
eed9ea9
test: add test verifying stepsOut changes dynamically with nesting depth
benbellick Mar 16, 2026
c37d527
test: simplify arithmetic body test to single lambda (x -> x + x)
benbellick Mar 16, 2026
d557629
fix: remove unused local variables flagged by PMD
benbellick Mar 16, 2026
0dd2114
Merge pull request #2 from substrait-io/benbellick/proposed-lambda-va…
limameml Mar 17, 2026
d158f07
Merge branch 'main' into limame.malainine/add-lambda-support
limameml Mar 17, 2026
5c2c99c
fix: remove uri mentions left in substrait plans and add all_match an…
limameml Mar 17, 2026
8b81dbb
adressing some of @benbellick's comments
limameml Mar 18, 2026
aa659c5
refactor: reorder newLambdaParameterReference parameters for readability
benbellick Mar 18, 2026
e3e9f48
docs: add javadoc to newLambdaParameterReference explaining validation
benbellick Mar 18, 2026
b35ad63
refactor: make newLambdaParameterReference package-private
benbellick Mar 18, 2026
7dc36b9
refactor: simplify newLambdaParameterReference to take Type directly
benbellick Mar 18, 2026
c88459d
test: add lambdaWithFunctionCall test to LambdaBuilderTest
benbellick Mar 18, 2026
d34bb58
test: add invalid proto test for out-of-bounds param index
benbellick Mar 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,6 @@ out/**
.metals
.bloop
.project
.classpath
.settings
bin/
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,19 @@ public O visit(Expression.IfThen expr, C context) throws E {
return visitFallback(expr, context);
}

/**
* Visits a Lambda expression.
*
* @param expr the Lambda expression
* @param context the visitation context
* @return the visit result
* @throws E if visitation fails
*/
@Override
public O visit(Expression.Lambda expr, C context) throws E {
return visitFallback(expr, context);
}

/**
* Visits a scalar function invocation.
*
Expand Down
24 changes: 24 additions & 0 deletions core/src/main/java/io/substrait/expression/Expression.java
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,30 @@ public <R, C extends VisitationContext, E extends Throwable> R accept(
}
}

@Value.Immutable
abstract class Lambda implements Expression {
public abstract Type.Struct parameters();

public abstract Expression body();

@Override
public Type getType() {
List<Type> paramTypes = parameters().fields();
Type returnType = body().getType();

// TODO: Type.Func nullability is hardcoded to false here because the spec does not allow for
// declaring otherwise.
// See: https://github.com/substrait-io/substrait/issues/976
return Type.withNullability(false).func(paramTypes, returnType);
}

@Override
public <R, C extends VisitationContext, E extends Throwable> R accept(
ExpressionVisitor<R, C, E> visitor, C context) throws E {
return visitor.visit(this, context);
}
}

/**
* Base interface for user-defined literals.
*
Expand Down
10 changes: 10 additions & 0 deletions core/src/main/java/io/substrait/expression/ExpressionVisitor.java
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,16 @@ public interface ExpressionVisitor<R, C extends VisitationContext, E extends Thr
*/
R visit(Expression.NestedStruct expr, C context) throws E;

/**
* Visit a Lambda expression.
*
* @param expr the Lambda expression
* @param context visitation context
* @return visit result
* @throws E on visit failure
*/
R visit(Expression.Lambda expr, C context) throws E;

/**
* Visit a user-defined any literal.
*
Expand Down
26 changes: 25 additions & 1 deletion core/src/main/java/io/substrait/expression/FieldReference.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ public abstract class FieldReference implements Expression {

public abstract Optional<Integer> outerReferenceStepsOut();

public abstract Optional<Integer> lambdaParameterReferenceStepsOut();

@Override
public Type getType() {
return type();
Expand All @@ -35,16 +37,30 @@ public <R, C extends VisitationContext, E extends Throwable> R accept(
return visitor.visit(this, context);
}

@Value.Check
protected void check() {
if (outerReferenceStepsOut().isPresent() && lambdaParameterReferenceStepsOut().isPresent()) {
throw new IllegalArgumentException(
"FieldReference cannot have both outerReferenceStepsOut and lambdaParameterReferenceStepsOut set");
}
}

public boolean isSimpleRootReference() {
return segments().size() == 1
&& !inputExpression().isPresent()
&& !outerReferenceStepsOut().isPresent();
&& !outerReferenceStepsOut().isPresent()
&& !lambdaParameterReferenceStepsOut().isPresent();
}

public boolean isOuterReference() {
return outerReferenceStepsOut().orElse(0) > 0;
}

/** Returns true if this field reference refers to a lambda parameter. */
public boolean isLambdaParameterReference() {
return lambdaParameterReferenceStepsOut().isPresent();
}

public FieldReference dereferenceStruct(int index) {
Type newType = StructFieldFinder.getReferencedType(type(), index);
return dereference(newType, StructField.of(index));
Expand Down Expand Up @@ -134,6 +150,14 @@ public static FieldReference newInputRelReference(int index, List<Rel> rels) {
index, currentOffset));
}

static FieldReference newLambdaParameterReference(int stepsOut, int paramIndex, Type knownType) {
return ImmutableFieldReference.builder()
.addSegments(StructField.of(paramIndex))
.type(knownType)
.lambdaParameterReferenceStepsOut(stepsOut)
.build();
}

public interface ReferenceSegment {
FieldReference apply(FieldReference reference);

Expand Down
167 changes: 167 additions & 0 deletions core/src/main/java/io/substrait/expression/LambdaBuilder.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
package io.substrait.expression;

import io.substrait.type.Type;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Function;

/**
* Builds lambda expressions with build-time validation of parameter references.
*
* <p>Maintains a stack of lambda parameter scopes. Each call to {@link #lambda} pushes parameters
* onto the stack, builds the body via a callback, and pops. Nested lambdas simply call {@code
* lambda()} again on the same builder.
*
* <p>The callback receives a {@link Scope} handle for creating validated parameter references. The
* correct {@code stepsOut} value is computed automatically from the stack.
*
* <pre>{@code
* LambdaBuilder lb = new LambdaBuilder();
*
* // Simple: (x: i32) -> x
* Expression.Lambda simple = lb.lambda(List.of(R.I32), params -> params.ref(0));
*
* // Nested: (x: i32) -> (y: i64) -> add(x, y)
* Expression.Lambda nested = lb.lambda(List.of(R.I32), outer ->
* lb.lambda(List.of(R.I64), inner ->
* add(outer.ref(0), inner.ref(0))
* )
* );
* }</pre>
*/
public class LambdaBuilder {
private final List<Type.Struct> lambdaContext = new ArrayList<>();

/**
* Builds a lambda expression. The body function receives a {@link Scope} for creating validated
* parameter references. Nested lambdas are built by calling this method again inside the
* callback.
*
* @param paramTypes the lambda's parameter types
* @param bodyFn function that builds the lambda body given a scope handle
* @return the constructed lambda expression
*/
public Expression.Lambda lambda(List<Type> paramTypes, Function<Scope, Expression> bodyFn) {
Type.Struct params = Type.Struct.builder().nullable(false).addAllFields(paramTypes).build();
pushLambdaContext(params);
try {
Scope scope = new Scope(params);
Expression body = bodyFn.apply(scope);
return ImmutableExpression.Lambda.builder().parameters(params).body(body).build();
} finally {
popLambdaContext();
}
}

/**
* Builds a lambda expression from a pre-built parameter struct. Used by internal converters that
* already have a Type.Struct (e.g., during protobuf deserialization).
*
* @param params the lambda's parameter struct
* @param bodyFn function that builds the lambda body
* @return the constructed lambda expression
*/
public Expression.Lambda lambdaFromStruct(
Type.Struct params, java.util.function.Supplier<Expression> bodyFn) {
pushLambdaContext(params);
try {
Expression body = bodyFn.get();
return ImmutableExpression.Lambda.builder().parameters(params).body(body).build();
} finally {
popLambdaContext();
}
}

/**
* Resolves the parameter struct for a lambda at the given stepsOut from the current innermost
* scope. Used by internal converters to validate lambda parameter references during
* deserialization.
*
* @param stepsOut number of lambda scopes to traverse outward (0 = current/innermost)
* @return the parameter struct at the target scope level
* @throws IllegalArgumentException if stepsOut exceeds the current nesting depth
*/
public Type.Struct resolveParams(int stepsOut) {
int targetDepth = lambdaContext.size() - stepsOut;
if (targetDepth <= 0 || targetDepth > lambdaContext.size()) {
throw new IllegalArgumentException(
String.format(
"Lambda parameter reference with stepsOut=%d is invalid (current depth: %d)",
stepsOut, lambdaContext.size()));
}
return lambdaContext.get(targetDepth - 1);
}

/**
* Creates a validated field reference to a lambda parameter. Validates that stepsOut is valid for
* the current lambda nesting context.
*
* @param stepsOut number of lambda scopes to traverse outward (0 = current/innermost)
* @param paramIndex index of the parameter within the target lambda's parameter struct
* @return a field reference to the specified lambda parameter
* @throws IllegalArgumentException if stepsOut exceeds the current nesting depth
* @throws IndexOutOfBoundsException if paramIndex is out of bounds for the target lambda
*/
public FieldReference newParameterReference(int stepsOut, int paramIndex) {
Type.Struct params = resolveParams(stepsOut);
Type type = params.fields().get(paramIndex);
return FieldReference.newLambdaParameterReference(stepsOut, paramIndex, type);
}

/**
* Pushes a lambda's parameters onto the context stack. This makes the parameters available for
* validation when building the lambda's body, and allows nested lambda parameter references to
* correctly compute their stepsOut values.
*/
private void pushLambdaContext(Type.Struct params) {
lambdaContext.add(params);
}

/**
* Pops the most recently pushed lambda parameters from the context stack. Called after a lambda's
* body has been built, restoring the context to the enclosing lambda's scope.
*/
private void popLambdaContext() {
lambdaContext.remove(lambdaContext.size() - 1);
}

/**
* A handle to a particular lambda's parameter scope. Use {@link #ref} to create validated
* parameter references.
*
* <p>Each Scope captures the depth of the lambdaContext stack at the time it was created. When
* {@link #ref} is called, the Substrait {@code stepsOut} value is computed as the difference
* between the current stack depth and the captured depth. This means the same Scope produces
* different stepsOut values depending on the nesting level at the time of the call, which is what
* allows outer.ref(0) to produce stepsOut=1 when called inside a nested lambda.
*/
public class Scope {
private final Type.Struct params;
private final int depth;

private Scope(Type.Struct params) {
this.params = params;
this.depth = lambdaContext.size();
}

/**
* Computes the number of lambda boundaries between this scope and the current innermost scope.
* This value changes dynamically as nested lambdas are built.
*/
private int stepsOut() {
return lambdaContext.size() - depth;
}

/**
* Creates a validated reference to a parameter of this lambda.
*
* @param paramIndex index of the parameter within this lambda's parameter struct
* @return a {@link FieldReference} pointing to the specified parameter
* @throws IndexOutOfBoundsException if paramIndex is out of bounds
*/
public FieldReference ref(int paramIndex) throws IndexOutOfBoundsException {
Type type = params.fields().get(paramIndex);
return FieldReference.newLambdaParameterReference(stepsOut(), paramIndex, type);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,18 @@ public Expression visit(
});
}

@Override
public Expression visit(
io.substrait.expression.Expression.Lambda expr, EmptyVisitationContext context)
throws RuntimeException {
return io.substrait.proto.Expression.newBuilder()
.setLambda(
io.substrait.proto.Expression.Lambda.newBuilder()
.setParameters(typeProtoConverter.toProto(expr.parameters()).getStruct())
.setBody(expr.body().accept(this, context)))
.build();
}

@Override
public Expression visit(
io.substrait.expression.Expression.UserDefinedAnyLiteral expr,
Expand Down Expand Up @@ -617,6 +629,10 @@ public Expression visit(FieldReference expr, EmptyVisitationContext context) {
out.setOuterReference(
io.substrait.proto.Expression.FieldReference.OuterReference.newBuilder()
.setStepsOut(expr.outerReferenceStepsOut().get()));
} else if (expr.lambdaParameterReferenceStepsOut().isPresent()) {
out.setLambdaParameterReference(
io.substrait.proto.Expression.FieldReference.LambdaParameterReference.newBuilder()
.setStepsOut(expr.lambdaParameterReferenceStepsOut().get()));
} else {
out.setRootReference(Expression.FieldReference.RootReference.getDefaultInstance());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import io.substrait.expression.FieldReference.ReferenceSegment;
import io.substrait.expression.FunctionArg;
import io.substrait.expression.FunctionOption;
import io.substrait.expression.LambdaBuilder;
import io.substrait.expression.WindowBound;
import io.substrait.extension.ExtensionLookup;
import io.substrait.extension.SimpleExtension;
Expand Down Expand Up @@ -37,6 +38,7 @@ public class ProtoExpressionConverter {
private final Type.Struct rootType;
private final ProtoTypeConverter protoTypeConverter;
private final ProtoRelConverter protoRelConverter;
private final LambdaBuilder lambdaBuilder = new LambdaBuilder();

public ProtoExpressionConverter(
ExtensionLookup lookup,
Expand Down Expand Up @@ -75,6 +77,21 @@ public FieldReference from(io.substrait.proto.Expression.FieldReference referenc
reference.getDirectReference().getStructField().getField(),
rootType,
reference.getOuterReference().getStepsOut());
case LAMBDA_PARAMETER_REFERENCE:
{
io.substrait.proto.Expression.FieldReference.LambdaParameterReference lambdaParamRef =
reference.getLambdaParameterReference();

// Check for unsupported nested field access
if (reference.getDirectReference().getStructField().hasChild()) {
throw new UnsupportedOperationException(
"Nested field access in lambda parameters is not yet supported");
}

return lambdaBuilder.newParameterReference(
lambdaParamRef.getStepsOut(),
reference.getDirectReference().getStructField().getField());
}
case ROOTTYPE_NOT_SET:
default:
throw new IllegalArgumentException("Unhandled type: " + reference.getRootTypeCase());
Expand Down Expand Up @@ -260,6 +277,18 @@ public Type visit(Type.Struct type) throws RuntimeException {
}
}

case LAMBDA:
{
io.substrait.proto.Expression.Lambda protoLambda = expr.getLambda();
Type.Struct parameters =
(Type.Struct)
protoTypeConverter.from(
io.substrait.proto.Type.newBuilder()
.setStruct(protoLambda.getParameters())
.build());

return lambdaBuilder.lambdaFromStruct(parameters, () -> from(protoLambda.getBody()));
}
// TODO enum.
case ENUM:
throw new UnsupportedOperationException("Unsupported type: " + expr.getRexTypeCase());
Expand Down
Loading
Loading