diff --git a/core/state/database.go b/core/state/database.go index e3dadbb0fb..1290a72306 100644 --- a/core/state/database.go +++ b/core/state/database.go @@ -84,10 +84,10 @@ type Trie interface { // PrefetchAccount attempts to resolve specific accounts from the database // to accelerate subsequent trie operations. PrefetchAccount([]common.Address) error - - // GetStorage returns the value for key stored in the trie. The value bytes - // must not be modified by the caller. If a node was not found in the database, - // a trie.MissingNodeError is returned. + + // GetStorage returns the value for key stored in the trie. The value bytes must + // not be modified by the caller. If a node was not found in the database, a + // trie.MissingNodeError is returned. GetStorage(addr common.Address, key []byte) ([]byte, error) // PrefetchStorage attempts to resolve specific storage slots from the database diff --git a/core/state/state_object.go b/core/state/state_object.go index 8e7b0a346c..3eda059e24 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -163,6 +163,19 @@ func (s *stateObject) GetState(key common.Hash) common.Hash { return value } +// GetStateWithMeter retrieves a value associated with the given storage key, +// and charges per-node via the meter during trie traversal. +func (s *stateObject) GetStateWithMeter(key common.Hash, meter func(uint64) error) (common.Hash, error) { + origin, err := s.GetCommittedStateWithMeter(key, meter) + if err != nil { + return common.Hash{}, err + } + if value, dirty := s.dirtyStorage[key]; dirty { + return value, nil + } + return origin, nil +} + // getState retrieves a value associated with the given storage key, along with // its original value. func (s *stateObject) getState(key common.Hash) (common.Hash, common.Hash) { @@ -174,6 +187,20 @@ func (s *stateObject) getState(key common.Hash) (common.Hash, common.Hash) { return origin, origin } +// getStateWithMeter retrieves a value associated with the given storage key, along with +// its original value, charging the meter during trie traversal. +func (s *stateObject) getStateWithMeter(key common.Hash, meter func(uint64) error) (common.Hash, common.Hash, error) { + origin, err := s.GetCommittedStateWithMeter(key, meter) + if err != nil { + return common.Hash{}, common.Hash{}, err + } + value, dirty := s.dirtyStorage[key] + if dirty { + return value, origin, nil + } + return origin, origin, nil +} + // GetCommittedState retrieves the value associated with the specific key // without any mutations caused in the current execution. func (s *stateObject) GetCommittedState(key common.Hash) common.Hash { @@ -201,7 +228,7 @@ func (s *stateObject) GetCommittedState(key common.Hash) common.Hash { s.db.StorageLoaded++ start := time.Now() - value, err := s.db.reader.Storage(s.address, key) + value, err := s.resolveStorage(key) if err != nil { s.db.setError(err) return common.Hash{} @@ -219,6 +246,86 @@ func (s *stateObject) GetCommittedState(key common.Hash) common.Hash { return value } +// GetCommittedStateWithMeter retrieves the value associated with the specific key +// without any mutations caused in the current execution, charging the meter +// during trie traversal. +func (s *stateObject) GetCommittedStateWithMeter(key common.Hash, meter func(uint64) error) (common.Hash, error) { + s.storageMutex.Lock() + defer s.storageMutex.Unlock() + // If we have a pending write or clean cached, return that + if value, pending := s.pendingStorage[key]; pending { + return value, nil + } + + if value, cached := s.originStorage[key]; cached { + return value, nil + } + + // If the object was destructed in *this* block (and potentially resurrected), + // the storage has been cleared out, and we should *not* consult the previous + // database about any storage values. The only possible alternatives are: + // 1) resurrect happened, and new slot values were set -- those should + // have been handles via pendingStorage above. + // 2) we don't have new values, and can deliver empty response back + if _, destructed := s.db.stateObjectsDestruct[s.address]; destructed { + s.originStorage[key] = common.Hash{} // track the empty slot as origin value + return common.Hash{}, nil + } + s.db.StorageLoaded++ + + start := time.Now() + value, err := s.resolveStorageWithMeter(key, meter) + if err != nil { + s.db.setError(err) + return common.Hash{}, err + } + s.db.StorageReads += time.Since(start) + + // Schedule the resolved storage slots for prefetching if it's enabled. + if s.db.prefetcher != nil && s.data.Root != types.EmptyRootHash { + if err = s.db.prefetcher.prefetch(s.addrHash, s.origin.Root, s.address, nil, []common.Hash{key}, true); err != nil { + log.Error("Failed to prefetch storage slot", "addr", s.address, "key", key, "err", err) + } + } + s.originStorage[key] = value + + return value, nil +} + +func (s *stateObject) resolveStorage(key common.Hash) (common.Hash, error) { + tr, err := s.getTrie() + if err != nil { + return common.Hash{}, err + } + ret, err := tr.GetStorage(s.address, key.Bytes()) + if err != nil { + return common.Hash{}, err + } + var value common.Hash + value.SetBytes(ret) + return value, nil +} + +func (s *stateObject) resolveStorageWithMeter(key common.Hash, meter func(uint64) error) (common.Hash, error) { + tr, err := s.getTrie() + if err != nil { + return common.Hash{}, err + } + if metered, ok := tr.(interface { + GetStorageWithMeter(common.Address, []byte, func(uint64) error) ([]byte, error) + }); ok { + ret, err := metered.GetStorageWithMeter(s.address, key.Bytes(), meter) + if err != nil { + return common.Hash{}, err + } + var value common.Hash + value.SetBytes(ret) + return value, nil + } + // Fallback to non-metered lookup + return s.resolveStorage(key) +} + // SetState updates a value in account storage. // It returns the previous value func (s *stateObject) SetState(key, value common.Hash) common.Hash { @@ -234,6 +341,28 @@ func (s *stateObject) SetState(key, value common.Hash) common.Hash { return prev } +// SetStateWithMeter updates a value in account storage and returns the previous +// value, charging the meter during trie traversal. +func (s *stateObject) SetStateWithMeter(key, value common.Hash, meter func(uint64) error) (common.Hash, error) { + origin, err := s.GetCommittedStateWithMeter(key, meter) + if err != nil { + return common.Hash{}, err + } + prev := origin + if dirtyValue, dirty := s.dirtyStorage[key]; dirty { + prev = dirtyValue + } + // If the new value is the same as old, don't set. Otherwise, track only the + // dirty changes, supporting reverting all of it back to no change. + if prev == value { + return prev, nil + } + // New value is different, update and journal the change + s.db.journal.storageChange(s.address, key, prev, origin) + s.setState(key, value, origin) + return prev, nil +} + // setState updates a value in account dirty storage. The dirtiness will be // removed if the value being set equals to the original value. func (s *stateObject) setState(key common.Hash, value common.Hash, origin common.Hash) { diff --git a/core/state/statedb.go b/core/state/statedb.go index 2ac293c5e5..c61d475394 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -732,6 +732,16 @@ func (s *StateDB) GetState(addr common.Address, hash common.Hash) common.Hash { }) } +// GetStateWithMeter retrieves the value associated with the specific key and +// charges per-node via the meter during trie traversal. +func (s *StateDB) GetStateWithMeter(addr common.Address, hash common.Hash, meter func(uint64) error) (common.Hash, error) { + stateObject := s.getStateObject(addr) + if stateObject == nil { + return common.Hash{}, nil + } + return stateObject.GetStateWithMeter(hash, meter) +} + // GetCommittedState retrieves the value associated with the specific key // without any mutations caused in the current execution. func (s *StateDB) GetCommittedState(addr common.Address, hash common.Hash) common.Hash { @@ -754,6 +764,16 @@ func (s *StateDB) GetStateAndCommittedState(addr common.Address, hash common.Has return common.Hash{}, common.Hash{} } +// GetStateAndCommittedStateWithMeter returns the current value and the original value, +// charging the meter during trie traversal. +func (s *StateDB) GetStateAndCommittedStateWithMeter(addr common.Address, hash common.Hash, meter func(uint64) error) (common.Hash, common.Hash, error) { + stateObject := s.getStateObject(addr) + if stateObject != nil { + return stateObject.getStateWithMeter(hash, meter) + } + return common.Hash{}, common.Hash{}, nil +} + // Database retrieves the low level database supporting the lower level trie ops. func (s *StateDB) Database() Database { return s.db @@ -880,6 +900,18 @@ func (s *StateDB) SetState(addr common.Address, key, value common.Hash) common.H return common.Hash{} } +// SetStateWithMeter updates storage and returns the previous value, +// charging per-node via the meter during trie traversal. +func (s *StateDB) SetStateWithMeter(addr common.Address, key, value common.Hash, meter func(uint64) error) (common.Hash, error) { + stateObject := s.getOrNewStateObject(addr) + if stateObject != nil { + stateObject = s.mvRecordWritten(stateObject) + MVWrite(s, blockstm.NewStateKey(addr, key)) + return stateObject.SetStateWithMeter(key, value, meter) + } + return common.Hash{}, nil +} + // SetStorage replaces the entire storage for the specified account with given // storage. This function should only be used for debugging and the mutations // must be discarded afterwards. diff --git a/core/state/statedb_hooked.go b/core/state/statedb_hooked.go index ae52ebdfbc..88af5081b8 100644 --- a/core/state/statedb_hooked.go +++ b/core/state/statedb_hooked.go @@ -90,10 +90,18 @@ func (s *hookedStateDB) GetStateAndCommittedState(addr common.Address, hash comm return s.inner.GetStateAndCommittedState(addr, hash) } +func (s *hookedStateDB) GetStateAndCommittedStateWithMeter(addr common.Address, hash common.Hash, meter func(uint64) error) (common.Hash, common.Hash, error) { + return s.inner.GetStateAndCommittedStateWithMeter(addr, hash, meter) +} + func (s *hookedStateDB) GetState(addr common.Address, hash common.Hash) common.Hash { return s.inner.GetState(addr, hash) } +func (s *hookedStateDB) GetStateWithMeter(addr common.Address, hash common.Hash, meter func(uint64) error) (common.Hash, error) { + return s.inner.GetStateWithMeter(addr, hash, meter) +} + func (s *hookedStateDB) GetStorageRoot(addr common.Address) common.Hash { return s.inner.GetStorageRoot(addr) } @@ -225,6 +233,17 @@ func (s *hookedStateDB) SetState(address common.Address, key common.Hash, value return prev } +func (s *hookedStateDB) SetStateWithMeter(address common.Address, key common.Hash, value common.Hash, meter func(uint64) error) (common.Hash, error) { + prev, err := s.inner.SetStateWithMeter(address, key, value, meter) + if err != nil { + return prev, err + } + if s.hooks.OnStorageChange != nil && prev != value { + s.hooks.OnStorageChange(address, key, prev, value) + } + return prev, nil +} + func (s *hookedStateDB) SelfDestruct(address common.Address) uint256.Int { var prevCode []byte var prevCodeHash common.Hash diff --git a/core/vm/evm.go b/core/vm/evm.go index 59c1272995..b6d701d797 100644 --- a/core/vm/evm.go +++ b/core/vm/evm.go @@ -457,8 +457,6 @@ func (evm *EVM) StaticCall(caller common.Address, addr common.Address, input []b if p, isPrecompile := evm.precompile(addr); isPrecompile { ret, gas, err = RunPrecompiledContract(p, input, gas, evm.Config.Tracer) } else { - // Initialise a new contract and set the code that is to be used by the EVM. - // The contract is a scoped environment for this execution context only. contract := NewContract(caller, addr, new(uint256.Int), gas, evm.jumpDests) contract.SetCallCode(evm.resolveCodeHash(addr), evm.resolveCode(addr)) diff --git a/core/vm/gas_table.go b/core/vm/gas_table.go index dbf9cb7be2..74da6c945c 100644 --- a/core/vm/gas_table.go +++ b/core/vm/gas_table.go @@ -101,9 +101,20 @@ var ( func gasSStore(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { var ( - y, x = stack.Back(1), stack.Back(0) - current, original = evm.StateDB.GetStateAndCommittedState(contract.Address(), x.Bytes32()) + y, x = stack.Back(1), stack.Back(0) + depthGas = uint64(0) ) + // Meter callback to charge per-node during trie traversal. + meter := func(nodeCount uint64) error { + if storageTrieDepthStepGas > 0 && nodeCount > storageTrieDepthFreeLevels { + depthGas += (nodeCount - storageTrieDepthFreeLevels) * storageTrieDepthStepGas + } + return nil + } + current, original, err := evm.StateDB.GetStateAndCommittedStateWithMeter(contract.Address(), x.Bytes32(), meter) + if err != nil { + return 0, err + } // The legacy gas metering only takes into consideration the current state // Legacy rules should be applied if we are in Petersburg (removal of EIP-1283) // OR Constantinople is not active @@ -115,12 +126,12 @@ func gasSStore(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySi // 3. From a non-zero to a non-zero (CHANGE) switch { case current == (common.Hash{}) && y.Sign() != 0: // 0 => non 0 - return params.SstoreSetGas, nil + return depthGas + params.SstoreSetGas, nil case current != (common.Hash{}) && y.Sign() == 0: // non 0 => 0 evm.StateDB.AddRefund(params.SstoreRefundGas) - return params.SstoreClearGas, nil + return depthGas + params.SstoreClearGas, nil default: // non 0 => non 0 (or 0 => 0) - return params.SstoreResetGas, nil + return depthGas + params.SstoreResetGas, nil } } @@ -140,18 +151,18 @@ func gasSStore(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySi // (2.2.2.2.) Otherwise, add 4800 gas to refund counter. value := common.Hash(y.Bytes32()) if current == value { // noop (1) - return params.NetSstoreNoopGas, nil + return depthGas + params.NetSstoreNoopGas, nil } if original == current { if original == (common.Hash{}) { // create slot (2.1.1) - return params.NetSstoreInitGas, nil + return depthGas + params.NetSstoreInitGas, nil } if value == (common.Hash{}) { // delete slot (2.1.2b) evm.StateDB.AddRefund(params.NetSstoreClearRefund) } - return params.NetSstoreCleanGas, nil // write existing slot (2.1.2) + return depthGas + params.NetSstoreCleanGas, nil // write existing slot (2.1.2) } if original != (common.Hash{}) { @@ -170,7 +181,7 @@ func gasSStore(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySi } } - return params.NetSstoreDirtyGas, nil + return depthGas + params.NetSstoreDirtyGas, nil } // Here come the EIP2200 rules: @@ -195,25 +206,36 @@ func gasSStoreEIP2200(evm *EVM, contract *Contract, stack *Stack, mem *Memory, m } // Gas sentry honoured, do the actual gas calculation based on the stored value var ( - y, x = stack.Back(1), stack.Back(0) - current, original = evm.StateDB.GetStateAndCommittedState(contract.Address(), x.Bytes32()) + y, x = stack.Back(1), stack.Back(0) + depthGas = uint64(0) ) + // Meter callback to charge per-node during trie traversal. + meter := func(nodeCount uint64) error { + if storageTrieDepthStepGas > 0 && nodeCount > storageTrieDepthFreeLevels { + depthGas += (nodeCount - storageTrieDepthFreeLevels) * storageTrieDepthStepGas + } + return nil + } + current, original, err := evm.StateDB.GetStateAndCommittedStateWithMeter(contract.Address(), x.Bytes32(), meter) + if err != nil { + return 0, err + } value := common.Hash(y.Bytes32()) if current == value { // noop (1) - return params.SloadGasEIP2200, nil + return depthGas + params.SloadGasEIP2200, nil } if original == current { if original == (common.Hash{}) { // create slot (2.1.1) - return params.SstoreSetGasEIP2200, nil + return depthGas + params.SstoreSetGasEIP2200, nil } if value == (common.Hash{}) { // delete slot (2.1.2b) evm.StateDB.AddRefund(params.SstoreClearsScheduleRefundEIP2200) } - return params.SstoreResetGasEIP2200, nil // write existing slot (2.1.2) + return depthGas + params.SstoreResetGasEIP2200, nil // write existing slot (2.1.2) } if original != (common.Hash{}) { @@ -232,7 +254,7 @@ func gasSStoreEIP2200(evm *EVM, contract *Contract, stack *Stack, mem *Memory, m } } - return params.SloadGasEIP2200, nil // dirty update (2.2) + return depthGas + params.SloadGasEIP2200, nil // dirty update (2.2) } func makeGasLog(n uint64) gasFunc { diff --git a/core/vm/instructions.go b/core/vm/instructions.go index 2a8e41d601..a9344ad1c9 100644 --- a/core/vm/instructions.go +++ b/core/vm/instructions.go @@ -559,7 +559,19 @@ func opMstore8(pc *uint64, evm *EVM, scope *ScopeContext) ([]byte, error) { func opSload(pc *uint64, evm *EVM, scope *ScopeContext) ([]byte, error) { loc := scope.Stack.peek() hash := common.Hash(loc.Bytes32()) - val := evm.StateDB.GetState(scope.Contract.Address(), hash) + meter := func(nodeCount uint64) error { + if storageTrieDepthStepGas == 0 || nodeCount <= storageTrieDepthFreeLevels { + return nil + } + if !scope.Contract.UseGas(storageTrieDepthStepGas, evm.Config.Tracer, tracing.GasChangeUnspecified) { + return ErrOutOfGas + } + return nil + } + val, err := evm.StateDB.GetStateWithMeter(scope.Contract.Address(), hash, meter) + if err != nil { + return nil, err + } loc.SetBytes(val.Bytes()) return nil, nil @@ -572,7 +584,10 @@ func opSstore(pc *uint64, evm *EVM, scope *ScopeContext) ([]byte, error) { loc := scope.Stack.pop() val := scope.Stack.pop() + // Note: depth-based gas is charged during gas calculation phase in makeGasSStoreFunc/gasSStoreEIP2200, + // so we use the non-metered SetState here. evm.StateDB.SetState(scope.Contract.Address(), loc.Bytes32(), val.Bytes32()) + return nil, nil } diff --git a/core/vm/interface.go b/core/vm/interface.go index 6c78d07171..b9a79d2a3e 100644 --- a/core/vm/interface.go +++ b/core/vm/interface.go @@ -53,8 +53,11 @@ type StateDB interface { GetRefund() uint64 GetStateAndCommittedState(common.Address, common.Hash) (common.Hash, common.Hash) + GetStateAndCommittedStateWithMeter(common.Address, common.Hash, func(uint64) error) (common.Hash, common.Hash, error) GetState(common.Address, common.Hash) common.Hash + GetStateWithMeter(common.Address, common.Hash, func(uint64) error) (common.Hash, error) SetState(common.Address, common.Hash, common.Hash) common.Hash + SetStateWithMeter(common.Address, common.Hash, common.Hash, func(uint64) error) (common.Hash, error) GetStorageRoot(addr common.Address) common.Hash GetTransientState(addr common.Address, key common.Hash) common.Hash diff --git a/core/vm/operations_acl.go b/core/vm/operations_acl.go index effa46e415..f4bbb0d564 100644 --- a/core/vm/operations_acl.go +++ b/core/vm/operations_acl.go @@ -26,6 +26,21 @@ import ( "github.com/ethereum/go-ethereum/params" ) +// Storage-trie gas charging parameters. +// +// These are consensus-affecting; keep disabled (0) unless activated by fork. +const ( + // storageTrieDepthStepGas is the gas charged per trie-level when accessing a + // particular storage slot. 0 disables depth-based charging. + // + // This is experimental and potentially expensive, since it may require a trie + // traversal to measure the canonical lookup-path node count. + storageTrieDepthStepGas uint64 = 1000 + + // storageTrieDepthFreeLevels is the number of trie levels that are free. + storageTrieDepthFreeLevels uint64 = 2 +) + func makeGasSStoreFunc(clearingRefund uint64) gasFunc { return func(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { // If we fail the minimum gas availability invariant, fail (0) @@ -34,11 +49,22 @@ func makeGasSStoreFunc(clearingRefund uint64) gasFunc { } // Gas sentry honoured, do the actual gas calculation based on the stored value var ( - y, x = stack.Back(1), stack.peek() - slot = common.Hash(x.Bytes32()) - current, original = evm.StateDB.GetStateAndCommittedState(contract.Address(), slot) - cost = uint64(0) + y, x = stack.Back(1), stack.peek() + slot = common.Hash(x.Bytes32()) + depthGas = uint64(0) + cost = uint64(0) ) + // Meter callback to charge per-node during trie traversal. + meter := func(nodeCount uint64) error { + if storageTrieDepthStepGas > 0 && nodeCount > storageTrieDepthFreeLevels { + depthGas += (nodeCount - storageTrieDepthFreeLevels) * storageTrieDepthStepGas + } + return nil + } + current, original, err := evm.StateDB.GetStateAndCommittedStateWithMeter(contract.Address(), slot, meter) + if err != nil { + return 0, err + } // Check slot presence in the access list if _, slotPresent := evm.StateDB.SlotInAccessList(contract.Address(), slot); !slotPresent { cost = params.ColdSloadCostEIP2929 @@ -51,11 +77,11 @@ func makeGasSStoreFunc(clearingRefund uint64) gasFunc { if current == value { // noop (1) // EIP 2200 original clause: // return params.SloadGasEIP2200, nil - return cost + params.WarmStorageReadCostEIP2929, nil // SLOAD_GAS + return cost + depthGas + params.WarmStorageReadCostEIP2929, nil // SLOAD_GAS } if original == current { if original == (common.Hash{}) { // create slot (2.1.1) - return cost + params.SstoreSetGasEIP2200, nil + return cost + depthGas + params.SstoreSetGasEIP2200, nil } if value == (common.Hash{}) { // delete slot (2.1.2b) @@ -63,7 +89,7 @@ func makeGasSStoreFunc(clearingRefund uint64) gasFunc { } // EIP-2200 original clause: // return params.SstoreResetGasEIP2200, nil // write existing slot (2.1.2) - return cost + (params.SstoreResetGasEIP2200 - params.ColdSloadCostEIP2929), nil // write existing slot (2.1.2) + return cost + depthGas + (params.SstoreResetGasEIP2200 - params.ColdSloadCostEIP2929), nil // write existing slot (2.1.2) } if original != (common.Hash{}) { @@ -90,7 +116,7 @@ func makeGasSStoreFunc(clearingRefund uint64) gasFunc { } // EIP-2200 original clause: //return params.SloadGasEIP2200, nil // dirty update (2.2) - return cost + params.WarmStorageReadCostEIP2929, nil // dirty update (2.2) + return cost + depthGas + params.WarmStorageReadCostEIP2929, nil // dirty update (2.2) } } diff --git a/eth/tracers/js/tracer_test.go b/eth/tracers/js/tracer_test.go index bb79cbe1a2..01c3ca8df8 100644 --- a/eth/tracers/js/tracer_test.go +++ b/eth/tracers/js/tracer_test.go @@ -40,6 +40,9 @@ type dummyStatedb struct { func (*dummyStatedb) GetRefund() uint64 { return 1337 } func (*dummyStatedb) GetBalance(addr common.Address) *uint256.Int { return new(uint256.Int) } +func (*dummyStatedb) GetStateAndCommittedStateWithMeter(_ common.Address, _ common.Hash, _ func(uint64) error) (common.Hash, common.Hash, error) { + return common.Hash{}, common.Hash{}, nil +} type vmContext struct { blockCtx vm.BlockContext diff --git a/eth/tracers/logger/logger_test.go b/eth/tracers/logger/logger_test.go index c3841b4ef2..72d096c1db 100644 --- a/eth/tracers/logger/logger_test.go +++ b/eth/tracers/logger/logger_test.go @@ -44,6 +44,10 @@ func (*dummyStatedb) GetStateAndCommittedState(_ common.Address, _ common.Hash) return common.Hash{}, common.Hash{} } +func (*dummyStatedb) GetStateAndCommittedStateWithMeter(_ common.Address, _ common.Hash, _ func(uint64) error) (common.Hash, common.Hash, error) { + return common.Hash{}, common.Hash{}, nil +} + func TestStoreCapture(t *testing.T) { var ( logger = NewStructLogger(nil) diff --git a/trie/secure_trie.go b/trie/secure_trie.go index a72d2a6deb..b9121929e8 100644 --- a/trie/secure_trie.go +++ b/trie/secure_trie.go @@ -64,7 +64,7 @@ func NewSecure(stateRoot common.Hash, owner common.Hash, root common.Hash, db da // // StateTrie is not safe for concurrent use. type StateTrie struct { - trie Trie + trie *Trie db database.NodeDatabase preimages preimageStore secKeyCache map[common.Hash][]byte @@ -84,8 +84,7 @@ func NewStateTrie(id *ID, db database.NodeDatabase) (*StateTrie, error) { return nil, err } tr := &StateTrie{ - //nolint:govet - trie: *trie, + trie: trie, db: db, secKeyCache: make(map[common.Hash][]byte), } @@ -106,6 +105,30 @@ func (t *StateTrie) MustGet(key []byte) []byte { return t.trie.MustGet(crypto.Keccak256(key)) } +// GetStorage attempts to retrieve a storage slot with provided account address +// and slot key. The value bytes must not be modified by the caller. +// If the specified storage slot is not in the trie, nil will be returned. +// If a trie node is not found in the database, a MissingNodeError is returned. +func (t *StateTrie) GetStorage(_ common.Address, key []byte) ([]byte, error) { + enc, err := t.trie.Get(crypto.Keccak256(key)) + if err != nil || len(enc) == 0 { + return nil, err + } + _, content, _, err := rlp.Split(enc) + return content, err +} + +// GetStorageWithMeter retrieves a storage slot and invokes the meter once per +// decoded trie node visited during lookup. +func (t *StateTrie) GetStorageWithMeter(_ common.Address, key []byte, meter func(uint64) error) ([]byte, error) { + enc, err := t.trie.GetWithMeter(crypto.Keccak256(key), meter) + if err != nil || len(enc) == 0 { + return nil, err + } + _, content, _, err := rlp.Split(enc) + return content, err +} + // GetAccount attempts to retrieve an account with provided account address. // If the specified account is not in the trie, nil will be returned. // If a trie node is not found in the database, a MissingNodeError is returned. @@ -142,19 +165,6 @@ func (t *StateTrie) PrefetchAccount(addresses []common.Address) error { return t.trie.Prefetch(keys) } -// GetStorage attempts to retrieve a storage slot with provided account address -// and slot key. The value bytes must not be modified by the caller. -// If the specified storage slot is not in the trie, nil will be returned. -// If a trie node is not found in the database, a MissingNodeError is returned. -func (t *StateTrie) GetStorage(_ common.Address, key []byte) ([]byte, error) { - enc, err := t.trie.Get(crypto.Keccak256(key)) - if err != nil || len(enc) == 0 { - return nil, err - } - _, content, _, err := rlp.Split(enc) - return content, err -} - // PrefetchStorage attempts to resolve specific storage slots from the database // to accelerate subsequent trie operations. func (t *StateTrie) PrefetchStorage(_ common.Address, keys [][]byte) error { @@ -305,8 +315,12 @@ func (t *StateTrie) Hash() common.Hash { // Copy returns a copy of StateTrie. func (t *StateTrie) Copy() *StateTrie { + var copied *Trie + if t.trie != nil { + copied = t.trie.Copy() + } return &StateTrie{ - trie: *t.trie.Copy(), + trie: copied, db: t.db, secKeyCache: make(map[common.Hash][]byte), preimages: t.preimages, diff --git a/trie/trie.go b/trie/trie.go index 08f19cca2e..81958e1493 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -202,6 +202,22 @@ func (t *Trie) Get(key []byte) ([]byte, error) { return value, err } +// GetWithMeter returns the value for key stored in the trie. The meter +// callback is invoked once per decoded node visited, with the current node +// count (1-based). If meter returns an error, traversal stops and the error +// is returned. +func (t *Trie) GetWithMeter(key []byte, meter func(uint64) error) ([]byte, error) { + // Short circuit if the trie is already committed and not usable. + if t.committed { + return nil, ErrCommitted + } + value, newroot, didResolve, _, err := t.getWithMeter(t.root, keybytesToHex(key), 0, 0, meter) + if err == nil && didResolve { + t.root = newroot + } + return value, err +} + func (t *Trie) get(origNode node, key []byte, pos int) (value []byte, newnode node, didResolve bool, err error) { switch n := (origNode).(type) { case nil: @@ -214,14 +230,14 @@ func (t *Trie) get(origNode node, key []byte, pos int) (value []byte, newnode no return nil, n, false, nil } - value, newnode, didResolve, err = t.get(n.Val, key, pos+len(n.Key)) + value, newnode, didResolve, err := t.get(n.Val, key, pos+len(n.Key)) if err == nil && didResolve { n.Val = newnode } return value, n, didResolve, err case *fullNode: - value, newnode, didResolve, err = t.get(n.Children[key[pos]], key, pos+1) + value, newnode, didResolve, err := t.get(n.Children[key[pos]], key, pos+1) if err == nil && didResolve { n.Children[key[pos]] = newnode } @@ -241,6 +257,65 @@ func (t *Trie) get(origNode node, key []byte, pos int) (value []byte, newnode no } } +// getWithMeter retrieves the value for a given key from the trie while tracking +// the number of nodes visited using the provided meter function. +func (t *Trie) getWithMeter(origNode node, key []byte, pos int, count uint64, meter func(uint64) error) (value []byte, newnode node, didResolve bool, newCount uint64, err error) { + switch n := (origNode).(type) { + case nil: + return nil, nil, false, count, nil + case valueNode: + count++ + if meter != nil { + if err := meter(count); err != nil { + return nil, n, false, count, err + } + } + return n, n, false, count, nil + case *shortNode: + count++ + if meter != nil { + if err := meter(count); err != nil { + return nil, n, false, count, err + } + } + if !bytes.HasPrefix(key[pos:], n.Key) { + // key not found in trie + return nil, n, false, count, nil + } + + value, newnode, didResolve, count, err = t.getWithMeter(n.Val, key, pos+len(n.Key), count, meter) + if err == nil && didResolve { + n.Val = newnode + } + + return value, n, didResolve, count, err + case *fullNode: + count++ + if meter != nil { + if err := meter(count); err != nil { + return nil, n, false, count, err + } + } + value, newnode, didResolve, count, err = t.getWithMeter(n.Children[key[pos]], key, pos+1, count, meter) + if err == nil && didResolve { + n.Children[key[pos]] = newnode + } + + return value, n, didResolve, count, err + case hashNode: + child, err := t.resolveAndTrack(n, key[:pos]) + if err != nil { + return nil, n, true, count, err + } + + value, newnode, _, count, err := t.getWithMeter(child, key, pos, count, meter) + + return value, newnode, true, count, err + default: + panic(fmt.Sprintf("%T: invalid node: %v", origNode, origNode)) + } +} + // Prefetch attempts to resolve the leaves and intermediate trie nodes // specified by the key list in parallel. The results are silently // discarded to simplify the function. diff --git a/trie/trie_test.go b/trie/trie_test.go index f3f30992f8..e1ffa1c669 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -1569,7 +1569,7 @@ func testTrieCopyNewTrie(t *testing.T, entries []kv) { // Traverse the original tree, the changes made on the copy one shouldn't // affect the old one for _, entry := range entries { - d, _ := trCpy.Get(entry.k) + d, _, _ := trCpy.Get(entry.k) if !bytes.Equal(d, entry.v) { t.Errorf("Unexpected data, key: %v, want: %v, got: %v", entry.k, entry.v, d) } diff --git a/trie/verkle.go b/trie/verkle.go index bf36aad70b..2474c39eaf 100644 --- a/trie/verkle.go +++ b/trie/verkle.go @@ -132,6 +132,13 @@ func (t *VerkleTrie) GetStorage(addr common.Address, key []byte) ([]byte, error) return common.TrimLeftZeroes(val), nil } +// GetStorageWithMeter retrieves a storage slot and invokes the meter for each +// visited node if supported. Depth is not defined for Verkle tries. +func (t *VerkleTrie) GetStorageWithMeter(addr common.Address, key []byte, meter func(uint64) error) ([]byte, error) { + // No trie-node depth semantics for Verkle; ignore meter. + return t.GetStorage(addr, key) +} + // PrefetchStorage attempts to resolve specific storage slots from the database // to accelerate subsequent trie operations. func (t *VerkleTrie) PrefetchStorage(addr common.Address, keys [][]byte) error {