@@ -229,7 +229,8 @@ func ParseQueries(c core.Catalog, pkg PackageSettings) (*Result, error) {
229229 continue
230230 }
231231 for _ , stmt := range tree .Statements {
232- query , err := parseQuery (c , stmt , source )
232+ rewriteParameters := pkg .rewriteParams
233+ query , err := parseQuery (c , stmt , source , rewriteParameters )
233234 if err == errUnsupportedStatementType {
234235 continue
235236 }
@@ -407,7 +408,7 @@ func validateCmd(n nodes.Node, name, cmd string) error {
407408
408409var errUnsupportedStatementType = errors .New ("parseQuery: unsupported statement type" )
409410
410- func parseQuery (c core.Catalog , stmt nodes.Node , source string ) (* Query , error ) {
411+ func parseQuery (c core.Catalog , stmt nodes.Node , source string , rewriteParameters bool ) (* Query , error ) {
411412 if err := validateParamRef (stmt ); err != nil {
412413 return nil , err
413414 }
@@ -443,6 +444,16 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error)
443444 }
444445 rvs := rangeVars (raw .Stmt )
445446 refs := findParameters (raw .Stmt )
447+ var edits []edit
448+ if rewriteParameters {
449+ edits , err = rewriteNumberedParameters (refs , raw , rawSQL )
450+ if err != nil {
451+ return nil , err
452+ }
453+ } else {
454+ refs = uniqueParamRefs (refs )
455+ sort .Slice (refs , func (i , j int ) bool { return refs [i ].ref .Number < refs [j ].ref .Number })
456+ }
446457 params , err := resolveCatalogRefs (c , rvs , refs )
447458 if err != nil {
448459 return nil , err
@@ -452,7 +463,13 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error)
452463 if err != nil {
453464 return nil , err
454465 }
455- expanded , err := expand (c , raw , rawSQL )
466+ expandEdits , err := expand (c , raw , rawSQL )
467+ if err != nil {
468+ return nil , err
469+ }
470+ edits = append (edits , expandEdits ... )
471+
472+ expanded , err := editQuery (rawSQL , edits )
456473 if err != nil {
457474 return nil , err
458475 }
@@ -472,6 +489,18 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error)
472489 }, nil
473490}
474491
492+ func rewriteNumberedParameters (refs []paramRef , raw nodes.RawStmt , sql string ) ([]edit , error ) {
493+ edits := make ([]edit , len (refs ))
494+ for i , ref := range refs {
495+ edits [i ] = edit {
496+ Location : ref .ref .Location - raw .StmtLocation ,
497+ Old : fmt .Sprintf ("$%d" , ref .ref .Number ),
498+ New : "?" ,
499+ }
500+ }
501+ return edits , nil
502+ }
503+
475504func stripComments (sql string ) (string , []string , error ) {
476505 s := bufio .NewScanner (strings .NewReader (sql ))
477506 var lines , comments []string
@@ -494,7 +523,7 @@ type edit struct {
494523 New string
495524}
496525
497- func expand (c core.Catalog , raw nodes.RawStmt , sql string ) (string , error ) {
526+ func expand (c core.Catalog , raw nodes.RawStmt , sql string ) ([] edit , error ) {
498527 list := search (raw , func (node nodes.Node ) bool {
499528 switch node .(type ) {
500529 case nodes.DeleteStmt :
@@ -507,17 +536,17 @@ func expand(c core.Catalog, raw nodes.RawStmt, sql string) (string, error) {
507536 return true
508537 })
509538 if len (list .Items ) == 0 {
510- return sql , nil
539+ return nil , nil
511540 }
512541 var edits []edit
513542 for _ , item := range list .Items {
514543 edit , err := expandStmt (c , raw , item )
515544 if err != nil {
516- return "" , err
545+ return nil , err
517546 }
518547 edits = append (edits , edit ... )
519548 }
520- return editQuery ( sql , edits )
549+ return edits , nil
521550}
522551
523552func expandStmt (c core.Catalog , raw nodes.RawStmt , node nodes.Node ) ([]edit , error ) {
@@ -958,7 +987,8 @@ type paramRef struct {
958987type paramSearch struct {
959988 parent nodes.Node
960989 rangeVar * nodes.RangeVar
961- refs map [int ]paramRef
990+ refs * []paramRef
991+ seen map [int ]struct {}
962992
963993 // XXX: Gross state hack for limit
964994 limitCount nodes.Node
@@ -1005,7 +1035,8 @@ func (p paramSearch) Visit(node nodes.Node) Visitor {
10051035 continue
10061036 }
10071037 // TODO: Out-of-bounds panic
1008- p .refs [ref .Number ] = paramRef {parent : n .Cols .Items [i ], ref : ref , rv : p .rangeVar }
1038+ * p .refs = append (* p .refs , paramRef {parent : n .Cols .Items [i ], ref : ref , rv : p .rangeVar })
1039+ p .seen [ref .Location ] = struct {}{}
10091040 }
10101041 for _ , vl := range s .ValuesLists {
10111042 for i , v := range vl {
@@ -1014,7 +1045,8 @@ func (p paramSearch) Visit(node nodes.Node) Visitor {
10141045 continue
10151046 }
10161047 // TODO: Out-of-bounds panic
1017- p .refs [ref .Number ] = paramRef {parent : n .Cols .Items [i ], ref : ref , rv : p .rangeVar }
1048+ * p .refs = append (* p .refs , paramRef {parent : n .Cols .Items [i ], ref : ref , rv : p .rangeVar })
1049+ p .seen [ref .Location ] = struct {}{}
10181050 }
10191051 }
10201052 }
@@ -1050,7 +1082,7 @@ func (p paramSearch) Visit(node nodes.Node) Visitor {
10501082 parent = limitOffset {}
10511083 }
10521084 }
1053- if _ , found := p .refs [n .Number ]; found {
1085+ if _ , found := p .seen [n .Location ]; found {
10541086 break
10551087 }
10561088
@@ -1072,21 +1104,18 @@ func (p paramSearch) Visit(node nodes.Node) Visitor {
10721104 }
10731105
10741106 if set {
1075- p .refs [n .Number ] = paramRef {parent : parent , ref : n , rv : p .rangeVar }
1107+ * p .refs = append (* p .refs , paramRef {parent : parent , ref : n , rv : p .rangeVar })
1108+ p .seen [n .Location ] = struct {}{}
10761109 }
10771110 return nil
10781111 }
10791112 return p
10801113}
10811114
10821115func findParameters (root nodes.Node ) []paramRef {
1083- v := paramSearch {refs : map [int ]paramRef {}}
1084- Walk (v , root )
10851116 refs := make ([]paramRef , 0 )
1086- for _ , r := range v .refs {
1087- refs = append (refs , r )
1088- }
1089- sort .Slice (refs , func (i , j int ) bool { return refs [i ].ref .Number < refs [j ].ref .Number })
1117+ v := paramSearch {seen : make (map [int ]struct {}), refs : & refs }
1118+ Walk (v , root )
10901119 return refs
10911120}
10921121
@@ -1348,3 +1377,15 @@ func resolveCatalogRefs(c core.Catalog, rvs []nodes.RangeVar, args []paramRef) (
13481377 }
13491378 return a , nil
13501379}
1380+
1381+ func uniqueParamRefs (in []paramRef ) []paramRef {
1382+ m := make (map [int ]struct {}, len (in ))
1383+ o := make ([]paramRef , 0 , len (in ))
1384+ for _ , v := range in {
1385+ if _ , ok := m [v .ref .Number ]; ! ok {
1386+ m [v .ref .Number ] = struct {}{}
1387+ o = append (o , v )
1388+ }
1389+ }
1390+ return o
1391+ }
0 commit comments