From c5ef8c99685377c418e1dece5f779c62dcae04e2 Mon Sep 17 00:00:00 2001 From: "jeremy.spriet" Date: Tue, 19 Dec 2023 18:19:43 +0100 Subject: [PATCH] feat(pgproto3/backend): add a SetMaxBodyLen to limit the max body length for the receive --- pgproto3/backend.go | 13 +++++++++++++ pgproto3/backend_test.go | 18 ++++++++++++++++++ pgproto3/pgproto3.go | 9 +++++++++ 3 files changed, 40 insertions(+) diff --git a/pgproto3/backend.go b/pgproto3/backend.go index 6db77e4a2..efa909c3a 100644 --- a/pgproto3/backend.go +++ b/pgproto3/backend.go @@ -38,6 +38,7 @@ type Backend struct { terminate Terminate bodyLen int + maxBodyLen int // maxBodyLen is the maximum length of a message body in octets. If a message body exceeds this length, Receive will return an error. msgType byte partialMsg bool authType uint32 @@ -158,6 +159,9 @@ func (b *Backend) Receive() (FrontendMessage, error) { b.msgType = header[0] b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4 + if b.maxBodyLen > 0 && b.bodyLen > b.maxBodyLen { + return nil, &ExceededMaxBodyLenErr{b.maxBodyLen, b.bodyLen} + } b.partialMsg = true } @@ -260,3 +264,12 @@ func (b *Backend) SetAuthType(authType uint32) error { return nil } + +// SetMaxBodyLen sets the maximum length of a message body in octets. If a message body exceeds this length, Receive will return +// an error. This is useful for protecting against malicious clients that send large messages with the intent of +// causing memory exhaustion. +// The default value is 0. +// If maxBodyLen is 0, then no maximum is enforced. +func (b *Backend) SetMaxBodyLen(maxBodyLen int) { + b.maxBodyLen = maxBodyLen +} diff --git a/pgproto3/backend_test.go b/pgproto3/backend_test.go index 596245ddf..5655122a8 100644 --- a/pgproto3/backend_test.go +++ b/pgproto3/backend_test.go @@ -120,3 +120,21 @@ func TestStartupMessage(t *testing.T) { } }) } + +func TestBackendReceiveExceededMaxBodyLen(t *testing.T) { + t.Parallel() + + server := &interruptReader{} + server.push([]byte{'Q', 0, 0, 10, 10}) + + backend := pgproto3.NewBackend(server, nil) + + // Set max body len to 5 + backend.SetMaxBodyLen(5) + + // Receive regular msg + msg, err := backend.Receive() + assert.Nil(t, msg) + var invalidBodyLenErr *pgproto3.ExceededMaxBodyLenErr + assert.ErrorAs(t, err, &invalidBodyLenErr) +} diff --git a/pgproto3/pgproto3.go b/pgproto3/pgproto3.go index ef5a54896..04be291cb 100644 --- a/pgproto3/pgproto3.go +++ b/pgproto3/pgproto3.go @@ -70,6 +70,15 @@ func (e *writeError) Unwrap() error { return e.err } +type ExceededMaxBodyLenErr struct { + maxExpectedBodyLen int + actualBodyLen int +} + +func (e *ExceededMaxBodyLenErr) Error() string { + return fmt.Sprintf("invalid body length: expected at most %d, but got %d", e.maxExpectedBodyLen, e.actualBodyLen) +} + // getValueFromJSON gets the value from a protocol message representation in JSON. func getValueFromJSON(v map[string]string) ([]byte, error) { if v == nil {