diff --git a/dot/mock_node_builder_test.go b/dot/mock_node_builder_test.go index ad64164c1c..06bfba5fba 100644 --- a/dot/mock_node_builder_test.go +++ b/dot/mock_node_builder_test.go @@ -228,7 +228,11 @@ func (mr *MocknodeBuilderIfaceMockRecorder) loadRuntime(config, ns, stateSrvc, k } // newSyncService mocks base method. +<<<<<<< HEAD func (m *MocknodeBuilderIface) newSyncService(config *config.Config, st *state.Service, finalityGadget BlockJustificationVerifier, verifier *babe.VerificationManager, cs *core.Service, net *network.Service, telemetryMailer Telemetry) (network.Syncer, error) { +======= +func (m *MocknodeBuilderIface) newSyncService(config *config.Config, st *state.Service, finalityGadget sync.FinalityGadget, verifier *babe.VerificationManager, cs *core.Service, net *network.Service, telemetryMailer Telemetry) (*sync.Service, error) { +>>>>>>> development m.ctrl.T.Helper() ret := m.ctrl.Call(m, "newSyncService", config, st, finalityGadget, verifier, cs, net, telemetryMailer) ret0, _ := ret[0].(network.Syncer) diff --git a/dot/node.go b/dot/node.go index 9e99ac30be..aa6e656956 100644 --- a/dot/node.go +++ b/dot/node.go @@ -61,7 +61,7 @@ type nodeBuilderIface interface { ) (*core.Service, error) createGRANDPAService(config *cfg.Config, st *state.Service, ks KeyStore, net *network.Service, telemetryMailer Telemetry) (*grandpa.Service, error) - newSyncService(config *cfg.Config, st *state.Service, finalityGadget BlockJustificationVerifier, + newSyncService(config *cfg.Config, st *state.Service, finalityGadget dotsync.FinalityGadget, verifier *babe.VerificationManager, cs *core.Service, net *network.Service, telemetryMailer Telemetry) (network.Syncer, error) createBABEService(config *cfg.Config, st *state.Service, ks KeyStore, cs *core.Service, diff --git a/dot/services.go b/dot/services.go index fc6645cecd..8dca45d357 100644 --- a/dot/services.go +++ b/dot/services.go @@ -497,7 +497,7 @@ func (nodeBuilder) createBlockVerifier(st *state.Service) *babe.VerificationMana return babe.NewVerificationManager(st.Block, st.Slot, st.Epoch) } -func (nodeBuilder) newSyncService(config *cfg.Config, st *state.Service, fg BlockJustificationVerifier, +func (nodeBuilder) newSyncService(config *cfg.Config, st *state.Service, fg sync.FinalityGadget, verifier *babe.VerificationManager, cs *core.Service, net *network.Service, telemetryMailer Telemetry) ( network.Syncer, error) { slotDuration, err := st.Epoch.GetSlotDuration() diff --git a/dot/services_integration_test.go b/dot/services_integration_test.go index 6b7261f52a..a922cc0298 100644 --- a/dot/services_integration_test.go +++ b/dot/services_integration_test.go @@ -17,6 +17,7 @@ import ( "github.com/ChainSafe/gossamer/dot/network" rpc "github.com/ChainSafe/gossamer/dot/rpc" "github.com/ChainSafe/gossamer/dot/state" + "github.com/ChainSafe/gossamer/dot/sync" "github.com/ChainSafe/gossamer/dot/telemetry" "github.com/ChainSafe/gossamer/dot/types" "github.com/ChainSafe/gossamer/internal/log" @@ -373,7 +374,7 @@ func Test_nodeBuilder_newSyncService(t *testing.T) { require.NoError(t, err) type args struct { - fg BlockJustificationVerifier + fg sync.FinalityGadget verifier *babe.VerificationManager cs *core.Service net *network.Service diff --git a/dot/sync/chain_sync.go b/dot/sync/chain_sync.go new file mode 100644 index 0000000000..a37240138a --- /dev/null +++ b/dot/sync/chain_sync.go @@ -0,0 +1,1087 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package sync + +import ( + "bytes" + "errors" + "fmt" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/libp2p/go-libp2p/core/peer" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + "golang.org/x/exp/slices" + + "github.com/ChainSafe/gossamer/dot/network" + "github.com/ChainSafe/gossamer/dot/network/messages" + "github.com/ChainSafe/gossamer/dot/peerset" + "github.com/ChainSafe/gossamer/dot/telemetry" + "github.com/ChainSafe/gossamer/dot/types" + "github.com/ChainSafe/gossamer/internal/database" + "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/lib/common/variadic" +) + +var _ ChainSync = (*chainSync)(nil) + +type chainSyncState byte + +const ( + bootstrap chainSyncState = iota + tip +) + +type blockOrigin byte + +const ( + networkInitialSync blockOrigin = iota + networkBroadcast +) + +func (s chainSyncState) String() string { + switch s { + case bootstrap: + return "bootstrap" + case tip: + return "tip" + default: + return "unknown" + } +} + +var ( + pendingBlocksLimit = messages.MaxBlocksInResponse * 32 + isSyncedGauge = promauto.NewGauge(prometheus.GaugeOpts{ + Namespace: "gossamer_network_syncer", + Name: "is_synced", + Help: "bool representing whether the node is synced to the head of the chain", + }) + + blockSizeGauge = promauto.NewGauge(prometheus.GaugeOpts{ + Namespace: "gossamer_sync", + Name: "block_size", + Help: "represent the size of blocks synced", + }) +) + +// ChainSync contains the methods used by the high-level service into the `chainSync` module +type ChainSync interface { + start() + stop() error + + // called upon receiving a BlockAnnounceHandshake + onBlockAnnounceHandshake(p peer.ID, hash common.Hash, number uint) error + + // getSyncMode returns the current syncing state + getSyncMode() chainSyncState + + // getHighestBlock returns the highest block or an error + getHighestBlock() (highestBlock uint, err error) + + onBlockAnnounce(announcedBlock) error +} + +type announcedBlock struct { + who peer.ID + header *types.Header +} + +type chainSync struct { + wg sync.WaitGroup + stopCh chan struct{} + + blockState BlockState + network Network + + workerPool *syncWorkerPool + + // tracks the latest state we know of from our peers, + // ie. their best block hash and number + peerViewSet *peerViewSet + + // disjoint set of blocks which are known but not ready to be processed + // ie. we only know the hash, number, or the parent block is unknown, or the body is unknown + // note: the block may have empty fields, as some data about it may be unknown + pendingBlocks DisjointBlockSet + + syncMode atomic.Value + + finalisedCh <-chan *types.FinalisationInfo + + minPeers int + slotDuration time.Duration + + storageState StorageState + transactionState TransactionState + babeVerifier BabeVerifier + finalityGadget FinalityGadget + blockImportHandler BlockImportHandler + telemetry Telemetry + badBlocks []string + requestMaker network.RequestMaker + waitPeersDuration time.Duration +} + +type chainSyncConfig struct { + bs BlockState + net Network + requestMaker network.RequestMaker + pendingBlocks DisjointBlockSet + minPeers, maxPeers int + slotDuration time.Duration + storageState StorageState + transactionState TransactionState + babeVerifier BabeVerifier + finalityGadget FinalityGadget + blockImportHandler BlockImportHandler + telemetry Telemetry + badBlocks []string + waitPeersDuration time.Duration +} + +func newChainSync(cfg chainSyncConfig) *chainSync { + atomicState := atomic.Value{} + atomicState.Store(tip) + return &chainSync{ + stopCh: make(chan struct{}), + storageState: cfg.storageState, + transactionState: cfg.transactionState, + babeVerifier: cfg.babeVerifier, + finalityGadget: cfg.finalityGadget, + blockImportHandler: cfg.blockImportHandler, + telemetry: cfg.telemetry, + blockState: cfg.bs, + network: cfg.net, + peerViewSet: newPeerViewSet(cfg.maxPeers), + pendingBlocks: cfg.pendingBlocks, + syncMode: atomicState, + finalisedCh: cfg.bs.GetFinalisedNotifierChannel(), + minPeers: cfg.minPeers, + slotDuration: cfg.slotDuration, + workerPool: newSyncWorkerPool(cfg.net, cfg.requestMaker), + badBlocks: cfg.badBlocks, + requestMaker: cfg.requestMaker, + waitPeersDuration: cfg.waitPeersDuration, + } +} + +func (cs *chainSync) waitWorkersAndTarget() { + waitPeersTimer := time.NewTimer(cs.waitPeersDuration) + + highestFinalizedHeader, err := cs.blockState.GetHighestFinalisedHeader() + if err != nil { + panic(fmt.Sprintf("failed to get highest finalised header: %v", err)) + } + + for { + cs.workerPool.useConnectedPeers() + totalAvailable := cs.workerPool.totalWorkers() + + if totalAvailable >= uint(cs.minPeers) && + cs.peerViewSet.getTarget() > 0 { + return + } + + err := cs.network.BlockAnnounceHandshake(highestFinalizedHeader) + if err != nil && !errors.Is(err, network.ErrNoPeersConnected) { + logger.Errorf("retrieving target info from peers: %v", err) + } + + select { + case <-waitPeersTimer.C: + waitPeersTimer.Reset(cs.waitPeersDuration) + + case <-cs.stopCh: + return + } + } +} + +func (cs *chainSync) start() { + // since the default status from sync mode is syncMode(tip) + isSyncedGauge.Set(1) + + cs.wg.Add(1) + go cs.pendingBlocks.run(cs.finalisedCh, cs.stopCh, &cs.wg) + + // wait until we have a minimal workers in the sync worker pool + cs.waitWorkersAndTarget() +} + +func (cs *chainSync) stop() error { + err := cs.workerPool.stop() + if err != nil { + return fmt.Errorf("stopping worker poll: %w", err) + } + + close(cs.stopCh) + allStopCh := make(chan struct{}) + go func() { + defer close(allStopCh) + cs.wg.Wait() + }() + + timeoutTimer := time.NewTimer(30 * time.Second) + + select { + case <-allStopCh: + if !timeoutTimer.Stop() { + <-timeoutTimer.C + } + return nil + case <-timeoutTimer.C: + return ErrStopTimeout + } +} + +func (cs *chainSync) isBootstrapSync(currentBlockNumber uint) bool { + syncTarget := cs.peerViewSet.getTarget() + return currentBlockNumber+messages.MaxBlocksInResponse < syncTarget +} + +func (cs *chainSync) bootstrapSync() { + defer cs.wg.Done() + currentBlock, err := cs.blockState.GetHighestFinalisedHeader() + if err != nil { + panic("cannot find highest finalised header") + } + + for { + select { + case <-cs.stopCh: + logger.Warn("ending bootstrap sync, chain sync stop channel triggered") + return + default: + } + + isBootstrap := cs.isBootstrapSync(currentBlock.Number) + if isBootstrap { + cs.workerPool.useConnectedPeers() + err = cs.requestMaxBlocksFrom(currentBlock, networkInitialSync) + if err != nil { + if errors.Is(err, errBlockStatePaused) { + logger.Debugf("exiting bootstrap sync: %s", err) + return + } + logger.Errorf("requesting max blocks from best block header: %s", err) + } + + currentBlock, err = cs.blockState.BestBlockHeader() + if err != nil { + logger.Errorf("getting best block header: %v", err) + } + } else { + // we are less than 128 blocks behind the target we can use tip sync + cs.syncMode.Store(tip) + isSyncedGauge.Set(1) + logger.Infof("🔁 switched sync mode to %s", tip.String()) + return + } + } +} + +func (cs *chainSync) getSyncMode() chainSyncState { + return cs.syncMode.Load().(chainSyncState) +} + +// onBlockAnnounceHandshake sets a peer's best known block +func (cs *chainSync) onBlockAnnounceHandshake(who peer.ID, bestHash common.Hash, bestNumber uint) error { + cs.workerPool.fromBlockAnnounce(who) + cs.peerViewSet.update(who, bestHash, bestNumber) + + if cs.getSyncMode() == bootstrap { + return nil + } + + bestBlockHeader, err := cs.blockState.BestBlockHeader() + if err != nil { + return err + } + + isBootstrap := cs.isBootstrapSync(bestBlockHeader.Number) + if !isBootstrap { + return nil + } + + // we are more than 128 blocks behind the head, switch to bootstrap + cs.syncMode.Store(bootstrap) + isSyncedGauge.Set(0) + logger.Infof("🔁 switched sync mode to %s", bootstrap.String()) + + cs.wg.Add(1) + go cs.bootstrapSync() + return nil +} + +func (cs *chainSync) onBlockAnnounce(announced announcedBlock) error { + // TODO: https://github.com/ChainSafe/gossamer/issues/3432 + if cs.pendingBlocks.hasBlock(announced.header.Hash()) { + return fmt.Errorf("%w: block #%d (%s)", + errAlreadyInDisjointSet, announced.header.Number, announced.header.Hash()) + } + + err := cs.pendingBlocks.addHeader(announced.header) + if err != nil { + return fmt.Errorf("while adding pending block header: %w", err) + } + + if cs.getSyncMode() == bootstrap { + return nil + } + + bestBlockHeader, err := cs.blockState.BestBlockHeader() + if err != nil { + return fmt.Errorf("getting best block header: %w", err) + } + + isBootstrap := cs.isBootstrapSync(bestBlockHeader.Number) + if !isBootstrap { + return cs.requestAnnouncedBlock(bestBlockHeader, announced) + } + + return nil +} + +func (cs *chainSync) requestAnnouncedBlock(bestBlockHeader *types.Header, announce announcedBlock) error { + peerWhoAnnounced := announce.who + announcedHash := announce.header.Hash() + announcedNumber := announce.header.Number + + has, err := cs.blockState.HasHeader(announcedHash) + if err != nil { + return fmt.Errorf("checking if header exists: %s", err) + } + + if has { + return nil + } + + highestFinalizedHeader, err := cs.blockState.GetHighestFinalisedHeader() + if err != nil { + return fmt.Errorf("getting highest finalized header") + } + + // if the announced block contains a lower number than our best + // block header, let's check if it is greater than our latests + // finalized header, if so this block belongs to a fork chain + if announcedNumber < bestBlockHeader.Number { + // ignore the block if it has the same or lower number + // TODO: is it following the protocol to send a blockAnnounce with number < highestFinalized number? + if announcedNumber <= highestFinalizedHeader.Number { + return nil + } + + return cs.requestForkBlocks(bestBlockHeader, highestFinalizedHeader, announce.header, announce.who) + } + + err = cs.requestChainBlocks(announce.header, bestBlockHeader, peerWhoAnnounced) + if err != nil { + return fmt.Errorf("requesting chain blocks: %w", err) + } + + err = cs.requestPendingBlocks(highestFinalizedHeader) + if err != nil { + return fmt.Errorf("while requesting pending blocks") + } + + return nil +} + +func (cs *chainSync) requestChainBlocks(announcedHeader, bestBlockHeader *types.Header, + peerWhoAnnounced peer.ID) error { + gapLength := uint32(announcedHeader.Number - bestBlockHeader.Number) + startAtBlock := announcedHeader.Number + totalBlocks := uint32(1) + + var request *messages.BlockRequestMessage + startingBlock := *variadic.MustNewUint32OrHash(announcedHeader.Hash()) + + if gapLength > 1 { + request = messages.NewBlockRequest(startingBlock, gapLength, + messages.BootstrapRequestData, messages.Descending) + + startAtBlock = announcedHeader.Number - uint(*request.Max) + 1 + totalBlocks = *request.Max + + logger.Infof("requesting %d blocks from peer: %v, descending request from #%d (%s)", + gapLength, peerWhoAnnounced, announcedHeader.Number, announcedHeader.Hash().Short()) + } else { + request = messages.NewBlockRequest(startingBlock, 1, messages.BootstrapRequestData, messages.Descending) + logger.Infof("requesting a single block from peer: %v with Number: #%d and Hash: (%s)", + peerWhoAnnounced, announcedHeader.Number, announcedHeader.Hash().Short()) + } + + resultsQueue := make(chan *syncTaskResult) + err := cs.submitRequest(request, &peerWhoAnnounced, resultsQueue) + if err != nil { + return err + } + err = cs.handleWorkersResults(resultsQueue, networkBroadcast, startAtBlock, totalBlocks) + if err != nil { + return fmt.Errorf("while handling workers results: %w", err) + } + + return nil +} + +func (cs *chainSync) requestForkBlocks(bestBlockHeader, highestFinalizedHeader, announcedHeader *types.Header, + peerWhoAnnounced peer.ID) error { + logger.Infof("block announce lower than best block #%d (%s) and greater highest finalized #%d (%s)", + bestBlockHeader.Number, bestBlockHeader.Hash().Short(), + highestFinalizedHeader.Number, highestFinalizedHeader.Hash().Short()) + + parentExists, err := cs.blockState.HasHeader(announcedHeader.ParentHash) + if err != nil && !errors.Is(err, database.ErrNotFound) { + return fmt.Errorf("while checking header exists: %w", err) + } + + gapLength := uint32(1) + startAtBlock := announcedHeader.Number + announcedHash := announcedHeader.Hash() + var request *messages.BlockRequestMessage + startingBlock := *variadic.MustNewUint32OrHash(announcedHash) + + if parentExists { + request = messages.NewBlockRequest(startingBlock, 1, messages.BootstrapRequestData, messages.Descending) + } else { + gapLength = uint32(announcedHeader.Number - highestFinalizedHeader.Number) + startAtBlock = highestFinalizedHeader.Number + 1 + request = messages.NewBlockRequest(startingBlock, gapLength, messages.BootstrapRequestData, messages.Descending) + } + + logger.Infof("requesting %d fork blocks from peer: %v starting at #%d (%s)", + gapLength, peerWhoAnnounced, announcedHeader.Number, announcedHash.Short()) + + resultsQueue := make(chan *syncTaskResult) + err = cs.submitRequest(request, &peerWhoAnnounced, resultsQueue) + if err != nil { + return err + } + err = cs.handleWorkersResults(resultsQueue, networkBroadcast, startAtBlock, gapLength) + if err != nil { + return fmt.Errorf("while handling workers results: %w", err) + } + + return nil +} + +func (cs *chainSync) requestPendingBlocks(highestFinalizedHeader *types.Header) error { + pendingBlocksTotal := cs.pendingBlocks.size() + logger.Infof("total of pending blocks: %d", pendingBlocksTotal) + if pendingBlocksTotal < 1 { + return nil + } + + pendingBlocks := cs.pendingBlocks.getBlocks() + for _, pendingBlock := range pendingBlocks { + if pendingBlock.number <= highestFinalizedHeader.Number { + cs.pendingBlocks.removeBlock(pendingBlock.hash) + continue + } + + parentExists, err := cs.blockState.HasHeader(pendingBlock.header.ParentHash) + if err != nil { + return fmt.Errorf("getting pending block parent header: %w", err) + } + + if parentExists { + err := cs.handleReadyBlock(pendingBlock.toBlockData(), networkBroadcast) + if err != nil { + return fmt.Errorf("handling ready block: %w", err) + } + continue + } + + gapLength := pendingBlock.number - highestFinalizedHeader.Number + if gapLength > 128 { + logger.Warnf("gap of %d blocks, max expected: 128 block", gapLength) + gapLength = 128 + } + + descendingGapRequest := messages.NewBlockRequest(*variadic.MustNewUint32OrHash(pendingBlock.hash), + uint32(gapLength), messages.BootstrapRequestData, messages.Descending) + startAtBlock := pendingBlock.number - uint(*descendingGapRequest.Max) + 1 + + // the `requests` in the tip sync are not related necessarily + // this is why we need to treat them separately + resultsQueue := make(chan *syncTaskResult) + err = cs.submitRequest(descendingGapRequest, nil, resultsQueue) + if err != nil { + return err + } + // TODO: we should handle the requests concurrently + // a way of achieve that is by constructing a new `handleWorkersResults` for + // handling only tip sync requests + err = cs.handleWorkersResults(resultsQueue, networkBroadcast, startAtBlock, *descendingGapRequest.Max) + if err != nil { + return fmt.Errorf("while handling workers results: %w", err) + } + } + + return nil +} + +func (cs *chainSync) requestMaxBlocksFrom(bestBlockHeader *types.Header, origin blockOrigin) error { //nolint:unparam + startRequestAt := bestBlockHeader.Number + 1 + + // targetBlockNumber is the virtual target we will request, however + // we should bound it to the real target which is collected through + // block announces received from other peers + targetBlockNumber := startRequestAt + maxRequestsAllowed*128 + realTarget := cs.peerViewSet.getTarget() + + if targetBlockNumber > realTarget { + targetBlockNumber = realTarget + } + + requests := messages.NewAscendingBlockRequests(startRequestAt, targetBlockNumber, + messages.BootstrapRequestData) + + var expectedAmountOfBlocks uint32 + for _, request := range requests { + if request.Max != nil { + expectedAmountOfBlocks += *request.Max + } + } + + resultsQueue, err := cs.submitRequests(requests) + if err != nil { + return err + } + err = cs.handleWorkersResults(resultsQueue, origin, startRequestAt, expectedAmountOfBlocks) + if err != nil { + return fmt.Errorf("while handling workers results: %w", err) + } + + return nil +} + +func (cs *chainSync) submitRequest( + request *messages.BlockRequestMessage, + who *peer.ID, + resultCh chan<- *syncTaskResult, +) error { + if !cs.blockState.IsPaused() { + cs.workerPool.submitRequest(request, who, resultCh) + return nil + } + return fmt.Errorf("submitting request: %w", errBlockStatePaused) +} + +func (cs *chainSync) submitRequests(requests []*messages.BlockRequestMessage) ( + resultCh chan *syncTaskResult, err error) { + if !cs.blockState.IsPaused() { + return cs.workerPool.submitRequests(requests), nil + } + return nil, fmt.Errorf("submitting requests: %w", errBlockStatePaused) +} + +func (cs *chainSync) showSyncStats(syncBegin time.Time, syncedBlocks int) { + finalisedHeader, err := cs.blockState.GetHighestFinalisedHeader() + if err != nil { + logger.Criticalf("getting highest finalized header: %w", err) + return + } + + totalSyncAndImportSeconds := time.Since(syncBegin).Seconds() + bps := float64(syncedBlocks) / totalSyncAndImportSeconds + logger.Infof("⛓️ synced %d blocks, "+ + "took: %.2f seconds, bps: %.2f blocks/second", + syncedBlocks, totalSyncAndImportSeconds, bps) + + logger.Infof( + "🚣 currently syncing, %d peers connected, "+ + "%d available workers, "+ + "target block number %d, "+ + "finalised #%d (%s) "+ + "sync mode: %s", + len(cs.network.Peers()), + cs.workerPool.totalWorkers(), + cs.peerViewSet.getTarget(), + finalisedHeader.Number, + finalisedHeader.Hash().Short(), + cs.getSyncMode().String(), + ) +} + +// handleWorkersResults, every time we submit requests to workers they results should be computed here +// and every cicle we should endup with a complete chain, whenever we identify +// any error from a worker we should evaluate the error and re-insert the request +// in the queue and wait for it to completes +// TODO: handle only justification requests +func (cs *chainSync) handleWorkersResults( + workersResults chan *syncTaskResult, origin blockOrigin, startAtBlock uint, expectedSyncedBlocks uint32) error { + startTime := time.Now() + syncingChain := make([]*types.BlockData, expectedSyncedBlocks) + // the total numbers of blocks is missing in the syncing chain + waitingBlocks := expectedSyncedBlocks + +taskResultLoop: + for waitingBlocks > 0 { + // in a case where we don't handle workers results we should check the pool + idleDuration := time.Minute + idleTimer := time.NewTimer(idleDuration) + + select { + case <-cs.stopCh: + return nil + + case <-idleTimer.C: + logger.Warnf("idle ticker triggered! checking pool") + cs.workerPool.useConnectedPeers() + continue + + case taskResult := <-workersResults: + if !idleTimer.Stop() { + <-idleTimer.C + } + + who := taskResult.who + request := taskResult.request + response := taskResult.response + + logger.Debugf("task result: peer(%s), with error: %v, with response: %v", + taskResult.who, taskResult.err != nil, taskResult.response != nil) + + if taskResult.err != nil { + if !errors.Is(taskResult.err, network.ErrReceivedEmptyMessage) { + logger.Errorf("task result: peer(%s) error: %s", + taskResult.who, taskResult.err) + + if errors.Is(taskResult.err, messages.ErrNilBlockInResponse) { + cs.network.ReportPeer(peerset.ReputationChange{ + Value: peerset.BadMessageValue, + Reason: peerset.BadMessageReason, + }, who) + } + + if strings.Contains(taskResult.err.Error(), "protocols not supported") { + cs.network.ReportPeer(peerset.ReputationChange{ + Value: peerset.BadProtocolValue, + Reason: peerset.BadProtocolReason, + }, who) + } + } + + err := cs.submitRequest(request, nil, workersResults) + if err != nil { + return err + } + continue + } + + if request.Direction == messages.Descending { + // reverse blocks before pre-validating and placing in ready queue + reverseBlockData(response.BlockData) + } + + err := validateResponseFields(request.RequestedData, response.BlockData) + if err != nil { + logger.Criticalf("validating fields: %s", err) + // TODO: check the reputation change for nil body in response + // and nil justification in response + if errors.Is(err, errNilHeaderInResponse) { + cs.network.ReportPeer(peerset.ReputationChange{ + Value: peerset.IncompleteHeaderValue, + Reason: peerset.IncompleteHeaderReason, + }, who) + } + + err = cs.submitRequest(taskResult.request, nil, workersResults) + if err != nil { + return err + } + continue taskResultLoop + } + + isChain := isResponseAChain(response.BlockData) + if !isChain { + logger.Criticalf("response from %s is not a chain", who) + err = cs.submitRequest(taskResult.request, nil, workersResults) + if err != nil { + return err + } + continue taskResultLoop + } + + grows := doResponseGrowsTheChain(response.BlockData, syncingChain, + startAtBlock, expectedSyncedBlocks) + if !grows { + logger.Criticalf("response from %s does not grows the ongoing chain", who) + err = cs.submitRequest(taskResult.request, nil, workersResults) + if err != nil { + return err + } + continue taskResultLoop + } + + for _, blockInResponse := range response.BlockData { + if slices.Contains(cs.badBlocks, blockInResponse.Hash.String()) { + logger.Criticalf("%s sent a known bad block: %s (#%d)", + who, blockInResponse.Hash.String(), blockInResponse.Number()) + + cs.network.ReportPeer(peerset.ReputationChange{ + Value: peerset.BadBlockAnnouncementValue, + Reason: peerset.BadBlockAnnouncementReason, + }, who) + + cs.workerPool.ignorePeerAsWorker(taskResult.who) + err = cs.submitRequest(taskResult.request, nil, workersResults) + if err != nil { + return err + } + continue taskResultLoop + } + + blockExactIndex := blockInResponse.Header.Number - startAtBlock + if blockExactIndex < uint(expectedSyncedBlocks) { + syncingChain[blockExactIndex] = blockInResponse + } + } + + // we need to check if we've filled all positions + // otherwise we should wait for more responses + waitingBlocks -= uint32(len(response.BlockData)) + + // we received a response without the desired amount of blocks + // we should include a new request to retrieve the missing blocks + if len(response.BlockData) < int(*request.Max) { + difference := uint32(int(*request.Max) - len(response.BlockData)) + lastItem := response.BlockData[len(response.BlockData)-1] + + startRequestNumber := uint32(lastItem.Header.Number + 1) + startAt, err := variadic.NewUint32OrHash(startRequestNumber) + if err != nil { + panic(err) + } + + taskResult.request = &messages.BlockRequestMessage{ + RequestedData: messages.BootstrapRequestData, + StartingBlock: *startAt, + Direction: messages.Ascending, + Max: &difference, + } + err = cs.submitRequest(taskResult.request, nil, workersResults) + if err != nil { + return err + } + continue taskResultLoop + } + } + } + + retreiveBlocksSeconds := time.Since(startTime).Seconds() + logger.Infof("🔽 retrieved %d blocks, took: %.2f seconds, starting process...", + expectedSyncedBlocks, retreiveBlocksSeconds) + + // response was validated! place into ready block queue + for _, bd := range syncingChain { + // block is ready to be processed! + if err := cs.handleReadyBlock(bd, origin); err != nil { + return fmt.Errorf("while handling ready block: %w", err) + } + } + + cs.showSyncStats(startTime, len(syncingChain)) + return nil +} + +func (cs *chainSync) handleReadyBlock(bd *types.BlockData, origin blockOrigin) error { + // if header was not requested, get it from the pending set + // if we're expecting headers, validate should ensure we have a header + if bd.Header == nil { + block := cs.pendingBlocks.getBlock(bd.Hash) + if block == nil { + // block wasn't in the pending set! + // let's check the db as maybe we already processed it + has, err := cs.blockState.HasHeader(bd.Hash) + if err != nil && !errors.Is(err, database.ErrNotFound) { + logger.Debugf("failed to check if header is known for hash %s: %s", bd.Hash, err) + return err + } + + if has { + logger.Tracef("ignoring block we've already processed, hash=%s", bd.Hash) + return err + } + + // this is bad and shouldn't happen + logger.Errorf("block with unknown header is ready: hash=%s", bd.Hash) + return err + } + + if block.header == nil { + logger.Errorf("new ready block number (unknown) with hash %s", bd.Hash) + return nil + } + + bd.Header = block.header + } + + err := cs.processBlockData(*bd, origin) + if err != nil { + // depending on the error, we might want to save this block for later + logger.Errorf("block data processing for block with hash %s failed: %s", bd.Hash, err) + return err + } + + cs.pendingBlocks.removeBlock(bd.Hash) + return nil +} + +// processBlockData processes the BlockData from a BlockResponse and +// returns the index of the last BlockData it handled on success, +// or the index of the block data that errored on failure. +// TODO: https://github.com/ChainSafe/gossamer/issues/3468 +func (cs *chainSync) processBlockData(blockData types.BlockData, origin blockOrigin) error { + // while in bootstrap mode we don't need to broadcast block announcements + announceImportedBlock := cs.getSyncMode() == tip + + if blockData.Header != nil { + var ( + hasJustification = blockData.Justification != nil && len(*blockData.Justification) > 0 + round uint64 + setID uint64 + ) + + if hasJustification { + var err error + round, setID, err = cs.finalityGadget.VerifyBlockJustification( + blockData.Header.Hash(), blockData.Header.Number, *blockData.Justification) + if err != nil { + return fmt.Errorf("verifying justification: %w", err) + } + } + + if blockData.Body != nil { + err := cs.processBlockDataWithHeaderAndBody(blockData, origin, announceImportedBlock) + if err != nil { + return fmt.Errorf("processing block data with header and body: %w", err) + } + } + + if hasJustification { + header := blockData.Header + err := cs.blockState.SetFinalisedHash(header.Hash(), round, setID) + if err != nil { + return fmt.Errorf("setting finalised hash: %w", err) + } + err = cs.blockState.SetJustification(header.Hash(), *blockData.Justification) + if err != nil { + return fmt.Errorf("setting justification for block number %d: %w", header.Number, err) + } + + return nil + } + } + + err := cs.blockState.CompareAndSetBlockData(&blockData) + if err != nil { + return fmt.Errorf("comparing and setting block data: %w", err) + } + + return nil +} + +func (cs *chainSync) processBlockDataWithHeaderAndBody(blockData types.BlockData, + origin blockOrigin, announceImportedBlock bool) (err error) { + + if origin != networkInitialSync { + err = cs.babeVerifier.VerifyBlock(blockData.Header) + if err != nil { + return fmt.Errorf("babe verifying block: %w", err) + } + } + + cs.handleBody(blockData.Body) + + block := &types.Block{ + Header: *blockData.Header, + Body: *blockData.Body, + } + + err = cs.handleBlock(block, announceImportedBlock) + if err != nil { + return fmt.Errorf("handling block: %w", err) + } + + return nil +} + +// handleHeader handles block bodies included in BlockResponses +func (cs *chainSync) handleBody(body *types.Body) { + acc := 0 + for _, ext := range *body { + acc += len(ext) + cs.transactionState.RemoveExtrinsic(ext) + } + + blockSizeGauge.Set(float64(acc)) +} + +// handleHeader handles blocks (header+body) included in BlockResponses +func (cs *chainSync) handleBlock(block *types.Block, announceImportedBlock bool) error { + parent, err := cs.blockState.GetHeader(block.Header.ParentHash) + if err != nil { + return fmt.Errorf("%w: %s", errFailedToGetParent, err) + } + + cs.storageState.Lock() + defer cs.storageState.Unlock() + + ts, err := cs.storageState.TrieState(&parent.StateRoot) + if err != nil { + return err + } + + root := ts.Trie().MustHash() + if !bytes.Equal(parent.StateRoot[:], root[:]) { + panic("parent state root does not match snapshot state root") + } + + rt, err := cs.blockState.GetRuntime(parent.Hash()) + if err != nil { + return err + } + + rt.SetContextStorage(ts) + + _, err = rt.ExecuteBlock(block) + if err != nil { + return fmt.Errorf("failed to execute block %d: %w", block.Header.Number, err) + } + + if err = cs.blockImportHandler.HandleBlockImport(block, ts, announceImportedBlock); err != nil { + return err + } + + blockHash := block.Header.Hash() + cs.telemetry.SendMessage(telemetry.NewBlockImport( + &blockHash, + block.Header.Number, + "NetworkInitialSync")) + + return nil +} + +// validateResponseFields checks that the expected fields are in the block data +func validateResponseFields(requestedData byte, blocks []*types.BlockData) error { + for _, bd := range blocks { + if bd == nil { + return errNilBlockData + } + + if (requestedData&messages.RequestedDataHeader) == messages.RequestedDataHeader && bd.Header == nil { + return fmt.Errorf("%w: %s", errNilHeaderInResponse, bd.Hash) + } + + if (requestedData&messages.RequestedDataBody) == messages.RequestedDataBody && bd.Body == nil { + return fmt.Errorf("%w: %s", errNilBodyInResponse, bd.Hash) + } + + // if we requested strictly justification + if (requestedData|messages.RequestedDataJustification) == messages.RequestedDataJustification && + bd.Justification == nil { + return fmt.Errorf("%w: %s", errNilJustificationInResponse, bd.Hash) + } + } + + return nil +} + +func isResponseAChain(responseBlockData []*types.BlockData) bool { + if len(responseBlockData) < 2 { + return true + } + + previousBlockData := responseBlockData[0] + for _, currBlockData := range responseBlockData[1:] { + previousHash := previousBlockData.Header.Hash() + isParent := previousHash == currBlockData.Header.ParentHash + if !isParent { + return false + } + + previousBlockData = currBlockData + } + + return true +} + +// doResponseGrowsTheChain will check if the acquired blocks grows the current chain +// matching their parent hashes +func doResponseGrowsTheChain(response, ongoingChain []*types.BlockData, startAtBlock uint, expectedTotal uint32) bool { + // the ongoing chain does not have any element, we can safely insert an item in it + if len(ongoingChain) < 1 { + return true + } + + compareParentHash := func(parent, child *types.BlockData) bool { + return parent.Header.Hash() == child.Header.ParentHash + } + + firstBlockInResponse := response[0] + firstBlockExactIndex := firstBlockInResponse.Header.Number - startAtBlock + if firstBlockExactIndex != 0 && firstBlockExactIndex < uint(expectedTotal) { + leftElement := ongoingChain[firstBlockExactIndex-1] + if leftElement != nil && !compareParentHash(leftElement, firstBlockInResponse) { + return false + } + } + + switch { + // if the response contains only one block then we should check both sides + // for example, if the response contains only one block called X we should + // check if its parent hash matches with the left element as well as we should + // check if the right element contains X hash as its parent hash + // ... W <- X -> Y ... + // we can skip left side comparison if X is in the 0 index and we can skip + // right side comparison if X is in the last index + case len(response) == 1: + if uint32(firstBlockExactIndex+1) < expectedTotal { + rightElement := ongoingChain[firstBlockExactIndex+1] + if rightElement != nil && !compareParentHash(firstBlockInResponse, rightElement) { + return false + } + } + // if the response contains more than 1 block then we need to compare + // only the start and the end of the acquired response, for example + // let's say we receive a response [C, D, E] and we need to check + // if those values fits correctly: + // ... B <- C D E -> F + // we skip the left check if its index is equals to 0 and we skip the right + // check if it ends in the latest position of the ongoing array + case len(response) > 1: + lastBlockInResponse := response[len(response)-1] + lastBlockExactIndex := lastBlockInResponse.Header.Number - startAtBlock + + if uint32(lastBlockExactIndex+1) < expectedTotal { + rightElement := ongoingChain[lastBlockExactIndex+1] + if rightElement != nil && !compareParentHash(lastBlockInResponse, rightElement) { + return false + } + } + } + + return true +} + +func (cs *chainSync) getHighestBlock() (highestBlock uint, err error) { + if cs.peerViewSet.size() == 0 { + return 0, errNoPeers + } + + for _, ps := range cs.peerViewSet.values() { + if ps.number < highestBlock { + continue + } + highestBlock = ps.number + } + + return highestBlock, nil +} diff --git a/dot/sync/chain_sync_test.go b/dot/sync/chain_sync_test.go new file mode 100644 index 0000000000..4af6deac79 --- /dev/null +++ b/dot/sync/chain_sync_test.go @@ -0,0 +1,1901 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package sync + +import ( + "errors" + "fmt" + "sync/atomic" + "testing" + "time" + + "github.com/ChainSafe/gossamer/dot/network" + "github.com/ChainSafe/gossamer/dot/network/messages" + "github.com/ChainSafe/gossamer/dot/peerset" + "github.com/ChainSafe/gossamer/dot/telemetry" + "github.com/ChainSafe/gossamer/dot/types" + "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/lib/common/variadic" + "github.com/ChainSafe/gossamer/lib/runtime/storage" + "github.com/ChainSafe/gossamer/pkg/trie" + inmemory_trie "github.com/ChainSafe/gossamer/pkg/trie/inmemory" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" +) + +func Test_chainSyncState_String(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + s chainSyncState + want string + }{ + { + name: "case_bootstrap", + s: bootstrap, + want: "bootstrap", + }, + { + name: "case_tip", + s: tip, + want: "tip", + }, + { + name: "case_unknown", + s: 3, + want: "unknown", + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := tt.s.String() + assert.Equal(t, tt.want, got) + }) + } +} + +func Test_chainSync_onBlockAnnounce(t *testing.T) { + t.Parallel() + const somePeer = peer.ID("abc") + + errTest := errors.New("test error") + emptyTrieState := storage.NewTrieState(inmemory_trie.NewEmptyTrie()) + block1AnnounceHeader := types.NewHeader(common.Hash{}, emptyTrieState.Trie().MustHash(), + common.Hash{}, 1, nil) + block2AnnounceHeader := types.NewHeader(block1AnnounceHeader.Hash(), + emptyTrieState.Trie().MustHash(), + common.Hash{}, 2, nil) + + testCases := map[string]struct { + waitBootstrapSync bool + chainSyncBuilder func(ctrl *gomock.Controller) *chainSync + peerID peer.ID + blockAnnounceHeader *types.Header + errWrapped error + errMessage string + expectedSyncMode chainSyncState + }{ + "announced_block_already_exists_in_disjoint_set": { + chainSyncBuilder: func(ctrl *gomock.Controller) *chainSync { + pendingBlocks := NewMockDisjointBlockSet(ctrl) + pendingBlocks.EXPECT().hasBlock(block2AnnounceHeader.Hash()).Return(true) + return &chainSync{ + stopCh: make(chan struct{}), + pendingBlocks: pendingBlocks, + peerViewSet: newPeerViewSet(0), + workerPool: newSyncWorkerPool(NewMockNetwork(nil), NewMockRequestMaker(nil)), + } + }, + peerID: somePeer, + blockAnnounceHeader: block2AnnounceHeader, + errWrapped: errAlreadyInDisjointSet, + errMessage: fmt.Sprintf("already in disjoint set: block #%d (%s)", + block2AnnounceHeader.Number, block2AnnounceHeader.Hash()), + }, + "failed_to_add_announced_block_in_disjoint_set": { + chainSyncBuilder: func(ctrl *gomock.Controller) *chainSync { + pendingBlocks := NewMockDisjointBlockSet(ctrl) + pendingBlocks.EXPECT().hasBlock(block2AnnounceHeader.Hash()).Return(false) + pendingBlocks.EXPECT().addHeader(block2AnnounceHeader).Return(errTest) + + return &chainSync{ + stopCh: make(chan struct{}), + pendingBlocks: pendingBlocks, + peerViewSet: newPeerViewSet(0), + workerPool: newSyncWorkerPool(NewMockNetwork(nil), NewMockRequestMaker(nil)), + } + }, + peerID: somePeer, + blockAnnounceHeader: block2AnnounceHeader, + errWrapped: errTest, + errMessage: "while adding pending block header: test error", + }, + "announced_block_while_in_bootstrap_mode": { + chainSyncBuilder: func(ctrl *gomock.Controller) *chainSync { + pendingBlocks := NewMockDisjointBlockSet(ctrl) + pendingBlocks.EXPECT().hasBlock(block2AnnounceHeader.Hash()).Return(false) + pendingBlocks.EXPECT().addHeader(block2AnnounceHeader).Return(nil) + + state := atomic.Value{} + state.Store(bootstrap) + + return &chainSync{ + stopCh: make(chan struct{}), + pendingBlocks: pendingBlocks, + syncMode: state, + peerViewSet: newPeerViewSet(0), + workerPool: newSyncWorkerPool(NewMockNetwork(nil), NewMockRequestMaker(nil)), + } + }, + peerID: somePeer, + blockAnnounceHeader: block2AnnounceHeader, + }, + "announced_block_while_in_tip_mode": { + chainSyncBuilder: func(ctrl *gomock.Controller) *chainSync { + pendingBlocksMock := NewMockDisjointBlockSet(ctrl) + pendingBlocksMock.EXPECT().hasBlock(block2AnnounceHeader.Hash()).Return(false) + pendingBlocksMock.EXPECT().addHeader(block2AnnounceHeader).Return(nil) + pendingBlocksMock.EXPECT().removeBlock(block2AnnounceHeader.Hash()) + pendingBlocksMock.EXPECT().size().Return(0) + + blockStateMock := NewMockBlockState(ctrl) + blockStateMock.EXPECT(). + HasHeader(block2AnnounceHeader.Hash()). + Return(false, nil) + blockStateMock.EXPECT().IsPaused().Return(false) + + blockStateMock.EXPECT(). + BestBlockHeader(). + Return(block1AnnounceHeader, nil) + + blockStateMock.EXPECT(). + GetHighestFinalisedHeader(). + Return(block2AnnounceHeader, nil). + Times(2) + + expectedRequest := messages.NewBlockRequest(*variadic.MustNewUint32OrHash(block2AnnounceHeader.Hash()), + 1, messages.BootstrapRequestData, messages.Descending) + + fakeBlockBody := types.Body([]types.Extrinsic{}) + mockedBlockResponse := &messages.BlockResponseMessage{ + BlockData: []*types.BlockData{ + { + Hash: block2AnnounceHeader.Hash(), + Header: block2AnnounceHeader, + Body: &fakeBlockBody, + }, + }, + } + + networkMock := NewMockNetwork(ctrl) + networkMock.EXPECT().Peers().Return([]common.PeerInfo{}) + + requestMaker := NewMockRequestMaker(ctrl) + requestMaker.EXPECT(). + Do(somePeer, expectedRequest, &messages.BlockResponseMessage{}). + DoAndReturn(func(_, _, response any) any { + responsePtr := response.(*messages.BlockResponseMessage) + *responsePtr = *mockedBlockResponse + return nil + }) + + babeVerifierMock := NewMockBabeVerifier(ctrl) + storageStateMock := NewMockStorageState(ctrl) + importHandlerMock := NewMockBlockImportHandler(ctrl) + telemetryMock := NewMockTelemetry(ctrl) + + const announceBlock = true + ensureSuccessfulBlockImportFlow(t, block1AnnounceHeader, mockedBlockResponse.BlockData, + blockStateMock, babeVerifierMock, storageStateMock, importHandlerMock, telemetryMock, + networkBroadcast, announceBlock) + + workerPool := newSyncWorkerPool(networkMock, requestMaker) + // include the peer who announced the block in the pool + workerPool.newPeer(somePeer) + + state := atomic.Value{} + state.Store(tip) + + return &chainSync{ + stopCh: make(chan struct{}), + pendingBlocks: pendingBlocksMock, + syncMode: state, + workerPool: workerPool, + network: networkMock, + blockState: blockStateMock, + babeVerifier: babeVerifierMock, + telemetry: telemetryMock, + storageState: storageStateMock, + blockImportHandler: importHandlerMock, + peerViewSet: newPeerViewSet(0), + } + }, + peerID: somePeer, + blockAnnounceHeader: block2AnnounceHeader, + }, + } + + for name, tt := range testCases { + tt := tt + t.Run(name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + chainSync := tt.chainSyncBuilder(ctrl) + err := chainSync.onBlockAnnounce(announcedBlock{ + who: tt.peerID, + header: tt.blockAnnounceHeader, + }) + + assert.ErrorIs(t, err, tt.errWrapped) + if tt.errWrapped != nil { + assert.EqualError(t, err, tt.errMessage) + } + + if tt.waitBootstrapSync { + chainSync.wg.Wait() + err = chainSync.workerPool.stop() + require.NoError(t, err) + } + }) + } +} + +func Test_chainSync_onBlockAnnounceHandshake_tipModeNeedToCatchup(t *testing.T) { + ctrl := gomock.NewController(t) + const somePeer = peer.ID("abc") + + emptyTrieState := storage.NewTrieState(inmemory_trie.NewEmptyTrie()) + block1AnnounceHeader := types.NewHeader(common.Hash{}, emptyTrieState.Trie().MustHash(), + common.Hash{}, 1, nil) + block2AnnounceHeader := types.NewHeader(block1AnnounceHeader.Hash(), + emptyTrieState.Trie().MustHash(), + common.Hash{}, 130, nil) + + blockStateMock := NewMockBlockState(ctrl) + blockStateMock.EXPECT(). + BestBlockHeader(). + Return(block1AnnounceHeader, nil). + Times(2) + + blockStateMock.EXPECT(). + BestBlockHeader(). + Return(block2AnnounceHeader, nil). + Times(1) + + blockStateMock.EXPECT(). + GetHighestFinalisedHeader(). + Return(block1AnnounceHeader, nil). + Times(3) + + blockStateMock.EXPECT().IsPaused().Return(false).Times(2) + + expectedRequest := messages.NewAscendingBlockRequests( + block1AnnounceHeader.Number+1, + block2AnnounceHeader.Number, messages.BootstrapRequestData) + + networkMock := NewMockNetwork(ctrl) + networkMock.EXPECT().Peers().Return([]common.PeerInfo{}). + Times(2) + networkMock.EXPECT().AllConnectedPeersIDs().Return([]peer.ID{}).Times(2) + + firstMockedResponse := createSuccesfullBlockResponse(t, block1AnnounceHeader.Hash(), 2, 128) + latestItemFromMockedResponse := firstMockedResponse.BlockData[len(firstMockedResponse.BlockData)-1] + + secondMockedResponse := createSuccesfullBlockResponse(t, latestItemFromMockedResponse.Hash, + int(latestItemFromMockedResponse.Header.Number+1), 1) + + requestMaker := NewMockRequestMaker(ctrl) + requestMaker.EXPECT(). + Do(somePeer, expectedRequest[0], &messages.BlockResponseMessage{}). + DoAndReturn(func(_, _, response any) any { + responsePtr := response.(*messages.BlockResponseMessage) + *responsePtr = *firstMockedResponse + return nil + }).Times(2) + + requestMaker.EXPECT(). + Do(somePeer, expectedRequest[1], &messages.BlockResponseMessage{}). + DoAndReturn(func(_, _, response any) any { + responsePtr := response.(*messages.BlockResponseMessage) + *responsePtr = *secondMockedResponse + return nil + }).Times(2) + + babeVerifierMock := NewMockBabeVerifier(ctrl) + storageStateMock := NewMockStorageState(ctrl) + importHandlerMock := NewMockBlockImportHandler(ctrl) + telemetryMock := NewMockTelemetry(ctrl) + + const announceBlock = false + ensureSuccessfulBlockImportFlow(t, block1AnnounceHeader, firstMockedResponse.BlockData, + blockStateMock, babeVerifierMock, storageStateMock, importHandlerMock, telemetryMock, + networkInitialSync, announceBlock) + ensureSuccessfulBlockImportFlow(t, latestItemFromMockedResponse.Header, secondMockedResponse.BlockData, + blockStateMock, babeVerifierMock, storageStateMock, importHandlerMock, telemetryMock, + networkInitialSync, announceBlock) + + state := atomic.Value{} + state.Store(tip) + + stopCh := make(chan struct{}) + defer close(stopCh) + + chainSync := &chainSync{ + stopCh: stopCh, + peerViewSet: newPeerViewSet(10), + syncMode: state, + pendingBlocks: newDisjointBlockSet(0), + workerPool: newSyncWorkerPool(networkMock, requestMaker), + network: networkMock, + blockState: blockStateMock, + babeVerifier: babeVerifierMock, + telemetry: telemetryMock, + storageState: storageStateMock, + blockImportHandler: importHandlerMock, + } + + err := chainSync.onBlockAnnounceHandshake(somePeer, block2AnnounceHeader.Hash(), block2AnnounceHeader.Number) + require.NoError(t, err) + + chainSync.wg.Wait() + err = chainSync.workerPool.stop() + require.NoError(t, err) + + require.Equal(t, chainSync.getSyncMode(), tip) +} + +func TestChainSync_onBlockAnnounceHandshake_onBootstrapMode(t *testing.T) { + const randomHashString = "0x580d77a9136035a0bc3c3cd86286172f7f81291164c5914266073a30466fba21" + randomHash := common.MustHexToHash(randomHashString) + + testcases := map[string]struct { + newChainSync func(t *testing.T, ctrl *gomock.Controller) *chainSync + peerID peer.ID + bestHash common.Hash + bestNumber uint + shouldBeAWorker bool + workerStatus byte + }{ + "new_peer": { + newChainSync: func(t *testing.T, ctrl *gomock.Controller) *chainSync { + networkMock := NewMockNetwork(ctrl) + workerPool := newSyncWorkerPool(networkMock, NewMockRequestMaker(nil)) + + cs := newChainSyncTest(t, ctrl) + cs.syncMode.Store(bootstrap) + cs.workerPool = workerPool + return cs + }, + peerID: peer.ID("peer-test"), + bestHash: randomHash, + bestNumber: uint(20), + shouldBeAWorker: true, + workerStatus: available, + }, + "ignore_peer_should_not_be_included_in_the_workerpoll": { + newChainSync: func(t *testing.T, ctrl *gomock.Controller) *chainSync { + networkMock := NewMockNetwork(ctrl) + workerPool := newSyncWorkerPool(networkMock, NewMockRequestMaker(nil)) + workerPool.ignorePeers = map[peer.ID]struct{}{ + peer.ID("peer-test"): {}, + } + + cs := newChainSyncTest(t, ctrl) + cs.syncMode.Store(bootstrap) + cs.workerPool = workerPool + return cs + }, + peerID: peer.ID("peer-test"), + bestHash: randomHash, + bestNumber: uint(20), + shouldBeAWorker: false, + }, + "peer_already_exists_in_the_pool": { + newChainSync: func(t *testing.T, ctrl *gomock.Controller) *chainSync { + networkMock := NewMockNetwork(ctrl) + workerPool := newSyncWorkerPool(networkMock, NewMockRequestMaker(nil)) + workerPool.workers = map[peer.ID]*syncWorker{ + peer.ID("peer-test"): { + worker: &worker{status: available}, + }, + } + + cs := newChainSyncTest(t, ctrl) + cs.syncMode.Store(bootstrap) + cs.workerPool = workerPool + return cs + }, + peerID: peer.ID("peer-test"), + bestHash: randomHash, + bestNumber: uint(20), + shouldBeAWorker: true, + workerStatus: available, + }, + } + + for tname, tt := range testcases { + tt := tt + t.Run(tname, func(t *testing.T) { + ctrl := gomock.NewController(t) + cs := tt.newChainSync(t, ctrl) + cs.onBlockAnnounceHandshake(tt.peerID, tt.bestHash, tt.bestNumber) + + view, exists := cs.peerViewSet.find(tt.peerID) + require.True(t, exists) + require.Equal(t, tt.peerID, view.who) + require.Equal(t, tt.bestHash, view.hash) + require.Equal(t, tt.bestNumber, view.number) + + if tt.shouldBeAWorker { + syncWorker, exists := cs.workerPool.workers[tt.peerID] + require.True(t, exists) + require.Equal(t, tt.workerStatus, syncWorker.worker.status) + } else { + _, exists := cs.workerPool.workers[tt.peerID] + require.False(t, exists) + } + }) + } +} + +func newChainSyncTest(t *testing.T, ctrl *gomock.Controller) *chainSync { + t.Helper() + + mockBlockState := NewMockBlockState(ctrl) + mockBlockState.EXPECT().GetFinalisedNotifierChannel().Return(make(chan *types.FinalisationInfo)) + + cfg := chainSyncConfig{ + bs: mockBlockState, + pendingBlocks: newDisjointBlockSet(pendingBlocksLimit), + minPeers: 1, + maxPeers: 5, + slotDuration: 6 * time.Second, + } + + return newChainSync(cfg) +} + +func setupChainSyncToBootstrapMode(t *testing.T, blocksAhead uint, + bs BlockState, net Network, reqMaker network.RequestMaker, babeVerifier BabeVerifier, + storageState StorageState, blockImportHandler BlockImportHandler, telemetry Telemetry) *chainSync { + t.Helper() + mockedPeerID := []peer.ID{ + peer.ID("some_peer_1"), + peer.ID("some_peer_2"), + peer.ID("some_peer_3"), + } + + peerViewMap := map[peer.ID]peerView{} + for _, p := range mockedPeerID { + peerViewMap[p] = peerView{ + who: p, + hash: common.Hash{1, 2, 3}, + number: blocksAhead, + } + } + + cfg := chainSyncConfig{ + pendingBlocks: newDisjointBlockSet(pendingBlocksLimit), + minPeers: 1, + maxPeers: 5, + slotDuration: 6 * time.Second, + bs: bs, + net: net, + requestMaker: reqMaker, + babeVerifier: babeVerifier, + storageState: storageState, + blockImportHandler: blockImportHandler, + telemetry: telemetry, + } + + chainSync := newChainSync(cfg) + chainSync.peerViewSet = &peerViewSet{view: peerViewMap} + chainSync.syncMode.Store(bootstrap) + + return chainSync +} + +func TestChainSync_BootstrapSync_SuccessfulSync_WithOneWorker(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + mockedGenesisHeader := types.NewHeader(common.NewHash([]byte{0}), trie.EmptyHash, + trie.EmptyHash, 0, types.NewDigest()) + + const blocksAhead = 128 + totalBlockResponse := createSuccesfullBlockResponse(t, mockedGenesisHeader.Hash(), 1, blocksAhead) + mockedNetwork := NewMockNetwork(ctrl) + + workerPeerID := peer.ID("noot") + startingBlock := variadic.MustNewUint32OrHash(1) + max := uint32(128) + + mockedRequestMaker := NewMockRequestMaker(ctrl) + + expectedBlockRequestMessage := &messages.BlockRequestMessage{ + RequestedData: messages.BootstrapRequestData, + StartingBlock: *startingBlock, + Direction: messages.Ascending, + Max: &max, + } + + mockedRequestMaker.EXPECT(). + Do(workerPeerID, expectedBlockRequestMessage, &messages.BlockResponseMessage{}). + DoAndReturn(func(_, _, response any) any { + responsePtr := response.(*messages.BlockResponseMessage) + *responsePtr = *totalBlockResponse + return nil + }) + + mockedBlockState := NewMockBlockState(ctrl) + mockedBlockState.EXPECT().GetFinalisedNotifierChannel().Return(make(chan *types.FinalisationInfo)) + mockedBlockState.EXPECT().IsPaused().Return(false) + + mockBabeVerifier := NewMockBabeVerifier(ctrl) + mockStorageState := NewMockStorageState(ctrl) + mockImportHandler := NewMockBlockImportHandler(ctrl) + mockTelemetry := NewMockTelemetry(ctrl) + + mockedBlockState.EXPECT().GetHighestFinalisedHeader().Return(types.NewEmptyHeader(), nil).Times(1) + mockedNetwork.EXPECT().Peers().Return([]common.PeerInfo{}).Times(1) + + const announceBlock = false + // setup mocks for new synced blocks that doesn't exists in our local database + ensureSuccessfulBlockImportFlow(t, mockedGenesisHeader, totalBlockResponse.BlockData, mockedBlockState, + mockBabeVerifier, mockStorageState, mockImportHandler, mockTelemetry, networkInitialSync, announceBlock) + + // setup a chain sync which holds in its peer view map + // 3 peers, each one announce block X as its best block number. + // We start this test with genesis block being our best block, so + // we're far behind by X blocks, we should execute a bootstrap + // sync request those blocks + cs := setupChainSyncToBootstrapMode(t, blocksAhead, + mockedBlockState, mockedNetwork, mockedRequestMaker, mockBabeVerifier, + mockStorageState, mockImportHandler, mockTelemetry) + + target := cs.peerViewSet.getTarget() + require.Equal(t, uint(128), target) + + // include a new worker in the worker pool set, this worker + // should be an available peer that will receive a block request + // the worker pool executes the workers management + cs.workerPool.fromBlockAnnounce(peer.ID("noot")) + + err := cs.requestMaxBlocksFrom(mockedGenesisHeader, networkInitialSync) + require.NoError(t, err) + + err = cs.workerPool.stop() + require.NoError(t, err) +} + +func TestChainSync_BootstrapSync_SuccessfulSync_WithTwoWorkers(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + mockBlockState := NewMockBlockState(ctrl) + mockBlockState.EXPECT().GetFinalisedNotifierChannel().Return(make(chan *types.FinalisationInfo)) + mockedGenesisHeader := types.NewHeader(common.NewHash([]byte{0}), trie.EmptyHash, + trie.EmptyHash, 0, types.NewDigest()) + + mockNetwork := NewMockNetwork(ctrl) + mockRequestMaker := NewMockRequestMaker(ctrl) + + mockBabeVerifier := NewMockBabeVerifier(ctrl) + mockStorageState := NewMockStorageState(ctrl) + mockImportHandler := NewMockBlockImportHandler(ctrl) + mockTelemetry := NewMockTelemetry(ctrl) + + mockBlockState.EXPECT().GetHighestFinalisedHeader().Return(types.NewEmptyHeader(), nil).Times(1) + mockBlockState.EXPECT().IsPaused().Return(false) + mockNetwork.EXPECT().Peers().Return([]common.PeerInfo{}).Times(1) + + // this test expects two workers responding each request with 128 blocks which means + // we should import 256 blocks in total + blockResponse := createSuccesfullBlockResponse(t, mockedGenesisHeader.Hash(), 1, 256) + + // here we split the whole set in two parts each one will be the "response" for each peer + worker1Response := &messages.BlockResponseMessage{ + BlockData: blockResponse.BlockData[:128], + } + const announceBlock = false + // the first peer will respond the from the block 1 to 128 so the ensureBlockImportFlow + // will setup the expectations starting from the genesis header until block 128 + ensureSuccessfulBlockImportFlow(t, mockedGenesisHeader, worker1Response.BlockData, mockBlockState, + mockBabeVerifier, mockStorageState, mockImportHandler, mockTelemetry, networkInitialSync, announceBlock) + + worker2Response := &messages.BlockResponseMessage{ + BlockData: blockResponse.BlockData[128:], + } + // the worker 2 will respond from block 129 to 256 so the ensureBlockImportFlow + // will setup the expectations starting from block 128, from previous worker, until block 256 + parent := worker1Response.BlockData[len(worker1Response.BlockData)-1] + ensureSuccessfulBlockImportFlow(t, parent.Header, worker2Response.BlockData, mockBlockState, + mockBabeVerifier, mockStorageState, mockImportHandler, mockTelemetry, networkInitialSync, announceBlock) + + // we use gomock.Any since I cannot guarantee which peer picks which request + // but the first call to DoBlockRequest will return the first set and the second + // call will return the second set + mockRequestMaker.EXPECT(). + Do(gomock.Any(), gomock.Any(), &messages.BlockResponseMessage{}). + DoAndReturn(func(_, _, response any) any { + responsePtr := response.(*messages.BlockResponseMessage) + *responsePtr = *worker1Response + return nil + }) + + mockRequestMaker.EXPECT(). + Do(gomock.Any(), gomock.Any(), &messages.BlockResponseMessage{}). + DoAndReturn(func(_, _, response any) any { + responsePtr := response.(*messages.BlockResponseMessage) + *responsePtr = *worker2Response + return nil + }) + + // setup a chain sync which holds in its peer view map + // 3 peers, each one announce block 129 as its best block number. + // We start this test with genesis block being our best block, so + // we're far behind by 128 blocks, we should execute a bootstrap + // sync request those blocks + const blocksAhead = 256 + cs := setupChainSyncToBootstrapMode(t, blocksAhead, + mockBlockState, mockNetwork, mockRequestMaker, mockBabeVerifier, + mockStorageState, mockImportHandler, mockTelemetry) + + target := cs.peerViewSet.getTarget() + require.Equal(t, uint(blocksAhead), target) + + // include a new worker in the worker pool set, this worker + // should be an available peer that will receive a block request + // the worker pool executes the workers management + cs.workerPool.fromBlockAnnounce(peer.ID("noot")) + cs.workerPool.fromBlockAnnounce(peer.ID("noot2")) + + err := cs.requestMaxBlocksFrom(mockedGenesisHeader, networkInitialSync) + require.NoError(t, err) + + err = cs.workerPool.stop() + require.NoError(t, err) +} + +func TestChainSync_BootstrapSync_SuccessfulSync_WithOneWorkerFailing(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + mockBlockState := NewMockBlockState(ctrl) + mockBlockState.EXPECT().GetFinalisedNotifierChannel().Return(make(chan *types.FinalisationInfo)) + mockBlockState.EXPECT().IsPaused().Return(false).Times(2) + mockedGenesisHeader := types.NewHeader(common.NewHash([]byte{0}), trie.EmptyHash, + trie.EmptyHash, 0, types.NewDigest()) + + mockNetwork := NewMockNetwork(ctrl) + mockRequestMaker := NewMockRequestMaker(ctrl) + + mockBabeVerifier := NewMockBabeVerifier(ctrl) + mockStorageState := NewMockStorageState(ctrl) + mockImportHandler := NewMockBlockImportHandler(ctrl) + mockTelemetry := NewMockTelemetry(ctrl) + + mockBlockState.EXPECT().GetHighestFinalisedHeader().Return(types.NewEmptyHeader(), nil).Times(1) + + mockNetwork.EXPECT().Peers().Return([]common.PeerInfo{}).Times(1) + + // this test expects two workers responding each request with 128 blocks which means + // we should import 256 blocks in total + blockResponse := createSuccesfullBlockResponse(t, mockedGenesisHeader.Hash(), 1, 256) + const announceBlock = false + + // here we split the whole set in two parts each one will be the "response" for each peer + worker1Response := &messages.BlockResponseMessage{ + BlockData: blockResponse.BlockData[:128], + } + + // the first peer will respond the from the block 1 to 128 so the ensureBlockImportFlow + // will setup the expectations starting from the genesis header until block 128 + ensureSuccessfulBlockImportFlow(t, mockedGenesisHeader, worker1Response.BlockData, mockBlockState, + mockBabeVerifier, mockStorageState, mockImportHandler, mockTelemetry, networkInitialSync, announceBlock) + + worker2Response := &messages.BlockResponseMessage{ + BlockData: blockResponse.BlockData[128:], + } + // the worker 2 will respond from block 129 to 256 so the ensureBlockImportFlow + // will setup the expectations starting from block 128, from previous worker, until block 256 + parent := worker1Response.BlockData[len(worker1Response.BlockData)-1] + ensureSuccessfulBlockImportFlow(t, parent.Header, worker2Response.BlockData, mockBlockState, + mockBabeVerifier, mockStorageState, mockImportHandler, mockTelemetry, networkInitialSync, announceBlock) + + // we use gomock.Any since I cannot guarantee which peer picks which request + // but the first call to DoBlockRequest will return the first set and the second + // call will return the second set + doBlockRequestCount := atomic.Int32{} + mockRequestMaker.EXPECT(). + Do(gomock.Any(), gomock.Any(), &messages.BlockResponseMessage{}). + DoAndReturn(func(peerID, _, response any) any { + // lets ensure that the DoBlockRequest is called by + // peer.ID(alice) and peer.ID(bob). When bob calls, this method will fail + // then alice should pick the failed request and re-execute it which will + // be the third call + responsePtr := response.(*messages.BlockResponseMessage) + defer func() { doBlockRequestCount.Add(1) }() + + switch doBlockRequestCount.Load() { + case 0: + *responsePtr = *worker1Response + case 1: + return errors.New("a bad error while getting a response") + default: + *responsePtr = *worker2Response + } + return nil + + }).Times(3) + + // setup a chain sync which holds in its peer view map + // 3 peers, each one announce block 129 as its best block number. + // We start this test with genesis block being our best block, so + // we're far behind by 128 blocks, we should execute a bootstrap + // sync request those blocks + const blocksAhead = 256 + cs := setupChainSyncToBootstrapMode(t, blocksAhead, + mockBlockState, mockNetwork, mockRequestMaker, mockBabeVerifier, + mockStorageState, mockImportHandler, mockTelemetry) + + target := cs.peerViewSet.getTarget() + require.Equal(t, uint(blocksAhead), target) + + // include a new worker in the worker pool set, this worker + // should be an available peer that will receive a block request + // the worker pool executes the workers management + cs.workerPool.fromBlockAnnounce(peer.ID("alice")) + cs.workerPool.fromBlockAnnounce(peer.ID("bob")) + + err := cs.requestMaxBlocksFrom(mockedGenesisHeader, networkInitialSync) + require.NoError(t, err) + + err = cs.workerPool.stop() + require.NoError(t, err) +} + +func TestChainSync_BootstrapSync_SuccessfulSync_WithProtocolNotSupported(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + mockBlockState := NewMockBlockState(ctrl) + mockBlockState.EXPECT().GetFinalisedNotifierChannel().Return(make(chan *types.FinalisationInfo)) + mockBlockState.EXPECT().IsPaused().Return(false).Times(2) + mockBlockState.EXPECT(). + GetHighestFinalisedHeader(). + Return(types.NewEmptyHeader(), nil). + Times(1) + mockedGenesisHeader := types.NewHeader(common.NewHash([]byte{0}), trie.EmptyHash, + trie.EmptyHash, 0, types.NewDigest()) + + mockNetwork := NewMockNetwork(ctrl) + mockNetwork.EXPECT().Peers().Return([]common.PeerInfo{}) + mockRequestMaker := NewMockRequestMaker(ctrl) + + mockBabeVerifier := NewMockBabeVerifier(ctrl) + mockStorageState := NewMockStorageState(ctrl) + mockImportHandler := NewMockBlockImportHandler(ctrl) + mockTelemetry := NewMockTelemetry(ctrl) + + // this test expects two workers responding each request with 128 blocks which means + // we should import 256 blocks in total + blockResponse := createSuccesfullBlockResponse(t, mockedGenesisHeader.Hash(), 1, 256) + const announceBlock = false + + // here we split the whole set in two parts each one will be the "response" for each peer + worker1Response := &messages.BlockResponseMessage{ + BlockData: blockResponse.BlockData[:128], + } + + // the first peer will respond the from the block 1 to 128 so the ensureBlockImportFlow + // will setup the expectations starting from the genesis header until block 128 + ensureSuccessfulBlockImportFlow(t, mockedGenesisHeader, worker1Response.BlockData, mockBlockState, + mockBabeVerifier, mockStorageState, mockImportHandler, mockTelemetry, networkInitialSync, announceBlock) + + worker2Response := &messages.BlockResponseMessage{ + BlockData: blockResponse.BlockData[128:], + } + // the worker 2 will respond from block 129 to 256 so the ensureBlockImportFlow + // will setup the expectations starting from block 128, from previous worker, until block 256 + parent := worker1Response.BlockData[len(worker1Response.BlockData)-1] + ensureSuccessfulBlockImportFlow(t, parent.Header, worker2Response.BlockData, mockBlockState, + mockBabeVerifier, mockStorageState, mockImportHandler, mockTelemetry, networkInitialSync, announceBlock) + + // we use gomock.Any since I cannot guarantee which peer picks which request + // but the first call to DoBlockRequest will return the first set and the second + // call will return the second set + doBlockRequestCount := atomic.Int32{} + mockRequestMaker.EXPECT(). + Do(gomock.Any(), gomock.Any(), &messages.BlockResponseMessage{}). + DoAndReturn(func(peerID, _, response any) any { + // lets ensure that the DoBlockRequest is called by + // peer.ID(alice) and peer.ID(bob). When bob calls, this method will fail + // then alice should pick the failed request and re-execute it which will + // be the third call + responsePtr := response.(*messages.BlockResponseMessage) + defer func() { doBlockRequestCount.Add(1) }() + + switch doBlockRequestCount.Load() { + case 0: + *responsePtr = *worker1Response + case 1: + return errors.New("protocols not supported") + default: + *responsePtr = *worker2Response + } + + return nil + }).Times(3) + + // since some peer will fail with protocols not supported his + // reputation will be affected and + mockNetwork.EXPECT().ReportPeer(peerset.ReputationChange{ + Value: peerset.BadProtocolValue, + Reason: peerset.BadProtocolReason, + }, gomock.AssignableToTypeOf(peer.ID(""))) + // setup a chain sync which holds in its peer view map + // 3 peers, each one announce block 129 as its best block number. + // We start this test with genesis block being our best block, so + // we're far behind by 128 blocks, we should execute a bootstrap + // sync request those blocks + const blocksAhead = 256 + cs := setupChainSyncToBootstrapMode(t, blocksAhead, + mockBlockState, mockNetwork, mockRequestMaker, mockBabeVerifier, + mockStorageState, mockImportHandler, mockTelemetry) + + target := cs.peerViewSet.getTarget() + require.Equal(t, uint(blocksAhead), target) + + // include a new worker in the worker pool set, this worker + // should be an available peer that will receive a block request + // the worker pool executes the workers management + cs.workerPool.fromBlockAnnounce(peer.ID("alice")) + cs.workerPool.fromBlockAnnounce(peer.ID("bob")) + + err := cs.requestMaxBlocksFrom(mockedGenesisHeader, networkInitialSync) + require.NoError(t, err) + + err = cs.workerPool.stop() + require.NoError(t, err) +} + +func TestChainSync_BootstrapSync_SuccessfulSync_WithNilHeaderInResponse(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + mockBlockState := NewMockBlockState(ctrl) + mockBlockState.EXPECT().GetFinalisedNotifierChannel().Return(make(chan *types.FinalisationInfo)) + mockBlockState.EXPECT().IsPaused().Return(false).Times(2) + mockBlockState.EXPECT(). + GetHighestFinalisedHeader(). + Return(types.NewEmptyHeader(), nil). + Times(1) + mockedGenesisHeader := types.NewHeader(common.NewHash([]byte{0}), trie.EmptyHash, + trie.EmptyHash, 0, types.NewDigest()) + + mockNetwork := NewMockNetwork(ctrl) + mockNetwork.EXPECT().Peers().Return([]common.PeerInfo{}) + mockRequestMaker := NewMockRequestMaker(ctrl) + + mockBabeVerifier := NewMockBabeVerifier(ctrl) + mockStorageState := NewMockStorageState(ctrl) + mockImportHandler := NewMockBlockImportHandler(ctrl) + mockTelemetry := NewMockTelemetry(ctrl) + + // this test expects two workers responding each request with 128 blocks which means + // we should import 256 blocks in total + blockResponse := createSuccesfullBlockResponse(t, mockedGenesisHeader.Hash(), 1, 256) + const announceBlock = false + + // here we split the whole set in two parts each one will be the "response" for each peer + worker1Response := &messages.BlockResponseMessage{ + BlockData: blockResponse.BlockData[:128], + } + + // the first peer will respond the from the block 1 to 128 so the ensureBlockImportFlow + // will setup the expectations starting from the genesis header until block 128 + ensureSuccessfulBlockImportFlow(t, mockedGenesisHeader, worker1Response.BlockData, mockBlockState, + mockBabeVerifier, mockStorageState, mockImportHandler, mockTelemetry, networkInitialSync, announceBlock) + + worker2Response := &messages.BlockResponseMessage{ + BlockData: blockResponse.BlockData[128:], + } + // the worker 2 will respond from block 129 to 256 so the ensureBlockImportFlow + // will setup the expectations starting from block 128, from previous worker, until block 256 + parent := worker1Response.BlockData[127] + ensureSuccessfulBlockImportFlow(t, parent.Header, worker2Response.BlockData, mockBlockState, + mockBabeVerifier, mockStorageState, mockImportHandler, mockTelemetry, networkInitialSync, announceBlock) + + // we use gomock.Any since I cannot guarantee which peer picks which request + // but the first call to DoBlockRequest will return the first set and the second + // call will return the second set + doBlockRequestCount := atomic.Int32{} + mockRequestMaker.EXPECT(). + Do(gomock.Any(), gomock.Any(), &messages.BlockResponseMessage{}). + DoAndReturn(func(peerID, _, response any) any { + // lets ensure that the DoBlockRequest is called by + // peer.ID(alice) and peer.ID(bob). When bob calls, this method return an + // response item but without header as was requested + responsePtr := response.(*messages.BlockResponseMessage) + defer func() { doBlockRequestCount.Add(1) }() + + switch doBlockRequestCount.Load() { + case 0: + *responsePtr = *worker1Response + case 1: + incompleteBlockData := createSuccesfullBlockResponse(t, mockedGenesisHeader.Hash(), 128, 256) + incompleteBlockData.BlockData[0].Header = nil + + *responsePtr = *incompleteBlockData + default: + *responsePtr = *worker2Response + } + + return nil + }).Times(3) + + // since some peer will fail with protocols not supported his + // reputation will be affected and + mockNetwork.EXPECT().ReportPeer(peerset.ReputationChange{ + Value: peerset.IncompleteHeaderValue, + Reason: peerset.IncompleteHeaderReason, + }, gomock.AssignableToTypeOf(peer.ID(""))) + // setup a chain sync which holds in its peer view map + // 3 peers, each one announce block 129 as its best block number. + // We start this test with genesis block being our best block, so + // we're far behind by 128 blocks, we should execute a bootstrap + // sync request those blocks + const blocksAhead = 256 + cs := setupChainSyncToBootstrapMode(t, blocksAhead, + mockBlockState, mockNetwork, mockRequestMaker, mockBabeVerifier, + mockStorageState, mockImportHandler, mockTelemetry) + + target := cs.peerViewSet.getTarget() + require.Equal(t, uint(blocksAhead), target) + + // include a new worker in the worker pool set, this worker + // should be an available peer that will receive a block request + // the worker pool executes the workers management + cs.workerPool.fromBlockAnnounce(peer.ID("alice")) + cs.workerPool.fromBlockAnnounce(peer.ID("bob")) + + err := cs.requestMaxBlocksFrom(mockedGenesisHeader, networkInitialSync) + require.NoError(t, err) + + err = cs.workerPool.stop() + require.NoError(t, err) +} + +func TestChainSync_BootstrapSync_SuccessfulSync_WithNilBlockInResponse(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + mockBlockState := NewMockBlockState(ctrl) + mockBlockState.EXPECT().GetFinalisedNotifierChannel().Return(make(chan *types.FinalisationInfo)) + mockBlockState.EXPECT().IsPaused().Return(false).Times(2) + mockBlockState.EXPECT(). + GetHighestFinalisedHeader(). + Return(types.NewEmptyHeader(), nil). + Times(1) + mockedGenesisHeader := types.NewHeader(common.NewHash([]byte{0}), trie.EmptyHash, + trie.EmptyHash, 0, types.NewDigest()) + + mockBabeVerifier := NewMockBabeVerifier(ctrl) + mockStorageState := NewMockStorageState(ctrl) + mockImportHandler := NewMockBlockImportHandler(ctrl) + mockTelemetry := NewMockTelemetry(ctrl) + + blockResponse := createSuccesfullBlockResponse(t, mockedGenesisHeader.Hash(), 1, 128) + const announceBlock = false + + workerResponse := &messages.BlockResponseMessage{ + BlockData: blockResponse.BlockData, + } + + // the first peer will respond the from the block 1 to 128 so the ensureBlockImportFlow + // will setup the expectations starting from the genesis header until block 128 + ensureSuccessfulBlockImportFlow(t, mockedGenesisHeader, workerResponse.BlockData, mockBlockState, + mockBabeVerifier, mockStorageState, mockImportHandler, mockTelemetry, networkInitialSync, announceBlock) + + doBlockRequestCount := atomic.Int32{} + mockRequestMaker := NewMockRequestMaker(ctrl) + mockRequestMaker.EXPECT(). + Do(gomock.Any(), gomock.Any(), &messages.BlockResponseMessage{}). + DoAndReturn(func(peerID, _, response any) any { + // lets ensure that the DoBlockRequest is called by + // peer.ID(alice) and peer.ID(bob). When bob calls, this method return an + // response item but without header as was requested + responsePtr := response.(*messages.BlockResponseMessage) + defer func() { doBlockRequestCount.Add(1) }() + + switch doBlockRequestCount.Load() { + case 0: + return messages.ErrNilBlockInResponse + case 1: + *responsePtr = *workerResponse + } + + return nil + }).Times(2) + + mockNetwork := NewMockNetwork(ctrl) + mockNetwork.EXPECT().Peers().Return([]common.PeerInfo{}) + + // reputation will be affected and + mockNetwork.EXPECT().ReportPeer(peerset.ReputationChange{ + Value: peerset.BadMessageValue, + Reason: peerset.BadMessageReason, + }, gomock.AssignableToTypeOf(peer.ID(""))) + + const blocksAhead = 128 + cs := setupChainSyncToBootstrapMode(t, blocksAhead, + mockBlockState, mockNetwork, mockRequestMaker, mockBabeVerifier, + mockStorageState, mockImportHandler, mockTelemetry) + + target := cs.peerViewSet.getTarget() + require.Equal(t, uint(blocksAhead), target) + + // include a new worker in the worker pool set, this worker + // should be an available peer that will receive a block request + // the worker pool executes the workers management + cs.workerPool.fromBlockAnnounce(peer.ID("alice")) + cs.workerPool.fromBlockAnnounce(peer.ID("bob")) + + err := cs.requestMaxBlocksFrom(mockedGenesisHeader, networkInitialSync) + require.NoError(t, err) + + err = cs.workerPool.stop() + require.NoError(t, err) +} + +func TestChainSync_BootstrapSync_SuccessfulSync_WithResponseIsNotAChain(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + mockBlockState := NewMockBlockState(ctrl) + mockBlockState.EXPECT().GetFinalisedNotifierChannel().Return(make(chan *types.FinalisationInfo)) + mockBlockState.EXPECT().IsPaused().Return(false).Times(2) + mockBlockState.EXPECT(). + GetHighestFinalisedHeader(). + Return(types.NewEmptyHeader(), nil). + Times(1) + mockedGenesisHeader := types.NewHeader(common.NewHash([]byte{0}), trie.EmptyHash, + trie.EmptyHash, 0, types.NewDigest()) + + mockNetwork := NewMockNetwork(ctrl) + mockNetwork.EXPECT().Peers().Return([]common.PeerInfo{}) + mockRequestMaker := NewMockRequestMaker(ctrl) + + mockBabeVerifier := NewMockBabeVerifier(ctrl) + mockStorageState := NewMockStorageState(ctrl) + mockImportHandler := NewMockBlockImportHandler(ctrl) + mockTelemetry := NewMockTelemetry(ctrl) + + // this test expects two workers responding each request with 128 blocks which means + // we should import 256 blocks in total + blockResponse := createSuccesfullBlockResponse(t, mockedGenesisHeader.Hash(), 1, 256) + const announceBlock = false + + // here we split the whole set in two parts each one will be the "response" for each peer + worker1Response := &messages.BlockResponseMessage{ + BlockData: blockResponse.BlockData[:128], + } + + // the first peer will respond the from the block 1 to 128 so the ensureBlockImportFlow + // will setup the expectations starting from the genesis header until block 128 + ensureSuccessfulBlockImportFlow(t, mockedGenesisHeader, worker1Response.BlockData, mockBlockState, + mockBabeVerifier, mockStorageState, mockImportHandler, mockTelemetry, networkInitialSync, announceBlock) + + worker2Response := &messages.BlockResponseMessage{ + BlockData: blockResponse.BlockData[128:], + } + // the worker 2 will respond from block 129 to 256 so the ensureBlockImportFlow + // will setup the expectations starting from block 128, from previous worker, until block 256 + parent := worker1Response.BlockData[127] + ensureSuccessfulBlockImportFlow(t, parent.Header, worker2Response.BlockData, mockBlockState, + mockBabeVerifier, mockStorageState, mockImportHandler, mockTelemetry, networkInitialSync, announceBlock) + + // we use gomock.Any since I cannot guarantee which peer picks which request + // but the first call to DoBlockRequest will return the first set and the second + // call will return the second set + doBlockRequestCount := atomic.Int32{} + mockRequestMaker.EXPECT(). + Do(gomock.Any(), gomock.Any(), &messages.BlockResponseMessage{}). + DoAndReturn(func(peerID, _, response any) any { + // lets ensure that the DoBlockRequest is called by + // peer.ID(alice) and peer.ID(bob). When bob calls, this method return an + // response that does not form an chain + responsePtr := response.(*messages.BlockResponseMessage) + defer func() { doBlockRequestCount.Add(1) }() + + switch doBlockRequestCount.Load() { + case 0: + *responsePtr = *worker1Response + case 1: + notAChainBlockData := createSuccesfullBlockResponse(t, mockedGenesisHeader.Hash(), 128, 256) + // swap positions to force the problem + notAChainBlockData.BlockData[0], notAChainBlockData.BlockData[130] = + notAChainBlockData.BlockData[130], notAChainBlockData.BlockData[0] + + *responsePtr = *notAChainBlockData + default: + *responsePtr = *worker2Response + } + + return nil + }).Times(3) + + // setup a chain sync which holds in its peer view map + // 3 peers, each one announce block 129 as its best block number. + // We start this test with genesis block being our best block, so + // we're far behind by 128 blocks, we should execute a bootstrap + // sync request those blocks + const blocksAhead = 256 + cs := setupChainSyncToBootstrapMode(t, blocksAhead, + mockBlockState, mockNetwork, mockRequestMaker, mockBabeVerifier, + mockStorageState, mockImportHandler, mockTelemetry) + + target := cs.peerViewSet.getTarget() + + require.Equal(t, uint(blocksAhead), target) + + // include a new worker in the worker pool set, this worker + // should be an available peer that will receive a block request + // the worker pool executes the workers management + cs.workerPool.fromBlockAnnounce(peer.ID("alice")) + cs.workerPool.fromBlockAnnounce(peer.ID("bob")) + + err := cs.requestMaxBlocksFrom(mockedGenesisHeader, networkInitialSync) + require.NoError(t, err) + + err = cs.workerPool.stop() + require.NoError(t, err) +} + +func TestChainSync_BootstrapSync_SuccessfulSync_WithReceivedBadBlock(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + mockBlockState := NewMockBlockState(ctrl) + mockBlockState.EXPECT().GetFinalisedNotifierChannel().Return(make(chan *types.FinalisationInfo)) + mockBlockState.EXPECT().IsPaused().Return(false).Times(2) + mockBlockState.EXPECT(). + GetHighestFinalisedHeader(). + Return(types.NewEmptyHeader(), nil). + Times(1) + + mockedGenesisHeader := types.NewHeader(common.NewHash([]byte{0}), trie.EmptyHash, + trie.EmptyHash, 0, types.NewDigest()) + + mockNetwork := NewMockNetwork(ctrl) + mockNetwork.EXPECT().Peers().Return([]common.PeerInfo{}) + mockRequestMaker := NewMockRequestMaker(ctrl) + + mockBabeVerifier := NewMockBabeVerifier(ctrl) + mockStorageState := NewMockStorageState(ctrl) + mockImportHandler := NewMockBlockImportHandler(ctrl) + mockTelemetry := NewMockTelemetry(ctrl) + + // this test expects two workers responding each request with 128 blocks which means + // we should import 256 blocks in total + blockResponse := createSuccesfullBlockResponse(t, mockedGenesisHeader.Hash(), 1, 256) + const announceBlock = false + + // here we split the whole set in two parts each one will be the "response" for each peer + worker1Response := &messages.BlockResponseMessage{ + BlockData: blockResponse.BlockData[:128], + } + + // the first peer will respond the from the block 1 to 128 so the ensureBlockImportFlow + // will setup the expectations starting from the genesis header until block 128 + ensureSuccessfulBlockImportFlow(t, mockedGenesisHeader, worker1Response.BlockData, mockBlockState, + mockBabeVerifier, mockStorageState, mockImportHandler, mockTelemetry, networkInitialSync, announceBlock) + + worker2Response := &messages.BlockResponseMessage{ + BlockData: blockResponse.BlockData[128:], + } + // the worker 2 will respond from block 129 to 256 so the ensureBlockImportFlow + // will setup the expectations starting from block 128, from previous worker, until block 256 + parent := worker1Response.BlockData[len(worker1Response.BlockData)-1] + ensureSuccessfulBlockImportFlow(t, parent.Header, worker2Response.BlockData, mockBlockState, + mockBabeVerifier, mockStorageState, mockImportHandler, mockTelemetry, networkInitialSync, announceBlock) + + fakeBadBlockHash := common.MustHexToHash("0x18767cb4bb4cc13bf119f6613aec5487d4c06a2e453de53d34aea6f3f1ee9855") + + // we use gomock.Any since I cannot guarantee which peer picks which request + // but the first call to DoBlockRequest will return the first set and the second + // call will return the second set + doBlockRequestCount := atomic.Int32{} + mockRequestMaker.EXPECT(). + Do(gomock.Any(), gomock.Any(), &messages.BlockResponseMessage{}). + DoAndReturn(func(peerID, _, response any) any { + // lets ensure that the DoBlockRequest is called by + // peer.ID(alice) and peer.ID(bob). When bob calls, this method return an + // response that contains a know bad block + responsePtr := response.(*messages.BlockResponseMessage) + defer func() { doBlockRequestCount.Add(1) }() + + switch doBlockRequestCount.Load() { + case 0: + *responsePtr = *worker1Response + case 1: + // use the fisrt response last item hash to produce the second response block data + // so we can guarantee that the second response continues the first response blocks + firstResponseLastItem := worker1Response.BlockData[len(worker1Response.BlockData)-1] + blockDataWithBadBlock := createSuccesfullBlockResponse(t, + firstResponseLastItem.Header.Hash(), + 129, + 128) + + // changes the last item from the second response to be a bad block, so we guarantee that + // this second response is a chain, (changing the hash from a block in the middle of the block + // response brokes the `isAChain` verification) + lastItem := len(blockDataWithBadBlock.BlockData) - 1 + blockDataWithBadBlock.BlockData[lastItem].Hash = fakeBadBlockHash + *responsePtr = *blockDataWithBadBlock + default: + *responsePtr = *worker2Response + } + + return nil + }).Times(3) + + mockNetwork.EXPECT().ReportPeer(peerset.ReputationChange{ + Value: peerset.BadBlockAnnouncementValue, + Reason: peerset.BadBlockAnnouncementReason, + }, gomock.AssignableToTypeOf(peer.ID(""))) + // setup a chain sync which holds in its peer view map + // 3 peers, each one announce block 129 as its best block number. + // We start this test with genesis block being our best block, so + // we're far behind by 128 blocks, we should execute a bootstrap + // sync request those blocks + const blocksAhead = 256 + cs := setupChainSyncToBootstrapMode(t, blocksAhead, + mockBlockState, mockNetwork, mockRequestMaker, mockBabeVerifier, + mockStorageState, mockImportHandler, mockTelemetry) + + cs.badBlocks = []string{fakeBadBlockHash.String()} + + target := cs.peerViewSet.getTarget() + require.Equal(t, uint(blocksAhead), target) + + // include a new worker in the worker pool set, this worker + // should be an available peer that will receive a block request + // the worker pool executes the workers management + cs.workerPool.fromBlockAnnounce(peer.ID("alice")) + cs.workerPool.fromBlockAnnounce(peer.ID("bob")) + + err := cs.requestMaxBlocksFrom(mockedGenesisHeader, networkInitialSync) + require.NoError(t, err) + + err = cs.workerPool.stop() + require.NoError(t, err) + + // peer should be not in the worker pool + // peer should be in the ignore list + require.Len(t, cs.workerPool.workers, 1) + require.Len(t, cs.workerPool.ignorePeers, 1) +} + +func TestChainSync_BootstrapSync_SucessfulSync_ReceivedPartialBlockData(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + mockBlockState := NewMockBlockState(ctrl) + mockBlockState.EXPECT().GetFinalisedNotifierChannel().Return(make(chan *types.FinalisationInfo)) + mockBlockState.EXPECT().IsPaused().Return(false).Times(2) + mockBlockState.EXPECT(). + GetHighestFinalisedHeader(). + Return(types.NewEmptyHeader(), nil). + Times(1) + + mockedGenesisHeader := types.NewHeader(common.NewHash([]byte{0}), trie.EmptyHash, + trie.EmptyHash, 0, types.NewDigest()) + + mockNetwork := NewMockNetwork(ctrl) + mockNetwork.EXPECT().Peers().Return([]common.PeerInfo{}) + + mockRequestMaker := NewMockRequestMaker(ctrl) + + mockBabeVerifier := NewMockBabeVerifier(ctrl) + mockStorageState := NewMockStorageState(ctrl) + mockImportHandler := NewMockBlockImportHandler(ctrl) + mockTelemetry := NewMockTelemetry(ctrl) + + // create a set of 128 blocks + blockResponse := createSuccesfullBlockResponse(t, mockedGenesisHeader.Hash(), 1, 128) + const announceBlock = false + + // the worker will return a partial size of the set + worker1Response := &messages.BlockResponseMessage{ + BlockData: blockResponse.BlockData[:97], + } + + // the first peer will respond the from the block 1 to 96 so the ensureBlockImportFlow + // will setup the expectations starting from the genesis header until block 96 + ensureSuccessfulBlockImportFlow(t, mockedGenesisHeader, worker1Response.BlockData, mockBlockState, + mockBabeVerifier, mockStorageState, mockImportHandler, mockTelemetry, networkInitialSync, announceBlock) + + worker1MissingBlocksResponse := &messages.BlockResponseMessage{ + BlockData: blockResponse.BlockData[97:], + } + + // last item from the previous response + parent := worker1Response.BlockData[96] + ensureSuccessfulBlockImportFlow(t, parent.Header, worker1MissingBlocksResponse.BlockData, mockBlockState, + mockBabeVerifier, mockStorageState, mockImportHandler, mockTelemetry, networkInitialSync, announceBlock) + + doBlockRequestCount := 0 + mockRequestMaker.EXPECT(). + Do(gomock.Any(), gomock.Any(), &messages.BlockResponseMessage{}). + DoAndReturn(func(peerID, _, response any) any { + // lets ensure that the DoBlockRequest is called by + // peer.ID(alice). The first call will return only 97 blocks + // the handler should issue another call to retrieve the missing blocks + responsePtr := response.(*messages.BlockResponseMessage) + defer func() { doBlockRequestCount++ }() + + if doBlockRequestCount == 0 { + *responsePtr = *worker1Response + } else { + *responsePtr = *worker1MissingBlocksResponse + } + + return nil + }).Times(2) + + const blocksAhead = 128 + cs := setupChainSyncToBootstrapMode(t, blocksAhead, + mockBlockState, mockNetwork, mockRequestMaker, mockBabeVerifier, + mockStorageState, mockImportHandler, mockTelemetry) + + target := cs.peerViewSet.getTarget() + require.Equal(t, uint(blocksAhead), target) + + cs.workerPool.fromBlockAnnounce(peer.ID("alice")) + + err := cs.requestMaxBlocksFrom(mockedGenesisHeader, networkInitialSync) + require.NoError(t, err) + + err = cs.workerPool.stop() + require.NoError(t, err) + + require.Len(t, cs.workerPool.workers, 1) + + _, ok := cs.workerPool.workers[peer.ID("alice")] + require.True(t, ok) +} + +func createSuccesfullBlockResponse(t *testing.T, parentHeader common.Hash, + startingAt, numBlocks int) *messages.BlockResponseMessage { + t.Helper() + + response := new(messages.BlockResponseMessage) + response.BlockData = make([]*types.BlockData, numBlocks) + + emptyTrieState := storage.NewTrieState(inmemory_trie.NewEmptyTrie()) + tsRoot := emptyTrieState.Trie().MustHash() + + firstHeader := types.NewHeader(parentHeader, tsRoot, common.Hash{}, + uint(startingAt), nil) + response.BlockData[0] = &types.BlockData{ + Hash: firstHeader.Hash(), + Header: firstHeader, + Body: types.NewBody([]types.Extrinsic{}), + Justification: nil, + } + + parentHash := firstHeader.Hash() + for idx := 1; idx < numBlocks; idx++ { + blockNumber := idx + startingAt + header := types.NewHeader(parentHash, tsRoot, common.Hash{}, + uint(blockNumber), nil) + response.BlockData[idx] = &types.BlockData{ + Hash: header.Hash(), + Header: header, + Body: types.NewBody([]types.Extrinsic{}), + Justification: nil, + } + parentHash = header.Hash() + } + + return response +} + +// ensureSuccessfulBlockImportFlow will setup the expectations for method calls +// that happens while chain sync imports a block +func ensureSuccessfulBlockImportFlow(t *testing.T, parentHeader *types.Header, + blocksReceived []*types.BlockData, mockBlockState *MockBlockState, + mockBabeVerifier *MockBabeVerifier, mockStorageState *MockStorageState, + mockImportHandler *MockBlockImportHandler, mockTelemetry *MockTelemetry, origin blockOrigin, announceBlock bool) { + t.Helper() + + for idx, blockData := range blocksReceived { + if origin != networkInitialSync { + mockBabeVerifier.EXPECT().VerifyBlock(blockData.Header).Return(nil) + } + + var previousHeader *types.Header + if idx == 0 { + previousHeader = parentHeader + } else { + previousHeader = blocksReceived[idx-1].Header + } + + mockBlockState.EXPECT().GetHeader(blockData.Header.ParentHash).Return(previousHeader, nil).AnyTimes() + mockStorageState.EXPECT().Lock().AnyTimes() + mockStorageState.EXPECT().Unlock().AnyTimes() + + emptyTrieState := storage.NewTrieState(inmemory_trie.NewEmptyTrie()) + parentStateRoot := previousHeader.StateRoot + mockStorageState.EXPECT().TrieState(&parentStateRoot). + Return(emptyTrieState, nil).AnyTimes() + + ctrl := gomock.NewController(t) + mockRuntimeInstance := NewMockInstance(ctrl) + mockBlockState.EXPECT().GetRuntime(previousHeader.Hash()). + Return(mockRuntimeInstance, nil).AnyTimes() + + expectedBlock := &types.Block{ + Header: *blockData.Header, + Body: *blockData.Body, + } + + mockRuntimeInstance.EXPECT().SetContextStorage(emptyTrieState).AnyTimes() + mockRuntimeInstance.EXPECT().ExecuteBlock(expectedBlock). + Return(nil, nil).AnyTimes() + + mockImportHandler.EXPECT().HandleBlockImport(expectedBlock, emptyTrieState, announceBlock). + Return(nil).AnyTimes() + + blockHash := blockData.Header.Hash() + expectedTelemetryMessage := telemetry.NewBlockImport( + &blockHash, + blockData.Header.Number, + "NetworkInitialSync") + mockTelemetry.EXPECT().SendMessage(expectedTelemetryMessage).AnyTimes() + mockBlockState.EXPECT().CompareAndSetBlockData(blockData).Return(nil).AnyTimes() + } +} + +func TestChainSync_validateResponseFields(t *testing.T) { + t.Parallel() + + block1Header := &types.Header{ + ParentHash: common.MustHexToHash("0x00597cb4bb4cc13bf119f6613aec7642d4c06a2e453de53d34aea6f3f1eeb504"), + Number: 2, + } + + block2Header := &types.Header{ + ParentHash: block1Header.Hash(), + Number: 3, + } + + cases := map[string]struct { + wantErr error + errString string + setupChainSync func(t *testing.T) *chainSync + requestedData byte + blockData *types.BlockData + }{ + "requested_bootstrap_data_but_got_nil_header": { + wantErr: errNilHeaderInResponse, + errString: "expected header, received none: " + + block2Header.Hash().String(), + requestedData: messages.BootstrapRequestData, + blockData: &types.BlockData{ + Hash: block2Header.Hash(), + Header: nil, + Body: &types.Body{}, + Justification: &[]byte{0}, + }, + setupChainSync: func(t *testing.T) *chainSync { + ctrl := gomock.NewController(t) + blockStateMock := NewMockBlockState(ctrl) + blockStateMock.EXPECT().HasHeader(block1Header.ParentHash).Return(true, nil) + + networkMock := NewMockNetwork(ctrl) + networkMock.EXPECT().ReportPeer(peerset.ReputationChange{ + Value: peerset.IncompleteHeaderValue, + Reason: peerset.IncompleteHeaderReason, + }, peer.ID("peer")) + + return &chainSync{ + blockState: blockStateMock, + network: networkMock, + } + }, + }, + "requested_bootstrap_data_but_got_nil_body": { + wantErr: errNilBodyInResponse, + errString: "expected body, received none: " + + block2Header.Hash().String(), + requestedData: messages.BootstrapRequestData, + blockData: &types.BlockData{ + Hash: block2Header.Hash(), + Header: block2Header, + Body: nil, + Justification: &[]byte{0}, + }, + setupChainSync: func(t *testing.T) *chainSync { + ctrl := gomock.NewController(t) + blockStateMock := NewMockBlockState(ctrl) + blockStateMock.EXPECT().HasHeader(block1Header.ParentHash).Return(true, nil) + networkMock := NewMockNetwork(ctrl) + + return &chainSync{ + blockState: blockStateMock, + network: networkMock, + } + }, + }, + "requested_only_justification_but_got_nil": { + wantErr: errNilJustificationInResponse, + errString: "expected justification, received none: " + + block2Header.Hash().String(), + requestedData: messages.RequestedDataJustification, + blockData: &types.BlockData{ + Hash: block2Header.Hash(), + Header: block2Header, + Body: nil, + Justification: nil, + }, + setupChainSync: func(t *testing.T) *chainSync { + ctrl := gomock.NewController(t) + blockStateMock := NewMockBlockState(ctrl) + blockStateMock.EXPECT().HasHeader(block1Header.ParentHash).Return(true, nil) + networkMock := NewMockNetwork(ctrl) + + return &chainSync{ + blockState: blockStateMock, + network: networkMock, + } + }, + }, + } + + for tname, tt := range cases { + tt := tt + t.Run(tname, func(t *testing.T) { + t.Parallel() + + err := validateResponseFields(tt.requestedData, []*types.BlockData{tt.blockData}) + require.ErrorIs(t, err, tt.wantErr) + if tt.errString != "" { + require.EqualError(t, err, tt.errString) + } + }) + } +} + +func TestChainSync_isResponseAChain(t *testing.T) { + t.Parallel() + + block1Header := &types.Header{ + ParentHash: common.MustHexToHash("0x00597cb4bb4cc13bf119f6613aec7642d4c06a2e453de53d34aea6f3f1eeb504"), + Number: 2, + } + + block2Header := &types.Header{ + ParentHash: block1Header.Hash(), + Number: 3, + } + + block4Header := &types.Header{ + ParentHash: common.MustHexToHash("0x198616547187613bf119f6613aec7642d4c06a2e453de53d34aea6f390788677"), + Number: 4, + } + + cases := map[string]struct { + expected bool + blockData []*types.BlockData + }{ + "not_a_chain": { + expected: false, + blockData: []*types.BlockData{ + { + Hash: block1Header.Hash(), + Header: block1Header, + Body: &types.Body{}, + Justification: &[]byte{0}, + }, + { + Hash: block2Header.Hash(), + Header: block2Header, + Body: &types.Body{}, + Justification: &[]byte{0}, + }, + { + Hash: block4Header.Hash(), + Header: block4Header, + Body: &types.Body{}, + Justification: &[]byte{0}, + }, + }, + }, + "is_a_chain": { + expected: true, + blockData: []*types.BlockData{ + { + Hash: block1Header.Hash(), + Header: block1Header, + Body: &types.Body{}, + Justification: &[]byte{0}, + }, + { + Hash: block2Header.Hash(), + Header: block2Header, + Body: &types.Body{}, + Justification: &[]byte{0}, + }, + }, + }, + } + + for tname, tt := range cases { + tt := tt + t.Run(tname, func(t *testing.T) { + t.Parallel() + output := isResponseAChain(tt.blockData) + require.Equal(t, tt.expected, output) + }) + } +} + +func TestChainSync_doResponseGrowsTheChain(t *testing.T) { + block1Header := types.NewHeader(common.Hash{}, common.Hash{}, common.Hash{}, 1, types.NewDigest()) + block2Header := types.NewHeader(block1Header.Hash(), common.Hash{}, common.Hash{}, 2, types.NewDigest()) + block3Header := types.NewHeader(block2Header.Hash(), common.Hash{}, common.Hash{}, 3, types.NewDigest()) + block4Header := types.NewHeader(block3Header.Hash(), common.Hash{}, common.Hash{}, 4, types.NewDigest()) + + testcases := map[string]struct { + response []*types.BlockData + ongoingChain []*types.BlockData + startAt uint + exepectedTotal uint32 + expectedOut bool + }{ + // the ongoing chain does not have any data so the response + // can be inserted in the ongoing chain without any problems + "empty_ongoing_chain": { + ongoingChain: []*types.BlockData{}, + expectedOut: true, + }, + + "one_in_response_growing_ongoing_chain_without_check": { + startAt: 1, + exepectedTotal: 3, + // the ongoing chain contains 3 positions, the block number 1 is at position 0 + ongoingChain: []*types.BlockData{ + {Header: types.NewHeader(common.Hash{}, common.Hash{}, common.Hash{}, 1, types.NewDigest())}, + nil, + nil, + }, + + // the response contains the block number 3 which should be placed in position 2 + // in the ongoing chain, which means that no comparison should be done to place + // block number 3 in the ongoing chain + response: []*types.BlockData{ + {Header: types.NewHeader(common.Hash{}, common.Hash{}, common.Hash{}, 3, types.NewDigest())}, + }, + expectedOut: true, + }, + + "one_in_response_growing_ongoing_chain_by_checking_neighbours": { + startAt: 1, + exepectedTotal: 3, + // the ongoing chain contains 3 positions, the block number 1 is at position 0 + ongoingChain: []*types.BlockData{ + {Header: block1Header}, + nil, + {Header: block3Header}, + }, + + // the response contains the block number 2 which should be placed in position 1 + // in the ongoing chain, which means that a comparison should be made to check + // if the parent hash of block 2 is the same hash of block 1 + response: []*types.BlockData{ + {Header: block2Header}, + }, + expectedOut: true, + }, + + "one_in_response_failed_to_grow_ongoing_chain": { + startAt: 1, + exepectedTotal: 3, + ongoingChain: []*types.BlockData{ + {Header: block1Header}, + nil, + nil, + }, + response: []*types.BlockData{ + {Header: types.NewHeader(common.Hash{}, common.Hash{}, common.Hash{}, 2, types.NewDigest())}, + }, + expectedOut: false, + }, + + "many_in_response_grow_ongoing_chain_only_left_check": { + startAt: 1, + exepectedTotal: 3, + ongoingChain: []*types.BlockData{ + {Header: block1Header}, + nil, + nil, + nil, + }, + response: []*types.BlockData{ + {Header: block2Header}, + {Header: block3Header}, + }, + expectedOut: true, + }, + + "many_in_response_grow_ongoing_chain_left_right_check": { + startAt: 1, + exepectedTotal: 3, + ongoingChain: []*types.BlockData{ + {Header: block1Header}, + nil, + nil, + {Header: block4Header}, + }, + response: []*types.BlockData{ + {Header: block2Header}, + {Header: block3Header}, + }, + expectedOut: true, + }, + } + + for tname, tt := range testcases { + tt := tt + + t.Run(tname, func(t *testing.T) { + out := doResponseGrowsTheChain(tt.response, tt.ongoingChain, tt.startAt, tt.exepectedTotal) + require.Equal(t, tt.expectedOut, out) + }) + } +} + +func TestChainSync_getHighestBlock(t *testing.T) { + t.Parallel() + + cases := map[string]struct { + expectedHighestBlock uint + wantErr error + chainSyncPeerViewSet *peerViewSet + }{ + "no_peer_view": { + wantErr: errNoPeers, + expectedHighestBlock: 0, + chainSyncPeerViewSet: newPeerViewSet(10), + }, + "highest_block": { + expectedHighestBlock: 500, + chainSyncPeerViewSet: &peerViewSet{ + view: map[peer.ID]peerView{ + peer.ID("peer-A"): { + number: 100, + }, + peer.ID("peer-B"): { + number: 500, + }, + }, + }, + }, + } + + for tname, tt := range cases { + tt := tt + t.Run(tname, func(t *testing.T) { + t.Parallel() + + chainSync := &chainSync{ + peerViewSet: tt.chainSyncPeerViewSet, + } + + highestBlock, err := chainSync.getHighestBlock() + require.ErrorIs(t, err, tt.wantErr) + require.Equal(t, tt.expectedHighestBlock, highestBlock) + }) + } +} +func TestChainSync_BootstrapSync_SuccessfulSync_WithInvalidJusticationBlock(t *testing.T) { + // TODO: https://github.com/ChainSafe/gossamer/issues/3468 + t.Skip() + t.Parallel() + + ctrl := gomock.NewController(t) + mockBlockState := NewMockBlockState(ctrl) + mockBlockState.EXPECT().GetFinalisedNotifierChannel().Return(make(chan *types.FinalisationInfo)) + mockedGenesisHeader := types.NewHeader(common.NewHash([]byte{0}), trie.EmptyHash, + trie.EmptyHash, 0, types.NewDigest()) + + mockNetwork := NewMockNetwork(ctrl) + mockRequestMaker := NewMockRequestMaker(ctrl) + + mockBabeVerifier := NewMockBabeVerifier(ctrl) + mockStorageState := NewMockStorageState(ctrl) + mockImportHandler := NewMockBlockImportHandler(ctrl) + mockTelemetry := NewMockTelemetry(ctrl) + mockFinalityGadget := NewMockFinalityGadget(ctrl) + + // this test expects two workers responding each request with 128 blocks which means + // we should import 256 blocks in total + blockResponse := createSuccesfullBlockResponse(t, mockedGenesisHeader.Hash(), 1, 129) + const announceBlock = false + + invalidJustificationBlock := blockResponse.BlockData[90] + invalidJustification := &[]byte{0x01, 0x01, 0x01, 0x02} + invalidJustificationBlock.Justification = invalidJustification + + // here we split the whole set in two parts each one will be the "response" for each peer + worker1Response := &messages.BlockResponseMessage{ + BlockData: blockResponse.BlockData[:128], + } + + // the first peer will respond the from the block 1 to 128 so the ensureBlockImportFlow + // will setup the expectations starting from the genesis header until block 128 + ensureSuccessfulBlockImportFlow(t, mockedGenesisHeader, worker1Response.BlockData[:90], mockBlockState, + mockBabeVerifier, mockStorageState, mockImportHandler, mockTelemetry, networkInitialSync, announceBlock) + + errVerifyBlockJustification := errors.New("VerifyBlockJustification mock error") + mockFinalityGadget.EXPECT(). + VerifyBlockJustification( + invalidJustificationBlock.Header.Hash(), + invalidJustificationBlock.Header.Number, + *invalidJustification). + Return(uint64(0), uint64(0), errVerifyBlockJustification) + + // we use gomock.Any since I cannot guarantee which peer picks which request + // but the first call to DoBlockRequest will return the first set and the second + // call will return the second set + mockRequestMaker.EXPECT(). + Do(gomock.Any(), gomock.Any(), &messages.BlockResponseMessage{}). + DoAndReturn(func(peerID, _, response any) any { + responsePtr := response.(*messages.BlockResponseMessage) + *responsePtr = *worker1Response + + fmt.Println("mocked request maker") + return nil + }) + + // setup a chain sync which holds in its peer view map + // 3 peers, each one announce block 129 as its best block number. + // We start this test with genesis block being our best block, so + // we're far behind by 128 blocks, we should execute a bootstrap + // sync request those blocks + const blocksAhead = 128 + cs := setupChainSyncToBootstrapMode(t, blocksAhead, + mockBlockState, mockNetwork, mockRequestMaker, mockBabeVerifier, + mockStorageState, mockImportHandler, mockTelemetry) + + cs.finalityGadget = mockFinalityGadget + + target := cs.peerViewSet.getTarget() + require.Equal(t, uint(blocksAhead), target) + + // include a new worker in the worker pool set, this worker + // should be an available peer that will receive a block request + // the worker pool executes the workers management + cs.workerPool.fromBlockAnnounce(peer.ID("alice")) + //cs.workerPool.fromBlockAnnounce(peer.ID("bob")) + + err := cs.requestMaxBlocksFrom(mockedGenesisHeader, networkInitialSync) + require.ErrorIs(t, err, errVerifyBlockJustification) + + err = cs.workerPool.stop() + require.NoError(t, err) + + // peer should be not in the worker pool + // peer should be in the ignore list + require.Len(t, cs.workerPool.workers, 1) +} diff --git a/dot/sync/interfaces.go b/dot/sync/interfaces.go new file mode 100644 index 0000000000..03a03cda8e --- /dev/null +++ b/dot/sync/interfaces.go @@ -0,0 +1,90 @@ +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package sync + +import ( + "encoding/json" + "sync" + + "github.com/ChainSafe/gossamer/dot/peerset" + "github.com/ChainSafe/gossamer/dot/types" + "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/lib/runtime" + rtstorage "github.com/ChainSafe/gossamer/lib/runtime/storage" + "github.com/libp2p/go-libp2p/core/peer" +) + +// BlockState is the interface for the block state +type BlockState interface { + BestBlockHeader() (*types.Header, error) + BestBlockNumber() (number uint, err error) + CompareAndSetBlockData(bd *types.BlockData) error + GetBlockBody(common.Hash) (*types.Body, error) + GetHeader(common.Hash) (*types.Header, error) + HasHeader(hash common.Hash) (bool, error) + Range(startHash, endHash common.Hash) (hashes []common.Hash, err error) + RangeInMemory(start, end common.Hash) ([]common.Hash, error) + GetReceipt(common.Hash) ([]byte, error) + GetMessageQueue(common.Hash) ([]byte, error) + GetJustification(common.Hash) ([]byte, error) + SetFinalisedHash(hash common.Hash, round uint64, setID uint64) error + SetJustification(hash common.Hash, data []byte) error + GetHashByNumber(blockNumber uint) (common.Hash, error) + GetBlockByHash(common.Hash) (*types.Block, error) + GetRuntime(blockHash common.Hash) (runtime runtime.Instance, err error) + StoreRuntime(blockHash common.Hash, runtime runtime.Instance) + GetHighestFinalisedHeader() (*types.Header, error) + GetFinalisedNotifierChannel() chan *types.FinalisationInfo + GetHeaderByNumber(num uint) (*types.Header, error) + GetAllBlocksAtNumber(num uint) ([]common.Hash, error) + IsDescendantOf(parent, child common.Hash) (bool, error) + + IsPaused() bool + Pause() error +} + +// StorageState is the interface for the storage state +type StorageState interface { + TrieState(root *common.Hash) (*rtstorage.TrieState, error) + sync.Locker +} + +// TransactionState is the interface for transaction queue methods +type TransactionState interface { + RemoveExtrinsic(ext types.Extrinsic) +} + +// BabeVerifier deals with BABE block verification +type BabeVerifier interface { + VerifyBlock(header *types.Header) error +} + +// FinalityGadget implements justification verification functionality +type FinalityGadget interface { + VerifyBlockJustification(finalizedHash common.Hash, finalizedNumber uint, encoded []byte) ( + round uint64, setID uint64, err error) +} + +// BlockImportHandler is the interface for the handler of newly imported blocks +type BlockImportHandler interface { + HandleBlockImport(block *types.Block, state *rtstorage.TrieState, announce bool) error +} + +// Network is the interface for the network +type Network interface { + // Peers returns a list of currently connected peers + Peers() []common.PeerInfo + + // ReportPeer reports peer based on the peer behaviour. + ReportPeer(change peerset.ReputationChange, p peer.ID) + + AllConnectedPeersIDs() []peer.ID + + BlockAnnounceHandshake(*types.Header) error +} + +// Telemetry is the telemetry client to send telemetry messages. +type Telemetry interface { + SendMessage(msg json.Marshaler) +} diff --git a/dot/sync/mocks_test.go b/dot/sync/mocks_test.go index e006ce4493..6ad35f501c 100644 --- a/dot/sync/mocks_test.go +++ b/dot/sync/mocks_test.go @@ -393,6 +393,20 @@ func (mr *MockBlockStateMockRecorder) RangeInMemory(arg0, arg1 any) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RangeInMemory", reflect.TypeOf((*MockBlockState)(nil).RangeInMemory), arg0, arg1) } +// SetFinalisedHash mocks base method. +func (m *MockBlockState) SetFinalisedHash(arg0 common.Hash, arg1, arg2 uint64) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetFinalisedHash", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetFinalisedHash indicates an expected call of SetFinalisedHash. +func (mr *MockBlockStateMockRecorder) SetFinalisedHash(arg0, arg1, arg2 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetFinalisedHash", reflect.TypeOf((*MockBlockState)(nil).SetFinalisedHash), arg0, arg1, arg2) +} + // SetJustification mocks base method. func (m *MockBlockState) SetJustification(arg0 common.Hash, arg1 []byte) error { m.ctrl.T.Helper() @@ -577,17 +591,19 @@ func (m *MockFinalityGadget) EXPECT() *MockFinalityGadgetMockRecorder { } // VerifyBlockJustification mocks base method. -func (m *MockFinalityGadget) VerifyBlockJustification(arg0 common.Hash, arg1 []byte) error { +func (m *MockFinalityGadget) VerifyBlockJustification(arg0 common.Hash, arg1 uint, arg2 []byte) (uint64, uint64, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "VerifyBlockJustification", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 + ret := m.ctrl.Call(m, "VerifyBlockJustification", arg0, arg1, arg2) + ret0, _ := ret[0].(uint64) + ret1, _ := ret[1].(uint64) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 } // VerifyBlockJustification indicates an expected call of VerifyBlockJustification. -func (mr *MockFinalityGadgetMockRecorder) VerifyBlockJustification(arg0, arg1 any) *gomock.Call { +func (mr *MockFinalityGadgetMockRecorder) VerifyBlockJustification(arg0, arg1, arg2 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "VerifyBlockJustification", reflect.TypeOf((*MockFinalityGadget)(nil).VerifyBlockJustification), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "VerifyBlockJustification", reflect.TypeOf((*MockFinalityGadget)(nil).VerifyBlockJustification), arg0, arg1, arg2) } // MockBlockImportHandler is a mock of BlockImportHandler interface. diff --git a/dot/sync/syncer_integration_test.go b/dot/sync/syncer_integration_test.go new file mode 100644 index 0000000000..7361a5280e --- /dev/null +++ b/dot/sync/syncer_integration_test.go @@ -0,0 +1,213 @@ +//go:build integration + +// Copyright 2021 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package sync + +import ( + "errors" + "path/filepath" + "testing" + + "github.com/ChainSafe/gossamer/dot/state" + "github.com/ChainSafe/gossamer/dot/types" + "github.com/ChainSafe/gossamer/internal/database" + "github.com/ChainSafe/gossamer/internal/log" + "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/lib/genesis" + runtime "github.com/ChainSafe/gossamer/lib/runtime" + rtstorage "github.com/ChainSafe/gossamer/lib/runtime/storage" + wazero_runtime "github.com/ChainSafe/gossamer/lib/runtime/wazero" + "github.com/ChainSafe/gossamer/lib/utils" + "github.com/ChainSafe/gossamer/pkg/trie" + "github.com/ChainSafe/gossamer/tests/utils/config" + + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" +) + +func newTestSyncer(t *testing.T) *Service { + ctrl := gomock.NewController(t) + + mockTelemetryClient := NewMockTelemetry(ctrl) + mockTelemetryClient.EXPECT().SendMessage(gomock.Any()).AnyTimes() + + wazero_runtime.DefaultTestLogLvl = log.Warn + + cfg := &Config{} + testDatadirPath := t.TempDir() + + scfg := state.Config{ + Path: testDatadirPath, + LogLevel: log.Info, + Telemetry: mockTelemetryClient, + GenesisBABEConfig: config.BABEConfigurationTestDefault, + } + stateSrvc := state.NewService(scfg) + stateSrvc.UseMemDB() + + gen, genTrie, genHeader := newWestendDevGenesisWithTrieAndHeader(t) + err := stateSrvc.Initialise(&gen, &genHeader, genTrie) + require.NoError(t, err) + + err = stateSrvc.Start() + require.NoError(t, err) + + if cfg.BlockState == nil { + cfg.BlockState = stateSrvc.Block + } + + if cfg.StorageState == nil { + cfg.StorageState = stateSrvc.Storage + } + + // initialise runtime + genState := rtstorage.NewTrieState(genTrie) + + rtCfg := wazero_runtime.Config{ + Storage: genState, + LogLvl: log.Critical, + } + + if stateSrvc != nil { + rtCfg.NodeStorage.BaseDB = stateSrvc.Base + } else { + rtCfg.NodeStorage.BaseDB, err = database.LoadDatabase(filepath.Join(testDatadirPath, "offline_storage"), false) + require.NoError(t, err) + } + + rtCfg.CodeHash, err = cfg.StorageState.(*state.InmemoryStorageState).LoadCodeHash(nil) + require.NoError(t, err) + + instance, err := wazero_runtime.NewRuntimeFromGenesis(rtCfg) + require.NoError(t, err) + + bestBlockHash := cfg.BlockState.(*state.BlockState).BestBlockHash() + cfg.BlockState.(*state.BlockState).StoreRuntime(bestBlockHash, instance) + blockImportHandler := NewMockBlockImportHandler(ctrl) + blockImportHandler.EXPECT().HandleBlockImport(gomock.AssignableToTypeOf(&types.Block{}), + gomock.AssignableToTypeOf(&rtstorage.TrieState{}), false).DoAndReturn( + func(block *types.Block, ts *rtstorage.TrieState, _ bool) error { + // store updates state trie nodes in database + if err = stateSrvc.Storage.StoreTrie(ts, &block.Header); err != nil { + logger.Warnf("failed to store state trie for imported block %s: %s", block.Header.Hash(), err) + return err + } + + // store block in database + err = stateSrvc.Block.AddBlock(block) + require.NoError(t, err) + + stateSrvc.Block.StoreRuntime(block.Header.Hash(), instance) + logger.Debugf("imported block %s and stored state trie with root %s", + block.Header.Hash(), ts.Trie().MustHash()) + return nil + }).AnyTimes() + cfg.BlockImportHandler = blockImportHandler + + cfg.TransactionState = stateSrvc.Transaction + mockBabeVerifier := NewMockBabeVerifier(ctrl) + mockBabeVerifier.EXPECT().VerifyBlock(gomock.AssignableToTypeOf(&types.Header{})).AnyTimes() + cfg.BabeVerifier = mockBabeVerifier + cfg.LogLvl = log.Trace + mockFinalityGadget := NewMockFinalityGadget(ctrl) + mockFinalityGadget.EXPECT().VerifyBlockJustification(gomock.AssignableToTypeOf(common.Hash{}), + gomock.AssignableToTypeOf(uint(0)), gomock.AssignableToTypeOf([]byte{})). + DoAndReturn(func(hash common.Hash, justification []byte) error { + return nil + }).AnyTimes() + + cfg.FinalityGadget = mockFinalityGadget + cfg.Network = NewMockNetwork(ctrl) + cfg.Telemetry = mockTelemetryClient + cfg.RequestMaker = NewMockRequestMaker(ctrl) + syncer, err := NewService(cfg) + require.NoError(t, err) + return syncer +} + +func newWestendDevGenesisWithTrieAndHeader(t *testing.T) ( + gen genesis.Genesis, genesisTrie trie.Trie, genesisHeader types.Header) { + t.Helper() + + genesisPath := utils.GetWestendDevRawGenesisPath(t) + genesisPtr, err := genesis.NewGenesisFromJSONRaw(genesisPath) + require.NoError(t, err) + gen = *genesisPtr + + genesisTrie, err = runtime.NewTrieFromGenesis(gen) + require.NoError(t, err) + + parentHash := common.NewHash([]byte{0}) + stateRoot := genesisTrie.MustHash() + extrinsicRoot := trie.EmptyHash + const number = 0 + digest := types.NewDigest() + genesisHeaderPtr := types.NewHeader(parentHash, + stateRoot, extrinsicRoot, number, digest) + genesisHeader = *genesisHeaderPtr + + return gen, genesisTrie, genesisHeader +} + +func TestHighestBlock(t *testing.T) { + type input struct { + highestBlock uint + err error + } + type output struct { + highestBlock uint + } + type test struct { + name string + in input + out output + } + tests := []test{ + { + name: "when_*chainSync.getHighestBlock()_returns_0,_error_should_return_0", + in: input{ + highestBlock: 0, + err: errors.New("fake error"), + }, + out: output{ + highestBlock: 0, + }, + }, + { + name: "when_*chainSync.getHighestBlock()_returns_0,_nil_should_return_0", + in: input{ + highestBlock: 0, + err: nil, + }, + out: output{ + highestBlock: 0, + }, + }, + { + name: "when_*chainSync.getHighestBlock()_returns_50,_nil_should_return_50", + in: input{ + highestBlock: 50, + err: nil, + }, + out: output{ + highestBlock: 50, + }, + }, + } + for _, ts := range tests { + t.Run(ts.name, func(t *testing.T) { + s := newTestSyncer(t) + + ctrl := gomock.NewController(t) + chainSync := NewMockChainSync(ctrl) + chainSync.EXPECT().getHighestBlock().Return(ts.in.highestBlock, ts.in.err) + + s.chainSync = chainSync + + result := s.HighestBlock() + require.Equal(t, result, ts.out.highestBlock) + }) + } +} diff --git a/go.mod b/go.mod index 7d59b785d4..7e26919817 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,7 @@ require ( github.com/ethereum/go-ethereum v1.14.8 github.com/fatih/color v1.17.0 github.com/gammazero/deque v0.2.1 - github.com/go-playground/validator/v10 v10.22.0 + github.com/go-playground/validator/v10 v10.22.1 github.com/google/go-cmp v0.6.0 github.com/google/uuid v1.6.0 github.com/gorilla/mux v1.8.1 @@ -32,7 +32,7 @@ require ( github.com/minio/sha256-simd v1.0.1 github.com/multiformats/go-multiaddr v0.13.0 github.com/nanobox-io/golang-scribble v0.0.0-20190309225732-aa3e7c118975 - github.com/prometheus/client_golang v1.20.2 + github.com/prometheus/client_golang v1.20.3 github.com/prometheus/client_model v0.6.1 github.com/qdm12/gotree v0.2.0 github.com/spf13/cobra v1.8.1 @@ -40,10 +40,11 @@ require ( github.com/stretchr/testify v1.9.0 github.com/tetratelabs/wazero v1.1.0 github.com/tidwall/btree v1.7.0 + github.com/tyler-smith/go-bip39 v1.1.0 go.uber.org/mock v0.4.0 - golang.org/x/crypto v0.26.0 + golang.org/x/crypto v0.27.0 golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 - golang.org/x/term v0.23.0 + golang.org/x/term v0.24.0 google.golang.org/protobuf v1.34.2 ) @@ -208,8 +209,8 @@ require ( golang.org/x/mod v0.19.0 // indirect golang.org/x/net v0.27.0 // indirect golang.org/x/sync v0.8.0 // indirect - golang.org/x/sys v0.23.0 // indirect - golang.org/x/text v0.17.0 // indirect + golang.org/x/sys v0.25.0 // indirect + golang.org/x/text v0.18.0 // indirect golang.org/x/tools v0.23.0 // indirect gonum.org/v1/gonum v0.15.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect diff --git a/go.sum b/go.sum index 1d94c1ebb9..1f2e48bc1e 100644 --- a/go.sum +++ b/go.sum @@ -190,8 +190,8 @@ github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/o github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= -github.com/go-playground/validator/v10 v10.22.0 h1:k6HsTZ0sTnROkhS//R0O+55JgM8C4Bx7ia+JlgcnOao= -github.com/go-playground/validator/v10 v10.22.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= +github.com/go-playground/validator/v10 v10.22.1 h1:40JcKH+bBNGFczGuoBYgX4I6m/i27HYW8P9FDk5PbgA= +github.com/go-playground/validator/v10 v10.22.1/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI= github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= github.com/go-yaml/yaml v2.1.0+incompatible/go.mod h1:w2MrLa16VYP0jy6N7M5kHaCkaLENm+P+Tv+MfurjSw0= @@ -545,8 +545,8 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH github.com/polydawn/refmt v0.89.0 h1:ADJTApkvkeBZsN0tBTx8QjpD9JkmxbKp0cxfr9qszm4= github.com/polydawn/refmt v0.89.0/go.mod h1:/zvteZs/GwLtCgZ4BL6CBsk9IKIlexP43ObX9AxTqTw= github.com/prometheus/client_golang v0.8.0/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= -github.com/prometheus/client_golang v1.20.2 h1:5ctymQzZlyOON1666svgwn3s6IKWgfbjsejTMiXIyjg= -github.com/prometheus/client_golang v1.20.2/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE= +github.com/prometheus/client_golang v1.20.3 h1:oPksm4K8B+Vt35tUhw6GbSNSgVlVSBH0qELP/7u83l4= +github.com/prometheus/client_golang v1.20.3/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= @@ -665,6 +665,8 @@ github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+F github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY= github.com/tomasen/realip v0.0.0-20180522021738-f0c99a92ddce h1:fb190+cK2Xz/dvi9Hv8eCYJYvIGUTN2/KLq1pT6CjEc= github.com/tomasen/realip v0.0.0-20180522021738-f0c99a92ddce/go.mod h1:o8v6yHRoik09Xen7gje4m9ERNah1d1PPsVq1VEx9vE4= +github.com/tyler-smith/go-bip39 v1.1.0 h1:5eUemwrMargf3BSLRRCalXT93Ns6pQJIjYQN2nyfOP8= +github.com/tyler-smith/go-bip39 v1.1.0/go.mod h1:gUYDtqQw1JS3ZJ8UWVcGTGqqr6YIN3CWg+kkNaLt55U= github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= github.com/urfave/cli v1.22.2/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= github.com/urfave/cli v1.22.10/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= @@ -731,8 +733,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= -golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= -golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54= +golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A= +golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 h1:2dVuKD2vS7b0QIHQbpyTISPd0LeHDbnYEryqj5Q1ug8= golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56/go.mod h1:M4RDyNAINzryxdtnbRXRL/OHtkFuWGRjvuhBJpk2IlY= @@ -838,8 +840,8 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.23.0 h1:YfKFowiIMvtgl1UERQoTPPToxltDeZfbj4H7dVUCwmM= -golang.org/x/sys v0.23.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= +golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= @@ -847,8 +849,8 @@ golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= -golang.org/x/term v0.23.0 h1:F6D4vR+EHoL9/sWAWgAR1H2DcHr4PareCbAaCo1RpuU= -golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk= +golang.org/x/term v0.24.0 h1:Mh5cbb+Zk2hqqXNO7S1iTjEphVL+jb8ZWaqh/g+JWkM= +golang.org/x/term v0.24.0/go.mod h1:lOBK/LVxemqiMij05LGJ0tzNr8xlmwBRJ81PX6wVLH8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= @@ -859,8 +861,8 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= -golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224= +golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= diff --git a/internal/client/consensus/grandpa/authorities.go b/internal/client/consensus/grandpa/authorities.go new file mode 100644 index 0000000000..261042fe9a --- /dev/null +++ b/internal/client/consensus/grandpa/authorities.go @@ -0,0 +1,10 @@ +// Copyright 2023 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package grandpa + +// generic representation of hash and number tuple +type HashNumber[H, N any] struct { + Hash H + Number N +} diff --git a/internal/client/consensus/grandpa/justification.go b/internal/client/consensus/grandpa/justification.go new file mode 100644 index 0000000000..94fc50bb92 --- /dev/null +++ b/internal/client/consensus/grandpa/justification.go @@ -0,0 +1,302 @@ +// Copyright 2023 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package grandpa + +import ( + "errors" + "fmt" + "io" + "reflect" + + primitives "github.com/ChainSafe/gossamer/internal/primitives/consensus/grandpa" + "github.com/ChainSafe/gossamer/internal/primitives/runtime" + "github.com/ChainSafe/gossamer/internal/primitives/runtime/generic" + grandpa "github.com/ChainSafe/gossamer/pkg/finality-grandpa" + "github.com/ChainSafe/gossamer/pkg/scale" +) + +var ( + errInvalidAuthoritiesSet = errors.New("current state of blockchain has invalid authorities set") + errBadJustification = errors.New("bad justification for header") + errBlockNotDescendentOfBase = errors.New("block not descendent of base") +) + +// A GRANDPA justification for block finality, it includes a commit message and +// an ancestry proof including all headers routing all precommit target blocks +// to the commit target block. Due to the current voting strategy the precommit +// targets should be the same as the commit target, since honest voters don't +// vote past authority set change blocks. +// +// This is meant to be stored in the db and passed around the network to other +// nodes, and are used by syncing nodes to prove authority set handoffs. +type GrandpaJustification[Hash runtime.Hash, N runtime.Number] struct { + // The GRANDPA justification for block finality. + Justification primitives.GrandpaJustification[Hash, N] +} + +// Type used for decoding grandpa justifications (can pass in generic Header type) +type decodeGrandpaJustification[ + Hash runtime.Hash, + N runtime.Number, + Hasher runtime.Hasher[Hash], +] GrandpaJustification[Hash, N] + +func decodeJustification[ + Hash runtime.Hash, + N runtime.Number, + Hasher runtime.Hasher[Hash], +](encodedJustification []byte) (*GrandpaJustification[Hash, N], error) { + newJustificaiton := decodeGrandpaJustification[Hash, N, Hasher]{} + err := scale.Unmarshal(encodedJustification, &newJustificaiton) + if err != nil { + return nil, err + } + return newJustificaiton.GrandpaJustification(), nil +} + +func (dgj *decodeGrandpaJustification[H, N, Hasher]) UnmarshalSCALE(reader io.Reader) (err error) { + type roundCommitHeader struct { + Round uint64 + Commit primitives.Commit[H, N] + Headers []generic.Header[N, H, Hasher] + } + rch := roundCommitHeader{} + decoder := scale.NewDecoder(reader) + err = decoder.Decode(&rch) + if err != nil { + return + } + + dgj.Justification.Round = rch.Round + dgj.Justification.Commit = rch.Commit + dgj.Justification.VoteAncestries = make([]runtime.Header[N, H], len(rch.Headers)) + for i, header := range rch.Headers { + header := header + dgj.Justification.VoteAncestries[i] = &header + } + return +} + +func (dgj decodeGrandpaJustification[Hash, N, Hasher]) GrandpaJustification() *GrandpaJustification[Hash, N] { + return &GrandpaJustification[Hash, N]{ + Justification: primitives.GrandpaJustification[Hash, N]{ + Round: dgj.Justification.Round, + Commit: dgj.Justification.Commit, + VoteAncestries: dgj.Justification.VoteAncestries, + }, + } +} + +// DecodeGrandpaJustificationVerifyFinalizes will decode a GRANDPA justification and validate the commit and +// the votes' ancestry proofs finalize the given block. +func DecodeGrandpaJustificationVerifyFinalizes[ + Hash runtime.Hash, + N runtime.Number, + Hasher runtime.Hasher[Hash], +]( + encoded []byte, + finalizedTarget HashNumber[Hash, N], + setID uint64, + voters grandpa.VoterSet[string], +) (GrandpaJustification[Hash, N], error) { + justification, err := decodeJustification[Hash, N, Hasher](encoded) + if err != nil { + return GrandpaJustification[Hash, N]{}, fmt.Errorf("error decoding justification for header: %s", err) + } + + decodedTarget := HashNumber[Hash, N]{ + Hash: justification.Justification.Commit.TargetHash, + Number: justification.Justification.Commit.TargetNumber, + } + + if decodedTarget != finalizedTarget { + return GrandpaJustification[Hash, N]{}, fmt.Errorf("invalid commit target in grandpa justification") + } + + return *justification, justification.verifyWithVoterSet(setID, voters) +} + +// Verify will validate the commit and the votes' ancestry proofs. +func (j *GrandpaJustification[Hash, N]) Verify(setID uint64, authorities primitives.AuthorityList) error { + var weights []grandpa.IDWeight[string] + for _, authority := range authorities { + weight := grandpa.IDWeight[string]{ + ID: string(authority.AuthorityID.Bytes()), + Weight: uint64(authority.AuthorityWeight), + } + weights = append(weights, weight) + } + + voters := grandpa.NewVoterSet[string](weights) + if voters != nil { + err := j.verifyWithVoterSet(setID, *voters) + return err + } + return fmt.Errorf("%w", errInvalidAuthoritiesSet) +} + +// Validate the commit and the votes' ancestry proofs. +func (j *GrandpaJustification[Hash, N]) verifyWithVoterSet( + setID uint64, + voters grandpa.VoterSet[string], +) error { + ancestryChain := newAncestryChain[Hash, N](j.Justification.VoteAncestries) + signedPrecommits := make([]grandpa.SignedPrecommit[Hash, N, string, string], 0) + for _, pc := range j.Justification.Commit.Precommits { + signedPrecommits = append(signedPrecommits, grandpa.SignedPrecommit[Hash, N, string, string]{ + Precommit: pc.Precommit, + Signature: string(pc.Signature[:]), + ID: string(pc.ID.Bytes()), + }) + } + commitValidationResult, err := grandpa.ValidateCommit[Hash, N, string, string]( + grandpa.Commit[Hash, N, string, string]{ + TargetHash: j.Justification.Commit.TargetHash, + TargetNumber: j.Justification.Commit.TargetNumber, + Precommits: signedPrecommits, + }, + voters, + ancestryChain, + ) + if err != nil { + return fmt.Errorf("%w: invalid commit in grandpa justification", errBadJustification) + } + + if !commitValidationResult.Valid() { + return fmt.Errorf("%w: invalid commit in grandpa justification", errBadJustification) + } + + // we pick the precommit for the lowest block as the base that + // should serve as the root block for populating ancestry (i.e. + // collect all headers from all precommit blocks to the base) + precommits := j.Justification.Commit.Precommits + var minPrecommit *grandpa.SignedPrecommit[Hash, N, primitives.AuthoritySignature, primitives.AuthorityID] + if len(precommits) == 0 { + panic("can only fail if precommits is empty; commit has been validated above; " + + "valid commits must include precommits") + } + for _, precommit := range precommits { + currPrecommit := precommit + if minPrecommit == nil { + minPrecommit = &currPrecommit + } else if currPrecommit.Precommit.TargetNumber <= minPrecommit.Precommit.TargetNumber { + minPrecommit = &currPrecommit + } + } + + baseHash := minPrecommit.Precommit.TargetHash + visitedHashes := make(map[Hash]struct{}) + for _, signed := range precommits { + msg := grandpa.NewMessage(signed.Precommit) + isValidSignature := primitives.CheckMessageSignature[Hash, N]( + msg, + signed.ID, + signed.Signature, + primitives.RoundNumber(j.Justification.Round), + primitives.SetID(setID), + ) + + if !isValidSignature { + return fmt.Errorf("%w: invalid signature for precommit in grandpa justification", + errBadJustification) + } + + if baseHash == signed.Precommit.TargetHash { + continue + } + + route, err := ancestryChain.Ancestry(baseHash, signed.Precommit.TargetHash) + if err != nil { + return fmt.Errorf("%w: invalid precommit ancestry proof in grandpa justification", + errBadJustification) + } + + // ancestry starts from parent HashField but the precommit target HashField has been + // visited + visitedHashes[signed.Precommit.TargetHash] = struct{}{} + for _, hash := range route { + visitedHashes[hash] = struct{}{} + } + } + + ancestryHashes := make(map[Hash]struct{}) + for _, header := range j.Justification.VoteAncestries { + hash := header.Hash() + ancestryHashes[hash] = struct{}{} + } + + if len(visitedHashes) != len(ancestryHashes) { + return fmt.Errorf("%w: invalid precommit ancestries in grandpa justification with unused headers", + errBadJustification) + } + + // Check if maps are equal + if !reflect.DeepEqual(ancestryHashes, visitedHashes) { + return fmt.Errorf("%w: invalid precommit ancestries in grandpa justification with unused headers", + errBadJustification) + } + + return nil +} + +// Target is the target block NumberField and HashField that this justifications proves finality for +func (j *GrandpaJustification[Hash, N]) Target() HashNumber[Hash, N] { + return HashNumber[Hash, N]{ + Number: j.Justification.Commit.TargetNumber, + Hash: j.Justification.Commit.TargetHash, + } +} + +// ancestryChain a utility trait implementing `grandpa.Chain` using a given set of headers. +// This is useful when validating commits, using the given set of headers to +// verify a valid ancestry route to the target commit block. +type ancestryChain[Hash runtime.Hash, N runtime.Number] struct { + ancestry map[Hash]runtime.Header[N, Hash] +} + +func newAncestryChain[Hash runtime.Hash, N runtime.Number]( + headers []runtime.Header[N, Hash], +) ancestryChain[Hash, N] { + ancestry := make(map[Hash]runtime.Header[N, Hash]) + for _, header := range headers { + hash := header.Hash() + ancestry[hash] = header + } + return ancestryChain[Hash, N]{ + ancestry: ancestry, + } +} + +func (ac ancestryChain[Ordered, N]) Ancestry(base Ordered, block Ordered) ([]Ordered, error) { + route := make([]Ordered, 0) + currentHash := block + + for { + if currentHash == base { + break + } + + br, ok := ac.ancestry[currentHash] + if !ok { + return nil, fmt.Errorf("%w", errBlockNotDescendentOfBase) + } + block = br.ParentHash() + currentHash = block + route = append(route, currentHash) + } + + if len(route) != 0 { + route = route[:len(route)-1] + } + return route, nil +} + +func (ac ancestryChain[Ordered, N]) IsEqualOrDescendantOf(base Ordered, block Ordered) bool { + if base == block { + return true + } + + _, err := ac.Ancestry(base, block) + return err == nil +} diff --git a/internal/client/consensus/grandpa/justification_test.go b/internal/client/consensus/grandpa/justification_test.go new file mode 100644 index 0000000000..b5fbb84d75 --- /dev/null +++ b/internal/client/consensus/grandpa/justification_test.go @@ -0,0 +1,523 @@ +// Copyright 2023 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package grandpa + +import ( + "reflect" + "testing" + + primitives "github.com/ChainSafe/gossamer/internal/primitives/consensus/grandpa" + ced25519 "github.com/ChainSafe/gossamer/internal/primitives/core/ed25519" + "github.com/ChainSafe/gossamer/internal/primitives/core/hash" + "github.com/ChainSafe/gossamer/internal/primitives/keyring/ed25519" + "github.com/ChainSafe/gossamer/internal/primitives/runtime" + "github.com/ChainSafe/gossamer/internal/primitives/runtime/generic" + grandpa "github.com/ChainSafe/gossamer/pkg/finality-grandpa" + "github.com/ChainSafe/gossamer/pkg/scale" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func makePrecommit(t *testing.T, + targetHash string, + targetNumber uint64, + round uint64, //nolint:unparam + setID uint64, + voter ed25519.Keyring, +) grandpa.SignedPrecommit[hash.H256, uint64, primitives.AuthoritySignature, primitives.AuthorityID] { + t.Helper() + + precommit := grandpa.Precommit[hash.H256, uint64]{ + TargetHash: hash.H256(targetHash), + TargetNumber: targetNumber, + } + msg := grandpa.NewMessage(precommit) + encoded := primitives.NewLocalizedPayload(primitives.RoundNumber(round), primitives.SetID(setID), msg) + signature := voter.Sign(encoded) + + return grandpa.SignedPrecommit[hash.H256, uint64, primitives.AuthoritySignature, primitives.AuthorityID]{ + Precommit: grandpa.Precommit[hash.H256, uint64]{ + TargetHash: hash.H256(targetHash), + TargetNumber: targetNumber, + }, + Signature: signature, + ID: voter.Pair().Public().(ced25519.Public), + } +} + +func TestJustificationEncoding(t *testing.T) { + var hashA = "a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" //nolint:lll + var precommits []grandpa.SignedPrecommit[hash.H256, uint64, primitives.AuthoritySignature, primitives.AuthorityID] + precommit := makePrecommit(t, hashA, 1, 1, 1, ed25519.Alice) + precommits = append(precommits, precommit) + + expAncestries := make([]runtime.Header[uint64, hash.H256], 0) + expAncestries = append(expAncestries, generic.NewHeader[uint64, hash.H256, runtime.BlakeTwo256]( + 100, + hash.H256(""), + hash.H256(""), + hash.H256(hashA), + runtime.Digest{}), + ) + + expected := primitives.GrandpaJustification[hash.H256, uint64]{ + Round: 2, + Commit: primitives.Commit[hash.H256, uint64]{ + TargetHash: hash.H256( + "b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", //nolint:lll + ), + TargetNumber: 1, + Precommits: precommits, + }, + VoteAncestries: expAncestries, + } + + encodedJustification, err := scale.Marshal(expected) + require.NoError(t, err) + + justification, err := decodeJustification[hash.H256, uint64, runtime.BlakeTwo256](encodedJustification) + require.NoError(t, err) + require.Equal(t, expected, justification.Justification) +} + +func TestDecodeGrandpaJustificationVerifyFinalizes(t *testing.T) { + var a hash.H256 = "a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" //nolint:lll + + // Invalid Encoding + invalidEncoding := []byte{21} + _, err := DecodeGrandpaJustificationVerifyFinalizes[hash.H256, uint64, runtime.BlakeTwo256]( + invalidEncoding, + HashNumber[hash.H256, uint64]{}, + 2, + grandpa.VoterSet[string]{}) + require.Error(t, err) + + // Invalid target + justification := primitives.GrandpaJustification[hash.H256, uint64]{ + Commit: primitives.Commit[hash.H256, uint64]{ + TargetHash: a, + TargetNumber: 1, + }, + } + + encWrongTarget, err := scale.Marshal(justification) + require.NoError(t, err) + _, err = DecodeGrandpaJustificationVerifyFinalizes[hash.H256, uint64, runtime.BlakeTwo256]( + encWrongTarget, + HashNumber[hash.H256, uint64]{}, + 2, + grandpa.VoterSet[string]{}) + require.Error(t, err) + require.ErrorContains(t, err, "invalid commit target in grandpa justification") + + headerB := generic.NewHeader[uint64, hash.H256, runtime.BlakeTwo256]( + 2, + hash.H256(""), + hash.H256(""), + a, + runtime.Digest{}) + + hederList := []runtime.Header[uint64, hash.H256]{headerB} + + var precommits []grandpa.SignedPrecommit[hash.H256, uint64, primitives.AuthoritySignature, primitives.AuthorityID] + precommits = append(precommits, makePrecommit(t, string(a), 1, 1, 1, ed25519.Alice)) + precommits = append(precommits, makePrecommit(t, string(a), 1, 1, 1, ed25519.Bob)) + precommits = append(precommits, makePrecommit(t, string(headerB.Hash()), 2, 1, 1, ed25519.Charlie)) + + expectedJustification := primitives.GrandpaJustification[hash.H256, uint64]{ + Round: 1, + Commit: primitives.Commit[hash.H256, uint64]{ + TargetHash: a, + TargetNumber: 1, + Precommits: precommits, + }, + VoteAncestries: hederList, + } + + encodedJustification, err := scale.Marshal(expectedJustification) + require.NoError(t, err) + + target := HashNumber[hash.H256, uint64]{ + Hash: a, + Number: 1, + } + + idWeights := make([]grandpa.IDWeight[string], 0) + for i := 1; i <= 4; i++ { + var id ced25519.Public + switch i { + case 1: + id = ed25519.Alice.Pair().Public().(ced25519.Public) + case 2: + id = ed25519.Bob.Pair().Public().(ced25519.Public) + case 3: + id = ed25519.Charlie.Pair().Public().(ced25519.Public) + case 4: + id = ed25519.Ferdie.Pair().Public().(ced25519.Public) + } + idWeights = append(idWeights, grandpa.IDWeight[string]{ + ID: string(id[:]), Weight: 1, + }) + } + voters := grandpa.NewVoterSet(idWeights) + + newJustification, err := DecodeGrandpaJustificationVerifyFinalizes[hash.H256, uint64, runtime.BlakeTwo256]( + encodedJustification, + target, + 1, + *voters) + require.NoError(t, err) + require.Equal(t, expectedJustification, newJustification.Justification) +} + +func TestJustification_verify(t *testing.T) { + // Nil voter case + auths := make(primitives.AuthorityList, 0) + justification := GrandpaJustification[hash.H256, uint64]{} + err := justification.Verify(2, auths) + require.ErrorIs(t, err, errInvalidAuthoritiesSet) + + // happy path + for i := 1; i <= 4; i++ { + var id ced25519.Public + switch i { + case 1: + id = ed25519.Alice.Pair().Public().(ced25519.Public) + case 2: + id = ed25519.Bob.Pair().Public().(ced25519.Public) + case 3: + id = ed25519.Charlie.Pair().Public().(ced25519.Public) + case 4: + id = ed25519.Ferdie.Pair().Public().(ced25519.Public) + } + auths = append(auths, primitives.AuthorityIDWeight{ + AuthorityID: id, + AuthorityWeight: 1, + }) + } + + var a hash.H256 = "a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" //nolint:lll + headerB := generic.NewHeader[uint64, hash.H256, runtime.BlakeTwo256]( + 2, + hash.H256(""), + hash.H256(""), + a, + runtime.Digest{}) + + headerList := []runtime.Header[uint64, hash.H256]{headerB} + + var precommits []grandpa.SignedPrecommit[hash.H256, uint64, primitives.AuthoritySignature, primitives.AuthorityID] + precommits = append(precommits, makePrecommit(t, string(a), 1, 1, 2, ed25519.Alice)) + precommits = append(precommits, makePrecommit(t, string(a), 1, 1, 2, ed25519.Bob)) + precommits = append(precommits, makePrecommit(t, string(headerB.Hash()), 2, 1, 2, ed25519.Charlie)) + + validJustification := GrandpaJustification[hash.H256, uint64]{ + Justification: primitives.GrandpaJustification[hash.H256, uint64]{ + Round: 1, + Commit: primitives.Commit[hash.H256, uint64]{ + TargetHash: a, + TargetNumber: 1, + Precommits: precommits, + }, + VoteAncestries: headerList, + }, + } + + err = validJustification.Verify(2, auths) + require.NoError(t, err) +} + +func TestJustification_verifyWithVoterSet(t *testing.T) { + // 1) invalid commit + idWeights := make([]grandpa.IDWeight[string], 0) + for i := 1; i <= 4; i++ { + var id ced25519.Public + switch i { + case 1: + id = ed25519.Alice.Pair().Public().(ced25519.Public) + case 2: + id = ed25519.Bob.Pair().Public().(ced25519.Public) + case 3: + id = ed25519.Charlie.Pair().Public().(ced25519.Public) + case 4: + id = ed25519.Ferdie.Pair().Public().(ced25519.Public) + } + idWeights = append(idWeights, grandpa.IDWeight[string]{ + ID: string(id[:]), Weight: 1, + }) + } + voters := grandpa.NewVoterSet(idWeights) + + invalidJustification := GrandpaJustification[hash.H256, uint64]{ + primitives.GrandpaJustification[hash.H256, uint64]{ + Commit: primitives.Commit[hash.H256, uint64]{ + TargetHash: "B", + TargetNumber: 2, + }, + }, + } + + err := invalidJustification.verifyWithVoterSet(2, *voters) + require.ErrorIs(t, err, errBadJustification) + require.Equal(t, err.Error(), "bad justification for header: invalid commit in grandpa justification") + + // 2) visitedHashes != ancestryHashes + headerA := generic.NewHeader[uint64, hash.H256, runtime.BlakeTwo256]( + 1, + hash.H256(""), + hash.H256(""), + hash.H256(""), + runtime.Digest{}) + + headerB := generic.NewHeader[uint64, hash.H256, runtime.BlakeTwo256]( + 2, + hash.H256(""), + hash.H256(""), + headerA.Hash(), + runtime.Digest{}) + + headerList := []runtime.Header[uint64, hash.H256]{ + headerA, + headerB, + } + + var precommits []grandpa.SignedPrecommit[hash.H256, uint64, primitives.AuthoritySignature, primitives.AuthorityID] + precommits = append(precommits, makePrecommit(t, string(headerA.Hash()), 1, 1, 2, ed25519.Alice)) + precommits = append(precommits, makePrecommit(t, string(headerA.Hash()), 1, 1, 2, ed25519.Bob)) + precommits = append(precommits, makePrecommit(t, string(headerB.Hash()), 2, 1, 2, ed25519.Charlie)) + + validJustification := GrandpaJustification[hash.H256, uint64]{ + primitives.GrandpaJustification[hash.H256, uint64]{ + Commit: primitives.Commit[hash.H256, uint64]{ + TargetHash: headerA.Hash(), + TargetNumber: 1, + Precommits: precommits, + }, + VoteAncestries: headerList, + Round: 1, + }, + } + + err = validJustification.verifyWithVoterSet(2, *voters) + require.ErrorIs(t, err, errBadJustification) + require.Equal(t, err.Error(), "bad justification for header: "+ + "invalid precommit ancestries in grandpa justification with unused headers") + + // Valid case + headerList = []runtime.Header[uint64, hash.H256]{ + headerB, + } + + validJustification = GrandpaJustification[hash.H256, uint64]{ + primitives.GrandpaJustification[hash.H256, uint64]{ + Commit: primitives.Commit[hash.H256, uint64]{ + TargetHash: headerA.Hash(), + TargetNumber: 1, + Precommits: precommits, + }, + VoteAncestries: headerList, + Round: 1, + }, + } + + err = validJustification.verifyWithVoterSet(2, *voters) + require.NoError(t, err) +} + +func Test_newAncestryChain(t *testing.T) { + dummyHeader := generic.NewHeader[uint64, hash.H256, runtime.BlakeTwo256]( + 1, + hash.H256(""), + hash.H256(""), + hash.H256(""), + runtime.Digest{}) + + expAncestryMap := make(map[hash.H256]runtime.Header[uint64, hash.H256]) + expAncestryMap[dummyHeader.Hash()] = dummyHeader + type testCase struct { + name string + headers []runtime.Header[uint64, hash.H256] + want ancestryChain[hash.H256, uint64] + } + tests := []testCase{ + { + name: "noInputHeaders", + headers: []runtime.Header[uint64, hash.H256]{}, + want: ancestryChain[hash.H256, uint64]{ + ancestry: make(map[hash.H256]runtime.Header[uint64, hash.H256]), + }, + }, + { + name: "validInput", + headers: []runtime.Header[uint64, hash.H256]{ + dummyHeader, + }, + want: ancestryChain[hash.H256, uint64]{ + ancestry: expAncestryMap, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := newAncestryChain[hash.H256, uint64](tt.headers); !reflect.DeepEqual(got, tt.want) { + t.Errorf("newAncestryChain() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAncestryChain_Ancestry(t *testing.T) { + headerA := generic.NewHeader[uint64, hash.H256, runtime.BlakeTwo256]( + 1, + hash.H256(""), + hash.H256(""), + hash.H256(""), + runtime.Digest{}) + + headerB := generic.NewHeader[uint64, hash.H256, runtime.BlakeTwo256]( + 2, + hash.H256(""), + hash.H256(""), + headerA.Hash(), + runtime.Digest{}) + + headerC := generic.NewHeader[uint64, hash.H256, runtime.BlakeTwo256]( + 3, + hash.H256(""), + hash.H256(""), + headerB.Hash(), + runtime.Digest{}) + + invalidParentHeader := generic.NewHeader[uint64, hash.H256, runtime.BlakeTwo256]( + 2, + hash.H256(""), + hash.H256(""), + hash.H256("invalid"), + runtime.Digest{}) + + headerList := []runtime.Header[uint64, hash.H256]{ + headerA, + headerB, + headerC, + } + invalidHeaderList := []runtime.Header[uint64, hash.H256]{ + invalidParentHeader, + } + validAncestryMap := newAncestryChain[hash.H256, uint64](headerList) + invalidAncestryMap := newAncestryChain[hash.H256, uint64](invalidHeaderList) + + type testCase struct { + name string + chain ancestryChain[hash.H256, uint64] + base hash.H256 + block hash.H256 + want []hash.H256 + expErr error + } + tests := []testCase{ + { + name: "baseEqualsBlock", + chain: validAncestryMap, + base: headerA.Hash(), + block: headerA.Hash(), + want: []hash.H256{}, + }, + { + name: "baseEqualsBlock", + chain: validAncestryMap, + base: headerA.Hash(), + block: "notDescendant", + expErr: errBlockNotDescendentOfBase, + }, + { + name: "invalidParentHashField", + chain: invalidAncestryMap, + base: headerA.Hash(), + block: "notDescendant", + expErr: errBlockNotDescendentOfBase, + }, + { + name: "validRoute", + chain: validAncestryMap, + base: headerA.Hash(), + block: headerC.Hash(), + want: []hash.H256{headerB.Hash()}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.chain.Ancestry(tt.base, tt.block) + assert.ErrorIs(t, err, tt.expErr) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestAncestryChain_IsEqualOrDescendantOf(t *testing.T) { + headerA := generic.NewHeader[uint64, hash.H256, runtime.BlakeTwo256]( + 1, + hash.H256(""), + hash.H256(""), + hash.H256(""), + runtime.Digest{}) + + headerB := generic.NewHeader[uint64, hash.H256, runtime.BlakeTwo256]( + 2, + hash.H256(""), + hash.H256(""), + headerA.Hash(), + runtime.Digest{}) + + headerC := generic.NewHeader[uint64, hash.H256, runtime.BlakeTwo256]( + 3, + hash.H256(""), + hash.H256(""), + headerB.Hash(), + runtime.Digest{}) + + headerList := []runtime.Header[uint64, hash.H256]{ + headerA, + headerB, + headerC, + } + + validAncestryMap := newAncestryChain[hash.H256, uint64](headerList) + + type testCase struct { + name string + chain ancestryChain[hash.H256, uint64] + base hash.H256 + block hash.H256 + want bool + } + tests := []testCase{ + { + name: "baseEqualsBlock", + chain: validAncestryMap, + base: headerA.Hash(), + block: headerA.Hash(), + want: true, + }, + { + name: "baseEqualsBlock", + chain: validAncestryMap, + base: headerA.Hash(), + block: "someInvalidBLock", + want: false, + }, + { + name: "validRoute", + chain: validAncestryMap, + base: headerA.Hash(), + block: headerC.Hash(), + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.chain.IsEqualOrDescendantOf(tt.base, tt.block) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/internal/primitives/consensus/grandpa/app/app.go b/internal/primitives/consensus/grandpa/app/app.go new file mode 100644 index 0000000000..9ffa7079bc --- /dev/null +++ b/internal/primitives/consensus/grandpa/app/app.go @@ -0,0 +1,29 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package app + +import ( + "fmt" + + "github.com/ChainSafe/gossamer/internal/primitives/core/crypto" + "github.com/ChainSafe/gossamer/internal/primitives/core/ed25519" +) + +// Public key used in grandpa +type Public = ed25519.Public + +var _ crypto.Public[Signature] = Public{} + +// NewPublic is constructor for Public +func NewPublic(data []byte) (Public, error) { + if len(data) != 32 { + return Public{}, fmt.Errorf("invalid public key from data: %v", data) + } + pub := Public{} + copy(pub[:], data) + return pub, nil +} + +// Signature is signature type used in grandpa +type Signature = ed25519.Signature diff --git a/internal/primitives/consensus/grandpa/grandpa.go b/internal/primitives/consensus/grandpa/grandpa.go new file mode 100644 index 0000000000..43bc7cd95f --- /dev/null +++ b/internal/primitives/consensus/grandpa/grandpa.go @@ -0,0 +1,98 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package grandpa + +import ( + "github.com/ChainSafe/gossamer/internal/log" + "github.com/ChainSafe/gossamer/internal/primitives/consensus/grandpa/app" + "github.com/ChainSafe/gossamer/internal/primitives/runtime" + grandpa "github.com/ChainSafe/gossamer/pkg/finality-grandpa" + "github.com/ChainSafe/gossamer/pkg/scale" + "golang.org/x/exp/constraints" +) + +var logger = log.NewFromGlobal(log.AddContext("consensus", "grandpa")) + +// AuthorityID is the identity of a Grandpa authority. +type AuthorityID = app.Public + +// NewAuthorityID is constructor for AuthorityID +func NewAuthorityID(data []byte) (AuthorityID, error) { + return app.NewPublic(data) +} + +// AuthoritySignature is the signature for a Grandpa authority. +type AuthoritySignature = app.Signature + +// GrandpaEngineID is the ConsensusEngineID of GRANDPA. +var GrandpaEngineID = runtime.ConsensusEngineID{'F', 'R', 'N', 'K'} + +// AuthorityWeight is the weight of an authority. +type AuthorityWeight uint64 + +// AuthorityIndex is the index of an authority. +type AuthorityIndex uint64 + +// SetID is the monotonic identifier of a GRANDPA set of authorities. +type SetID uint64 + +// RoundNumber is the round indicator. +type RoundNumber uint64 + +// AuthorityIDWeight is struct containing AuthorityID and AuthorityWeight +type AuthorityIDWeight struct { + AuthorityID + AuthorityWeight +} + +// AuthorityList is a list of Grandpa authorities with associated weights. +type AuthorityList []AuthorityIDWeight + +// SignedMessage is a signed message. +type SignedMessage[H, N any] grandpa.SignedMessage[H, N, AuthoritySignature, AuthorityID] + +// Commit is a commit message for this chain's block type. +type Commit[H, N any] grandpa.Commit[H, N, AuthoritySignature, AuthorityID] + +// GrandpaJustification is A GRANDPA justification for block finality, it includes +// a commit message and an ancestry proof including all headers routing all +// precommit target blocks to the commit target block. Due to the current voting +// strategy the precommit targets should be the same as the commit target, since +// honest voters don't vote past authority set change blocks. +// +// This is meant to be stored in the db and passed around the network to other +// nodes, and are used by syncing nodes to prove authority set handoffs. +type GrandpaJustification[Ordered runtime.Hash, N runtime.Number] struct { + Round uint64 + Commit Commit[Ordered, N] + VoteAncestries []runtime.Header[N, Ordered] +} + +// CheckMessageSignature will check a message signature by encoding the message as +// a localised payload and verifying the provided signature using the expected +// authority id. +func CheckMessageSignature[H comparable, N constraints.Unsigned]( + message grandpa.Message[H, N], + id AuthorityID, + signature AuthoritySignature, + round RoundNumber, + setID SetID) bool { + + buf := NewLocalizedPayload(round, setID, message) + valid := id.Verify(signature, buf) + + if !valid { + logger.Debugf("Bad signature on message from %v", id) + } + return valid +} + +// LocalizedPayload will encode round message localised to a given round and set id. +func NewLocalizedPayload(round RoundNumber, setID SetID, message any) []byte { + return scale.MustMarshal(struct { + Message any + RoundNumber + SetID + }{message, round, setID}) +} diff --git a/internal/primitives/consensus/grandpa/grandpa_test.go b/internal/primitives/consensus/grandpa/grandpa_test.go new file mode 100644 index 0000000000..8717972bfe --- /dev/null +++ b/internal/primitives/consensus/grandpa/grandpa_test.go @@ -0,0 +1,65 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package grandpa + +import ( + "testing" + + ced25519 "github.com/ChainSafe/gossamer/internal/primitives/core/ed25519" + "github.com/ChainSafe/gossamer/internal/primitives/core/hash" + "github.com/ChainSafe/gossamer/internal/primitives/keyring/ed25519" + grandpa "github.com/ChainSafe/gossamer/pkg/finality-grandpa" + "github.com/stretchr/testify/require" +) + +func makePrecommit(t *testing.T, + precommit grandpa.Precommit[hash.H256, uint64], + round uint64, + setID uint64, + voter ed25519.Keyring, +) grandpa.SignedPrecommit[hash.H256, uint64, AuthoritySignature, AuthorityID] { + t.Helper() + msg := grandpa.NewMessage(precommit) + encoded := NewLocalizedPayload(RoundNumber(round), SetID(setID), msg) + signature := voter.Sign(encoded) + + return grandpa.SignedPrecommit[hash.H256, uint64, AuthoritySignature, AuthorityID]{ + Precommit: precommit, + Signature: signature, + ID: voter.Pair().Public().(ced25519.Public), + } +} + +func TestCheckMessageSignature(t *testing.T) { + precommit := grandpa.Precommit[hash.H256, uint64]{ + TargetHash: hash.H256("a"), + TargetNumber: 1, + } + signedPrecommit := makePrecommit(t, precommit, 1, 1, ed25519.Alice) + valid := CheckMessageSignature[hash.H256, uint64]( + grandpa.NewMessage(precommit), signedPrecommit.ID, signedPrecommit.Signature, 1, 1) + require.True(t, valid) + valid = CheckMessageSignature[hash.H256, uint64]( + grandpa.NewMessage(precommit), signedPrecommit.ID, signedPrecommit.Signature, 2, 1) + require.False(t, valid) + + signedPrecommit = makePrecommit(t, precommit, 2, 1, ed25519.Alice) + valid = CheckMessageSignature[hash.H256, uint64]( + grandpa.NewMessage(precommit), signedPrecommit.ID, signedPrecommit.Signature, 2, 1) + require.True(t, valid) + valid = CheckMessageSignature[hash.H256, uint64]( + grandpa.NewMessage(precommit), signedPrecommit.ID, signedPrecommit.Signature, 1, 1) + require.False(t, valid) + + signedPrecommit = makePrecommit(t, precommit, 3, 3, ed25519.Bob) + valid = CheckMessageSignature[hash.H256, uint64]( + grandpa.NewMessage(precommit), signedPrecommit.ID, signedPrecommit.Signature, 3, 3) + require.True(t, valid) + valid = CheckMessageSignature[hash.H256, uint64]( + grandpa.NewMessage(precommit), ed25519.Bob.Pair().Public().(ced25519.Public), signedPrecommit.Signature, 3, 3) + require.True(t, valid) + valid = CheckMessageSignature[hash.H256, uint64]( + grandpa.NewMessage(precommit), ed25519.Alice.Pair().Public().(ced25519.Public), signedPrecommit.Signature, 3, 3) + require.False(t, valid) +} diff --git a/internal/primitives/core/crypto/crypto.go b/internal/primitives/core/crypto/crypto.go new file mode 100644 index 0000000000..bcdd7427e8 --- /dev/null +++ b/internal/primitives/core/crypto/crypto.go @@ -0,0 +1,214 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package crypto + +import ( + "fmt" + "regexp" + "strconv" + "strings" + + "github.com/ChainSafe/gossamer/internal/primitives/core/hashing" + "github.com/ChainSafe/gossamer/pkg/scale" +) + +// DevPhrase is the root phrase for our publicly known keys. +const DevPhrase = "bottom drive obey lake curtain smoke basket hold race lonely fit walk" + +// DeriveJunction is a since derivation junction description. It is the single parameter +// used when creating a new secret key from an existing secret key and, in the case of +// `SoftRaw` and `SoftIndex` a new public key from an existing public key. +type DeriveJunction struct { + inner any +} +type DeriveJunctions interface { + DeriveJunctionSoft | DeriveJunctionHard +} + +func (dj DeriveJunction) Value() any { + if dj.inner == nil { + panic("nil inner for DeriveJunction") + } + return dj.inner +} + +// DeriveJunctionSoft is a soft (vanilla) derivation. Public keys have a correspondent derivation. +type DeriveJunctionSoft [32]byte + +// DeriveJunctionHard is a hard ("hardened") derivation. Public keys do not have a correspondent derivation. +type DeriveJunctionHard [32]byte + +// Harden will consume self to return a hard derive junction with the same chain code. +func (dj *DeriveJunction) Harden() DeriveJunction { + switch inner := dj.inner.(type) { + case DeriveJunctionSoft: + dj.inner = DeriveJunctionHard(inner) + } + return *dj +} + +// NewDeriveJunctionSoft creates a new soft (vanilla) DeriveJunction from a given, encodable, value. +func NewDeriveJunctionSoft(index any) (DeriveJunctionSoft, error) { + var cc = [32]byte{} + data, err := scale.Marshal(index) + if err != nil { + return DeriveJunctionSoft{}, err + } + + if len(data) > 32 { + cc = hashing.BlakeTwo256(data) + } else { + copy(cc[:], data) + } + return DeriveJunctionSoft(cc), nil +} + +// NewDeriveJunctionFromString is constructor of DeriveJunction from string representation. +func NewDeriveJunctionFromString(j string) DeriveJunction { + hard := false + trimmed := strings.TrimPrefix(j, "/") + if trimmed != j { + hard = true + } + code := trimmed + + var res DeriveJunction + n, err := strconv.Atoi(code) + if err == nil { + soft, err := NewDeriveJunctionSoft(n) + if err != nil { + panic(err) + } + res = DeriveJunction{ + inner: soft, + } + } else { + soft, err := NewDeriveJunctionSoft(code) + if err != nil { + panic(err) + } + res = DeriveJunction{ + inner: soft, + } + } + + if hard { + return res.Harden() + } else { + return res + } +} + +// NewDeriveJunction is constructor for DeriveJunction +func NewDeriveJunction[V DeriveJunctions](value V) DeriveJunction { + return DeriveJunction{ + inner: value, + } +} + +var secretPhraseRegex = regexp.MustCompile(`^(?P[\d\w ]+)?(?P(//?[^/]+)*)(///(?P.*))?$`) + +var junctionRegex = regexp.MustCompile(`/(/?[^/]+)`) + +// Trait used for types that are really just a fixed-length array. +type Bytes interface { + // Return a `Vec` filled with raw data. + Bytes() []byte +} + +// Trait suitable for typical cryptographic key public type. +type Public[Signature any] interface { + Bytes + + // Verify a signature on a message. Returns true if the signature is good. + Verify(sig Signature, message []byte) bool +} + +// SecretURI A secret uri (`SURI`) that can be used to generate a key pair. +// +// The `SURI` can be parsed from a string. The string is interpreted in the following way: +// +// - If `string` is a possibly `0x` prefixed 64-digit hex string, then it will be interpreted +// directly as a secret key (aka "seed" in `subkey`). +// - If `string` is a valid BIP-39 key phrase of 12, 15, 18, 21 or 24 words, then the key will +// be derived from it. In this case: +// - the phrase may be followed by one or more items delimited by `/` characters. +// - the path may be followed by `///`, in which case everything after the `///` is treated +// +// as a password. +// - If `string` begins with a `/` character it is prefixed with the public `DevPhrase` +// and interpreted as above. +// +// In this case they are interpreted as HDKD junctions; purely numeric items are interpreted as +// integers, non-numeric items as strings. Junctions prefixed with `/` are interpreted as soft +// junctions, and with `//` as hard junctions. +// +// There is no correspondence mapping between `SURI` strings and the keys they represent. +// Two different non-identical strings can actually lead to the same secret being derived. +// Notably, integer junction indices may be legally prefixed with arbitrary number of zeros. +// Similarly an empty password (ending the `SURI` with `///`) is perfectly valid and will +// generally be equivalent to no password at all. +type SecretURI struct { + // The phrase to derive the private key. + // This can either be a 64-bit hex string or a BIP-39 key phrase. + Phrase string + // Optional password as given as part of the uri. + Password *string + // The junctions as part of the uri. + Junctions []DeriveJunction +} + +// NewSecretURI is contructor for SecretURI +func NewSecretURI(s string) (SecretURI, error) { + matches := secretPhraseRegex.FindStringSubmatch(s) + if matches == nil { + return SecretURI{}, fmt.Errorf("invalid format") + } + + var ( + junctions []DeriveJunction + phrase = DevPhrase + password *string + ) + for i, name := range secretPhraseRegex.SubexpNames() { + if i == 0 { + continue + } + switch name { + case "path": + junctionMatches := junctionRegex.FindAllString(matches[i], -1) + for _, jm := range junctionMatches { + junctions = append(junctions, NewDeriveJunctionFromString(jm)) + } + case "phrase": + if matches[i] != "" { + phrase = matches[i] + } + case "password": + if matches[i] != "" { + pw := matches[i] + password = &pw + } + } + } + return SecretURI{ + Phrase: phrase, + Password: password, + Junctions: junctions, + }, nil +} + +// Pair is an interface suitable for typical cryptographic PKI key pair type. +// +// For now it just specifies how to create a key from a phrase and derivation path. +type Pair[Seed, Signature any] interface { + // Derive a child key from a series of given junctions. + Derive(path []DeriveJunction, seed *Seed) (Pair[Seed, Signature], Seed, error) + + // Sign a message. + Sign(message []byte) Signature + + // Get the public key. + Public() Public[Signature] +} diff --git a/internal/primitives/core/ed25519/ed25519.go b/internal/primitives/core/ed25519/ed25519.go new file mode 100644 index 0000000000..d7db6c21c7 --- /dev/null +++ b/internal/primitives/core/ed25519/ed25519.go @@ -0,0 +1,275 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package ed25519 + +import ( + gocrypto "crypto" + "crypto/ed25519" + "crypto/rand" + "encoding/hex" + "fmt" + "strings" + + "github.com/ChainSafe/go-schnorrkel" + "github.com/ChainSafe/gossamer/internal/primitives/core/crypto" + "github.com/ChainSafe/gossamer/internal/primitives/core/hashing" + "github.com/ChainSafe/gossamer/pkg/scale" + "github.com/tyler-smith/go-bip39" +) + +// A secret seed. +type seed [32]byte + +// A Public key. +type Public [32]byte + +// Bytes returns a byte slice +func (p Public) Bytes() []byte { + return p[:] +} + +// Verify a signature on a message. Returns true if the signature is good. +func (p Public) Verify(sig Signature, message []byte) bool { + return ed25519.Verify(p[:], message, sig[:]) +} + +// NewPublic creates a new instance from the given 32-byte `data`. +// +// NOTE: No checking goes on to ensure this is a real public key. Only use it if +// you are certain that the array actually is a pubkey. +func NewPublic(data [32]byte) Public { + return Public(data) +} + +var _ crypto.Public[Signature] = Public{} + +// Derive a single hard junction. +func deriveHardJunction(secretSeed seed, cc [32]byte) seed { + tuple := struct { + ID string + SecretSeed seed + CC [32]byte + }{"Ed25519HDKD", secretSeed, cc} + encoded := scale.MustMarshal(tuple) + return hashing.BlakeTwo256(encoded) +} + +// Pair is a key pair. +type Pair struct { + public gocrypto.PublicKey + secret ed25519.PrivateKey +} + +// Derive a child key from a series of given junctions. +func (p Pair) Derive(path []crypto.DeriveJunction, seed *[32]byte) (crypto.Pair[[32]byte, Signature], [32]byte, error) { + var acc [32]byte + copy(acc[:], p.secret.Seed()) + for _, j := range path { + switch cc := j.Value().(type) { + case crypto.DeriveJunctionSoft: + return Pair{}, [32]byte{}, fmt.Errorf("soft key in path") + case crypto.DeriveJunctionHard: + acc = deriveHardJunction(acc, cc) + } + } + pair := NewPairFromSeed(acc) + return pair, acc, nil +} + +// Seed is the seed for this key. +func (p Pair) Seed() [32]byte { + var seed [32]byte + copy(seed[:], p.secret.Seed()) + return seed +} + +// Public will return the public key. +func (p Pair) Public() crypto.Public[Signature] { + pubKey, ok := p.public.(ed25519.PublicKey) + if !ok { + panic("huh?") + } + if len(pubKey) != 32 { + panic("huh?") + } + var pub Public + copy(pub[:], pubKey) + return pub +} + +// Sign a message. +func (p Pair) Sign(message []byte) Signature { + signed := ed25519.Sign(p.secret, message) + if len(signed) != 64 { + panic("huh?") + } + var sig Signature + copy(sig[:], signed) + return sig +} + +// NewGeneratedPair will generate new secure (random) key pair. +// +// This is only for ephemeral keys really, since you won't have access to the secret key +// for storage. If you want a persistent key pair, use `generate_with_phrase` instead. +func NewGeneratedPair() (Pair, [32]byte) { + seedSlice := make([]byte, 32) + _, err := rand.Read(seedSlice) + if err != nil { + panic(err) + } + + var seed [32]byte + copy(seed[:], seedSlice) + return NewPairFromSeed(seed), seed +} + +// NewGeneratedPairWithPhrase will generate new secure (random) key pair and provide the recovery phrase. +// +// You can recover the same key later with `from_phrase`. +// +// This is generally slower than `generate()`, so prefer that unless you need to persist +// the key from the current session. +func NewGeneratedPairWithPhrase(password *string) (Pair, string, [32]byte) { + entropy, err := bip39.NewEntropy(128) + if err != nil { + panic(err) + } + phrase, err := bip39.NewMnemonic(entropy) + if err != nil { + panic(err) + } + pair, seed, err := NewPairFromPhrase(phrase, password) + if err != nil { + panic(err) + } + return pair, phrase, seed +} + +// NewPairFromPhrase returns the KeyPair from the English BIP39 seed `phrase`, or `None` if it's invalid. +func NewPairFromPhrase(phrase string, password *string) (pair Pair, seed [32]byte, err error) { + pass := "" + if password != nil { + pass = *password + } + bigSeed, err := schnorrkel.SeedFromMnemonic(phrase, pass) + if err != nil { + return Pair{}, [32]byte{}, err + } + + if !(32 <= len(bigSeed)) { + panic("huh?") + } + + seedSlice := bigSeed[:][0:32] + copy(seed[:], seedSlice) + return NewPairFromSeedSlice(seedSlice), seed, nil +} + +// NewPairFromSeed will generate new key pair from the provided `seed`. +// +// @WARNING: THIS WILL ONLY BE SECURE IF THE `seed` IS SECURE. If it can be guessed +// by an attacker then they can also derive your key. +func NewPairFromSeed(seed [32]byte) Pair { + return NewPairFromSeedSlice(seed[:]) +} + +// NewPairFromSeedSlice will make a new key pair from secret seed material. The slice must be the correct size or +// it will return `None`. +// +// @WARNING: THIS WILL ONLY BE SECURE IF THE `seed` IS SECURE. If it can be guessed +// by an attacker then they can also derive your key. +func NewPairFromSeedSlice(seedSlice []byte) Pair { + secret := ed25519.NewKeyFromSeed(seedSlice) + public := secret.Public() + return Pair{ + public: public, + secret: secret, + } +} + +// NewPairFromStringWithSeed interprets the string `s` in order to generate a key Pair. Returns +// both the pair and an optional seed, in the case that the pair can be expressed as a direct +// derivation from a seed (some cases, such as Sr25519 derivations with path components, cannot). +// +// This takes a helper function to do the key generation from a phrase, password and +// junction iterator. +// +// - If `s` is a possibly `0x` prefixed 64-digit hex string, then it will be interpreted +// directly as a secret key (aka "seed" in `subkey`). +// - If `s` is a valid BIP-39 key phrase of 12, 15, 18, 21 or 24 words, then the key will +// be derived from it. In this case: +// - the phrase may be followed by one or more items delimited by `/` characters. +// - the path may be followed by `///`, in which case everything after the `///` is treated +// +// as a password. +// - If `s` begins with a `/` character it is prefixed with the Substrate public `DevPhrase` +// and +// +// interpreted as above. +// +// In this case they are interpreted as HDKD junctions; purely numeric items are interpreted as +// integers, non-numeric items as strings. Junctions prefixed with `/` are interpreted as soft +// junctions, and with `//` as hard junctions. +// +// There is no correspondence mapping between SURI strings and the keys they represent. +// Two different non-identical strings can actually lead to the same secret being derived. +// Notably, integer junction indices may be legally prefixed with arbitrary number of zeros. +// Similarly an empty password (ending the SURI with `///`) is perfectly valid and will +// generally be equivalent to no password at all. +// +// `nil` is returned if no matches are found. +func NewPairFromStringWithSeed(s string, passwordOverride *string) ( + pair crypto.Pair[[32]byte, Signature], seed [32]byte, err error, +) { + sURI, err := crypto.NewSecretURI(s) + if err != nil { + return Pair{}, [32]byte{}, err + } + var password *string + if passwordOverride != nil { + password = passwordOverride + } else { + password = sURI.Password + } + + var ( + root Pair + // seed []byte + ) + trimmedPhrase := strings.TrimPrefix(sURI.Phrase, "0x") + if trimmedPhrase != sURI.Phrase { + seedBytes, err := hex.DecodeString(trimmedPhrase) + if err != nil { + return Pair{}, [32]byte{}, err + } + root = NewPairFromSeedSlice(seedBytes) + copy(seed[:], seedBytes) + } else { + root, seed, err = NewPairFromPhrase(sURI.Phrase, password) + if err != nil { + return Pair{}, [32]byte{}, err + } + } + return root.Derive(sURI.Junctions, &seed) +} + +// NewPairFromString interprets the string `s` in order to generate a key pair. +func NewPairFromString(s string, passwordOverride *string) (crypto.Pair[[32]byte, Signature], error) { + pair, _, err := NewPairFromStringWithSeed(s, passwordOverride) + return pair, err +} + +var _ crypto.Pair[[32]byte, Signature] = Pair{} + +// Signature is a signature (a 512-bit value). +type Signature [64]byte + +// NewSignatureFromRaw constructors a new instance from the given 64-byte `data`. +// +// NOTE: No checking goes on to ensure this is a real signature. Only use it if +// you are certain that the array actually is a signature. +func NewSignatureFromRaw(data [64]byte) Signature { + return Signature(data) +} diff --git a/internal/primitives/core/ed25519/ed25519_test.go b/internal/primitives/core/ed25519/ed25519_test.go new file mode 100644 index 0000000000..124a00828a --- /dev/null +++ b/internal/primitives/core/ed25519/ed25519_test.go @@ -0,0 +1,144 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package ed25519_test + +import ( + "encoding/hex" + "fmt" + "testing" + + "github.com/ChainSafe/gossamer/internal/primitives/core/crypto" + "github.com/ChainSafe/gossamer/internal/primitives/core/ed25519" + "github.com/stretchr/testify/require" +) + +func mustHexDecodeString32(t *testing.T, s string) [32]byte { + t.Helper() + seedSlice, err := hex.DecodeString(s) + require.NoError(t, err) + + var seed [32]byte + copy(seed[:], seedSlice) + return seed +} +func mustHexDecodeString64(t *testing.T, s string) [64]byte { + t.Helper() + seedSlice, err := hex.DecodeString(s) + require.NoError(t, err) + + var seed [64]byte + copy(seed[:], seedSlice) + return seed +} + +var password string = "password" + +func TestDefaultPhraseShouldBeUsed(t *testing.T) { + pair, err := ed25519.NewPairFromString("//Alice///password", nil) + require.NoError(t, err) + + pair1, err := ed25519.NewPairFromString( + fmt.Sprintf("%s//Alice", crypto.DevPhrase), &password, + ) + require.NoError(t, err) + + require.Equal(t, pair, pair1) +} + +func TestNewPairFromString_DifferentAliases(t *testing.T) { + pair, err := ed25519.NewPairFromString("//Alice///password", nil) + require.NoError(t, err) + + pair1, err := ed25519.NewPairFromString("//Bob///password", nil) + require.NoError(t, err) + + require.NotEqual(t, pair, pair1) +} + +func TestSeedAndDeriveShouldWork(t *testing.T) { + seed := mustHexDecodeString32(t, "9d61b19deffd5a60ba844af492ec2cc44449c5697b326919703bac031cae7f60") + pair := ed25519.NewPairFromSeed(seed) + require.Equal(t, pair.Seed(), seed) + + path := []crypto.DeriveJunction{crypto.NewDeriveJunction(crypto.DeriveJunctionHard{})} + derived, _, err := pair.Derive(path, nil) + require.NoError(t, err) + + expected := mustHexDecodeString32(t, "ede3354e133f9c8e337ddd6ee5415ed4b4ffe5fc7d21e933f4930a3730e5b21c") + require.Equal(t, expected, derived.(ed25519.Pair).Seed()) +} + +func TestVectorShouldWork(t *testing.T) { + seed := mustHexDecodeString32(t, "9d61b19deffd5a60ba844af492ec2cc44449c5697b326919703bac031cae7f60") + expected := mustHexDecodeString32(t, "d75a980182b10ab7d54bfed3c964073a0ee172f3daa62325af021a68f707511a") + + pair := ed25519.NewPairFromSeed(seed) + public := pair.Public() + require.Equal(t, public, ed25519.NewPublic(expected)) + + signature := mustHexDecodeString64(t, + "e5564300c360ac729086e2cc806e828a84877f1eb8e5d974d873e065224901555fb8821590a33bacc61e39701cf9b46bd25bf5f0595bbe24655141438e7a100b") //nolint: lll + message := []byte("") + require.Equal(t, ed25519.NewSignatureFromRaw(signature), pair.Sign(message)) + require.True(t, public.Verify(signature, message)) +} + +func TestVectorByStringShouldWork(t *testing.T) { + pair, err := ed25519.NewPairFromString("0x9d61b19deffd5a60ba844af492ec2cc44449c5697b326919703bac031cae7f60", nil) + require.NoError(t, err) + public := pair.Public() + require.Equal(t, ed25519.NewPublic( + mustHexDecodeString32(t, "d75a980182b10ab7d54bfed3c964073a0ee172f3daa62325af021a68f707511a"), + ), public) + + signature := mustHexDecodeString64(t, + "e5564300c360ac729086e2cc806e828a84877f1eb8e5d974d873e065224901555fb8821590a33bacc61e39701cf9b46bd25bf5f0595bbe24655141438e7a100b") //nolint: lll + message := []byte("") + require.Equal(t, ed25519.NewSignatureFromRaw(signature), pair.Sign(message)) + require.True(t, public.Verify(signature, message)) +} + +func TestGeneratedPairShouldWork(t *testing.T) { + pair, _ := ed25519.NewGeneratedPair() + public := pair.Public() + message := []byte("Something important") + signature := pair.Sign(message) + require.True(t, public.Verify(signature, message)) + require.False(t, public.Verify(signature, []byte("Something else"))) +} + +func TestSeededPairShouldWork(t *testing.T) { + pair := ed25519.NewPairFromSeedSlice([]byte("12345678901234567890123456789012")) + public := pair.Public() + require.Equal(t, public, ed25519.NewPublic( + mustHexDecodeString32(t, "2f8c6129d816cf51c374bc7f08c3e63ed156cf78aefb4a6550d97b87997977ee"), + )) + message := mustHexDecodeString32(t, "2f8c6129d816cf51c374bc7f08c3e63ed156cf78aefb4a6550d97b87997977ee") + signature := pair.Sign(message[:]) + require.True(t, public.Verify(signature, message[:])) + require.False(t, public.Verify(signature, []byte("Other Message"))) +} + +func TestGenerateWithPhraseRecoveryPossible(t *testing.T) { + pair1, phrase, _ := ed25519.NewGeneratedPairWithPhrase(nil) + pair2, _, err := ed25519.NewPairFromPhrase(phrase, nil) + require.NoError(t, err) + require.Equal(t, pair1.Public(), pair2.Public()) +} + +func TestGenerateWithPasswordPhraseRecoverPossible(t *testing.T) { + password := "password" + pair1, phrase, _ := ed25519.NewGeneratedPairWithPhrase(&password) + pair2, _, err := ed25519.NewPairFromPhrase(phrase, &password) + require.NoError(t, err) + require.Equal(t, pair1.Public(), pair2.Public()) +} + +func TestPasswordDoesSomething(t *testing.T) { + password := "password" + pair1, phrase, _ := ed25519.NewGeneratedPairWithPhrase(&password) + pair2, _, err := ed25519.NewPairFromPhrase(phrase, nil) + require.NoError(t, err) + require.NotEqual(t, pair1.Public(), pair2.Public()) +} diff --git a/internal/primitives/core/hash/hash.go b/internal/primitives/core/hash/hash.go new file mode 100644 index 0000000000..e9235736f0 --- /dev/null +++ b/internal/primitives/core/hash/hash.go @@ -0,0 +1,65 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package hash + +import ( + "crypto/rand" + "encoding/binary" + "fmt" + "io" + + "github.com/ChainSafe/gossamer/pkg/scale" +) + +// H256 is a fixed-size uninterpreted hash type with 32 bytes (256 bits) size. +type H256 string + +// Bytes returns a byte slice +func (h256 H256) Bytes() []byte { + return []byte(h256) +} + +// String returns string representation of H256 +func (h256 H256) String() string { + return fmt.Sprintf("%v", h256.Bytes()) +} + +// MarshalSCALE fulfils the SCALE interface for encoding +func (h256 H256) MarshalSCALE() ([]byte, error) { + var arr [32]byte + copy(arr[:], []byte(h256)) + return scale.Marshal(arr) +} + +// UnmarshalSCALE fulfils the SCALE interface for decoding +func (h256 *H256) UnmarshalSCALE(r io.Reader) error { + var arr [32]byte + decoder := scale.NewDecoder(r) + err := decoder.Decode(&arr) + if err != nil { + return err + } + if arr != [32]byte{} { + *h256 = H256(arr[:]) + } + return nil +} + +// NewH256FromLowUint64BigEndian is constructor for H256 from a uint64 +func NewH256FromLowUint64BigEndian(v uint64) H256 { + b := make([]byte, 8) + binary.BigEndian.PutUint64(b, v) + full := append(b, make([]byte, 24)...) + return H256(full) +} + +// NewRandomH256 is constructor for a random H256 +func NewRandomH256() H256 { + token := make([]byte, 32) + _, err := rand.Read(token) + if err != nil { + panic(err) + } + return H256(token) +} diff --git a/internal/primitives/core/hashing/hashing.go b/internal/primitives/core/hashing/hashing.go new file mode 100644 index 0000000000..6cb31194e9 --- /dev/null +++ b/internal/primitives/core/hashing/hashing.go @@ -0,0 +1,24 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package hashing + +import ( + "golang.org/x/crypto/blake2b" +) + +// BlakeTwo256 returns a Blake2 256-bit hash of the input data +func BlakeTwo256(data []byte) [32]byte { + h, err := blake2b.New256(nil) + if err != nil { + panic(err) + } + _, err = h.Write(data) + if err != nil { + panic(err) + } + encoded := h.Sum(nil) + var arr [32]byte + copy(arr[:], encoded) + return arr +} diff --git a/internal/primitives/keyring/ed25519/ed25519.go b/internal/primitives/keyring/ed25519/ed25519.go new file mode 100644 index 0000000000..556cfe66ec --- /dev/null +++ b/internal/primitives/keyring/ed25519/ed25519.go @@ -0,0 +1,58 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package ed25519 + +import ( + "fmt" + + "github.com/ChainSafe/gossamer/internal/primitives/core/ed25519" +) + +type Keyring uint + +const ( + Alice Keyring = iota + Bob + Charlie + Dave + Eve + Ferdie + One + Two +) + +func (k Keyring) Sign(msg []byte) ed25519.Signature { + return k.Pair().Sign(msg) +} + +func (k Keyring) Pair() ed25519.Pair { + pair, err := ed25519.NewPairFromString(fmt.Sprintf("//%s", k), nil) + if err != nil { + panic("static values are known good; qed") + } + return pair.(ed25519.Pair) +} + +func (k Keyring) String() string { + switch k { + case Alice: + return "Alice" + case Bob: + return "Bob" + case Charlie: + return "Charlie" + case Dave: + return "Dave" + case Eve: + return "Eve" + case Ferdie: + return "Ferdie" + case One: + return "One" + case Two: + return "Two" + default: + panic("unsupported Keyring") + } +} diff --git a/internal/primitives/runtime/digest.go b/internal/primitives/runtime/digest.go new file mode 100644 index 0000000000..7ed6358c60 --- /dev/null +++ b/internal/primitives/runtime/digest.go @@ -0,0 +1,80 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package runtime + +// Digest item that is able to encode/decode 'system' digest items and +// provide opaque access to other items. +type DigestItemTypes interface { + PreRuntime | Consensus | Seal | Other | RuntimeEnvironmentUpdated +} + +// Digest item that is able to encode/decode 'system' digest items and +// provide opaque access to other items. +// TODO: implement this as scale.VaryingDataType +type DigestItem any + +// NewDigestItem is constructor for DigestItem +func NewDigestItem[T DigestItemTypes](item T) DigestItem { + return NewDigestItem(item) +} + +// A pre-runtime digest. +// +// These are messages from the consensus engine to the runtime, although +// the consensus engine can (and should) read them itself to avoid +// code and state duplication. It is erroneous for a runtime to produce +// these, but this is not (yet) checked. +// +// NOTE: the runtime is not allowed to panic or fail in an `on_initialize` +// call if an expected `PreRuntime` digest is not present. It is the +// responsibility of a external block verifier to check this. Runtime API calls +// will initialize the block without pre-runtime digests, so initialization +// cannot fail when they are missing. +type PreRuntime struct { + ConsensusEngineID + Bytes []byte +} + +// A message from the runtime to the consensus engine. This should *never* +// be generated by the native code of any consensus engine, but this is not +// checked (yet). +type Consensus struct { + ConsensusEngineID + Bytes []byte +} + +// Put a Seal on it. This is only used by native code, and is never seen +// by runtimes. +type Seal struct { + ConsensusEngineID + Bytes []byte +} + +// Some other thing. Unsupported and experimental. +type Other []byte + +// An indication for the light clients that the runtime execution +// environment is updated. +type RuntimeEnvironmentUpdated struct{} + +// Digest is a header digest. +type Digest struct { + // A list of logs in the digest. + Logs []DigestItem +} + +// Push new digest item. +func (d *Digest) Push(item DigestItem) { + d.Logs = append(d.Logs, item) +} + +// Pop a digest item. +func (d *Digest) Pop() DigestItem { + if len(d.Logs) == 0 { + return nil + } + item := d.Logs[len(d.Logs)-1] + d.Logs = d.Logs[:len(d.Logs)-1] + return item +} diff --git a/internal/primitives/runtime/generic/block.go b/internal/primitives/runtime/generic/block.go new file mode 100644 index 0000000000..a08acf42b6 --- /dev/null +++ b/internal/primitives/runtime/generic/block.go @@ -0,0 +1,72 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package generic + +import ( + "github.com/ChainSafe/gossamer/internal/primitives/core/hash" + "github.com/ChainSafe/gossamer/internal/primitives/runtime" +) + +// Something to identify a block. +type BlockID any + +// BlockIDTypes is the interface constraint of `BlockID`. +type BlockIDTypes[H, N any] interface { + BlockIDHash[H] | BlockIDNumber[N] +} + +// NewBlockID is the constructor for `BlockID`. +func NewBlockID[H, N any, T BlockIDTypes[H, N]](blockID T) BlockID { + return BlockID(blockID) +} + +// BlockIDHash is id by block header hash. +type BlockIDHash[H any] struct { + Inner H +} + +// BlockIDNumber is id by block number. +type BlockIDNumber[N any] struct { + Inner N +} + +// Block is a block. +type Block[N runtime.Number, H runtime.Hash, Hasher runtime.Hasher[H]] struct { + // The block header. + header runtime.Header[N, H] + // The accompanying extrinsics. + extrinsics []runtime.Extrinsic +} + +// Header returns the header. +func (b Block[N, H, Hasher]) Header() runtime.Header[N, H] { + return b.header +} + +// Extrinsics returns the block extrinsics. +func (b Block[N, H, Hasher]) Extrinsics() []runtime.Extrinsic { + return b.extrinsics +} + +// Deconstruct returns both header and extrinsics. +func (b Block[N, H, Hasher]) Deconstruct() (header runtime.Header[N, H], extrinsics []runtime.Extrinsic) { + return b.Header(), b.Extrinsics() +} + +// Hash returns the block hash. +func (b Block[N, H, Hasher]) Hash() H { + hasher := *new(Hasher) + return hasher.HashEncoded(b.header) +} + +// NewBlock is the constructor for `Block`. +func NewBlock[N runtime.Number, H runtime.Hash, Hasher runtime.Hasher[H]]( + header runtime.Header[N, H], extrinsics []runtime.Extrinsic) Block[N, H, Hasher] { + return Block[N, H, Hasher]{ + header: header, + extrinsics: extrinsics, + } +} + +var _ runtime.Block[uint, hash.H256] = Block[uint, hash.H256, runtime.BlakeTwo256]{} diff --git a/internal/primitives/runtime/generic/header.go b/internal/primitives/runtime/generic/header.go new file mode 100644 index 0000000000..ef10c50ab4 --- /dev/null +++ b/internal/primitives/runtime/generic/header.go @@ -0,0 +1,133 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package generic + +import ( + "io" + + "github.com/ChainSafe/gossamer/internal/primitives/core/hash" + "github.com/ChainSafe/gossamer/internal/primitives/runtime" + "github.com/ChainSafe/gossamer/pkg/scale" +) + +// Header is a block header, and implements a compatible encoding to `sp_runtime::generic::Header` +type Header[N runtime.Number, H runtime.Hash, Hasher runtime.Hasher[H]] struct { + // The parent hash. + parentHash H + // The block number. + number N + // The state trie merkle root + stateRoot H + // The merkle root of the extrinsics. + extrinsicsRoot H + // A chain-specific digest of data useful for light clients or referencing auxiliary data. + digest runtime.Digest +} + +// Number returns the block number. +func (h Header[N, H, Hasher]) Number() N { + return h.number +} + +// SetNumber sets the block number. +func (h *Header[N, H, Hasher]) SetNumber(number N) { + h.number = number +} + +// ExtrinsicsRoot returns the extrinsics root. +func (h Header[N, H, Hasher]) ExtrinsicsRoot() H { + return h.extrinsicsRoot +} + +// SetExtrinsicsRoot sets the extrinsics root. +func (h *Header[N, H, Hasher]) SetExtrinsicsRoot(root H) { + h.extrinsicsRoot = root +} + +// StateRoot returns the state root. +func (h Header[N, H, Hasher]) StateRoot() H { + return h.stateRoot +} + +// SetStateRoot sets the state root. +func (h *Header[N, H, Hasher]) SetStateRoot(root H) { + h.stateRoot = root +} + +// ParentHash returns the parent hash. +func (h Header[N, H, Hasher]) ParentHash() H { + return h.parentHash +} + +// SetParentHash sets the parent hash. +func (h *Header[N, H, Hasher]) SetParentHash(hash H) { + h.parentHash = hash +} + +// Digest returns the digest. +func (h Header[N, H, Hasher]) Digest() runtime.Digest { + return h.digest +} + +// DigestMut returns a mutable reference to the stored digest. +func (h Header[N, H, Hasher]) DigestMut() *runtime.Digest { + return &h.digest +} + +type encodingHelper[H any] struct { + ParentHash H + // uses compact encoding so we need to cast to uint + // https://github.com/paritytech/substrate/blob/e374a33fe1d99d59eb24a08981090bdb4503e81b/primitives/runtime/src/generic/header.rs#L47 + Number uint + StateRoot H + ExtrinsicsRoot H + Digest runtime.Digest +} + +// MarshalSCALE implements custom SCALE encoding. +func (h Header[N, H, Hasher]) MarshalSCALE() ([]byte, error) { + help := encodingHelper[H]{h.parentHash, uint(h.number), h.stateRoot, h.extrinsicsRoot, h.digest} + return scale.Marshal(help) +} + +// UnmarshalSCALE implements custom SCALE decoding. +func (h *Header[N, H, Hasher]) UnmarshalSCALE(r io.Reader) error { + var header encodingHelper[H] + decoder := scale.NewDecoder(r) + err := decoder.Decode(&header) + if err != nil { + return err + } + h.parentHash = header.ParentHash + h.number = N(header.Number) + h.stateRoot = header.StateRoot + h.extrinsicsRoot = header.ExtrinsicsRoot + h.digest = header.Digest + return nil +} + +// Hash returns the hash of the header. +func (h Header[N, H, Hasher]) Hash() H { + hasher := *new(Hasher) + return hasher.HashEncoded(h) +} + +// NewHeader is the constructor for `Header` +func NewHeader[N runtime.Number, H runtime.Hash, Hasher runtime.Hasher[H]]( + number N, + extrinsicsRoot H, + stateRoot H, + parentHash H, + digest runtime.Digest, +) *Header[N, H, Hasher] { + return &Header[N, H, Hasher]{ + number: number, + extrinsicsRoot: extrinsicsRoot, + stateRoot: stateRoot, + parentHash: parentHash, + digest: digest, + } +} + +var _ runtime.Header[uint64, hash.H256] = &Header[uint64, hash.H256, runtime.BlakeTwo256]{} diff --git a/internal/primitives/runtime/interfaces.go b/internal/primitives/runtime/interfaces.go new file mode 100644 index 0000000000..c239aae471 --- /dev/null +++ b/internal/primitives/runtime/interfaces.go @@ -0,0 +1,105 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package runtime + +import ( + "github.com/ChainSafe/gossamer/internal/primitives/core/hash" + "github.com/ChainSafe/gossamer/internal/primitives/core/hashing" + "github.com/ChainSafe/gossamer/pkg/scale" + "golang.org/x/exp/constraints" +) + +// Number is the header number type +type Number interface { + ~uint | ~uint32 | ~uint64 +} + +// Hash type +type Hash interface { + constraints.Ordered + // Bytes returns a byte slice representation of Hash + Bytes() []byte + // String returns a unique string representation of the hash + String() string +} + +// Hasher is an interface around hashing +type Hasher[H Hash] interface { + // Produce the hash of some byte-slice. + Hash(s []byte) H + + // Produce the hash of some codec-encodable value. + HashEncoded(s any) H +} + +// Blake2-256 Hash implementation. +type BlakeTwo256 struct{} + +// Produce the hash of some byte-slice. +func (bt256 BlakeTwo256) Hash(s []byte) hash.H256 { + h := hashing.BlakeTwo256(s) + return hash.H256(h[:]) +} + +// Produce the hash of some codec-encodable value. +func (bt256 BlakeTwo256) HashEncoded(s any) hash.H256 { + bytes := scale.MustMarshal(s) + return bt256.Hash(bytes) +} + +var _ Hasher[hash.H256] = BlakeTwo256{} + +// Header is the interface for a header. It has types for a `Number`, +// and `Hash`. It provides access to an `ExtrinsicsRoot`, `StateRoot` and +// `ParentHash`, as well as a `Digest` and a block `Number`. +type Header[N Number, H Hash] interface { + // Returns a reference to the header number. + Number() N + // Sets the header number. + SetNumber(number N) + + // Returns a reference to the extrinsics root. + ExtrinsicsRoot() H + // Sets the extrinsic root. + SetExtrinsicsRoot(root H) + + // Returns a reference to the state root. + StateRoot() H + // Sets the state root. + SetStateRoot(root H) + + // Returns a reference to the parent hash. + ParentHash() H + // Sets the parent hash. + SetParentHash(hash H) + + // Returns a reference to the digest. + Digest() Digest + // Get a mutable reference to the digest. + DigestMut() *Digest + + // Returns the hash of the header. + Hash() H +} + +// Block represents a block. It has types for `Extrinsic` pieces of information as well as a `Header`. +// +// You can iterate over each of the `Extrinsics` and retrieve the `Header`. +type Block[N Number, H Hash] interface { + // Returns a reference to the header. + Header() Header[N, H] + // Returns a reference to the list of extrinsics. + Extrinsics() []Extrinsic + // Split the block into header and list of extrinsics. + Deconstruct() (header Header[N, H], extrinsics []Extrinsic) + // Returns the hash of the block. + Hash() H +} + +// Extrinisic is the interface for an `Extrinsic`. +type Extrinsic interface { + // Is this `Extrinsic` signed? + // If no information are available about signed/unsigned, `nil` should be returned. + IsSigned() *bool +} diff --git a/internal/primitives/runtime/runtime.go b/internal/primitives/runtime/runtime.go new file mode 100644 index 0000000000..71b3668935 --- /dev/null +++ b/internal/primitives/runtime/runtime.go @@ -0,0 +1,7 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package runtime + +// Consensus engine unique ID. +type ConsensusEngineID [4]byte diff --git a/lib/grandpa/errors.go b/lib/grandpa/errors.go index 97d21d0bc6..801e46d03f 100644 --- a/lib/grandpa/errors.go +++ b/lib/grandpa/errors.go @@ -83,9 +83,6 @@ var ( // ErrAuthorityNotInSet is returned when a precommit within a justification is signed by a key not in the authority set ErrAuthorityNotInSet = errors.New("authority is not in set") - // errFinalisedBlocksMismatch is returned when we find another block finalised in the same set id and round - errFinalisedBlocksMismatch = errors.New("already have finalised block with the same setID and round") - errVoteToSignatureMismatch = errors.New("votes and authority count mismatch") errVoteBlockMismatch = errors.New("block in vote is not descendant of previously finalised block") errVoteFromSelf = errors.New("got vote from ourselves") diff --git a/lib/grandpa/message_handler.go b/lib/grandpa/message_handler.go index bbc328108b..c9ce389234 100644 --- a/lib/grandpa/message_handler.go +++ b/lib/grandpa/message_handler.go @@ -9,12 +9,15 @@ import ( "fmt" "github.com/ChainSafe/gossamer/dot/network" - "github.com/ChainSafe/gossamer/dot/types" "github.com/ChainSafe/gossamer/internal/database" + "github.com/ChainSafe/gossamer/internal/primitives/core/hash" + "github.com/ChainSafe/gossamer/internal/primitives/runtime" "github.com/ChainSafe/gossamer/lib/blocktree" "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/lib/crypto/ed25519" - "github.com/ChainSafe/gossamer/pkg/scale" + + client_grandpa "github.com/ChainSafe/gossamer/internal/client/consensus/grandpa" + finality_grandpa "github.com/ChainSafe/gossamer/pkg/finality-grandpa" "github.com/libp2p/go-libp2p/core/peer" ) @@ -360,144 +363,42 @@ func (h *MessageHandler) verifyPreCommitJustification(msg *CatchUpResponse) erro // VerifyBlockJustification verifies the finality justification for a block, returns scale encoded justification with // any extra bytes removed. -func (s *Service) VerifyBlockJustification(hash common.Hash, justification []byte) error { - fj := Justification{} - err := scale.Unmarshal(justification, &fj) - if err != nil { - return err - } - - if hash != fj.Commit.Hash { - return fmt.Errorf("%w: justification %s and block hash %s", - ErrJustificationMismatch, fj.Commit.Hash.Short(), hash.Short()) - } - - setID, err := s.grandpaState.GetSetIDByBlockNumber(uint(fj.Commit.Number)) - if err != nil { - return fmt.Errorf("cannot get set ID from block number: %w", err) - } - - has, err := s.blockState.HasFinalisedBlock(fj.Round, setID) +func (s *Service) VerifyBlockJustification(finalizedHash common.Hash, finalizedNumber uint, encoded []byte) ( + round uint64, setID uint64, err error, +) { + setID, err = s.grandpaState.GetSetIDByBlockNumber(finalizedNumber) if err != nil { - return fmt.Errorf("checking if round and set id has finalised block: %w", err) - } - - if has { - storedFinalisedHash, err := s.blockState.GetFinalisedHash(fj.Round, setID) - if err != nil { - return fmt.Errorf("getting finalised hash: %w", err) - } - if storedFinalisedHash != hash { - return fmt.Errorf("%w, setID=%d and round=%d", errFinalisedBlocksMismatch, setID, fj.Round) - } - - return nil - } - - isDescendant, err := isDescendantOfHighestFinalisedBlock(s.blockState, fj.Commit.Hash) - if err != nil { - return fmt.Errorf("checking if descendant of highest block: %w", err) - } - - if !isDescendant { - return errVoteBlockMismatch + return 0, 0, fmt.Errorf("cannot get set ID from block number: %w", err) } auths, err := s.grandpaState.GetAuthorities(setID) if err != nil { - return fmt.Errorf("cannot get authorities for set ID: %w", err) + return 0, 0, fmt.Errorf("cannot get authorities for set ID: %w", err) } - // threshold is two-thirds the number of authorities, - // uses the current set of authorities to define the threshold - threshold := (2 * len(auths) / 3) - - if len(fj.Commit.Precommits) < threshold { - return ErrMinVotesNotMet - } - - authPubKeys := make([]AuthData, len(fj.Commit.Precommits)) - for i, pcj := range fj.Commit.Precommits { - authPubKeys[i] = AuthData{AuthorityID: pcj.AuthorityID} - } - - equivocatoryVoters := getEquivocatoryVoters(authPubKeys) - - var count int - - logger.Debugf( - "verifying justification: set id %d, round %d, hash %s, number %d, sig count %d", - setID, fj.Round, fj.Commit.Hash, fj.Commit.Number, len(fj.Commit.Precommits)) - - for _, just := range fj.Commit.Precommits { - // check if vote was for descendant of committed block - isDescendant, err := s.blockState.IsDescendantOf(hash, just.Vote.Hash) - if err != nil { - return err - } - - if !isDescendant { - return ErrPrecommitBlockMismatch - } - - publicKey, err := ed25519.NewPublicKey(just.AuthorityID[:]) - if err != nil { - return err - } - - if !isInAuthSet(publicKey, auths) { - return ErrAuthorityNotInSet - } + logger.Debugf("verifying justification within set id %d and authorities %d", setID, len(auths)) - // verify signature for each precommit - msg, err := scale.Marshal(FullVote{ - Stage: precommit, - Vote: just.Vote, - Round: fj.Round, - SetID: setID, - }) - if err != nil { - return err + idsAndWeights := make([]finality_grandpa.IDWeight[string], len(auths)) + for idx, auth := range auths { + idsAndWeights[idx] = finality_grandpa.IDWeight[string]{ + ID: string(auth.Key.Encode()), + Weight: 1, } - - ok, err := publicKey.Verify(msg, just.Signature[:]) - if err != nil { - return err - } - - if !ok { - return ErrInvalidSignature - } - - if _, ok := equivocatoryVoters[just.AuthorityID]; ok { - continue - } - - count++ } - if count+len(equivocatoryVoters) < threshold { - return ErrMinVotesNotMet + voters := finality_grandpa.NewVoterSet(idsAndWeights) + target := client_grandpa.HashNumber[hash.H256, uint32]{ + Hash: hash.H256(finalizedHash.ToBytes()), + Number: uint32(finalizedNumber), } - err = verifyBlockHashAgainstBlockNumber(s.blockState, fj.Commit.Hash, uint(fj.Commit.Number)) + justification, err := client_grandpa.DecodeGrandpaJustificationVerifyFinalizes[hash.H256, uint32, runtime.BlakeTwo256]( + encoded, target, setID, *voters) if err != nil { - return fmt.Errorf("verifying block hash against block number: %w", err) + return 0, 0, fmt.Errorf("decoding and verifying justification: %w", err) } - for _, preCommit := range fj.Commit.Precommits { - err := verifyBlockHashAgainstBlockNumber(s.blockState, preCommit.Vote.Hash, uint(preCommit.Vote.Number)) - if err != nil { - return fmt.Errorf("verifying block hash against block number: %w", err) - } - } - - err = s.blockState.SetFinalisedHash(hash, fj.Round, setID) - if err != nil { - return fmt.Errorf("setting finalised hash: %w", err) - } - - return nil + return justification.Justification.Round, setID, nil } func verifyBlockHashAgainstBlockNumber(bs BlockState, hash common.Hash, number uint) error { @@ -512,13 +413,3 @@ func verifyBlockHashAgainstBlockNumber(bs BlockState, hash common.Hash, number u } return nil } - -func isInAuthSet(auth *ed25519.PublicKey, set []types.GrandpaVoter) bool { - for _, a := range set { - if bytes.Equal(a.Key.Encode(), auth.Encode()) { - return true - } - } - - return false -} diff --git a/lib/grandpa/message_handler_integration_test.go b/lib/grandpa/message_handler_integration_test.go index 89a43c184b..907f080de1 100644 --- a/lib/grandpa/message_handler_integration_test.go +++ b/lib/grandpa/message_handler_integration_test.go @@ -6,7 +6,6 @@ package grandpa import ( - "errors" "fmt" "testing" "time" @@ -712,299 +711,6 @@ func TestMessageHandler_HandleCatchUpResponse(t *testing.T) { require.Equal(t, round+1, gs.state.round) } -func TestMessageHandler_VerifyBlockJustification_WithEquivocatoryVotes(t *testing.T) { - kr, err := keystore.NewEd25519Keyring() - require.NoError(t, err) - aliceKeyPair := kr.Alice().(*ed25519.Keypair) - - auths := []types.GrandpaVoter{ - { - Key: *kr.Alice().Public().(*ed25519.PublicKey), - }, - { - Key: *kr.Bob().Public().(*ed25519.PublicKey), - }, - { - Key: *kr.Charlie().Public().(*ed25519.PublicKey), - }, - { - Key: *kr.Dave().Public().(*ed25519.PublicKey), - }, - { - Key: *kr.Eve().Public().(*ed25519.PublicKey), - }, - { - Key: *kr.Ferdie().Public().(*ed25519.PublicKey), - }, - { - Key: *kr.George().Public().(*ed25519.PublicKey), - }, - { - Key: *kr.Heather().Public().(*ed25519.PublicKey), - }, - { - Key: *kr.Ian().Public().(*ed25519.PublicKey), - }, - } - - gs, st := newTestService(t, aliceKeyPair) - err = st.Grandpa.SetNextChange(auths, 0) - require.NoError(t, err) - - body, err := types.NewBodyFromBytes([]byte{0}) - require.NoError(t, err) - - block := &types.Block{ - Header: *testHeader, - Body: *body, - } - - err = st.Block.AddBlock(block) - require.NoError(t, err) - - setID, err := st.Grandpa.IncrementSetID() - require.NoError(t, err) - require.Equal(t, uint64(1), setID) - - round := uint64(1) - number := uint32(1) - precommits := buildTestJustification(t, 18, round, setID, kr, precommit) - just := newJustification(round, testHash, number, precommits) - data, err := scale.Marshal(*just) - require.NoError(t, err) - err = gs.VerifyBlockJustification(testHash, data) - require.NoError(t, err) -} - -func TestMessageHandler_VerifyBlockJustification(t *testing.T) { - - kr, err := keystore.NewEd25519Keyring() - require.NoError(t, err) - aliceKeyPair := kr.Alice().(*ed25519.Keypair) - - auths := []types.GrandpaVoter{ - { - Key: *kr.Alice().Public().(*ed25519.PublicKey), - }, - { - Key: *kr.Bob().Public().(*ed25519.PublicKey), - }, - { - Key: *kr.Charlie().Public().(*ed25519.PublicKey), - }, - } - - gs, st := newTestService(t, aliceKeyPair) - err = st.Grandpa.SetNextChange(auths, 0) - require.NoError(t, err) - - body, err := types.NewBodyFromBytes([]byte{0}) - require.NoError(t, err) - - block := &types.Block{ - Header: *testHeader, - Body: *body, - } - - err = st.Block.AddBlock(block) - require.NoError(t, err) - - digest2 := types.NewDigest() - prd2, _ := types.NewBabeSecondaryPlainPreDigest(0, 2).ToPreRuntimeDigest() - digest2.Add(*prd2) - - testHeader2 := types.Header{ - ParentHash: testGenesisHeader.Hash(), - Number: 1, - Digest: digest2, - } - - block2 := &types.Block{ - Header: testHeader2, - Body: *body, - } - - err = st.Block.AddBlock(block2) - require.NoError(t, err) - - err = st.Block.SetHeader(&testHeader2) - require.NoError(t, err) - - setID, err := st.Grandpa.IncrementSetID() - require.NoError(t, err) - require.Equal(t, uint64(1), setID) - - round := uint64(1) - number := uint32(1) - precommits := buildTestJustification(t, 2, round, setID, kr, precommit) - just := newJustification(round, testHash, number, precommits) - data, err := scale.Marshal(*just) - require.NoError(t, err) - err = gs.VerifyBlockJustification(testHash, data) - require.NoError(t, err) - - // use wrong hash, shouldn't verify - precommits = buildTestJustification(t, 2, round+1, setID, kr, precommit) - just = newJustification(round+1, testHash, number, precommits) - just.Commit.Precommits[0].Vote.Hash = testHeader2.Hash() - data, err = scale.Marshal(*just) - require.NoError(t, err) - err = gs.VerifyBlockJustification(testHash, data) - require.Equal(t, ErrPrecommitBlockMismatch, err) -} - -func TestMessageHandler_VerifyBlockJustification_invalid(t *testing.T) { - kr, err := keystore.NewEd25519Keyring() - require.NoError(t, err) - aliceKeyPair := kr.Alice().(*ed25519.Keypair) - - auths := []types.GrandpaVoter{ - { - Key: *kr.Alice().Public().(*ed25519.PublicKey), - }, - { - Key: *kr.Bob().Public().(*ed25519.PublicKey), - }, - { - Key: *kr.Charlie().Public().(*ed25519.PublicKey), - }, - } - - gs, st := newTestService(t, aliceKeyPair) - err = st.Grandpa.SetNextChange(auths, 1) - require.NoError(t, err) - - body, err := types.NewBodyFromBytes([]byte{0}) - require.NoError(t, err) - - block := &types.Block{ - Header: *testHeader, - Body: *body, - } - - err = st.Block.AddBlock(block) - require.NoError(t, err) - - setID, err := st.Grandpa.IncrementSetID() - require.NoError(t, err) - require.Equal(t, uint64(1), setID) - - genhash := st.Block.GenesisHash() - round := uint64(2) - number := uint32(2) - - // use wrong hash, shouldn't verify - precommits := buildTestJustification(t, 2, round+1, setID, kr, precommit) - just := newJustification(round+1, testHash, number, precommits) - just.Commit.Precommits[0].Vote.Hash = genhash - data, err := scale.Marshal(*just) - require.NoError(t, err) - err = gs.VerifyBlockJustification(testHash, data) - require.Equal(t, ErrPrecommitBlockMismatch, err) - - // use wrong round, shouldn't verify - precommits = buildTestJustification(t, 2, round+1, setID, kr, precommit) - just = newJustification(round+2, testHash, number, precommits) - data, err = scale.Marshal(*just) - require.NoError(t, err) - err = gs.VerifyBlockJustification(testHash, data) - require.Equal(t, ErrInvalidSignature, err) - - // add authority not in set, shouldn't verify - precommits = buildTestJustification(t, len(auths)+1, round+1, setID, kr, precommit) - just = newJustification(round+1, testHash, number, precommits) - data, err = scale.Marshal(*just) - require.NoError(t, err) - err = gs.VerifyBlockJustification(testHash, data) - require.Equal(t, ErrAuthorityNotInSet, err) - - // not enough signatures, shouldn't verify - precommits = buildTestJustification(t, 1, round+1, setID, kr, precommit) - just = newJustification(round+1, testHash, number, precommits) - data, err = scale.Marshal(*just) - require.NoError(t, err) - err = gs.VerifyBlockJustification(testHash, data) - require.Equal(t, ErrMinVotesNotMet, err) - - // mismatch justification header and block header - precommits = buildTestJustification(t, 1, round+1, setID, kr, precommit) - just = newJustification(round+1, testHash, number, precommits) - data, err = scale.Marshal(*just) - require.NoError(t, err) - otherHeader := types.NewEmptyHeader() - err = gs.VerifyBlockJustification(otherHeader.Hash(), data) - require.ErrorIs(t, err, ErrJustificationMismatch) - - expectedErr := fmt.Sprintf("%s: justification %s and block hash %s", ErrJustificationMismatch, - testHash.Short(), otherHeader.Hash().Short()) - assert.ErrorIs(t, err, ErrJustificationMismatch) - require.EqualError(t, err, expectedErr) -} - -func TestMessageHandler_VerifyBlockJustification_ErrFinalisedBlockMismatch(t *testing.T) { - t.Parallel() - - kr, err := keystore.NewEd25519Keyring() - require.NoError(t, err) - aliceKeyPair := kr.Alice().(*ed25519.Keypair) - - auths := []types.GrandpaVoter{ - { - Key: *kr.Alice().Public().(*ed25519.PublicKey), - }, - { - Key: *kr.Bob().Public().(*ed25519.PublicKey), - }, - { - Key: *kr.Charlie().Public().(*ed25519.PublicKey), - }, - } - - gs, st := newTestService(t, aliceKeyPair) - err = st.Grandpa.SetNextChange(auths, 1) - require.NoError(t, err) - - body, err := types.NewBodyFromBytes([]byte{0}) - require.NoError(t, err) - - block := &types.Block{ - Header: *testHeader, - Body: *body, - } - - err = st.Block.AddBlock(block) - require.NoError(t, err) - - setID := uint64(0) - round := uint64(1) - number := uint32(1) - - err = st.Block.SetFinalisedHash(block.Header.Hash(), round, setID) - require.NoError(t, err) - - var testHeader2 = &types.Header{ - ParentHash: testHeader.Hash(), - Number: 2, - Digest: newTestDigest(), - } - - testHash = testHeader2.Hash() - block2 := &types.Block{ - Header: *testHeader2, - Body: *body, - } - err = st.Block.AddBlock(block2) - require.NoError(t, err) - - // justification fails since there is already a block finalised in this round and set id - precommits := buildTestJustification(t, 18, round, setID, kr, precommit) - just := newJustification(round, testHash, number, precommits) - data, err := scale.Marshal(*just) - require.NoError(t, err) - err = gs.VerifyBlockJustification(testHash, data) - require.ErrorIs(t, err, errFinalisedBlocksMismatch) -} - func Test_getEquivocatoryVoters(t *testing.T) { t.Parallel() @@ -1467,123 +1173,3 @@ func signFakeFullVote( return sig } - -func TestService_VerifyBlockJustification(t *testing.T) { //nolint - kr, err := keystore.NewEd25519Keyring() - require.NoError(t, err) - - precommits := buildTestJustification(t, 2, 1, 0, kr, precommit) - justification := newJustification(1, testHash, 1, precommits) - justificationBytes, err := scale.Marshal(*justification) - require.NoError(t, err) - - type fields struct { - blockStateBuilder func(ctrl *gomock.Controller) BlockState - grandpaStateBuilder func(ctrl *gomock.Controller) GrandpaState - } - type args struct { - hash common.Hash - justification []byte - } - tests := map[string]struct { - fields fields - args args - want []byte - wantErr error - }{ - "invalid_justification": { - fields: fields{ - blockStateBuilder: func(ctrl *gomock.Controller) BlockState { - return nil - }, - grandpaStateBuilder: func(ctrl *gomock.Controller) GrandpaState { - return nil - }, - }, - args: args{ - hash: common.Hash{}, - justification: []byte{1, 2, 3}, - }, - want: nil, - wantErr: errors.New("decoding struct: unmarshalling field at index 1: decoding struct: unmarshalling" + - " field at index 0: EOF"), - }, - "valid_justification": { - fields: fields{ - blockStateBuilder: func(ctrl *gomock.Controller) BlockState { - mockBlockState := NewMockBlockState(ctrl) - mockBlockState.EXPECT().HasFinalisedBlock(uint64(1), uint64(0)).Return(false, nil) - mockBlockState.EXPECT().GetHighestFinalisedHeader().Return(testHeader, nil) - mockBlockState.EXPECT().IsDescendantOf(testHash, testHash). - Return(true, nil).Times(3) - mockBlockState.EXPECT().GetHeader(testHash).Return(testHeader, nil).Times(3) - mockBlockState.EXPECT().SetFinalisedHash(testHash, uint64(1), - uint64(0)).Return(nil) - return mockBlockState - }, - grandpaStateBuilder: func(ctrl *gomock.Controller) GrandpaState { - mockGrandpaState := NewMockGrandpaState(ctrl) - mockGrandpaState.EXPECT().GetSetIDByBlockNumber(uint(1)).Return(uint64(0), nil) - mockGrandpaState.EXPECT().GetAuthorities(uint64(0)).Return([]types.GrandpaVoter{ - {Key: *kr.Alice().Public().(*ed25519.PublicKey), ID: 1}, - {Key: *kr.Bob().Public().(*ed25519.PublicKey), ID: 2}, - {Key: *kr.Charlie().Public().(*ed25519.PublicKey), ID: 3}, - }, nil) - return mockGrandpaState - }, - }, - args: args{ - hash: testHash, - justification: justificationBytes, - }, - want: justificationBytes, - }, - "valid_justification_extra_bytes": { - fields: fields{ - blockStateBuilder: func(ctrl *gomock.Controller) BlockState { - mockBlockState := NewMockBlockState(ctrl) - mockBlockState.EXPECT().HasFinalisedBlock(uint64(1), uint64(0)).Return(false, nil) - mockBlockState.EXPECT().GetHighestFinalisedHeader().Return(testHeader, nil) - mockBlockState.EXPECT().IsDescendantOf(testHash, testHash). - Return(true, nil).Times(3) - mockBlockState.EXPECT().GetHeader(testHash).Return(testHeader, nil).Times(3) - mockBlockState.EXPECT().SetFinalisedHash(testHash, uint64(1), - uint64(0)).Return(nil) - return mockBlockState - }, - grandpaStateBuilder: func(ctrl *gomock.Controller) GrandpaState { - mockGrandpaState := NewMockGrandpaState(ctrl) - mockGrandpaState.EXPECT().GetSetIDByBlockNumber(uint(1)).Return(uint64(0), nil) - mockGrandpaState.EXPECT().GetAuthorities(uint64(0)).Return([]types.GrandpaVoter{ - {Key: *kr.Alice().Public().(*ed25519.PublicKey), ID: 1}, - {Key: *kr.Bob().Public().(*ed25519.PublicKey), ID: 2}, - {Key: *kr.Charlie().Public().(*ed25519.PublicKey), ID: 3}, - }, nil) - return mockGrandpaState - }, - }, - args: args{ - hash: testHash, - justification: append(justificationBytes, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}...), - }, - want: justificationBytes, - }, - } - for name, tt := range tests { - tt := tt - t.Run(name, func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - s := &Service{ - blockState: tt.fields.blockStateBuilder(ctrl), - grandpaState: tt.fields.grandpaStateBuilder(ctrl), - } - err := s.VerifyBlockJustification(tt.args.hash, tt.args.justification) - if tt.wantErr != nil { - assert.ErrorContains(t, err, tt.wantErr.Error()) - } else { - require.NoError(t, err) - } - }) - } -} diff --git a/lib/grandpa/message_handler_test.go b/lib/grandpa/message_handler_test.go new file mode 100644 index 0000000000..7019abad1f --- /dev/null +++ b/lib/grandpa/message_handler_test.go @@ -0,0 +1,63 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package grandpa + +import ( + "testing" + + "github.com/ChainSafe/gossamer/dot/types" + "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/lib/crypto/ed25519" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" +) + +func TestVerify_WestendBlock512_Justification(t *testing.T) { + wndSetID0Voters := make([]types.GrandpaVoter, 0) + wndSetID0Authorities := []string{ + "0x959cebf18fecb305b96fd998c95f850145f52cbbb64b3ef937c0575cc7ebd652", + "0x9fc415cce1d0b2eed702c9e05f476217d23b46a8723fd56f08cddad650be7c2d", + "0xfeca0be2c87141f6074b221c919c0161a1c468d9173c5c1be59b68fab9a0ff93", + } + + for idx, pubkey := range wndSetID0Authorities { + edPubKey, err := ed25519.NewPublicKey(common.MustHexToBytes(pubkey)) + require.NoError(t, err) + + wndSetID0Voters = append(wndSetID0Voters, types.GrandpaVoter{ + ID: uint64(idx), + Key: *edPubKey, + }) + } + + const currentSetID uint64 = 0 + const block512Justification = "0xc9020000000000005895897f12e1a670609929433ac7a69dcae90e0cc2d9c" + + "32c0dce0e2a5e5e614e000200000c5895897f12e1a670609929433ac7a69dcae90e0cc2d9c32c0dce0e2a5e5e" + + "614e000200006216ec969bb5133b13f54a6121ef3a908d0a87d8409e2d471c0cad1c28532b6e27d6a8d746b43" + + "df96c2149915252a846227b060372e3bb6f49e91500d3d8ef0d959cebf18fecb305b96fd998c95f850145f52c" + + "bbb64b3ef937c0575cc7ebd6525895897f12e1a670609929433ac7a69dcae90e0cc2d9c32c0dce0e2a5e5e614" + + "e0002000092820b93ac482089fffc8246b4111da2e2b7adc786938c24eb25fe3b97cd21b946b7e12cb6fa5546" + + "b73c047ffc7c73b17a6a750bc6f2858bb0d0a7fff2fdd2029fc415cce1d0b2eed702c9e05f476217d23b46a87" + + "23fd56f08cddad650be7c2d5895897f12e1a670609929433ac7a69dcae90e0cc2d9c32c0dce0e2a5e5e614e00" + + "02000017a338b777152d2213908ab29f961ebbca04e6bd1e4cfde6cb1a0b7b7f244c2670935cdf4c2acb4dd06" + + "1913848f5865aa887406a3ea0c8d0dcd4d551ff249900feca0be2c87141f6074b221c919c0161a1c468d9173c5c1be59b68fab9a0ff9300" + + ctrl := gomock.NewController(t) + grandpaMockService := NewMockGrandpaState(ctrl) + grandpaMockService.EXPECT().GetSetIDByBlockNumber(uint(512)).Return(currentSetID, nil) + grandpaMockService.EXPECT().GetAuthorities(currentSetID).Return(wndSetID0Voters, nil) + + service := &Service{ + grandpaState: grandpaMockService, + } + + round, setID, err := service.VerifyBlockJustification( + common.MustHexToHash("0x5895897f12e1a670609929433ac7a69dcae90e0cc2d9c32c0dce0e2a5e5e614e"), + 512, + common.MustHexToBytes(block512Justification)) + + require.NoError(t, err) + require.Equal(t, uint64(0), setID) + require.Equal(t, uint64(713), round) +} diff --git a/lib/runtime/storage/storagediff.go b/lib/runtime/storage/storagediff.go index 25e5b64962..d2d3dddf44 100644 --- a/lib/runtime/storage/storagediff.go +++ b/lib/runtime/storage/storagediff.go @@ -141,13 +141,17 @@ func (cs *storageDiff) clearPrefixInChild(keyToChild string, prefix []byte, // optional limit. It returns the number of keys deleted and a boolean // indicating if all keys with the prefix were removed. func (cs *storageDiff) clearPrefix(prefix []byte, trieKeys []string, limit int) (deleted uint32, allDeleted bool) { - allKeys := slices.Clone(trieKeys) newKeys := maps.Keys(cs.upserts) - allKeys = append(allKeys, newKeys...) + keysToClear := maps.Keys(cs.upserts) + for _, k := range trieKeys { + if _, ok := cs.upserts[k]; !ok { + keysToClear = append(keysToClear, k) + } + } deleted = 0 - sort.Strings(allKeys) - for _, k := range allKeys { + sort.Strings(keysToClear) + for _, k := range keysToClear { if limit == 0 { break } @@ -161,7 +165,7 @@ func (cs *storageDiff) clearPrefix(prefix []byte, trieKeys []string, limit int) } } - return deleted, deleted == uint32(len(allKeys)) + return deleted, deleted == uint32(len(keysToClear)) } // getFromChild attempts to retrieve a value associated with a specific key