From 8956ee2343f0388d0df6e0988a21a6c3e9f1e49f Mon Sep 17 00:00:00 2001 From: Ville Vesilehto Date: Tue, 2 Dec 2025 21:32:09 +0200 Subject: [PATCH] fix: auto-deref for map/slice and conditionals Ensure pointers are automatically derefer'd in map keys, slice indices, and conditional expressions. Add dereferencing for map keys to prefer exact type matches (e.g. map[*T]) before falling back to deref'd types. Regression test added. Signed-off-by: Ville Vesilehto --- checker/checker.go | 12 ++++ compiler/compiler.go | 15 +++++ test/issues/836/issue_test.go | 101 ++++++++++++++++++++++++++++++++++ 3 files changed, 128 insertions(+) create mode 100644 test/issues/836/issue_test.go diff --git a/checker/checker.go b/checker/checker.go index 0210b416..c0c0dd2e 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -549,6 +549,13 @@ func (v *Checker) memberNode(node *ast.MemberNode) Nature { switch base.Kind { case reflect.Map: + // If the map key is a pointer, we should not dereference the property. + if !prop.AssignableTo(base.Key(&v.config.NtCache)) { + propDeref := prop.Deref(&v.config.NtCache) + if propDeref.AssignableTo(base.Key(&v.config.NtCache)) { + prop = propDeref + } + } if !prop.AssignableTo(base.Key(&v.config.NtCache)) && !prop.IsUnknown(&v.config.NtCache) { return v.error(node.Property, "cannot use %s to get an element from %s", prop.String(), base.String()) } @@ -562,6 +569,7 @@ func (v *Checker) memberNode(node *ast.MemberNode) Nature { return base.Elem(&v.config.NtCache) case reflect.Array, reflect.Slice: + prop = prop.Deref(&v.config.NtCache) if !prop.IsInteger && !prop.IsUnknown(&v.config.NtCache) { return v.error(node.Property, "array elements can only be selected using an integer (got %s)", prop.String()) } @@ -607,6 +615,7 @@ func (v *Checker) sliceNode(node *ast.SliceNode) Nature { if node.From != nil { from := v.visit(node.From) + from = from.Deref(&v.config.NtCache) if !from.IsInteger && !from.IsUnknown(&v.config.NtCache) { return v.error(node.From, "non-integer slice index %v", from.String()) } @@ -614,6 +623,7 @@ func (v *Checker) sliceNode(node *ast.SliceNode) Nature { if node.To != nil { to := v.visit(node.To) + to = to.Deref(&v.config.NtCache) if !to.IsInteger && !to.IsUnknown(&v.config.NtCache) { return v.error(node.To, "non-integer slice index %v", to.String()) } @@ -942,6 +952,7 @@ func (v *Checker) checkBuiltinGet(node *ast.BuiltinNode) Nature { base := v.visit(node.Arguments[0]) prop := v.visit(node.Arguments[1]) + prop = prop.Deref(&v.config.NtCache) if id, ok := node.Arguments[0].(*ast.IdentifierNode); ok && id.Value == "$env" { if s, ok := node.Arguments[1].(*ast.StringNode); ok { @@ -1260,6 +1271,7 @@ func (v *Checker) sequenceNode(node *ast.SequenceNode) Nature { func (v *Checker) conditionalNode(node *ast.ConditionalNode) Nature { c := v.visit(node.Cond) + c = c.Deref(&v.config.NtCache) if !c.IsBool() && !c.IsUnknown(&v.config.NtCache) { return v.error(node.Cond, "non-bool expression (type %v) used as condition", c.String()) } diff --git a/compiler/compiler.go b/compiler/compiler.go index d8a53244..f657ecce 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -716,6 +716,18 @@ func (c *compiler) MemberNode(node *ast.MemberNode) { if op == OpFetch { c.compile(node.Property) + deref := true + // If the map key is a pointer, we should not dereference the property. + if node.Node.Type() != nil && node.Node.Type().Kind() == reflect.Map { + keyType := node.Node.Type().Key() + propType := node.Property.Type() + if propType != nil && propType.AssignableTo(keyType) { + deref = false + } + } + if deref { + c.derefInNeeded(node.Property) + } c.emit(OpFetch) } else { c.emitLocation(node.Location(), op, c.addConstant( @@ -728,11 +740,13 @@ func (c *compiler) SliceNode(node *ast.SliceNode) { c.compile(node.Node) if node.To != nil { c.compile(node.To) + c.derefInNeeded(node.To) } else { c.emit(OpLen) } if node.From != nil { c.compile(node.From) + c.derefInNeeded(node.From) } else { c.emitPush(0) } @@ -1213,6 +1227,7 @@ func (c *compiler) lookupVariable(name string) (int, bool) { func (c *compiler) ConditionalNode(node *ast.ConditionalNode) { c.compile(node.Cond) + c.derefInNeeded(node.Cond) otherwise := c.emit(OpJumpIfFalse, placeholder) c.emit(OpPop) diff --git a/test/issues/836/issue_test.go b/test/issues/836/issue_test.go new file mode 100644 index 00000000..e4ef6b34 --- /dev/null +++ b/test/issues/836/issue_test.go @@ -0,0 +1,101 @@ +package issue_test + +import ( + "testing" + + "github.com/expr-lang/expr" + "github.com/expr-lang/expr/internal/testify/require" +) + +type InputStruct struct { + Enabled *bool `json:"enabled"` +} + +func TestIssue836(t *testing.T) { + str := "foo" + ptrStr := &str + b := true + ptrBool := &b + i := 1 + ptrInt := &i + + env := map[string]interface{}{ + "ptrStr": ptrStr, + "ptrBool": ptrBool, + "ptrInt": ptrInt, + "arr": []int{1, 2, 3}, + "mapPtr": map[*int]int{ptrInt: 42}, + } + + t.Run("map access with pointer key", func(t *testing.T) { + program, err := expr.Compile(`{"foo": "bar"}[ptrStr]`, expr.Env(env)) + require.NoError(t, err) + + output, err := expr.Run(program, env) + require.NoError(t, err) + require.Equal(t, "bar", output) + }) + + t.Run("conditional with pointer condition", func(t *testing.T) { + program, err := expr.Compile(`ptrBool ? 1 : 0`, expr.Env(env)) + require.NoError(t, err) + + output, err := expr.Run(program, env) + require.NoError(t, err) + require.Equal(t, 1, output) + }) + + t.Run("get() with pointer key", func(t *testing.T) { + program, err := expr.Compile(`get({"foo": "bar"}, ptrStr)`, expr.Env(env)) + require.NoError(t, err) + + output, err := expr.Run(program, env) + require.NoError(t, err) + require.Equal(t, "bar", output) + }) + + t.Run("struct field pointer check in ternary", func(t *testing.T) { + var v InputStruct + // v.Enabled is nil + + env := map[string]any{ + "v": v, + } + + code := `v.Enabled == nil ? 'default' : ( v.Enabled ? 'enabled' : 'disabled' )` + + program, err := expr.Compile(code, expr.Env(env)) + require.NoError(t, err) + + output, err := expr.Run(program, env) + require.NoError(t, err) + require.Equal(t, "default", output) + }) + + t.Run("struct field pointer check in ternary (enabled)", func(t *testing.T) { + b := true + v := InputStruct{Enabled: &b} + + env := map[string]any{ + "v": v, + } + + code := `v.Enabled == nil ? 'default' : ( v.Enabled ? 'enabled' : 'disabled' )` + + program, err := expr.Compile(code, expr.Env(env)) + require.NoError(t, err) + + output, err := expr.Run(program, env) + require.NoError(t, err) + require.Equal(t, "enabled", output) + }) + + t.Run("slice with pointer indices", func(t *testing.T) { + program, err := expr.Compile(`arr[ptrInt:ptrInt]`, expr.Env(env)) + require.NoError(t, err) + + output, err := expr.Run(program, env) + require.NoError(t, err) + require.Equal(t, []int{}, output) + }) +}