|
18 | 18 | import static com.google.common.collect.ImmutableList.toImmutableList; |
19 | 19 | import static java.util.stream.Collectors.toCollection; |
20 | 20 |
|
| 21 | +import com.google.common.base.Preconditions; |
21 | 22 | import com.google.common.collect.ImmutableList; |
22 | 23 | import com.google.common.collect.Lists; |
23 | 24 | import dev.cel.bundle.Cel; |
|
32 | 33 | import dev.cel.common.formats.ValueString; |
33 | 34 | import dev.cel.common.navigation.CelNavigableMutableAst; |
34 | 35 | import dev.cel.common.navigation.CelNavigableMutableExpr; |
| 36 | +import dev.cel.common.types.CelType; |
35 | 37 | import dev.cel.extensions.CelOptionalLibrary.Function; |
36 | 38 | import dev.cel.optimizer.AstMutator; |
37 | 39 | import dev.cel.optimizer.CelAstOptimizer; |
38 | 40 | import dev.cel.policy.CelCompiledRule.CelCompiledMatch; |
39 | 41 | import dev.cel.policy.CelCompiledRule.CelCompiledMatch.OutputValue; |
| 42 | +import dev.cel.policy.CelCompiledRule.CelCompiledMatch.Result; |
40 | 43 | import dev.cel.policy.CelCompiledRule.CelCompiledVariable; |
41 | 44 | import java.util.ArrayList; |
42 | 45 | import java.util.Arrays; |
@@ -74,54 +77,70 @@ private Step optimizeRule(Cel cel, CelCompiledRule compiledRule) { |
74 | 77 | } |
75 | 78 |
|
76 | 79 | long lastOutputId = 0; |
| 80 | + // The expected output type of the rule, used to verify that all branches agree on the type. |
| 81 | + CelType lastOutputType = null; |
77 | 82 | for (CelCompiledMatch match : Lists.reverse(compiledRule.matches())) { |
78 | 83 | CelAbstractSyntaxTree conditionAst = match.condition(); |
79 | 84 | boolean isTriviallyTrue = match.isConditionTriviallyTrue(); |
80 | 85 | CelMutableAst condAst = CelMutableAst.fromCelAst(conditionAst); |
81 | 86 |
|
| 87 | + long currentSourceId = lastOutputId; |
| 88 | + |
82 | 89 | switch (match.result().kind()) { |
83 | 90 | case OUTPUT: |
84 | 91 | // If the match has an output, then it is considered a non-optional output since |
85 | 92 | // it is explicitly stated. If the rule itself is optional, then the base case value |
86 | 93 | // of output being optional.none() will convert the non-optional value to an optional |
87 | 94 | // one. |
88 | 95 | OutputValue matchOutput = match.result().output(); |
89 | | - CelMutableAst outAst = CelMutableAst.fromCelAst(matchOutput.ast()); |
90 | | - Step step = Step.newNonOptionalStep(!isTriviallyTrue, condAst, outAst); |
| 96 | + Step step = |
| 97 | + Step.newNonOptionalStep( |
| 98 | + !isTriviallyTrue, condAst, CelMutableAst.fromCelAst(matchOutput.ast())); |
| 99 | + currentSourceId = matchOutput.sourceId(); |
| 100 | + |
91 | 101 | output = combine(astMutator, step, output); |
92 | 102 |
|
93 | | - assertComposedAstIsValid( |
94 | | - cel, |
95 | | - output.expr, |
96 | | - "incompatible output types found.", |
97 | | - matchOutput.sourceId(), |
98 | | - lastOutputId); |
99 | | - lastOutputId = matchOutput.sourceId(); |
| 103 | + String outputFailureMessage = |
| 104 | + String.format( |
| 105 | + "incompatible output types: block has output type %s, but previous outputs have" |
| 106 | + + " type %s", |
| 107 | + lastOutputType == null ? "" : lastOutputType.name(), |
| 108 | + matchOutput.ast().getResultType().name()); |
| 109 | + lastOutputType = |
| 110 | + assertComposedAstIsValid( |
| 111 | + cel, output.expr, outputFailureMessage, currentSourceId, lastOutputId) |
| 112 | + .getResultType(); |
| 113 | + |
100 | 114 | break; |
101 | 115 | case RULE: |
102 | 116 | // If the match has a nested rule, then compute the rule and whether it has |
103 | 117 | // an optional return value. |
104 | 118 | CelCompiledRule matchNestedRule = match.result().rule(); |
105 | 119 | Step nestedRule = optimizeRule(cel, matchNestedRule); |
106 | | - boolean nestedHasOptional = matchNestedRule.hasOptionalOutput(); |
107 | | - |
108 | 120 | Step ruleStep = |
109 | | - nestedHasOptional |
110 | | - ? Step.newOptionalStep(!isTriviallyTrue, condAst, nestedRule.expr) |
111 | | - : Step.newNonOptionalStep(!isTriviallyTrue, condAst, nestedRule.expr); |
| 121 | + new Step( |
| 122 | + matchNestedRule.hasOptionalOutput(), !isTriviallyTrue, condAst, nestedRule.expr); |
| 123 | + currentSourceId = getFirstOutputSourceId(matchNestedRule); |
| 124 | + |
112 | 125 | output = combine(astMutator, ruleStep, output); |
113 | 126 |
|
114 | | - assertComposedAstIsValid( |
115 | | - cel, |
116 | | - output.expr, |
117 | | - String.format( |
118 | | - "failed composing the subrule '%s' due to incompatible output types.", |
119 | | - matchNestedRule.ruleId().map(ValueString::value).orElse("")), |
120 | | - lastOutputId); |
| 127 | + lastOutputType = |
| 128 | + assertComposedAstIsValid( |
| 129 | + cel, |
| 130 | + output.expr, |
| 131 | + String.format( |
| 132 | + "failed composing the subrule '%s' due to incompatible output types.", |
| 133 | + matchNestedRule.ruleId().map(ValueString::value).orElse("")), |
| 134 | + currentSourceId, |
| 135 | + lastOutputId) |
| 136 | + .getResultType(); |
121 | 137 | break; |
122 | 138 | } |
| 139 | + |
| 140 | + lastOutputId = currentSourceId; |
123 | 141 | } |
124 | 142 |
|
| 143 | + Preconditions.checkState(output != null, "Policy contains no outputs."); |
125 | 144 | CelMutableAst resultExpr = output.expr; |
126 | 145 | resultExpr = inlineCompiledVariables(resultExpr, compiledRule.variables()); |
127 | 146 | resultExpr = astMutator.renumberIdsConsecutively(resultExpr); |
@@ -266,21 +285,30 @@ private CelMutableAst inlineCompiledVariables( |
266 | 285 | return mutatedAst; |
267 | 286 | } |
268 | 287 |
|
269 | | - private void assertComposedAstIsValid( |
| 288 | + private CelAbstractSyntaxTree assertComposedAstIsValid( |
270 | 289 | Cel cel, CelMutableAst composedAst, String failureMessage, Long... ids) { |
271 | | - assertComposedAstIsValid(cel, composedAst, failureMessage, Arrays.asList(ids)); |
| 290 | + return assertComposedAstIsValid(cel, composedAst, failureMessage, Arrays.asList(ids)); |
272 | 291 | } |
273 | 292 |
|
274 | | - private void assertComposedAstIsValid( |
| 293 | + private CelAbstractSyntaxTree assertComposedAstIsValid( |
275 | 294 | Cel cel, CelMutableAst composedAst, String failureMessage, List<Long> ids) { |
276 | 295 | try { |
277 | | - cel.check(composedAst.toParsedAst()).getAst(); |
| 296 | + return cel.check(composedAst.toParsedAst()).getAst(); |
278 | 297 | } catch (CelValidationException e) { |
279 | 298 | ids = ids.stream().filter(id -> id > 0).collect(toCollection(ArrayList::new)); |
280 | 299 | throw new RuleCompositionException(failureMessage, e, ids); |
281 | 300 | } |
282 | 301 | } |
283 | 302 |
|
| 303 | + private static long getFirstOutputSourceId(CelCompiledRule rule) { |
| 304 | + for (CelCompiledMatch match : rule.matches()) { |
| 305 | + if (match.result().kind() == Result.Kind.OUTPUT) { |
| 306 | + return match.result().output().sourceId(); |
| 307 | + } |
| 308 | + } |
| 309 | + return rule.sourceId(); |
| 310 | + } |
| 311 | + |
284 | 312 | // Step represents an intermediate stage of rule and match expression composition. |
285 | 313 | // |
286 | 314 | // The CelCompiledRule and CelCompiledMatch types are meant to represent standalone tuples of |
@@ -311,11 +339,6 @@ private Step( |
311 | 339 | this.expr = expr; |
312 | 340 | } |
313 | 341 |
|
314 | | - private static Step newOptionalStep( |
315 | | - boolean isConditional, CelMutableAst cond, CelMutableAst expr) { |
316 | | - return new Step(/* isOptional= */ true, isConditional, cond, expr); |
317 | | - } |
318 | | - |
319 | 342 | private static Step newNonOptionalStep( |
320 | 343 | boolean isConditional, CelMutableAst cond, CelMutableAst expr) { |
321 | 344 | return new Step(/* isOptional= */ false, isConditional, cond, expr); |
|
0 commit comments