diff --git a/dot/network/helpers_test.go b/dot/network/helpers_test.go index 24d4abba46..d48c5d4176 100644 --- a/dot/network/helpers_test.go +++ b/dot/network/helpers_test.go @@ -14,7 +14,6 @@ import ( "github.com/ChainSafe/gossamer/dot/types" "github.com/ChainSafe/gossamer/internal/log" "github.com/ChainSafe/gossamer/lib/common" - "github.com/ChainSafe/gossamer/lib/common/variadic" libp2pnetwork "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" "github.com/stretchr/testify/require" @@ -121,7 +120,7 @@ func (s *testStreamHandler) readStream(stream libp2pnetwork.Stream, } } -var starting, _ = variadic.NewUint32OrHash(uint32(1)) +var starting = messages.NewFromBlock(uint(1)) var one = uint32(1) diff --git a/dot/network/message_test.go b/dot/network/message_test.go index 1e4a6872fc..c9b492e030 100644 --- a/dot/network/message_test.go +++ b/dot/network/message_test.go @@ -5,125 +5,15 @@ package network import ( "encoding/hex" - "regexp" "testing" "github.com/ChainSafe/gossamer/dot/network/messages" "github.com/ChainSafe/gossamer/dot/types" "github.com/ChainSafe/gossamer/lib/common" - "github.com/ChainSafe/gossamer/lib/common/variadic" "github.com/stretchr/testify/require" ) -func TestEncodeBlockRequestMessage(t *testing.T) { - t.Parallel() - - expected := common.MustHexToBytes("0x0880808008280130011220dcd1346701ca8396496e52" + - "aa2785b1748deb6db09551b72159dcb3e08991025b") - genesisHash := common.MustHexToBytes("0xdcd1346701ca8396496e52aa2785b1748deb6db09551b72159dcb3e08991025b") - - var one uint32 = 1 - bm := &messages.BlockRequestMessage{ - RequestedData: 1, - StartingBlock: *variadic.NewUint32OrHashFromBytes(append([]byte{0}, genesisHash...)), - Direction: 1, - Max: &one, - } - - encMsg, err := bm.Encode() - require.NoError(t, err) - - require.Equal(t, expected, encMsg) - - res := new(messages.BlockRequestMessage) - err = res.Decode(encMsg) - require.NoError(t, err) - require.Equal(t, bm, res) -} - -func TestEncodeBlockRequestMessage_BlockHash(t *testing.T) { - t.Parallel() - - genesisHash := common.MustHexToBytes("0xdcd1346701ca8396496e52aa2785b1748deb6db09551b72159dcb3e08991025b") - - var one uint32 = 1 - bm := &messages.BlockRequestMessage{ - RequestedData: 1, - StartingBlock: *variadic.NewUint32OrHashFromBytes(append([]byte{0}, genesisHash...)), - Direction: 1, - Max: &one, - } - - encMsg, err := bm.Encode() - require.NoError(t, err) - - res := new(messages.BlockRequestMessage) - err = res.Decode(encMsg) - require.NoError(t, err) - require.Equal(t, bm, res) -} - -func TestEncodeBlockRequestMessage_BlockNumber(t *testing.T) { - t.Parallel() - - var one uint32 = 1 - bm := &messages.BlockRequestMessage{ - RequestedData: 1, - StartingBlock: *variadic.NewUint32OrHashFromBytes([]byte{1, 1}), - Direction: 1, - Max: &one, - } - - encMsg, err := bm.Encode() - require.NoError(t, err) - - res := new(messages.BlockRequestMessage) - err = res.Decode(encMsg) - require.NoError(t, err) - require.Equal(t, bm, res) -} - -func TestBlockRequestString(t *testing.T) { - t.Parallel() - - genesisHash := common.MustHexToBytes("0xdcd1346701ca8396496e52aa2785b1748deb6db09551b72159dcb3e08991025b") - - bm := &messages.BlockRequestMessage{ - RequestedData: 1, - StartingBlock: *variadic.NewUint32OrHashFromBytes(append([]byte{0}, genesisHash...)), - Direction: 1, - Max: nil, - } - - var blockRequestStringRegex = regexp.MustCompile( - `^\ABlockRequestMessage RequestedData=[0-9]* StartingBlock={[\[0-9(\s?)]+\]} Direction=[0-9]* Max=[0-9]*\z$`) //nolint:lll - - match := blockRequestStringRegex.MatchString(bm.String()) - require.True(t, match) -} - -func TestEncodeBlockRequestMessage_NoOptionals(t *testing.T) { - t.Parallel() - - genesisHash := common.MustHexToBytes("0xdcd1346701ca8396496e52aa2785b1748deb6db09551b72159dcb3e08991025b") - - bm := &messages.BlockRequestMessage{ - RequestedData: 1, - StartingBlock: *variadic.NewUint32OrHashFromBytes(append([]byte{0}, genesisHash...)), - Direction: 1, - Max: nil, - } - - encMsg, err := bm.Encode() - require.NoError(t, err) - - res := new(messages.BlockRequestMessage) - err = res.Decode(encMsg) - require.NoError(t, err) - require.Equal(t, bm, res) -} - func TestEncodeBlockResponseMessage_Empty(t *testing.T) { t.Parallel() @@ -446,7 +336,7 @@ func TestAscendingBlockRequest(t *testing.T) { expectedBlockRequestMessage: []*messages.BlockRequestMessage{ { RequestedData: messages.BootstrapRequestData, - StartingBlock: *variadic.MustNewUint32OrHash(uint32(10)), + StartingBlock: *messages.NewFromBlock(uint(10)), Direction: messages.Ascending, Max: &one, }, @@ -461,7 +351,7 @@ func TestAscendingBlockRequest(t *testing.T) { expectedBlockRequestMessage: []*messages.BlockRequestMessage{ { RequestedData: messages.BootstrapRequestData, - StartingBlock: *variadic.MustNewUint32OrHash(uint32(1)), + StartingBlock: *messages.NewFromBlock(uint(1)), Direction: messages.Ascending, Max: &maxResponseSize, }, @@ -475,25 +365,25 @@ func TestAscendingBlockRequest(t *testing.T) { expectedBlockRequestMessage: []*messages.BlockRequestMessage{ { RequestedData: messages.BootstrapRequestData, - StartingBlock: *variadic.MustNewUint32OrHash(uint32(1)), + StartingBlock: *messages.NewFromBlock(uint(1)), Direction: messages.Ascending, Max: &maxResponseSize, }, { RequestedData: messages.BootstrapRequestData, - StartingBlock: *variadic.MustNewUint32OrHash(uint32(129)), + StartingBlock: *messages.NewFromBlock(uint(129)), Direction: messages.Ascending, Max: &maxResponseSize, }, { RequestedData: messages.BootstrapRequestData, - StartingBlock: *variadic.MustNewUint32OrHash(uint32(257)), + StartingBlock: *messages.NewFromBlock(uint(257)), Direction: messages.Ascending, Max: &maxResponseSize, }, { RequestedData: messages.BootstrapRequestData, - StartingBlock: *variadic.MustNewUint32OrHash(uint32(385)), + StartingBlock: *messages.NewFromBlock(uint(385)), Direction: messages.Ascending, Max: &maxResponseSize, }, @@ -507,31 +397,31 @@ func TestAscendingBlockRequest(t *testing.T) { expectedBlockRequestMessage: []*messages.BlockRequestMessage{ { RequestedData: messages.BootstrapRequestData, - StartingBlock: *variadic.MustNewUint32OrHash(uint32(1)), + StartingBlock: *messages.NewFromBlock(uint(1)), Direction: messages.Ascending, Max: &maxResponseSize, }, { RequestedData: messages.BootstrapRequestData, - StartingBlock: *variadic.MustNewUint32OrHash(uint32(129)), + StartingBlock: *messages.NewFromBlock(uint(129)), Direction: messages.Ascending, Max: &maxResponseSize, }, { RequestedData: messages.BootstrapRequestData, - StartingBlock: *variadic.MustNewUint32OrHash(uint32(257)), + StartingBlock: *messages.NewFromBlock(uint(257)), Direction: messages.Ascending, Max: &maxResponseSize, }, { RequestedData: messages.BootstrapRequestData, - StartingBlock: *variadic.MustNewUint32OrHash(uint32(385)), + StartingBlock: *messages.NewFromBlock(uint(385)), Direction: messages.Ascending, Max: &maxResponseSize, }, { RequestedData: messages.BootstrapRequestData, - StartingBlock: *variadic.MustNewUint32OrHash(uint32(513)), + StartingBlock: *messages.NewFromBlock(uint(513)), Direction: messages.Ascending, Max: &three, }, diff --git a/dot/network/messages/block.go b/dot/network/messages/block.go index 90c7461d9f..44923dd980 100644 --- a/dot/network/messages/block.go +++ b/dot/network/messages/block.go @@ -7,11 +7,11 @@ import ( "encoding/binary" "errors" "fmt" + "math" pb "github.com/ChainSafe/gossamer/dot/network/proto" "github.com/ChainSafe/gossamer/dot/types" "github.com/ChainSafe/gossamer/lib/common" - "github.com/ChainSafe/gossamer/lib/common/variadic" "github.com/ChainSafe/gossamer/pkg/scale" "google.golang.org/protobuf/proto" ) @@ -30,6 +30,17 @@ const ( Descending ) +func (s SyncDirection) String() string { + switch s { + case Ascending: + return "ascending" + case Descending: + return "descending" + default: + return "undefined direction" + } +} + // The following defines the fields that will needs to be // in the response message const ( @@ -50,19 +61,63 @@ var ( var ( errBlockRequestFromNumberInvalid = errors.New("block request message From number is not valid") - errInvalidStartingBlockType = errors.New("invalid StartingBlock in messsage") ErrNilBlockInResponse = errors.New("nil block in response") ) +type fromBlockType byte + +const ( + fromBlockNumber fromBlockType = iota + fromBlockHash +) + +type FromBlock struct { + value any +} + +// NewFromBlock returns a new FromBlock given an uint or Hash +// to be used while issuing a block request or while decoding +// a received block request message +func NewFromBlock[T common.Hash | ~uint](value T) *FromBlock { + return &FromBlock{ + value: value, + } +} + +// RawValue returns the inner uint or hash value +func (x *FromBlock) RawValue() any { + return x.value +} + +// Encode will encode a FromBlock into a 4 bytes representation +func (x *FromBlock) Encode() (fromBlockType, []byte) { + switch rawValue := x.value.(type) { + case uint: + encoded := make([]byte, 4) + if rawValue > uint(math.MaxUint32) { + rawValue = math.MaxUint32 + } + binary.LittleEndian.PutUint32(encoded, uint32(rawValue)) + return fromBlockNumber, encoded + case common.Hash: + return fromBlockHash, rawValue.ToBytes() + default: + panic(fmt.Sprintf("unsupported FromBlock type: %T", x.value)) + } +} + // BlockRequestMessage is sent to request some blocks from a peer type BlockRequestMessage struct { RequestedData byte - StartingBlock variadic.Uint32OrHash // first byte 0 = block hash (32 byte), first byte 1 = block number (uint32) - Direction SyncDirection // 0 = ascending, 1 = descending + + // starting block represents a protobuf "oneof" data type + // which means that this field can be either a number or hash + StartingBlock FromBlock + Direction SyncDirection // 0 = ascending, 1 = descending Max *uint32 } -func NewBlockRequest(startingBlock variadic.Uint32OrHash, amount uint32, +func NewBlockRequest(startingBlock FromBlock, amount uint32, requestedData byte, direction SyncDirection) *BlockRequestMessage { return &BlockRequestMessage{ RequestedData: requestedData, @@ -82,7 +137,7 @@ func NewAscendingBlockRequests(startNumber, targetNumber uint, requestedData byt // start and end block are the same, just request 1 block if diff == 0 { return []*BlockRequestMessage{ - NewBlockRequest(*variadic.MustNewUint32OrHash(uint32(startNumber)), 1, requestedData, Ascending), + NewBlockRequest(*NewFromBlock(startNumber), 1, requestedData, Ascending), } } @@ -107,8 +162,7 @@ func NewAscendingBlockRequests(startNumber, targetNumber uint, requestedData byt max = uint32(missingBlocks) } - start := variadic.MustNewUint32OrHash(startNumber) - reqs[i] = NewBlockRequest(*start, max, requestedData, Ascending) + reqs[i] = NewBlockRequest(*NewFromBlock(startNumber), max, requestedData, Ascending) startNumber += uint(max) } @@ -141,19 +195,16 @@ func (bm *BlockRequestMessage) Encode() ([]byte, error) { MaxBlocks: max, } - if bm.StartingBlock.IsHash() { - hash := bm.StartingBlock.Hash() + protoType, encoded := bm.StartingBlock.Encode() + switch protoType { + case fromBlockHash: msg.FromBlock = &pb.BlockRequest_Hash{ - Hash: hash[:], + Hash: encoded, } - } else if bm.StartingBlock.IsUint32() { - buf := make([]byte, 4) - binary.LittleEndian.PutUint32(buf, bm.StartingBlock.Uint32()) + case fromBlockNumber: msg.FromBlock = &pb.BlockRequest_Number{ - Number: buf, + Number: encoded, } - } else { - return nil, errInvalidStartingBlockType } return proto.Marshal(msg) @@ -168,20 +219,20 @@ func (bm *BlockRequestMessage) Decode(in []byte) error { } var ( - startingBlock *variadic.Uint32OrHash + startingBlock *FromBlock max *uint32 ) switch from := msg.FromBlock.(type) { case *pb.BlockRequest_Hash: - startingBlock, err = variadic.NewUint32OrHash(common.BytesToHash(from.Hash)) + startingBlock = NewFromBlock(common.BytesToHash(from.Hash)) case *pb.BlockRequest_Number: if len(from.Number) != 4 { return fmt.Errorf("%w expected 4 bytes, got %d bytes", errBlockRequestFromNumberInvalid, len(from.Number)) } - number := binary.LittleEndian.Uint32(from.Number) - startingBlock, err = variadic.NewUint32OrHash(number) + number := uint(binary.LittleEndian.Uint32(from.Number)) + startingBlock = NewFromBlock(number) default: err = errors.New("invalid StartingBlock") } diff --git a/dot/sync/chain_sync.go b/dot/sync/chain_sync.go index a37240138a..75683b1b4b 100644 --- a/dot/sync/chain_sync.go +++ b/dot/sync/chain_sync.go @@ -24,7 +24,6 @@ import ( "github.com/ChainSafe/gossamer/dot/types" "github.com/ChainSafe/gossamer/internal/database" "github.com/ChainSafe/gossamer/lib/common" - "github.com/ChainSafe/gossamer/lib/common/variadic" ) var _ ChainSync = (*chainSync)(nil) @@ -399,7 +398,7 @@ func (cs *chainSync) requestChainBlocks(announcedHeader, bestBlockHeader *types. totalBlocks := uint32(1) var request *messages.BlockRequestMessage - startingBlock := *variadic.MustNewUint32OrHash(announcedHeader.Hash()) + startingBlock := *messages.NewFromBlock(announcedHeader.Hash()) if gapLength > 1 { request = messages.NewBlockRequest(startingBlock, gapLength, @@ -444,7 +443,7 @@ func (cs *chainSync) requestForkBlocks(bestBlockHeader, highestFinalizedHeader, startAtBlock := announcedHeader.Number announcedHash := announcedHeader.Hash() var request *messages.BlockRequestMessage - startingBlock := *variadic.MustNewUint32OrHash(announcedHash) + startingBlock := *messages.NewFromBlock(announcedHash) if parentExists { request = messages.NewBlockRequest(startingBlock, 1, messages.BootstrapRequestData, messages.Descending) @@ -503,7 +502,7 @@ func (cs *chainSync) requestPendingBlocks(highestFinalizedHeader *types.Header) gapLength = 128 } - descendingGapRequest := messages.NewBlockRequest(*variadic.MustNewUint32OrHash(pendingBlock.hash), + descendingGapRequest := messages.NewBlockRequest(*messages.NewFromBlock(pendingBlock.hash), uint32(gapLength), messages.BootstrapRequestData, messages.Descending) startAtBlock := pendingBlock.number - uint(*descendingGapRequest.Max) + 1 @@ -754,11 +753,8 @@ taskResultLoop: difference := uint32(int(*request.Max) - len(response.BlockData)) lastItem := response.BlockData[len(response.BlockData)-1] - startRequestNumber := uint32(lastItem.Header.Number + 1) - startAt, err := variadic.NewUint32OrHash(startRequestNumber) - if err != nil { - panic(err) - } + startRequestNumber := lastItem.Header.Number + 1 + startAt := messages.NewFromBlock(startRequestNumber) taskResult.request = &messages.BlockRequestMessage{ RequestedData: messages.BootstrapRequestData, diff --git a/dot/sync/chain_sync_test.go b/dot/sync/chain_sync_test.go index 4af6deac79..e6e2ddd077 100644 --- a/dot/sync/chain_sync_test.go +++ b/dot/sync/chain_sync_test.go @@ -16,7 +16,6 @@ import ( "github.com/ChainSafe/gossamer/dot/telemetry" "github.com/ChainSafe/gossamer/dot/types" "github.com/ChainSafe/gossamer/lib/common" - "github.com/ChainSafe/gossamer/lib/common/variadic" "github.com/ChainSafe/gossamer/lib/runtime/storage" "github.com/ChainSafe/gossamer/pkg/trie" inmemory_trie "github.com/ChainSafe/gossamer/pkg/trie/inmemory" @@ -159,7 +158,7 @@ func Test_chainSync_onBlockAnnounce(t *testing.T) { Return(block2AnnounceHeader, nil). Times(2) - expectedRequest := messages.NewBlockRequest(*variadic.MustNewUint32OrHash(block2AnnounceHeader.Hash()), + expectedRequest := messages.NewBlockRequest(*messages.NewFromBlock(block2AnnounceHeader.Hash()), 1, messages.BootstrapRequestData, messages.Descending) fakeBlockBody := types.Body([]types.Extrinsic{}) @@ -514,7 +513,7 @@ func TestChainSync_BootstrapSync_SuccessfulSync_WithOneWorker(t *testing.T) { mockedNetwork := NewMockNetwork(ctrl) workerPeerID := peer.ID("noot") - startingBlock := variadic.MustNewUint32OrHash(1) + startingBlock := messages.NewFromBlock(uint(1)) max := uint32(128) mockedRequestMaker := NewMockRequestMaker(ctrl) diff --git a/dot/sync/message.go b/dot/sync/message.go index e8fdce6db8..1bf7922c1a 100644 --- a/dot/sync/message.go +++ b/dot/sync/message.go @@ -6,6 +6,7 @@ package sync import ( "bytes" "fmt" + "slices" "github.com/ChainSafe/gossamer/dot/network/messages" "github.com/ChainSafe/gossamer/dot/peerset" @@ -21,8 +22,8 @@ func (s *Service) CreateBlockResponse(from peer.ID, req *messages.BlockRequestMe *messages.BlockResponseMessage, error) { logger.Debugf("sync request from %s: %s", from, req.String()) - if !req.StartingBlock.IsUint32() && !req.StartingBlock.IsHash() { - return nil, ErrInvalidBlockRequest + if req.RequestedData == 0 { + return nil, fmt.Errorf("%w: invalid requested data %v", ErrInvalidBlockRequest, req.RequestedData) } encodedRequest, err := req.Encode() @@ -77,18 +78,18 @@ func (s *Service) handleAscendingRequest(req *messages.BlockRequestMessage) (*me return nil, fmt.Errorf("getting best block for request: %w", err) } - switch startBlock := req.StartingBlock.Value().(type) { - case uint32: + switch startBlock := req.StartingBlock.RawValue().(type) { + case uint: if startBlock == 0 { startBlock = 1 } // if request start is higher than our best block, return error - if bestBlockNumber < uint(startBlock) { + if bestBlockNumber < startBlock { return nil, errRequestStartTooHigh } - startNumber = uint(startBlock) + startNumber = startBlock case common.Hash: startHash = &startBlock @@ -147,18 +148,18 @@ func (s *Service) handleDescendingRequest(req *messages.BlockRequestMessage) (*m max = uint(*req.Max) } - switch startBlock := req.StartingBlock.Value().(type) { - case uint32: + switch startBlock := req.StartingBlock.RawValue().(type) { + case uint: bestBlockNumber, err := s.blockState.BestBlockNumber() if err != nil { return nil, fmt.Errorf("failed to get best block %d for request: %w", bestBlockNumber, err) } // if request start is higher than our best block, only return blocks from our best block and below - if bestBlockNumber < uint(startBlock) { + if bestBlockNumber < startBlock { startNumber = bestBlockNumber } else { - startNumber = uint(startBlock) + startNumber = startBlock } case common.Hash: startHash = &startBlock @@ -192,15 +193,15 @@ func (s *Service) handleDescendingRequest(req *messages.BlockRequestMessage) (*m } if startHash == nil || endHash == nil { - logger.Debugf("handling BlockRequestMessage with direction %s "+ - "from start block with number %d to end block with number %d", - req.Direction, startNumber, endNumber) + logger.Infof("handling block request message with direction %s "+ + "from number %d to number %d\n", + req.Direction.String(), startNumber, endNumber) return s.handleDescendingByNumber(startNumber, endNumber, req.RequestedData) } - logger.Debugf("handling block request message with direction %s "+ - "from start block with hash %s to end block with hash %s", - req.Direction, *startHash, *endHash) + logger.Infof("handling block request message with direction %s "+ + "from hash %s to end block with hash %s", + req.Direction.String(), *startHash, *endHash) return s.handleChainByHash(*endHash, *startHash, max, req.RequestedData, req.Direction) } @@ -305,19 +306,20 @@ func (s *Service) handleAscendingByNumber(start, end uint, func (s *Service) handleDescendingByNumber(start, end uint, requestedData byte) (*messages.BlockResponseMessage, error) { var err error - data := make([]*types.BlockData, (start-end)+1) + + response := &messages.BlockResponseMessage{ + BlockData: make([]*types.BlockData, (start-end)+1), + } for i := uint(0); start-i >= end; i++ { blockNumber := start - i - data[i], err = s.getBlockDataByNumber(blockNumber, requestedData) + response.BlockData[i], err = s.getBlockDataByNumber(blockNumber, requestedData) if err != nil { return nil, err } } - return &messages.BlockResponseMessage{ - BlockData: data, - }, nil + return response, nil } func (s *Service) handleChainByHash(ancestor, descendant common.Hash, @@ -338,10 +340,12 @@ func (s *Service) handleChainByHash(ancestor, descendant common.Hash, } } - data := make([]*types.BlockData, len(subchain)) + response := &messages.BlockResponseMessage{ + BlockData: make([]*types.BlockData, len(subchain)), + } for i, hash := range subchain { - data[i], err = s.getBlockData(hash, requestedData) + response.BlockData[i], err = s.getBlockData(hash, requestedData) if err != nil { return nil, err } @@ -349,12 +353,10 @@ func (s *Service) handleChainByHash(ancestor, descendant common.Hash, // reverse BlockData, if descending request if direction == messages.Descending { - reverseBlockData(data) + slices.Reverse(response.BlockData) } - return &messages.BlockResponseMessage{ - BlockData: data, - }, nil + return response, nil } func (s *Service) getBlockDataByNumber(num uint, requestedData byte) (*types.BlockData, error) { @@ -372,10 +374,6 @@ func (s *Service) getBlockData(hash common.Hash, requestedData byte) (*types.Blo Hash: hash, } - if requestedData == 0 { - return blockData, nil - } - if (requestedData & messages.RequestedDataHeader) == 1 { blockData.Header, err = s.blockState.GetHeader(hash) if err != nil { diff --git a/dot/sync/message_integration_test.go b/dot/sync/message_integration_test.go index d10465525a..4ad7f442a4 100644 --- a/dot/sync/message_integration_test.go +++ b/dot/sync/message_integration_test.go @@ -11,7 +11,6 @@ import ( "github.com/ChainSafe/gossamer/dot/network/messages" "github.com/ChainSafe/gossamer/dot/state" "github.com/ChainSafe/gossamer/dot/types" - "github.com/ChainSafe/gossamer/lib/common/variadic" "github.com/ChainSafe/gossamer/pkg/trie" "github.com/libp2p/go-libp2p/core/peer" @@ -52,8 +51,7 @@ func TestService_CreateBlockResponse_MaxSize(t *testing.T) { addTestBlocksToState(t, messages.MaxBlocksInResponse*2, s.blockState) // test ascending - start, err := variadic.NewUint32OrHash(1) - require.NoError(t, err) + start := messages.NewFromBlock(uint(1)) req := &messages.BlockRequestMessage{ RequestedData: 3, @@ -97,9 +95,7 @@ func TestService_CreateBlockResponse_MaxSize(t *testing.T) { require.Equal(t, uint(16), resp.BlockData[15].Number()) // test descending - start, err = variadic.NewUint32OrHash(uint32(128)) - require.NoError(t, err) - + start = messages.NewFromBlock(uint(128)) req = &messages.BlockRequestMessage{ RequestedData: 3, StartingBlock: *start, @@ -114,8 +110,7 @@ func TestService_CreateBlockResponse_MaxSize(t *testing.T) { require.Equal(t, uint(1), resp.BlockData[127].Number()) max = uint32(messages.MaxBlocksInResponse + 100) - start, err = variadic.NewUint32OrHash(uint32(256)) - require.NoError(t, err) + start = messages.NewFromBlock(uint(256)) req = &messages.BlockRequestMessage{ RequestedData: 3, @@ -153,12 +148,9 @@ func TestService_CreateBlockResponse_StartHash(t *testing.T) { startHash, err := s.blockState.GetHashByNumber(1) require.NoError(t, err) - start, err := variadic.NewUint32OrHash(startHash) - require.NoError(t, err) - req := &messages.BlockRequestMessage{ RequestedData: 3, - StartingBlock: *start, + StartingBlock: *messages.NewFromBlock(startHash), Direction: messages.Ascending, Max: nil, } @@ -173,8 +165,7 @@ func TestService_CreateBlockResponse_StartHash(t *testing.T) { startHash, err = s.blockState.GetHashByNumber(16) require.NoError(t, err) - start, err = variadic.NewUint32OrHash(startHash) - require.NoError(t, err) + start := messages.NewFromBlock(startHash) req = &messages.BlockRequestMessage{ RequestedData: 3, @@ -206,8 +197,7 @@ func TestService_CreateBlockResponse_StartHash(t *testing.T) { startHash, err = s.blockState.GetHashByNumber(256) require.NoError(t, err) - start, err = variadic.NewUint32OrHash(startHash) - require.NoError(t, err) + start = messages.NewFromBlock(startHash) req = &messages.BlockRequestMessage{ RequestedData: 3, @@ -225,8 +215,7 @@ func TestService_CreateBlockResponse_StartHash(t *testing.T) { startHash, err = s.blockState.GetHashByNumber(128) require.NoError(t, err) - start, err = variadic.NewUint32OrHash(startHash) - require.NoError(t, err) + start = messages.NewFromBlock(startHash) req = &messages.BlockRequestMessage{ RequestedData: 3, @@ -356,9 +345,7 @@ func TestService_CreateBlockResponse_Fields(t *testing.T) { Justification: &c, } - start, err := variadic.NewUint32OrHash(uint32(1)) - require.NoError(t, err) - + start := messages.NewFromBlock(uint(1)) err = s.blockState.CompareAndSetBlockData(bds) require.NoError(t, err) diff --git a/dot/sync/message_test.go b/dot/sync/message_test.go index ff4413ba32..59953520cb 100644 --- a/dot/sync/message_test.go +++ b/dot/sync/message_test.go @@ -5,12 +5,12 @@ package sync import ( "errors" + "fmt" "testing" "github.com/ChainSafe/gossamer/dot/network/messages" "github.com/ChainSafe/gossamer/dot/types" "github.com/ChainSafe/gossamer/lib/common" - "github.com/ChainSafe/gossamer/lib/common/variadic" lrucache "github.com/ChainSafe/gossamer/lib/utils/lru-cache" "github.com/libp2p/go-libp2p/core/peer" "github.com/stretchr/testify/assert" @@ -34,22 +34,28 @@ func TestService_CreateBlockResponse(t *testing.T) { mockBlockState := NewMockBlockState(ctrl) return mockBlockState }, - args: args{req: &messages.BlockRequestMessage{}}, - err: ErrInvalidBlockRequest, + args: args{req: &messages.BlockRequestMessage{ + StartingBlock: *messages.NewFromBlock(uint(0)), + }}, + err: fmt.Errorf("%w: invalid requested data %v", ErrInvalidBlockRequest, 0), }, "ascending_request_nil_startHash": { blockStateBuilder: func(ctrl *gomock.Controller) BlockState { + dummyHeader := &types.Header{Number: 2} mockBlockState := NewMockBlockState(ctrl) mockBlockState.EXPECT().BestBlockNumber().Return(uint(1), nil) mockBlockState.EXPECT().GetHashByNumber(uint(1)).Return(common.Hash{1, 2}, nil) + mockBlockState.EXPECT().GetHeader(common.Hash{1, 2}).Return(dummyHeader, nil) return mockBlockState }, args: args{req: &messages.BlockRequestMessage{ - StartingBlock: *variadic.MustNewUint32OrHash(uint32(0)), + RequestedData: messages.RequestedDataHeader, + StartingBlock: *messages.NewFromBlock(uint(0)), Direction: messages.Ascending, }}, want: &messages.BlockResponseMessage{BlockData: []*types.BlockData{{ - Hash: common.Hash{1, 2}, + Hash: common.Hash{1, 2}, + Header: &types.Header{Number: 2}, }}}, }, "ascending_request_start_number_higher": { @@ -59,7 +65,8 @@ func TestService_CreateBlockResponse(t *testing.T) { return mockBlockState }, args: args{req: &messages.BlockRequestMessage{ - StartingBlock: *variadic.MustNewUint32OrHash(2), + RequestedData: messages.BootstrapRequestData, + StartingBlock: *messages.NewFromBlock(uint(2)), Direction: messages.Ascending, }}, err: errRequestStartTooHigh, @@ -72,29 +79,48 @@ func TestService_CreateBlockResponse(t *testing.T) { return mockBlockState }, args: args{req: &messages.BlockRequestMessage{ - StartingBlock: *variadic.MustNewUint32OrHash(0), + RequestedData: messages.BootstrapRequestData, + StartingBlock: *messages.NewFromBlock(uint(0)), Direction: messages.Descending, }}, want: &messages.BlockResponseMessage{BlockData: []*types.BlockData{}}, }, "descending_request_start_number_higher": { blockStateBuilder: func(ctrl *gomock.Controller) BlockState { + dummyBody := types.NewBody([]types.Extrinsic{ + {0, 1, 2, 3}, + {5, 5, 5, 5}, + }) + mockBlockState := NewMockBlockState(ctrl) mockBlockState.EXPECT().BestBlockNumber().Return(uint(1), nil) mockBlockState.EXPECT().GetHashByNumber(uint(1)).Return(common.Hash{1, 2}, nil) + mockBlockState.EXPECT(). + GetBlockBody(common.Hash{1, 2}). + Return(dummyBody, nil) return mockBlockState }, args: args{req: &messages.BlockRequestMessage{ - StartingBlock: *variadic.MustNewUint32OrHash(2), + RequestedData: messages.RequestedDataBody, + StartingBlock: *messages.NewFromBlock(uint(2)), Direction: messages.Descending, }}, err: nil, want: &messages.BlockResponseMessage{BlockData: []*types.BlockData{{ Hash: common.Hash{1, 2}, + Body: types.NewBody([]types.Extrinsic{ + {0, 1, 2, 3}, + {5, 5, 5, 5}, + }), }}}, }, "ascending_request_startHash": { blockStateBuilder: func(ctrl *gomock.Controller) BlockState { + dummyBody := types.NewBody([]types.Extrinsic{ + {0, 1, 2, 3}, + {5, 5, 5, 5}, + }) + mockBlockState := NewMockBlockState(ctrl) mockBlockState.EXPECT().GetHeader(common.Hash{}).Return(&types.Header{ Number: 1, @@ -103,21 +129,34 @@ func TestService_CreateBlockResponse(t *testing.T) { mockBlockState.EXPECT().GetHashByNumber(uint(2)).Return(common.Hash{1, 2, 3}, nil) mockBlockState.EXPECT().IsDescendantOf(common.Hash{}, common.Hash{1, 2, 3}).Return(true, nil) - mockBlockState.EXPECT().Range(common.Hash{}, common.Hash{1, 2, 3}).Return([]common.Hash{{1, - 2}}, - nil) + mockBlockState.EXPECT(). + Range(common.Hash{}, common.Hash{1, 2, 3}). + Return([]common.Hash{{1, 2}}, nil) + mockBlockState.EXPECT(). + GetBlockBody(common.Hash{1, 2}). + Return(dummyBody, nil) return mockBlockState }, args: args{req: &messages.BlockRequestMessage{ - StartingBlock: *variadic.MustNewUint32OrHash(common.Hash{}), + RequestedData: messages.RequestedDataBody, + StartingBlock: *messages.NewFromBlock(common.Hash{}), Direction: messages.Ascending, }}, want: &messages.BlockResponseMessage{BlockData: []*types.BlockData{{ Hash: common.Hash{1, 2}, + Body: types.NewBody([]types.Extrinsic{ + {0, 1, 2, 3}, + {5, 5, 5, 5}, + }), }}}, }, "descending_request_startHash": { blockStateBuilder: func(ctrl *gomock.Controller) BlockState { + dummyBody := types.NewBody([]types.Extrinsic{ + {0, 1, 2, 3}, + {5, 5, 5, 5}, + }) + mockBlockState := NewMockBlockState(ctrl) mockBlockState.EXPECT().GetHeader(common.Hash{}).Return(&types.Header{ Number: 1, @@ -128,14 +167,23 @@ func TestService_CreateBlockResponse(t *testing.T) { mockBlockState.EXPECT().Range(common.MustHexToHash( "0x6443a0b46e0412e626363028115a9f2cf963eeed526b8b33e5316f08b50d0dc3"), common.Hash{}).Return([]common.Hash{{1, 2}}, nil) + + mockBlockState.EXPECT(). + GetBlockBody(common.Hash{1, 2}). + Return(dummyBody, nil) return mockBlockState }, args: args{req: &messages.BlockRequestMessage{ - StartingBlock: *variadic.MustNewUint32OrHash(common.Hash{}), + RequestedData: messages.RequestedDataBody, + StartingBlock: *messages.NewFromBlock(common.Hash{}), Direction: messages.Descending, }}, want: &messages.BlockResponseMessage{BlockData: []*types.BlockData{{ Hash: common.Hash{1, 2}, + Body: types.NewBody([]types.Extrinsic{ + {0, 1, 2, 3}, + {5, 5, 5, 5}, + }), }}}, }, "invalid_direction": { @@ -144,7 +192,8 @@ func TestService_CreateBlockResponse(t *testing.T) { }, args: args{ req: &messages.BlockRequestMessage{ - StartingBlock: *variadic.MustNewUint32OrHash(common.Hash{}), + RequestedData: messages.BootstrapRequestData, + StartingBlock: *messages.NewFromBlock(common.Hash{}), Direction: messages.SyncDirection(3), }}, err: errInvalidRequestDirection, diff --git a/dot/sync/worker_pool_test.go b/dot/sync/worker_pool_test.go index 14785b28c1..d28cb80b12 100644 --- a/dot/sync/worker_pool_test.go +++ b/dot/sync/worker_pool_test.go @@ -10,7 +10,6 @@ import ( "github.com/ChainSafe/gossamer/dot/network/messages" "github.com/ChainSafe/gossamer/dot/types" "github.com/ChainSafe/gossamer/lib/common" - "github.com/ChainSafe/gossamer/lib/common/variadic" "github.com/libp2p/go-libp2p/core/peer" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" @@ -130,7 +129,7 @@ func TestSyncWorkerPool_listenForRequests_submitRequest(t *testing.T) { defer workerPool.stop() blockHash := common.MustHexToHash("0x750646b852a29e5f3668959916a03d6243a3137e91d0cd36870364931030f707") - blockRequest := messages.NewBlockRequest(*variadic.MustNewUint32OrHash(blockHash), + blockRequest := messages.NewBlockRequest(*messages.NewFromBlock(blockHash), 1, messages.BootstrapRequestData, messages.Ascending) mockedBlockResponse := &messages.BlockResponseMessage{ BlockData: []*types.BlockData{ @@ -178,11 +177,11 @@ func TestSyncWorkerPool_singleWorker_multipleRequests(t *testing.T) { workerPool.newPeer(availablePeer) firstRequestBlockHash := common.MustHexToHash("0x750646b852a29e5f3668959916a03d6243a3137e91d0cd36870364931030f707") - firstBlockRequest := messages.NewBlockRequest(*variadic.MustNewUint32OrHash(firstRequestBlockHash), + firstBlockRequest := messages.NewBlockRequest(*messages.NewFromBlock(firstRequestBlockHash), 1, messages.BootstrapRequestData, messages.Ascending) secondRequestBlockHash := common.MustHexToHash("0x897646b852a29e5f3668959916a03d6243a3137e91d0cd36870364931030f707") - secondBlockRequest := messages.NewBlockRequest(*variadic.MustNewUint32OrHash(firstRequestBlockHash), + secondBlockRequest := messages.NewBlockRequest(*messages.NewFromBlock(firstRequestBlockHash), 1, messages.BootstrapRequestData, messages.Ascending) firstMockedBlockResponse := &messages.BlockResponseMessage{ diff --git a/lib/common/variadic/uint32OrHash.go b/lib/common/variadic/uint32OrHash.go deleted file mode 100644 index 4a4c7a4025..0000000000 --- a/lib/common/variadic/uint32OrHash.go +++ /dev/null @@ -1,155 +0,0 @@ -// Copyright 2021 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - -package variadic - -import ( - "encoding/binary" - "errors" - "io" - - "github.com/ChainSafe/gossamer/lib/common" -) - -// Uint32OrHash represents a variadic type that is either uint32 or common.Hash. -type Uint32OrHash struct { - value interface{} -} - -// NewUint32OrHash returns a new variadic.Uint32OrHash given an int, uint32, or Hash -func NewUint32OrHash(value interface{}) (*Uint32OrHash, error) { - switch v := value.(type) { - case int: // in order to accept constants int such as `NewUint32OrHash(1)` - return &Uint32OrHash{ - value: uint32(v), - }, nil - case uint: - return &Uint32OrHash{ - value: uint32(v), - }, nil - case uint32: - return &Uint32OrHash{ - value: v, - }, nil - case common.Hash: - return &Uint32OrHash{ - value: v, - }, nil - default: - return nil, errors.New("value is not uint32 or common.Hash") - } -} - -// MustNewUint32OrHash returns a new variadic.Uint32OrHash given an int, uint32, or Hash -// It panics if the input value is invalid -func MustNewUint32OrHash(value interface{}) *Uint32OrHash { - val, err := NewUint32OrHash(value) - if err != nil { - panic(err) - } - - return val -} - -// NewUint32OrHashFromBytes returns a new variadic.Uint32OrHash from an encoded variadic uint32 or hash -func NewUint32OrHashFromBytes(data []byte) *Uint32OrHash { - firstByte := data[0] - if firstByte == 0 { - return &Uint32OrHash{ - value: common.NewHash(data[1:]), - } - } else if firstByte == 1 { - num := data[1:] - if len(num) < 4 { - num = common.AppendZeroes(num, 4) - } - return &Uint32OrHash{ - value: binary.LittleEndian.Uint32(num), - } - } else { - return nil - } -} - -// Value returns the interface value. -func (x *Uint32OrHash) Value() interface{} { - if x == nil { - return nil - } - return x.value -} - -// IsHash returns true if the value is a hash -func (x *Uint32OrHash) IsHash() bool { - if x == nil { - return false - } - _, is := x.value.(common.Hash) - return is -} - -// Hash returns the value as a common.Hash. It panics if the value is not a hash. -func (x *Uint32OrHash) Hash() common.Hash { - if !x.IsHash() { - panic("value is not common.Hash") - } - - return x.value.(common.Hash) -} - -// IsUint32 returns true if the value is a uint32 -func (x *Uint32OrHash) IsUint32() bool { - if x == nil { - return false - } - - _, is := x.value.(uint32) - return is -} - -// Uint32 returns the value as a uint32. It panics if the value is not a uint32. -func (x *Uint32OrHash) Uint32() uint32 { - if !x.IsUint32() { - panic("value is not uint32") - } - - return x.value.(uint32) -} - -// Encode will encode a Uint32OrHash using SCALE -func (x *Uint32OrHash) Encode() ([]byte, error) { - var encMsg []byte - switch c := x.Value().(type) { - case uint32: - startingBlockByteArray := make([]byte, 4) - binary.LittleEndian.PutUint32(startingBlockByteArray, c) - encMsg = append(encMsg, append([]byte{1}, startingBlockByteArray...)...) - case common.Hash: - encMsg = append(encMsg, append([]byte{0}, c.ToBytes()...)...) - } - return encMsg, nil -} - -// Decode decodes a value into a Uint32OrHash -func (x *Uint32OrHash) Decode(r io.Reader) error { - startingBlockType, err := common.ReadByte(r) - if err != nil { - return err - } - if startingBlockType == 0 { - hash := make([]byte, 32) - _, err = r.Read(hash) - if err != nil { - return err - } - x.value = common.NewHash(hash) - } else { - num := make([]byte, 4) - _, err = r.Read(num) - if err != nil { - return err - } - x.value = binary.LittleEndian.Uint32(num) - } - return nil -} diff --git a/lib/common/variadic/uint32OrHash_test.go b/lib/common/variadic/uint32OrHash_test.go deleted file mode 100644 index f5a77d4916..0000000000 --- a/lib/common/variadic/uint32OrHash_test.go +++ /dev/null @@ -1,79 +0,0 @@ -// Copyright 2021 ChainSafe Systems (ON) -// SPDX-License-Identifier: LGPL-3.0-only - -package variadic - -import ( - "encoding/binary" - "testing" - - "github.com/ChainSafe/gossamer/lib/common" - "github.com/stretchr/testify/require" -) - -func TestNewUint32OrHash(t *testing.T) { - hash, err := common.HexToHash("0xdcd1346701ca8396496e52aa2785b1748deb6db09551b72159dcb3e08991025b") - require.NoError(t, err) - - res, err := NewUint32OrHash(hash) - require.NoError(t, err) - require.Equal(t, res.Value(), hash) - - num := 77 - - res, err = NewUint32OrHash(num) - require.NoError(t, err) - require.Equal(t, uint32(num), res.Value()) - - res, err = NewUint32OrHash(uint(num)) - require.NoError(t, err) - require.Equal(t, uint32(num), res.Value()) - - res, err = NewUint32OrHash(uint32(num)) - require.NoError(t, err) - require.Equal(t, uint32(num), res.Value()) -} - -func TestNewUint32OrHashFromBytes(t *testing.T) { - genesisHash, err := common.HexToBytes("0xdcd1346701ca8396496e52aa2785b1748deb6db09551b72159dcb3e08991025b") - require.NoError(t, err) - - buf := make([]byte, 4) - binary.LittleEndian.PutUint32(buf, uint32(1)) - - for _, x := range []struct { - description string - targetHash []byte - targetFirstByte uint8 - expectedType interface{} - }{ - { - description: "block request with genesis hash type 0", - targetHash: genesisHash, - targetFirstByte: 0, - expectedType: common.Hash{}, - }, - { - description: "block request with Block Number int type 1", - targetHash: buf, - targetFirstByte: 1, - expectedType: (uint32)(0), - }, - } { - t.Run(x.description, func(t *testing.T) { - data := append([]byte{x.targetFirstByte}, x.targetHash...) - - val := NewUint32OrHashFromBytes(data) - require.NoError(t, err) - require.IsType(t, x.expectedType, val.Value()) - - if x.expectedType == (uint32)(0) { - startingBlockByteArray := make([]byte, 4) - binary.LittleEndian.PutUint32(startingBlockByteArray, val.Value().(uint32)) - require.Equal(t, x.targetHash, startingBlockByteArray) - } else { - require.Equal(t, common.NewHash(x.targetHash), val.Value()) - } - }) - } -} diff --git a/scripts/retrieve_block/retrieve_block.go b/scripts/retrieve_block/retrieve_block.go index dc2098efa9..a9958669a5 100644 --- a/scripts/retrieve_block/retrieve_block.go +++ b/scripts/retrieve_block/retrieve_block.go @@ -13,7 +13,6 @@ import ( "github.com/ChainSafe/gossamer/dot/network/messages" "github.com/ChainSafe/gossamer/lib/common" - "github.com/ChainSafe/gossamer/lib/common/variadic" "github.com/ChainSafe/gossamer/scripts/p2p" lip2pnetwork "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/protocol" @@ -46,19 +45,17 @@ func buildRequestMessage(arg string) *messages.BlockRequestMessage { return nil } -func parseTargetBlock(arg string) variadic.Uint32OrHash { - var value any - value, err := strconv.Atoi(arg) - if err != nil { - value = common.MustHexToHash(arg) +func parseTargetBlock(arg string) messages.FromBlock { + if strings.HasPrefix(arg, "0x") { + return *messages.NewFromBlock(common.MustHexToHash(arg)) } - v, err := variadic.NewUint32OrHash(value) + value, err := strconv.Atoi(arg) if err != nil { - log.Fatalf("\ncannot parse variadic type: %s", err.Error()) + log.Fatalf("\ntrying to convert %v to number: %s", arg, err.Error()) } - return *v + return *messages.NewFromBlock(uint(value)) } func waitAndStoreResponse(stream lip2pnetwork.Stream, outputFile string) bool { diff --git a/scripts/retrieve_block/retrieve_block_test.go b/scripts/retrieve_block/retrieve_block_test.go index 54bd348c67..2153e879bb 100644 --- a/scripts/retrieve_block/retrieve_block_test.go +++ b/scripts/retrieve_block/retrieve_block_test.go @@ -8,7 +8,6 @@ import ( "github.com/ChainSafe/gossamer/dot/network/messages" "github.com/ChainSafe/gossamer/lib/common" - "github.com/ChainSafe/gossamer/lib/common/variadic" "github.com/stretchr/testify/require" ) @@ -20,35 +19,35 @@ func TestBuildRequestMessage(t *testing.T) { { arg: "10", expected: messages.NewBlockRequest( - *variadic.MustNewUint32OrHash(uint(10)), 1, + *messages.NewFromBlock(uint(10)), 1, messages.BootstrapRequestData, messages.Ascending), }, { arg: "0x9b0211aadcef4bb65e69346cfd256ddd2abcb674271326b08f0975dac7c17bc7", - expected: messages.NewBlockRequest(*variadic.MustNewUint32OrHash( + expected: messages.NewBlockRequest(*messages.NewFromBlock( common.MustHexToHash("0x9b0211aadcef4bb65e69346cfd256ddd2abcb674271326b08f0975dac7c17bc7")), 1, messages.BootstrapRequestData, messages.Ascending), }, { arg: "0x9b0211aadcef4bb65e69346cfd256ddd2abcb674271326b08f0975dac7c17bc7,asc,20", - expected: messages.NewBlockRequest(*variadic.MustNewUint32OrHash( + expected: messages.NewBlockRequest(*messages.NewFromBlock( common.MustHexToHash("0x9b0211aadcef4bb65e69346cfd256ddd2abcb674271326b08f0975dac7c17bc7")), 20, messages.BootstrapRequestData, messages.Ascending), }, { arg: "1,asc,20", - expected: messages.NewBlockRequest(*variadic.MustNewUint32OrHash(uint(1)), + expected: messages.NewBlockRequest(*messages.NewFromBlock(uint(1)), 20, messages.BootstrapRequestData, messages.Ascending), }, { arg: "0x9b0211aadcef4bb65e69346cfd256ddd2abcb674271326b08f0975dac7c17bc7,desc,20", - expected: messages.NewBlockRequest(*variadic.MustNewUint32OrHash( + expected: messages.NewBlockRequest(*messages.NewFromBlock( common.MustHexToHash("0x9b0211aadcef4bb65e69346cfd256ddd2abcb674271326b08f0975dac7c17bc7")), 20, messages.BootstrapRequestData, messages.Descending), }, { arg: "1,desc,20", - expected: messages.NewBlockRequest(*variadic.MustNewUint32OrHash(uint(1)), + expected: messages.NewBlockRequest(*messages.NewFromBlock(uint(1)), 20, messages.BootstrapRequestData, messages.Descending), }, }