diff --git a/pkg/mmr/mmr.go b/pkg/mmr/mmr.go index 74f1f627011..f55c2092992 100644 --- a/pkg/mmr/mmr.go +++ b/pkg/mmr/mmr.go @@ -29,17 +29,17 @@ type MMRNode struct { } type MMR struct { - size uint64 - batch *MMRBatch - hasher hash.Hash - mtx sync.Mutex + size uint64 + storage MMRStorage + hasher hash.Hash + mtx sync.Mutex } -func NewMMR(size uint64, batch *MMRBatch, hasher hash.Hash) *MMR { +func NewMMR(size uint64, storage MMRStorage, hasher hash.Hash) *MMR { return &MMR{ - size: size, - batch: batch, - hasher: hasher, + size: size, + storage: storage, + hasher: hasher, } } @@ -72,7 +72,7 @@ func (mmr *MMR) Push(leaf MMRElement) (uint64, error) { elements = append(elements, parentElement) } - mmr.batch.append(elemPosition, elements) + mmr.storage.append(elemPosition, elements) mmr.size = position + 1 return position, nil } @@ -83,7 +83,7 @@ func (mmr *MMR) Root() (MMRElement, error) { if mmr.size == 0 { return nil, errorGetRootOnEmpty } else if mmr.size == 1 { - root, err := mmr.batch.getElement(0) + root, err := mmr.storage.getElement(0) if err != nil || root == nil { return nil, errorInconsistentStore } @@ -94,7 +94,7 @@ func (mmr *MMR) Root() (MMRElement, error) { peaks := make([]MMRElement, 0) for _, pos := range peaksPosition { - peak, err := mmr.batch.getElement(pos) + peak, err := mmr.storage.getElement(pos) if err != nil || peak == nil { return nil, errorInconsistentStore } @@ -105,7 +105,7 @@ func (mmr *MMR) Root() (MMRElement, error) { } func (mmr *MMR) Commit() error { - return mmr.batch.commit() + return mmr.storage.commit() } func (mmr *MMR) findElement(position uint64, values []MMRElement) (MMRElement, error) { @@ -114,7 +114,7 @@ func (mmr *MMR) findElement(position uint64, values []MMRElement) (MMRElement, e return values[positionOffset], nil } - value, err := mmr.batch.getElement(position) + value, err := mmr.storage.getElement(position) if err != nil || value == nil { return nil, errorInconsistentStore } diff --git a/pkg/mmr/mmr_batch.go b/pkg/mmr/mmr_batch.go index ece2cdc81cf..0a6cc34a7fd 100644 --- a/pkg/mmr/mmr_batch.go +++ b/pkg/mmr/mmr_batch.go @@ -7,7 +7,8 @@ import "slices" type MMRStorage interface { getElement(pos uint64) (*MMRElement, error) - append(pos uint64, items []MMRElement) error + append(pos uint64, elements []MMRElement) error + commit() error } type MMRBatch struct { @@ -22,15 +23,16 @@ func NewMMRBatch(storage MMRStorage) *MMRBatch { } } -func (b *MMRBatch) append(pos uint64, elements []MMRElement) { +func (b *MMRBatch) append(pos uint64, elements []MMRElement) error { b.nodes = append(b.nodes, MMRNode{ pos: pos, elements: elements, }) + return nil } func (b *MMRBatch) getElement(pos uint64) (*MMRElement, error) { - revNodes := b.nodes[:] + revNodes := b.nodes slices.Reverse(revNodes) for _, node := range revNodes { if pos < node.pos { @@ -64,3 +66,5 @@ func (b *MMRBatch) drain() []MMRNode { return nodes } + +var _ MMRStorage = (*MMRBatch)(nil) diff --git a/pkg/mmr/utils.go b/pkg/mmr/utils.go index e3c43cc0c9a..79e4cb88399 100644 --- a/pkg/mmr/utils.go +++ b/pkg/mmr/utils.go @@ -33,6 +33,12 @@ func (s *MemStorage) append(pos uint64, elements []MMRElement) error { return nil } +func (s *MemStorage) commit() error { + // Do nothing since all changes are automatically commited + return nil +} + func NewInMemMMR(hasher hash.Hash) *MMR { - return NewMMR(0, NewMMRBatch(NewMemStorage()), hasher) + storage := NewMemStorage() + return NewMMR(0, storage, hasher) }