@@ -47,13 +47,10 @@ func (mc *mysqlConn) handleParams() (err error) {
4747 charsets := strings .Split (val , "," )
4848 for _ , charset := range charsets {
4949 err = mc .exec ("SET NAMES " + charset )
50- if err = = nil {
51- break
50+ if err ! = nil {
51+ return
5252 }
5353 }
54- if err != nil {
55- return
56- }
5754
5855 // TLS-Encryption
5956 case "tls" :
@@ -78,11 +75,11 @@ func (mc *mysqlConn) handleParams() (err error) {
7875
7976func (mc * mysqlConn ) Begin () (driver.Tx , error ) {
8077 err := mc .exec ("START TRANSACTION" )
81- if err ! = nil {
82- return nil , err
78+ if err = = nil {
79+ return & mysqlTx { mc } , err
8380 }
8481
85- return & mysqlTx { mc } , err
82+ return nil , err
8683}
8784
8885func (mc * mysqlConn ) Close () (err error ) {
@@ -96,7 +93,7 @@ func (mc *mysqlConn) Close() (err error) {
9693
9794func (mc * mysqlConn ) Prepare (query string ) (driver.Stmt , error ) {
9895 // Send command
99- err := mc .writeCommandPacket (COM_STMT_PREPARE , query )
96+ err := mc .writeCommandPacketStr (COM_STMT_PREPARE , query )
10097 if err != nil {
10198 return nil , err
10299 }
@@ -106,52 +103,54 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
106103 }
107104
108105 // Read Result
109- var columnCount uint16
110- columnCount , err = stmt .readPrepareResultPacket ()
111- if err != nil {
112- return nil , err
113- }
114-
115- if stmt .paramCount > 0 {
116- stmt .params , err = stmt .mc .readColumns (stmt .paramCount )
117- if err != nil {
118- return nil , err
106+ columnCount , err := stmt .readPrepareResultPacket ()
107+ if err == nil {
108+ if stmt .paramCount > 0 {
109+ stmt .params , err = stmt .mc .readColumns (stmt .paramCount )
110+ if err != nil {
111+ return nil , err
112+ }
119113 }
120- }
121114
122- if columnCount > 0 {
123- _ , err = stmt .mc .readUntilEOF ()
124- if err != nil {
125- return nil , err
115+ if columnCount > 0 {
116+ _ , err = stmt .mc .readUntilEOF ()
126117 }
127118 }
128119
129120 return stmt , err
130121}
131122
132- func (mc * mysqlConn ) Exec (query string , args []driver.Value ) (driver.Result , error ) {
133- if len (args ) > 0 {
134- return nil , driver .ErrSkip
135- }
136-
137- mc .affectedRows = 0
138- mc .insertId = 0
139-
140- err := mc .exec (query )
141- if err != nil {
142- return nil , err
123+ func (mc * mysqlConn ) Exec (query string , args []driver.Value ) (_ driver.Result , err error ) {
124+ if len (args ) > 0 { // with args, must use prepared stmt
125+ var res driver.Result
126+ var stmt driver.Stmt
127+ stmt , err = mc .Prepare (query )
128+ if err == nil {
129+ res , err = stmt .Exec (args )
130+ if err == nil {
131+ return res , stmt .Close ()
132+ }
133+ }
134+ } else { // no args, fastpath
135+ mc .affectedRows = 0
136+ mc .insertId = 0
137+
138+ err = mc .exec (query )
139+ if err == nil {
140+ return & mysqlResult {
141+ affectedRows : int64 (mc .affectedRows ),
142+ insertId : int64 (mc .insertId ),
143+ }, err
144+ }
143145 }
146+ return nil , err
144147
145- return & mysqlResult {
146- affectedRows : int64 (mc .affectedRows ),
147- insertId : int64 (mc .insertId ),
148- }, err
149148}
150149
151150// Internal function to execute commands
152151func (mc * mysqlConn ) exec (query string ) (err error ) {
153152 // Send command
154- err = mc .writeCommandPacket (COM_QUERY , query )
153+ err = mc .writeCommandPacketStr (COM_QUERY , query )
155154 if err != nil {
156155 return
157156 }
@@ -175,39 +174,42 @@ func (mc *mysqlConn) exec(query string) (err error) {
175174 }
176175
177176 mc .affectedRows , err = mc .readUntilEOF ()
178- return
179177 }
180178
181179 return
182180}
183181
184- func (mc * mysqlConn ) Query (query string , args []driver.Value ) (driver.Rows , error ) {
185- if len (args ) > 0 {
186- return nil , driver .ErrSkip
187- }
188-
189- // Send command
190- err := mc .writeCommandPacket (COM_QUERY , query )
191- if err != nil {
192- return nil , err
193- }
194-
195- // Read Result
196- var resLen int
197- resLen , err = mc .readResultSetHeaderPacket ()
198- if err != nil {
199- return nil , err
200- }
201-
202- rows := & mysqlRows {mc , false , nil , false }
203-
204- if resLen > 0 {
205- // Columns
206- rows .columns , err = mc .readColumns (resLen )
207- if err != nil {
208- return nil , err
182+ func (mc * mysqlConn ) Query (query string , args []driver.Value ) (_ driver.Rows , err error ) {
183+ if len (args ) > 0 { // with args, must use prepared stmt
184+ var rows driver.Rows
185+ var stmt driver.Stmt
186+ stmt , err = mc .Prepare (query )
187+ if err == nil {
188+ rows , err = stmt .Query (args )
189+ if err == nil {
190+ return rows , stmt .Close ()
191+ }
192+ }
193+ return
194+ } else { // no args, fastpath
195+ var rows * mysqlRows
196+ // Send command
197+ err = mc .writeCommandPacketStr (COM_QUERY , query )
198+ if err == nil {
199+ // Read Result
200+ var resLen int
201+ resLen , err = mc .readResultSetHeaderPacket ()
202+ if err == nil {
203+ rows = & mysqlRows {mc , false , nil , false }
204+
205+ if resLen > 0 {
206+ // Columns
207+ rows .columns , err = mc .readColumns (resLen )
208+ }
209+ return rows , err
210+ }
209211 }
210212 }
211213
212- return rows , err
214+ return nil , err
213215}
0 commit comments