Skip to content

Commit

Permalink
consortium-v2/snapshot: make FindAncientHeader more readable
Browse files Browse the repository at this point in the history
This commit refactors FindAncientHeader, changes its name to
findAncestorHeader, adds some comments and unit test to make the code more
readable.
  • Loading branch information
minh-bq committed Nov 14, 2024
1 parent 8e06b07 commit 3212b56
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 30 deletions.
83 changes: 53 additions & 30 deletions consensus/consortium/v2/snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
blsCommon "github.com/ethereum/go-ethereum/crypto/bls/common"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/internal/ethapi"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/params"
"github.com/hashicorp/golang-lru/arc/v2"
)
Expand Down Expand Up @@ -243,7 +244,7 @@ func (s *Snapshot) apply(headers []*types.Header, chain consensus.ChainHeaderRea
// Change the validator set base on the size of the validators set
if number > 0 && number%s.config.EpochV2 == uint64(len(snap.validators())/2) {
// Get the most recent checkpoint header
checkpointHeader := FindAncientHeader(header, uint64(len(snap.validators())/2), chain, parents)
checkpointHeader := findAncestorHeader(number-uint64(len(snap.validators())/2), chain, parents, header)
if checkpointHeader == nil {
return nil, consensus.ErrUnknownAncestor
}
Expand Down Expand Up @@ -420,36 +421,58 @@ func (s *Snapshot) IsRecentlySigned(validator common.Address) bool {
return false
}

// FindAncientHeader finds the most recent checkpoint header
// Travel through the candidateParents to find the ancient header.
// If all headers in candidateParents have the number is larger than the header number,
// the search function will return the index, but it is not valid if we check with the
// header since the number and hash is not equals. The candidateParents is
// only available when it downloads blocks from the network.
// Otherwise, the candidateParents is nil, and it will be found by header hash and number.
func FindAncientHeader(header *types.Header, ite uint64, chain consensus.ChainHeaderReader, candidateParents []*types.Header) *types.Header {
ancient := header
for i := uint64(1); i <= ite; i++ {
parentHash := ancient.ParentHash
parentHeight := ancient.Number.Uint64() - 1
found := false
if len(candidateParents) > 0 {
index := sort.Search(len(candidateParents), func(i int) bool {
return candidateParents[i].Number.Uint64() >= parentHeight
})
if index < len(candidateParents) && candidateParents[index].Number.Uint64() == parentHeight &&
candidateParents[index].Hash() == parentHash {
ancient = candidateParents[index]
found = true
}
}
if !found {
ancient = chain.GetHeader(parentHash, parentHeight)
found = true
// findAncestorHeader traverses back to look for the requested ancestor header
// in parents list or in chaindata
//
// parents are guaranteed to be ordered and linked by the check when InsertChain
//
// There are 2 possible cases:
// Case 1: ancestor header is in parents list
// <- parents ->
// [ ancestorHeader ]
//
// Case 2: ancestor header's height is lower than parents list
// <- parents ->
// ancestorHeader ... [ ]

func findAncestorHeader(
ancestorBlockNumber uint64,
chain consensus.ChainHeaderReader,
parents []*types.Header,
currentHeader *types.Header,
) *types.Header {
// Find the first header in parents list that is higher or equal to checkpoint block
index := sort.Search(len(parents), func(i int) bool {
return parents[i].Number.Uint64() >= ancestorBlockNumber
})

// This must not happen, checkpoint header's height cannot be higher the parents list
if len(parents) != 0 && index >= len(parents) {
log.Warn(
"Checkpoint header's height is higher than parents list",
"checkpointNumber", ancestorBlockNumber,
"last parent", parents[len(parents)-1].Number,
)
return nil
}

if len(parents) != 0 && parents[index].Number.Uint64() == ancestorBlockNumber {
// Case 1: checkpoint header is in parents list
return parents[index]
} else {
// Case 2: checkpoint header's height is lower than parents list
var headerIterator *types.Header
if len(parents) != 0 {
headerIterator = parents[0]
} else {
headerIterator = currentHeader
}
if ancient == nil || !found {
return nil
for headerIterator.Number.Uint64() != ancestorBlockNumber {
headerIterator = chain.GetHeader(headerIterator.ParentHash, headerIterator.Number.Uint64()-1)
if headerIterator == nil {
return nil
}
}
return headerIterator
}
return ancient
}
102 changes: 102 additions & 0 deletions consensus/consortium/v2/snapshot_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package v2

import (
"math/big"
"testing"

"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/params"
)

type mockChainReader struct {
headerMapping map[common.Hash]*types.Header
}

func (chainReader *mockChainReader) Config() *params.ChainConfig { return nil }
func (chainReader *mockChainReader) CurrentHeader() *types.Header { return nil }
func (chainReader *mockChainReader) GetHeader(hash common.Hash, number uint64) *types.Header {
return chainReader.headerMapping[hash]
}
func (chainReader *mockChainReader) GetHeaderByNumber(number uint64) *types.Header { return nil }
func (chainReader *mockChainReader) GetHeaderByHash(hash common.Hash) *types.Header { return nil }
func (chainReader *mockChainReader) DB() ethdb.Database { return nil }
func (chainReader *mockChainReader) StateCache() state.Database { return nil }
func (chainReader *mockChainReader) OpEvents() []*vm.PublishEvent { return nil }

func TestFindCheckpointHeader(t *testing.T) {
// Case 1: checkpoint header is in parent list
parents := make([]*types.Header, 10)
for i := range parents {
parents[i] = &types.Header{Number: big.NewInt(int64(i)), Coinbase: common.BigToAddress(big.NewInt(int64(i)))}
}

checkpointHeader := findAncestorHeader(5, nil, parents, nil)
if checkpointHeader.Number.Cmp(big.NewInt(5)) != 0 && checkpointHeader.Coinbase != common.BigToAddress(big.NewInt(5)) {
t.Fatalf("Expect checkpoint header number: %d, got: %d", 5, checkpointHeader.Number.Int64())
}

// Case 2: checkpoint header is higher than parent list, this must not happen
// but the function must not crash in this case
checkpointHeader = findAncestorHeader(11, nil, parents, nil)
if checkpointHeader != nil {
t.Fatalf("Expect %v checkpoint header, got %v", nil, checkpointHeader)
}

// Case 3: checkpoint header is lower than parent list
// parent list ranges from [10, 20)
for i := range parents {
parents[i] = &types.Header{Number: big.NewInt(int64(i + 10)), ParentHash: common.BigToHash(big.NewInt(int64(i + 10 - 1)))}
}
mockChain := mockChainReader{
headerMapping: make(map[common.Hash]*types.Header),
}
// create mock chain 1
for i := 5; i < 10; i++ {
mockChain.headerMapping[common.BigToHash(big.NewInt(int64(100+i)))] = &types.Header{
Number: big.NewInt(int64(i)),
ParentHash: common.BigToHash(big.NewInt(int64(100 + i - 1))),
}
}

// create mock chain 2
for i := 5; i < 10; i++ {
mockChain.headerMapping[common.BigToHash(big.NewInt(int64(i)))] = &types.Header{
Number: big.NewInt(int64(i)),
ParentHash: common.BigToHash(big.NewInt(int64(i - 1))),
}
}

// Must traverse and get the correct header in chain 2
checkpointHeader = findAncestorHeader(5, &mockChain, parents, nil)
if checkpointHeader == nil {
t.Fatal("Failed to find checkpoint header")
}
if checkpointHeader.Number.Cmp(big.NewInt(5)) != 0 && checkpointHeader.ParentHash != common.BigToHash(big.NewInt(int64(4))) {
t.Fatalf("Expect checkpoint header number %d, parent hash: %s, got number: %d, parent hash: %s",
5, common.BigToHash(big.NewInt(int64(4))),
checkpointHeader.Number.Int64(), checkpointHeader.ParentHash,
)
}

// Case 4: find checkpoint header with nil parent list
checkpointHeader = findAncestorHeader(
5,
&mockChain,
nil,
&types.Header{Number: big.NewInt(10), ParentHash: common.BigToHash(big.NewInt(109))},
)
// Must traverse and get the correct header in chain 1
if checkpointHeader == nil {
t.Fatal("Failed to find checkpoint header")
}
if checkpointHeader.Number.Cmp(big.NewInt(5)) != 0 && checkpointHeader.ParentHash != common.BigToHash(big.NewInt(int64(104))) {
t.Fatalf("Expect checkpoint header number %d, parent hash: %s, got number: %d, parent hash: %s",
5, common.BigToHash(big.NewInt(int64(104))),
checkpointHeader.Number.Int64(), checkpointHeader.ParentHash,
)
}
}

0 comments on commit 3212b56

Please sign in to comment.