Skip to content

Commit

Permalink
[Enhancement] Add features related to the usage of the ClosestProof
Browse files Browse the repository at this point in the history
… method (#43)
  • Loading branch information
h5law authored Mar 21, 2024
1 parent 8682379 commit e3dbbbd
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 5 deletions.
3 changes: 3 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,7 @@ var (
ErrBadProof = errors.New("bad proof")
// ErrKeyNotFound is returned when a key is not found in the tree.
ErrKeyNotFound = errors.New("key not found")
// ErrInvalidClosestPath is returned when the path used in the ClosestProof
// method does not match the size of the trie's PathHasher
ErrInvalidClosestPath = errors.New("invalid path does not match path hasher size")
)
11 changes: 11 additions & 0 deletions hasher.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ type PathHasher interface {
type ValueHasher interface {
// HashValue hashes value data to produce the digest stored in leaf node.
HashValue([]byte) []byte
// ValueHashSize returns the length (in bytes) of digests produced by this hasher.
ValueHashSize() int
}

type trieHasher struct {
Expand Down Expand Up @@ -59,10 +61,19 @@ func (ph *pathHasher) PathSize() int {
return ph.hasher.Size()
}

// HashValue hashes the producdes a digest of the data provided by the value hasher
func (vh *valueHasher) HashValue(data []byte) []byte {
return vh.digest(data)
}

// ValueHashSize returns the length (in bytes) of digests produced by the value hasher
func (vh *valueHasher) ValueHashSize() int {
if vh.hasher == nil {
return 0
}
return vh.hasher.Size()
}

func (th *trieHasher) digest(data []byte) []byte {
th.hasher.Write(data)
sum := th.hasher.Sum(nil)
Expand Down
58 changes: 53 additions & 5 deletions proofs.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,11 @@ func (proof *SparseMerkleProof) validateBasic(spec *TrieSpec) error {
// Check that leaf data for non-membership proofs is a valid size.
lps := len(leafPrefix) + spec.ph.PathSize()
if proof.NonMembershipLeafData != nil && len(proof.NonMembershipLeafData) < lps {
return fmt.Errorf("invalid non-membership leaf data size: got %d but min is %d", len(proof.NonMembershipLeafData), lps)
return fmt.Errorf(
"invalid non-membership leaf data size: got %d but min is %d",
len(proof.NonMembershipLeafData),
lps,
)
}

// Check that all supplied sidenodes are the correct size.
Expand Down Expand Up @@ -133,7 +137,11 @@ func (proof *SparseCompactMerkleProof) validateBasic(spec *TrieSpec) error {

// Compact proofs: check that NumSideNodes is within the right range.
if proof.NumSideNodes < 0 || proof.NumSideNodes > spec.ph.PathSize()*8 {
return fmt.Errorf("invalid number of side nodes: got %d, min is 0 and max is %d", len(proof.SideNodes), spec.ph.PathSize()*8)
return fmt.Errorf(
"invalid number of side nodes: got %d, min is 0 and max is %d",
len(proof.SideNodes),
spec.ph.PathSize()*8,
)
}

// Compact proofs: check that the length of the bit mask is as expected
Expand Down Expand Up @@ -185,7 +193,24 @@ func (proof *SparseMerkleClosestProof) Unmarshal(bz []byte) error {
return dec.Decode(proof)
}

// GetValueHash returns the value hash of the closest proof.
func (proof *SparseMerkleClosestProof) GetValueHash(spec *TrieSpec) []byte {
if proof.ClosestValueHash == nil {
return nil
}
if spec.sumTrie {
return proof.ClosestValueHash[:len(proof.ClosestValueHash)-sumSize]
}
return proof.ClosestValueHash
}

func (proof *SparseMerkleClosestProof) validateBasic(spec *TrieSpec) error {
// ensure the proof length is the same size (in bytes) as the path
// hasher of the spec provided
if len(proof.Path) != spec.PathHasherSize() {
return fmt.Errorf("invalid path length: got %d, want %d", len(proof.Path), spec.PathHasherSize())
}

// ensure the depth of the leaf node being proven is within the path size
if proof.Depth < 0 || proof.Depth > spec.ph.PathSize()*8 {
return fmt.Errorf("invalid depth: got %d, outside of [0, %d]", proof.Depth, spec.ph.PathSize()*8)
Expand Down Expand Up @@ -231,6 +256,12 @@ type SparseCompactMerkleClosestProof struct {
}

func (proof *SparseCompactMerkleClosestProof) validateBasic(spec *TrieSpec) error {
// Ensure the proof length is the same size (in bytes) as the path
// hasher of the spec provided
if len(proof.Path) != spec.PathHasherSize() {
return fmt.Errorf("invalid path length: got %d, want %d", len(proof.Path), spec.PathHasherSize())
}

// Do a basic sanity check on the proof on the fields of the proof specific to
// the compact proof only.
//
Expand All @@ -246,7 +277,12 @@ func (proof *SparseCompactMerkleClosestProof) validateBasic(spec *TrieSpec) erro
}
for i, b := range proof.FlippedBits {
if len(b) > maxSliceLen {
return fmt.Errorf("invalid compressed flipped bit index %d: got length %d, max is %d]", i, bytesToInt(b), maxSliceLen)
return fmt.Errorf(
"invalid compressed flipped bit index %d: got length %d, max is %d]",
i,
bytesToInt(b),
maxSliceLen,
)
}
}
// perform a sanity check on the closest proof
Expand Down Expand Up @@ -320,7 +356,13 @@ func VerifyClosestProof(proof *SparseMerkleClosestProof, root []byte, spec *Trie
return VerifySumProof(proof.ClosestProof, root, proof.ClosestPath, valueHash, sum, spec)
}

func verifyProofWithUpdates(proof *SparseMerkleProof, root []byte, key []byte, value []byte, spec *TrieSpec) (bool, [][][]byte, error) {
func verifyProofWithUpdates(
proof *SparseMerkleProof,
root []byte,
key []byte,
value []byte,
spec *TrieSpec,
) (bool, [][][]byte, error) {
path := spec.ph.Path(key)

if err := proof.validateBasic(spec); err != nil {
Expand Down Expand Up @@ -384,7 +426,13 @@ func VerifyCompactProof(proof *SparseCompactMerkleProof, root []byte, key, value
}

// VerifyCompactSumProof is similar to VerifySumProof but for a compacted Merkle proof.
func VerifyCompactSumProof(proof *SparseCompactMerkleProof, root []byte, key, value []byte, sum uint64, spec *TrieSpec) (bool, error) {
func VerifyCompactSumProof(
proof *SparseCompactMerkleProof,
root []byte,
key, value []byte,
sum uint64,
spec *TrieSpec,
) (bool, error) {
decompactedProof, err := DecompactProof(proof, spec)
if err != nil {
return false, errors.Join(ErrBadProof, err)
Expand Down
5 changes: 5 additions & 0 deletions smt.go
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,11 @@ func (smt *SMT) ProveClosest(path []byte) (
proof *SparseMerkleClosestProof, // proof of the key-value pair found
err error, // the error value encountered
) {
// Ensure the path provided is the correct length for the path hasher.
if len(path) != smt.Spec().PathHasherSize() {
return nil, ErrInvalidClosestPath
}

workingPath := make([]byte, len(path))
copy(workingPath, path)
var siblings []trieNode
Expand Down
12 changes: 12 additions & 0 deletions types.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,18 @@ func newTrieSpec(hasher hash.Hash, sumTrie bool) TrieSpec {
// Spec returns the TrieSpec associated with the given trie
func (spec *TrieSpec) Spec() *TrieSpec { return spec }

// PathHasherSize returns the length (in bytes) of digests produced by the
// path hasher
func (spec *TrieSpec) PathHasherSize() int { return spec.ph.PathSize() }

// ValueHasherSize returns the length (in bytes) of digests produced by the
// value hasher
func (spec *TrieSpec) ValueHasherSize() int { return spec.vh.ValueHashSize() }

// TrieHasherSize returns the length (in bytes) of digests produced by the
// trie hasher
func (spec *TrieSpec) TrieHasherSize() int { return spec.th.hashSize() }

func (spec *TrieSpec) depth() int { return spec.ph.PathSize() * 8 }
func (spec *TrieSpec) digestValue(data []byte) []byte {
if spec.vh == nil {
Expand Down

0 comments on commit e3dbbbd

Please sign in to comment.