Skip to content

Commit 6040683

Browse files
committed
apiendpoint.Mount: take options struct including middleware stack
Allow for an endpoint-specific middleware stack.
1 parent 4500740 commit 6040683

2 files changed

Lines changed: 29 additions & 6 deletions

File tree

apiendpoint/api_endpoint.go

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"github.com/jackc/pgx/v5/pgconn"
1818

1919
"github.com/riverqueue/apiframe/apierror"
20+
"github.com/riverqueue/apiframe/apimiddleware"
2021
"github.com/riverqueue/apiframe/internal/validate"
2122
)
2223

@@ -87,18 +88,40 @@ func (m *EndpointMeta) validate() {
8788
}
8889
}
8990

91+
type MountOpts struct {
92+
Logger *slog.Logger
93+
// MiddlewareStack is a stack of middleware that will be mounted in front of
94+
// the API endpoint handler. If not specified, no middleware will be used.
95+
MiddlewareStack *apimiddleware.MiddlewareStack
96+
}
97+
9098
// Mount mounts an endpoint to a Go http.ServeMux. The logger is used to log
9199
// information about endpoint execution.
92-
func Mount[TReq any, TResp any](mux *http.ServeMux, logger *slog.Logger, apiEndpoint EndpointExecuteInterface[TReq, TResp]) EndpointInterface {
100+
func Mount[TReq any, TResp any](mux *http.ServeMux, apiEndpoint EndpointExecuteInterface[TReq, TResp], opts *MountOpts) EndpointInterface {
101+
if opts == nil {
102+
opts = &MountOpts{}
103+
}
104+
105+
logger := opts.Logger
106+
if logger == nil {
107+
logger = slog.Default()
108+
}
109+
93110
apiEndpoint.SetLogger(logger)
94111

95112
meta := apiEndpoint.Meta()
96113
meta.validate() // panic on problem
97114
apiEndpoint.SetMeta(meta)
98115

99-
mux.HandleFunc(meta.Pattern, func(w http.ResponseWriter, r *http.Request) {
100-
executeAPIEndpoint(w, r, logger, meta, apiEndpoint.Execute)
101-
})
116+
innerHandler := func(w http.ResponseWriter, r *http.Request) {
117+
executeAPIEndpoint(w, r, opts.Logger, meta, apiEndpoint.Execute)
118+
}
119+
120+
if opts.MiddlewareStack != nil {
121+
mux.Handle(meta.Pattern, opts.MiddlewareStack.Mount(http.HandlerFunc(innerHandler)))
122+
} else {
123+
mux.HandleFunc(meta.Pattern, innerHandler)
124+
}
102125

103126
return apiEndpoint
104127
}

apiendpoint/api_endpoint_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ func TestMountAndServe(t *testing.T) {
3636
mux = http.NewServeMux()
3737
)
3838

39-
Mount(mux, logger, &getEndpoint{})
40-
Mount(mux, logger, &postEndpoint{})
39+
Mount(mux, &getEndpoint{}, &MountOpts{Logger: logger})
40+
Mount(mux, &postEndpoint{}, &MountOpts{Logger: logger})
4141

4242
return mux, &testBundle{
4343
recorder: httptest.NewRecorder(),

0 commit comments

Comments
 (0)