Skip to content

Commit

Permalink
fix(pkg/trie): Fix prefixed trie iterator (#4278)
Browse files Browse the repository at this point in the history
  • Loading branch information
dimartiro authored Oct 28, 2024
1 parent 2577544 commit 920a215
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 60 deletions.
35 changes: 19 additions & 16 deletions lib/runtime/storage/trie.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,14 @@ func (t *TrieState) NextKey(key []byte) []byte {
nextKey = []byte(currentTx.sortedKeys[pos])
}

nextKeyOnState := t.state.PrefixedIter(key).NextKeyFunc(func(nextKey []byte) bool {
_, deleted := currentTx.deletes[string(nextKey)]
return !deleted
})
var nextKeyOnState []byte
for k := range t.state.KeysFrom(key) {
if _, deleted := currentTx.deletes[string(k)]; !deleted {
nextKeyOnState = k
break
}
}

if nextKeyOnState == nil {
return nextKey
}
Expand All @@ -214,8 +218,7 @@ func (t *TrieState) ClearPrefix(prefix []byte) error {
if currentTx := t.getCurrentTransaction(); currentTx != nil {
keysOnState := make([]string, 0)

iter := t.state.PrefixedIter(prefix)
for key := iter.NextKey(); bytes.HasPrefix(key, prefix); key = iter.NextKey() {
for key := range t.state.PrefixedKeys(prefix) {
keysOnState = append(keysOnState, string(key))
}

Expand All @@ -235,8 +238,7 @@ func (t *TrieState) ClearPrefixLimit(prefix []byte, limit uint32) (
if currentTx := t.getCurrentTransaction(); currentTx != nil {
keysOnState := make([]string, 0)

iter := t.state.PrefixedIter(prefix)
for key := iter.NextKey(); bytes.HasPrefix(key, prefix); key = iter.NextKey() {
for key := range t.state.PrefixedKeys(prefix) {
keysOnState = append(keysOnState, string(key))
}

Expand Down Expand Up @@ -430,8 +432,7 @@ func (t *TrieState) ClearPrefixInChild(keyToChild, prefix []byte) error {
}

var onStateKeys []string
iter := child.PrefixedIter(prefix)
for key := iter.NextKey(); bytes.HasPrefix(key, prefix); key = iter.NextKey() {
for key := range child.PrefixedKeys(prefix) {
onStateKeys = append(onStateKeys, string(key))
}

Expand Down Expand Up @@ -466,8 +467,7 @@ func (t *TrieState) ClearPrefixInChildWithLimit(keyToChild, prefix []byte, limit
}

var onStateKeys []string
iter := child.PrefixedIter(prefix)
for key := iter.NextKey(); bytes.HasPrefix(key, prefix); key = iter.NextKey() {
for key := range child.PrefixedKeys(prefix) {
onStateKeys = append(onStateKeys, string(key))
}

Expand Down Expand Up @@ -516,10 +516,13 @@ func (t *TrieState) GetChildNextKey(keyToChild, key []byte) ([]byte, error) {
return nil, err
}

nextKeyOnState := childTrie.PrefixedIter(key).NextKeyFunc(func(nextKey []byte) bool {
_, deleted := childChanges.deletes[string(nextKey)]
return !deleted
})
var nextKeyOnState []byte
for k := range childTrie.KeysFrom(key) {
if _, deleted := childChanges.deletes[string(k)]; !deleted {
nextKeyOnState = k
break
}
}

if nextKeyOnState == nil {
return nextKey, nil
Expand Down
53 changes: 34 additions & 19 deletions pkg/trie/inmemory/iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package inmemory
import (
"bytes"
"fmt"
"iter"

"github.com/ChainSafe/gossamer/pkg/trie"
"github.com/ChainSafe/gossamer/pkg/trie/codec"
Expand Down Expand Up @@ -46,7 +47,7 @@ func NewInMemoryTrieIterator(opts ...IterOpts) *InMemoryTrieIterator {
return iter
}

func (t *InMemoryTrieIterator) NextEntry() *trie.Entry {
func (t *InMemoryTrieIterator) nextEntry() *trie.Entry {
found := findNextNode(t.trie.root, []byte(nil), t.cursorAtKey)
if found != nil {
t.cursorAtKey = found.Key
Expand All @@ -55,30 +56,13 @@ func (t *InMemoryTrieIterator) NextEntry() *trie.Entry {
}

func (t *InMemoryTrieIterator) NextKey() []byte {
entry := t.NextEntry()
entry := t.nextEntry()
if entry != nil {
return codec.NibblesToKeyLE(entry.Key)
}
return nil
}

// NextKeyFunc advance the iterator until the predicate condition meets
func (t *InMemoryTrieIterator) NextKeyFunc(predicate func(nextKey []byte) bool) (nextKey []byte) {
for entry := t.NextEntry(); entry != nil; entry = t.NextEntry() {
key := codec.NibblesToKeyLE(entry.Key)
if predicate(key) {
return key
}
}
return nil
}

func (t *InMemoryTrieIterator) Seek(targetKey []byte) {
t.NextKeyFunc(func(nextKey []byte) bool {
return bytes.Compare(nextKey, targetKey) >= 0
})
}

// Entries returns all the key-value pairs in the trie as a map of keys to values
// where the keys are encoded in Little Endian.
func (t *InMemoryTrie) Entries() (keyValueMap map[string][]byte) {
Expand All @@ -87,6 +71,37 @@ func (t *InMemoryTrie) Entries() (keyValueMap map[string][]byte) {
return keyValueMap
}

// KeysFrom returns an iterator over all keys in the trie that are greater than the given key.
func (t *InMemoryTrie) KeysFrom(key []byte) iter.Seq[[]byte] {
iter := NewInMemoryTrieIterator(WithTrie(t), WithCursorAt(codec.KeyLEToNibbles(key)))

return func(yield func([]byte) bool) {
for key := iter.NextKey(); key != nil; key = iter.NextKey() {
if !yield(key) {
return
}
}
}
}

// PrefixedKeys returns an iterator over all keys in the trie that have the given prefix.
func (t *InMemoryTrie) PrefixedKeys(prefix []byte) iter.Seq[[]byte] {
iter := NewInMemoryTrieIterator(WithTrie(t), WithCursorAt(codec.KeyLEToNibbles(prefix)))

return func(yield func([]byte) bool) {
// Return same prefix as first key if it's present in trie
if t.Get(prefix) != nil && !yield(prefix) {
return
}

for key := iter.NextKey(); bytes.HasPrefix(key, prefix); key = iter.NextKey() {
if !yield(key) {
return
}
}
}
}

// NextKey returns the next key in the trie in lexicographic order.
// It returns nil if no next key is found.
func (t *InMemoryTrie) NextKey(keyLE []byte) (nextKeyLE []byte) {
Expand Down
48 changes: 37 additions & 11 deletions pkg/trie/inmemory/iterator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@
package inmemory

import (
"bytes"
"testing"

"github.com/ChainSafe/gossamer/pkg/trie/codec"
"github.com/stretchr/testify/require"
)

Expand All @@ -22,13 +20,13 @@ func TestInMemoryTrieIterator(t *testing.T) {
tt.Put([]byte("account_storage:JJK:EEE"), []byte("0x10"))

iter := NewInMemoryTrieIterator(WithTrie(tt))
require.Equal(t, []byte("account_storage:ABC:AAA"), codec.NibblesToKeyLE((iter.NextEntry().Key)))
require.Equal(t, []byte("account_storage:ABC:CCC"), codec.NibblesToKeyLE((iter.NextEntry().Key)))
require.Equal(t, []byte("account_storage:ABC:DDD"), codec.NibblesToKeyLE((iter.NextEntry().Key)))
require.Equal(t, []byte("account_storage:JJK:EEE"), codec.NibblesToKeyLE((iter.NextEntry().Key)))
require.Equal(t, []byte("some_other_storage:XCC:ZZZ"), codec.NibblesToKeyLE((iter.NextEntry().Key)))
require.Equal(t, []byte("yet_another_storage:BLABLA:YYY:JJJ"), codec.NibblesToKeyLE((iter.NextEntry().Key)))
require.Nil(t, iter.NextEntry())
require.Equal(t, []byte("account_storage:ABC:AAA"), iter.NextKey())
require.Equal(t, []byte("account_storage:ABC:CCC"), iter.NextKey())
require.Equal(t, []byte("account_storage:ABC:DDD"), iter.NextKey())
require.Equal(t, []byte("account_storage:JJK:EEE"), iter.NextKey())
require.Equal(t, []byte("some_other_storage:XCC:ZZZ"), iter.NextKey())
require.Equal(t, []byte("yet_another_storage:BLABLA:YYY:JJJ"), iter.NextKey())
require.Nil(t, iter.NextKey())
}

func TestInMemoryIteratorGetAllKeysWithPrefix(t *testing.T) {
Expand All @@ -42,10 +40,9 @@ func TestInMemoryIteratorGetAllKeysWithPrefix(t *testing.T) {
tt.Put([]byte("account_storage:JJK:EEE"), []byte("0x10"))

prefix := []byte("account_storage")
iter := tt.PrefixedIter(prefix)

keys := make([][]byte, 0)
for key := iter.NextKey(); bytes.HasPrefix(key, prefix); key = iter.NextKey() {
for key := range tt.PrefixedKeys(prefix) {
keys = append(keys, key)
}

Expand All @@ -58,3 +55,32 @@ func TestInMemoryIteratorGetAllKeysWithPrefix(t *testing.T) {

require.Equal(t, expectedKeys, keys)
}

func TestInMemoryIteratorGetAllKeysWithPrefixIncluded(t *testing.T) {
tt := NewEmptyTrie()

tt.Put([]byte("services_storage:serviceA:19090"), []byte("0x10"))
tt.Put([]byte("services_storage:serviceB:22222"), []byte("0x10"))
tt.Put([]byte("account_storage"), []byte("0x10"))
tt.Put([]byte("account_storage:ABC:AAA"), []byte("0x10"))
tt.Put([]byte("account_storage:ABC:CCC"), []byte("0x10"))
tt.Put([]byte("account_storage:ABC:DDD"), []byte("0x10"))
tt.Put([]byte("account_storage:JJK:EEE"), []byte("0x10"))

prefix := []byte("account_storage")

keys := make([][]byte, 0)
for key := range tt.PrefixedKeys(prefix) {
keys = append(keys, key)
}

expectedKeys := [][]byte{
[]byte("account_storage"),
[]byte("account_storage:ABC:AAA"),
[]byte("account_storage:ABC:CCC"),
[]byte("account_storage:ABC:DDD"),
[]byte("account_storage:JJK:EEE"),
}

require.Equal(t, expectedKeys, keys)
}
17 changes: 3 additions & 14 deletions pkg/trie/trie.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package trie

import (
"fmt"
"iter"

"github.com/ChainSafe/gossamer/lib/common"
"github.com/ChainSafe/gossamer/pkg/trie/tracking"
Expand Down Expand Up @@ -35,20 +36,9 @@ type KVStoreWrite interface {
}

type TrieIterator interface {
// NextKey performs a depth-first search on the trie and returns the next key
// and value based on the current state of the iterator.
NextEntry() (entry *Entry)

// NextKey performs a depth-first search on the trie and returns the next key
// based on the current state of the iterator.
NextKey() (nextKey []byte)

// NextKeyFunc performs a depth-first search on the trie and returns the next key
// that satisfies the predicate based on the current state of the iterator.
NextKeyFunc(predicate func(nextKey []byte) bool) (nextKey []byte)

// Seek moves the iterator to the first key that is greater than the target key.
Seek(targetKey []byte)
}

type PrefixTrieWrite interface {
Expand Down Expand Up @@ -78,12 +68,11 @@ type TrieRead interface {
Hashable
ChildTriesRead

Iter() TrieIterator
PrefixedIter(prefix []byte) TrieIterator

Entries() (keyValueMap map[string][]byte)
NextKey(key []byte) []byte
GetKeysWithPrefix(prefix []byte) (keysLE [][]byte)
PrefixedKeys(prefix []byte) iter.Seq[[]byte]
KeysFrom(key []byte) iter.Seq[[]byte]
}

type Trie interface {
Expand Down

0 comments on commit 920a215

Please sign in to comment.