diff --git a/cel/cel_test.go b/cel/cel_test.go index c2b9ad0b..05adeb2f 100644 --- a/cel/cel_test.go +++ b/cel/cel_test.go @@ -1713,6 +1713,77 @@ func TestCustomInterpreterDecorator(t *testing.T) { } } +func TestCustomInterpreterDecoratorV2(t *testing.T) { + var lastInstruction interpreter.InterpretableV2 + optimizeArith := func(i interpreter.InterpretableV2) (interpreter.InterpretableV2, error) { + lastInstruction = i + // Only optimize the instruction if it is a call. + call, ok := i.(interpreter.InterpretableCall) + if !ok { + return i, nil + } + // Only optimize the math functions when they have constant arguments. + switch call.Function() { + case operators.Add, + operators.Subtract, + operators.Multiply, + operators.Divide: + // These are all binary operators so they should have two arguments + args := call.Args() + _, lhsIsConst := args[0].(interpreter.InterpretableConst) + _, rhsIsConst := args[1].(interpreter.InterpretableConst) + // When the values are constant then the call can be evaluated with + // an empty activation and the value returns as a constant. + if !lhsIsConst || !rhsIsConst { + return i, nil + } + val := call.Eval(interpreter.EmptyActivation()) + if types.IsError(val) { + return nil, val.(*types.Err) + } + return interpreter.NewConstValue(call.ID(), val), nil + default: + return i, nil + } + } + + env := testEnv(t, Variable("foo", IntType)) + ast, iss := env.Compile(`foo == -1 + 2 * 3 / 3`) + if iss.Err() != nil { + t.Fatalf("env.Compile() failed: %v", iss.Err()) + } + _, err := env.Program(ast, + EvalOptions(OptPartialEval), + CustomDecoratorV2(optimizeArith)) + if err != nil { + t.Fatalf("env.Program() failed: %v", err) + } + call, ok := lastInstruction.(interpreter.InterpretableCall) + if !ok { + t.Errorf("got %v, expected call", lastInstruction) + } + args := call.Args() + lhs := args[0] + lastAttr, ok := lhs.(interpreter.InterpretableAttribute) + if !ok { + t.Errorf("got %v, wanted attribute", lhs) + } + absAttr := lastAttr.Attr().(interpreter.NamespacedAttribute) + varNames := absAttr.CandidateVariableNames() + if len(varNames) != 1 || varNames[0] != "foo" { + t.Errorf("got variables %v, wanted foo", varNames) + } + rhs := args[1] + lastConst, ok := rhs.(interpreter.InterpretableConst) + if !ok { + t.Errorf("got %v, wanted constant", rhs) + } + // This is the last number produced by the optimization. + if lastConst.Value().Equal(types.IntOne) == types.False { + t.Errorf("got %v as the last observed constant, wanted 1", lastConst) + } +} + // TestEstimateCostAndRuntimeCost sanity checks that the cost systems are usable from the program API. func TestEstimateCostAndRuntimeCost(t *testing.T) { intList := ListType(IntType) @@ -3961,3 +4032,124 @@ func TestExpressionSizeLimitEarlyEnforcement(t *testing.T) { }) } } + +func TestProgramEvalInvalidInput(t *testing.T) { + env, err := NewEnv() + if err != nil { + t.Fatalf("NewEnv() failed: %v", err) + } + ast, iss := env.Compile("true") + if iss.Err() != nil { + t.Fatalf("Compile() failed: %v", iss.Err()) + } + prg, err := env.Program(ast) + if err != nil { + t.Fatalf("Program() failed: %v", err) + } + + tests := []struct { + name string + input any + wantErr string + }{ + { + name: "int input", + input: 123, + wantErr: "invalid input, wanted Activation or map[string]any", + }, + { + name: "nil input", + input: nil, + wantErr: "invalid input, wanted Activation or map[string]any", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + _, _, err := prg.Eval(tc.input) + if err == nil || !strings.Contains(err.Error(), tc.wantErr) { + t.Errorf("Eval(%v) err = %v, expected error containing %q", tc.input, err, tc.wantErr) + } + }) + } +} + +func TestProgramContextEvalInvalidInput(t *testing.T) { + env, err := NewEnv() + if err != nil { + t.Fatalf("NewEnv() failed: %v", err) + } + ast, iss := env.Compile("true") + if iss.Err() != nil { + t.Fatalf("Compile() failed: %v", iss.Err()) + } + prg, err := env.Program(ast) + if err != nil { + t.Fatalf("Program() failed: %v", err) + } + + tests := []struct { + name string + ctx context.Context + input any + wantErr string + }{ + { + name: "nil context", + ctx: nil, + input: map[string]any{}, + wantErr: "context can not be nil", + }, + { + name: "invalid input type", + ctx: context.Background(), + input: 123, + wantErr: "invalid input, wanted Activation or map[string]any", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + _, _, err := prg.ContextEval(tc.ctx, tc.input) + if err == nil || !strings.Contains(err.Error(), tc.wantErr) { + t.Errorf("ContextEval(%v, %v) err = %v, expected error containing %q", tc.ctx, tc.input, err, tc.wantErr) + } + }) + } +} + +func TestOptionalOperatorsLegacyEval(t *testing.T) { + tests := []struct { + name string + expr interpreter.Interpretable + want ref.Val + }{ + { + name: "optional or", + expr: &evalOptionalOr{ + id: 1, + lhs: interpreter.NewConstValue(2, types.OptionalOf(types.True)), + rhs: interpreter.NewConstValue(3, types.False), + }, + want: types.OptionalOf(types.True), + }, + { + name: "optional or value mismatch", + expr: &evalOptionalOrValue{ + id: 4, + lhs: interpreter.NewConstValue(5, types.True), + rhs: interpreter.NewConstValue(6, types.False), + }, + want: types.NoSuchOverloadErr(), + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := tc.expr.Eval(NoVars()) + if got.Equal(tc.want) != types.True && !types.IsError(got) { + t.Errorf("Eval() = %v, wanted %v", got, tc.want) + } + if types.IsError(got) && got.(*types.Err).Error() != tc.want.(*types.Err).Error() { + t.Errorf("Eval() = %v, wanted %v", got, tc.want) + } + }) + } +} diff --git a/cel/library.go b/cel/library.go index 3c8b6ba3..332eb3f1 100644 --- a/cel/library.go +++ b/cel/library.go @@ -590,7 +590,7 @@ func (lib *optionalLib) CompileOptions() []EnvOption { // ProgramOptions implements the Library interface method. func (lib *optionalLib) ProgramOptions() []ProgramOption { return []ProgramOption{ - CustomDecorator(decorateOptionalOr), + CustomDecoratorV2(decorateOptionalOr), } } @@ -683,7 +683,7 @@ func EnableErrorOnBadPresenceTest(value bool) EnvOption { return features(featureEnableErrorOnBadPresenceTest, value) } -func decorateOptionalOr(i interpreter.Interpretable) (interpreter.Interpretable, error) { +func decorateOptionalOr(i interpreter.InterpretableV2) (interpreter.InterpretableV2, error) { call, ok := i.(interpreter.InterpretableCall) if !ok { return i, nil @@ -720,8 +720,8 @@ func decorateOptionalOr(i interpreter.Interpretable) (interpreter.Interpretable, // the second optional expression is evaluated and returned. type evalOptionalOr struct { id int64 - lhs interpreter.Interpretable - rhs interpreter.Interpretable + lhs interpreter.InterpretableV2 + rhs interpreter.InterpretableV2 } // ID implements the Interpretable interface method. @@ -729,11 +729,9 @@ func (opt *evalOptionalOr) ID() int64 { return opt.id } -// Eval evaluates the left-hand side optional to determine whether it contains a value, else -// proceeds with the right-hand side evaluation. -func (opt *evalOptionalOr) Eval(ctx interpreter.Activation) ref.Val { +func (opt *evalOptionalOr) Exec(frame *interpreter.ExecutionFrame) ref.Val { // short-circuit lhs. - optLHS := opt.lhs.Eval(ctx) + optLHS := opt.lhs.Exec(frame) switch val := optLHS.(type) { case *types.Err, *types.Unknown: return optLHS @@ -741,18 +739,24 @@ func (opt *evalOptionalOr) Eval(ctx interpreter.Activation) ref.Val { if val.HasValue() { return optLHS } - return opt.rhs.Eval(ctx) + return opt.rhs.Exec(frame) default: return types.NoSuchOverloadErr() } } +// Eval evaluates the left-hand side optional to determine whether it contains a value, else +// proceeds with the right-hand side evaluation. +func (opt *evalOptionalOr) Eval(ctx interpreter.Activation) ref.Val { + return opt.Exec(interpreter.AsFrame(ctx)) +} + // evalOptionalOrValue selects between an optional or a concrete value. If the optional has a value, // its value is returned, otherwise the alternative value expression is evaluated and returned. type evalOptionalOrValue struct { id int64 - lhs interpreter.Interpretable - rhs interpreter.Interpretable + lhs interpreter.InterpretableV2 + rhs interpreter.InterpretableV2 } // ID implements the Interpretable interface method. @@ -760,11 +764,9 @@ func (opt *evalOptionalOrValue) ID() int64 { return opt.id } -// Eval evaluates the left-hand side optional to determine whether it contains a value, else -// proceeds with the right-hand side evaluation. -func (opt *evalOptionalOrValue) Eval(ctx interpreter.Activation) ref.Val { +func (opt *evalOptionalOrValue) Exec(frame *interpreter.ExecutionFrame) ref.Val { // short-circuit lhs. - optLHS := opt.lhs.Eval(ctx) + optLHS := opt.lhs.Exec(frame) switch val := optLHS.(type) { case *types.Err, *types.Unknown: @@ -773,12 +775,18 @@ func (opt *evalOptionalOrValue) Eval(ctx interpreter.Activation) ref.Val { if val.HasValue() { return val.GetValue() } - return opt.rhs.Eval(ctx) + return opt.rhs.Exec(frame) default: return types.NoSuchOverloadErr() } } +// Eval evaluates the left-hand side optional to determine whether it contains a value, else +// proceeds with the right-hand side evaluation. +func (opt *evalOptionalOrValue) Eval(ctx interpreter.Activation) ref.Val { + return opt.Exec(interpreter.AsFrame(ctx)) +} + type timeLegacyLibrary struct{} func (timeLegacyLibrary) CompileOptions() []EnvOption { diff --git a/cel/options.go b/cel/options.go index d7d2ab03..86a98d35 100644 --- a/cel/options.go +++ b/cel/options.go @@ -456,6 +456,14 @@ func CustomDecorator(dec interpreter.InterpretableDecorator) ProgramOption { } } +// CustomDecoratorV2 appends an InterpreterDecoratorV2 to the program. +func CustomDecoratorV2(dec interpreter.InterpretableDecoratorV2) ProgramOption { + return func(p *prog) (*prog, error) { + p.plannerOptions = append(p.plannerOptions, interpreter.CustomDecoratorV2(dec)) + return p, nil + } +} + // Functions adds function overloads that extend or override the set of CEL built-ins. // // Deprecated: use Function() instead to declare the function, its overload signatures, diff --git a/cel/program.go b/cel/program.go index c8b72504..79df0374 100644 --- a/cel/program.go +++ b/cel/program.go @@ -18,7 +18,6 @@ import ( "context" "errors" "fmt" - "sync" "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/functions" @@ -160,7 +159,7 @@ type prog struct { regexOptimizations []*interpreter.RegexOptimization // Interpretable configured from an Ast and aggregate decorator set based on program options. - interpretable interpreter.Interpretable + interpretable interpreter.InterpretableV2 observable *interpreter.ObservableInterpretable callCostEstimator interpreter.ActualCostEstimator costOptions []interpreter.CostTrackerOption @@ -313,22 +312,19 @@ func (p *prog) Eval(input any) (out ref.Val, det *EvalDetails, err error) { } }() // Build a hierarchical activation if there are default vars set. - var vars Activation - switch v := input.(type) { - case Activation: - vars = v - case map[string]any: - vars = activationPool.Setup(v) - defer activationPool.Put(vars) - default: - return nil, nil, fmt.Errorf("invalid input, wanted Activation or map[string]any, got: (%T)%v", input, input) - } - if p.defaultVars != nil { - vars = interpreter.NewHierarchicalActivation(p.defaultVars, vars) + var frame *interpreter.ExecutionFrame + if f, ok := input.(*interpreter.ExecutionFrame); ok { + frame = f + } else { + frame, err = p.newExecutionFrame(input) + if err != nil { + return nil, nil, err + } + defer frame.Close() } if p.observable != nil { det = &EvalDetails{} - out = p.observable.ObserveEval(vars, func(observed any) { + out = p.observable.ObserveExec(frame, func(observed any) { switch o := observed.(type) { case interpreter.EvalState: det.state = o @@ -337,7 +333,7 @@ func (p *prog) Eval(input any) (out ref.Val, det *EvalDetails, err error) { } }) } else { - out = p.interpretable.Eval(vars) + out = p.interpretable.Exec(frame) } // The output of an internal Eval may have a value (`v`) that is a types.Err. This step // translates the CEL value to a Go error response. This interface does not quite match the @@ -353,164 +349,29 @@ func (p *prog) ContextEval(ctx context.Context, input any) (ref.Val, *EvalDetail if ctx == nil { return nil, nil, fmt.Errorf("context can not be nil") } - // Configure the input, making sure to wrap Activation inputs in the special ctxActivation which - // exposes the #interrupted variable and manages rate-limited checks of the ctx.Done() state. - var vars Activation - switch v := input.(type) { - case Activation: - vars = ctxActivationPool.Setup(v, ctx.Done(), p.interruptCheckFrequency) - defer ctxActivationPool.Put(vars) - case map[string]any: - rawVars := activationPool.Setup(v) - defer activationPool.Put(rawVars) - vars = ctxActivationPool.Setup(rawVars, ctx.Done(), p.interruptCheckFrequency) - defer ctxActivationPool.Put(vars) - default: - return nil, nil, fmt.Errorf("invalid input, wanted Activation or map[string]any, got: (%T)%v", input, input) - } - out, det, err := p.Eval(vars) - if err != nil && errors.Is(err, interpreter.InterruptError{}) { - return out, det, fmt.Errorf("%w: %w", err, context.Cause(ctx)) - } - return out, det, err -} - -type ctxEvalActivation struct { - parent Activation - interrupt <-chan struct{} - interruptCheckCount uint - interruptCheckFrequency uint -} - -// ResolveName implements the Activation interface method, but adds a special #interrupted variable -// which is capable of testing whether a 'done' signal is provided from a context.Context channel. -func (a *ctxEvalActivation) ResolveName(name string) (any, bool) { - if name == "#interrupted" { - a.interruptCheckCount++ - if a.interruptCheckCount%a.interruptCheckFrequency == 0 { - select { - case <-a.interrupt: - return true, true - default: - return nil, false - } - } - return nil, false + frame, err := p.newExecutionFrame(input) + if err != nil { + return nil, nil, err } - return a.parent.ResolveName(name) -} - -func (a *ctxEvalActivation) Parent() Activation { - return a.parent -} - -func (a *ctxEvalActivation) AsPartialActivation() (interpreter.PartialActivation, bool) { - pa, ok := a.parent.(interpreter.PartialActivation) - return pa, ok -} - -func newCtxEvalActivationPool() *ctxEvalActivationPool { - return &ctxEvalActivationPool{ - Pool: sync.Pool{ - New: func() any { - return &ctxEvalActivation{} - }, - }, + defer frame.Close() + frame.SetContext(ctx, p.interruptCheckFrequency) + out, det, errEval := p.Eval(frame) + if errEval != nil && errors.Is(errEval, interpreter.InterruptError{}) { + return out, det, fmt.Errorf("%w: %w", errEval, context.Cause(ctx)) } + return out, det, errEval } -type ctxEvalActivationPool struct { - sync.Pool -} - -// Setup initializes a pooled Activation with the ability check for context.Context cancellation -func (p *ctxEvalActivationPool) Setup(vars Activation, done <-chan struct{}, interruptCheckRate uint) *ctxEvalActivation { - a := p.Pool.Get().(*ctxEvalActivation) - a.parent = vars - a.interrupt = done - a.interruptCheckCount = 0 - a.interruptCheckFrequency = interruptCheckRate - return a -} - -type evalActivation struct { - vars map[string]any - lazyVars map[string]any -} - -// ResolveName looks up the value of the input variable name, if found. -// -// Lazy bindings may be supplied within the map-based input in either of the following forms: -// - func() any -// - func() ref.Val -// -// The lazy binding will only be invoked once per evaluation. -// -// Values which are not represented as ref.Val types on input may be adapted to a ref.Val using -// the types.Adapter configured in the environment. -func (a *evalActivation) ResolveName(name string) (any, bool) { - v, found := a.vars[name] - if !found { - return nil, false - } - switch obj := v.(type) { - case func() ref.Val: - if resolved, found := a.lazyVars[name]; found { - return resolved, true - } - lazy := obj() - a.lazyVars[name] = lazy - return lazy, true - case func() any: - if resolved, found := a.lazyVars[name]; found { - return resolved, true - } - lazy := obj() - a.lazyVars[name] = lazy - return lazy, true - default: - return obj, true +// newExecutionFrame creates an ExecutionFrame for the given input without a timeout context. +func (p *prog) newExecutionFrame(input any) (*interpreter.ExecutionFrame, error) { + frame, err := interpreter.NewExecutionFrame(input) + if err != nil { + return nil, err } -} - -// Parent implements the Activation interface -func (a *evalActivation) Parent() Activation { - return nil -} - -func newEvalActivationPool() *evalActivationPool { - return &evalActivationPool{ - Pool: sync.Pool{ - New: func() any { - return &evalActivation{lazyVars: make(map[string]any)} - }, - }, + if p.defaultVars != nil { + // Update the frame's activation in place. + frame.Activation = interpreter.NewHierarchicalActivation(p.defaultVars, frame.Activation) } -} - -type evalActivationPool struct { - sync.Pool -} -// Setup initializes a pooled Activation object with the map input. -func (p *evalActivationPool) Setup(vars map[string]any) *evalActivation { - a := p.Pool.Get().(*evalActivation) - a.vars = vars - return a + return frame, nil } - -func (p *evalActivationPool) Put(value any) { - a := value.(*evalActivation) - for k := range a.lazyVars { - delete(a.lazyVars, k) - } - p.Pool.Put(a) -} - -var ( - // activationPool is an internally managed pool of Activation values that wrap map[string]any inputs - activationPool = newEvalActivationPool() - - // ctxActivationPool is an internally managed pool of Activation values that expose a special #interrupted variable - ctxActivationPool = newCtxEvalActivationPool() -) diff --git a/ext/bindings.go b/ext/bindings.go index bef29ae2..4c070b1a 100644 --- a/ext/bindings.go +++ b/ext/bindings.go @@ -108,7 +108,7 @@ func (lib *celBindings) CompileOptions() []cel.EnvOption { func (lib *celBindings) ProgramOptions() []cel.ProgramOption { if lib.version >= 1 { - celBlockPlan := func(i interpreter.Interpretable) (interpreter.Interpretable, error) { + celBlockPlan := func(i interpreter.InterpretableV2) (interpreter.InterpretableV2, error) { call, ok := i.(interpreter.InterpretableCall) if !ok { return i, nil @@ -140,7 +140,7 @@ func (lib *celBindings) ProgramOptions() []cel.ProgramOption { return i, nil } } - return []cel.ProgramOption{cel.CustomDecorator(celBlockPlan)} + return []cel.ProgramOption{cel.CustomDecoratorV2(celBlockPlan)} } return []cel.ProgramOption{} } @@ -190,7 +190,7 @@ func celBind(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Ex ), nil } -func newDynamicBlock(slotExprs []interpreter.Interpretable, expr interpreter.Interpretable) interpreter.Interpretable { +func newDynamicBlock(slotExprs []interpreter.InterpretableV2, expr interpreter.InterpretableV2) interpreter.InterpretableV2 { bs := &dynamicBlock{ slotExprs: slotExprs, expr: expr, @@ -213,8 +213,8 @@ func newDynamicBlock(slotExprs []interpreter.Interpretable, expr interpreter.Int } type dynamicBlock struct { - slotExprs []interpreter.Interpretable - expr interpreter.Interpretable + slotExprs []interpreter.InterpretableV2 + expr interpreter.InterpretableV2 slotActivationPool *sync.Pool } @@ -223,14 +223,18 @@ func (b *dynamicBlock) ID() int64 { return b.expr.ID() } -// Eval implements the Interpretable interface method. -func (b *dynamicBlock) Eval(activation cel.Activation) ref.Val { +func (b *dynamicBlock) Exec(frame *interpreter.ExecutionFrame) ref.Val { sa := b.slotActivationPool.Get().(*dynamicSlotActivation) - sa.Activation = activation + sa.Activation = frame defer b.clearSlots(sa) return b.expr.Eval(sa) } +// Eval implements the Interpretable interface method. +func (b *dynamicBlock) Eval(activation cel.Activation) ref.Val { + return b.Exec(interpreter.AsFrame(activation)) +} + func (b *dynamicBlock) clearSlots(sa *dynamicSlotActivation) { sa.reset() b.slotActivationPool.Put(sa) @@ -243,7 +247,7 @@ type slotVal struct { type dynamicSlotActivation struct { cel.Activation - slotExprs []interpreter.Interpretable + slotExprs []interpreter.InterpretableV2 slotCount int slotVals []*slotVal } @@ -282,7 +286,7 @@ func (sa *dynamicSlotActivation) reset() { } } -func newConstantBlock(slots traits.Lister, expr interpreter.Interpretable) interpreter.Interpretable { +func newConstantBlock(slots traits.Lister, expr interpreter.InterpretableV2) interpreter.InterpretableV2 { count := slots.Size().(types.Int) return &constantBlock{slots: slots, slotCount: int(count), expr: expr} } @@ -290,7 +294,7 @@ func newConstantBlock(slots traits.Lister, expr interpreter.Interpretable) inter type constantBlock struct { slots traits.Lister slotCount int - expr interpreter.Interpretable + expr interpreter.InterpretableV2 } // ID implements the interpreter.Interpretable interface method. @@ -298,11 +302,15 @@ func (b *constantBlock) ID() int64 { return b.expr.ID() } +func (b *constantBlock) Exec(frame *interpreter.ExecutionFrame) ref.Val { + vars := constantSlotActivation{Activation: frame, slots: b.slots, slotCount: b.slotCount} + return b.expr.Eval(vars) +} + // Eval implements the interpreter.Interpretable interface method, and will proxy @index prefixed variable // lookups into a set of constant slots determined from the plan step. func (b *constantBlock) Eval(activation cel.Activation) ref.Val { - vars := constantSlotActivation{Activation: activation, slots: b.slots, slotCount: b.slotCount} - return b.expr.Eval(vars) + return b.Exec(interpreter.AsFrame(activation)) } type constantSlotActivation struct { diff --git a/ext/bindings_test.go b/ext/bindings_test.go index ca570772..3dddc588 100644 --- a/ext/bindings_test.go +++ b/ext/bindings_test.go @@ -17,6 +17,7 @@ package ext import ( "fmt" "strings" + "sync" "testing" "github.com/google/cel-go/cel" @@ -25,6 +26,7 @@ import ( "github.com/google/cel-go/common/operators" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/interpreter" ) var bindingTests = []struct { @@ -495,3 +497,28 @@ func TestBlockEval_RuntimeErrors(t *testing.T) { }) } } + +func TestDynamicBlockEval(t *testing.T) { + db := &dynamicBlock{ + expr: interpreter.NewConstValue(1, types.IntOne), + slotActivationPool: &sync.Pool{ + New: func() any { + return &dynamicSlotActivation{} + }, + }, + } + res := db.Eval(cel.NoVars()) + if res.Equal(types.IntOne) != types.True { + t.Errorf("db.Eval() = %v, wanted 1", res) + } +} + +func TestConstantBlockEval(t *testing.T) { + cb := &constantBlock{ + expr: interpreter.NewConstValue(2, types.IntOne), + } + res := cb.Eval(cel.NoVars()) + if res.Equal(types.IntOne) != types.True { + t.Errorf("cb.Eval() = %v, wanted 1", res) + } +} diff --git a/interpreter/BUILD.bazel b/interpreter/BUILD.bazel index 220e23d4..1274a6f1 100644 --- a/interpreter/BUILD.bazel +++ b/interpreter/BUILD.bazel @@ -14,6 +14,7 @@ go_library( "decorators.go", "dispatcher.go", "evalstate.go", + "frame.go", "interpretable.go", "interpreter.go", "optimizations.go", @@ -47,6 +48,7 @@ go_test( "activation_test.go", "attribute_patterns_test.go", "attributes_test.go", + "frame_test.go", "interpreter_test.go", "prune_test.go", "runtimecost_test.go", diff --git a/interpreter/activation.go b/interpreter/activation.go index dd40619e..023efbb8 100644 --- a/interpreter/activation.go +++ b/interpreter/activation.go @@ -110,8 +110,9 @@ func (a *mapActivation) ResolveName(name string) (any, bool) { // hierarchicalActivation which implements Activation and contains a parent and // child activation. type hierarchicalActivation struct { - parent Activation - child Activation + parent Activation + child Activation + poolAllocated bool } // Parent implements the Activation interface method. @@ -127,10 +128,28 @@ func (a *hierarchicalActivation) ResolveName(name string) (any, bool) { return a.parent.ResolveName(name) } +// Unwrap returns the parent activation, stripping the local child scope. +// This allows global disambiguation to skip past locally introduced variables. +func (a *hierarchicalActivation) Unwrap() Activation { + return a.parent +} + +// AsPartialActivation checks the child first via direct type assertion (to +// avoid recursion through the folder → frame → hierarchicalActivation cycle), +// then walks the parent hierarchy via the free function. +func (a *hierarchicalActivation) AsPartialActivation() (PartialActivation, bool) { + if pv, ok := a.child.(partialActivationConverter); ok { + if p, ok := pv.AsPartialActivation(); ok { + return p, true + } + } + return AsPartialActivation(a.parent) +} + // NewHierarchicalActivation takes two activations and produces a new one which prioritizes // resolution in the child first and parent(s) second. func NewHierarchicalActivation(parent Activation, child Activation) Activation { - return &hierarchicalActivation{parent, child} + return &hierarchicalActivation{parent: parent, child: child, poolAllocated: false} } // NewPartialActivation returns an Activation which contains a list of AttributePattern values diff --git a/interpreter/attributes.go b/interpreter/attributes.go index 107b5eea..51e3ddd4 100644 --- a/interpreter/attributes.go +++ b/interpreter/attributes.go @@ -190,7 +190,7 @@ func (r *attrFactory) AbsoluteAttribute(id int64, names ...string) NamespacedAtt func (r *attrFactory) ConditionalAttribute(id int64, expr Interpretable, t, f Attribute) Attribute { return &conditionalAttribute{ id: id, - expr: expr, + expr: adaptToV2(expr), truthy: t, falsy: f, adapter: r.adapter, @@ -225,7 +225,7 @@ func (r *attrFactory) MaybeAttribute(id int64, name string) Attribute { func (r *attrFactory) RelativeAttribute(id int64, operand Interpretable) Attribute { return &relativeAttribute{ id: id, - operand: operand, + operand: adaptToV2(operand), qualifiers: []Qualifier{}, adapter: r.adapter, fac: r, @@ -384,7 +384,7 @@ func (a *absoluteAttribute) Resolve(vars Activation) (any, error) { type conditionalAttribute struct { id int64 - expr Interpretable + expr InterpretableV2 truthy Attribute falsy Attribute adapter types.Adapter @@ -571,7 +571,7 @@ func (a *maybeAttribute) String() string { type relativeAttribute struct { id int64 - operand Interpretable + operand InterpretableV2 qualifiers []Qualifier adapter types.Adapter fac AttributeFactory diff --git a/interpreter/attributes_test.go b/interpreter/attributes_test.go index aebe01aa..fb2362e2 100644 --- a/interpreter/attributes_test.go +++ b/interpreter/attributes_test.go @@ -1177,9 +1177,6 @@ func TestAttributeStateTracking(t *testing.T) { if err != nil { t.Fatal(err) } - if err != nil { - t.Fatal(err) - } out := i.Eval(in) if types.IsUnknown(tc.out) && types.IsUnknown(out) { if !reflect.DeepEqual(tc.out, out) { @@ -1324,3 +1321,113 @@ func testExprTypeToType(t testing.TB, fieldType *exprpb.Type) *types.Type { } return ft } + +func TestConditionalAttributeQualify(t *testing.T) { + reg, _ := types.NewRegistry() + cont := containers.DefaultContainer + fac := NewAttributeFactory(cont, reg, reg) + + truthy := fac.AbsoluteAttribute(1, "a") + falsy := fac.AbsoluteAttribute(2, "b") + cond := &conditionalAttribute{ + id: 3, + expr: NewConstValue(4, types.True), + truthy: truthy, + falsy: falsy, + adapter: reg, + fac: fac, + } + + activation, _ := NewActivation(map[string]any{"a": "key", "b": "other"}) + obj := map[string]any{"key": 100} + + // Test Qualify + res, err := cond.Qualify(activation, obj) + if err != nil { + t.Fatalf("Qualify() failed: %v", err) + } + if res != 100 { + t.Errorf("Qualify() returned %v, wanted 100", res) + } + + // Test QualifyIfPresent + res, found, err := cond.QualifyIfPresent(activation, obj, false) + if err != nil { + t.Fatalf("QualifyIfPresent() failed: %v", err) + } + if !found || res != 100 { + t.Errorf("QualifyIfPresent() returned (%v, %v), wanted (100, true)", res, found) + } +} + +func TestQualifyIfPresent(t *testing.T) { + reg := newTestRegistry(t) + cont := containers.DefaultContainer + fac := NewAttributeFactory(cont, reg, reg) + activation, _ := NewActivation(map[string]any{ + "a": "b", + "c": int64(1), + }) + + tests := []struct { + name string + qual Qualifier + obj any + out any + }{ + { + name: "absolute_attribute", + qual: fac.AbsoluteAttribute(1, "a"), + obj: map[string]any{"b": 100}, + out: 100, + }, + { + name: "maybe_attribute", + qual: fac.MaybeAttribute(1, "a"), + obj: map[string]any{"b": 100}, + out: 100, + }, + { + name: "relative_attribute", + qual: fac.RelativeAttribute(2, NewConstValue(1, types.String("b"))), + obj: map[string]any{"b": 200}, + out: 200, + }, + { + name: "string_qualifier", + qual: makeOptQualifier(t, fac, nil, 1, "b"), + obj: map[string]any{"b": 300}, + out: 300, + }, + { + name: "int_qualifier", + qual: makeOptQualifier(t, fac, nil, 1, int64(1)), + obj: map[int64]any{int64(1): "value"}, + out: "value", + }, + { + name: "uint_qualifier", + qual: makeOptQualifier(t, fac, nil, 1, uint64(1)), + obj: map[uint64]any{uint64(1): "uvalue"}, + out: "uvalue", + }, + { + name: "bool_qualifier", + qual: makeOptQualifier(t, fac, nil, 1, true), + obj: map[bool]any{true: "bvalue"}, + out: "bvalue", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + res, found, err := tc.qual.QualifyIfPresent(activation, tc.obj, false) + if err != nil { + t.Fatalf("QualifyIfPresent() failed: %v", err) + } + if !found || !reflect.DeepEqual(res, tc.out) { + t.Errorf("QualifyIfPresent() returned (%v, %v), wanted (%v, true)", res, found, tc.out) + } + }) + } +} diff --git a/interpreter/decorators.go b/interpreter/decorators.go index 502db35f..9c973664 100644 --- a/interpreter/decorators.go +++ b/interpreter/decorators.go @@ -25,9 +25,13 @@ import ( // Interpretable expression nodes at construction time. type InterpretableDecorator func(Interpretable) (Interpretable, error) +// InterpretableDecoratorV2 is a functional interface for decorating or replacing +// InterpretableV2 expression nodes at construction time. +type InterpretableDecoratorV2 func(InterpretableV2) (InterpretableV2, error) + // decObserveEval records evaluation state into an EvalState object. -func decObserveEval(observer EvalObserver) InterpretableDecorator { - return func(i Interpretable) (Interpretable, error) { +func decObserveEval(observer EvalObserver) InterpretableDecoratorV2 { + return func(i InterpretableV2) (InterpretableV2, error) { switch inst := i.(type) { case *evalWatch, *evalWatchAttr, *evalWatchConst, *evalWatchConstructor: // these instruction are already watching, return straight-away. @@ -49,8 +53,8 @@ func decObserveEval(observer EvalObserver) InterpretableDecorator { }, nil default: return &evalWatch{ - Interpretable: i, - observer: observer, + InterpretableV2: i, + observer: observer, }, nil } } @@ -58,8 +62,8 @@ func decObserveEval(observer EvalObserver) InterpretableDecorator { // decInterruptFolds creates an intepretable decorator which marks comprehensions as interruptable // where the interrupt state is communicated via a hidden variable on the Activation. -func decInterruptFolds() InterpretableDecorator { - return func(i Interpretable) (Interpretable, error) { +func decInterruptFolds() InterpretableDecoratorV2 { + return func(i InterpretableV2) (InterpretableV2, error) { fold, ok := i.(*evalFold) if !ok { return i, nil @@ -70,8 +74,8 @@ func decInterruptFolds() InterpretableDecorator { } // decDisableShortcircuits ensures that all branches of an expression will be evaluated, no short-circuiting. -func decDisableShortcircuits() InterpretableDecorator { - return func(i Interpretable) (Interpretable, error) { +func decDisableShortcircuits() InterpretableDecoratorV2 { + return func(i InterpretableV2) (InterpretableV2, error) { switch expr := i.(type) { case *evalOr: return &evalExhaustiveOr{ @@ -104,8 +108,8 @@ func decDisableShortcircuits() InterpretableDecorator { // conditionally precomputing the result. // - build list and map values with constant elements. // - convert 'in' operations to set membership tests if possible. -func decOptimize() InterpretableDecorator { - return func(i Interpretable) (Interpretable, error) { +func decOptimize() InterpretableDecoratorV2 { + return func(i InterpretableV2) (InterpretableV2, error) { switch inst := i.(type) { case *evalList: return maybeBuildListLiteral(i, inst) @@ -124,7 +128,7 @@ func decOptimize() InterpretableDecorator { } // decRegexOptimizer compiles regex pattern string constants. -func decRegexOptimizer(regexOptimizations ...*RegexOptimization) InterpretableDecorator { +func decRegexOptimizer(regexOptimizations ...*RegexOptimization) InterpretableDecoratorV2 { functionMatchMap := make(map[string]*RegexOptimization) overloadMatchMap := make(map[string]*RegexOptimization) for _, m := range regexOptimizations { @@ -134,7 +138,7 @@ func decRegexOptimizer(regexOptimizations ...*RegexOptimization) InterpretableDe } } - return func(i Interpretable) (Interpretable, error) { + return func(i InterpretableV2) (InterpretableV2, error) { call, ok := i.(InterpretableCall) if !ok { return i, nil @@ -165,7 +169,7 @@ func decRegexOptimizer(regexOptimizations ...*RegexOptimization) InterpretableDe } } -func maybeOptimizeConstUnary(i Interpretable, call InterpretableCall) (Interpretable, error) { +func maybeOptimizeConstUnary(i InterpretableV2, call InterpretableCall) (InterpretableV2, error) { args := call.Args() if len(args) != 1 { return i, nil @@ -181,7 +185,7 @@ func maybeOptimizeConstUnary(i Interpretable, call InterpretableCall) (Interpret return NewConstValue(call.ID(), val), nil } -func maybeBuildListLiteral(i Interpretable, l *evalList) (Interpretable, error) { +func maybeBuildListLiteral(i InterpretableV2, l *evalList) (InterpretableV2, error) { for _, elem := range l.elems { _, isConst := elem.(InterpretableConst) if !isConst { @@ -191,7 +195,7 @@ func maybeBuildListLiteral(i Interpretable, l *evalList) (Interpretable, error) return NewConstValue(l.ID(), l.Eval(EmptyActivation())), nil } -func maybeBuildMapLiteral(i Interpretable, mp *evalMap) (Interpretable, error) { +func maybeBuildMapLiteral(i InterpretableV2, mp *evalMap) (InterpretableV2, error) { for idx, key := range mp.keys { _, isConst := key.(InterpretableConst) if !isConst { @@ -209,7 +213,7 @@ func maybeBuildMapLiteral(i Interpretable, mp *evalMap) (Interpretable, error) { // test if the following conditions are true: // - the list is a constant with homogeneous element types. // - the elements are all of primitive type. -func maybeOptimizeSetMembership(i Interpretable, inlist InterpretableCall) (Interpretable, error) { +func maybeOptimizeSetMembership(i InterpretableV2, inlist InterpretableCall) (InterpretableV2, error) { args := inlist.Args() lhs := args[0] rhs := args[1] diff --git a/interpreter/frame.go b/interpreter/frame.go index 357b5565..c6149d50 100644 --- a/interpreter/frame.go +++ b/interpreter/frame.go @@ -16,8 +16,12 @@ package interpreter import ( "context" + "errors" + "fmt" "sync" "sync/atomic" + + "github.com/google/cel-go/common/types/ref" ) // evalContext contains the stateful information needed for a single evaluation. @@ -66,27 +70,39 @@ type ExecutionFrame struct { } // NewExecutionFrame creates a new execution frame from the pool. -func NewExecutionFrame(vars Activation) *ExecutionFrame { +func NewExecutionFrame(input any) (*ExecutionFrame, error) { f := frameStack.Get().(*ExecutionFrame) - f.Activation = vars - return f + switch v := input.(type) { + case Activation: + f.Activation = v + case map[string]any: + f.Activation = activationInput.create(v) + default: + return nil, fmt.Errorf("invalid input, wanted Activation or map[string]any, got: (%T)%v", input, input) + } + return f, nil } // SetContext sets the context for the execution frame. -func (f *ExecutionFrame) SetContext(ctx context.Context, interruptCheckFrequency uint) { - if f.ctx == nil { - f.ctx = evalContextPool.Get().(*evalContext) +func (f *ExecutionFrame) SetContext(ctx context.Context, interruptCheckFrequency uint) error { + if f.parent != nil { + return errors.New("SetContext() called on child frame") + } + if f.ctx != nil { + return errors.New("SetContext() called more than once") } + f.ctx = evalContextPool.Get().(*evalContext) f.ctx.ctx, f.ctx.cancel = context.WithCancel(ctx) f.ctx.interrupt = ctx.Done() f.ctx.interruptCheckFrequency = interruptCheckFrequency f.ctx.interruptCheckCount.Store(0) f.ctx.interrupted.Store(false) + return nil } // Close releases the resources held by the execution frame and returns it to the pool. func (f *ExecutionFrame) Close() { - if f.ctx != nil { + if f.parent == nil && f.ctx != nil { if f.ctx.cancel != nil { f.ctx.cancel() f.ctx.cancel = nil @@ -99,10 +115,18 @@ func (f *ExecutionFrame) Close() { f.ctx.interruptCheckCount.Store(0) f.ctx.interruptCheckFrequency = 0 evalContextPool.Put(f.ctx) - f.ctx = nil } + f.ctx = nil f.parent = nil - activationStack.release(f.Activation) + switch a := f.Activation.(type) { + case *hierarchicalActivation: + if child, ok := a.child.(*inputActivation); ok { + activationInput.release(child) + } + activationStack.release(a) + case *inputActivation: + activationInput.release(a) + } f.Activation = nil frameStack.Put(f) } @@ -121,6 +145,9 @@ func (f *ExecutionFrame) push(activation Activation) *ExecutionFrame { // pop returns the parent frame, releasing the current frame back to the pool. func (f *ExecutionFrame) pop() *ExecutionFrame { + if f.parent == nil { + return f + } parent := f.parent activationStack.release(f.Activation) f.Activation = nil @@ -193,12 +220,13 @@ func (pool *activationStackPool) create(parent, child Activation) Activation { h := pool.Get().(*hierarchicalActivation) h.child = child h.parent = parent + h.poolAllocated = true return h } func (pool *activationStackPool) release(activation Activation) { h, ok := activation.(*hierarchicalActivation) - if !ok { + if !ok || !h.poolAllocated { return } h.parent = nil @@ -206,7 +234,7 @@ func (pool *activationStackPool) release(activation Activation) { pool.Pool.Put(h) } -func newActivationPool() *activationStackPool { +func newActivationStackPool() *activationStackPool { return &activationStackPool{ Pool: sync.Pool{ New: func() any { @@ -216,6 +244,84 @@ func newActivationPool() *activationStackPool { } } +type inputActivation struct { + vars map[string]any + lazyVars map[string]any +} + +// ResolveName looks up the value of the input variable name, if found. +// +// Lazy bindings may be supplied within the map-based input in either of the following forms: +// - func() any +// - func() ref.Val +// +// The lazy binding will only be invoked once per evaluation. +// +// Values which are not represented as ref.Val types on input may be adapted to a ref.Val using +// the types.Adapter configured in the environment. +func (a *inputActivation) ResolveName(name string) (any, bool) { + v, found := a.vars[name] + if !found { + return nil, false + } + switch obj := v.(type) { + case func() ref.Val: + if resolved, found := a.lazyVars[name]; found { + return resolved, true + } + lazy := obj() + a.lazyVars[name] = lazy + return lazy, true + case func() any: + if resolved, found := a.lazyVars[name]; found { + return resolved, true + } + lazy := obj() + a.lazyVars[name] = lazy + return lazy, true + default: + return obj, true + } +} + +// Parent implements the Activation interface +func (a *inputActivation) Parent() Activation { + return nil +} + +func newActivationInputPool() *activationInputPool { + return &activationInputPool{ + Pool: sync.Pool{ + New: func() any { + return &inputActivation{ + lazyVars: make(map[string]any), + } + }, + }, + } +} + +type activationInputPool struct { + sync.Pool +} + +// create initializes a pooled Activation object with the map input. +func (p *activationInputPool) create(vars map[string]any) *inputActivation { + a := p.Pool.Get().(*inputActivation) + a.vars = vars + return a +} + +func (p *activationInputPool) release(value any) { + a := value.(*inputActivation) + for k := range a.lazyVars { + delete(a.lazyVars, k) + } + a.vars = nil + p.Pool.Put(a) +} + var ( - activationStack = newActivationPool() + activationStack = newActivationStackPool() + activationInput = newActivationInputPool() ) diff --git a/interpreter/frame_test.go b/interpreter/frame_test.go index 5c8b47b5..816827b5 100644 --- a/interpreter/frame_test.go +++ b/interpreter/frame_test.go @@ -17,6 +17,9 @@ package interpreter import ( "context" "testing" + + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" ) func TestFrameCheckInterrupt(t *testing.T) { @@ -83,7 +86,7 @@ func TestFrameCheckInterrupt(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - frame := NewExecutionFrame(EmptyActivation()) + frame := mustNewExecutionFrame(t, EmptyActivation()) defer frame.Close() var cleanup context.CancelFunc @@ -128,7 +131,7 @@ func TestFrameResolveName(t *testing.T) { { name: "resolve in base activation", setup: func() *ExecutionFrame { - return NewExecutionFrame(baseAct) + return mustNewExecutionFrame(t, baseAct) }, varName: "x", wantVal: 1, @@ -137,7 +140,7 @@ func TestFrameResolveName(t *testing.T) { { name: "missing in base activation", setup: func() *ExecutionFrame { - return NewExecutionFrame(baseAct) + return mustNewExecutionFrame(t, baseAct) }, varName: "y", wantVal: nil, @@ -146,7 +149,7 @@ func TestFrameResolveName(t *testing.T) { { name: "resolve in child activation", setup: func() *ExecutionFrame { - f := NewExecutionFrame(baseAct) + f := mustNewExecutionFrame(t, baseAct) return f.push(childAct) }, varName: "y", @@ -156,7 +159,7 @@ func TestFrameResolveName(t *testing.T) { { name: "resolve in parent activation from child", setup: func() *ExecutionFrame { - f := NewExecutionFrame(baseAct) + f := mustNewExecutionFrame(t, baseAct) return f.push(childAct) }, varName: "x", @@ -166,7 +169,7 @@ func TestFrameResolveName(t *testing.T) { { name: "missing in hierarchical activation", setup: func() *ExecutionFrame { - f := NewExecutionFrame(baseAct) + f := mustNewExecutionFrame(t, baseAct) return f.push(childAct) }, varName: "z", @@ -212,7 +215,7 @@ func TestFrameParent(t *testing.T) { { name: "base frame has no parent activation", setup: func() (*ExecutionFrame, func()) { - f := NewExecutionFrame(baseAct) + f := mustNewExecutionFrame(t, baseAct) return f, func() { f.Close() } }, want: nil, @@ -220,7 +223,7 @@ func TestFrameParent(t *testing.T) { { name: "pushed frame returns parent activation", setup: func() (*ExecutionFrame, func()) { - f := NewExecutionFrame(baseAct) + f := mustNewExecutionFrame(t, baseAct) child := f.push(childAct) return child, func() { child.pop() @@ -248,7 +251,7 @@ func TestFrameUnwrap(t *testing.T) { if err != nil { t.Fatalf("NewActivation(x) failed: %v", err) } - frame := NewExecutionFrame(baseAct) + frame := mustNewExecutionFrame(t, baseAct) defer frame.Close() if got := frame.Unwrap(); got != baseAct { @@ -285,21 +288,21 @@ func TestFrameAsPartialActivation(t *testing.T) { { name: "non-partial activation returns false", setup: func() *ExecutionFrame { - return NewExecutionFrame(baseAct) + return mustNewExecutionFrame(t, baseAct) }, wantFound: false, }, { name: "partial activation returns true", setup: func() *ExecutionFrame { - return NewExecutionFrame(partAct) + return mustNewExecutionFrame(t, partAct) }, wantFound: true, }, { name: "hierarchical activation wrapping partial activation returns true", setup: func() *ExecutionFrame { - f := NewExecutionFrame(partAct) + f := mustNewExecutionFrame(t, partAct) return f.push(baseAct) }, wantFound: true, @@ -338,7 +341,7 @@ func TestFramePushPop(t *testing.T) { t.Fatalf("NewActivation(y) failed: %v", err) } - frame := NewExecutionFrame(baseAct) + frame := mustNewExecutionFrame(t, baseAct) defer frame.Close() childFrame := frame.push(childAct) @@ -356,7 +359,7 @@ func TestFramePushPop(t *testing.T) { } func TestFrameClose(t *testing.T) { - frame := NewExecutionFrame(EmptyActivation()) + frame := mustNewExecutionFrame(t, EmptyActivation()) ctx := context.Background() frame.SetContext(ctx, 1) @@ -376,3 +379,199 @@ func TestFrameClose(t *testing.T) { t.Error("context not canceled after Close()") } } + +func TestFrameLifecycleAndPooling(t *testing.T) { + vars := map[string]any{"a": 1, "b": 2} + frame := mustNewExecutionFrame(t, vars) + val, found := frame.ResolveName("a") + if !found || val != 1 { + t.Errorf("ResolveName('a') got %v, %t; want 1, true", val, found) + } + + // Wrap in a hierarchical activation (e.g. simulating defaultVars setup in program.go) + parentAct, err := NewActivation(map[string]any{"c": 3}) + if err != nil { + t.Fatalf("NewActivation failed: %v", err) + } + frame.Activation = NewHierarchicalActivation(parentAct, frame.Activation) + + val, found = frame.ResolveName("c") + if !found || val != 3 { + t.Errorf("ResolveName('c') got %v, %t; want 3, true", val, found) + } + + val, found = frame.ResolveName("a") + if !found || val != 1 { + t.Errorf("ResolveName('a') got %v, %t; want 1, true", val, found) + } + + // Close the frame. This should release the pooled evalActivation under the hierarchical activation. + frame.Close() + + // Verify that we can obtain a clean frame from the pool. + newFrame := mustNewExecutionFrame(t, map[string]any{"x": 10}) + defer newFrame.Close() + val, found = newFrame.ResolveName("x") + if !found || val != 10 { + t.Errorf("ResolveName('x') got %v, %t; want 10, true", val, found) + } + val, found = newFrame.ResolveName("a") + if found { + t.Errorf("ResolveName('a') found on fresh frame: %v", val) + } +} + +func TestFrameSetContext(t *testing.T) { + ctx := context.Background() + f := mustNewExecutionFrame(t, EmptyActivation()) + defer f.Close() + + if err := f.SetContext(ctx, 1); err != nil { + t.Errorf("SetContext failed: %v", err) + } +} + +func TestFrameSetContextTwiceError(t *testing.T) { + ctx := context.Background() + f := mustNewExecutionFrame(t, EmptyActivation()) + defer f.Close() + + if err := f.SetContext(ctx, 1); err != nil { + t.Fatalf("SetContext failed first time: %v", err) + } + + if err := f.SetContext(ctx, 1); err == nil { + t.Error("expected SetContext to return an error when called twice, got nil") + } +} + +func TestFrameSetContextChildError(t *testing.T) { + ctx := context.Background() + f := mustNewExecutionFrame(t, EmptyActivation()) + defer f.Close() + + child := f.push(EmptyActivation()) + defer child.pop() + + if err := child.SetContext(ctx, 1); err == nil { + t.Error("expected SetContext on a child frame to return an error, got nil") + } +} + +func TestNewExecutionFrameInvalidInput(t *testing.T) { + f, err := NewExecutionFrame(123) + if err == nil { + f.Close() + t.Error("NewExecutionFrame with int input did not return error") + } +} + +func TestFramePopBaseFrame(t *testing.T) { + f := mustNewExecutionFrame(t, EmptyActivation()) + defer f.Close() + + popped := f.pop() + if popped != f { + t.Errorf("pop() on base frame got %v, want %v", popped, f) + } +} + +func TestLazyVariableResolution(t *testing.T) { + lazyRefValCalled := 0 + lazyAnyCalled := 0 + + vars := map[string]any{ + "lazy_ref": func() ref.Val { + lazyRefValCalled++ + return types.IntOne + }, + "lazy_any": func() any { + lazyAnyCalled++ + return 2 + }, + "normal": 3, + } + + frame := mustNewExecutionFrame(t, vars) + defer frame.Close() + + tests := []struct { + name string + lookupName string + wantFound bool + wantVal any + wantRefCall int + wantAnyCall int + }{ + { + name: "missing variable", + lookupName: "missing", + wantFound: false, + }, + { + name: "normal variable", + lookupName: "normal", + wantFound: true, + wantVal: 3, + }, + { + name: "lazy ref.Val first call", + lookupName: "lazy_ref", + wantFound: true, + wantVal: types.IntOne, + wantRefCall: 1, + }, + { + name: "lazy ref.Val second call cached", + lookupName: "lazy_ref", + wantFound: true, + wantVal: types.IntOne, + wantRefCall: 1, + }, + { + name: "lazy any first call", + lookupName: "lazy_any", + wantFound: true, + wantVal: 2, + wantRefCall: 1, + wantAnyCall: 1, + }, + { + name: "lazy any second call cached", + lookupName: "lazy_any", + wantFound: true, + wantVal: 2, + wantRefCall: 1, + wantAnyCall: 1, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + val, found := frame.ResolveName(tc.lookupName) + if found != tc.wantFound { + t.Fatalf("ResolveName(%q) found = %t, want = %t", tc.lookupName, found, tc.wantFound) + } + if found { + if val != tc.wantVal { + t.Errorf("ResolveName(%q) value = %v, want = %v", tc.lookupName, val, tc.wantVal) + } + } + if lazyRefValCalled != tc.wantRefCall { + t.Errorf("lazyRefValCalled = %d, want = %d", lazyRefValCalled, tc.wantRefCall) + } + if lazyAnyCalled != tc.wantAnyCall { + t.Errorf("lazyAnyCalled = %d, want = %d", lazyAnyCalled, tc.wantAnyCall) + } + }) + } +} + +func mustNewExecutionFrame(t testing.TB, input any) *ExecutionFrame { + t.Helper() + f, err := NewExecutionFrame(input) + if err != nil { + t.Fatalf("NewExecutionFrame() failed: %v", err) + } + return f +} diff --git a/interpreter/interpretable.go b/interpreter/interpretable.go index ff7080c5..50766039 100644 --- a/interpreter/interpretable.go +++ b/interpreter/interpretable.go @@ -26,20 +26,53 @@ import ( "github.com/google/cel-go/common/types/traits" ) -// Interpretable can accept a given Activation and produce a value along with -// an accompanying EvalState which can be used to inspect whether additional -// data might be necessary to complete the evaluation. +// Interpretable evaluates an Activation and produces a value. type Interpretable interface { // ID value corresponding to the expression node. ID() int64 - // Eval an Activation to produce an output. + // Eval evaluates an Activation and produces an output. Eval(activation Activation) ref.Val } +// InterpretableV2 evaluates an ExecutionFrame and produces a value. +// +// The ExecutionFrame should not be stored and should always be passed as the first +// argument to any function as it behaves like Golang's context.Context. +type InterpretableV2 interface { + Interpretable + + // Exec evaluates the expression within the given ExecutionFrame. + Exec(frame *ExecutionFrame) ref.Val +} + +// adaptToV2 adapts a V1 Interpretable implementation to the V2 interface. +// +// This adapter is used to bridge the legacy Interpretable interface to the +// modern InterpretableV2 interface, providing a shim that allows the use of +// both interfaces in the same system. +func adaptToV2(i Interpretable) InterpretableV2 { + switch v := i.(type) { + case InterpretableV2: + return v + default: + return &v1Adapter{Interpretable: v} + } +} + +// v1Adapter handles bridging a V1 Interpretable implementation to the V2 interface. +type v1Adapter struct { + Interpretable +} + +// Exec implements the InterpretableV2 interface method. +func (a *v1Adapter) Exec(f *ExecutionFrame) ref.Val { + return a.Eval(f) +} + // InterpretableConst interface for tracking whether the Interpretable is a constant value. type InterpretableConst interface { - Interpretable + InterpretableV2 // Value returns the constant value of the instruction. Value() ref.Val @@ -47,7 +80,7 @@ type InterpretableConst interface { // InterpretableAttribute interface for tracking whether the Interpretable is an attribute. type InterpretableAttribute interface { - Interpretable + InterpretableV2 // Attr returns the Attribute value. Attr() Attribute @@ -81,7 +114,7 @@ type InterpretableAttribute interface { // InterpretableCall interface for inspecting Interpretable instructions related to function calls. type InterpretableCall interface { - Interpretable + InterpretableV2 // Function returns the function name as it appears in text or mangled operator name as it // appears in the operators.go file. @@ -94,16 +127,16 @@ type InterpretableCall interface { // Args returns the normalized arguments to the function overload. // For receiver-style functions, the receiver target is arg 0. - Args() []Interpretable + Args() []InterpretableV2 } // InterpretableConstructor interface for inspecting Interpretable instructions that initialize a list, map // or struct. type InterpretableConstructor interface { - Interpretable + InterpretableV2 // InitVals returns all the list elements, map key and values or struct field values. - InitVals() []Interpretable + InitVals() []InterpretableV2 // Type returns the type constructed. Type() ref.Type @@ -112,44 +145,81 @@ type InterpretableConstructor interface { // ObservableInterpretable is an Interpretable which supports stateful observation, such as tracing // or cost-tracking. type ObservableInterpretable struct { - Interpretable + InterpretableV2 observers []StatefulObserver } // ID implements the Interpretable method to get the expression id associated with the step. func (oi *ObservableInterpretable) ID() int64 { - return oi.Interpretable.ID() + return oi.InterpretableV2.ID() +} + +// Exec implements the InterpretableV2 interface method. +func (oi *ObservableInterpretable) Exec(frame *ExecutionFrame) ref.Val { + return oi.ObserveExec(frame, func(any) {}) } // Eval proxies to the ObserveEval method while invoking a no-op callback to report the observations. func (oi *ObservableInterpretable) Eval(vars Activation) ref.Val { - return oi.ObserveEval(vars, func(any) {}) + return oi.ObserveExec(AsFrame(vars), func(any) {}) } -// ObserveEval evaluates an interpretable and performs per-evaluation state-tracking. +// ObserveExec evaluates an interpretable and performs per-evaluation state-tracking. // // This method is concurrency safe and the expectation is that the observer function will use // a switch statement to determine the type of the state which has been reported back from the call. -func (oi *ObservableInterpretable) ObserveEval(vars Activation, observer func(any)) ref.Val { - var err error +func (oi *ObservableInterpretable) ObserveExec(frame *ExecutionFrame, observer func(any)) ref.Val { // Initialize the state needed for the observers to function. for _, obs := range oi.observers { - vars, err = obs.InitState(vars) + state, err := obs.InitState(frame) if err != nil { return types.WrapErr(err) } // Provide an initial reference to the state to ensure state is available // even in cases of interrupting errors generated during evaluation. - observer(obs.GetState(vars)) + observer(state) } - result := oi.Interpretable.Eval(vars) + result := oi.InterpretableV2.Exec(frame) // Get the state which needs to be reported back as having been observed. for _, obs := range oi.observers { - observer(obs.GetState(vars)) + observer(obs.GetState(frame)) } return result } +// AsFrame promotes an Activation to an ExecutionFrame. +func AsFrame(a Activation) *ExecutionFrame { + if f, ok := a.(*ExecutionFrame); ok { + return f + } + frame := &ExecutionFrame{Activation: a} + // Walk the activation hierarchy to find a parent ExecutionFrame and inherit + // its shared context. + if parent := findFrame(a); parent != nil { + frame.ctx = parent.ctx + } + return frame +} + +// findFrame walks the activation hierarchy via Unwrap and Parent to locate an +// existing ExecutionFrame, if one exists. +func findFrame(a Activation) *ExecutionFrame { + if wrapper, ok := a.(activationWrapper); ok { + unwrapped := wrapper.Unwrap() + if f, ok := unwrapped.(*ExecutionFrame); ok { + return f + } + return findFrame(unwrapped) + } + if p := a.Parent(); p != nil { + if f, ok := p.(*ExecutionFrame); ok { + return f + } + return findFrame(p) + } + return nil +} + // Core Interpretable implementations used during the program planning phase. type evalTestOnly struct { @@ -162,9 +232,9 @@ func (test *evalTestOnly) ID() int64 { return test.id } -// Eval implements the Interpretable interface method. -func (test *evalTestOnly) Eval(ctx Activation) ref.Val { - val, err := test.Resolve(ctx) +// Exec implements the InterpretableV2 interface method. +func (test *evalTestOnly) Exec(frame *ExecutionFrame) ref.Val { + val, err := test.Resolve(frame) // Return an error if the resolve step fails if err != nil { return types.LabelErrNode(test.id, types.WrapErr(err)) @@ -175,6 +245,11 @@ func (test *evalTestOnly) Eval(ctx Activation) ref.Val { return test.Adapter().NativeToValue(val) } +// Eval implements the Interpretable interface method. +func (test *evalTestOnly) Eval(ctx Activation) ref.Val { + return test.Exec(AsFrame(ctx)) +} + // AddQualifier appends a qualifier that will always and only perform a presence test. func (test *evalTestOnly) AddQualifier(q Qualifier) (Attribute, error) { cq, ok := q.(ConstantQualifier) @@ -230,6 +305,11 @@ func (cons *evalConst) ID() int64 { return cons.id } +// Exec implements the InterpretableV2 interface method. +func (cons *evalConst) Exec(frame *ExecutionFrame) ref.Val { + return cons.val +} + // Eval implements the Interpretable interface method. func (cons *evalConst) Eval(ctx Activation) ref.Val { return cons.val @@ -242,7 +322,7 @@ func (cons *evalConst) Value() ref.Val { type evalOr struct { id int64 - terms []Interpretable + terms []InterpretableV2 } // ID implements the Interpretable interface method. @@ -250,12 +330,12 @@ func (or *evalOr) ID() int64 { return or.id } -// Eval implements the Interpretable interface method. -func (or *evalOr) Eval(ctx Activation) ref.Val { +// Exec implements the InterpretableV2 interface method. +func (or *evalOr) Exec(frame *ExecutionFrame) ref.Val { var err ref.Val = nil var unk *types.Unknown for _, term := range or.terms { - val := term.Eval(ctx) + val := term.Exec(frame) boolVal, ok := val.(types.Bool) // short-circuit on true. if ok && boolVal == types.True { @@ -283,9 +363,14 @@ func (or *evalOr) Eval(ctx Activation) ref.Val { return types.False } +// Eval implements the Interpretable interface method. +func (or *evalOr) Eval(ctx Activation) ref.Val { + return or.Exec(AsFrame(ctx)) +} + type evalAnd struct { id int64 - terms []Interpretable + terms []InterpretableV2 } // ID implements the Interpretable interface method. @@ -293,12 +378,12 @@ func (and *evalAnd) ID() int64 { return and.id } -// Eval implements the Interpretable interface method. -func (and *evalAnd) Eval(ctx Activation) ref.Val { +// Exec implements the InterpretableV2 interface method. +func (and *evalAnd) Exec(frame *ExecutionFrame) ref.Val { var err ref.Val = nil var unk *types.Unknown for _, term := range and.terms { - val := term.Eval(ctx) + val := term.Exec(frame) boolVal, ok := val.(types.Bool) // short-circuit on false. if ok && boolVal == types.False { @@ -326,10 +411,15 @@ func (and *evalAnd) Eval(ctx Activation) ref.Val { return types.True } +// Eval implements the Interpretable interface method. +func (and *evalAnd) Eval(ctx Activation) ref.Val { + return and.Exec(AsFrame(ctx)) +} + type evalEq struct { id int64 - lhs Interpretable - rhs Interpretable + lhs InterpretableV2 + rhs InterpretableV2 } // ID implements the Interpretable interface method. @@ -337,10 +427,10 @@ func (eq *evalEq) ID() int64 { return eq.id } -// Eval implements the Interpretable interface method. -func (eq *evalEq) Eval(ctx Activation) ref.Val { - lVal := eq.lhs.Eval(ctx) - rVal := eq.rhs.Eval(ctx) +// Exec implements the InterpretableV2 interface method. +func (eq *evalEq) Exec(frame *ExecutionFrame) ref.Val { + lVal := eq.lhs.Exec(frame) + rVal := eq.rhs.Exec(frame) if types.IsUnknownOrError(lVal) { return lVal } @@ -350,6 +440,11 @@ func (eq *evalEq) Eval(ctx Activation) ref.Val { return types.Equal(lVal, rVal) } +// Eval implements the Interpretable interface method. +func (eq *evalEq) Eval(ctx Activation) ref.Val { + return eq.Exec(AsFrame(ctx)) +} + // Function implements the InterpretableCall interface method. func (*evalEq) Function() string { return operators.Equals @@ -361,14 +456,14 @@ func (*evalEq) OverloadID() string { } // Args implements the InterpretableCall interface method. -func (eq *evalEq) Args() []Interpretable { - return []Interpretable{eq.lhs, eq.rhs} +func (eq *evalEq) Args() []InterpretableV2 { + return []InterpretableV2{eq.lhs, eq.rhs} } type evalNe struct { id int64 - lhs Interpretable - rhs Interpretable + lhs InterpretableV2 + rhs InterpretableV2 } // ID implements the Interpretable interface method. @@ -376,10 +471,10 @@ func (ne *evalNe) ID() int64 { return ne.id } -// Eval implements the Interpretable interface method. -func (ne *evalNe) Eval(ctx Activation) ref.Val { - lVal := ne.lhs.Eval(ctx) - rVal := ne.rhs.Eval(ctx) +// Exec implements the InterpretableV2 interface method. +func (ne *evalNe) Exec(frame *ExecutionFrame) ref.Val { + lVal := ne.lhs.Exec(frame) + rVal := ne.rhs.Exec(frame) if types.IsUnknownOrError(lVal) { return lVal } @@ -389,6 +484,11 @@ func (ne *evalNe) Eval(ctx Activation) ref.Val { return types.Bool(types.Equal(lVal, rVal) != types.True) } +// Eval implements the Interpretable interface method. +func (ne *evalNe) Eval(ctx Activation) ref.Val { + return ne.Exec(AsFrame(ctx)) +} + // Function implements the InterpretableCall interface method. func (*evalNe) Function() string { return operators.NotEquals @@ -400,8 +500,8 @@ func (*evalNe) OverloadID() string { } // Args implements the InterpretableCall interface method. -func (ne *evalNe) Args() []Interpretable { - return []Interpretable{ne.lhs, ne.rhs} +func (ne *evalNe) Args() []InterpretableV2 { + return []InterpretableV2{ne.lhs, ne.rhs} } type evalZeroArity struct { @@ -416,9 +516,14 @@ func (zero *evalZeroArity) ID() int64 { return zero.id } +// Exec implements the InterpretableV2 interface method. +func (zero *evalZeroArity) Exec(frame *ExecutionFrame) ref.Val { + return types.LabelErrNode(zero.id, zero.impl()) +} + // Eval implements the Interpretable interface method. func (zero *evalZeroArity) Eval(ctx Activation) ref.Val { - return types.LabelErrNode(zero.id, zero.impl()) + return zero.Exec(AsFrame(ctx)) } // Function implements the InterpretableCall interface method. @@ -432,15 +537,15 @@ func (zero *evalZeroArity) OverloadID() string { } // Args returns the argument to the unary function. -func (zero *evalZeroArity) Args() []Interpretable { - return []Interpretable{} +func (zero *evalZeroArity) Args() []InterpretableV2 { + return []InterpretableV2{} } type evalUnary struct { id int64 function string overload string - arg Interpretable + arg InterpretableV2 trait int impl functions.UnaryOp nonStrict bool @@ -451,9 +556,9 @@ func (un *evalUnary) ID() int64 { return un.id } -// Eval implements the Interpretable interface method. -func (un *evalUnary) Eval(ctx Activation) ref.Val { - argVal := un.arg.Eval(ctx) +// Exec implements the InterpretableV2 interface method. +func (un *evalUnary) Exec(frame *ExecutionFrame) ref.Val { + argVal := un.arg.Exec(frame) // Early return if the argument to the function is unknown or error. strict := !un.nonStrict if strict && types.IsUnknownOrError(argVal) { @@ -472,6 +577,11 @@ func (un *evalUnary) Eval(ctx Activation) ref.Val { return types.NewErrWithNodeID(un.id, "no such overload: %s", un.function) } +// Eval implements the Interpretable interface method. +func (un *evalUnary) Eval(ctx Activation) ref.Val { + return un.Exec(AsFrame(ctx)) +} + // Function implements the InterpretableCall interface method. func (un *evalUnary) Function() string { return un.function @@ -483,16 +593,16 @@ func (un *evalUnary) OverloadID() string { } // Args returns the argument to the unary function. -func (un *evalUnary) Args() []Interpretable { - return []Interpretable{un.arg} +func (un *evalUnary) Args() []InterpretableV2 { + return []InterpretableV2{un.arg} } type evalBinary struct { id int64 function string overload string - lhs Interpretable - rhs Interpretable + lhs InterpretableV2 + rhs InterpretableV2 trait int impl functions.BinaryOp nonStrict bool @@ -503,10 +613,10 @@ func (bin *evalBinary) ID() int64 { return bin.id } -// Eval implements the Interpretable interface method. -func (bin *evalBinary) Eval(ctx Activation) ref.Val { - lVal := bin.lhs.Eval(ctx) - rVal := bin.rhs.Eval(ctx) +// Exec implements the InterpretableV2 interface method. +func (bin *evalBinary) Exec(frame *ExecutionFrame) ref.Val { + lVal := bin.lhs.Exec(frame) + rVal := bin.rhs.Exec(frame) // Early return if any argument to the function is unknown or error. strict := !bin.nonStrict if strict { @@ -530,6 +640,11 @@ func (bin *evalBinary) Eval(ctx Activation) ref.Val { return types.NewErrWithNodeID(bin.id, "no such overload: %s", bin.function) } +// Eval implements the Interpretable interface method. +func (bin *evalBinary) Eval(ctx Activation) ref.Val { + return bin.Exec(AsFrame(ctx)) +} + // Function implements the InterpretableCall interface method. func (bin *evalBinary) Function() string { return bin.function @@ -541,22 +656,22 @@ func (bin *evalBinary) OverloadID() string { } // Args returns the argument to the unary function. -func (bin *evalBinary) Args() []Interpretable { - return []Interpretable{bin.lhs, bin.rhs} +func (bin *evalBinary) Args() []InterpretableV2 { + return []InterpretableV2{bin.lhs, bin.rhs} } type evalVarArgs struct { id int64 function string overload string - args []Interpretable + args []InterpretableV2 trait int impl functions.FunctionOp nonStrict bool } // NewCall creates a new call Interpretable. -func NewCall(id int64, function, overload string, args []Interpretable, impl functions.FunctionOp) InterpretableCall { +func NewCall(id int64, function, overload string, args []InterpretableV2, impl functions.FunctionOp) InterpretableCall { return &evalVarArgs{ id: id, function: function, @@ -571,13 +686,13 @@ func (fn *evalVarArgs) ID() int64 { return fn.id } -// Eval implements the Interpretable interface method. -func (fn *evalVarArgs) Eval(ctx Activation) ref.Val { +// Exec implements the InterpretableV2 interface method. +func (fn *evalVarArgs) Exec(frame *ExecutionFrame) ref.Val { argVals := make([]ref.Val, len(fn.args)) // Early return if any argument to the function is unknown or error. strict := !fn.nonStrict for i, arg := range fn.args { - argVals[i] = arg.Eval(ctx) + argVals[i] = arg.Exec(frame) if strict && types.IsUnknownOrError(argVals[i]) { return argVals[i] } @@ -596,6 +711,11 @@ func (fn *evalVarArgs) Eval(ctx Activation) ref.Val { return types.NewErrWithNodeID(fn.id, "no such overload: %s %d", fn.function, fn.id) } +// Eval implements the Interpretable interface method. +func (fn *evalVarArgs) Eval(ctx Activation) ref.Val { + return fn.Exec(AsFrame(ctx)) +} + // Function implements the InterpretableCall interface method. func (fn *evalVarArgs) Function() string { return fn.function @@ -607,13 +727,13 @@ func (fn *evalVarArgs) OverloadID() string { } // Args returns the argument to the unary function. -func (fn *evalVarArgs) Args() []Interpretable { +func (fn *evalVarArgs) Args() []InterpretableV2 { return fn.args } type evalList struct { id int64 - elems []Interpretable + elems []InterpretableV2 optionals []bool hasOptionals bool adapter types.Adapter @@ -624,12 +744,12 @@ func (l *evalList) ID() int64 { return l.id } -// Eval implements the Interpretable interface method. -func (l *evalList) Eval(ctx Activation) ref.Val { +// Exec implements the InterpretableV2 interface method. +func (l *evalList) Exec(frame *ExecutionFrame) ref.Val { elemVals := make([]ref.Val, 0, len(l.elems)) // If any argument is unknown or error early terminate. for i, elem := range l.elems { - elemVal := elem.Eval(ctx) + elemVal := elem.Exec(frame) if types.IsUnknownOrError(elemVal) { return elemVal } @@ -645,10 +765,15 @@ func (l *evalList) Eval(ctx Activation) ref.Val { } elemVals = append(elemVals, elemVal) } - return l.adapter.NativeToValue(elemVals) + return types.NewRefValList(l.adapter, elemVals) } -func (l *evalList) InitVals() []Interpretable { +// Eval implements the Interpretable interface method. +func (l *evalList) Eval(ctx Activation) ref.Val { + return l.Exec(AsFrame(ctx)) +} + +func (l *evalList) InitVals() []InterpretableV2 { return l.elems } @@ -658,8 +783,8 @@ func (l *evalList) Type() ref.Type { type evalMap struct { id int64 - keys []Interpretable - vals []Interpretable + keys []InterpretableV2 + vals []InterpretableV2 optionals []bool hasOptionals bool adapter types.Adapter @@ -670,16 +795,16 @@ func (m *evalMap) ID() int64 { return m.id } -// Eval implements the Interpretable interface method. -func (m *evalMap) Eval(ctx Activation) ref.Val { - entries := make(map[ref.Val]ref.Val) +// Exec implements the InterpretableV2 interface method. +func (m *evalMap) Exec(frame *ExecutionFrame) ref.Val { + entries := make(map[ref.Val]ref.Val, len(m.keys)) // If any argument is unknown or error early terminate. for i, key := range m.keys { - keyVal := key.Eval(ctx) + keyVal := key.Exec(frame) if types.IsUnknownOrError(keyVal) { return keyVal } - valVal := m.vals[i].Eval(ctx) + valVal := m.vals[i].Exec(frame) if types.IsUnknownOrError(valVal) { return valVal } @@ -696,14 +821,19 @@ func (m *evalMap) Eval(ctx Activation) ref.Val { } entries[keyVal] = valVal } - return m.adapter.NativeToValue(entries) + return types.NewRefValMap(m.adapter, entries) +} + +// Eval implements the Interpretable interface method. +func (m *evalMap) Eval(ctx Activation) ref.Val { + return m.Exec(AsFrame(ctx)) } -func (m *evalMap) InitVals() []Interpretable { +func (m *evalMap) InitVals() []InterpretableV2 { if len(m.keys) != len(m.vals) { return nil } - result := make([]Interpretable, len(m.keys)+len(m.vals)) + result := make([]InterpretableV2, len(m.keys)+len(m.vals)) idx := 0 for i, k := range m.keys { v := m.vals[i] @@ -723,7 +853,7 @@ type evalObj struct { id int64 typeName string fields []string - vals []Interpretable + vals []InterpretableV2 optionals []bool hasOptionals bool provider types.Provider @@ -734,12 +864,12 @@ func (o *evalObj) ID() int64 { return o.id } -// Eval implements the Interpretable interface method. -func (o *evalObj) Eval(ctx Activation) ref.Val { - fieldVals := make(map[string]ref.Val) +// Exec implements the InterpretableV2 interface method. +func (o *evalObj) Exec(frame *ExecutionFrame) ref.Val { + fieldVals := make(map[string]ref.Val, len(o.fields)) // If any argument is unknown or error early terminate. for i, field := range o.fields { - val := o.vals[i].Eval(ctx) + val := o.vals[i].Exec(frame) if types.IsUnknownOrError(val) { return val } @@ -759,8 +889,13 @@ func (o *evalObj) Eval(ctx Activation) ref.Val { return types.LabelErrNode(o.id, o.provider.NewValue(o.typeName, fieldVals)) } +// Eval implements the Interpretable interface method. +func (o *evalObj) Eval(ctx Activation) ref.Val { + return o.Exec(AsFrame(ctx)) +} + // InitVals implements the InterpretableConstructor interface method. -func (o *evalObj) InitVals() []Interpretable { +func (o *evalObj) InitVals() []InterpretableV2 { return o.vals } @@ -774,11 +909,11 @@ type evalFold struct { accuVar string iterVar string iterVar2 string - iterRange Interpretable - accu Interpretable - cond Interpretable - step Interpretable - result Interpretable + iterRange InterpretableV2 + accu InterpretableV2 + cond InterpretableV2 + step InterpretableV2 + result InterpretableV2 adapter types.Adapter // note an exhaustive fold will ensure that all branches are evaluated @@ -793,13 +928,13 @@ func (fold *evalFold) ID() int64 { return fold.id } -// Eval implements the Interpretable interface method. -func (fold *evalFold) Eval(ctx Activation) ref.Val { +// Exec implements the InterpretableV2 interface method. +func (fold *evalFold) Exec(frame *ExecutionFrame) ref.Val { // Initialize the folder interface - f := newFolder(fold, ctx) + f := newFolder(fold, frame) defer releaseFolder(f) - foldRange := fold.iterRange.Eval(ctx) + foldRange := fold.iterRange.Exec(frame) if types.IsUnknownOrError(foldRange) { return foldRange } @@ -824,14 +959,19 @@ func (fold *evalFold) Eval(ctx Activation) ref.Val { return f.foldIterable(iterable) } +// Eval implements the Interpretable interface method. +func (fold *evalFold) Eval(ctx Activation) ref.Val { + return fold.Exec(AsFrame(ctx)) +} + // Optional Interpretable implementations that specialize, subsume, or extend the core evaluation // plan via decorators. // evalSetMembership is an Interpretable implementation which tests whether an input value // exists within the set of map keys used to model a set. type evalSetMembership struct { - inst Interpretable - arg Interpretable + inst InterpretableV2 + arg InterpretableV2 valueSet map[ref.Val]ref.Val } @@ -840,9 +980,9 @@ func (e *evalSetMembership) ID() int64 { return e.inst.ID() } -// Eval implements the Interpretable interface method. -func (e *evalSetMembership) Eval(ctx Activation) ref.Val { - val := e.arg.Eval(ctx) +// Exec implements the InterpretableV2 interface method. +func (e *evalSetMembership) Exec(frame *ExecutionFrame) ref.Val { + val := e.arg.Exec(frame) if types.IsUnknownOrError(val) { return val } @@ -852,18 +992,28 @@ func (e *evalSetMembership) Eval(ctx Activation) ref.Val { return types.False } +// Eval implements the Interpretable interface method. +func (e *evalSetMembership) Eval(ctx Activation) ref.Val { + return e.Exec(AsFrame(ctx)) +} + // evalWatch is an Interpretable implementation that wraps the execution of a given // expression so that it may observe the computed value and send it to an observer. type evalWatch struct { - Interpretable + InterpretableV2 observer EvalObserver } +// Exec implements the InterpretableV2 interface method. +func (e *evalWatch) Exec(frame *ExecutionFrame) ref.Val { + val := e.InterpretableV2.Exec(frame) + e.observer(frame, e.ID(), e.InterpretableV2, val) + return val +} + // Eval implements the Interpretable interface method. func (e *evalWatch) Eval(vars Activation) ref.Val { - val := e.Interpretable.Eval(vars) - e.observer(vars, e.ID(), e.Interpretable, val) - return val + return e.Exec(AsFrame(vars)) } // evalWatchAttr describes a watcher of an InterpretableAttribute Interpretable. @@ -918,11 +1068,16 @@ func (e *evalWatchAttr) AddQualifier(q Qualifier) (Attribute, error) { return e, err } +// Exec implements the InterpretableV2 interface method. +func (e *evalWatchAttr) Exec(frame *ExecutionFrame) ref.Val { + val := e.InterpretableAttribute.Exec(frame) + e.observer(frame, e.ID(), e.InterpretableAttribute, val) + return val +} + // Eval implements the Interpretable interface method. func (e *evalWatchAttr) Eval(vars Activation) ref.Val { - val := e.InterpretableAttribute.Eval(vars) - e.observer(vars, e.ID(), e.InterpretableAttribute, val) - return val + return e.Exec(AsFrame(vars)) } // evalWatchConstQual observes the qualification of an object using a constant boolean, int, @@ -1049,17 +1204,22 @@ type evalWatchConst struct { observer EvalObserver } -// Eval implements the Interpretable interface method. -func (e *evalWatchConst) Eval(vars Activation) ref.Val { +// Exec implements the InterpretableV2 interface method. +func (e *evalWatchConst) Exec(frame *ExecutionFrame) ref.Val { val := e.Value() - e.observer(vars, e.ID(), e.InterpretableConst, val) + e.observer(frame, e.ID(), e.InterpretableConst, val) return val } +// Eval implements the Interpretable interface method. +func (e *evalWatchConst) Eval(vars Activation) ref.Val { + return e.Exec(AsFrame(vars)) +} + // evalExhaustiveOr is just like evalOr, but does not short-circuit argument evaluation. type evalExhaustiveOr struct { id int64 - terms []Interpretable + terms []InterpretableV2 } // ID implements the Interpretable interface method. @@ -1067,13 +1227,13 @@ func (or *evalExhaustiveOr) ID() int64 { return or.id } -// Eval implements the Interpretable interface method. -func (or *evalExhaustiveOr) Eval(ctx Activation) ref.Val { +// Exec implements the InterpretableV2 interface method. +func (or *evalExhaustiveOr) Exec(frame *ExecutionFrame) ref.Val { var err ref.Val = nil var unk *types.Unknown isTrue := false for _, term := range or.terms { - val := term.Eval(ctx) + val := term.Exec(frame) boolVal, ok := val.(types.Bool) // flag the result as true if ok && boolVal == types.True { @@ -1103,10 +1263,15 @@ func (or *evalExhaustiveOr) Eval(ctx Activation) ref.Val { return types.False } +// Eval implements the Interpretable interface method. +func (or *evalExhaustiveOr) Eval(ctx Activation) ref.Val { + return or.Exec(AsFrame(ctx)) +} + // evalExhaustiveAnd is just like evalAnd, but does not short-circuit argument evaluation. type evalExhaustiveAnd struct { id int64 - terms []Interpretable + terms []InterpretableV2 } // ID implements the Interpretable interface method. @@ -1114,13 +1279,13 @@ func (and *evalExhaustiveAnd) ID() int64 { return and.id } -// Eval implements the Interpretable interface method. -func (and *evalExhaustiveAnd) Eval(ctx Activation) ref.Val { +// Exec implements the InterpretableV2 interface method. +func (and *evalExhaustiveAnd) Exec(frame *ExecutionFrame) ref.Val { var err ref.Val = nil var unk *types.Unknown isFalse := false for _, term := range and.terms { - val := term.Eval(ctx) + val := term.Exec(frame) boolVal, ok := val.(types.Bool) // short-circuit on false. if ok && boolVal == types.False { @@ -1150,6 +1315,11 @@ func (and *evalExhaustiveAnd) Eval(ctx Activation) ref.Val { return types.True } +// Eval implements the Interpretable interface method. +func (and *evalExhaustiveAnd) Eval(ctx Activation) ref.Val { + return and.Exec(AsFrame(ctx)) +} + // evalExhaustiveConditional is like evalConditional, but does not short-circuit argument // evaluation. type evalExhaustiveConditional struct { @@ -1163,11 +1333,11 @@ func (cond *evalExhaustiveConditional) ID() int64 { return cond.id } -// Eval implements the Interpretable interface method. -func (cond *evalExhaustiveConditional) Eval(ctx Activation) ref.Val { - cVal := cond.attr.expr.Eval(ctx) - tVal, tErr := cond.attr.truthy.Resolve(ctx) - fVal, fErr := cond.attr.falsy.Resolve(ctx) +// Exec implements the InterpretableV2 interface method. +func (cond *evalExhaustiveConditional) Exec(frame *ExecutionFrame) ref.Val { + cVal := cond.attr.expr.Exec(frame) + tVal, tErr := cond.attr.truthy.Resolve(frame) + fVal, fErr := cond.attr.falsy.Resolve(frame) cBool, ok := cVal.(types.Bool) if !ok { return types.ValOrErr(cVal, "no such overload") @@ -1184,6 +1354,11 @@ func (cond *evalExhaustiveConditional) Eval(ctx Activation) ref.Val { return cond.adapter.NativeToValue(fVal) } +// Eval implements the Interpretable interface method. +func (cond *evalExhaustiveConditional) Eval(ctx Activation) ref.Val { + return cond.Exec(AsFrame(ctx)) +} + // evalAttr evaluates an Attribute value. type evalAttr struct { adapter types.Adapter @@ -1215,15 +1390,20 @@ func (a *evalAttr) Adapter() types.Adapter { return a.adapter } -// Eval implements the Interpretable interface method. -func (a *evalAttr) Eval(ctx Activation) ref.Val { - v, err := a.attr.Resolve(ctx) +// Exec implements the InterpretableV2 interface method. +func (a *evalAttr) Exec(frame *ExecutionFrame) ref.Val { + v, err := a.attr.Resolve(frame) if err != nil { return types.LabelErrNode(a.ID(), types.WrapErr(err)) } return a.adapter.NativeToValue(v) } +// Eval implements the Interpretable interface method. +func (a *evalAttr) Eval(ctx Activation) ref.Val { + return a.Exec(AsFrame(ctx)) +} + // Qualify proxies to the Attribute's Qualify method. func (a *evalAttr) Qualify(vars Activation, obj any) (any, error) { return a.attr.Qualify(vars, obj) @@ -1249,7 +1429,7 @@ type evalWatchConstructor struct { } // InitVals implements the InterpretableConstructor InitVals function. -func (c *evalWatchConstructor) InitVals() []Interpretable { +func (c *evalWatchConstructor) InitVals() []InterpretableV2 { return c.constructor.InitVals() } @@ -1263,11 +1443,16 @@ func (c *evalWatchConstructor) ID() int64 { return c.constructor.ID() } +// Exec implements the InterpretableV2 interface method. +func (c *evalWatchConstructor) Exec(frame *ExecutionFrame) ref.Val { + val := c.constructor.Exec(frame) + c.observer(frame, c.ID(), c.constructor, val) + return val +} + // Eval implements the Interpretable Eval function. func (c *evalWatchConstructor) Eval(vars Activation) ref.Val { - val := c.constructor.Eval(vars) - c.observer(vars, c.ID(), c.constructor, val) - return val + return c.Exec(AsFrame(vars)) } func invalidOptionalEntryInit(field any, value ref.Val) ref.Val { @@ -1279,10 +1464,10 @@ func invalidOptionalElementInit(value ref.Val) ref.Val { } // newFolder creates or initializes a pooled folder instance. -func newFolder(eval *evalFold, ctx Activation) *folder { +func newFolder(eval *evalFold, frame *ExecutionFrame) *folder { f := folderPool.Get().(*folder) f.evalFold = eval - f.activation = ctx + f.frame = frame.push(f) return f } @@ -1303,7 +1488,7 @@ func releaseFolder(f *folder) { // cel.bind or cel.@block. type folder struct { *evalFold - activation Activation + frame *ExecutionFrame // fold state objects. accuVal ref.Val @@ -1322,16 +1507,16 @@ func (f *folder) foldIterable(iterable traits.Iterable) ref.Val { for it.HasNext() == types.True { f.iterVar1Val = it.Next() - cond := f.cond.Eval(f) + cond := f.cond.Exec(f.frame) condBool, ok := cond.(types.Bool) if f.interrupted || (!f.exhaustive && ok && condBool != types.True) { return f.evalResult() } // Update the accumulation value and check for eval interuption. - f.accuVal = f.step.Eval(f) + f.accuVal = f.step.Exec(f.frame) f.initialized = true - if f.interruptable && checkInterrupt(f.activation) { + if f.interruptable && f.frame.CheckInterrupt() { f.interrupted = true return f.evalResult() } @@ -1348,16 +1533,16 @@ func (f *folder) FoldEntry(key, val any) bool { // Terminate evaluation if evaluation is interrupted or the condition is not true and exhaustive // eval is not enabled. - cond := f.cond.Eval(f) + cond := f.cond.Exec(f.frame) condBool, ok := cond.(types.Bool) if f.interrupted || (!f.exhaustive && ok && condBool != types.True) { return false } // Update the accumulation value and check for eval interuption. - f.accuVal = f.step.Eval(f) + f.accuVal = f.step.Exec(f.frame) f.initialized = true - if f.interruptable && checkInterrupt(f.activation) { + if f.interruptable && f.frame.CheckInterrupt() { f.interrupted = true return false } @@ -1371,7 +1556,7 @@ func (f *folder) ResolveName(name string) (any, bool) { if name == f.accuVar { if !f.initialized { f.initialized = true - initVal := f.accu.Eval(f.activation) + initVal := f.accu.Exec(f.frame.parent) if !f.exhaustive { if l, isList := initVal.(traits.Lister); isList && l.Size() == types.IntZero { initVal = types.NewMutableList(f.adapter) @@ -1396,23 +1581,23 @@ func (f *folder) ResolveName(name string) (any, bool) { return f.iterVar2Val, true } } - return f.activation.ResolveName(name) + return f.frame.parent.ResolveName(name) } // Parent returns the activation embedded into the folder. func (f *folder) Parent() Activation { - return f.activation + return f.frame.parent } // Unwrap returns the parent activation, thus omitting access to local state func (f *folder) Unwrap() Activation { - return f.activation + return f.frame.parent } // UnknownAttributePatterns implements the PartialActivation interface returning the unknown patterns // if they were provided to the input activation, or an empty set if the proxied activation is not partial. func (f *folder) UnknownAttributePatterns() []*AttributePattern { - if pv, ok := f.activation.(partialActivationConverter); ok { + if pv, ok := f.frame.parent.Activation.(partialActivationConverter); ok { if partial, isPartial := pv.AsPartialActivation(); isPartial { return partial.UnknownAttributePatterns() } @@ -1421,7 +1606,7 @@ func (f *folder) UnknownAttributePatterns() []*AttributePattern { } func (f *folder) AsPartialActivation() (PartialActivation, bool) { - if pv, ok := f.activation.(partialActivationConverter); ok { + if pv, ok := f.frame.parent.Activation.(partialActivationConverter); ok { if _, isPartial := pv.AsPartialActivation(); isPartial { return f, true } @@ -1435,7 +1620,7 @@ func (f *folder) evalResult() ref.Val { if f.interrupted { return types.WrapErr(InterruptError{}) } - res := f.result.Eval(f) + res := f.result.Exec(f.frame) // Convert a mutable list or map to an immutable one if the comprehension has generated a list or // map as a result. if !types.IsUnknownOrError(res) && f.mutableValue { @@ -1452,7 +1637,8 @@ func (f *folder) evalResult() ref.Val { // reset clears any state associated with folder evaluation. func (f *folder) reset() { f.evalFold = nil - f.activation = nil + f.frame.pop() + f.frame = nil f.accuVal = nil f.iterVar1Val = nil f.iterVar2Val = nil @@ -1463,11 +1649,6 @@ func (f *folder) reset() { f.computeResult = false } -func checkInterrupt(a Activation) bool { - stop, found := a.ResolveName("#interrupted") - return found && stop == true -} - // InterruptError is a specialized error type used to signal that program evaluation should check // whether a context cancellation is responsible for the error. type InterruptError struct{} diff --git a/interpreter/interpreter.go b/interpreter/interpreter.go index d81ef128..77a8b3ac 100644 --- a/interpreter/interpreter.go +++ b/interpreter/interpreter.go @@ -29,11 +29,11 @@ import ( // PlannerOption configures the program plan options during interpretable setup. type PlannerOption func(*planner) (*planner, error) -// Interpreter generates a new Interpretable from a checked or unchecked expression. +// Interpreter generates a new InterpretableV2 from a checked or unchecked expression. type Interpreter interface { - // NewInterpretable creates an Interpretable from a checked expression and an + // NewInterpretable creates an InterpretableV2 from a checked expression and an // optional list of PlannerOption values. - NewInterpretable(exprAST *ast.AST, opts ...PlannerOption) (Interpretable, error) + NewInterpretable(exprAST *ast.AST, opts ...PlannerOption) (InterpretableV2, error) } // EvalObserver is a functional interface that accepts an expression id and an observed value. @@ -44,15 +44,15 @@ type EvalObserver func(vars Activation, id int64, programStep any, value ref.Val // StatefulObserver observes evaluation while tracking or utilizing stateful behavior. type StatefulObserver interface { // InitState configures stateful metadata on the activation. - InitState(Activation) (Activation, error) + InitState(*ExecutionFrame) (any, error) // GetState retrieves the stateful metadata from the activation. - GetState(Activation) any + GetState(*ExecutionFrame) any // Observe passes the activation and relevant evaluation metadata to the observer. // The observe method is expected to do the equivalent of GetState(vars) in order // to find the metadata that needs to be updated upon invocation. - Observe(vars Activation, id int64, programStep any, value ref.Val) + Observe(Activation, int64, any, ref.Val) } // EvalCancelledError represents a cancelled program evaluation operation. @@ -106,37 +106,6 @@ func EvalStateObserver(opts ...evalStateOption) PlannerOption { } } -// evalStateConverter identifies an object which is convertible to an EvalState instance. -type evalStateConverter interface { - asEvalState() EvalState -} - -// evalStateActivation hides state in the Activation in a manner not accessible to expressions. -type evalStateActivation struct { - vars Activation - state EvalState -} - -// ResolveName proxies variable lookups to the backing activation. -func (esa evalStateActivation) ResolveName(name string) (any, bool) { - return esa.vars.ResolveName(name) -} - -// Parent proxies parent lookups to the backing activation. -func (esa evalStateActivation) Parent() Activation { - return esa.vars -} - -// AsPartialActivation supports conversion to a partial activation in order to detect unknown attributes. -func (esa evalStateActivation) AsPartialActivation() (PartialActivation, bool) { - return AsPartialActivation(esa.vars) -} - -// asEvalState implements the evalStateConverter method. -func (esa evalStateActivation) asEvalState() EvalState { - return esa.state -} - // activationWrapper identifies an object carrying local variables which should not be exposed to the user // Activations used for such purposes can be unwrapped to return the activation which omits local state. type activationWrapper interface { @@ -144,57 +113,56 @@ type activationWrapper interface { Unwrap() Activation } -// asEvalState walks the Activation hierarchy and returns the first EvalState found, if present. -func asEvalState(vars Activation) (EvalState, bool) { - if conv, ok := vars.(evalStateConverter); ok { - return conv.asEvalState(), true - } - // Check if the current activation wraps another activation. This is used to support - // wrappers such as the @block() activation which may be composed of a dynamicSlotActivation or a - // constantSlotActivation. In this case, the underlying activation is the portion which interacts - // with the EvalState. - if wrapper, ok := vars.(activationWrapper); ok { - unwrapped := wrapper.Unwrap() - // Recursively call asEvalState on the unwrapped activation. This will check the unwrapped value and its parents. - return asEvalState(unwrapped) - } - if vars.Parent() != nil { - return asEvalState(vars.Parent()) - } - return nil, false -} - // evalStateFactory holds a reference to a factory function that produces an EvalState instance. type evalStateFactory struct { factory func() EvalState } -// InitState produces an EvalState instance and bundles it into the Activation in a way which is +// InitState produces an EvalState instance and bundles it into the ExecutionFrame in a way which is // not visible to expression evaluation. -func (et *evalStateFactory) InitState(vars Activation) (Activation, error) { +func (et *evalStateFactory) InitState(frame *ExecutionFrame) (any, error) { state := et.factory() - return evalStateActivation{vars: vars, state: state}, nil + if frame.ctx == nil { + frame.ctx = evalContextPool.Get().(*evalContext) + } + frame.ctx.state = state + return state, nil } // GetState extracts the EvalState from the Activation. -func (et *evalStateFactory) GetState(vars Activation) any { - if state, found := asEvalState(vars); found { - return state +func (et *evalStateFactory) GetState(frame *ExecutionFrame) any { + if frame.ctx == nil { + return nil } - return nil + return frame.ctx.state } // Observe records the evaluation state for a given expression node and program step. func (et *evalStateFactory) Observe(vars Activation, id int64, programStep any, val ref.Val) { - state, found := asEvalState(vars) - if !found { + frame := AsFrame(vars) + if frame.ctx == nil || frame.ctx.state == nil { return } - state.SetValue(id, val) + frame.ctx.state.SetValue(id, val) } // CustomDecorator configures a custom interpretable decorator for the program. func CustomDecorator(dec InterpretableDecorator) PlannerOption { + return func(p *planner) (*planner, error) { + dec2 := func(i InterpretableV2) (InterpretableV2, error) { + legacy, err := dec(i) + if err != nil { + return nil, err + } + return adaptToV2(legacy), nil + } + p.decorators = append(p.decorators, dec2) + return p, nil + } +} + +// CustomDecoratorV2 configures a custom V2 interpretable decorator for the program. +func CustomDecoratorV2(dec InterpretableDecoratorV2) PlannerOption { return func(p *planner) (*planner, error) { p.decorators = append(p.decorators, dec) return p, nil @@ -207,7 +175,7 @@ func CustomDecorator(dec InterpretableDecorator) PlannerOption { // provided to the decorator. This decorator is not thread-safe, and the EvalState // must be reset between Eval() calls. func ExhaustiveEval() PlannerOption { - return CustomDecorator(decDisableShortcircuits()) + return CustomDecoratorV2(decDisableShortcircuits()) } // InterruptableEval annotates comprehension loops with information that indicates they @@ -216,13 +184,13 @@ func ExhaustiveEval() PlannerOption { // The custom activation is currently managed higher up in the stack within the 'cel' package // and should not require any custom support on behalf of callers. func InterruptableEval() PlannerOption { - return CustomDecorator(decInterruptFolds()) + return CustomDecoratorV2(decInterruptFolds()) } // Optimize will pre-compute operations such as list and map construction and optimize // call arguments to set membership tests. The set of optimizations will increase over time. func Optimize() PlannerOption { - return CustomDecorator(decOptimize()) + return CustomDecoratorV2(decOptimize()) } // RegexOptimization provides a way to replace an InterpretableCall for a regex function when the @@ -247,7 +215,7 @@ type RegexOptimization struct { // CompileRegexConstants compiles regex pattern string constants at program creation time and reports any regex pattern // compile errors. func CompileRegexConstants(regexOptimizations ...*RegexOptimization) PlannerOption { - return CustomDecorator(decRegexOptimizer(regexOptimizations...)) + return CustomDecoratorV2(decRegexOptimizer(regexOptimizations...)) } type exprInterpreter struct { @@ -273,10 +241,10 @@ func NewInterpreter(dispatcher Dispatcher, attrFactory: attrFactory} } -// NewIntepretable implements the Interpreter interface method. +// NewInterpretable implements the Interpreter interface method. func (i *exprInterpreter) NewInterpretable( checked *ast.AST, - opts ...PlannerOption) (Interpretable, error) { + opts ...PlannerOption) (InterpretableV2, error) { p := newPlanner(i.dispatcher, i.provider, i.adapter, i.attrFactory, i.container, checked) var err error for _, o := range opts { diff --git a/interpreter/interpreter_test.go b/interpreter/interpreter_test.go index 6bdb03e1..02f4b370 100644 --- a/interpreter/interpreter_test.go +++ b/interpreter/interpreter_test.go @@ -1695,7 +1695,7 @@ func BenchmarkInterpreter(b *testing.B) { if tst.err != "" || tst.progErr != "" { continue } - prg, vars, err := program(b, &tst, Optimize(), CompileRegexConstants(MatchesRegexOptimization)) + prg, frame, err := program(b, &tst, Optimize(), CompileRegexConstants(MatchesRegexOptimization)) if err != nil { b.Fatal(err) } @@ -1704,7 +1704,7 @@ func BenchmarkInterpreter(b *testing.B) { b.ResetTimer() b.ReportAllocs() for i := 0; i < b.N; i++ { - prg.Eval(vars) + prg.Exec(frame) } }) } @@ -1894,26 +1894,6 @@ func TestInterpreter_ExhaustiveConditionalExpr(t *testing.T) { } } -func TestInterpreter_WrappedActivationEvalState(t *testing.T) { - vars, _ := NewActivation(map[string]any{ - "a": types.True, - "b": types.True, - "c": types.False, - "d": types.False, - }) - state := NewEvalState() - esa := &evalStateActivation{vars: vars, state: state} - wrappedVars := &testActivationWrapper{esa, "test_activation_wrapper"} - ac, _ := NewActivation(wrappedVars) - es, found := asEvalState(ac) - if !found { - t.Errorf("asEvalState(%v) failed to find EvalState", ac) - } - if es != state { - t.Errorf("asEvalState(%v) returned %v, wanted %v", ac, es, state) - } -} - func TestInterpreter_InterruptableEval(t *testing.T) { items := make([]int64, 5000) for i := int64(0); i < 5000; i++ { @@ -1929,7 +1909,7 @@ func TestInterpreter_InterruptableEval(t *testing.T) { }, out: true, } - prg, vars, err := program(t, &tc, InterruptableEval()) + prg, frame, err := program(t, &tc, InterruptableEval()) if err != nil { t.Fatalf("program(%s) failed: %v", tc.expr, err) } @@ -1938,37 +1918,14 @@ func TestInterpreter_InterruptableEval(t *testing.T) { evalCtx, cancel := context.WithTimeout(ctx, 10*time.Microsecond) defer cancel() - ctxVars := &contextActivation{ - Activation: vars, - interrupt: func() bool { - select { - case <-evalCtx.Done(): - return true - default: - return false - } - }, - } - out := prg.Eval(ctxVars) + frame.SetContext(evalCtx, 100) + out := prg.Exec(frame) + frame.Close() if !types.IsError(out) || out.(*types.Err).String() != "operation interrupted" { t.Errorf("Got %v, wanted operation interrupted error", out) } } -type contextActivation struct { - Activation - interruptCount int - interrupt func() bool -} - -func (ca *contextActivation) ResolveName(name string) (any, bool) { - if name == "#interrupted" { - ca.interruptCount++ - return ca.interruptCount%100 == 0 && ca.interrupt(), true - } - return ca.Activation.ResolveName(name) -} - func TestInterpreter_ExhaustiveLogicalOrEquals(t *testing.T) { // a || b == "b" // Operator "==" is at Expr 4, should be evaluated though "a" is true @@ -2277,7 +2234,7 @@ func testContainer(name string) *containers.Container { return cont } -func program(t testing.TB, tst *testCase, opts ...PlannerOption) (Interpretable, Activation, error) { +func program(t testing.TB, tst *testCase, opts ...PlannerOption) (InterpretableV2, *ExecutionFrame, error) { // Configure the package. cont := containers.DefaultContainer if tst.container != "" { @@ -2353,7 +2310,7 @@ func program(t testing.TB, tst *testCase, opts ...PlannerOption) (Interpretable, if err != nil { return nil, nil, err } - return prg, vars, nil + return prg, AsFrame(vars), nil } // Check the expression. checked, errs := checker.Check(parsed, s, env) @@ -2365,7 +2322,7 @@ func program(t testing.TB, tst *testCase, opts ...PlannerOption) (Interpretable, if err != nil { return nil, nil, err } - return prg, vars, nil + return prg, AsFrame(vars), nil } func base64Encode(val ref.Val) ref.Val { @@ -2513,3 +2470,237 @@ type testActivationWrapper struct { func (tw *testActivationWrapper) Unwrap() Activation { return tw.Activation } + +func TestInterruptErrorIs(t *testing.T) { + ie := InterruptError{} + tests := []struct { + name string + target error + want bool + }{ + { + name: "same type", + target: InterruptError{}, + want: true, + }, + { + name: "different error", + target: fmt.Errorf("other error"), + want: false, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := ie.Is(tc.target); got != tc.want { + t.Errorf("Is(%v) = %t, wanted %t", tc.target, got, tc.want) + } + }) + } +} + +func TestFolderActivation(t *testing.T) { + parentAct := EmptyActivation() + frame := mustNewExecutionFrame(t, parentAct) + defer frame.Close() + fld := newFolder(&evalFold{}, frame) + defer releaseFolder(fld) + if fld.Parent() != frame { + t.Errorf("fld.Parent() = %v, wanted %v", fld.Parent(), frame) + } + if fld.Unwrap() != frame { + t.Errorf("fld.Unwrap() = %v, wanted %v", fld.Unwrap(), frame) + } +} + +func TestEvalWatchConstructor(t *testing.T) { + watchConst := &evalWatchConstructor{ + constructor: &evalList{ + id: 1, + elems: []InterpretableV2{NewConstValue(2, types.IntOne)}, + }, + } + if len(watchConst.InitVals()) != 1 { + t.Errorf("watchConst.InitVals() len = %d, wanted 1", len(watchConst.InitVals())) + } + if watchConst.Type() != types.ListType { + t.Errorf("watchConst.Type() = %v, wanted %v", watchConst.Type(), types.ListType) + } +} + +func TestCustomDecorator(t *testing.T) { + decV1 := func(i Interpretable) (Interpretable, error) { + return i, nil + } + opt := CustomDecorator(decV1) + p := &planner{} + p2, err := opt(p) + if err != nil { + t.Fatalf("CustomDecorator failed: %v", err) + } + if len(p2.decorators) != 1 { + t.Fatalf("CustomDecorator did not add decorator") + } + res, err := p2.decorators[0](NewConstValue(1, types.IntOne)) + if err != nil { + t.Fatalf("wrapped decorator failed: %v", err) + } + if res.ID() != 1 { + t.Errorf("wrapped decorator returned node with ID %d, wanted 1", res.ID()) + } +} + +func TestCostTrackerActualCost(t *testing.T) { + ct := &CostTracker{cost: 42} + if ct.ActualCost() != 42 { + t.Errorf("ct.ActualCost() = %d, wanted 42", ct.ActualCost()) + } +} + +func TestV2Adapter(t *testing.T) { + legacy := &testLegacyInterpretable{id: 42} + adapted := adaptToV2(legacy) + if adapted.ID() != 42 { + t.Errorf("adapted.ID() = %d, wanted 42", adapted.ID()) + } + frame := mustNewExecutionFrame(t, EmptyActivation()) + defer frame.Close() + val := adapted.Exec(frame) + if val.Equal(types.IntOne) != types.True { + t.Errorf("adapted.Exec() = %v, wanted 1", val) + } +} + +func TestInterpretableArgs(t *testing.T) { + tests := []struct { + name string + call InterpretableCall + want int + }{ + { + name: "evalNe", + call: &evalNe{lhs: NewConstValue(1, types.IntOne), rhs: NewConstValue(2, types.IntOne)}, + want: 2, + }, + { + name: "evalZeroArity", + call: &evalZeroArity{}, + want: 0, + }, + { + name: "evalVarArgs", + call: &evalVarArgs{args: []InterpretableV2{NewConstValue(1, types.IntOne)}}, + want: 1, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := len(tc.call.Args()); got != tc.want { + t.Errorf("Args() len = %d, wanted %d", got, tc.want) + } + }) + } +} + +func TestNewCall(t *testing.T) { + call := NewCall(10, "f", "f_overload", []InterpretableV2{NewConstValue(1, types.IntOne)}, func(args ...ref.Val) ref.Val { return types.True }) + if call.ID() != 10 { + t.Errorf("NewCall.ID() = %d, wanted 10", call.ID()) + } +} + +func TestExhaustiveOperatorsLegacyEval(t *testing.T) { + reg, _ := types.NewRegistry() + fac := NewAttributeFactory(containers.DefaultContainer, reg, reg) + + tests := []struct { + name string + expr Interpretable + }{ + { + name: "exhaustive or", + expr: &evalExhaustiveOr{id: 1}, + }, + { + name: "exhaustive and", + expr: &evalExhaustiveAnd{id: 2}, + }, + { + name: "exhaustive conditional", + expr: &evalExhaustiveConditional{ + id: 3, + attr: &conditionalAttribute{ + expr: NewConstValue(4, types.True), + truthy: fac.AbsoluteAttribute(5, "a"), + falsy: fac.AbsoluteAttribute(6, "b"), + }, + }, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + tc.expr.Eval(EmptyActivation()) + }) + } +} + +func TestFindFrame(t *testing.T) { + frame := mustNewExecutionFrame(t, EmptyActivation()) + defer frame.Close() + + tests := []struct { + name string + act Activation + }{ + { + name: "nested wrapper", + act: &testActivationWrapper{Activation: &testActivationWrapper{Activation: frame, name: "w1"}, name: "w2"}, + }, + { + name: "parent hierarchy", + act: &parentActivationWrapper{parent: frame}, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + found := findFrame(tc.act) + if found != frame { + t.Errorf("findFrame() = %v, wanted %v", found, frame) + } + }) + } +} + +func TestObservableInterpretable(t *testing.T) { + obsInt := &ObservableInterpretable{InterpretableV2: NewConstValue(12, types.True)} + if obsInt.ID() != 12 { + t.Errorf("obsInt.ID() = %d, wanted 12", obsInt.ID()) + } + res := obsInt.Eval(EmptyActivation()) + if res.Equal(types.True) != types.True { + t.Errorf("obsInt.Eval() = %v, wanted true", res) + } +} + +type testLegacyInterpretable struct { + id int64 +} + +func (t *testLegacyInterpretable) ID() int64 { + return t.id +} + +func (t *testLegacyInterpretable) Eval(vars Activation) ref.Val { + return types.IntOne +} + +type parentActivationWrapper struct { + parent Activation +} + +func (paw *parentActivationWrapper) ResolveName(name string) (any, bool) { + return paw.parent.ResolveName(name) +} + +func (paw *parentActivationWrapper) Parent() Activation { + return paw.parent +} diff --git a/interpreter/planner.go b/interpreter/planner.go index 0bc38449..95fd3455 100644 --- a/interpreter/planner.go +++ b/interpreter/planner.go @@ -43,7 +43,7 @@ func newPlanner(disp Dispatcher, container: cont, refMap: exprAST.ReferenceMap(), typeMap: exprAST.TypeMap(), - decorators: make([]InterpretableDecorator, 0), + decorators: make([]InterpretableDecoratorV2, 0), observers: make([]StatefulObserver, 0), } } @@ -57,7 +57,7 @@ type planner struct { container *containers.Container refMap map[int64]*ast.ReferenceInfo typeMap map[int64]*types.Type - decorators []InterpretableDecorator + decorators []InterpretableDecoratorV2 observers []StatefulObserver } @@ -72,7 +72,7 @@ type planBuilder struct { // useful for layering functionality into the evaluation that is not natively understood by CEL, // such as state-tracking, expression re-write, and possibly efficient thread-safe memoization of // repeated expressions. -func (p *planner) Plan(expr ast.Expr) (Interpretable, error) { +func (p *planner) Plan(expr ast.Expr) (InterpretableV2, error) { pb := &planBuilder{planner: p, localVars: make(map[string]int)} i, err := pb.plan(expr) if err != nil { @@ -81,10 +81,10 @@ func (p *planner) Plan(expr ast.Expr) (Interpretable, error) { if len(p.observers) == 0 { return i, nil } - return &ObservableInterpretable{Interpretable: i, observers: p.observers}, nil + return &ObservableInterpretable{InterpretableV2: i, observers: p.observers}, nil } -func (p *planBuilder) plan(expr ast.Expr) (Interpretable, error) { +func (p *planBuilder) plan(expr ast.Expr) (InterpretableV2, error) { switch expr.Kind() { case ast.CallKind: return p.decorate(p.planCall(expr)) @@ -109,7 +109,7 @@ func (p *planBuilder) plan(expr ast.Expr) (Interpretable, error) { // decorate applies the InterpretableDecorator functions to the given Interpretable. // Both the Interpretable and error generated by a Plan step are accepted as arguments // for convenience. -func (p *planBuilder) decorate(i Interpretable, err error) (Interpretable, error) { +func (p *planBuilder) decorate(i InterpretableV2, err error) (InterpretableV2, error) { if err != nil { return nil, err } @@ -123,7 +123,7 @@ func (p *planBuilder) decorate(i Interpretable, err error) (Interpretable, error } // planIdent creates an Interpretable that resolves an identifier from an Activation. -func (p *planBuilder) planIdent(expr ast.Expr) (Interpretable, error) { +func (p *planBuilder) planIdent(expr ast.Expr) (InterpretableV2, error) { // Establish whether the identifier is in the reference map. if identRef, found := p.refMap[expr.ID()]; found { return p.planCheckedIdent(expr.ID(), identRef) @@ -142,7 +142,7 @@ func (p *planBuilder) planIdent(expr ast.Expr) (Interpretable, error) { }, nil } -func (p *planBuilder) planCheckedIdent(id int64, identRef *ast.ReferenceInfo) (Interpretable, error) { +func (p *planBuilder) planCheckedIdent(id int64, identRef *ast.ReferenceInfo) (InterpretableV2, error) { // Plan a constant reference if this is the case for this simple identifier. if identRef.Value != nil { return NewConstValue(id, identRef.Value), nil @@ -171,7 +171,7 @@ func (p *planBuilder) planCheckedIdent(id int64, identRef *ast.ReferenceInfo) (I // a) selects a field from a map or proto. // b) creates a field presence test for a select within a has() macro. // c) resolves the select expression to a namespaced identifier. -func (p *planBuilder) planSelect(expr ast.Expr) (Interpretable, error) { +func (p *planBuilder) planSelect(expr ast.Expr) (InterpretableV2, error) { // If the Select id appears in the reference map from the CheckedExpr proto then it is either // a namespaced identifier or enum value. if identRef, found := p.refMap[expr.ID()]; found { @@ -227,7 +227,7 @@ func (p *planBuilder) planSelect(expr ast.Expr) (Interpretable, error) { // planCall creates a callable Interpretable while specializing for common functions and invocation // patterns. Specifically, conditional operators &&, ||, ?:, and (in)equality functions result in // optimized Interpretable values. -func (p *planBuilder) planCall(expr ast.Expr) (Interpretable, error) { +func (p *planBuilder) planCall(expr ast.Expr) (InterpretableV2, error) { call := expr.AsCall() target, fnName, oName := p.resolveFunction(expr) argCount := len(call.Args()) @@ -237,7 +237,7 @@ func (p *planBuilder) planCall(expr ast.Expr) (Interpretable, error) { offset++ } - args := make([]Interpretable, argCount) + args := make([]InterpretableV2, argCount) if target != nil { arg, err := p.plan(target) if err != nil { @@ -307,7 +307,7 @@ func (p *planBuilder) planCall(expr ast.Expr) (Interpretable, error) { func (p *planBuilder) planCallZero(expr ast.Expr, function string, overload string, - impl *functions.Overload) (Interpretable, error) { + impl *functions.Overload) (InterpretableV2, error) { if impl == nil || impl.Function == nil { return nil, fmt.Errorf("no such overload: %s()", function) } @@ -324,7 +324,7 @@ func (p *planBuilder) planCallUnary(expr ast.Expr, function string, overload string, impl *functions.Overload, - args []Interpretable) (Interpretable, error) { + args []InterpretableV2) (InterpretableV2, error) { var fn functions.UnaryOp var trait int var nonStrict bool @@ -352,7 +352,7 @@ func (p *planBuilder) planCallBinary(expr ast.Expr, function string, overload string, impl *functions.Overload, - args []Interpretable) (Interpretable, error) { + args []InterpretableV2) (InterpretableV2, error) { var fn functions.BinaryOp var trait int var nonStrict bool @@ -381,7 +381,7 @@ func (p *planBuilder) planCallVarArgs(expr ast.Expr, function string, overload string, impl *functions.Overload, - args []Interpretable) (Interpretable, error) { + args []InterpretableV2) (InterpretableV2, error) { var fn functions.FunctionOp var trait int var nonStrict bool @@ -405,7 +405,7 @@ func (p *planBuilder) planCallVarArgs(expr ast.Expr, } // planCallEqual generates an equals (==) Interpretable. -func (p *planBuilder) planCallEqual(expr ast.Expr, args []Interpretable) (Interpretable, error) { +func (p *planBuilder) planCallEqual(expr ast.Expr, args []InterpretableV2) (InterpretableV2, error) { return &evalEq{ id: expr.ID(), lhs: args[0], @@ -414,7 +414,7 @@ func (p *planBuilder) planCallEqual(expr ast.Expr, args []Interpretable) (Interp } // planCallNotEqual generates a not equals (!=) Interpretable. -func (p *planBuilder) planCallNotEqual(expr ast.Expr, args []Interpretable) (Interpretable, error) { +func (p *planBuilder) planCallNotEqual(expr ast.Expr, args []InterpretableV2) (InterpretableV2, error) { return &evalNe{ id: expr.ID(), lhs: args[0], @@ -423,7 +423,7 @@ func (p *planBuilder) planCallNotEqual(expr ast.Expr, args []Interpretable) (Int } // planCallLogicalAnd generates a logical and (&&) Interpretable. -func (p *planBuilder) planCallLogicalAnd(expr ast.Expr, args []Interpretable) (Interpretable, error) { +func (p *planBuilder) planCallLogicalAnd(expr ast.Expr, args []InterpretableV2) (InterpretableV2, error) { return &evalAnd{ id: expr.ID(), terms: args, @@ -431,7 +431,7 @@ func (p *planBuilder) planCallLogicalAnd(expr ast.Expr, args []Interpretable) (I } // planCallLogicalOr generates a logical or (||) Interpretable. -func (p *planBuilder) planCallLogicalOr(expr ast.Expr, args []Interpretable) (Interpretable, error) { +func (p *planBuilder) planCallLogicalOr(expr ast.Expr, args []InterpretableV2) (InterpretableV2, error) { return &evalOr{ id: expr.ID(), terms: args, @@ -439,7 +439,7 @@ func (p *planBuilder) planCallLogicalOr(expr ast.Expr, args []Interpretable) (In } // planCallConditional generates a conditional / ternary (c ? t : f) Interpretable. -func (p *planBuilder) planCallConditional(expr ast.Expr, args []Interpretable) (Interpretable, error) { +func (p *planBuilder) planCallConditional(expr ast.Expr, args []InterpretableV2) (InterpretableV2, error) { cond := args[0] t := args[1] var tAttr Attribute @@ -467,7 +467,7 @@ func (p *planBuilder) planCallConditional(expr ast.Expr, args []Interpretable) ( // planCallIndex either extends an attribute with the argument to the index operation, or creates // a relative attribute based on the return of a function call or operation. -func (p *planBuilder) planCallIndex(expr ast.Expr, args []Interpretable, optional bool) (Interpretable, error) { +func (p *planBuilder) planCallIndex(expr ast.Expr, args []InterpretableV2, optional bool) (InterpretableV2, error) { op := args[0] ind := args[1] opType := p.typeMap[op.ID()] @@ -502,7 +502,7 @@ func (p *planBuilder) planCallIndex(expr ast.Expr, args []Interpretable, optiona } // planCreateList generates a list construction Interpretable. -func (p *planBuilder) planCreateList(expr ast.Expr) (Interpretable, error) { +func (p *planBuilder) planCreateList(expr ast.Expr) (InterpretableV2, error) { list := expr.AsList() optionalIndices := list.OptionalIndices() elements := list.Elements() @@ -513,7 +513,7 @@ func (p *planBuilder) planCreateList(expr ast.Expr) (Interpretable, error) { } optionals[index] = true } - elems := make([]Interpretable, len(elements)) + elems := make([]InterpretableV2, len(elements)) for i, elem := range elements { elemVal, err := p.plan(elem) if err != nil { @@ -531,12 +531,12 @@ func (p *planBuilder) planCreateList(expr ast.Expr) (Interpretable, error) { } // planCreateStruct generates a map or object construction Interpretable. -func (p *planBuilder) planCreateMap(expr ast.Expr) (Interpretable, error) { +func (p *planBuilder) planCreateMap(expr ast.Expr) (InterpretableV2, error) { m := expr.AsMap() entries := m.Entries() optionals := make([]bool, len(entries)) - keys := make([]Interpretable, len(entries)) - vals := make([]Interpretable, len(entries)) + keys := make([]InterpretableV2, len(entries)) + vals := make([]InterpretableV2, len(entries)) hasOptionals := false for i, e := range entries { entry := e.AsMapEntry() @@ -565,7 +565,7 @@ func (p *planBuilder) planCreateMap(expr ast.Expr) (Interpretable, error) { } // planCreateObj generates an object construction Interpretable. -func (p *planBuilder) planCreateStruct(expr ast.Expr) (Interpretable, error) { +func (p *planBuilder) planCreateStruct(expr ast.Expr) (InterpretableV2, error) { obj := expr.AsStruct() typeName, defined := p.resolveTypeName(obj.TypeName()) if !defined { @@ -574,7 +574,7 @@ func (p *planBuilder) planCreateStruct(expr ast.Expr) (Interpretable, error) { objFields := obj.Fields() optionals := make([]bool, len(objFields)) fields := make([]string, len(objFields)) - vals := make([]Interpretable, len(objFields)) + vals := make([]InterpretableV2, len(objFields)) hasOptionals := false for i, f := range objFields { field := f.AsStructField() @@ -599,7 +599,7 @@ func (p *planBuilder) planCreateStruct(expr ast.Expr) (Interpretable, error) { } // planComprehension generates an Interpretable fold operation. -func (p *planBuilder) planComprehension(expr ast.Expr) (Interpretable, error) { +func (p *planBuilder) planComprehension(expr ast.Expr) (InterpretableV2, error) { fold := expr.AsComprehension() accu, err := p.plan(fold.AccuInit()) if err != nil { @@ -639,7 +639,7 @@ func (p *planBuilder) planComprehension(expr ast.Expr) (Interpretable, error) { } // planConst generates a constant valued Interpretable. -func (p *planBuilder) planConst(expr ast.Expr) (Interpretable, error) { +func (p *planBuilder) planConst(expr ast.Expr) (InterpretableV2, error) { return NewConstValue(expr.ID(), expr.AsLiteral()), nil } @@ -726,7 +726,7 @@ func (p *planBuilder) resolveFunction(expr ast.Expr) (ast.Expr, string, string) // relativeAttr indicates that the attribute in this case acts as a qualifier and as such needs to // be observed to ensure that it's evaluation value is properly recorded for state tracking. -func (p *planBuilder) relativeAttr(id int64, eval Interpretable, opt bool) (InterpretableAttribute, error) { +func (p *planBuilder) relativeAttr(id int64, eval InterpretableV2, opt bool) (InterpretableAttribute, error) { eAttr, ok := eval.(InterpretableAttribute) if !ok { eAttr = &evalAttr{ diff --git a/interpreter/runtimecost.go b/interpreter/runtimecost.go index 6c44cd79..15853cf9 100644 --- a/interpreter/runtimecost.go +++ b/interpreter/runtimecost.go @@ -62,48 +62,6 @@ func CostObserver(opts ...costTrackPlanOption) PlannerOption { } } -// costTrackerConverter identifies an object which is convertible to a CostTracker instance. -type costTrackerConverter interface { - asCostTracker() *CostTracker -} - -// costTrackActivation hides state in the Activation in a manner not accessible to expressions. -type costTrackActivation struct { - vars Activation - costTracker *CostTracker -} - -// ResolveName proxies variable lookups to the backing activation. -func (cta costTrackActivation) ResolveName(name string) (any, bool) { - return cta.vars.ResolveName(name) -} - -// Parent proxies parent lookups to the backing activation. -func (cta costTrackActivation) Parent() Activation { - return cta.vars -} - -// AsPartialActivation supports conversion to a partial activation in order to detect unknown attributes. -func (cta costTrackActivation) AsPartialActivation() (PartialActivation, bool) { - return AsPartialActivation(cta.vars) -} - -// asCostTracker implements the costTrackerConverter method. -func (cta costTrackActivation) asCostTracker() *CostTracker { - return cta.costTracker -} - -// asCostTracker walks the Activation hierarchy and returns the first cost tracker found, if present. -func asCostTracker(vars Activation) (*CostTracker, bool) { - if conv, ok := vars.(costTrackerConverter); ok { - return conv.asCostTracker(), true - } - if vars.Parent() != nil { - return asCostTracker(vars.Parent()) - } - return nil, false -} - // costTrackerFactory holds a factory for producing new CostTracker instances on each Eval call. type costTrackerFactory struct { factory func() (*CostTracker, error) @@ -111,27 +69,32 @@ type costTrackerFactory struct { // InitState produces a CostTracker and bundles it into an Activation in a way which is not visible // to expression evaluation. -func (ct *costTrackerFactory) InitState(vars Activation) (Activation, error) { +func (ct *costTrackerFactory) InitState(frame *ExecutionFrame) (any, error) { tracker, err := ct.factory() if err != nil { return nil, err } - return costTrackActivation{vars: vars, costTracker: tracker}, nil + if frame.ctx == nil { + frame.ctx = evalContextPool.Get().(*evalContext) + } + frame.ctx.costs = tracker + return tracker, nil } // GetState extracts the CostTracker from the Activation. -func (ct *costTrackerFactory) GetState(vars Activation) any { - if tracker, found := asCostTracker(vars); found { - return tracker +func (ct *costTrackerFactory) GetState(frame *ExecutionFrame) any { + if frame == nil || frame.ctx == nil { + return nil } - return nil + return frame.ctx.costs } // Observe computes the incremental cost of each step and records it into the CostTracker associated // with the evaluation. func (ct *costTrackerFactory) Observe(vars Activation, id int64, programStep any, val ref.Val) { - tracker, found := asCostTracker(vars) - if !found { + frame := AsFrame(vars) + tracker := ct.GetState(frame).(*CostTracker) + if tracker == nil { return } switch t := programStep.(type) { @@ -397,7 +360,7 @@ func (s *refValStack) drop(ids ...int64) { // the stack. // WARNING: It is possible for multiple expressions with the same ID to exist (due to how macros are implemented) so it's // possible that a dropped ID will remain on the stack. They should be removed when IDs on the stack are popped. -func (s *refValStack) dropArgs(args []Interpretable) ([]ref.Val, bool) { +func (s *refValStack) dropArgs(args []InterpretableV2) ([]ref.Val, bool) { result := make([]ref.Val, len(args)) argloop: for nIdx := len(args) - 1; nIdx >= 0; nIdx-- { diff --git a/interpreter/runtimecost_test.go b/interpreter/runtimecost_test.go index a19308c6..679a5a46 100644 --- a/interpreter/runtimecost_test.go +++ b/interpreter/runtimecost_test.go @@ -161,7 +161,8 @@ func computeCost(t *testing.T, expr string, vars []*decls.VariableDecl, ctx Acti } } }() - prg.Eval(ctx) + frame := AsFrame(ctx) + prg.Exec(frame) // TODO: enable this once all attributes are properly pushed and popped from stack. //if len(costTracker.stack) != 1 { // t.Fatalf(`Expected resulting stack size to be 1 but got %d: %#+v`, len(costTracker.stack), costTracker.stack)