Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 31 additions & 5 deletions gopls/internal/golang/addtest.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ func {{.TestFuncName}}(t *{{.TestingPackageName}}.T) {
{{- if ne $index 0}}, {{end}}
{{- if .Name}}tt.{{.Name}}{{else}}{{.Value}}{{end}}
{{- end -}}
{{- if and .Receiver.Constructor.IsVariadic (not .Receiver.Constructor.IsUnusedVariadic) }}...{{- end -}}
)

{{- /* Handles the error return from constructor. */}}
Expand Down Expand Up @@ -123,6 +124,7 @@ func {{.TestFuncName}}(t *{{.TestingPackageName}}.T) {
{{- if ne $index 0}}, {{end}}
{{- if .Name}}tt.{{.Name}}{{else}}{{.Value}}{{end}}
{{- end -}}
{{- if and .Func.IsVariadic (not .Func.IsUnusedVariadic) }}...{{- end -}}
)

{{- /* Handles the returned error before the rest of return value. */}}
Expand Down Expand Up @@ -167,6 +169,12 @@ type function struct {
Name string
Args []field
Results []field
// IsVariadic holds information if the function is variadic.
// In case of the IsVariadic=true must expand it in the function-under-test call.
IsVariadic bool
// IsUnusedVariadic holds information if the function's variadic args is not used.
// In this case ... must be skipped.
IsUnusedVariadic bool
}

type receiver struct {
Expand Down Expand Up @@ -238,7 +246,7 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
extraImports = make(map[string]string) // imports to add to test file
)

var collectImports = func(file *ast.File) (map[string]string, error) {
collectImports := func(file *ast.File) (map[string]string, error) {
imps := make(map[string]string)
for _, spec := range file.Imports {
// TODO(hxjiang): support dot imports.
Expand Down Expand Up @@ -478,21 +486,31 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
PackageName: qual(pkg.Types()),
TestFuncName: testName,
Func: function{
Name: fn.Name(),
Name: fn.Name(),
IsVariadic: sig.Variadic(),
},
}

isContextType := func(t types.Type) bool {
return typesinternal.IsTypeNamed(t, "context", "Context")
}

isUnusedParameter := func(name string) bool {
return name == "" || name == "_"
}

for i := range sig.Params().Len() {
param := sig.Params().At(i)
name, typ := param.Name(), param.Type()
f := field{Type: types.TypeString(typ, qual)}
if i == 0 && isContextType(typ) {
f.Value = qual(types.NewPackage("context", "context")) + ".Background()"
} else if name == "" || name == "_" {
} else if isUnusedParameter(name) && data.Func.IsVariadic && sig.Params().Len()-1 == i {
// The last argument is the variadic argument, and it's not used in the function body,
// so we don't need to render it in the test case struct.
data.Func.IsUnusedVariadic = true
continue
} else if isUnusedParameter(name) {
f.Value, _ = typesinternal.ZeroString(typ, qual)
} else {
f.Name = name
Expand Down Expand Up @@ -619,14 +637,22 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
}

if constructor != nil {
data.Receiver.Constructor = &function{Name: constructor.Name()}
data.Receiver.Constructor = &function{
Name: constructor.Name(),
IsVariadic: constructor.Signature().Variadic(),
}
for i := range constructor.Signature().Params().Len() {
param := constructor.Signature().Params().At(i)
name, typ := param.Name(), param.Type()
f := field{Type: types.TypeString(typ, qual)}
if i == 0 && isContextType(typ) {
f.Value = qual(types.NewPackage("context", "context")) + ".Background()"
} else if name == "" || name == "_" {
} else if isUnusedParameter(name) && data.Receiver.Constructor.IsVariadic && constructor.Signature().Params().Len()-1 == i {
// The last argument is the variadic argument, and it's not used in the function body,
// so we don't need to render it in the test case struct.
data.Receiver.Constructor.IsUnusedVariadic = true
continue
} else if isUnusedParameter(name) {
f.Value, _ = typesinternal.ZeroString(typ, qual)
} else {
f.Name = name
Expand Down
179 changes: 179 additions & 0 deletions gopls/internal/test/marker/testdata/codeaction/addtest.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1540,3 +1540,182 @@ type Foo struct {}
func NewFoo()

func (*Foo) Method[T any]() {} // no suggested fix
-- varargsinput/varargsinput.go --
package main

func Function(args ...string) (out, out1, out2 string) {return args[0], "", ""} //@codeaction("Function", "source.addTest", edit=function_varargsinput)

type Foo struct {}

func NewFoo(args ...string) (*Foo, error) {
_ = args
return nil, nil
}

func (*Foo) Method(args ...string) (out, out1, out2 string) {return args[0], "", ""} //@codeaction("Method", "source.addTest", edit=method_varargsinput)
-- varargsinput/varargsinput_test.go --
package main_test
-- @function_varargsinput/varargsinput/varargsinput_test.go --
@@ -2 +2,34 @@
+
+import (
+ "testing"
+
+ "golang.org/lsptests/addtest/varargsinput"
+)
+
+func TestFunction(t *testing.T) {
+ tests := []struct {
+ name string // description of this test case
+ // Named input parameters for target function.
+ args []string
+ want string
+ want2 string
+ want3 string
+ }{
+ // TODO: Add test cases.
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, got2, got3 := main.Function(tt.args...)
+ // TODO: update the condition below to compare got with tt.want.
+ if true {
+ t.Errorf("Function() = %v, want %v", got, tt.want)
+ }
+ if true {
+ t.Errorf("Function() = %v, want %v", got2, tt.want2)
+ }
+ if true {
+ t.Errorf("Function() = %v, want %v", got3, tt.want3)
+ }
+ })
+ }
+}
-- @method_varargsinput/varargsinput/varargsinput_test.go --
@@ -2 +2,40 @@
+
+import (
+ "testing"
+
+ "golang.org/lsptests/addtest/varargsinput"
+)
+
+func TestFoo_Method(t *testing.T) {
+ tests := []struct {
+ name string // description of this test case
+ // Named input parameters for receiver constructor.
+ cargs []string
+ // Named input parameters for target function.
+ args []string
+ want string
+ want2 string
+ want3 string
+ }{
+ // TODO: Add test cases.
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ f, err := main.NewFoo(tt.cargs...)
+ if err != nil {
+ t.Fatalf("could not construct receiver type: %v", err)
+ }
+ got, got2, got3 := f.Method(tt.args...)
+ // TODO: update the condition below to compare got with tt.want.
+ if true {
+ t.Errorf("Method() = %v, want %v", got, tt.want)
+ }
+ if true {
+ t.Errorf("Method() = %v, want %v", got2, tt.want2)
+ }
+ if true {
+ t.Errorf("Method() = %v, want %v", got3, tt.want3)
+ }
+ })
+ }
+}
-- unusedvarargsinput/unusedvarargsinput.go --
package main

func Function(_ ...string) (out, out1, out2 string) {return "", "", ""} //@codeaction("Function", "source.addTest", edit=function_unusedvarargsinput)

type Foo struct {}

func NewFoo(_ ...string) (*Foo, error) {
return nil, nil
}

func (*Foo) Method(_ ...string) (out, out1, out2 string) {return "", "", ""} //@codeaction("Method", "source.addTest", edit=method_unusedvarargsinput)
-- unusedvarargsinput/unusedvarargsinput_test.go --
package main_test
-- @function_unusedvarargsinput/unusedvarargsinput/unusedvarargsinput_test.go --
@@ -2 +2,32 @@
+
+import (
+ "testing"
+
+ "golang.org/lsptests/addtest/unusedvarargsinput"
+)
+
+func TestFunction(t *testing.T) {
+ tests := []struct {
+ name string // description of this test case
+ want string
+ want2 string
+ want3 string
+ }{
+ // TODO: Add test cases.
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, got2, got3 := main.Function()
+ // TODO: update the condition below to compare got with tt.want.
+ if true {
+ t.Errorf("Function() = %v, want %v", got, tt.want)
+ }
+ if true {
+ t.Errorf("Function() = %v, want %v", got2, tt.want2)
+ }
+ if true {
+ t.Errorf("Function() = %v, want %v", got3, tt.want3)
+ }
+ })
+ }
+}
-- @method_unusedvarargsinput/unusedvarargsinput/unusedvarargsinput_test.go --
@@ -2 +2,36 @@
+
+import (
+ "testing"
+
+ "golang.org/lsptests/addtest/unusedvarargsinput"
+)
+
+func TestFoo_Method(t *testing.T) {
+ tests := []struct {
+ name string // description of this test case
+ want string
+ want2 string
+ want3 string
+ }{
+ // TODO: Add test cases.
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ f, err := main.NewFoo()
+ if err != nil {
+ t.Fatalf("could not construct receiver type: %v", err)
+ }
+ got, got2, got3 := f.Method()
+ // TODO: update the condition below to compare got with tt.want.
+ if true {
+ t.Errorf("Method() = %v, want %v", got, tt.want)
+ }
+ if true {
+ t.Errorf("Method() = %v, want %v", got2, tt.want2)
+ }
+ if true {
+ t.Errorf("Method() = %v, want %v", got3, tt.want3)
+ }
+ })
+ }
+}