Skip to content

Commit 28a6c81

Browse files
author
Meru Patel
committed
test: add unit tests for generator, safety packages
- Add 13 tests for SQLGenerator (CREATE/DROP/ALTER TABLE, COLUMN, INDEX, etc) - Add 4 tests for safety/validator - Add 4 tests for safety/backup - Coverage improved: 24% -> 35%
1 parent 92d4512 commit 28a6c81

3 files changed

Lines changed: 316 additions & 0 deletions

File tree

internal/diff/generator_test.go

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
package diff
2+
3+
import (
4+
"strings"
5+
"testing"
6+
7+
"github.com/meru143/dbdiff/pkg/types"
8+
)
9+
10+
func TestSQLGenerator_CreateTable(t *testing.T) {
11+
diffs := types.DiffList{
12+
{Type: types.DiffAdd, Object: types.ObjectTable, Name: "users"},
13+
}
14+
gen := NewSQLGenerator(diffs, "public")
15+
gen.SetTransaction(false)
16+
sql := gen.Generate()
17+
18+
if !strings.Contains(sql, "CREATE TABLE") {
19+
t.Errorf("Expected CREATE TABLE, got: %s", sql)
20+
}
21+
}
22+
23+
func TestSQLGenerator_DropTable(t *testing.T) {
24+
diffs := types.DiffList{
25+
{Type: types.DiffDrop, Object: types.ObjectTable, Name: "users"},
26+
}
27+
gen := NewSQLGenerator(diffs, "public")
28+
gen.SetTransaction(false)
29+
sql := gen.Generate()
30+
31+
if !strings.Contains(sql, "DROP TABLE") {
32+
t.Errorf("Expected DROP TABLE, got: %s", sql)
33+
}
34+
}
35+
36+
func TestSQLGenerator_AddColumn(t *testing.T) {
37+
diffs := types.DiffList{
38+
{Type: types.DiffAdd, Object: types.ObjectColumn, Name: "email", TableName: "users", NewValue: "varchar"},
39+
}
40+
gen := NewSQLGenerator(diffs, "public")
41+
gen.SetTransaction(false)
42+
sql := gen.Generate()
43+
44+
if !strings.Contains(sql, "ADD COLUMN") {
45+
t.Errorf("Expected ADD COLUMN, got: %s", sql)
46+
}
47+
}
48+
49+
func TestSQLGenerator_DropColumn(t *testing.T) {
50+
diffs := types.DiffList{
51+
{Type: types.DiffDrop, Object: types.ObjectColumn, Name: "email", TableName: "users"},
52+
}
53+
gen := NewSQLGenerator(diffs, "public")
54+
gen.SetTransaction(false)
55+
sql := gen.Generate()
56+
57+
if !strings.Contains(sql, "DROP COLUMN") {
58+
t.Errorf("Expected DROP COLUMN, got: %s", sql)
59+
}
60+
}
61+
62+
func TestSQLGenerator_AlterColumn(t *testing.T) {
63+
diffs := types.DiffList{
64+
{Type: types.DiffAlter, Object: types.ObjectColumn, Name: "email", TableName: "users", OldValue: "varchar", NewValue: "text"},
65+
}
66+
gen := NewSQLGenerator(diffs, "public")
67+
gen.SetTransaction(false)
68+
sql := gen.Generate()
69+
70+
if !strings.Contains(sql, "ALTER COLUMN") {
71+
t.Errorf("Expected ALTER COLUMN, got: %s", sql)
72+
}
73+
}
74+
75+
func TestSQLGenerator_CreateIndex(t *testing.T) {
76+
diffs := types.DiffList{
77+
{Type: types.DiffAdd, Object: types.ObjectIndex, Name: "users_email_idx", TableName: "users"},
78+
}
79+
gen := NewSQLGenerator(diffs, "public")
80+
gen.SetTransaction(false)
81+
sql := gen.Generate()
82+
83+
if !strings.Contains(sql, "CREATE INDEX") {
84+
t.Errorf("Expected CREATE INDEX, got: %s", sql)
85+
}
86+
}
87+
88+
func TestSQLGenerator_DropIndex(t *testing.T) {
89+
diffs := types.DiffList{
90+
{Type: types.DiffDrop, Object: types.ObjectIndex, Name: "users_email_idx"},
91+
}
92+
gen := NewSQLGenerator(diffs, "public")
93+
gen.SetTransaction(false)
94+
sql := gen.Generate()
95+
96+
if !strings.Contains(sql, "DROP INDEX") {
97+
t.Errorf("Expected DROP INDEX, got: %s", sql)
98+
}
99+
}
100+
101+
func TestSQLGenerator_AddConstraint(t *testing.T) {
102+
diffs := types.DiffList{
103+
{Type: types.DiffAdd, Object: types.ObjectConstraint, Name: "users_pkey", TableName: "users"},
104+
}
105+
gen := NewSQLGenerator(diffs, "public")
106+
gen.SetTransaction(false)
107+
sql := gen.Generate()
108+
109+
if !strings.Contains(sql, "ADD CONSTRAINT") {
110+
t.Errorf("Expected ADD CONSTRAINT, got: %s", sql)
111+
}
112+
}
113+
114+
func TestSQLGenerator_DropConstraint(t *testing.T) {
115+
diffs := types.DiffList{
116+
{Type: types.DiffDrop, Object: types.ObjectConstraint, Name: "users_pkey", TableName: "users"},
117+
}
118+
gen := NewSQLGenerator(diffs, "public")
119+
gen.SetTransaction(false)
120+
sql := gen.Generate()
121+
122+
if !strings.Contains(sql, "DROP CONSTRAINT") {
123+
t.Errorf("Expected DROP CONSTRAINT, got: %s", sql)
124+
}
125+
}
126+
127+
func TestSQLGenerator_CreateSequence(t *testing.T) {
128+
diffs := types.DiffList{
129+
{Type: types.DiffAdd, Object: types.ObjectSequence, Name: "users_id_seq"},
130+
}
131+
gen := NewSQLGenerator(diffs, "public")
132+
gen.SetTransaction(false)
133+
sql := gen.Generate()
134+
135+
if !strings.Contains(sql, "CREATE SEQUENCE") {
136+
t.Errorf("Expected CREATE SEQUENCE, got: %s", sql)
137+
}
138+
}
139+
140+
func TestSQLGenerator_TransactionWrapper(t *testing.T) {
141+
diffs := types.DiffList{
142+
{Type: types.DiffAdd, Object: types.ObjectTable, Name: "users"},
143+
}
144+
gen := NewSQLGenerator(diffs, "public")
145+
gen.SetTransaction(true)
146+
sql := gen.Generate()
147+
148+
if !strings.Contains(sql, "DO $$") {
149+
t.Errorf("Expected transaction wrapper, got: %s", sql)
150+
}
151+
}
152+
153+
func TestSQLGenerator_EmptyDiffs(t *testing.T) {
154+
diffs := types.DiffList{}
155+
gen := NewSQLGenerator(diffs, "public")
156+
gen.SetTransaction(false)
157+
sql := gen.Generate()
158+
159+
if strings.Contains(sql, "CREATE") || strings.Contains(sql, "DROP") {
160+
t.Errorf("Expected empty output for empty diffs, got: %s", sql)
161+
}
162+
}
163+
164+
func TestSQLGenerator_MultipleDiffs(t *testing.T) {
165+
diffs := types.DiffList{
166+
{Type: types.DiffAdd, Object: types.ObjectTable, Name: "users"},
167+
{Type: types.DiffAdd, Object: types.ObjectTable, Name: "posts"},
168+
{Type: types.DiffDrop, Object: types.ObjectTable, Name: "old_table"},
169+
}
170+
gen := NewSQLGenerator(diffs, "public")
171+
gen.SetTransaction(false)
172+
sql := gen.Generate()
173+
174+
if !strings.Contains(sql, "users") || !strings.Contains(sql, "posts") || !strings.Contains(sql, "old_table") {
175+
t.Errorf("Expected all table names in output, got: %s", sql)
176+
}
177+
}

internal/safety/backup_test.go

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
package safety
2+
3+
import (
4+
"os"
5+
"path/filepath"
6+
"strings"
7+
"testing"
8+
)
9+
10+
func TestBackupManager_NewBackupManager(t *testing.T) {
11+
tmpDir := t.TempDir()
12+
bm := NewBackupManager(tmpDir, 3)
13+
14+
if bm == nil {
15+
t.Fatal("Expected backup manager, got nil")
16+
}
17+
}
18+
19+
func TestBackupManager_Backup(t *testing.T) {
20+
tmpDir := t.TempDir()
21+
bm := NewBackupManager(tmpDir, 3)
22+
23+
content := []byte("test migration content")
24+
backupPath, err := bm.Backup("test", content)
25+
if err != nil {
26+
t.Fatalf("Backup failed: %v", err)
27+
}
28+
29+
// Check file exists
30+
if _, err := os.Stat(backupPath); os.IsNotExist(err) {
31+
t.Error("Backup file not created")
32+
}
33+
34+
// Check content
35+
readContent, err := os.ReadFile(backupPath)
36+
if err != nil {
37+
t.Fatalf("Failed to read backup: %v", err)
38+
}
39+
if string(readContent) != string(content) {
40+
t.Error("Backup content mismatch")
41+
}
42+
}
43+
44+
func TestBackupManager_RestorePath(t *testing.T) {
45+
tmpDir := t.TempDir()
46+
bm := NewBackupManager(tmpDir, 3)
47+
48+
backupPath, _ := bm.Backup("test", []byte("original content"))
49+
restored, err := bm.RestorePath(backupPath)
50+
if err != nil {
51+
t.Fatalf("RestorePath failed: %v", err)
52+
}
53+
if string(restored) != "original content" {
54+
t.Error("Restored content mismatch")
55+
}
56+
}
57+
58+
func TestBackupManager_CustomPrefix(t *testing.T) {
59+
tmpDir := t.TempDir()
60+
bm := NewBackupManager(tmpDir, 3)
61+
62+
backupPath, _ := bm.Backup("custom", []byte("data"))
63+
if !strings.HasPrefix(filepath.Base(backupPath), "custom") {
64+
t.Error("Backup path should have custom prefix")
65+
}
66+
}

internal/safety/validator_test.go

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
package safety
2+
3+
import (
4+
"testing"
5+
6+
"github.com/meru143/dbdiff/pkg/types"
7+
)
8+
9+
func TestValidator_NewValidator(t *testing.T) {
10+
protected := []string{"users", "accounts"}
11+
v := NewValidator(protected, false)
12+
13+
if v == nil {
14+
t.Fatal("Expected validator, got nil")
15+
}
16+
}
17+
18+
func TestValidator_IsProtected(t *testing.T) {
19+
protected := []string{"users", "accounts"}
20+
v := NewValidator(protected, false)
21+
22+
tests := []struct {
23+
name string
24+
objName string
25+
expected bool
26+
}{
27+
{"users table", "users", true},
28+
{"accounts table", "accounts", true},
29+
{"other table", "posts", false},
30+
{"empty", "", false},
31+
}
32+
33+
for _, tt := range tests {
34+
t.Run(tt.name, func(t *testing.T) {
35+
result := v.IsProtected(tt.objName)
36+
if result != tt.expected {
37+
t.Errorf("IsProtected(%s) = %v, want %v", tt.objName, result, tt.expected)
38+
}
39+
})
40+
}
41+
}
42+
43+
func TestValidator_ValidateDiff(t *testing.T) {
44+
protected := []string{"users"}
45+
v := NewValidator(protected, false)
46+
47+
diffs := types.DiffList{
48+
{Type: types.DiffDrop, Object: types.ObjectTable, Name: "users"},
49+
{Type: types.DiffDrop, Object: types.ObjectTable, Name: "posts"},
50+
}
51+
52+
result := v.ValidateDiff(&diffs[0])
53+
if len(result.Errors) == 0 {
54+
t.Error("Expected error for protected table 'users'")
55+
}
56+
57+
result = v.ValidateDiff(&diffs[1])
58+
if len(result.Errors) > 0 {
59+
t.Error("Expected no error for non-protected table 'posts'")
60+
}
61+
}
62+
63+
func TestValidator_SetProtectedObjects(t *testing.T) {
64+
v := NewValidator([]string{}, false)
65+
v.SetProtectedObjects([]string{"admin", "config"})
66+
67+
if !v.IsProtected("admin") {
68+
t.Error("Expected 'admin' to be protected")
69+
}
70+
if v.IsProtected("users") {
71+
t.Error("Expected 'users' not to be protected")
72+
}
73+
}

0 commit comments

Comments
 (0)