Skip to content

Commit d5ad92e

Browse files
committed
merge compressedReader and compressedWriter
1 parent 876af07 commit d5ad92e

File tree

4 files changed

+23
-50
lines changed

4 files changed

+23
-50
lines changed

compress.go

Lines changed: 13 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,41 +11,29 @@ import (
1111
// for debugging wire protocol.
1212
const debugTrace = false
1313

14-
type compressedReader struct {
15-
buf packetReader
14+
type compressor struct {
15+
mc *mysqlConn
16+
// for reader
1617
bytesBuf []byte
17-
mc *mysqlConn
1818
zr io.ReadCloser
19-
}
20-
21-
type compressedWriter struct {
19+
// for writer
2220
connWriter io.Writer
23-
mc *mysqlConn
2421
zw *zlib.Writer
2522
}
2623

27-
func newCompressedReader(buf packetReader, mc *mysqlConn) *compressedReader {
28-
return &compressedReader{
29-
buf: buf,
30-
bytesBuf: make([]byte, 0),
31-
mc: mc,
32-
}
33-
}
34-
35-
func newCompressedWriter(connWriter io.Writer, mc *mysqlConn) *compressedWriter {
36-
// level 1 or 2 is the best trade-off between speed and compression ratio
24+
func newCompressor(mc *mysqlConn, w io.Writer) *compressor {
3725
zw, err := zlib.NewWriterLevel(new(bytes.Buffer), 2)
3826
if err != nil {
3927
panic(err) // compress/zlib return non-nil error only if level is invalid
4028
}
41-
return &compressedWriter{
42-
connWriter: connWriter,
29+
return &compressor{
4330
mc: mc,
31+
connWriter: w,
4432
zw: zw,
4533
}
4634
}
4735

48-
func (r *compressedReader) readNext(need int) ([]byte, error) {
36+
func (r *compressor) readNext(need int) ([]byte, error) {
4937
for len(r.bytesBuf) < need {
5038
if err := r.uncompressPacket(); err != nil {
5139
return nil, err
@@ -57,8 +45,8 @@ func (r *compressedReader) readNext(need int) ([]byte, error) {
5745
return data, nil
5846
}
5947

60-
func (r *compressedReader) uncompressPacket() error {
61-
header, err := r.buf.readNext(7) // size of compressed header
48+
func (r *compressor) uncompressPacket() error {
49+
header, err := r.mc.buf.readNext(7) // size of compressed header
6250
if err != nil {
6351
return err
6452
}
@@ -76,7 +64,7 @@ func (r *compressedReader) uncompressPacket() error {
7664
}
7765
r.mc.compressionSequence++
7866

79-
comprData, err := r.buf.readNext(comprLength)
67+
comprData, err := r.mc.buf.readNext(comprLength)
8068
if err != nil {
8169
return err
8270
}
@@ -138,7 +126,7 @@ const maxPayloadLen = maxPacketSize - 4
138126

139127
var blankHeader = make([]byte, 7)
140128

141-
func (w *compressedWriter) Write(data []byte) (int, error) {
129+
func (w *compressor) Write(data []byte) (int, error) {
142130
totalBytes := len(data)
143131
dataLen := len(data)
144132
var buf bytes.Buffer
@@ -183,7 +171,7 @@ func (w *compressedWriter) Write(data []byte) (int, error) {
183171

184172
// writeCompressedPacket writes a compressed packet with header.
185173
// data should start with 7 size space for header followed by payload.
186-
func (w *compressedWriter) writeCompressedPacket(data []byte, uncompressedLen int) error {
174+
func (w *compressor) writeCompressedPacket(data []byte, uncompressedLen int) error {
187175
comprLength := len(data) - 7
188176

189177
// compression header

compress_test.go

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,25 +19,13 @@ func newMockConn() *mysqlConn {
1919
return newConn
2020
}
2121

22-
type mockBuf struct {
23-
reader io.Reader
24-
}
25-
26-
func newMockBuf(reader io.Reader) *mockBuf {
27-
return &mockBuf{
28-
reader: reader,
22+
func newMockBuf(data []byte) buffer {
23+
return buffer{
24+
buf: data,
25+
length: len(data),
2926
}
3027
}
3128

32-
func (mb *mockBuf) readNext(need int) ([]byte, error) {
33-
data := make([]byte, need)
34-
_, err := mb.reader.Read(data)
35-
if err != nil {
36-
return nil, err
37-
}
38-
return data, nil
39-
}
40-
4129
// compressHelper compresses uncompressedPacket and checks state variables
4230
func compressHelper(t *testing.T, mc *mysqlConn, uncompressedPacket []byte) []byte {
4331
// get status variables
@@ -47,7 +35,7 @@ func compressHelper(t *testing.T, mc *mysqlConn, uncompressedPacket []byte) []by
4735
var b bytes.Buffer
4836
connWriter := &b
4937

50-
cw := newCompressedWriter(connWriter, mc)
38+
cw := newCompressor(mc, connWriter)
5139

5240
n, err := cw.Write(uncompressedPacket)
5341

@@ -79,10 +67,8 @@ func uncompressHelper(t *testing.T, mc *mysqlConn, compressedPacket []byte, expS
7967
cs := mc.compressionSequence
8068

8169
// mocking out buf variable
82-
mockConnReader := bytes.NewReader(compressedPacket)
83-
mockBuf := newMockBuf(mockConnReader)
84-
85-
cr := newCompressedReader(mockBuf, mc)
70+
mc.buf = newMockBuf(compressedPacket)
71+
cr := newCompressor(mc, nil)
8672

8773
uncompressedPacket, err := cr.readNext(expSize)
8874
if err != nil {

connector.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,9 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
169169
}
170170

171171
if mc.compress {
172-
mc.packetReader = newCompressedReader(&mc.buf, mc)
173-
mc.packetWriter = newCompressedWriter(mc.packetWriter, mc)
172+
cmpr := newCompressor(mc, mc.packetWriter)
173+
mc.packetReader = cmpr
174+
mc.packetWriter = cmpr
174175
}
175176
if mc.cfg.MaxAllowedPacket > 0 {
176177
mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket

packets.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,6 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro
206206
if mc.flags&clientProtocol41 == 0 {
207207
return nil, "", ErrOldProtocol
208208
}
209-
210-
// TODO(methane): writing to mc.cfg.XXX is bad idea. Fix it later.
211209
if mc.flags&clientSSL == 0 && mc.cfg.TLS != nil {
212210
if mc.cfg.AllowFallbackToPlaintext {
213211
mc.cfg.TLS = nil

0 commit comments

Comments
 (0)