@@ -188,29 +188,74 @@ func UsesArrays(r Generateable, settings config.CombinedSettings) bool {
188188 return false
189189}
190190
191+ type fileImports struct {
192+ Std []string
193+ Dep []string
194+ }
195+
196+ func mergeImports (imps ... fileImports ) [][]string {
197+ if len (imps ) == 1 {
198+ return [][]string {imps [0 ].Std , imps [0 ].Dep }
199+ }
200+
201+ var stds , pkgs []string
202+ seenStd := map [string ]struct {}{}
203+ seenPkg := map [string ]struct {}{}
204+ for i := range imps {
205+ for _ , std := range imps [i ].Std {
206+ if _ , ok := seenStd [std ]; ok {
207+ continue
208+ }
209+ stds = append (stds , std )
210+ seenStd [std ] = struct {}{}
211+ }
212+ for _ , pkg := range imps [i ].Dep {
213+ if _ , ok := seenPkg [pkg ]; ok {
214+ continue
215+ }
216+ pkgs = append (pkgs , pkg )
217+ seenPkg [pkg ] = struct {}{}
218+ }
219+ }
220+ return [][]string {stds , pkgs }
221+ }
222+
191223func Imports (r Generateable , settings config.CombinedSettings ) func (string ) [][]string {
192224 return func (filename string ) [][]string {
225+ if filename == "all.go" {
226+ var imps []fileImports
227+ imps = append (imps , dbImports (r , settings ))
228+ imps = append (imps , modelImports (r , settings ))
229+ imps = append (imps , interfaceImports (r , settings ))
230+ imps = append (imps , queryImports (r , settings , filename ))
231+ return mergeImports (imps ... )
232+ }
233+
193234 if filename == "db.go" {
194- imps := []string {"context" , "database/sql" }
195- if settings .Go .EmitPreparedQueries {
196- imps = append (imps , "fmt" )
197- }
198- return [][]string {imps }
235+ return mergeImports (dbImports (r , settings ))
199236 }
200237
201238 if filename == "models.go" {
202- return ModelImports ( r , settings )
239+ return mergeImports ( modelImports ( r , settings ) )
203240 }
204241
205242 if filename == "querier.go" {
206- return InterfaceImports ( r , settings )
243+ return mergeImports ( interfaceImports ( r , settings ) )
207244 }
208245
209- return QueryImports ( r , settings , filename )
246+ return mergeImports ( queryImports ( r , settings , filename ) )
210247 }
211248}
212249
213- func InterfaceImports (r Generateable , settings config.CombinedSettings ) [][]string {
250+ func dbImports (r Generateable , settings config.CombinedSettings ) fileImports {
251+ std := []string {"context" , "database/sql" }
252+ if settings .Go .EmitPreparedQueries {
253+ std = append (std , "fmt" )
254+ }
255+ return fileImports {Std : std }
256+ }
257+
258+ func interfaceImports (r Generateable , settings config.CombinedSettings ) fileImports {
214259 gq := r .GoQueries (settings )
215260 uses := func (name string ) bool {
216261 for _ , q := range gq {
@@ -284,10 +329,10 @@ func InterfaceImports(r Generateable, settings config.CombinedSettings) [][]stri
284329
285330 sort .Strings (stds )
286331 sort .Strings (pkgs )
287- return [][] string {stds , pkgs }
332+ return fileImports {stds , pkgs }
288333}
289334
290- func ModelImports (r Generateable , settings config.CombinedSettings ) [][] string {
335+ func modelImports (r Generateable , settings config.CombinedSettings ) fileImports {
291336 std := make (map [string ]struct {})
292337 if UsesType (r , "sql.Null" , settings ) {
293338 std ["database/sql" ] = struct {}{}
@@ -343,10 +388,10 @@ func ModelImports(r Generateable, settings config.CombinedSettings) [][]string {
343388
344389 sort .Strings (stds )
345390 sort .Strings (pkgs )
346- return [][] string {stds , pkgs }
391+ return fileImports {stds , pkgs }
347392}
348393
349- func QueryImports (r Generateable , settings config.CombinedSettings , filename string ) [][] string {
394+ func queryImports (r Generateable , settings config.CombinedSettings , filename string ) fileImports {
350395 // for _, strct := range r.Structs() {
351396 // for _, f := range strct.Fields {
352397 // if strings.HasPrefix(f.Type, "[]") {
@@ -356,7 +401,7 @@ func QueryImports(r Generateable, settings config.CombinedSettings, filename str
356401 // }
357402 var gq []GoQuery
358403 for _ , query := range r .GoQueries (settings ) {
359- if query .SourceName == filename {
404+ if query .SourceName == filename || settings . Go . EmitSingleFile {
360405 gq = append (gq , query )
361406 }
362407 }
@@ -481,7 +526,7 @@ func QueryImports(r Generateable, settings config.CombinedSettings, filename str
481526
482527 sort .Strings (stds )
483528 sort .Strings (pkgs )
484- return [][] string {stds , pkgs }
529+ return fileImports {stds , pkgs }
485530}
486531
487532func enumValueName (value string ) string {
@@ -924,7 +969,8 @@ func (r Result) GoQueries(settings config.CombinedSettings) []GoQuery {
924969 return qs
925970}
926971
927- var dbTmpl = `// Code generated by sqlc. DO NOT EDIT.
972+ var templateSet = `
973+ {{define "dbFile"}}// Code generated by sqlc. DO NOT EDIT.
928974
929975package {{.Package}}
930976
@@ -935,6 +981,10 @@ import (
935981 {{end}}
936982)
937983
984+ {{template "dbCode" . }}
985+ {{end}}
986+
987+ {{define "dbCode"}}
938988type DBTX interface {
939989 ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
940990 PrepareContext(context.Context, string) (*sql.Stmt, error)
@@ -1029,9 +1079,9 @@ func (q *Queries) WithTx(tx *sql.Tx) *Queries {
10291079 {{- end}}
10301080 }
10311081}
1032- `
1082+ {{end}}
10331083
1034- var ifaceTmpl = ` // Code generated by sqlc. DO NOT EDIT.
1084+ {{define "interfaceFile"}} // Code generated by sqlc. DO NOT EDIT.
10351085
10361086package {{.Package}}
10371087
@@ -1042,6 +1092,10 @@ import (
10421092 {{end}}
10431093)
10441094
1095+ {{template "interfaceCode" . }}
1096+ {{end}}
1097+
1098+ {{define "interfaceCode"}}
10451099type Querier interface {
10461100 {{- range .GoQueries}}
10471101 {{- if eq .Cmd ":one"}}
@@ -1060,9 +1114,9 @@ type Querier interface {
10601114}
10611115
10621116var _ Querier = (*Queries)(nil)
1063- `
1117+ {{end}}
10641118
1065- var modelsTmpl = ` // Code generated by sqlc. DO NOT EDIT.
1119+ {{define "modelsFile"}} // Code generated by sqlc. DO NOT EDIT.
10661120
10671121package {{.Package}}
10681122
@@ -1073,6 +1127,10 @@ import (
10731127 {{end}}
10741128)
10751129
1130+ {{template "modelsCode" . }}
1131+ {{end}}
1132+
1133+ {{define "modelsCode"}}
10761134{{range .Enums}}
10771135{{if .Comment}}{{comment .Comment}}{{end}}
10781136type {{.Name}} string
@@ -1099,9 +1157,9 @@ type {{.Name}} struct { {{- range .Fields}}
10991157 {{- end}}
11001158}
11011159{{end}}
1102- `
1160+ {{end}}
11031161
1104- var sqlTmpl = ` // Code generated by sqlc. DO NOT EDIT.
1162+ {{define "queryFile"}} // Code generated by sqlc. DO NOT EDIT.
11051163// source: {{.SourceName}}
11061164
11071165package {{.Package}}
@@ -1113,8 +1171,12 @@ import (
11131171 {{end}}
11141172)
11151173
1174+ {{template "queryCode" . }}
1175+ {{end}}
1176+
1177+ {{define "queryCode"}}
11161178{{range .GoQueries}}
1117- {{if eq .SourceName $ .SourceName}}
1179+ {{if $.OutputQuery .SourceName}}
11181180const {{.ConstantName}} = {{$.Q}}-- name: {{.MethodName}} {{.Cmd}}
11191181{{.SQL}}
11201182{{$.Q}}
@@ -1209,6 +1271,27 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (int64, er
12091271{{end}}
12101272{{end}}
12111273{{end}}
1274+ {{end}}
1275+
1276+ {{define "singleFile"}}// Code generated by sqlc. DO NOT EDIT.
1277+
1278+ package {{.Package}}
1279+
1280+ import (
1281+ {{range imports "all.go"}}
1282+ {{range .}}"{{.}}"
1283+ {{end}}
1284+ {{end}}
1285+ )
1286+
1287+ {{template "modelsCode" . }}
1288+
1289+ {{template "queryCode" . }}
1290+
1291+ {{template "dbCode" . }}
1292+
1293+ {{template "interfaceCode" . }}
1294+ {{end}}
12121295`
12131296
12141297type tmplCtx struct {
@@ -1225,6 +1308,11 @@ type tmplCtx struct {
12251308 EmitJSONTags bool
12261309 EmitPreparedQueries bool
12271310 EmitInterface bool
1311+ EmitSingleFile bool
1312+ }
1313+
1314+ func (t * tmplCtx ) OutputQuery (sourceName string ) bool {
1315+ return t .SourceName == sourceName || t .EmitSingleFile
12281316}
12291317
12301318func LowerTitle (s string ) string {
@@ -1244,17 +1332,15 @@ func Generate(r Generateable, settings config.CombinedSettings) (map[string]stri
12441332 "imports" : Imports (r , settings ),
12451333 }
12461334
1247- dbFile := template .Must (template .New ("table" ).Funcs (funcMap ).Parse (dbTmpl ))
1248- modelsFile := template .Must (template .New ("table" ).Funcs (funcMap ).Parse (modelsTmpl ))
1249- sqlFile := template .Must (template .New ("table" ).Funcs (funcMap ).Parse (sqlTmpl ))
1250- ifaceFile := template .Must (template .New ("table" ).Funcs (funcMap ).Parse (ifaceTmpl ))
1335+ tmpl := template .Must (template .New ("table" ).Funcs (funcMap ).Parse (templateSet ))
12511336
12521337 golang := settings .Go
12531338 tctx := tmplCtx {
12541339 Settings : settings .Global ,
12551340 EmitInterface : golang .EmitInterface ,
12561341 EmitJSONTags : golang .EmitJSONTags ,
12571342 EmitPreparedQueries : golang .EmitPreparedQueries ,
1343+ EmitSingleFile : golang .EmitSingleFile ,
12581344 Q : "`" ,
12591345 Package : golang .Package ,
12601346 GoQueries : r .GoQueries (settings ),
@@ -1264,11 +1350,11 @@ func Generate(r Generateable, settings config.CombinedSettings) (map[string]stri
12641350
12651351 output := map [string ]string {}
12661352
1267- execute := func (name string , t * template. Template ) error {
1353+ execute := func (name , templateName string ) error {
12681354 var b bytes.Buffer
12691355 w := bufio .NewWriter (& b )
12701356 tctx .SourceName = name
1271- err := t . Execute (w , tctx )
1357+ err := tmpl . ExecuteTemplate (w , templateName , & tctx )
12721358 w .Flush ()
12731359 if err != nil {
12741360 return err
@@ -1285,14 +1371,22 @@ func Generate(r Generateable, settings config.CombinedSettings) (map[string]stri
12851371 return nil
12861372 }
12871373
1288- if err := execute ("db.go" , dbFile ); err != nil {
1374+ // Output a single file with all code
1375+ if golang .EmitSingleFile {
1376+ if err := execute ("db.go" , "singleFile" ); err != nil {
1377+ return nil , err
1378+ }
1379+ return output , nil
1380+ }
1381+
1382+ if err := execute ("db.go" , "dbFile" ); err != nil {
12891383 return nil , err
12901384 }
1291- if err := execute ("models.go" , modelsFile ); err != nil {
1385+ if err := execute ("models.go" , " modelsFile" ); err != nil {
12921386 return nil , err
12931387 }
12941388 if golang .EmitInterface {
1295- if err := execute ("querier.go" , ifaceFile ); err != nil {
1389+ if err := execute ("querier.go" , "interfaceFile" ); err != nil {
12961390 return nil , err
12971391 }
12981392 }
@@ -1303,7 +1397,7 @@ func Generate(r Generateable, settings config.CombinedSettings) (map[string]stri
13031397 }
13041398
13051399 for source := range files {
1306- if err := execute (source , sqlFile ); err != nil {
1400+ if err := execute (source , "queryFile" ); err != nil {
13071401 return nil , err
13081402 }
13091403 }
0 commit comments