Skip to content

Commit 0dd2114

Browse files
authored
Merge pull request #2 from substrait-io/benbellick/proposed-lambda-validation-impl
feat(core): add LambdaBuilder for build-time lambda validation
2 parents 1c1f8d2 + d557629 commit 0dd2114

16 files changed

Lines changed: 527 additions & 501 deletions

File tree

core/src/main/java/io/substrait/expression/Expression.java

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -775,10 +775,6 @@ public Type getType() {
775775
return Type.withNullability(false).func(paramTypes, returnType);
776776
}
777777

778-
public static ImmutableExpression.Lambda.Builder builder() {
779-
return ImmutableExpression.Lambda.builder();
780-
}
781-
782778
@Override
783779
public <R, C extends VisitationContext, E extends Throwable> R accept(
784780
ExpressionVisitor<R, C, E> visitor, C context) throws E {
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
package io.substrait.expression;
2+
3+
import io.substrait.type.Type;
4+
import java.util.ArrayList;
5+
import java.util.List;
6+
import java.util.function.Function;
7+
8+
/**
9+
* Builds lambda expressions with build-time validation of parameter references.
10+
*
11+
* <p>Maintains a stack of lambda parameter scopes. Each call to {@link #lambda} pushes parameters
12+
* onto the stack, builds the body via a callback, and pops. Nested lambdas simply call {@code
13+
* lambda()} again on the same builder.
14+
*
15+
* <p>The callback receives a {@link Scope} handle for creating validated parameter references. The
16+
* correct {@code stepsOut} value is computed automatically from the stack.
17+
*
18+
* <pre>{@code
19+
* LambdaBuilder lb = new LambdaBuilder();
20+
*
21+
* // Simple: (x: i32) -> x
22+
* Expression.Lambda simple = lb.lambda(List.of(R.I32), params -> params.ref(0));
23+
*
24+
* // Nested: (x: i32) -> (y: i64) -> add(x, y)
25+
* Expression.Lambda nested = lb.lambda(List.of(R.I32), outer ->
26+
* lb.lambda(List.of(R.I64), inner ->
27+
* add(outer.ref(0), inner.ref(0))
28+
* )
29+
* );
30+
* }</pre>
31+
*/
32+
public class LambdaBuilder {
33+
private final List<Type.Struct> lambdaContext = new ArrayList<>();
34+
35+
/**
36+
* Builds a lambda expression. The body function receives a {@link Scope} for creating validated
37+
* parameter references. Nested lambdas are built by calling this method again inside the
38+
* callback.
39+
*
40+
* @param paramTypes the lambda's parameter types
41+
* @param bodyFn function that builds the lambda body given a scope handle
42+
* @return the constructed lambda expression
43+
*/
44+
public Expression.Lambda lambda(List<Type> paramTypes, Function<Scope, Expression> bodyFn) {
45+
Type.Struct params = Type.Struct.builder().nullable(false).addAllFields(paramTypes).build();
46+
pushLambdaContext(params);
47+
try {
48+
Scope scope = new Scope(params);
49+
Expression body = bodyFn.apply(scope);
50+
return ImmutableExpression.Lambda.builder().parameters(params).body(body).build();
51+
} finally {
52+
popLambdaContext();
53+
}
54+
}
55+
56+
/**
57+
* Builds a lambda expression from a pre-built parameter struct. Used by internal converters that
58+
* already have a Type.Struct (e.g., during protobuf deserialization).
59+
*
60+
* @param params the lambda's parameter struct
61+
* @param bodyFn function that builds the lambda body
62+
* @return the constructed lambda expression
63+
*/
64+
public Expression.Lambda lambdaFromStruct(
65+
Type.Struct params, java.util.function.Supplier<Expression> bodyFn) {
66+
pushLambdaContext(params);
67+
try {
68+
Expression body = bodyFn.get();
69+
return ImmutableExpression.Lambda.builder().parameters(params).body(body).build();
70+
} finally {
71+
popLambdaContext();
72+
}
73+
}
74+
75+
/**
76+
* Resolves the parameter struct for a lambda at the given stepsOut from the current innermost
77+
* scope. Used by internal converters to validate lambda parameter references during
78+
* deserialization.
79+
*
80+
* @param stepsOut number of lambda scopes to traverse outward (0 = current/innermost)
81+
* @return the parameter struct at the target scope level
82+
* @throws IllegalArgumentException if stepsOut exceeds the current nesting depth
83+
*/
84+
public Type.Struct resolveParams(int stepsOut) {
85+
int targetDepth = lambdaContext.size() - stepsOut;
86+
if (targetDepth <= 0 || targetDepth > lambdaContext.size()) {
87+
throw new IllegalArgumentException(
88+
String.format(
89+
"Lambda parameter reference with stepsOut=%d is invalid (current depth: %d)",
90+
stepsOut, lambdaContext.size()));
91+
}
92+
return lambdaContext.get(targetDepth - 1);
93+
}
94+
95+
/**
96+
* Pushes a lambda's parameters onto the context stack. This makes the parameters available for
97+
* validation when building the lambda's body, and allows nested lambda parameter references to
98+
* correctly compute their stepsOut values.
99+
*/
100+
private void pushLambdaContext(Type.Struct params) {
101+
lambdaContext.add(params);
102+
}
103+
104+
/**
105+
* Pops the most recently pushed lambda parameters from the context stack. Called after a lambda's
106+
* body has been built, restoring the context to the enclosing lambda's scope.
107+
*/
108+
private void popLambdaContext() {
109+
lambdaContext.remove(lambdaContext.size() - 1);
110+
}
111+
112+
/**
113+
* A handle to a particular lambda's parameter scope. Use {@link #ref} to create validated
114+
* parameter references.
115+
*
116+
* <p>Each Scope captures the depth of the lambdaContext stack at the time it was created. When
117+
* {@link #ref} is called, the Substrait {@code stepsOut} value is computed as the difference
118+
* between the current stack depth and the captured depth. This means the same Scope produces
119+
* different stepsOut values depending on the nesting level at the time of the call, which is what
120+
* allows outer.ref(0) to produce stepsOut=1 when called inside a nested lambda.
121+
*/
122+
public class Scope {
123+
private final Type.Struct params;
124+
private final int depth;
125+
126+
private Scope(Type.Struct params) {
127+
this.params = params;
128+
this.depth = lambdaContext.size();
129+
}
130+
131+
/**
132+
* Computes the number of lambda boundaries between this scope and the current innermost scope.
133+
* This value changes dynamically as nested lambdas are built.
134+
*/
135+
private int stepsOut() {
136+
return lambdaContext.size() - depth;
137+
}
138+
139+
/**
140+
* Creates a validated reference to a parameter of this lambda.
141+
*
142+
* @param paramIndex index of the parameter within this lambda's parameter struct
143+
* @return a {@link FieldReference} pointing to the specified parameter
144+
* @throws IndexOutOfBoundsException if paramIndex is out of bounds
145+
*/
146+
public FieldReference ref(int paramIndex) {
147+
return FieldReference.newLambdaParameterReference(paramIndex, params, stepsOut());
148+
}
149+
}
150+
}

core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java

Lines changed: 4 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import io.substrait.expression.FieldReference.ReferenceSegment;
77
import io.substrait.expression.FunctionArg;
88
import io.substrait.expression.FunctionOption;
9+
import io.substrait.expression.LambdaBuilder;
910
import io.substrait.expression.WindowBound;
1011
import io.substrait.extension.ExtensionLookup;
1112
import io.substrait.extension.SimpleExtension;
@@ -37,7 +38,7 @@ public class ProtoExpressionConverter {
3738
private final Type.Struct rootType;
3839
private final ProtoTypeConverter protoTypeConverter;
3940
private final ProtoRelConverter protoRelConverter;
40-
private final LambdaParameterStack lambdaParameterStack = new LambdaParameterStack();
41+
private final LambdaBuilder lambdaBuilder = new LambdaBuilder();
4142

4243
public ProtoExpressionConverter(
4344
ExtensionLookup lookup,
@@ -82,7 +83,7 @@ public FieldReference from(io.substrait.proto.Expression.FieldReference referenc
8283
reference.getLambdaParameterReference();
8384

8485
int stepsOut = lambdaParamRef.getStepsOut();
85-
Type.Struct lambdaParameters = lambdaParameterStack.get(stepsOut);
86+
Type.Struct lambdaParameters = lambdaBuilder.resolveParams(stepsOut);
8687

8788
// Check for unsupported nested field access
8889
if (reference.getDirectReference().getStructField().hasChild()) {
@@ -290,16 +291,7 @@ public Type visit(Type.Struct type) throws RuntimeException {
290291
.setStruct(protoLambda.getParameters())
291292
.build());
292293

293-
lambdaParameterStack.push(parameters);
294-
295-
Expression body;
296-
try {
297-
body = from(protoLambda.getBody());
298-
} finally {
299-
lambdaParameterStack.pop();
300-
}
301-
302-
return Expression.Lambda.builder().parameters(parameters).body(body).build();
294+
return lambdaBuilder.lambdaFromStruct(parameters, () -> from(protoLambda.getBody()));
303295
}
304296
// TODO enum.
305297
case ENUM:
@@ -620,42 +612,4 @@ public Expression.SortField fromSortField(SortField s) {
620612
public static FunctionOption fromFunctionOption(io.substrait.proto.FunctionOption o) {
621613
return FunctionOption.builder().name(o.getName()).addAllValues(o.getPreferenceList()).build();
622614
}
623-
624-
/**
625-
* A stack for tracking lambda parameter types during expression parsing.
626-
*
627-
* <p>When parsing nested lambda expressions, each lambda's parameters are pushed onto this stack.
628-
* Lambda parameter references use "stepsOut" to indicate which enclosing lambda they reference:
629-
*
630-
* <ul>
631-
* <li>stepsOut=0 refers to the innermost (current) lambda
632-
* <li>stepsOut=1 refers to the next enclosing lambda
633-
* <li>stepsOut=N refers to N levels up
634-
* </ul>
635-
*/
636-
private static class LambdaParameterStack {
637-
private final List<Type.Struct> stack = new ArrayList<>();
638-
639-
void push(Type.Struct parameters) {
640-
stack.add(parameters);
641-
}
642-
643-
void pop() {
644-
if (stack.isEmpty()) {
645-
throw new IllegalArgumentException("Lambda parameter stack is empty");
646-
}
647-
stack.remove(stack.size() - 1);
648-
}
649-
650-
Type.Struct get(int stepsOut) {
651-
int index = stack.size() - 1 - stepsOut;
652-
if (index < 0 || index >= stack.size()) {
653-
throw new IllegalArgumentException(
654-
String.format(
655-
"Lambda parameter reference with stepsOut=%d is invalid (current depth: %d)",
656-
stepsOut, stack.size()));
657-
}
658-
return stack.get(index);
659-
}
660-
}
661615
}

core/src/main/java/io/substrait/relation/ExpressionCopyOnWriteVisitor.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import io.substrait.expression.ExpressionVisitor;
99
import io.substrait.expression.FieldReference;
1010
import io.substrait.expression.FunctionArg;
11+
import io.substrait.expression.ImmutableExpression;
1112
import io.substrait.util.EmptyVisitationContext;
1213
import java.util.List;
1314
import java.util.Optional;
@@ -448,7 +449,10 @@ public Optional<Expression> visit(Expression.Lambda lambda, EmptyVisitationConte
448449
return Optional.empty();
449450
}
450451
return Optional.of(
451-
Expression.Lambda.builder().from(lambda).body(newBody.orElse(lambda.body())).build());
452+
ImmutableExpression.Lambda.builder()
453+
.from(lambda)
454+
.body(newBody.orElse(lambda.body()))
455+
.build());
452456
}
453457

454458
// utilities
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
package io.substrait.expression;
2+
3+
import static org.junit.jupiter.api.Assertions.assertEquals;
4+
import static org.junit.jupiter.api.Assertions.assertThrows;
5+
6+
import io.substrait.type.Type;
7+
import io.substrait.type.TypeCreator;
8+
import java.util.List;
9+
import org.junit.jupiter.api.Test;
10+
11+
/** Tests for {@link LambdaBuilder}. */
12+
class LambdaBuilderTest {
13+
14+
static final TypeCreator R = TypeCreator.REQUIRED;
15+
16+
final LambdaBuilder lb = new LambdaBuilder();
17+
18+
// (x: i32)@p -> p[0]
19+
@Test
20+
void simpleLambda() {
21+
Expression.Lambda lambda = lb.lambda(List.of(R.I32), params -> params.ref(0));
22+
23+
Expression.Lambda expected =
24+
ImmutableExpression.Lambda.builder()
25+
.parameters(Type.Struct.builder().nullable(false).addFields(R.I32).build())
26+
.body(
27+
FieldReference.newLambdaParameterReference(
28+
0, Type.Struct.builder().nullable(false).addFields(R.I32).build(), 0))
29+
.build();
30+
31+
assertEquals(expected, lambda);
32+
}
33+
34+
// (x: i32)@outer -> (y: i64)@inner -> outer[0]
35+
@Test
36+
void nestedLambda() {
37+
Expression.Lambda lambda =
38+
lb.lambda(List.of(R.I32), outer -> lb.lambda(List.of(R.I64), inner -> outer.ref(0)));
39+
40+
Expression.Lambda expectedInner =
41+
ImmutableExpression.Lambda.builder()
42+
.parameters(Type.Struct.builder().nullable(false).addFields(R.I64).build())
43+
.body(
44+
FieldReference.newLambdaParameterReference(
45+
0, Type.Struct.builder().nullable(false).addFields(R.I32).build(), 1))
46+
.build();
47+
48+
Expression.Lambda expected =
49+
ImmutableExpression.Lambda.builder()
50+
.parameters(Type.Struct.builder().nullable(false).addFields(R.I32).build())
51+
.body(expectedInner)
52+
.build();
53+
54+
assertEquals(expected, lambda);
55+
}
56+
57+
// Verify that the same scope handle produces different stepsOut values depending on nesting.
58+
// outer.ref(0) should produce stepsOut=0 at the top level and stepsOut=1 inside a nested lambda.
59+
@Test
60+
void scopeStepsOutChangesDynamically() {
61+
lb.lambda(
62+
List.of(R.I32),
63+
outer -> {
64+
FieldReference atTopLevel = outer.ref(0);
65+
assertEquals(0, atTopLevel.lambdaParameterReferenceStepsOut().orElse(-1));
66+
67+
lb.lambda(
68+
List.of(R.I64),
69+
inner -> {
70+
FieldReference atNestedLevel = outer.ref(0);
71+
assertEquals(1, atNestedLevel.lambdaParameterReferenceStepsOut().orElse(-1));
72+
return inner.ref(0);
73+
});
74+
75+
return atTopLevel;
76+
});
77+
}
78+
79+
// (x: i32)@p -> p[5] — only 1 param, index 5 is out of bounds
80+
@Test
81+
void invalidFieldIndex_outOfBounds() {
82+
assertThrows(
83+
IndexOutOfBoundsException.class, () -> lb.lambda(List.of(R.I32), params -> params.ref(5)));
84+
}
85+
86+
// (x: i32)@p -> p[-1] — negative index
87+
@Test
88+
void negativeFieldIndex() {
89+
assertThrows(Exception.class, () -> lb.lambda(List.of(R.I32), params -> params.ref(-1)));
90+
}
91+
92+
// (x: i32)@outer -> (y: i64)@inner -> outer[5] — outer only has 1 param
93+
@Test
94+
void nestedOuterFieldIndexOutOfBounds() {
95+
assertThrows(
96+
IndexOutOfBoundsException.class,
97+
() -> lb.lambda(List.of(R.I32), outer -> lb.lambda(List.of(R.I64), inner -> outer.ref(5))));
98+
}
99+
100+
// (x: i32)@outer -> (y: i64)@inner -> inner[3] — inner only has 1 param
101+
@Test
102+
void nestedInnerFieldIndexOutOfBounds() {
103+
assertThrows(
104+
IndexOutOfBoundsException.class,
105+
() -> lb.lambda(List.of(R.I32), outer -> lb.lambda(List.of(R.I64), inner -> inner.ref(3))));
106+
}
107+
}

0 commit comments

Comments
 (0)