Skip to content

Commit

Permalink
feat: state.{Get,Set}Extra[SA any](*StateDB,types.ExtraPayloads,...) (
Browse files Browse the repository at this point in the history
#48)

* feat: `state.{Get,Set}Extra[SA any](*StateDB,types.ExtraPayloads,...)`

* test: `GetExtra()` at each point in `CreateAccount()` + `SetExtra()` lifecycle

* test: reverting extras to snapshot

* test: `GetExtra()` after `StateDB.Copy()` and writes to original
  • Loading branch information
ARR4N authored Oct 9, 2024
1 parent 51cd795 commit 77c5571
Show file tree
Hide file tree
Showing 2 changed files with 236 additions and 0 deletions.
64 changes: 64 additions & 0 deletions core/state/state.libevm.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Copyright 2024 the libevm authors.
//
// The libevm additions to go-ethereum are free software: you can redistribute
// them and/or modify them under the terms of the GNU Lesser General Public License
// as published by the Free Software Foundation, either version 3 of the License,
// or (at your option) any later version.
//
// The libevm additions are distributed in the hope that they will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser
// General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see
// <http://www.gnu.org/licenses/>.

package state

import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
)

// GetExtra returns the extra payload from the [types.StateAccount] associated
// with the address, or a zero-value `SA` if not found. The
// [types.ExtraPayloads] MUST be sourced from [types.RegisterExtras].
func GetExtra[SA any](s *StateDB, p types.ExtraPayloads[SA], addr common.Address) SA {
stateObject := s.getStateObject(addr)
if stateObject != nil {
return p.FromStateAccount(&stateObject.data)
}
var zero SA
return zero
}

// SetExtra sets the extra payload for the address. See [GetExtra] for details.
func SetExtra[SA any](s *StateDB, p types.ExtraPayloads[SA], addr common.Address, extra SA) {
stateObject := s.getOrNewStateObject(addr)
if stateObject != nil {
setExtraOnObject(stateObject, p, addr, extra)
}
}

func setExtraOnObject[SA any](s *stateObject, p types.ExtraPayloads[SA], addr common.Address, extra SA) {
s.db.journal.append(extraChange[SA]{
payloads: p,
account: &addr,
prev: p.FromStateAccount(&s.data),
})
p.SetOnStateAccount(&s.data, extra)
}

// extraChange is a [journalEntry] for [SetExtra] / [setExtraOnObject].
type extraChange[SA any] struct {
payloads types.ExtraPayloads[SA]
account *common.Address
prev SA
}

func (e extraChange[SA]) dirtied() *common.Address { return e.account }

func (e extraChange[SA]) revert(s *StateDB) {
e.payloads.SetOnStateAccount(&s.getStateObject(*e.account).data, e.prev)
}
172 changes: 172 additions & 0 deletions core/state/state.libevm_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
// Copyright 2024 the libevm authors.
//
// The libevm additions to go-ethereum are free software: you can redistribute
// them and/or modify them under the terms of the GNU Lesser General Public License
// as published by the Free Software Foundation, either version 3 of the License,
// or (at your option) any later version.
//
// The libevm additions are distributed in the hope that they will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser
// General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see
// <http://www.gnu.org/licenses/>.

package state_test

import (
"fmt"
"reflect"
"testing"

"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/rawdb"
"github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/state/snapshot"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/ethdb/memorydb"
"github.com/ethereum/go-ethereum/libevm/ethtest"
"github.com/ethereum/go-ethereum/triedb"
)

func TestGetSetExtra(t *testing.T) {
type accountExtra struct {
// Data is a pointer to test deep copying.
Data *[]byte // MUST be exported; I spent 20 minutes investigating failing tests because I'm an idiot
}

types.TestOnlyClearRegisteredExtras()
t.Cleanup(types.TestOnlyClearRegisteredExtras)
// Just as its Data field is a pointer, the registered type is a pointer to
// test deep copying.
payloads := types.RegisterExtras[*accountExtra]()

rng := ethtest.NewPseudoRand(42)
addr := rng.Address()
nonce := rng.Uint64()
balance := rng.Uint256()
buf := rng.Bytes(8)
extra := &accountExtra{Data: &buf}

views := newWithSnaps(t)
stateDB := views.newStateDB(t, types.EmptyRootHash)

assert.Nilf(t, state.GetExtra(stateDB, payloads, addr), "state.GetExtra() returns zero-value %T if before account creation", extra)
stateDB.CreateAccount(addr)
stateDB.SetNonce(addr, nonce)
stateDB.SetBalance(addr, balance)
assert.Nilf(t, state.GetExtra(stateDB, payloads, addr), "state.GetExtra() returns zero-value %T if after account creation but before SetExtra()", extra)
state.SetExtra(stateDB, payloads, addr, extra)
require.Equal(t, extra, state.GetExtra(stateDB, payloads, addr), "state.GetExtra() immediately after SetExtra()")

root, err := stateDB.Commit(1, false) // arbitrary block number
require.NoErrorf(t, err, "%T.Commit(1, false)", stateDB)
require.NotEqualf(t, types.EmptyRootHash, root, "root hash returned by %T.Commit() is not the empty root", stateDB)

t.Run(fmt.Sprintf("retrieve from %T", views.snaps), func(t *testing.T) {
iter, err := views.snaps.AccountIterator(root, common.Hash{})
require.NoErrorf(t, err, "%T.AccountIterator(...)", views.snaps)
defer iter.Release()

require.Truef(t, iter.Next(), "%T.Next() (i.e. at least one account)", iter)
require.NoErrorf(t, iter.Error(), "%T.Error()", iter)

t.Run("types.FullAccount()", func(t *testing.T) {
got, err := types.FullAccount(iter.Account())
require.NoErrorf(t, err, "types.FullAccount(%T.Account())", iter)

want := &types.StateAccount{
Nonce: nonce,
Balance: balance,
Root: types.EmptyRootHash,
CodeHash: types.EmptyCodeHash[:],
}
payloads.SetOnStateAccount(want, extra)

if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("types.FullAccount(%T.Account()) diff (-want +got):\n%s", iter, diff)
}
})

require.Falsef(t, iter.Next(), "%T.Next() after first account (i.e. only one)", iter)
})

t.Run(fmt.Sprintf("retrieve from new %T", stateDB), func(t *testing.T) {
s := views.newStateDB(t, root)
assert.Equalf(t, nonce, s.GetNonce(addr), "%T.GetNonce()", s)
assert.Equalf(t, balance, s.GetBalance(addr), "%T.GetBalance()", s)
assert.Equal(t, extra, state.GetExtra(s, payloads, addr), "state.GetExtra()")
})

t.Run("reverting to snapshot", func(t *testing.T) {
s := views.newStateDB(t, root)
snap := s.Snapshot()

oldExtra := extra
buf := append(*oldExtra.Data, rng.Bytes(8)...)
newExtra := &accountExtra{Data: &buf}

state.SetExtra(s, payloads, addr, newExtra)
assert.Equalf(t, newExtra, state.GetExtra(s, payloads, addr), "state.GetExtra() after overwriting with new value")
s.RevertToSnapshot(snap)
assert.Equalf(t, oldExtra, state.GetExtra(s, payloads, addr), "state.GetExtra() after reverting to snapshot")
})

t.Run(fmt.Sprintf("%T.Copy()", stateDB), func(t *testing.T) {
require.Equalf(t, reflect.Pointer, reflect.TypeOf(extra).Kind(), "extra-payload type")
require.Equalf(t, reflect.Pointer, reflect.TypeOf(extra.Data).Kind(), "extra-payload field")

orig := views.newStateDB(t, root)
cp := orig.Copy()

oldExtra := extra
buf := append(*oldExtra.Data, rng.Bytes(8)...)
newExtra := &accountExtra{Data: &buf}

assert.Equalf(t, oldExtra, state.GetExtra(orig, payloads, addr), "GetExtra([original %T]) before setting", orig)
assert.Equalf(t, oldExtra, state.GetExtra(cp, payloads, addr), "GetExtra([copy of %T]) returns the same payload", orig)
state.SetExtra(orig, payloads, addr, newExtra)
assert.Equalf(t, newExtra, state.GetExtra(orig, payloads, addr), "GetExtra([original %T]) returns overwritten payload", orig)
assert.Equalf(t, oldExtra, state.GetExtra(cp, payloads, addr), "GetExtra([copy of %T]) returns original payload despite overwriting on original", orig)
})
}

// stateViews are different ways to access the same data.
type stateViews struct {
snaps *snapshot.Tree
database state.Database
}

func (v stateViews) newStateDB(t *testing.T, root common.Hash) *state.StateDB {
t.Helper()
s, err := state.New(root, v.database, v.snaps)
require.NoError(t, err, "state.New()")
return s
}

func newWithSnaps(t *testing.T) stateViews {
t.Helper()
empty := types.EmptyRootHash
kvStore := memorydb.New()
ethDB := rawdb.NewDatabase(kvStore)
snaps, err := snapshot.New(
snapshot.Config{
CacheSize: 16, // Mb (arbitrary but non-zero)
},
kvStore,
triedb.NewDatabase(ethDB, nil),
empty,
)
require.NoError(t, err, "snapshot.New()")

return stateViews{
snaps: snaps,
database: state.NewDatabase(ethDB),
}
}

0 comments on commit 77c5571

Please sign in to comment.