Skip to content

Commit

Permalink
Make leaves generic
Browse files Browse the repository at this point in the history
  • Loading branch information
dimartiro committed Sep 6, 2024
1 parent d2044a9 commit 5d902bc
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 71 deletions.
18 changes: 10 additions & 8 deletions pkg/mmr/memstorage.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,32 +8,34 @@ import (
)

// MemStorage provides an in-memory storage mechanism for an MMR.
type MemStorage struct {
storage *btree.Map[uint64, MMRElement]
type MemStorage[T any] struct {
storage *btree.Map[uint64, T]
}

// NewMemStorage initialises a new instance of MemStorage with an empty storage.
func NewMemStorage() *MemStorage {
return &MemStorage{
storage: btree.NewMap[uint64, MMRElement](0),
func NewMemStorage[T any]() *MemStorage[T] {
return &MemStorage[T]{
storage: btree.NewMap[uint64, T](0),
}
}

func (s *MemStorage) getElement(pos uint64) (*MMRElement, error) {
//nolint:unparam
func (s *MemStorage[T]) getElement(pos uint64) (*T, error) {
if element, ok := s.storage.Get(pos); ok {
return &element, nil
}
return nil, nil
}

func (s *MemStorage) append(pos uint64, elements []MMRElement) error {
func (s *MemStorage[T]) append(pos uint64, elements []T) error {
for i, element := range elements {
s.storage.Set(pos+uint64(i), element)
}
return nil
}

func (s *MemStorage) commit() error {
//nolint:unused
func (s *MemStorage[T]) commit() error {
// Do nothing since all changes are automatically committed
return nil
}
6 changes: 4 additions & 2 deletions pkg/mmr/memstorage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ import (
"github.com/stretchr/testify/assert"
)

type MMRElement []byte

func TestGetElement(t *testing.T) {
memStorage := NewMemStorage()
memStorage := NewMemStorage[MMRElement]()
elements := make(map[uint64]MMRElement)

for i := uint64(1); i < 100; i++ {
Expand All @@ -30,7 +32,7 @@ func TestGetElement(t *testing.T) {
}

func TestGetNotFoundElement(t *testing.T) {
memStorage := NewMemStorage()
memStorage := NewMemStorage[MMRElement]()

element, err := memStorage.getElement(100)
assert.NoError(t, err)
Expand Down
84 changes: 38 additions & 46 deletions pkg/mmr/mmr.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,47 +4,44 @@ package mmr

import (
"errors"
"hash"
"math/bits"
"sync"
)

var (
errorInconsistentStore = errors.New("inconsistent store")
errorGetRootOnEmpty = errors.New("get root on empty MMR")
errorNotEnoughPeeks = errors.New("not enough peaks")
)

// MMRElement is an alias to easily change the MMR element type in case we need it
type MMRElement []byte

type MMRStorage interface {
getElement(pos uint64) (*MMRElement, error)
append(pos uint64, elements []MMRElement) error
type MMRStorage[T any] interface {
getElement(pos uint64) (*T, error)
append(pos uint64, elements []T) error
commit() error
}

type MergeFunc[T any] func(left, right T) (*T, error)

// MMR represents a Merkle Mountain Range (MMR) which is a persistent,
// append-only data structure that allows for efficient cryptographic proofs of
// inclusion for any piece of data added to it.
type MMR struct {
type MMR[T any] struct {
size uint64
storage MMRStorage
hasher hash.Hash
mtx sync.Mutex
storage MMRStorage[T]
merge MergeFunc[T]
}

// NewMMR initialises and returns a new MMR instance.
func NewMMR(size uint64, storage MMRStorage, hasher hash.Hash) *MMR {
return &MMR{
func NewMMR[T any](size uint64, storage MMRStorage[T], merger MergeFunc[T]) *MMR[T] {
return &MMR[T]{
size: size,
storage: storage,
hasher: hasher,
merge: merger,
}
}

// Push adds a new leaf to the MMR returning its position.
func (mmr *MMR) Push(leaf MMRElement) (uint64, error) {
elements := []MMRElement{leaf}
func (mmr *MMR[T]) Push(leaf T) (uint64, error) {
elements := []T{leaf}
peakMap := mmr.peakMap()
elemPosition := mmr.size
position := mmr.size
Expand All @@ -62,13 +59,16 @@ func (mmr *MMR) Push(leaf MMRElement) (uint64, error) {

rightElement := elements[len(elements)-1]

parentElement := mmr.merge(leftElement, rightElement)
parentElement, err := mmr.merge(*leftElement, rightElement)
if err != nil {
return 0, err
}

if err != nil {
return 0, err
}

elements = append(elements, parentElement)
elements = append(elements, *parentElement)
}

err := mmr.storage.append(elemPosition, elements)
Expand All @@ -81,19 +81,19 @@ func (mmr *MMR) Push(leaf MMRElement) (uint64, error) {
}

// Root returns the root of the MMR by merging the peaks.
func (mmr *MMR) Root() (MMRElement, error) {
func (mmr *MMR[T]) Root() (*T, error) {
if mmr.size == 0 {
return nil, errorGetRootOnEmpty
} else if mmr.size == 1 {
root, err := mmr.storage.getElement(0)
if err != nil || root == nil {
return nil, errorInconsistentStore
}
return *root, nil
return root, nil
}

peaksPosition := mmr.getPeaks()
peaks := make([]MMRElement, 0)
peaks := make([]T, 0)

for _, pos := range peaksPosition {
peak, err := mmr.storage.getElement(pos)
Expand All @@ -103,44 +103,33 @@ func (mmr *MMR) Root() (MMRElement, error) {
peaks = append(peaks, *peak)
}

return mmr.bagPeaks(peaks), nil
return mmr.bagPeaks(peaks)
}

// Commit commits the current state of the MMR to underlying storage.
func (mmr *MMR) Commit() error {
func (mmr *MMR[T]) Commit() error {
return mmr.storage.commit()
}

func (mmr *MMR) findElement(position uint64, values []MMRElement) (MMRElement, error) {
func (mmr *MMR[T]) findElement(position uint64, values []T) (*T, error) {
if position > mmr.size {
positionOffset := position - mmr.size
return values[positionOffset], nil
return &values[positionOffset], nil
}

value, err := mmr.storage.getElement(position)
if err != nil || value == nil {
return nil, errorInconsistentStore
}

return *value, nil
}

func (mmr *MMR) merge(left, right MMRElement) MMRElement {
// Since we could share mmr.hash instance in multiple goroutines
defer mmr.mtx.Unlock()
mmr.mtx.Lock()

mmr.hasher.Reset()
mmr.hasher.Write(left)
mmr.hasher.Write(right)
return mmr.hasher.Sum(nil)
return value, nil
}

/*
Returns a bitmap of the peaks in the MMR.
Eg: 0b11 means that the MMR has 2 peaks at position 0 and at position 1
*/
func (mmr *MMR) peakMap() uint64 {
func (mmr *MMR[T]) peakMap() uint64 {
if mmr.size == 0 {
return 0
}
Expand All @@ -164,7 +153,7 @@ func (mmr *MMR) peakMap() uint64 {
/*
getPeaks() the positions of the peaks in the MMR.
*/
func (mmr *MMR) getPeaks() []uint64 {
func (mmr *MMR[T]) getPeaks() []uint64 {
if mmr.size == 0 {
return []uint64{}
}
Expand All @@ -185,21 +174,24 @@ func (mmr *MMR) getPeaks() []uint64 {
return peaks
}

func (mmr *MMR) bagPeaks(peaks []MMRElement) MMRElement {
func (mmr *MMR[T]) bagPeaks(peaks []T) (*T, error) {
for len(peaks) > 1 {
var rightPeak, leftPeak MMRElement
var rightPeak, leftPeak T

rightPeak, peaks = peaks[len(peaks)-1], peaks[:len(peaks)-1]
leftPeak, peaks = peaks[len(peaks)-1], peaks[:len(peaks)-1]

mergedPeak := mmr.merge(rightPeak, leftPeak)
peaks = append(peaks, mergedPeak)
mergedPeak, err := mmr.merge(rightPeak, leftPeak)
if err != nil {
return nil, err
}
peaks = append(peaks, *mergedPeak)
}

if len(peaks) < 1 {
return nil
return nil, errorNotEnoughPeeks
}

// #nosec G602
return peaks[0]
return &peaks[0], nil
}
38 changes: 23 additions & 15 deletions pkg/mmr/mmr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,32 @@ package mmr

import (
"encoding/binary"
"hash"
"testing"

"github.com/stretchr/testify/assert"
"golang.org/x/crypto/blake2b"
)

func newInMemMMR(hasher hash.Hash) *MMR {
return NewMMR(0, NewMemStorage(), hasher)
type HashedNumber []byte

func mergeHashedNumbers(left, right HashedNumber) (*HashedNumber, error) {
hasher, err := blake2b.New256(nil)
if err != nil {
return nil, err
}
hasher.Write(left[:])
hasher.Write(right[:])

res := HashedNumber(hasher.Sum(nil))
return &res, nil
}

func newInMemMMR[T any](merger MergeFunc[T]) *MMR[T] {
var storage MMRStorage[T] = NewMemStorage[T]()
return NewMMR(0, storage, merger)
}

func hashNumber(number uint32) MMRElement {
func hashNumber(number uint32) HashedNumber {
hasher, _ := blake2b.New256(nil)
var numBytes [4]byte
binary.LittleEndian.PutUint32(numBytes[:], number)
Expand All @@ -28,27 +42,21 @@ func hashNumber(number uint32) MMRElement {
}

func TestPushOneElement_RootShouldBeSameLeaf(t *testing.T) {
hasher, err := blake2b.New256(nil)
assert.NoError(t, err)

inMemMMR := newInMemMMR(hasher)
inMemMMR := newInMemMMR(mergeHashedNumbers)

leaf := hashNumber(0)
_, err = inMemMMR.Push(leaf)
_, err := inMemMMR.Push(leaf)
assert.NoError(t, err)

root, err := inMemMMR.Root()
assert.NoError(t, err)

assert.Equal(t, root, leaf)
assert.Equal(t, *root, leaf)
}

// Compared with the same MMR using substrate's implementation
func TestPushManyElementsGetRootOk(t *testing.T) {
hasher, err := blake2b.New256(nil)
assert.NoError(t, err)

inMemMMR := newInMemMMR(hasher)
inMemMMR := newInMemMMR(mergeHashedNumbers)

for i := uint32(0); i < 100; i++ {
leaf := hashNumber(i)
Expand All @@ -63,5 +71,5 @@ func TestPushManyElementsGetRootOk(t *testing.T) {
0x5, 0x0, 0xd0, 0xeb, 0xdb, 0xca, 0xd3, 0x6a, 0x79, 0xd3, 0x32, 0x5d,
0xbd, 0x2a, 0x4b, 0x2b, 0x97, 0x30, 0x1d, 0x8e, 0x48, 0x2a, 0x9b, 0xe2,
0x2, 0x1, 0x6e, 0x9f, 0x1c, 0xaa, 0xe1, 0x3f,
}, []byte(root))
}, []byte(*root))
}

0 comments on commit 5d902bc

Please sign in to comment.