diff --git a/dot/network/messages/state.go b/dot/network/messages/state.go index efb48765b3..baae4e1487 100644 --- a/dot/network/messages/state.go +++ b/dot/network/messages/state.go @@ -8,6 +8,7 @@ import ( pb "github.com/ChainSafe/gossamer/dot/network/proto" "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/pkg/trie" "google.golang.org/protobuf/proto" ) @@ -52,3 +53,48 @@ func (s *StateRequest) Decode(in []byte) error { s.NoProof = message.NoProof return nil } + +type StateResponse struct { + Entries []KeyValueStateEntry + Proof []byte +} + +type KeyValueStateEntry struct { + StateRoot common.Hash + StateEntries trie.Entries + Complete bool +} + +func (s *StateResponse) Decode(in []byte) error { + decodedResponse := &pb.StateResponse{} + err := proto.Unmarshal(in, decodedResponse) + if err != nil { + return err + } + + s.Proof = make([]byte, len(decodedResponse.Proof)) + copy(s.Proof, decodedResponse.Proof) + + s.Entries = make([]KeyValueStateEntry, len(decodedResponse.Entries)) + for idx, entry := range decodedResponse.Entries { + s.Entries[idx] = KeyValueStateEntry{ + Complete: entry.Complete, + StateRoot: common.BytesToHash(entry.StateRoot), + } + + trieFragment := make(trie.Entries, len(entry.Entries)) + for stateEntryIdx, stateEntry := range entry.Entries { + trieFragment[stateEntryIdx] = trie.Entry{ + Key: make([]byte, len(stateEntry.Key)), + Value: make([]byte, len(stateEntry.Value)), + } + + copy(trieFragment[stateEntryIdx].Key, stateEntry.Key) + copy(trieFragment[stateEntryIdx].Value, stateEntry.Value) + } + + s.Entries[idx].StateEntries = trieFragment + } + + return nil +} diff --git a/scripts/p2p/common_p2p.go b/scripts/p2p/common_p2p.go index fd03ca3924..82c97333dd 100644 --- a/scripts/p2p/common_p2p.go +++ b/scripts/p2p/common_p2p.go @@ -5,6 +5,7 @@ package p2p import ( "encoding/json" + "errors" "fmt" "io" "log" @@ -117,20 +118,22 @@ func parsePeerAddress(arg string) peer.AddrInfo { return *p } -func ReadStream(stream lip2pnetwork.Stream) []byte { +var errZeroLength = errors.New("zero length") + +func ReadStream(stream lip2pnetwork.Stream) ([]byte, error) { responseBuf := make([]byte, network.MaxBlockResponseSize) length, _, err := network.ReadLEB128ToUint64(stream) if err != nil { - log.Fatalf("reading response length: %s", err.Error()) + return nil, fmt.Errorf("reading leb128: %w", err) } if length == 0 { - return nil + return nil, errZeroLength } if length > network.MaxBlockResponseSize { - log.Fatalf("%s: max %d, got %d", network.ErrGreaterThanMaxSize, network.MaxBlockResponseSize, length) + return nil, fmt.Errorf("%w: max %d, got %d", network.ErrGreaterThanMaxSize, network.MaxBlockResponseSize, length) } if length > uint64(len(responseBuf)) { @@ -142,22 +145,22 @@ func ReadStream(stream lip2pnetwork.Stream) []byte { for tot < int(length) { n, err := stream.Read(responseBuf[tot:]) if err != nil { - log.Fatalf("reading stream: %s", err.Error()) + return nil, fmt.Errorf("reading stream: %w", err) } tot += n } if tot != int(length) { - log.Fatalf("%s: expected %d bytes, received %d bytes", network.ErrFailedToReadEntireMessage, length, tot) + return nil, fmt.Errorf("%w: expected %d bytes, received %d bytes", network.ErrFailedToReadEntireMessage, length, tot) } - return responseBuf[:tot] + return responseBuf[:tot], nil } -func WriteStream(msg *messages.BlockRequestMessage, stream lip2pnetwork.Stream) { +func WriteStream(msg messages.P2PMessage, stream lip2pnetwork.Stream) error { encMsg, err := msg.Encode() if err != nil { - log.Fatalf("encoding message: %s", err.Error()) + return fmt.Errorf("encoding message: %w", err) } msgLen := uint64(len(encMsg)) @@ -166,6 +169,8 @@ func WriteStream(msg *messages.BlockRequestMessage, stream lip2pnetwork.Stream) _, err = stream.Write(encMsg) if err != nil { - log.Fatalf("writing message: %s", err.Error()) + return fmt.Errorf("writing message: %w", err) } + + return nil } diff --git a/scripts/retrieve_block/retrieve_block.go b/scripts/retrieve_block/retrieve_block.go index 1c837c9966..dc2098efa9 100644 --- a/scripts/retrieve_block/retrieve_block.go +++ b/scripts/retrieve_block/retrieve_block.go @@ -62,13 +62,18 @@ func parseTargetBlock(arg string) variadic.Uint32OrHash { } func waitAndStoreResponse(stream lip2pnetwork.Stream, outputFile string) bool { - output := p2p.ReadStream(stream) + output, err := p2p.ReadStream(stream) if len(output) == 0 { return false } + if err != nil { + log.Println(err.Error()) + return false + } + blockResponse := &messages.BlockResponseMessage{} - err := blockResponse.Decode(output) + err = blockResponse.Decode(output) if err != nil { log.Fatalf("could not decode block response message: %s", err.Error()) } @@ -125,7 +130,12 @@ func main() { } defer stream.Close() //nolint:errcheck - p2p.WriteStream(requestMessage, stream) + err = p2p.WriteStream(requestMessage, stream) + if err != nil { + log.Println(err.Error()) + continue + } + if !waitAndStoreResponse(stream, os.Args[3]) { continue } diff --git a/scripts/retrieve_state/retrieve_state.go b/scripts/retrieve_state/retrieve_state.go new file mode 100644 index 0000000000..236d410b35 --- /dev/null +++ b/scripts/retrieve_state/retrieve_state.go @@ -0,0 +1,223 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package main + +import ( + "context" + "crypto/rand" + "encoding/json" + "errors" + "fmt" + "log" + "math/big" + "os" + + "github.com/ChainSafe/gossamer/dot/network/messages" + "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/pkg/scale" + "github.com/ChainSafe/gossamer/pkg/trie" + "github.com/ChainSafe/gossamer/pkg/trie/inmemory" + "github.com/ChainSafe/gossamer/scripts/p2p" + lip2pnetwork "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/protocol" +) + +var ( + errZeroLengthResponse = errors.New("zero length response") + errEmptyStateEntries = errors.New("empty state entries") +) + +type StateRequestProvider struct { + lastKeys [][]byte + collectedResponses []*messages.StateResponse + targetHash common.Hash + completed bool +} + +func NewStateRequestProvider(target common.Hash) *StateRequestProvider { + return &StateRequestProvider{ + lastKeys: [][]byte{}, + targetHash: target, + collectedResponses: make([]*messages.StateResponse, 0), + } +} + +func (s *StateRequestProvider) buildRequest() *messages.StateRequest { + return &messages.StateRequest{ + Block: s.targetHash, + Start: s.lastKeys, + NoProof: true, + } +} + +func (s *StateRequestProvider) processResponse(stateResponse *messages.StateResponse) (err error) { + if len(stateResponse.Entries) == 0 { + return errEmptyStateEntries + } + + log.Printf("retrieved %d entries\n", len(stateResponse.Entries)) + for idx, entry := range stateResponse.Entries { + log.Printf("\t#%d with %d entries (complete: %v, root: %s)\n", + idx, len(entry.StateEntries), entry.Complete, entry.StateRoot.String()) + } + + s.collectedResponses = append(s.collectedResponses, stateResponse) + + if len(s.lastKeys) == 2 && len(stateResponse.Entries[0].StateEntries) == 0 { + // pop last item and keep the first + // do not remove the parent trie position. + s.lastKeys = s.lastKeys[:len(s.lastKeys)-1] + } else { + s.lastKeys = [][]byte{} + } + + for _, state := range stateResponse.Entries { + if !state.Complete { + lastItemInResponse := state.StateEntries[len(state.StateEntries)-1] + s.lastKeys = append(s.lastKeys, lastItemInResponse.Key) + s.completed = false + } else { + s.completed = true + } + } + + return nil +} + +func (s *StateRequestProvider) buildTrie(expectedStorageRootHash common.Hash, destination string) error { + tt := inmemory.NewEmptyTrie() + tt.SetVersion(trie.V1) + + entries := make([]string, 0) + + for _, stateResponse := range s.collectedResponses { + for _, stateEntry := range stateResponse.Entries { + for _, kv := range stateEntry.StateEntries { + + trieEntry := trie.Entry{Key: kv.Key, Value: kv.Value} + encodedTrieEntry, err := scale.Marshal(trieEntry) + if err != nil { + return err + } + entries = append(entries, common.BytesToHex(encodedTrieEntry)) + + if err := tt.Put(kv.Key, kv.Value); err != nil { + return err + } + } + } + } + + rootHash := tt.MustHash() + if expectedStorageRootHash != rootHash { + log.Printf("\n\texpected root hash: %s\ngot root hash: %s\n", + expectedStorageRootHash.String(), rootHash.String()) + } + + fmt.Printf("=> trie root hash: %s\n", tt.MustHash().String()) + encodedEntries, err := json.Marshal(entries) + if err != nil { + return err + } + + err = os.WriteFile(destination, encodedEntries, 0o600) + return err +} + +func main() { + if len(os.Args) != 5 { + log.Fatalf(` + script usage: + go run retrieve_state.go [hash] [expected storage root hash] [network chain spec] [output file]`) + } + + targetBlockHash := common.MustHexToHash(os.Args[1]) + expectedStorageRootHash := common.MustHexToHash(os.Args[2]) + chain := p2p.ParseChainSpec(os.Args[3]) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + protocolID := protocol.ID(fmt.Sprintf("/%s/state/2", chain.ProtocolID)) + + p2pHost := p2p.SetupP2PClient() + bootnodes := p2p.ParseBootnodes(chain.Bootnodes) + provider := NewStateRequestProvider(targetBlockHash) + + var ( + pid peer.AddrInfo + refreshPeerID bool = true + ) + + for !provider.completed { + if refreshPeerID { + rng, err := rand.Int(rand.Reader, big.NewInt(int64(len(bootnodes)))) + if err != nil { + panic(err) + } + + pid = bootnodes[rng.Uint64()] + err = p2pHost.Connect(ctx, pid) + if err != nil { + log.Printf("WARN: while connecting: %s\n", err.Error()) + continue + } + + log.Printf("OK: requesting from peer %s\n", pid.String()) + } + + stream, err := p2pHost.NewStream(ctx, pid.ID, protocolID) + if err != nil { + log.Printf("WARN: failed to create stream using protocol %s: %s", protocolID, err.Error()) + refreshPeerID = false + continue + } + + err = sendAndProcessResponse(provider, stream) + if err != nil { + log.Printf("WARN: %s\n", err.Error()) + refreshPeerID = true + continue + } + + // keep using the same peer + refreshPeerID = false + } + + if err := provider.buildTrie(expectedStorageRootHash, os.Args[4]); err != nil { + panic(err) + } +} + +func sendAndProcessResponse(provider *StateRequestProvider, stream lip2pnetwork.Stream) error { + defer stream.Close() //nolint:errcheck + + err := p2p.WriteStream(provider.buildRequest(), stream) + if err != nil { + return err + } + + output, err := p2p.ReadStream(stream) + if err != nil { + return err + } + + if len(output) == 0 { + return errZeroLengthResponse + } + + stateResponse := &messages.StateResponse{} + err = stateResponse.Decode(output) + if err != nil { + return err + } + + err = provider.processResponse(stateResponse) + if err != nil { + return err + } + + return nil +}