Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 74 additions & 16 deletions cel/folding.go
Original file line number Diff line number Diff line change
Expand Up @@ -506,26 +506,84 @@ func (opt *constantFoldingOptimizer) constantExprMatcher(ctx *OptimizerContext,
if isNestedComprehension(e) {
return false
}
vars := map[string]bool{}
constantExprs := true
visitor := ast.NewExprVisitor(func(e ast.Expr) {
if e.Kind() == ast.ComprehensionKind {
nested := e.AsComprehension()
vars[nested.AccuVar()] = true
vars[nested.IterVar()] = true
return opt.constantComprehension(ctx, a, e.AsComprehension(), map[string]bool{})
default:
return false
}
}

// constantComprehension reports whether a comprehension subtree can be folded into a constant,
// honoring CEL variable scoping: iteration ranges and accumulator initializers are evaluated in
// the enclosing scope, while loop conditions, loop steps, and results may also reference the
// iteration and accumulator variables.
func (opt *constantFoldingOptimizer) constantComprehension(ctx *OptimizerContext, a *ast.AST, compre ast.ComprehensionExpr, vars map[string]bool) bool {
if !opt.constantScopedExpr(ctx, a, compre.IterRange(), vars) ||
!opt.constantScopedExpr(ctx, a, compre.AccuInit(), vars) {
return false
}
inner := map[string]bool{}
for k := range vars {
inner[k] = true
}
inner[compre.AccuVar()] = true
inner[compre.IterVar()] = true
if compre.HasIterVar2() {
inner[compre.IterVar2()] = true
}
return opt.constantScopedExpr(ctx, a, compre.LoopCondition(), inner) &&
opt.constantScopedExpr(ctx, a, compre.LoopStep(), inner) &&
opt.constantScopedExpr(ctx, a, compre.Result(), inner)
}

// constantScopedExpr reports whether an expression references only the given in-scope variables
// and contains no late-bound function calls, recursing into nested comprehensions with their
// own iteration and accumulator variables.
func (opt *constantFoldingOptimizer) constantScopedExpr(ctx *OptimizerContext, a *ast.AST, e ast.Expr, vars map[string]bool) bool {

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there's probably a simpler way to achieve this check which the cel-java implementation uses:

https://github.com/google/cel-java/blob/14d4c2e39151f2e99e36f9818a9118b01c1d9ed3/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java#L224

Thanks for looking into this! Consider using the NavigableExpr methods to mirror how Java checks for this case.

switch e.Kind() {
case ast.ComprehensionKind:
return opt.constantComprehension(ctx, a, e.AsComprehension(), vars)
case ast.IdentKind:
return vars[e.AsIdent()]
case ast.CallKind:
if isLateBoundFunctionCall(ctx, a, e) {
return false
}
call := e.AsCall()
if call.IsMemberFunction() && !opt.constantScopedExpr(ctx, a, call.Target(), vars) {
return false
}
for _, arg := range call.Args() {
if !opt.constantScopedExpr(ctx, a, arg, vars) {
return false
}
if e.Kind() == ast.IdentKind && !vars[e.AsIdent()] {
constantExprs = false
}
return true
case ast.ListKind:
for _, elem := range e.AsList().Elements() {
if !opt.constantScopedExpr(ctx, a, elem, vars) {
return false
}
// Late-bound function calls cannot be folded.
if e.Kind() == ast.CallKind && isLateBoundFunctionCall(ctx, a, e) {
constantExprs = false
}
return true
case ast.MapKind:
for _, entry := range e.AsMap().Entries() {
me := entry.AsMapEntry()
if !opt.constantScopedExpr(ctx, a, me.Key(), vars) || !opt.constantScopedExpr(ctx, a, me.Value(), vars) {
return false
}
})
ast.PreOrderVisit(e, visitor)
return constantExprs
}
return true
case ast.StructKind:
for _, field := range e.AsStruct().Fields() {
if !opt.constantScopedExpr(ctx, a, field.AsStructField().Value(), vars) {
return false
}
}
return true
case ast.SelectKind:
return opt.constantScopedExpr(ctx, a, e.AsSelect().Operand(), vars)
default:
return false
return true
}
}

Expand Down
4 changes: 4 additions & 0 deletions cel/folding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,10 @@ func TestConstantFoldingOptimizer(t *testing.T) {
"o": &proto3pb.TestAllTypes{RepeatedInt32: []int32{1, 2, 3}},
},
},
{
expr: `[(x - 1 > 3) ? (x - 1) : 5].exists(x, x - 1 > 3)`,
folded: `[(x - 1 > 3) ? (x - 1) : 5].exists(x, x - 1 > 3)`,
},
}
e, err := NewEnv(
OptionalTypes(),
Expand Down