Skip to content

Commit d1be4e4

Browse files
committed
fix: update CI Go version to match go.mod, add DSN parser package
1 parent c95ff69 commit d1be4e4

File tree

3 files changed

+500
-1
lines changed

3 files changed

+500
-1
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ jobs:
1515
- name: Set up Go
1616
uses: actions/setup-go@v5
1717
with:
18-
go-version: '1.21'
18+
go-version-file: 'go.mod'
1919

2020
- name: Download dependencies
2121
run: go mod download

internal/db/dsn/parser.go

Lines changed: 295 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,295 @@
1+
// Package dsn provides DSN parsing utilities for multiple database backends.
2+
// It converts various URL formats into the native DSN strings expected by
3+
// each database driver.
4+
package dsn
5+
6+
import (
7+
"fmt"
8+
"net/url"
9+
"strings"
10+
)
11+
12+
// Config holds parsed connection parameters
13+
type Config struct {
14+
Driver string // "postgres", "mysql", "sqlserver"
15+
Host string
16+
Port string
17+
User string
18+
Password string
19+
Database string
20+
Params map[string]string
21+
}
22+
23+
// Parse converts a connection string into a Config struct.
24+
// Supported formats:
25+
// - postgres://user:pass@host:port/dbname?sslmode=disable
26+
// - mysql://user:pass@tcp(host:port)/dbname?parseTime=true
27+
// - sqlserver://user:pass@host:port?database=dbname
28+
func Parse(connStr string) (*Config, error) {
29+
lower := strings.ToLower(connStr)
30+
31+
switch {
32+
case strings.HasPrefix(lower, "postgres://"), strings.HasPrefix(lower, "postgresql://"):
33+
return parsePostgres(connStr)
34+
case strings.HasPrefix(lower, "mysql://"), strings.HasPrefix(lower, "mysql+tcp://"):
35+
return parseMySQL(connStr)
36+
case strings.HasPrefix(lower, "sqlserver://"):
37+
return parseSQLServer(connStr)
38+
default:
39+
return nil, fmt.Errorf("unsupported connection string format: %s", connStr)
40+
}
41+
}
42+
43+
// ToNativeDSN converts a Config back to the native DSN expected by the driver.
44+
func (c *Config) ToNativeDSN() string {
45+
switch c.Driver {
46+
case "postgres":
47+
return c.toPostgresDSN()
48+
case "mysql":
49+
return c.toMySQLDSN()
50+
case "sqlserver":
51+
return c.toSQLServerDSN()
52+
default:
53+
return ""
54+
}
55+
}
56+
57+
func parsePostgres(connStr string) (*Config, error) {
58+
u, err := url.Parse(connStr)
59+
if err != nil {
60+
return nil, fmt.Errorf("invalid postgres URL: %w", err)
61+
}
62+
63+
cfg := &Config{
64+
Driver: "postgres",
65+
Host: u.Hostname(),
66+
Port: u.Port(),
67+
Params: make(map[string]string),
68+
}
69+
70+
if cfg.Port == "" {
71+
cfg.Port = "5432"
72+
}
73+
74+
if u.User != nil {
75+
cfg.User = u.User.Username()
76+
cfg.Password, _ = u.User.Password()
77+
}
78+
79+
cfg.Database = strings.TrimPrefix(u.Path, "/")
80+
81+
for k, v := range u.Query() {
82+
cfg.Params[k] = v[0]
83+
}
84+
85+
return cfg, nil
86+
}
87+
88+
func parseMySQL(connStr string) (*Config, error) {
89+
// MySQL URLs: mysql://user:pass@tcp(host:port)/dbname?params
90+
// or mysql://user:pass@host:port/dbname?params
91+
92+
// Strip the mysql:// prefix
93+
raw := connStr
94+
if strings.HasPrefix(strings.ToLower(raw), "mysql+tcp://") {
95+
raw = "mysql://" + raw[12:]
96+
}
97+
98+
// Check for tcp() format: mysql://user:pass@tcp(host:port)/dbname
99+
if idx := strings.Index(raw, "tcp("); idx != -1 {
100+
return parseMySQLTCP(raw)
101+
}
102+
103+
// Standard URL format: mysql://user:pass@host:port/dbname
104+
u, err := url.Parse(raw)
105+
if err != nil {
106+
return nil, fmt.Errorf("invalid mysql URL: %w", err)
107+
}
108+
109+
cfg := &Config{
110+
Driver: "mysql",
111+
Host: u.Hostname(),
112+
Port: u.Port(),
113+
Params: make(map[string]string),
114+
}
115+
116+
if cfg.Port == "" {
117+
cfg.Port = "3306"
118+
}
119+
120+
if u.User != nil {
121+
cfg.User = u.User.Username()
122+
cfg.Password, _ = u.User.Password()
123+
}
124+
125+
cfg.Database = strings.TrimPrefix(u.Path, "/")
126+
127+
for k, v := range u.Query() {
128+
cfg.Params[k] = v[0]
129+
}
130+
131+
return cfg, nil
132+
}
133+
134+
func parseMySQLTCP(raw string) (*Config, error) {
135+
// Format: mysql://user:pass@tcp(host:port)/dbname?params
136+
cfg := &Config{
137+
Driver: "mysql",
138+
Params: make(map[string]string),
139+
}
140+
141+
// Remove mysql://
142+
raw = raw[8:]
143+
144+
// Split on @tcp(
145+
atIdx := strings.Index(raw, "@tcp(")
146+
if atIdx == -1 {
147+
return nil, fmt.Errorf("invalid mysql tcp format")
148+
}
149+
150+
userPart := raw[:atIdx]
151+
rest := raw[atIdx+5:] // after "@tcp("
152+
153+
// Parse user:pass
154+
if colonIdx := strings.Index(userPart, ":"); colonIdx != -1 {
155+
cfg.User = userPart[:colonIdx]
156+
cfg.Password = userPart[colonIdx+1:]
157+
} else {
158+
cfg.User = userPart
159+
}
160+
161+
// Parse host:port)/dbname?params
162+
closeIdx := strings.Index(rest, ")")
163+
if closeIdx == -1 {
164+
return nil, fmt.Errorf("invalid mysql tcp format: missing closing paren")
165+
}
166+
167+
hostPort := rest[:closeIdx]
168+
afterParen := rest[closeIdx+1:]
169+
170+
if colonIdx := strings.LastIndex(hostPort, ":"); colonIdx != -1 {
171+
cfg.Host = hostPort[:colonIdx]
172+
cfg.Port = hostPort[colonIdx+1:]
173+
} else {
174+
cfg.Host = hostPort
175+
cfg.Port = "3306"
176+
}
177+
178+
// Parse /dbname?params
179+
if strings.HasPrefix(afterParen, "/") {
180+
afterParen = afterParen[1:]
181+
}
182+
183+
if qIdx := strings.Index(afterParen, "?"); qIdx != -1 {
184+
cfg.Database = afterParen[:qIdx]
185+
queryStr := afterParen[qIdx+1:]
186+
187+
for _, pair := range strings.Split(queryStr, "&") {
188+
if eqIdx := strings.Index(pair, "="); eqIdx != -1 {
189+
cfg.Params[pair[:eqIdx]] = pair[eqIdx+1:]
190+
}
191+
}
192+
} else {
193+
cfg.Database = afterParen
194+
}
195+
196+
return cfg, nil
197+
}
198+
199+
func parseSQLServer(connStr string) (*Config, error) {
200+
u, err := url.Parse(connStr)
201+
if err != nil {
202+
return nil, fmt.Errorf("invalid sqlserver URL: %w", err)
203+
}
204+
205+
cfg := &Config{
206+
Driver: "sqlserver",
207+
Host: u.Hostname(),
208+
Port: u.Port(),
209+
Params: make(map[string]string),
210+
}
211+
212+
if cfg.Port == "" {
213+
cfg.Port = "1433"
214+
}
215+
216+
if u.User != nil {
217+
cfg.User = u.User.Username()
218+
cfg.Password, _ = u.User.Password()
219+
}
220+
221+
// SQL Server uses ?database=dbname instead of /dbname
222+
for k, v := range u.Query() {
223+
if strings.ToLower(k) == "database" {
224+
cfg.Database = v[0]
225+
} else {
226+
cfg.Params[k] = v[0]
227+
}
228+
}
229+
230+
// Also check the path for /dbname if no database param
231+
if cfg.Database == "" {
232+
cfg.Database = strings.TrimPrefix(u.Path, "/")
233+
}
234+
235+
return cfg, nil
236+
}
237+
238+
func (c *Config) toPostgresDSN() string {
239+
u := &url.URL{
240+
Scheme: "postgres",
241+
Host: fmt.Sprintf("%s:%s", c.Host, c.Port),
242+
Path: c.Database,
243+
}
244+
if c.User != "" {
245+
if c.Password != "" {
246+
u.User = url.UserPassword(c.User, c.Password)
247+
} else {
248+
u.User = url.User(c.User)
249+
}
250+
}
251+
q := url.Values{}
252+
for k, v := range c.Params {
253+
q.Set(k, v)
254+
}
255+
u.RawQuery = q.Encode()
256+
return u.String()
257+
}
258+
259+
func (c *Config) toMySQLDSN() string {
260+
// go-sql-driver format: user:pass@tcp(host:port)/dbname?params
261+
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s",
262+
c.User, c.Password, c.Host, c.Port, c.Database)
263+
264+
if len(c.Params) > 0 {
265+
params := make([]string, 0, len(c.Params))
266+
for k, v := range c.Params {
267+
params = append(params, fmt.Sprintf("%s=%s", k, v))
268+
}
269+
dsn += "?" + strings.Join(params, "&")
270+
}
271+
return dsn
272+
}
273+
274+
func (c *Config) toSQLServerDSN() string {
275+
u := &url.URL{
276+
Scheme: "sqlserver",
277+
Host: fmt.Sprintf("%s:%s", c.Host, c.Port),
278+
}
279+
if c.User != "" {
280+
if c.Password != "" {
281+
u.User = url.UserPassword(c.User, c.Password)
282+
} else {
283+
u.User = url.User(c.User)
284+
}
285+
}
286+
q := url.Values{}
287+
if c.Database != "" {
288+
q.Set("database", c.Database)
289+
}
290+
for k, v := range c.Params {
291+
q.Set(k, v)
292+
}
293+
u.RawQuery = q.Encode()
294+
return u.String()
295+
}

0 commit comments

Comments
 (0)