Skip to content

Commit e8d31e6

Browse files
maskri17copybara-github
authored andcommitted
Adding runtime support for two variable comprehensions
PiperOrigin-RevId: 797506843
1 parent 1ee7d1d commit e8d31e6

7 files changed

Lines changed: 292 additions & 23 deletions

File tree

checker/src/main/java/dev/cel/checker/ExprChecker.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,7 @@ private CelExpr visit(CelExpr expr, CelExpr.CelComprehension compre) {
530530
case DYN:
531531
case ERROR:
532532
varType = SimpleType.DYN;
533+
varType2 = SimpleType.DYN;
533534
break;
534535
case TYPE_PARAM:
535536
// Mark the range as DYN to avoid its free variable being associated with the wrong type
@@ -538,6 +539,7 @@ private CelExpr visit(CelExpr expr, CelExpr.CelComprehension compre) {
538539
inferenceContext.isAssignable(SimpleType.DYN, rangeType);
539540
// Mark the variable type as DYN.
540541
varType = SimpleType.DYN;
542+
varType2 = SimpleType.DYN;
541543
break;
542544
default:
543545
env.reportError(
@@ -547,6 +549,7 @@ private CelExpr visit(CelExpr expr, CelExpr.CelComprehension compre) {
547549
+ "(must be list, map, or dynamic)",
548550
CelTypes.format(rangeType));
549551
varType = SimpleType.DYN;
552+
varType2 = SimpleType.DYN;
550553
break;
551554
}
552555

extensions/src/main/java/dev/cel/extensions/CelComprehensionsExtensions.java

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,13 @@ private static Optional<CelExpr> expandAllMacro(
5959
checkNotNull(exprFactory);
6060
checkNotNull(target);
6161
checkArgument(arguments.size() == 3);
62-
CelExpr arg0 = checkNotNull(arguments.get(0));
62+
CelExpr arg0 = validatedIterationVariable(exprFactory, arguments.get(0));
6363
if (arg0.exprKind().getKind() != CelExpr.ExprKind.Kind.IDENT) {
64-
return Optional.of(reportArgumentError(exprFactory, arg0));
64+
return Optional.of(arg0);
6565
}
66-
CelExpr arg1 = checkNotNull(arguments.get(1));
66+
CelExpr arg1 = validatedIterationVariable(exprFactory, arguments.get(1));
6767
if (arg1.exprKind().getKind() != CelExpr.ExprKind.Kind.IDENT) {
68-
return Optional.of(reportArgumentError(exprFactory, arg1));
68+
return Optional.of(arg1);
6969
}
7070
CelExpr arg2 = checkNotNull(arguments.get(2));
7171
CelExpr accuInit = exprFactory.newBoolLiteral(true);
@@ -96,13 +96,13 @@ private static Optional<CelExpr> expandExistsMacro(
9696
checkNotNull(exprFactory);
9797
checkNotNull(target);
9898
checkArgument(arguments.size() == 3);
99-
CelExpr arg0 = checkNotNull(arguments.get(0));
99+
CelExpr arg0 = validatedIterationVariable(exprFactory, arguments.get(0));
100100
if (arg0.exprKind().getKind() != CelExpr.ExprKind.Kind.IDENT) {
101-
return Optional.of(reportArgumentError(exprFactory, arg0));
101+
return Optional.of(arg0);
102102
}
103-
CelExpr arg1 = checkNotNull(arguments.get(1));
103+
CelExpr arg1 = validatedIterationVariable(exprFactory, arguments.get(1));
104104
if (arg1.exprKind().getKind() != CelExpr.ExprKind.Kind.IDENT) {
105-
return Optional.of(reportArgumentError(exprFactory, arg1));
105+
return Optional.of(arg1);
106106
}
107107
CelExpr arg2 = checkNotNull(arguments.get(2));
108108
CelExpr accuInit = exprFactory.newBoolLiteral(false);
@@ -135,13 +135,13 @@ private static Optional<CelExpr> expandExistsOneMacro(
135135
checkNotNull(exprFactory);
136136
checkNotNull(target);
137137
checkArgument(arguments.size() == 3);
138-
CelExpr arg0 = checkNotNull(arguments.get(0));
138+
CelExpr arg0 = validatedIterationVariable(exprFactory, arguments.get(0));
139139
if (arg0.exprKind().getKind() != CelExpr.ExprKind.Kind.IDENT) {
140-
return Optional.of(reportArgumentError(exprFactory, arg0));
140+
return Optional.of(arg0);
141141
}
142-
CelExpr arg1 = checkNotNull(arguments.get(1));
142+
CelExpr arg1 = validatedIterationVariable(exprFactory, arguments.get(1));
143143
if (arg1.exprKind().getKind() != CelExpr.ExprKind.Kind.IDENT) {
144-
return Optional.of(reportArgumentError(exprFactory, arg1));
144+
return Optional.of(arg1);
145145
}
146146
CelExpr arg2 = checkNotNull(arguments.get(2));
147147
CelExpr accuInit = exprFactory.newIntLiteral(0);
@@ -177,13 +177,13 @@ private static Optional<CelExpr> transformListMacro(
177177
checkNotNull(exprFactory);
178178
checkNotNull(target);
179179
checkArgument(arguments.size() == 3 || arguments.size() == 4);
180-
CelExpr arg0 = checkNotNull(arguments.get(0));
180+
CelExpr arg0 = validatedIterationVariable(exprFactory, arguments.get(0));
181181
if (arg0.exprKind().getKind() != CelExpr.ExprKind.Kind.IDENT) {
182-
return Optional.of(reportArgumentError(exprFactory, arg0));
182+
return Optional.of(arg0);
183183
}
184-
CelExpr arg1 = checkNotNull(arguments.get(1));
184+
CelExpr arg1 = validatedIterationVariable(exprFactory, arguments.get(1));
185185
if (arg1.exprKind().getKind() != CelExpr.ExprKind.Kind.IDENT) {
186-
return Optional.of(reportArgumentError(exprFactory, arg1));
186+
return Optional.of(arg1);
187187
}
188188
CelExpr transform;
189189
CelExpr filter = null;
@@ -220,9 +220,30 @@ private static Optional<CelExpr> transformListMacro(
220220
exprFactory.newIdentifier(exprFactory.getAccumulatorVarName())));
221221
}
222222

223+
private static CelExpr validatedIterationVariable(
224+
CelMacroExprFactory exprFactory, CelExpr argument) {
225+
226+
CelExpr arg = checkNotNull(argument);
227+
if (arg.exprKind().getKind() != CelExpr.ExprKind.Kind.IDENT) {
228+
return reportArgumentError(exprFactory, arg);
229+
} else if (arg.exprKind().ident().name().equals("__result__")) {
230+
return reportAccumulatorOverwriteError(exprFactory, arg);
231+
} else {
232+
return arg;
233+
}
234+
}
235+
223236
private static CelExpr reportArgumentError(CelMacroExprFactory exprFactory, CelExpr argument) {
224237
return exprFactory.reportError(
225238
CelIssue.formatError(
226239
exprFactory.getSourceLocation(argument), "The argument must be a simple name"));
227240
}
241+
242+
private static CelExpr reportAccumulatorOverwriteError(
243+
CelMacroExprFactory exprFactory, CelExpr argument) {
244+
return exprFactory.reportError(
245+
CelIssue.formatError(
246+
exprFactory.getSourceLocation(argument),
247+
"The iteration variable overwrites accumulator variable"));
248+
}
228249
}

extensions/src/main/java/dev/cel/extensions/CelExtensions.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -315,9 +315,7 @@ public static CelRegexExtensions regex() {
315315
* <p>This will include all functions denoted in {@link CelComprehensionsExtensions.Function},
316316
* including any future additions.
317317
*/
318-
// TODO: Remove visibility restrictions and make this public once the feature is
319-
// ready.
320-
private static CelComprehensionsExtensions comprehensions() {
318+
public static CelComprehensionsExtensions comprehensions() {
321319
return COMPREHENSIONS_EXTENSIONS;
322320
}
323321

extensions/src/test/java/dev/cel/extensions/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ java_library(
3030
"//extensions:sets_function",
3131
"//extensions:strings",
3232
"//parser:macro",
33+
"//parser:unparser",
3334
"//runtime",
3435
"//runtime:function_binding",
3536
"//runtime:interpreter_util",
Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package dev.cel.extensions;
16+
17+
import static com.google.common.truth.Truth.assertThat;
18+
import static org.junit.Assert.assertThrows;
19+
20+
import com.google.testing.junit.testparameterinjector.TestParameter;
21+
import com.google.testing.junit.testparameterinjector.TestParameterInjector;
22+
import com.google.testing.junit.testparameterinjector.TestParameters;
23+
import dev.cel.common.CelAbstractSyntaxTree;
24+
import dev.cel.common.CelOptions;
25+
import dev.cel.common.CelValidationException;
26+
import dev.cel.compiler.CelCompiler;
27+
import dev.cel.compiler.CelCompilerFactory;
28+
import dev.cel.parser.CelStandardMacro;
29+
import dev.cel.parser.CelUnparser;
30+
import dev.cel.parser.CelUnparserFactory;
31+
import dev.cel.runtime.CelEvaluationException;
32+
import dev.cel.runtime.CelRuntime;
33+
import dev.cel.runtime.CelRuntimeFactory;
34+
import org.junit.Test;
35+
import org.junit.runner.RunWith;
36+
37+
/** Test for {@link CelExtensions#comprehensions()} */
38+
@RunWith(TestParameterInjector.class)
39+
public class CelComprehensionsExtensionsTest {
40+
41+
private static final CelOptions CEL_OPTIONS =
42+
CelOptions.current()
43+
// Enable macro call population for unparsing
44+
.populateMacroCalls(true)
45+
.build();
46+
47+
private static final CelCompiler CEL_COMPILER =
48+
CelCompilerFactory.standardCelCompilerBuilder()
49+
.setOptions(CEL_OPTIONS)
50+
.setStandardMacros(CelStandardMacro.STANDARD_MACROS)
51+
.addLibraries(CelExtensions.comprehensions())
52+
.addLibraries(CelExtensions.lists())
53+
.addLibraries(CelOptionalLibrary.INSTANCE, CelExtensions.bindings())
54+
.build();
55+
private static final CelRuntime CEL_RUNTIME =
56+
CelRuntimeFactory.standardCelRuntimeBuilder()
57+
.addLibraries(CelOptionalLibrary.INSTANCE)
58+
.addLibraries(CelExtensions.lists())
59+
.build();
60+
61+
private static final CelUnparser UNPARSER = CelUnparserFactory.newUnparser();
62+
63+
@Test
64+
public void allMacro_twoVarComprehension_success(
65+
@TestParameter({
66+
// list.all()
67+
"[1, 2, 3, 4].all(i, v, i < 5 && v > 0)",
68+
"[1, 2, 3, 4].all(i, v, i < v)",
69+
"[1, 2, 3, 4].all(i, v, i > v) == false",
70+
"cel.bind(listA, [1, 2, 3, 4], cel.bind(listB, [1, 2, 3, 4, 5], listA.all(i, v,"
71+
+ " listB[?i].hasValue() && listB[i] == v)))",
72+
"cel.bind(listA, [1, 2, 3, 4, 5, 6], cel.bind(listB, [1, 2, 3, 4, 5], listA.all(i, v,"
73+
+ " listB[?i].hasValue() && listB[i] == v))) == false",
74+
// map.all()
75+
"{'hello': 'world', 'hello!': 'world'}.all(k, v, k.startsWith('hello') && v =="
76+
+ " 'world')",
77+
"{'hello': 'world', 'hello!': 'worlds'}.all(k, v, k.startsWith('hello') &&"
78+
+ " v.endsWith('world')) == false",
79+
"{'a': 1, 'b': 2}.all(k, v, k.startsWith('a') && v == 1) == false",
80+
})
81+
String expr)
82+
throws Exception {
83+
CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst();
84+
85+
Object result = CEL_RUNTIME.createProgram(ast).eval();
86+
87+
assertThat(result).isEqualTo(true);
88+
}
89+
90+
@Test
91+
public void existsMacro_twoVarComprehension_success(
92+
@TestParameter({
93+
// list.exists()
94+
"[1, 2, 3, 4].exists(i, v, i > 2 && v < 5)",
95+
"[10, 1, 30].exists(i, v, i == v)",
96+
"[].exists(i, v, true) == false",
97+
"cel.bind(l, ['hello', 'world', 'hello!', 'worlds'], l.exists(i, v,"
98+
+ " v.startsWith('hello') && l[?(i+1)].optMap(next,"
99+
+ " next.endsWith('world')).orValue(false)))",
100+
// map.exists()
101+
"{'hello': 'world', 'hello!': 'worlds'}.exists(k, v, k.startsWith('hello') &&"
102+
+ " v.endsWith('world'))",
103+
"{}.exists(k, v, true) == false",
104+
"{'a': 1, 'b': 2}.exists(k, v, v == 3) == false",
105+
"{'a': 'b', 'c': 'c'}.exists(k, v, k == v)"
106+
})
107+
String expr)
108+
throws Exception {
109+
CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst();
110+
111+
Object result = CEL_RUNTIME.createProgram(ast).eval();
112+
113+
assertThat(result).isEqualTo(true);
114+
}
115+
116+
@Test
117+
public void exists_oneMacro_twoVarComprehension_success(
118+
@TestParameter({
119+
// list.exists_one()
120+
"[0, 5, 6].exists_one(i, v, i == v)",
121+
"[0, 1, 5].exists_one(i, v, i == v) == false",
122+
"[10, 11, 12].exists_one(i, v, i == v) == false",
123+
"cel.bind(l, ['hello', 'world', 'hello!', 'worlds'], l.exists_one(i, v,"
124+
+ " v.startsWith('hello') && l[?(i+1)].optMap(next,"
125+
+ " next.endsWith('world')).orValue(false)))",
126+
"cel.bind(l, ['hello', 'goodbye', 'hello!', 'goodbye'], l.exists_one(i, v,"
127+
+ " v.startsWith('hello') && l[?(i+1)].optMap(next, next =="
128+
+ " 'goodbye').orValue(false))) == false",
129+
// map.exists_one()
130+
"{'hello': 'world', 'hello!': 'worlds'}.exists_one(k, v, k.startsWith('hello') &&"
131+
+ " v.endsWith('world'))",
132+
"{'hello': 'world', 'hello!': 'wow, world'}.exists_one(k, v, k.startsWith('hello') &&"
133+
+ " v.endsWith('world')) == false",
134+
"{'a': 1, 'b': 1}.exists_one(k, v, v == 2) == false"
135+
})
136+
String expr)
137+
throws Exception {
138+
CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst();
139+
140+
Object result = CEL_RUNTIME.createProgram(ast).eval();
141+
142+
assertThat(result).isEqualTo(true);
143+
}
144+
145+
@Test
146+
public void transformListMacro_twoVarComprehension_success(
147+
@TestParameter({
148+
// list.transformList()
149+
"[1, 2, 3].transformList(i, v, (i * v) + v) == [1, 4, 9]",
150+
"[1, 2, 3].transformList(i, v, i % 2 == 0, (i * v) + v) == [1, 9]",
151+
"[1, 2, 3].transformList(i, v, i > 0 && v < 3, (i * v) + v) == [4]",
152+
"[1, 2, 3].transformList(i, v, i % 2 == 0, (i * v) + v) == [1, 9]",
153+
"[1, 2, 3].transformList(i, v, (i * v) + v) == [1, 4, 9]",
154+
// map.transformList()
155+
"{'greeting': 'hello', 'farewell': 'goodbye'}.transformList(k, _, k).sort() =="
156+
+ " ['farewell', 'greeting']",
157+
"{'greeting': 'hello', 'farewell': 'goodbye'}.transformList(_, v, v).sort() =="
158+
+ " ['goodbye', 'hello']"
159+
})
160+
String expr)
161+
throws Exception {
162+
CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst();
163+
164+
Object result = CEL_RUNTIME.createProgram(ast).eval();
165+
166+
assertThat(result).isEqualTo(true);
167+
}
168+
169+
@Test
170+
public void unparseAST_twoVarComprehension(
171+
@TestParameter({
172+
"cel.bind(listA, [1, 2, 3, 4], cel.bind(listB, [1, 2, 3, 4, 5], listA.all(i, v,"
173+
+ " listB[?i].hasValue() && listB[i] == v)))",
174+
"[1, 2, 3, 4].exists(i, v, i > 2 && v < 5)",
175+
"{\"a\": 1, \"b\": 1}.exists_one(k, v, v == 2) == false",
176+
"[1, 2, 3].transformList(i, v, i > 0 && v < 3, i * v + v) == [4]",
177+
"[1, 2, 2].transformList(i, v, i / 2 == 1)",
178+
"{\"a\": \"b\", \"c\": \"d\"}.exists_one(k, v, k == \"b\" || v == \"b\")",
179+
"{\"a\": \"b\", \"c\": \"d\"}.exists(k, v, k == \"b\" || v == \"b\")",
180+
"[null, null, \"hello\", string].all(i, v, i == 0 || type(v) != int)"
181+
})
182+
String expr)
183+
throws Exception {
184+
CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst();
185+
String unparsed = UNPARSER.unparse(ast);
186+
assertThat(unparsed).isEqualTo(expr);
187+
}
188+
189+
@Test
190+
@TestParameters("{expr: '[].all(i.j, k, i.j < k)', err: 'The argument must be a simple name'}")
191+
@TestParameters("{expr: '1.exists_one(j, j < 5)', err: 'cannot be range'}")
192+
@TestParameters("{expr: '1.exists_one(j, k, j < k)', err: 'cannot be range'}")
193+
@TestParameters(
194+
"{expr: '[].transformList(__result__, i, __result__ < i)', err: 'The iteration variable"
195+
+ " overwrites accumulator variable'}")
196+
@TestParameters(
197+
"{expr: '[].exists(j, __result__, __result__ < j)', err: 'The iteration variable"
198+
+ " overwrites accumulator variable'}")
199+
public void twoVarComprehension_compilerErrors(String expr, String err) throws Exception {
200+
CelValidationException e =
201+
assertThrows(CelValidationException.class, () -> CEL_COMPILER.compile(expr).getAst());
202+
203+
assertThat(e).hasMessageThat().contains(err);
204+
}
205+
206+
@Test
207+
public void twoVarComprehension_arithematicRuntimeError() throws Exception {
208+
CelAbstractSyntaxTree ast = CEL_COMPILER.compile("[0].all(i, k, i/k < k)").getAst();
209+
210+
CelEvaluationException e =
211+
assertThrows(CelEvaluationException.class, () -> CEL_RUNTIME.createProgram(ast).eval());
212+
213+
assertThat(e).hasCauseThat().isInstanceOf(ArithmeticException.class);
214+
assertThat(e).hasCauseThat().hasMessageThat().contains("/ by zero");
215+
}
216+
217+
@Test
218+
public void twoVarComprehension_outOfBoundsRuntimeError() throws Exception {
219+
CelAbstractSyntaxTree ast = CEL_COMPILER.compile("[1, 2].exists(i, v, [0][v] > 0)").getAst();
220+
221+
CelEvaluationException e =
222+
assertThrows(CelEvaluationException.class, () -> CEL_RUNTIME.createProgram(ast).eval());
223+
224+
assertThat(e).hasCauseThat().isInstanceOf(IndexOutOfBoundsException.class);
225+
assertThat(e).hasCauseThat().hasMessageThat().contains("Index out of bounds: 1");
226+
}
227+
}

runtime/src/main/java/dev/cel/runtime/BUILD.bazel

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ java_library(
309309
"@maven//:com_google_errorprone_error_prone_annotations",
310310
"@maven//:com_google_guava_guava",
311311
"@maven//:com_google_protobuf_protobuf_java",
312-
"@maven//:org_jspecify_jspecify",
312+
"@maven_android//:com_google_protobuf_protobuf_javalite",
313313
],
314314
)
315315

0 commit comments

Comments
 (0)