diff --git a/internal/check/port.go b/internal/check/port.go new file mode 100644 index 0000000..d18ddd5 --- /dev/null +++ b/internal/check/port.go @@ -0,0 +1,76 @@ +package check + +import ( + "context" + "fmt" + "net" + "net/url" + "strings" + "time" +) + +type PortCheck struct { + Service string + Port string + dialer func(address string) error +} + +func (c *PortCheck) Name() string { + return fmt.Sprintf("%s port reachable", c.Service) +} + +func (c *PortCheck) Run(_ context.Context) Result { + dial := c.dialer + if dial == nil { + dial = func(address string) error { + conn, err := net.DialTimeout("tcp", address, 2*time.Second) + if err != nil { + return err + } + conn.Close() + return nil + } + } + + address := "localhost:" + c.Port + if err := dial(address); err != nil { + return Result{ + Name: c.Name(), + Status: StatusFail, + Message: fmt.Sprintf("nothing listening on port %s", c.Port), + Fix: fmt.Sprintf("start %s and make sure it is running on port %s", c.Service, c.Port), + } + } + return Result{ + Name: c.Name(), + Status: StatusPass, + Message: fmt.Sprintf("%s is listening on port %s", c.Service, c.Port), + } +} + +// portFromURL extracts the port from a service URL, falling back to defaultPort. +// Handles standard URLs (postgres://, redis://, mongodb://) and MySQL DSNs (user:pass@tcp(host:port)/db). +func portFromURL(rawURL, defaultPort string) string { + if rawURL == "" { + return defaultPort + } + + // MySQL DSN format: user:pass@tcp(host:port)/db + if idx := strings.Index(rawURL, "tcp("); idx != -1 { + rest := rawURL[idx+4:] + if end := strings.Index(rest, ")"); end != -1 { + _, port, err := net.SplitHostPort(rest[:end]) + if err == nil && port != "" { + return port + } + } + } + + // Standard URL format + u, err := url.Parse(rawURL) + if err == nil && u.Port() != "" { + return u.Port() + } + + return defaultPort +} diff --git a/internal/check/port_test.go b/internal/check/port_test.go new file mode 100644 index 0000000..b569196 --- /dev/null +++ b/internal/check/port_test.go @@ -0,0 +1,71 @@ +package check + +import ( + "context" + "errors" + "testing" +) + +func TestPortCheck_Pass(t *testing.T) { + c := &PortCheck{Service: "PostgreSQL", Port: "5432", dialer: func(_ string) error { + return nil + }} + result := c.Run(context.Background()) + if result.Status != StatusPass { + t.Errorf("expected pass, got %v: %s", result.Status, result.Message) + } +} + +func TestPortCheck_Fail(t *testing.T) { + c := &PortCheck{Service: "PostgreSQL", Port: "5432", dialer: func(_ string) error { + return errors.New("connection refused") + }} + result := c.Run(context.Background()) + if result.Status != StatusFail { + t.Errorf("expected fail, got %v", result.Status) + } +} + +func TestPortCheck_MessageContainsPort(t *testing.T) { + c := &PortCheck{Service: "Redis", Port: "6379", dialer: func(_ string) error { + return errors.New("connection refused") + }} + result := c.Run(context.Background()) + if result.Message != "nothing listening on port 6379" { + t.Errorf("unexpected message: %s", result.Message) + } +} + +func TestPortFromURL_StandardURL(t *testing.T) { + cases := []struct { + url string + defaultPort string + expectedPort string + }{ + {"postgres://user:pass@localhost:5555/db", "5432", "5555"}, + {"redis://localhost:6380", "6379", "6380"}, + {"mongodb://localhost:27018", "27017", "27018"}, + {"", "5432", "5432"}, + {"postgres://localhost/db", "5432", "5432"}, // no port in URL → default + } + for _, tc := range cases { + got := portFromURL(tc.url, tc.defaultPort) + if got != tc.expectedPort { + t.Errorf("portFromURL(%q, %q) = %q, want %q", tc.url, tc.defaultPort, got, tc.expectedPort) + } + } +} + +func TestPortFromURL_MySQLDSN(t *testing.T) { + got := portFromURL("user:pass@tcp(localhost:3307)/db", "3306") + if got != "3307" { + t.Errorf("expected 3307, got %s", got) + } +} + +func TestPortFromURL_MySQLDSN_Default(t *testing.T) { + got := portFromURL("user:pass@tcp(localhost:3306)/db", "3306") + if got != "3306" { + t.Errorf("expected 3306, got %s", got) + } +} diff --git a/internal/check/registry.go b/internal/check/registry.go index 4fca558..3b8dec7 100644 --- a/internal/check/registry.go +++ b/internal/check/registry.go @@ -49,7 +49,9 @@ func Build(stack detector.DetectedStack) []Check { cs = append(cs, &ComposeImageCheck{}) } if stack.Postgres { - cs = append(cs, &PostgresCheck{URL: os.Getenv("DATABASE_URL")}) + dbURL := os.Getenv("DATABASE_URL") + cs = append(cs, &PortCheck{Service: "PostgreSQL", Port: portFromURL(dbURL, "5432")}) + cs = append(cs, &PostgresCheck{URL: dbURL}) } if stack.MySQL { cs = append(cs, &MySQLCheck{URL: os.Getenv("MYSQL_URL")}) @@ -62,11 +64,22 @@ func Build(stack detector.DetectedStack) []Check { cs = append(cs, &MongoCheck{URL: url}) } if stack.Redis { - url := os.Getenv("REDIS_URL") - if url == "" { - url = os.Getenv("REDIS_URI") + redisURL := os.Getenv("REDIS_URL") + if redisURL == "" { + redisURL = os.Getenv("REDIS_URI") + } + cs = append(cs, &PortCheck{Service: "Redis", Port: portFromURL(redisURL, "6379")}) + cs = append(cs, &RedisCheck{URL: redisURL}) + } + if stack.MySQL { + cs = append(cs, &PortCheck{Service: "MySQL", Port: portFromURL(os.Getenv("MYSQL_URL"), "3306")}) + } + if stack.MongoDB { + mongoURL := os.Getenv("MONGODB_URI") + if mongoURL == "" { + mongoURL = os.Getenv("MONGO_URL") } - cs = append(cs, &RedisCheck{URL: url}) + cs = append(cs, &PortCheck{Service: "MongoDB", Port: portFromURL(mongoURL, "27017")}) } if stack.EnvExample {