Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(lib/blocktree): improve blocktree.GetHashesAtNumber #3799

Merged
merged 9 commits into from
Mar 22, 2024
34 changes: 12 additions & 22 deletions dot/state/block.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (
"github.com/ChainSafe/gossamer/pkg/scale"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"golang.org/x/exp/slices"

rtstorage "github.com/ChainSafe/gossamer/lib/runtime/storage"
wazero_runtime "github.com/ChainSafe/gossamer/lib/runtime/wazero"
Expand Down Expand Up @@ -255,19 +254,20 @@ func (bs *BlockState) GetHashByNumber(num uint) (common.Hash, error) {

// GetHashesByNumber returns the block hashes with the given number
func (bs *BlockState) GetHashesByNumber(blockNumber uint) ([]common.Hash, error) {
block, err := bs.GetBlockByNumber(blockNumber)
if err != nil {
return nil, fmt.Errorf("getting block by number: %w", err)
}

blockHashes := bs.bt.GetAllBlocksAtNumber(block.Header.ParentHash)
inMemoryBlockHashes := bs.bt.GetHashesAtNumber(blockNumber)
if len(inMemoryBlockHashes) == 0 {
bh, err := bs.db.Get(headerHashKey(uint64(blockNumber)))
if err != nil {
if errors.Is(err, database.ErrNotFound) {
return []common.Hash{}, nil
}
return []common.Hash{}, fmt.Errorf("cannot get block by its number %d: %w", blockNumber, err)
}

hash := block.Header.Hash()
if !slices.Contains(blockHashes, hash) {
blockHashes = append(blockHashes, hash)
return []common.Hash{common.NewHash(bh)}, nil
}

return blockHashes, nil
return inMemoryBlockHashes, nil
}

// GetAllDescendants gets all the descendants for a given block hash (including itself), by first checking in memory
Expand Down Expand Up @@ -496,17 +496,7 @@ func (bs *BlockState) AddBlockWithArrivalTime(block *types.Block, arrivalTime ti

// GetAllBlocksAtNumber returns all unfinalised blocks with the given number
func (bs *BlockState) GetAllBlocksAtNumber(num uint) ([]common.Hash, error) {
header, err := bs.GetHeaderByNumber(num)
if err != nil {
return nil, err
}

return bs.GetAllBlocksAtDepth(header.ParentHash), nil
}

// GetAllBlocksAtDepth returns all hashes with the depth of the given hash plus one
func (bs *BlockState) GetAllBlocksAtDepth(hash common.Hash) []common.Hash {
return bs.bt.GetAllBlocksAtNumber(hash)
return bs.bt.GetHashesAtNumber(num), nil
}

func (bs *BlockState) isBlockOnCurrentChain(header *types.Header) (bool, error) {
Expand Down
4 changes: 1 addition & 3 deletions dot/state/grandpa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ func testBlockState(t *testing.T, db database.Database) *BlockState {
return bs
}

func TestAddScheduledChangesKeepTheRightForkTree(t *testing.T) {
func TestAddScheduledChangesKeepTheRightForkTree(t *testing.T) { //nolint:tparallel
EclesioMeloJunior marked this conversation as resolved.
Show resolved Hide resolved
t.Parallel()

keyring, err := keystore.NewSr25519Keyring()
Expand Down Expand Up @@ -220,8 +220,6 @@ func TestAddScheduledChangesKeepTheRightForkTree(t *testing.T) {
for tname, tt := range tests {
tt := tt
t.Run(tname, func(t *testing.T) {
t.Parallel()

// clear the scheduledChangeRoots after the test ends
// this does not cause race condition because t.Run without
// t.Parallel() blocks until this function returns
Expand Down
1 change: 0 additions & 1 deletion lib/babe/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ type BlockState interface {
BestBlockHash() common.Hash
BestBlockHeader() (*types.Header, error)
AddBlock(*types.Block) error
GetAllBlocksAtDepth(hash common.Hash) []common.Hash
GetHeader(common.Hash) (*types.Header, error)
GetBlockByNumber(blockNumber uint) (*types.Block, error)
GetBlockHashesBySlot(slot uint64) (blockHashes []common.Hash, err error)
Expand Down
20 changes: 10 additions & 10 deletions lib/blocktree/blocktree.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,24 +120,24 @@ func (bt *BlockTree) AddBlock(header *types.Header, arrivalTime time.Time) (err
return nil
}

// GetAllBlocksAtNumber will return all blocks hashes with the number of the given hash plus one.
// GetHashesAtNumber will return all blocks hashes that contains the number of the given hash plus one.
// To find all blocks at a number matching a certain block, pass in that block's parent hash
func (bt *BlockTree) GetAllBlocksAtNumber(hash common.Hash) (hashes []common.Hash) {
func (bt *BlockTree) GetHashesAtNumber(number uint) (hashes []common.Hash) {
EclesioMeloJunior marked this conversation as resolved.
Show resolved Hide resolved
bt.RLock()
defer bt.RUnlock()

if bt.getNode(hash) == nil {
return hashes
if number < bt.root.number {
return []common.Hash{}
EclesioMeloJunior marked this conversation as resolved.
Show resolved Hide resolved
}

number := bt.getNode(hash).number + 1

if bt.root.number == number {
hashes = append(hashes, bt.root.hash)
return hashes
bestLeave := bt.leaves.bestBlock()
EclesioMeloJunior marked this conversation as resolved.
Show resolved Hide resolved
if number > bestLeave.number {
return []common.Hash{}
EclesioMeloJunior marked this conversation as resolved.
Show resolved Hide resolved
}

return bt.root.getNodesWithNumber(number, hashes)
possibleNumOfBlocks := len(bt.leaves.nodes())
hashes = make([]common.Hash, 0, possibleNumOfBlocks)
return bt.root.hashesAtNumber(number, hashes)
}

var ErrStartGreaterThanEnd = errors.New("start greater than end")
Expand Down
8 changes: 5 additions & 3 deletions lib/blocktree/blocktree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,10 @@ func Test_BlockTree_GetNode(t *testing.T) {
require.NotNil(t, block)
}

func Test_BlockTree_GetAllBlocksAtNumber(t *testing.T) {
func Test_BlockTree_GetHashesAtNumber(t *testing.T) {
bt, _ := createTestBlockTree(t, testHeader, 8)
hashes := bt.root.getNodesWithNumber(10, []common.Hash{})
hashes := make([]common.Hash, 0)
hashes = bt.root.hashesAtNumber(10, hashes)

require.Empty(t, hashes)

Expand Down Expand Up @@ -194,7 +195,8 @@ func Test_BlockTree_GetAllBlocksAtNumber(t *testing.T) {
}
}

hashes = bt.root.getNodesWithNumber(desiredNumber, []common.Hash{})
hashes = make([]common.Hash, 0, 100)
hashes = bt.root.hashesAtNumber(desiredNumber, hashes)
require.Equal(t, expected, hashes)
}

Expand Down
27 changes: 15 additions & 12 deletions lib/blocktree/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,23 @@ func (n *node) getNode(h common.Hash) *node {
return nil
}

// getNodesWithNumber returns all descendent nodes with the desired number
func (n *node) getNodesWithNumber(number uint, hashes []common.Hash) []common.Hash {
for _, child := range n.children {
// number matches
if child.number == number {
hashes = append(hashes, child.hash)
}

// are deeper than desired number, return
if child.number > number {
return hashes
// hashesAtNumber returns all nodes in the chain that contains the desired number
func (n *node) hashesAtNumber(number uint, hashes []common.Hash) []common.Hash {
EclesioMeloJunior marked this conversation as resolved.
Show resolved Hide resolved
// there is no need to go furthen in the node's children
// since they have a greater number at least
if number == n.number {
hashes = append(hashes, n.hash)
return hashes
}

// if the number is greater than current node,
// then search among its children
if number > n.number {
for _, children := range n.children {
hashes = children.hashesAtNumber(number, hashes)
EclesioMeloJunior marked this conversation as resolved.
Show resolved Hide resolved
}

hashes = child.getNodesWithNumber(number, hashes)
return hashes
}

return hashes
Expand Down
Loading