diff --git a/pkg/ffapi/apiserver.go b/pkg/ffapi/apiserver.go index c800436..3c8b226 100644 --- a/pkg/ffapi/apiserver.go +++ b/pkg/ffapi/apiserver.go @@ -22,9 +22,12 @@ import ( "io" "net" "net/http" + "reflect" "strings" "time" + "github.com/getkin/kin-openapi/openapi3" + "github.com/getkin/kin-openapi/openapi3gen" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" @@ -71,26 +74,29 @@ type apiServer[T any] struct { monitoringPublicURL string mux *mux.Router oah *OpenAPIHandlerFactory + baseSwaggerGenOptions SwaggerGenOptions APIServerOptions[T] } type APIServerOptions[T any] struct { - MetricsRegistry metric.MetricsRegistry - MetricsSubsystemName string - Routes []*Route // move to use VersionedAPIs for support of Tags and ExternalDocs - VersionedAPIs *VersionedAPIs - MonitoringRoutes []*Route - EnrichRequest func(r *APIRequest) (T, error) - Description string - APIConfig config.Section - MonitoringConfig config.Section - CORSConfig config.Section - FavIcon16 []byte - FavIcon32 []byte - PanicOnMissingDescription bool - SupportFieldRedaction bool - HandleYAML bool + MetricsRegistry metric.MetricsRegistry + MetricsSubsystemName string + Routes []*Route // move to use VersionedAPIs for support of Tags and ExternalDocs + VersionedAPIs *VersionedAPIs + MonitoringRoutes []*Route + EnrichRequest func(r *APIRequest) (T, error) + Description string + APIConfig config.Section + MonitoringConfig config.Section + CORSConfig config.Section + FavIcon16 []byte + FavIcon32 []byte + PanicOnMissingDescription bool + SupportFieldRedaction bool + HandleYAML bool + SwaggerExportComponentOpts *openapi3gen.ExportComponentSchemasOptions + SwaggerAdditionalSchemaCustomizer func(name string, t reflect.Type, tag reflect.StructTag, schema *openapi3.Schema) error } type VersionedAPIs struct { @@ -130,6 +136,15 @@ func NewAPIServer[T any](ctx context.Context, options APIServerOptions[T]) APISe apiDynamicPublicURLHeader: options.APIConfig.GetString(ConfAPIDynamicPublicURLHeader), APIServerOptions: options, started: make(chan struct{}), + baseSwaggerGenOptions: SwaggerGenOptions{ + Title: options.Description, + Version: "1.0", + PanicOnMissingDescription: options.PanicOnMissingDescription, + DefaultRequestTimeout: options.APIConfig.GetDuration(ConfAPIRequestTimeout), + SupportFieldRedaction: options.SupportFieldRedaction, + ExportComponentOpts: options.SwaggerExportComponentOpts, + AdditionalSchemaCustomizer: options.SwaggerAdditionalSchemaCustomizer, + }, } if as.FavIcon16 == nil { as.FavIcon16 = ffLogo16 @@ -295,14 +310,8 @@ func (as *apiServer[T]) createMuxRouter(ctx context.Context) (*mux.Router, error hf := as.handlerFactory(logrus.InfoLevel) as.oah = &OpenAPIHandlerFactory{ - BaseSwaggerGenOptions: SwaggerGenOptions{ - Title: as.Description, - Version: "1.0", - PanicOnMissingDescription: as.PanicOnMissingDescription, - DefaultRequestTimeout: as.requestTimeout, - SupportFieldRedaction: as.SupportFieldRedaction, - }, - StaticPublicURL: as.apiPublicURL, // this is most likely not yet set, we'll ensure its set later on + BaseSwaggerGenOptions: as.baseSwaggerGenOptions, + StaticPublicURL: as.apiPublicURL, // this is most likely not yet set, we'll ensure its set later on } if as.monitoringEnabled { diff --git a/pkg/ffapi/openapi3.go b/pkg/ffapi/openapi3.go index b2fdfc3..b007977 100644 --- a/pkg/ffapi/openapi3.go +++ b/pkg/ffapi/openapi3.go @@ -66,8 +66,10 @@ type SwaggerGenOptions struct { RouteCustomizations func(ctx context.Context, sg *SwaggerGen, route *Route, op *openapi3.Operation) // OpenAPI 3.0.x specific options - Tags openapi3.Tags - ExternalDocs *openapi3.ExternalDocs + Tags openapi3.Tags + ExternalDocs *openapi3.ExternalDocs + ExportComponentOpts *openapi3gen.ExportComponentSchemasOptions + AdditionalSchemaCustomizer func(name string, t reflect.Type, tag reflect.StructTag, schema *openapi3.Schema) error } type BaseURLVariable struct { @@ -82,12 +84,16 @@ var ( ) type SwaggerGen struct { - options *SwaggerGenOptions + options *SwaggerGenOptions + ffTagHandler *OpenAPITagHandler } func NewSwaggerGen(options *SwaggerGenOptions) *SwaggerGen { return &SwaggerGen{ options: options, + ffTagHandler: &OpenAPITagHandler{ + PanicOnMissingDescription: options.PanicOnMissingDescription, + }, } } @@ -164,12 +170,12 @@ func (sg *SwaggerGen) initInput(op *openapi3.Operation) { } } -func (sg *SwaggerGen) isTrue(str string) bool { +func isTrue(str string) bool { return strings.EqualFold(str, "true") } func (sg *SwaggerGen) ffInputTagHandler(ctx context.Context, route *Route, name string, tag reflect.StructTag, schema *openapi3.Schema) error { - if sg.isTrue(tag.Get("ffexcludeinput")) { + if isTrue(tag.Get("ffexcludeinput")) { return &openapi3gen.ExcludeSchemaSentinel{} } if taggedRoutes, ok := tag.Lookup("ffexcludeinput"); ok { @@ -179,17 +185,17 @@ func (sg *SwaggerGen) ffInputTagHandler(ctx context.Context, route *Route, name } } } - return sg.ffTagHandler(ctx, route, name, tag, schema) + return sg.ffTagHandler.HandleFFTags(ctx, route, name, tag, schema) } func (sg *SwaggerGen) ffOutputTagHandler(ctx context.Context, route *Route, name string, tag reflect.StructTag, schema *openapi3.Schema) error { - if sg.isTrue(tag.Get("ffexcludeoutput")) { + if isTrue(tag.Get("ffexcludeoutput")) { return &openapi3gen.ExcludeSchemaSentinel{} } - return sg.ffTagHandler(ctx, route, name, tag, schema) + return sg.ffTagHandler.HandleFFTags(ctx, route, name, tag, schema) } -func (sg *SwaggerGen) applyFFExtensionsTag(ctx context.Context, schema *openapi3.Schema, tag string) error { +func applyFFExtensionsTag(ctx context.Context, schema *openapi3.Schema, tag string) error { if tag == "" { return nil } @@ -211,16 +217,20 @@ func (sg *SwaggerGen) applyFFExtensionsTag(ctx context.Context, schema *openapi3 return nil } -func (sg *SwaggerGen) ffTagHandler(ctx context.Context, route *Route, name string, tag reflect.StructTag, schema *openapi3.Schema) error { +type OpenAPITagHandler struct { + PanicOnMissingDescription bool +} + +func (th *OpenAPITagHandler) HandleFFTags(ctx context.Context, route *Route, name string, tag reflect.StructTag, schema *openapi3.Schema) error { if ffEnum := tag.Get("ffenum"); ffEnum != "" { schema.Enum = fftypes.FFEnumValues(ffEnum) } if ffExtensions := tag.Get("ffschemaext"); ffExtensions != "" { - if err := sg.applyFFExtensionsTag(ctx, schema, ffExtensions); err != nil { + if err := applyFFExtensionsTag(ctx, schema, ffExtensions); err != nil { return err } } - if sg.isTrue(tag.Get("ffexclude")) { + if isTrue(tag.Get("ffexclude")) { return &openapi3gen.ExcludeSchemaSentinel{} } if taggedRoutes, ok := tag.Lookup("ffexclude"); ok { @@ -234,11 +244,11 @@ func (sg *SwaggerGen) ffTagHandler(ctx context.Context, route *Route, name strin if structName, ok := tag.Lookup("ffstruct"); ok { key := fmt.Sprintf("%s.%s", structName, name) description := i18n.Expand(ctx, i18n.MessageKey(key)) - if description == key && sg.options.PanicOnMissingDescription { + if description == key && th.PanicOnMissingDescription { return i18n.NewError(ctx, i18n.MsgFieldDescriptionMissing, key, route.Name) } schema.Description = description - } else if sg.options.PanicOnMissingDescription { + } else if th.PanicOnMissingDescription { return i18n.NewError(ctx, i18n.MsgFFStructTagMissing, name, route.Name) } } @@ -269,23 +279,35 @@ func (sg *SwaggerGen) addCustomType(t reflect.Type, schema *openapi3.Schema) { } } -func (sg *SwaggerGen) addInput(ctx context.Context, doc *openapi3.T, route *Route, op *openapi3.Operation) { - var schemaRef *openapi3.SchemaRef - var err error +func (sg *SwaggerGen) schemaRefOptions(ctx context.Context, route *Route, tagHandler func(ctx context.Context, route *Route, name string, tag reflect.StructTag, schema *openapi3.Schema) error) []openapi3gen.Option { schemaCustomizer := func(name string, t reflect.Type, tag reflect.StructTag, schema *openapi3.Schema) error { sg.addCustomType(t, schema) - return sg.ffInputTagHandler(ctx, route, name, tag, schema) + err := tagHandler(ctx, route, name, tag, schema) + if err == nil && sg.options != nil && sg.options.AdditionalSchemaCustomizer != nil { + err = sg.options.AdditionalSchemaCustomizer(name, t, tag, schema) + } + return err + } + options := []openapi3gen.Option{openapi3gen.SchemaCustomizer(schemaCustomizer)} + if sg.options != nil && sg.options.ExportComponentOpts != nil { + options = append(options, openapi3gen.CreateComponentSchemas(*sg.options.ExportComponentOpts)) } + return options +} + +func (sg *SwaggerGen) addInput(ctx context.Context, doc *openapi3.T, route *Route, op *openapi3.Operation) { + var schemaRef *openapi3.SchemaRef + var err error switch { case route.JSONInputSchema != nil: schemaRef, err = route.JSONInputSchema(ctx, func(obj interface{}) (*openapi3.SchemaRef, error) { - return openapi3gen.NewSchemaRefForValue(obj, doc.Components.Schemas, openapi3gen.SchemaCustomizer(schemaCustomizer)) + return openapi3gen.NewSchemaRefForValue(obj, doc.Components.Schemas, sg.schemaRefOptions(ctx, route, sg.ffInputTagHandler)...) }) if err != nil { panic(fmt.Sprintf("invalid schema: %s", err)) } case route.JSONInputValue != nil: - schemaRef, err = openapi3gen.NewSchemaRefForValue(route.JSONInputValue(), doc.Components.Schemas, openapi3gen.SchemaCustomizer(schemaCustomizer)) + schemaRef, err = openapi3gen.NewSchemaRefForValue(route.JSONInputValue(), doc.Components.Schemas, sg.schemaRefOptions(ctx, route, sg.ffInputTagHandler)...) if err != nil { panic(fmt.Sprintf("invalid schema: %s", err)) } @@ -351,11 +373,7 @@ func (sg *SwaggerGen) addURLEncodedFormInput(ctx context.Context, op *openapi3.O // CheckObjectDocumented lets unit tests on individual structures validate that all the ffstruct tags are set, // without having to build their own swagger. func CheckObjectDocumented(example interface{}) { - (&SwaggerGen{ - options: &SwaggerGenOptions{ - PanicOnMissingDescription: true, - }, - }).Generate(context.Background(), []*Route{{ + (NewSwaggerGen(&SwaggerGenOptions{PanicOnMissingDescription: true})).Generate(context.Background(), []*Route{{ Name: "doctest", Path: "doctest", Method: http.MethodPost, @@ -368,14 +386,10 @@ func (sg *SwaggerGen) addOutput(ctx context.Context, doc *openapi3.T, route *Rou var schemaRef *openapi3.SchemaRef var err error s := i18n.Expand(ctx, i18n.APISuccessResponse) - schemaCustomizer := func(name string, t reflect.Type, tag reflect.StructTag, schema *openapi3.Schema) error { - sg.addCustomType(t, schema) - return sg.ffOutputTagHandler(ctx, route, name, tag, schema) - } switch { case route.JSONOutputSchema != nil: schemaRef, err = route.JSONOutputSchema(ctx, func(obj interface{}) (*openapi3.SchemaRef, error) { - return openapi3gen.NewSchemaRefForValue(obj, doc.Components.Schemas, openapi3gen.SchemaCustomizer(schemaCustomizer)) + return openapi3gen.NewSchemaRefForValue(obj, doc.Components.Schemas, sg.schemaRefOptions(ctx, route, sg.ffOutputTagHandler)...) }) if err != nil { panic(fmt.Sprintf("invalid schema: %s", err)) @@ -383,7 +397,7 @@ func (sg *SwaggerGen) addOutput(ctx context.Context, doc *openapi3.T, route *Rou case route.JSONOutputValue != nil: outputValue := route.JSONOutputValue() if outputValue != nil { - schemaRef, err = openapi3gen.NewSchemaRefForValue(outputValue, doc.Components.Schemas, openapi3gen.SchemaCustomizer(schemaCustomizer)) + schemaRef, err = openapi3gen.NewSchemaRefForValue(outputValue, doc.Components.Schemas, sg.schemaRefOptions(ctx, route, sg.ffOutputTagHandler)...) if err != nil { panic(fmt.Sprintf("invalid schema: %s", err)) }