diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md new file mode 100644 index 0000000..66b0a45 --- /dev/null +++ b/.github/copilot-instructions.md @@ -0,0 +1,39 @@ +# Copilot instructions for Mint + +## Build, test, and run commands + +- Full test suite: `go test ./...` +- Root package tests: `go test .` +- Single test: `go test -run '^TestJSONExtractor$' .` +- Single subtest: `go test -run 'TestJSONExtractor/valid_json' .` +- Run the demo server: `go run ./_examples` + +## High-level architecture + +Mint is a small Go module (`github.com/cymoo/mint`) whose public API lives in `mint.go` under package name `m`. It wraps standard `net/http`; applications still use `http.NewServeMux()` and register routes with `mux.HandleFunc(pattern, m.H(handler))`. + +`H(fn any)` is the central adapter. At registration time it uses reflection to validate the handler signature and cache parameter types. At request time it builds arguments from supported inputs, calls the handler, and writes the returned value through the framework response pipeline. + +Request data is injected through extractor types that implement `Extractor`: + +- `JSON[T]` reads and unmarshals the request body, then validates the target. +- `Query[T]` decodes `r.URL.Query()` through the configured `gorilla/schema.Decoder`, then validates. +- `Form[T]` calls `r.ParseForm()`, decodes `r.Form`, then validates. +- `Path[T]` also implements `KeySetter`; `H` derives path variable names from `r.Pattern` and assigns them to `Path` parameters in handler parameter order before `Extract` reads `r.PathValue`. + +Responses are normalized in `handleOneResult`, `handleTwoResults`, `handleCommonTypes`, and `handleResult`. Special return types include `StatusCode`, `HTML`, `[]byte`, `io.Reader`, `http.Handler`, `Responder`, `Result[T]`, `error`, and `(T, error)`; other values are JSON-encoded. + +Configuration is global and thread-safe through `Initialize`, `Configure`, and `Reset`. Defaults include a schema decoder with `IgnoreUnknownKeys(true)`, validation enabled with `go-playground/validator`, `json.Marshal`/`json.Unmarshal`, and `log.Default()`. Tests that mutate configuration should call `Reset()`. + +Error handling flows through `handleError` and `toHTTPError`. Prefer returning `HTTPError` for explicit HTTP status/error payloads and `ExtractError` from extractors for request extraction failures. Generic errors are converted by message heuristics such as "not found" -> 404. + +## Key conventions + +- Handler functions passed to `H` may return zero, one, or two values only. Two-value returns must be `(T, error)` where T is a concrete type, `http.Handler`, `io.Reader`, or `Responder`; the first return value cannot be `Result[T]`. +- Handler parameters are limited to built-in/custom extractors, `http.ResponseWriter`, and `*http.Request`; unsupported parameter types panic at `H()` registration time. +- Custom extractors should define `Extract(*http.Request) error` on a pointer receiver so `reflect.PointerTo(paramType).Implements(extractorType)` detects them when handlers accept the non-pointer extractor value. +- Use `Result[T]`, `OK[T]`, and `Err[T]` when a handler needs custom status codes, headers, or a typed error result. +- Use `json` tags for JSON bodies and validation field names. Use `schema` tags for query/form decoding; the default validator also uses `schema` tags when choosing validation field names. +- Path parameters depend on Go 1.23 `net/http` route patterns and tests should set both `req.Pattern` and `req.SetPathValue(...)`; `createRequestWithPattern` in `mint_test.go` is the existing helper. +- Keep tests table/subtest-oriented with `testing`, `httptest`, and helpers like `parseJSONResponse`; test files are in package `m`, so they can exercise unexported helpers directly. +- The `_examples` directory is a runnable demo, but it is ignored by `go test ./...` because its name starts with an underscore. diff --git a/README.md b/README.md index 140708f..842cd4d 100644 --- a/README.md +++ b/README.md @@ -574,7 +574,7 @@ If you don't call `Initialize()` or `Configure()`, Mint uses sensible defaults: { SchemaDecoder: schema.NewDecoder() with IgnoreUnknownKeys(true), EnableValidation: true, - Validator: validator with JSON/form tag support, + Validator: validator with JSON/schema/form tag support, Logger: log.Default(), JSONMarshalFunc: json.Marshal, JSONUnmarshalFunc: json.Unmarshal, diff --git a/mint.go b/mint.go index 43382e6..ae76947 100644 --- a/mint.go +++ b/mint.go @@ -125,21 +125,18 @@ func newDefaultSchemaDecoder() *schema.Decoder { // newDefaultValidator creates a validator with sensible defaults func newDefaultValidator() *validator.Validate { v := validator.New() - // Use json tag as field name for validation errors + // Use request-facing tags as field names for validation errors. v.RegisterTagNameFunc(func(fld reflect.StructField) string { - name := strings.SplitN(fld.Tag.Get("json"), ",", 2)[0] - if name == "-" { - return "" - } - if name != "" { - return name - } - // Fallback to form tag - name = strings.SplitN(fld.Tag.Get("form"), ",", 2)[0] - if name == "-" { - return "" + for _, tag := range []string{"json", "schema", "form"} { + name := strings.SplitN(fld.Tag.Get(tag), ",", 2)[0] + if name == "-" { + return "" + } + if name != "" { + return name + } } - return name + return fld.Name }) return v } @@ -163,6 +160,8 @@ func Initialize(opts ...Option) { for _, opt := range opts { opt(cfg) } + global.mu.Lock() + defer global.mu.Unlock() global.config = cfg }) } @@ -276,6 +275,8 @@ var ( readerType = reflect.TypeOf((*io.Reader)(nil)).Elem() handlerType = reflect.TypeOf((*http.Handler)(nil)).Elem() + responderType = reflect.TypeOf((*Responder)(nil)).Elem() + resultMarkerType = reflect.TypeOf((*resultMarker)(nil)).Elem() responseWriterType = reflect.TypeOf((*http.ResponseWriter)(nil)).Elem() httpRequestType = reflect.TypeOf((*http.Request)(nil)) ) @@ -429,50 +430,41 @@ func (p *Path[T]) Extract(r *http.Request) error { return NewMissingPathError(p.Key) } - switch ptr := any(&p.Value).(type) { - case *string: - *ptr = pv - case *int: - if val, err := strconv.Atoi(pv); err != nil { - return NewPathConversionError(p.Key, pv, "int", err) - } else { - *ptr = val - } - case *int64: - if val, err := strconv.ParseInt(pv, 10, 64); err != nil { - return NewPathConversionError(p.Key, pv, "int64", err) - } else { - *ptr = val - } - case *uint: - if val, err := strconv.ParseUint(pv, 10, 0); err != nil { - return NewPathConversionError(p.Key, pv, "uint", err) - } else { - *ptr = uint(val) + val := reflect.ValueOf(&p.Value).Elem() + targetType := val.Type().String() + + switch val.Kind() { + case reflect.String: + val.SetString(pv) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + parsed, err := strconv.ParseInt(pv, 10, val.Type().Bits()) + if err != nil { + return NewPathConversionError(p.Key, pv, targetType, err) } - case *uint64: - if val, err := strconv.ParseUint(pv, 10, 64); err != nil { - return NewPathConversionError(p.Key, pv, "uint64", err) - } else { - *ptr = val + val.SetInt(parsed) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + parsed, err := strconv.ParseUint(pv, 10, val.Type().Bits()) + if err != nil { + return NewPathConversionError(p.Key, pv, targetType, err) } - case *float64: - if val, err := strconv.ParseFloat(pv, 64); err != nil { - return NewPathConversionError(p.Key, pv, "float64", err) - } else { - *ptr = val + val.SetUint(parsed) + case reflect.Float32, reflect.Float64: + parsed, err := strconv.ParseFloat(pv, val.Type().Bits()) + if err != nil { + return NewPathConversionError(p.Key, pv, targetType, err) } - case *bool: - if val, err := strconv.ParseBool(pv); err != nil { - return NewPathConversionError(p.Key, pv, "bool", err) - } else { - *ptr = val + val.SetFloat(parsed) + case reflect.Bool: + parsed, err := strconv.ParseBool(pv) + if err != nil { + return NewPathConversionError(p.Key, pv, targetType, err) } + val.SetBool(parsed) default: return &ExtractError{ Type: "unsupported_type", Field: p.Key, - Message: fmt.Sprintf("Unsupported path parameter type: %T", &p.Value), + Message: fmt.Sprintf("Unsupported path parameter type: %s", targetType), } } return nil @@ -520,22 +512,36 @@ func (rw *ResponseWriter) Write(b []byte) (int, error) { return rw.ResponseWriter.Write(b) } +func (rw *ResponseWriter) Unwrap() http.ResponseWriter { + return rw.ResponseWriter +} + type resultMarker interface { isResultType() bool toResult() Result[any] } func H(fn any) http.HandlerFunc { + if fn == nil { + log.Panic("H: handler must be a function, got ") + } + fnVal := reflect.ValueOf(fn) fnType := fnVal.Type() if fnType.Kind() != reflect.Func { log.Panicf("H: handler must be a function, got %T", fn) } + if fnVal.IsNil() { + log.Panic("H: handler function must not be nil") + } paramTypes := make([]reflect.Type, fnType.NumIn()) for i := 0; i < fnType.NumIn(); i++ { paramTypes[i] = fnType.In(i) + if !isSupportedParamType(paramTypes[i]) { + log.Panicf("H: unsupported parameter type %s", paramTypes[i].String()) + } } numOut := fnType.NumOut() @@ -546,8 +552,8 @@ func H(fn any) http.HandlerFunc { if numOut == 1 { rt := fnType.Out(0) if rt.Kind() == reflect.Interface { - if !rt.Implements(errorType) && !rt.Implements(handlerType) && !rt.Implements(readerType) { - log.Panic("H: interface return type must implement error, http.Handler or io.Reader") + if !isSupportedSingleInterfaceReturn(rt) { + log.Panic("H: interface return type must implement error, http.Handler, io.Reader or Responder") } } } @@ -556,10 +562,10 @@ func H(fn any) http.HandlerFunc { rt1 := fnType.Out(0) rt2 := fnType.Out(1) - if rt1.Kind() == reflect.Interface { - log.Panic("H: first return value cannot be an interface when returning two values") + if rt1.Kind() == reflect.Interface && !isSupportedDataInterfaceReturn(rt1) { + log.Panic("H: first return interface must implement http.Handler, io.Reader or Responder") } - if rt1.Implements(reflect.TypeOf((*resultMarker)(nil)).Elem()) { + if rt1.Implements(resultMarkerType) { log.Panicf("H: first return value cannot be Result when returning two values") } @@ -578,7 +584,7 @@ func H(fn any) http.HandlerFunc { for i, paramType := range paramTypes { switch { - case reflect.PointerTo(paramType).Implements(extractorType): + case isExtractorParam(paramType): paramVal := reflect.New(paramType).Elem() extractor := paramVal.Addr().Interface().(Extractor) @@ -639,9 +645,16 @@ func H(fn any) http.HandlerFunc { } rv := results[0].Interface() - err := results[1].Interface() + errVal := results[1].Interface() - e := handleTwoResults(rw, rv, err) + if errVal == nil { + if handler, ok := rv.(http.Handler); ok { + handler.ServeHTTP(rw, r) + return + } + } + + e := handleTwoResults(rw, rv, errVal) if e != nil { logger().Printf("failed to write response: %v", e) } @@ -649,6 +662,29 @@ func H(fn any) http.HandlerFunc { } } +func isExtractorParam(paramType reflect.Type) bool { + return reflect.PointerTo(paramType).Implements(extractorType) +} + +func isSupportedParamType(paramType reflect.Type) bool { + return isExtractorParam(paramType) || + paramType == httpRequestType || + (paramType.Kind() == reflect.Interface && paramType.Implements(responseWriterType)) +} + +func isSupportedSingleInterfaceReturn(rt reflect.Type) bool { + return rt.Implements(errorType) || + rt.Implements(handlerType) || + rt.Implements(readerType) || + rt.Implements(responderType) +} + +func isSupportedDataInterfaceReturn(rt reflect.Type) bool { + return rt.Implements(handlerType) || + rt.Implements(readerType) || + rt.Implements(responderType) +} + func WriteHeaders(w http.ResponseWriter, headers http.Header) { for key, values := range headers { for _, value := range values { @@ -657,6 +693,12 @@ func WriteHeaders(w http.ResponseWriter, headers http.Header) { } } +func setHeaderIfMissing(w http.ResponseWriter, key, value string) { + if w.Header().Get(key) == "" { + w.Header().Set(key, value) + } +} + func NewBodyReadError(err error) error { return &ExtractError{ Type: ErrTypeBodyRead, @@ -827,25 +869,25 @@ func handleCommonTypes(w http.ResponseWriter, data any) error { switch v := data.(type) { case string: - w.Header().Set("Content-Type", "text/plain; charset=utf-8") + setHeaderIfMissing(w, "Content-Type", "text/plain; charset=utf-8") _, err := fmt.Fprint(w, v) return err case StatusCode: w.WriteHeader(int(v)) return nil case []byte: - w.Header().Set("Content-Type", "application/octet-stream") + setHeaderIfMissing(w, "Content-Type", "application/octet-stream") _, err := w.Write(v) return err case HTML, template.HTML: - w.Header().Set("Content-Type", "text/html; charset=utf-8") + setHeaderIfMissing(w, "Content-Type", "text/html; charset=utf-8") _, err := fmt.Fprint(w, v) return err case io.Reader: _, err := io.Copy(w, v) return err default: - w.Header().Set("Content-Type", "application/json; charset=utf-8") + setHeaderIfMissing(w, "Content-Type", "application/json; charset=utf-8") return jsonEncode(w, data) } } @@ -855,18 +897,22 @@ func handleResult(w http.ResponseWriter, result Result[any]) error { WriteHeaders(w, result.Headers) } - if result.Code != 0 { - w.WriteHeader(result.Code) + if result.Err != nil { + return handleErrorWithCode(w, result.Err, result.Code) } - if result.Err != nil { - return handleError(w, result.Err) + if result.Code != 0 { + w.WriteHeader(result.Code) } return handleCommonTypes(w, result.Data) } func handleError(w http.ResponseWriter, err error) error { + return handleErrorWithCode(w, err, 0) +} + +func handleErrorWithCode(w http.ResponseWriter, err error, codeOverride int) error { if errorHandler() != nil { errorHandler()(w, err) return nil @@ -882,14 +928,22 @@ func handleError(w http.ResponseWriter, err error) error { return nil } - w.Header().Set("Content-Type", "application/json; charset=utf-8") + if codeOverride != 0 { + httpErr = &HTTPError{ + Code: codeOverride, + Err: inferErrorType(codeOverride), + Message: httpErr.Message, + } + } + + setHeaderIfMissing(w, "Content-Type", "application/json; charset=utf-8") if !statusWritten { w.WriteHeader(httpErr.Code) } if httpErr.Code >= 500 { - log.Println(httpErr.Error()) + logger().Println(httpErr.Error()) } return jsonEncode(w, httpErr) @@ -1061,8 +1115,9 @@ func extractPatternNames(pattern string) []string { } inParam = false depth-- - if currentName != "" { - names = append(names, currentName) + name := strings.TrimSuffix(currentName, "...") + if name != "" && name != "$" { + names = append(names, name) } } else if inParam { currentName += string(char) diff --git a/mint_test.go b/mint_test.go index 5b1977b..6974989 100644 --- a/mint_test.go +++ b/mint_test.go @@ -368,6 +368,41 @@ func TestPathExtractor(t *testing.T) { t.Fatal("expected error for invalid bool") } }) + + t.Run("named int path value", func(t *testing.T) { + type UserID int + + req := createRequestWithPattern("GET", "/users/42", "/users/{id}") + req.SetPathValue("id", "42") + + var p Path[UserID] + p.SetKey("id") + err := p.Extract(req) + if err != nil { + t.Fatalf("Extract failed: %v", err) + } + if p.Value != UserID(42) { + t.Errorf("expected Value=42, got %d", p.Value) + } + }) + + t.Run("catch-all path value", func(t *testing.T) { + handler := H(func(path Path[string]) string { + return path.Value + }) + rec := httptest.NewRecorder() + req := createRequestWithPattern("GET", "/files/docs/readme.md", "/files/{path...}") + req.SetPathValue("path", "docs/readme.md") + + handler(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", rec.Code) + } + if rec.Body.String() != "docs/readme.md" { + t.Errorf("expected catch-all path, got %q", rec.Body.String()) + } + }) } // ========== Handler Tests ========== @@ -614,6 +649,40 @@ func TestH_ErrorHandling(t *testing.T) { } }) + t.Run("return http.Handler and error - serves handler", func(t *testing.T) { + inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusAccepted) + w.Write([]byte("delegated")) + }) + handler := H(func() (http.Handler, error) { + return inner, nil + }) + rec := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/", nil) + handler(rec, req) + if rec.Code != http.StatusAccepted { + t.Errorf("expected status 202, got %d", rec.Code) + } + if rec.Body.String() != "delegated" { + t.Errorf("unexpected body: %s", rec.Body.String()) + } + }) + + t.Run("return http.Handler and error - error wins", func(t *testing.T) { + inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusAccepted) + }) + handler := H(func() (http.Handler, error) { + return inner, errors.New("something failed") + }) + rec := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/", nil) + handler(rec, req) + if rec.Code != http.StatusInternalServerError { + t.Errorf("expected status 500, got %d", rec.Code) + } + }) + t.Run("return data and error - error only", func(t *testing.T) { handler := H(func() (string, error) { return "", errors.New("failed") @@ -709,6 +778,27 @@ func TestH_ResultType(t *testing.T) { } }) + t.Run("Err result keeps status and body code consistent", func(t *testing.T) { + handler := H(func() Result[User] { + return Err[User](400, errors.New("name is required")) + }) + rec := httptest.NewRecorder() + req := httptest.NewRequest("POST", "/", nil) + handler(rec, req) + + if rec.Code != 400 { + t.Errorf("expected status 400, got %d", rec.Code) + } + var httpErr HTTPError + parseJSONResponse(t, rec.Body.Bytes(), &httpErr) + if httpErr.Code != 400 { + t.Errorf("expected body code 400, got %d", httpErr.Code) + } + if httpErr.Err != "bad_request" { + t.Errorf("expected error type bad_request, got %s", httpErr.Err) + } + }) + t.Run("Result with custom headers", func(t *testing.T) { handler := H(func() Result[string] { r := OK("test") @@ -724,6 +814,24 @@ func TestH_ResultType(t *testing.T) { } }) + t.Run("Result preserves custom content type", func(t *testing.T) { + handler := H(func() Result[[]byte] { + return Result[[]byte]{ + Headers: http.Header{ + "Content-Type": []string{"text/csv"}, + }, + Data: []byte("id,name\n1,Alice\n"), + } + }) + rec := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/", nil) + handler(rec, req) + + if rec.Header().Get("Content-Type") != "text/csv" { + t.Errorf("expected Content-Type=text/csv, got %s", rec.Header().Get("Content-Type")) + } + }) + t.Run("Result with custom status code", func(t *testing.T) { handler := H(func() Result[User] { r := OK(User{Name: "Dave"}) @@ -1060,6 +1168,14 @@ func TestResponseWriter(t *testing.T) { t.Errorf("expected statusCode=200, got %d", rw.statusCode) } }) + + t.Run("unwrap returns underlying writer", func(t *testing.T) { + rec := httptest.NewRecorder() + rw := &ResponseWriter{ResponseWriter: rec} + if rw.Unwrap() != rec { + t.Error("expected Unwrap to return underlying writer") + } + }) } // ========== Responder Interface Tests ========== @@ -1089,6 +1205,22 @@ func TestResponderInterface(t *testing.T) { } } +func TestResponderInterfaceReturn(t *testing.T) { + handler := H(func() Responder { + return CustomResponder{statusCode: 201, body: "created"} + }) + rec := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/", nil) + handler(rec, req) + + if rec.Code != 201 { + t.Errorf("expected status 201, got %d", rec.Code) + } + if rec.Body.String() != "created" { + t.Errorf("unexpected body: %s", rec.Body.String()) + } +} + // ========== WriteHeaders Tests ========== func TestWriteHeaders(t *testing.T) { @@ -1337,15 +1469,12 @@ func TestH_Panics(t *testing.T) { }) t.Run("panic on unsupported parameter type", func(t *testing.T) { - handler := H(func(x int) {}) defer func() { if r := recover(); r == nil { t.Error("expected panic") } }() - rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/", nil) - handler(rec, req) + H(func(x int) {}) }) } @@ -1886,6 +2015,36 @@ func TestConfigWithValidation(t *testing.T) { } }) + t.Run("validation uses schema tag field names", func(t *testing.T) { + Reset() + + type Request struct { + MaxResults int `schema:"max_results" validate:"lte=10"` + } + + handler := H(func(q Query[Request]) Request { + return q.Value + }) + + req := httptest.NewRequest("GET", "/?max_results=20", nil) + rec := httptest.NewRecorder() + handler(rec, req) + + if rec.Code != 400 { + t.Fatalf("expected validation error, got status %d", rec.Code) + } + + var errResp map[string]any + parseJSONResponse(t, rec.Body.Bytes(), &errResp) + message := errResp["message"].(string) + if !strings.Contains(message, "max_results") { + t.Errorf("expected schema field name in validation message, got %q", message) + } + if strings.Contains(message, "MaxResults") { + t.Errorf("did not expect Go field name in validation message, got %q", message) + } + }) + t.Run("validation can be disabled", func(t *testing.T) { Reset() Configure(WithValidation(false))