Skip to content

Commit 5e713c4

Browse files
maskri17copybara-github
authored andcommitted
Checker and parser changes to support comprehensionsV2
PiperOrigin-RevId: 788649636
1 parent 5b06cc0 commit 5e713c4

11 files changed

Lines changed: 290 additions & 1 deletion

File tree

checker/src/main/java/dev/cel/checker/ExprChecker.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import com.google.auto.value.AutoValue;
2323
import com.google.common.base.Joiner;
2424
import com.google.common.base.Optional;
25+
import com.google.common.base.Strings;
2526
import com.google.common.collect.ImmutableList;
2627
import com.google.common.collect.Maps;
2728
import com.google.errorprone.annotations.CheckReturnValue;
@@ -510,13 +511,21 @@ private CelExpr visit(CelExpr expr, CelExpr.CelComprehension compre) {
510511
CelType accuType = env.getType(visitedInit);
511512
CelType rangeType = inferenceContext.specialize(env.getType(visitedRange));
512513
CelType varType;
514+
CelType varType2 = null;
513515
switch (rangeType.kind()) {
514516
case LIST:
515517
varType = ((ListType) rangeType).elemType();
518+
if (!Strings.isNullOrEmpty(compre.iterVar2())) {
519+
varType2 = varType;
520+
varType = SimpleType.INT;
521+
}
516522
break;
517523
case MAP:
518524
// Ranges over the keys.
519525
varType = ((MapType) rangeType).keyType();
526+
if (!Strings.isNullOrEmpty(compre.iterVar2())) {
527+
varType2 = ((MapType) rangeType).valueType();
528+
}
520529
break;
521530
case DYN:
522531
case ERROR:
@@ -547,6 +556,9 @@ private CelExpr visit(CelExpr expr, CelExpr.CelComprehension compre) {
547556
// Declare iteration variable on inner scope.
548557
env.enterScope();
549558
env.add(CelIdentDecl.newIdentDeclaration(compre.iterVar(), varType));
559+
if (!Strings.isNullOrEmpty(compre.iterVar2())) {
560+
env.add(CelIdentDecl.newIdentDeclaration(compre.iterVar2(), varType2));
561+
}
550562
CelExpr condition = visit(compre.loopCondition());
551563
assertType(condition, SimpleType.BOOL);
552564
CelExpr visitedStep = visit(compre.loopStep());

checker/src/test/java/dev/cel/checker/ExprCheckerTest.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -792,6 +792,27 @@ public void quantifiers() throws Exception {
792792
runTest();
793793
}
794794

795+
@Test
796+
public void twoVarComprehensions() throws Exception {
797+
CelType messageType = StructTypeReference.create("cel.expr.conformance.proto3.TestAllTypes");
798+
declareVariable("x", messageType);
799+
source =
800+
"x.map_string_string.all(i, v, i < v) "
801+
+ "&& x.repeated_int64.all(i, v, i < v) "
802+
+ "&& [1, 2, 3, 4].all(i, v, i < 5 && v > 0)"
803+
+ "&& {'a': 1, 'b': 2}.all(k, v, k.startsWith('a') && v == 1)";
804+
runTest();
805+
}
806+
807+
@Test
808+
public void twoVarComprehensionsErrors() throws Exception {
809+
CelType messageType = StructTypeReference.create("cel.expr.conformance.proto3.TestAllTypes");
810+
declareVariable("x", messageType);
811+
source =
812+
"x.map_string_string.all(i + 1, v, i < v) " + "&& x.repeated_int64.all(i, v + 1, i < v)";
813+
runTest();
814+
}
815+
795816
@Test
796817
public void quantifiersErrors() throws Exception {
797818
CelType messageType = StructTypeReference.create("cel.expr.conformance.proto3.TestAllTypes");
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
Source: x.map_string_string.all(i, v, i < v) && x.repeated_int64.all(i, v, i < v) && [1, 2, 3, 4].all(i, v, i < 5 && v > 0)&& {'a': 1, 'b': 2}.all(k, v, k.startsWith('a') && v == 1)
2+
declare x {
3+
value cel.expr.conformance.proto3.TestAllTypes
4+
}
5+
=====>
6+
_&&_(
7+
_&&_(
8+
__comprehension__(
9+
// Variable
10+
i,
11+
// Target
12+
x~cel.expr.conformance.proto3.TestAllTypes^x.map_string_string~map(string, string),
13+
// Accumulator
14+
@result,
15+
// Init
16+
true~bool,
17+
// LoopCondition
18+
@not_strictly_false(
19+
@result~bool^@result
20+
)~bool^not_strictly_false,
21+
// LoopStep
22+
_&&_(
23+
@result~bool^@result,
24+
_<_(
25+
i~string^i,
26+
v~string^v
27+
)~bool^less_string
28+
)~bool^logical_and,
29+
// Result
30+
@result~bool^@result)~bool,
31+
__comprehension__(
32+
// Variable
33+
i,
34+
// Target
35+
x~cel.expr.conformance.proto3.TestAllTypes^x.repeated_int64~list(int),
36+
// Accumulator
37+
@result,
38+
// Init
39+
true~bool,
40+
// LoopCondition
41+
@not_strictly_false(
42+
@result~bool^@result
43+
)~bool^not_strictly_false,
44+
// LoopStep
45+
_&&_(
46+
@result~bool^@result,
47+
_<_(
48+
i~int^i,
49+
v~int^v
50+
)~bool^less_int64
51+
)~bool^logical_and,
52+
// Result
53+
@result~bool^@result)~bool
54+
)~bool^logical_and,
55+
_&&_(
56+
__comprehension__(
57+
// Variable
58+
i,
59+
// Target
60+
[
61+
1~int,
62+
2~int,
63+
3~int,
64+
4~int
65+
]~list(int),
66+
// Accumulator
67+
@result,
68+
// Init
69+
true~bool,
70+
// LoopCondition
71+
@not_strictly_false(
72+
@result~bool^@result
73+
)~bool^not_strictly_false,
74+
// LoopStep
75+
_&&_(
76+
@result~bool^@result,
77+
_&&_(
78+
_<_(
79+
i~int^i,
80+
5~int
81+
)~bool^less_int64,
82+
_>_(
83+
v~int^v,
84+
0~int
85+
)~bool^greater_int64
86+
)~bool^logical_and
87+
)~bool^logical_and,
88+
// Result
89+
@result~bool^@result)~bool,
90+
__comprehension__(
91+
// Variable
92+
k,
93+
// Target
94+
{
95+
"a"~string:1~int,
96+
"b"~string:2~int
97+
}~map(string, int),
98+
// Accumulator
99+
@result,
100+
// Init
101+
true~bool,
102+
// LoopCondition
103+
@not_strictly_false(
104+
@result~bool^@result
105+
)~bool^not_strictly_false,
106+
// LoopStep
107+
_&&_(
108+
@result~bool^@result,
109+
_&&_(
110+
k~string^k.startsWith(
111+
"a"~string
112+
)~bool^starts_with_string,
113+
_==_(
114+
v~int^v,
115+
1~int
116+
)~bool^equals
117+
)~bool^logical_and
118+
)~bool^logical_and,
119+
// Result
120+
@result~bool^@result)~bool
121+
)~bool^logical_and
122+
)~bool^logical_and
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
Source: x.map_string_string.all(i + 1, v, i < v) && x.repeated_int64.all(i, v + 1, i < v)
2+
declare x {
3+
value cel.expr.conformance.proto3.TestAllTypes
4+
}
5+
=====>
6+
ERROR: test_location:1:27: The argument must be a simple name
7+
| x.map_string_string.all(i + 1, v, i < v) && x.repeated_int64.all(i, v + 1, i < v)
8+
| ..........................^
9+
ERROR: test_location:1:71: The argument must be a simple name
10+
| x.map_string_string.all(i + 1, v, i < v) && x.repeated_int64.all(i, v + 1, i < v)
11+
| ......................................................................^

common/src/main/java/dev/cel/common/ast/CelExprFactory.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ public final CelExpr.CelStruct.Entry newMessageField(String field, CelExpr value
143143
.build();
144144
}
145145

146-
/** Fold creates a fold comprehension instruction. */
146+
/** Fold creates a fold for one variable comprehension instruction. */
147147
public final CelExpr fold(
148148
String iterVar,
149149
CelExpr iterRange,

extensions/src/main/java/dev/cel/extensions/BUILD.bazel

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ java_library(
3030
],
3131
deps = [
3232
":bindings",
33+
":comprehensions",
3334
":encoders",
3435
":lists",
3536
":math",
@@ -289,3 +290,18 @@ java_library(
289290
"@maven//:com_google_re2j_re2j",
290291
],
291292
)
293+
294+
java_library(
295+
name = "comprehensions",
296+
srcs = ["CelComprehensionsExtensions.java"],
297+
visibility = ["//:internal"],
298+
deps = [
299+
"//common:compiler_common",
300+
"//common/ast",
301+
"//compiler:compiler_builder",
302+
"//parser:macro",
303+
"//parser:operator",
304+
"//parser:parser_builder",
305+
"@maven//:com_google_guava_guava",
306+
],
307+
)
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package dev.cel.extensions;
16+
17+
import static com.google.common.base.Preconditions.checkArgument;
18+
import static com.google.common.base.Preconditions.checkNotNull;
19+
20+
import com.google.common.collect.ImmutableList;
21+
import com.google.common.collect.ImmutableSet;
22+
import dev.cel.common.CelIssue;
23+
import dev.cel.common.ast.CelExpr;
24+
import dev.cel.compiler.CelCompilerLibrary;
25+
import dev.cel.parser.CelMacro;
26+
import dev.cel.parser.CelMacroExprFactory;
27+
import dev.cel.parser.CelParserBuilder;
28+
import dev.cel.parser.Operator;
29+
import java.util.Optional;
30+
31+
/** Internal implementation of CEL comprehensions extensions. */
32+
public final class CelComprehensionsExtensions implements CelCompilerLibrary {
33+
34+
public ImmutableSet<CelMacro> macros() {
35+
return ImmutableSet.of(
36+
CelMacro.newReceiverMacro(
37+
Operator.ALL.getFunction(), 3, CelComprehensionsExtensions::expandAllMacro));
38+
}
39+
40+
@Override
41+
public void setParserOptions(CelParserBuilder parserBuilder) {
42+
parserBuilder.addMacros(macros());
43+
}
44+
45+
private static Optional<CelExpr> expandAllMacro(
46+
CelMacroExprFactory exprFactory, CelExpr target, ImmutableList<CelExpr> arguments) {
47+
checkNotNull(exprFactory);
48+
checkNotNull(target);
49+
checkArgument(arguments.size() == 3);
50+
CelExpr arg0 = checkNotNull(arguments.get(0));
51+
if (arg0.exprKind().getKind() != CelExpr.ExprKind.Kind.IDENT) {
52+
return Optional.of(reportArgumentError(exprFactory, arg0));
53+
}
54+
CelExpr arg1 = checkNotNull(arguments.get(1));
55+
if (arg1.exprKind().getKind() != CelExpr.ExprKind.Kind.IDENT) {
56+
return Optional.of(reportArgumentError(exprFactory, arg1));
57+
}
58+
CelExpr arg2 = checkNotNull(arguments.get(2));
59+
CelExpr accuInit = exprFactory.newBoolLiteral(true);
60+
CelExpr condition =
61+
exprFactory.newGlobalCall(
62+
Operator.NOT_STRICTLY_FALSE.getFunction(),
63+
exprFactory.newIdentifier(exprFactory.getAccumulatorVarName()));
64+
CelExpr step =
65+
exprFactory.newGlobalCall(
66+
Operator.LOGICAL_AND.getFunction(),
67+
exprFactory.newIdentifier(exprFactory.getAccumulatorVarName()),
68+
arg2);
69+
CelExpr result = exprFactory.newIdentifier(exprFactory.getAccumulatorVarName());
70+
return Optional.of(
71+
exprFactory.fold(
72+
arg0.ident().name(),
73+
arg1.ident().name(),
74+
target,
75+
exprFactory.getAccumulatorVarName(),
76+
accuInit,
77+
condition,
78+
step,
79+
result));
80+
}
81+
82+
private static CelExpr reportArgumentError(CelMacroExprFactory exprFactory, CelExpr argument) {
83+
return exprFactory.reportError(
84+
CelIssue.formatError(
85+
exprFactory.getSourceLocation(argument), "The argument must be a simple name"));
86+
}
87+
}

extensions/src/main/java/dev/cel/extensions/CelExtensions.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ public final class CelExtensions {
3535
private static final CelBindingsExtensions BINDINGS_EXTENSIONS = new CelBindingsExtensions();
3636
private static final CelEncoderExtensions ENCODER_EXTENSIONS = new CelEncoderExtensions();
3737
private static final CelRegexExtensions REGEX_EXTENSIONS = new CelRegexExtensions();
38+
private static final CelComprehensionsExtensions COMPREHENSIONS_EXTENSIONS =
39+
new CelComprehensionsExtensions();
3840

3941
/**
4042
* Implementation of optional values.
@@ -295,6 +297,18 @@ public static CelRegexExtensions regex() {
295297
return REGEX_EXTENSIONS;
296298
}
297299

300+
/**
301+
* Extended functions for Two Variable Comprehensions Expressions.
302+
*
303+
* <p>Refer to README.md for available functions.
304+
*
305+
* <p>This will include all functions denoted in {@link CelComprehensionsExtensions.Function}, including
306+
* any future additions.
307+
*/
308+
private static CelComprehensionsExtensions comprehensions() {
309+
return COMPREHENSIONS_EXTENSIONS;
310+
}
311+
298312
/**
299313
* Retrieves all function names used by every extension libraries.
300314
*

parser/src/main/java/dev/cel/parser/CelMacroExprFactory.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ public final CelExpr copy(CelExpr expr) {
133133
builder.setComprehension(
134134
CelExpr.CelComprehension.newBuilder()
135135
.setIterVar(expr.comprehension().iterVar())
136+
.setIterVar2(expr.comprehension().iterVar2())
136137
.setIterRange(copy(expr.comprehension().iterRange()))
137138
.setAccuVar(expr.comprehension().accuVar())
138139
.setAccuInit(copy(expr.comprehension().accuInit()))

testing/src/main/java/dev/cel/testing/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ java_library(
5555
"//common/types:type_providers",
5656
"//compiler",
5757
"//compiler:compiler_builder",
58+
"//extensions/src/main/java/dev/cel/extensions:comprehensions",
5859
"//parser:macro",
5960
"@maven//:com_google_guava_guava",
6061
"@maven//:com_google_protobuf_protobuf_java",

0 commit comments

Comments
 (0)