diff --git a/pkg/generators/openapi.go b/pkg/generators/openapi.go index c5c009381..d42455a3a 100644 --- a/pkg/generators/openapi.go +++ b/pkg/generators/openapi.go @@ -957,6 +957,14 @@ func (g openAPITypeWriter) generateProperty(m *types.Member, parent *types.Type) if name == "" { return nil } + if listType, err := getSingleTagsValue(m.CommentLines, "listType"); err != nil { + return err + } else if listType != "" { + t := resolveAliasAndPtrType(m.Type) + if t.Kind != types.Slice && t.Kind != types.Array { + return fmt.Errorf("listType marker may only be used on slice or array fields, but was used on %v", m.Name) + } + } validationSchema, err := ParseCommentTags(m.Type, m.CommentLines, markerPrefix) if err != nil { return err diff --git a/pkg/generators/openapi_test.go b/pkg/generators/openapi_test.go index 55be4c557..ee63a3925 100644 --- a/pkg/generators/openapi_test.go +++ b/pkg/generators/openapi_test.go @@ -88,6 +88,78 @@ func assertEqual(t *testing.T, got, want string) { } } +func TestListTypeValidation(t *testing.T) { + tests := []struct { + name string + inputFile string + expectedError string + }{ + { + name: "listType on slice", + inputFile: ` + package foo + + type Blah struct { + // +listType=atomic + Slice []string + } + `, + expectedError: "", + }, + { + name: "listType on array", + inputFile: ` + package foo + + type Blah struct { + // +listType=atomic + Array [1]string + } + `, + expectedError: "", + }, + { + name: "listType on string (fail)", + inputFile: ` + package foo + + type Blah struct { + // +listType=atomic + String string + } + `, + expectedError: "listType marker may only be used on slice or array fields", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + packagestest.TestAll(t, func(t *testing.T, exporter packagestest.Exporter) { + exported := packagestest.Export(t, exporter, []packagestest.Module{{ + Name: "example.com/base", + Files: map[string]interface{}{ + "foo/foo.go": test.inputFile, + }, + }}) + defer exported.Cleanup() + + _, funcErr, _, _, _ := testOpenAPITypeWriter(t, exported.Config) + if test.expectedError == "" { + if funcErr != nil { + t.Errorf("Expected no error, got: %v", funcErr) + } + } else { + if funcErr == nil { + t.Errorf("Expected error containing %q, got nil", test.expectedError) + } else if !strings.Contains(funcErr.Error(), test.expectedError) { + t.Errorf("Expected error containing %q, got: %v", test.expectedError, funcErr) + } + } + }) + }) + } +} + func TestSimple(t *testing.T) { inputFile := ` package foo