Skip to content

Commit

Permalink
Merge pull request #588 from matheusd/some-fixes
Browse files Browse the repository at this point in the history
Some test and general fixes
  • Loading branch information
lthibault authored Aug 20, 2024
2 parents c7c3a76 + cf2a39a commit 396906c
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 6 deletions.
7 changes: 7 additions & 0 deletions codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,10 @@ func MustUnmarshalRoot(data []byte) Ptr {
return p
}

var (
errTooManySegments = errors.New("message has too many segments")
)

// An Encoder represents a framer for serializing a particular Cap'n
// Proto stream.
type Encoder struct {
Expand All @@ -220,6 +224,9 @@ func (e *Encoder) Encode(m *Message) error {
if nsegs == 0 {
return errors.New("encode: message has no segments")
}
if nsegs > 1<<32 {
return exc.WrapError("encode", errTooManySegments)
}
e.bufs = append(e.bufs[:0], nil) // first element is placeholder for header
maxSeg := SegmentID(nsegs - 1)
hdrSize := streamHeaderSize(maxSeg)
Expand Down
37 changes: 37 additions & 0 deletions codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ package capnp

import (
"bytes"
"errors"
"io"
"testing"

"github.com/stretchr/testify/require"
)

func TestEncoder(t *testing.T) {
Expand Down Expand Up @@ -72,6 +75,40 @@ func TestDecoder(t *testing.T) {
}
}

type tooManySegsArena struct {
data []byte
}

func (t *tooManySegsArena) NumSegments() int64 { return 1<<32 + 1 }

func (t *tooManySegsArena) Data(id SegmentID) ([]byte, error) {
return nil, errors.New("no data")
}

func (t *tooManySegsArena) Allocate(minsz Size, segs map[SegmentID]*Segment) (SegmentID, []byte, error) {
return 0, nil, errors.New("cannot allocate")
}

func (t *tooManySegsArena) Release() {}

// TestEncoderTooManySegments verifies attempting to encode an arena that has
// more segments than possible.
func TestEncoderTooManySegments(t *testing.T) {
t.Parallel()
zeroWord := []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}
arena := &tooManySegsArena{data: zeroWord}

// Setup via field because NewMessage checks arena has > 1 segments.
var msg Message
msg.Arena = arena
var buf bytes.Buffer
enc := NewEncoder(&buf)
err := enc.Encode(&msg)

// Encoding should error with a specific error.
require.ErrorIs(t, err, errTooManySegments)
}

func TestDecoder_MaxMessageSize(t *testing.T) {
t.Parallel()

Expand Down
2 changes: 1 addition & 1 deletion integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1770,7 +1770,7 @@ func BenchmarkMarshal_ReuseMsg(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
a := data[r.Intn(len(data))]
seg, err := msg.Reset(msg.Arena)
seg, err := msg.Reset(capnp.SingleSegment(nil))
if err != nil {
b.Fatal(err)
}
Expand Down
2 changes: 1 addition & 1 deletion message_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -646,11 +646,11 @@ var errReadOnlyArena = errors.New("Allocate called on read-only arena")

func BenchmarkMessageGetFirstSegment(b *testing.B) {
var msg Message
var arena Arena = SingleSegment(nil)

b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
arena := SingleSegment(nil)
_, err := msg.Reset(arena)
if err != nil {
b.Fatal(err)
Expand Down
10 changes: 8 additions & 2 deletions rpc/transport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,10 @@ type outgoingMsg struct {
}

func (o *outgoingMsg) Release() {
if m := o.message.Message(); !o.released && m != nil {
if o.released {
return
}
if m := o.message.Message(); m != nil {
o.released = true
m.Release()
}
Expand Down Expand Up @@ -246,7 +249,10 @@ func (i *incomingMsg) Message() rpccp.Message {
}

func (i *incomingMsg) Release() {
if m := i.Message().Message(); !i.released && m != nil {
if i.released {
return
}
if m := i.Message().Message(); m != nil {
i.released = true
m.Release()
}
Expand Down
2 changes: 0 additions & 2 deletions rpc/transport/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,10 @@ func testTransport(t *testing.T, makePipe func() (t1, t2 Transport, err error))
if err != nil {
t.Fatal("t1.NewMessage #1:", err)
}
defer callMsg.Release()
bootMsg, err := t1.NewMessage()
if err != nil {
t.Fatal("t1.NewMessage #2:", err)
}
defer bootMsg.Release()

// Fill in bootstrap message
boot, err := bootMsg.Message().NewBootstrap()
Expand Down

0 comments on commit 396906c

Please sign in to comment.