diff --git a/cl/compile.go b/cl/compile.go index bc073c23de..efc048c80d 100644 --- a/cl/compile.go +++ b/cl/compile.go @@ -853,8 +853,21 @@ func (p *context) compileInstrOrValue(b llssa.Builder, iv instrOrValue, asValue b.DeferStackDrain() } case *ssa.BinOp: - x := p.compileValue(b, v.X) - y := p.compileValue(b, v.Y) + if isUntypedNilConst(v.X) && isUntypedNilConst(v.Y) { + switch v.Op { + case token.EQL: + ret = p.prog.BoolVal(true) + break + case token.NEQ: + ret = p.prog.BoolVal(false) + break + } + if !ret.IsNil() { + break + } + } + x := p.compileValueAs(b, v.X, v.Y.Type()) + y := p.compileValueAs(b, v.Y, v.X.Type()) ret = b.BinOp(v.Op, x, y) case *ssa.UnOp: if v.Op == token.MUL { @@ -891,10 +904,18 @@ func (p *context) compileInstrOrValue(b llssa.Builder, iv instrOrValue, asValue } case *ssa.ChangeType: t := v.Type() + if isUntypedNilConst(v.X) { + ret = p.nilOf(t) + break + } x := p.compileValue(b, v.X) ret = b.ChangeType(p.type_(t, llssa.InGo), x) case *ssa.Convert: t := v.Type() + if isUntypedNilConst(v.X) { + ret = p.nilOf(t) + break + } x := p.compileValue(b, v.X) ret = b.Convert(p.type_(t, llssa.InGo), x) case *ssa.FieldAddr: @@ -967,6 +988,10 @@ func (p *context) compileInstrOrValue(b llssa.Builder, iv instrOrValue, asValue } } t := p.type_(v.Type(), llssa.InGo) + if isUntypedNilConst(v.X) { + ret = p.prog.Nil(t) + break + } if unop, ok := v.X.(*ssa.UnOp); ok && unop.Op == token.MUL { if vt := p.type_(unop.Type(), llssa.InGo); vt.RawType() != nil { if p.isLargeNonPointerValue(vt) { @@ -1047,6 +1072,26 @@ func (p *context) compileInstrOrValue(b llssa.Builder, iv instrOrValue, asValue return ret } +func isUntypedNilConst(v ssa.Value) bool { + c, ok := v.(*ssa.Const) + if !ok || c.Value != nil { + return false + } + basic, ok := c.Type().Underlying().(*types.Basic) + return ok && basic.Kind() == types.UntypedNil +} + +func (p *context) nilOf(typ types.Type) llssa.Expr { + return p.prog.Nil(p.type_(typ, llssa.InGo)) +} + +func (p *context) compileValueAs(b llssa.Builder, v ssa.Value, typ types.Type) llssa.Expr { + if isUntypedNilConst(v) { + return p.nilOf(typ) + } + return p.compileValue(b, v) +} + func (p *context) assertNilDerefBase(b llssa.Builder, addr ssa.Value) { if field, ok := addr.(*ssa.FieldAddr); ok { base := p.compileValue(b, field.X) diff --git a/test/go/switch_nil_test.go b/test/go/switch_nil_test.go new file mode 100644 index 0000000000..f825aee370 --- /dev/null +++ b/test/go/switch_nil_test.go @@ -0,0 +1,115 @@ +package gotest + +import "testing" + +func TestSwitchNilCases(t *testing.T) { + switch f := func() {}; f { + case nil: + t.Fatal("non-nil func matched nil") + default: + } + + var nilFunc func() + switch nilFunc { + case nil: + default: + t.Fatal("nil func did not match nil") + } + + switch m := make(map[int]int); m { + case nil: + t.Fatal("non-nil map matched nil") + default: + } + + var nilMap map[int]int + switch nilMap { + case nil: + default: + t.Fatal("nil map did not match nil") + } + + switch s := make([]int, 1); s { + case nil: + t.Fatal("non-nil slice matched nil") + default: + } + + var nilSlice []int + switch nilSlice { + case nil: + default: + t.Fatal("nil slice did not match nil") + } + + switch c1, c2 := make(chan int), make(chan int); c1 { + case nil: + t.Fatal("non-nil channel matched nil") + case c2: + t.Fatal("channel matched a different channel") + case c1: + default: + t.Fatal("channel did not match itself") + } + + switch (*int)(nil) { + case nil: + case any(nil): + t.Fatal("typed nil pointer matched nil interface") + default: + t.Fatal("typed nil pointer did not match nil") + } +} + +func TestGenericSwitchNilCases(t *testing.T) { + checkGenericNilSwitch[int](t) + checkGenericNilSwitch[string](t) +} + +func checkGenericNilSwitch[T any](t *testing.T) { + switch []T(nil) { + case nil: + default: + t.Fatal("nil generic slice did not match nil") + } + if []T(nil) != nil { + t.Fatal("nil generic slice compared non-nil") + } + + switch make([]T, 1) { + case nil: + t.Fatal("non-nil generic slice matched nil") + default: + } + + switch (func() T)(nil) { + case nil: + default: + t.Fatal("nil generic func did not match nil") + } + if (func() T)(nil) != nil { + t.Fatal("nil generic func compared non-nil") + } + + var zero T + switch func() T { return zero } { + case nil: + t.Fatal("non-nil generic func matched nil") + default: + } + + switch (map[int]T)(nil) { + case nil: + default: + t.Fatal("nil generic map did not match nil") + } + if (map[int]T)(nil) != nil { + t.Fatal("nil generic map compared non-nil") + } + + switch make(map[int]T) { + case nil: + t.Fatal("non-nil generic map matched nil") + default: + } +} diff --git a/test/goroot/xfail.yaml b/test/goroot/xfail.yaml index a5a66d10b4..ae8e5a0787 100644 --- a/test/goroot/xfail.yaml +++ b/test/goroot/xfail.yaml @@ -1818,10 +1818,6 @@ xfails: directive: run case: unsafebuiltins.go reason: current main goroot run failure on darwin/arm64 - - platform: darwin/arm64 - directive: run - case: switch.go - reason: current main goroot run failure on darwin/arm64 - platform: darwin/arm64 directive: run case: zerodivide.go @@ -3660,10 +3656,6 @@ xfails: directive: run case: fixedbugs/issue47928.go reason: current main goroot run failure on darwin/arm64 - - platform: darwin/arm64 - directive: run - case: fixedbugs/issue53635.go - reason: current main goroot run failure on darwin/arm64 - platform: darwin/arm64 directive: run case: fixedbugs/issue60601.go @@ -3858,10 +3850,6 @@ xfails: directive: run case: fixedbugs/issue47928.go reason: current main goroot run failure on linux/amd64 - - platform: linux/amd64 - directive: run - case: fixedbugs/issue53635.go - reason: current main goroot run failure on linux/amd64 - platform: linux/amd64 directive: run case: fixedbugs/issue60601.go