Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
5ca6ff8
feat: add lambda expression support
limameml Feb 24, 2026
ed61127
Merge branch 'main' into limame.malainine/add-lambda-support
limameml Feb 25, 2026
a90c609
adresse some of @benbellick's comments
limameml Feb 25, 2026
4053357
tweak: encapsulate LambdaParameterStack logic in a class
limameml Mar 2, 2026
ed5de56
Merge branch 'main' into limame.malainine/add-lambda-support
limameml Mar 2, 2026
eb7b2b2
tweak: adding comments expailining LambdaParameterStack
limameml Mar 2, 2026
60b66be
Merge branch 'main' into limame.malainine/add-lambda-support
vbarua Mar 6, 2026
1d1b26e
build: ignore build related generated files
vbarua Mar 7, 2026
2224c4b
feat: enable parsing of func types in extensions
vbarua Mar 6, 2026
83b9e24
test: copy substrait-go lambda plans
vbarua Mar 10, 2026
d802c19
test: add LambdaRoundtripTests
vbarua Mar 10, 2026
14738e3
feat(isthmus): add TRANSFORM SqlFunction to handle transform:list_func
vbarua Mar 9, 2026
2b5436c
Merge pull request #1 from substrait-io/vbarua/lambda-testing
limameml Mar 11, 2026
6fae010
tweak: add filter in the function mapping
limameml Mar 11, 2026
1c1f8d2
adress @vbarua's comments
limameml Mar 13, 2026
bd30a21
feat(core): add LambdaBuilder for build-time validation of lambda par…
benbellick Mar 16, 2026
50db699
refactor: unify lambda validation, add JSON-based roundtrip tests
benbellick Mar 16, 2026
f172704
docs: fix LambdaBuilder javadoc to use params/outer/inner naming
benbellick Mar 16, 2026
67c5a8a
refactor: clarify Scope internals, extract stepsOut() method and docu…
benbellick Mar 16, 2026
eed9ea9
test: add test verifying stepsOut changes dynamically with nesting depth
benbellick Mar 16, 2026
c37d527
test: simplify arithmetic body test to single lambda (x -> x + x)
benbellick Mar 16, 2026
d557629
fix: remove unused local variables flagged by PMD
benbellick Mar 16, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,6 @@ out/**
.metals
.bloop
.project
.classpath
.settings
bin/
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,19 @@ public O visit(Expression.IfThen expr, C context) throws E {
return visitFallback(expr, context);
}

/**
* Visits a Lambda expression.
*
* @param expr the Lambda expression
* @param context the visitation context
* @return the visit result
* @throws E if visitation fails
*/
@Override
public O visit(Expression.Lambda expr, C context) throws E {
return visitFallback(expr, context);
}

/**
* Visits a scalar function invocation.
*
Expand Down
24 changes: 24 additions & 0 deletions core/src/main/java/io/substrait/expression/Expression.java
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,30 @@ public <R, C extends VisitationContext, E extends Throwable> R accept(
}
}

@Value.Immutable
abstract class Lambda implements Expression {
public abstract Type.Struct parameters();

public abstract Expression body();

@Override
public Type getType() {
List<Type> paramTypes = parameters().fields();
Type returnType = body().getType();

// TODO: Type.Func nullability is hardcoded to false here because the spec does not allow for
// declaring otherwise.
// See: https://github.com/substrait-io/substrait/issues/976
return Type.withNullability(false).func(paramTypes, returnType);
}

@Override
public <R, C extends VisitationContext, E extends Throwable> R accept(
ExpressionVisitor<R, C, E> visitor, C context) throws E {
return visitor.visit(this, context);
}
}

/**
* Base interface for user-defined literals.
*
Expand Down
10 changes: 10 additions & 0 deletions core/src/main/java/io/substrait/expression/ExpressionVisitor.java
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,16 @@ public interface ExpressionVisitor<R, C extends VisitationContext, E extends Thr
*/
R visit(Expression.NestedStruct expr, C context) throws E;

/**
* Visit a Lambda expression.
*
* @param expr the Lambda expression
* @param context visitation context
* @return visit result
* @throws E on visit failure
*/
R visit(Expression.Lambda expr, C context) throws E;

/**
* Visit a user-defined any literal.
*
Expand Down
18 changes: 17 additions & 1 deletion core/src/main/java/io/substrait/expression/FieldReference.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ public abstract class FieldReference implements Expression {

public abstract Optional<Integer> outerReferenceStepsOut();

public abstract Optional<Integer> lambdaParameterReferenceStepsOut();

@Override
public Type getType() {
return type();
Expand All @@ -38,13 +40,18 @@ public <R, C extends VisitationContext, E extends Throwable> R accept(
public boolean isSimpleRootReference() {
return segments().size() == 1
&& !inputExpression().isPresent()
&& !outerReferenceStepsOut().isPresent();
&& !outerReferenceStepsOut().isPresent()
&& !lambdaParameterReferenceStepsOut().isPresent();
}

public boolean isOuterReference() {
return outerReferenceStepsOut().orElse(0) > 0;
}

public boolean isLambdaParameterReference() {
return lambdaParameterReferenceStepsOut().isPresent();
}

public FieldReference dereferenceStruct(int index) {
Type newType = StructFieldFinder.getReferencedType(type(), index);
return dereference(newType, StructField.of(index));
Expand Down Expand Up @@ -134,6 +141,15 @@ public static FieldReference newInputRelReference(int index, List<Rel> rels) {
index, currentOffset));
}

public static FieldReference newLambdaParameterReference(
int paramIndex, Type.Struct lambdaParamsType, int stepsOut) {
return ImmutableFieldReference.builder()
.addSegments(StructField.of(paramIndex))
.type(lambdaParamsType.fields().get(paramIndex))
.lambdaParameterReferenceStepsOut(stepsOut)
.build();
}

public interface ReferenceSegment {
FieldReference apply(FieldReference reference);

Expand Down
150 changes: 150 additions & 0 deletions core/src/main/java/io/substrait/expression/LambdaBuilder.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
package io.substrait.expression;

import io.substrait.type.Type;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Function;

/**
* Builds lambda expressions with build-time validation of parameter references.
*
* <p>Maintains a stack of lambda parameter scopes. Each call to {@link #lambda} pushes parameters
* onto the stack, builds the body via a callback, and pops. Nested lambdas simply call {@code
* lambda()} again on the same builder.
*
* <p>The callback receives a {@link Scope} handle for creating validated parameter references. The
* correct {@code stepsOut} value is computed automatically from the stack.
*
* <pre>{@code
* LambdaBuilder lb = new LambdaBuilder();
*
* // Simple: (x: i32) -> x
* Expression.Lambda simple = lb.lambda(List.of(R.I32), params -> params.ref(0));
*
* // Nested: (x: i32) -> (y: i64) -> add(x, y)
* Expression.Lambda nested = lb.lambda(List.of(R.I32), outer ->
* lb.lambda(List.of(R.I64), inner ->
* add(outer.ref(0), inner.ref(0))
* )
* );
* }</pre>
*/
public class LambdaBuilder {
private final List<Type.Struct> lambdaContext = new ArrayList<>();

/**
* Builds a lambda expression. The body function receives a {@link Scope} for creating validated
* parameter references. Nested lambdas are built by calling this method again inside the
* callback.
*
* @param paramTypes the lambda's parameter types
* @param bodyFn function that builds the lambda body given a scope handle
* @return the constructed lambda expression
*/
public Expression.Lambda lambda(List<Type> paramTypes, Function<Scope, Expression> bodyFn) {
Type.Struct params = Type.Struct.builder().nullable(false).addAllFields(paramTypes).build();
pushLambdaContext(params);
try {
Scope scope = new Scope(params);
Expression body = bodyFn.apply(scope);
return ImmutableExpression.Lambda.builder().parameters(params).body(body).build();
} finally {
popLambdaContext();
}
}

/**
* Builds a lambda expression from a pre-built parameter struct. Used by internal converters that
* already have a Type.Struct (e.g., during protobuf deserialization).
*
* @param params the lambda's parameter struct
* @param bodyFn function that builds the lambda body
* @return the constructed lambda expression
*/
public Expression.Lambda lambdaFromStruct(
Type.Struct params, java.util.function.Supplier<Expression> bodyFn) {
pushLambdaContext(params);
try {
Expression body = bodyFn.get();
return ImmutableExpression.Lambda.builder().parameters(params).body(body).build();
} finally {
popLambdaContext();
}
}

/**
* Resolves the parameter struct for a lambda at the given stepsOut from the current innermost
* scope. Used by internal converters to validate lambda parameter references during
* deserialization.
*
* @param stepsOut number of lambda scopes to traverse outward (0 = current/innermost)
* @return the parameter struct at the target scope level
* @throws IllegalArgumentException if stepsOut exceeds the current nesting depth
*/
public Type.Struct resolveParams(int stepsOut) {
int targetDepth = lambdaContext.size() - stepsOut;
if (targetDepth <= 0 || targetDepth > lambdaContext.size()) {
throw new IllegalArgumentException(
String.format(
"Lambda parameter reference with stepsOut=%d is invalid (current depth: %d)",
stepsOut, lambdaContext.size()));
}
return lambdaContext.get(targetDepth - 1);
}

/**
* Pushes a lambda's parameters onto the context stack. This makes the parameters available for
* validation when building the lambda's body, and allows nested lambda parameter references to
* correctly compute their stepsOut values.
*/
private void pushLambdaContext(Type.Struct params) {
lambdaContext.add(params);
}

/**
* Pops the most recently pushed lambda parameters from the context stack. Called after a lambda's
* body has been built, restoring the context to the enclosing lambda's scope.
*/
private void popLambdaContext() {
lambdaContext.remove(lambdaContext.size() - 1);
}

/**
* A handle to a particular lambda's parameter scope. Use {@link #ref} to create validated
* parameter references.
*
* <p>Each Scope captures the depth of the lambdaContext stack at the time it was created. When
* {@link #ref} is called, the Substrait {@code stepsOut} value is computed as the difference
* between the current stack depth and the captured depth. This means the same Scope produces
* different stepsOut values depending on the nesting level at the time of the call, which is what
* allows outer.ref(0) to produce stepsOut=1 when called inside a nested lambda.
*/
public class Scope {
private final Type.Struct params;
private final int depth;

private Scope(Type.Struct params) {
this.params = params;
this.depth = lambdaContext.size();
}

/**
* Computes the number of lambda boundaries between this scope and the current innermost scope.
* This value changes dynamically as nested lambdas are built.
*/
private int stepsOut() {
return lambdaContext.size() - depth;
}

/**
* Creates a validated reference to a parameter of this lambda.
*
* @param paramIndex index of the parameter within this lambda's parameter struct
* @return a {@link FieldReference} pointing to the specified parameter
* @throws IndexOutOfBoundsException if paramIndex is out of bounds
*/
public FieldReference ref(int paramIndex) {
return FieldReference.newLambdaParameterReference(paramIndex, params, stepsOut());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,18 @@ public Expression visit(
});
}

@Override
public Expression visit(
io.substrait.expression.Expression.Lambda expr, EmptyVisitationContext context)
throws RuntimeException {
return io.substrait.proto.Expression.newBuilder()
.setLambda(
io.substrait.proto.Expression.Lambda.newBuilder()
.setParameters(typeProtoConverter.toProto(expr.parameters()).getStruct())
.setBody(expr.body().accept(this, context)))
.build();
}

@Override
public Expression visit(
io.substrait.expression.Expression.UserDefinedAnyLiteral expr,
Expand Down Expand Up @@ -617,6 +629,10 @@ public Expression visit(FieldReference expr, EmptyVisitationContext context) {
out.setOuterReference(
io.substrait.proto.Expression.FieldReference.OuterReference.newBuilder()
.setStepsOut(expr.outerReferenceStepsOut().get()));
} else if (expr.lambdaParameterReferenceStepsOut().isPresent()) {
out.setLambdaParameterReference(
io.substrait.proto.Expression.FieldReference.LambdaParameterReference.newBuilder()
.setStepsOut(expr.lambdaParameterReferenceStepsOut().get()));
} else {
out.setRootReference(Expression.FieldReference.RootReference.getDefaultInstance());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import io.substrait.expression.FieldReference.ReferenceSegment;
import io.substrait.expression.FunctionArg;
import io.substrait.expression.FunctionOption;
import io.substrait.expression.LambdaBuilder;
import io.substrait.expression.WindowBound;
import io.substrait.extension.ExtensionLookup;
import io.substrait.extension.SimpleExtension;
Expand Down Expand Up @@ -37,6 +38,7 @@ public class ProtoExpressionConverter {
private final Type.Struct rootType;
private final ProtoTypeConverter protoTypeConverter;
private final ProtoRelConverter protoRelConverter;
private final LambdaBuilder lambdaBuilder = new LambdaBuilder();

public ProtoExpressionConverter(
ExtensionLookup lookup,
Expand Down Expand Up @@ -75,6 +77,25 @@ public FieldReference from(io.substrait.proto.Expression.FieldReference referenc
reference.getDirectReference().getStructField().getField(),
rootType,
reference.getOuterReference().getStepsOut());
case LAMBDA_PARAMETER_REFERENCE:
{
io.substrait.proto.Expression.FieldReference.LambdaParameterReference lambdaParamRef =
reference.getLambdaParameterReference();

int stepsOut = lambdaParamRef.getStepsOut();
Type.Struct lambdaParameters = lambdaBuilder.resolveParams(stepsOut);

// Check for unsupported nested field access
if (reference.getDirectReference().getStructField().hasChild()) {
throw new UnsupportedOperationException(
"Nested field access in lambda parameters is not yet supported");
}

return FieldReference.newLambdaParameterReference(
reference.getDirectReference().getStructField().getField(),
lambdaParameters,
stepsOut);
}
case ROOTTYPE_NOT_SET:
default:
throw new IllegalArgumentException("Unhandled type: " + reference.getRootTypeCase());
Expand Down Expand Up @@ -260,6 +281,18 @@ public Type visit(Type.Struct type) throws RuntimeException {
}
}

case LAMBDA:
{
io.substrait.proto.Expression.Lambda protoLambda = expr.getLambda();
Type.Struct parameters =
(Type.Struct)
protoTypeConverter.from(
io.substrait.proto.Type.newBuilder()
.setStruct(protoLambda.getParameters())
.build());

return lambdaBuilder.lambdaFromStruct(parameters, () -> from(protoLambda.getBody()));
}
// TODO enum.
case ENUM:
throw new UnsupportedOperationException("Unsupported type: " + expr.getRexTypeCase());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ public class DefaultExtensionCatalog {
/** Extension identifier for set functions. */
public static final String FUNCTIONS_SET = "extension:io.substrait:functions_set";

/** Extension identifier for list functions. */
public static final String FUNCTIONS_LIST = "extension:io.substrait:functions_list";

/** Extension identifier for string functions. */
public static final String FUNCTIONS_STRING = "extension:io.substrait:functions_string";

Expand All @@ -75,6 +78,7 @@ private static SimpleExtension.ExtensionCollection loadDefaultCollection() {
"arithmetic",
"comparison",
"datetime",
"list",
"logarithmic",
"rounding",
"rounding_decimal",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,6 @@ public interface ExtendedTypeCreator<T, I> {
T listE(T type);

T mapE(T key, T value);

T funcE(Iterable<? extends T> parameterTypes, T returnType);
}
Loading
Loading