Skip to content

Commit

Permalink
fix(lib/blocktree): improve blocktree.GetHashesAtNumber (#3799)
Browse files Browse the repository at this point in the history
  • Loading branch information
EclesioMeloJunior authored Mar 22, 2024
1 parent 5629c17 commit f9ab505
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 65 deletions.
34 changes: 12 additions & 22 deletions dot/state/block.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (
"github.com/ChainSafe/gossamer/pkg/scale"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"golang.org/x/exp/slices"

rtstorage "github.com/ChainSafe/gossamer/lib/runtime/storage"
wazero_runtime "github.com/ChainSafe/gossamer/lib/runtime/wazero"
Expand Down Expand Up @@ -255,19 +254,20 @@ func (bs *BlockState) GetHashByNumber(num uint) (common.Hash, error) {

// GetHashesByNumber returns the block hashes with the given number
func (bs *BlockState) GetHashesByNumber(blockNumber uint) ([]common.Hash, error) {
block, err := bs.GetBlockByNumber(blockNumber)
if err != nil {
return nil, fmt.Errorf("getting block by number: %w", err)
}

blockHashes := bs.bt.GetAllBlocksAtNumber(block.Header.ParentHash)
inMemoryBlockHashes := bs.bt.GetHashesAtNumber(blockNumber)
if len(inMemoryBlockHashes) == 0 {
bh, err := bs.db.Get(headerHashKey(uint64(blockNumber)))
if err != nil {
if errors.Is(err, database.ErrNotFound) {
return []common.Hash{}, nil
}
return []common.Hash{}, fmt.Errorf("cannot get block by its number %d: %w", blockNumber, err)
}

hash := block.Header.Hash()
if !slices.Contains(blockHashes, hash) {
blockHashes = append(blockHashes, hash)
return []common.Hash{common.NewHash(bh)}, nil
}

return blockHashes, nil
return inMemoryBlockHashes, nil
}

// GetAllDescendants gets all the descendants for a given block hash (including itself), by first checking in memory
Expand Down Expand Up @@ -496,17 +496,7 @@ func (bs *BlockState) AddBlockWithArrivalTime(block *types.Block, arrivalTime ti

// GetAllBlocksAtNumber returns all unfinalised blocks with the given number
func (bs *BlockState) GetAllBlocksAtNumber(num uint) ([]common.Hash, error) {
header, err := bs.GetHeaderByNumber(num)
if err != nil {
return nil, err
}

return bs.GetAllBlocksAtDepth(header.ParentHash), nil
}

// GetAllBlocksAtDepth returns all hashes with the depth of the given hash plus one
func (bs *BlockState) GetAllBlocksAtDepth(hash common.Hash) []common.Hash {
return bs.bt.GetAllBlocksAtNumber(hash)
return bs.bt.GetHashesAtNumber(num), nil
}

func (bs *BlockState) isBlockOnCurrentChain(header *types.Header) (bool, error) {
Expand Down
4 changes: 1 addition & 3 deletions dot/state/grandpa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ func testBlockState(t *testing.T, db database.Database) *BlockState {
return bs
}

func TestAddScheduledChangesKeepTheRightForkTree(t *testing.T) {
func TestAddScheduledChangesKeepTheRightForkTree(t *testing.T) { //nolint:tparallel
t.Parallel()

keyring, err := keystore.NewSr25519Keyring()
Expand Down Expand Up @@ -220,8 +220,6 @@ func TestAddScheduledChangesKeepTheRightForkTree(t *testing.T) {
for tname, tt := range tests {
tt := tt
t.Run(tname, func(t *testing.T) {
t.Parallel()

// clear the scheduledChangeRoots after the test ends
// this does not cause race condition because t.Run without
// t.Parallel() blocks until this function returns
Expand Down
14 changes: 0 additions & 14 deletions lib/babe/mock_state_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion lib/babe/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ type BlockState interface {
BestBlockHash() common.Hash
BestBlockHeader() (*types.Header, error)
AddBlock(*types.Block) error
GetAllBlocksAtDepth(hash common.Hash) []common.Hash
GetHeader(common.Hash) (*types.Header, error)
GetBlockByNumber(blockNumber uint) (*types.Block, error)
GetBlockHashesBySlot(slot uint64) (blockHashes []common.Hash, err error)
Expand Down
20 changes: 10 additions & 10 deletions lib/blocktree/blocktree.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,24 +120,24 @@ func (bt *BlockTree) AddBlock(header *types.Header, arrivalTime time.Time) (err
return nil
}

// GetAllBlocksAtNumber will return all blocks hashes with the number of the given hash plus one.
// GetHashesAtNumber will return all blocks hashes that contains the number of the given hash plus one.
// To find all blocks at a number matching a certain block, pass in that block's parent hash
func (bt *BlockTree) GetAllBlocksAtNumber(hash common.Hash) (hashes []common.Hash) {
func (bt *BlockTree) GetHashesAtNumber(number uint) (hashes []common.Hash) {
bt.RLock()
defer bt.RUnlock()

if bt.getNode(hash) == nil {
return hashes
if number < bt.root.number {
return []common.Hash{}
}

number := bt.getNode(hash).number + 1

if bt.root.number == number {
hashes = append(hashes, bt.root.hash)
return hashes
bestLeave := bt.leaves.bestBlock()
if number > bestLeave.number {
return []common.Hash{}
}

return bt.root.getNodesWithNumber(number, hashes)
possibleNumOfBlocks := len(bt.leaves.nodes())
hashes = make([]common.Hash, 0, possibleNumOfBlocks)
return bt.root.hashesAtNumber(number, hashes)
}

var ErrStartGreaterThanEnd = errors.New("start greater than end")
Expand Down
8 changes: 5 additions & 3 deletions lib/blocktree/blocktree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,10 @@ func Test_BlockTree_GetNode(t *testing.T) {
require.NotNil(t, block)
}

func Test_BlockTree_GetAllBlocksAtNumber(t *testing.T) {
func Test_BlockTree_GetHashesAtNumber(t *testing.T) {
bt, _ := createTestBlockTree(t, testHeader, 8)
hashes := bt.root.getNodesWithNumber(10, []common.Hash{})
hashes := make([]common.Hash, 0)
hashes = bt.root.hashesAtNumber(10, hashes)

require.Empty(t, hashes)

Expand Down Expand Up @@ -194,7 +195,8 @@ func Test_BlockTree_GetAllBlocksAtNumber(t *testing.T) {
}
}

hashes = bt.root.getNodesWithNumber(desiredNumber, []common.Hash{})
hashes = make([]common.Hash, 0, 100)
hashes = bt.root.hashesAtNumber(desiredNumber, hashes)
require.Equal(t, expected, hashes)
}

Expand Down
27 changes: 15 additions & 12 deletions lib/blocktree/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,23 @@ func (n *node) getNode(h common.Hash) *node {
return nil
}

// getNodesWithNumber returns all descendent nodes with the desired number
func (n *node) getNodesWithNumber(number uint, hashes []common.Hash) []common.Hash {
for _, child := range n.children {
// number matches
if child.number == number {
hashes = append(hashes, child.hash)
}

// are deeper than desired number, return
if child.number > number {
return hashes
// hashesAtNumber returns all nodes in the chain that contains the desired number
func (n *node) hashesAtNumber(number uint, hashes []common.Hash) []common.Hash {
// there is no need to go furthen in the node's children
// since they have a greater number at least
if number == n.number {
hashes = append(hashes, n.hash)
return hashes
}

// if the number is greater than current node,
// then search among its children
if number > n.number {
for _, children := range n.children {
hashes = children.hashesAtNumber(number, hashes)
}

hashes = child.getNodesWithNumber(number, hashes)
return hashes
}

return hashes
Expand Down

0 comments on commit f9ab505

Please sign in to comment.