Skip to content

Commit 1c05916

Browse files
committed
use sync.Pool for zlib
1 parent d5ad92e commit 1c05916

File tree

1 file changed

+106
-77
lines changed

1 file changed

+106
-77
lines changed

compress.go

Lines changed: 106 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -6,47 +6,108 @@ import (
66
"fmt"
77
"io"
88
"os"
9+
"sync"
910
)
1011

12+
var (
13+
zrPool *sync.Pool // Do not use directly. Use zDecompress() instead.
14+
zwPool *sync.Pool // Do not use directly. Use zCompress() instead.
15+
)
16+
17+
func init() {
18+
zrPool = &sync.Pool{
19+
New: func() any { return nil },
20+
}
21+
zwPool = &sync.Pool{
22+
New: func() any {
23+
zw, err := zlib.NewWriterLevel(new(bytes.Buffer), 2)
24+
if err != nil {
25+
panic(err) // compress/zlib return non-nil error only if level is invalid
26+
}
27+
return zw
28+
},
29+
}
30+
}
31+
32+
func zDecompress(src, dst []byte) (int, error) {
33+
br := bytes.NewReader(src)
34+
var zr io.ReadCloser
35+
var err error
36+
37+
if a := zrPool.Get(); a == nil {
38+
if zr, err = zlib.NewReader(br); err != nil {
39+
return 0, err
40+
}
41+
} else {
42+
zr = a.(io.ReadCloser)
43+
if zr.(zlib.Resetter).Reset(br, nil); err != nil {
44+
return 0, err
45+
}
46+
}
47+
defer func() {
48+
zr.Close()
49+
zrPool.Put(zr)
50+
}()
51+
52+
lenRead := 0
53+
size := len(dst)
54+
55+
for lenRead < size {
56+
n, err := zr.Read(dst[lenRead:])
57+
lenRead += n
58+
59+
if err == io.EOF {
60+
if lenRead < size {
61+
return lenRead, io.ErrUnexpectedEOF
62+
}
63+
} else if err != nil {
64+
return lenRead, err
65+
}
66+
}
67+
return lenRead, nil
68+
}
69+
70+
func zCompress(src []byte, dst io.Writer) error {
71+
zw := zwPool.Get().(*zlib.Writer)
72+
zw.Reset(dst)
73+
if _, err := zw.Write(src); err != nil {
74+
return err
75+
}
76+
zw.Close()
77+
zwPool.Put(zw)
78+
return nil
79+
}
80+
1181
// for debugging wire protocol.
1282
const debugTrace = false
1383

1484
type compressor struct {
15-
mc *mysqlConn
16-
// for reader
17-
bytesBuf []byte
18-
zr io.ReadCloser
19-
// for writer
85+
mc *mysqlConn
86+
bytesBuf []byte // read buffer (FIFO)
2087
connWriter io.Writer
21-
zw *zlib.Writer
2288
}
2389

2490
func newCompressor(mc *mysqlConn, w io.Writer) *compressor {
25-
zw, err := zlib.NewWriterLevel(new(bytes.Buffer), 2)
26-
if err != nil {
27-
panic(err) // compress/zlib return non-nil error only if level is invalid
28-
}
2991
return &compressor{
3092
mc: mc,
3193
connWriter: w,
32-
zw: zw,
3394
}
3495
}
3596

36-
func (r *compressor) readNext(need int) ([]byte, error) {
37-
for len(r.bytesBuf) < need {
38-
if err := r.uncompressPacket(); err != nil {
97+
func (c *compressor) readNext(need int) ([]byte, error) {
98+
for len(c.bytesBuf) < need {
99+
if err := c.uncompressPacket(); err != nil {
39100
return nil, err
40101
}
41102
}
42103

43-
data := r.bytesBuf[:need:need] // prevent caller writes into r.bytesBuf
44-
r.bytesBuf = r.bytesBuf[need:]
104+
data := c.bytesBuf[:need:need] // prevent caller writes into r.bytesBuf
105+
c.bytesBuf = c.bytesBuf[need:]
45106
return data, nil
46107
}
47108

48-
func (r *compressor) uncompressPacket() error {
49-
header, err := r.mc.buf.readNext(7) // size of compressed header
109+
func (c *compressor) uncompressPacket() error {
110+
header, err := c.mc.buf.readNext(7) // size of compressed header
50111
if err != nil {
51112
return err
52113
}
@@ -59,74 +120,48 @@ func (r *compressor) uncompressPacket() error {
59120
fmt.Fprintf(os.Stderr, "uncompress cmplen=%v uncomplen=%v seq=%v\n",
60121
comprLength, uncompressedLength, compressionSequence)
61122
}
62-
if compressionSequence != r.mc.compressionSequence {
123+
if compressionSequence != c.mc.compressionSequence {
63124
return ErrPktSync
64125
}
65-
r.mc.compressionSequence++
126+
c.mc.compressionSequence++
66127

67-
comprData, err := r.mc.buf.readNext(comprLength)
128+
comprData, err := c.mc.buf.readNext(comprLength)
68129
if err != nil {
69130
return err
70131
}
71132

72133
// if payload is uncompressed, its length will be specified as zero, and its
73134
// true length is contained in comprLength
74135
if uncompressedLength == 0 {
75-
r.bytesBuf = append(r.bytesBuf, comprData...)
136+
c.bytesBuf = append(c.bytesBuf, comprData...)
76137
return nil
77138
}
78139

79-
// write comprData to a bytes.buffer, then read it using zlib into data
80-
br := bytes.NewReader(comprData)
81-
if r.zr == nil {
82-
if r.zr, err = zlib.NewReader(br); err != nil {
83-
return err
84-
}
85-
} else {
86-
if err = r.zr.(zlib.Resetter).Reset(br, nil); err != nil {
87-
return err
88-
}
89-
}
90-
defer r.zr.Close()
91-
92140
// use existing capacity in bytesBuf if possible
93-
offset := len(r.bytesBuf)
94-
if cap(r.bytesBuf)-offset < uncompressedLength {
95-
old := r.bytesBuf
96-
r.bytesBuf = make([]byte, offset, offset+uncompressedLength)
97-
copy(r.bytesBuf, old)
141+
offset := len(c.bytesBuf)
142+
if cap(c.bytesBuf)-offset < uncompressedLength {
143+
old := c.bytesBuf
144+
c.bytesBuf = make([]byte, offset, offset+uncompressedLength)
145+
copy(c.bytesBuf, old)
98146
}
99147

100-
data := r.bytesBuf[offset : offset+uncompressedLength]
101-
lenRead := 0
102-
103-
// http://grokbase.com/t/gg/golang-nuts/146y9ppn6b/go-nuts-stream-compression-with-compress-flate
104-
for lenRead < uncompressedLength {
105-
n, err := r.zr.Read(data[lenRead:])
106-
lenRead += n
107-
108-
if err == io.EOF {
109-
if lenRead < uncompressedLength {
110-
return io.ErrUnexpectedEOF
111-
}
112-
break
113-
} else if err != nil {
114-
return err
115-
}
148+
lenRead, err := zDecompress(comprData, c.bytesBuf[offset:offset+uncompressedLength])
149+
if err != nil {
150+
return err
116151
}
117152
if lenRead != uncompressedLength {
118153
return fmt.Errorf("invalid compressed packet: uncompressed length in header is %d, actual %d",
119154
uncompressedLength, lenRead)
120155
}
121-
r.bytesBuf = r.bytesBuf[:offset+uncompressedLength]
156+
c.bytesBuf = c.bytesBuf[:offset+uncompressedLength]
122157
return nil
123158
}
124159

125160
const maxPayloadLen = maxPacketSize - 4
126161

127162
var blankHeader = make([]byte, 7)
128163

129-
func (w *compressor) Write(data []byte) (int, error) {
164+
func (c *compressor) Write(data []byte) (int, error) {
130165
totalBytes := len(data)
131166
dataLen := len(data)
132167
var buf bytes.Buffer
@@ -150,17 +185,12 @@ func (w *compressor) Write(data []byte) (int, error) {
150185
}
151186
uncompressedLen = 0
152187
} else {
153-
w.zw.Reset(&buf)
154-
if _, err := w.zw.Write(payload); err != nil {
155-
return 0, err
156-
}
157-
w.zw.Close()
188+
zCompress(payload, &buf)
158189
}
159190

160-
if err := w.writeCompressedPacket(buf.Bytes(), uncompressedLen); err != nil {
191+
if err := c.writeCompressedPacket(buf.Bytes(), uncompressedLen); err != nil {
161192
return 0, err
162193
}
163-
164194
dataLen -= payloadLen
165195
data = data[payloadLen:]
166196
buf.Reset()
@@ -171,33 +201,32 @@ func (w *compressor) Write(data []byte) (int, error) {
171201

172202
// writeCompressedPacket writes a compressed packet with header.
173203
// data should start with 7 size space for header followed by payload.
174-
func (w *compressor) writeCompressedPacket(data []byte, uncompressedLen int) error {
204+
func (c *compressor) writeCompressedPacket(data []byte, uncompressedLen int) error {
175205
comprLength := len(data) - 7
206+
if debugTrace {
207+
c.mc.cfg.Logger.Print(
208+
fmt.Sprintf(
209+
"writeCompressedPacket: comprLength=%v, uncompressedLen=%v, seq=%v",
210+
comprLength, uncompressedLen, c.mc.compressionSequence))
211+
}
176212

177213
// compression header
178214
data[0] = byte(0xff & comprLength)
179215
data[1] = byte(0xff & (comprLength >> 8))
180216
data[2] = byte(0xff & (comprLength >> 16))
181217

182-
data[3] = w.mc.compressionSequence
218+
data[3] = c.mc.compressionSequence
183219

184220
// this value is never greater than maxPayloadLength
185221
data[4] = byte(0xff & uncompressedLen)
186222
data[5] = byte(0xff & (uncompressedLen >> 8))
187223
data[6] = byte(0xff & (uncompressedLen >> 16))
188224

189-
if debugTrace {
190-
w.mc.cfg.Logger.Print(
191-
fmt.Sprintf(
192-
"writeCompressedPacket: comprLength=%v, uncompressedLen=%v, seq=%v",
193-
comprLength, uncompressedLen, int(data[3])))
194-
}
195-
196-
if _, err := w.connWriter.Write(data); err != nil {
197-
w.mc.cfg.Logger.Print(err)
225+
if _, err := c.connWriter.Write(data); err != nil {
226+
c.mc.cfg.Logger.Print(err)
198227
return err
199228
}
200229

201-
w.mc.compressionSequence++
230+
c.mc.compressionSequence++
202231
return nil
203232
}

0 commit comments

Comments
 (0)