From cd79cef4ddc3e1e3709d8dd0dc8295708b38420e Mon Sep 17 00:00:00 2001 From: cymoo Date: Thu, 7 May 2026 23:55:53 +0800 Subject: [PATCH 1/2] fix: resolve correctness bugs and improve robustness MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - fix(path): replace concrete type switch with reflect.Kind so named type aliases (e.g. type UserID int) work correctly with Path[T] - fix(path): normalize catch-all pattern keys by stripping trailing ... so /files/{path...} extracts via r.PathValue("path") as expected - fix(result): write error status and JSON body code consistently when Result.Err is set; Result.Code is now used as the explicit override - fix(headers): use setHeaderIfMissing so a Content-Type set in Result.Headers is not overwritten by automatic response handling - fix(handler): validate unsupported parameter types at H() registration time instead of panicking on the first request - fix(handler): allow Responder as a valid single-value interface return; allow http.Handler/io.Reader/Responder as first interface return in two-value signatures - fix(writer): add Unwrap() to ResponseWriter for net/http compatibility with optional-interface discovery (e.g. http.NewResponseController) - fix(config): lock mutex inside Initialize's sync.Once to be consistent with Configure/Reset; use logger() instead of log.Println for 5xx logs - fix(validation): extend default tag name lookup order to json→schema→form so schema-tagged query/form fields appear correctly in error messages - test: add targeted tests for each of the above issues Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .github/copilot-instructions.md | 39 +++++++ README.md | 2 +- mint.go | 179 ++++++++++++++++++++------------ mint_test.go | 133 +++++++++++++++++++++++- 4 files changed, 281 insertions(+), 72 deletions(-) create mode 100644 .github/copilot-instructions.md diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md new file mode 100644 index 0000000..e87add8 --- /dev/null +++ b/.github/copilot-instructions.md @@ -0,0 +1,39 @@ +# Copilot instructions for Mint + +## Build, test, and run commands + +- Full test suite: `go test ./...` +- Root package tests: `go test .` +- Single test: `go test -run '^TestJSONExtractor$' .` +- Single subtest: `go test -run 'TestJSONExtractor/valid_json' .` +- Run the demo server: `go run ./_examples` + +## High-level architecture + +Mint is a small Go module (`github.com/cymoo/mint`) whose public API lives in `mint.go` under package name `m`. It wraps standard `net/http`; applications still use `http.NewServeMux()` and register routes with `mux.HandleFunc(pattern, m.H(handler))`. + +`H(fn any)` is the central adapter. At registration time it uses reflection to validate the handler signature and cache parameter types. At request time it builds arguments from supported inputs, calls the handler, and writes the returned value through the framework response pipeline. + +Request data is injected through extractor types that implement `Extractor`: + +- `JSON[T]` reads and unmarshals the request body, then validates the target. +- `Query[T]` decodes `r.URL.Query()` through the configured `gorilla/schema.Decoder`, then validates. +- `Form[T]` calls `r.ParseForm()`, decodes `r.Form`, then validates. +- `Path[T]` also implements `KeySetter`; `H` derives path variable names from `r.Pattern` and assigns them to `Path` parameters in handler parameter order before `Extract` reads `r.PathValue`. + +Responses are normalized in `handleOneResult`, `handleTwoResults`, `handleCommonTypes`, and `handleResult`. Special return types include `StatusCode`, `HTML`, `[]byte`, `io.Reader`, `http.Handler`, `Responder`, `Result[T]`, `error`, and `(T, error)`; other values are JSON-encoded. + +Configuration is global and thread-safe through `Initialize`, `Configure`, and `Reset`. Defaults include a schema decoder with `IgnoreUnknownKeys(true)`, validation enabled with `go-playground/validator`, `json.Marshal`/`json.Unmarshal`, and `log.Default()`. Tests that mutate configuration should call `Reset()`. + +Error handling flows through `handleError` and `toHTTPError`. Prefer returning `HTTPError` for explicit HTTP status/error payloads and `ExtractError` from extractors for request extraction failures. Generic errors are converted by message heuristics such as "not found" -> 404. + +## Key conventions + +- Handler functions passed to `H` may return zero, one, or two values only. Two-value returns must be `(non-interface, error)`, and the first return value cannot be `Result[T]`. +- Handler parameters are limited to built-in/custom extractors, `http.ResponseWriter`, and `*http.Request`; unsupported parameter types panic during request handling. +- Custom extractors should define `Extract(*http.Request) error` on a pointer receiver so `reflect.PointerTo(paramType).Implements(extractorType)` detects them when handlers accept the non-pointer extractor value. +- Use `Result[T]`, `OK[T]`, and `Err[T]` when a handler needs custom status codes, headers, or a typed error result. +- Use `json` tags for JSON bodies and validation field names. Use `schema` tags for query/form decoding; the default validator also uses `schema` tags when choosing validation field names. +- Path parameters depend on Go 1.23 `net/http` route patterns and tests should set both `req.Pattern` and `req.SetPathValue(...)`; `createRequestWithPattern` in `mint_test.go` is the existing helper. +- Keep tests table/subtest-oriented with `testing`, `httptest`, and helpers like `parseJSONResponse`; test files are in package `m`, so they can exercise unexported helpers directly. +- The `_examples` directory is a runnable demo, but it is ignored by `go test ./...` because its name starts with an underscore. diff --git a/README.md b/README.md index 140708f..842cd4d 100644 --- a/README.md +++ b/README.md @@ -574,7 +574,7 @@ If you don't call `Initialize()` or `Configure()`, Mint uses sensible defaults: { SchemaDecoder: schema.NewDecoder() with IgnoreUnknownKeys(true), EnableValidation: true, - Validator: validator with JSON/form tag support, + Validator: validator with JSON/schema/form tag support, Logger: log.Default(), JSONMarshalFunc: json.Marshal, JSONUnmarshalFunc: json.Unmarshal, diff --git a/mint.go b/mint.go index 43382e6..d969678 100644 --- a/mint.go +++ b/mint.go @@ -125,21 +125,18 @@ func newDefaultSchemaDecoder() *schema.Decoder { // newDefaultValidator creates a validator with sensible defaults func newDefaultValidator() *validator.Validate { v := validator.New() - // Use json tag as field name for validation errors + // Use request-facing tags as field names for validation errors. v.RegisterTagNameFunc(func(fld reflect.StructField) string { - name := strings.SplitN(fld.Tag.Get("json"), ",", 2)[0] - if name == "-" { - return "" - } - if name != "" { - return name - } - // Fallback to form tag - name = strings.SplitN(fld.Tag.Get("form"), ",", 2)[0] - if name == "-" { - return "" + for _, tag := range []string{"json", "schema", "form"} { + name := strings.SplitN(fld.Tag.Get(tag), ",", 2)[0] + if name == "-" { + return "" + } + if name != "" { + return name + } } - return name + return fld.Name }) return v } @@ -163,6 +160,8 @@ func Initialize(opts ...Option) { for _, opt := range opts { opt(cfg) } + global.mu.Lock() + defer global.mu.Unlock() global.config = cfg }) } @@ -276,6 +275,8 @@ var ( readerType = reflect.TypeOf((*io.Reader)(nil)).Elem() handlerType = reflect.TypeOf((*http.Handler)(nil)).Elem() + responderType = reflect.TypeOf((*Responder)(nil)).Elem() + resultMarkerType = reflect.TypeOf((*resultMarker)(nil)).Elem() responseWriterType = reflect.TypeOf((*http.ResponseWriter)(nil)).Elem() httpRequestType = reflect.TypeOf((*http.Request)(nil)) ) @@ -429,50 +430,41 @@ func (p *Path[T]) Extract(r *http.Request) error { return NewMissingPathError(p.Key) } - switch ptr := any(&p.Value).(type) { - case *string: - *ptr = pv - case *int: - if val, err := strconv.Atoi(pv); err != nil { - return NewPathConversionError(p.Key, pv, "int", err) - } else { - *ptr = val - } - case *int64: - if val, err := strconv.ParseInt(pv, 10, 64); err != nil { - return NewPathConversionError(p.Key, pv, "int64", err) - } else { - *ptr = val - } - case *uint: - if val, err := strconv.ParseUint(pv, 10, 0); err != nil { - return NewPathConversionError(p.Key, pv, "uint", err) - } else { - *ptr = uint(val) + val := reflect.ValueOf(&p.Value).Elem() + targetType := val.Type().String() + + switch val.Kind() { + case reflect.String: + val.SetString(pv) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + parsed, err := strconv.ParseInt(pv, 10, val.Type().Bits()) + if err != nil { + return NewPathConversionError(p.Key, pv, targetType, err) } - case *uint64: - if val, err := strconv.ParseUint(pv, 10, 64); err != nil { - return NewPathConversionError(p.Key, pv, "uint64", err) - } else { - *ptr = val + val.SetInt(parsed) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + parsed, err := strconv.ParseUint(pv, 10, val.Type().Bits()) + if err != nil { + return NewPathConversionError(p.Key, pv, targetType, err) } - case *float64: - if val, err := strconv.ParseFloat(pv, 64); err != nil { - return NewPathConversionError(p.Key, pv, "float64", err) - } else { - *ptr = val + val.SetUint(parsed) + case reflect.Float32, reflect.Float64: + parsed, err := strconv.ParseFloat(pv, val.Type().Bits()) + if err != nil { + return NewPathConversionError(p.Key, pv, targetType, err) } - case *bool: - if val, err := strconv.ParseBool(pv); err != nil { - return NewPathConversionError(p.Key, pv, "bool", err) - } else { - *ptr = val + val.SetFloat(parsed) + case reflect.Bool: + parsed, err := strconv.ParseBool(pv) + if err != nil { + return NewPathConversionError(p.Key, pv, targetType, err) } + val.SetBool(parsed) default: return &ExtractError{ Type: "unsupported_type", Field: p.Key, - Message: fmt.Sprintf("Unsupported path parameter type: %T", &p.Value), + Message: fmt.Sprintf("Unsupported path parameter type: %s", targetType), } } return nil @@ -520,12 +512,20 @@ func (rw *ResponseWriter) Write(b []byte) (int, error) { return rw.ResponseWriter.Write(b) } +func (rw *ResponseWriter) Unwrap() http.ResponseWriter { + return rw.ResponseWriter +} + type resultMarker interface { isResultType() bool toResult() Result[any] } func H(fn any) http.HandlerFunc { + if fn == nil { + log.Panic("H: handler must be a function, got ") + } + fnVal := reflect.ValueOf(fn) fnType := fnVal.Type() @@ -536,6 +536,9 @@ func H(fn any) http.HandlerFunc { paramTypes := make([]reflect.Type, fnType.NumIn()) for i := 0; i < fnType.NumIn(); i++ { paramTypes[i] = fnType.In(i) + if !isSupportedParamType(paramTypes[i]) { + log.Panicf("H: unsupported parameter type %s", paramTypes[i].String()) + } } numOut := fnType.NumOut() @@ -546,8 +549,8 @@ func H(fn any) http.HandlerFunc { if numOut == 1 { rt := fnType.Out(0) if rt.Kind() == reflect.Interface { - if !rt.Implements(errorType) && !rt.Implements(handlerType) && !rt.Implements(readerType) { - log.Panic("H: interface return type must implement error, http.Handler or io.Reader") + if !isSupportedSingleInterfaceReturn(rt) { + log.Panic("H: interface return type must implement error, http.Handler, io.Reader or Responder") } } } @@ -556,10 +559,10 @@ func H(fn any) http.HandlerFunc { rt1 := fnType.Out(0) rt2 := fnType.Out(1) - if rt1.Kind() == reflect.Interface { - log.Panic("H: first return value cannot be an interface when returning two values") + if rt1.Kind() == reflect.Interface && !isSupportedDataInterfaceReturn(rt1) { + log.Panic("H: first return interface must implement http.Handler, io.Reader or Responder") } - if rt1.Implements(reflect.TypeOf((*resultMarker)(nil)).Elem()) { + if rt1.Implements(resultMarkerType) { log.Panicf("H: first return value cannot be Result when returning two values") } @@ -578,7 +581,7 @@ func H(fn any) http.HandlerFunc { for i, paramType := range paramTypes { switch { - case reflect.PointerTo(paramType).Implements(extractorType): + case isExtractorParam(paramType): paramVal := reflect.New(paramType).Elem() extractor := paramVal.Addr().Interface().(Extractor) @@ -649,6 +652,29 @@ func H(fn any) http.HandlerFunc { } } +func isExtractorParam(paramType reflect.Type) bool { + return reflect.PointerTo(paramType).Implements(extractorType) +} + +func isSupportedParamType(paramType reflect.Type) bool { + return isExtractorParam(paramType) || + paramType == httpRequestType || + (paramType.Kind() == reflect.Interface && paramType.Implements(responseWriterType)) +} + +func isSupportedSingleInterfaceReturn(rt reflect.Type) bool { + return rt.Implements(errorType) || + rt.Implements(handlerType) || + rt.Implements(readerType) || + rt.Implements(responderType) +} + +func isSupportedDataInterfaceReturn(rt reflect.Type) bool { + return rt.Implements(handlerType) || + rt.Implements(readerType) || + rt.Implements(responderType) +} + func WriteHeaders(w http.ResponseWriter, headers http.Header) { for key, values := range headers { for _, value := range values { @@ -657,6 +683,12 @@ func WriteHeaders(w http.ResponseWriter, headers http.Header) { } } +func setHeaderIfMissing(w http.ResponseWriter, key, value string) { + if w.Header().Get(key) == "" { + w.Header().Set(key, value) + } +} + func NewBodyReadError(err error) error { return &ExtractError{ Type: ErrTypeBodyRead, @@ -827,25 +859,25 @@ func handleCommonTypes(w http.ResponseWriter, data any) error { switch v := data.(type) { case string: - w.Header().Set("Content-Type", "text/plain; charset=utf-8") + setHeaderIfMissing(w, "Content-Type", "text/plain; charset=utf-8") _, err := fmt.Fprint(w, v) return err case StatusCode: w.WriteHeader(int(v)) return nil case []byte: - w.Header().Set("Content-Type", "application/octet-stream") + setHeaderIfMissing(w, "Content-Type", "application/octet-stream") _, err := w.Write(v) return err case HTML, template.HTML: - w.Header().Set("Content-Type", "text/html; charset=utf-8") + setHeaderIfMissing(w, "Content-Type", "text/html; charset=utf-8") _, err := fmt.Fprint(w, v) return err case io.Reader: _, err := io.Copy(w, v) return err default: - w.Header().Set("Content-Type", "application/json; charset=utf-8") + setHeaderIfMissing(w, "Content-Type", "application/json; charset=utf-8") return jsonEncode(w, data) } } @@ -855,18 +887,22 @@ func handleResult(w http.ResponseWriter, result Result[any]) error { WriteHeaders(w, result.Headers) } - if result.Code != 0 { - w.WriteHeader(result.Code) + if result.Err != nil { + return handleErrorWithCode(w, result.Err, result.Code) } - if result.Err != nil { - return handleError(w, result.Err) + if result.Code != 0 { + w.WriteHeader(result.Code) } return handleCommonTypes(w, result.Data) } func handleError(w http.ResponseWriter, err error) error { + return handleErrorWithCode(w, err, 0) +} + +func handleErrorWithCode(w http.ResponseWriter, err error, codeOverride int) error { if errorHandler() != nil { errorHandler()(w, err) return nil @@ -882,14 +918,22 @@ func handleError(w http.ResponseWriter, err error) error { return nil } - w.Header().Set("Content-Type", "application/json; charset=utf-8") + if codeOverride != 0 { + httpErr = &HTTPError{ + Code: codeOverride, + Err: inferErrorType(codeOverride), + Message: httpErr.Message, + } + } + + setHeaderIfMissing(w, "Content-Type", "application/json; charset=utf-8") if !statusWritten { w.WriteHeader(httpErr.Code) } if httpErr.Code >= 500 { - log.Println(httpErr.Error()) + logger().Println(httpErr.Error()) } return jsonEncode(w, httpErr) @@ -1061,8 +1105,9 @@ func extractPatternNames(pattern string) []string { } inParam = false depth-- - if currentName != "" { - names = append(names, currentName) + name := strings.TrimSuffix(currentName, "...") + if name != "" && name != "$" { + names = append(names, name) } } else if inParam { currentName += string(char) diff --git a/mint_test.go b/mint_test.go index 5b1977b..828a3e3 100644 --- a/mint_test.go +++ b/mint_test.go @@ -368,6 +368,41 @@ func TestPathExtractor(t *testing.T) { t.Fatal("expected error for invalid bool") } }) + + t.Run("named int path value", func(t *testing.T) { + type UserID int + + req := createRequestWithPattern("GET", "/users/42", "/users/{id}") + req.SetPathValue("id", "42") + + var p Path[UserID] + p.SetKey("id") + err := p.Extract(req) + if err != nil { + t.Fatalf("Extract failed: %v", err) + } + if p.Value != UserID(42) { + t.Errorf("expected Value=42, got %d", p.Value) + } + }) + + t.Run("catch-all path value", func(t *testing.T) { + handler := H(func(path Path[string]) string { + return path.Value + }) + rec := httptest.NewRecorder() + req := createRequestWithPattern("GET", "/files/docs/readme.md", "/files/{path...}") + req.SetPathValue("path", "docs/readme.md") + + handler(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", rec.Code) + } + if rec.Body.String() != "docs/readme.md" { + t.Errorf("expected catch-all path, got %q", rec.Body.String()) + } + }) } // ========== Handler Tests ========== @@ -709,6 +744,27 @@ func TestH_ResultType(t *testing.T) { } }) + t.Run("Err result keeps status and body code consistent", func(t *testing.T) { + handler := H(func() Result[User] { + return Err[User](400, errors.New("name is required")) + }) + rec := httptest.NewRecorder() + req := httptest.NewRequest("POST", "/", nil) + handler(rec, req) + + if rec.Code != 400 { + t.Errorf("expected status 400, got %d", rec.Code) + } + var httpErr HTTPError + parseJSONResponse(t, rec.Body.Bytes(), &httpErr) + if httpErr.Code != 400 { + t.Errorf("expected body code 400, got %d", httpErr.Code) + } + if httpErr.Err != "bad_request" { + t.Errorf("expected error type bad_request, got %s", httpErr.Err) + } + }) + t.Run("Result with custom headers", func(t *testing.T) { handler := H(func() Result[string] { r := OK("test") @@ -724,6 +780,24 @@ func TestH_ResultType(t *testing.T) { } }) + t.Run("Result preserves custom content type", func(t *testing.T) { + handler := H(func() Result[[]byte] { + return Result[[]byte]{ + Headers: http.Header{ + "Content-Type": []string{"text/csv"}, + }, + Data: []byte("id,name\n1,Alice\n"), + } + }) + rec := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/", nil) + handler(rec, req) + + if rec.Header().Get("Content-Type") != "text/csv" { + t.Errorf("expected Content-Type=text/csv, got %s", rec.Header().Get("Content-Type")) + } + }) + t.Run("Result with custom status code", func(t *testing.T) { handler := H(func() Result[User] { r := OK(User{Name: "Dave"}) @@ -1060,6 +1134,14 @@ func TestResponseWriter(t *testing.T) { t.Errorf("expected statusCode=200, got %d", rw.statusCode) } }) + + t.Run("unwrap returns underlying writer", func(t *testing.T) { + rec := httptest.NewRecorder() + rw := &ResponseWriter{ResponseWriter: rec} + if rw.Unwrap() != rec { + t.Error("expected Unwrap to return underlying writer") + } + }) } // ========== Responder Interface Tests ========== @@ -1089,6 +1171,22 @@ func TestResponderInterface(t *testing.T) { } } +func TestResponderInterfaceReturn(t *testing.T) { + handler := H(func() Responder { + return CustomResponder{statusCode: 201, body: "created"} + }) + rec := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/", nil) + handler(rec, req) + + if rec.Code != 201 { + t.Errorf("expected status 201, got %d", rec.Code) + } + if rec.Body.String() != "created" { + t.Errorf("unexpected body: %s", rec.Body.String()) + } +} + // ========== WriteHeaders Tests ========== func TestWriteHeaders(t *testing.T) { @@ -1337,15 +1435,12 @@ func TestH_Panics(t *testing.T) { }) t.Run("panic on unsupported parameter type", func(t *testing.T) { - handler := H(func(x int) {}) defer func() { if r := recover(); r == nil { t.Error("expected panic") } }() - rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/", nil) - handler(rec, req) + H(func(x int) {}) }) } @@ -1886,6 +1981,36 @@ func TestConfigWithValidation(t *testing.T) { } }) + t.Run("validation uses schema tag field names", func(t *testing.T) { + Reset() + + type Request struct { + MaxResults int `schema:"max_results" validate:"lte=10"` + } + + handler := H(func(q Query[Request]) Request { + return q.Value + }) + + req := httptest.NewRequest("GET", "/?max_results=20", nil) + rec := httptest.NewRecorder() + handler(rec, req) + + if rec.Code != 400 { + t.Fatalf("expected validation error, got status %d", rec.Code) + } + + var errResp map[string]any + parseJSONResponse(t, rec.Body.Bytes(), &errResp) + message := errResp["message"].(string) + if !strings.Contains(message, "max_results") { + t.Errorf("expected schema field name in validation message, got %q", message) + } + if strings.Contains(message, "MaxResults") { + t.Errorf("did not expect Go field name in validation message, got %q", message) + } + }) + t.Run("validation can be disabled", func(t *testing.T) { Reset() Configure(WithValidation(false)) From a2696240a559040a0777194ac5e3eb7f2fbd7e10 Mon Sep 17 00:00:00 2001 From: cymoo Date: Sat, 16 May 2026 22:57:10 +0800 Subject: [PATCH 2/2] fix: address Copilot review issues from PR #4 - fix(handler): add fnVal.IsNil() guard to catch typed-nil function values at H() registration time instead of panicking on first request - fix(handler): handle http.Handler in two-value return path; previously (http.Handler, error) would fall through to JSON encoding instead of calling ServeHTTP, mirroring the existing one-value path behavior - docs: update copilot-instructions.md to reflect that two-value returns allow interface first value (http.Handler/io.Reader/Responder), and that unsupported param types panic at H() registration not request time - test: add regression tests for (http.Handler, error) two-value return Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .github/copilot-instructions.md | 4 ++-- mint.go | 14 ++++++++++++-- mint_test.go | 34 +++++++++++++++++++++++++++++++++ 3 files changed, 48 insertions(+), 4 deletions(-) diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index e87add8..66b0a45 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -29,8 +29,8 @@ Error handling flows through `handleError` and `toHTTPError`. Prefer returning ` ## Key conventions -- Handler functions passed to `H` may return zero, one, or two values only. Two-value returns must be `(non-interface, error)`, and the first return value cannot be `Result[T]`. -- Handler parameters are limited to built-in/custom extractors, `http.ResponseWriter`, and `*http.Request`; unsupported parameter types panic during request handling. +- Handler functions passed to `H` may return zero, one, or two values only. Two-value returns must be `(T, error)` where T is a concrete type, `http.Handler`, `io.Reader`, or `Responder`; the first return value cannot be `Result[T]`. +- Handler parameters are limited to built-in/custom extractors, `http.ResponseWriter`, and `*http.Request`; unsupported parameter types panic at `H()` registration time. - Custom extractors should define `Extract(*http.Request) error` on a pointer receiver so `reflect.PointerTo(paramType).Implements(extractorType)` detects them when handlers accept the non-pointer extractor value. - Use `Result[T]`, `OK[T]`, and `Err[T]` when a handler needs custom status codes, headers, or a typed error result. - Use `json` tags for JSON bodies and validation field names. Use `schema` tags for query/form decoding; the default validator also uses `schema` tags when choosing validation field names. diff --git a/mint.go b/mint.go index d969678..ae76947 100644 --- a/mint.go +++ b/mint.go @@ -532,6 +532,9 @@ func H(fn any) http.HandlerFunc { if fnType.Kind() != reflect.Func { log.Panicf("H: handler must be a function, got %T", fn) } + if fnVal.IsNil() { + log.Panic("H: handler function must not be nil") + } paramTypes := make([]reflect.Type, fnType.NumIn()) for i := 0; i < fnType.NumIn(); i++ { @@ -642,9 +645,16 @@ func H(fn any) http.HandlerFunc { } rv := results[0].Interface() - err := results[1].Interface() + errVal := results[1].Interface() + + if errVal == nil { + if handler, ok := rv.(http.Handler); ok { + handler.ServeHTTP(rw, r) + return + } + } - e := handleTwoResults(rw, rv, err) + e := handleTwoResults(rw, rv, errVal) if e != nil { logger().Printf("failed to write response: %v", e) } diff --git a/mint_test.go b/mint_test.go index 828a3e3..6974989 100644 --- a/mint_test.go +++ b/mint_test.go @@ -649,6 +649,40 @@ func TestH_ErrorHandling(t *testing.T) { } }) + t.Run("return http.Handler and error - serves handler", func(t *testing.T) { + inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusAccepted) + w.Write([]byte("delegated")) + }) + handler := H(func() (http.Handler, error) { + return inner, nil + }) + rec := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/", nil) + handler(rec, req) + if rec.Code != http.StatusAccepted { + t.Errorf("expected status 202, got %d", rec.Code) + } + if rec.Body.String() != "delegated" { + t.Errorf("unexpected body: %s", rec.Body.String()) + } + }) + + t.Run("return http.Handler and error - error wins", func(t *testing.T) { + inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusAccepted) + }) + handler := H(func() (http.Handler, error) { + return inner, errors.New("something failed") + }) + rec := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/", nil) + handler(rec, req) + if rec.Code != http.StatusInternalServerError { + t.Errorf("expected status 500, got %d", rec.Code) + } + }) + t.Run("return data and error - error only", func(t *testing.T) { handler := H(func() (string, error) { return "", errors.New("failed")