From 0bc9fbc6203cc533956b800bfb38d3f694b69cdb Mon Sep 17 00:00:00 2001 From: Zoe Woods Date: Thu, 10 Jul 2025 20:47:30 +0100 Subject: [PATCH 01/10] chore: add new helper methods --- helpers/crypto.go | 63 +++++++++++++++++++++++++++++ helpers/email.go | 27 +++++++++++++ helpers/encoding.go | 22 ++++++++++ helpers/local.go | 15 ------- helpers/request.go | 69 +++++++++++++++++++++++++++++++ helpers/response.go | 23 +++++++++++ helpers/strings.go | 99 +++++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 303 insertions(+), 15 deletions(-) create mode 100644 helpers/crypto.go create mode 100644 helpers/email.go create mode 100644 helpers/encoding.go delete mode 100644 helpers/local.go create mode 100644 helpers/request.go create mode 100644 helpers/response.go create mode 100644 helpers/strings.go diff --git a/helpers/crypto.go b/helpers/crypto.go new file mode 100644 index 0000000..e2c67e1 --- /dev/null +++ b/helpers/crypto.go @@ -0,0 +1,63 @@ +package helpers + +import ( + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "fmt" + + "golang.org/x/crypto/bcrypt" +) + +// HashPassword hashes using bcrypt +func HashPassword(password string) (string, error) { + bytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + return string(bytes), err +} + +// CheckPassword compares hash with plain +func CheckPassword(hash, password string) bool { + err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) + return err == nil +} + +// HMACSHA256 generates HMAC +func HMACSHA256(secret, data string) string { + h := hmac.New(sha256.New, []byte(secret)) + h.Write([]byte(data)) + return base64.StdEncoding.EncodeToString(h.Sum(nil)) +} + +// GenerateOTP generates a random OTP string based on a given length +func GenerateOTP(length int, charset string) (string, error) { + if length <= 0 { + return "", fmt.Errorf("GenerateOTP: length must be > 0") + } + + if charset == "" { + charset = "23456789ABCDEFGHJKLMNPQRSTUVWXYZ" + } + + charsetLen := len(charset) + otp := make([]byte, length) + + for i := 0; i < length; { + b := make([]byte, 1) + _, err := rand.Read(b) + if err != nil { + return "", fmt.Errorf("GenerateOTP: failed to read random byte: %w", err) + } + + val := int(b[0]) + max := 256 - (256 % charsetLen) + if val >= max { + continue + } + + otp[i] = charset[val%charsetLen] + i++ + } + + return string(otp), nil +} diff --git a/helpers/email.go b/helpers/email.go new file mode 100644 index 0000000..6d51a49 --- /dev/null +++ b/helpers/email.go @@ -0,0 +1,27 @@ +package helpers + +import ( + "regexp" + "strings" +) + +// IsEmailValid performs basic validation +// not RFC perfect, but solid +func IsEmailValid(email string) bool { + re := regexp.MustCompile(`^[a-z0-9._%+\-]+@[a-z0-9.\-]+\.[a-z]{2,}$`) + return re.MatchString(strings.ToLower(email)) +} + +// MaskEmail hides part of the email +func MaskEmail(email string) string { + at := strings.Index(email, "@") + if at <= 1 { + return email + } + return email[:1] + strings.Repeat("*", at-1) + email[at:] +} + +// NormalizeEmail trims spaces and converts to lowercase. +func NormalizeEmail(email string) string { + return strings.ToLower(strings.TrimSpace(email)) +} diff --git a/helpers/encoding.go b/helpers/encoding.go new file mode 100644 index 0000000..20efda5 --- /dev/null +++ b/helpers/encoding.go @@ -0,0 +1,22 @@ +package helpers + +import ( + "encoding/base64" + "encoding/hex" +) + +// ToBase64 encodes string to base64 +func ToBase64(data string) string { + return base64.StdEncoding.EncodeToString([]byte(data)) +} + +// FromBase64 decodes base64 string +func FromBase64(encoded string) (string, error) { + bytes, err := base64.StdEncoding.DecodeString(encoded) + return string(bytes), err +} + +// ToHex encodes string to hex +func ToHex(data string) string { + return hex.EncodeToString([]byte(data)) +} diff --git a/helpers/local.go b/helpers/local.go deleted file mode 100644 index 772ac6e..0000000 --- a/helpers/local.go +++ /dev/null @@ -1,15 +0,0 @@ -package helpers - -import ( - "net/http" - - "github.com/go-chi/chi/v5" -) - -// URLParam returns the url parameter from a http.Request object. -func URLParam(r *http.Request, key string) string { - if value := chi.URLParam(r, key); value != "" { - return value - } - return "" -} diff --git a/helpers/request.go b/helpers/request.go new file mode 100644 index 0000000..2897ee4 --- /dev/null +++ b/helpers/request.go @@ -0,0 +1,69 @@ +package helpers + +import ( + "encoding/json" + "net" + "net/http" + "strings" + + "github.com/go-chi/chi/v5" +) + +// GetHeader safely retrieves a header key +func GetHeader(r *http.Request, key string) string { + return strings.TrimSpace(r.Header.Get(key)) +} + +// GetBearerToken extracts token from Authorization header +func GetBearerToken(r *http.Request) string { + auth := r.Header.Get("Authorization") + if strings.HasPrefix(auth, "Bearer ") { + return strings.TrimPrefix(auth, "Bearer ") + } + return "" +} + +// IsJSONRequest checks Content-Type +func IsJSONRequest(r *http.Request) bool { + return strings.HasPrefix(r.Header.Get("Content-Type"), "application/json") +} + +// GetIP retrieves the real IP address from the request, accounting for proxies. +func GetIP(r *http.Request) string { + // Try X-Forwarded-For header + xff := r.Header.Get("X-Forwarded-For") + if xff != "" { + parts := strings.Split(xff, ",") + return strings.TrimSpace(parts[0]) + } + + // Try X-Real-IP header + if ip := r.Header.Get("X-Real-Ip"); ip != "" { + return ip + } + + // Fallback to remote address + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + return r.RemoteAddr + } + return host +} + +// BindJSON decodes JSON body into the provided struct. +func BindJSON(r *http.Request, dst interface{}) error { + defer r.Body.Close() + decoder := json.NewDecoder(r.Body) + decoder.DisallowUnknownFields() + return decoder.Decode(dst) +} + +// RouteContext returns the chi.RouteContext from the request context. +func RouteContext(r *http.Request) *chi.Context { + return chi.RouteContext(r.Context()) +} + +// URLParam returns a URL parameter from a http.Request object. +func URLParam(r *http.Request, key string) string { + return chi.URLParam(r, key) +} diff --git a/helpers/response.go b/helpers/response.go new file mode 100644 index 0000000..b56a1ab --- /dev/null +++ b/helpers/response.go @@ -0,0 +1,23 @@ +package helpers + +import ( + "encoding/json" + "net/http" +) + +// RespondWithJSON writes a JSON response with status code +func RespondWithJSON(w http.ResponseWriter, status int, payload interface{}) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + json.NewEncoder(w).Encode(payload) +} + +// RespondWithError writes a standardized error message +func RespondWithError(w http.ResponseWriter, status int, message string) { + RespondWithJSON(w, status, map[string]string{"error": message}) +} + +// NoContent sends a 204 No Content response +func NoContent(w http.ResponseWriter) { + w.WriteHeader(http.StatusNoContent) +} diff --git a/helpers/strings.go b/helpers/strings.go new file mode 100644 index 0000000..fd19356 --- /dev/null +++ b/helpers/strings.go @@ -0,0 +1,99 @@ +package helpers + +import ( + "crypto/rand" + "encoding/hex" + "regexp" + "strings" + "unicode" +) + +// IsEmpty returns true if the trimmed string is empty. +func IsEmpty(s string) bool { + return strings.TrimSpace(s) == "" +} + +// Truncate shortens a string to a max length with optional ellipsis. +func Truncate(s string, max int, withEllipsis bool) string { + if len(s) <= max { + return s + } + if withEllipsis && max > 3 { + return s[:max-3] + "..." + } + return s[:max] +} + +// Slugify creates a URL-safe slug (lowercase, hyphens, alphanumeric). +func Slugify(s string) string { + s = strings.ToLower(s) + s = regexp.MustCompile(`[^a-z0-9]+`).ReplaceAllString(s, "-") + s = strings.Trim(s, "-") + return s +} + +// RandomString generates a random alphanumeric string of n bytes. +func RandomString(n int) (string, error) { + bytes := make([]byte, n) + _, err := rand.Read(bytes) + if err != nil { + return "", err + } + return hex.EncodeToString(bytes)[:n], nil +} + +// ContainsAny checks if a string contains any of the substrings. +func ContainsAny(s string, substrs ...string) bool { + for _, sub := range substrs { + if strings.Contains(s, sub) { + return true + } + } + return false +} + +// RemoveWhitespace removes all space, tab, newline characters. +func RemoveWhitespace(s string) string { + return strings.Join(strings.Fields(s), "") +} + +// IsNumeric checks if string only contains digits. +func IsNumeric(s string) bool { + for _, r := range s { + if !unicode.IsDigit(r) { + return false + } + } + return true +} + +// Capitalize capitalizes the first letter of a string. +func Capitalize(s string) string { + if s == "" { + return "" + } + runes := []rune(s) + runes[0] = unicode.ToUpper(runes[0]) + return string(runes) +} + +// IsAlphaNumeric checks if string is alphanumeric +func IsAlphaNumeric(s string) bool { + re := regexp.MustCompile(`^[a-zA-Z0-9]+$`) + return re.MatchString(s) +} + +// IsSlug checks if string is URL slug friendly +func IsSlug(s string) bool { + re := regexp.MustCompile(`^[a-z0-9\-]+$`) + return re.MatchString(s) +} + +// IsStrongPassword performs basic password strength check +func IsStrongPassword(p string) bool { + length := len(p) >= 8 + hasUpper := regexp.MustCompile(`[A-Z]`).MatchString(p) + hasNumber := regexp.MustCompile(`[0-9]`).MatchString(p) + hasSymbol := regexp.MustCompile(`[!@#~$%^&*()+|_]`).MatchString(p) + return length && hasUpper && hasNumber && hasSymbol +} From 642ef313fec471876c2d31570fafa702df31e76e Mon Sep 17 00:00:00 2001 From: Zoe Woods Date: Thu, 10 Jul 2025 20:47:38 +0100 Subject: [PATCH 02/10] chore: update configuration parameters --- config/config.go | 22 +++++++++++++++++++++- config/types.go | 44 ++++++++++++++++++++++++++++++-------------- 2 files changed, 51 insertions(+), 15 deletions(-) diff --git a/config/config.go b/config/config.go index cb8f9ae..d4bf47c 100644 --- a/config/config.go +++ b/config/config.go @@ -31,14 +31,34 @@ func Load() error { } func Create() error { - file, err := json.MarshalIndent(&Config{Port: "7000", Address: "0.0.0.0", Experimental: false}, "", " ") + defaultConfig := Config{ + Port: "7000", + Address: "0.0.0.0", + Experimental: false, + ReadTimeout: 15, + WriteTimeout: 15, + IdleTimeout: 60, + LogLevel: "info", + MaxHeaderBytes: 1048576, + EnableTLS: false, + TLSCertFile: "", + TLSKeyFile: "", + ShutdownTimeout: 15, + EnableCORS: false, + AllowedOrigins: []string{"*"}, + EnableRequestLogging: false, + } + + file, err := json.MarshalIndent(&defaultConfig, "", " ") if err != nil { return fmt.Errorf("Create: failed marshalling config: %w", err) } + err = os.WriteFile(CONFIG, file, 0644) if err != nil { return fmt.Errorf("Create: failed writing config: %w", err) } + return nil } diff --git a/config/types.go b/config/types.go index be629ba..1fb57ec 100644 --- a/config/types.go +++ b/config/types.go @@ -1,19 +1,35 @@ package config type Config struct { - Port string `json:"port"` - Address string `json:"address"` - Experimental bool `json:"experimental"` + Port string `json:"port"` + Address string `json:"address"` + Experimental bool `json:"experimental"` + ReadTimeout int `json:"readTimeout"` // in seconds + WriteTimeout int `json:"writeTimeout"` // in seconds + IdleTimeout int `json:"idleTimeout"` // in seconds + LogLevel string `json:"logLevel"` // e.g. "debug", "info", "disabled" + MaxHeaderBytes int `json:"maxHeaderBytes"` + EnableTLS bool `json:"enableTLS"` + TLSCertFile string `json:"tlsCertFile"` + TLSKeyFile string `json:"tlsKeyFile"` + ShutdownTimeout int `json:"shutdownTimeout"` // graceful shutdown timeout seconds + EnableCORS bool `json:"enableCORS"` + AllowedOrigins []string `json:"allowedOrigins"` + EnableRequestLogging bool `json:"enableRequestLogging"` } -func Port() string { - return c.Port -} - -func Address() string { - return c.Address -} - -func Experimental() bool { - return c.Experimental -} +func Port() string { return c.Port } +func Address() string { return c.Address } +func Experimental() bool { return c.Experimental } +func ReadTimeout() int { return c.ReadTimeout } +func WriteTimeout() int { return c.WriteTimeout } +func IdleTimeout() int { return c.IdleTimeout } +func LogLevel() string { return c.LogLevel } +func MaxHeaderBytes() int { return c.MaxHeaderBytes } +func EnableTLS() bool { return c.EnableTLS } +func TLSCertFile() string { return c.TLSCertFile } +func TLSKeyFile() string { return c.TLSKeyFile } +func ShutdownTimeout() int { return c.ShutdownTimeout } +func EnableCORS() bool { return c.EnableCORS } +func AllowedOrigins() []string { return c.AllowedOrigins } +func EnableRequestLogging() bool { return c.EnableRequestLogging } From 5cc593b1feac62adbeb3c1ce1a0b96e513f11d09 Mon Sep 17 00:00:00 2001 From: Zoe Woods Date: Thu, 10 Jul 2025 20:47:47 +0100 Subject: [PATCH 03/10] chore: add proper logging --- log/log.go | 20 ++++++++++++++++++++ log/nop.go | 20 ++++++++++++++++++++ log/zerolog.go | 32 ++++++++++++++++++++++++++++++++ 3 files changed, 72 insertions(+) create mode 100644 log/log.go create mode 100644 log/nop.go create mode 100644 log/zerolog.go diff --git a/log/log.go b/log/log.go new file mode 100644 index 0000000..f65791e --- /dev/null +++ b/log/log.go @@ -0,0 +1,20 @@ +package log + +import "time" + +type Logger interface { + Debug() Entry + Info() Entry + Warn() Entry + Error() Entry + Fatal() Entry +} + +type Entry interface { + Str(key, value string) Entry + Dur(key string, value time.Duration) Entry + Int(key string, value int) Entry + Bool(key string, value bool) Entry + Msg(msg string) + Err(error) Entry +} diff --git a/log/nop.go b/log/nop.go new file mode 100644 index 0000000..c2f399c --- /dev/null +++ b/log/nop.go @@ -0,0 +1,20 @@ +package log + +import "time" + +type NoOpLogger struct{} + +func (l *NoOpLogger) Debug() Entry { return &noopEntry{} } +func (l *NoOpLogger) Info() Entry { return &noopEntry{} } +func (l *NoOpLogger) Warn() Entry { return &noopEntry{} } +func (l *NoOpLogger) Error() Entry { return &noopEntry{} } +func (l *NoOpLogger) Fatal() Entry { return &noopEntry{} } + +type noopEntry struct{} + +func (n *noopEntry) Str(string, string) Entry { return n } +func (n *noopEntry) Dur(string, time.Duration) Entry { return n } +func (n *noopEntry) Int(string, int) Entry { return n } +func (n *noopEntry) Bool(string, bool) Entry { return n } +func (n *noopEntry) Err(error) Entry { return n } +func (n *noopEntry) Msg(string) {} diff --git a/log/zerolog.go b/log/zerolog.go new file mode 100644 index 0000000..f6c6ff0 --- /dev/null +++ b/log/zerolog.go @@ -0,0 +1,32 @@ +package log + +import ( + "time" + + "github.com/rs/zerolog" +) + +type ZeroLogger struct { + z zerolog.Logger +} + +func NewZeroLogger(z zerolog.Logger) *ZeroLogger { + return &ZeroLogger{z} +} + +func (zl *ZeroLogger) Debug() Entry { return &zeroEntry{zl.z.Debug()} } +func (zl *ZeroLogger) Info() Entry { return &zeroEntry{zl.z.Info()} } +func (zl *ZeroLogger) Warn() Entry { return &zeroEntry{zl.z.Warn()} } +func (zl *ZeroLogger) Error() Entry { return &zeroEntry{zl.z.Error()} } +func (zl *ZeroLogger) Fatal() Entry { return &zeroEntry{zl.z.Fatal()} } + +type zeroEntry struct { + e *zerolog.Event +} + +func (z *zeroEntry) Str(k, v string) Entry { z.e.Str(k, v); return z } +func (z *zeroEntry) Dur(k string, v time.Duration) Entry { z.e.Dur(k, v); return z } +func (z *zeroEntry) Int(k string, v int) Entry { z.e.Int(k, v); return z } +func (z *zeroEntry) Bool(k string, v bool) Entry { z.e.Bool(k, v); return z } +func (z *zeroEntry) Err(e error) Entry { z.e.Err(e); return z } +func (z *zeroEntry) Msg(m string) { z.e.Msg(m) } From ad4fc7646fe91585e7625c82c06ed75cc275a6f0 Mon Sep 17 00:00:00 2001 From: Zoe Woods Date: Thu, 10 Jul 2025 20:47:57 +0100 Subject: [PATCH 04/10] feat: add public middleware --- middleware/public.go | 100 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 100 insertions(+) create mode 100644 middleware/public.go diff --git a/middleware/public.go b/middleware/public.go new file mode 100644 index 0000000..b3aacbb --- /dev/null +++ b/middleware/public.go @@ -0,0 +1,100 @@ +package middleware + +import ( + "context" + "net/http" + + "github.com/Etwodev/ramchi/log" +) + +// ctxKey is a private type used as a key for storing values in context. +// This prevents collisions with other context keys. +type ctxKey string + +// loggerCtxKey is the key used to store the logger instance in the request context. +var loggerCtxKey = ctxKey("logger") + +// LoggerInjectionMiddleware returns a Middleware that injects the provided logger +// instance into the request's context. This allows downstream handlers and middleware +// to retrieve the logger directly from the context for structured logging. +// If your preferred logging library is not supported, please raise an issue on this repo. +// +// Usage: +// +// // Create the logger (e.g., in main.go) +// myLogger := zerolog.New(format) +// +// // Create the middleware +// func Middlewares() []middleware.Middleware { +// return []middleware.Middleware{ +// middleware.NewLoggingMiddleware(myLogger), +// middleware.NewMiddleware(auth.Middleware(), "auth", true, false), +// } +// } +// +// // Load the middleware +// s.LoadMiddleware(Middlewares()) +// +// // In your handlers, you can retrieve the logger from the context like this: +// +// func MyHandler(w http.ResponseWriter, r *http.Request) { +// logger := middleware.LoggerFromContext(r.Context()) +// logger.Info().Msg("Handling request") +// // ... +// } +func NewLoggingMiddleware(logger log.Logger) Middleware { + return NewMiddleware(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := context.WithValue(r.Context(), loggerCtxKey, logger) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + }, "ramchi_logger_inject", true, false) +} + +// LoggerFromContext retrieves the logger instance from the context. +// Returns nil if no logger is found. Requires 'LoggerInjectionMiddleware' to +// be consumed. +func LoggerFromContext(ctx context.Context) log.Logger { + if logger, ok := ctx.Value(loggerCtxKey).(log.Logger); ok { + return logger + } + return nil +} + +// NewCORSMiddleware returns a simple CORS middleware. +// allowedOrigins is a list of origins that are allowed. Use ["*"] for allowing all. +func NewCORSMiddleware(allowedOrigins []string) Middleware { + return NewMiddleware(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + origin := r.Header.Get("Origin") + allowed := false + + if len(allowedOrigins) == 1 && allowedOrigins[0] == "*" { + allowed = true + } else { + for _, o := range allowedOrigins { + if o == origin { + allowed = true + break + } + } + } + + if allowed { + w.Header().Set("Access-Control-Allow-Origin", origin) + w.Header().Set("Vary", "Origin") + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Accept, Authorization, Content-Type, X-CSRF-Token") + w.Header().Set("Access-Control-Allow-Credentials", "true") + } + + // For OPTIONS requests, respond with 200 immediately (CORS preflight) + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusOK) + return + } + + next.ServeHTTP(w, r) + }) + }, "ramchi_cors", true, false) +} From c554c1621015bc78fc11cd1273e4e16810ff66f5 Mon Sep 17 00:00:00 2001 From: Zoe Woods Date: Thu, 10 Jul 2025 20:48:07 +0100 Subject: [PATCH 05/10] chore: update router handling --- router/local.go | 135 ++++++++++++++++++++++++++--------------------- router/router.go | 18 ++++--- 2 files changed, 86 insertions(+), 67 deletions(-) diff --git a/router/local.go b/router/local.go index 3407d11..27f8cd4 100644 --- a/router/local.go +++ b/router/local.go @@ -4,112 +4,127 @@ import ( "net/http" ) -type preRouter struct { - status bool - prefix string - routes []Route -} +// --- Internal structs --- -type preRoute struct { +type route struct { method string path string status bool experimental bool handler http.HandlerFunc + middleware []func(http.Handler) http.Handler } -// RouterWrapper wraps a router with extra functionality . -// It is passed in when creating a new router. -type RouterWrapper func(r Router) Router +type router struct { + status bool + prefix string + routes []Route + middleware []func(http.Handler) http.Handler +} -// RouteWrapper wraps a route with extra functionality. -// It is passed in when creating a new route. -type RouteWrapper func(r Route) Route +// --- Route implementation --- -// Routes returns an array of routes -func (p preRouter) Routes() []Route { - return p.routes +func (r route) Handler() http.HandlerFunc { + return r.handler } -// Status returns whether the router should be enabled. -func (p preRouter) Status() bool { - return p.status +func (r route) Method() string { + return r.method } -// Prefix returns the starting string of a route, e.g. /api -func (p preRouter) Prefix() string { - return p.prefix +func (r route) Path() string { + return r.path } -// Function returns the function route applies. -func (p preRoute) Handler() http.HandlerFunc { - return p.handler +func (r route) Status() bool { + return r.status } -// Method returns the http method that the route responds to. -func (p preRoute) Method() string { - return p.method +func (r route) Experimental() bool { + return r.experimental +} + +func (r route) Middleware() []func(http.Handler) http.Handler { + return r.middleware +} + +// --- Router implementation --- + +func (r router) Routes() []Route { + return r.routes } -// Path returns the subpath where the route responds to. -func (p preRoute) Path() string { - return p.path +func (r router) Status() bool { + return r.status } -// Status returns whether the route should be enabled. -func (p preRoute) Status() bool { - return p.status +func (r router) Prefix() string { + return r.prefix } -// Experimental returns whether the route is enabled. -func (p preRoute) Experimental() bool { - return p.experimental +func (r router) Middleware() []func(http.Handler) http.Handler { + return r.middleware } -// NewRouter initializes a new local router for the system. -func NewRouter(prefix string, routes []Route, status bool, opts ...RouterWrapper) Router { - var r Router = preRouter{status, prefix, routes} +// --- Wrappers for extensibility --- + +type RouterWrapper func(r Router) Router +type RouteWrapper func(r Route) Route + +// --- Constructors --- + +// NewRouter creates a new Router with a prefix, status flag, routes, and optional middleware. +func NewRouter(prefix string, routes []Route, status bool, middleware []func(http.Handler) http.Handler, opts ...RouterWrapper) Router { + var r Router = router{ + status: status, + prefix: prefix, + routes: routes, + middleware: middleware, + } for _, o := range opts { r = o(r) } return r } -// NewRoute initializes a new local route for the router. -func NewRoute(method string, path string, status bool, experimental bool, handler http.HandlerFunc, opts ...RouteWrapper) Route { - var r Route = preRoute{method, path, status, experimental, handler} +// NewRoute creates a new Route with method, path, status, middleware, and handler. +func NewRoute(method, path string, status, experimental bool, handler http.HandlerFunc, middleware []func(http.Handler) http.Handler, opts ...RouteWrapper) Route { + var r Route = route{ + method: method, + path: path, + status: status, + experimental: experimental, + handler: handler, + middleware: middleware, + } for _, o := range opts { r = o(r) } return r } -// NewGetRoute initializes a new route with the http method GET. -func NewGetRoute(path string, status bool, experimental bool, handler http.HandlerFunc, opts ...RouteWrapper) Route { - return NewRoute(http.MethodGet, path, status, experimental, handler, opts...) +// --- Convenience functions for each HTTP verb --- + +func NewGetRoute(path string, status, experimental bool, handler http.HandlerFunc, middleware []func(http.Handler) http.Handler, opts ...RouteWrapper) Route { + return NewRoute(http.MethodGet, path, status, experimental, handler, middleware, opts...) } -// NewPostRoute initializes a new route with the http method POST. -func NewPostRoute(path string, status bool, experimental bool, handler http.HandlerFunc, opts ...RouteWrapper) Route { - return NewRoute(http.MethodPost, path, status, experimental, handler, opts...) +func NewPostRoute(path string, status, experimental bool, handler http.HandlerFunc, middleware []func(http.Handler) http.Handler, opts ...RouteWrapper) Route { + return NewRoute(http.MethodPost, path, status, experimental, handler, middleware, opts...) } -// NewPutRoute initializes a new route with the http method PUT. -func NewPutRoute(path string, status bool, experimental bool, handler http.HandlerFunc, opts ...RouteWrapper) Route { - return NewRoute(http.MethodPut, path, status, experimental, handler, opts...) +func NewPutRoute(path string, status, experimental bool, handler http.HandlerFunc, middleware []func(http.Handler) http.Handler, opts ...RouteWrapper) Route { + return NewRoute(http.MethodPut, path, status, experimental, handler, middleware, opts...) } -// NewDeleteRoute initializes a new route with the http method DELETE. -func NewDeleteRoute(path string, status bool, experimental bool, handler http.HandlerFunc, opts ...RouteWrapper) Route { - return NewRoute(http.MethodDelete, path, status, experimental, handler, opts...) +func NewDeleteRoute(path string, status, experimental bool, handler http.HandlerFunc, middleware []func(http.Handler) http.Handler, opts ...RouteWrapper) Route { + return NewRoute(http.MethodDelete, path, status, experimental, handler, middleware, opts...) } -// NewOptionsRoute initializes a new route with the http method OPTIONS. -func NewOptionsRoute(path string, status bool, experimental bool, handler http.HandlerFunc, opts ...RouteWrapper) Route { - return NewRoute(http.MethodOptions, path, status, experimental, handler, opts...) +func NewOptionsRoute(path string, status, experimental bool, handler http.HandlerFunc, middleware []func(http.Handler) http.Handler, opts ...RouteWrapper) Route { + return NewRoute(http.MethodOptions, path, status, experimental, handler, middleware, opts...) } -// NewHeadRoute initializes a new route with the http method HEAD. -func NewHeadRoute(path string, status bool, experimental bool, handler http.HandlerFunc, opts ...RouteWrapper) Route { - return NewRoute(http.MethodHead, path, status, experimental, handler, opts...) +func NewHeadRoute(path string, status, experimental bool, handler http.HandlerFunc, middleware []func(http.Handler) http.Handler, opts ...RouteWrapper) Route { + return NewRoute(http.MethodHead, path, status, experimental, handler, middleware, opts...) } diff --git a/router/router.go b/router/router.go index 38cc771..66f243c 100644 --- a/router/router.go +++ b/router/router.go @@ -5,23 +5,27 @@ import ( ) type Router interface { - // Returns the list of all routes + // Routes returns all registered routes Routes() []Route - // Is the router enabled + // Status returns whether the router is active Status() bool - // The router prefix + // Prefix returns the base path Prefix() string + // Middleware returns router-level middleware + Middleware() []func(http.Handler) http.Handler } type Route interface { - // Handler returns the function the route applies + // Handler is the HTTP handler function Handler() http.HandlerFunc - // Method returns the http method the route corresponds to + // Method is the HTTP verb (GET, POST, etc.) Method() string - // Path returns the subpath where the route responds + // Path is the relative route path Path() string - // Status returns whether the route is enabled + // Status returns whether the route is active Status() bool // Experimental returns whether the route is experimental Experimental() bool + // Middleware returns route-level middleware + Middleware() []func(http.Handler) http.Handler } From e1f34a47a21cec2b41584c4c82c27226c12dbffd Mon Sep 17 00:00:00 2001 From: Zoe Woods Date: Thu, 10 Jul 2025 20:48:16 +0100 Subject: [PATCH 06/10] feat: update package handling --- README.md | 211 +++++++++++++++++++++++++++++++++++++++++++++++++++--- go.mod | 7 +- go.sum | 5 +- ramchi.go | 146 +++++++++++++++++++++++++++---------- 4 files changed, 319 insertions(+), 50 deletions(-) diff --git a/README.md b/README.md index 0501e62..bf4b7e8 100644 --- a/README.md +++ b/README.md @@ -1,22 +1,211 @@ # ramchi -`ramchi` is an extension to `chi` for rapid & modular development of sites and restful applications. -It allows modular composition of routes, with easy handler registering in a manner that is simple to use, -`ramchi` is based upon developer experience and usage, while still making your website fast and responsive. -## Install +`ramchi` is an extension to the [chi](https://github.com/go-chi/chi) HTTP router designed for rapid and modular development of web applications. `ramchi` focuses on developer experience while ensuring your website remains fast and responsive. -`go get -u github.com/etwodev/ramchi` +--- -## Config +## Features -When you create [your first server](), `ramchi` will generate a `ramchi.config.json` file, -which allows you to configure aspects of the server. +- Modular router and middleware loading +- Support for feature flagging via experimental toggles +- Unified backend and frontend serving capabilities +- Zero-configuration TLS support +- Structured, leveled logging powered by `zerolog` +- Graceful shutdown and signal handling +- Extensible helpers for requests, responses, crypto, email, and more + +--- + +## Installation + +```bash +go get -u github.com/etwodev/ramchi +``` + +--- + +## Getting Started + +ramchi allows easy, modular registration of endpoints through grouping. + +Create a new server instance and start it: + +```go +package main + +import ( + "github.com/etwodev/ramchi" + "encoding/json" + "net/http" +) + +func main() { + s := ramchi.New() + s.LoadRouter(Routers()) + s.Start() +} + +func Routers() []router.Router { + return []router.Router{ + router.NewRouter("example", Routes(), true), + } +} + +func Routes() []router.Route { + return []router.Route{ + router.NewGetRoute("/demo", true, false, ExampleGetHandler), + } +} + +// This route will be a GET endpoint registered at /example/demo +func ExampleGetHandler(w http.ResponseWriter, r *http.Request) { + res, _ := json.Marshal(map[string]string{"success": "ping"}) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(201) + if _, err := w.Write(res); err != nil { + t.Fatal(err) + } +} +``` + +On the first run, `ramchi` will automatically generate a default `ramchi.config.json` file in your working directory. + +--- + +## Configuration + +The `ramchi.config.json` file controls server behavior and feature toggling. + +### Default Configuration Example ```json { - "port": "8080", - "address": "localhost", - "experimental": false + "port": "7000", + "address": "0.0.0.0", + "experimental": false, + "logLevel": "info", + "enableTLS": false, + "tlsCertFile": "", + "tlsKeyFile": "", + "readTimeout": 15, + "writeTimeout": 15, + "idleTimeout": 60, + "maxHeaderBytes": 1048576, + "shutdownTimeout": 15, + "enableCORS": false, + "allowedOrigins": ["*"], + "enableRequestLogging": false +} +``` + +### Configuration Fields + +| Field | Type | Description | Default | +| ---------------------- | --------- | --------------------------------------------------------------------------- | ----------- | +| `port` | string | TCP port the server listens on | `"7000"` | +| `address` | string | IP address to bind to | `"0.0.0.0"` | +| `experimental` | bool | Enables or disables experimental feature flags | `false` | +| `logLevel` | string | Log verbosity level (`debug`, `info`, `warn`, `error`, `fatal`, `disabled`) | `"info"` | +| `enableTLS` | bool | Enable HTTPS by providing TLS certificate and key | `false` | +| `tlsCertFile` | string | Path to TLS certificate file (required if `enableTLS` is true) | `""` | +| `tlsKeyFile` | string | Path to TLS key file (required if `enableTLS` is true) | `""` | +| `readTimeout` | int | Maximum duration (in seconds) for reading the request | `15` | +| `writeTimeout` | int | Maximum duration (in seconds) before timing out response writes | `15` | +| `idleTimeout` | int | Maximum duration (in seconds) to keep idle connections open | `60` | +| `maxHeaderBytes` | int | Maximum size of request headers in bytes | `1048576` | +| `shutdownTimeout` | int | Time (in seconds) allowed for graceful shutdown | `15` | +| `enableCORS` | bool | Automatically enables CORS middleware | `false` | +| `allowedOrigins` | \[]string | List of allowed CORS origins (e.g., `["*"]`, `["https://example.com"]`) | `["*"]` | +| `enableRequestLogging` | bool | Automatically enables HTTP request logging middleware | `false` | + +--- + +## Togglable Middleware + +You can enable built-in middleware through the config file without registering them manually. + +### Available Middleware + +| Middleware | Config Flag | Description | +| ------------------- | ---------------------- | --------------------------------------------------------------- | +| **CORS** | `enableCORS` | Adds a permissive or origin-restricted CORS layer | +| **Request Logging** | `enableRequestLogging` | Logs all incoming HTTP requests using structured logging format | + +These are injected globally before any custom middleware or routes. + +If you require more control (e.g., middleware ordering or conditional logic), you can still register them manually through the `LoadMiddleware()` method. + +--- + +## Using the Configuration in Code + +Your application can access config values via the `config` package accessor functions: + +```go +import c "github.com/etwodev/ramchi/config" + +port := c.Port() // e.g. "7000" +address := c.Address() // e.g. "0.0.0.0" +if c.Experimental() { + // Enable experimental features +} +level := c.LogLevel() // e.g. "debug" +if c.EnableTLS() { + cert := c.TLSCertFile() + key := c.TLSKeyFile() + // Use cert and key to start HTTPS server } ``` +The server internally uses these config values to set up logging, timeouts, TLS, and feature flags. + +--- + +## Logging + +`ramchi` integrates [zerolog](https://github.com/rs/zerolog) for structured, leveled logging with console-friendly output by default. However, logging can be replaced. If you would like a specific package to be supported, please raise an issue. + +* Log verbosity is controlled by the `logLevel` config (e.g., `debug`, `info`, `disabled`). +* Logs include contextual fields such as the server group, function names, HTTP method, route path, and middleware names. +* Graceful shutdown logs warnings and fatal errors as appropriate. + +--- + +## Middleware & Routing + +* Load your middlewares and routers modularly before starting the server. +* `ramchi` respects middleware and route `Experimental` flags based on your config. +* Routes and middleware with disabled status or mismatched experimental flags are skipped. + +--- + +## TLS Support + +Set `enableTLS` to `true` and provide valid paths to `tlsCertFile` and `tlsKeyFile` in your config to serve HTTPS. + +--- + +## Extending Helpers + +`ramchi` ships with helper packages for common tasks: + +* `helpers/request.go`: HTTP request utilities (e.g., extracting IP, URL params) +* `helpers/response.go`: Response helpers for JSON encoding, error handling +* `helpers/crypto.go`: Crypto utilities (hashing, encryption helpers) +* `helpers/email.go`: Email sending and templating helpers +* `helpers/strings.go`: String manipulation utilities (e.g., truncation, padding, sanitization) +* `helpers/encoding.go`: Encoding utilities (e.g., toHex, toBase64) + +You are encouraged to extend these helper packages or create your own. + +--- + +## Contributing + +Contributions and suggestions are welcome. Please open issues or pull requests on the [GitHub repository](https://github.com/etwodev/ramchi). + +--- + +## License + +MIT License © Etwodev diff --git a/go.mod b/go.mod index 0d19641..30e3f0f 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,8 @@ module github.com/Etwodev/ramchi -go 1.21.1 +go 1.23.0 + +toolchain go1.23.2 require ( github.com/go-chi/chi/v5 v5.0.10 @@ -10,5 +12,6 @@ require ( require ( github.com/mattn/go-colorable v0.1.12 // indirect github.com/mattn/go-isatty v0.0.14 // indirect - golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6 // indirect + golang.org/x/crypto v0.40.0 + golang.org/x/sys v0.34.0 // indirect ) diff --git a/go.sum b/go.sum index facc7b8..b372c6a 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,9 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/rs/zerolog v1.30.0 h1:SymVODrcRsaRaSInD9yQtKbtWqwsfoPcRff/oRXLj4c= github.com/rs/zerolog v1.30.0/go.mod h1:/tk+P47gFdPXq4QYjvCmT5/Gsug2nagsFWBWhAiSi1w= +golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM= +golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6 h1:foEbQz/B0Oz6YIqu/69kfXPYeFQAuuMYFkjaqXzl5Wo= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= +golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= diff --git a/ramchi.go b/ramchi.go index eccd3f9..a8b0059 100644 --- a/ramchi.go +++ b/ramchi.go @@ -7,8 +7,10 @@ import ( "os" "os/signal" "path" + "time" c "github.com/Etwodev/ramchi/config" + "github.com/Etwodev/ramchi/log" "github.com/Etwodev/ramchi/middleware" "github.com/Etwodev/ramchi/router" @@ -16,24 +18,39 @@ import ( "github.com/rs/zerolog" ) -var log zerolog.Logger - type Server struct { idle chan struct{} middlewares []middleware.Middleware routers []router.Router instance *http.Server + logger log.Logger } func New() *Server { - format := zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: "2006-01-02T15:04:05"} - log = zerolog.New(format).With().Timestamp().Str("Group", "ramchi").Logger() - err := c.New() if err != nil { - log.Fatal().Str("Function", "New").Err(err).Msg("Unexpected error") + baseLogger := zerolog.New(os.Stdout).With().Timestamp().Str("Group", "ramchi").Logger() + baseLogger.Fatal().Str("Function", "New").Err(err).Msg("Failed to load config") + } + + level, err := zerolog.ParseLevel(c.LogLevel()) + if err != nil { + level = zerolog.InfoLevel + } + zerolog.SetGlobalLevel(level) + + format := zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: "2006-01-02T15:04:05"} + baseLogger := zerolog.New(format).With().Timestamp().Str("Group", "ramchi").Logger() + + logger := log.NewZeroLogger(baseLogger) + + return &Server{ + logger: logger, } - return &Server{} +} + +func (s *Server) Logger() log.Logger { + return s.logger } func (s *Server) LoadRouter(routers []router.Router) { @@ -45,34 +62,68 @@ func (s *Server) LoadMiddleware(middlewares []middleware.Middleware) { } func (s *Server) Start() { - s.instance = &http.Server{Addr: fmt.Sprintf("%s:%s", c.Address(), c.Port()), Handler: s.handler()} - log.Debug().Str("Port", c.Port()).Str("Address", c.Address()).Bool("Experimental", c.Experimental()).Msg("Server started") + // Load CORS middleware if enabled in config + if c.EnableCORS() && len(c.AllowedOrigins()) > 0 { + corsMw := middleware.NewCORSMiddleware(c.AllowedOrigins()) + s.LoadMiddleware([]middleware.Middleware{corsMw}) + } + + // Load Logging middleware if enabled + if c.EnableRequestLogging() { + loggingMw := middleware.NewLoggingMiddleware(s.logger) + s.LoadMiddleware([]middleware.Middleware{loggingMw}) + } + + s.instance = &http.Server{ + Addr: fmt.Sprintf("%s:%s", c.Address(), c.Port()), + Handler: s.handler(), + ReadTimeout: time.Duration(c.ReadTimeout()) * time.Second, + WriteTimeout: time.Duration(c.WriteTimeout()) * time.Second, + IdleTimeout: time.Duration(c.IdleTimeout()) * time.Second, + MaxHeaderBytes: c.MaxHeaderBytes(), + } + + s.logger.Debug(). + Str("Port", c.Port()). + Str("Address", c.Address()). + Bool("Experimental", c.Experimental()). + Msg("Server started") s.idle = make(chan struct{}) go func() { sigint := make(chan os.Signal, 1) signal.Notify(sigint, os.Interrupt) <-sigint - if err := s.instance.Shutdown(context.Background()); err != nil { - log.Warn().Str("Function", "Shutdown").Err(err).Msg("Server shutdown failed!") + + timeout := time.Duration(c.ShutdownTimeout()) * time.Second + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + if err := s.instance.Shutdown(ctx); err != nil { + s.logger.Warn().Str("Function", "Shutdown").Err(err).Msg("Server shutdown failed!") } close(s.idle) }() - if err := s.instance.ListenAndServe(); err != http.ErrServerClosed { - log.Fatal().Str("Function", "ListenAndServe").Err(err).Msg("Unexpected error") + if c.EnableTLS() { + s.logger.Info().Msg("Starting HTTPS server") + if err := s.instance.ListenAndServeTLS(c.TLSCertFile(), c.TLSKeyFile()); err != nil && err != http.ErrServerClosed { + s.logger.Fatal().Err(err).Msg("HTTPS server failed") + } + } else { + s.logger.Info().Msg("Starting HTTP server") + if err := s.instance.ListenAndServe(); err != nil && err != http.ErrServerClosed { + s.logger.Fatal().Err(err).Msg("HTTP server failed") + } } <-s.idle - log.Debug().Str("Port", c.Port()).Str("Address", c.Address()).Bool("Experimental", c.Experimental()).Msg("Server stopped") -} - -func Handle(w http.ResponseWriter, function string, err error, msg string, code int) { - if err != nil { - log.Error().Str("Function", function).Str("Status", http.StatusText(code)).Err(err).Msg(msg) - http.Error(w, http.StatusText(code), code) - } + s.logger.Debug(). + Str("Port", c.Port()). + Str("Address", c.Address()). + Bool("Experimental", c.Experimental()). + Msg("Server stopped") } func (s *Server) handler() *chi.Mux { @@ -82,28 +133,51 @@ func (s *Server) handler() *chi.Mux { } func (s *Server) initMux(m *chi.Mux) { + // Global middleware for _, middleware := range s.middlewares { if middleware.Status() && (middleware.Experimental() == c.Experimental() || !middleware.Experimental()) { - log.Debug().Str("Name", middleware.Name()).Bool("Experimental", middleware.Experimental()).Bool("Status", middleware.Status()).Msg("Registering middleware") + s.logger.Debug(). + Str("Name", middleware.Name()). + Bool("Experimental", middleware.Experimental()). + Bool("Status", middleware.Status()). + Msg("Registering middleware") + m.Use(middleware.Method()) } } - for _, router := range s.routers { - if router.Status() { - for _, r := range router.Routes() { - if r.Status() && (r.Experimental() == c.Experimental() || !r.Experimental()) { - fullPath := path.Join("/", router.Prefix(), r.Path()) - log.Debug(). - Bool("Experimental", r.Experimental()). - Bool("Status", r.Status()). - Str("Method", r.Method()). - Str("Path", fullPath). - Msg("Registering route") - m.Method(r.Method(), fullPath, r.Handler()) + // Routers + for _, rtr := range s.routers { + if !rtr.Status() { + continue + } + + m.Route("/"+rtr.Prefix(), func(r chi.Router) { + for _, rmw := range rtr.Middleware() { + r.Use(rmw) + } + + for _, rt := range rtr.Routes() { + if !rt.Status() || (rt.Experimental() != c.Experimental() && rt.Experimental()) { + continue } + + fullPath := "/" + rt.Path() + + s.logger.Debug(). + Bool("Experimental", rt.Experimental()). + Bool("Status", rt.Status()). + Str("Method", rt.Method()). + Str("Path", path.Join("/", rtr.Prefix(), rt.Path())). + Msg("Registering route") + + finalHandler := http.Handler(rt.Handler()) + for _, mw := range rt.Middleware() { + finalHandler = mw(finalHandler) + } + + r.Method(rt.Method(), fullPath, finalHandler) } - } + }) } - } From a9b02c73438f4706e31938de08b919ef7cbe9079 Mon Sep 17 00:00:00 2001 From: Zoe Woods Date: Thu, 10 Jul 2025 20:58:20 +0100 Subject: [PATCH 07/10] chore: update tests --- log/log.go | 19 +++++++++- middleware/public.go | 19 +--------- ramchi_test.go | 82 ++++++++++++++++++++++++++++++++++++-------- 3 files changed, 87 insertions(+), 33 deletions(-) diff --git a/log/log.go b/log/log.go index f65791e..a8af734 100644 --- a/log/log.go +++ b/log/log.go @@ -1,6 +1,16 @@ package log -import "time" +import ( + "context" + "time" +) + +// ctxKey is a private type used as a key for storing values in context. +// This prevents collisions with other context keys. +type ctxKey string + +// loggerCtxKey is the key used to store the logger instance in the request context. +var LoggerCtxKey = ctxKey("logger") type Logger interface { Debug() Entry @@ -18,3 +28,10 @@ type Entry interface { Msg(msg string) Err(error) Entry } + +func FromContext(ctx context.Context) Logger { + if logger, ok := ctx.Value(LoggerCtxKey).(Logger); ok { + return logger + } + return nil +} diff --git a/middleware/public.go b/middleware/public.go index b3aacbb..38ae04b 100644 --- a/middleware/public.go +++ b/middleware/public.go @@ -7,13 +7,6 @@ import ( "github.com/Etwodev/ramchi/log" ) -// ctxKey is a private type used as a key for storing values in context. -// This prevents collisions with other context keys. -type ctxKey string - -// loggerCtxKey is the key used to store the logger instance in the request context. -var loggerCtxKey = ctxKey("logger") - // LoggerInjectionMiddleware returns a Middleware that injects the provided logger // instance into the request's context. This allows downstream handlers and middleware // to retrieve the logger directly from the context for structured logging. @@ -45,22 +38,12 @@ var loggerCtxKey = ctxKey("logger") func NewLoggingMiddleware(logger log.Logger) Middleware { return NewMiddleware(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx := context.WithValue(r.Context(), loggerCtxKey, logger) + ctx := context.WithValue(r.Context(), log.LoggerCtxKey, logger) next.ServeHTTP(w, r.WithContext(ctx)) }) }, "ramchi_logger_inject", true, false) } -// LoggerFromContext retrieves the logger instance from the context. -// Returns nil if no logger is found. Requires 'LoggerInjectionMiddleware' to -// be consumed. -func LoggerFromContext(ctx context.Context) log.Logger { - if logger, ok := ctx.Value(loggerCtxKey).(log.Logger); ok { - return logger - } - return nil -} - // NewCORSMiddleware returns a simple CORS middleware. // allowedOrigins is a list of origins that are allowed. Use ["*"] for allowing all. func NewCORSMiddleware(allowedOrigins []string) Middleware { diff --git a/ramchi_test.go b/ramchi_test.go index c21716f..19e5ae9 100644 --- a/ramchi_test.go +++ b/ramchi_test.go @@ -2,34 +2,39 @@ package ramchi import ( "encoding/json" - "errors" "io" "net/http" "net/http/httptest" "testing" + c "github.com/Etwodev/ramchi/config" + "github.com/Etwodev/ramchi/log" + "github.com/Etwodev/ramchi/middleware" "github.com/Etwodev/ramchi/router" ) -func testRequest(t *testing.T, ts *httptest.Server, method, path string, body io.Reader) (*http.Response, string) { +func testRequest(t *testing.T, ts *httptest.Server, method, path string, body io.Reader, headers map[string]string) (*http.Response, string) { req, err := http.NewRequest(method, ts.URL+path, body) if err != nil { t.Fatal(err) return nil, "" } + for k, v := range headers { + req.Header.Set(k, v) + } resp, err := http.DefaultClient.Do(req) if err != nil { t.Fatal(err) return nil, "" } + defer resp.Body.Close() respBody, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) return nil, "" } - defer resp.Body.Close() return resp, string(respBody) } @@ -41,43 +46,92 @@ func TestBasicServer(t *testing.T) { ts := New() + // Simulate what Start() would do — apply middlewares based on config + if c.EnableCORS() && len(c.AllowedOrigins()) > 0 { + corsMw := middleware.NewCORSMiddleware(c.AllowedOrigins()) + ts.LoadMiddleware([]middleware.Middleware{corsMw}) + } + if c.EnableRequestLogging() { + loggingMw := middleware.NewLoggingMiddleware(ts.Logger()) + ts.LoadMiddleware([]middleware.Middleware{loggingMw}) + } + + // Handlers pingAll := func(w http.ResponseWriter, r *http.Request) { + // Confirm logger middleware injected the logger + logger := log.FromContext(r.Context()) + if logger == nil { + t.Error("Expected logger to be injected into context via middleware") + } + res, _ := json.Marshal(map[string]string{"success": "ping"}) w.Header().Set("Content-Type", "application/json") w.WriteHeader(201) - if _, err := w.Write(res); err != nil { - t.Fatal(err) - } + w.Write(res) } errorAll := func(w http.ResponseWriter, r *http.Request) { - Handle(w, "errorAll", errors.New(ERROR_RESPONSE), ERROR_MESSAGE, ERROR_STATUS_CODE) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(ERROR_STATUS_CODE) + response := map[string]string{ + "error": ERROR_MESSAGE, + "details": ERROR_RESPONSE, + } + json.NewEncoder(w).Encode(response) } + // Routes testRoutes := func() []router.Route { return []router.Route{ - router.NewGetRoute("ping", true, false, pingAll), // No leading slash - router.NewGetRoute("error", true, false, errorAll), // No leading slash + router.NewGetRoute("ping", true, false, pingAll, nil), + router.NewGetRoute("error", true, false, errorAll, nil), } } + // Routers testRouters := func() []router.Router { return []router.Router{ - router.NewRouter("test", testRoutes(), true), // Prefix is now "test" + router.NewRouter("test", testRoutes(), true, nil), } } - ts.LoadRouter(testRouters()) instance := httptest.NewServer(ts.handler()) defer instance.Close() - // Expect routes to be mounted under /test/ - if _, body := testRequest(t, instance, http.MethodGet, "/test/ping", nil); body != `{"success":"ping"}` { + // ─── Test /ping ───────────────────────────────────────────────────── + resp, body := testRequest(t, instance, http.MethodGet, "/test/ping", nil, nil) + if resp.StatusCode != http.StatusCreated { + t.Errorf("Expected 201 status, got %d", resp.StatusCode) + } + if body != `{"success":"ping"}` { t.Fatalf("Unexpected ping response: %s", body) } - if _, body := testRequest(t, instance, http.MethodGet, "/test/error", nil); body != "I'm a teapot\n" { + // ─── Test /error ──────────────────────────────────────────────────── + resp, body = testRequest(t, instance, http.MethodGet, "/test/error", nil, nil) + if resp.StatusCode != ERROR_STATUS_CODE { + t.Errorf("Expected status %d, got %d", ERROR_STATUS_CODE, resp.StatusCode) + } + expected := `{"details":"test error pass-through","error":"Example error has occurred"}` + "\n" + if body != expected { t.Fatalf("Unexpected error response: %s", body) } + + // ─── Test CORS ────────────────────────────────────────────────────── + req, _ := http.NewRequest(http.MethodOptions, instance.URL+"/test/ping", nil) + req.Header.Set("Origin", "http://example.com") + req.Header.Set("Access-Control-Request-Method", "GET") + corsResp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + + if got := corsResp.Header.Get("Access-Control-Allow-Origin"); got != "http://example.com" && got != "*" { + t.Errorf("CORS header mismatch: got '%s', expected 'http://example.com'", got) + } + + if got := corsResp.Header.Get("Access-Control-Allow-Methods"); got == "" { + t.Errorf("CORS Allow-Methods header missing") + } } From f397639c53b420ee3d437bfbf07bf4805a184b22 Mon Sep 17 00:00:00 2001 From: Zoe Woods Date: Thu, 10 Jul 2025 21:06:18 +0100 Subject: [PATCH 08/10] chore: fix linting --- helpers/response.go | 8 ++++---- ramchi_test.go | 14 ++++++++++++-- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/helpers/response.go b/helpers/response.go index b56a1ab..c681a9c 100644 --- a/helpers/response.go +++ b/helpers/response.go @@ -6,15 +6,15 @@ import ( ) // RespondWithJSON writes a JSON response with status code -func RespondWithJSON(w http.ResponseWriter, status int, payload interface{}) { +func RespondWithJSON(w http.ResponseWriter, status int, payload interface{}) error { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) - json.NewEncoder(w).Encode(payload) + return json.NewEncoder(w).Encode(payload) } // RespondWithError writes a standardized error message -func RespondWithError(w http.ResponseWriter, status int, message string) { - RespondWithJSON(w, status, map[string]string{"error": message}) +func RespondWithError(w http.ResponseWriter, status int, message string) error { + return RespondWithJSON(w, status, map[string]string{"error": message}) } // NoContent sends a 204 No Content response diff --git a/ramchi_test.go b/ramchi_test.go index 19e5ae9..260e11e 100644 --- a/ramchi_test.go +++ b/ramchi_test.go @@ -67,7 +67,12 @@ func TestBasicServer(t *testing.T) { res, _ := json.Marshal(map[string]string{"success": "ping"}) w.Header().Set("Content-Type", "application/json") w.WriteHeader(201) - w.Write(res) + + _, err := w.Write(res) + if err != nil { + t.Fatal(err) + } + } errorAll := func(w http.ResponseWriter, r *http.Request) { @@ -77,7 +82,12 @@ func TestBasicServer(t *testing.T) { "error": ERROR_MESSAGE, "details": ERROR_RESPONSE, } - json.NewEncoder(w).Encode(response) + + err := json.NewEncoder(w).Encode(response) + if err != nil { + t.Fatal(err) + } + } // Routes From d6523594408ff195e241ef311a0d5f85010393ee Mon Sep 17 00:00:00 2001 From: Zoe Woods Date: Thu, 10 Jul 2025 21:20:26 +0100 Subject: [PATCH 09/10] fix: add default config option --- config/config.go | 10 +++++++--- ramchi_test.go | 20 ++++++++++++++++++++ 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/config/config.go b/config/config.go index d4bf47c..a89cdb3 100644 --- a/config/config.go +++ b/config/config.go @@ -13,7 +13,7 @@ var c *Config func Load() error { _, err := os.Stat(CONFIG) if os.IsNotExist(err) { - if err := Create(); err != nil { + if err := Create(nil); err != nil { return fmt.Errorf("Load: failed creating load: %w", err) } } @@ -30,8 +30,8 @@ func Load() error { return nil } -func Create() error { - defaultConfig := Config{ +func Create(override *Config) error { + var defaultConfig Config = Config{ Port: "7000", Address: "0.0.0.0", Experimental: false, @@ -49,6 +49,10 @@ func Create() error { EnableRequestLogging: false, } + if override != nil { + defaultConfig = *override + } + file, err := json.MarshalIndent(&defaultConfig, "", " ") if err != nil { return fmt.Errorf("Create: failed marshalling config: %w", err) diff --git a/ramchi_test.go b/ramchi_test.go index 260e11e..7b118bf 100644 --- a/ramchi_test.go +++ b/ramchi_test.go @@ -44,6 +44,26 @@ func TestBasicServer(t *testing.T) { const ERROR_MESSAGE = "Example error has occurred" const ERROR_RESPONSE = "test error pass-through" + defaultConfig := &c.Config{ + Port: "7000", + Address: "127.0.0.1", + Experimental: false, + ReadTimeout: 15, + WriteTimeout: 15, + IdleTimeout: 60, + LogLevel: "debug", + MaxHeaderBytes: 1048576, + EnableTLS: false, + TLSCertFile: "", + TLSKeyFile: "", + ShutdownTimeout: 5, + EnableCORS: true, + AllowedOrigins: []string{"http://example.com"}, + EnableRequestLogging: true, + } + + c.Create(defaultConfig) + ts := New() // Simulate what Start() would do — apply middlewares based on config From 3f5a94653f671438422dbe60f0e4191940549aac Mon Sep 17 00:00:00 2001 From: Zoe Woods Date: Thu, 10 Jul 2025 21:22:10 +0100 Subject: [PATCH 10/10] fix: add proper error handling --- ramchi_test.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ramchi_test.go b/ramchi_test.go index 7b118bf..eac9f35 100644 --- a/ramchi_test.go +++ b/ramchi_test.go @@ -62,7 +62,10 @@ func TestBasicServer(t *testing.T) { EnableRequestLogging: true, } - c.Create(defaultConfig) + err := c.Create(defaultConfig) + if err != nil { + t.Fatal(err) + } ts := New()