Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tool(script): state retrieval script #4140

Merged
merged 16 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions dot/network/messages/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

pb "github.com/ChainSafe/gossamer/dot/network/proto"
"github.com/ChainSafe/gossamer/lib/common"
"github.com/ChainSafe/gossamer/pkg/trie"
"google.golang.org/protobuf/proto"
)

Expand Down Expand Up @@ -52,3 +53,48 @@ func (s *StateRequest) Decode(in []byte) error {
s.NoProof = message.NoProof
return nil
}

type StateResponse struct {
Entries []KeyValueStateEntry
Proof []byte
}

type KeyValueStateEntry struct {
StateRoot common.Hash
StateEntries trie.Entries
Complete bool
}

func (s *StateResponse) Decode(in []byte) error {
decodedResponse := &pb.StateResponse{}
err := proto.Unmarshal(in, decodedResponse)
if err != nil {
return err
}

s.Proof = make([]byte, len(decodedResponse.Proof))
copy(s.Proof, decodedResponse.Proof)

s.Entries = make([]KeyValueStateEntry, len(decodedResponse.Entries))
for idx, entry := range decodedResponse.Entries {
s.Entries[idx] = KeyValueStateEntry{
Complete: entry.Complete,
StateRoot: common.BytesToHash(entry.StateRoot),
}

trieFragment := make(trie.Entries, len(entry.Entries))
for stateEntryIdx, stateEntry := range entry.Entries {
trieFragment[stateEntryIdx] = trie.Entry{
Key: make([]byte, len(stateEntry.Key)),
Value: make([]byte, len(stateEntry.Value)),
}

copy(trieFragment[stateEntryIdx].Key, stateEntry.Key)
copy(trieFragment[stateEntryIdx].Value, stateEntry.Value)
}

s.Entries[idx].StateEntries = trieFragment
}

return nil
}
25 changes: 15 additions & 10 deletions scripts/p2p/common_p2p.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package p2p

import (
"encoding/json"
"errors"
"fmt"
"io"
"log"
Expand Down Expand Up @@ -117,20 +118,22 @@ func parsePeerAddress(arg string) peer.AddrInfo {
return *p
}

func ReadStream(stream lip2pnetwork.Stream) []byte {
var errZeroLength = errors.New("zero length")

func ReadStream(stream lip2pnetwork.Stream) ([]byte, error) {
responseBuf := make([]byte, network.MaxBlockResponseSize)

length, _, err := network.ReadLEB128ToUint64(stream)
if err != nil {
log.Fatalf("reading response length: %s", err.Error())
return nil, fmt.Errorf("reading leb128: %w", err)
}

if length == 0 {
return nil
return nil, errZeroLength
}

if length > network.MaxBlockResponseSize {
log.Fatalf("%s: max %d, got %d", network.ErrGreaterThanMaxSize, network.MaxBlockResponseSize, length)
return nil, fmt.Errorf("%w: max %d, got %d", network.ErrGreaterThanMaxSize, network.MaxBlockResponseSize, length)
}

if length > uint64(len(responseBuf)) {
Expand All @@ -142,22 +145,22 @@ func ReadStream(stream lip2pnetwork.Stream) []byte {
for tot < int(length) {
n, err := stream.Read(responseBuf[tot:])
if err != nil {
log.Fatalf("reading stream: %s", err.Error())
return nil, fmt.Errorf("reading stream: %w", err)
}
tot += n
}

if tot != int(length) {
log.Fatalf("%s: expected %d bytes, received %d bytes", network.ErrFailedToReadEntireMessage, length, tot)
return nil, fmt.Errorf("%w: expected %d bytes, received %d bytes", network.ErrFailedToReadEntireMessage, length, tot)
}

return responseBuf[:tot]
return responseBuf[:tot], nil
}

func WriteStream(msg *messages.BlockRequestMessage, stream lip2pnetwork.Stream) {
func WriteStream(msg messages.P2PMessage, stream lip2pnetwork.Stream) error {
encMsg, err := msg.Encode()
if err != nil {
log.Fatalf("encoding message: %s", err.Error())
return fmt.Errorf("encoding message: %w", err)
}

msgLen := uint64(len(encMsg))
Expand All @@ -166,6 +169,8 @@ func WriteStream(msg *messages.BlockRequestMessage, stream lip2pnetwork.Stream)

_, err = stream.Write(encMsg)
if err != nil {
log.Fatalf("writing message: %s", err.Error())
return fmt.Errorf("writing message: %w", err)
}

return nil
}
16 changes: 13 additions & 3 deletions scripts/retrieve_block/retrieve_block.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,18 @@ func parseTargetBlock(arg string) variadic.Uint32OrHash {
}

func waitAndStoreResponse(stream lip2pnetwork.Stream, outputFile string) bool {
output := p2p.ReadStream(stream)
output, err := p2p.ReadStream(stream)
if len(output) == 0 {
return false
}

if err != nil {
log.Println(err.Error())
return false
}

blockResponse := &messages.BlockResponseMessage{}
err := blockResponse.Decode(output)
err = blockResponse.Decode(output)
if err != nil {
log.Fatalf("could not decode block response message: %s", err.Error())
}
Expand Down Expand Up @@ -125,7 +130,12 @@ func main() {
}

defer stream.Close() //nolint:errcheck
p2p.WriteStream(requestMessage, stream)
err = p2p.WriteStream(requestMessage, stream)
if err != nil {
log.Println(err.Error())
continue
}

if !waitAndStoreResponse(stream, os.Args[3]) {
continue
}
Expand Down
223 changes: 223 additions & 0 deletions scripts/retrieve_state/retrieve_state.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
// Copyright 2024 ChainSafe Systems (ON)
// SPDX-License-Identifier: LGPL-3.0-only

package main

import (
"context"
"crypto/rand"
"encoding/json"
"errors"
"fmt"
"log"
"math/big"
"os"

"github.com/ChainSafe/gossamer/dot/network/messages"
"github.com/ChainSafe/gossamer/lib/common"
"github.com/ChainSafe/gossamer/pkg/scale"
"github.com/ChainSafe/gossamer/pkg/trie"
"github.com/ChainSafe/gossamer/pkg/trie/inmemory"
"github.com/ChainSafe/gossamer/scripts/p2p"
lip2pnetwork "github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/protocol"
)

var (
errZeroLengthResponse = errors.New("zero length response")
errEmptyStateEntries = errors.New("empty state entries")
)

type StateRequestProvider struct {
lastKeys [][]byte
collectedResponses []*messages.StateResponse
targetHash common.Hash
completed bool
}

func NewStateRequestProvider(target common.Hash) *StateRequestProvider {
return &StateRequestProvider{
lastKeys: [][]byte{},
targetHash: target,
collectedResponses: make([]*messages.StateResponse, 0),
}
}

func (s *StateRequestProvider) buildRequest() *messages.StateRequest {
return &messages.StateRequest{
Block: s.targetHash,
Start: s.lastKeys,
NoProof: true,
}
}

func (s *StateRequestProvider) processResponse(stateResponse *messages.StateResponse) (err error) {
if len(stateResponse.Entries) == 0 {
return errEmptyStateEntries
}

log.Printf("retrieved %d entries\n", len(stateResponse.Entries))
for idx, entry := range stateResponse.Entries {
log.Printf("\t#%d with %d entries (complete: %v, root: %s)\n",
idx, len(entry.StateEntries), entry.Complete, entry.StateRoot.String())
}

s.collectedResponses = append(s.collectedResponses, stateResponse)

if len(s.lastKeys) == 2 && len(stateResponse.Entries[0].StateEntries) == 0 {
// pop last item and keep the first
// do not remove the parent trie position.
s.lastKeys = s.lastKeys[:len(s.lastKeys)-1]
} else {
s.lastKeys = [][]byte{}
}

for _, state := range stateResponse.Entries {
if !state.Complete {
lastItemInResponse := state.StateEntries[len(state.StateEntries)-1]
s.lastKeys = append(s.lastKeys, lastItemInResponse.Key)
s.completed = false
} else {
s.completed = true
}
}

return nil
}

func (s *StateRequestProvider) buildTrie(expectedStorageRootHash common.Hash, destination string) error {
tt := inmemory.NewEmptyTrie()
tt.SetVersion(trie.V1)

entries := make([]string, 0)

for _, stateResponse := range s.collectedResponses {
for _, stateEntry := range stateResponse.Entries {
for _, kv := range stateEntry.StateEntries {

trieEntry := trie.Entry{Key: kv.Key, Value: kv.Value}
encodedTrieEntry, err := scale.Marshal(trieEntry)
if err != nil {
return err
}
entries = append(entries, common.BytesToHex(encodedTrieEntry))

if err := tt.Put(kv.Key, kv.Value); err != nil {
return err
}
}
}
}

rootHash := tt.MustHash()
if expectedStorageRootHash != rootHash {
log.Printf("\n\texpected root hash: %s\ngot root hash: %s\n",
expectedStorageRootHash.String(), rootHash.String())
}

fmt.Printf("=> trie root hash: %s\n", tt.MustHash().String())
encodedEntries, err := json.Marshal(entries)
if err != nil {
return err
}

err = os.WriteFile(destination, encodedEntries, 0o600)
return err
}

func main() {
if len(os.Args) != 5 {
log.Fatalf(`
script usage:
go run retrieve_state.go [hash] [expected storage root hash] [network chain spec] [output file]`)
}

targetBlockHash := common.MustHexToHash(os.Args[1])
expectedStorageRootHash := common.MustHexToHash(os.Args[2])
chain := p2p.ParseChainSpec(os.Args[3])

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

protocolID := protocol.ID(fmt.Sprintf("/%s/state/2", chain.ProtocolID))

p2pHost := p2p.SetupP2PClient()
bootnodes := p2p.ParseBootnodes(chain.Bootnodes)
provider := NewStateRequestProvider(targetBlockHash)

var (
pid peer.AddrInfo
refreshPeerID bool = true
)

for !provider.completed {
if refreshPeerID {
rng, err := rand.Int(rand.Reader, big.NewInt(int64(len(bootnodes))))
if err != nil {
panic(err)
}

pid = bootnodes[rng.Uint64()]
err = p2pHost.Connect(ctx, pid)
if err != nil {
log.Printf("WARN: while connecting: %s\n", err.Error())
continue
}

log.Printf("OK: requesting from peer %s\n", pid.String())
}

stream, err := p2pHost.NewStream(ctx, pid.ID, protocolID)
if err != nil {
log.Printf("WARN: failed to create stream using protocol %s: %s", protocolID, err.Error())
refreshPeerID = false
continue
}

err = sendAndProcessResponse(provider, stream)
if err != nil {
log.Printf("WARN: %s\n", err.Error())
refreshPeerID = true
continue
}

// keep using the same peer
refreshPeerID = false
}

if err := provider.buildTrie(expectedStorageRootHash, os.Args[4]); err != nil {
panic(err)
}
}

func sendAndProcessResponse(provider *StateRequestProvider, stream lip2pnetwork.Stream) error {
defer stream.Close() //nolint:errcheck

err := p2p.WriteStream(provider.buildRequest(), stream)
if err != nil {
return err
}

output, err := p2p.ReadStream(stream)
if err != nil {
return err
}

if len(output) == 0 {
return errZeroLengthResponse
}

stateResponse := &messages.StateResponse{}
err = stateResponse.Decode(output)
if err != nil {
return err
}

err = provider.processResponse(stateResponse)
if err != nil {
return err
}

return nil
}
Loading