3232import dev .cel .common .formats .ValueString ;
3333import dev .cel .common .navigation .CelNavigableMutableAst ;
3434import dev .cel .common .navigation .CelNavigableMutableExpr ;
35+ import dev .cel .common .types .CelType ;
36+ import dev .cel .common .types .OptionalType ;
3537import dev .cel .extensions .CelOptionalLibrary .Function ;
3638import dev .cel .optimizer .AstMutator ;
3739import dev .cel .optimizer .CelAstOptimizer ;
@@ -67,10 +69,14 @@ private Step optimizeRule(Cel cel, CelCompiledRule compiledRule) {
6769 // If the rule has an optional output, the last result in the ternary should return
6870 // `optional.none`. This output is implicit and created here to reflect the desired
6971 // last possible output of this type of rule.
72+ CelType ruleOutputType = getRuleOutputType (compiledRule );
73+ // The expected output type of the rule, used to verify that all branches agree on the type.
74+ CelType expectedOutputType = null ;
7075 if (compiledRule .hasOptionalOutput ()) {
7176 output =
7277 Step .newUnconditionalOptionalStep (
7378 newTrueLiteral (), astMutator .newGlobalCall (Function .OPTIONAL_NONE .getFunction ()));
79+ expectedOutputType = OptionalType .create (ruleOutputType );
7480 }
7581
7682 long lastOutputId = 0 ;
@@ -79,47 +85,61 @@ private Step optimizeRule(Cel cel, CelCompiledRule compiledRule) {
7985 boolean isTriviallyTrue = match .isConditionTriviallyTrue ();
8086 CelMutableAst condAst = CelMutableAst .fromCelAst (conditionAst );
8187
88+ CelType currentType ;
89+ Step step ;
90+ long currentSourceId = lastOutputId ;
91+
8292 switch (match .result ().kind ()) {
8393 case OUTPUT :
8494 // If the match has an output, then it is considered a non-optional output since
8595 // it is explicitly stated. If the rule itself is optional, then the base case value
8696 // of output being optional.none() will convert the non-optional value to an optional
8797 // one.
8898 OutputValue matchOutput = match .result ().output ();
89- CelMutableAst outAst = CelMutableAst .fromCelAst (matchOutput .ast ());
90- Step step = Step .newNonOptionalStep (!isTriviallyTrue , condAst , outAst );
99+ currentType = matchOutput .ast ().getResultType ();
100+ step =
101+ Step .newNonOptionalStep (
102+ !isTriviallyTrue , condAst , CelMutableAst .fromCelAst (matchOutput .ast ()));
103+ currentSourceId = matchOutput .sourceId ();
104+
105+ CelType baseType = (expectedOutputType == null ) ? currentType : expectedOutputType ;
106+ String outputFailureMessage =
107+ String .format (
108+ "incompatible output types: block has output type %s, but previous outputs have"
109+ + " type %s" ,
110+ baseType .name (), currentType .name ());
111+
91112 output = combine (astMutator , step , output );
113+ expectedOutputType = (expectedOutputType == null ) ? currentType : expectedOutputType ;
92114
93115 assertComposedAstIsValid (
94- cel ,
95- output .expr ,
96- "conflicting output types found." ,
97- matchOutput .sourceId (),
98- lastOutputId );
99- lastOutputId = matchOutput .sourceId ();
116+ cel , output .expr , outputFailureMessage , currentSourceId , lastOutputId );
100117 break ;
101118 case RULE :
102119 // If the match has a nested rule, then compute the rule and whether it has
103120 // an optional return value.
104121 CelCompiledRule matchNestedRule = match .result ().rule ();
122+ currentType = getRuleOutputType (matchNestedRule );
105123 Step nestedRule = optimizeRule (cel , matchNestedRule );
106- boolean nestedHasOptional = matchNestedRule .hasOptionalOutput ();
124+ step =
125+ new Step (
126+ matchNestedRule .hasOptionalOutput (), !isTriviallyTrue , condAst , nestedRule .expr );
127+ currentSourceId = matchNestedRule .sourceId ();
128+
129+ String ruleFailureMessage =
130+ String .format (
131+ "failed composing the subrule '%s' due to incompatible output types." ,
132+ matchNestedRule .ruleId ().map (ValueString ::value ).orElse ("" ));
107133
108- Step ruleStep =
109- nestedHasOptional
110- ? Step .newOptionalStep (!isTriviallyTrue , condAst , nestedRule .expr )
111- : Step .newNonOptionalStep (!isTriviallyTrue , condAst , nestedRule .expr );
112- output = combine (astMutator , ruleStep , output );
134+ output = combine (astMutator , step , output );
135+ expectedOutputType = (expectedOutputType == null ) ? currentType : expectedOutputType ;
113136
114137 assertComposedAstIsValid (
115- cel ,
116- output .expr ,
117- String .format (
118- "failed composing the subrule '%s' due to conflicting output types." ,
119- matchNestedRule .ruleId ().map (ValueString ::value ).orElse ("" )),
120- lastOutputId );
138+ cel , output .expr , ruleFailureMessage , currentSourceId , lastOutputId );
121139 break ;
122140 }
141+
142+ lastOutputId = currentSourceId ;
123143 }
124144
125145 CelMutableAst resultExpr = output .expr ;
@@ -281,6 +301,20 @@ private void assertComposedAstIsValid(
281301 }
282302 }
283303
304+ private CelType getRuleOutputType (CelCompiledRule rule ) {
305+ for (CelCompiledMatch match : rule .matches ()) {
306+ switch (match .result ().kind ()) {
307+ case OUTPUT :
308+ return match .result ().output ().ast ().getResultType ();
309+ case RULE :
310+ return getRuleOutputType (match .result ().rule ());
311+ }
312+ throw new IllegalStateException ("Unknown match result kind: " + match .result ().kind ());
313+ }
314+
315+ throw new IllegalStateException ("Policy rule contains no outputs" );
316+ }
317+
284318 // Step represents an intermediate stage of rule and match expression composition.
285319 //
286320 // The CelCompiledRule and CelCompiledMatch types are meant to represent standalone tuples of
@@ -311,11 +345,6 @@ private Step(
311345 this .expr = expr ;
312346 }
313347
314- private static Step newOptionalStep (
315- boolean isConditional , CelMutableAst cond , CelMutableAst expr ) {
316- return new Step (/* isOptional= */ true , isConditional , cond , expr );
317- }
318-
319348 private static Step newNonOptionalStep (
320349 boolean isConditional , CelMutableAst cond , CelMutableAst expr ) {
321350 return new Step (/* isOptional= */ false , isConditional , cond , expr );
0 commit comments