Skip to content

Commit 4cc2cfd

Browse files
maskri17copybara-github
authored andcommitted
Checker and parser changes to support comprehensionsV2
PiperOrigin-RevId: 788649636
1 parent 5b06cc0 commit 4cc2cfd

11 files changed

Lines changed: 283 additions & 1 deletion

File tree

checker/src/main/java/dev/cel/checker/ExprChecker.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import com.google.auto.value.AutoValue;
2323
import com.google.common.base.Joiner;
2424
import com.google.common.base.Optional;
25+
import com.google.common.base.Strings;
2526
import com.google.common.collect.ImmutableList;
2627
import com.google.common.collect.Maps;
2728
import com.google.errorprone.annotations.CheckReturnValue;
@@ -510,13 +511,21 @@ private CelExpr visit(CelExpr expr, CelExpr.CelComprehension compre) {
510511
CelType accuType = env.getType(visitedInit);
511512
CelType rangeType = inferenceContext.specialize(env.getType(visitedRange));
512513
CelType varType;
514+
CelType varType2 = null;
513515
switch (rangeType.kind()) {
514516
case LIST:
515517
varType = ((ListType) rangeType).elemType();
518+
if (!Strings.isNullOrEmpty(compre.iterVar2())) {
519+
varType2 = varType;
520+
varType = SimpleType.INT;
521+
}
516522
break;
517523
case MAP:
518524
// Ranges over the keys.
519525
varType = ((MapType) rangeType).keyType();
526+
if (!Strings.isNullOrEmpty(compre.iterVar2())) {
527+
varType2 = ((MapType) rangeType).valueType();
528+
}
520529
break;
521530
case DYN:
522531
case ERROR:
@@ -547,6 +556,9 @@ private CelExpr visit(CelExpr expr, CelExpr.CelComprehension compre) {
547556
// Declare iteration variable on inner scope.
548557
env.enterScope();
549558
env.add(CelIdentDecl.newIdentDeclaration(compre.iterVar(), varType));
559+
if (!Strings.isNullOrEmpty(compre.iterVar2())) {
560+
env.add(CelIdentDecl.newIdentDeclaration(compre.iterVar2(), varType2));
561+
}
550562
CelExpr condition = visit(compre.loopCondition());
551563
assertType(condition, SimpleType.BOOL);
552564
CelExpr visitedStep = visit(compre.loopStep());

checker/src/test/java/dev/cel/checker/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ java_library(
4141
"//common/types:type_providers",
4242
"//compiler",
4343
"//compiler:compiler_builder",
44+
"//extensions",
4445
"//parser:macro",
4546
"//parser:operator",
4647
"//testing:adorner",

checker/src/test/java/dev/cel/checker/ExprCheckerTest.java

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import dev.cel.common.CelMutableAst;
3636
import dev.cel.common.CelOverloadDecl;
3737
import dev.cel.common.CelProtoAbstractSyntaxTree;
38+
import dev.cel.common.CelValidationResult;
3839
import dev.cel.common.CelVarDecl;
3940
import dev.cel.common.ast.CelConstant;
4041
import dev.cel.common.internal.EnvVisitable;
@@ -52,7 +53,10 @@
5253
import dev.cel.common.types.StructTypeReference;
5354
import dev.cel.common.types.TypeParamType;
5455
import dev.cel.common.types.TypeType;
56+
import dev.cel.compiler.CelCompiler;
57+
import dev.cel.compiler.CelCompilerFactory;
5558
import dev.cel.expr.conformance.proto3.TestAllTypes;
59+
import dev.cel.extensions.CelExtensions;
5660
import dev.cel.parser.CelMacro;
5761
import dev.cel.testing.CelAdorner;
5862
import dev.cel.testing.CelBaselineTestCase;
@@ -792,6 +796,47 @@ public void quantifiers() throws Exception {
792796
runTest();
793797
}
794798

799+
@Test
800+
public void twoVarCompreQuantifiers() throws Exception {
801+
CelType messageType = StructTypeReference.create("cel.expr.conformance.proto3.TestAllTypes");
802+
source = "x.map_string_string.all(i, v, i < v) " + "&& x.repeated_int64.all(i, v, i < v)";
803+
CelCompiler compiler =
804+
CelCompilerFactory.standardCelCompilerBuilder()
805+
.addMessageTypes(TestAllTypes.getDescriptor())
806+
.addLibraries(CelExtensions.comprehensions())
807+
.addVar("x", messageType)
808+
.build();
809+
810+
CelAbstractSyntaxTree ast = compiler.check(compiler.parse(source).getAst()).getAst();
811+
812+
if (ast != null) {
813+
testOutput()
814+
.println(
815+
CelDebug.toAdornedDebugString(
816+
CelProtoAbstractSyntaxTree.fromCelAst(ast).getExpr(),
817+
new CheckedExprAdorner(
818+
CelProtoAbstractSyntaxTree.fromCelAst(ast).toCheckedExpr())));
819+
}
820+
testOutput().println();
821+
}
822+
823+
@Test
824+
public void twoVarCompreQuantifiers_invalidArgument_reportsError() throws Exception {
825+
CelType messageType = StructTypeReference.create("cel.expr.conformance.proto3.TestAllTypes");
826+
source =
827+
"x.map_string_string.all(i + 1, v, i < v) " + "&& x.repeated_int64.all(i, v + 1, i < v)";
828+
CelCompiler compiler =
829+
CelCompilerFactory.standardCelCompilerBuilder()
830+
.addMessageTypes(TestAllTypes.getDescriptor())
831+
.addLibraries(CelExtensions.comprehensions())
832+
.addVar("x", messageType)
833+
.build();
834+
835+
CelValidationResult parseResult = compiler.parse(source);
836+
testOutput().println(parseResult.getErrorString());
837+
testOutput().println();
838+
}
839+
795840
@Test
796841
public void quantifiersErrors() throws Exception {
797842
CelType messageType = StructTypeReference.create("cel.expr.conformance.proto3.TestAllTypes");
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
_&&_(
2+
__comprehension__(
3+
// Variable
4+
i,
5+
// Target
6+
x~cel.expr.conformance.proto3.TestAllTypes^x.map_string_string~map(string, string),
7+
// Accumulator
8+
@result,
9+
// Init
10+
true~bool,
11+
// LoopCondition
12+
@not_strictly_false(
13+
@result~bool^@result
14+
)~bool^not_strictly_false,
15+
// LoopStep
16+
_&&_(
17+
@result~bool^@result,
18+
_<_(
19+
i~string^i,
20+
v~string^v
21+
)~bool^less_string
22+
)~bool^logical_and,
23+
// Result
24+
@result~bool^@result)~bool,
25+
__comprehension__(
26+
// Variable
27+
i,
28+
// Target
29+
x~cel.expr.conformance.proto3.TestAllTypes^x.repeated_int64~list(int),
30+
// Accumulator
31+
@result,
32+
// Init
33+
true~bool,
34+
// LoopCondition
35+
@not_strictly_false(
36+
@result~bool^@result
37+
)~bool^not_strictly_false,
38+
// LoopStep
39+
_&&_(
40+
@result~bool^@result,
41+
_<_(
42+
i~int^i,
43+
v~int^v
44+
)~bool^less_int64
45+
)~bool^logical_and,
46+
// Result
47+
@result~bool^@result)~bool
48+
)~bool^logical_and
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
ERROR: <input>:1:27: The argument must be a simple name
2+
| x.map_string_string.all(i + 1, v, i < v) && x.repeated_int64.all(i, v + 1, i < v)
3+
| ..........................^
4+
ERROR: <input>:1:71: The argument must be a simple name
5+
| x.map_string_string.all(i + 1, v, i < v) && x.repeated_int64.all(i, v + 1, i < v)
6+
| ......................................................................^

common/src/main/java/dev/cel/common/ast/CelExprFactory.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ public final CelExpr.CelStruct.Entry newMessageField(String field, CelExpr value
143143
.build();
144144
}
145145

146-
/** Fold creates a fold comprehension instruction. */
146+
/** Fold creates a fold for one variable comprehension instruction. */
147147
public final CelExpr fold(
148148
String iterVar,
149149
CelExpr iterRange,

extensions/src/main/java/dev/cel/extensions/BUILD.bazel

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ java_library(
3030
],
3131
deps = [
3232
":bindings",
33+
":comprehensions",
3334
":encoders",
3435
":lists",
3536
":math",
@@ -289,3 +290,21 @@ java_library(
289290
"@maven//:com_google_re2j_re2j",
290291
],
291292
)
293+
294+
java_library(
295+
name = "comprehensions",
296+
srcs = ["CelComprehensionsExtensions.java"],
297+
tags = [
298+
],
299+
deps = [
300+
"//common:compiler_common",
301+
"//common/ast",
302+
"//compiler:compiler_builder",
303+
"//extensions:extension_library",
304+
"//parser:macro",
305+
"//parser:operator",
306+
"//parser:parser_builder",
307+
"//runtime",
308+
"@maven//:com_google_guava_guava",
309+
],
310+
)
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package dev.cel.extensions;
16+
17+
import static com.google.common.base.Preconditions.checkArgument;
18+
import static com.google.common.base.Preconditions.checkNotNull;
19+
20+
import com.google.common.collect.ImmutableList;
21+
import com.google.common.collect.ImmutableSet;
22+
import dev.cel.common.CelIssue;
23+
import dev.cel.common.ast.CelExpr;
24+
import dev.cel.compiler.CelCompilerLibrary;
25+
import dev.cel.parser.CelMacro;
26+
import dev.cel.parser.CelMacroExprFactory;
27+
import dev.cel.parser.CelParserBuilder;
28+
import dev.cel.parser.Operator;
29+
import java.util.Optional;
30+
31+
/** Internal implementation of CEL comprehensions extensions. */
32+
final class CelComprehensionsExtensions implements CelCompilerLibrary {
33+
// @Override
34+
public ImmutableSet<CelMacro> macros() {
35+
return ImmutableSet.of(
36+
CelMacro.newReceiverMacro(
37+
Operator.ALL.getFunction(), 3, CelComprehensionsExtensions::expandAllMacro));
38+
}
39+
40+
@Override
41+
public void setParserOptions(CelParserBuilder parserBuilder) {
42+
parserBuilder.addMacros(macros());
43+
}
44+
45+
private static Optional<CelExpr> expandAllMacro(
46+
CelMacroExprFactory exprFactory, CelExpr target, ImmutableList<CelExpr> arguments) {
47+
checkNotNull(exprFactory);
48+
checkNotNull(target);
49+
checkArgument(arguments.size() == 3);
50+
CelExpr arg0 = checkNotNull(arguments.get(0));
51+
if (arg0.exprKind().getKind() != CelExpr.ExprKind.Kind.IDENT) {
52+
return Optional.of(reportArgumentError(exprFactory, arg0));
53+
}
54+
CelExpr arg1 = checkNotNull(arguments.get(1));
55+
if (arg1.exprKind().getKind() != CelExpr.ExprKind.Kind.IDENT) {
56+
return Optional.of(reportArgumentError(exprFactory, arg1));
57+
}
58+
CelExpr arg2 = checkNotNull(arguments.get(2));
59+
CelExpr accuInit = exprFactory.newBoolLiteral(true);
60+
CelExpr condition =
61+
exprFactory.newGlobalCall(
62+
Operator.NOT_STRICTLY_FALSE.getFunction(),
63+
exprFactory.newIdentifier(exprFactory.getAccumulatorVarName()));
64+
CelExpr step =
65+
exprFactory.newGlobalCall(
66+
Operator.LOGICAL_AND.getFunction(),
67+
exprFactory.newIdentifier(exprFactory.getAccumulatorVarName()),
68+
arg2);
69+
CelExpr result = exprFactory.newIdentifier(exprFactory.getAccumulatorVarName());
70+
return Optional.of(
71+
exprFactory.fold(
72+
arg0.ident().name(),
73+
arg1.ident().name(),
74+
target,
75+
exprFactory.getAccumulatorVarName(),
76+
accuInit,
77+
condition,
78+
step,
79+
result));
80+
}
81+
82+
private static CelExpr reportArgumentError(CelMacroExprFactory exprFactory, CelExpr argument) {
83+
return exprFactory.reportError(
84+
CelIssue.formatError(
85+
exprFactory.getSourceLocation(argument), "The argument must be a simple name"));
86+
}
87+
}

extensions/src/main/java/dev/cel/extensions/CelExtensions.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ public final class CelExtensions {
3535
private static final CelBindingsExtensions BINDINGS_EXTENSIONS = new CelBindingsExtensions();
3636
private static final CelEncoderExtensions ENCODER_EXTENSIONS = new CelEncoderExtensions();
3737
private static final CelRegexExtensions REGEX_EXTENSIONS = new CelRegexExtensions();
38+
private static final CelComprehensionsExtensions COMPREHENSIONS_EXTENSIONS =
39+
new CelComprehensionsExtensions();
3840

3941
/**
4042
* Implementation of optional values.
@@ -295,6 +297,18 @@ public static CelRegexExtensions regex() {
295297
return REGEX_EXTENSIONS;
296298
}
297299

300+
/**
301+
* Extended functions for Regular Expressions.
302+
*
303+
* <p>Refer to README.md for available functions.
304+
*
305+
* <p>This will include all functions denoted in {@link CelRegexExtensions.Function}, including
306+
* any future additions.
307+
*/
308+
public static CelComprehensionsExtensions comprehensions() {
309+
return COMPREHENSIONS_EXTENSIONS;
310+
}
311+
298312
/**
299313
* Retrieves all function names used by every extension libraries.
300314
*
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package dev.cel.extensions;
16+
17+
import static com.google.common.truth.Truth.assertThat;
18+
19+
import com.google.testing.junit.testparameterinjector.TestParameterInjector;
20+
import dev.cel.common.CelAbstractSyntaxTree;
21+
import dev.cel.compiler.CelCompiler;
22+
import dev.cel.compiler.CelCompilerFactory;
23+
import dev.cel.runtime.CelRuntime;
24+
import dev.cel.runtime.CelRuntimeFactory;
25+
import org.junit.Test;
26+
import org.junit.runner.RunWith;
27+
28+
/** Test for {@link CelExtensions#comprehensions()} */
29+
@RunWith(TestParameterInjector.class)
30+
public class CelComprehensionsExtensionsTest {
31+
private static final CelCompiler CEL_COMPILER =
32+
CelCompilerFactory.standardCelCompilerBuilder()
33+
.addLibraries(CelExtensions.comprehensions())
34+
.build();
35+
private static final CelRuntime CEL_RUNTIME =
36+
CelRuntimeFactory.standardCelRuntimeBuilder()
37+
// .addLibraries(CelExtensions.comprehensions())
38+
.build();
39+
40+
@Test
41+
public void all_comprehension_success() throws Exception {
42+
CelAbstractSyntaxTree ast =
43+
CEL_COMPILER.compile("[4,5,6].all(x, s, x > 3)").getAst();
44+
45+
Object result = CEL_RUNTIME.createProgram(ast).eval();
46+
47+
assertThat(result).isEqualTo(true);
48+
}
49+
}

0 commit comments

Comments
 (0)