@@ -6,47 +6,108 @@ import (
66 "fmt"
77 "io"
88 "os"
9+ "sync"
910)
1011
12+ var (
13+ zrPool * sync.Pool // Do not use directly. Use zDecompress() instead.
14+ zwPool * sync.Pool // Do not use directly. Use zCompress() instead.
15+ )
16+
17+ func init () {
18+ zrPool = & sync.Pool {
19+ New : func () any { return nil },
20+ }
21+ zwPool = & sync.Pool {
22+ New : func () any {
23+ zw , err := zlib .NewWriterLevel (new (bytes.Buffer ), 2 )
24+ if err != nil {
25+ panic (err ) // compress/zlib return non-nil error only if level is invalid
26+ }
27+ return zw
28+ },
29+ }
30+ }
31+
32+ func zDecompress (src , dst []byte ) (int , error ) {
33+ br := bytes .NewReader (src )
34+ var zr io.ReadCloser
35+ var err error
36+
37+ if a := zrPool .Get (); a == nil {
38+ if zr , err = zlib .NewReader (br ); err != nil {
39+ return 0 , err
40+ }
41+ } else {
42+ zr = a .(io.ReadCloser )
43+ if zr .(zlib.Resetter ).Reset (br , nil ); err != nil {
44+ return 0 , err
45+ }
46+ }
47+ defer func () {
48+ zr .Close ()
49+ zrPool .Put (zr )
50+ }()
51+
52+ lenRead := 0
53+ size := len (dst )
54+
55+ for lenRead < size {
56+ n , err := zr .Read (dst [lenRead :])
57+ lenRead += n
58+
59+ if err == io .EOF {
60+ if lenRead < size {
61+ return lenRead , io .ErrUnexpectedEOF
62+ }
63+ } else if err != nil {
64+ return lenRead , err
65+ }
66+ }
67+ return lenRead , nil
68+ }
69+
70+ func zCompress (src []byte , dst io.Writer ) error {
71+ zw := zwPool .Get ().(* zlib.Writer )
72+ zw .Reset (dst )
73+ if _ , err := zw .Write (src ); err != nil {
74+ return err
75+ }
76+ zw .Close ()
77+ zwPool .Put (zw )
78+ return nil
79+ }
80+
1181// for debugging wire protocol.
1282const debugTrace = false
1383
1484type compressor struct {
15- mc * mysqlConn
16- // for reader
17- bytesBuf []byte
18- zr io.ReadCloser
19- // for writer
85+ mc * mysqlConn
86+ bytesBuf []byte // read buffer (FIFO)
2087 connWriter io.Writer
21- zw * zlib.Writer
2288}
2389
2490func newCompressor (mc * mysqlConn , w io.Writer ) * compressor {
25- zw , err := zlib .NewWriterLevel (new (bytes.Buffer ), 2 )
26- if err != nil {
27- panic (err ) // compress/zlib return non-nil error only if level is invalid
28- }
2991 return & compressor {
3092 mc : mc ,
3193 connWriter : w ,
32- zw : zw ,
3394 }
3495}
3596
36- func (r * compressor ) readNext (need int ) ([]byte , error ) {
37- for len (r .bytesBuf ) < need {
38- if err := r .uncompressPacket (); err != nil {
97+ func (c * compressor ) readNext (need int ) ([]byte , error ) {
98+ for len (c .bytesBuf ) < need {
99+ if err := c .uncompressPacket (); err != nil {
39100 return nil , err
40101 }
41102 }
42103
43- data := r .bytesBuf [:need :need ] // prevent caller writes into r.bytesBuf
44- r .bytesBuf = r .bytesBuf [need :]
104+ data := c .bytesBuf [:need :need ] // prevent caller writes into r.bytesBuf
105+ c .bytesBuf = c .bytesBuf [need :]
45106 return data , nil
46107}
47108
48- func (r * compressor ) uncompressPacket () error {
49- header , err := r .mc .buf .readNext (7 ) // size of compressed header
109+ func (c * compressor ) uncompressPacket () error {
110+ header , err := c .mc .buf .readNext (7 ) // size of compressed header
50111 if err != nil {
51112 return err
52113 }
@@ -59,74 +120,48 @@ func (r *compressor) uncompressPacket() error {
59120 fmt .Fprintf (os .Stderr , "uncompress cmplen=%v uncomplen=%v seq=%v\n " ,
60121 comprLength , uncompressedLength , compressionSequence )
61122 }
62- if compressionSequence != r .mc .compressionSequence {
123+ if compressionSequence != c .mc .compressionSequence {
63124 return ErrPktSync
64125 }
65- r .mc .compressionSequence ++
126+ c .mc .compressionSequence ++
66127
67- comprData , err := r .mc .buf .readNext (comprLength )
128+ comprData , err := c .mc .buf .readNext (comprLength )
68129 if err != nil {
69130 return err
70131 }
71132
72133 // if payload is uncompressed, its length will be specified as zero, and its
73134 // true length is contained in comprLength
74135 if uncompressedLength == 0 {
75- r .bytesBuf = append (r .bytesBuf , comprData ... )
136+ c .bytesBuf = append (c .bytesBuf , comprData ... )
76137 return nil
77138 }
78139
79- // write comprData to a bytes.buffer, then read it using zlib into data
80- br := bytes .NewReader (comprData )
81- if r .zr == nil {
82- if r .zr , err = zlib .NewReader (br ); err != nil {
83- return err
84- }
85- } else {
86- if err = r .zr .(zlib.Resetter ).Reset (br , nil ); err != nil {
87- return err
88- }
89- }
90- defer r .zr .Close ()
91-
92140 // use existing capacity in bytesBuf if possible
93- offset := len (r .bytesBuf )
94- if cap (r .bytesBuf )- offset < uncompressedLength {
95- old := r .bytesBuf
96- r .bytesBuf = make ([]byte , offset , offset + uncompressedLength )
97- copy (r .bytesBuf , old )
141+ offset := len (c .bytesBuf )
142+ if cap (c .bytesBuf )- offset < uncompressedLength {
143+ old := c .bytesBuf
144+ c .bytesBuf = make ([]byte , offset , offset + uncompressedLength )
145+ copy (c .bytesBuf , old )
98146 }
99147
100- data := r .bytesBuf [offset : offset + uncompressedLength ]
101- lenRead := 0
102-
103- // http://grokbase.com/t/gg/golang-nuts/146y9ppn6b/go-nuts-stream-compression-with-compress-flate
104- for lenRead < uncompressedLength {
105- n , err := r .zr .Read (data [lenRead :])
106- lenRead += n
107-
108- if err == io .EOF {
109- if lenRead < uncompressedLength {
110- return io .ErrUnexpectedEOF
111- }
112- break
113- } else if err != nil {
114- return err
115- }
148+ lenRead , err := zDecompress (comprData , c .bytesBuf [offset :offset + uncompressedLength ])
149+ if err != nil {
150+ return err
116151 }
117152 if lenRead != uncompressedLength {
118153 return fmt .Errorf ("invalid compressed packet: uncompressed length in header is %d, actual %d" ,
119154 uncompressedLength , lenRead )
120155 }
121- r .bytesBuf = r .bytesBuf [:offset + uncompressedLength ]
156+ c .bytesBuf = c .bytesBuf [:offset + uncompressedLength ]
122157 return nil
123158}
124159
125160const maxPayloadLen = maxPacketSize - 4
126161
127162var blankHeader = make ([]byte , 7 )
128163
129- func (w * compressor ) Write (data []byte ) (int , error ) {
164+ func (c * compressor ) Write (data []byte ) (int , error ) {
130165 totalBytes := len (data )
131166 dataLen := len (data )
132167 var buf bytes.Buffer
@@ -150,17 +185,12 @@ func (w *compressor) Write(data []byte) (int, error) {
150185 }
151186 uncompressedLen = 0
152187 } else {
153- w .zw .Reset (& buf )
154- if _ , err := w .zw .Write (payload ); err != nil {
155- return 0 , err
156- }
157- w .zw .Close ()
188+ zCompress (payload , & buf )
158189 }
159190
160- if err := w .writeCompressedPacket (buf .Bytes (), uncompressedLen ); err != nil {
191+ if err := c .writeCompressedPacket (buf .Bytes (), uncompressedLen ); err != nil {
161192 return 0 , err
162193 }
163-
164194 dataLen -= payloadLen
165195 data = data [payloadLen :]
166196 buf .Reset ()
@@ -171,33 +201,32 @@ func (w *compressor) Write(data []byte) (int, error) {
171201
172202// writeCompressedPacket writes a compressed packet with header.
173203// data should start with 7 size space for header followed by payload.
174- func (w * compressor ) writeCompressedPacket (data []byte , uncompressedLen int ) error {
204+ func (c * compressor ) writeCompressedPacket (data []byte , uncompressedLen int ) error {
175205 comprLength := len (data ) - 7
206+ if debugTrace {
207+ c .mc .cfg .Logger .Print (
208+ fmt .Sprintf (
209+ "writeCompressedPacket: comprLength=%v, uncompressedLen=%v, seq=%v" ,
210+ comprLength , uncompressedLen , c .mc .compressionSequence ))
211+ }
176212
177213 // compression header
178214 data [0 ] = byte (0xff & comprLength )
179215 data [1 ] = byte (0xff & (comprLength >> 8 ))
180216 data [2 ] = byte (0xff & (comprLength >> 16 ))
181217
182- data [3 ] = w .mc .compressionSequence
218+ data [3 ] = c .mc .compressionSequence
183219
184220 // this value is never greater than maxPayloadLength
185221 data [4 ] = byte (0xff & uncompressedLen )
186222 data [5 ] = byte (0xff & (uncompressedLen >> 8 ))
187223 data [6 ] = byte (0xff & (uncompressedLen >> 16 ))
188224
189- if debugTrace {
190- w .mc .cfg .Logger .Print (
191- fmt .Sprintf (
192- "writeCompressedPacket: comprLength=%v, uncompressedLen=%v, seq=%v" ,
193- comprLength , uncompressedLen , int (data [3 ])))
194- }
195-
196- if _ , err := w .connWriter .Write (data ); err != nil {
197- w .mc .cfg .Logger .Print (err )
225+ if _ , err := c .connWriter .Write (data ); err != nil {
226+ c .mc .cfg .Logger .Print (err )
198227 return err
199228 }
200229
201- w .mc .compressionSequence ++
230+ c .mc .compressionSequence ++
202231 return nil
203232}
0 commit comments