-
Notifications
You must be signed in to change notification settings - Fork 892
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GODRIVER-3302 Handle malformatted message length properly. #1758
Changes from 3 commits
93d8ee1
be66ac6
ddbd3e9
b565a23
2d988ed
9bea0eb
03aa027
39f3021
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -79,9 +79,9 @@ type connection struct { | |
driverConnectionID uint64 | ||
generation uint64 | ||
|
||
// awaitingResponse indicates that the server response was not completely | ||
// awaitingResponse indicates the size of server response that was not completely | ||
// read before returning the connection to the pool. | ||
awaitingResponse bool | ||
awaitingResponse *int32 | ||
|
||
// oidcTokenGenID is the monotonic generation ID for OIDC tokens, used to invalidate | ||
// accessTokens in the OIDC authenticator cache. | ||
|
@@ -423,15 +423,10 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) { | |
|
||
dst, errMsg, err := c.read(ctx) | ||
if err != nil { | ||
if nerr := net.Error(nil); errors.As(err, &nerr) && nerr.Timeout() && csot.IsTimeoutContext(ctx) { | ||
// If the error was a timeout error and CSOT is enabled, instead of | ||
// closing the connection mark it as awaiting response so the pool | ||
// can read the response before making it available to other | ||
// operations. | ||
c.awaitingResponse = true | ||
} else { | ||
// Otherwise, use the pre-CSOT behavior and close the connection | ||
// because we don't know if there are other bytes left to read. | ||
if c.awaitingResponse == nil { | ||
// If the connection was not marked as awaiting response, use the | ||
// pre-CSOT behavior and close the connection because we don't know | ||
// if there are other bytes left to read. | ||
c.close() | ||
} | ||
message := errMsg | ||
|
@@ -461,21 +456,37 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string, | |
} | ||
}() | ||
|
||
needToWait := func(err error) bool { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Optional: Consider a more descriptive name, like |
||
// If the error was a timeout error and CSOT is enabled, instead of | ||
// closing the connection mark it as awaiting response so the pool | ||
// can read the response before making it available to other | ||
// operations. | ||
nerr := net.Error(nil) | ||
return errors.As(err, &nerr) && nerr.Timeout() && csot.IsTimeoutContext(ctx) | ||
} | ||
|
||
// We use an array here because it only costs 4 bytes on the stack and means we'll only need to | ||
// reslice dst once instead of twice. | ||
var sizeBuf [4]byte | ||
|
||
// We do a ReadFull into an array here instead of doing an opportunistic ReadAtLeast into dst | ||
// because there might be more than one wire message waiting to be read, for example when | ||
// reading messages from an exhaust cursor. | ||
_, err = io.ReadFull(c.nc, sizeBuf[:]) | ||
n, err := io.ReadFull(c.nc, sizeBuf[:]) | ||
if err != nil { | ||
if l := int32(n); l == 0 && needToWait(err) { | ||
c.awaitingResponse = &l | ||
} | ||
return nil, "incomplete read of message header", err | ||
} | ||
|
||
// read the length as an int32 | ||
size := (int32(sizeBuf[0])) | (int32(sizeBuf[1]) << 8) | (int32(sizeBuf[2]) << 16) | (int32(sizeBuf[3]) << 24) | ||
|
||
if size < 4 { | ||
err = fmt.Errorf("malformed message length: %d", size) | ||
return nil, err.Error(), err | ||
} | ||
// In the case of a hello response where MaxMessageSize has not yet been set, use the hard-coded | ||
// defaultMaxMessageSize instead. | ||
maxMessageSize := c.desc.MaxMessageSize | ||
|
@@ -489,8 +500,11 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string, | |
dst := make([]byte, size) | ||
copy(dst, sizeBuf[:]) | ||
|
||
_, err = io.ReadFull(c.nc, dst[4:]) | ||
n, err = io.ReadFull(c.nc, dst[4:]) | ||
if err != nil { | ||
if l := size - 4 - int32(n); l > 0 && needToWait(err) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Suggest renaming |
||
c.awaitingResponse = &l | ||
} | ||
return dst, "incomplete read of full message", err | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -9,6 +9,8 @@ package topology | |||||
import ( | ||||||
"context" | ||||||
"fmt" | ||||||
"io" | ||||||
"io/ioutil" | ||||||
"net" | ||||||
"sync" | ||||||
"sync/atomic" | ||||||
|
@@ -788,17 +790,27 @@ var ( | |||||
// | ||||||
// It calls the package-global BGReadCallback function, if set, with the | ||||||
// address, timings, and any errors that occurred. | ||||||
func bgRead(pool *pool, conn *connection) { | ||||||
var start, read time.Time | ||||||
start = time.Now() | ||||||
errs := make([]error, 0) | ||||||
connClosed := false | ||||||
func bgRead(pool *pool, conn *connection, size int32) { | ||||||
var err error | ||||||
start := time.Now() | ||||||
|
||||||
defer func() { | ||||||
read := time.Now() | ||||||
errs := make([]error, 0) | ||||||
connClosed := false | ||||||
if err != nil { | ||||||
errs = append(errs, err) | ||||||
connClosed = true | ||||||
err = conn.close() | ||||||
if err != nil { | ||||||
errs = append(errs, fmt.Errorf("error closing conn after reading: %w", err)) | ||||||
} | ||||||
} | ||||||
|
||||||
// No matter what happens, always check the connection back into the | ||||||
// pool, which will either make it available for other operations or | ||||||
// remove it from the pool if it was closed. | ||||||
err := pool.checkInNoEvent(conn) | ||||||
err = pool.checkInNoEvent(conn) | ||||||
if err != nil { | ||||||
errs = append(errs, fmt.Errorf("error checking in: %w", err)) | ||||||
} | ||||||
|
@@ -808,34 +820,37 @@ func bgRead(pool *pool, conn *connection) { | |||||
} | ||||||
}() | ||||||
|
||||||
err := conn.nc.SetReadDeadline(time.Now().Add(BGReadTimeout)) | ||||||
err = conn.nc.SetReadDeadline(time.Now().Add(BGReadTimeout)) | ||||||
if err != nil { | ||||||
errs = append(errs, fmt.Errorf("error setting a read deadline: %w", err)) | ||||||
|
||||||
connClosed = true | ||||||
err := conn.close() | ||||||
if err != nil { | ||||||
errs = append(errs, fmt.Errorf("error closing conn after setting read deadline: %w", err)) | ||||||
} | ||||||
|
||||||
err = fmt.Errorf("error setting a read deadline: %w", err) | ||||||
return | ||||||
} | ||||||
|
||||||
// The context here is only used for cancellation, not deadline timeout, so | ||||||
// use context.Background(). The read timeout is set by calling | ||||||
// SetReadDeadline above. | ||||||
_, _, err = conn.read(context.Background()) | ||||||
read = time.Now() | ||||||
if err != nil { | ||||||
errs = append(errs, fmt.Errorf("error reading: %w", err)) | ||||||
|
||||||
connClosed = true | ||||||
err := conn.close() | ||||||
if size == 0 { | ||||||
var sizeBuf [4]byte | ||||||
_, err = io.ReadFull(conn.nc, sizeBuf[:]) | ||||||
Comment on lines
+829
to
+830
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Suggest making this logic a function that can be used here and in the connection read method: func readWMSize(r io.Reader) (int32, error) {
const wireMessageSizePrefix = 4
var wmSizeBytes [wireMessageSizePrefix]byte
if _, err := io.ReadFull(r, wmSizeBytes[:]); err != nil {
return 0, fmt.Errorf("error reading the message size: %w", err)
}
size := (int32(wmSizeBytes[0])) |
(int32(wmSizeBytes[1]) << 8) |
(int32(wmSizeBytes[2]) << 16) |
(int32(wmSizeBytes[3]) << 24)
if size < 4 {
return 0, fmt.Errorf("malformed message length: %d", size)
}
return size, nil
} |
||||||
if err != nil { | ||||||
errs = append(errs, fmt.Errorf("error closing conn after reading: %w", err)) | ||||||
err = fmt.Errorf("error reading the message size: %w", err) | ||||||
return | ||||||
} | ||||||
|
||||||
return | ||||||
size = (int32(sizeBuf[0])) | (int32(sizeBuf[1]) << 8) | (int32(sizeBuf[2]) << 16) | (int32(sizeBuf[3]) << 24) | ||||||
if size < 4 { | ||||||
err = fmt.Errorf("malformed message length: %d", size) | ||||||
return | ||||||
} | ||||||
maxMessageSize := conn.desc.MaxMessageSize | ||||||
if maxMessageSize == 0 { | ||||||
maxMessageSize = defaultMaxMessageSize | ||||||
} | ||||||
if uint32(size) > maxMessageSize { | ||||||
err = errResponseTooLarge | ||||||
return | ||||||
} | ||||||
size -= 4 | ||||||
} | ||||||
_, err = io.CopyN(ioutil.Discard, conn.nc, int64(size)) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
if err != nil { | ||||||
err = fmt.Errorf("error reading message of %d: %w", size, err) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Optional: This error message is a bit confusing. Consider a clearer error message.
Suggested change
|
||||||
} | ||||||
} | ||||||
|
||||||
|
@@ -886,9 +901,10 @@ func (p *pool) checkInNoEvent(conn *connection) error { | |||||
// means that connections in "awaiting response" state are checked in but | ||||||
// not usable, which is not covered by the current pool events. We may need | ||||||
// to add pool event information in the future to communicate that. | ||||||
if conn.awaitingResponse { | ||||||
conn.awaitingResponse = false | ||||||
go bgRead(p, conn) | ||||||
if conn.awaitingResponse != nil { | ||||||
size := *conn.awaitingResponse | ||||||
conn.awaitingResponse = nil | ||||||
go bgRead(p, conn, size) | ||||||
return nil | ||||||
} | ||||||
|
||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggest renaming
awaitingResponse
toawaitRemainingBytes
.