From ca611978c19af641c70256cd0db0f9faea121faa Mon Sep 17 00:00:00 2001 From: Diego Date: Tue, 26 Mar 2024 14:16:00 +0100 Subject: [PATCH] refactor(lib/runtime/storage): Don't rely on trie snapshots for storage transactions (#3777) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Radosvet M Co-authored-by: Kirill Co-authored-by: JimboJ <40345116+jimjbrettj@users.noreply.github.com> --- .../modules/childstate_integration_test.go | 11 +- dot/rpc/modules/childstate_test.go | 11 +- dot/state/storage_test.go | 17 +- lib/runtime/interfaces.go | 39 +- lib/runtime/storage/storagediff.go | 292 +++++++++ lib/runtime/storage/storagediff_test.go | 557 ++++++++++++++++++ lib/runtime/storage/trie.go | 358 ++++++++--- lib/runtime/storage/trie_test.go | 547 ++++++++++------- lib/runtime/wazero/imports.go | 22 +- lib/runtime/wazero/imports_test.go | 156 +++-- pkg/trie/child_storage.go | 9 +- pkg/trie/child_storage_test.go | 8 +- pkg/trie/database_test.go | 2 +- pkg/trie/layout.go | 10 + 14 files changed, 1602 insertions(+), 437 deletions(-) create mode 100644 lib/runtime/storage/storagediff.go create mode 100644 lib/runtime/storage/storagediff_test.go diff --git a/dot/rpc/modules/childstate_integration_test.go b/dot/rpc/modules/childstate_integration_test.go index ca9576ae45..3cec80fe52 100644 --- a/dot/rpc/modules/childstate_integration_test.go +++ b/dot/rpc/modules/childstate_integration_test.go @@ -247,12 +247,13 @@ func setupChildStateStorage(t *testing.T) (*ChildStateModule, common.Hash) { tr.Put([]byte(":first_key"), []byte(":value1")) tr.Put([]byte(":second_key"), []byte(":second_value")) - childTr := trie.NewEmptyTrie() - childTr.Put([]byte(":child_first"), []byte(":child_first_value")) - childTr.Put([]byte(":child_second"), []byte(":child_second_value")) - childTr.Put([]byte(":another_child"), []byte("value")) + childStorageKey := []byte(":child_storage_key") - err = tr.SetChild([]byte(":child_storage_key"), childTr) + err = tr.SetChildStorage(childStorageKey, []byte(":child_first"), []byte(":child_first_value")) + require.NoError(t, err) + err = tr.SetChildStorage(childStorageKey, []byte(":child_second"), []byte(":child_second_value")) + require.NoError(t, err) + err = tr.SetChildStorage(childStorageKey, []byte(":another_child"), []byte("value")) require.NoError(t, err) stateRoot, err := tr.Root() diff --git a/dot/rpc/modules/childstate_test.go b/dot/rpc/modules/childstate_test.go index be7ba66613..1b3913a3f4 100644 --- a/dot/rpc/modules/childstate_test.go +++ b/dot/rpc/modules/childstate_test.go @@ -27,12 +27,13 @@ func createTestTrieState(t *testing.T) (*trie.Trie, common.Hash) { tr.Put([]byte(":first_key"), []byte(":value1")) tr.Put([]byte(":second_key"), []byte(":second_value")) - childTr := trie.NewEmptyTrie() - childTr.Put([]byte(":child_first"), []byte(":child_first_value")) - childTr.Put([]byte(":child_second"), []byte(":child_second_value")) - childTr.Put([]byte(":another_child"), []byte("value")) + childStorageKey := []byte(":child_storage_key") - err := tr.SetChild([]byte(":child_storage_key"), childTr) + err := tr.SetChildStorage(childStorageKey, []byte(":child_first"), []byte(":child_first_value")) + require.NoError(t, err) + err = tr.SetChildStorage(childStorageKey, []byte(":child_second"), []byte(":child_second_value")) + require.NoError(t, err) + err = tr.SetChildStorage(childStorageKey, []byte(":another_child"), []byte("value")) require.NoError(t, err) stateRoot, err := tr.Root() diff --git a/dot/state/storage_test.go b/dot/state/storage_test.go index d3b401b714..3fdfe0ae40 100644 --- a/dot/state/storage_test.go +++ b/dot/state/storage_test.go @@ -13,7 +13,6 @@ import ( "github.com/ChainSafe/gossamer/lib/common" runtime "github.com/ChainSafe/gossamer/lib/runtime/storage" "github.com/ChainSafe/gossamer/pkg/trie" - "github.com/ChainSafe/gossamer/pkg/trie/node" "go.uber.org/mock/gomock" "github.com/stretchr/testify/require" @@ -44,10 +43,9 @@ func TestStorage_StoreAndLoadTrie(t *testing.T) { trie, err := storage.LoadFromDB(root) require.NoError(t, err) - ts2 := runtime.NewTrieState(trie) - newSnapshot := ts2.Snapshot() + ts2 := runtime.NewTrieState(trie).Trie() - require.True(t, ts.Trie().Equal(newSnapshot)) + require.True(t, trie.Equal(ts2)) } func TestStorage_GetStorageByBlockHash(t *testing.T) { @@ -183,16 +181,9 @@ func TestGetStorageChildAndGetStorageFromChild(t *testing.T) { trieDB := NewMockDatabase(ctrl) trieDB.EXPECT().Get(gomock.Any()).Times(0) - trieRoot := &node.Node{ - PartialKey: []byte{1, 2}, - StorageValue: []byte{3, 4}, - Dirty: true, - } - testChildTrie := trie.NewTrie(trieRoot, trieDB) - - testChildTrie.Put([]byte("keyInsidechild"), []byte("voila")) + genTrie.PutIntoChild([]byte("keyToChild"), []byte{1, 2}, []byte{3, 4}) + genTrie.PutIntoChild([]byte("keyToChild"), []byte("keyInsidechild"), []byte("voila")) - err = genTrie.SetChild([]byte("keyToChild"), testChildTrie) require.NoError(t, err) tries := newTriesEmpty() diff --git a/lib/runtime/interfaces.go b/lib/runtime/interfaces.go index e831fb8c9c..d8b030d15a 100644 --- a/lib/runtime/interfaces.go +++ b/lib/runtime/interfaces.go @@ -9,34 +9,53 @@ import ( "github.com/ChainSafe/gossamer/pkg/trie" ) -// Storage runtime interface. -type Storage interface { +// Trie storage interface. +type Trie interface { + Root() (common.Hash, error) Put(key []byte, value []byte) (err error) Get(key []byte) []byte - Root() (common.Hash, error) - SetChild(keyToChild []byte, child *trie.Trie) error + Delete(key []byte) (err error) + NextKey([]byte) []byte + ClearPrefix(prefix []byte) (err error) + ClearPrefixLimit(prefix []byte, limit uint32) ( + deleted uint32, allDeleted bool, err error) +} + +// ChildTrie storage interface.S +type ChildTrie interface { + GetChildRoot(keyToChild []byte) (common.Hash, error) SetChildStorage(keyToChild, key, value []byte) error GetChildStorage(keyToChild, key []byte) ([]byte, error) - Delete(key []byte) (err error) DeleteChild(keyToChild []byte) (err error) DeleteChildLimit(keyToChild []byte, limit *[]byte) ( deleted uint32, allDeleted bool, err error) ClearChildStorage(keyToChild, key []byte) error - NextKey([]byte) []byte ClearPrefixInChild(keyToChild, prefix []byte) error ClearPrefixInChildWithLimit(keyToChild, prefix []byte, limit uint32) (uint32, bool, error) GetChildNextKey(keyToChild, key []byte) ([]byte, error) - GetChild(keyToChild []byte) (*trie.Trie, error) - ClearPrefix(prefix []byte) (err error) - ClearPrefixLimit(prefix []byte, limit uint32) ( - deleted uint32, allDeleted bool, err error) +} + +// Transactional storage interface. +type Transactional interface { StartTransaction() CommitTransaction() RollbackTransaction() +} + +// Runtime storage interface. +type Runtime interface { LoadCode() []byte SetVersion(v trie.TrieLayout) } +// Storage runtime interface. +type Storage interface { + Trie + ChildTrie + Transactional + Runtime +} + // BasicNetwork interface for functions used by runtime network state function type BasicNetwork interface { NetworkState() common.NetworkState diff --git a/lib/runtime/storage/storagediff.go b/lib/runtime/storage/storagediff.go new file mode 100644 index 0000000000..11edfa9511 --- /dev/null +++ b/lib/runtime/storage/storagediff.go @@ -0,0 +1,292 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package storage + +import ( + "bytes" + "sort" + "strings" + + "golang.org/x/exp/maps" + "golang.org/x/exp/slices" + + "github.com/ChainSafe/gossamer/pkg/trie" +) + +// storageDiff is a structure that stores the differences between consecutive +// states of a trie, such as those occurring during the execution of a block. +// It records updates (upserts), deletions, and changes to child tries. +// This mechanism facilitates applying state transitions efficiently. +// Changes accumulated in storageDiff can be applied to a trie using +// the `applyToTrie` method +// Note: this structure is not thread safe, be careful +type storageDiff struct { + upserts map[string][]byte + deletes map[string]bool + childChangeSet map[string]*storageDiff +} + +// newChangeSet initialises and returns a new storageDiff instance +func newStorageDiff() *storageDiff { + return &storageDiff{ + upserts: make(map[string][]byte), + deletes: make(map[string]bool), + childChangeSet: make(map[string]*storageDiff), + } +} + +// get retrieves the value associated with the key if it's present in the +// change set and returns a boolean indicating if the key is marked for deletion +func (cs *storageDiff) get(key string) ([]byte, bool) { + if cs == nil { + return nil, false + } + + // Check in recent upserts if not found check if we want to delete it + if val, ok := cs.upserts[key]; ok { + return val, false + } else if deleted := cs.deletes[key]; deleted { + return nil, true + } + + return nil, false +} + +// upsert records a new value for the key, or updates an existing value. +// If the key was previously marked for deletion, that deletion is undone +func (cs *storageDiff) upsert(key string, value []byte) { + if cs == nil { + return + } + + // If we previously deleted this trie we have to undo that deletion + if cs.deletes[key] { + delete(cs.deletes, key) + } + + cs.upserts[key] = value +} + +// delete marks a key for deletion and removes it from upserts and +// child changesets, if present. +func (cs *storageDiff) delete(key string) { + if cs == nil { + return + } + + delete(cs.childChangeSet, key) + delete(cs.upserts, key) + cs.deletes[key] = true +} + +// deleteChildLimit deletes lexicographical sorted keys from a child trie with +// a maximum limit, potentially marking the entire child trie for deletion +// if the limit is exceeded. +// Keys created during the current block execution do not count toward the limit +// https://spec.polkadot.network/chap-host-api#id-version-2-prototype-2 +func (cs *storageDiff) deleteChildLimit(keyToChild string, + currentChildKeys []string, limit int) ( + deleted uint32, allDeleted bool) { + + childChanges := cs.childChangeSet[keyToChild] + if childChanges == nil { + childChanges = newStorageDiff() + } + + if limit == -1 { + cs.delete(keyToChild) + deletedKeys := len(childChanges.upserts) + len(currentChildKeys) + return uint32(deletedKeys), true + } + + allKeys := slices.Clone(currentChildKeys) + newKeys := maps.Keys(childChanges.upserts) + allKeys = append(allKeys, newKeys...) + sort.Strings(allKeys) + + for _, k := range allKeys { + if limit == 0 { + break + } + childChanges.delete(k) + deleted++ + // Do not consider keys created during actual block execution + if !slices.Contains(newKeys, k) { + limit-- + } + } + cs.childChangeSet[keyToChild] = childChanges + + return deleted, deleted == uint32(len(allKeys)) +} + +// clearPrefixInChild clears keys with a specific prefix within a child trie. +func (cs *storageDiff) clearPrefixInChild(keyToChild string, prefix []byte, + childKeys []string, limit int) (deleted uint32, allDeleted bool) { + childChanges := cs.childChangeSet[keyToChild] + if childChanges == nil { + childChanges = newStorageDiff() + } + deleted, allDeleted = childChanges.clearPrefix(prefix, childKeys, limit) + cs.childChangeSet[keyToChild] = childChanges + + return deleted, allDeleted +} + +// clearPrefix removes all keys matching a specified prefix, within an +// optional limit. It returns the number of keys deleted and a boolean +// indicating if all keys with the prefix were removed. +func (cs *storageDiff) clearPrefix(prefix []byte, trieKeys []string, limit int) (deleted uint32, allDeleted bool) { + allKeys := slices.Clone(trieKeys) + newKeys := maps.Keys(cs.upserts) + allKeys = append(allKeys, newKeys...) + + deleted = 0 + sort.Strings(allKeys) + for _, k := range allKeys { + if limit == 0 { + break + } + keyBytes := []byte(k) + if bytes.HasPrefix(keyBytes, prefix) { + cs.delete(k) + deleted++ + if !slices.Contains(newKeys, k) { + limit-- + } + } + } + + return deleted, deleted == uint32(len(allKeys)) +} + +// getFromChild attempts to retrieve a value associated with a specific key +// from a child trie's change set identified by keyToChild. +// It returns the value and a boolean indicating if it was marked for deletion. +func (cs *storageDiff) getFromChild(keyToChild, key string) ([]byte, bool) { + if cs == nil { + return nil, false + } + + childTrieChanges := cs.childChangeSet[keyToChild] + if childTrieChanges != nil { + value, deleted := childTrieChanges.get(key) + return value, deleted + } + + return nil, false +} + +// upsertChild inserts or updates a value associated with a key within a +// specific child trie. If the child trie or the key was previously marked for +// deletion, this marking is reversed, and the value is updated. +func (cs *storageDiff) upsertChild(keyToChild, key string, value []byte) { + if cs == nil { + return + } + + // If we previously deleted this child trie we have to undo that deletion + if cs.deletes[keyToChild] { + delete(cs.deletes, keyToChild) + } + + childChanges := cs.childChangeSet[keyToChild] + if childChanges == nil { + childChanges = newStorageDiff() + } + + childChanges.upserts[key] = value + cs.childChangeSet[keyToChild] = childChanges +} + +// deleteFromChild marks a key for deletion within a specific child trie. +func (cs *storageDiff) deleteFromChild(keyToChild, key string) { + if cs == nil { + return + } + + childChanges := cs.childChangeSet[keyToChild] + if childChanges == nil { + childChanges = newStorageDiff() + } + + childChanges.delete(key) + cs.childChangeSet[keyToChild] = childChanges +} + +// snapshot creates a deep copy of the current change set, including all upserts, +// deletions, and child trie change sets. +func (cs *storageDiff) snapshot() *storageDiff { + if cs == nil { + panic("Trying to create snapshot from nil change set") + } + + childChangeSetCopy := make(map[string]*storageDiff) + for k, v := range cs.childChangeSet { + childChangeSetCopy[k] = v.snapshot() + } + + return &storageDiff{ + upserts: maps.Clone(cs.upserts), + deletes: maps.Clone(cs.deletes), + childChangeSet: childChangeSetCopy, + } +} + +// applyToTrie applies all accumulated changes in the change set to the +// provided trie. This includes insertions, deletions, and modifications in both +// the main trie and child tries. +// In case of errors during the application of changes, the method will panic +func (cs *storageDiff) applyToTrie(t *trie.Trie) { + if cs == nil { + panic("trying to apply nil change set") + } + + // Apply trie upserts + for k, v := range cs.upserts { + err := t.Put([]byte(k), v) + if err != nil { + panic("Error applying upserts changes to trie") + } + } + + // Apply child trie upserts + for childKeyString, childChangeSet := range cs.childChangeSet { + childKey := []byte(childKeyString) + + for k, v := range childChangeSet.upserts { + err := t.PutIntoChild(childKey, []byte(k), v) + if err != nil { + panic("Error applying child trie changes to trie") + } + } + + for k := range childChangeSet.deletes { + err := t.ClearFromChild(childKey, []byte(k)) + if err != nil { + if !strings.Contains(err.Error(), trie.ErrChildTrieDoesNotExist.Error()) { + panic("Error applying child trie keys deletion to trie") + } + } + } + } + + // Apply trie deletions + for k := range cs.deletes { + key := []byte(k) + child, _ := t.GetChild(key) + if child != nil { + err := t.DeleteChild(key) + if err != nil { + panic("Error deleting child trie from trie") + } + } else { + err := t.Delete([]byte(k)) + if err != nil { + panic("Error deleting key from trie") + } + } + + } +} diff --git a/lib/runtime/storage/storagediff_test.go b/lib/runtime/storage/storagediff_test.go new file mode 100644 index 0000000000..8959f3b935 --- /dev/null +++ b/lib/runtime/storage/storagediff_test.go @@ -0,0 +1,557 @@ +// Copyright 2024 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package storage + +import ( + "testing" + + "github.com/ChainSafe/gossamer/pkg/trie" + "github.com/stretchr/testify/require" +) + +const testKey = "key" + +var testValue = []byte("value") + +func Test_MainTrie(t *testing.T) { + t.Parallel() + t.Run("get", func(t *testing.T) { + t.Parallel() + t.Run("from_empty", func(t *testing.T) { + t.Parallel() + + changes := newStorageDiff() + val, deleted := changes.get("test") + + require.False(t, deleted) + require.Nil(t, val) + }) + + t.Run("found", func(t *testing.T) { + t.Parallel() + + changes := newStorageDiff() + changes.upsert("test", []byte("test")) + + val, deleted := changes.get("test") + + require.False(t, deleted) + require.Equal(t, []byte("test"), val) + }) + + t.Run("not_found", func(t *testing.T) { + t.Parallel() + + changes := newStorageDiff() + changes.upsert("notfound", []byte("test")) + + val, deleted := changes.get("test") + + require.False(t, deleted) + require.Nil(t, val) + }) + }) + t.Run("upsert", func(t *testing.T) { + t.Parallel() + + changes := newStorageDiff() + + changes.upsert(testKey, testValue) + + val, deleted := changes.get(testKey) + require.False(t, deleted) + require.Equal(t, testValue, val) + }) + t.Run("delete", func(t *testing.T) { + t.Parallel() + + changes := newStorageDiff() + changes.upsert(testKey, testValue) + changes.delete(testKey) + + val, deleted := changes.get(testKey) + require.True(t, deleted) + require.Nil(t, val) + }) + t.Run("clearPrefix", func(t *testing.T) { + t.Parallel() + + testEntries := map[string][]byte{ + "pre": []byte("pre"), + "predict": []byte("predict"), + "prediction": []byte("prediction"), + } + + commonPrefix := []byte("pre") + + cases := map[string]struct { + prefix []byte + limit int + trieKeys []string + deleted uint32 + allDelted bool + }{ + "empty_trie_limit_1": { + prefix: commonPrefix, + limit: 1, + deleted: 3, // Since keys during block exec does not count + allDelted: true, + }, + "empty_trie_limit_2": { + prefix: commonPrefix, + limit: 2, + deleted: 3, // Since keys during block exec does not count + allDelted: true, + }, + "empty_trie_same_limit_than_stored_keys": { + prefix: commonPrefix, + limit: 3, + deleted: 3, + allDelted: true, + }, + "empty_trie_no_limit": { + prefix: commonPrefix, + limit: -1, + deleted: 3, + allDelted: true, + }, + "with_previous_state_not_sharing_prefix_limit_1": { + prefix: commonPrefix, + limit: 1, + trieKeys: []string{"bio"}, + deleted: 3, // Since keys during block exec does not count + allDelted: false, + }, + "with_previous_state_not_sharing_prefix_limit_2": { + prefix: commonPrefix, + limit: 2, + trieKeys: []string{"bio"}, + deleted: 3, // Since keys during block exec does not count + allDelted: false, + }, + "with_previous_state_not_sharing_prefix_limit_3": { + prefix: commonPrefix, + limit: 3, + trieKeys: []string{"bio"}, + deleted: 3, + allDelted: false, + }, + "with_previous_state_not_sharing_prefix_with_no_limit": { + prefix: commonPrefix, + limit: -1, + trieKeys: []string{"bio"}, + deleted: 3, + allDelted: false, + }, + "with_previous_state_sharing_prefix_limit_1": { + prefix: []byte("p"), + limit: 1, + trieKeys: []string{"p"}, + deleted: 1, // the "p" key only + allDelted: false, + }, + "with_previous_state_sharing_prefix_limit_2": { + prefix: []byte("p"), + limit: 2, + trieKeys: []string{"p"}, + deleted: 4, // Since keys during block exec does not count + allDelted: true, + }, + "with_previous_state_sharing_prefix_limit_3": { + prefix: []byte("p"), + limit: 3, + trieKeys: []string{"p"}, + deleted: 4, + allDelted: true, + }, + "with_previous_state_sharing_prefix_with_no_limit": { + prefix: []byte("p"), + limit: -1, + trieKeys: []string{"p"}, + deleted: 4, + allDelted: true, + }, + } + + for tname, tt := range cases { + tt := tt + t.Run(tname, func(t *testing.T) { + t.Parallel() + + changes := newStorageDiff() + + for k, v := range testEntries { + changes.upsert(k, v) + } + + deleted, allDeleted := changes.clearPrefix(tt.prefix, tt.trieKeys, tt.limit) + require.Equal(t, tt.deleted, deleted) + require.Equal(t, tt.allDelted, allDeleted) + }) + } + }) +} + +func Test_ChildTrie(t *testing.T) { + t.Parallel() + t.Run("getFromChild", func(t *testing.T) { + t.Parallel() + + t.Run("empty_storage_diff", func(t *testing.T) { + t.Parallel() + + changes := newStorageDiff() + val, deleted := changes.getFromChild("notFound", "testChildKey") + + require.False(t, deleted) + require.Nil(t, val) + }) + + t.Run("non_existent_child", func(t *testing.T) { + t.Parallel() + + changes := newStorageDiff() + changes.upsertChild("testChild", "testChildKey", []byte("test")) + val, deleted := changes.getFromChild("notFound", "testChildKey") + + require.False(t, deleted) + require.Nil(t, val) + }) + + t.Run("not_found_in_child", func(t *testing.T) { + t.Parallel() + + changes := newStorageDiff() + changes.upsertChild("testChild", "testChildKey", []byte("test")) + val, deleted := changes.getFromChild("testChild", "notFound") + + require.False(t, deleted) + require.Nil(t, val) + }) + + t.Run("found_in_child", func(t *testing.T) { + t.Parallel() + + changes := newStorageDiff() + changes.upsertChild("testChild", "testChildKey", []byte("test")) + val, deleted := changes.getFromChild("testChild", "testChildKey") + + require.False(t, deleted) + require.Equal(t, []byte("test"), val) + }) + }) + + t.Run("upsertChild", func(t *testing.T) { + t.Parallel() + + changes := newStorageDiff() + + childkey := "testChild" + changes.upsertChild(childkey, testKey, testValue) + + val, deleted := changes.getFromChild(childkey, testKey) + + require.False(t, deleted) + require.Equal(t, testValue, val) + }) + + t.Run("deleteFromChild", func(t *testing.T) { + t.Parallel() + + changes := newStorageDiff() + + childkey := "testChild" + changes.upsertChild(childkey, testKey, testValue) + changes.deleteFromChild(childkey, testKey) + + val, deleted := changes.getFromChild(childkey, testKey) + + require.True(t, deleted) + require.Nil(t, val) + }) + + t.Run("deleteChildLimit", func(t *testing.T) { + t.Parallel() + + testEntries := map[string][]byte{ + "key1": []byte("key1"), + "key2": []byte("key2"), + "key3": []byte("key3"), + } + + cases := map[string]struct { + limit int + currentChildKeys []string + deleted uint32 + allDelted bool + }{ + "empty_child_trie_limit_1": { + limit: 1, + deleted: 3, // Since keys during block exec does not count + allDelted: true, + }, + "empty_child_trie_limit_2": { + limit: 2, + deleted: 3, // Since keys during block exec does not count + allDelted: true, + }, + "empty_child_trie_same_limit_than_stored_keys": { + limit: 3, + deleted: 3, + allDelted: true, + }, + "empty_child_trie_no_limit": { + limit: -1, + deleted: 3, + allDelted: true, + }, + "with_current_child_trie_1_entry_limit_1": { + limit: 1, + currentChildKeys: []string{"currentKey1"}, + deleted: 1, // Deletes currentKey1 only + allDelted: false, + }, + "with_current_child_trie_1_entry_limit_2": { + limit: 2, + currentChildKeys: []string{"currentKey1"}, + deleted: 4, // Since keys during block exec does not count + allDelted: true, + }, + "with_current_child_trie_1_entry_limit_3": { + limit: 3, + currentChildKeys: []string{"currentKey1"}, + deleted: 4, + allDelted: true, + }, + "with_current_child_trie_with_no_limit": { + limit: -1, + currentChildKeys: []string{"currentKey1"}, + deleted: 4, + allDelted: true, + }, + } + + for tname, tt := range cases { + tt := tt + t.Run(tname, func(t *testing.T) { + t.Parallel() + + changes := newStorageDiff() + + for k, v := range testEntries { + changes.upsertChild(testKey, k, v) + } + + deleted, allDeleted := changes.deleteChildLimit(testKey, tt.currentChildKeys, tt.limit) + require.Equal(t, tt.deleted, deleted) + require.Equal(t, tt.allDelted, allDeleted) + }) + } + }) + + t.Run("clearPrefixInChild", func(t *testing.T) { + t.Parallel() + + testEntries := map[string][]byte{ + "pre": []byte("pre"), + "predict": []byte("predict"), + "prediction": []byte("prediction"), + } + + commonPrefix := []byte("pre") + + cases := map[string]struct { + prefix []byte + limit int + trieKeys []string + deleted uint32 + allDelted bool + }{ + "empty_trie_limit_1": { + prefix: commonPrefix, + limit: 1, + deleted: 3, // Since keys during block exec does not count + allDelted: true, + }, + "empty_trie_limit_2": { + prefix: commonPrefix, + limit: 2, + deleted: 3, // Since keys during block exec does not count + allDelted: true, + }, + "empty_trie_same_limit_than_stored_keys": { + prefix: commonPrefix, + limit: 3, + deleted: 3, + allDelted: true, + }, + "empty_trie_no_limit": { + prefix: commonPrefix, + limit: -1, + deleted: 3, + allDelted: true, + }, + "with_previous_state_not_sharing_prefix_limit_1": { + prefix: commonPrefix, + limit: 1, + trieKeys: []string{"bio"}, + deleted: 3, // Since keys during block exec does not count + allDelted: false, + }, + "with_previous_state_not_sharing_prefix_limit_2": { + prefix: commonPrefix, + limit: 2, + trieKeys: []string{"bio"}, + deleted: 3, // Since keys during block exec does not count + allDelted: false, + }, + "with_previous_state_not_sharing_prefix_limit_3": { + prefix: commonPrefix, + limit: 3, + trieKeys: []string{"bio"}, + deleted: 3, + allDelted: false, + }, + "with_previous_state_not_sharing_prefix_with_no_limit": { + prefix: commonPrefix, + limit: -1, + trieKeys: []string{"bio"}, + deleted: 3, + allDelted: false, + }, + "with_previous_state_sharing_prefix_limit_1": { + prefix: []byte("p"), + limit: 1, + trieKeys: []string{"p"}, + deleted: 1, // the "p" key only + allDelted: false, + }, + "with_previous_state_sharing_prefix_limit_2": { + prefix: []byte("p"), + limit: 2, + trieKeys: []string{"p"}, + deleted: 4, // Since keys during block exec does not count + allDelted: true, + }, + "with_previous_state_sharing_prefix_limit_3": { + prefix: []byte("p"), + limit: 3, + trieKeys: []string{"p"}, + deleted: 4, + allDelted: true, + }, + "with_previous_state_sharing_prefix_with_no_limit": { + prefix: []byte("p"), + limit: -1, + trieKeys: []string{"p"}, + deleted: 4, + allDelted: true, + }, + } + + for tname, tt := range cases { + tt := tt + t.Run(tname, func(t *testing.T) { + t.Parallel() + + changes := newStorageDiff() + + for k, v := range testEntries { + changes.upsertChild("child", k, v) + } + + deleted, allDeleted := changes.clearPrefixInChild("child", tt.prefix, tt.trieKeys, tt.limit) + require.Equal(t, tt.deleted, deleted) + require.Equal(t, tt.allDelted, allDeleted) + }) + } + }) +} + +func Test_Snapshot(t *testing.T) { + t.Parallel() + + changes := newStorageDiff() + + changes.upsert("key1", []byte("value1")) + changes.upsert("key2", []byte("value2")) + changes.delete("key2") + changes.upsertChild("childKey", "key1", []byte("value1")) + changes.upsertChild("childKey", "key2", []byte("value2")) + + snapshot := changes.snapshot() + + require.Equal(t, changes, snapshot) +} + +func Test_ApplyToTrie(t *testing.T) { + t.Parallel() + + const key = "key1" + var value = []byte("value1") + + t.Run("add_entries_in_main_trie", func(t *testing.T) { + t.Parallel() + + state := trie.NewEmptyTrie() + + diff := newStorageDiff() + diff.upsert(key, value) + + expected := trie.NewEmptyTrie() + expected.Put([]byte(key), value) + + diff.applyToTrie(state) + require.Equal(t, expected, state) + }) + + t.Run("delete_entries_from_main_trie", func(t *testing.T) { + t.Parallel() + + state := trie.NewEmptyTrie() + state.Put([]byte(key), value) + + diff := newStorageDiff() + diff.delete(key) + + expected := trie.NewEmptyTrie() + + diff.applyToTrie(state) + require.Equal(t, expected, state) + }) + + t.Run("create_new_child_trie", func(t *testing.T) { + t.Parallel() + + state := trie.NewEmptyTrie() + + childKey := "child" + + diff := newStorageDiff() + diff.upsertChild(childKey, key, value) + + expected := trie.NewEmptyTrie() + expected.PutIntoChild([]byte(childKey), []byte(key), value) + + diff.applyToTrie(state) + require.Equal(t, expected, state) + }) + + t.Run("remove_from_child_trie", func(t *testing.T) { + t.Parallel() + + childKey := "child" + + state := trie.NewEmptyTrie() + state.PutIntoChild([]byte(childKey), []byte(key), value) + + expected := trie.NewEmptyTrie() + + diff := newStorageDiff() + diff.deleteFromChild(childKey, key) + + diff.applyToTrie(state) + require.Equal(t, expected, state) + }) +} diff --git a/lib/runtime/storage/trie.go b/lib/runtime/storage/trie.go index 27fb419569..323865661f 100644 --- a/lib/runtime/storage/trie.go +++ b/lib/runtime/storage/trie.go @@ -4,6 +4,7 @@ package storage import ( + "bytes" "container/list" "encoding/binary" "fmt" @@ -15,27 +16,37 @@ import ( "golang.org/x/exp/maps" ) -// TrieState is a wrapper around a transient trie that is used during the course of executing some runtime call. -// If the execution of the call is successful, the trie will be saved in the StorageState. +// TrieState relies on `storageDiff` to perform changes over the current state. +// It has support for transactions using "nested" storageDiff changes +// If the execution of the call is successful, the changes will be applied to +// the current `state` type TrieState struct { mtx sync.RWMutex + state *trie.Trie transactions *list.List } -func NewTrieState(state *trie.Trie) *TrieState { +// NewTrieState initialises and returns a new TrieState instance +func NewTrieState(initialState *trie.Trie) *TrieState { transactions := list.New() - transactions.PushBack(state) return &TrieState{ transactions: transactions, + state: initialState, } } -func (t *TrieState) getCurrentTrie() *trie.Trie { - return t.transactions.Back().Value.(*trie.Trie) +func (t *TrieState) getCurrentTransaction() *storageDiff { + innerTransaction := t.transactions.Back() + if innerTransaction == nil { + return nil + } + return innerTransaction.Value.(*storageDiff) } -func (t *TrieState) updateCurrentTrie(new *trie.Trie) { - t.transactions.Back().Value = new +func (t *TrieState) SetVersion(v trie.TrieLayout) { + t.mtx.Lock() + defer t.mtx.Unlock() + t.state.SetVersion(v) } // StartTransaction begins a new nested storage transaction @@ -44,37 +55,42 @@ func (t *TrieState) StartTransaction() { t.mtx.Lock() defer t.mtx.Unlock() - t.transactions.PushBack(t.getCurrentTrie().Snapshot()) + nextChangeSet := t.getCurrentTransaction() + if nextChangeSet == nil { + nextChangeSet = newStorageDiff() + } + + t.transactions.PushBack(nextChangeSet.snapshot()) } -// Rollback rolls back all storage changes made since StartTransaction was called. +// RollbackTransaction back all storage changes made since StartTransaction was called. func (t *TrieState) RollbackTransaction() { t.mtx.Lock() defer t.mtx.Unlock() - if t.transactions.Len() <= 1 { + if t.transactions.Len() < 1 { panic("no transactions to rollback") } t.transactions.Remove(t.transactions.Back()) } -// Commit commits all storage changes made since StartTransaction was called. +// CommitTransaction all storage changes made since StartTransaction was called. func (t *TrieState) CommitTransaction() { t.mtx.Lock() defer t.mtx.Unlock() - if t.transactions.Len() <= 1 { + if t.transactions.Len() == 0 { panic("no transactions to commit") } - t.transactions.Back().Prev().Value = t.transactions.Remove(t.transactions.Back()) -} - -func (t *TrieState) SetVersion(v trie.TrieLayout) { - t.mtx.Lock() - defer t.mtx.Unlock() - t.getCurrentTrie().SetVersion(v) + if t.transactions.Len() > 1 { + // We merge this transaction with its parent transaction + t.transactions.Back().Prev().Value = t.transactions.Remove(t.transactions.Back()) + } else { + // This is the last transaction so we apply all the changes to our state + t.transactions.Remove(t.transactions.Back()).(*storageDiff).applyToTrie(t.state) + } } // Trie returns the TrieState's underlying trie @@ -82,17 +98,7 @@ func (t *TrieState) Trie() *trie.Trie { t.mtx.RLock() defer t.mtx.RUnlock() - return t.getCurrentTrie() -} - -// Snapshot creates a new "version" of the trie. The trie before Snapshot is called -// can no longer be modified, all further changes are on a new "version" of the trie. -// It returns the new version of the trie. -func (t *TrieState) Snapshot() *trie.Trie { - t.mtx.RLock() - defer t.mtx.RUnlock() - - return t.getCurrentTrie().Snapshot() + return t.state } // Put puts a key-value pair in the trie @@ -100,30 +106,47 @@ func (t *TrieState) Put(key, value []byte) (err error) { t.mtx.Lock() defer t.mtx.Unlock() - return t.getCurrentTrie().Put(key, value) + // If we have running transactions we apply the change there, + // if not, we apply the changes directly on our state trie + if t.getCurrentTransaction() != nil { + t.getCurrentTransaction().upsert(string(key), value) + return nil + } else { + return t.state.Put(key, value) + } } // Get gets a value from the trie func (t *TrieState) Get(key []byte) []byte { t.mtx.RLock() defer t.mtx.RUnlock() - return t.getCurrentTrie().Get(key) + + // If we find the key or it is deleted return from latest transaction + if currentTx := t.getCurrentTransaction(); currentTx != nil { + val, deleted := currentTx.get(string(key)) + if val != nil || deleted { + return val + } + } + + // If we didn't find the key in the latest transactions lookup from state + return t.state.Get(key) } // MustRoot returns the trie's root hash. It panics if it fails to compute the root. func (t *TrieState) MustRoot() common.Hash { - t.mtx.RLock() - defer t.mtx.RUnlock() + hash, err := t.Root() + if err != nil { + panic(err) + } - return t.getCurrentTrie().MustHash() + return hash } // Root returns the trie's root hash func (t *TrieState) Root() (common.Hash, error) { - t.mtx.RLock() - defer t.mtx.RUnlock() - - return t.getCurrentTrie().Hash() + // Since the Root function is called without running transactions we can do: + return t.state.Hash() } // Has returns whether or not a key exists @@ -133,34 +156,51 @@ func (t *TrieState) Has(key []byte) bool { // Delete deletes a key from the trie func (t *TrieState) Delete(key []byte) (err error) { - val := t.getCurrentTrie().Get(key) - if val == nil { - return nil - } - t.mtx.Lock() defer t.mtx.Unlock() - err = t.getCurrentTrie().Delete(key) - if err != nil { - return fmt.Errorf("deleting from trie: %w", err) + if currentTx := t.getCurrentTransaction(); currentTx != nil { + t.getCurrentTransaction().delete(string(key)) + return nil } - return nil + return t.state.Delete(key) } // NextKey returns the next key in the trie in lexicographical order. If it does not exist, it returns nil. func (t *TrieState) NextKey(key []byte) []byte { t.mtx.RLock() defer t.mtx.RUnlock() - return t.getCurrentTrie().NextKey(key) + + if currentTx := t.getCurrentTransaction(); currentTx != nil { + allEntries := t.state.Entries() + maps.Copy(allEntries, currentTx.upserts) + + keys := maps.Keys(allEntries) + sort.Strings(keys) + + for _, k := range keys { + if k > string(key) && !currentTx.deletes[k] { + return allEntries[k] + } + } + } + + return t.state.NextKey(key) } // ClearPrefix deletes all key-value pairs from the trie where the key starts with the given prefix func (t *TrieState) ClearPrefix(prefix []byte) (err error) { t.mtx.Lock() defer t.mtx.Unlock() - return t.getCurrentTrie().ClearPrefix(prefix) + + if currentTx := t.getCurrentTransaction(); currentTx != nil { + trieKeys := t.state.Entries() + currentTx.clearPrefix(prefix, maps.Keys(trieKeys), -1) + return + } + + return t.state.ClearPrefix(prefix) } // ClearPrefixLimit deletes key-value pairs from the trie where the key starts with the given prefix till limit reached @@ -169,21 +209,36 @@ func (t *TrieState) ClearPrefixLimit(prefix []byte, limit uint32) ( t.mtx.Lock() defer t.mtx.Unlock() - return t.getCurrentTrie().ClearPrefixLimit(prefix, limit) + if currentTx := t.getCurrentTransaction(); currentTx != nil { + trieKeys := t.state.Entries() + deleted, allDeleted = currentTx.clearPrefix(prefix, maps.Keys(trieKeys), int(limit)) + return deleted, allDeleted, nil + } + + return t.state.ClearPrefixLimit(prefix, limit) } // TrieEntries returns every key-value pair in the trie func (t *TrieState) TrieEntries() map[string][]byte { t.mtx.RLock() defer t.mtx.RUnlock() - return t.getCurrentTrie().Entries() -} -// SetChild sets the child trie at the given key -func (t *TrieState) SetChild(keyToChild []byte, child *trie.Trie) error { - t.mtx.Lock() - defer t.mtx.Unlock() - return t.getCurrentTrie().SetChild(keyToChild, child) + entries := make(map[string][]byte) + + // Get entries from original trie + maps.Copy(entries, t.state.Entries()) + + if currentTx := t.getCurrentTransaction(); currentTx != nil { + // Overwrite it with last changes + maps.Copy(entries, t.getCurrentTransaction().upserts) + + // Remove deleted keys + for k := range t.getCurrentTransaction().deletes { + delete(entries, k) + } + } + + return entries } // SetChildStorage sets a key-value pair in a child trie @@ -191,15 +246,26 @@ func (t *TrieState) SetChildStorage(keyToChild, key, value []byte) error { t.mtx.Lock() defer t.mtx.Unlock() - return t.getCurrentTrie().PutIntoChild(keyToChild, key, value) + if currentTx := t.getCurrentTransaction(); currentTx != nil { + keyToChildStr := string(keyToChild) + keyString := string(key) + currentTx.upsertChild(keyToChildStr, keyString, value) + return nil + } + + return t.state.PutIntoChild(keyToChild, key, value) } -// GetChild returns the child trie at the given key -func (t *TrieState) GetChild(keyToChild []byte) (*trie.Trie, error) { +func (t *TrieState) GetChildRoot(keyToChild []byte) (common.Hash, error) { t.mtx.RLock() defer t.mtx.RUnlock() - return t.getCurrentTrie().GetChild(keyToChild) + child, err := t.state.GetChild(keyToChild) + if err != nil { + return common.EmptyHash, err + } + + return child.Hash() } // GetChildStorage returns a value from a child trie @@ -207,15 +273,28 @@ func (t *TrieState) GetChildStorage(keyToChild, key []byte) ([]byte, error) { t.mtx.RLock() defer t.mtx.RUnlock() - return t.getCurrentTrie().GetFromChild(keyToChild, key) + if currentTx := t.getCurrentTransaction(); currentTx != nil { + val, deleted := currentTx.getFromChild(string(keyToChild), string(key)) + if val != nil || deleted { + return val, nil + } + } + + // If we didnt find the key in the latest transactions lookup from state + return t.state.GetFromChild(keyToChild, key) } // DeleteChild deletes a child trie from the main trie -func (t *TrieState) DeleteChild(key []byte) (err error) { +func (t *TrieState) DeleteChild(keyToChild []byte) (err error) { t.mtx.Lock() defer t.mtx.Unlock() - return t.getCurrentTrie().DeleteChild(key) + if currentTx := t.getCurrentTransaction(); currentTx != nil { + currentTx.delete(string(keyToChild)) + return nil + } + + return t.state.DeleteChild(keyToChild) } // DeleteChildLimit deletes up to limit of database entries by lexicographic order. @@ -224,22 +303,43 @@ func (t *TrieState) DeleteChildLimit(key []byte, limit *[]byte) ( t.mtx.Lock() defer t.mtx.Unlock() - trieSnapshot := t.getCurrentTrie().Snapshot() + if currentTx := t.getCurrentTransaction(); currentTx != nil { + deleteLimit := -1 + if limit != nil { + deleteLimit = int(binary.LittleEndian.Uint32(*limit)) + } + + childKey := string(key) + + child, err := t.state.GetChild(key) + + childEntriesKeys := make([]string, 0) + if err != nil { + // If child trie does not exists and won't be created return err + if currentTx.childChangeSet[childKey] == nil { + return 0, false, err + } + } else { + childEntriesKeys = maps.Keys(child.Entries()) + } - tr, err := trieSnapshot.GetChild(key) + deleted, allDeleted = currentTx.deleteChildLimit(childKey, childEntriesKeys, deleteLimit) + return deleted, allDeleted, nil + } + + child, err := t.state.GetChild(key) if err != nil { return 0, false, err } - childTrieEntries := tr.Entries() + childTrieEntries := child.Entries() qtyEntries := uint32(len(childTrieEntries)) if limit == nil { - err = trieSnapshot.DeleteChild(key) + err = t.state.DeleteChild(key) if err != nil { return 0, false, fmt.Errorf("deleting child trie: %w", err) } - t.updateCurrentTrie(trieSnapshot) return qtyEntries, true, nil } limitUint := binary.LittleEndian.Uint32(*limit) @@ -252,7 +352,7 @@ func (t *TrieState) DeleteChildLimit(key []byte, limit *[]byte) ( // a bad intermediary state. Take also care of the caching of deleted Merkle // values within the tries, which is used for online pruning. // See https://github.com/ChainSafe/gossamer/issues/3032 - err = tr.Delete([]byte(k)) + err = child.Delete([]byte(k)) if err != nil { return deleted, allDeleted, fmt.Errorf("deleting from child trie located at key 0x%x: %w", key, err) } @@ -262,7 +362,7 @@ func (t *TrieState) DeleteChildLimit(key []byte, limit *[]byte) ( break } } - t.updateCurrentTrie(trieSnapshot) + allDeleted = deleted == qtyEntries return deleted, allDeleted, nil } @@ -271,7 +371,15 @@ func (t *TrieState) DeleteChildLimit(key []byte, limit *[]byte) ( func (t *TrieState) ClearChildStorage(keyToChild, key []byte) error { t.mtx.Lock() defer t.mtx.Unlock() - return t.getCurrentTrie().ClearFromChild(keyToChild, key) + + if currentTx := t.getCurrentTransaction(); currentTx != nil { + keyToChildStr := string(keyToChild) + keyStr := string(key) + currentTx.deleteFromChild(keyToChildStr, keyStr) + return nil + } + + return t.state.ClearFromChild(keyToChild, key) } // ClearPrefixInChild clears all the keys from the child trie that have the given prefix @@ -279,7 +387,18 @@ func (t *TrieState) ClearPrefixInChild(keyToChild, prefix []byte) error { t.mtx.Lock() defer t.mtx.Unlock() - child, err := t.getCurrentTrie().GetChild(keyToChild) + if currentTx := t.getCurrentTransaction(); currentTx != nil { + child, err := t.state.GetChild(keyToChild) + childKeys := make([]string, 0) + if err == nil { + childKeys = maps.Keys(child.Entries()) + } + + currentTx.clearPrefixInChild(string(keyToChild), prefix, childKeys, -1) + return nil + } + + child, err := t.state.GetChild(keyToChild) if err != nil { return err } @@ -299,7 +418,18 @@ func (t *TrieState) ClearPrefixInChildWithLimit(keyToChild, prefix []byte, limit t.mtx.Lock() defer t.mtx.Unlock() - child, err := t.getCurrentTrie().GetChild(keyToChild) + if currentTx := t.getCurrentTransaction(); currentTx != nil { + child, err := t.state.GetChild(keyToChild) + childKeys := make([]string, 0) + if err == nil { + childKeys = maps.Keys(child.Entries()) + } + + deleted, allDeleted := currentTx.clearPrefixInChild(string(keyToChild), prefix, childKeys, int(limit)) + return deleted, allDeleted, nil + } + + child, err := t.state.GetChild(keyToChild) if err != nil || child == nil { return 0, false, err } @@ -311,19 +441,88 @@ func (t *TrieState) ClearPrefixInChildWithLimit(keyToChild, prefix []byte, limit func (t *TrieState) GetChildNextKey(keyToChild, key []byte) ([]byte, error) { t.mtx.RLock() defer t.mtx.RUnlock() - child, err := t.getCurrentTrie().GetChild(keyToChild) + + if currentTx := t.getCurrentTransaction(); currentTx != nil { + // If we are going to delete this child we return error + if currentTx.deletes[string(keyToChild)] { + return nil, trie.ErrChildTrieDoesNotExist + } + + if childChanges := currentTx.childChangeSet[string(keyToChild)]; childChanges != nil { + allEntries := make(map[string][]byte) + + maps.Copy(allEntries, childChanges.upserts) + child, err := t.state.GetChild(keyToChild) + if err != nil { + // Child trie does not exists and won't exists in the future + if len(allEntries) == 0 { + return nil, err + } + } else { + allEntries = child.Entries() + } + keys := maps.Keys(allEntries) + sort.Strings(keys) + + for _, k := range keys { + if k > string(key) && !childChanges.deletes[k] { + return allEntries[k], nil + } + } + return nil, nil + } + } + + child, err := t.state.GetChild(keyToChild) if err != nil { return nil, err } if child == nil { return nil, nil } + return child.NextKey(key), nil } // GetKeysWithPrefixFromChild ... func (t *TrieState) GetKeysWithPrefixFromChild(keyToChild, prefix []byte) ([][]byte, error) { - child, err := t.GetChild(keyToChild) + t.mtx.RLock() + defer t.mtx.RUnlock() + + if currentTx := t.getCurrentTransaction(); currentTx != nil { + // If we are going to delete this child we return error + if currentTx.deletes[string(keyToChild)] { + return nil, trie.ErrChildTrieDoesNotExist + } + + if childChanges := currentTx.childChangeSet[string(keyToChild)]; childChanges != nil { + allEntries := make(map[string][]byte) + + maps.Copy(allEntries, childChanges.upserts) + child, err := t.state.GetChild(keyToChild) + if err != nil { + // Child trie does not exists and won't exists in the future + if len(allEntries) == 0 { + return nil, err + } + } else { + allEntries = child.Entries() + } + keys := maps.Keys(allEntries) + + values := make([][]byte, 0) + + for _, k := range keys { + if bytes.HasPrefix([]byte(k), prefix) { + values = append(values, allEntries[k]) + } + } + + return values, nil + } + } + + child, err := t.state.GetChild(keyToChild) if err != nil { return nil, err } @@ -349,5 +548,6 @@ func (t *TrieState) LoadCodeHash() (common.Hash, error) { func (t *TrieState) GetChangedNodeHashes() (inserted, deleted map[common.Hash]struct{}, err error) { t.mtx.RLock() defer t.mtx.RUnlock() - return t.getCurrentTrie().GetChangedNodeHashes() + + return t.state.GetChangedNodeHashes() } diff --git a/lib/runtime/storage/trie_test.go b/lib/runtime/storage/trie_test.go index da797731e7..0696fa08b1 100644 --- a/lib/runtime/storage/trie_test.go +++ b/lib/runtime/storage/trie_test.go @@ -4,10 +4,8 @@ package storage import ( - "bytes" "encoding/binary" "fmt" - "sort" "testing" "github.com/ChainSafe/gossamer/lib/common" @@ -24,202 +22,380 @@ var testCases = []string{ "bnm", } -func TestTrieState_SetGet(t *testing.T) { - testFunc := func(ts *TrieState) { - for _, tc := range testCases { - ts.Put([]byte(tc), []byte(tc)) - } +func TestTrieState_WithAndWithoutTransactions(t *testing.T) { + t.Parallel() - for _, tc := range testCases { - res := ts.Get([]byte(tc)) - require.Equal(t, []byte(tc), res) - } + prefixedKeys := [][]byte{ + []byte("noot"), + []byte("noodle"), + []byte("other"), } - - ts := NewTrieState(trie.NewEmptyTrie()) - testFunc(ts) -} - -func TestTrieState_SetGetChildStorage(t *testing.T) { - ts := NewTrieState(trie.NewEmptyTrie()) - - for _, tc := range testCases { - childTrie := trie.NewEmptyTrie() - err := ts.SetChild([]byte(tc), childTrie) - require.NoError(t, err) - - err = ts.SetChildStorage([]byte(tc), []byte(tc), []byte(tc)) - require.NoError(t, err) + sortedKeys := [][]byte{ + []byte("key1"), + []byte("key2"), + []byte("key3"), } - for _, tc := range testCases { - res, err := ts.GetChildStorage([]byte(tc), []byte(tc)) - require.NoError(t, err) - require.Equal(t, []byte(tc), res) - } -} - -func TestTrieState_SetAndClearFromChild(t *testing.T) { - testFunc := func(ts *TrieState) { - for _, tc := range testCases { - childTrie := trie.NewEmptyTrie() - err := ts.SetChild([]byte(tc), childTrie) - require.NoError(t, err) - - err = ts.SetChildStorage([]byte(tc), []byte(tc), []byte(tc)) - require.NoError(t, err) - } - - for _, tc := range testCases { - err := ts.ClearChildStorage([]byte(tc), []byte(tc)) - require.NoError(t, err) - - _, err = ts.GetChildStorage([]byte(tc), []byte(tc)) - require.ErrorContains(t, err, "child trie does not exist at key") - } - } - - ts := NewTrieState(trie.NewEmptyTrie()) - testFunc(ts) -} + keyToChild := []byte("keytochild") -func TestTrieState_Delete(t *testing.T) { - testFunc := func(ts *TrieState) { - for _, tc := range testCases { - ts.Put([]byte(tc), []byte(tc)) - } + cases := map[string]struct { + changes func(t *testing.T, ts *TrieState) + checks func(t *testing.T, ts *TrieState, isTransactionRunning bool) + }{ + "set_get": { + changes: func(t *testing.T, ts *TrieState) { + for _, tc := range testCases { + err := ts.Put([]byte(tc), []byte(tc)) + require.NoError(t, err) + } + }, + checks: func(t *testing.T, ts *TrieState, _ bool) { + for _, tc := range testCases { + res := ts.Get([]byte(tc)) + require.Equal(t, []byte(tc), res) + } + }, + }, + "set_child_storage": { + changes: func(t *testing.T, ts *TrieState) { + for _, tc := range testCases { + err := ts.SetChildStorage([]byte(tc), []byte(tc), []byte(tc)) + require.NoError(t, err) + } + }, + checks: func(t *testing.T, ts *TrieState, _ bool) { + for _, tc := range testCases { + res, err := ts.GetChildStorage([]byte(tc), []byte(tc)) + require.NoError(t, err) + require.Equal(t, []byte(tc), res) + } + }, + }, + "set_and_clear_from_child": { + changes: func(t *testing.T, ts *TrieState) { + for _, tc := range testCases { + err := ts.SetChildStorage([]byte(tc), []byte(tc), []byte(tc)) + require.NoError(t, err) + } + }, + checks: func(t *testing.T, ts *TrieState, isTransactionRunning bool) { + for _, tc := range testCases { + err := ts.ClearChildStorage([]byte(tc), []byte(tc)) + require.NoError(t, err) - ts.Delete([]byte(testCases[0])) - has := ts.Has([]byte(testCases[0])) - require.False(t, has) - } + val, err := ts.GetChildStorage([]byte(tc), []byte(tc)) - ts := NewTrieState(trie.NewEmptyTrie()) - testFunc(ts) -} + require.Nil(t, val) -func TestTrieState_Root(t *testing.T) { - testFunc := func(ts *TrieState) { - for _, tc := range testCases { - ts.Put([]byte(tc), []byte(tc)) - } + if isTransactionRunning { + require.Nil(t, err) + } else { + require.ErrorContains(t, err, "child trie does not exist at key") + } + } + }, + }, + "delete": { + changes: func(t *testing.T, ts *TrieState) { + for _, tc := range testCases { + ts.Put([]byte(tc), []byte(tc)) + } + }, + checks: func(t *testing.T, ts *TrieState, _ bool) { + ts.Delete([]byte(testCases[0])) + has := ts.Has([]byte(testCases[0])) + require.False(t, has) + }, + }, + "delete_child": { + changes: func(t *testing.T, ts *TrieState) { + for _, tc := range prefixedKeys { + ts.SetChildStorage(keyToChild, tc, tc) + } + }, + checks: func(t *testing.T, ts *TrieState, _ bool) { + err := ts.DeleteChild(keyToChild) + require.Nil(t, err) - expected := ts.MustRoot() - require.Equal(t, expected, ts.MustRoot()) - } + root, err := ts.GetChildStorage(keyToChild, prefixedKeys[0]) + require.NotNil(t, err) + require.Nil(t, root) + }, + }, + "clear_prefix": { + changes: func(t *testing.T, ts *TrieState) { + for i, key := range prefixedKeys { + err := ts.Put(key, []byte{byte(i)}) + require.Nil(t, err) + } + }, + checks: func(t *testing.T, ts *TrieState, _ bool) { + err := ts.ClearPrefix([]byte("noo")) + require.Nil(t, err) + + for i, key := range prefixedKeys { + val := ts.Get(key) + if i < 2 { + require.Nil(t, val) + } else { + require.NotNil(t, val) + } + } + }, + }, + "clear_prefix_with_limit_1": { + changes: func(t *testing.T, ts *TrieState) { + for i, key := range prefixedKeys { + err := ts.Put(key, []byte{byte(i)}) + require.Nil(t, err) + } + }, + checks: func(t *testing.T, ts *TrieState, isTransactionRunning bool) { + deleted, allDeleted, err := ts.ClearPrefixLimit([]byte("noo"), uint32(1)) + require.Nil(t, err) + + if isTransactionRunning { + // New keys are not considered towards the limit + require.Equal(t, uint32(2), deleted) + require.False(t, allDeleted) + } else { + require.Equal(t, uint32(1), deleted) + require.False(t, allDeleted) + } + }, + }, + "clear_prefix_in_child": { + changes: func(t *testing.T, ts *TrieState) { + for i, key := range prefixedKeys { + err := ts.SetChildStorage(keyToChild, key, []byte{byte(i)}) + require.NoError(t, err) + } + }, + checks: func(t *testing.T, ts *TrieState, _ bool) { + err := ts.ClearPrefixInChild(keyToChild, []byte("noo")) + require.NoError(t, err) + + for i, key := range prefixedKeys { + val, err := ts.GetChildStorage(keyToChild, key) + require.NoError(t, err) + if i < 2 { + require.Nil(t, val) + } else { + require.NotNil(t, val) + } + } + }, + }, + "clear_prefix_in_child_with_limit_1": { + changes: func(t *testing.T, ts *TrieState) { + for i, key := range prefixedKeys { + err := ts.SetChildStorage(keyToChild, key, []byte{byte(i)}) + require.NoError(t, err) + } - ts := NewTrieState(trie.NewEmptyTrie()) - testFunc(ts) -} + }, + checks: func(t *testing.T, ts *TrieState, isTransactionRunning bool) { + deleted, allDeleted, err := ts.ClearPrefixInChildWithLimit(keyToChild, []byte("noo"), uint32(1)) -func TestTrieState_ClearPrefix(t *testing.T) { - ts := NewTrieState(trie.NewEmptyTrie()) + require.NoError(t, err) + require.False(t, allDeleted) - keys := []string{ - "noot", - "noodle", - "other", - } + if isTransactionRunning { + require.Equal(t, uint32(2), deleted) + } else { + require.Equal(t, uint32(1), deleted) + } + }, + }, + "delete_child_limit_child_not_exists": { + changes: func(t *testing.T, ts *TrieState) { + for i, key := range sortedKeys { + err := ts.SetChildStorage(keyToChild, key, []byte{byte(i)}) + require.Nil(t, err) + } + }, + checks: func(t *testing.T, ts *TrieState, isTransactionRunning bool) { + testLimitBytes := make([]byte, 4) + binary.LittleEndian.PutUint32(testLimitBytes, uint32(2)) + optLimit2 := &testLimitBytes - for i, key := range keys { - ts.Put([]byte(key), []byte{byte(i)}) - } + errMsg := fmt.Sprintf("child trie does not exist at key 0x%x", ":child_storage:default:fakekey") - ts.ClearPrefix([]byte("noo")) + _, _, err := ts.DeleteChildLimit([]byte("fakekey"), optLimit2) + require.Error(t, err) + require.EqualError(t, err, errMsg) - for i, key := range keys { - val := ts.Get([]byte(key)) - if i < 2 { - require.Nil(t, val) - } else { - require.NotNil(t, val) - } - } -} + }, + }, + "delete_child_limit_with_limit": { + changes: func(t *testing.T, ts *TrieState) { + for i, key := range sortedKeys { + err := ts.SetChildStorage(keyToChild, key, []byte{byte(i)}) + require.Nil(t, err) + } + }, + checks: func(t *testing.T, ts *TrieState, isTransactionRunning bool) { + testLimitBytes := make([]byte, 4) + binary.LittleEndian.PutUint32(testLimitBytes, uint32(2)) + optLimit2 := &testLimitBytes + + deleted, all, err := ts.DeleteChildLimit(keyToChild, optLimit2) + require.NoError(t, err) + + if isTransactionRunning { + require.Equal(t, uint32(3), deleted) + require.Equal(t, true, all) + } else { + require.Equal(t, uint32(2), deleted) + require.Equal(t, false, all) + } + }, + }, + "delete_child_limit_nil": { + changes: func(t *testing.T, ts *TrieState) { + for i, key := range sortedKeys { + err := ts.SetChildStorage(keyToChild, key, []byte{byte(i)}) + require.Nil(t, err) + } + }, + checks: func(t *testing.T, ts *TrieState, isTransactionRunning bool) { + deleted, all, err := ts.DeleteChildLimit(keyToChild, nil) -func TestTrieState_ClearPrefixInChild(t *testing.T) { - ts := NewTrieState(trie.NewEmptyTrie()) - child := trie.NewEmptyTrie() + require.Nil(t, err) + require.Equal(t, uint32(3), deleted) + require.Equal(t, true, all) + }, + }, + "next_key": { + changes: func(t *testing.T, ts *TrieState) { + for _, tc := range sortedKeys { + err := ts.Put(tc, tc) + require.Nil(t, err) + } + }, + checks: func(t *testing.T, ts *TrieState, _ bool) { + for i, tc := range sortedKeys { + next := ts.NextKey(tc) + if i == len(sortedKeys)-1 { + require.Nil(t, next) + } else { + require.Equal(t, sortedKeys[i+1], next, common.BytesToHex(tc)) + } + } + }, + }, + "child_next_key": { + changes: func(t *testing.T, ts *TrieState) { + for _, tc := range sortedKeys { + err := ts.SetChildStorage(keyToChild, tc, tc) + require.Nil(t, err) + } + }, + checks: func(t *testing.T, ts *TrieState, _ bool) { + for i, tc := range sortedKeys { + next, err := ts.GetChildNextKey(keyToChild, tc) + require.Nil(t, err) + + if i == len(sortedKeys)-1 { + require.Nil(t, next) + } else { + require.Equal(t, sortedKeys[i+1], next, common.BytesToHex(tc)) + } + } + }, + }, + "entries": { + changes: func(t *testing.T, ts *TrieState) { + for _, tc := range testCases { + err := ts.Put([]byte(tc), []byte(tc)) + require.Nil(t, err) + } + }, + checks: func(t *testing.T, ts *TrieState, _ bool) { + entries := ts.TrieEntries() + require.Len(t, entries, len(testCases)) - keys := []string{ - "noot", - "noodle", - "other", - } + for _, tc := range testCases { + require.Contains(t, entries, tc) + } + }, + }, + "get_keys_with_prefix_from_child": { + changes: func(t *testing.T, ts *TrieState) { + for _, tc := range prefixedKeys { + err := ts.SetChildStorage(keyToChild, tc, tc) + require.Nil(t, err) + } + }, + checks: func(t *testing.T, ts *TrieState, _ bool) { + values, err := ts.GetKeysWithPrefixFromChild(keyToChild, []byte("noo")) - for i, key := range keys { - child.Put([]byte(key), []byte{byte(i)}) + require.Nil(t, err) + require.Len(t, values, 2) + require.Contains(t, values, []byte("noot")) + require.Contains(t, values, []byte("noodle")) + }, + }, } - keyToChild := []byte("keytochild") + for tname, tt := range cases { + tt := tt + t.Run(tname, func(t *testing.T) { + t.Parallel() + t.Run("without_transactions", func(t *testing.T) { + t.Parallel() - err := ts.SetChild(keyToChild, child) - require.NoError(t, err) + ts := NewTrieState(trie.NewEmptyTrie()) + tt.changes(t, ts) + tt.checks(t, ts, false) + }) - err = ts.ClearPrefixInChild(keyToChild, []byte("noo")) - require.NoError(t, err) + t.Run("during_transaction", func(t *testing.T) { + t.Parallel() - for i, key := range keys { - val, err := ts.GetChildStorage(keyToChild, []byte(key)) - require.NoError(t, err) - if i < 2 { - require.Nil(t, val) - } else { - require.NotNil(t, val) - } - } -} + ts := NewTrieState(trie.NewEmptyTrie()) + ts.StartTransaction() + tt.changes(t, ts) + tt.checks(t, ts, true) + ts.CommitTransaction() + }) -func TestTrieState_NextKey(t *testing.T) { - ts := NewTrieState(trie.NewEmptyTrie()) - for _, tc := range testCases { - ts.Put([]byte(tc), []byte(tc)) - } + t.Run("after_transaction_committed", func(t *testing.T) { + t.Parallel() - sort.Slice(testCases, func(i, j int) bool { - return bytes.Compare([]byte(testCases[i]), []byte(testCases[j])) == -1 - }) - - for i, tc := range testCases { - next := ts.NextKey([]byte(tc)) - if i == len(testCases)-1 { - require.Nil(t, next) - } else { - require.Equal(t, []byte(testCases[i+1]), next, common.BytesToHex([]byte(tc))) - } + ts := NewTrieState(trie.NewEmptyTrie()) + ts.StartTransaction() + tt.changes(t, ts) + ts.CommitTransaction() + tt.checks(t, ts, false) + }) + }) } } -func TestTrieState_CommitStorageTransaction(t *testing.T) { +func TestTrieState_Root(t *testing.T) { ts := NewTrieState(trie.NewEmptyTrie()) for _, tc := range testCases { ts.Put([]byte(tc), []byte(tc)) } - ts.StartTransaction() - testValue := []byte("noot") - ts.Put([]byte(testCases[0]), testValue) - ts.CommitTransaction() - - val := ts.Get([]byte(testCases[0])) - require.Equal(t, testValue, val) + expected := ts.MustRoot() + require.Equal(t, expected, ts.MustRoot()) } -func TestTrieState_RollbackStorageTransaction(t *testing.T) { +func TestTrieState_ChildRoot(t *testing.T) { ts := NewTrieState(trie.NewEmptyTrie()) + keyToChild := []byte("child") + for _, tc := range testCases { - ts.Put([]byte(tc), []byte(tc)) + ts.SetChildStorage(keyToChild, []byte(tc), []byte(tc)) } - ts.StartTransaction() - testValue := []byte("noot") - ts.Put([]byte(testCases[0]), testValue) - ts.RollbackTransaction() - - val := ts.Get([]byte(testCases[0])) - require.Equal(t, []byte(testCases[0]), val) + root, err := ts.GetChildRoot(keyToChild) + require.Nil(t, err) + require.NotNil(t, root) } func TestTrieState_NestedTransactions(t *testing.T) { @@ -258,7 +434,7 @@ func TestTrieState_NestedTransactions(t *testing.T) { require.NotNil(t, ts.Get([]byte("key-3"))) require.Nil(t, ts.Get([]byte("key-4"))) - require.Equal(t, 1, ts.transactions.Len()) + require.Equal(t, 0, ts.transactions.Len()) }, }, "committing_all_nested_transactions": { @@ -295,7 +471,7 @@ func TestTrieState_NestedTransactions(t *testing.T) { require.NotNil(t, ts.Get([]byte("key-1"))) require.NotNil(t, ts.Get([]byte("key-2"))) require.NotNil(t, ts.Get([]byte("key-4"))) - require.Equal(t, 1, ts.transactions.Len()) + require.Equal(t, 0, ts.transactions.Len()) }, }, "rollback_without_transaction_should_panic": { @@ -324,56 +500,3 @@ func TestTrieState_NestedTransactions(t *testing.T) { }) } } - -func TestTrieState_DeleteChildLimit(t *testing.T) { - ts := NewTrieState(trie.NewEmptyTrie()) - child := trie.NewEmptyTrie() - - keys := []string{ - "key3", - "key1", - "key2", - } - - for i, key := range keys { - child.Put([]byte(key), []byte{byte(i)}) - } - - keyToChild := []byte("keytochild") - - err := ts.SetChild(keyToChild, child) - require.NoError(t, err) - - testLimitBytes := make([]byte, 4) - binary.LittleEndian.PutUint32(testLimitBytes, uint32(2)) - optLimit2 := &testLimitBytes - - testCases := []struct { - key []byte - limit *[]byte - expectedDeleted uint32 - expectedDelAll bool - errMsg string - }{ - { - key: []byte("fakekey"), - limit: optLimit2, - expectedDeleted: 0, - expectedDelAll: false, - errMsg: fmt.Sprintf("child trie does not exist at key 0x%x", ":child_storage:default:fakekey"), - }, - {key: []byte("keytochild"), limit: optLimit2, expectedDeleted: 2, expectedDelAll: false}, - {key: []byte("keytochild"), limit: nil, expectedDeleted: 1, expectedDelAll: true}, - } - for _, test := range testCases { - deleted, all, err := ts.DeleteChildLimit(test.key, test.limit) - if test.errMsg != "" { - require.Error(t, err) - require.EqualError(t, err, test.errMsg) - continue - } - require.NoError(t, err) - require.Equal(t, test.expectedDeleted, deleted) - require.Equal(t, test.expectedDelAll, all) - } -} diff --git a/lib/runtime/wazero/imports.go b/lib/runtime/wazero/imports.go index 4390e14457..3a399f7ebf 100644 --- a/lib/runtime/wazero/imports.go +++ b/lib/runtime/wazero/imports.go @@ -1241,17 +1241,12 @@ func ext_default_child_storage_root_version_1( panic("nil runtime context") } storage := rtCtx.Storage - child, err := storage.GetChild(read(m, childStorageKey)) - if err != nil { - logger.Errorf("failed to retrieve child: %s", err) - return 0 - } - - childRoot, err := trie.V0.Hash(child) + childRoot, err := storage.GetChildRoot(read(m, childStorageKey)) if err != nil { logger.Errorf("failed to encode child root: %s", err) return 0 } + childRootSlice := childRoot[:] ret, err := write(m, rtCtx.Allocator, scale.MustMarshal(&childRootSlice)) @@ -1270,19 +1265,8 @@ func ext_default_child_storage_root_version_2(ctx context.Context, m api.Module, } storage := rtCtx.Storage key := read(m, childStorageKey) - child, err := storage.GetChild(key) - if err != nil { - logger.Errorf("failed to retrieve child: %s", err) - return mustWrite(m, rtCtx.Allocator, emptyByteVectorEncoded) - } - - stateVersion, err := trie.ParseVersion(uint8(version)) - if err != nil { - logger.Errorf("failed parsing state version: %s", err) - return 0 - } - childRoot, err := stateVersion.Hash(child) + childRoot, err := storage.GetChildRoot(key) if err != nil { logger.Errorf("failed to encode child root: %s", err) return mustWrite(m, rtCtx.Allocator, emptyByteVectorEncoded) diff --git a/lib/runtime/wazero/imports_test.go b/lib/runtime/wazero/imports_test.go index 5c857a89c4..e731e6dbb0 100644 --- a/lib/runtime/wazero/imports_test.go +++ b/lib/runtime/wazero/imports_test.go @@ -934,10 +934,7 @@ func Test_ext_default_child_storage_read_version_1(t *testing.T) { setupInstance: func(t *testing.T) *Instance { inst := NewTestInstance(t, runtime.HOST_API_TEST_RUNTIME, TestWithVersion(DefaultVersion)) - err := inst.Context.Storage.SetChild(testChildKey, trie.NewEmptyTrie()) - require.NoError(t, err) - - err = inst.Context.Storage.SetChildStorage(testChildKey, testKey, testValue) + err := inst.Context.Storage.SetChildStorage(testChildKey, testKey, testValue) require.NoError(t, err) return inst }, @@ -995,7 +992,7 @@ func Test_ext_default_child_storage_set_version_1(t *testing.T) { setupInstance: func(t *testing.T) *Instance { inst := NewTestInstance(t, runtime.HOST_API_TEST_RUNTIME, TestWithVersion(DefaultVersion)) - err := inst.Context.Storage.SetChild(testChildKey, trie.NewEmptyTrie()) + err := inst.Context.Storage.SetChildStorage(testChildKey, []byte("exists"), []byte("exists")) require.NoError(t, err) return inst @@ -1074,10 +1071,7 @@ func Test_ext_default_child_storage_set_version_1(t *testing.T) { func Test_ext_default_child_storage_clear_version_1(t *testing.T) { inst := NewTestInstance(t, runtime.HOST_API_TEST_RUNTIME, TestWithVersion(DefaultVersion)) - err := inst.Context.Storage.SetChild(testChildKey, trie.NewEmptyTrie()) - require.NoError(t, err) - - err = inst.Context.Storage.SetChildStorage(testChildKey, testKey, testValue) + err := inst.Context.Storage.SetChildStorage(testChildKey, testKey, testValue) require.NoError(t, err) // Confirm if value is set @@ -1114,11 +1108,8 @@ func Test_ext_default_child_storage_clear_prefix_version_1(t *testing.T) { {[]byte("keyThree"), []byte("value3")}, } - err := inst.Context.Storage.SetChild(testChildKey, trie.NewEmptyTrie()) - require.NoError(t, err) - for _, kv := range testKeyValuePair { - err = inst.Context.Storage.SetChildStorage(testChildKey, kv.key, kv.value) + err := inst.Context.Storage.SetChildStorage(testChildKey, kv.key, kv.value) require.NoError(t, err) } @@ -1144,10 +1135,7 @@ func Test_ext_default_child_storage_clear_prefix_version_1(t *testing.T) { func Test_ext_default_child_storage_exists_version_1(t *testing.T) { inst := NewTestInstance(t, runtime.HOST_API_TEST_RUNTIME, TestWithVersion(DefaultVersion)) - err := inst.Context.Storage.SetChild(testChildKey, trie.NewEmptyTrie()) - require.NoError(t, err) - - err = inst.Context.Storage.SetChildStorage(testChildKey, testKey, testValue) + err := inst.Context.Storage.SetChildStorage(testChildKey, testKey, testValue) require.NoError(t, err) encChildKey, err := scale.Marshal(testChildKey) @@ -1173,10 +1161,8 @@ func Test_ext_default_child_storage_get_version_1(t *testing.T) { "value_exists_expected_value": { setupInstance: func(t *testing.T) *Instance { inst := NewTestInstance(t, runtime.HOST_API_TEST_RUNTIME, TestWithVersion(DefaultVersion)) - err := inst.Context.Storage.SetChild(testChildKey, trie.NewEmptyTrie()) - require.NoError(t, err) - err = inst.Context.Storage.SetChildStorage(testChildKey, testKey, testValue) + err := inst.Context.Storage.SetChildStorage(testChildKey, testKey, testValue) require.NoError(t, err) return inst }, @@ -1231,11 +1217,8 @@ func Test_ext_default_child_storage_next_key_version_1(t *testing.T) { setupInstance: func(t *testing.T) *Instance { inst := NewTestInstance(t, runtime.HOST_API_TEST_RUNTIME, TestWithVersion(DefaultVersion)) - err := inst.Context.Storage.SetChild(testChildKey, trie.NewEmptyTrie()) - require.NoError(t, err) - for _, kv := range testKeyValuePair { - err = inst.Context.Storage.SetChildStorage(testChildKey, kv.key, kv.value) + err := inst.Context.Storage.SetChildStorage(testChildKey, kv.key, kv.value) require.NoError(t, err) } @@ -1253,11 +1236,8 @@ func Test_ext_default_child_storage_next_key_version_1(t *testing.T) { setupInstance: func(t *testing.T) *Instance { inst := NewTestInstance(t, runtime.HOST_API_TEST_RUNTIME, TestWithVersion(DefaultVersion)) - err := inst.Context.Storage.SetChild(testChildKey, trie.NewEmptyTrie()) - require.NoError(t, err) - kv := testKeyValuePair[0] - err = inst.Context.Storage.SetChildStorage(testChildKey, kv.key, kv.value) + err := inst.Context.Storage.SetChildStorage(testChildKey, kv.key, kv.value) require.NoError(t, err) return inst @@ -1295,18 +1275,10 @@ func Test_ext_default_child_storage_next_key_version_1(t *testing.T) { func Test_ext_default_child_storage_root_version_1(t *testing.T) { inst := NewTestInstance(t, runtime.HOST_API_TEST_RUNTIME, TestWithVersion(DefaultVersion)) - err := inst.Context.Storage.SetChild(testChildKey, trie.NewEmptyTrie()) - require.NoError(t, err) - - err = inst.Context.Storage.SetChildStorage(testChildKey, testKey, testValue) + err := inst.Context.Storage.SetChildStorage(testChildKey, testKey, testValue) require.NoError(t, err) - child, err := inst.Context.Storage.GetChild(testChildKey) - require.NoError(t, err) - - stateVersion := trie.V0 - - rootHash, err := stateVersion.Hash(child) + rootHash, err := inst.Context.Storage.GetChildRoot(testChildKey) require.NoError(t, err) encChildKey, err := scale.Marshal(testChildKey) @@ -1331,16 +1303,10 @@ func Test_ext_default_child_storage_root_version_2(t *testing.T) { stateVersion := trie.V1 - err := inst.Context.Storage.SetChild(testChildKey, trie.NewEmptyTrie()) - require.NoError(t, err) - - err = inst.Context.Storage.SetChildStorage(testChildKey, testKey, testValue) + err := inst.Context.Storage.SetChildStorage(testChildKey, testKey, testValue) require.NoError(t, err) - child, err := inst.Context.Storage.GetChild(testChildKey) - require.NoError(t, err) - - rootHash, err := stateVersion.Hash(child) + rootHash, err := inst.Context.Storage.GetChildRoot(testChildKey) require.NoError(t, err) encChildKey, err := scale.Marshal(testChildKey) @@ -1368,13 +1334,13 @@ func Test_ext_default_child_storage_root_version_2(t *testing.T) { func Test_ext_default_child_storage_storage_kill_version_1(t *testing.T) { inst := NewTestInstance(t, runtime.HOST_API_TEST_RUNTIME, TestWithVersion(DefaultVersion)) - err := inst.Context.Storage.SetChild(testChildKey, trie.NewEmptyTrie()) + err := inst.Context.Storage.SetChildStorage(testChildKey, []byte("test"), []byte("test")) require.NoError(t, err) // Confirm if value is set - child, err := inst.Context.Storage.GetChild(testChildKey) + value, err := inst.Context.Storage.GetChildStorage(testChildKey, []byte("test")) require.NoError(t, err) - require.NotNil(t, child) + require.Equal(t, []byte("test"), value) encChildKey, err := scale.Marshal(testChildKey) require.NoError(t, err) @@ -1382,23 +1348,28 @@ func Test_ext_default_child_storage_storage_kill_version_1(t *testing.T) { _, err = inst.Exec("rtm_ext_default_child_storage_storage_kill_version_1", encChildKey) require.NoError(t, err) - child, _ = inst.Context.Storage.GetChild(testChildKey) - require.Nil(t, child) + value, err = inst.Context.Storage.GetChildStorage(testChildKey, []byte("test")) + require.NotNil(t, err) + require.Nil(t, value) } func Test_ext_default_child_storage_storage_kill_version_2_limit_all(t *testing.T) { inst := NewTestInstance(t, runtime.HOST_API_TEST_RUNTIME, TestWithVersion(DefaultVersion)) - tr := trie.NewEmptyTrie() - tr.Put([]byte(`key2`), []byte(`value2`)) - tr.Put([]byte(`key1`), []byte(`value1`)) - err := inst.Context.Storage.SetChild(testChildKey, tr) + err := inst.Context.Storage.SetChildStorage(testChildKey, []byte(`key2`), []byte(`value2`)) + require.NoError(t, err) + + err = inst.Context.Storage.SetChildStorage(testChildKey, []byte(`key1`), []byte(`value1`)) require.NoError(t, err) // Confirm if value is set - child, err := inst.Context.Storage.GetChild(testChildKey) + value, err := inst.Context.Storage.GetChildStorage(testChildKey, []byte(`key1`)) + require.NoError(t, err) + require.NotNil(t, value) + + value, err = inst.Context.Storage.GetChildStorage(testChildKey, []byte(`key2`)) require.NoError(t, err) - require.NotNil(t, child) + require.NotNil(t, value) encChildKey, err := scale.Marshal(testChildKey) require.NoError(t, err) @@ -1414,24 +1385,32 @@ func Test_ext_default_child_storage_storage_kill_version_2_limit_all(t *testing. require.NoError(t, err) require.Equal(t, []byte{1, 0, 0, 0}, res) - child, err = inst.Context.Storage.GetChild(testChildKey) + value, err = inst.Context.Storage.GetChildStorage(testChildKey, []byte(`key1`)) + require.NoError(t, err) + require.Nil(t, value) + + value, err = inst.Context.Storage.GetChildStorage(testChildKey, []byte(`key2`)) require.NoError(t, err) - require.Empty(t, child.Entries()) + require.Nil(t, value) } func Test_ext_default_child_storage_storage_kill_version_2_limit_1(t *testing.T) { inst := NewTestInstance(t, runtime.HOST_API_TEST_RUNTIME, TestWithVersion(DefaultVersion)) - tr := trie.NewEmptyTrie() - tr.Put([]byte(`key2`), []byte(`value2`)) - tr.Put([]byte(`key1`), []byte(`value1`)) - err := inst.Context.Storage.SetChild(testChildKey, tr) + err := inst.Context.Storage.SetChildStorage(testChildKey, []byte(`key2`), []byte(`value2`)) + require.NoError(t, err) + + err = inst.Context.Storage.SetChildStorage(testChildKey, []byte(`key1`), []byte(`value1`)) require.NoError(t, err) // Confirm if value is set - child, err := inst.Context.Storage.GetChild(testChildKey) + value, err := inst.Context.Storage.GetChildStorage(testChildKey, []byte(`key1`)) require.NoError(t, err) - require.NotNil(t, child) + require.NotNil(t, value) + + value, err = inst.Context.Storage.GetChildStorage(testChildKey, []byte(`key2`)) + require.NoError(t, err) + require.NotNil(t, value) encChildKey, err := scale.Marshal(testChildKey) require.NoError(t, err) @@ -1447,24 +1426,32 @@ func Test_ext_default_child_storage_storage_kill_version_2_limit_1(t *testing.T) require.NoError(t, err) require.Equal(t, []byte{0, 0, 0, 0}, res) - child, err = inst.Context.Storage.GetChild(testChildKey) + value, err = inst.Context.Storage.GetChildStorage(testChildKey, []byte(`key1`)) require.NoError(t, err) - require.Equal(t, 1, len(child.Entries())) + require.Nil(t, value) + + value, err = inst.Context.Storage.GetChildStorage(testChildKey, []byte(`key2`)) + require.NoError(t, err) + require.NotNil(t, value) } func Test_ext_default_child_storage_storage_kill_version_2_limit_none(t *testing.T) { inst := NewTestInstance(t, runtime.HOST_API_TEST_RUNTIME, TestWithVersion(DefaultVersion)) - tr := trie.NewEmptyTrie() - tr.Put([]byte(`key2`), []byte(`value2`)) - tr.Put([]byte(`key1`), []byte(`value1`)) - err := inst.Context.Storage.SetChild(testChildKey, tr) + err := inst.Context.Storage.SetChildStorage(testChildKey, []byte(`key2`), []byte(`value2`)) + require.NoError(t, err) + + err = inst.Context.Storage.SetChildStorage(testChildKey, []byte(`key1`), []byte(`value1`)) require.NoError(t, err) // Confirm if value is set - child, err := inst.Context.Storage.GetChild(testChildKey) + value, err := inst.Context.Storage.GetChildStorage(testChildKey, []byte(`key1`)) require.NoError(t, err) - require.NotNil(t, child) + require.NotNil(t, value) + + value, err = inst.Context.Storage.GetChildStorage(testChildKey, []byte(`key2`)) + require.NoError(t, err) + require.NotNil(t, value) encChildKey, err := scale.Marshal(testChildKey) require.NoError(t, err) @@ -1477,19 +1464,21 @@ func Test_ext_default_child_storage_storage_kill_version_2_limit_none(t *testing require.NoError(t, err) require.Equal(t, []byte{1, 0, 0, 0}, res) - child, err = inst.Context.Storage.GetChild(testChildKey) - require.Error(t, err) - require.Nil(t, child) + hash, err := inst.Context.Storage.GetChildRoot(testChildKey) + require.Error(t, err, trie.ErrChildTrieDoesNotExist) + require.Equal(t, common.EmptyHash, hash) } func Test_ext_default_child_storage_storage_kill_version_3(t *testing.T) { inst := NewTestInstance(t, runtime.HOST_API_TEST_RUNTIME, TestWithVersion(DefaultVersion)) - tr := trie.NewEmptyTrie() - tr.Put([]byte(`key2`), []byte(`value2`)) - tr.Put([]byte(`key1`), []byte(`value1`)) - tr.Put([]byte(`key3`), []byte(`value3`)) - err := inst.Context.Storage.SetChild(testChildKey, tr) + err := inst.Context.Storage.SetChildStorage(testChildKey, []byte(`key2`), []byte(`value2`)) + require.NoError(t, err) + + err = inst.Context.Storage.SetChildStorage(testChildKey, []byte(`key1`), []byte(`value1`)) + require.NoError(t, err) + + err = inst.Context.Storage.SetChildStorage(testChildKey, []byte(`key3`), []byte(`value3`)) require.NoError(t, err) testLimitBytes := make([]byte, 4) @@ -1687,11 +1676,8 @@ func Test_ext_default_child_storage_clear_prefix_version_2(t *testing.T) { {[]byte("keyThree"), []byte("value3")}, } - err := inst.Context.Storage.SetChild(testChildKey, trie.NewEmptyTrie()) - require.NoError(t, err) - for _, kv := range testKeyValuePair { - err = inst.Context.Storage.SetChildStorage(testChildKey, kv.key, kv.value) + err := inst.Context.Storage.SetChildStorage(testChildKey, kv.key, kv.value) require.NoError(t, err) } diff --git a/pkg/trie/child_storage.go b/pkg/trie/child_storage.go index bd7fcd0bf9..75222cefda 100644 --- a/pkg/trie/child_storage.go +++ b/pkg/trie/child_storage.go @@ -15,10 +15,10 @@ var ChildStorageKeyPrefix = []byte(":child_storage:default:") var ErrChildTrieDoesNotExist = errors.New("child trie does not exist") -// SetChild inserts a child trie into the main trie at key :child_storage:[keyToChild] +// setChild inserts a child trie into the main trie at key :child_storage:[keyToChild] // A child trie is added as a node (K, V) in the main trie. K is the child storage key // associated to the child trie, and V is the root hash of the child trie. -func (t *Trie) SetChild(keyToChild []byte, child *Trie) error { +func (t *Trie) setChild(keyToChild []byte, child *Trie) error { childHash, err := child.Hash() if err != nil { return err @@ -38,6 +38,7 @@ func (t *Trie) SetChild(keyToChild []byte, child *Trie) error { } // GetChild returns the child trie at key :child_storage:[keyToChild] +// TODO: do we need to return an error when the child trie does not exist? func (t *Trie) GetChild(keyToChild []byte) (*Trie, error) { key := make([]byte, len(ChildStorageKeyPrefix)+len(keyToChild)) copy(key, ChildStorageKeyPrefix) @@ -74,7 +75,7 @@ func (t *Trie) PutIntoChild(keyToChild, key, value []byte) error { } delete(t.childTries, origChildHash) - return t.SetChild(keyToChild, child) + return t.setChild(keyToChild, child) } // GetFromChild retrieves a key-value pair from the child trie located @@ -132,5 +133,5 @@ func (t *Trie) ClearFromChild(keyToChild, key []byte) error { return t.DeleteChild(keyToChild) } - return t.SetChild(keyToChild, child) + return t.setChild(keyToChild, child) } diff --git a/pkg/trie/child_storage_test.go b/pkg/trie/child_storage_test.go index d75dda0eaf..a58d89d348 100644 --- a/pkg/trie/child_storage_test.go +++ b/pkg/trie/child_storage_test.go @@ -16,7 +16,7 @@ func TestPutAndGetChild(t *testing.T) { childTrie := buildSmallTrie() parentTrie := NewEmptyTrie() - err := parentTrie.SetChild(childKey, childTrie) + err := parentTrie.setChild(childKey, childTrie) assert.NoError(t, err) childTrieRes, err := parentTrie.GetChild(childKey) @@ -30,7 +30,7 @@ func TestPutAndDeleteChild(t *testing.T) { childTrie := buildSmallTrie() parentTrie := NewEmptyTrie() - err := parentTrie.SetChild(childKey, childTrie) + err := parentTrie.setChild(childKey, childTrie) assert.NoError(t, err) err = parentTrie.DeleteChild(childKey) @@ -46,7 +46,7 @@ func TestPutAndClearFromChild(t *testing.T) { childTrie := buildSmallTrie() parentTrie := NewEmptyTrie() - err := parentTrie.SetChild(childKey, childTrie) + err := parentTrie.setChild(childKey, childTrie) assert.NoError(t, err) err = parentTrie.ClearFromChild(childKey, keyInChild) @@ -64,7 +64,7 @@ func TestPutAndGetFromChild(t *testing.T) { childTrie := buildSmallTrie() parentTrie := NewEmptyTrie() - err := parentTrie.SetChild(childKey, childTrie) + err := parentTrie.setChild(childKey, childTrie) assert.NoError(t, err) testKey := []byte("child_key") diff --git a/pkg/trie/database_test.go b/pkg/trie/database_test.go index 3303c33a3b..833f218931 100644 --- a/pkg/trie/database_test.go +++ b/pkg/trie/database_test.go @@ -318,7 +318,7 @@ func Test_Trie_PutChild_Store_Load(t *testing.T) { } for _, keyToChildTrie := range keysToChildTries { - err := trie.SetChild(keyToChildTrie, childTrie) + err := trie.setChild(keyToChildTrie, childTrie) require.NoError(t, err) err = trie.WriteDirty(db) diff --git a/pkg/trie/layout.go b/pkg/trie/layout.go index 53d5e17f9e..e89bc8ae38 100644 --- a/pkg/trie/layout.go +++ b/pkg/trie/layout.go @@ -47,6 +47,16 @@ type Entry struct{ Key, Value []byte } // Entries is a list of entry used to build a trie type Entries []Entry +func NewEntriesFromMap(source map[string][]byte) Entries { + entries := Entries{} + + for k, v := range source { + entries = append(entries, Entry{[]byte(k), v}) + } + + return entries +} + // String returns a string representation of trie version func (v TrieLayout) String() string { switch v {