@@ -76,17 +76,6 @@ func (i *importer) usesType(typ string) bool {
7676 return false
7777}
7878
79- func (i * importer ) usesArrays () bool {
80- for _ , strct := range i .Structs {
81- for _ , f := range strct .Fields {
82- if strings .HasPrefix (f .Type , "[]" ) {
83- return true
84- }
85- }
86- }
87- return false
88- }
89-
9079func (i * importer ) Imports (filename string ) [][]ImportSpec {
9180 dbFileName := "db.go"
9281 if i .Settings .Go .OutputDBFileName != "" {
@@ -143,34 +132,16 @@ var stdlibTypes = map[string]string{
143132 "net.HardwareAddr" : "net" ,
144133}
145134
146- func (i * importer ) interfaceImports () fileImports {
147- uses := func (name string ) bool {
148- for _ , q := range i .Queries {
149- if q .hasRetType () {
150- if strings .HasPrefix (q .Ret .Type (), name ) {
151- return true
152- }
153- }
154- if ! q .Arg .isEmpty () {
155- if strings .HasPrefix (q .Arg .Type (), name ) {
156- return true
157- }
158- }
159- }
160- return false
161- }
135+ func buildImports (settings config.CombinedSettings , queries []Query , uses func (string ) bool ) (map [string ]struct {}, map [ImportSpec ]struct {}) {
136+ pkg := make (map [ImportSpec ]struct {})
137+ std := make (map [string ]struct {})
162138
163- std := map [string ]struct {}{
164- "context" : {},
165- }
166139 if uses ("sql.Null" ) {
167140 std ["database/sql" ] = struct {}{}
168141 }
169142
170- pkg := make (map [ImportSpec ]struct {})
171-
172- sqlpkg := SQLPackageFromString (i .Settings .Go .SQLPackage )
173- for _ , q := range i .Queries {
143+ sqlpkg := SQLPackageFromString (settings .Go .SQLPackage )
144+ for _ , q := range queries {
174145 if q .Cmd == metadata .CmdExecResult {
175146 switch sqlpkg {
176147 case SQLPackagePGX :
@@ -180,14 +151,15 @@ func (i *importer) interfaceImports() fileImports {
180151 }
181152 }
182153 }
154+
183155 for typeName , pkg := range stdlibTypes {
184156 if uses (typeName ) {
185157 std [pkg ] = struct {}{}
186158 }
187159 }
188160
189161 overrideTypes := map [string ]string {}
190- for _ , o := range i . Settings .Overrides {
162+ for _ , o := range settings .Overrides {
191163 if o .GoBasicType || o .GoTypeName == "" {
192164 continue
193165 }
@@ -208,7 +180,7 @@ func (i *importer) interfaceImports() fileImports {
208180 }
209181
210182 // Custom imports
211- for _ , o := range i . Settings .Overrides {
183+ for _ , o := range settings .Overrides {
212184 if o .GoBasicType || o .GoTypeName == "" {
213185 continue
214186 }
@@ -219,80 +191,52 @@ func (i *importer) interfaceImports() fileImports {
219191 }
220192 }
221193
222- pkgs := make ([]ImportSpec , 0 , len (pkg ))
223- for spec := range pkg {
224- pkgs = append (pkgs , spec )
225- }
194+ return std , pkg
195+ }
226196
227- stds := make ([]ImportSpec , 0 , len (std ))
228- for path := range std {
229- stds = append (stds , ImportSpec {Path : path })
230- }
197+ func (i * importer ) interfaceImports () fileImports {
198+ std , pkg := buildImports (i .Settings , i .Queries , func (name string ) bool {
199+ for _ , q := range i .Queries {
200+ if q .hasRetType () {
201+ if strings .HasPrefix (q .Ret .Type (), name ) {
202+ return true
203+ }
204+ }
205+ if ! q .Arg .isEmpty () {
206+ if strings .HasPrefix (q .Arg .Type (), name ) {
207+ return true
208+ }
209+ }
210+ }
211+ return false
212+ })
231213
232- sort . Slice ( stds , func ( i , j int ) bool { return stds [ i ]. Path < stds [ j ]. Path })
233- sort . Slice ( pkgs , func ( i , j int ) bool { return pkgs [ i ]. Path < pkgs [ j ]. Path })
234- return fileImports { stds , pkgs }
214+ std [ "context" ] = struct {}{}
215+
216+ return sortedImports ( std , pkg )
235217}
236218
237219func (i * importer ) modelImports () fileImports {
238- std := make (map [string ]struct {})
239- if i .usesType ("sql.Null" ) {
240- std ["database/sql" ] = struct {}{}
241- }
242- for typeName , pkg := range stdlibTypes {
243- if i .usesType (typeName ) {
244- std [pkg ] = struct {}{}
245- }
246- }
220+ std , pkg := buildImports (i .Settings , nil , func (prefix string ) bool {
221+ return i .usesType (prefix )
222+ })
223+
247224 if len (i .Enums ) > 0 {
248225 std ["fmt" ] = struct {}{}
249226 }
250227
251- // Custom imports
252- pkg := make (map [ImportSpec ]struct {})
253- overrideTypes := map [string ]string {}
254- for _ , o := range i .Settings .Overrides {
255- if o .GoBasicType || o .GoTypeName == "" {
256- continue
257- }
258- overrideTypes [o .GoTypeName ] = o .GoImportPath
259- }
260-
261- _ , overrideNullTime := overrideTypes ["pq.NullTime" ]
262- if i .usesType ("pq.NullTime" ) && ! overrideNullTime {
263- pkg [ImportSpec {Path : "github.com/lib/pq" }] = struct {}{}
264- }
265-
266- _ , overrideUUID := overrideTypes ["uuid.UUID" ]
267- if i .usesType ("uuid.UUID" ) && ! overrideUUID {
268- pkg [ImportSpec {Path : "github.com/google/uuid" }] = struct {}{}
269- }
270- _ , overrideNullUUID := overrideTypes ["uuid.NullUUID" ]
271- if i .usesType ("uuid.NullUUID" ) && ! overrideNullUUID {
272- pkg [ImportSpec {Path : "github.com/google/uuid" }] = struct {}{}
273- }
274-
275- for _ , o := range i .Settings .Overrides {
276- if o .GoBasicType || o .GoTypeName == "" {
277- continue
278- }
279- _ , alreadyImported := std [o .GoImportPath ]
280- hasPackageAlias := o .GoPackage != ""
281- if (! alreadyImported || hasPackageAlias ) && i .usesType (o .GoTypeName ) {
282- pkg [ImportSpec {Path : o .GoImportPath , ID : o .GoPackage }] = struct {}{}
283- }
284- }
228+ return sortedImports (std , pkg )
229+ }
285230
231+ func sortedImports (std map [string ]struct {}, pkg map [ImportSpec ]struct {}) fileImports {
286232 pkgs := make ([]ImportSpec , 0 , len (pkg ))
287233 for spec := range pkg {
288234 pkgs = append (pkgs , spec )
289235 }
290-
291236 stds := make ([]ImportSpec , 0 , len (std ))
292237 for path := range std {
293238 stds = append (stds , ImportSpec {Path : path })
294239 }
295-
296240 sort .Slice (stds , func (i , j int ) bool { return stds [i ].Path < stds [j ].Path })
297241 sort .Slice (pkgs , func (i , j int ) bool { return pkgs [i ].Path < pkgs [j ].Path })
298242 return fileImports {stds , pkgs }
@@ -306,7 +250,7 @@ func (i *importer) queryImports(filename string) fileImports {
306250 }
307251 }
308252
309- uses := func (name string ) bool {
253+ std , pkg := buildImports ( i . Settings , gq , func (name string ) bool {
310254 for _ , q := range gq {
311255 if q .hasRetType () {
312256 if q .Ret .EmitStruct () {
@@ -336,7 +280,7 @@ func (i *importer) queryImports(filename string) fileImports {
336280 }
337281 }
338282 return false
339- }
283+ })
340284
341285 sliceScan := func () bool {
342286 for _ , q := range gq {
@@ -370,80 +314,12 @@ func (i *importer) queryImports(filename string) fileImports {
370314 return false
371315 }
372316
373- pkg := make (map [ImportSpec ]struct {})
374- std := map [string ]struct {}{
375- "context" : {},
376- }
377- if uses ("sql.Null" ) {
378- std ["database/sql" ] = struct {}{}
379- }
317+ std ["context" ] = struct {}{}
380318
381319 sqlpkg := SQLPackageFromString (i .Settings .Go .SQLPackage )
382-
383- for _ , q := range gq {
384- if q .Cmd == metadata .CmdExecResult {
385- switch sqlpkg {
386- case SQLPackagePGX :
387- pkg [ImportSpec {Path : "github.com/jackc/pgconn" }] = struct {}{}
388- default :
389- std ["database/sql" ] = struct {}{}
390- }
391- }
392- }
393- for typeName , pkg := range stdlibTypes {
394- if uses (typeName ) {
395- std [pkg ] = struct {}{}
396- }
397- }
398-
399- overrideTypes := map [string ]string {}
400- for _ , o := range i .Settings .Overrides {
401- if o .GoBasicType || o .GoTypeName == "" {
402- continue
403- }
404- overrideTypes [o .GoTypeName ] = o .GoImportPath
405- }
406-
407320 if sliceScan () && sqlpkg != SQLPackagePGX {
408321 pkg [ImportSpec {Path : "github.com/lib/pq" }] = struct {}{}
409322 }
410323
411- _ , overrideNullTime := overrideTypes ["pq.NullTime" ]
412- if uses ("pq.NullTime" ) && ! overrideNullTime {
413- pkg [ImportSpec {Path : "github.com/lib/pq" }] = struct {}{}
414- }
415- _ , overrideUUID := overrideTypes ["uuid.UUID" ]
416- if uses ("uuid.UUID" ) && ! overrideUUID {
417- pkg [ImportSpec {Path : "github.com/google/uuid" }] = struct {}{}
418- }
419- _ , overrideNullUUID := overrideTypes ["uuid.NullUUID" ]
420- if uses ("uuid.NullUUID" ) && ! overrideNullUUID {
421- pkg [ImportSpec {Path : "github.com/google/uuid" }] = struct {}{}
422- }
423-
424- // Custom imports
425- for _ , o := range i .Settings .Overrides {
426- if o .GoBasicType || o .GoTypeName == "" {
427- continue
428- }
429- _ , alreadyImported := std [o .GoImportPath ]
430- hasPackageAlias := o .GoPackage != ""
431- if (! alreadyImported || hasPackageAlias ) && uses (o .GoTypeName ) {
432- pkg [ImportSpec {Path : o .GoImportPath , ID : o .GoPackage }] = struct {}{}
433- }
434- }
435-
436- pkgs := make ([]ImportSpec , 0 , len (pkg ))
437- for spec := range pkg {
438- pkgs = append (pkgs , spec )
439- }
440-
441- stds := make ([]ImportSpec , 0 , len (std ))
442- for path := range std {
443- stds = append (stds , ImportSpec {Path : path })
444- }
445-
446- sort .Slice (stds , func (i , j int ) bool { return stds [i ].Path < stds [j ].Path })
447- sort .Slice (pkgs , func (i , j int ) bool { return pkgs [i ].Path < pkgs [j ].Path })
448- return fileImports {stds , pkgs }
324+ return sortedImports (std , pkg )
449325}
0 commit comments