diff --git a/core/state/snapshot/iterator.go b/core/state/snapshot/iterator.go index b07c2938f3..c9b98353fb 100644 --- a/core/state/snapshot/iterator.go +++ b/core/state/snapshot/iterator.go @@ -351,10 +351,16 @@ type diskStorageIterator struct { // is always false. func (dl *diskLayer) StorageIterator(account common.Hash, seek common.Hash) (StorageIterator, bool) { pos := common.TrimRightZeroes(seek[:]) + + // create prefix to be rawdb.SnapshotStoragePrefix + account[:] + prefix := make([]byte, len(rawdb.SnapshotStoragePrefix)+common.HashLength) + copy(prefix, rawdb.SnapshotStoragePrefix) + copy(prefix[len(rawdb.SnapshotStoragePrefix):], account[:]) + return &diskStorageIterator{ layer: dl, account: account, - it: dl.diskdb.NewIterator(append(rawdb.SnapshotStoragePrefix, account.Bytes()...), pos), + it: dl.diskdb.NewIterator(prefix, pos), }, false } diff --git a/core/state/snapshot/snapshot.go b/core/state/snapshot/snapshot.go index abfa2c2772..0fd824f512 100644 --- a/core/state/snapshot/snapshot.go +++ b/core/state/snapshot/snapshot.go @@ -912,6 +912,21 @@ func (t *Tree) DiskRoot() common.Hash { return t.diskRoot() } +func (t *Tree) DiskAccountIterator(seek common.Hash) AccountIterator { + t.lock.Lock() + defer t.lock.Unlock() + + return t.disklayer().AccountIterator(seek) +} + +func (t *Tree) DiskStorageIterator(account common.Hash, seek common.Hash) StorageIterator { + t.lock.Lock() + defer t.lock.Unlock() + + it, _ := t.disklayer().StorageIterator(account, seek) + return it +} + // NewTestTree creates a *Tree with a pre-populated diskLayer func NewTestTree(diskdb ethdb.KeyValueStore, blockHash, root common.Hash) *Tree { base := &diskLayer{ diff --git a/plugin/evm/atomic_syncer_test.go b/plugin/evm/atomic_syncer_test.go index 7206a52a26..63e729d76f 100644 --- a/plugin/evm/atomic_syncer_test.go +++ b/plugin/evm/atomic_syncer_test.go @@ -41,7 +41,7 @@ func testAtomicSyncer(t *testing.T, serverTrieDB *trie.Database, targetHeight ui numLeaves := 0 mockClient := syncclient.NewMockClient( message.Codec, - handlers.NewLeafsRequestHandler(serverTrieDB, message.Codec, handlerstats.NewNoopHandlerStats()), + handlers.NewLeafsRequestHandler(serverTrieDB, nil, message.Codec, handlerstats.NewNoopHandlerStats()), nil, nil, ) diff --git a/plugin/evm/message/leafs_request.go b/plugin/evm/message/leafs_request.go index 6e266477ae..380b24ebb7 100644 --- a/plugin/evm/message/leafs_request.go +++ b/plugin/evm/message/leafs_request.go @@ -42,6 +42,7 @@ func (nt NodeType) String() string { // NodeType outlines which trie to read from state/atomic. type LeafsRequest struct { Root common.Hash `serialize:"true"` + Account common.Hash `serialize:"true"` Start []byte `serialize:"true"` End []byte `serialize:"true"` Limit uint16 `serialize:"true"` @@ -50,8 +51,8 @@ type LeafsRequest struct { func (l LeafsRequest) String() string { return fmt.Sprintf( - "LeafsRequest(Root=%s, Start=%s, End=%s, Limit=%d, NodeType=%s)", - l.Root, common.Bytes2Hex(l.Start), common.Bytes2Hex(l.End), l.Limit, l.NodeType, + "LeafsRequest(Root=%s, Account=%s, Start=%s, End=%s, Limit=%d, NodeType=%s)", + l.Root, l.Account, common.Bytes2Hex(l.Start), common.Bytes2Hex(l.End), l.Limit, l.NodeType, ) } diff --git a/plugin/evm/message/leafs_request_test.go b/plugin/evm/message/leafs_request_test.go index f773d91925..543222619c 100644 --- a/plugin/evm/message/leafs_request_test.go +++ b/plugin/evm/message/leafs_request_test.go @@ -40,7 +40,7 @@ func TestMarshalLeafsRequest(t *testing.T) { NodeType: StateTrieNode, } - base64LeafsRequest := "AAAAAAAAAAAAAAAAAAAAAABpbSBST09UaW5nIGZvciB5YQAAACBS/fwHIYJlTxY/Xw+aYh1ylWbHTRADfE17uwQH0eLGSQAAACCBhVrYaB0NhtHpHgAWeTnLZpTSxCKs0gigByk5SH9pmQQAAQ==" + base64LeafsRequest := "AAAAAAAAAAAAAAAAAAAAAABpbSBST09UaW5nIGZvciB5YQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIFL9/AchgmVPFj9fD5piHXKVZsdNEAN8TXu7BAfR4sZJAAAAIIGFWthoHQ2G0ekeABZ5OctmlNLEIqzSCKAHKTlIf2mZBAAB" leafsRequestBytes, err := Codec.Marshal(Version, leafsRequest) assert.NoError(t, err) diff --git a/plugin/evm/vm.go b/plugin/evm/vm.go index 9f8d834675..0744f22775 100644 --- a/plugin/evm/vm.go +++ b/plugin/evm/vm.go @@ -930,7 +930,13 @@ func (vm *VM) setAppRequestHandlers() { Cache: vm.config.StateSyncServerTrieCache, }, ) - syncRequestHandler := handlers.NewSyncHandler(vm.chain.BlockChain().GetBlock, evmTrieDB, vm.atomicTrie.TrieDB(), vm.networkCodec, handlerstats.NewHandlerStats(metrics.Enabled)) + syncRequestHandler := handlers.NewSyncHandler( + vm.chain.BlockChain(), + evmTrieDB, + vm.atomicTrie.TrieDB(), + vm.networkCodec, + handlerstats.NewHandlerStats(metrics.Enabled), + ) vm.Network.SetRequestHandler(syncRequestHandler) } diff --git a/sync/client/client.go b/sync/client/client.go index 19baa79c68..55f29341ce 100644 --- a/sync/client/client.go +++ b/sync/client/client.go @@ -32,7 +32,7 @@ import ( ) var ( - StateSyncVersion = version.NewDefaultApplication(constants.PlatformName, 1, 7, 11) + StateSyncVersion = version.NewDefaultApplication(constants.PlatformName, 1, 7, 12) errEmptyResponse = errors.New("empty response") errTooManyBlocks = errors.New("response contains more blocks than requested") errHashMismatch = errors.New("hash does not match expected value") diff --git a/sync/client/client_test.go b/sync/client/client_test.go index 378d4860e6..91814a30fb 100644 --- a/sync/client/client_test.go +++ b/sync/client/client_test.go @@ -375,15 +375,16 @@ func TestGetBlocks(t *testing.T) { } } -func buildGetter(blocks []*types.Block) func(hash common.Hash, height uint64) *types.Block { - return func(blockHash common.Hash, blockHeight uint64) *types.Block { - requestedBlock := blocks[blockHeight] - if requestedBlock.Hash() != blockHash { - fmt.Printf("ERROR height=%d, hash=%s, parentHash=%s, reqHash=%s\n", blockHeight, blockHash, requestedBlock.ParentHash(), requestedBlock.Hash()) - return nil - } - - return requestedBlock +func buildGetter(blocks []*types.Block) handlers.BlockProvider { + return &handlers.TestBlockProvider{ + GetBlockFn: func(blockHash common.Hash, blockHeight uint64) *types.Block { + requestedBlock := blocks[blockHeight] + if requestedBlock.Hash() != blockHash { + fmt.Printf("ERROR height=%d, hash=%s, parentHash=%s, reqHash=%s\n", blockHeight, blockHash, requestedBlock.ParentHash(), requestedBlock.Hash()) + return nil + } + return requestedBlock + }, } } @@ -396,7 +397,7 @@ func TestGetLeafs(t *testing.T) { largeTrieRoot, largeTrieKeys, _ := trie.GenerateTrie(t, trieDB, 100_000, common.HashLength) smallTrieRoot, _, _ := trie.GenerateTrie(t, trieDB, leafsLimit, common.HashLength) - handler := handlers.NewLeafsRequestHandler(trieDB, message.Codec, handlerstats.NewNoopHandlerStats()) + handler := handlers.NewLeafsRequestHandler(trieDB, nil, message.Codec, handlerstats.NewNoopHandlerStats()) client := NewClient(&ClientConfig{ NetworkClient: &mockNetwork{}, Codec: message.Codec, @@ -611,21 +612,11 @@ func TestGetLeafs(t *testing.T) { if _, err := message.Codec.Unmarshal(response, &leafResponse); err != nil { t.Fatal(err) } - leafResponse.Keys = leafResponse.Keys[1:] - leafResponse.Vals = leafResponse.Vals[1:] - - tr, err := trie.New(largeTrieRoot, trieDB) - if err != nil { - t.Fatal(err) - } - leafResponse.ProofKeys, leafResponse.ProofVals, err = handlers.GenerateRangeProof(tr, leafResponse.Keys[0], leafResponse.Keys[len(leafResponse.Keys)-1]) - if err != nil { - t.Fatal(err) - } - - modifiedResponse, err := message.Codec.Marshal(message.Version, leafResponse) + modifiedRequest := request + modifiedRequest.Start = leafResponse.Keys[1] + modifiedResponse, err := handler.OnLeafsRequest(context.Background(), ids.GenerateTestNodeID(), 2, modifiedRequest) if err != nil { - t.Fatal(err) + t.Fatal("unexpected error in calling leafs request handler", err) } return modifiedResponse }, @@ -791,7 +782,7 @@ func TestGetLeafsRetries(t *testing.T) { trieDB := trie.NewDatabase(memorydb.New()) root, _, _ := trie.GenerateTrie(t, trieDB, 100_000, common.HashLength) - handler := handlers.NewLeafsRequestHandler(trieDB, message.Codec, handlerstats.NewNoopHandlerStats()) + handler := handlers.NewLeafsRequestHandler(trieDB, nil, message.Codec, handlerstats.NewNoopHandlerStats()) mockNetClient := &mockNetwork{} const maxAttempts = 8 diff --git a/sync/client/leaf_syncer.go b/sync/client/leaf_syncer.go index 57f765530b..4035a8022e 100644 --- a/sync/client/leaf_syncer.go +++ b/sync/client/leaf_syncer.go @@ -10,6 +10,7 @@ import ( "sync" "github.com/ava-labs/coreth/plugin/evm/message" + "github.com/ava-labs/coreth/utils" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/log" "golang.org/x/sync/errgroup" @@ -40,6 +41,7 @@ type OnSyncFailure func(error) error // LeafSyncTask represents a complete task to be completed by the leaf syncer. type LeafSyncTask struct { Root common.Hash // Root of the trie to sync + Account common.Hash // Account hash of the trie to sync (only applicable to storage tries) Start []byte // Starting key to request new leaves NodeType message.NodeType // Specifies the message type (atomic/state trie) for the leaf syncer to send OnStart OnStart // Callback when tasks begins, returns true if work can be skipped @@ -115,6 +117,7 @@ func (c *CallbackLeafSyncer) syncTask(ctx context.Context, task *LeafSyncTask) e leafsResponse, err := c.client.GetLeafs(message.LeafsRequest{ Root: root, + Account: task.Account, Start: start, End: nil, // will request until the end of the trie Limit: defaultLeafRequestLimit, @@ -147,7 +150,7 @@ func (c *CallbackLeafSyncer) syncTask(ctx context.Context, task *LeafSyncTask) e // Update start to be one bit past the last returned key for the next request. // Note: since more was true, this cannot cause an overflow. start = leafsResponse.Keys[len(leafsResponse.Keys)-1] - IncrOne(start) + utils.IncrOne(start) } } @@ -222,17 +225,3 @@ func (c *CallbackLeafSyncer) addTasks(ctx context.Context, tasks []*LeafSyncTask } return nil } - -// IncrOne increments bytes value by one -func IncrOne(bytes []byte) { - index := len(bytes) - 1 - for index >= 0 { - if bytes[index] < 255 { - bytes[index]++ - break - } else { - bytes[index] = 0 - index-- - } - } -} diff --git a/sync/handlers/block_request.go b/sync/handlers/block_request.go index c6e7afd920..db99d9b847 100644 --- a/sync/handlers/block_request.go +++ b/sync/handlers/block_request.go @@ -11,7 +11,6 @@ import ( "github.com/ava-labs/avalanchego/codec" "github.com/ava-labs/avalanchego/ids" - "github.com/ava-labs/coreth/core/types" "github.com/ava-labs/coreth/peer" "github.com/ava-labs/coreth/plugin/evm/message" "github.com/ava-labs/coreth/sync/handlers/stats" @@ -26,17 +25,17 @@ const parentLimit = uint16(64) // BlockRequestHandler is a peer.RequestHandler for message.BlockRequest // serving requested blocks starting at specified hash type BlockRequestHandler struct { - stats stats.BlockRequestHandlerStats - network peer.Network - getter func(common.Hash, uint64) *types.Block - codec codec.Manager + stats stats.BlockRequestHandlerStats + network peer.Network + blockProvider BlockProvider + codec codec.Manager } -func NewBlockRequestHandler(getter func(common.Hash, uint64) *types.Block, codec codec.Manager, handlerStats stats.BlockRequestHandlerStats) *BlockRequestHandler { +func NewBlockRequestHandler(blockProvider BlockProvider, codec codec.Manager, handlerStats stats.BlockRequestHandlerStats) *BlockRequestHandler { return &BlockRequestHandler{ - getter: getter, - codec: codec, - stats: handlerStats, + blockProvider: blockProvider, + codec: codec, + stats: handlerStats, } } @@ -75,7 +74,7 @@ func (b *BlockRequestHandler) OnBlockRequest(ctx context.Context, nodeID ids.Nod break } - block := b.getter(hash, height) + block := b.blockProvider.GetBlock(hash, height) if block == nil { b.stats.IncMissingBlockHash() break diff --git a/sync/handlers/block_request_test.go b/sync/handlers/block_request_test.go index bfd1f35fe2..4930d3f230 100644 --- a/sync/handlers/block_request_test.go +++ b/sync/handlers/block_request_test.go @@ -41,13 +41,16 @@ func TestBlockRequestHandler(t *testing.T) { } mockHandlerStats := &stats.MockHandlerStats{} - blockRequestHandler := NewBlockRequestHandler(func(hash common.Hash, height uint64) *types.Block { - blk, ok := blocksDB[hash] - if !ok || blk.NumberU64() != height { - return nil - } - return blk - }, message.Codec, mockHandlerStats) + blockProvider := &TestBlockProvider{ + GetBlockFn: func(hash common.Hash, height uint64) *types.Block { + blk, ok := blocksDB[hash] + if !ok || blk.NumberU64() != height { + return nil + } + return blk + }, + } + blockRequestHandler := NewBlockRequestHandler(blockProvider, message.Codec, mockHandlerStats) tests := []struct { name string @@ -163,18 +166,21 @@ func TestBlockRequestHandlerCtxExpires(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() blockRequestCallCount := 0 - blockRequestHandler := NewBlockRequestHandler(func(hash common.Hash, height uint64) *types.Block { - blockRequestCallCount++ - // cancel ctx after the 2nd call to simulate ctx expiring due to deadline exceeding - if blockRequestCallCount >= cancelAfterNumRequests { - cancel() - } - blk, ok := blocksDB[hash] - if !ok || blk.NumberU64() != height { - return nil - } - return blk - }, message.Codec, stats.NewNoopHandlerStats()) + blockProvider := &TestBlockProvider{ + GetBlockFn: func(hash common.Hash, height uint64) *types.Block { + blockRequestCallCount++ + // cancel ctx after the 2nd call to simulate ctx expiring due to deadline exceeding + if blockRequestCallCount >= cancelAfterNumRequests { + cancel() + } + blk, ok := blocksDB[hash] + if !ok || blk.NumberU64() != height { + return nil + } + return blk + }, + } + blockRequestHandler := NewBlockRequestHandler(blockProvider, message.Codec, stats.NewNoopHandlerStats()) responseBytes, err := blockRequestHandler.OnBlockRequest(ctx, ids.GenerateTestNodeID(), 1, message.BlockRequest{ Hash: blocks[10].Hash(), diff --git a/sync/handlers/handler.go b/sync/handlers/handler.go index 95189f65e8..bc872e1c49 100644 --- a/sync/handlers/handler.go +++ b/sync/handlers/handler.go @@ -8,6 +8,7 @@ import ( "github.com/ava-labs/avalanchego/codec" "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/coreth/core/state/snapshot" "github.com/ava-labs/coreth/core/types" "github.com/ava-labs/coreth/plugin/evm/message" "github.com/ava-labs/coreth/sync/handlers/stats" @@ -17,6 +18,19 @@ import ( var _ message.RequestHandler = &syncHandler{} +type BlockProvider interface { + GetBlock(common.Hash, uint64) *types.Block +} + +type SnapshotProvider interface { + Snapshots() *snapshot.Tree +} + +type SyncDataProvider interface { + BlockProvider + SnapshotProvider +} + type syncHandler struct { stateTrieLeafsRequestHandler *LeafsRequestHandler atomicTrieLeafsRequestHandler *LeafsRequestHandler @@ -26,16 +40,16 @@ type syncHandler struct { // NewSyncHandler constructs the handler for serving state sync. func NewSyncHandler( - getBlock func(common.Hash, uint64) *types.Block, + provider SyncDataProvider, evmTrieDB *trie.Database, atomicTrieDB *trie.Database, networkCodec codec.Manager, stats stats.HandlerStats, ) message.RequestHandler { return &syncHandler{ - stateTrieLeafsRequestHandler: NewLeafsRequestHandler(evmTrieDB, networkCodec, stats), - atomicTrieLeafsRequestHandler: NewLeafsRequestHandler(atomicTrieDB, networkCodec, stats), - blockRequestHandler: NewBlockRequestHandler(getBlock, networkCodec, stats), + stateTrieLeafsRequestHandler: NewLeafsRequestHandler(evmTrieDB, provider, networkCodec, stats), + atomicTrieLeafsRequestHandler: NewLeafsRequestHandler(atomicTrieDB, nil, networkCodec, stats), + blockRequestHandler: NewBlockRequestHandler(provider, networkCodec, stats), codeRequestHandler: NewCodeRequestHandler(evmTrieDB.DiskDB(), networkCodec, stats), } } diff --git a/sync/handlers/iterators.go b/sync/handlers/iterators.go new file mode 100644 index 0000000000..53125c9e99 --- /dev/null +++ b/sync/handlers/iterators.go @@ -0,0 +1,68 @@ +// (c) 2021-2022, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package handlers + +import ( + "github.com/ava-labs/coreth/core/state/snapshot" + "github.com/ava-labs/coreth/ethdb" +) + +var ( + _ ethdb.Iterator = &accountIt{} + _ ethdb.Iterator = &storageIt{} +) + +// accountIt wraps a [snapshot.AccountIterator] to conform to [ethdb.Iterator] +// accounts will be returned in consensus (FullRLP) format for compatibility with trie data. +type accountIt struct { + snapshot.AccountIterator + err error + val []byte +} + +func (it *accountIt) Next() bool { + if it.err != nil { + return false + } + for it.AccountIterator.Next() { + it.val, it.err = snapshot.FullAccountRLP(it.Account()) + return it.err == nil + } + it.val = nil + return false +} + +func (it *accountIt) Key() []byte { + if it.err != nil { + return nil + } + return it.Hash().Bytes() +} + +func (it *accountIt) Value() []byte { + if it.err != nil { + return nil + } + return it.val +} + +func (it *accountIt) Error() error { + if it.err != nil { + return it.err + } + return it.AccountIterator.Error() +} + +// storageIt wraps a [snapshot.StorageIterator] to conform to [ethdb.Iterator] +type storageIt struct { + snapshot.StorageIterator +} + +func (it *storageIt) Key() []byte { + return it.Hash().Bytes() +} + +func (it *storageIt) Value() []byte { + return it.Slot() +} diff --git a/sync/handlers/leafs_request.go b/sync/handlers/leafs_request.go index df8f5defc4..99ade42346 100644 --- a/sync/handlers/leafs_request.go +++ b/sync/handlers/leafs_request.go @@ -7,38 +7,53 @@ import ( "bytes" "context" "fmt" + "sync" "time" "github.com/ava-labs/avalanchego/codec" "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/utils/math" "github.com/ava-labs/avalanchego/utils/wrappers" + "github.com/ava-labs/coreth/core/state/snapshot" "github.com/ava-labs/coreth/core/types" + "github.com/ava-labs/coreth/ethdb" "github.com/ava-labs/coreth/ethdb/memorydb" "github.com/ava-labs/coreth/plugin/evm/message" "github.com/ava-labs/coreth/sync/handlers/stats" "github.com/ava-labs/coreth/trie" + "github.com/ava-labs/coreth/utils" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/log" ) -// Maximum number of leaves to return in a message.LeafsResponse -// This parameter overrides any other Limit specified -// in message.LeafsRequest if it is greater than this value -const maxLeavesLimit = uint16(1024) +const ( + // Maximum number of leaves to return in a message.LeafsResponse + // This parameter overrides any other Limit specified + // in message.LeafsRequest if it is greater than this value + maxLeavesLimit = uint16(1024) + + segmentLen = 64 // divide data from snapshot to segments of this size +) // LeafsRequestHandler is a peer.RequestHandler for types.LeafsRequest // serving requested trie data type LeafsRequestHandler struct { - trieDB *trie.Database - codec codec.Manager - stats stats.LeafsRequestHandlerStats + trieDB *trie.Database + snapshotProvider SnapshotProvider + codec codec.Manager + stats stats.LeafsRequestHandlerStats + pool sync.Pool } -func NewLeafsRequestHandler(trieDB *trie.Database, codec codec.Manager, syncerStats stats.LeafsRequestHandlerStats) *LeafsRequestHandler { +func NewLeafsRequestHandler(trieDB *trie.Database, snapshotProvider SnapshotProvider, codec codec.Manager, syncerStats stats.LeafsRequestHandlerStats) *LeafsRequestHandler { return &LeafsRequestHandler{ - trieDB: trieDB, - codec: codec, - stats: syncerStats, + trieDB: trieDB, + snapshotProvider: snapshotProvider, + codec: codec, + stats: syncerStats, + pool: sync.Pool{ + New: func() interface{} { return make([][]byte, 0, maxLeavesLimit) }, + }, } } @@ -85,125 +100,336 @@ func (lrh *LeafsRequestHandler) OnLeafsRequest(ctx context.Context, nodeID ids.N lrh.stats.IncMissingRoot() return nil, nil } + // override limit if it is greater than the configured maxLeavesLimit + limit := leafsRequest.Limit + if limit > maxLeavesLimit { + limit = maxLeavesLimit + } + + var leafsResponse message.LeafsResponse + // pool response's key/val allocations + leafsResponse.Keys = lrh.pool.Get().([][]byte) + leafsResponse.Vals = lrh.pool.Get().([][]byte) + defer func() { + for i := range leafsResponse.Keys { + // clear out slices before returning them to the pool + // to avoid memory leak. + leafsResponse.Keys[i] = nil + leafsResponse.Vals[i] = nil + } + lrh.pool.Put(leafsResponse.Keys[:0]) + lrh.pool.Put(leafsResponse.Vals[:0]) + }() + + responseBuilder := &responseBuilder{ + request: &leafsRequest, + response: &leafsResponse, + t: t, + keyLength: keyLength, + limit: limit, + stats: lrh.stats, + } + // pass snapshot to responseBuilder if non-nil snapshot getter provided + if lrh.snapshotProvider != nil { + responseBuilder.snap = lrh.snapshotProvider.Snapshots() + } + err = responseBuilder.handleRequest(ctx) // ensure metrics are captured properly on all return paths - leafCount := uint16(0) defer func() { lrh.stats.UpdateLeafsRequestProcessingTime(time.Since(startTime)) - lrh.stats.UpdateLeafsReturned(leafCount) + lrh.stats.UpdateLeafsReturned(uint16(len(leafsResponse.Keys))) + lrh.stats.UpdateRangeProofKeysReturned(int64(len(leafsResponse.ProofKeys))) + lrh.stats.UpdateGenerateRangeProofTime(responseBuilder.proofTime) + lrh.stats.UpdateReadLeafsTime(responseBuilder.trieReadTime) }() + if err != nil { + log.Debug("failed to serve leafs request", "nodeID", nodeID, "requestID", requestID, "request", leafsRequest, "err", err) + return nil, nil + } + if len(leafsResponse.Keys) == 0 && ctx.Err() != nil { + log.Debug("context err set before any leafs were iterated", "nodeID", nodeID, "requestID", requestID, "request", leafsRequest, "ctxErr", ctx.Err()) + return nil, nil + } - // create iterator to iterate the trie - // Note that leafsRequest.Start could be an original start point - // or leafsResponse.NextKey from partial response to previous request - it := trie.NewIterator(t.NodeIterator(leafsRequest.Start)) - - // override limit if it is greater than the configured maxLeavesLimit - limit := leafsRequest.Limit - if limit > maxLeavesLimit { - limit = maxLeavesLimit + responseBytes, err := lrh.codec.Marshal(message.Version, leafsResponse) + if err != nil { + log.Debug("failed to marshal LeafsResponse, dropping request", "nodeID", nodeID, "requestID", requestID, "request", leafsRequest, "err", err) + return nil, nil } - var leafsResponse message.LeafsResponse + log.Debug("handled leafsRequest", "time", time.Since(startTime), "leafs", len(leafsResponse.Keys), "proofLen", len(leafsResponse.ProofKeys)) + return responseBytes, nil +} - // more indicates whether there are more leaves in the trie - more := false - for it.Next() { - // if we're at the end, break this loop - if len(leafsRequest.End) > 0 && bytes.Compare(it.Key, leafsRequest.End) > 0 { - more = true - break +type responseBuilder struct { + request *message.LeafsRequest + response *message.LeafsResponse + t *trie.Trie + snap *snapshot.Tree + keyLength int + limit uint16 + + // stats + trieReadTime time.Duration + proofTime time.Duration + stats stats.LeafsRequestHandlerStats +} + +func (rb *responseBuilder) handleRequest(ctx context.Context) error { + // Read from snapshot if a [snapshot.Tree] was provided in initialization + if rb.snap != nil { + if done, err := rb.fillFromSnapshot(ctx); err != nil { + return err + } else if done { + return nil } + } - // If we've returned enough data or run out of time, set the more flag and exit - // this flag will determine if the proof is generated or not - if leafCount >= limit || ctx.Err() != nil { - if leafCount == 0 { - log.Debug("context err set before any leafs were iterated", "nodeID", nodeID, "requestID", requestID, "request", leafsRequest, "ctxErr", ctx.Err()) - return nil, nil - } - more = true - break + if len(rb.response.Keys) < int(rb.limit) { + // more indicates whether there are more leaves in the trie + more, err := rb.fillFromTrie(ctx, rb.request.End) + if err != nil { + rb.stats.IncTrieError() + return err + } + if len(rb.request.Start) == 0 && !more { + // omit proof via early return + return nil } + } - // collect data to return - leafCount++ - leafsResponse.Keys = append(leafsResponse.Keys, it.Key) - leafsResponse.Vals = append(leafsResponse.Vals, it.Value) + // Generate the proof and add it to the response. + proof, err := rb.generateRangeProof(rb.request.Start, rb.response.Keys) + if err != nil { + rb.stats.IncProofError() + return err } - // Update read leafs time here, so that we include the case that an error occurred. - lrh.stats.UpdateReadLeafsTime(time.Since(startTime)) + defer proof.Close() // closing memdb does not error - if it.Err != nil { - log.Debug("failed to iterate trie, dropping request", "nodeID", nodeID, "requestID", requestID, "request", leafsRequest, "err", it.Err) - lrh.stats.IncTrieError() - return nil, nil + rb.response.ProofKeys, rb.response.ProofVals, err = iterateKeyVals(proof) + if err != nil { + rb.stats.IncProofError() + return err + } + return nil +} + +// fillFromSnapshot reads data from snapshot and returns true if the response is complete +// (otherwise the trie must be iterated further and a range proof may be needed) +func (rb *responseBuilder) fillFromSnapshot(ctx context.Context) (bool, error) { + snapshotReadStart := time.Now() + rb.stats.IncSnapshotReadAttempt() + + // Optimistically read leafs from the snapshot, assuming they have not been + // modified since the requested root. If this assumption can be verified with + // range proofs and data from the trie, we can skip iterating the trie as + // an optimization. + snapKeys, snapVals, more, err := rb.readLeafsFromSnapshot(ctx) + // Update read snapshot time here, so that we include the case that an error occurred. + rb.stats.UpdateSnapshotReadTime(time.Since(snapshotReadStart)) + if err != nil { + rb.stats.IncSnapshotReadError() + return false, err } - // only generate proof if we're not returning the full trie - // we determine this based on if the starting point is nil and if the iterator - // indicates that are more leaves in the trie. - if len(leafsRequest.Start) > 0 || more { - start := leafsRequest.Start - // If [start] in the request is empty, populate it with the appropriate length - // key starting at 0. - if len(start) == 0 { - start = bytes.Repeat([]byte{0x00}, keyLength) + // Check if the entire range read from the snapshot is valid according to the trie. + proof, ok, err := rb.isRangeValid(snapKeys, snapVals, false) + if err != nil { + rb.stats.IncProofError() + return false, err + } + defer proof.Close() // closing memdb does not error + if ok { + rb.response.Keys, rb.response.Vals = snapKeys, snapVals + if len(rb.request.Start) == 0 && !more { + // omit proof via early return + rb.stats.IncSnapshotReadSuccess() + return true, nil } - // If there is a non-zero number of keys, set [end] for the range proof to the - // last key included in the response. - end := leafsRequest.End - if len(leafsResponse.Keys) > 0 { - end = leafsResponse.Keys[len(leafsResponse.Keys)-1] + rb.response.ProofKeys, rb.response.ProofVals, err = iterateKeyVals(proof) + if err != nil { + rb.stats.IncProofError() + return false, err } - rangeProofStart := time.Now() - leafsResponse.ProofKeys, leafsResponse.ProofVals, err = GenerateRangeProof(t, start, end) - lrh.stats.UpdateGenerateRangeProofTime(time.Since(rangeProofStart)) - lrh.stats.UpdateRangeProofKeysReturned(int64(len(leafsResponse.Keys))) - // Generate the proof and add it to the response. + rb.stats.IncSnapshotReadSuccess() + return true, nil + } + // The data from the snapshot could not be validated as a whole. It is still likely + // most of the data from the snapshot is useable, so we try to validate smaller + // segments of the data and use them in the response. + hasGap := false + for i := 0; i < len(snapKeys); i += segmentLen { + segmentEnd := math.Min(i+segmentLen, len(snapKeys)) + proof, ok, err := rb.isRangeValid(snapKeys[i:segmentEnd], snapVals[i:segmentEnd], hasGap) if err != nil { - log.Debug("failed to create valid proof serving leafs request", "nodeID", nodeID, "requestID", requestID, "request", leafsRequest, "err", err) - lrh.stats.IncTrieError() - return nil, nil + rb.stats.IncProofError() + return false, err + } + _ = proof.Close() // we don't need this proof + if !ok { + // segment is not valid + rb.stats.IncSnapshotSegmentInvalid() + hasGap = true + continue + } + + // segment is valid + rb.stats.IncSnapshotSegmentValid() + if hasGap { + // if there is a gap between valid segments, fill the gap with data from the trie + _, err := rb.fillFromTrie(ctx, snapKeys[i]) + if err != nil { + rb.stats.IncTrieError() + return false, err + } + if len(rb.response.Keys) >= int(rb.limit) || ctx.Err() != nil { + break + } + // remove the last key added since it is snapKeys[i] and will be added back + // Note: this is safe because we were able to verify the range proof that + // shows snapKeys[i] is part of the trie. + rb.response.Keys = rb.response.Keys[:len(rb.response.Keys)-1] + rb.response.Vals = rb.response.Vals[:len(rb.response.Vals)-1] + } + hasGap = false + // all the key/vals in the segment are valid, but possibly shorten segmentEnd + // here to respect limit. this is necessary in case the number of leafs we read + // from the trie is more than the length of a segment which cannot be validated. limit + segmentEnd = math.Min(segmentEnd, i+int(rb.limit)-len(rb.response.Keys)) + rb.response.Keys = append(rb.response.Keys, snapKeys[i:segmentEnd]...) + rb.response.Vals = append(rb.response.Vals, snapVals[i:segmentEnd]...) + + if len(rb.response.Keys) >= int(rb.limit) { + break } } + return false, nil +} - responseBytes, err := lrh.codec.Marshal(message.Version, leafsResponse) - if err != nil { - log.Debug("failed to marshal LeafsResponse, dropping request", "nodeID", nodeID, "requestID", requestID, "request", leafsRequest, "err", err) - return nil, nil +// generateRangeProof returns a range proof for the range specified by [start] and [keys] using [t]. +func (rb *responseBuilder) generateRangeProof(start []byte, keys [][]byte) (*memorydb.Database, error) { + proof := memorydb.New() + startTime := time.Now() + defer func() { rb.proofTime += time.Since(startTime) }() + + // If [start] is empty, populate it with the appropriate length key starting at 0. + if len(start) == 0 { + start = bytes.Repeat([]byte{0x00}, rb.keyLength) } - log.Debug("handled leafsRequest", "time", time.Since(startTime), "leafs", leafCount, "proofLen", len(leafsResponse.ProofKeys)) - return responseBytes, nil + if err := rb.t.Prove(start, 0, proof); err != nil { + _ = proof.Close() // closing memdb does not error + return nil, err + } + if len(keys) > 0 { + // If there is a non-zero number of keys, set [end] for the range proof to the last key. + end := keys[len(keys)-1] + if err := rb.t.Prove(end, 0, proof); err != nil { + _ = proof.Close() // closing memdb does not error + return nil, err + } + } + return proof, nil } -// GenerateRangeProof returns the required proof key-values pairs for the range proof of -// [t] from [start, end]. -func GenerateRangeProof(t *trie.Trie, start, end []byte) ([][]byte, [][]byte, error) { - proof := memorydb.New() - defer proof.Close() // Closing the memorydb should never error +// verifyRangeProof verifies the provided range proof with [keys/vals], starting at [start]. +// Returns nil on success. +func (rb *responseBuilder) verifyRangeProof(keys, vals [][]byte, start []byte, proof *memorydb.Database) error { + startTime := time.Now() + defer func() { rb.proofTime += time.Since(startTime) }() + + // If [start] is empty, populate it with the appropriate length key starting at 0. + if len(start) == 0 { + start = bytes.Repeat([]byte{0x00}, rb.keyLength) + } + var end []byte + if len(keys) > 0 { + end = keys[len(keys)-1] + } + _, err := trie.VerifyRangeProof(rb.request.Root, start, end, keys, vals, proof) + return err +} + +// iterateKeyVals returns the key-value pairs contained in [db] +func iterateKeyVals(db *memorydb.Database) ([][]byte, [][]byte, error) { + if db == nil { + return nil, nil, nil + } + // iterate db into [][]byte and return + it := db.NewIterator(nil, nil) + defer it.Release() - if err := t.Prove(start, 0, proof); err != nil { - return nil, nil, err + keys := make([][]byte, 0, db.Len()) + vals := make([][]byte, 0, db.Len()) + for it.Next() { + keys = append(keys, it.Key()) + vals = append(vals, it.Value()) } - if err := t.Prove(end, 0, proof); err != nil { - return nil, nil, err + return keys, vals, it.Error() +} + +// isRangeValid generates and verifies a range proof, returning true if keys/vals are +// part of the trie. If [hasGap] is true, the range is validated independent of the +// existing response. If [hasGap] is false, the range proof begins at a key which +// guarantees the range can be appended to the response. +func (rb *responseBuilder) isRangeValid(keys, vals [][]byte, hasGap bool) (*memorydb.Database, bool, error) { + var startKey []byte + if hasGap { + startKey = keys[0] + } else { + startKey = rb.nextKey() } - // dump proof into response - proofIt := proof.NewIterator(nil, nil) - defer proofIt.Release() + proof, err := rb.generateRangeProof(startKey, keys) + if err != nil { + return nil, false, err + } + return proof, rb.verifyRangeProof(keys, vals, startKey, proof) == nil, nil +} - keys := make([][]byte, 0, proof.Len()) - values := make([][]byte, 0, proof.Len()) - for proofIt.Next() { - keys = append(keys, proofIt.Key()) - values = append(values, proofIt.Value()) +// nextKey returns the nextKey that could potentially be part of the response. +func (rb *responseBuilder) nextKey() []byte { + if len(rb.response.Keys) == 0 { + return rb.request.Start } + nextKey := common.CopyBytes(rb.response.Keys[len(rb.response.Keys)-1]) + utils.IncrOne(nextKey) + return nextKey +} - return keys, values, proofIt.Error() +// fillFromTrie iterates key/values from the response builder's trie and appends +// them to the response. Iteration begins from the last key already in the response, +// or the request start if the response is empty. Iteration ends at [end] or if +// the number of leafs reaches the builder's limit. +// Returns true if there are more keys in the trie. +func (rb *responseBuilder) fillFromTrie(ctx context.Context, end []byte) (bool, error) { + startTime := time.Now() + defer func() { rb.trieReadTime += time.Since(startTime) }() + + // create iterator to iterate the trie + it := trie.NewIterator(rb.t.NodeIterator(rb.nextKey())) + more := false + for it.Next() { + // if we're at the end, break this loop + if len(end) > 0 && bytes.Compare(it.Key, end) > 0 { + more = true + break + } + + // If we've returned enough data or run out of time, set the more flag and exit + // this flag will determine if the proof is generated or not + if len(rb.response.Keys) >= int(rb.limit) || ctx.Err() != nil { + more = true + break + } + + // append key/vals to the response + rb.response.Keys = append(rb.response.Keys, it.Key) + rb.response.Vals = append(rb.response.Vals, it.Value) + } + return more, it.Err } // getKeyLength returns trie key length for given nodeType @@ -217,3 +443,42 @@ func getKeyLength(nodeType message.NodeType) (int, error) { } return 0, fmt.Errorf("cannot get key length for unknown node type: %s", nodeType) } + +// readLeafsFromSnapshot iterates the storage snapshot of the requested account +// (or the main account trie if account is empty). Returns up to [rb.limit] key/value +// pairs with for keys that are in the request's range (inclusive), and a boolean +// indicating if there are more keys in the snapshot. +func (rb *responseBuilder) readLeafsFromSnapshot(ctx context.Context) ([][]byte, [][]byte, bool, error) { + var ( + snapIt ethdb.Iterator + startHash = common.BytesToHash(rb.request.Start) + more = false + keys = make([][]byte, 0, rb.limit) + vals = make([][]byte, 0, rb.limit) + ) + + // Get an iterator into the storage or the main account snapshot. + if rb.request.Account == (common.Hash{}) { + snapIt = &accountIt{AccountIterator: rb.snap.DiskAccountIterator(startHash)} + } else { + snapIt = &storageIt{StorageIterator: rb.snap.DiskStorageIterator(rb.request.Account, startHash)} + } + defer snapIt.Release() + for snapIt.Next() { + // if we're at the end, break this loop + if len(rb.request.End) > 0 && bytes.Compare(snapIt.Key(), rb.request.End) > 0 { + more = true + break + } + // If we've returned enough data or run out of time, set the more flag and exit + // this flag will determine if the proof is generated or not + if len(keys) >= int(rb.limit) || ctx.Err() != nil { + more = true + break + } + + keys = append(keys, snapIt.Key()) + vals = append(vals, snapIt.Value()) + } + return keys, vals, more, snapIt.Error() +} diff --git a/sync/handlers/leafs_request_test.go b/sync/handlers/leafs_request_test.go index 682d11a455..858fd9649f 100644 --- a/sync/handlers/leafs_request_test.go +++ b/sync/handlers/leafs_request_test.go @@ -10,12 +10,17 @@ import ( "testing" "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/coreth/core/rawdb" + "github.com/ava-labs/coreth/core/state/snapshot" "github.com/ava-labs/coreth/core/types" + "github.com/ava-labs/coreth/ethdb" "github.com/ava-labs/coreth/ethdb/memorydb" "github.com/ava-labs/coreth/plugin/evm/message" "github.com/ava-labs/coreth/sync/handlers/stats" "github.com/ava-labs/coreth/trie" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/rlp" "github.com/stretchr/testify/assert" ) @@ -26,13 +31,34 @@ func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { trieDB := trie.NewDatabase(memdb) corruptedTrieRoot, _, _ := trie.GenerateTrie(t, trieDB, 100, common.HashLength) - largeTrieRoot, largeTrieKeys, _ := trie.GenerateTrie(t, trieDB, 10_000, common.HashLength) - smallTrieRoot, _, _ := trie.GenerateTrie(t, trieDB, 500, common.HashLength) - // Corrupt [corruptedTrieRoot] trie.CorruptTrie(t, trieDB, corruptedTrieRoot, 5) - leafsHandler := NewLeafsRequestHandler(trieDB, message.Codec, mockHandlerStats) + largeTrieRoot, largeTrieKeys, _ := trie.GenerateTrie(t, trieDB, 10_000, common.HashLength) + smallTrieRoot, _, _ := trie.GenerateTrie(t, trieDB, 500, common.HashLength) + accountTrieRoot, accounts := trie.FillAccounts( + t, + trieDB, + common.Hash{}, + 10_000, + func(t *testing.T, i int, acc types.StateAccount) types.StateAccount { + if i == 0 { + // set the storage trie root for a single account + acc.Root = largeTrieRoot + } + return acc + }) + + // find the hash of the account we set to have a storage + var accHash common.Hash + for key, account := range accounts { + if account.Root == largeTrieRoot { + accHash = crypto.Keccak256Hash(key.Address[:]) + break + } + } + snapshotProvider := &TestSnapshotProvider{} + leafsHandler := NewLeafsRequestHandler(trieDB, snapshotProvider, message.Codec, mockHandlerStats) tests := map[string]struct { prepareTestFn func() (context.Context, message.LeafsRequest) @@ -282,18 +308,7 @@ func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { assert.EqualValues(t, len(leafsResponse.Vals), maxLeavesLimit) assert.EqualValues(t, 1, mockHandlerStats.LeafsRequestCount) assert.EqualValues(t, len(leafsResponse.Keys), mockHandlerStats.LeafsReturnedSum) - - proofDB := memorydb.New() - defer proofDB.Close() - for i, proofKey := range leafsResponse.ProofKeys { - if err = proofDB.Put(proofKey, leafsResponse.ProofVals[i]); err != nil { - t.Fatal(err) - } - } - - more, err := trie.VerifyRangeProof(largeTrieRoot, bytes.Repeat([]byte{0x00}, common.HashLength), leafsResponse.Keys[len(leafsResponse.Keys)-1], leafsResponse.Keys, leafsResponse.Vals, proofDB) - assert.NoError(t, err) - assert.True(t, more) + assertRangeProofIsValid(t, &request, &leafsResponse, true) }, }, "full range with 0x00 start": { @@ -315,18 +330,7 @@ func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { assert.EqualValues(t, len(leafsResponse.Vals), maxLeavesLimit) assert.EqualValues(t, 1, mockHandlerStats.LeafsRequestCount) assert.EqualValues(t, len(leafsResponse.Keys), mockHandlerStats.LeafsReturnedSum) - - proofDB := memorydb.New() - defer proofDB.Close() - for i, proofKey := range leafsResponse.ProofKeys { - if err = proofDB.Put(proofKey, leafsResponse.ProofVals[i]); err != nil { - t.Fatal(err) - } - } - - more, err := trie.VerifyRangeProof(largeTrieRoot, bytes.Repeat([]byte{0x00}, common.HashLength), leafsResponse.Keys[len(leafsResponse.Keys)-1], leafsResponse.Keys, leafsResponse.Vals, proofDB) - assert.NoError(t, err) - assert.True(t, more) + assertRangeProofIsValid(t, &request, &leafsResponse, true) }, }, "partial mid range": { @@ -351,18 +355,7 @@ func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { assert.EqualValues(t, 40, len(leafsResponse.Vals)) assert.EqualValues(t, 1, mockHandlerStats.LeafsRequestCount) assert.EqualValues(t, len(leafsResponse.Keys), mockHandlerStats.LeafsReturnedSum) - - proofDB := memorydb.New() - defer proofDB.Close() - for i, proofKey := range leafsResponse.ProofKeys { - if err = proofDB.Put(proofKey, leafsResponse.ProofVals[i]); err != nil { - t.Fatal(err) - } - } - - more, err := trie.VerifyRangeProof(largeTrieRoot, request.Start, request.End, leafsResponse.Keys, leafsResponse.Vals, proofDB) - assert.NoError(t, err) - assert.True(t, more) + assertRangeProofIsValid(t, &request, &leafsResponse, true) }, }, "partial end range": { @@ -384,18 +377,7 @@ func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { assert.EqualValues(t, 600, len(leafsResponse.Vals)) assert.EqualValues(t, 1, mockHandlerStats.LeafsRequestCount) assert.EqualValues(t, len(leafsResponse.Keys), mockHandlerStats.LeafsReturnedSum) - - proofDB := memorydb.New() - defer proofDB.Close() - for i, proofKey := range leafsResponse.ProofKeys { - if err = proofDB.Put(proofKey, leafsResponse.ProofVals[i]); err != nil { - t.Fatal(err) - } - } - - more, err := trie.VerifyRangeProof(largeTrieRoot, request.Start, request.End, leafsResponse.Keys, leafsResponse.Vals, proofDB) - assert.NoError(t, err) - assert.False(t, more) + assertRangeProofIsValid(t, &request, &leafsResponse, false) }, }, "final end range": { @@ -417,18 +399,7 @@ func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { assert.EqualValues(t, len(leafsResponse.Vals), 0) assert.EqualValues(t, 1, mockHandlerStats.LeafsRequestCount) assert.EqualValues(t, len(leafsResponse.Keys), mockHandlerStats.LeafsReturnedSum) - - proofDB := memorydb.New() - defer proofDB.Close() - for i, proofKey := range leafsResponse.ProofKeys { - if err = proofDB.Put(proofKey, leafsResponse.ProofVals[i]); err != nil { - t.Fatal(err) - } - } - - more, err := trie.VerifyRangeProof(request.Root, request.Start, request.End, leafsResponse.Keys, leafsResponse.Vals, proofDB) - assert.NoError(t, err) - assert.False(t, more) + assertRangeProofIsValid(t, &request, &leafsResponse, false) }, }, "small trie root": { @@ -455,22 +426,212 @@ func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { assert.Empty(t, leafsResponse.ProofVals) assert.EqualValues(t, 1, mockHandlerStats.LeafsRequestCount) assert.EqualValues(t, len(leafsResponse.Keys), mockHandlerStats.LeafsReturnedSum) + assertRangeProofIsValid(t, &request, &leafsResponse, false) + }, + }, + "account data served from snapshot": { + prepareTestFn: func() (context.Context, message.LeafsRequest) { + snap, err := snapshot.New(memdb, trieDB, 64, common.Hash{}, accountTrieRoot, false, true, false) + if err != nil { + t.Fatal(err) + } + snapshotProvider.Snapshot = snap + return context.Background(), message.LeafsRequest{ + Root: accountTrieRoot, + Limit: maxLeavesLimit, + NodeType: message.StateTrieNode, + } + }, + assertResponseFn: func(t *testing.T, request message.LeafsRequest, response []byte, err error) { + assert.NoError(t, err) + var leafsResponse message.LeafsResponse + _, err = message.Codec.Unmarshal(response, &leafsResponse) + assert.NoError(t, err) + assert.EqualValues(t, maxLeavesLimit, len(leafsResponse.Keys)) + assert.EqualValues(t, maxLeavesLimit, len(leafsResponse.Vals)) + assert.EqualValues(t, 1, mockHandlerStats.LeafsRequestCount) + assert.EqualValues(t, len(leafsResponse.Keys), mockHandlerStats.LeafsReturnedSum) + assert.EqualValues(t, 1, mockHandlerStats.SnapshotReadAttemptCount) + assert.EqualValues(t, 1, mockHandlerStats.SnapshotReadSuccessCount) + assertRangeProofIsValid(t, &request, &leafsResponse, true) + }, + }, + "partial account data served from snapshot": { + prepareTestFn: func() (context.Context, message.LeafsRequest) { + snap, err := snapshot.New(memdb, trieDB, 64, common.Hash{}, accountTrieRoot, false, true, false) + if err != nil { + t.Fatal(err) + } + snapshotProvider.Snapshot = snap + it := snap.DiskAccountIterator(common.Hash{}) + defer it.Release() + i := 0 + for it.Next() { + if i > int(maxLeavesLimit) { + // no need to modify beyond the request limit + break + } + // modify one entry of 1 in 4 segments + if i%(segmentLen*4) == 0 { + var acc snapshot.Account + if err := rlp.DecodeBytes(it.Account(), &acc); err != nil { + t.Fatalf("could not parse snapshot account: %v", err) + } + acc.Nonce++ + bytes, err := rlp.EncodeToBytes(acc) + if err != nil { + t.Fatalf("coult not encode snapshot account to bytes: %v", err) + } + rawdb.WriteAccountSnapshot(memdb, it.Hash(), bytes) + } + i++ + } - firstKey := bytes.Repeat([]byte{0x00}, common.HashLength) - lastKey := leafsResponse.Keys[len(leafsResponse.Keys)-1] + return context.Background(), message.LeafsRequest{ + Root: accountTrieRoot, + Limit: maxLeavesLimit, + NodeType: message.StateTrieNode, + } + }, + assertResponseFn: func(t *testing.T, request message.LeafsRequest, response []byte, err error) { + assert.NoError(t, err) + var leafsResponse message.LeafsResponse + _, err = message.Codec.Unmarshal(response, &leafsResponse) + assert.NoError(t, err) + assert.EqualValues(t, maxLeavesLimit, len(leafsResponse.Keys)) + assert.EqualValues(t, maxLeavesLimit, len(leafsResponse.Vals)) + assert.EqualValues(t, 1, mockHandlerStats.LeafsRequestCount) + assert.EqualValues(t, len(leafsResponse.Keys), mockHandlerStats.LeafsReturnedSum) + assert.EqualValues(t, 1, mockHandlerStats.SnapshotReadAttemptCount) + assert.EqualValues(t, 0, mockHandlerStats.SnapshotReadSuccessCount) + assertRangeProofIsValid(t, &request, &leafsResponse, true) - more, err := trie.VerifyRangeProof(smallTrieRoot, firstKey, lastKey, leafsResponse.Keys, leafsResponse.Vals, nil) + // expect 1/4th of segments to be invalid + numSegments := maxLeavesLimit / segmentLen + assert.EqualValues(t, numSegments/4, mockHandlerStats.SnapshotSegmentInvalidCount) + assert.EqualValues(t, 3*numSegments/4, mockHandlerStats.SnapshotSegmentValidCount) + }, + }, + "storage data served from snapshot": { + prepareTestFn: func() (context.Context, message.LeafsRequest) { + snap, err := snapshot.New(memdb, trieDB, 64, common.Hash{}, accountTrieRoot, false, true, false) + if err != nil { + t.Fatal(err) + } + snapshotProvider.Snapshot = snap + return context.Background(), message.LeafsRequest{ + Root: largeTrieRoot, + Account: accHash, + Limit: maxLeavesLimit, + NodeType: message.StateTrieNode, + } + }, + assertResponseFn: func(t *testing.T, request message.LeafsRequest, response []byte, err error) { + assert.NoError(t, err) + var leafsResponse message.LeafsResponse + _, err = message.Codec.Unmarshal(response, &leafsResponse) assert.NoError(t, err) - assert.False(t, more) + assert.EqualValues(t, maxLeavesLimit, len(leafsResponse.Keys)) + assert.EqualValues(t, maxLeavesLimit, len(leafsResponse.Vals)) + assert.EqualValues(t, 1, mockHandlerStats.LeafsRequestCount) + assert.EqualValues(t, len(leafsResponse.Keys), mockHandlerStats.LeafsReturnedSum) + assert.EqualValues(t, 1, mockHandlerStats.SnapshotReadAttemptCount) + assert.EqualValues(t, 1, mockHandlerStats.SnapshotReadSuccessCount) + assertRangeProofIsValid(t, &request, &leafsResponse, true) + }, + }, + "partial storage data served from snapshot": { + prepareTestFn: func() (context.Context, message.LeafsRequest) { + snap, err := snapshot.New(memdb, trieDB, 64, common.Hash{}, accountTrieRoot, false, true, false) + if err != nil { + t.Fatal(err) + } + snapshotProvider.Snapshot = snap + it := snap.DiskStorageIterator(accHash, common.Hash{}) + defer it.Release() + i := 0 + for it.Next() { + if i > int(maxLeavesLimit) { + // no need to modify beyond the request limit + break + } + // modify one entry of 1 in 4 segments + if i%(segmentLen*4) == 0 { + randomBytes := make([]byte, 5) + _, err := rand.Read(randomBytes) + assert.NoError(t, err) + rawdb.WriteStorageSnapshot(memdb, accHash, it.Hash(), randomBytes) + } + i++ + } + + return context.Background(), message.LeafsRequest{ + Root: largeTrieRoot, + Account: accHash, + Limit: maxLeavesLimit, + NodeType: message.StateTrieNode, + } + }, + assertResponseFn: func(t *testing.T, request message.LeafsRequest, response []byte, err error) { + assert.NoError(t, err) + var leafsResponse message.LeafsResponse + _, err = message.Codec.Unmarshal(response, &leafsResponse) + assert.NoError(t, err) + assert.EqualValues(t, maxLeavesLimit, len(leafsResponse.Keys)) + assert.EqualValues(t, maxLeavesLimit, len(leafsResponse.Vals)) + assert.EqualValues(t, 1, mockHandlerStats.LeafsRequestCount) + assert.EqualValues(t, len(leafsResponse.Keys), mockHandlerStats.LeafsReturnedSum) + assert.EqualValues(t, 1, mockHandlerStats.SnapshotReadAttemptCount) + assert.EqualValues(t, 0, mockHandlerStats.SnapshotReadSuccessCount) + assertRangeProofIsValid(t, &request, &leafsResponse, true) + + // expect 1/4th of segments to be invalid + numSegments := maxLeavesLimit / segmentLen + assert.EqualValues(t, numSegments/4, mockHandlerStats.SnapshotSegmentInvalidCount) + assert.EqualValues(t, 3*numSegments/4, mockHandlerStats.SnapshotSegmentValidCount) }, }, } for name, test := range tests { t.Run(name, func(t *testing.T) { ctx, request := test.prepareTestFn() + t.Cleanup(func() { + <-snapshot.WipeSnapshot(memdb, true) + mockHandlerStats.Reset() + snapshotProvider.Snapshot = nil // reset the snapshot to nil + }) + response, err := leafsHandler.OnLeafsRequest(ctx, ids.GenerateTestNodeID(), 1, request) test.assertResponseFn(t, request, response, err) - mockHandlerStats.Reset() }) } } + +func assertRangeProofIsValid(t *testing.T, request *message.LeafsRequest, response *message.LeafsResponse, expectMore bool) { + t.Helper() + + var start, end []byte + if len(request.Start) == 0 { + start = bytes.Repeat([]byte{0x00}, common.HashLength) + } else { + start = request.Start + } + if len(response.Keys) > 0 { + end = response.Keys[len(response.Vals)-1] + } + + var proof ethdb.Database + if len(response.ProofKeys) > 0 { + proof = memorydb.New() + defer proof.Close() + for i, proofKey := range response.ProofKeys { + if err := proof.Put(proofKey, response.ProofVals[i]); err != nil { + t.Fatal(err) + } + } + } + + more, err := trie.VerifyRangeProof(request.Root, start, end, response.Keys, response.Vals, proof) + assert.NoError(t, err) + assert.Equal(t, expectMore, more) +} diff --git a/sync/handlers/stats/mock_stats.go b/sync/handlers/stats/mock_stats.go index 011b913081..14de4db440 100644 --- a/sync/handlers/stats/mock_stats.go +++ b/sync/handlers/stats/mock_stats.go @@ -30,9 +30,16 @@ type MockHandlerStats struct { InvalidLeafsRequestCount, LeafsReturnedSum, MissingRootCount, - TrieErrorCount uint32 + TrieErrorCount, + ProofErrorCount, + SnapshotReadErrorCount, + SnapshotReadAttemptCount, + SnapshotReadSuccessCount, + SnapshotSegmentValidCount, + SnapshotSegmentInvalidCount uint32 ProofKeysReturned int64 LeafsReadTime, + SnapshotReadTime, GenerateRangeProofTime, LeafRequestProcessingTimeSum time.Duration } @@ -55,10 +62,17 @@ func (m *MockHandlerStats) Reset() { m.LeafsReturnedSum = 0 m.MissingRootCount = 0 m.TrieErrorCount = 0 - m.LeafRequestProcessingTimeSum = 0 + m.ProofErrorCount = 0 + m.SnapshotReadErrorCount = 0 + m.SnapshotReadAttemptCount = 0 + m.SnapshotReadSuccessCount = 0 + m.SnapshotSegmentValidCount = 0 + m.SnapshotSegmentInvalidCount = 0 m.ProofKeysReturned = 0 - m.GenerateRangeProofTime = 0 m.LeafsReadTime = 0 + m.SnapshotReadTime = 0 + m.GenerateRangeProofTime = 0 + m.LeafRequestProcessingTimeSum = 0 } func (m *MockHandlerStats) IncBlockRequest() { @@ -157,6 +171,12 @@ func (m *MockHandlerStats) UpdateGenerateRangeProofTime(duration time.Duration) m.GenerateRangeProofTime += duration } +func (m *MockHandlerStats) UpdateSnapshotReadTime(duration time.Duration) { + m.lock.Lock() + defer m.lock.Unlock() + m.SnapshotReadTime += duration +} + func (m *MockHandlerStats) UpdateRangeProofKeysReturned(numProofKeys int64) { m.lock.Lock() defer m.lock.Unlock() @@ -174,3 +194,39 @@ func (m *MockHandlerStats) IncTrieError() { defer m.lock.Unlock() m.TrieErrorCount++ } + +func (m *MockHandlerStats) IncProofError() { + m.lock.Lock() + defer m.lock.Unlock() + m.ProofErrorCount++ +} + +func (m *MockHandlerStats) IncSnapshotReadError() { + m.lock.Lock() + defer m.lock.Unlock() + m.SnapshotReadErrorCount++ +} + +func (m *MockHandlerStats) IncSnapshotReadAttempt() { + m.lock.Lock() + defer m.lock.Unlock() + m.SnapshotReadAttemptCount++ +} + +func (m *MockHandlerStats) IncSnapshotReadSuccess() { + m.lock.Lock() + defer m.lock.Unlock() + m.SnapshotReadSuccessCount++ +} + +func (m *MockHandlerStats) IncSnapshotSegmentValid() { + m.lock.Lock() + defer m.lock.Unlock() + m.SnapshotSegmentValidCount++ +} + +func (m *MockHandlerStats) IncSnapshotSegmentInvalid() { + m.lock.Lock() + defer m.lock.Unlock() + m.SnapshotSegmentInvalidCount++ +} diff --git a/sync/handlers/stats/stats.go b/sync/handlers/stats/stats.go index 5b2c9a7949..d4a3cf85a4 100644 --- a/sync/handlers/stats/stats.go +++ b/sync/handlers/stats/stats.go @@ -38,10 +38,17 @@ type LeafsRequestHandlerStats interface { UpdateLeafsReturned(numLeafs uint16) UpdateLeafsRequestProcessingTime(duration time.Duration) UpdateReadLeafsTime(duration time.Duration) + UpdateSnapshotReadTime(duration time.Duration) UpdateGenerateRangeProofTime(duration time.Duration) UpdateRangeProofKeysReturned(numProofKeys int64) IncMissingRoot() IncTrieError() + IncProofError() + IncSnapshotReadError() + IncSnapshotReadAttempt() + IncSnapshotReadSuccess() + IncSnapshotSegmentValid() + IncSnapshotSegmentInvalid() } type handlerStats struct { @@ -65,10 +72,17 @@ type handlerStats struct { leafsReturned metrics.Histogram leafsRequestProcessingTime metrics.Timer leafsReadTime metrics.Timer + snapshotReadTime metrics.Timer generateRangeProofTime metrics.Timer proofKeysReturned metrics.Histogram missingRoot metrics.Counter trieError metrics.Counter + proofError metrics.Counter + snapshotReadError metrics.Counter + snapshotReadAttempt metrics.Counter + snapshotReadSuccess metrics.Counter + snapshotSegmentValid metrics.Counter + snapshotSegmentInvalid metrics.Counter } func (h *handlerStats) IncBlockRequest() { @@ -131,6 +145,10 @@ func (h *handlerStats) UpdateReadLeafsTime(duration time.Duration) { h.leafsReadTime.Update(duration) } +func (h *handlerStats) UpdateSnapshotReadTime(duration time.Duration) { + h.snapshotReadTime.Update(duration) +} + func (h *handlerStats) UpdateGenerateRangeProofTime(duration time.Duration) { h.generateRangeProofTime.Update(duration) } @@ -139,13 +157,14 @@ func (h *handlerStats) UpdateRangeProofKeysReturned(numProofKeys int64) { h.proofKeysReturned.Update(numProofKeys) } -func (h *handlerStats) IncMissingRoot() { - h.missingRoot.Inc(1) -} - -func (h *handlerStats) IncTrieError() { - h.trieError.Inc(1) -} +func (h *handlerStats) IncMissingRoot() { h.missingRoot.Inc(1) } +func (h *handlerStats) IncTrieError() { h.trieError.Inc(1) } +func (h *handlerStats) IncProofError() { h.proofError.Inc(1) } +func (h *handlerStats) IncSnapshotReadError() { h.snapshotReadError.Inc(1) } +func (h *handlerStats) IncSnapshotReadAttempt() { h.snapshotReadAttempt.Inc(1) } +func (h *handlerStats) IncSnapshotReadSuccess() { h.snapshotReadSuccess.Inc(1) } +func (h *handlerStats) IncSnapshotSegmentValid() { h.snapshotSegmentValid.Inc(1) } +func (h *handlerStats) IncSnapshotSegmentInvalid() { h.snapshotSegmentInvalid.Inc(1) } func NewHandlerStats(enabled bool) HandlerStats { if !enabled { @@ -172,10 +191,17 @@ func NewHandlerStats(enabled bool) HandlerStats { leafsRequestProcessingTime: metrics.GetOrRegisterTimer("leafs_request_processing_time", nil), leafsReturned: metrics.GetOrRegisterHistogram("leafs_request_total_leafs", nil, metrics.NewExpDecaySample(1028, 0.015)), leafsReadTime: metrics.GetOrRegisterTimer("leafs_read_time", nil), + snapshotReadTime: metrics.GetOrRegisterTimer("snapshot_read_time", nil), generateRangeProofTime: metrics.GetOrRegisterTimer("generate_range_proof_time", nil), proofKeysReturned: metrics.GetOrRegisterHistogram("proof_keys_returned", nil, metrics.NewExpDecaySample(1028, 0.015)), missingRoot: metrics.GetOrRegisterCounter("leafs_request_missing_root", nil), trieError: metrics.GetOrRegisterCounter("leafs_request_trie_error", nil), + proofError: metrics.GetOrRegisterCounter("leafs_request_proof_error", nil), + snapshotReadError: metrics.GetOrRegisterCounter("snapshot_read_error", nil), + snapshotReadAttempt: metrics.GetOrRegisterCounter("snapshot_read_attempt", nil), + snapshotReadSuccess: metrics.GetOrRegisterCounter("snapshot_read_success", nil), + snapshotSegmentValid: metrics.GetOrRegisterCounter("snapshot_segment_valid", nil), + snapshotSegmentInvalid: metrics.GetOrRegisterCounter("snapshot_segment_invalid", nil), } } @@ -202,7 +228,14 @@ func (n *noopHandlerStats) IncInvalidLeafsRequest() func (n *noopHandlerStats) UpdateLeafsRequestProcessingTime(time.Duration) {} func (n *noopHandlerStats) UpdateLeafsReturned(uint16) {} func (n *noopHandlerStats) UpdateReadLeafsTime(duration time.Duration) {} +func (n *noopHandlerStats) UpdateSnapshotReadTime(duration time.Duration) {} func (n *noopHandlerStats) UpdateGenerateRangeProofTime(duration time.Duration) {} func (n *noopHandlerStats) UpdateRangeProofKeysReturned(numProofKeys int64) {} func (n *noopHandlerStats) IncMissingRoot() {} func (n *noopHandlerStats) IncTrieError() {} +func (n *noopHandlerStats) IncProofError() {} +func (n *noopHandlerStats) IncSnapshotReadError() {} +func (n *noopHandlerStats) IncSnapshotReadAttempt() {} +func (n *noopHandlerStats) IncSnapshotReadSuccess() {} +func (n *noopHandlerStats) IncSnapshotSegmentValid() {} +func (n *noopHandlerStats) IncSnapshotSegmentInvalid() {} diff --git a/sync/handlers/test_providers.go b/sync/handlers/test_providers.go new file mode 100644 index 0000000000..81dafbfd00 --- /dev/null +++ b/sync/handlers/test_providers.go @@ -0,0 +1,31 @@ +// (c) 2021-2022, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package handlers + +import ( + "github.com/ava-labs/coreth/core/state/snapshot" + "github.com/ava-labs/coreth/core/types" + "github.com/ethereum/go-ethereum/common" +) + +var ( + _ BlockProvider = &TestBlockProvider{} + _ SnapshotProvider = &TestSnapshotProvider{} +) + +type TestBlockProvider struct { + GetBlockFn func(common.Hash, uint64) *types.Block +} + +func (t *TestBlockProvider) GetBlock(hash common.Hash, number uint64) *types.Block { + return t.GetBlockFn(hash, number) +} + +type TestSnapshotProvider struct { + Snapshot *snapshot.Tree +} + +func (t *TestSnapshotProvider) Snapshots() *snapshot.Tree { + return t.Snapshot +} diff --git a/sync/statesync/state_syncer.go b/sync/statesync/state_syncer.go index 6c27901ca9..18aee528cc 100644 --- a/sync/statesync/state_syncer.go +++ b/sync/statesync/state_syncer.go @@ -146,6 +146,7 @@ func (s *stateSyncer) Start(ctx context.Context) { for storageRoot, storageTrieProgress := range s.progressMarker.StorageTries { storageTasks = append(storageTasks, &syncclient.LeafSyncTask{ Root: storageRoot, + Account: storageTrieProgress.Account, Start: storageTrieProgress.startFrom, NodeType: message.StateTrieNode, OnLeafs: storageTrieProgress.handleLeafs, @@ -267,6 +268,7 @@ func (s *stateSyncer) createStorageTrieTask(accountHash common.Hash, storageRoot s.progressMarker.StorageTries[storageRoot] = progress return &syncclient.LeafSyncTask{ Root: storageRoot, + Account: accountHash, NodeType: message.StateTrieNode, OnLeafs: progress.handleLeafs, OnFinish: s.onFinish, diff --git a/sync/statesync/sync_helpers.go b/sync/statesync/sync_helpers.go index 5ac28c2e68..8cecdc1df6 100644 --- a/sync/statesync/sync_helpers.go +++ b/sync/statesync/sync_helpers.go @@ -10,8 +10,8 @@ import ( "github.com/ava-labs/coreth/core/state/snapshot" "github.com/ava-labs/coreth/core/types" "github.com/ava-labs/coreth/ethdb" - syncclient "github.com/ava-labs/coreth/sync/client" "github.com/ava-labs/coreth/trie" + "github.com/ava-labs/coreth/utils" "github.com/ethereum/go-ethereum/common" ) @@ -102,7 +102,7 @@ func restoreMainTrieProgressFromSnapshot(db ethdb.Iteratee, tr *TrieProgress) er // since lastKey is already added to the stack trie, // we should start syncing from the next key. tr.startFrom = lastKey - syncclient.IncrOne(tr.startFrom) + utils.IncrOne(tr.startFrom) } return it.Error() } @@ -136,7 +136,7 @@ func restoreStorageTrieProgressFromSnapshot(db ethdb.Iteratee, tr *TrieProgress, // since lastKey is already added to the stack trie, // we should start syncing from the next key. tr.startFrom = lastKey - syncclient.IncrOne(tr.startFrom) + utils.IncrOne(tr.startFrom) } return it.Error() } diff --git a/sync/statesync/sync_test.go b/sync/statesync/sync_test.go index 596ab9dc9d..e2f9ef70c6 100644 --- a/sync/statesync/sync_test.go +++ b/sync/statesync/sync_test.go @@ -42,7 +42,7 @@ func testSync(t *testing.T, test syncTest) { ctx = test.ctx } clientDB, serverTrieDB, root := test.prepareForTest(t) - leafsRequestHandler := handlers.NewLeafsRequestHandler(serverTrieDB, message.Codec, handlerstats.NewNoopHandlerStats()) + leafsRequestHandler := handlers.NewLeafsRequestHandler(serverTrieDB, nil, message.Codec, handlerstats.NewNoopHandlerStats()) codeRequestHandler := handlers.NewCodeRequestHandler(serverTrieDB.DiskDB(), message.Codec, handlerstats.NewNoopHandlerStats()) mockClient := statesyncclient.NewMockClient(message.Codec, leafsRequestHandler, codeRequestHandler, nil) // Set intercept functions for the mock client @@ -101,14 +101,14 @@ func TestSimpleSyncCases(t *testing.T) { "accounts": { prepareForTest: func(t *testing.T) (ethdb.Database, *trie.Database, common.Hash) { serverTrieDB := trie.NewDatabase(memorydb.New()) - root, _ := FillAccounts(t, serverTrieDB, common.Hash{}, 1000, nil) + root, _ := trie.FillAccounts(t, serverTrieDB, common.Hash{}, 1000, nil) return memorydb.New(), serverTrieDB, root }, }, "accounts with code": { prepareForTest: func(t *testing.T) (ethdb.Database, *trie.Database, common.Hash) { serverTrieDB := trie.NewDatabase(memorydb.New()) - root, _ := FillAccounts(t, serverTrieDB, common.Hash{}, 1000, func(t *testing.T, index int, account types.StateAccount) types.StateAccount { + root, _ := trie.FillAccounts(t, serverTrieDB, common.Hash{}, 1000, func(t *testing.T, index int, account types.StateAccount) types.StateAccount { if index%3 == 0 { codeBytes := make([]byte, 256) _, err := rand.Read(codeBytes) @@ -135,7 +135,7 @@ func TestSimpleSyncCases(t *testing.T) { "accounts with storage": { prepareForTest: func(t *testing.T) (ethdb.Database, *trie.Database, common.Hash) { serverTrieDB := trie.NewDatabase(memorydb.New()) - root, _ := FillAccounts(t, serverTrieDB, common.Hash{}, 1000, func(t *testing.T, i int, account types.StateAccount) types.StateAccount { + root, _ := trie.FillAccounts(t, serverTrieDB, common.Hash{}, 1000, func(t *testing.T, i int, account types.StateAccount) types.StateAccount { if i%5 == 0 { account.Root, _, _ = trie.GenerateTrie(t, serverTrieDB, 16, common.HashLength) } @@ -155,7 +155,7 @@ func TestSimpleSyncCases(t *testing.T) { "failed to fetch leafs": { prepareForTest: func(t *testing.T) (ethdb.Database, *trie.Database, common.Hash) { serverTrieDB := trie.NewDatabase(memorydb.New()) - root, _ := FillAccounts(t, serverTrieDB, common.Hash{}, 100, nil) + root, _ := trie.FillAccounts(t, serverTrieDB, common.Hash{}, 100, nil) return memorydb.New(), serverTrieDB, root }, GetLeafsIntercept: func(_ message.LeafsRequest, _ message.LeafsResponse) (message.LeafsResponse, error) { @@ -237,7 +237,7 @@ func TestResumeSyncLargeStorageTrieInterrupted(t *testing.T) { serverTrieDB := trie.NewDatabase(memorydb.New()) largeStorageRoot, _, _ := trie.GenerateTrie(t, serverTrieDB, 2000, common.HashLength) - root, _ := FillAccounts(t, serverTrieDB, common.Hash{}, 2000, func(t *testing.T, index int, account types.StateAccount) types.StateAccount { + root, _ := trie.FillAccounts(t, serverTrieDB, common.Hash{}, 2000, func(t *testing.T, index int, account types.StateAccount) types.StateAccount { // Set the root for a single account if index == 10 { account.Root = largeStorageRoot @@ -275,14 +275,14 @@ func TestResumeSyncToNewRootAfterLargeStorageTrieInterrupted(t *testing.T) { largeStorageRoot1, _, _ := trie.GenerateTrie(t, serverTrieDB, 2000, common.HashLength) largeStorageRoot2, _, _ := trie.GenerateTrie(t, serverTrieDB, 2000, common.HashLength) - root1, _ := FillAccounts(t, serverTrieDB, common.Hash{}, 2000, func(t *testing.T, index int, account types.StateAccount) types.StateAccount { + root1, _ := trie.FillAccounts(t, serverTrieDB, common.Hash{}, 2000, func(t *testing.T, index int, account types.StateAccount) types.StateAccount { // Set the root for a single account if index == 10 { account.Root = largeStorageRoot1 } return account }) - root2, _ := FillAccounts(t, serverTrieDB, root1, 100, func(t *testing.T, index int, account types.StateAccount) types.StateAccount { + root2, _ := trie.FillAccounts(t, serverTrieDB, root1, 100, func(t *testing.T, index int, account types.StateAccount) types.StateAccount { if index == 20 { account.Root = largeStorageRoot2 } @@ -320,7 +320,7 @@ func TestResumeSyncLargeStorageTrieWithConsecutiveDuplicatesInterrupted(t *testi serverTrieDB := trie.NewDatabase(memorydb.New()) largeStorageRoot, _, _ := trie.GenerateTrie(t, serverTrieDB, 2000, common.HashLength) - root, _ := FillAccounts(t, serverTrieDB, common.Hash{}, 100, func(t *testing.T, index int, account types.StateAccount) types.StateAccount { + root, _ := trie.FillAccounts(t, serverTrieDB, common.Hash{}, 100, func(t *testing.T, index int, account types.StateAccount) types.StateAccount { // Set the root for 2 successive accounts if index == 10 || index == 11 { account.Root = largeStorageRoot @@ -357,7 +357,7 @@ func TestResumeSyncLargeStorageTrieWithSpreadOutDuplicatesInterrupted(t *testing serverTrieDB := trie.NewDatabase(memorydb.New()) largeStorageRoot, _, _ := trie.GenerateTrie(t, serverTrieDB, 2000, common.HashLength) - root, _ := FillAccounts(t, serverTrieDB, common.Hash{}, 100, func(t *testing.T, index int, account types.StateAccount) types.StateAccount { + root, _ := trie.FillAccounts(t, serverTrieDB, common.Hash{}, 100, func(t *testing.T, index int, account types.StateAccount) types.StateAccount { if index == 10 || index == 90 { account.Root = largeStorageRoot } diff --git a/sync/statesync/test_sync.go b/sync/statesync/test_sync.go index 45dc53c461..425507a8cc 100644 --- a/sync/statesync/test_sync.go +++ b/sync/statesync/test_sync.go @@ -5,12 +5,9 @@ package statesync import ( "bytes" - cryptoRand "crypto/rand" - "math/big" "math/rand" "testing" - "github.com/ava-labs/avalanchego/utils/wrappers" "github.com/ava-labs/coreth/accounts/keystore" "github.com/ava-labs/coreth/core/rawdb" "github.com/ava-labs/coreth/core/state/snapshot" @@ -92,7 +89,7 @@ func assertDBConsistency(t testing.TB, root common.Hash, serverTrieDB, clientTri } func fillAccountsWithStorage(t *testing.T, serverTrieDB *trie.Database, root common.Hash, numAccounts int) common.Hash { - newRoot, _ := FillAccounts(t, serverTrieDB, root, numAccounts, func(t *testing.T, index int, account types.StateAccount) types.StateAccount { + newRoot, _ := trie.FillAccounts(t, serverTrieDB, root, numAccounts, func(t *testing.T, index int, account types.StateAccount) types.StateAccount { codeBytes := make([]byte, 256) _, err := rand.Read(codeBytes) if err != nil { @@ -105,67 +102,12 @@ func fillAccountsWithStorage(t *testing.T, serverTrieDB *trie.Database, root com // now create state trie numKeys := 16 - account.Root, _, _ = trie.GenerateTrie(t, serverTrieDB, numKeys, wrappers.LongLen+1) + account.Root, _, _ = trie.GenerateTrie(t, serverTrieDB, numKeys, common.HashLength) return account }) return newRoot } -// FillAccounts adds [numAccounts] randomly generated accounts to the secure trie at [root] and commits it to [trieDB]. -// [onAccount] is called if non-nil (so the caller can modify the account before it is stored in the secure trie). -// returns the new trie root and a map of funded keys to StateAccount structs. -func FillAccounts( - t *testing.T, trieDB *trie.Database, root common.Hash, numAccounts int, - onAccount func(*testing.T, int, types.StateAccount) types.StateAccount, -) (common.Hash, map[*keystore.Key]*types.StateAccount) { - var ( - minBalance = big.NewInt(3000000000000000000) - randBalance = big.NewInt(1000000000000000000) - maxNonce = 10 - accounts = make(map[*keystore.Key]*types.StateAccount, numAccounts) - ) - - tr, err := trie.NewSecure(root, trieDB) - if err != nil { - t.Fatalf("error opening trie: %v", err) - } - - for i := 0; i < numAccounts; i++ { - acc := types.StateAccount{ - Nonce: uint64(rand.Intn(maxNonce)), - Balance: new(big.Int).Add(minBalance, randBalance), - CodeHash: types.EmptyCodeHash[:], - Root: types.EmptyRootHash, - } - if onAccount != nil { - acc = onAccount(t, i, acc) - } - - accBytes, err := rlp.EncodeToBytes(acc) - if err != nil { - t.Fatalf("failed to rlp encode account: %v", err) - } - - key, err := keystore.NewKey(cryptoRand.Reader) - if err != nil { - t.Fatal(err) - } - if err = tr.TryUpdate(key.Address[:], accBytes); err != nil { - t.Fatalf("error updating trie with account, address=%s, err=%v", key.Address, err) - } - accounts[key] = &acc - } - - newRoot, _, err := tr.Commit(nil) - if err != nil { - t.Fatalf("error committing trie: %v", err) - } - if err := trieDB.Commit(newRoot, false, nil); err != nil { - t.Fatalf("error committing trieDB: %v", err) - } - return newRoot, accounts -} - // FillAccountsWithOverlappingStorage adds [numAccounts] randomly generated accounts to the secure trie at [root] // and commits it to [trieDB]. For each 3 accounts created: // - One does not have a storage trie, @@ -181,7 +123,7 @@ func FillAccountsWithOverlappingStorage( storageRoots = append(storageRoots, storageRoot) } storageRootIndex := 0 - return FillAccounts(t, trieDB, root, numAccounts, func(t *testing.T, i int, account types.StateAccount) types.StateAccount { + return trie.FillAccounts(t, trieDB, root, numAccounts, func(t *testing.T, i int, account types.StateAccount) types.StateAccount { switch i % 3 { case 0: // unmodified account case 1: // account with overlapping storage root diff --git a/trie/test_trie.go b/trie/test_trie.go index 79d992b642..d464ccc790 100644 --- a/trie/test_trie.go +++ b/trie/test_trie.go @@ -4,13 +4,18 @@ package trie import ( + cryptoRand "crypto/rand" "encoding/binary" + "math/big" "math/rand" "testing" "github.com/ava-labs/avalanchego/utils/wrappers" + "github.com/ava-labs/coreth/accounts/keystore" + "github.com/ava-labs/coreth/core/types" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/rlp" "github.com/stretchr/testify/assert" ) @@ -124,3 +129,58 @@ func CorruptTrie(t *testing.T, trieDB *Database, root common.Hash, n int) { t.Fatal(err) } } + +// FillAccounts adds [numAccounts] randomly generated accounts to the secure trie at [root] and commits it to [trieDB]. +// [onAccount] is called if non-nil (so the caller can modify the account before it is stored in the secure trie). +// returns the new trie root and a map of funded keys to StateAccount structs. +func FillAccounts( + t *testing.T, trieDB *Database, root common.Hash, numAccounts int, + onAccount func(*testing.T, int, types.StateAccount) types.StateAccount, +) (common.Hash, map[*keystore.Key]*types.StateAccount) { + var ( + minBalance = big.NewInt(3000000000000000000) + randBalance = big.NewInt(1000000000000000000) + maxNonce = 10 + accounts = make(map[*keystore.Key]*types.StateAccount, numAccounts) + ) + + tr, err := NewSecure(root, trieDB) + if err != nil { + t.Fatalf("error opening trie: %v", err) + } + + for i := 0; i < numAccounts; i++ { + acc := types.StateAccount{ + Nonce: uint64(rand.Intn(maxNonce)), + Balance: new(big.Int).Add(minBalance, randBalance), + CodeHash: types.EmptyCodeHash[:], + Root: types.EmptyRootHash, + } + if onAccount != nil { + acc = onAccount(t, i, acc) + } + + accBytes, err := rlp.EncodeToBytes(acc) + if err != nil { + t.Fatalf("failed to rlp encode account: %v", err) + } + + key, err := keystore.NewKey(cryptoRand.Reader) + if err != nil { + t.Fatal(err) + } + if err = tr.TryUpdate(key.Address[:], accBytes); err != nil { + t.Fatalf("error updating trie with account, address=%s, err=%v", key.Address, err) + } + accounts[key] = &acc + } + + newRoot, _, err := tr.Commit(nil) + if err != nil { + t.Fatalf("error committing trie: %v", err) + } + if err := trieDB.Commit(newRoot, false, nil); err != nil { + t.Fatalf("error committing trieDB: %v", err) + } + return newRoot, accounts +} diff --git a/utils/bytes.go b/utils/bytes.go new file mode 100644 index 0000000000..186e3c41ef --- /dev/null +++ b/utils/bytes.go @@ -0,0 +1,18 @@ +// (c) 2021-2022, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package utils + +// IncrOne increments bytes value by one +func IncrOne(bytes []byte) { + index := len(bytes) - 1 + for index >= 0 { + if bytes[index] < 255 { + bytes[index]++ + break + } else { + bytes[index] = 0 + index-- + } + } +}