From d59d0536cbcc30f430fd1adbe2316069ebaa2202 Mon Sep 17 00:00:00 2001 From: Bui Quang Minh Date: Mon, 25 Mar 2024 11:02:01 +0700 Subject: [PATCH] consortium-v2/snapshot: make FindAncientHeader more readable This commit refactors FindAncientHeader, changes its name to findCheckpointHeader, adds some comments and unit test to make the code more readable. --- consensus/consortium/v2/snapshot.go | 83 +++++++++++------- consensus/consortium/v2/snapshot_test.go | 102 +++++++++++++++++++++++ 2 files changed, 155 insertions(+), 30 deletions(-) create mode 100644 consensus/consortium/v2/snapshot_test.go diff --git a/consensus/consortium/v2/snapshot.go b/consensus/consortium/v2/snapshot.go index 5931f0713a..de023cc8b3 100644 --- a/consensus/consortium/v2/snapshot.go +++ b/consensus/consortium/v2/snapshot.go @@ -15,6 +15,7 @@ import ( blsCommon "github.com/ethereum/go-ethereum/crypto/bls/common" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/internal/ethapi" + "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/params" lru "github.com/hashicorp/golang-lru" ) @@ -225,7 +226,7 @@ func (s *Snapshot) apply(headers []*types.Header, chain consensus.ChainHeaderRea // Change the validator set base on the size of the validators set if number > 0 && number%s.config.EpochV2 == uint64(len(snap.validators())/2) { // Get the most recent checkpoint header - checkpointHeader := FindAncientHeader(header, uint64(len(snap.validators())/2), chain, parents) + checkpointHeader := findCheckpointHeader(number-uint64(len(snap.validators())/2), chain, parents, header) if checkpointHeader == nil { return nil, consensus.ErrUnknownAncestor } @@ -362,36 +363,58 @@ func (s *Snapshot) IsRecentlySigned(validator common.Address) bool { return false } -// FindAncientHeader finds the most recent checkpoint header -// Travel through the candidateParents to find the ancient header. -// If all headers in candidateParents have the number is larger than the header number, -// the search function will return the index, but it is not valid if we check with the -// header since the number and hash is not equals. The candidateParents is -// only available when it downloads blocks from the network. -// Otherwise, the candidateParents is nil, and it will be found by header hash and number. -func FindAncientHeader(header *types.Header, ite uint64, chain consensus.ChainHeaderReader, candidateParents []*types.Header) *types.Header { - ancient := header - for i := uint64(1); i <= ite; i++ { - parentHash := ancient.ParentHash - parentHeight := ancient.Number.Uint64() - 1 - found := false - if len(candidateParents) > 0 { - index := sort.Search(len(candidateParents), func(i int) bool { - return candidateParents[i].Number.Uint64() >= parentHeight - }) - if index < len(candidateParents) && candidateParents[index].Number.Uint64() == parentHeight && - candidateParents[index].Hash() == parentHash { - ancient = candidateParents[index] - found = true - } - } - if !found { - ancient = chain.GetHeader(parentHash, parentHeight) - found = true +// findCheckpointHeader traverses back to look for the most recent checkpoint +// header in parents list or in chaindata +// +// parents are guaranteed to be ordered and linked by the check when InsertChain +// +// There are 2 possible cases: +// Case 1: checkpoint header is in parents list +// <- parents -> +// [ checkpointHeader ] +// +// Case 2: checkpoint header's height is lower than parents list +// <- parents -> +// checkpointHeader ... [ ] + +func findCheckpointHeader( + checkpointBlockNumber uint64, + chain consensus.ChainHeaderReader, + parents []*types.Header, + header *types.Header, +) *types.Header { + // Find the first header in parents list that is higher or equal to checkpoint block + index := sort.Search(len(parents), func(i int) bool { + return parents[i].Number.Uint64() >= checkpointBlockNumber + }) + + // This must not happen, checkpoint header's height cannot be higher the parents list + if len(parents) != 0 && index >= len(parents) { + log.Warn( + "Checkpoint header's height is higher than parents list", + "checkpointNumber", checkpointBlockNumber, + "last parent", parents[len(parents)-1].Number, + ) + return nil + } + + if len(parents) != 0 && parents[index].Number.Uint64() == checkpointBlockNumber { + // Case 1: checkpoint header is in parents list + return parents[index] + } else { + // Case 2: checkpoint header's height is lower than parents list + var headerIterator *types.Header + if len(parents) != 0 { + headerIterator = parents[0] + } else { + headerIterator = header } - if ancient == nil || !found { - return nil + for headerIterator.Number.Uint64() != checkpointBlockNumber { + headerIterator = chain.GetHeader(headerIterator.ParentHash, headerIterator.Number.Uint64()-1) + if headerIterator == nil { + return nil + } } + return headerIterator } - return ancient } diff --git a/consensus/consortium/v2/snapshot_test.go b/consensus/consortium/v2/snapshot_test.go new file mode 100644 index 0000000000..63f7c2cf5d --- /dev/null +++ b/consensus/consortium/v2/snapshot_test.go @@ -0,0 +1,102 @@ +package v2 + +import ( + "math/big" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/state" + "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/core/vm" + "github.com/ethereum/go-ethereum/ethdb" + "github.com/ethereum/go-ethereum/params" +) + +type mockChainReader struct { + headerMapping map[common.Hash]*types.Header +} + +func (chainReader *mockChainReader) Config() *params.ChainConfig { return nil } +func (chainReader *mockChainReader) CurrentHeader() *types.Header { return nil } +func (chainReader *mockChainReader) GetHeader(hash common.Hash, number uint64) *types.Header { + return chainReader.headerMapping[hash] +} +func (chainReader *mockChainReader) GetHeaderByNumber(number uint64) *types.Header { return nil } +func (chainReader *mockChainReader) GetHeaderByHash(hash common.Hash) *types.Header { return nil } +func (chainReader *mockChainReader) DB() ethdb.Database { return nil } +func (chainReader *mockChainReader) StateCache() state.Database { return nil } +func (chainReader *mockChainReader) OpEvents() []*vm.PublishEvent { return nil } + +func TestFindCheckpointHeader(t *testing.T) { + // Case 1: checkpoint header is in parent list + parents := make([]*types.Header, 10) + for i := range parents { + parents[i] = &types.Header{Number: big.NewInt(int64(i)), Coinbase: common.BigToAddress(big.NewInt(int64(i)))} + } + + checkpointHeader := findCheckpointHeader(5, nil, parents, nil) + if checkpointHeader.Number.Cmp(big.NewInt(5)) != 0 && checkpointHeader.Coinbase != common.BigToAddress(big.NewInt(5)) { + t.Fatalf("Expect checkpoint header number: %d, got: %d", 5, checkpointHeader.Number.Int64()) + } + + // Case 2: checkpoint header is higher than parent list, this must not happen + // but the function must not crash in this case + checkpointHeader = findCheckpointHeader(11, nil, parents, nil) + if checkpointHeader != nil { + t.Fatalf("Expect %v checkpoint header, got %v", nil, checkpointHeader) + } + + // Case 3: checkpoint header is lower than parent list + // parent list ranges from [10, 20) + for i := range parents { + parents[i] = &types.Header{Number: big.NewInt(int64(i + 10)), ParentHash: common.BigToHash(big.NewInt(int64(i + 10 - 1)))} + } + mockChain := mockChainReader{ + headerMapping: make(map[common.Hash]*types.Header), + } + // create mock chain 1 + for i := 5; i < 10; i++ { + mockChain.headerMapping[common.BigToHash(big.NewInt(int64(100+i)))] = &types.Header{ + Number: big.NewInt(int64(i)), + ParentHash: common.BigToHash(big.NewInt(int64(100 + i - 1))), + } + } + + // create mock chain 2 + for i := 5; i < 10; i++ { + mockChain.headerMapping[common.BigToHash(big.NewInt(int64(i)))] = &types.Header{ + Number: big.NewInt(int64(i)), + ParentHash: common.BigToHash(big.NewInt(int64(i - 1))), + } + } + + // Must traverse and get the correct header in chain 2 + checkpointHeader = findCheckpointHeader(5, &mockChain, parents, nil) + if checkpointHeader == nil { + t.Fatal("Failed to find checkpoint header") + } + if checkpointHeader.Number.Cmp(big.NewInt(5)) != 0 && checkpointHeader.ParentHash != common.BigToHash(big.NewInt(int64(4))) { + t.Fatalf("Expect checkpoint header number %d, parent hash: %s, got number: %d, parent hash: %s", + 5, common.BigToHash(big.NewInt(int64(4))), + checkpointHeader.Number.Int64(), checkpointHeader.ParentHash, + ) + } + + // Case 4: find checkpoint header with nil parent list + checkpointHeader = findCheckpointHeader( + 5, + &mockChain, + nil, + &types.Header{Number: big.NewInt(10), ParentHash: common.BigToHash(big.NewInt(109))}, + ) + // Must traverse and get the correct header in chain 1 + if checkpointHeader == nil { + t.Fatal("Failed to find checkpoint header") + } + if checkpointHeader.Number.Cmp(big.NewInt(5)) != 0 && checkpointHeader.ParentHash != common.BigToHash(big.NewInt(int64(104))) { + t.Fatalf("Expect checkpoint header number %d, parent hash: %s, got number: %d, parent hash: %s", + 5, common.BigToHash(big.NewInt(int64(104))), + checkpointHeader.Number.Int64(), checkpointHeader.ParentHash, + ) + } +}