Skip to content

Commit

Permalink
feedback and lint
Browse files Browse the repository at this point in the history
  • Loading branch information
jimjbrettj committed Oct 15, 2024
1 parent 38ae1e9 commit 9b839a0
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 62 deletions.
5 changes: 3 additions & 2 deletions lib/grandpa/message_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,9 @@ func (h *MessageHandler) handleMessage(from peer.ID, m GrandpaMessage) (network.
}
}

func (h *MessageHandler) handleNeighbourMessage(packet *NeighbourPacketV1, from peer.ID) error {
logger.Debugf("handling neighbor message from peer %v with set id %v and round %v", from.ShortString(), packet.SetID, packet.Round)
func (h *MessageHandler) handleNeighbourMessage(packet *NeighbourPacketV1, from peer.ID) error { //nolint
logger.Debugf("handling neighbour message from peer %v with set id %v and round %v",
from.ShortString(), packet.SetID, packet.Round)
h.grandpa.neighborMsgChan <- neighborData{
peer: from,
neighborMsg: packet,
Expand Down
47 changes: 21 additions & 26 deletions lib/grandpa/neighbor_tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ package grandpa

import (
"fmt"
"time"

"github.com/ChainSafe/gossamer/dot/types"
"github.com/libp2p/go-libp2p/core/peer"
"time"
)

// NeighborBroadcastPeriod See https://github.com/paritytech/polkadot-sdk/blob/08498f5473351c3d2f8eacbe1bfd7bc6d3a2ef8d/substrate/client/consensus/grandpa/src/communication/mod.rs#L73
const NeighborBroadcastPeriod = time.Minute * 2
// https://github.com/paritytech/polkadot-sdk/blob/08498f5473351c3d2f8eacbe1bfd7bc6d3a2ef8d/substrate/client/consensus/grandpa/src/communication/mod.rs#L73 //nolint
const neighbourBroadcastPeriod = time.Minute * 2

type neighborData struct {
peer peer.ID
Expand Down Expand Up @@ -54,59 +55,54 @@ func (nt *NeighborTracker) Stop() {
}

func (nt *NeighborTracker) run() {
logger.Info("starting neighbor tracker")
ticker := time.NewTicker(NeighborBroadcastPeriod)
logger.Info("starting neighbour tracker")
ticker := time.NewTicker(neighbourBroadcastPeriod)
defer ticker.Stop()

for {
select {
case <-ticker.C:
logger.Debugf("neighbor message broadcast triggered by ticker")
logger.Debugf("neighbour message broadcast triggered by ticker")
err := nt.BroadcastNeighborMsg()
if err != nil {
logger.Errorf("broadcasting neighbor message: %v", err)
logger.Errorf("broadcasting neighbour message: %v", err)
}

case block := <-nt.finalizationCha:
if block != nil {
err := nt.UpdateState(block.SetID, block.Round, uint32(block.Header.Number))
if err != nil {
logger.Errorf("updating neighbor state: %v", err)
}
err = nt.BroadcastNeighborMsg()
nt.UpdateState(block.SetID, block.Round, uint32(block.Header.Number)) //nolint
err := nt.BroadcastNeighborMsg()
if err != nil {
logger.Errorf("broadcasting neighbor message: %v", err)
logger.Errorf("broadcasting neighbour message: %v", err)
}
ticker.Reset(NeighborBroadcastPeriod)
ticker.Reset(neighbourBroadcastPeriod)
}
case neighborData := <-nt.neighborMsgChan:
if neighborData.neighborMsg.Number > nt.peerview[neighborData.peer].highestFinalized {
err := nt.UpdatePeer(neighborData.peer, neighborData.neighborMsg.SetID, neighborData.neighborMsg.Round, neighborData.neighborMsg.Number)
err := nt.UpdatePeer(
neighborData.peer,
neighborData.neighborMsg.SetID,
neighborData.neighborMsg.Round,
neighborData.neighborMsg.Number,
)
if err != nil {
logger.Errorf("updating neighbor: %v", err)
logger.Errorf("updating neighbour: %v", err)
}
}
case <-nt.stoppedNeighbor:
logger.Info("stopping neighbor tracker")
logger.Info("stopping neighbour tracker")
return
}
}
}

func (nt *NeighborTracker) UpdateState(setID uint64, round uint64, highestFinalized uint32) error {
if nt == nil {
return fmt.Errorf("neighbor tracker is nil")
}
func (nt *NeighborTracker) UpdateState(setID uint64, round uint64, highestFinalized uint32) {
nt.currentSetID = setID
nt.currentRound = round
nt.highestFinalized = highestFinalized
return nil
}

func (nt *NeighborTracker) UpdatePeer(p peer.ID, setID uint64, round uint64, highestFinalized uint32) error {
if nt == nil {
return fmt.Errorf("neighbor tracker is nil")
}
if nt.peerview == nil {
return fmt.Errorf("neighbour tracker has nil peer tracker")
}
Expand All @@ -116,7 +112,6 @@ func (nt *NeighborTracker) UpdatePeer(p peer.ID, setID uint64, round uint64, hig
}

func (nt *NeighborTracker) BroadcastNeighborMsg() error {
logger.Warnf("braodcasting neighbor message to relevant peers")
packet := NeighbourPacketV1{
Round: nt.currentRound,
SetID: nt.currentSetID,
Expand Down
44 changes: 10 additions & 34 deletions lib/grandpa/neighbor_tracker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ package grandpa

import (
"fmt"
"testing"

"github.com/libp2p/go-libp2p/core/peer"
"github.com/stretchr/testify/require"
"testing"
)

func TestNeighborTracker_UpdatePeer(t *testing.T) {
Expand Down Expand Up @@ -72,25 +73,13 @@ func TestNeighborTracker_UpdatePeer(t *testing.T) {
highestFinalized: 6,
},
},
{
name: "nil tracker",
args: args{
p: "testPeer",
setID: 1,
round: 2,
highestFinalized: 3,
},
expectedErr: fmt.Errorf("neighbor tracker is nil"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
nt := tt.tracker
err := nt.UpdatePeer(tt.args.p, tt.args.setID, tt.args.round, tt.args.highestFinalized)
require.Equal(t, err, tt.expectedErr)
if nt != nil {
require.Equal(t, tt.expectedState, nt.peerview[tt.args.p])
}
require.Equal(t, tt.expectedState, nt.peerview[tt.args.p])
})
}
}
Expand All @@ -102,20 +91,10 @@ func TestNeighborTracker_UpdateState(t *testing.T) {
highestFinalized uint32
}
tests := []struct {
name string
tracker *NeighborTracker
args args
expectedErr error
name string
tracker *NeighborTracker
args args
}{
{
name: "nil tracker",
args: args{
setID: 1,
round: 2,
highestFinalized: 3,
},
expectedErr: fmt.Errorf("neighbor tracker is nil"),
},
{
name: "happy path",
tracker: &NeighborTracker{},
Expand All @@ -129,13 +108,10 @@ func TestNeighborTracker_UpdateState(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
nt := tt.tracker
err := nt.UpdateState(tt.args.setID, tt.args.round, tt.args.highestFinalized)
require.Equal(t, err, tt.expectedErr)
if nt != nil {
require.Equal(t, nt.currentSetID, tt.args.setID)
require.Equal(t, nt.currentRound, tt.args.round)
require.Equal(t, nt.highestFinalized, tt.args.highestFinalized)
}
nt.UpdateState(tt.args.setID, tt.args.round, tt.args.highestFinalized)
require.Equal(t, nt.currentSetID, tt.args.setID)
require.Equal(t, nt.currentRound, tt.args.round)
require.Equal(t, nt.highestFinalized, tt.args.highestFinalized)
})
}
}

0 comments on commit 9b839a0

Please sign in to comment.