-
Notifications
You must be signed in to change notification settings - Fork 892
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GODRIVER-3302 Handle malformatted message length properly. #1758
Changes from 4 commits
93d8ee1
be66ac6
ddbd3e9
b565a23
2d988ed
9bea0eb
03aa027
39f3021
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -79,9 +79,9 @@ type connection struct { | |||||||||||
driverConnectionID uint64 | ||||||||||||
generation uint64 | ||||||||||||
|
||||||||||||
// awaitingResponse indicates that the server response was not completely | ||||||||||||
// awaitRemainingBytes indicates the size of server response that was not completely | ||||||||||||
// read before returning the connection to the pool. | ||||||||||||
awaitingResponse bool | ||||||||||||
awaitRemainingBytes *int32 | ||||||||||||
|
||||||||||||
// oidcTokenGenID is the monotonic generation ID for OIDC tokens, used to invalidate | ||||||||||||
// accessTokens in the OIDC authenticator cache. | ||||||||||||
|
@@ -115,12 +115,6 @@ func newConnection(addr address.Address, opts ...ConnectionOption) *connection { | |||||||||||
return c | ||||||||||||
} | ||||||||||||
|
||||||||||||
// DriverConnectionID returns the driver connection ID. | ||||||||||||
// TODO(GODRIVER-2824): change return type to int64. | ||||||||||||
func (c *connection) DriverConnectionID() uint64 { | ||||||||||||
return c.driverConnectionID | ||||||||||||
} | ||||||||||||
|
||||||||||||
// setGenerationNumber sets the connection's generation number if a callback has been provided to do so in connection | ||||||||||||
// configuration. | ||||||||||||
func (c *connection) setGenerationNumber() { | ||||||||||||
|
@@ -142,6 +136,39 @@ func (c *connection) hasGenerationNumber() bool { | |||||||||||
return c.desc.LoadBalanced() | ||||||||||||
} | ||||||||||||
|
||||||||||||
func configureTLS(ctx context.Context, | ||||||||||||
tlsConnSource tlsConnectionSource, | ||||||||||||
nc net.Conn, | ||||||||||||
addr address.Address, | ||||||||||||
config *tls.Config, | ||||||||||||
ocspOpts *ocsp.VerifyOptions, | ||||||||||||
) (net.Conn, error) { | ||||||||||||
// Ensure config.ServerName is always set for SNI. | ||||||||||||
if config.ServerName == "" { | ||||||||||||
hostname := addr.String() | ||||||||||||
colonPos := strings.LastIndex(hostname, ":") | ||||||||||||
if colonPos == -1 { | ||||||||||||
colonPos = len(hostname) | ||||||||||||
} | ||||||||||||
|
||||||||||||
hostname = hostname[:colonPos] | ||||||||||||
config.ServerName = hostname | ||||||||||||
} | ||||||||||||
|
||||||||||||
client := tlsConnSource.Client(nc, config) | ||||||||||||
if err := clientHandshake(ctx, client); err != nil { | ||||||||||||
return nil, err | ||||||||||||
} | ||||||||||||
|
||||||||||||
// Only do OCSP verification if TLS verification is requested. | ||||||||||||
if !config.InsecureSkipVerify { | ||||||||||||
if ocspErr := ocsp.Verify(ctx, client.ConnectionState(), ocspOpts); ocspErr != nil { | ||||||||||||
return nil, ocspErr | ||||||||||||
} | ||||||||||||
} | ||||||||||||
return client, nil | ||||||||||||
} | ||||||||||||
|
||||||||||||
// connect handles the I/O for a connection. It will dial, configure TLS, and perform initialization | ||||||||||||
// handshakes. All errors returned by connect are considered "before the handshake completes" and | ||||||||||||
// must be handled by calling the appropriate SDAM handshake error handler. | ||||||||||||
|
@@ -317,6 +344,10 @@ func (c *connection) closeConnectContext() { | |||||||||||
} | ||||||||||||
} | ||||||||||||
|
||||||||||||
func (c *connection) cancellationListenerCallback() { | ||||||||||||
_ = c.close() | ||||||||||||
} | ||||||||||||
|
||||||||||||
func transformNetworkError(ctx context.Context, originalError error, contextDeadlineUsed bool) error { | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moved |
||||||||||||
if originalError == nil { | ||||||||||||
return nil | ||||||||||||
|
@@ -339,10 +370,6 @@ func transformNetworkError(ctx context.Context, originalError error, contextDead | |||||||||||
return originalError | ||||||||||||
} | ||||||||||||
|
||||||||||||
func (c *connection) cancellationListenerCallback() { | ||||||||||||
_ = c.close() | ||||||||||||
} | ||||||||||||
|
||||||||||||
func (c *connection) writeWireMessage(ctx context.Context, wm []byte) error { | ||||||||||||
var err error | ||||||||||||
if atomic.LoadInt64(&c.state) != connConnected { | ||||||||||||
|
@@ -423,15 +450,10 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) { | |||||||||||
|
||||||||||||
dst, errMsg, err := c.read(ctx) | ||||||||||||
if err != nil { | ||||||||||||
if nerr := net.Error(nil); errors.As(err, &nerr) && nerr.Timeout() && csot.IsTimeoutContext(ctx) { | ||||||||||||
// If the error was a timeout error and CSOT is enabled, instead of | ||||||||||||
// closing the connection mark it as awaiting response so the pool | ||||||||||||
// can read the response before making it available to other | ||||||||||||
// operations. | ||||||||||||
c.awaitingResponse = true | ||||||||||||
} else { | ||||||||||||
// Otherwise, use the pre-CSOT behavior and close the connection | ||||||||||||
// because we don't know if there are other bytes left to read. | ||||||||||||
if c.awaitRemainingBytes == nil { | ||||||||||||
// If the connection was not marked as awaiting response, use the | ||||||||||||
// pre-CSOT behavior and close the connection because we don't know | ||||||||||||
// if there are other bytes left to read. | ||||||||||||
c.close() | ||||||||||||
} | ||||||||||||
message := errMsg | ||||||||||||
|
@@ -448,6 +470,29 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) { | |||||||||||
return dst, nil | ||||||||||||
} | ||||||||||||
|
||||||||||||
func (c *connection) parseWmSizeBytes(wmSizeBytes [4]byte) (int32, error) { | ||||||||||||
// read the length as an int32 | ||||||||||||
size := (int32(wmSizeBytes[0])) | | ||||||||||||
(int32(wmSizeBytes[1]) << 8) | | ||||||||||||
(int32(wmSizeBytes[2]) << 16) | | ||||||||||||
(int32(wmSizeBytes[3]) << 24) | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Optional: Use the
Suggested change
|
||||||||||||
|
||||||||||||
if size < 4 { | ||||||||||||
return 0, fmt.Errorf("malformed message length: %d", size) | ||||||||||||
} | ||||||||||||
// In the case of a hello response where MaxMessageSize has not yet been set, use the hard-coded | ||||||||||||
// defaultMaxMessageSize instead. | ||||||||||||
maxMessageSize := c.desc.MaxMessageSize | ||||||||||||
if maxMessageSize == 0 { | ||||||||||||
maxMessageSize = defaultMaxMessageSize | ||||||||||||
} | ||||||||||||
if uint32(size) > maxMessageSize { | ||||||||||||
return 0, errResponseTooLarge | ||||||||||||
} | ||||||||||||
|
||||||||||||
return size, nil | ||||||||||||
} | ||||||||||||
|
||||||||||||
func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string, err error) { | ||||||||||||
go c.cancellationListener.Listen(ctx, c.cancellationListenerCallback) | ||||||||||||
defer func() { | ||||||||||||
|
@@ -461,36 +506,43 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string, | |||||||||||
} | ||||||||||||
}() | ||||||||||||
|
||||||||||||
needToWait := func(err error) bool { | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Optional: Consider a more descriptive name, like |
||||||||||||
// If the error was a timeout error and CSOT is enabled, instead of | ||||||||||||
// closing the connection mark it as awaiting response so the pool | ||||||||||||
// can read the response before making it available to other | ||||||||||||
// operations. | ||||||||||||
nerr := net.Error(nil) | ||||||||||||
return errors.As(err, &nerr) && nerr.Timeout() && csot.IsTimeoutContext(ctx) | ||||||||||||
} | ||||||||||||
|
||||||||||||
// We use an array here because it only costs 4 bytes on the stack and means we'll only need to | ||||||||||||
// reslice dst once instead of twice. | ||||||||||||
var sizeBuf [4]byte | ||||||||||||
|
||||||||||||
// We do a ReadFull into an array here instead of doing an opportunistic ReadAtLeast into dst | ||||||||||||
// because there might be more than one wire message waiting to be read, for example when | ||||||||||||
// reading messages from an exhaust cursor. | ||||||||||||
_, err = io.ReadFull(c.nc, sizeBuf[:]) | ||||||||||||
n, err := io.ReadFull(c.nc, sizeBuf[:]) | ||||||||||||
if err != nil { | ||||||||||||
if l := int32(n); l == 0 && needToWait(err) { | ||||||||||||
c.awaitRemainingBytes = &l | ||||||||||||
} | ||||||||||||
return nil, "incomplete read of message header", err | ||||||||||||
} | ||||||||||||
|
||||||||||||
// read the length as an int32 | ||||||||||||
size := (int32(sizeBuf[0])) | (int32(sizeBuf[1]) << 8) | (int32(sizeBuf[2]) << 16) | (int32(sizeBuf[3]) << 24) | ||||||||||||
|
||||||||||||
// In the case of a hello response where MaxMessageSize has not yet been set, use the hard-coded | ||||||||||||
// defaultMaxMessageSize instead. | ||||||||||||
maxMessageSize := c.desc.MaxMessageSize | ||||||||||||
if maxMessageSize == 0 { | ||||||||||||
maxMessageSize = defaultMaxMessageSize | ||||||||||||
} | ||||||||||||
if uint32(size) > maxMessageSize { | ||||||||||||
return nil, errResponseTooLarge.Error(), errResponseTooLarge | ||||||||||||
size, err := c.parseWmSizeBytes(sizeBuf) | ||||||||||||
if err != nil { | ||||||||||||
return nil, err.Error(), err | ||||||||||||
} | ||||||||||||
|
||||||||||||
dst := make([]byte, size) | ||||||||||||
copy(dst, sizeBuf[:]) | ||||||||||||
|
||||||||||||
_, err = io.ReadFull(c.nc, dst[4:]) | ||||||||||||
n, err = io.ReadFull(c.nc, dst[4:]) | ||||||||||||
if err != nil { | ||||||||||||
remainingBytes := size - 4 - int32(n) | ||||||||||||
if remainingBytes > 0 && needToWait(err) { | ||||||||||||
c.awaitRemainingBytes = &remainingBytes | ||||||||||||
} | ||||||||||||
return dst, "incomplete read of full message", err | ||||||||||||
} | ||||||||||||
|
||||||||||||
|
@@ -537,10 +589,6 @@ func (c *connection) setCanStream(canStream bool) { | |||||||||||
c.canStream = canStream | ||||||||||||
} | ||||||||||||
|
||||||||||||
func (c initConnection) supportsStreaming() bool { | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Merged in |
||||||||||||
return c.canStream | ||||||||||||
} | ||||||||||||
|
||||||||||||
func (c *connection) setStreaming(streaming bool) { | ||||||||||||
c.currentlyStreaming = streaming | ||||||||||||
} | ||||||||||||
|
@@ -554,6 +602,12 @@ func (c *connection) setSocketTimeout(timeout time.Duration) { | |||||||||||
c.writeTimeout = timeout | ||||||||||||
} | ||||||||||||
|
||||||||||||
// DriverConnectionID returns the driver connection ID. | ||||||||||||
// TODO(GODRIVER-2824): change return type to int64. | ||||||||||||
func (c *connection) DriverConnectionID() uint64 { | ||||||||||||
return c.driverConnectionID | ||||||||||||
} | ||||||||||||
|
||||||||||||
func (c *connection) ID() string { | ||||||||||||
return c.id | ||||||||||||
} | ||||||||||||
|
@@ -562,6 +616,14 @@ func (c *connection) ServerConnectionID() *int64 { | |||||||||||
return c.serverConnectionID | ||||||||||||
} | ||||||||||||
|
||||||||||||
func (c *connection) OIDCTokenGenID() uint64 { | ||||||||||||
return c.oidcTokenGenID | ||||||||||||
} | ||||||||||||
|
||||||||||||
func (c *connection) SetOIDCTokenGenID(genID uint64) { | ||||||||||||
c.oidcTokenGenID = genID | ||||||||||||
} | ||||||||||||
|
||||||||||||
// initConnection is an adapter used during connection initialization. It has the minimum | ||||||||||||
// functionality necessary to implement the driver.Connection interface, which is required to pass a | ||||||||||||
// *connection to a Handshaker. | ||||||||||||
|
@@ -599,7 +661,7 @@ func (c initConnection) CurrentlyStreaming() bool { | |||||||||||
return c.getCurrentlyStreaming() | ||||||||||||
} | ||||||||||||
func (c initConnection) SupportsStreaming() bool { | ||||||||||||
return c.supportsStreaming() | ||||||||||||
return c.canStream | ||||||||||||
} | ||||||||||||
|
||||||||||||
// Connection implements the driver.Connection interface to allow reading and writing wire | ||||||||||||
|
@@ -833,39 +895,6 @@ func (c *Connection) DriverConnectionID() uint64 { | |||||||||||
return c.connection.DriverConnectionID() | ||||||||||||
} | ||||||||||||
|
||||||||||||
func configureTLS(ctx context.Context, | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moved closer to the caller. |
||||||||||||
tlsConnSource tlsConnectionSource, | ||||||||||||
nc net.Conn, | ||||||||||||
addr address.Address, | ||||||||||||
config *tls.Config, | ||||||||||||
ocspOpts *ocsp.VerifyOptions, | ||||||||||||
) (net.Conn, error) { | ||||||||||||
// Ensure config.ServerName is always set for SNI. | ||||||||||||
if config.ServerName == "" { | ||||||||||||
hostname := addr.String() | ||||||||||||
colonPos := strings.LastIndex(hostname, ":") | ||||||||||||
if colonPos == -1 { | ||||||||||||
colonPos = len(hostname) | ||||||||||||
} | ||||||||||||
|
||||||||||||
hostname = hostname[:colonPos] | ||||||||||||
config.ServerName = hostname | ||||||||||||
} | ||||||||||||
|
||||||||||||
client := tlsConnSource.Client(nc, config) | ||||||||||||
if err := clientHandshake(ctx, client); err != nil { | ||||||||||||
return nil, err | ||||||||||||
} | ||||||||||||
|
||||||||||||
// Only do OCSP verification if TLS verification is requested. | ||||||||||||
if !config.InsecureSkipVerify { | ||||||||||||
if ocspErr := ocsp.Verify(ctx, client.ConnectionState(), ocspOpts); ocspErr != nil { | ||||||||||||
return nil, ocspErr | ||||||||||||
} | ||||||||||||
} | ||||||||||||
return client, nil | ||||||||||||
} | ||||||||||||
|
||||||||||||
// OIDCTokenGenID returns the OIDC token generation ID. | ||||||||||||
func (c *Connection) OIDCTokenGenID() uint64 { | ||||||||||||
return c.oidcTokenGenID | ||||||||||||
|
@@ -919,11 +948,3 @@ func (c *cancellListener) StopListening() bool { | |||||||||||
c.done <- struct{}{} | ||||||||||||
return c.aborted | ||||||||||||
} | ||||||||||||
|
||||||||||||
func (c *connection) OIDCTokenGenID() uint64 { | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moved closer with other |
||||||||||||
return c.oidcTokenGenID | ||||||||||||
} | ||||||||||||
|
||||||||||||
func (c *connection) SetOIDCTokenGenID(genID uint64) { | ||||||||||||
c.oidcTokenGenID = genID | ||||||||||||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved
DriverConnectionID
down with other public methods.