diff --git a/pkg/mmr/memstorage.go b/pkg/mmr/memstorage.go index f2eee411d47..5dd3fc0ffb3 100644 --- a/pkg/mmr/memstorage.go +++ b/pkg/mmr/memstorage.go @@ -8,32 +8,34 @@ import ( ) // MemStorage provides an in-memory storage mechanism for an MMR. -type MemStorage struct { - storage *btree.Map[uint64, MMRElement] +type MemStorage[T any] struct { + storage *btree.Map[uint64, T] } // NewMemStorage initialises a new instance of MemStorage with an empty storage. -func NewMemStorage() *MemStorage { - return &MemStorage{ - storage: btree.NewMap[uint64, MMRElement](0), +func NewMemStorage[T any]() *MemStorage[T] { + return &MemStorage[T]{ + storage: btree.NewMap[uint64, T](0), } } -func (s *MemStorage) getElement(pos uint64) (*MMRElement, error) { +//nolint:unparam +func (s *MemStorage[T]) getElement(pos uint64) (*T, error) { if element, ok := s.storage.Get(pos); ok { return &element, nil } return nil, nil } -func (s *MemStorage) append(pos uint64, elements []MMRElement) error { +func (s *MemStorage[T]) append(pos uint64, elements []T) error { for i, element := range elements { s.storage.Set(pos+uint64(i), element) } return nil } -func (s *MemStorage) commit() error { +//nolint:unused +func (s *MemStorage[T]) commit() error { // Do nothing since all changes are automatically committed return nil } diff --git a/pkg/mmr/memstorage_test.go b/pkg/mmr/memstorage_test.go index 0a0795065a2..4b5c45d7345 100644 --- a/pkg/mmr/memstorage_test.go +++ b/pkg/mmr/memstorage_test.go @@ -10,8 +10,10 @@ import ( "github.com/stretchr/testify/assert" ) +type MMRElement []byte + func TestGetElement(t *testing.T) { - memStorage := NewMemStorage() + memStorage := NewMemStorage[MMRElement]() elements := make(map[uint64]MMRElement) for i := uint64(1); i < 100; i++ { @@ -30,7 +32,7 @@ func TestGetElement(t *testing.T) { } func TestGetNotFoundElement(t *testing.T) { - memStorage := NewMemStorage() + memStorage := NewMemStorage[MMRElement]() element, err := memStorage.getElement(100) assert.NoError(t, err) diff --git a/pkg/mmr/mmr.go b/pkg/mmr/mmr.go index 387a45416af..f62cd772144 100644 --- a/pkg/mmr/mmr.go +++ b/pkg/mmr/mmr.go @@ -4,47 +4,44 @@ package mmr import ( "errors" - "hash" "math/bits" - "sync" ) var ( errorInconsistentStore = errors.New("inconsistent store") errorGetRootOnEmpty = errors.New("get root on empty MMR") + errorNotEnoughPeeks = errors.New("not enough peaks") ) -// MMRElement is an alias to easily change the MMR element type in case we need it -type MMRElement []byte - -type MMRStorage interface { - getElement(pos uint64) (*MMRElement, error) - append(pos uint64, elements []MMRElement) error +type MMRStorage[T any] interface { + getElement(pos uint64) (*T, error) + append(pos uint64, elements []T) error commit() error } +type MergeFunc[T any] func(left, right T) (*T, error) + // MMR represents a Merkle Mountain Range (MMR) which is a persistent, // append-only data structure that allows for efficient cryptographic proofs of // inclusion for any piece of data added to it. -type MMR struct { +type MMR[T any] struct { size uint64 - storage MMRStorage - hasher hash.Hash - mtx sync.Mutex + storage MMRStorage[T] + merge MergeFunc[T] } // NewMMR initialises and returns a new MMR instance. -func NewMMR(size uint64, storage MMRStorage, hasher hash.Hash) *MMR { - return &MMR{ +func NewMMR[T any](size uint64, storage MMRStorage[T], merger MergeFunc[T]) *MMR[T] { + return &MMR[T]{ size: size, storage: storage, - hasher: hasher, + merge: merger, } } // Push adds a new leaf to the MMR returning its position. -func (mmr *MMR) Push(leaf MMRElement) (uint64, error) { - elements := []MMRElement{leaf} +func (mmr *MMR[T]) Push(leaf T) (uint64, error) { + elements := []T{leaf} peakMap := mmr.peakMap() elemPosition := mmr.size position := mmr.size @@ -62,13 +59,16 @@ func (mmr *MMR) Push(leaf MMRElement) (uint64, error) { rightElement := elements[len(elements)-1] - parentElement := mmr.merge(leftElement, rightElement) + parentElement, err := mmr.merge(*leftElement, rightElement) + if err != nil { + return 0, err + } if err != nil { return 0, err } - elements = append(elements, parentElement) + elements = append(elements, *parentElement) } err := mmr.storage.append(elemPosition, elements) @@ -81,7 +81,7 @@ func (mmr *MMR) Push(leaf MMRElement) (uint64, error) { } // Root returns the root of the MMR by merging the peaks. -func (mmr *MMR) Root() (MMRElement, error) { +func (mmr *MMR[T]) Root() (*T, error) { if mmr.size == 0 { return nil, errorGetRootOnEmpty } else if mmr.size == 1 { @@ -89,11 +89,11 @@ func (mmr *MMR) Root() (MMRElement, error) { if err != nil || root == nil { return nil, errorInconsistentStore } - return *root, nil + return root, nil } peaksPosition := mmr.getPeaks() - peaks := make([]MMRElement, 0) + peaks := make([]T, 0) for _, pos := range peaksPosition { peak, err := mmr.storage.getElement(pos) @@ -103,18 +103,18 @@ func (mmr *MMR) Root() (MMRElement, error) { peaks = append(peaks, *peak) } - return mmr.bagPeaks(peaks), nil + return mmr.bagPeaks(peaks) } // Commit commits the current state of the MMR to underlying storage. -func (mmr *MMR) Commit() error { +func (mmr *MMR[T]) Commit() error { return mmr.storage.commit() } -func (mmr *MMR) findElement(position uint64, values []MMRElement) (MMRElement, error) { +func (mmr *MMR[T]) findElement(position uint64, values []T) (*T, error) { if position > mmr.size { positionOffset := position - mmr.size - return values[positionOffset], nil + return &values[positionOffset], nil } value, err := mmr.storage.getElement(position) @@ -122,25 +122,14 @@ func (mmr *MMR) findElement(position uint64, values []MMRElement) (MMRElement, e return nil, errorInconsistentStore } - return *value, nil -} - -func (mmr *MMR) merge(left, right MMRElement) MMRElement { - // Since we could share mmr.hash instance in multiple goroutines - defer mmr.mtx.Unlock() - mmr.mtx.Lock() - - mmr.hasher.Reset() - mmr.hasher.Write(left) - mmr.hasher.Write(right) - return mmr.hasher.Sum(nil) + return value, nil } /* Returns a bitmap of the peaks in the MMR. Eg: 0b11 means that the MMR has 2 peaks at position 0 and at position 1 */ -func (mmr *MMR) peakMap() uint64 { +func (mmr *MMR[T]) peakMap() uint64 { if mmr.size == 0 { return 0 } @@ -164,7 +153,7 @@ func (mmr *MMR) peakMap() uint64 { /* getPeaks() the positions of the peaks in the MMR. */ -func (mmr *MMR) getPeaks() []uint64 { +func (mmr *MMR[T]) getPeaks() []uint64 { if mmr.size == 0 { return []uint64{} } @@ -185,21 +174,24 @@ func (mmr *MMR) getPeaks() []uint64 { return peaks } -func (mmr *MMR) bagPeaks(peaks []MMRElement) MMRElement { +func (mmr *MMR[T]) bagPeaks(peaks []T) (*T, error) { for len(peaks) > 1 { - var rightPeak, leftPeak MMRElement + var rightPeak, leftPeak T rightPeak, peaks = peaks[len(peaks)-1], peaks[:len(peaks)-1] leftPeak, peaks = peaks[len(peaks)-1], peaks[:len(peaks)-1] - mergedPeak := mmr.merge(rightPeak, leftPeak) - peaks = append(peaks, mergedPeak) + mergedPeak, err := mmr.merge(rightPeak, leftPeak) + if err != nil { + return nil, err + } + peaks = append(peaks, *mergedPeak) } if len(peaks) < 1 { - return nil + return nil, errorNotEnoughPeeks } // #nosec G602 - return peaks[0] + return &peaks[0], nil } diff --git a/pkg/mmr/mmr_test.go b/pkg/mmr/mmr_test.go index 7cdae719fed..876b1c935f4 100644 --- a/pkg/mmr/mmr_test.go +++ b/pkg/mmr/mmr_test.go @@ -5,18 +5,32 @@ package mmr import ( "encoding/binary" - "hash" "testing" "github.com/stretchr/testify/assert" "golang.org/x/crypto/blake2b" ) -func newInMemMMR(hasher hash.Hash) *MMR { - return NewMMR(0, NewMemStorage(), hasher) +type HashedNumber []byte + +func mergeHashedNumbers(left, right HashedNumber) (*HashedNumber, error) { + hasher, err := blake2b.New256(nil) + if err != nil { + return nil, err + } + hasher.Write(left[:]) + hasher.Write(right[:]) + + res := HashedNumber(hasher.Sum(nil)) + return &res, nil +} + +func newInMemMMR[T any](merger MergeFunc[T]) *MMR[T] { + var storage MMRStorage[T] = NewMemStorage[T]() + return NewMMR(0, storage, merger) } -func hashNumber(number uint32) MMRElement { +func hashNumber(number uint32) HashedNumber { hasher, _ := blake2b.New256(nil) var numBytes [4]byte binary.LittleEndian.PutUint32(numBytes[:], number) @@ -28,27 +42,21 @@ func hashNumber(number uint32) MMRElement { } func TestPushOneElement_RootShouldBeSameLeaf(t *testing.T) { - hasher, err := blake2b.New256(nil) - assert.NoError(t, err) - - inMemMMR := newInMemMMR(hasher) + inMemMMR := newInMemMMR(mergeHashedNumbers) leaf := hashNumber(0) - _, err = inMemMMR.Push(leaf) + _, err := inMemMMR.Push(leaf) assert.NoError(t, err) root, err := inMemMMR.Root() assert.NoError(t, err) - assert.Equal(t, root, leaf) + assert.Equal(t, *root, leaf) } // Compared with the same MMR using substrate's implementation func TestPushManyElementsGetRootOk(t *testing.T) { - hasher, err := blake2b.New256(nil) - assert.NoError(t, err) - - inMemMMR := newInMemMMR(hasher) + inMemMMR := newInMemMMR(mergeHashedNumbers) for i := uint32(0); i < 100; i++ { leaf := hashNumber(i) @@ -63,5 +71,5 @@ func TestPushManyElementsGetRootOk(t *testing.T) { 0x5, 0x0, 0xd0, 0xeb, 0xdb, 0xca, 0xd3, 0x6a, 0x79, 0xd3, 0x32, 0x5d, 0xbd, 0x2a, 0x4b, 0x2b, 0x97, 0x30, 0x1d, 0x8e, 0x48, 0x2a, 0x9b, 0xe2, 0x2, 0x1, 0x6e, 0x9f, 0x1c, 0xaa, 0xe1, 0x3f, - }, []byte(root)) + }, []byte(*root)) }