Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 47 additions & 2 deletions cl/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -853,8 +853,21 @@ func (p *context) compileInstrOrValue(b llssa.Builder, iv instrOrValue, asValue
b.DeferStackDrain()
}
case *ssa.BinOp:
x := p.compileValue(b, v.X)
y := p.compileValue(b, v.Y)
if isUntypedNilConst(v.X) && isUntypedNilConst(v.Y) {
switch v.Op {
case token.EQL:
ret = p.prog.BoolVal(true)
break
case token.NEQ:
ret = p.prog.BoolVal(false)
break
}
if !ret.IsNil() {
break
}
}
x := p.compileValueAs(b, v.X, v.Y.Type())
y := p.compileValueAs(b, v.Y, v.X.Type())
ret = b.BinOp(v.Op, x, y)
case *ssa.UnOp:
if v.Op == token.MUL {
Expand Down Expand Up @@ -891,10 +904,18 @@ func (p *context) compileInstrOrValue(b llssa.Builder, iv instrOrValue, asValue
}
case *ssa.ChangeType:
t := v.Type()
if isUntypedNilConst(v.X) {
ret = p.nilOf(t)
break
}
x := p.compileValue(b, v.X)
ret = b.ChangeType(p.type_(t, llssa.InGo), x)
case *ssa.Convert:
t := v.Type()
if isUntypedNilConst(v.X) {
ret = p.nilOf(t)
break
}
x := p.compileValue(b, v.X)
ret = b.Convert(p.type_(t, llssa.InGo), x)
case *ssa.FieldAddr:
Expand Down Expand Up @@ -967,6 +988,10 @@ func (p *context) compileInstrOrValue(b llssa.Builder, iv instrOrValue, asValue
}
}
t := p.type_(v.Type(), llssa.InGo)
if isUntypedNilConst(v.X) {
ret = p.prog.Nil(t)
break
}
if unop, ok := v.X.(*ssa.UnOp); ok && unop.Op == token.MUL {
if vt := p.type_(unop.Type(), llssa.InGo); vt.RawType() != nil {
if p.isLargeNonPointerValue(vt) {
Expand Down Expand Up @@ -1047,6 +1072,26 @@ func (p *context) compileInstrOrValue(b llssa.Builder, iv instrOrValue, asValue
return ret
}

func isUntypedNilConst(v ssa.Value) bool {
c, ok := v.(*ssa.Const)
if !ok || c.Value != nil {
return false
}
basic, ok := c.Type().Underlying().(*types.Basic)
return ok && basic.Kind() == types.UntypedNil
}

func (p *context) nilOf(typ types.Type) llssa.Expr {
return p.prog.Nil(p.type_(typ, llssa.InGo))
}

func (p *context) compileValueAs(b llssa.Builder, v ssa.Value, typ types.Type) llssa.Expr {
if isUntypedNilConst(v) {
return p.nilOf(typ)
}
return p.compileValue(b, v)
}

func (p *context) assertNilDerefBase(b llssa.Builder, addr ssa.Value) {
if field, ok := addr.(*ssa.FieldAddr); ok {
base := p.compileValue(b, field.X)
Expand Down
115 changes: 115 additions & 0 deletions test/go/switch_nil_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
package gotest

import "testing"

func TestSwitchNilCases(t *testing.T) {
switch f := func() {}; f {
case nil:
t.Fatal("non-nil func matched nil")
default:
}

var nilFunc func()
switch nilFunc {
case nil:
default:
t.Fatal("nil func did not match nil")
}

switch m := make(map[int]int); m {
case nil:
t.Fatal("non-nil map matched nil")
default:
}

var nilMap map[int]int
switch nilMap {
case nil:
default:
t.Fatal("nil map did not match nil")
}

switch s := make([]int, 1); s {
case nil:
t.Fatal("non-nil slice matched nil")
default:
}

var nilSlice []int
switch nilSlice {
case nil:
default:
t.Fatal("nil slice did not match nil")
}

switch c1, c2 := make(chan int), make(chan int); c1 {
case nil:
t.Fatal("non-nil channel matched nil")
case c2:
t.Fatal("channel matched a different channel")
case c1:
default:
t.Fatal("channel did not match itself")
}

switch (*int)(nil) {
case nil:
case any(nil):
t.Fatal("typed nil pointer matched nil interface")
default:
t.Fatal("typed nil pointer did not match nil")
}
}

func TestGenericSwitchNilCases(t *testing.T) {
checkGenericNilSwitch[int](t)
checkGenericNilSwitch[string](t)
}

func checkGenericNilSwitch[T any](t *testing.T) {
switch []T(nil) {
case nil:
default:
t.Fatal("nil generic slice did not match nil")
}
if []T(nil) != nil {
t.Fatal("nil generic slice compared non-nil")
}

switch make([]T, 1) {
case nil:
t.Fatal("non-nil generic slice matched nil")
default:
}

switch (func() T)(nil) {
case nil:
default:
t.Fatal("nil generic func did not match nil")
}
if (func() T)(nil) != nil {
t.Fatal("nil generic func compared non-nil")
}

var zero T
switch func() T { return zero } {
case nil:
t.Fatal("non-nil generic func matched nil")
default:
}

switch (map[int]T)(nil) {
case nil:
default:
t.Fatal("nil generic map did not match nil")
}
if (map[int]T)(nil) != nil {
t.Fatal("nil generic map compared non-nil")
}

switch make(map[int]T) {
case nil:
t.Fatal("non-nil generic map matched nil")
default:
}
}
12 changes: 0 additions & 12 deletions test/goroot/xfail.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1818,10 +1818,6 @@ xfails:
directive: run
case: unsafebuiltins.go
reason: current main goroot run failure on darwin/arm64
- platform: darwin/arm64
directive: run
case: switch.go
reason: current main goroot run failure on darwin/arm64
- platform: darwin/arm64
directive: run
case: zerodivide.go
Expand Down Expand Up @@ -3660,10 +3656,6 @@ xfails:
directive: run
case: fixedbugs/issue47928.go
reason: current main goroot run failure on darwin/arm64
- platform: darwin/arm64
directive: run
case: fixedbugs/issue53635.go
reason: current main goroot run failure on darwin/arm64
- platform: darwin/arm64
directive: run
case: fixedbugs/issue60601.go
Expand Down Expand Up @@ -3858,10 +3850,6 @@ xfails:
directive: run
case: fixedbugs/issue47928.go
reason: current main goroot run failure on linux/amd64
- platform: linux/amd64
directive: run
case: fixedbugs/issue53635.go
reason: current main goroot run failure on linux/amd64
- platform: linux/amd64
directive: run
case: fixedbugs/issue60601.go
Expand Down
Loading