|
1 | 1 | package testhelpers |
2 | 2 |
|
3 | 3 | import ( |
| 4 | + "fmt" |
| 5 | + "log" |
| 6 | + "os" |
| 7 | + "path/filepath" |
| 8 | + "regexp" |
| 9 | + "strings" |
| 10 | + "sync" |
4 | 11 | "testing" |
5 | 12 |
|
6 | | - "gorm.io/driver/sqlite" |
7 | | - "gorm.io/gorm" |
8 | | - |
9 | 13 | "github.com/gov-dx-sandbox/exchange/policy-decision-point/v1/models" |
| 14 | + "github.com/joho/godotenv" |
| 15 | + "gorm.io/driver/postgres" |
| 16 | + "gorm.io/gorm" |
10 | 17 | ) |
11 | 18 |
|
12 | 19 | // StringPtr returns a pointer to the given string value. |
13 | | -// This is a convenience helper for test code that needs string pointers. |
14 | 20 | func StringPtr(s string) *string { |
15 | 21 | return &s |
16 | 22 | } |
17 | 23 |
|
18 | 24 | // OwnerPtr returns a pointer to the given Owner value. |
19 | | -// This is a convenience helper for test code that needs Owner pointers. |
20 | 25 | func OwnerPtr(o models.Owner) *models.Owner { |
21 | 26 | return &o |
22 | 27 | } |
23 | 28 |
|
24 | | -// SetupTestDB creates an in-memory SQLite database for testing. |
25 | | -// It creates the policy_metadata table with SQLite-compatible schema. |
26 | | -// SQLite doesn't support PostgreSQL-specific features like gen_random_uuid(), enums, jsonb. |
27 | | -func SetupTestDB(t *testing.T) *gorm.DB { |
28 | | - db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) |
| 29 | +var envLoadOnce sync.Once |
| 30 | + |
| 31 | +// loadEnvOnce loads environment variables from .env.local file (once) |
| 32 | +func loadEnvOnce() { |
| 33 | + envLoadOnce.Do(func() { |
| 34 | + // Try to load .env.local file from current directory and parent directories |
| 35 | + envFiles := []string{ |
| 36 | + ".env.local", |
| 37 | + "../.env.local", |
| 38 | + "../../.env.local", |
| 39 | + "../../../.env.local", |
| 40 | + } |
| 41 | + |
| 42 | + for _, envFile := range envFiles { |
| 43 | + if absPath, err := filepath.Abs(envFile); err == nil { |
| 44 | + if _, err := os.Stat(absPath); err == nil { |
| 45 | + if err := godotenv.Load(absPath); err == nil { |
| 46 | + log.Printf("Loaded test environment from: %s", absPath) |
| 47 | + return |
| 48 | + } |
| 49 | + } |
| 50 | + } |
| 51 | + } |
| 52 | + }) |
| 53 | +} |
| 54 | + |
| 55 | +// getEnvOrDefault returns the environment variable value or a default |
| 56 | +func getEnvOrDefault(key, defaultValue string) string { |
| 57 | + loadEnvOnce() // Ensure .env.local is loaded |
| 58 | + if value := os.Getenv(key); value != "" { |
| 59 | + return value |
| 60 | + } |
| 61 | + return defaultValue |
| 62 | +} |
| 63 | + |
| 64 | +// isValidDBName checks if the database name is safe to use in SQL |
| 65 | +func isValidDBName(name string) bool { |
| 66 | + match, _ := regexp.MatchString("^[a-zA-Z0-9_]+$", name) |
| 67 | + return match |
| 68 | +} |
| 69 | + |
| 70 | +// SetupPostgresTestDB creates a PostgreSQL test database connection |
| 71 | +func SetupPostgresTestDB(t *testing.T) *gorm.DB { |
| 72 | + host := getEnvOrDefault("TEST_DB_HOST", "localhost") |
| 73 | + port := getEnvOrDefault("TEST_DB_PORT", "5432") |
| 74 | + testDB := getEnvOrDefault("TEST_DB_DATABASE", "pdp_service_test") |
| 75 | + sslmode := getEnvOrDefault("TEST_DB_SSLMODE", "disable") |
| 76 | + |
| 77 | + // Try to get credentials from environment first |
| 78 | + loadEnvOnce() |
| 79 | + |
| 80 | + // Try credential combinations |
| 81 | + credentials := []struct { |
| 82 | + user string |
| 83 | + pass string |
| 84 | + }{ |
| 85 | + {getEnvOrDefault("TEST_DB_USERNAME", "postgres"), getEnvOrDefault("TEST_DB_PASSWORD", "password")}, |
| 86 | + {"postgres", "password"}, |
| 87 | + {"postgres", ""}, |
| 88 | + {os.Getenv("USER"), ""}, |
| 89 | + } |
| 90 | + |
| 91 | + var db *gorm.DB |
| 92 | + var err error |
| 93 | + |
| 94 | + for _, cred := range credentials { |
| 95 | + if cred.user == "" { |
| 96 | + continue |
| 97 | + } |
| 98 | + |
| 99 | + // 1. Try connecting to the test database directly |
| 100 | + db, err = tryConnection(host, port, cred.user, cred.pass, testDB, sslmode) |
| 101 | + if err == nil { |
| 102 | + t.Logf("Connected to PostgreSQL with user=%s", cred.user) |
| 103 | + return setupDatabase(t, db) |
| 104 | + } |
| 105 | + |
| 106 | + // 2. If test database doesn't exist, try to connect to default database and create it |
| 107 | + if isDBNotExistError(err) { |
| 108 | + defaultDB := "postgres" |
| 109 | + if adminDB, adminErr := tryConnection(host, port, cred.user, cred.pass, defaultDB, sslmode); adminErr == nil { |
| 110 | + t.Logf("Connected to admin database, creating test database") |
| 111 | + |
| 112 | + // Validate DB name to prevent SQL injection |
| 113 | + if !isValidDBName(testDB) { |
| 114 | + t.Fatalf("Invalid database name: %s", testDB) |
| 115 | + } |
| 116 | + if !isValidDBName(cred.user) { |
| 117 | + t.Fatalf("Invalid database owner: %s", cred.user) |
| 118 | + } |
| 119 | + |
| 120 | + // Create test database |
| 121 | + // Note: CREATE DATABASE cannot be parameterized in postgres, so we use string formatting |
| 122 | + // We rely on isValidDBName for safety, and also double-quote the identifiers |
| 123 | + createSQL := fmt.Sprintf("CREATE DATABASE \"%s\" WITH OWNER = \"%s\"", testDB, cred.user) |
| 124 | + if createErr := adminDB.Exec(createSQL).Error; createErr != nil { |
| 125 | + // Database might already exist (race condition), ignore error |
| 126 | + t.Logf("Note: Could not create test database: %v", createErr) |
| 127 | + } |
| 128 | + |
| 129 | + // Close admin connection properly |
| 130 | + if sqlDB, err := adminDB.DB(); err == nil { |
| 131 | + sqlDB.Close() |
| 132 | + } |
| 133 | + |
| 134 | + // Now try connecting to the test database again |
| 135 | + db, err = tryConnection(host, port, cred.user, cred.pass, testDB, sslmode) |
| 136 | + if err == nil { |
| 137 | + t.Logf("Successfully created and connected to test database with user=%s", cred.user) |
| 138 | + return setupDatabase(t, db) |
| 139 | + } |
| 140 | + } |
| 141 | + } |
| 142 | + } |
| 143 | + |
| 144 | + if err != nil { |
| 145 | + t.Skipf("Skipping test: could not connect to test database with any credentials: %v", err) |
| 146 | + return nil |
| 147 | + } |
| 148 | + |
| 149 | + return setupDatabase(t, db) |
| 150 | +} |
| 151 | + |
| 152 | +// tryConnection attempts to connect to PostgreSQL with given credentials |
| 153 | +func tryConnection(host, port, user, password, database, sslmode string) (*gorm.DB, error) { |
| 154 | + dsn := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=%s", |
| 155 | + host, port, user, password, database, sslmode) |
| 156 | + return gorm.Open(postgres.Open(dsn), &gorm.Config{ |
| 157 | + DisableForeignKeyConstraintWhenMigrating: true, |
| 158 | + }) |
| 159 | +} |
| 160 | + |
| 161 | +// isDBNotExistError checks if the error is due to database not existing |
| 162 | +func isDBNotExistError(err error) bool { |
| 163 | + if err == nil { |
| 164 | + return false |
| 165 | + } |
| 166 | + return strings.Contains(err.Error(), "does not exist") || strings.Contains(err.Error(), "3D000") |
| 167 | +} |
| 168 | + |
| 169 | +// setupDatabase performs migration and cleanup for the test database |
| 170 | +func setupDatabase(t *testing.T, db *gorm.DB) *gorm.DB { |
| 171 | + // Auto-migrate all models |
| 172 | + err := db.AutoMigrate( |
| 173 | + &models.PolicyMetadata{}, |
| 174 | + ) |
29 | 175 | if err != nil { |
30 | | - t.Fatalf("Failed to connect to test database: %v", err) |
31 | | - } |
32 | | - |
33 | | - // Create table manually for SQLite compatibility |
34 | | - createTableSQL := ` |
35 | | - CREATE TABLE IF NOT EXISTS policy_metadata ( |
36 | | - id TEXT PRIMARY KEY, |
37 | | - schema_id TEXT NOT NULL, |
38 | | - field_name TEXT NOT NULL, |
39 | | - display_name TEXT, |
40 | | - description TEXT, |
41 | | - source TEXT NOT NULL DEFAULT 'fallback', |
42 | | - is_owner INTEGER NOT NULL DEFAULT 0, |
43 | | - access_control_type TEXT NOT NULL DEFAULT 'restricted', |
44 | | - allow_list TEXT NOT NULL DEFAULT '{}', |
45 | | - owner TEXT, |
46 | | - created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, |
47 | | - updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, |
48 | | - UNIQUE(schema_id, field_name) |
49 | | - ) |
50 | | - ` |
51 | | - if err := db.Exec(createTableSQL).Error; err != nil { |
52 | | - t.Fatalf("Failed to create table: %v", err) |
| 176 | + t.Skipf("Skipping test: could not migrate test database: %v", err) |
| 177 | + return nil |
53 | 178 | } |
54 | 179 |
|
| 180 | + // Clean up test data before each test |
| 181 | + CleanupTestData(t, db) |
| 182 | + |
55 | 183 | return db |
56 | 184 | } |
| 185 | + |
| 186 | +// CleanupTestData removes all test data from the database |
| 187 | +func CleanupTestData(t *testing.T, db *gorm.DB) { |
| 188 | + if db == nil { |
| 189 | + return |
| 190 | + } |
| 191 | + |
| 192 | + // Delete all policy metadata |
| 193 | + if err := db.Exec("DELETE FROM policy_metadata").Error; err != nil { |
| 194 | + t.Logf("Warning: could not cleanup policy_metadata: %v", err) |
| 195 | + } |
| 196 | +} |
0 commit comments