Skip to content

Commit

Permalink
fix race conditions
Browse files Browse the repository at this point in the history
  • Loading branch information
jimjbrettj committed Oct 17, 2024
1 parent eb08a05 commit 4f39d8f
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 9 deletions.
21 changes: 17 additions & 4 deletions lib/grandpa/neighbor_tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package grandpa

import (
"fmt"
"sync"
"time"

"github.com/ChainSafe/gossamer/dot/types"
Expand All @@ -26,6 +27,7 @@ type neighborState struct {
}

type neighborTracker struct {
sync.Mutex
grandpa *Service

peerview map[peer.ID]neighborState
Expand Down Expand Up @@ -73,7 +75,7 @@ func (nt *neighborTracker) run() {

case block := <-nt.finalizationCha:
if block != nil {
nt.UpdateState(block.SetID, block.Round, uint32(block.Header.Number)) //nolint
nt.updateState(block.SetID, block.Round, uint32(block.Header.Number)) //nolint
err := nt.BroadcastNeighborMsg()
if err != nil {
logger.Errorf("broadcasting neighbour message: %v", err)
Expand All @@ -82,7 +84,7 @@ func (nt *neighborTracker) run() {
}
case neighborData := <-nt.neighborMsgChan:
if neighborData.neighborMsg.Number > nt.peerview[neighborData.peer].highestFinalized {
nt.UpdatePeer(
nt.updatePeer(
neighborData.peer,
neighborData.neighborMsg.SetID,
neighborData.neighborMsg.Round,
Expand All @@ -96,17 +98,28 @@ func (nt *neighborTracker) run() {
}
}

func (nt *neighborTracker) UpdateState(setID uint64, round uint64, highestFinalized uint32) {
func (nt *neighborTracker) updateState(setID uint64, round uint64, highestFinalized uint32) {
nt.Lock()
defer nt.Unlock()

nt.currentSetID = setID
nt.currentRound = round
nt.highestFinalized = highestFinalized
}

func (nt *neighborTracker) UpdatePeer(p peer.ID, setID uint64, round uint64, highestFinalized uint32) {
func (nt *neighborTracker) updatePeer(p peer.ID, setID uint64, round uint64, highestFinalized uint32) {
nt.Lock()
defer nt.Unlock()
peerState := neighborState{setID, round, highestFinalized}
nt.peerview[p] = peerState
}

func (nt *neighborTracker) getPeer(p peer.ID) neighborState {
nt.Lock()
defer nt.Unlock()
return nt.peerview[p]
}

func (nt *neighborTracker) BroadcastNeighborMsg() error {
packet := NeighbourPacketV1{
Round: nt.currentRound,
Expand Down
11 changes: 6 additions & 5 deletions lib/grandpa/neighbor_tracker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func TestNeighbourTracker_UpdatePeer(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
nt := tt.tracker
nt.UpdatePeer(tt.args.p, tt.args.setID, tt.args.round, tt.args.highestFinalized)
nt.updatePeer(tt.args.p, tt.args.setID, tt.args.round, tt.args.highestFinalized)
require.Equal(t, tt.expectedState, nt.peerview[tt.args.p])
})
}
Expand Down Expand Up @@ -101,7 +101,7 @@ func TestNeighbourTracker_UpdateState(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
nt := tt.tracker
nt.UpdateState(tt.args.setID, tt.args.round, tt.args.highestFinalized)
nt.updateState(tt.args.setID, tt.args.round, tt.args.highestFinalized)
require.Equal(t, nt.currentSetID, tt.args.setID)
require.Equal(t, nt.currentRound, tt.args.round)
require.Equal(t, nt.highestFinalized, tt.args.highestFinalized)
Expand Down Expand Up @@ -235,9 +235,10 @@ func TestNeighbourTracker_UpdatePeer_viaChannel(t *testing.T) {

time.Sleep(100 * time.Millisecond)

require.Equal(t, uint64(5), nt.peerview["testPeer"].round)
require.Equal(t, uint64(6), nt.peerview["testPeer"].setID)
require.Equal(t, uint32(7), nt.peerview["testPeer"].highestFinalized)
testPeer := nt.getPeer("testPeer")
require.Equal(t, uint64(5), testPeer.round)
require.Equal(t, uint64(6), testPeer.setID)
require.Equal(t, uint32(7), testPeer.highestFinalized)

nt.Stop()
}

0 comments on commit 4f39d8f

Please sign in to comment.