From 007f3066f6e41531ce2d9b96f60f6f555359a520 Mon Sep 17 00:00:00 2001 From: David Vilaverde Date: Tue, 7 May 2024 03:38:15 -0400 Subject: [PATCH] fixing bad connection error when reading large compressed packets (#863) * fixing bad connection error when reading large compressed packets * fixing linting errors * minor cleanup and some more comments * minor cleanup and some more comments * fixing issue when net_buffer_length=1024 * fixing packet reader lookup condition * handle possible nil access violation when attempting to read next compressed packet * removed deprecated linters that no longer exist in golangci-lint 1.58.0 * addressing PR feedback * addressing PR feedback * removed compressedReaderActive --------- Co-authored-by: dvilaverde Co-authored-by: lance6716 --- packet/conn.go | 125 ++++++++++++++++++++++++++++++++----------------- 1 file changed, 83 insertions(+), 42 deletions(-) diff --git a/packet/conn.go b/packet/conn.go index 866e9ac81..7250623fa 100644 --- a/packet/conn.go +++ b/packet/conn.go @@ -9,6 +9,7 @@ import ( "crypto/sha1" "crypto/x509" "encoding/pem" + goErrors "errors" "io" "net" "sync" @@ -65,8 +66,6 @@ type Conn struct { compressedHeader [7]byte - compressedReaderActive bool - compressedReader io.Reader } @@ -107,42 +106,17 @@ func (c *Conn) ReadPacketReuseMem(dst []byte) ([]byte, error) { }() if c.Compression != MYSQL_COMPRESS_NONE { - if !c.compressedReaderActive { - if _, err := io.ReadFull(c.reader, c.compressedHeader[:7]); err != nil { - return nil, errors.Wrapf(ErrBadConn, "io.ReadFull(compressedHeader) failed. err %v", err) - } - - compressedSequence := c.compressedHeader[3] - uncompressedLength := int(uint32(c.compressedHeader[4]) | uint32(c.compressedHeader[5])<<8 | uint32(c.compressedHeader[6])<<16) - if compressedSequence != c.CompressedSequence { - return nil, errors.Errorf("invalid compressed sequence %d != %d", - compressedSequence, c.CompressedSequence) - } - - if uncompressedLength > 0 { - var err error - switch c.Compression { - case MYSQL_COMPRESS_ZLIB: - c.compressedReader, err = zlib.NewReader(c.reader) - case MYSQL_COMPRESS_ZSTD: - c.compressedReader, err = zstd.NewReader(c.reader) - } - if err != nil { - return nil, err - } + if c.compressedReader == nil { + var err error + c.compressedReader, err = c.newCompressedPacketReader() + if err != nil { + return nil, err } - c.compressedReaderActive = true } } - if c.compressedReader != nil { - if err := c.ReadPacketTo(buf, c.compressedReader); err != nil { - return nil, errors.Trace(err) - } - } else { - if err := c.ReadPacketTo(buf, c.reader); err != nil { - return nil, errors.Trace(err) - } + if err := c.ReadPacketTo(buf); err != nil { + return nil, errors.Trace(err) } readBytes := buf.Bytes() @@ -167,7 +141,44 @@ func (c *Conn) ReadPacketReuseMem(dst []byte) ([]byte, error) { return result, nil } -func (c *Conn) copyN(dst io.Writer, src io.Reader, n int64) (written int64, err error) { +// newCompressedPacketReader creates a new compressed packet reader. +func (c *Conn) newCompressedPacketReader() (io.Reader, error) { + if _, err := io.ReadFull(c.reader, c.compressedHeader[:7]); err != nil { + return nil, errors.Wrapf(ErrBadConn, "io.ReadFull(compressedHeader) failed. err %v", err) + } + + compressedSequence := c.compressedHeader[3] + if compressedSequence != c.CompressedSequence { + return nil, errors.Errorf("invalid compressed sequence %d != %d", + compressedSequence, c.CompressedSequence) + } + + compressedLength := int(uint32(c.compressedHeader[0]) | uint32(c.compressedHeader[1])<<8 | uint32(c.compressedHeader[2])<<16) + uncompressedLength := int(uint32(c.compressedHeader[4]) | uint32(c.compressedHeader[5])<<8 | uint32(c.compressedHeader[6])<<16) + if uncompressedLength > 0 { + limitedReader := io.LimitReader(c.reader, int64(compressedLength)) + switch c.Compression { + case MYSQL_COMPRESS_ZLIB: + return zlib.NewReader(limitedReader) + case MYSQL_COMPRESS_ZSTD: + return zstd.NewReader(limitedReader) + } + } + + return nil, nil +} + +func (c *Conn) currentPacketReader() io.Reader { + if c.Compression == MYSQL_COMPRESS_NONE || c.compressedReader == nil { + return c.reader + } else { + return c.compressedReader + } +} + +func (c *Conn) copyN(dst io.Writer, n int64) (int64, error) { + var written int64 + for n > 0 { bcap := cap(c.copyNBuf) if int64(bcap) > n { @@ -175,14 +186,33 @@ func (c *Conn) copyN(dst io.Writer, src io.Reader, n int64) (written int64, err } buf := c.copyNBuf[:bcap] - rd, err := io.ReadAtLeast(src, buf, bcap) + // Call ReadAtLeast with the currentPacketReader as it may change on every iteration + // of this loop. + rd, err := io.ReadAtLeast(c.currentPacketReader(), buf, bcap) + n -= int64(rd) + // ReadAtLeast will return EOF or ErrUnexpectedEOF when fewer than the min + // bytes are read. In this case, and when we have compression then advance + // the sequence number and reset the compressed reader to continue reading + // the remaining bytes in the next compressed packet. + if c.Compression != MYSQL_COMPRESS_NONE && + (goErrors.Is(err, io.ErrUnexpectedEOF) || goErrors.Is(err, io.EOF)) { + // we have read to EOF and read an incomplete uncompressed packet + // so advance the compressed sequence number and reset the compressed reader + // to get the remaining unread uncompressed bytes from the next compressed packet. + c.CompressedSequence++ + if c.compressedReader, err = c.newCompressedPacketReader(); err != nil { + return written, errors.Trace(err) + } + } + if err != nil { return written, errors.Trace(err) } - wr, err := dst.Write(buf) + // careful to only write from the buffer the number of bytes read + wr, err := dst.Write(buf[:rd]) written += int64(wr) if err != nil { return written, errors.Trace(err) @@ -192,9 +222,21 @@ func (c *Conn) copyN(dst io.Writer, src io.Reader, n int64) (written int64, err return written, nil } -func (c *Conn) ReadPacketTo(w io.Writer, r io.Reader) error { - if _, err := io.ReadFull(r, c.header[:4]); err != nil { +func (c *Conn) ReadPacketTo(w io.Writer) error { + b := utils.BytesBufferGet() + defer func() { + utils.BytesBufferPut(b) + }() + + // packets that come in a compressed packet may be partial + // so use the copyN function to read the packet header into a + // buffer, since copyN is capable of getting the next compressed + // packet and updating the Conn state with a new compressedReader. + if _, err := c.copyN(b, 4); err != nil { return errors.Wrapf(ErrBadConn, "io.ReadFull(header) failed. err %v", err) + } else { + // copy was successful so copy the 4 bytes from the buffer to the header + copy(c.header[:4], b.Bytes()[:4]) } length := int(uint32(c.header[0]) | uint32(c.header[1])<<8 | uint32(c.header[2])<<16) @@ -211,7 +253,7 @@ func (c *Conn) ReadPacketTo(w io.Writer, r io.Reader) error { buf.Grow(length) } - if n, err := c.copyN(w, r, int64(length)); err != nil { + if n, err := c.copyN(w, int64(length)); err != nil { return errors.Wrapf(ErrBadConn, "io.CopyN failed. err %v, copied %v, expected %v", err, n, length) } else if n != int64(length) { return errors.Wrapf(ErrBadConn, "io.CopyN failed(n != int64(length)). %v bytes copied, while %v expected", n, length) @@ -220,7 +262,7 @@ func (c *Conn) ReadPacketTo(w io.Writer, r io.Reader) error { return nil } - if err = c.ReadPacketTo(w, r); err != nil { + if err = c.ReadPacketTo(w); err != nil { return errors.Wrap(err, "ReadPacketTo failed") } } @@ -270,7 +312,6 @@ func (c *Conn) WritePacket(data []byte) error { return errors.Wrapf(ErrBadConn, "Write failed. only %v bytes written, while %v expected", n, len(data)) } c.compressedReader = nil - c.compressedReaderActive = false default: return errors.Wrapf(ErrBadConn, "Write failed. Unsuppored compression algorithm set") }