Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GODRIVER-3302 Handle malformatted message length properly. #1758

Merged
merged 8 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 27 additions & 13 deletions x/mongo/driver/topology/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ type connection struct {
driverConnectionID uint64
generation uint64

// awaitingResponse indicates that the server response was not completely
// awaitingResponse indicates the size of server response that was not completely
// read before returning the connection to the pool.
awaitingResponse bool
awaitingResponse *int32
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest renaming awaitingResponse to awaitRemainingBytes.


// oidcTokenGenID is the monotonic generation ID for OIDC tokens, used to invalidate
// accessTokens in the OIDC authenticator cache.
Expand Down Expand Up @@ -423,15 +423,10 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) {

dst, errMsg, err := c.read(ctx)
if err != nil {
if nerr := net.Error(nil); errors.As(err, &nerr) && nerr.Timeout() && csot.IsTimeoutContext(ctx) {
// If the error was a timeout error and CSOT is enabled, instead of
// closing the connection mark it as awaiting response so the pool
// can read the response before making it available to other
// operations.
c.awaitingResponse = true
} else {
// Otherwise, use the pre-CSOT behavior and close the connection
// because we don't know if there are other bytes left to read.
if c.awaitingResponse == nil {
// If the connection was not marked as awaiting response, use the
// pre-CSOT behavior and close the connection because we don't know
// if there are other bytes left to read.
c.close()
}
message := errMsg
Expand Down Expand Up @@ -461,21 +456,37 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string,
}
}()

needToWait := func(err error) bool {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional: Consider a more descriptive name, like isCSOTTimeout.

// If the error was a timeout error and CSOT is enabled, instead of
// closing the connection mark it as awaiting response so the pool
// can read the response before making it available to other
// operations.
nerr := net.Error(nil)
return errors.As(err, &nerr) && nerr.Timeout() && csot.IsTimeoutContext(ctx)
}

// We use an array here because it only costs 4 bytes on the stack and means we'll only need to
// reslice dst once instead of twice.
var sizeBuf [4]byte

// We do a ReadFull into an array here instead of doing an opportunistic ReadAtLeast into dst
// because there might be more than one wire message waiting to be read, for example when
// reading messages from an exhaust cursor.
_, err = io.ReadFull(c.nc, sizeBuf[:])
n, err := io.ReadFull(c.nc, sizeBuf[:])
if err != nil {
if l := int32(n); l == 0 && needToWait(err) {
c.awaitingResponse = &l
}
return nil, "incomplete read of message header", err
}

// read the length as an int32
size := (int32(sizeBuf[0])) | (int32(sizeBuf[1]) << 8) | (int32(sizeBuf[2]) << 16) | (int32(sizeBuf[3]) << 24)

if size < 4 {
err = fmt.Errorf("malformed message length: %d", size)
return nil, err.Error(), err
}
// In the case of a hello response where MaxMessageSize has not yet been set, use the hard-coded
// defaultMaxMessageSize instead.
maxMessageSize := c.desc.MaxMessageSize
Expand All @@ -489,8 +500,11 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string,
dst := make([]byte, size)
copy(dst, sizeBuf[:])

_, err = io.ReadFull(c.nc, dst[4:])
n, err = io.ReadFull(c.nc, dst[4:])
if err != nil {
if l := size - 4 - int32(n); l > 0 && needToWait(err) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest renaming l to remainingBytes.

c.awaitingResponse = &l
}
return dst, "incomplete read of full message", err
}

Expand Down
17 changes: 17 additions & 0 deletions x/mongo/driver/topology/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,23 @@ func TestConnection(t *testing.T) {
}
listener.assertCalledOnce(t)
})
t.Run("size too small errors", func(t *testing.T) {
err := errors.New("malformed message length: 3")
tnc := &testNetConn{readerr: err, buf: []byte{0x03, 0x00, 0x00, 0x00}}
conn := &connection{id: "foobar", nc: tnc, state: connConnected}
listener := newTestCancellationListener(false)
conn.cancellationListener = listener

want := ConnectionError{ConnectionID: "foobar", Wrapped: err, message: err.Error()}
_, got := conn.readWireMessage(context.Background())
if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
t.Errorf("errors do not match. got %v; want %v", got, want)
}
if !tnc.closed {
t.Errorf("failed to closeConnection net.Conn after error writing bytes.")
}
listener.assertCalledOnce(t)
})
t.Run("full message read errors", func(t *testing.T) {
err := errors.New("Read error")
tnc := &testNetConn{readerr: err, buf: []byte{0x11, 0x00, 0x00, 0x00}}
Expand Down
78 changes: 47 additions & 31 deletions x/mongo/driver/topology/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ package topology
import (
"context"
"fmt"
"io"
"io/ioutil"
"net"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -788,17 +790,27 @@ var (
//
// It calls the package-global BGReadCallback function, if set, with the
// address, timings, and any errors that occurred.
func bgRead(pool *pool, conn *connection) {
var start, read time.Time
start = time.Now()
errs := make([]error, 0)
connClosed := false
func bgRead(pool *pool, conn *connection, size int32) {
var err error
start := time.Now()

defer func() {
read := time.Now()
errs := make([]error, 0)
connClosed := false
if err != nil {
errs = append(errs, err)
connClosed = true
err = conn.close()
if err != nil {
errs = append(errs, fmt.Errorf("error closing conn after reading: %w", err))
}
}

// No matter what happens, always check the connection back into the
// pool, which will either make it available for other operations or
// remove it from the pool if it was closed.
err := pool.checkInNoEvent(conn)
err = pool.checkInNoEvent(conn)
if err != nil {
errs = append(errs, fmt.Errorf("error checking in: %w", err))
}
Expand All @@ -808,34 +820,37 @@ func bgRead(pool *pool, conn *connection) {
}
}()

err := conn.nc.SetReadDeadline(time.Now().Add(BGReadTimeout))
err = conn.nc.SetReadDeadline(time.Now().Add(BGReadTimeout))
if err != nil {
errs = append(errs, fmt.Errorf("error setting a read deadline: %w", err))

connClosed = true
err := conn.close()
if err != nil {
errs = append(errs, fmt.Errorf("error closing conn after setting read deadline: %w", err))
}

err = fmt.Errorf("error setting a read deadline: %w", err)
return
}

// The context here is only used for cancellation, not deadline timeout, so
// use context.Background(). The read timeout is set by calling
// SetReadDeadline above.
_, _, err = conn.read(context.Background())
read = time.Now()
if err != nil {
errs = append(errs, fmt.Errorf("error reading: %w", err))

connClosed = true
err := conn.close()
if size == 0 {
var sizeBuf [4]byte
_, err = io.ReadFull(conn.nc, sizeBuf[:])
Comment on lines +829 to +830
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest making this logic a function that can be used here and in the connection read method:

func readWMSize(r io.Reader) (int32, error) {
	const wireMessageSizePrefix = 4

	var wmSizeBytes [wireMessageSizePrefix]byte
	if _, err := io.ReadFull(r, wmSizeBytes[:]); err != nil {
		return 0, fmt.Errorf("error reading the message size: %w", err)
	}

	size := (int32(wmSizeBytes[0])) |
		(int32(wmSizeBytes[1]) << 8) |
		(int32(wmSizeBytes[2]) << 16) |
		(int32(wmSizeBytes[3]) << 24)

	if size < 4 {
		return 0, fmt.Errorf("malformed message length: %d", size)
	}

	return size, nil
}

if err != nil {
errs = append(errs, fmt.Errorf("error closing conn after reading: %w", err))
err = fmt.Errorf("error reading the message size: %w", err)
return
}

return
size = (int32(sizeBuf[0])) | (int32(sizeBuf[1]) << 8) | (int32(sizeBuf[2]) << 16) | (int32(sizeBuf[3]) << 24)
if size < 4 {
err = fmt.Errorf("malformed message length: %d", size)
return
}
maxMessageSize := conn.desc.MaxMessageSize
if maxMessageSize == 0 {
maxMessageSize = defaultMaxMessageSize
}
if uint32(size) > maxMessageSize {
err = errResponseTooLarge
return
}
size -= 4
}
_, err = io.CopyN(ioutil.Discard, conn.nc, int64(size))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ioutil is deprecated, should use io.Discard

Suggested change
_, err = io.CopyN(ioutil.Discard, conn.nc, int64(size))
_, err = io.CopyN(io.Discard, conn.nc, int64(size))

if err != nil {
err = fmt.Errorf("error reading message of %d: %w", size, err)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional: This error message is a bit confusing. Consider a clearer error message.

Suggested change
err = fmt.Errorf("error reading message of %d: %w", size, err)
err = fmt.Errorf("error discarding %d byte message: %w", size, err)

}
}

Expand Down Expand Up @@ -886,9 +901,10 @@ func (p *pool) checkInNoEvent(conn *connection) error {
// means that connections in "awaiting response" state are checked in but
// not usable, which is not covered by the current pool events. We may need
// to add pool event information in the future to communicate that.
if conn.awaitingResponse {
conn.awaitingResponse = false
go bgRead(p, conn)
if conn.awaitingResponse != nil {
size := *conn.awaitingResponse
conn.awaitingResponse = nil
go bgRead(p, conn, size)
return nil
}

Expand Down
Loading
Loading