Skip to content

Commit

Permalink
add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jimjbrettj committed Oct 8, 2024
1 parent e7c7a03 commit 7c17435
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 5 deletions.
27 changes: 22 additions & 5 deletions lib/grandpa/neighbor_tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,22 @@ func (nt *NeighborTracker) run() {

case block := <-nt.finalizationCha:
if block != nil {
nt.UpdateState(block.SetID, block.Round, uint32(block.Header.Number))
err := nt.BroadcastNeighborMsg()
err := nt.UpdateState(block.SetID, block.Round, uint32(block.Header.Number))
if err != nil {
logger.Errorf("updating neighbor state: %v", err)
}
err = nt.BroadcastNeighborMsg()
if err != nil {
logger.Errorf("broadcasting neighbor message: %v", err)
}
ticker.Reset(duration)
}
case neighborData := <-nt.neighborMsgChan:
if neighborData.neighborMsg.Number > nt.peerview[neighborData.peer].highestFinalized {
nt.UpdatePeer(neighborData.peer, neighborData.neighborMsg.SetID, neighborData.neighborMsg.Round, neighborData.neighborMsg.Number)
err := nt.UpdatePeer(neighborData.peer, neighborData.neighborMsg.SetID, neighborData.neighborMsg.Round, neighborData.neighborMsg.Number)
if err != nil {
logger.Errorf("updating neighbor: %v", err)
}
}
case <-nt.stoppedNeighbor:
logger.Info("stopping neighbor tracker")
Expand All @@ -88,15 +94,26 @@ func (nt *NeighborTracker) run() {
}
}

func (nt *NeighborTracker) UpdateState(setID uint64, round uint64, highestFinalized uint32) {
func (nt *NeighborTracker) UpdateState(setID uint64, round uint64, highestFinalized uint32) error {
if nt == nil {
return fmt.Errorf("neighbor tracker is nil")
}
nt.currentSetID = setID
nt.currentRound = round
nt.highestFinalized = highestFinalized
return nil
}

func (nt *NeighborTracker) UpdatePeer(p peer.ID, setID uint64, round uint64, highestFinalized uint32) {
func (nt *NeighborTracker) UpdatePeer(p peer.ID, setID uint64, round uint64, highestFinalized uint32) error {
if nt == nil {
return fmt.Errorf("neighbor tracker is nil")
}
if nt.peerview == nil {
return fmt.Errorf("neighbour tracker has nil peer tracker")
}
peerState := neighborState{setID, round, highestFinalized}
nt.peerview[p] = peerState
return nil
}

func (nt *NeighborTracker) BroadcastNeighborMsg() error {
Expand Down
141 changes: 141 additions & 0 deletions lib/grandpa/neighbor_tracker_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
package grandpa

import (
"fmt"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/stretchr/testify/require"
"testing"
)

func TestNeighborTracker_UpdatePeer(t *testing.T) {
initPeerview := map[peer.ID]neighborState{}
initPeerview["testPeer"] = neighborState{
setID: 1,
round: 2,
highestFinalized: 3,
}
type args struct {
p peer.ID
setID uint64
round uint64
highestFinalized uint32
}
tests := []struct {
name string
tracker *NeighborTracker
args args
expectedState neighborState
expectedErr error
}{
{
name: "simple update",
tracker: &NeighborTracker{
peerview: map[peer.ID]neighborState{},
},
args: args{
p: "testPeer",
setID: 1,
round: 2,
highestFinalized: 3,
},
expectedState: neighborState{
setID: 1,
round: 2,
highestFinalized: 3,
},
},
{
name: "nil peerview",
tracker: &NeighborTracker{},
args: args{
p: "testPeer",
setID: 1,
round: 2,
highestFinalized: 3,
},
expectedErr: fmt.Errorf("neighbour tracker has nil peer tracker"),
},
{
name: "updating existing peer",
tracker: &NeighborTracker{
peerview: map[peer.ID]neighborState{},
},
args: args{
p: "testPeer",
setID: 4,
round: 5,
highestFinalized: 6,
},
expectedState: neighborState{
setID: 4,
round: 5,
highestFinalized: 6,
},
},
{
name: "nil tracker",
args: args{
p: "testPeer",
setID: 1,
round: 2,
highestFinalized: 3,
},
expectedErr: fmt.Errorf("neighbor tracker is nil"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
nt := tt.tracker
err := nt.UpdatePeer(tt.args.p, tt.args.setID, tt.args.round, tt.args.highestFinalized)
require.Equal(t, err, tt.expectedErr)
if nt != nil {
require.Equal(t, tt.expectedState, nt.peerview[tt.args.p])
}
})
}
}

func TestNeighborTracker_UpdateState(t *testing.T) {
type args struct {
setID uint64
round uint64
highestFinalized uint32
}
tests := []struct {
name string
tracker *NeighborTracker
args args
expectedErr error
}{
{
name: "nil tracker",
args: args{
setID: 1,
round: 2,
highestFinalized: 3,
},
expectedErr: fmt.Errorf("neighbor tracker is nil"),
},
{
name: "happy path",
tracker: &NeighborTracker{},
args: args{
setID: 1,
round: 2,
highestFinalized: 3,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
nt := tt.tracker
err := nt.UpdateState(tt.args.setID, tt.args.round, tt.args.highestFinalized)
require.Equal(t, err, tt.expectedErr)
if nt != nil {
require.Equal(t, nt.currentSetID, tt.args.setID)
require.Equal(t, nt.currentRound, tt.args.round)
require.Equal(t, nt.highestFinalized, tt.args.highestFinalized)
}
})
}
}

0 comments on commit 7c17435

Please sign in to comment.