From f200d7e12ae094f600bfec84a1ff2404bc45dc6e Mon Sep 17 00:00:00 2001 From: michaelkedar Date: Tue, 16 Jun 2026 06:17:19 +0000 Subject: [PATCH 1/2] panic catchers --- go/internal/api/determine_version.go | 2 +- go/internal/api/query_affected.go | 55 ++++++-- go/internal/api/query_affected_test.go | 125 ++++++++++++++++++ go/internal/database/datastore/repo_index.go | 5 +- .../database/datastore/repo_index_test.go | 58 ++++++++ go/internal/osvutil/batcher/batcher.go | 29 +++- go/internal/osvutil/batcher/batcher_test.go | 45 +++++++ go/internal/osvutil/safe/safe.go | 78 +++++++++++ 8 files changed, 383 insertions(+), 14 deletions(-) create mode 100644 go/internal/database/datastore/repo_index_test.go create mode 100644 go/internal/osvutil/safe/safe.go diff --git a/go/internal/api/determine_version.go b/go/internal/api/determine_version.go index 59930cfc153..3ddcd479506 100644 --- a/go/internal/api/determine_version.go +++ b/go/internal/api/determine_version.go @@ -176,7 +176,7 @@ func (s *server) DetermineVersion(ctx context.Context, req *pb.DetermineVersionP // Filter and prepare file hashes var validHashes []*pb.FileHash for _, fh := range query.GetFileHashes() { - if fh.Hash != nil && len(fh.GetHash()) <= 100 { + if fh.GetHash() != nil && len(fh.GetHash()) <= 100 { validHashes = append(validHashes, fh) } } diff --git a/go/internal/api/query_affected.go b/go/internal/api/query_affected.go index dbbb12e067b..3c65a131052 100644 --- a/go/internal/api/query_affected.go +++ b/go/internal/api/query_affected.go @@ -14,6 +14,7 @@ import ( "cloud.google.com/go/pubsub/v2" "github.com/google/osv.dev/go/internal/models" + "github.com/google/osv.dev/go/internal/osvutil/safe" "github.com/google/osv.dev/go/internal/osvutil/schema" "github.com/google/osv.dev/go/logger" "github.com/google/osv.dev/go/purl" @@ -84,6 +85,16 @@ func (s *server) QueryAffected(ctx context.Context, params *pb.QueryAffectedPara estimatedSizeBytes, ) if err != nil { + var panicErr *safe.PanicError + if errors.As(err, &panicErr) { + logger.ErrorContext(ctx, "recovered panic in background worker", + slog.Any("panic", panicErr.Value), + slog.String("stack", string(panicErr.Stack)), + ) + + return nil, status.Error(codes.Internal, "internal server error") + } + return nil, err } if s.verboseLogs { @@ -181,8 +192,8 @@ func (s *server) QueryAffectedBatch(ctx context.Context, params *pb.QueryAffecte // Create a buffered channel so workers can exit even if we return early on error. resultsChan := make(chan *queryAndHydrateResult, len(queries)) - pipelineCtx, cancelPipelines := context.WithCancel(ctx) - defer cancelPipelines() + pipelineCtx, cancelPipelines := context.WithCancelCause(ctx) + defer cancelPipelines(nil) batchCtx, matchCancel := context.WithTimeout(pipelineCtx, s.getBatchQueryTimeout()) defer matchCancel() @@ -243,7 +254,12 @@ func (s *server) QueryAffectedBatch(ctx context.Context, params *pb.QueryAffecte } for i, matcherIter := range iters { - go func() { + go safe.Func(func(r any, stack []byte) { + resultsChan <- &queryAndHydrateResult{ + idx: i, + err: &safe.PanicError{Value: r, Stack: stack}, + } + }, func() { if queryInfos[i] == nil { // handling unknown PURL types resultsChan <- &queryAndHydrateResult{ @@ -270,7 +286,7 @@ func (s *server) QueryAffectedBatch(ctx context.Context, params *pb.QueryAffecte nextToken: nextToken, err: err, } - }() + })() } list := &pb.BatchVulnerabilityList{} @@ -279,7 +295,17 @@ func (s *server) QueryAffectedBatch(ctx context.Context, params *pb.QueryAffecte for range queryInfos { result := <-resultsChan if result.err != nil { - cancelPipelines() // Abort all other running pipelines in the background + cancelPipelines(result.err) // Abort all other running pipelines in the background + var panicErr *safe.PanicError + if errors.As(result.err, &panicErr) { + logger.ErrorContext(ctx, "recovered panic in batch worker", + slog.Any("panic", panicErr.Value), + slog.String("stack", string(panicErr.Stack)), + ) + + return nil, status.Error(codes.Internal, "internal server error") + } + return nil, fmt.Errorf("error in query at index %d: %w", result.idx, result.err) } list.Results[result.idx] = &pb.VulnerabilityList{ @@ -501,7 +527,7 @@ func (s *server) runMatcher( if startTok == "" { currentCursor = func() string { return startCursor } } - go func() { + safe.GoCancel(cancel, func() { defer close(done) defer close(resultIDs) idx := 0 @@ -515,7 +541,11 @@ func (s *server) runMatcher( return } - currentCursor = match.Cursor + if match.Cursor != nil { + currentCursor = match.Cursor + } else { + currentCursor = func() string { return "" } + } if !match.IsMatch { continue } @@ -531,7 +561,7 @@ func (s *server) runMatcher( } // We finished the entire query only if context was not cancelled (meaning loop finished naturally) currentCursor = func() string { return "" } - }() + }) return matcherResult{ resultIDsCh: resultIDs, @@ -551,8 +581,13 @@ func (s *server) runMatcher( func (s *server) hydrateParallel(ctx context.Context, resultIDs <-chan matchVuln, hydrate hydrateFunc) <-chan hydratedResult { hydrated := make(chan hydratedResult, numParallelHydration) var wg sync.WaitGroup + onPanic := func(r any, stack []byte) { + hydrated <- hydratedResult{ + err: &safe.PanicError{Value: r, Stack: stack}, + } + } for range numParallelHydration { - wg.Go(func() { + wg.Go(safe.Func(onPanic, func() { for mv := range resultIDs { v, err := hydrate(ctx, mv.id) if err != nil { @@ -561,7 +596,7 @@ func (s *server) hydrateParallel(ctx context.Context, resultIDs <-chan matchVuln } hydrated <- hydratedResult{index: mv.index, v: v, id: mv.id} } - }) + })) } go func() { wg.Wait() diff --git a/go/internal/api/query_affected_test.go b/go/internal/api/query_affected_test.go index 7ad79c0a97f..e3164427de2 100644 --- a/go/internal/api/query_affected_test.go +++ b/go/internal/api/query_affected_test.go @@ -3,6 +3,7 @@ package api import ( "context" "errors" + "fmt" "iter" "strings" "sync" @@ -783,3 +784,127 @@ func TestQueryAffected_HydrationNotFound_PublishesRecovery(t *testing.T) { t.Errorf("Expected message id 'VULN-1', got %q", msg.Attributes["id"]) } } + +func TestQueryAffected_NilCursorSafety(t *testing.T) { + ctx := context.Background() + + store := &mockQueryVulnStore{ + matchPackages: func(_ context.Context, _, _, _, _ string) iter.Seq2[models.MatchResult, error] { + return func(yield func(models.MatchResult, error) bool) { + for i := range 3000 { + if !yield(models.MatchResult{ + IsMatch: true, + ID: fmt.Sprintf("VULN-%d", i), + Cursor: nil, // Explicitly nil + }, nil) { + return + } + } + } + }, + get: func(_ context.Context, id string) (*osvschema.Vulnerability, error) { + return &osvschema.Vulnerability{Id: id}, nil + }, + } + + s := &server{ + vulnStore: store, + } + + params := &pb.QueryAffectedParameters{ + Query: &pb.Query{ + Param: &pb.Query_Version{Version: "1.0.0"}, + Package: &osvschema.Package{Name: "pkg-1", Ecosystem: "npm"}, + }, + } + + got, err := s.QueryAffected(ctx, params) + if err != nil { + t.Fatalf("QueryAffected() unexpected error: %v", err) + } + + // It should succeed and return the next page token as empty string (since cursor was nil) + if got.GetNextPageToken() != "" { + t.Errorf("Expected empty next page token, got %q", got.GetNextPageToken()) + } +} + +func TestQueryAffected_MatcherPanicPropagation(t *testing.T) { + ctx := context.Background() + store := &mockQueryVulnStore{ + matchPackages: func(_ context.Context, _, _, _, _ string) iter.Seq2[models.MatchResult, error] { + return func(_ func(models.MatchResult, error) bool) { + // Simulates a panic inside the database iterator loop (runMatcher) + panic("matcher database crash") + } + }, + } + + s := &server{ + vulnStore: store, + } + + params := &pb.QueryAffectedParameters{ + Query: &pb.Query{ + Param: &pb.Query_Version{Version: "1.0.0"}, + Package: &osvschema.Package{Name: "pkg-1", Ecosystem: "npm"}, + }, + } + + _, err := s.QueryAffected(ctx, params) + if err == nil { + t.Fatalf("Expected error, got nil") + } + st, ok := status.FromError(err) + if !ok { + t.Fatalf("Expected gRPC status error, got %v", err) + } + if st.Code() != codes.Internal { + t.Errorf("Expected gRPC code Internal, got %v", st.Code()) + } + if !strings.Contains(st.Message(), "internal server error") { + t.Errorf("Expected error message to contain 'internal server error', got %q", st.Message()) + } +} + +func TestQueryAffectedBatch_WorkerPanicPropagation(t *testing.T) { + ctx := context.Background() + store := &mockQueryVulnStore{ + matchPackages: func(_ context.Context, _, _, _, _ string) iter.Seq2[models.MatchResult, error] { + return func(_ func(models.MatchResult, error) bool) { + // Simulates a panic inside the batch worker pipeline (QueryAffectedBatch worker) + panic("batch worker crash") + } + }, + } + + s := &server{ + vulnStore: store, + } + + params := &pb.QueryAffectedBatchParameters{ + Query: &pb.BatchQuery{ + Queries: []*pb.Query{ + { + Param: &pb.Query_Version{Version: "1.0.0"}, + Package: &osvschema.Package{Name: "pkg-1", Ecosystem: "npm"}, + }, + }, + }, + } + + _, err := s.QueryAffectedBatch(ctx, params) + if err == nil { + t.Fatalf("Expected error, got nil") + } + st, ok := status.FromError(err) + if !ok { + t.Fatalf("Expected gRPC status error, got %v", err) + } + if st.Code() != codes.Internal { + t.Errorf("Expected gRPC code Internal, got %v", st.Code()) + } + if !strings.Contains(st.Message(), "internal server error") { + t.Errorf("Expected error message to contain 'internal server error', got %q", st.Message()) + } +} diff --git a/go/internal/database/datastore/repo_index.go b/go/internal/database/datastore/repo_index.go index 3c153b6b918..7dba2e01f72 100644 --- a/go/internal/database/datastore/repo_index.go +++ b/go/internal/database/datastore/repo_index.go @@ -23,6 +23,7 @@ import ( "cloud.google.com/go/datastore" "github.com/google/osv.dev/go/internal/models" + "github.com/google/osv.dev/go/internal/osvutil/safe" "golang.org/x/sync/errgroup" ) @@ -60,7 +61,7 @@ func (s *RepoIndexStore) QueryBuckets(ctx context.Context, nodeHashes [][]byte) for _, hash := range nodeHashes { h := hash // capture loop variable - g.Go(func() error { + g.Go(safe.ErrgroupFunc(func() error { q := datastore.NewQuery(RepoIndexBucketKind). FilterField("node_hash", "=", h). Limit(models.MaxMatchesToCare) @@ -92,7 +93,7 @@ func (s *RepoIndexStore) QueryBuckets(ctx context.Context, nodeHashes [][]byte) mu.Unlock() return nil - }) + })) } if err := g.Wait(); err != nil { diff --git a/go/internal/database/datastore/repo_index_test.go b/go/internal/database/datastore/repo_index_test.go new file mode 100644 index 00000000000..6b551f6d558 --- /dev/null +++ b/go/internal/database/datastore/repo_index_test.go @@ -0,0 +1,58 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package datastore + +import ( + "context" + "errors" + "testing" + + "github.com/google/osv.dev/go/internal/osvutil/safe" +) + +func TestRepoIndexStore_QueryBucketsPanicPropagation(t *testing.T) { + ctx := context.Background() + + // Initialize the store with a nil client. + // Any call to s.client.GetAll inside the errgroup tasks will trigger a nil-pointer panic. + store := NewRepoIndexStore(nil) + + // Trigger QueryBuckets with some mock node hashes + _, err := store.QueryBuckets(ctx, [][]byte{ + {0xab, 0xcd}, + {0x12, 0x34}, + }) + + if err == nil { + t.Fatalf("Expected error due to nil-pointer panic, got nil") + } + + // Verify that the panic was recovered and returned as a *safe.PanicError + var panicErr *safe.PanicError + if !errors.As(err, &panicErr) { + t.Fatalf("Expected error to be *safe.PanicError, got %T: %v", err, err) + } + + // The panic value should indicate a nil pointer dereference (runtime error) + if panicErr.Value == nil { + t.Errorf("Expected non-nil panic value") + } + + if len(panicErr.Stack) == 0 { + t.Errorf("Expected stack trace to be populated, got empty") + } + + t.Logf("Successfully recovered and propagated nil-pointer panic from errgroup: %v", panicErr) +} diff --git a/go/internal/osvutil/batcher/batcher.go b/go/internal/osvutil/batcher/batcher.go index 675cb4d9ade..4fade33cf5d 100644 --- a/go/internal/osvutil/batcher/batcher.go +++ b/go/internal/osvutil/batcher/batcher.go @@ -25,8 +25,11 @@ package batcher import ( "context" "errors" + "runtime/debug" "sync" "time" + + "github.com/google/osv.dev/go/internal/osvutil/safe" ) // Result wraps the value and error returned for a single key in the batch. @@ -167,6 +170,30 @@ func (b *Batcher[K, R]) Get(ctx context.Context, key K) (R, error) { } func (b *Batcher[K, R]) runBatchLoop() { + var batch []*request[K, R] + defer func() { + if r := recover(); r != nil { + panicErr := &safe.PanicError{ + Value: r, + Stack: debug.Stack(), + } + b.mu.Lock() + // If we panicked before stealing the pending slice, steal it now + if len(batch) == 0 && len(b.pending) > 0 { + batch = b.pending + b.pending = nil + } + b.mu.Unlock() + + for _, req := range batch { + select { + case req.resultChan <- Result[R]{Err: panicErr}: + default: + } + } + } + }() + // Wait for the batch to fill up or the timeout to expire. select { case <-time.After(b.timeout): @@ -174,7 +201,7 @@ func (b *Batcher[K, R]) runBatchLoop() { } b.mu.Lock() - batch := b.pending + batch = b.pending b.pending = nil // Reset pending so the next request starts a new batch. // Drain triggerChan to avoid stale triggers for the next batch. select { diff --git a/go/internal/osvutil/batcher/batcher_test.go b/go/internal/osvutil/batcher/batcher_test.go index 5e7f6c4fe44..2e3aa15aa14 100644 --- a/go/internal/osvutil/batcher/batcher_test.go +++ b/go/internal/osvutil/batcher/batcher_test.go @@ -20,6 +20,8 @@ import ( "sync" "testing" "time" + + "github.com/google/osv.dev/go/internal/osvutil/safe" ) func TestBatcher_Get(t *testing.T) { @@ -207,3 +209,46 @@ func TestBatcher_Cancellation(t *testing.T) { t.Logf("Total time: %v", time.Since(start)) } + +func TestBatcher_PanicPropagation(t *testing.T) { + ctx := context.Background() + + // A batch function that panics + batchFunc := func(_ context.Context, _ []string) []Result[string] { + panic("database connection lost") + } + + b := New(10*time.Millisecond, 2, batchFunc) + + var wg sync.WaitGroup + var err1, err2 error + + wg.Go(func() { + _, err1 = b.Get(ctx, "key1") + }) + wg.Go(func() { + _, err2 = b.Get(ctx, "key2") + }) + + wg.Wait() + + // Both callers should have received the PanicError + for i, err := range []error{err1, err2} { + if err == nil { + t.Errorf("Expected error for caller %d, got nil", i+1) + continue + } + var panicErr *safe.PanicError + if !errors.As(err, &panicErr) { + t.Errorf("Expected error to be *safe.PanicError, got %T: %v", err, err) + continue + } + if panicErr.Value != "database connection lost" { + t.Errorf("Expected panic value 'database connection lost', got %v", panicErr.Value) + } + if len(panicErr.Stack) == 0 { + t.Errorf("Expected stack trace to be populated, got empty") + } + t.Logf("Caller %d successfully received propagated PanicError: %v", i+1, panicErr) + } +} diff --git a/go/internal/osvutil/safe/safe.go b/go/internal/osvutil/safe/safe.go new file mode 100644 index 00000000000..5c572532f7b --- /dev/null +++ b/go/internal/osvutil/safe/safe.go @@ -0,0 +1,78 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package safe provides simple utilities to spawn goroutines with panic recovery +// and propagate those panics back to the main thread as error values. +package safe + +import ( + "context" + "fmt" + "runtime/debug" +) + +// PanicError wraps a panic value and its stack trace. +type PanicError struct { + Value any + Stack []byte +} + +func (p *PanicError) Error() string { + return fmt.Sprintf("panic recovered: %v", p.Value) +} + +// GoCancel spawns a goroutine. If it panics, it cancels the context with a PanicError. +func GoCancel(cancel context.CancelCauseFunc, f func()) { + go func() { + defer func() { + if r := recover(); r != nil { + cancel(&PanicError{ + Value: r, + Stack: debug.Stack(), + }) + } + }() + f() + }() +} + +// Func wraps any function with panic recovery, propagating it to a callback. +func Func(onPanic func(r any, stack []byte), f func()) func() { + return func() { + defer func() { + if r := recover(); r != nil { + onPanic(r, debug.Stack()) + } + }() + f() + } +} + +// ErrgroupFunc wraps a function returning an error with panic recovery. +// If the function panics, it recovers the panic and returns a PanicError as the error, +// allowing it to propagate through errgroup.Group tasks safely. +func ErrgroupFunc(f func() error) func() error { + return func() (err error) { + defer func() { + if r := recover(); r != nil { + err = &PanicError{ + Value: r, + Stack: debug.Stack(), + } + } + }() + + return f() + } +} From d4237617e46b45830fa4697e22d51987e04a6a9a Mon Sep 17 00:00:00 2001 From: michaelkedar Date: Wed, 17 Jun 2026 00:15:31 +0000 Subject: [PATCH 2/2] dr health --- go/cmd/api/main.go | 33 +++++++++--- go/internal/api/query_affected.go | 6 +++ go/internal/api/server.go | 52 +++++++++++++++++++ .../database/datastore/vulnerability.go | 21 ++++++++ .../database/datastore/vulnerability_test.go | 41 +++++++++++++++ go/internal/models/vulnerability.go | 8 +++ go/testutils/gcs.go | 9 ++++ 7 files changed, 163 insertions(+), 7 deletions(-) diff --git a/go/cmd/api/main.go b/go/cmd/api/main.go index 9d9a366cbe7..60117288427 100644 --- a/go/cmd/api/main.go +++ b/go/cmd/api/main.go @@ -107,13 +107,32 @@ func run() error { recovererPublisher = &clients.GCPPublisher{Publisher: pubsubClient.Publisher(recovererTopic)} } + var healthInterval time.Duration + if t := os.Getenv("OSV_HEALTH_CHECK_INTERVAL"); t != "" { + if d, err := time.ParseDuration(t); err == nil { + healthInterval = d + } else { + logger.ErrorContext(ctx, "Invalid OSV_HEALTH_CHECK_INTERVAL, using default", slog.Any("error", err)) + } + } + var healthThreshold int + if m := os.Getenv("OSV_HEALTH_CHECK_THRESHOLD"); m != "" { + if val, err := strconv.Atoi(m); err == nil { + healthThreshold = val + } else { + logger.ErrorContext(ctx, "Invalid OSV_HEALTH_CHECK_THRESHOLD, using default", slog.Any("error", err)) + } + } + return api.RunServer(ctx, api.ServerOptions{ - Port: *port, - VerboseLogs: verboseLogs, - VulnStore: vulnStore, - RelationsStore: relationsStore, - ImportFindingsStore: importFindingsStore, - RepoIndexStore: repoIndexStore, - RecovererPublisher: recovererPublisher, + Port: *port, + VerboseLogs: verboseLogs, + VulnStore: vulnStore, + RelationsStore: relationsStore, + ImportFindingsStore: importFindingsStore, + RepoIndexStore: repoIndexStore, + RecovererPublisher: recovererPublisher, + HealthCheckInterval: healthInterval, + HealthCheckThreshold: healthThreshold, }) } diff --git a/go/internal/api/query_affected.go b/go/internal/api/query_affected.go index 3c65a131052..d42c5c20025 100644 --- a/go/internal/api/query_affected.go +++ b/go/internal/api/query_affected.go @@ -670,6 +670,12 @@ func (s *server) collectAndSort(ctx context.Context, if errors.Is(err, models.ErrInvalidCursor) { return nil, status.Error(codes.InvalidArgument, "invalid cursor") } + var panicErr *safe.PanicError + if errors.As(err, &panicErr) { + // Return the raw PanicError so the caller handlers can detect it, + // log the stack trace, and obscure it into a clean "internal server error". + return nil, err + } return nil, status.Error(codes.Internal, err.Error()) } diff --git a/go/internal/api/server.go b/go/internal/api/server.go index 026899012d4..671cb185c0a 100644 --- a/go/internal/api/server.go +++ b/go/internal/api/server.go @@ -42,6 +42,9 @@ type ServerOptions struct { ImportFindingsStore models.ImportFindingsStore RepoIndexStore models.RepoIndexStore RecovererPublisher clients.Publisher + + HealthCheckInterval time.Duration + HealthCheckThreshold int } // RunServer starts the gRPC server and handles graceful shutdown. @@ -65,6 +68,9 @@ func RunServer(ctx context.Context, opts ServerOptions) error { healthServer := health.NewServer() healthgrpc.RegisterHealthServer(s, healthServer) + // Start background dependency health monitor + go monitorDatabaseHealth(ctx, healthServer, opts.VulnStore, opts.HealthCheckInterval, opts.HealthCheckThreshold) + logger.InfoContext(ctx, "server listening", "port", opts.Port) serveErr := make(chan error, 1) @@ -90,3 +96,49 @@ func RunServer(ctx context.Context, opts ServerOptions) error { return nil } + +// monitorDatabaseHealth runs a background loop to passively monitor critical backend dependencies (Datastore, GCS, Batcher) +// and updates the gRPC serving status. It runs at the configured interval and requires a configured number of consecutive +// failures to mark the server as unhealthy, preventing transient network noise from causing false-positive outages. +func monitorDatabaseHealth(ctx context.Context, healthServer *health.Server, store models.VulnerabilityStore, interval time.Duration, threshold int) { + if interval <= 0 { + interval = 10 * time.Second // Sensible default + } + if threshold <= 0 { + threshold = 3 // Sensible default + } + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + // Default to SERVING on startup + healthServer.SetServingStatus("", healthgrpc.HealthCheckResponse_SERVING) + + consecutiveFailures := 0 + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + pingCtx, cancel := context.WithTimeout(ctx, 3*time.Second) + err := store.Ping(pingCtx) + cancel() + + if err != nil { + consecutiveFailures++ + logger.ErrorContext(ctx, "Dependency health check failed", "error", err, "failures", consecutiveFailures) + + if consecutiveFailures >= threshold { + healthServer.SetServingStatus("", healthgrpc.HealthCheckResponse_NOT_SERVING) + } + } else { + if consecutiveFailures > 0 { + logger.InfoContext(ctx, "Dependency health restored") + } + consecutiveFailures = 0 + healthServer.SetServingStatus("", healthgrpc.HealthCheckResponse_SERVING) + } + } + } +} diff --git a/go/internal/database/datastore/vulnerability.go b/go/internal/database/datastore/vulnerability.go index 75c125f9b7d..732f31f1b5b 100644 --- a/go/internal/database/datastore/vulnerability.go +++ b/go/internal/database/datastore/vulnerability.go @@ -448,3 +448,24 @@ func (s *VulnerabilityStore) uploadToGCS(ctx context.Context, id string, enriche return nil } + +// Ping performs a cheap, end-to-end health check of GCS, Datastore, and the internal Batcher. +func (s *VulnerabilityStore) Ping(ctx context.Context) error { + // 1. Test GCS (Cheap metadata check) + // We read attributes of a dummy path. If GCS is healthy, it will return clients.ErrNotFound. + // If GCS is unreachable or permissions are broken, it will return a real connection/IAM error. + _, gcsErr := s.gcsStore.ReadObjectAttrs(ctx, "health-check-ping") + if gcsErr != nil && !errors.Is(gcsErr, clients.ErrNotFound) { + return fmt.Errorf("gcs unreachable: %w", gcsErr) + } + + // 2. Test Batcher + Datastore + // This queues a request in modifiedBatcher, verifying the background + // worker loop is alive and Datastore is responding. + _, dbErr := s.GetModified(ctx, "OSV-PING-NON-EXISTENT") + if dbErr != nil && !errors.Is(dbErr, models.ErrNotFound) { + return fmt.Errorf("datastore/batcher unreachable: %w", dbErr) + } + + return nil +} diff --git a/go/internal/database/datastore/vulnerability_test.go b/go/internal/database/datastore/vulnerability_test.go index 2f3b58a1424..9a4b5404886 100644 --- a/go/internal/database/datastore/vulnerability_test.go +++ b/go/internal/database/datastore/vulnerability_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "errors" + "strings" "sync" "testing" "time" @@ -533,3 +534,43 @@ func TestVulnerabilityStore_GetModified_BatchTriggerEarly(t *testing.T) { t.Errorf("Request 2 got %v, want %v", got2, want2) } } + +func TestVulnerabilityStore_Ping(t *testing.T) { + ctx := context.Background() + dsClient := testutils.MustNewDatastoreClientForTesting(t) + mockGCS := testutils.NewMockStorage() + + store := NewVulnerabilityStore(VulnStoreConfig{ + Client: dsClient, + GCS: mockGCS, + }) + + // 1. Test Happy Path: Both GCS and Datastore are healthy (but empty). + // Ping should return nil because it correctly filters out ErrNotFound. + if err := store.Ping(ctx); err != nil { + t.Errorf("Expected Ping to succeed on healthy stores, got: %v", err) + } + + // 2. Test GCS Outage: Mock GCS returning a real connection error. + mockGCS.ReadError = errors.New("GCS connection timeout") + if err := store.Ping(ctx); err == nil { + t.Errorf("Expected Ping to fail when GCS is unreachable, got nil") + } else if !strings.Contains(err.Error(), "gcs unreachable") { + t.Errorf("Expected GCS unreachable error, got: %v", err) + } + + // 3. Test Datastore Outage: If we pass a nil client, Ping should fail + // cleanly without panicking (due to our batcher panic-safety). + brokenStore := NewVulnerabilityStore(VulnStoreConfig{ + Client: nil, + GCS: mockGCS, + }) + // Reset the GCS error so only Datastore is failing + mockGCS.ReadError = nil + + if err := brokenStore.Ping(ctx); err == nil { + t.Errorf("Expected Ping to fail when Datastore client is nil, got nil") + } else if !strings.Contains(err.Error(), "datastore/batcher unreachable") { + t.Errorf("Expected Datastore/Batcher unreachable error, got: %v", err) + } +} diff --git a/go/internal/models/vulnerability.go b/go/internal/models/vulnerability.go index b250bcad8ba..906e1f93a0b 100644 --- a/go/internal/models/vulnerability.go +++ b/go/internal/models/vulnerability.go @@ -54,6 +54,7 @@ type MatchResult struct { var ErrInvalidCursor = errors.New("invalid cursor for query") +//nolint:interfacebloat // The store represents a rich domain repository with multiple query and match interfaces. type VulnerabilityStore interface { // ListBySource returns an iterator over vulnerabilities for a given source. ListBySource(ctx context.Context, source string, skipWithdrawn bool) iter.Seq2[*VulnSourceRef, error] @@ -85,6 +86,9 @@ type VulnerabilityStore interface { // MatchCommitsBatch returns iterators of vulnerability IDs for a batch of commit queries. MatchCommitsBatch(ctx context.Context, queries []CommitQuery) ([]iter.Seq2[MatchResult, error], error) + + // Ping verifies connection health to the underlying store dependencies (e.g. database, storage). + Ping(ctx context.Context) error } type UnimplementedVulnerabilityStore struct{} @@ -130,3 +134,7 @@ func (s UnimplementedVulnerabilityStore) MatchPackagesBatch(_ context.Context, _ func (s UnimplementedVulnerabilityStore) MatchCommitsBatch(_ context.Context, _ []CommitQuery) ([]iter.Seq2[MatchResult, error], error) { panic("not implemented") } + +func (s UnimplementedVulnerabilityStore) Ping(_ context.Context) error { + panic("not implemented") +} diff --git a/go/testutils/gcs.go b/go/testutils/gcs.go index 48ea21a01e2..7e6df550223 100644 --- a/go/testutils/gcs.go +++ b/go/testutils/gcs.go @@ -30,6 +30,7 @@ type MockStorage struct { objects map[string]*mockObject // object path -> object data WriteError error // Global error to return on WriteObject WriteErrors map[string]error // Per-path error to return on WriteObject + ReadError error // Global error to return on ReadObject and ReadObjectAttrs } // NewMockStorage creates a new mock storage client. @@ -44,6 +45,10 @@ func (c *MockStorage) ReadObject(_ context.Context, path string) ([]byte, error) c.mu.RLock() defer c.mu.RUnlock() + if c.ReadError != nil { + return nil, c.ReadError + } + obj, ok := c.objects[path] if !ok { return nil, clients.ErrNotFound @@ -60,6 +65,10 @@ func (c *MockStorage) ReadObjectAttrs(_ context.Context, path string) (*clients. c.mu.RLock() defer c.mu.RUnlock() + if c.ReadError != nil { + return nil, c.ReadError + } + obj, ok := c.objects[path] if !ok { return nil, clients.ErrNotFound