Skip to content

Commit

Permalink
Introduce and use database.WithDefault (#3478)
Browse files Browse the repository at this point in the history
  • Loading branch information
StephenButtolph authored Oct 17, 2024
1 parent abf76fd commit d7c9423
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 33 deletions.
15 changes: 15 additions & 0 deletions database/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,21 @@ func GetBool(db KeyValueReader, key []byte) (bool, error) {
return b[0] == BoolTrue, nil
}

// WithDefault returns the value at [key] in [db]. If the key doesn't exist, it
// returns [def].
func WithDefault[V any](
get func(KeyValueReader, []byte) (V, error),
db KeyValueReader,
key []byte,
def V,
) (V, error) {
v, err := get(db, key)
if err == ErrNotFound {
return def, nil
}
return v, err
}

func Count(db Iteratee) (int, error) {
iterator := db.NewIterator()
defer iterator.Release()
Expand Down
26 changes: 25 additions & 1 deletion database/helpers_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
// See the file LICENSE for licensing terms.

package database
package database_test

import (
"math/rand"
Expand All @@ -11,7 +11,10 @@ import (

"github.com/stretchr/testify/require"

"github.com/ava-labs/avalanchego/database/memdb"
"github.com/ava-labs/avalanchego/utils"

. "github.com/ava-labs/avalanchego/database"
)

func TestSortednessUint64(t *testing.T) {
Expand Down Expand Up @@ -49,3 +52,24 @@ func TestSortednessUint32(t *testing.T) {
}
require.True(t, utils.IsSortedBytes(intBytes))
}

func TestOrDefault(t *testing.T) {
require := require.New(t)

var (
db = memdb.New()
key = utils.RandomBytes(32)
)

// Key doesn't exist
v, err := WithDefault(GetUInt64, db, key, 1)
require.NoError(err)
require.Equal(uint64(1), v)

require.NoError(PutUInt64(db, key, 2))

// Key does exist
v, err = WithDefault(GetUInt64, db, key, 1)
require.NoError(err)
require.Equal(uint64(2), v)
}
15 changes: 7 additions & 8 deletions indexer/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,16 @@ func newIndex(
}

// Get next accepted index from db
nextAcceptedIndex, err := database.GetUInt64(i.vDB, nextAcceptedIndexKey)
if err == database.ErrNotFound {
// Couldn't find it in the database. Must not have accepted any containers in previous runs.
i.log.Info("created new index",
zap.Uint64("nextAcceptedIndex", i.nextAcceptedIndex),
)
return i, nil
}
nextAcceptedIndex, err := database.WithDefault(
database.GetUInt64,
i.vDB,
nextAcceptedIndexKey,
0,
)
if err != nil {
return nil, fmt.Errorf("couldn't get next accepted index from database: %w", err)
}

i.nextAcceptedIndex = nextAcceptedIndex
i.log.Info("created new index",
zap.Uint64("nextAcceptedIndex", i.nextAcceptedIndex),
Expand Down
18 changes: 3 additions & 15 deletions vms/example/xsvm/state/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,7 @@ func AddBlock(db database.KeyValueWriter, height uint64, blkID ids.ID, blk []byt

func GetNonce(db database.KeyValueReader, address ids.ShortID) (uint64, error) {
key := Flatten(addressPrefix, address[:])
nonce, err := database.GetUInt64(db, key)
if errors.Is(err, database.ErrNotFound) {
return 0, nil
}
return nonce, err
return database.WithDefault(database.GetUInt64, db, key, 0)
}

func SetNonce(db database.KeyValueWriter, address ids.ShortID, nonce uint64) error {
Expand All @@ -102,11 +98,7 @@ func IncrementNonce(db database.KeyValueReaderWriter, address ids.ShortID, nonce

func GetBalance(db database.KeyValueReader, address ids.ShortID, chainID ids.ID) (uint64, error) {
key := Flatten(addressPrefix, address[:], chainID[:])
balance, err := database.GetUInt64(db, key)
if errors.Is(err, database.ErrNotFound) {
return 0, nil
}
return balance, err
return database.WithDefault(database.GetUInt64, db, key, 0)
}

func SetBalance(db database.KeyValueWriterDeleter, address ids.ShortID, chainID ids.ID, balance uint64) error {
Expand Down Expand Up @@ -154,11 +146,7 @@ func AddLoanID(db database.KeyValueWriter, chainID ids.ID, loanID ids.ID) error

func GetLoan(db database.KeyValueReader, chainID ids.ID) (uint64, error) {
key := Flatten(chainPrefix, chainID[:])
balance, err := database.GetUInt64(db, key)
if errors.Is(err, database.ErrNotFound) {
return 0, nil
}
return balance, err
return database.WithDefault(database.GetUInt64, db, key, 0)
}

func SetLoan(db database.KeyValueWriterDeleter, chainID ids.ID, balance uint64) error {
Expand Down
10 changes: 1 addition & 9 deletions vms/platformvm/state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -1391,7 +1391,7 @@ func (s *state) loadMetadata() error {
s.persistedFeeState = feeState
s.SetFeeState(feeState)

accruedFees, err := getAccruedFees(s.singletonDB)
accruedFees, err := database.WithDefault(database.GetUInt64, s.singletonDB, AccruedFeesKey, 0)
if err != nil {
return err
}
Expand Down Expand Up @@ -2665,11 +2665,3 @@ func getFeeState(db database.KeyValueReader) (gas.State, error) {
}
return feeState, nil
}

func getAccruedFees(db database.KeyValueReader) (uint64, error) {
accruedFees, err := database.GetUInt64(db, AccruedFeesKey)
if err == database.ErrNotFound {
return 0, nil
}
return accruedFees, err
}

0 comments on commit d7c9423

Please sign in to comment.