Skip to content

Commit be5aa9b

Browse files
l46kokcopybara-github
authored andcommitted
Change comprehension variable mangling logic to always generate a unique index per comprehension expr
PiperOrigin-RevId: 629561823
1 parent 980c9b5 commit be5aa9b

22 files changed

Lines changed: 7328 additions & 2444 deletions

optimizer/src/main/java/dev/cel/optimizer/AstMutator.java

Lines changed: 8 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,7 @@
2020

2121
import com.google.auto.value.AutoValue;
2222
import com.google.common.base.Preconditions;
23-
import com.google.common.collect.HashBasedTable;
2423
import com.google.common.collect.ImmutableMap;
25-
import com.google.common.collect.Table;
2624
import com.google.errorprone.annotations.Immutable;
2725
import dev.cel.common.CelAbstractSyntaxTree;
2826
import dev.cel.common.CelMutableAst;
@@ -261,43 +259,22 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames(
261259
"Unexpected CelNavigableMutableExpr collision");
262260
},
263261
LinkedHashMap::new));
264-
int iterCount = 0;
265262

266263
// The map that we'll eventually return to the caller.
267264
HashMap<MangledComprehensionName, MangledComprehensionType> mangledIdentNamesToType =
268265
new HashMap<>();
269-
// Intermediary table used for the purposes of generating a unique mangled variable name.
270-
Table<Integer, MangledComprehensionType, MangledComprehensionName> comprehensionLevelToType =
271-
HashBasedTable.create();
272266
CelMutableExpr mutatedComprehensionExpr = navigableMutableAst.getAst().expr();
273267
CelMutableSource newSource = navigableMutableAst.getAst().source();
268+
int iterCount = 0;
274269
for (Entry<CelNavigableMutableExpr, MangledComprehensionType> comprehensionEntry :
275270
comprehensionsToMangle.entrySet()) {
276-
iterCount++;
277-
CelNavigableMutableExpr comprehensionNode = comprehensionEntry.getKey();
278-
MangledComprehensionType comprehensionEntryType = comprehensionEntry.getValue();
279-
280-
CelMutableExpr comprehensionExpr = comprehensionNode.expr();
281-
int comprehensionNestingLevel = countComprehensionNestingLevel(comprehensionNode);
282-
MangledComprehensionName mangledComprehensionName;
283-
if (comprehensionLevelToType.contains(comprehensionNestingLevel, comprehensionEntryType)) {
284-
mangledComprehensionName =
285-
comprehensionLevelToType.get(comprehensionNestingLevel, comprehensionEntryType);
286-
} else {
287-
// First time encountering the pair of <ComprehensionLevel, CelType>. Generate a unique
288-
// mangled variable name for this.
289-
int uniqueTypeIdx = comprehensionLevelToType.row(comprehensionNestingLevel).size();
290-
String mangledIterVarName =
291-
newIterVarPrefix + comprehensionNestingLevel + ":" + uniqueTypeIdx;
292-
String mangledResultName =
293-
newResultPrefix + comprehensionNestingLevel + ":" + uniqueTypeIdx;
294-
mangledComprehensionName =
295-
MangledComprehensionName.of(mangledIterVarName, mangledResultName);
296-
comprehensionLevelToType.put(
297-
comprehensionNestingLevel, comprehensionEntryType, mangledComprehensionName);
298-
}
299-
mangledIdentNamesToType.put(mangledComprehensionName, comprehensionEntryType);
271+
String mangledIterVarName = newIterVarPrefix + ":" + iterCount;
272+
String mangledResultName = newResultPrefix + ":" + iterCount;
273+
MangledComprehensionName mangledComprehensionName =
274+
MangledComprehensionName.of(mangledIterVarName, mangledResultName);
275+
mangledIdentNamesToType.put(mangledComprehensionName, comprehensionEntry.getValue());
300276

277+
CelMutableExpr comprehensionExpr = comprehensionEntry.getKey().expr();
301278
String iterVar = comprehensionExpr.comprehension().iterVar();
302279
String accuVar = comprehensionExpr.comprehension().accuVar();
303280
mutatedComprehensionExpr =
@@ -315,6 +292,7 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames(
315292
iterVar,
316293
mangledComprehensionName,
317294
comprehensionExpr.id());
295+
iterCount++;
318296
}
319297

320298
if (iterCount >= iterationLimit) {
@@ -822,19 +800,6 @@ private static long getMaxId(CelMutableExpr mutableExpr) {
822800
.orElseThrow(NoSuchElementException::new);
823801
}
824802

825-
private static int countComprehensionNestingLevel(CelNavigableMutableExpr comprehensionExpr) {
826-
int nestedLevel = 0;
827-
Optional<CelNavigableMutableExpr> maybeParent = comprehensionExpr.parent();
828-
while (maybeParent.isPresent()) {
829-
if (maybeParent.get().getKind().equals(Kind.COMPREHENSION)) {
830-
nestedLevel++;
831-
}
832-
833-
maybeParent = maybeParent.get().parent();
834-
}
835-
return nestedLevel;
836-
}
837-
838803
/**
839804
* Intermediate value class to store the mangled identifiers for iteration variable and the
840805
* comprehension result.

optimizer/src/test/java/dev/cel/optimizer/AstMutatorTest.java

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -706,7 +706,7 @@ public void list_replaceElement() throws Exception {
706706
}
707707

708708
@Test
709-
public void createStruct_replaceValue() throws Exception {
709+
public void struct_replaceValue() throws Exception {
710710
// Tree shape (brackets are expr IDs):
711711
// TestAllTypes [1]
712712
// single_int64 [2]
@@ -722,7 +722,7 @@ public void createStruct_replaceValue() throws Exception {
722722
}
723723

724724
@Test
725-
public void createMap_replaceKey() throws Exception {
725+
public void map_replaceKey() throws Exception {
726726
// Tree shape (brackets are expr IDs):
727727
// map [1]
728728
// map_entry [2]
@@ -737,7 +737,7 @@ public void createMap_replaceKey() throws Exception {
737737
}
738738

739739
@Test
740-
public void createMap_replaceValue() throws Exception {
740+
public void map_replaceValue() throws Exception {
741741
// Tree shape (brackets are expr IDs):
742742
// map [1]
743743
// map_entry [2]
@@ -808,15 +808,15 @@ public void mangleComprehensionVariable_singleMacro() throws Exception {
808808
assertThat(mangledAst.getExpr().toString())
809809
.isEqualTo(
810810
"COMPREHENSION [13] {\n"
811-
+ " iter_var: @c0:0\n"
811+
+ " iter_var: @c:0\n"
812812
+ " iter_range: {\n"
813813
+ " LIST [1] {\n"
814814
+ " elements: {\n"
815815
+ " CONSTANT [2] { value: false }\n"
816816
+ " }\n"
817817
+ " }\n"
818818
+ " }\n"
819-
+ " accu_var: @x0:0\n"
819+
+ " accu_var: @x:0\n"
820820
+ " accu_init: {\n"
821821
+ " CONSTANT [6] { value: false }\n"
822822
+ " }\n"
@@ -828,7 +828,7 @@ public void mangleComprehensionVariable_singleMacro() throws Exception {
828828
+ " function: !_\n"
829829
+ " args: {\n"
830830
+ " IDENT [7] {\n"
831-
+ " name: @x0:0\n"
831+
+ " name: @x:0\n"
832832
+ " }\n"
833833
+ " }\n"
834834
+ " }\n"
@@ -840,21 +840,22 @@ public void mangleComprehensionVariable_singleMacro() throws Exception {
840840
+ " function: _||_\n"
841841
+ " args: {\n"
842842
+ " IDENT [10] {\n"
843-
+ " name: @x0:0\n"
843+
+ " name: @x:0\n"
844844
+ " }\n"
845845
+ " IDENT [5] {\n"
846-
+ " name: @c0:0\n"
846+
+ " name: @c:0\n"
847847
+ " }\n"
848848
+ " }\n"
849849
+ " }\n"
850850
+ " }\n"
851851
+ " result: {\n"
852852
+ " IDENT [12] {\n"
853-
+ " name: @x0:0\n"
853+
+ " name: @x:0\n"
854854
+ " }\n"
855855
+ " }\n"
856856
+ "}");
857-
assertThat(CEL_UNPARSER.unparse(mangledAst)).isEqualTo("[false].exists(@c0:0, @c0:0)");
857+
858+
assertThat(CEL_UNPARSER.unparse(mangledAst)).isEqualTo("[false].exists(@c:0, @c:0)");
858859
assertThat(CEL.createProgram(CEL.check(mangledAst).getAst()).eval()).isEqualTo(false);
859860
assertConsistentMacroCalls(ast);
860861
}
@@ -891,7 +892,7 @@ public void mangleComprehensionVariable_nestedMacroWithShadowedVariables() throw
891892
assertThat(mangledAst.getExpr().toString())
892893
.isEqualTo(
893894
"COMPREHENSION [27] {\n"
894-
+ " iter_var: @c0:0\n"
895+
+ " iter_var: @c:1\n"
895896
+ " iter_range: {\n"
896897
+ " LIST [1] {\n"
897898
+ " elements: {\n"
@@ -901,7 +902,7 @@ public void mangleComprehensionVariable_nestedMacroWithShadowedVariables() throw
901902
+ " }\n"
902903
+ " }\n"
903904
+ " }\n"
904-
+ " accu_var: @x0:0\n"
905+
+ " accu_var: @x:1\n"
905906
+ " accu_init: {\n"
906907
+ " CONSTANT [20] { value: false }\n"
907908
+ " }\n"
@@ -913,7 +914,7 @@ public void mangleComprehensionVariable_nestedMacroWithShadowedVariables() throw
913914
+ " function: !_\n"
914915
+ " args: {\n"
915916
+ " IDENT [21] {\n"
916-
+ " name: @x0:0\n"
917+
+ " name: @x:1\n"
917918
+ " }\n"
918919
+ " }\n"
919920
+ " }\n"
@@ -925,20 +926,20 @@ public void mangleComprehensionVariable_nestedMacroWithShadowedVariables() throw
925926
+ " function: _||_\n"
926927
+ " args: {\n"
927928
+ " IDENT [24] {\n"
928-
+ " name: @x0:0\n"
929+
+ " name: @x:1\n"
929930
+ " }\n"
930931
+ " COMPREHENSION [19] {\n"
931-
+ " iter_var: @c1:0\n"
932+
+ " iter_var: @c:0\n"
932933
+ " iter_range: {\n"
933934
+ " LIST [5] {\n"
934935
+ " elements: {\n"
935936
+ " IDENT [6] {\n"
936-
+ " name: @c0:0\n"
937+
+ " name: @c:1\n"
937938
+ " }\n"
938939
+ " }\n"
939940
+ " }\n"
940941
+ " }\n"
941-
+ " accu_var: @x1:0\n"
942+
+ " accu_var: @x:0\n"
942943
+ " accu_init: {\n"
943944
+ " CONSTANT [12] { value: false }\n"
944945
+ " }\n"
@@ -950,7 +951,7 @@ public void mangleComprehensionVariable_nestedMacroWithShadowedVariables() throw
950951
+ " function: !_\n"
951952
+ " args: {\n"
952953
+ " IDENT [13] {\n"
953-
+ " name: @x1:0\n"
954+
+ " name: @x:0\n"
954955
+ " }\n"
955956
+ " }\n"
956957
+ " }\n"
@@ -962,13 +963,13 @@ public void mangleComprehensionVariable_nestedMacroWithShadowedVariables() throw
962963
+ " function: _||_\n"
963964
+ " args: {\n"
964965
+ " IDENT [16] {\n"
965-
+ " name: @x1:0\n"
966+
+ " name: @x:0\n"
966967
+ " }\n"
967968
+ " CALL [10] {\n"
968969
+ " function: _==_\n"
969970
+ " args: {\n"
970971
+ " IDENT [9] {\n"
971-
+ " name: @c1:0\n"
972+
+ " name: @c:0\n"
972973
+ " }\n"
973974
+ " CONSTANT [11] { value: 1 }\n"
974975
+ " }\n"
@@ -978,7 +979,7 @@ public void mangleComprehensionVariable_nestedMacroWithShadowedVariables() throw
978979
+ " }\n"
979980
+ " result: {\n"
980981
+ " IDENT [18] {\n"
981-
+ " name: @x1:0\n"
982+
+ " name: @x:0\n"
982983
+ " }\n"
983984
+ " }\n"
984985
+ " }\n"
@@ -987,12 +988,13 @@ public void mangleComprehensionVariable_nestedMacroWithShadowedVariables() throw
987988
+ " }\n"
988989
+ " result: {\n"
989990
+ " IDENT [26] {\n"
990-
+ " name: @x0:0\n"
991+
+ " name: @x:1\n"
991992
+ " }\n"
992993
+ " }\n"
993994
+ "}");
995+
994996
assertThat(CEL_UNPARSER.unparse(mangledAst))
995-
.isEqualTo("[x].exists(@c0:0, [@c0:0].exists(@c1:0, @c1:0 == 1))");
997+
.isEqualTo("[x].exists(@c:1, [@c:1].exists(@c:0, @c:0 == 1))");
996998
assertThat(CEL.createProgram(CEL.check(mangledAst).getAst()).eval(ImmutableMap.of("x", 1)))
997999
.isEqualTo(true);
9981000
assertConsistentMacroCalls(ast);

optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerBaselineTest.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,9 @@ private enum CseTestCase {
456456
MULTIPLE_MACROS_2(
457457
"[[1].exists(i, i > 0)] + [[1].exists(j, j > 0)] + [['a'].exists(k, k == 'a')] +"
458458
+ " [['a'].exists(l, l == 'a')] == [true, true, true, true]"),
459+
MULTIPLE_MACROS_3(
460+
"[1].exists(i, i > 0) && [1].exists(j, j > 0) && [1].exists(k, k > 1) && [2].exists(l, l >"
461+
+ " 1)"),
459462
NESTED_MACROS("[1,2,3].map(i, [1, 2, 3].map(i, i + 1)) == [[2, 3, 4], [2, 3, 4], [2, 3, 4]]"),
460463
NESTED_MACROS_2("[1, 2].map(y, [1, 2, 3].filter(x, x == y)) == [[1], [2]]"),
461464
INCLUSION_LIST("1 in [1,2,3] && 2 in [1,2,3] && 3 in [3, [1,2,3]] && 1 in [1,2,3]"),

optimizer/src/test/resources/constfold_before_subexpression_unparsed.baseline

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,22 @@ Result: true
302302
[BLOCK_RECURSION_DEPTH_8]: true
303303
[BLOCK_RECURSION_DEPTH_9]: true
304304

305+
Test case: MULTIPLE_MACROS_3
306+
Source: [1].exists(i, i > 0) && [1].exists(j, j > 0) && [1].exists(k, k > 1) && [2].exists(l, l > 1)
307+
=====>
308+
Result: false
309+
[CASCADED_BINDS]: false
310+
[BLOCK_COMMON_SUBEXPR_ONLY]: false
311+
[BLOCK_RECURSION_DEPTH_1]: false
312+
[BLOCK_RECURSION_DEPTH_2]: false
313+
[BLOCK_RECURSION_DEPTH_3]: false
314+
[BLOCK_RECURSION_DEPTH_4]: false
315+
[BLOCK_RECURSION_DEPTH_5]: false
316+
[BLOCK_RECURSION_DEPTH_6]: false
317+
[BLOCK_RECURSION_DEPTH_7]: false
318+
[BLOCK_RECURSION_DEPTH_8]: false
319+
[BLOCK_RECURSION_DEPTH_9]: false
320+
305321
Test case: NESTED_MACROS
306322
Source: [1,2,3].map(i, [1, 2, 3].map(i, i + 1)) == [[2, 3, 4], [2, 3, 4], [2, 3, 4]]
307323
=====>
@@ -386,17 +402,17 @@ Test case: MACRO_SHADOWED_VARIABLE
386402
Source: [x - 1 > 3 ? x - 1 : 5].exists(x, x - 1 > 3) || x - 1 > 3
387403
=====>
388404
Result: true
389-
[CASCADED_BINDS]: cel.bind(@r0, x - 1, cel.bind(@r1, @r0 > 3, [@r1 ? @r0 : 5].exists(@c0:0, @c0:0 - 1 > 3) || @r1))
390-
[BLOCK_COMMON_SUBEXPR_ONLY]: cel.@block([x - 1, @index0 > 3], [@index1 ? @index0 : 5].exists(@c0:0, @c0:0 - 1 > 3) || @index1)
391-
[BLOCK_RECURSION_DEPTH_1]: cel.@block([x - 1, @index0 > 3, @index1 ? @index0 : 5, [@index2], @c0:0 - 1, @index4 > 3, @x0:0 || @index5], @index3.exists(@c0:0, @index5) || @index1)
392-
[BLOCK_RECURSION_DEPTH_2]: cel.@block([x - 1 > 3, @index0 ? (x - 1) : 5, @c0:0 - 1 > 3, [@index1], @x0:0 || @index2], @index3.exists(@c0:0, @index2) || @index0)
393-
[BLOCK_RECURSION_DEPTH_3]: cel.@block([x - 1 > 3, [@index0 ? (x - 1) : 5], @x0:0 || @c0:0 - 1 > 3, @index1.exists(@c0:0, @c0:0 - 1 > 3)], @index3 || @index0)
394-
[BLOCK_RECURSION_DEPTH_4]: cel.@block([x - 1 > 3, [@index0 ? (x - 1) : 5].exists(@c0:0, @c0:0 - 1 > 3)], @index1 || @index0)
395-
[BLOCK_RECURSION_DEPTH_5]: cel.@block([x - 1 > 3], [@index0 ? (x - 1) : 5].exists(@c0:0, @c0:0 - 1 > 3) || @index0)
396-
[BLOCK_RECURSION_DEPTH_6]: cel.@block([x - 1 > 3], [@index0 ? (x - 1) : 5].exists(@c0:0, @c0:0 - 1 > 3) || @index0)
397-
[BLOCK_RECURSION_DEPTH_7]: cel.@block([x - 1 > 3], [@index0 ? (x - 1) : 5].exists(@c0:0, @c0:0 - 1 > 3) || @index0)
398-
[BLOCK_RECURSION_DEPTH_8]: cel.@block([x - 1 > 3], [@index0 ? (x - 1) : 5].exists(@c0:0, @c0:0 - 1 > 3) || @index0)
399-
[BLOCK_RECURSION_DEPTH_9]: cel.@block([x - 1 > 3], [@index0 ? (x - 1) : 5].exists(@c0:0, @c0:0 - 1 > 3) || @index0)
405+
[CASCADED_BINDS]: cel.bind(@r0, x - 1, cel.bind(@r1, @r0 > 3, [@r1 ? @r0 : 5].exists(@c:0, @c:0 - 1 > 3) || @r1))
406+
[BLOCK_COMMON_SUBEXPR_ONLY]: cel.@block([x - 1, @index0 > 3], [@index1 ? @index0 : 5].exists(@c:0, @c:0 - 1 > 3) || @index1)
407+
[BLOCK_RECURSION_DEPTH_1]: cel.@block([x - 1, @index0 > 3, @index1 ? @index0 : 5, [@index2], @c:0 - 1, @index4 > 3, @x:0 || @index5], @index3.exists(@c:0, @index5) || @index1)
408+
[BLOCK_RECURSION_DEPTH_2]: cel.@block([x - 1 > 3, @index0 ? (x - 1) : 5, @c:0 - 1 > 3, [@index1], @x:0 || @index2], @index3.exists(@c:0, @index2) || @index0)
409+
[BLOCK_RECURSION_DEPTH_3]: cel.@block([x - 1 > 3, [@index0 ? (x - 1) : 5], @x:0 || @c:0 - 1 > 3, @index1.exists(@c:0, @c:0 - 1 > 3)], @index3 || @index0)
410+
[BLOCK_RECURSION_DEPTH_4]: cel.@block([x - 1 > 3, [@index0 ? (x - 1) : 5].exists(@c:0, @c:0 - 1 > 3)], @index1 || @index0)
411+
[BLOCK_RECURSION_DEPTH_5]: cel.@block([x - 1 > 3], [@index0 ? (x - 1) : 5].exists(@c:0, @c:0 - 1 > 3) || @index0)
412+
[BLOCK_RECURSION_DEPTH_6]: cel.@block([x - 1 > 3], [@index0 ? (x - 1) : 5].exists(@c:0, @c:0 - 1 > 3) || @index0)
413+
[BLOCK_RECURSION_DEPTH_7]: cel.@block([x - 1 > 3], [@index0 ? (x - 1) : 5].exists(@c:0, @c:0 - 1 > 3) || @index0)
414+
[BLOCK_RECURSION_DEPTH_8]: cel.@block([x - 1 > 3], [@index0 ? (x - 1) : 5].exists(@c:0, @c:0 - 1 > 3) || @index0)
415+
[BLOCK_RECURSION_DEPTH_9]: cel.@block([x - 1 > 3], [@index0 ? (x - 1) : 5].exists(@c:0, @c:0 - 1 > 3) || @index0)
400416

401417
Test case: MACRO_SHADOWED_VARIABLE_2
402418
Source: ["foo", "bar"].map(x, [x + x, x + x]).map(x, [x + x, x + x])
@@ -668,4 +684,4 @@ Result: 31
668684
[BLOCK_RECURSION_DEPTH_6]: cel.@block([pure_custom_func(msg.oneof_type.payload.single_int64), @index0 + pure_custom_func(msg.oneof_type.payload.single_int32) + @index0], @index1 + pure_custom_func(msg.single_int64))
669685
[BLOCK_RECURSION_DEPTH_7]: cel.@block([pure_custom_func(msg.oneof_type.payload.single_int64)], @index0 + pure_custom_func(msg.oneof_type.payload.single_int32) + @index0 + pure_custom_func(msg.single_int64))
670686
[BLOCK_RECURSION_DEPTH_8]: cel.@block([pure_custom_func(msg.oneof_type.payload.single_int64)], @index0 + pure_custom_func(msg.oneof_type.payload.single_int32) + @index0 + pure_custom_func(msg.single_int64))
671-
[BLOCK_RECURSION_DEPTH_9]: cel.@block([pure_custom_func(msg.oneof_type.payload.single_int64)], @index0 + pure_custom_func(msg.oneof_type.payload.single_int32) + @index0 + pure_custom_func(msg.single_int64))
687+
[BLOCK_RECURSION_DEPTH_9]: cel.@block([pure_custom_func(msg.oneof_type.payload.single_int64)], @index0 + pure_custom_func(msg.oneof_type.payload.single_int32) + @index0 + pure_custom_func(msg.single_int64))

optimizer/src/test/resources/large_expressions_bind_cascaded.baseline

Lines changed: 1 addition & 1 deletion
Large diffs are not rendered by default.

optimizer/src/test/resources/large_expressions_block_common_subexpr.baseline

Lines changed: 1 addition & 1 deletion
Large diffs are not rendered by default.

optimizer/src/test/resources/large_expressions_block_recursion_depth_1.baseline

Lines changed: 1 addition & 1 deletion
Large diffs are not rendered by default.

optimizer/src/test/resources/large_expressions_block_recursion_depth_2.baseline

Lines changed: 1 addition & 1 deletion
Large diffs are not rendered by default.

optimizer/src/test/resources/large_expressions_block_recursion_depth_3.baseline

Lines changed: 1 addition & 1 deletion
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)