Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 58 additions & 39 deletions internal/merkle/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"math/big"
"slices"

"github.com/davecgh/go-spew/spew"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
)
Expand Down Expand Up @@ -120,15 +121,15 @@ func (inner *InnerNode) Valid() bool {
return (isPair || isIterated) && !(isPair && isIterated) // xor
}

func (inner *InnerNode) Children() (*Tree, *Tree) {
func (inner *InnerNode) Children() (*Tree, *Tree, error) {
if !inner.Valid() {
panic(fmt.Sprintf("invalid InnerNode state: %v\n", inner))
return nil, nil, fmt.Errorf("invalid InnerNode state: %v\n", inner)
}

if inner.Child != nil {
return inner.Child, inner.Child
return inner.Child, inner.Child, nil
} else {
return inner.LHS, inner.RHS
return inner.LHS, inner.RHS, nil
}
}

Expand All @@ -144,33 +145,42 @@ func (tree *Tree) GetRootHash() common.Hash {
return tree.RootHash
}

func (tree *Tree) FindChildByHash(hash common.Hash) *Tree {
func (tree *Tree) FindChildByHash(hash common.Hash) (*Tree, error) {
if tree.RootHash == hash {
return tree
return tree, nil
}
if inner := tree.Subtrees; inner != nil {
if !inner.Valid() {
panic(fmt.Sprintf("invalid InnerNode state: %v\n", inner))
return nil, fmt.Errorf("invalid InnerNode state: %v\n", inner)
}

if inner.Child != nil {
child := inner.Child.FindChildByHash(hash)
child, err := inner.Child.FindChildByHash(hash)
if err != nil {
return nil, err
}
if child != nil {
return child
return child, nil
}
} else {
lhs := inner.LHS.FindChildByHash(hash)
lhs, err := inner.LHS.FindChildByHash(hash)
if err != nil {
return nil, err
}
if lhs != nil {
return lhs
return lhs, nil
}

rhs := inner.RHS.FindChildByHash(hash)
rhs, err := inner.RHS.FindChildByHash(hash)
if err != nil {
return nil, err
}
if rhs != nil {
return rhs
return rhs, nil
}
}
}
return nil // not found
return nil, nil // not found
}

func (tree *Tree) Join(other *Tree) *Tree {
Expand Down Expand Up @@ -198,11 +208,11 @@ func (tree *Tree) Iterated(rep uint64) *Tree {
return root
}

func (tree *Tree) ProveLeaf(index *big.Int) *Proof {
func (tree *Tree) ProveLeaf(index *big.Int) (*Proof, error) {
return tree.ProveLeafRec(index)
}

func (tree *Tree) ProveLast() *Proof {
func (tree *Tree) ProveLast() (*Proof, error) {
// index = (1 << height) - 1
index := new(big.Int).Sub(
new(big.Int).Lsh(
Expand All @@ -214,21 +224,21 @@ func (tree *Tree) ProveLast() *Proof {
return tree.ProveLeaf(index)
}

func (tree *Tree) ProveLeafRec(index *big.Int) *Proof {
func (tree *Tree) ProveLeafRec(index *big.Int) (*Proof, error) {
numLeafs := new(big.Int).Lsh(one, uint(tree.Height))
if numLeafs.Cmp(index) <= 0 {
panic(fmt.Sprintf("index out of bounds: %v, %v", numLeafs, index))
return nil, fmt.Errorf("index out of bounds: %v, %v", numLeafs, index)
}

subtree := tree.Subtrees
if subtree == nil {
if index.Cmp(zero) != 0 {
panic(fmt.Sprintf("invalid Tree state: %v", tree))
return nil, fmt.Errorf("invalid Tree state: %v", tree)
}
if tree.Height != 0 {
panic(fmt.Sprintf("invalid Tree state: %v", tree))
return nil, fmt.Errorf("invalid Tree state: %v", tree)
}
return Leaf(tree.RootHash, index)
return Leaf(tree.RootHash, index), nil
}

shiftAmount := uint(tree.Height - 1)
Expand All @@ -245,17 +255,21 @@ func (tree *Tree) ProveLeafRec(index *big.Int) *Proof {
),
)

lhs, rhs := subtree.Children()
lhs, rhs, err := subtree.Children()
if err != nil {
return nil, err
}

if isLeftLeaf {
proof := lhs.ProveLeafRec(innerIndex)
proof, err := lhs.ProveLeafRec(innerIndex)
proof.PushHash(rhs.RootHash)
proof.Pos = index
return proof
return proof, err
} else {
proof := rhs.ProveLeafRec(innerIndex)
proof, err := rhs.ProveLeafRec(innerIndex)
proof.PushHash(lhs.RootHash)
proof.Pos = index
return proof
return proof, err
}
}

Expand Down Expand Up @@ -303,52 +317,57 @@ func (b *Builder) AppendRepeatedUint64(leaf *Tree, reps uint64) {
b.AppendRepeated(leaf, new(big.Int).SetUint64(reps))
}

func (b *Builder) AppendRepeated(leaf *Tree, reps *big.Int) {
func (b *Builder) AppendRepeated(leaf *Tree, reps *big.Int) error {
if reps.Cmp(zero) <= 0 {
panic("invalid repetitions")
return fmt.Errorf("invalid repetitions: %v", reps)
}

accumulatedCount, err := b.CalculateAccumulatedCount(reps)
if err != nil {
return err
}

accumulatedCount := b.CalculateAccumulatedCount(reps)
if height, ok := b.Height(); ok {
if height != leaf.Height {
panic("mismatched tree size")
return fmt.Errorf("mismatched tree sizes, height: %v and leaf height: %v", height, leaf.Height)
}
}
b.Trees = append(b.Trees, Node{
Tree: leaf,
AccumulatedCount: accumulatedCount,
})
return nil
}

func (b *Builder) Build() *Tree {
func (b *Builder) Build() (*Tree, error) {
if count, ok := b.Count(); ok {
if !isCountPow2(count) {
panic(fmt.Sprintf("builder has %v leafs, which is not a power of two", count))
return nil, fmt.Errorf("builder has %v leafs, which is not a power of two", count)
}
log2Size := countTrailingZeroes(count)
return buildMerkle(b.Trees, log2Size, big.NewInt(0))
return buildMerkle(b.Trees, log2Size, big.NewInt(0)), nil
} else {
panic("no leafs in the merkle builder")
return nil, fmt.Errorf("no leafs in the merkle builder: %v", spew.Sprint(b))
}
}

func (b *Builder) CalculateAccumulatedCount(reps *big.Int) *big.Int {
func (b *Builder) CalculateAccumulatedCount(reps *big.Int) (*big.Int, error) {
n := len(b.Trees)
if n != 0 {
if reps.Cmp(zero) == 0 {
panic("merkle builder is full")
return nil, fmt.Errorf("merkle builder is full")
}

accumulatedCount := new(big.Int).And(
new(big.Int).Add(reps, b.Trees[n-1].AccumulatedCount),
overflowMask,
)
if reps.Cmp(accumulatedCount) >= 0 {
panic("merkle tree overflow")
return nil, fmt.Errorf("merkle tree overflow")
}
return accumulatedCount
return accumulatedCount, nil
} else {
return reps
return reps, nil
}
}

Expand Down
Loading
Loading