Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GODRIVER-3302 Handle malformatted message length properly. #1758

Merged
merged 8 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 100 additions & 81 deletions x/mongo/driver/topology/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package topology
import (
"context"
"crypto/tls"
"encoding/binary"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -79,9 +80,9 @@ type connection struct {
driverConnectionID uint64
generation uint64

// awaitingResponse indicates that the server response was not completely
// awaitRemainingBytes indicates the size of server response that was not completely
// read before returning the connection to the pool.
awaitingResponse bool
awaitRemainingBytes *int32

// oidcTokenGenID is the monotonic generation ID for OIDC tokens, used to invalidate
// accessTokens in the OIDC authenticator cache.
Expand Down Expand Up @@ -115,12 +116,6 @@ func newConnection(addr address.Address, opts ...ConnectionOption) *connection {
return c
}

// DriverConnectionID returns the driver connection ID.
Copy link
Collaborator Author

@qingyang-hu qingyang-hu Aug 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved DriverConnectionID down with other public methods.

// TODO(GODRIVER-2824): change return type to int64.
func (c *connection) DriverConnectionID() uint64 {
return c.driverConnectionID
}

// setGenerationNumber sets the connection's generation number if a callback has been provided to do so in connection
// configuration.
func (c *connection) setGenerationNumber() {
Expand All @@ -142,6 +137,39 @@ func (c *connection) hasGenerationNumber() bool {
return c.desc.LoadBalanced()
}

func configureTLS(ctx context.Context,
tlsConnSource tlsConnectionSource,
nc net.Conn,
addr address.Address,
config *tls.Config,
ocspOpts *ocsp.VerifyOptions,
) (net.Conn, error) {
// Ensure config.ServerName is always set for SNI.
if config.ServerName == "" {
hostname := addr.String()
colonPos := strings.LastIndex(hostname, ":")
if colonPos == -1 {
colonPos = len(hostname)
}

hostname = hostname[:colonPos]
config.ServerName = hostname
}

client := tlsConnSource.Client(nc, config)
if err := clientHandshake(ctx, client); err != nil {
return nil, err
}

// Only do OCSP verification if TLS verification is requested.
if !config.InsecureSkipVerify {
if ocspErr := ocsp.Verify(ctx, client.ConnectionState(), ocspOpts); ocspErr != nil {
return nil, ocspErr
}
}
return client, nil
}

// connect handles the I/O for a connection. It will dial, configure TLS, and perform initialization
// handshakes. All errors returned by connect are considered "before the handshake completes" and
// must be handled by calling the appropriate SDAM handshake error handler.
Expand Down Expand Up @@ -317,6 +345,10 @@ func (c *connection) closeConnectContext() {
}
}

func (c *connection) cancellationListenerCallback() {
_ = c.close()
}

func transformNetworkError(ctx context.Context, originalError error, contextDeadlineUsed bool) error {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved transformNetworkError closer to the caller.

if originalError == nil {
return nil
Expand All @@ -339,10 +371,6 @@ func transformNetworkError(ctx context.Context, originalError error, contextDead
return originalError
}

func (c *connection) cancellationListenerCallback() {
_ = c.close()
}

func (c *connection) writeWireMessage(ctx context.Context, wm []byte) error {
var err error
if atomic.LoadInt64(&c.state) != connConnected {
Expand Down Expand Up @@ -423,15 +451,10 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) {

dst, errMsg, err := c.read(ctx)
if err != nil {
if nerr := net.Error(nil); errors.As(err, &nerr) && nerr.Timeout() && csot.IsTimeoutContext(ctx) {
// If the error was a timeout error and CSOT is enabled, instead of
// closing the connection mark it as awaiting response so the pool
// can read the response before making it available to other
// operations.
c.awaitingResponse = true
} else {
// Otherwise, use the pre-CSOT behavior and close the connection
// because we don't know if there are other bytes left to read.
if c.awaitRemainingBytes == nil {
// If the connection was not marked as awaiting response, use the
// pre-CSOT behavior and close the connection because we don't know
// if there are other bytes left to read.
c.close()
}
message := errMsg
Expand All @@ -448,6 +471,26 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) {
return dst, nil
}

func (c *connection) parseWmSizeBytes(wmSizeBytes [4]byte) (int32, error) {
// read the length as an int32
size := int32(binary.LittleEndian.Uint32(wmSizeBytes[:]))

if size < 4 {
return 0, fmt.Errorf("malformed message length: %d", size)
}
// In the case of a hello response where MaxMessageSize has not yet been set, use the hard-coded
// defaultMaxMessageSize instead.
maxMessageSize := c.desc.MaxMessageSize
if maxMessageSize == 0 {
maxMessageSize = defaultMaxMessageSize
}
if uint32(size) > maxMessageSize {
return 0, errResponseTooLarge
}

return size, nil
}

func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string, err error) {
go c.cancellationListener.Listen(ctx, c.cancellationListenerCallback)
defer func() {
Expand All @@ -461,36 +504,43 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string,
}
}()

isCSOTTimeout := func(err error) bool {
// If the error was a timeout error and CSOT is enabled, instead of
// closing the connection mark it as awaiting response so the pool
// can read the response before making it available to other
// operations.
nerr := net.Error(nil)
return errors.As(err, &nerr) && nerr.Timeout() && csot.IsTimeoutContext(ctx)
}

// We use an array here because it only costs 4 bytes on the stack and means we'll only need to
// reslice dst once instead of twice.
var sizeBuf [4]byte

// We do a ReadFull into an array here instead of doing an opportunistic ReadAtLeast into dst
// because there might be more than one wire message waiting to be read, for example when
// reading messages from an exhaust cursor.
_, err = io.ReadFull(c.nc, sizeBuf[:])
n, err := io.ReadFull(c.nc, sizeBuf[:])
if err != nil {
if l := int32(n); l == 0 && isCSOTTimeout(err) {
c.awaitRemainingBytes = &l
}
return nil, "incomplete read of message header", err
}

// read the length as an int32
size := (int32(sizeBuf[0])) | (int32(sizeBuf[1]) << 8) | (int32(sizeBuf[2]) << 16) | (int32(sizeBuf[3]) << 24)

// In the case of a hello response where MaxMessageSize has not yet been set, use the hard-coded
// defaultMaxMessageSize instead.
maxMessageSize := c.desc.MaxMessageSize
if maxMessageSize == 0 {
maxMessageSize = defaultMaxMessageSize
}
if uint32(size) > maxMessageSize {
return nil, errResponseTooLarge.Error(), errResponseTooLarge
size, err := c.parseWmSizeBytes(sizeBuf)
if err != nil {
return nil, err.Error(), err
}

dst := make([]byte, size)
copy(dst, sizeBuf[:])

_, err = io.ReadFull(c.nc, dst[4:])
n, err = io.ReadFull(c.nc, dst[4:])
if err != nil {
remainingBytes := size - 4 - int32(n)
if remainingBytes > 0 && isCSOTTimeout(err) {
c.awaitRemainingBytes = &remainingBytes
}
return dst, "incomplete read of full message", err
}

Expand Down Expand Up @@ -537,10 +587,6 @@ func (c *connection) setCanStream(canStream bool) {
c.canStream = canStream
}

func (c initConnection) supportsStreaming() bool {
Copy link
Collaborator Author

@qingyang-hu qingyang-hu Aug 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merged in (initConnection).SupportsStreaming().

return c.canStream
}

func (c *connection) setStreaming(streaming bool) {
c.currentlyStreaming = streaming
}
Expand All @@ -554,6 +600,12 @@ func (c *connection) setSocketTimeout(timeout time.Duration) {
c.writeTimeout = timeout
}

// DriverConnectionID returns the driver connection ID.
// TODO(GODRIVER-2824): change return type to int64.
func (c *connection) DriverConnectionID() uint64 {
return c.driverConnectionID
}

func (c *connection) ID() string {
return c.id
}
Expand All @@ -562,6 +614,14 @@ func (c *connection) ServerConnectionID() *int64 {
return c.serverConnectionID
}

func (c *connection) OIDCTokenGenID() uint64 {
return c.oidcTokenGenID
}

func (c *connection) SetOIDCTokenGenID(genID uint64) {
c.oidcTokenGenID = genID
}

// initConnection is an adapter used during connection initialization. It has the minimum
// functionality necessary to implement the driver.Connection interface, which is required to pass a
// *connection to a Handshaker.
Expand Down Expand Up @@ -599,7 +659,7 @@ func (c initConnection) CurrentlyStreaming() bool {
return c.getCurrentlyStreaming()
}
func (c initConnection) SupportsStreaming() bool {
return c.supportsStreaming()
return c.canStream
}

// Connection implements the driver.Connection interface to allow reading and writing wire
Expand Down Expand Up @@ -833,39 +893,6 @@ func (c *Connection) DriverConnectionID() uint64 {
return c.connection.DriverConnectionID()
}

func configureTLS(ctx context.Context,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved closer to the caller.

tlsConnSource tlsConnectionSource,
nc net.Conn,
addr address.Address,
config *tls.Config,
ocspOpts *ocsp.VerifyOptions,
) (net.Conn, error) {
// Ensure config.ServerName is always set for SNI.
if config.ServerName == "" {
hostname := addr.String()
colonPos := strings.LastIndex(hostname, ":")
if colonPos == -1 {
colonPos = len(hostname)
}

hostname = hostname[:colonPos]
config.ServerName = hostname
}

client := tlsConnSource.Client(nc, config)
if err := clientHandshake(ctx, client); err != nil {
return nil, err
}

// Only do OCSP verification if TLS verification is requested.
if !config.InsecureSkipVerify {
if ocspErr := ocsp.Verify(ctx, client.ConnectionState(), ocspOpts); ocspErr != nil {
return nil, ocspErr
}
}
return client, nil
}

// OIDCTokenGenID returns the OIDC token generation ID.
func (c *Connection) OIDCTokenGenID() uint64 {
return c.oidcTokenGenID
Expand Down Expand Up @@ -919,11 +946,3 @@ func (c *cancellListener) StopListening() bool {
c.done <- struct{}{}
return c.aborted
}

func (c *connection) OIDCTokenGenID() uint64 {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved closer with other *connection methods.

return c.oidcTokenGenID
}

func (c *connection) SetOIDCTokenGenID(genID uint64) {
c.oidcTokenGenID = genID
}
17 changes: 17 additions & 0 deletions x/mongo/driver/topology/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,23 @@ func TestConnection(t *testing.T) {
}
listener.assertCalledOnce(t)
})
t.Run("size too small errors", func(t *testing.T) {
err := errors.New("malformed message length: 3")
tnc := &testNetConn{readerr: err, buf: []byte{0x03, 0x00, 0x00, 0x00}}
conn := &connection{id: "foobar", nc: tnc, state: connConnected}
listener := newTestCancellationListener(false)
conn.cancellationListener = listener

want := ConnectionError{ConnectionID: "foobar", Wrapped: err, message: err.Error()}
_, got := conn.readWireMessage(context.Background())
if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
t.Errorf("errors do not match. got %v; want %v", got, want)
}
if !tnc.closed {
t.Errorf("failed to closeConnection net.Conn after error writing bytes.")
}
listener.assertCalledOnce(t)
})
t.Run("full message read errors", func(t *testing.T) {
err := errors.New("Read error")
tnc := &testNetConn{readerr: err, buf: []byte{0x11, 0x00, 0x00, 0x00}}
Expand Down
Loading
Loading