diff --git a/client/client.go b/client/client.go index c29933d..74e612f 100644 --- a/client/client.go +++ b/client/client.go @@ -77,28 +77,35 @@ func FetchCheckpoint(ctx context.Context, f Fetcher, v note.Verifier, origin str // Since the tiles commit only to immutable nodes, the job of building proofs is slightly // more complex as proofs can touch "ephemeral" nodes, so these need to be synthesized. type ProofBuilder struct { - cp log.Checkpoint + treeSize uint64 nodeCache nodeCache h compact.HashFn } -// NewProofBuilder creates a new ProofBuilder object for a given tree size. +// NewProofBuilderForsize returns a new ProofBuilding for the given tree size. +// +// Unlike NewProofBuilder below, no correctness checking of the root hash for the given tree size is performed. +func NewProofBuilderForSize(ctx context.Context, size uint64, h compact.HashFn, f Fetcher) *ProofBuilder { + tf := newTileFetcher(f, size) + return &ProofBuilder{ + treeSize: size, + nodeCache: newNodeCache(tf, size), + h: h, + } +} + // The returned ProofBuilder can be re-used for proofs related to a given tree size, but // it is not thread-safe and should not be accessed concurrently. func NewProofBuilder(ctx context.Context, cp log.Checkpoint, h compact.HashFn, f Fetcher) (*ProofBuilder, error) { - tf := newTileFetcher(f, cp.Size) - pb := &ProofBuilder{ - cp: cp, - nodeCache: newNodeCache(tf, cp.Size), - h: h, - } + pb := NewProofBuilderForSize(ctx, cp.Size, h, f) + // Can't re-create the root of a zero size checkpoint other than by convention, // so return early here in that case. if cp.Size == 0 { return pb, nil } - hashes, err := FetchRangeNodes(ctx, cp.Size, tf) + hashes, err := FetchRangeNodes(ctx, cp.Size, pb.nodeCache.getTile) if err != nil { return nil, fmt.Errorf("failed to fetch range nodes: %w", err) } @@ -127,7 +134,7 @@ func NewProofBuilder(ctx context.Context, cp log.Checkpoint, h compact.HashFn, f // This function uses the passed-in function to retrieve tiles containing any log tree // nodes necessary to build the proof. func (pb *ProofBuilder) InclusionProof(ctx context.Context, index uint64) ([][]byte, error) { - nodes, err := proof.Inclusion(index, pb.cp.Size) + nodes, err := proof.Inclusion(index, pb.treeSize) if err != nil { return nil, fmt.Errorf("failed to calculate inclusion proof node list: %w", err) } diff --git a/client/client_test.go b/client/client_test.go index b3625a4..0f746af 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -25,7 +25,9 @@ import ( "testing" "github.com/transparency-dev/formats/log" + "github.com/transparency-dev/merkle" "github.com/transparency-dev/merkle/compact" + "github.com/transparency-dev/merkle/proof" "github.com/transparency-dev/merkle/rfc6962" "github.com/transparency-dev/serverless-log/api" "golang.org/x/mod/sumdb/note" @@ -348,3 +350,45 @@ func TestHandleZeroRoot(t *testing.T) { t.Fatalf("NewProofBuilder: %v", err) } } + +func doTestProofBuilder(t *testing.T, pb *ProofBuilder, h merkle.LogHasher) { + t.Helper() + for _, from := range testCheckpoints { + if from.Size > pb.treeSize { + return + } + cp, err := pb.ConsistencyProof(t.Context(), from.Size, pb.treeSize) + if err != nil { + t.Fatalf("pb.ConsistencyProof(%d, %d): %v", from.Size, pb.treeSize, err) + } + if err := proof.VerifyConsistency(h, from.Size, pb.treeSize, cp, from.Hash, testCheckpoints[pb.treeSize].Hash); err != nil { + t.Fatalf("pb generated invalid consistency proof between %d and %d: %v", from.Size, pb.treeSize, err) + } + } + for i := range pb.treeSize { + leaf, err := GetLeaf(t.Context(), testLogFetcher, i) + if err != nil { + t.Fatalf("GetLeaf(%d): %v", i, err) + } + ip, err := pb.InclusionProof(t.Context(), i) + if err != nil { + t.Fatalf("pb.InclusionProof(%d): %v", i, err) + } + if err := proof.VerifyInclusion(h, i, pb.treeSize, h.HashLeaf(leaf), ip, testCheckpoints[pb.treeSize].Hash); err != nil { + t.Fatalf("pb generated invalid inclusion proof for leaf %d: %v", i, err) + } + } +} + +func TestProofBuilder(t *testing.T) { + pb, err := NewProofBuilder(t.Context(), testCheckpoints[len(testCheckpoints)-1], rfc6962.DefaultHasher.HashChildren, testLogFetcher) + if err != nil { + t.Fatalf("NewProofBuilder: %v", err) + } + doTestProofBuilder(t, pb, rfc6962.DefaultHasher) +} + +func TestProofBuilderForSize(t *testing.T) { + pb := NewProofBuilderForSize(t.Context(), testCheckpoints[len(testCheckpoints)-1].Size, rfc6962.DefaultHasher.HashChildren, testLogFetcher) + doTestProofBuilder(t, pb, rfc6962.DefaultHasher) +}