From f23b6afa05d1ca3a88f254a480ee1704d6aa8f35 Mon Sep 17 00:00:00 2001 From: Vuks69 <51289041+Vuks69@users.noreply.github.com> Date: Thu, 21 May 2026 21:13:37 +0200 Subject: [PATCH 1/2] feat(database): add postgres integration --- .env.example | 13 ++ docker-compose.yml | 20 +++ go.mod | 2 + go.sum | 2 + internal/core/start.go | 64 ++++++++- internal/db/cache.go | 70 +++++++++ internal/db/connection.go | 156 +++++++++++++++++++++ internal/db/guild.go | 54 +++++++ internal/db/migrations.go | 117 ++++++++++++++++ internal/db/migrations/001_init_schema.sql | 31 ++++ internal/db/moderation.go | 74 ++++++++++ internal/db/types.go | 35 +++++ internal/db/user.go | 64 +++++++++ 13 files changed, 700 insertions(+), 2 deletions(-) create mode 100644 internal/db/cache.go create mode 100644 internal/db/connection.go create mode 100644 internal/db/guild.go create mode 100644 internal/db/migrations.go create mode 100644 internal/db/migrations/001_init_schema.sql create mode 100644 internal/db/moderation.go create mode 100644 internal/db/types.go create mode 100644 internal/db/user.go diff --git a/.env.example b/.env.example index ab6f364..b64d3e6 100644 --- a/.env.example +++ b/.env.example @@ -1 +1,14 @@ +ENV="staging" BOT_TOKEN="your-bot-token" +DATABASE_URL= #"postgres://pen_bot:pen_bot_password@postgres:5432/pen_bot?sslmode=disable" +DB_NAME="pen_bot" +DB_HOST="localhost" +DB_PORT="5432" +DB_USER="pen_bot" +DB_PASSWORD="pen_bot_password" +DB_BOT_INSTANCE_ID="pen-bot-1" +DB_MAX_OPEN_CONNS="16" +DB_MAX_IDLE_CONNS="4" +DB_CONN_MAX_LIFETIME="30m" +DB_CACHE_CLEANUP_INTERVAL="15m" +POSTGRES_PASSWORD="pen_bot_password" diff --git a/docker-compose.yml b/docker-compose.yml index c218bfb..c270361 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -4,7 +4,27 @@ services: pull_policy: build env_file: - .env + depends_on: + postgres: + condition: service_healthy develop: watch: - action: rebuild path: . + + postgres: + image: postgres:16-alpine + environment: + POSTGRES_DB: pen_bot + POSTGRES_USER: pen_bot + POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-pen_bot_password} + volumes: + - postgres_data:/var/lib/postgresql/data + healthcheck: + test: ["CMD-SHELL", "pg_isready -U $${POSTGRES_USER}"] + interval: 10s + timeout: 5s + retries: 5 + +volumes: + postgres_data: diff --git a/go.mod b/go.mod index 7df91a2..d845980 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,8 @@ go 1.26.2 require github.com/disgoorg/disgo v0.19.3 +require github.com/lib/pq v1.12.3 + require ( github.com/disgoorg/godave v0.1.0 // indirect github.com/disgoorg/json/v2 v2.0.0 // indirect diff --git a/go.sum b/go.sum index f2e3793..7de5208 100644 --- a/go.sum +++ b/go.sum @@ -14,6 +14,8 @@ github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aN github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/klauspost/compress v1.18.4 h1:RPhnKRAQ4Fh8zU2FY/6ZFDwTVTxgJ/EMydqSTzE9a2c= github.com/klauspost/compress v1.18.4/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= +github.com/lib/pq v1.12.3 h1:tTWxr2YLKwIvK90ZXEw8GP7UFHtcbTtty8zsI+YjrfQ= +github.com/lib/pq v1.12.3/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/sasha-s/go-csync v0.0.0-20240107134140-fcbab37b09ad h1:qIQkSlF5vAUHxEmTbaqt1hkJ/t6skqEGYiMag343ucI= diff --git a/internal/core/start.go b/internal/core/start.go index e620a3f..bec54b0 100644 --- a/internal/core/start.go +++ b/internal/core/start.go @@ -8,7 +8,9 @@ import ( "strings" "sync" "syscall" + "time" + "github.com/Neon-Genesis-Linux/pen-bot/internal/db" "github.com/disgoorg/disgo" "github.com/disgoorg/disgo/bot" "github.com/disgoorg/disgo/events" @@ -78,10 +80,68 @@ func Start(ctx context.Context, token string, listener func(*events.MessageCreat return err } + cleanupCtx, cleanupCancel := context.WithCancel(context.Background()) + defer cleanupCancel() + cleanupInterval := db.ParseCleanupIntervalEnv("DB_CACHE_CLEANUP_INTERVAL", 15) + go func() { + defer func() { + if r := recover(); r != nil { + slog.Error("cache cleanup panicked", "recover", r) + } + }() + db.StartCleanup(cleanupCtx, cleanupInterval) + }() + + go func() { + defer func() { + if r := recover(); r != nil { + slog.Error("db connect panicked", "recover", r) + } + }() + connectDBWithRetry(ctx) + }() + + defer db.CloseDB() + slog.Info("pen bot is now running. Press CTRL-C to exit.") s := make(chan os.Signal, 1) signal.Notify(s, syscall.SIGINT, syscall.SIGTERM, os.Interrupt) - <-s - return nil + select { + case <-ctx.Done(): + return ctx.Err() + case <-s: + return nil + } +} + +func connectDBWithRetry(ctx context.Context) { + var dbClient *db.DB + var err error + + for attempt := range 10 { + dbClient, err = db.NewFromEnv() + if err == nil { + break + } + slog.Warn("db connection attempt failed", "attempt", attempt+1, "error", err) + select { + case <-ctx.Done(): + return + case <-time.After(time.Duration(attempt+1) * 2 * time.Second): + } + } + if err != nil { + slog.Error("db unavailable after retries", "error", err) + return + } + + if err := db.ApplyMigrations(ctx, dbClient); err != nil { + slog.Error("db migration failed", "error", err) + _ = dbClient.Close() + return + } + + db.SetGlobalDB(dbClient) + slog.Info("db connected and ready") } diff --git a/internal/db/cache.go b/internal/db/cache.go new file mode 100644 index 0000000..6f562bb --- /dev/null +++ b/internal/db/cache.go @@ -0,0 +1,70 @@ +package db + +import ( + "context" + "sync" + "time" +) + +type memCacheEntry struct { + data []byte + expiresAt time.Time +} + +var ( + mc = make(map[string]memCacheEntry) + mcMu sync.RWMutex +) + +func SetCacheEntry(key string, valueJSON []byte, ttl time.Duration) { + mcMu.Lock() + defer mcMu.Unlock() + var expiresAt time.Time + if ttl > 0 { + expiresAt = time.Now().Add(ttl) + } + mc[key] = memCacheEntry{data: valueJSON, expiresAt: expiresAt} +} + +func GetCacheEntry(key string) ([]byte, bool) { + mcMu.RLock() + e, ok := mc[key] + mcMu.RUnlock() + if !ok { + return nil, false + } + if !e.expiresAt.IsZero() && time.Now().After(e.expiresAt) { + mcMu.Lock() + delete(mc, key) + mcMu.Unlock() + return nil, false + } + return e.data, true +} + +func DeleteCacheEntry(key string) { + mcMu.Lock() + defer mcMu.Unlock() + delete(mc, key) +} + +func StartCleanup(ctx context.Context, interval time.Duration) { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + mcMu.Lock() + now := time.Now() + for k, e := range mc { + if !e.expiresAt.IsZero() && now.After(e.expiresAt) { + delete(mc, k) + } + } + mcMu.Unlock() + } + } +} diff --git a/internal/db/connection.go b/internal/db/connection.go new file mode 100644 index 0000000..588260f --- /dev/null +++ b/internal/db/connection.go @@ -0,0 +1,156 @@ +package db + +import ( + "context" + "database/sql" + "fmt" + "net" + "net/url" + "os" + "strconv" + "sync" + "time" + + _ "github.com/lib/pq" +) + +type DB struct { + *sql.DB +} + +type Config struct { + DSN string + MaxOpenConns int + MaxIdleConns int + ConnMaxLifetime time.Duration +} + +func New(cfg Config) (*DB, error) { + if cfg.DSN == "" { + return nil, fmt.Errorf("database DSN is required") + } + + db, err := sql.Open("postgres", cfg.DSN) + if err != nil { + return nil, err + } + + if cfg.MaxOpenConns > 0 { + db.SetMaxOpenConns(cfg.MaxOpenConns) + } + if cfg.MaxIdleConns > 0 { + db.SetMaxIdleConns(cfg.MaxIdleConns) + } + if cfg.ConnMaxLifetime > 0 { + db.SetConnMaxLifetime(cfg.ConnMaxLifetime) + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + if err := db.PingContext(ctx); err != nil { + _ = db.Close() + return nil, err + } + + return &DB{db}, nil +} + +func NewFromEnv() (*DB, error) { + dsn := os.Getenv("DATABASE_URL") + if dsn == "" { + host := getEnv("DB_HOST", "localhost") + port := getEnv("DB_PORT", "5432") + user := getEnv("DB_USER", "pen_bot") + password := os.Getenv("DB_PASSWORD") + dbName := os.Getenv("DB_NAME") + if dbName == "" { + botInstance := os.Getenv("DB_BOT_INSTANCE_ID") + if botInstance == "" { + dbName = "pen_bot" + } else { + dbName = fmt.Sprintf("pen_bot_%s", botInstance) + } + } + + u := &url.URL{ + Scheme: "postgres", + User: url.UserPassword(user, password), + Host: net.JoinHostPort(host, port), + Path: dbName, + } + u.RawQuery = "sslmode=disable" + dsn = u.String() + } + + cfg := Config{ + DSN: dsn, + MaxOpenConns: parseIntEnv("DB_MAX_OPEN_CONNS", 16), + MaxIdleConns: parseIntEnv("DB_MAX_IDLE_CONNS", 4), + ConnMaxLifetime: parseDurationEnv("DB_CONN_MAX_LIFETIME", 30*time.Minute), + } + + return New(cfg) +} + +func getEnv(key, fallback string) string { + value := os.Getenv(key) + if value == "" { + return fallback + } + return value +} + +func parseIntEnv(key string, fallback int) int { + value := os.Getenv(key) + if value == "" { + return fallback + } + parsed, err := strconv.Atoi(value) + if err != nil { + return fallback + } + return parsed +} + +func parseDurationEnv(key string, fallback time.Duration) time.Duration { + value := os.Getenv(key) + if value == "" { + return fallback + } + parsed, err := time.ParseDuration(value) + if err != nil { + return fallback + } + return parsed +} + +var ( + globalDB *DB + globalDBMu sync.RWMutex +) + +func SetGlobalDB(d *DB) { + globalDBMu.Lock() + defer globalDBMu.Unlock() + globalDB = d +} + +func GlobalDB() *DB { + globalDBMu.RLock() + defer globalDBMu.RUnlock() + return globalDB +} + +func CloseDB() { + globalDBMu.Lock() + defer globalDBMu.Unlock() + if globalDB != nil { + _ = globalDB.Close() + globalDB = nil + } +} + +func ParseCleanupIntervalEnv(key string, defaultMinutes int) time.Duration { + return parseDurationEnv(key, time.Duration(defaultMinutes)*time.Minute) +} diff --git a/internal/db/guild.go b/internal/db/guild.go new file mode 100644 index 0000000..4c79929 --- /dev/null +++ b/internal/db/guild.go @@ -0,0 +1,54 @@ +package db + +import ( + "context" + "database/sql" + "fmt" +) + +func GetGuild(ctx context.Context, db *DB, guildID string) (*Guild, error) { + if guildID == "" { + return nil, fmt.Errorf("guildID required") + } + var guild Guild + err := db.QueryRowContext(ctx, ` + SELECT id, guild_id, settings_json, created_at, updated_at + FROM guilds + WHERE guild_id = $1 + `, guildID).Scan( + &guild.ID, + &guild.GuildID, + &guild.SettingsJSON, + &guild.CreatedAt, + &guild.UpdatedAt, + ) + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, err + } + return &guild, nil +} + +func UpsertGuild(ctx context.Context, db *DB, guildID string, settingsJSON string) error { + if guildID == "" { + return fmt.Errorf("guildID required") + } + _, err := db.ExecContext(ctx, ` + INSERT INTO guilds (guild_id, settings_json, created_at, updated_at) + VALUES ($1, $2::jsonb, now(), now()) + ON CONFLICT (guild_id) DO UPDATE SET + settings_json = EXCLUDED.settings_json, + updated_at = now() + `, guildID, settingsJSON) + return err +} + +func DeleteGuild(ctx context.Context, db *DB, guildID string) error { + if guildID == "" { + return fmt.Errorf("guildID required") + } + _, err := db.ExecContext(ctx, `DELETE FROM guilds WHERE guild_id = $1`, guildID) + return err +} diff --git a/internal/db/migrations.go b/internal/db/migrations.go new file mode 100644 index 0000000..67e244f --- /dev/null +++ b/internal/db/migrations.go @@ -0,0 +1,117 @@ +package db + +import ( + "context" + "database/sql" + "embed" + "fmt" + "log/slog" + "sort" + "strings" + "time" +) + +//go:embed migrations/*.sql +var migrationsFS embed.FS + +func ApplyMigrations(ctx context.Context, db *DB) error { + if err := ensureSchemaMigrations(ctx, db); err != nil { + return err + } + + entries, err := migrationsFS.ReadDir("migrations") + if err != nil { + return err + } + + var names []string + for _, entry := range entries { + if entry.IsDir() { + continue + } + names = append(names, entry.Name()) + } + sort.Strings(names) + + for _, name := range names { + if strings.HasPrefix(name, ".") { + continue + } + applied, err := migrationApplied(ctx, db, name) + if err != nil { + return err + } + if applied { + continue + } + + content, err := migrationsFS.ReadFile("migrations/" + name) + if err != nil { + return err + } + sqlText := strings.TrimSpace(string(content)) + if sqlText == "" { + continue + } + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return err + } + defer func() { _ = tx.Rollback() }() + + for i, stmt := range strings.Split(sqlText, ";") { + stmt = strings.TrimSpace(stmt) + if stmt == "" { + continue + } + + if _, err := tx.ExecContext(ctx, stmt); err != nil { + slog.Warn("migration statement failed", "migration", name, "statement", i, "error", err) + return fmt.Errorf("migration %s statement %d failed: %w", name, i, err) + } + } + + if _, err := tx.ExecContext(ctx, `INSERT INTO schema_migrations(version, applied_at) VALUES($1, $2)`, name, time.Now().UTC()); err != nil { + return fmt.Errorf("failed to record migration %s: %w", name, err) + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("migration %s commit failed: %w", name, err) + } + } + + return nil +} + +func ensureSchemaMigrations(ctx context.Context, db *DB) error { + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return err + } + defer func() { _ = tx.Rollback() }() + + _, err = tx.ExecContext(ctx, ` + CREATE TABLE IF NOT EXISTS schema_migrations ( + version TEXT PRIMARY KEY, + applied_at TIMESTAMPTZ NOT NULL + ) + `) + if err != nil { + return err + } + + return tx.Commit() +} + +func migrationApplied(ctx context.Context, db *DB, version string) (bool, error) { + var existing string + err := db.QueryRowContext(ctx, `SELECT version FROM schema_migrations WHERE version = $1`, version).Scan(&existing) + if err == sql.ErrNoRows { + return false, nil + } + if err != nil { + return false, err + } + return true, nil +} diff --git a/internal/db/migrations/001_init_schema.sql b/internal/db/migrations/001_init_schema.sql new file mode 100644 index 0000000..12bf54b --- /dev/null +++ b/internal/db/migrations/001_init_schema.sql @@ -0,0 +1,31 @@ +CREATE TABLE IF NOT EXISTS guilds ( + id BIGINT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, + guild_id TEXT NOT NULL UNIQUE, + settings_json JSONB NOT NULL DEFAULT '{}'::jsonb, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +CREATE TABLE IF NOT EXISTS users ( + id BIGINT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, + guild_id TEXT NOT NULL, + user_id TEXT NOT NULL, + profile_json JSONB NOT NULL DEFAULT '{}'::jsonb, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + UNIQUE (guild_id, user_id) +); + +CREATE TABLE IF NOT EXISTS moderation_logs ( + id BIGINT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, + guild_id TEXT NOT NULL, + moderator_id TEXT NOT NULL, + target_id TEXT NOT NULL, + action TEXT NOT NULL, + reason TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +CREATE INDEX IF NOT EXISTS idx_moderation_logs_guild_id ON moderation_logs (guild_id); +CREATE INDEX IF NOT EXISTS idx_moderation_logs_guild_created ON moderation_logs (guild_id, created_at DESC); +CREATE INDEX IF NOT EXISTS idx_users_guild_id ON users (guild_id); \ No newline at end of file diff --git a/internal/db/moderation.go b/internal/db/moderation.go new file mode 100644 index 0000000..48f1d9a --- /dev/null +++ b/internal/db/moderation.go @@ -0,0 +1,74 @@ +package db + +import ( + "context" + "fmt" + "strings" +) + +func LogModeration(ctx context.Context, db *DB, guildID, moderatorID, targetID, action, reason string) error { + if guildID == "" { + return fmt.Errorf("guildID required") + } + _, err := db.ExecContext(ctx, ` + INSERT INTO moderation_logs (guild_id, moderator_id, target_id, action, reason, created_at) + VALUES ($1, $2, $3, $4, $5, now()) + `, guildID, moderatorID, targetID, action, reason) + return err +} + +func GetModerationLogs(ctx context.Context, db *DB, guildID string, limit int, cursor *ModerationCursor) ([]*ModerationLog, error) { + if guildID == "" { + return nil, fmt.Errorf("guildID required") + } + if limit <= 0 { + limit = 20 + } + + var args []interface{} + var clauses []string + + args = append(args, guildID) + clauses = append(clauses, fmt.Sprintf("guild_id = $%d", len(args))) + + if cursor != nil { + args = append(args, cursor.CreatedAt, cursor.ID) + clauses = append(clauses, fmt.Sprintf("(created_at, id) < ($%d, $%d)", len(args)-1, len(args))) + } + + args = append(args, limit) + query := fmt.Sprintf(` + SELECT id, guild_id, moderator_id, target_id, action, reason, created_at + FROM moderation_logs + WHERE %s + ORDER BY created_at DESC, id DESC + LIMIT $%d`, strings.Join(clauses, " AND "), len(args)) + + rows, err := db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + logs := make([]*ModerationLog, 0) + for rows.Next() { + var logEntry ModerationLog + if err := rows.Scan( + &logEntry.ID, + &logEntry.GuildID, + &logEntry.ModeratorID, + &logEntry.TargetID, + &logEntry.Action, + &logEntry.Reason, + &logEntry.CreatedAt, + ); err != nil { + return nil, err + } + logs = append(logs, &logEntry) + } + + if err := rows.Err(); err != nil { + return nil, err + } + return logs, nil +} diff --git a/internal/db/types.go b/internal/db/types.go new file mode 100644 index 0000000..2d845d0 --- /dev/null +++ b/internal/db/types.go @@ -0,0 +1,35 @@ +package db + +import "time" + +type Guild struct { + ID int64 + GuildID string + SettingsJSON []byte + CreatedAt time.Time + UpdatedAt time.Time +} + +type User struct { + ID int64 + GuildID string + UserID string + ProfileJSON []byte + CreatedAt time.Time + UpdatedAt time.Time +} + +type ModerationLog struct { + ID int64 + GuildID string + ModeratorID string + TargetID string + Action string + Reason string + CreatedAt time.Time +} + +type ModerationCursor struct { + CreatedAt time.Time + ID int64 +} diff --git a/internal/db/user.go b/internal/db/user.go new file mode 100644 index 0000000..5f7f75a --- /dev/null +++ b/internal/db/user.go @@ -0,0 +1,64 @@ +package db + +import ( + "context" + "database/sql" + "fmt" +) + +func GetUser(ctx context.Context, db *DB, guildID, userID string) (*User, error) { + if guildID == "" { + return nil, fmt.Errorf("guildID required") + } + if userID == "" { + return nil, fmt.Errorf("userID required") + } + var user User + err := db.QueryRowContext(ctx, ` + SELECT id, guild_id, user_id, profile_json, created_at, updated_at + FROM users + WHERE guild_id = $1 AND user_id = $2 + `, guildID, userID).Scan( + &user.ID, + &user.GuildID, + &user.UserID, + &user.ProfileJSON, + &user.CreatedAt, + &user.UpdatedAt, + ) + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, err + } + return &user, nil +} + +func UpsertUser(ctx context.Context, db *DB, guildID, userID, profileJSON string) error { + if guildID == "" { + return fmt.Errorf("guildID required") + } + if userID == "" { + return fmt.Errorf("userID required") + } + _, err := db.ExecContext(ctx, ` + INSERT INTO users (guild_id, user_id, profile_json, created_at, updated_at) + VALUES ($1, $2, $3::jsonb, now(), now()) + ON CONFLICT (guild_id, user_id) DO UPDATE SET + profile_json = EXCLUDED.profile_json, + updated_at = now() + `, guildID, userID, profileJSON) + return err +} + +func DeleteUser(ctx context.Context, db *DB, guildID, userID string) error { + if guildID == "" { + return fmt.Errorf("guildID required") + } + if userID == "" { + return fmt.Errorf("userID required") + } + _, err := db.ExecContext(ctx, `DELETE FROM users WHERE guild_id = $1 AND user_id = $2`, guildID, userID) + return err +} From c4adcbab7bf3d8260120a9bd9edb4c1507d4b217 Mon Sep 17 00:00:00 2001 From: Vuks69 <51289041+Vuks69@users.noreply.github.com> Date: Thu, 21 May 2026 23:48:31 +0200 Subject: [PATCH 2/2] test: add unit tests for command dispatching and environment variable parsing --- internal/core/start_test.go | 99 +++++++++++++++++++++++++++ internal/db/cache_test.go | 118 +++++++++++++++++++++++++++++++++ internal/db/connection_test.go | 52 +++++++++++++++ 3 files changed, 269 insertions(+) create mode 100644 internal/core/start_test.go create mode 100644 internal/db/cache_test.go create mode 100644 internal/db/connection_test.go diff --git a/internal/core/start_test.go b/internal/core/start_test.go new file mode 100644 index 0000000..a952852 --- /dev/null +++ b/internal/core/start_test.go @@ -0,0 +1,99 @@ +package core + +import ( + "testing" + + "github.com/disgoorg/disgo/discord" + "github.com/disgoorg/disgo/events" +) + +func TestDispatchExactMatch(t *testing.T) { + called := false + RegisterCommand("test-exact", func(e *events.MessageCreate) { called = true }) + defer delete(commandRegistry, "test-exact") + + DispatchCommand(&events.MessageCreate{ + GenericMessage: &events.GenericMessage{ + Message: discord.Message{ + Content: "!test-exact", + Author: discord.User{Bot: false}, + }, + }, + }) + + if !called { + t.Fatal("expected handler called") + } +} + +func TestDispatchIgnoresBot(t *testing.T) { + called := false + RegisterCommand("test-bot", func(e *events.MessageCreate) { called = true }) + defer delete(commandRegistry, "test-bot") + + DispatchCommand(&events.MessageCreate{ + GenericMessage: &events.GenericMessage{ + Message: discord.Message{ + Content: "!test-bot", + Author: discord.User{Bot: true}, + }, + }, + }) + + if called { + t.Fatal("expected bot message ignored") + } +} + +func TestDispatchNoPrefix(t *testing.T) { + called := false + RegisterCommand("test-noprefix", func(e *events.MessageCreate) { called = true }) + defer delete(commandRegistry, "test-noprefix") + + DispatchCommand(&events.MessageCreate{ + GenericMessage: &events.GenericMessage{ + Message: discord.Message{ + Content: "test-noprefix", + Author: discord.User{Bot: false}, + }, + }, + }) + + if called { + t.Fatal("expected no dispatch without prefix") + } +} + +func TestDispatchUnknownCommand(t *testing.T) { + DispatchCommand(&events.MessageCreate{ + GenericMessage: &events.GenericMessage{ + Message: discord.Message{ + Content: "!unknown-cmd", + Author: discord.User{Bot: false}, + }, + }, + }) +} + +func TestCustomPrefix(t *testing.T) { + oldPrefix := botPrefix + botPrefix = "?" + t.Cleanup(func() { botPrefix = oldPrefix }) + + called := false + RegisterCommand("test-custom", func(e *events.MessageCreate) { called = true }) + defer delete(commandRegistry, "test-custom") + + DispatchCommand(&events.MessageCreate{ + GenericMessage: &events.GenericMessage{ + Message: discord.Message{ + Content: "?test-custom", + Author: discord.User{Bot: false}, + }, + }, + }) + + if !called { + t.Fatal("expected handler called with custom prefix") + } +} diff --git a/internal/db/cache_test.go b/internal/db/cache_test.go new file mode 100644 index 0000000..e43811b --- /dev/null +++ b/internal/db/cache_test.go @@ -0,0 +1,118 @@ +package db_test + +import ( + "context" + "testing" + "time" + + "github.com/Neon-Genesis-Linux/pen-bot/internal/db" +) + +func TestSetGet(t *testing.T) { + const key = "test-set-get" + t.Cleanup(func() { db.DeleteCacheEntry(key) }) + + db.SetCacheEntry(key, []byte("v"), 0) + got, ok := db.GetCacheEntry(key) + if !ok { + t.Fatal("expected ok") + } + if string(got) != "v" { + t.Fatalf("got %q, want %q", string(got), "v") + } + + db.DeleteCacheEntry(key) + _, ok = db.GetCacheEntry(key) + if ok { + t.Fatal("expected false after delete") + } +} + +func TestExpiry(t *testing.T) { + const key = "test-expiry" + t.Cleanup(func() { db.DeleteCacheEntry(key) }) + + db.SetCacheEntry(key, []byte("v"), time.Millisecond) + time.Sleep(2 * time.Millisecond) + _, ok := db.GetCacheEntry(key) + if ok { + t.Fatal("expected false after expiry") + } +} + +func TestNoExpiryWhenTTLZero(t *testing.T) { + const key = "test-no-expiry" + t.Cleanup(func() { db.DeleteCacheEntry(key) }) + + db.SetCacheEntry(key, []byte("v"), 0) + got, ok := db.GetCacheEntry(key) + if !ok { + t.Fatal("expected ok for TTL=0") + } + if string(got) != "v" { + t.Fatalf("got %q, want %q", string(got), "v") + } +} + +func TestOverwrite(t *testing.T) { + const key = "test-overwrite" + t.Cleanup(func() { db.DeleteCacheEntry(key) }) + + db.SetCacheEntry(key, []byte("v1"), 0) + db.SetCacheEntry(key, []byte("v2"), 0) + got, ok := db.GetCacheEntry(key) + if !ok { + t.Fatal("expected ok") + } + if string(got) != "v2" { + t.Fatalf("got %q, want %q", string(got), "v2") + } +} + +func TestCleanup(t *testing.T) { + expireKey := "test-cleanup-expire" + keepKey := "test-cleanup-keep" + t.Cleanup(func() { + db.DeleteCacheEntry(expireKey) + db.DeleteCacheEntry(keepKey) + }) + + db.SetCacheEntry(expireKey, []byte("x"), time.Millisecond) + db.SetCacheEntry(keepKey, []byte("y"), 0) + + time.Sleep(2 * time.Millisecond) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go db.StartCleanup(ctx, time.Microsecond) + time.Sleep(10 * time.Millisecond) + + _, ok := db.GetCacheEntry(expireKey) + if ok { + t.Fatal("expected expired entry removed by cleanup") + } + _, ok = db.GetCacheEntry(keepKey) + if !ok { + t.Fatal("expected non-expired entry kept") + } +} + +func TestConcurrentAccess(t *testing.T) { + t.Parallel() + + done := make(chan struct{}, 2) + go func() { + for i := 0; i < 100; i++ { + db.SetCacheEntry("race-key", []byte("v"), time.Minute) + } + done <- struct{}{} + }() + go func() { + for i := 0; i < 100; i++ { + db.GetCacheEntry("race-key") + } + done <- struct{}{} + }() + <-done + <-done +} diff --git a/internal/db/connection_test.go b/internal/db/connection_test.go new file mode 100644 index 0000000..73a0872 --- /dev/null +++ b/internal/db/connection_test.go @@ -0,0 +1,52 @@ +package db + +import ( + "testing" + "time" +) + +func TestGetEnv(t *testing.T) { + t.Setenv("TEST_GETENV_KEY", "val") + if got := getEnv("TEST_GETENV_KEY", "fall"); got != "val" { + t.Fatalf("got %q, want %q", got, "val") + } + if got := getEnv("TEST_GETENV_MISSING", "fall"); got != "fall" { + t.Fatalf("got %q, want %q", got, "fall") + } +} + +func TestParseIntEnv(t *testing.T) { + t.Setenv("TEST_INT_VALID", "42") + if got := parseIntEnv("TEST_INT_VALID", 1); got != 42 { + t.Fatalf("got %d, want %d", got, 42) + } + if got := parseIntEnv("TEST_INT_MISSING", 1); got != 1 { + t.Fatalf("got %d, want %d", got, 1) + } + t.Setenv("TEST_INT_INVALID", "notanumber") + if got := parseIntEnv("TEST_INT_INVALID", 5); got != 5 { + t.Fatalf("got %d, want %d", got, 5) + } +} + +func TestParseDurationEnv(t *testing.T) { + want := 5 * time.Minute + t.Setenv("TEST_DUR_VALID", "5m") + if got := parseDurationEnv("TEST_DUR_VALID", time.Hour); got != want { + t.Fatalf("got %v, want %v", got, want) + } + if got := parseDurationEnv("TEST_DUR_MISSING", time.Hour); got != time.Hour { + t.Fatalf("got %v, want %v", got, time.Hour) + } +} + +func TestParseCleanupIntervalEnv(t *testing.T) { + want := 10 * time.Minute + t.Setenv("TEST_CLEANUP", "10m") + if got := ParseCleanupIntervalEnv("TEST_CLEANUP", 15); got != want { + t.Fatalf("got %v, want %v", got, want) + } + if got := ParseCleanupIntervalEnv("TEST_CLEANUP_MISSING", 15); got != 15*time.Minute { + t.Fatalf("got %v, want %v", got, 15*time.Minute) + } +}