From f9ab505f7b1cdd128708cf72ef4982e693f365d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ecl=C3=A9sio=20Junior?= Date: Fri, 22 Mar 2024 10:19:11 -0400 Subject: [PATCH] fix(lib/blocktree): improve `blocktree.GetHashesAtNumber` (#3799) --- dot/state/block.go | 34 ++++++++++++--------------------- dot/state/grandpa_test.go | 4 +--- lib/babe/mock_state_test.go | 14 -------------- lib/babe/state.go | 1 - lib/blocktree/blocktree.go | 20 +++++++++---------- lib/blocktree/blocktree_test.go | 8 +++++--- lib/blocktree/node.go | 27 ++++++++++++++------------ 7 files changed, 43 insertions(+), 65 deletions(-) diff --git a/dot/state/block.go b/dot/state/block.go index 690f4fcdd5..d57df0beb0 100644 --- a/dot/state/block.go +++ b/dot/state/block.go @@ -19,7 +19,6 @@ import ( "github.com/ChainSafe/gossamer/pkg/scale" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" - "golang.org/x/exp/slices" rtstorage "github.com/ChainSafe/gossamer/lib/runtime/storage" wazero_runtime "github.com/ChainSafe/gossamer/lib/runtime/wazero" @@ -255,19 +254,20 @@ func (bs *BlockState) GetHashByNumber(num uint) (common.Hash, error) { // GetHashesByNumber returns the block hashes with the given number func (bs *BlockState) GetHashesByNumber(blockNumber uint) ([]common.Hash, error) { - block, err := bs.GetBlockByNumber(blockNumber) - if err != nil { - return nil, fmt.Errorf("getting block by number: %w", err) - } - - blockHashes := bs.bt.GetAllBlocksAtNumber(block.Header.ParentHash) + inMemoryBlockHashes := bs.bt.GetHashesAtNumber(blockNumber) + if len(inMemoryBlockHashes) == 0 { + bh, err := bs.db.Get(headerHashKey(uint64(blockNumber))) + if err != nil { + if errors.Is(err, database.ErrNotFound) { + return []common.Hash{}, nil + } + return []common.Hash{}, fmt.Errorf("cannot get block by its number %d: %w", blockNumber, err) + } - hash := block.Header.Hash() - if !slices.Contains(blockHashes, hash) { - blockHashes = append(blockHashes, hash) + return []common.Hash{common.NewHash(bh)}, nil } - return blockHashes, nil + return inMemoryBlockHashes, nil } // GetAllDescendants gets all the descendants for a given block hash (including itself), by first checking in memory @@ -496,17 +496,7 @@ func (bs *BlockState) AddBlockWithArrivalTime(block *types.Block, arrivalTime ti // GetAllBlocksAtNumber returns all unfinalised blocks with the given number func (bs *BlockState) GetAllBlocksAtNumber(num uint) ([]common.Hash, error) { - header, err := bs.GetHeaderByNumber(num) - if err != nil { - return nil, err - } - - return bs.GetAllBlocksAtDepth(header.ParentHash), nil -} - -// GetAllBlocksAtDepth returns all hashes with the depth of the given hash plus one -func (bs *BlockState) GetAllBlocksAtDepth(hash common.Hash) []common.Hash { - return bs.bt.GetAllBlocksAtNumber(hash) + return bs.bt.GetHashesAtNumber(num), nil } func (bs *BlockState) isBlockOnCurrentChain(header *types.Header) (bool, error) { diff --git a/dot/state/grandpa_test.go b/dot/state/grandpa_test.go index 8018308c0b..7043cac47b 100644 --- a/dot/state/grandpa_test.go +++ b/dot/state/grandpa_test.go @@ -143,7 +143,7 @@ func testBlockState(t *testing.T, db database.Database) *BlockState { return bs } -func TestAddScheduledChangesKeepTheRightForkTree(t *testing.T) { +func TestAddScheduledChangesKeepTheRightForkTree(t *testing.T) { //nolint:tparallel t.Parallel() keyring, err := keystore.NewSr25519Keyring() @@ -220,8 +220,6 @@ func TestAddScheduledChangesKeepTheRightForkTree(t *testing.T) { for tname, tt := range tests { tt := tt t.Run(tname, func(t *testing.T) { - t.Parallel() - // clear the scheduledChangeRoots after the test ends // this does not cause race condition because t.Run without // t.Parallel() blocks until this function returns diff --git a/lib/babe/mock_state_test.go b/lib/babe/mock_state_test.go index c9b3b83da2..ab5e132028 100644 --- a/lib/babe/mock_state_test.go +++ b/lib/babe/mock_state_test.go @@ -113,20 +113,6 @@ func (mr *MockBlockStateMockRecorder) GenesisHash() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GenesisHash", reflect.TypeOf((*MockBlockState)(nil).GenesisHash)) } -// GetAllBlocksAtDepth mocks base method. -func (m *MockBlockState) GetAllBlocksAtDepth(arg0 common.Hash) []common.Hash { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetAllBlocksAtDepth", arg0) - ret0, _ := ret[0].([]common.Hash) - return ret0 -} - -// GetAllBlocksAtDepth indicates an expected call of GetAllBlocksAtDepth. -func (mr *MockBlockStateMockRecorder) GetAllBlocksAtDepth(arg0 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllBlocksAtDepth", reflect.TypeOf((*MockBlockState)(nil).GetAllBlocksAtDepth), arg0) -} - // GetBlockByNumber mocks base method. func (m *MockBlockState) GetBlockByNumber(arg0 uint) (*types.Block, error) { m.ctrl.T.Helper() diff --git a/lib/babe/state.go b/lib/babe/state.go index 47de36a930..f4aee6f4d4 100644 --- a/lib/babe/state.go +++ b/lib/babe/state.go @@ -24,7 +24,6 @@ type BlockState interface { BestBlockHash() common.Hash BestBlockHeader() (*types.Header, error) AddBlock(*types.Block) error - GetAllBlocksAtDepth(hash common.Hash) []common.Hash GetHeader(common.Hash) (*types.Header, error) GetBlockByNumber(blockNumber uint) (*types.Block, error) GetBlockHashesBySlot(slot uint64) (blockHashes []common.Hash, err error) diff --git a/lib/blocktree/blocktree.go b/lib/blocktree/blocktree.go index 0f1faf21cc..0fa4e4d06e 100644 --- a/lib/blocktree/blocktree.go +++ b/lib/blocktree/blocktree.go @@ -120,24 +120,24 @@ func (bt *BlockTree) AddBlock(header *types.Header, arrivalTime time.Time) (err return nil } -// GetAllBlocksAtNumber will return all blocks hashes with the number of the given hash plus one. +// GetHashesAtNumber will return all blocks hashes that contains the number of the given hash plus one. // To find all blocks at a number matching a certain block, pass in that block's parent hash -func (bt *BlockTree) GetAllBlocksAtNumber(hash common.Hash) (hashes []common.Hash) { +func (bt *BlockTree) GetHashesAtNumber(number uint) (hashes []common.Hash) { bt.RLock() defer bt.RUnlock() - if bt.getNode(hash) == nil { - return hashes + if number < bt.root.number { + return []common.Hash{} } - number := bt.getNode(hash).number + 1 - - if bt.root.number == number { - hashes = append(hashes, bt.root.hash) - return hashes + bestLeave := bt.leaves.bestBlock() + if number > bestLeave.number { + return []common.Hash{} } - return bt.root.getNodesWithNumber(number, hashes) + possibleNumOfBlocks := len(bt.leaves.nodes()) + hashes = make([]common.Hash, 0, possibleNumOfBlocks) + return bt.root.hashesAtNumber(number, hashes) } var ErrStartGreaterThanEnd = errors.New("start greater than end") diff --git a/lib/blocktree/blocktree_test.go b/lib/blocktree/blocktree_test.go index f7833eeec5..bca3a93e16 100644 --- a/lib/blocktree/blocktree_test.go +++ b/lib/blocktree/blocktree_test.go @@ -139,9 +139,10 @@ func Test_BlockTree_GetNode(t *testing.T) { require.NotNil(t, block) } -func Test_BlockTree_GetAllBlocksAtNumber(t *testing.T) { +func Test_BlockTree_GetHashesAtNumber(t *testing.T) { bt, _ := createTestBlockTree(t, testHeader, 8) - hashes := bt.root.getNodesWithNumber(10, []common.Hash{}) + hashes := make([]common.Hash, 0) + hashes = bt.root.hashesAtNumber(10, hashes) require.Empty(t, hashes) @@ -194,7 +195,8 @@ func Test_BlockTree_GetAllBlocksAtNumber(t *testing.T) { } } - hashes = bt.root.getNodesWithNumber(desiredNumber, []common.Hash{}) + hashes = make([]common.Hash, 0, 100) + hashes = bt.root.hashesAtNumber(desiredNumber, hashes) require.Equal(t, expected, hashes) } diff --git a/lib/blocktree/node.go b/lib/blocktree/node.go index 557f805066..5eb4556860 100644 --- a/lib/blocktree/node.go +++ b/lib/blocktree/node.go @@ -59,20 +59,23 @@ func (n *node) getNode(h common.Hash) *node { return nil } -// getNodesWithNumber returns all descendent nodes with the desired number -func (n *node) getNodesWithNumber(number uint, hashes []common.Hash) []common.Hash { - for _, child := range n.children { - // number matches - if child.number == number { - hashes = append(hashes, child.hash) - } - - // are deeper than desired number, return - if child.number > number { - return hashes +// hashesAtNumber returns all nodes in the chain that contains the desired number +func (n *node) hashesAtNumber(number uint, hashes []common.Hash) []common.Hash { + // there is no need to go furthen in the node's children + // since they have a greater number at least + if number == n.number { + hashes = append(hashes, n.hash) + return hashes + } + + // if the number is greater than current node, + // then search among its children + if number > n.number { + for _, children := range n.children { + hashes = children.hashesAtNumber(number, hashes) } - hashes = child.getNodesWithNumber(number, hashes) + return hashes } return hashes