Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions internal/check/port.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package check

import (
"context"
"fmt"
"net"
"net/url"
"strings"
"time"
)

type PortCheck struct {
Service string
Port string
dialer func(address string) error
}

func (c *PortCheck) Name() string {
return fmt.Sprintf("%s port reachable", c.Service)
}

func (c *PortCheck) Run(_ context.Context) Result {
dial := c.dialer
if dial == nil {
dial = func(address string) error {
conn, err := net.DialTimeout("tcp", address, 2*time.Second)
if err != nil {
return err
}
conn.Close()
return nil
}
}

address := "localhost:" + c.Port
if err := dial(address); err != nil {
return Result{
Name: c.Name(),
Status: StatusFail,
Message: fmt.Sprintf("nothing listening on port %s", c.Port),
Fix: fmt.Sprintf("start %s and make sure it is running on port %s", c.Service, c.Port),
}
}
return Result{
Name: c.Name(),
Status: StatusPass,
Message: fmt.Sprintf("%s is listening on port %s", c.Service, c.Port),
}
}

// portFromURL extracts the port from a service URL, falling back to defaultPort.
// Handles standard URLs (postgres://, redis://, mongodb://) and MySQL DSNs (user:pass@tcp(host:port)/db).
func portFromURL(rawURL, defaultPort string) string {
if rawURL == "" {
return defaultPort
}

// MySQL DSN format: user:pass@tcp(host:port)/db
if idx := strings.Index(rawURL, "tcp("); idx != -1 {
rest := rawURL[idx+4:]
if end := strings.Index(rest, ")"); end != -1 {
_, port, err := net.SplitHostPort(rest[:end])
if err == nil && port != "" {
return port
}
}
}

// Standard URL format
u, err := url.Parse(rawURL)
if err == nil && u.Port() != "" {
return u.Port()
}

return defaultPort
}
71 changes: 71 additions & 0 deletions internal/check/port_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package check

import (
"context"
"errors"
"testing"
)

func TestPortCheck_Pass(t *testing.T) {
c := &PortCheck{Service: "PostgreSQL", Port: "5432", dialer: func(_ string) error {
return nil
}}
result := c.Run(context.Background())
if result.Status != StatusPass {
t.Errorf("expected pass, got %v: %s", result.Status, result.Message)
}
}

func TestPortCheck_Fail(t *testing.T) {
c := &PortCheck{Service: "PostgreSQL", Port: "5432", dialer: func(_ string) error {
return errors.New("connection refused")
}}
result := c.Run(context.Background())
if result.Status != StatusFail {
t.Errorf("expected fail, got %v", result.Status)
}
}

func TestPortCheck_MessageContainsPort(t *testing.T) {
c := &PortCheck{Service: "Redis", Port: "6379", dialer: func(_ string) error {
return errors.New("connection refused")
}}
result := c.Run(context.Background())
if result.Message != "nothing listening on port 6379" {
t.Errorf("unexpected message: %s", result.Message)
}
}

func TestPortFromURL_StandardURL(t *testing.T) {
cases := []struct {
url string
defaultPort string
expectedPort string
}{
{"postgres://user:pass@localhost:5555/db", "5432", "5555"},
{"redis://localhost:6380", "6379", "6380"},
{"mongodb://localhost:27018", "27017", "27018"},
{"", "5432", "5432"},
{"postgres://localhost/db", "5432", "5432"}, // no port in URL → default
}
for _, tc := range cases {
got := portFromURL(tc.url, tc.defaultPort)
if got != tc.expectedPort {
t.Errorf("portFromURL(%q, %q) = %q, want %q", tc.url, tc.defaultPort, got, tc.expectedPort)
}
}
}

func TestPortFromURL_MySQLDSN(t *testing.T) {
got := portFromURL("user:pass@tcp(localhost:3307)/db", "3306")
if got != "3307" {
t.Errorf("expected 3307, got %s", got)
}
}

func TestPortFromURL_MySQLDSN_Default(t *testing.T) {
got := portFromURL("user:pass@tcp(localhost:3306)/db", "3306")
if got != "3306" {
t.Errorf("expected 3306, got %s", got)
}
}
23 changes: 18 additions & 5 deletions internal/check/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ func Build(stack detector.DetectedStack) []Check {
cs = append(cs, &ComposeImageCheck{})
}
if stack.Postgres {
cs = append(cs, &PostgresCheck{URL: os.Getenv("DATABASE_URL")})
dbURL := os.Getenv("DATABASE_URL")
cs = append(cs, &PortCheck{Service: "PostgreSQL", Port: portFromURL(dbURL, "5432")})
cs = append(cs, &PostgresCheck{URL: dbURL})
}
if stack.MySQL {
cs = append(cs, &MySQLCheck{URL: os.Getenv("MYSQL_URL")})
Expand All @@ -62,11 +64,22 @@ func Build(stack detector.DetectedStack) []Check {
cs = append(cs, &MongoCheck{URL: url})
}
if stack.Redis {
url := os.Getenv("REDIS_URL")
if url == "" {
url = os.Getenv("REDIS_URI")
redisURL := os.Getenv("REDIS_URL")
if redisURL == "" {
redisURL = os.Getenv("REDIS_URI")
}
cs = append(cs, &PortCheck{Service: "Redis", Port: portFromURL(redisURL, "6379")})
cs = append(cs, &RedisCheck{URL: redisURL})
}
if stack.MySQL {
cs = append(cs, &PortCheck{Service: "MySQL", Port: portFromURL(os.Getenv("MYSQL_URL"), "3306")})
}
if stack.MongoDB {
mongoURL := os.Getenv("MONGODB_URI")
if mongoURL == "" {
mongoURL = os.Getenv("MONGO_URL")
}
cs = append(cs, &RedisCheck{URL: url})
cs = append(cs, &PortCheck{Service: "MongoDB", Port: portFromURL(mongoURL, "27017")})
}

if stack.EnvExample {
Expand Down