Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 192 additions & 0 deletions cel/cel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1713,6 +1713,77 @@ func TestCustomInterpreterDecorator(t *testing.T) {
}
}

func TestCustomInterpreterDecoratorV2(t *testing.T) {
var lastInstruction interpreter.InterpretableV2
optimizeArith := func(i interpreter.InterpretableV2) (interpreter.InterpretableV2, error) {
lastInstruction = i
// Only optimize the instruction if it is a call.
call, ok := i.(interpreter.InterpretableCall)
if !ok {
return i, nil
}
// Only optimize the math functions when they have constant arguments.
switch call.Function() {
case operators.Add,
operators.Subtract,
operators.Multiply,
operators.Divide:
// These are all binary operators so they should have two arguments
args := call.Args()
_, lhsIsConst := args[0].(interpreter.InterpretableConst)
_, rhsIsConst := args[1].(interpreter.InterpretableConst)
// When the values are constant then the call can be evaluated with
// an empty activation and the value returns as a constant.
if !lhsIsConst || !rhsIsConst {
return i, nil
}
val := call.Eval(interpreter.EmptyActivation())
if types.IsError(val) {
return nil, val.(*types.Err)
}
return interpreter.NewConstValue(call.ID(), val), nil
default:
return i, nil
}
}

env := testEnv(t, Variable("foo", IntType))
ast, iss := env.Compile(`foo == -1 + 2 * 3 / 3`)
if iss.Err() != nil {
t.Fatalf("env.Compile() failed: %v", iss.Err())
}
_, err := env.Program(ast,
EvalOptions(OptPartialEval),
CustomDecoratorV2(optimizeArith))
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}
call, ok := lastInstruction.(interpreter.InterpretableCall)
if !ok {
t.Errorf("got %v, expected call", lastInstruction)
}
args := call.Args()
lhs := args[0]
lastAttr, ok := lhs.(interpreter.InterpretableAttribute)
if !ok {
t.Errorf("got %v, wanted attribute", lhs)
}
absAttr := lastAttr.Attr().(interpreter.NamespacedAttribute)
varNames := absAttr.CandidateVariableNames()
if len(varNames) != 1 || varNames[0] != "foo" {
t.Errorf("got variables %v, wanted foo", varNames)
}
rhs := args[1]
lastConst, ok := rhs.(interpreter.InterpretableConst)
if !ok {
t.Errorf("got %v, wanted constant", rhs)
}
// This is the last number produced by the optimization.
if lastConst.Value().Equal(types.IntOne) == types.False {
t.Errorf("got %v as the last observed constant, wanted 1", lastConst)
}
}

// TestEstimateCostAndRuntimeCost sanity checks that the cost systems are usable from the program API.
func TestEstimateCostAndRuntimeCost(t *testing.T) {
intList := ListType(IntType)
Expand Down Expand Up @@ -3961,3 +4032,124 @@ func TestExpressionSizeLimitEarlyEnforcement(t *testing.T) {
})
}
}

func TestProgramEvalInvalidInput(t *testing.T) {
env, err := NewEnv()
if err != nil {
t.Fatalf("NewEnv() failed: %v", err)
}
ast, iss := env.Compile("true")
if iss.Err() != nil {
t.Fatalf("Compile() failed: %v", iss.Err())
}
prg, err := env.Program(ast)
if err != nil {
t.Fatalf("Program() failed: %v", err)
}

tests := []struct {
name string
input any
wantErr string
}{
{
name: "int input",
input: 123,
wantErr: "invalid input, wanted Activation or map[string]any",
},
{
name: "nil input",
input: nil,
wantErr: "invalid input, wanted Activation or map[string]any",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
_, _, err := prg.Eval(tc.input)
if err == nil || !strings.Contains(err.Error(), tc.wantErr) {
t.Errorf("Eval(%v) err = %v, expected error containing %q", tc.input, err, tc.wantErr)
}
})
}
}

func TestProgramContextEvalInvalidInput(t *testing.T) {
env, err := NewEnv()
if err != nil {
t.Fatalf("NewEnv() failed: %v", err)
}
ast, iss := env.Compile("true")
if iss.Err() != nil {
t.Fatalf("Compile() failed: %v", iss.Err())
}
prg, err := env.Program(ast)
if err != nil {
t.Fatalf("Program() failed: %v", err)
}

tests := []struct {
name string
ctx context.Context
input any
wantErr string
}{
{
name: "nil context",
ctx: nil,
input: map[string]any{},
wantErr: "context can not be nil",
},
{
name: "invalid input type",
ctx: context.Background(),
input: 123,
wantErr: "invalid input, wanted Activation or map[string]any",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
_, _, err := prg.ContextEval(tc.ctx, tc.input)
if err == nil || !strings.Contains(err.Error(), tc.wantErr) {
t.Errorf("ContextEval(%v, %v) err = %v, expected error containing %q", tc.ctx, tc.input, err, tc.wantErr)
}
})
}
}

func TestOptionalOperatorsLegacyEval(t *testing.T) {
tests := []struct {
name string
expr interpreter.Interpretable
want ref.Val
}{
{
name: "optional or",
expr: &evalOptionalOr{
id: 1,
lhs: interpreter.NewConstValue(2, types.OptionalOf(types.True)),
rhs: interpreter.NewConstValue(3, types.False),
},
want: types.OptionalOf(types.True),
},
{
name: "optional or value mismatch",
expr: &evalOptionalOrValue{
id: 4,
lhs: interpreter.NewConstValue(5, types.True),
rhs: interpreter.NewConstValue(6, types.False),
},
want: types.NoSuchOverloadErr(),
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := tc.expr.Eval(NoVars())
if got.Equal(tc.want) != types.True && !types.IsError(got) {
t.Errorf("Eval() = %v, wanted %v", got, tc.want)
}
if types.IsError(got) && got.(*types.Err).Error() != tc.want.(*types.Err).Error() {
t.Errorf("Eval() = %v, wanted %v", got, tc.want)
}
})
}
}
40 changes: 24 additions & 16 deletions cel/library.go
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ func (lib *optionalLib) CompileOptions() []EnvOption {
// ProgramOptions implements the Library interface method.
func (lib *optionalLib) ProgramOptions() []ProgramOption {
return []ProgramOption{
CustomDecorator(decorateOptionalOr),
CustomDecoratorV2(decorateOptionalOr),
}
}

Expand Down Expand Up @@ -683,7 +683,7 @@ func EnableErrorOnBadPresenceTest(value bool) EnvOption {
return features(featureEnableErrorOnBadPresenceTest, value)
}

func decorateOptionalOr(i interpreter.Interpretable) (interpreter.Interpretable, error) {
func decorateOptionalOr(i interpreter.InterpretableV2) (interpreter.InterpretableV2, error) {
call, ok := i.(interpreter.InterpretableCall)
if !ok {
return i, nil
Expand Down Expand Up @@ -720,51 +720,53 @@ func decorateOptionalOr(i interpreter.Interpretable) (interpreter.Interpretable,
// the second optional expression is evaluated and returned.
type evalOptionalOr struct {
id int64
lhs interpreter.Interpretable
rhs interpreter.Interpretable
lhs interpreter.InterpretableV2
rhs interpreter.InterpretableV2
}

// ID implements the Interpretable interface method.
func (opt *evalOptionalOr) ID() int64 {
return opt.id
}

// Eval evaluates the left-hand side optional to determine whether it contains a value, else
// proceeds with the right-hand side evaluation.
func (opt *evalOptionalOr) Eval(ctx interpreter.Activation) ref.Val {
func (opt *evalOptionalOr) Exec(frame *interpreter.ExecutionFrame) ref.Val {
// short-circuit lhs.
optLHS := opt.lhs.Eval(ctx)
optLHS := opt.lhs.Exec(frame)
switch val := optLHS.(type) {
case *types.Err, *types.Unknown:
return optLHS
case *types.Optional:
if val.HasValue() {
return optLHS
}
return opt.rhs.Eval(ctx)
return opt.rhs.Exec(frame)
default:
return types.NoSuchOverloadErr()
}
}

// Eval evaluates the left-hand side optional to determine whether it contains a value, else
// proceeds with the right-hand side evaluation.
func (opt *evalOptionalOr) Eval(ctx interpreter.Activation) ref.Val {
return opt.Exec(interpreter.AsFrame(ctx))
}

// evalOptionalOrValue selects between an optional or a concrete value. If the optional has a value,
// its value is returned, otherwise the alternative value expression is evaluated and returned.
type evalOptionalOrValue struct {
id int64
lhs interpreter.Interpretable
rhs interpreter.Interpretable
lhs interpreter.InterpretableV2
rhs interpreter.InterpretableV2
}

// ID implements the Interpretable interface method.
func (opt *evalOptionalOrValue) ID() int64 {
return opt.id
}

// Eval evaluates the left-hand side optional to determine whether it contains a value, else
// proceeds with the right-hand side evaluation.
func (opt *evalOptionalOrValue) Eval(ctx interpreter.Activation) ref.Val {
func (opt *evalOptionalOrValue) Exec(frame *interpreter.ExecutionFrame) ref.Val {
// short-circuit lhs.
optLHS := opt.lhs.Eval(ctx)
optLHS := opt.lhs.Exec(frame)

switch val := optLHS.(type) {
case *types.Err, *types.Unknown:
Expand All @@ -773,12 +775,18 @@ func (opt *evalOptionalOrValue) Eval(ctx interpreter.Activation) ref.Val {
if val.HasValue() {
return val.GetValue()
}
return opt.rhs.Eval(ctx)
return opt.rhs.Exec(frame)
default:
return types.NoSuchOverloadErr()
}
}

// Eval evaluates the left-hand side optional to determine whether it contains a value, else
// proceeds with the right-hand side evaluation.
func (opt *evalOptionalOrValue) Eval(ctx interpreter.Activation) ref.Val {
return opt.Exec(interpreter.AsFrame(ctx))
}

type timeLegacyLibrary struct{}

func (timeLegacyLibrary) CompileOptions() []EnvOption {
Expand Down
8 changes: 8 additions & 0 deletions cel/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,14 @@ func CustomDecorator(dec interpreter.InterpretableDecorator) ProgramOption {
}
}

// CustomDecoratorV2 appends an InterpreterDecoratorV2 to the program.
func CustomDecoratorV2(dec interpreter.InterpretableDecoratorV2) ProgramOption {
return func(p *prog) (*prog, error) {
p.plannerOptions = append(p.plannerOptions, interpreter.CustomDecoratorV2(dec))
return p, nil
}
}

// Functions adds function overloads that extend or override the set of CEL built-ins.
//
// Deprecated: use Function() instead to declare the function, its overload signatures,
Expand Down
Loading