Skip to content

Commit 7718a6e

Browse files
l46kokcopybara-github
authored andcommitted
Optimize composed policies using Constant Folding and Common Subexpression Elimination
PiperOrigin-RevId: 797449393
1 parent 37828da commit 7718a6e

11 files changed

Lines changed: 154 additions & 85 deletions

File tree

optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,9 @@ private static boolean canFoldInOperator(CelNavigableMutableExpr navigableExpr)
202202
CelNavigableMutableExpr parent = identNode.parent().orElse(null);
203203
while (parent != null) {
204204
if (parent.getKind().equals(Kind.COMPREHENSION)) {
205-
if (parent.expr().comprehension().accuVar().equals(identNode.expr().ident().name())) {
205+
String identName = identNode.expr().ident().name();
206+
if (parent.expr().comprehension().accuVar().equals(identName)
207+
|| parent.expr().comprehension().iterVar().equals(identName)) {
206208
// Prevent folding a subexpression if it contains a variable declared by a
207209
// comprehension. The subexpression cannot be compiled without the full context of the
208210
// surrounding comprehension.

parser/src/main/java/dev/cel/parser/Operator.java

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,22 +103,34 @@ static Optional<Operator> find(String text) {
103103
private static final ImmutableMap<String, Operator> REVERSE_OPERATORS =
104104
ImmutableMap.<String, Operator>builder()
105105
.put(ADD.getFunction(), ADD)
106+
.put(ALL.getFunction(), ALL)
107+
.put(CONDITIONAL.getFunction(), CONDITIONAL)
106108
.put(DIVIDE.getFunction(), DIVIDE)
107109
.put(EQUALS.getFunction(), EQUALS)
110+
.put(EXISTS.getFunction(), EXISTS)
111+
.put(EXISTS_ONE.getFunction(), EXISTS_ONE)
112+
.put(FILTER.getFunction(), FILTER)
108113
.put(GREATER.getFunction(), GREATER)
109114
.put(GREATER_EQUALS.getFunction(), GREATER_EQUALS)
115+
.put(HAS.getFunction(), HAS)
110116
.put(IN.getFunction(), IN)
117+
.put(INDEX.getFunction(), INDEX)
111118
.put(LESS.getFunction(), LESS)
112119
.put(LESS_EQUALS.getFunction(), LESS_EQUALS)
113120
.put(LOGICAL_AND.getFunction(), LOGICAL_AND)
114121
.put(LOGICAL_NOT.getFunction(), LOGICAL_NOT)
115122
.put(LOGICAL_OR.getFunction(), LOGICAL_OR)
123+
.put(MAP.getFunction(), MAP)
116124
.put(MODULO.getFunction(), MODULO)
117125
.put(MULTIPLY.getFunction(), MULTIPLY)
118126
.put(NEGATE.getFunction(), NEGATE)
119127
.put(NOT_EQUALS.getFunction(), NOT_EQUALS)
120-
.put(SUBTRACT.getFunction(), SUBTRACT)
128+
.put(NOT_STRICTLY_FALSE.getFunction(), NOT_STRICTLY_FALSE)
121129
.put(OLD_IN.getFunction(), OLD_IN)
130+
.put(OLD_NOT_STRICTLY_FALSE.getFunction(), OLD_NOT_STRICTLY_FALSE)
131+
.put(OPTIONAL_INDEX.getFunction(), OPTIONAL_INDEX)
132+
.put(OPTIONAL_SELECT.getFunction(), OPTIONAL_SELECT)
133+
.put(SUBTRACT.getFunction(), SUBTRACT)
122134
.buildOrThrow();
123135

124136
// precedence of the operator, where the higher value means higher.
@@ -168,8 +180,8 @@ static Optional<Operator> find(String text) {
168180
.put(MODULO.getFunction(), "%")
169181
.buildOrThrow();
170182

171-
/** Lookup an operator by its mangled name, as used within the AST. */
172-
static Optional<Operator> findReverse(String op) {
183+
/** Lookup an operator by its mangled name (ex: _&&_), as used within the AST. */
184+
public static Optional<Operator> findReverse(String op) {
173185
return Optional.ofNullable(REVERSE_OPERATORS.get(op));
174186
}
175187

parser/src/test/java/dev/cel/parser/OperatorTest.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import static org.junit.Assert.assertFalse;
2020
import static org.junit.Assert.assertTrue;
2121

22+
import com.google.testing.junit.testparameterinjector.TestParameter;
2223
import com.google.testing.junit.testparameterinjector.TestParameterInjector;
2324
import com.google.testing.junit.testparameterinjector.TestParameters;
2425
import dev.cel.common.ast.CelExpr;
@@ -50,6 +51,11 @@ public void findReverse_returnsCorrectOperator() {
5051
assertThat(Operator.findReverse("_+_")).hasValue(Operator.ADD);
5152
}
5253

54+
@Test
55+
public void findReverse_allOperators(@TestParameter Operator operator) {
56+
assertThat(Operator.findReverse(operator.getFunction())).hasValue(operator);
57+
}
58+
5359
@Test
5460
public void findReverseBinaryOperator_returnsEmptyWhenNotFound() {
5561
assertThat(Operator.findReverseBinaryOperator("+")).isEmpty();

policy/src/main/java/dev/cel/policy/BUILD.bazel

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,8 @@ java_library(
215215
"//optimizer",
216216
"//optimizer:optimization_exception",
217217
"//optimizer:optimizer_builder",
218+
"//optimizer/optimizers:common_subexpression_elimination",
219+
"//optimizer/optimizers:constant_folding",
218220
"//validator",
219221
"//validator:ast_validator",
220222
"//validator:validator_builder",
@@ -247,7 +249,9 @@ java_library(
247249
"//common:cel_ast",
248250
"//common:compiler_common",
249251
"//common:mutable_ast",
252+
"//common/ast",
250253
"//common/formats:value_string",
254+
"//common/navigation:mutable_navigation",
251255
"//extensions:optional_library",
252256
"//optimizer:ast_optimizer",
253257
"//optimizer:mutable_ast",

policy/src/main/java/dev/cel/policy/CelPolicyCompilerImpl.java

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@
3636
import dev.cel.optimizer.CelOptimizationException;
3737
import dev.cel.optimizer.CelOptimizer;
3838
import dev.cel.optimizer.CelOptimizerFactory;
39+
import dev.cel.optimizer.optimizers.ConstantFoldingOptimizer;
40+
import dev.cel.optimizer.optimizers.SubexpressionOptimizer;
41+
import dev.cel.optimizer.optimizers.SubexpressionOptimizer.SubexpressionOptimizerOptions;
3942
import dev.cel.policy.CelCompiledRule.CelCompiledMatch;
4043
import dev.cel.policy.CelCompiledRule.CelCompiledMatch.Result;
4144
import dev.cel.policy.CelCompiledRule.CelCompiledMatch.Result.Kind;
@@ -98,7 +101,7 @@ public CelCompiledRule compileRule(CelPolicy policy) throws CelPolicyValidationE
98101
public CelAbstractSyntaxTree compose(CelPolicy policy, CelCompiledRule compiledRule)
99102
throws CelPolicyValidationException {
100103
Cel cel = compiledRule.cel();
101-
CelOptimizer optimizer =
104+
CelOptimizer composingOptimizer =
102105
CelOptimizerFactory.standardCelOptimizerBuilder(cel)
103106
.addAstOptimizers(
104107
RuleComposer.newInstance(compiledRule, variablesPrefix, iterationLimit))
@@ -110,7 +113,7 @@ public CelAbstractSyntaxTree compose(CelPolicy policy, CelCompiledRule compiledR
110113
// This is a minimal expression used as a basis of stitching together all the rules into a
111114
// single graph.
112115
ast = cel.compile("true").getAst();
113-
ast = optimizer.optimize(ast);
116+
ast = composingOptimizer.optimize(ast);
114117
} catch (CelValidationException | CelOptimizationException e) {
115118
if (e.getCause() instanceof RuleCompositionException) {
116119
RuleCompositionException re = (RuleCompositionException) e.getCause();
@@ -136,6 +139,24 @@ public CelAbstractSyntaxTree compose(CelPolicy policy, CelCompiledRule compiledR
136139
throw new CelPolicyValidationException("Unexpected error while composing rules.", e);
137140
}
138141

142+
CelOptimizer astOptimizer =
143+
CelOptimizerFactory.standardCelOptimizerBuilder(cel)
144+
.addAstOptimizers(
145+
ConstantFoldingOptimizer.getInstance(),
146+
SubexpressionOptimizer.newInstance(
147+
SubexpressionOptimizerOptions.newBuilder()
148+
.populateMacroCalls(true)
149+
.enableCelBlock(true)
150+
.build()))
151+
.build();
152+
try {
153+
// Optimize the composed graph using const fold and CSE
154+
ast = astOptimizer.optimize(ast);
155+
} catch (CelOptimizationException e) {
156+
throw new CelPolicyValidationException(
157+
"Failed to optimize the composed policy. Reason: " + e.getMessage(), e);
158+
}
159+
139160
assertAstDepthIsSafe(ast, cel);
140161

141162
return ast;

policy/src/main/java/dev/cel/policy/RuleComposer.java

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,16 @@
1919
import static java.util.stream.Collectors.toCollection;
2020

2121
import com.google.auto.value.AutoValue;
22+
import com.google.common.collect.ImmutableList;
2223
import com.google.common.collect.Lists;
2324
import dev.cel.bundle.Cel;
2425
import dev.cel.common.CelAbstractSyntaxTree;
2526
import dev.cel.common.CelMutableAst;
2627
import dev.cel.common.CelValidationException;
28+
import dev.cel.common.ast.CelExpr.ExprKind.Kind;
2729
import dev.cel.common.formats.ValueString;
30+
import dev.cel.common.navigation.CelNavigableMutableAst;
31+
import dev.cel.common.navigation.CelNavigableMutableExpr;
2832
import dev.cel.extensions.CelOptionalLibrary.Function;
2933
import dev.cel.optimizer.AstMutator;
3034
import dev.cel.optimizer.CelAstOptimizer;
@@ -151,23 +155,37 @@ private RuleOptimizationResult optimizeRule(Cel cel, CelCompiledRule compiledRul
151155
}
152156
}
153157

154-
CelMutableAst result = matchAst;
155-
for (CelCompiledVariable variable : Lists.reverse(compiledRule.variables())) {
156-
result =
157-
astMutator.replaceSubtreeWithNewBindMacro(
158-
result,
159-
variablePrefix + variable.name(),
160-
CelMutableAst.fromCelAst(variable.ast()),
161-
result.expr(),
162-
result.expr().id(),
163-
true);
164-
}
158+
CelMutableAst result = inlineCompiledVariables(matchAst, compiledRule.variables());
165159

166160
result = astMutator.renumberIdsConsecutively(result);
167161

168162
return RuleOptimizationResult.create(result, isOptionalResult);
169163
}
170164

165+
private CelMutableAst inlineCompiledVariables(
166+
CelMutableAst ast, List<CelCompiledVariable> compiledVariables) {
167+
CelMutableAst mutatedAst = ast;
168+
for (CelCompiledVariable compiledVariable : Lists.reverse(compiledVariables)) {
169+
String variableName = variablePrefix + compiledVariable.name();
170+
ImmutableList<CelNavigableMutableExpr> exprsToReplace =
171+
CelNavigableMutableAst.fromAst(mutatedAst)
172+
.getRoot()
173+
.allNodes()
174+
.filter(
175+
node ->
176+
node.expr().getKind().equals(Kind.IDENT)
177+
&& node.expr().ident().name().equals(variableName))
178+
.collect(toImmutableList());
179+
180+
for (CelNavigableMutableExpr expr : exprsToReplace) {
181+
CelMutableAst variableAst = CelMutableAst.fromCelAst(compiledVariable.ast());
182+
mutatedAst = astMutator.replaceSubtree(mutatedAst, variableAst, expr.id());
183+
}
184+
}
185+
186+
return mutatedAst;
187+
}
188+
171189
static RuleComposer newInstance(
172190
CelCompiledRule compiledRule, String variablePrefix, int iterationLimit) {
173191
return new RuleComposer(compiledRule, variablePrefix, iterationLimit);

policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,8 @@ public void compileYamlPolicy_multilineContainsError_throws(
136136

137137
@Test
138138
public void compileYamlPolicy_exceedsDefaultAstDepthLimit_throws() throws Exception {
139-
String longExpr =
140-
"0+1+2+3+4+5+6+7+8+9+10+11+12+13+14+15+16+17+18+19+20+21+22+23+24+25+26+27+28+29+30+31+32+33+34+35+36+37+38+39+40+41+42+43+44+45+46+47+48+49+50";
139+
Cel cel = newCel().toCelBuilder().addVar("msg", SimpleType.DYN).build();
140+
String longExpr = "msg.b.c.d.e.f";
141141
String policyContent =
142142
String.format(
143143
"name: deeply_nested_ast\n" + "rule:\n" + " match:\n" + " - output: %s", longExpr);
@@ -146,11 +146,35 @@ public void compileYamlPolicy_exceedsDefaultAstDepthLimit_throws() throws Except
146146
CelPolicyValidationException e =
147147
assertThrows(
148148
CelPolicyValidationException.class,
149-
() -> CelPolicyCompilerFactory.newPolicyCompiler(newCel()).build().compile(policy));
149+
() ->
150+
CelPolicyCompilerFactory.newPolicyCompiler(cel)
151+
.setAstDepthLimit(5)
152+
.build()
153+
.compile(policy));
154+
155+
assertThat(e)
156+
.hasMessageThat()
157+
.isEqualTo("ERROR: <input>:-1:0: AST's depth exceeds the configured limit: 5.");
158+
}
150159

160+
@Test
161+
public void compileYamlPolicy_constantFoldingFailure_throwsDuringComposition() throws Exception {
162+
String policyContent =
163+
"name: ast_with_div_by_zero\n" //
164+
+ "rule:\n" //
165+
+ " match:\n" //
166+
+ " - output: 1 / 0";
167+
CelPolicy policy = POLICY_PARSER.parse(policyContent);
168+
169+
CelPolicyValidationException e =
170+
assertThrows(
171+
CelPolicyValidationException.class,
172+
() -> CelPolicyCompilerFactory.newPolicyCompiler(newCel()).build().compile(policy));
151173
assertThat(e)
152174
.hasMessageThat()
153-
.isEqualTo("ERROR: <input>:-1:0: AST's depth exceeds the configured limit: 50.");
175+
.isEqualTo(
176+
"Failed to optimize the composed policy. Reason: Constant folding failure. Failed to"
177+
+ " evaluate subtree due to: evaluation error: / by zero");
154178
}
155179

156180
@Test

0 commit comments

Comments
 (0)