Skip to content

Commit 048b536

Browse files
committed
support opening a connection with custom flags
This introduces the use of the options pattern with the new `ConnectorWithOpts` function. The opt `WithOpenFlags` allows specifying custom flags to use when opening connections. When left unspecified, we use the same `sqliteh.OpenFlagsDefault` as before. Updates tailscale/corp#36592 Signed-off-by: Percy Wegmann <percy@tailscale.com>
1 parent 617c375 commit 048b536

1 file changed

Lines changed: 47 additions & 10 deletions

File tree

sqlite.go

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -116,19 +116,15 @@ type drv struct{}
116116

117117
func (drv) Open(name string) (driver.Conn, error) { panic("deprecated, unused") }
118118
func (drv) OpenConnector(name string) (driver.Connector, error) {
119-
return &connector{name: name}, nil
119+
return &connector{name: name, openFlags: sqliteh.OpenFlagsDefault}, nil
120120
}
121121

122122
// Connector returns a [driver.Connector] for the given connection
123123
// parameters.
124124
//
125125
// The tracer may be nil.
126126
func Connector(sqliteURI string, connInitFunc ConnInitFunc, tracer sqliteh.Tracer) driver.Connector {
127-
return &connector{
128-
name: sqliteURI,
129-
tracer: tracer,
130-
connInitFunc: connInitFunc,
131-
}
127+
return ConnectorWithOpts(sqliteURI, connInitFunc, WithTracer(tracer))
132128
}
133129

134130
// ConnectorWithLogger returns a [driver.Connector] for the given connection
@@ -137,11 +133,48 @@ func Connector(sqliteURI string, connInitFunc ConnInitFunc, tracer sqliteh.Trace
137133
//
138134
// The tracer may also be nil.
139135
func ConnectorWithLogger(sqliteURI string, connInitFunc ConnInitFunc, tracer sqliteh.Tracer, makeLogger func() ConnLogger) driver.Connector {
140-
return &connector{
136+
return ConnectorWithOpts(sqliteURI, connInitFunc, WithTracer(tracer), WithConnLogger(makeLogger))
137+
}
138+
139+
// ConnectorWithOpts returns a [driver.Connector] for the given connection parameters, optionally
140+
// configured with one or more [ConnectorOpt]s.
141+
func ConnectorWithOpts(sqliteURI string, connInitFunc ConnInitFunc, opts ...ConnectorOpt) driver.Connector {
142+
p := &connector{
141143
name: sqliteURI,
142-
tracer: tracer,
143-
makeLogger: makeLogger,
144144
connInitFunc: connInitFunc,
145+
openFlags: sqliteh.OpenFlagsDefault, // default flags unless [WithOpenFlags] option is used
146+
}
147+
for _, opt := range opts {
148+
opt(p)
149+
}
150+
return p
151+
}
152+
153+
// ConnectorOpt is an option to [ConnectorWithOpts].
154+
type ConnectorOpt func(p *connector)
155+
156+
// WithTracer returns a [ConnectorOpt] that configures the [driver.Connector]
157+
// to enable tracing on new connections using the given [sqliteh.Tracer].
158+
func WithTracer(tracer sqliteh.Tracer) ConnectorOpt {
159+
return func(p *connector) {
160+
p.tracer = tracer
161+
}
162+
}
163+
164+
// WithConnLogger returns a [ConnectorOpt] that configures the [driver.Connector]
165+
// to use a [ConnLogger] returned by the provided makeLogger when opening new
166+
// connections.
167+
func WithConnLogger(makeLogger func() ConnLogger) ConnectorOpt {
168+
return func(p *connector) {
169+
p.makeLogger = makeLogger
170+
}
171+
}
172+
173+
// WithOpenFlags returns a [ConnectorOpt] that configures the [driver.Connector]
174+
// to use the given [sqliteh.OpenFlags] when opening new connections.
175+
func WithOpenFlags(openFlags sqliteh.OpenFlags) ConnectorOpt {
176+
return func(p *connector) {
177+
p.openFlags = openFlags
145178
}
146179
}
147180

@@ -150,11 +183,15 @@ type connector struct {
150183
tracer sqliteh.Tracer // or nil
151184
makeLogger func() ConnLogger // or nil
152185
connInitFunc ConnInitFunc
186+
openFlags sqliteh.OpenFlags
153187
}
154188

189+
// Driver implements [driver.Connector.Driver].
155190
func (p *connector) Driver() driver.Driver { return drv{} }
191+
192+
// Connect implements [driver.Connector.Connect]
156193
func (p *connector) Connect(ctx context.Context) (driver.Conn, error) {
157-
db, err := Open(p.name, sqliteh.OpenFlagsDefault, "")
194+
db, err := Open(p.name, p.openFlags, "")
158195
if err != nil {
159196
if ec, ok := err.(sqliteh.ErrCode); ok {
160197
e := &Error{

0 commit comments

Comments
 (0)