Skip to content

Commit 85aad90

Browse files
authored
allow passing a custom validator instance (#9)
Rework the internal validation logic so that a customized validator can be provided. This allows for custom types and validations to be registered for use within the framework. Custom validators are passed as part of MountOpts.
1 parent 79df6b8 commit 85aad90

5 files changed

Lines changed: 63 additions & 37 deletions

File tree

apiendpoint/api_endpoint.go

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"net/http"
1515
"time"
1616

17+
"github.com/go-playground/validator/v10"
1718
"github.com/jackc/pgerrcode"
1819
"github.com/jackc/pgx/v5/pgconn"
1920

@@ -94,6 +95,9 @@ type MountOpts struct {
9495
// MiddlewareStack is a stack of middleware that will be mounted in front of
9596
// the API endpoint handler. If not specified, no middleware will be used.
9697
MiddlewareStack *apimiddleware.MiddlewareStack
98+
// Validator is the validator to use for this endpoint. If not specified,
99+
// the default validator will be used.
100+
Validator *validator.Validate
97101
}
98102

99103
// Mount mounts an endpoint to a Go http.ServeMux. The logger is used to log
@@ -108,14 +112,19 @@ func Mount[TReq any, TResp any](mux *http.ServeMux, apiEndpoint EndpointExecuteI
108112
logger = slog.Default()
109113
}
110114

115+
validator := opts.Validator
116+
if validator == nil {
117+
validator = validate.Default
118+
}
119+
111120
apiEndpoint.SetLogger(logger)
112121

113122
meta := apiEndpoint.Meta()
114123
meta.validate() // panic on problem
115124
apiEndpoint.SetMeta(meta)
116125

117126
innerHandler := func(w http.ResponseWriter, r *http.Request) {
118-
executeAPIEndpoint(w, r, opts.Logger, meta, apiEndpoint.Execute)
127+
executeAPIEndpoint(w, r, opts.Logger, meta, validator, apiEndpoint.Execute)
119128
}
120129

121130
if opts.MiddlewareStack != nil {
@@ -127,13 +136,10 @@ func Mount[TReq any, TResp any](mux *http.ServeMux, apiEndpoint EndpointExecuteI
127136
return apiEndpoint
128137
}
129138

130-
func executeAPIEndpoint[TReq any, TResp any](w http.ResponseWriter, r *http.Request, logger *slog.Logger, meta *EndpointMeta, execute func(ctx context.Context, req *TReq) (*TResp, error)) {
139+
func executeAPIEndpoint[TReq any, TResp any](w http.ResponseWriter, r *http.Request, logger *slog.Logger, meta *EndpointMeta, validator *validator.Validate, execute func(ctx context.Context, req *TReq) (*TResp, error)) {
131140
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
132141
defer cancel()
133142

134-
// Run as much code as we can in a sub-function that can return an error.
135-
// This is more convenient to write, but is also safer because unlike when
136-
// writing errors to ResponseWriter, there's no danger of a missing return.
137143
err := func() error {
138144
var req TReq
139145
if r.Method != http.MethodGet {
@@ -161,8 +167,8 @@ func executeAPIEndpoint[TReq any, TResp any](w http.ResponseWriter, r *http.Requ
161167
}
162168
}
163169

164-
if err := validate.StructCtx(ctx, &req); err != nil {
165-
return apierror.NewBadRequest(validate.PublicFacingMessage(err))
170+
if err := validator.StructCtx(ctx, &req); err != nil {
171+
return apierror.NewBadRequest(validate.PublicFacingMessage(validator, err))
166172
}
167173

168174
resp, err := execute(ctx, &req)

apitest/apitest.go

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"fmt"
66

7+
"github.com/riverqueue/apiframe/apiendpoint"
78
"github.com/riverqueue/apiframe/apierror"
89
"github.com/riverqueue/apiframe/internal/validate"
910
)
@@ -21,19 +22,28 @@ import (
2122
// Sample invocation:
2223
//
2324
// endpoint := &testEndpoint{}
24-
// resp, err := apitest.InvokeHandler(ctx, endpoint.Execute, &testRequest{ReqField: "string"})
25+
// resp, err := apitest.InvokeHandler(ctx, endpoint.Execute, nil, &testRequest{ReqField: "string"})
2526
// require.NoError(t, err)
26-
func InvokeHandler[TReq any, TResp any](ctx context.Context, handler func(context.Context, *TReq) (*TResp, error), req *TReq) (*TResp, error) {
27-
if err := validate.StructCtx(ctx, req); err != nil {
28-
return nil, apierror.NewBadRequest(validate.PublicFacingMessage(err))
27+
func InvokeHandler[TReq any, TResp any](ctx context.Context, handler func(context.Context, *TReq) (*TResp, error), opts *apiendpoint.MountOpts, req *TReq) (*TResp, error) {
28+
if opts == nil {
29+
opts = &apiendpoint.MountOpts{}
30+
}
31+
32+
validator := opts.Validator
33+
if validator == nil {
34+
validator = validate.Default
35+
}
36+
37+
if err := validator.StructCtx(ctx, req); err != nil {
38+
return nil, apierror.NewBadRequest(validate.PublicFacingMessage(validator, err))
2939
}
3040

3141
resp, err := handler(ctx, req)
3242
if err != nil {
3343
return nil, err
3444
}
3545

36-
if err := validate.StructCtx(ctx, resp); err != nil {
46+
if err := validator.StructCtx(ctx, resp); err != nil {
3747
return nil, fmt.Errorf("apitest: error validating response API resource: %w", err)
3848
}
3949

apitest/apitest_test.go

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@ import (
44
"context"
55
"testing"
66

7+
"github.com/go-playground/validator/v10"
78
"github.com/stretchr/testify/require"
89

10+
"github.com/riverqueue/apiframe/apiendpoint"
911
"github.com/riverqueue/apiframe/apierror"
1012
)
1113

@@ -28,15 +30,15 @@ func TestInvokeHandler(t *testing.T) {
2830
t.Run("Success", func(t *testing.T) {
2931
t.Parallel()
3032

31-
resp, err := InvokeHandler(ctx, handler, &testRequest{RequiredReqField: "string"})
33+
resp, err := InvokeHandler(ctx, handler, nil, &testRequest{RequiredReqField: "string"})
3234
require.NoError(t, err)
3335
require.Equal(t, &testResponse{RequiredRespField: "response value"}, resp)
3436
})
3537

3638
t.Run("ValidatesRequest", func(t *testing.T) {
3739
t.Parallel()
3840

39-
_, err := InvokeHandler(ctx, handler, &testRequest{RequiredReqField: ""})
41+
_, err := InvokeHandler(ctx, handler, nil, &testRequest{RequiredReqField: ""})
4042
require.Equal(t, apierror.NewBadRequestf("Field `req_field` is required."), err)
4143
})
4244

@@ -47,7 +49,18 @@ func TestInvokeHandler(t *testing.T) {
4749
return &testResponse{RequiredRespField: ""}, nil
4850
}
4951

50-
_, err := InvokeHandler(ctx, handler, &testRequest{RequiredReqField: "string"})
52+
_, err := InvokeHandler(ctx, handler, nil, &testRequest{RequiredReqField: "string"})
5153
require.EqualError(t, err, "apitest: error validating response API resource: Key: 'testResponse.resp_field' Error:Field validation for 'resp_field' failed on the 'required' tag")
5254
})
55+
56+
t.Run("CustomValidator", func(t *testing.T) {
57+
t.Parallel()
58+
59+
customValidator := validator.New()
60+
opts := &apiendpoint.MountOpts{Validator: customValidator}
61+
62+
resp, err := InvokeHandler(ctx, handler, opts, &testRequest{RequiredReqField: "string"})
63+
require.NoError(t, err)
64+
require.Equal(t, &testResponse{RequiredRespField: "response value"}, resp)
65+
})
5366
}

internal/validate/validate.go

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,27 @@
44
package validate
55

66
import (
7-
"context"
87
"fmt"
98
"reflect"
109
"strings"
1110

1211
"github.com/go-playground/validator/v10"
1312
)
1413

15-
// WithRequiredStructEnabled can be removed once validator/v11 is released.
16-
var validate = validator.New(validator.WithRequiredStructEnabled()) //nolint:gochecknoglobals
14+
// Default is the package's default validator instance. WithRequiredStructEnabled
15+
// can be removed once validator/v11 is released.
16+
var Default = validator.New(validator.WithRequiredStructEnabled()) //nolint:gochecknoglobals
1717

1818
func init() { //nolint:gochecknoinits
19-
validate.RegisterTagNameFunc(preferPublicName)
19+
Default.RegisterTagNameFunc(preferPublicName)
2020
}
2121

2222
// PublicFacingMessage builds a complete error message from a validator error
2323
// that's suitable for public-facing consumption.
2424
//
2525
// I only added a few possible validations to start. We'll probably need to add
2626
// more as we go and expand our usage.
27-
func PublicFacingMessage(validatorErr error) string {
27+
func PublicFacingMessage(v *validator.Validate, validatorErr error) string {
2828
var message string
2929

3030
//nolint:errorlint
@@ -97,13 +97,6 @@ func PublicFacingMessage(validatorErr error) string {
9797
return strings.TrimSpace(message)
9898
}
9999

100-
// StructCtx validates a structs exposed fields, and automatically validates
101-
// nested structs, unless otherwise specified and also allows passing of
102-
// context.Context for contextual validation information.
103-
func StructCtx(ctx context.Context, s any) error {
104-
return validate.StructCtx(ctx, s)
105-
}
106-
107100
// preferPublicName is a validator tag naming function that uses public names
108101
// like a field's JSON tag instead of actual field names in structs.
109102
// This is important because we sent these back as user-facing errors (and the

internal/validate/validate_test.go

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,16 @@ import (
44
"reflect"
55
"testing"
66

7+
"github.com/go-playground/validator/v10"
78
"github.com/stretchr/testify/require"
89
)
910

1011
func TestFromValidator(t *testing.T) {
1112
t.Parallel()
1213

14+
validator := validator.New(validator.WithRequiredStructEnabled())
15+
validator.RegisterTagNameFunc(preferPublicName)
16+
1317
// Fields have JSON tags so we can verify those are used over the
1418
// property name.
1519
type TestStruct struct {
@@ -43,71 +47,71 @@ func TestFromValidator(t *testing.T) {
4347

4448
testStruct := validTestStruct()
4549
testStruct.MaxInt = 1
46-
require.Equal(t, "Field `max_int` must be less than or equal to 0.", PublicFacingMessage(validate.Struct(testStruct)))
50+
require.Equal(t, "Field `max_int` must be less than or equal to 0.", PublicFacingMessage(validator, validator.Struct(testStruct)))
4751
})
4852

4953
t.Run("MaxSlice", func(t *testing.T) {
5054
t.Parallel()
5155

5256
testStruct := validTestStruct()
5357
testStruct.MaxSlice = []string{"1"}
54-
require.Equal(t, "Field `max_slice` must contain at most 0 element(s).", PublicFacingMessage(validate.Struct(testStruct)))
58+
require.Equal(t, "Field `max_slice` must contain at most 0 element(s).", PublicFacingMessage(validator, validator.Struct(testStruct)))
5559
})
5660

5761
t.Run("MaxString", func(t *testing.T) {
5862
t.Parallel()
5963

6064
testStruct := validTestStruct()
6165
testStruct.MaxString = "value"
62-
require.Equal(t, "Field `max_string` must be at most 0 character(s) long.", PublicFacingMessage(validate.Struct(testStruct)))
66+
require.Equal(t, "Field `max_string` must be at most 0 character(s) long.", PublicFacingMessage(validator, validator.Struct(testStruct)))
6367
})
6468

6569
t.Run("MinInt", func(t *testing.T) {
6670
t.Parallel()
6771

6872
testStruct := validTestStruct()
6973
testStruct.MinInt = 0
70-
require.Equal(t, "Field `min_int` must be greater or equal to 1.", PublicFacingMessage(validate.Struct(testStruct)))
74+
require.Equal(t, "Field `min_int` must be greater or equal to 1.", PublicFacingMessage(validator, validator.Struct(testStruct)))
7175
})
7276

7377
t.Run("MinSlice", func(t *testing.T) {
7478
t.Parallel()
7579

7680
testStruct := validTestStruct()
7781
testStruct.MinSlice = nil
78-
require.Equal(t, "Field `min_slice` must contain at least 1 element(s).", PublicFacingMessage(validate.Struct(testStruct)))
82+
require.Equal(t, "Field `min_slice` must contain at least 1 element(s).", PublicFacingMessage(validator, validator.Struct(testStruct)))
7983
})
8084

8185
t.Run("MinString", func(t *testing.T) {
8286
t.Parallel()
8387

8488
testStruct := validTestStruct()
8589
testStruct.MinString = ""
86-
require.Equal(t, "Field `min_string` must be at least 1 character(s) long.", PublicFacingMessage(validate.Struct(testStruct)))
90+
require.Equal(t, "Field `min_string` must be at least 1 character(s) long.", PublicFacingMessage(validator, validator.Struct(testStruct)))
8791
})
8892

8993
t.Run("OneOf", func(t *testing.T) {
9094
t.Parallel()
9195

9296
testStruct := validTestStruct()
9397
testStruct.OneOf = "red"
94-
require.Equal(t, "Field `one_of` should be one of the following values: blue green.", PublicFacingMessage(validate.Struct(testStruct)))
98+
require.Equal(t, "Field `one_of` should be one of the following values: blue green.", PublicFacingMessage(validator, validator.Struct(testStruct)))
9599
})
96100

97101
t.Run("Required", func(t *testing.T) {
98102
t.Parallel()
99103

100104
testStruct := validTestStruct()
101105
testStruct.Required = ""
102-
require.Equal(t, "Field `required` is required.", PublicFacingMessage(validate.Struct(testStruct)))
106+
require.Equal(t, "Field `required` is required.", PublicFacingMessage(validator, validator.Struct(testStruct)))
103107
})
104108

105109
t.Run("Unsupported", func(t *testing.T) {
106110
t.Parallel()
107111

108112
testStruct := validTestStruct()
109113
testStruct.Unsupported = "abc"
110-
require.Equal(t, "Validation on field `unsupported` failed on the `e164` tag.", PublicFacingMessage(validate.Struct(testStruct)))
114+
require.Equal(t, "Validation on field `unsupported` failed on the `e164` tag.", PublicFacingMessage(validator, validator.Struct(testStruct)))
111115
})
112116

113117
t.Run("MultipleErrors", func(t *testing.T) {
@@ -116,7 +120,7 @@ func TestFromValidator(t *testing.T) {
116120
testStruct := validTestStruct()
117121
testStruct.MinInt = 0
118122
testStruct.Required = ""
119-
require.Equal(t, "Field `min_int` must be greater or equal to 1. Field `required` is required.", PublicFacingMessage(validate.Struct(testStruct)))
123+
require.Equal(t, "Field `min_int` must be greater or equal to 1. Field `required` is required.", PublicFacingMessage(validator, validator.Struct(testStruct)))
120124
})
121125
}
122126

0 commit comments

Comments
 (0)