Skip to content

Commit a156ddc

Browse files
l46kokcopybara-github
authored andcommitted
Fix replaceSubtree to properly populate three arg map macro source
PiperOrigin-RevId: 797491661
1 parent fe46a63 commit a156ddc

12 files changed

Lines changed: 214 additions & 213 deletions

optimizer/src/main/java/dev/cel/optimizer/AstMutator.java

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -910,21 +910,32 @@ private static void unwrapListArgumentsInMacroCallExpr(
910910

911911
CelMutableExpr loopStepExpr = comprehension.loopStep();
912912
List<CelMutableExpr> loopStepArgs = loopStepExpr.call().args();
913-
if (loopStepArgs.size() != 2) {
913+
if (loopStepArgs.size() != 2 && loopStepArgs.size() != 3) {
914914
throw new IllegalArgumentException(
915915
String.format(
916-
"Expected exactly 2 arguments but got %d instead on expr id: %d",
916+
"Expected exactly 2 or 3 arguments but got %d instead on expr id: %d",
917917
loopStepArgs.size(), loopStepExpr.id()));
918918
}
919919

920920
CelMutableCall existingMacroCall = newMacroCallExpr.call();
921921
CelMutableCall newMacroCall =
922922
existingMacroCall.target().isPresent()
923-
? CelMutableCall.create(existingMacroCall.target().get(), existingMacroCall.function())
923+
? CelMutableCall.create(existingMacroCall.target().get(),
924+
existingMacroCall.function())
924925
: CelMutableCall.create(existingMacroCall.function());
925926
newMacroCall.addArgs(
926927
existingMacroCall.args().get(0)); // iter_var is first argument of the call by convention
927-
newMacroCall.addArgs(loopStepArgs.get(1).list().elements());
928+
929+
CelMutableList extraneousList = null;
930+
if (loopStepArgs.size() == 2) {
931+
extraneousList = loopStepArgs.get(1).list();
932+
} else {
933+
newMacroCall.addArgs(loopStepArgs.get(0));
934+
// For map(x,y,z), z is wrapped in a _+_(@result, [z])
935+
extraneousList = loopStepArgs.get(1).call().args().get(1).list();
936+
}
937+
938+
newMacroCall.addArgs(extraneousList.elements());
928939

929940
newMacroCallExpr.setCall(newMacroCall);
930941
}

optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ private OptimizationResult optimizeUsingCelBlock(CelAbstractSyntaxTree ast, Cel
184184

185185
if (iterCount == 0) {
186186
// No modification has been made.
187-
return OptimizationResult.create(astToModify.toParsedAst());
187+
return OptimizationResult.create(ast);
188188
}
189189

190190
ImmutableList.Builder<CelVarDecl> newVarDecls = ImmutableList.builder();
@@ -395,7 +395,7 @@ private OptimizationResult optimizeUsingCelBind(CelAbstractSyntaxTree ast) {
395395

396396
if (iterCount == 0) {
397397
// No modification has been made.
398-
return OptimizationResult.create(astToModify.toParsedAst());
398+
return OptimizationResult.create(ast);
399399
}
400400

401401
astToModify = astMutator.renumberIdsConsecutively(astToModify);

optimizer/src/test/java/dev/cel/optimizer/AstMutatorTest.java

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,34 @@ public void replaceSubtree_replaceExtraneousListCreatedByMacro_unparseSuccess()
581581
.containsExactly(2L);
582582
}
583583

584+
@Test
585+
@SuppressWarnings("unchecked") // Test only
586+
public void replaceSubtree_replaceExtraneousListCreatedByThreeArgMacro_unparseSuccess()
587+
throws Exception {
588+
CelAbstractSyntaxTree ast = CEL.compile("[1].map(x, true, 1)").getAst();
589+
CelMutableAst mutableAst = CelMutableAst.fromCelAst(ast);
590+
CelMutableAst mutableAst2 = CelMutableAst.fromCelAst(ast);
591+
592+
// These two mutation are equivalent.
593+
CelAbstractSyntaxTree mutatedAstWithList =
594+
AST_MUTATOR
595+
.replaceSubtree(
596+
mutableAst,
597+
CelMutableExpr.ofList(
598+
CelMutableList.create(CelMutableExpr.ofConstant(CelConstant.ofValue(2L)))),
599+
10L)
600+
.toParsedAst();
601+
CelAbstractSyntaxTree mutatedAstWithConstant =
602+
AST_MUTATOR
603+
.replaceSubtree(mutableAst2, CelMutableExpr.ofConstant(CelConstant.ofValue(2L)), 6L)
604+
.toParsedAst();
605+
606+
assertThat(CEL_UNPARSER.unparse(mutatedAstWithList)).isEqualTo("[1].map(x, true, 2)");
607+
assertThat(CEL_UNPARSER.unparse(mutatedAstWithConstant)).isEqualTo("[1].map(x, true, 2)");
608+
assertThat((List<Long>) CEL.createProgram(CEL.check(mutatedAstWithList).getAst()).eval())
609+
.containsExactly(2L);
610+
}
611+
584612
@Test
585613
public void globalCallExpr_replaceRoot() throws Exception {
586614
// Tree shape (brackets are expr IDs):

optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerBaselineTest.java

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -187,45 +187,6 @@ public void subexpression_ast(@TestParameter CseTestOptimizer cseTestOptimizer)
187187
}
188188
}
189189

190-
@Test
191-
public void populateMacroCallsDisabled_macroMapUnpopulated(@TestParameter CseTestCase testCase)
192-
throws Exception {
193-
skipBaselineVerification();
194-
Cel cel = newCelBuilder().build();
195-
CelOptimizer celOptimizerWithBinds =
196-
newCseOptimizer(
197-
cel,
198-
SubexpressionOptimizerOptions.newBuilder()
199-
.populateMacroCalls(false)
200-
.enableCelBlock(false)
201-
.build());
202-
CelOptimizer celOptimizerWithBlocks =
203-
newCseOptimizer(
204-
cel,
205-
SubexpressionOptimizerOptions.newBuilder()
206-
.populateMacroCalls(false)
207-
.enableCelBlock(true)
208-
.build());
209-
CelOptimizer celOptimizerWithFlattenedBlocks =
210-
newCseOptimizer(
211-
cel,
212-
SubexpressionOptimizerOptions.newBuilder()
213-
.populateMacroCalls(false)
214-
.enableCelBlock(true)
215-
.subexpressionMaxRecursionDepth(1)
216-
.build());
217-
CelAbstractSyntaxTree originalAst = cel.compile(testCase.source).getAst();
218-
219-
CelAbstractSyntaxTree astOptimizedWithBinds = celOptimizerWithBinds.optimize(originalAst);
220-
CelAbstractSyntaxTree astOptimizedWithBlocks = celOptimizerWithBlocks.optimize(originalAst);
221-
CelAbstractSyntaxTree astOptimizedWithFlattenedBlocks =
222-
celOptimizerWithFlattenedBlocks.optimize(originalAst);
223-
224-
assertThat(astOptimizedWithBinds.getSource().getMacroCalls()).isEmpty();
225-
assertThat(astOptimizedWithBlocks.getSource().getMacroCalls()).isEmpty();
226-
assertThat(astOptimizedWithFlattenedBlocks.getSource().getMacroCalls()).isEmpty();
227-
}
228-
229190
@Test
230191
public void large_expressions_bind_cascaded() throws Exception {
231192
CelOptimizer celOptimizer =

optimizer/src/test/java/dev/cel/optimizer/optimizers/SubexpressionOptimizerTest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,8 @@ private enum CseNoOpTestCase {
160160
NESTED_FUNCTION("int(timestamp(int(timestamp(1000000000))))"),
161161
// This cannot be optimized. Extracting the common subexpression would presence test
162162
// the bound identifier (e.g: has(@r0)), which is not valid.
163-
UNOPTIMIZABLE_TERNARY("has(msg.single_any) ? msg.single_any : 10");
163+
UNOPTIMIZABLE_TERNARY("has(msg.single_any) ? msg.single_any : 10"),
164+
MACRO("[1, 2, 3].exists(x, x > 0)");
164165

165166
private final String source;
166167

optimizer/src/test/resources/subexpression_ast_block_common_subexpr_only.baseline

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1781,7 +1781,7 @@ CALL [31] {
17811781
function: _==_
17821782
args: {
17831783
COMPREHENSION [30] {
1784-
iter_var: @it:0:0
1784+
iter_var: y
17851785
iter_range: {
17861786
LIST [1] {
17871787
elements: {
@@ -1790,7 +1790,7 @@ CALL [31] {
17901790
}
17911791
}
17921792
}
1793-
accu_var: @ac:0:0
1793+
accu_var: @result
17941794
accu_init: {
17951795
LIST [24] {
17961796
elements: {
@@ -1805,12 +1805,12 @@ CALL [31] {
18051805
function: _+_
18061806
args: {
18071807
IDENT [26] {
1808-
name: @ac:0:0
1808+
name: @result
18091809
}
18101810
LIST [27] {
18111811
elements: {
18121812
COMPREHENSION [23] {
1813-
iter_var: @it:1:0
1813+
iter_var: x
18141814
iter_range: {
18151815
LIST [6] {
18161816
elements: {
@@ -1820,7 +1820,7 @@ CALL [31] {
18201820
}
18211821
}
18221822
}
1823-
accu_var: @ac:1:0
1823+
accu_var: @result
18241824
accu_init: {
18251825
LIST [15] {
18261826
elements: {
@@ -1838,37 +1838,37 @@ CALL [31] {
18381838
function: _==_
18391839
args: {
18401840
IDENT [12] {
1841-
name: @it:1:0
1841+
name: x
18421842
}
18431843
IDENT [14] {
1844-
name: @it:0:0
1844+
name: y
18451845
}
18461846
}
18471847
}
18481848
CALL [19] {
18491849
function: _+_
18501850
args: {
18511851
IDENT [17] {
1852-
name: @ac:1:0
1852+
name: @result
18531853
}
18541854
LIST [18] {
18551855
elements: {
18561856
IDENT [11] {
1857-
name: @it:1:0
1857+
name: x
18581858
}
18591859
}
18601860
}
18611861
}
18621862
}
18631863
IDENT [20] {
1864-
name: @ac:1:0
1864+
name: @result
18651865
}
18661866
}
18671867
}
18681868
}
18691869
result: {
18701870
IDENT [22] {
1871-
name: @ac:1:0
1871+
name: @result
18721872
}
18731873
}
18741874
}
@@ -1879,7 +1879,7 @@ CALL [31] {
18791879
}
18801880
result: {
18811881
IDENT [29] {
1882-
name: @ac:0:0
1882+
name: @result
18831883
}
18841884
}
18851885
}
@@ -2380,10 +2380,10 @@ Test case: MACRO_SHADOWED_VARIABLE_2
23802380
Source: ["foo", "bar"].map(x, [x + x, x + x]).map(x, [x + x, x + x])
23812381
=====>
23822382
COMPREHENSION [35] {
2383-
iter_var: @it:0:0
2383+
iter_var: x
23842384
iter_range: {
23852385
COMPREHENSION [19] {
2386-
iter_var: @it:1:0
2386+
iter_var: x
23872387
iter_range: {
23882388
LIST [1] {
23892389
elements: {
@@ -2392,7 +2392,7 @@ COMPREHENSION [35] {
23922392
}
23932393
}
23942394
}
2395-
accu_var: @ac:1:0
2395+
accu_var: @result
23962396
accu_init: {
23972397
LIST [13] {
23982398
elements: {
@@ -2407,7 +2407,7 @@ COMPREHENSION [35] {
24072407
function: _+_
24082408
args: {
24092409
IDENT [15] {
2410-
name: @ac:1:0
2410+
name: @result
24112411
}
24122412
LIST [16] {
24132413
elements: {
@@ -2417,21 +2417,21 @@ COMPREHENSION [35] {
24172417
function: _+_
24182418
args: {
24192419
IDENT [7] {
2420-
name: @it:1:0
2420+
name: x
24212421
}
24222422
IDENT [9] {
2423-
name: @it:1:0
2423+
name: x
24242424
}
24252425
}
24262426
}
24272427
CALL [11] {
24282428
function: _+_
24292429
args: {
24302430
IDENT [10] {
2431-
name: @it:1:0
2431+
name: x
24322432
}
24332433
IDENT [12] {
2434-
name: @it:1:0
2434+
name: x
24352435
}
24362436
}
24372437
}
@@ -2444,12 +2444,12 @@ COMPREHENSION [35] {
24442444
}
24452445
result: {
24462446
IDENT [18] {
2447-
name: @ac:1:0
2447+
name: @result
24482448
}
24492449
}
24502450
}
24512451
}
2452-
accu_var: @ac:0:0
2452+
accu_var: @result
24532453
accu_init: {
24542454
LIST [29] {
24552455
elements: {
@@ -2464,7 +2464,7 @@ COMPREHENSION [35] {
24642464
function: _+_
24652465
args: {
24662466
IDENT [31] {
2467-
name: @ac:0:0
2467+
name: @result
24682468
}
24692469
LIST [32] {
24702470
elements: {
@@ -2474,21 +2474,21 @@ COMPREHENSION [35] {
24742474
function: _+_
24752475
args: {
24762476
IDENT [23] {
2477-
name: @it:0:0
2477+
name: x
24782478
}
24792479
IDENT [25] {
2480-
name: @it:0:0
2480+
name: x
24812481
}
24822482
}
24832483
}
24842484
CALL [27] {
24852485
function: _+_
24862486
args: {
24872487
IDENT [26] {
2488-
name: @it:0:0
2488+
name: x
24892489
}
24902490
IDENT [28] {
2491-
name: @it:0:0
2491+
name: x
24922492
}
24932493
}
24942494
}
@@ -2501,7 +2501,7 @@ COMPREHENSION [35] {
25012501
}
25022502
result: {
25032503
IDENT [34] {
2504-
name: @ac:0:0
2504+
name: @result
25052505
}
25062506
}
25072507
}

0 commit comments

Comments
 (0)