diff --git a/Dockerfile b/Dockerfile index abf1edc..f17df1f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -31,7 +31,10 @@ FROM scratch COPY --from=builder /app/mbproxy /mbproxy -HEALTHCHECK --interval=30s --timeout=3s --start-period=10s --retries=3 \ +ENV HEALTH_LISTEN=:8080 +EXPOSE 8080 + +HEALTHCHECK --interval=5s --timeout=3s --start-period=10s --retries=3 \ CMD ["/mbproxy", "-health"] ENTRYPOINT ["/mbproxy"] diff --git a/cmd/mbproxy/main.go b/cmd/mbproxy/main.go index ebe21a0..6088512 100644 --- a/cmd/mbproxy/main.go +++ b/cmd/mbproxy/main.go @@ -8,10 +8,11 @@ import ( "os" "os/signal" "syscall" + "time" "github.com/tma/mbproxy/internal/config" + "github.com/tma/mbproxy/internal/health" "github.com/tma/mbproxy/internal/logging" - "github.com/tma/mbproxy/internal/modbus" "github.com/tma/mbproxy/internal/proxy" ) @@ -20,7 +21,12 @@ func main() { flag.Parse() if *healthCheck { - if err := runHealthCheck(); err != nil { + addr := os.Getenv("HEALTH_LISTEN") + if addr == "" { + fmt.Fprintln(os.Stderr, "HEALTH_LISTEN is not set") + os.Exit(1) + } + if err := health.CheckHealth(addr); err != nil { fmt.Fprintln(os.Stderr, err) os.Exit(1) } @@ -48,6 +54,22 @@ func main() { os.Exit(1) } + // Start health server if configured + var hs *health.Server + if cfg.HealthListen != "" { + hs = health.NewServer(cfg.HealthListen, p, logger) + hsLn, err := hs.Listen() + if err != nil { + logger.Error("failed to start health server", "error", err) + os.Exit(1) + } + go func() { + if err := hs.Serve(hsLn); err != nil { + logger.Error("health server error", "error", err) + } + }() + } + // Start proxy in background errCh := make(chan error, 1) go func() { @@ -68,6 +90,14 @@ func main() { // Graceful shutdown cancel() + if hs != nil { + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + if err := hs.Shutdown(shutdownCtx); err != nil { + logger.Error("health server shutdown error", "error", err) + } + } + if err := p.Shutdown(cfg.ShutdownTimeout); err != nil { logger.Error("shutdown error", "error", err) os.Exit(1) @@ -75,24 +105,3 @@ func main() { logger.Info("shutdown complete") } - -func runHealthCheck() error { - cfg, err := config.Load() - if err != nil { - return err - } - - return checkUpstreamHealth(cfg, logging.New(cfg.LogLevel)) -} - -func checkUpstreamHealth(cfg *config.Config, logger *slog.Logger) (err error) { - client := modbus.NewClient(cfg.Upstream, cfg.Timeout, cfg.RequestDelay, cfg.ConnectDelay, logger) - defer func() { - closeErr := client.Close() - if err == nil && closeErr != nil { - err = closeErr - } - }() - - return client.Connect() -} diff --git a/cmd/mbproxy/main_test.go b/cmd/mbproxy/main_test.go deleted file mode 100644 index 0bc9260..0000000 --- a/cmd/mbproxy/main_test.go +++ /dev/null @@ -1,59 +0,0 @@ -package main - -import ( - "io" - "log/slog" - "net" - "testing" - "time" - - "github.com/tma/mbproxy/internal/config" -) - -func newTestLogger() *slog.Logger { - return slog.New(slog.NewTextHandler(io.Discard, nil)) -} - -func TestCheckUpstreamHealth_Success(t *testing.T) { - ln, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("failed to listen: %v", err) - } - defer ln.Close() - - acceptDone := make(chan struct{}) - go func() { - defer close(acceptDone) - conn, err := ln.Accept() - if err == nil { - conn.Close() - } - }() - - cfg := &config.Config{ - Upstream: ln.Addr().String(), - Timeout: time.Second, - } - if err := checkUpstreamHealth(cfg, newTestLogger()); err != nil { - t.Fatalf("expected health check to succeed, got %v", err) - } - - <-acceptDone -} - -func TestCheckUpstreamHealth_Failure(t *testing.T) { - ln, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("failed to reserve port: %v", err) - } - addr := ln.Addr().String() - ln.Close() - - cfg := &config.Config{ - Upstream: addr, - Timeout: 100 * time.Millisecond, - } - if err := checkUpstreamHealth(cfg, newTestLogger()); err == nil { - t.Fatal("expected health check to fail") - } -} diff --git a/internal/config/config.go b/internal/config/config.go index 1c86d1b..7e933bf 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -30,6 +30,7 @@ type Config struct { RequestDelay time.Duration ConnectDelay time.Duration ShutdownTimeout time.Duration + HealthListen string LogLevel string } @@ -46,6 +47,7 @@ func Load() (*Config, error) { RequestDelay: 0, ConnectDelay: 0, ShutdownTimeout: 30 * time.Second, + HealthListen: os.Getenv("HEALTH_LISTEN"), LogLevel: GetEnv("LOG_LEVEL", "INFO"), } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 55cc946..ebc316b 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -19,6 +19,7 @@ func TestLoad_Defaults(t *testing.T) { os.Unsetenv("MODBUS_READONLY") os.Unsetenv("MODBUS_TIMEOUT") os.Unsetenv("MODBUS_SHUTDOWN_TIMEOUT") + os.Unsetenv("HEALTH_LISTEN") os.Unsetenv("LOG_LEVEL") cfg, err := Load() @@ -56,6 +57,9 @@ func TestLoad_Defaults(t *testing.T) { if cfg.ShutdownTimeout != 30*time.Second { t.Errorf("expected 30s shutdown timeout, got %v", cfg.ShutdownTimeout) } + if cfg.HealthListen != "" { + t.Errorf("expected empty health listen, got %s", cfg.HealthListen) + } if cfg.LogLevel != "INFO" { t.Errorf("expected INFO log level, got %s", cfg.LogLevel) } @@ -192,3 +196,27 @@ func TestLoad_InvalidDuration(t *testing.T) { os.Unsetenv(envVar) } } + +func TestLoad_HealthListenCustom(t *testing.T) { + // Ensure optional env vars that Load() may read do not inherit + // potentially invalid values from the surrounding environment. + t.Setenv("MODBUS_LISTEN", "") + t.Setenv("MODBUS_READONLY", "") + t.Setenv("MODBUS_CACHE_TTL", "") + t.Setenv("MODBUS_TIMEOUT", "") + t.Setenv("MODBUS_REQUEST_DELAY", "") + t.Setenv("MODBUS_CONNECT_DELAY", "") + t.Setenv("MODBUS_SHUTDOWN_TIMEOUT", "") + + // Set required and explicitly tested env vars. + t.Setenv("MODBUS_UPSTREAM", "localhost:502") + t.Setenv("HEALTH_LISTEN", ":9090") + + cfg, err := Load() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cfg.HealthListen != ":9090" { + t.Errorf("expected :9090, got %s", cfg.HealthListen) + } +} diff --git a/internal/health/health.go b/internal/health/health.go index 625c7db..a32763a 100644 --- a/internal/health/health.go +++ b/internal/health/health.go @@ -107,17 +107,21 @@ func (s *Server) Shutdown(ctx context.Context) error { // CheckHealth performs an HTTP health check against the given address. // It returns nil if the endpoint responds with 200 OK. +// Wildcard listen addresses (e.g. ":8080", "0.0.0.0:8080", "[::]:8080") are +// normalized to localhost so they can be used as dial targets. IPv6 addresses +// are handled correctly via net.JoinHostPort. func CheckHealth(addr string) error { // Resolve the address so we can build a proper URL. host, port, err := net.SplitHostPort(addr) if err != nil { return fmt.Errorf("invalid address %q: %w", addr, err) } - if host == "" { + // Normalize wildcard and empty hosts to localhost. + if host == "" || host == "0.0.0.0" || host == "::" { host = "localhost" } - url := fmt.Sprintf("http://%s:%s/health", host, port) + url := fmt.Sprintf("http://%s/health", net.JoinHostPort(host, port)) client := &http.Client{Timeout: 3 * time.Second} resp, err := client.Get(url) diff --git a/internal/health/health_test.go b/internal/health/health_test.go index 25c8e52..85d209c 100644 --- a/internal/health/health_test.go +++ b/internal/health/health_test.go @@ -149,3 +149,49 @@ func TestCheckHealth_ConnectionRefused(t *testing.T) { t.Error("expected error for connection refused") } } + +func TestCheckHealth_WildcardAddresses(t *testing.T) { + // Start a test server bound to localhost so wildcard addresses can reach it. + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(Response{Status: "ok"}) + })) + defer ts.Close() + + _, port, err := net.SplitHostPort(ts.Listener.Addr().String()) + if err != nil { + t.Fatalf("failed to parse test server address: %v", err) + } + + // Each of these listen-style addresses should be normalized to localhost. + wildcards := []string{ + ":" + port, // empty host (":8080") + "0.0.0.0:" + port, // IPv4 wildcard + "[::]:" + port, // IPv6 wildcard + } + for _, addr := range wildcards { + if err := CheckHealth(addr); err != nil { + t.Errorf("CheckHealth(%q) expected success, got: %v", addr, err) + } + } +} + +func TestCheckHealth_IPv6Loopback(t *testing.T) { + // Start a test server bound to the IPv6 loopback address. + ln, err := net.Listen("tcp6", "[::1]:0") + if err != nil { + t.Skip("IPv6 loopback not available:", err) + } + ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(Response{Status: "ok"}) + })) + ts.Listener = ln + ts.Start() + defer ts.Close() + + addr := ts.Listener.Addr().String() // "[::1]:PORT" + if err := CheckHealth(addr); err != nil { + t.Errorf("CheckHealth(%q) expected success for IPv6 loopback, got: %v", addr, err) + } +}