Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ FROM scratch

COPY --from=builder /app/mbproxy /mbproxy

HEALTHCHECK --interval=30s --timeout=3s --start-period=10s --retries=3 \
ENV HEALTH_LISTEN=:8080
EXPOSE 8080

HEALTHCHECK --interval=5s --timeout=3s --start-period=10s --retries=3 \
CMD ["/mbproxy", "-health"]

ENTRYPOINT ["/mbproxy"]
55 changes: 32 additions & 23 deletions cmd/mbproxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ import (
"os"
"os/signal"
"syscall"
"time"

"github.com/tma/mbproxy/internal/config"
"github.com/tma/mbproxy/internal/health"
"github.com/tma/mbproxy/internal/logging"
"github.com/tma/mbproxy/internal/modbus"
"github.com/tma/mbproxy/internal/proxy"
)

Expand All @@ -20,7 +21,12 @@ func main() {
flag.Parse()

if *healthCheck {
if err := runHealthCheck(); err != nil {
addr := os.Getenv("HEALTH_LISTEN")
if addr == "" {
fmt.Fprintln(os.Stderr, "HEALTH_LISTEN is not set")
os.Exit(1)
}
if err := health.CheckHealth(addr); err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
Expand Down Expand Up @@ -48,6 +54,22 @@ func main() {
os.Exit(1)
}

// Start health server if configured
var hs *health.Server
if cfg.HealthListen != "" {
hs = health.NewServer(cfg.HealthListen, p, logger)
hsLn, err := hs.Listen()
if err != nil {
logger.Error("failed to start health server", "error", err)
os.Exit(1)
}
go func() {
if err := hs.Serve(hsLn); err != nil {
logger.Error("health server error", "error", err)
}
}()
}

// Start proxy in background
errCh := make(chan error, 1)
go func() {
Expand All @@ -68,31 +90,18 @@ func main() {
// Graceful shutdown
cancel()

if hs != nil {
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer shutdownCancel()
if err := hs.Shutdown(shutdownCtx); err != nil {
logger.Error("health server shutdown error", "error", err)
}
}

if err := p.Shutdown(cfg.ShutdownTimeout); err != nil {
logger.Error("shutdown error", "error", err)
os.Exit(1)
}

logger.Info("shutdown complete")
}

func runHealthCheck() error {
cfg, err := config.Load()
if err != nil {
return err
}

return checkUpstreamHealth(cfg, logging.New(cfg.LogLevel))
}

func checkUpstreamHealth(cfg *config.Config, logger *slog.Logger) (err error) {
client := modbus.NewClient(cfg.Upstream, cfg.Timeout, cfg.RequestDelay, cfg.ConnectDelay, logger)
defer func() {
closeErr := client.Close()
if err == nil && closeErr != nil {
err = closeErr
}
}()

return client.Connect()
}
59 changes: 0 additions & 59 deletions cmd/mbproxy/main_test.go

This file was deleted.

2 changes: 2 additions & 0 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ type Config struct {
RequestDelay time.Duration
ConnectDelay time.Duration
ShutdownTimeout time.Duration
HealthListen string
LogLevel string
}

Expand All @@ -46,6 +47,7 @@ func Load() (*Config, error) {
RequestDelay: 0,
ConnectDelay: 0,
ShutdownTimeout: 30 * time.Second,
HealthListen: os.Getenv("HEALTH_LISTEN"),
LogLevel: GetEnv("LOG_LEVEL", "INFO"),
}

Expand Down
28 changes: 28 additions & 0 deletions internal/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ func TestLoad_Defaults(t *testing.T) {
os.Unsetenv("MODBUS_READONLY")
os.Unsetenv("MODBUS_TIMEOUT")
os.Unsetenv("MODBUS_SHUTDOWN_TIMEOUT")
os.Unsetenv("HEALTH_LISTEN")
os.Unsetenv("LOG_LEVEL")

cfg, err := Load()
Expand Down Expand Up @@ -56,6 +57,9 @@ func TestLoad_Defaults(t *testing.T) {
if cfg.ShutdownTimeout != 30*time.Second {
t.Errorf("expected 30s shutdown timeout, got %v", cfg.ShutdownTimeout)
}
if cfg.HealthListen != "" {
t.Errorf("expected empty health listen, got %s", cfg.HealthListen)
}
if cfg.LogLevel != "INFO" {
t.Errorf("expected INFO log level, got %s", cfg.LogLevel)
}
Expand Down Expand Up @@ -192,3 +196,27 @@ func TestLoad_InvalidDuration(t *testing.T) {
os.Unsetenv(envVar)
}
}

func TestLoad_HealthListenCustom(t *testing.T) {
// Ensure optional env vars that Load() may read do not inherit
// potentially invalid values from the surrounding environment.
t.Setenv("MODBUS_LISTEN", "")
t.Setenv("MODBUS_READONLY", "")
t.Setenv("MODBUS_CACHE_TTL", "")
t.Setenv("MODBUS_TIMEOUT", "")
t.Setenv("MODBUS_REQUEST_DELAY", "")
t.Setenv("MODBUS_CONNECT_DELAY", "")
t.Setenv("MODBUS_SHUTDOWN_TIMEOUT", "")

// Set required and explicitly tested env vars.
t.Setenv("MODBUS_UPSTREAM", "localhost:502")
t.Setenv("HEALTH_LISTEN", ":9090")

cfg, err := Load()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cfg.HealthListen != ":9090" {
t.Errorf("expected :9090, got %s", cfg.HealthListen)
}
}
8 changes: 6 additions & 2 deletions internal/health/health.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,17 +107,21 @@ func (s *Server) Shutdown(ctx context.Context) error {

// CheckHealth performs an HTTP health check against the given address.
// It returns nil if the endpoint responds with 200 OK.
// Wildcard listen addresses (e.g. ":8080", "0.0.0.0:8080", "[::]:8080") are
// normalized to localhost so they can be used as dial targets. IPv6 addresses
// are handled correctly via net.JoinHostPort.
func CheckHealth(addr string) error {
// Resolve the address so we can build a proper URL.
host, port, err := net.SplitHostPort(addr)
if err != nil {
return fmt.Errorf("invalid address %q: %w", addr, err)
}
if host == "" {
// Normalize wildcard and empty hosts to localhost.
if host == "" || host == "0.0.0.0" || host == "::" {
host = "localhost"
}

url := fmt.Sprintf("http://%s:%s/health", host, port)
url := fmt.Sprintf("http://%s/health", net.JoinHostPort(host, port))

client := &http.Client{Timeout: 3 * time.Second}
resp, err := client.Get(url)
Expand Down
46 changes: 46 additions & 0 deletions internal/health/health_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,49 @@ func TestCheckHealth_ConnectionRefused(t *testing.T) {
t.Error("expected error for connection refused")
}
}

func TestCheckHealth_WildcardAddresses(t *testing.T) {
// Start a test server bound to localhost so wildcard addresses can reach it.
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(Response{Status: "ok"})
}))
defer ts.Close()

_, port, err := net.SplitHostPort(ts.Listener.Addr().String())
if err != nil {
t.Fatalf("failed to parse test server address: %v", err)
}

// Each of these listen-style addresses should be normalized to localhost.
wildcards := []string{
":" + port, // empty host (":8080")
"0.0.0.0:" + port, // IPv4 wildcard
"[::]:" + port, // IPv6 wildcard
}
for _, addr := range wildcards {
if err := CheckHealth(addr); err != nil {
t.Errorf("CheckHealth(%q) expected success, got: %v", addr, err)
}
}
}

func TestCheckHealth_IPv6Loopback(t *testing.T) {
// Start a test server bound to the IPv6 loopback address.
ln, err := net.Listen("tcp6", "[::1]:0")
if err != nil {
t.Skip("IPv6 loopback not available:", err)
}
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(Response{Status: "ok"})
}))
ts.Listener = ln
ts.Start()
defer ts.Close()

addr := ts.Listener.Addr().String() // "[::1]:PORT"
if err := CheckHealth(addr); err != nil {
t.Errorf("CheckHealth(%q) expected success for IPv6 loopback, got: %v", addr, err)
}
}
Loading