diff --git a/cel/folding.go b/cel/folding.go index d1ea6b19..3bf35d1f 100644 --- a/cel/folding.go +++ b/cel/folding.go @@ -506,26 +506,84 @@ func (opt *constantFoldingOptimizer) constantExprMatcher(ctx *OptimizerContext, if isNestedComprehension(e) { return false } - vars := map[string]bool{} - constantExprs := true - visitor := ast.NewExprVisitor(func(e ast.Expr) { - if e.Kind() == ast.ComprehensionKind { - nested := e.AsComprehension() - vars[nested.AccuVar()] = true - vars[nested.IterVar()] = true + return opt.constantComprehension(ctx, a, e.AsComprehension(), map[string]bool{}) + default: + return false + } +} + +// constantComprehension reports whether a comprehension subtree can be folded into a constant, +// honoring CEL variable scoping: iteration ranges and accumulator initializers are evaluated in +// the enclosing scope, while loop conditions, loop steps, and results may also reference the +// iteration and accumulator variables. +func (opt *constantFoldingOptimizer) constantComprehension(ctx *OptimizerContext, a *ast.AST, compre ast.ComprehensionExpr, vars map[string]bool) bool { + if !opt.constantScopedExpr(ctx, a, compre.IterRange(), vars) || + !opt.constantScopedExpr(ctx, a, compre.AccuInit(), vars) { + return false + } + inner := map[string]bool{} + for k := range vars { + inner[k] = true + } + inner[compre.AccuVar()] = true + inner[compre.IterVar()] = true + if compre.HasIterVar2() { + inner[compre.IterVar2()] = true + } + return opt.constantScopedExpr(ctx, a, compre.LoopCondition(), inner) && + opt.constantScopedExpr(ctx, a, compre.LoopStep(), inner) && + opt.constantScopedExpr(ctx, a, compre.Result(), inner) +} + +// constantScopedExpr reports whether an expression references only the given in-scope variables +// and contains no late-bound function calls, recursing into nested comprehensions with their +// own iteration and accumulator variables. +func (opt *constantFoldingOptimizer) constantScopedExpr(ctx *OptimizerContext, a *ast.AST, e ast.Expr, vars map[string]bool) bool { + switch e.Kind() { + case ast.ComprehensionKind: + return opt.constantComprehension(ctx, a, e.AsComprehension(), vars) + case ast.IdentKind: + return vars[e.AsIdent()] + case ast.CallKind: + if isLateBoundFunctionCall(ctx, a, e) { + return false + } + call := e.AsCall() + if call.IsMemberFunction() && !opt.constantScopedExpr(ctx, a, call.Target(), vars) { + return false + } + for _, arg := range call.Args() { + if !opt.constantScopedExpr(ctx, a, arg, vars) { + return false } - if e.Kind() == ast.IdentKind && !vars[e.AsIdent()] { - constantExprs = false + } + return true + case ast.ListKind: + for _, elem := range e.AsList().Elements() { + if !opt.constantScopedExpr(ctx, a, elem, vars) { + return false } - // Late-bound function calls cannot be folded. - if e.Kind() == ast.CallKind && isLateBoundFunctionCall(ctx, a, e) { - constantExprs = false + } + return true + case ast.MapKind: + for _, entry := range e.AsMap().Entries() { + me := entry.AsMapEntry() + if !opt.constantScopedExpr(ctx, a, me.Key(), vars) || !opt.constantScopedExpr(ctx, a, me.Value(), vars) { + return false } - }) - ast.PreOrderVisit(e, visitor) - return constantExprs + } + return true + case ast.StructKind: + for _, field := range e.AsStruct().Fields() { + if !opt.constantScopedExpr(ctx, a, field.AsStructField().Value(), vars) { + return false + } + } + return true + case ast.SelectKind: + return opt.constantScopedExpr(ctx, a, e.AsSelect().Operand(), vars) default: - return false + return true } } diff --git a/cel/folding_test.go b/cel/folding_test.go index 08d5f517..2bdb1dcb 100644 --- a/cel/folding_test.go +++ b/cel/folding_test.go @@ -326,6 +326,10 @@ func TestConstantFoldingOptimizer(t *testing.T) { "o": &proto3pb.TestAllTypes{RepeatedInt32: []int32{1, 2, 3}}, }, }, + { + expr: `[(x - 1 > 3) ? (x - 1) : 5].exists(x, x - 1 > 3)`, + folded: `[(x - 1 > 3) ? (x - 1) : 5].exists(x, x - 1 > 3)`, + }, } e, err := NewEnv( OptionalTypes(),