Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 26 additions & 7 deletions go/cmd/api/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,32 @@ func run() error {
recovererPublisher = &clients.GCPPublisher{Publisher: pubsubClient.Publisher(recovererTopic)}
}

var healthInterval time.Duration
if t := os.Getenv("OSV_HEALTH_CHECK_INTERVAL"); t != "" {
if d, err := time.ParseDuration(t); err == nil {
healthInterval = d
} else {
logger.ErrorContext(ctx, "Invalid OSV_HEALTH_CHECK_INTERVAL, using default", slog.Any("error", err))
}
}
var healthThreshold int
if m := os.Getenv("OSV_HEALTH_CHECK_THRESHOLD"); m != "" {
if val, err := strconv.Atoi(m); err == nil {
healthThreshold = val
} else {
logger.ErrorContext(ctx, "Invalid OSV_HEALTH_CHECK_THRESHOLD, using default", slog.Any("error", err))
}
}

return api.RunServer(ctx, api.ServerOptions{
Port: *port,
VerboseLogs: verboseLogs,
VulnStore: vulnStore,
RelationsStore: relationsStore,
ImportFindingsStore: importFindingsStore,
RepoIndexStore: repoIndexStore,
RecovererPublisher: recovererPublisher,
Port: *port,
VerboseLogs: verboseLogs,
VulnStore: vulnStore,
RelationsStore: relationsStore,
ImportFindingsStore: importFindingsStore,
RepoIndexStore: repoIndexStore,
RecovererPublisher: recovererPublisher,
HealthCheckInterval: healthInterval,
HealthCheckThreshold: healthThreshold,
})
}
2 changes: 1 addition & 1 deletion go/internal/api/determine_version.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ func (s *server) DetermineVersion(ctx context.Context, req *pb.DetermineVersionP
// Filter and prepare file hashes
var validHashes []*pb.FileHash
for _, fh := range query.GetFileHashes() {
if fh.Hash != nil && len(fh.GetHash()) <= 100 {
if fh.GetHash() != nil && len(fh.GetHash()) <= 100 {
validHashes = append(validHashes, fh)
}
}
Expand Down
61 changes: 51 additions & 10 deletions go/internal/api/query_affected.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (

"cloud.google.com/go/pubsub/v2"
"github.com/google/osv.dev/go/internal/models"
"github.com/google/osv.dev/go/internal/osvutil/safe"
"github.com/google/osv.dev/go/internal/osvutil/schema"
"github.com/google/osv.dev/go/logger"
"github.com/google/osv.dev/go/purl"
Expand Down Expand Up @@ -84,6 +85,16 @@ func (s *server) QueryAffected(ctx context.Context, params *pb.QueryAffectedPara
estimatedSizeBytes,
)
if err != nil {
var panicErr *safe.PanicError
if errors.As(err, &panicErr) {
logger.ErrorContext(ctx, "recovered panic in background worker",
slog.Any("panic", panicErr.Value),
slog.String("stack", string(panicErr.Stack)),
)

return nil, status.Error(codes.Internal, "internal server error")
}

return nil, err
}
if s.verboseLogs {
Expand Down Expand Up @@ -181,8 +192,8 @@ func (s *server) QueryAffectedBatch(ctx context.Context, params *pb.QueryAffecte
// Create a buffered channel so workers can exit even if we return early on error.
resultsChan := make(chan *queryAndHydrateResult, len(queries))

pipelineCtx, cancelPipelines := context.WithCancel(ctx)
defer cancelPipelines()
pipelineCtx, cancelPipelines := context.WithCancelCause(ctx)
defer cancelPipelines(nil)

batchCtx, matchCancel := context.WithTimeout(pipelineCtx, s.getBatchQueryTimeout())
defer matchCancel()
Expand Down Expand Up @@ -243,7 +254,12 @@ func (s *server) QueryAffectedBatch(ctx context.Context, params *pb.QueryAffecte
}

for i, matcherIter := range iters {
go func() {
go safe.Func(func(r any, stack []byte) {
resultsChan <- &queryAndHydrateResult{
idx: i,
err: &safe.PanicError{Value: r, Stack: stack},
}
}, func() {
if queryInfos[i] == nil {
// handling unknown PURL types
resultsChan <- &queryAndHydrateResult{
Expand All @@ -270,7 +286,7 @@ func (s *server) QueryAffectedBatch(ctx context.Context, params *pb.QueryAffecte
nextToken: nextToken,
err: err,
}
}()
})()
}

list := &pb.BatchVulnerabilityList{}
Expand All @@ -279,7 +295,17 @@ func (s *server) QueryAffectedBatch(ctx context.Context, params *pb.QueryAffecte
for range queryInfos {
result := <-resultsChan
if result.err != nil {
cancelPipelines() // Abort all other running pipelines in the background
cancelPipelines(result.err) // Abort all other running pipelines in the background
var panicErr *safe.PanicError
if errors.As(result.err, &panicErr) {
logger.ErrorContext(ctx, "recovered panic in batch worker",
slog.Any("panic", panicErr.Value),
slog.String("stack", string(panicErr.Stack)),
)

return nil, status.Error(codes.Internal, "internal server error")
}

return nil, fmt.Errorf("error in query at index %d: %w", result.idx, result.err)
}
list.Results[result.idx] = &pb.VulnerabilityList{
Expand Down Expand Up @@ -501,7 +527,7 @@ func (s *server) runMatcher(
if startTok == "" {
currentCursor = func() string { return startCursor }
}
go func() {
safe.GoCancel(cancel, func() {
defer close(done)
defer close(resultIDs)
idx := 0
Expand All @@ -515,7 +541,11 @@ func (s *server) runMatcher(

return
}
currentCursor = match.Cursor
if match.Cursor != nil {
currentCursor = match.Cursor
} else {
currentCursor = func() string { return "" }
}
if !match.IsMatch {
continue
}
Expand All @@ -531,7 +561,7 @@ func (s *server) runMatcher(
}
// We finished the entire query only if context was not cancelled (meaning loop finished naturally)
currentCursor = func() string { return "" }
}()
})

return matcherResult{
resultIDsCh: resultIDs,
Expand All @@ -551,8 +581,13 @@ func (s *server) runMatcher(
func (s *server) hydrateParallel(ctx context.Context, resultIDs <-chan matchVuln, hydrate hydrateFunc) <-chan hydratedResult {
hydrated := make(chan hydratedResult, numParallelHydration)
var wg sync.WaitGroup
onPanic := func(r any, stack []byte) {
hydrated <- hydratedResult{
err: &safe.PanicError{Value: r, Stack: stack},
}
}
for range numParallelHydration {
wg.Go(func() {
wg.Go(safe.Func(onPanic, func() {
for mv := range resultIDs {
v, err := hydrate(ctx, mv.id)
if err != nil {
Expand All @@ -561,7 +596,7 @@ func (s *server) hydrateParallel(ctx context.Context, resultIDs <-chan matchVuln
}
hydrated <- hydratedResult{index: mv.index, v: v, id: mv.id}
}
})
}))
}
go func() {
wg.Wait()
Expand Down Expand Up @@ -635,6 +670,12 @@ func (s *server) collectAndSort(ctx context.Context,
if errors.Is(err, models.ErrInvalidCursor) {
return nil, status.Error(codes.InvalidArgument, "invalid cursor")
}
var panicErr *safe.PanicError
if errors.As(err, &panicErr) {
// Return the raw PanicError so the caller handlers can detect it,
// log the stack trace, and obscure it into a clean "internal server error".
return nil, err
}

return nil, status.Error(codes.Internal, err.Error())
}
Expand Down
125 changes: 125 additions & 0 deletions go/internal/api/query_affected_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package api
import (
"context"
"errors"
"fmt"
"iter"
"strings"
"sync"
Expand Down Expand Up @@ -783,3 +784,127 @@ func TestQueryAffected_HydrationNotFound_PublishesRecovery(t *testing.T) {
t.Errorf("Expected message id 'VULN-1', got %q", msg.Attributes["id"])
}
}

func TestQueryAffected_NilCursorSafety(t *testing.T) {
ctx := context.Background()

store := &mockQueryVulnStore{
matchPackages: func(_ context.Context, _, _, _, _ string) iter.Seq2[models.MatchResult, error] {
return func(yield func(models.MatchResult, error) bool) {
for i := range 3000 {
if !yield(models.MatchResult{
IsMatch: true,
ID: fmt.Sprintf("VULN-%d", i),
Cursor: nil, // Explicitly nil
}, nil) {
return
}
}
}
},
get: func(_ context.Context, id string) (*osvschema.Vulnerability, error) {
return &osvschema.Vulnerability{Id: id}, nil
},
}

s := &server{
vulnStore: store,
}

params := &pb.QueryAffectedParameters{
Query: &pb.Query{
Param: &pb.Query_Version{Version: "1.0.0"},
Package: &osvschema.Package{Name: "pkg-1", Ecosystem: "npm"},
},
}

got, err := s.QueryAffected(ctx, params)
if err != nil {
t.Fatalf("QueryAffected() unexpected error: %v", err)
}

// It should succeed and return the next page token as empty string (since cursor was nil)
if got.GetNextPageToken() != "" {
t.Errorf("Expected empty next page token, got %q", got.GetNextPageToken())
}
}

func TestQueryAffected_MatcherPanicPropagation(t *testing.T) {
ctx := context.Background()
store := &mockQueryVulnStore{
matchPackages: func(_ context.Context, _, _, _, _ string) iter.Seq2[models.MatchResult, error] {
return func(_ func(models.MatchResult, error) bool) {
// Simulates a panic inside the database iterator loop (runMatcher)
panic("matcher database crash")
}
},
}

s := &server{
vulnStore: store,
}

params := &pb.QueryAffectedParameters{
Query: &pb.Query{
Param: &pb.Query_Version{Version: "1.0.0"},
Package: &osvschema.Package{Name: "pkg-1", Ecosystem: "npm"},
},
}

_, err := s.QueryAffected(ctx, params)
if err == nil {
t.Fatalf("Expected error, got nil")
}
st, ok := status.FromError(err)
if !ok {
t.Fatalf("Expected gRPC status error, got %v", err)
}
if st.Code() != codes.Internal {
t.Errorf("Expected gRPC code Internal, got %v", st.Code())
}
if !strings.Contains(st.Message(), "internal server error") {
t.Errorf("Expected error message to contain 'internal server error', got %q", st.Message())
}
}

func TestQueryAffectedBatch_WorkerPanicPropagation(t *testing.T) {
ctx := context.Background()
store := &mockQueryVulnStore{
matchPackages: func(_ context.Context, _, _, _, _ string) iter.Seq2[models.MatchResult, error] {
return func(_ func(models.MatchResult, error) bool) {
// Simulates a panic inside the batch worker pipeline (QueryAffectedBatch worker)
panic("batch worker crash")
}
},
}

s := &server{
vulnStore: store,
}

params := &pb.QueryAffectedBatchParameters{
Query: &pb.BatchQuery{
Queries: []*pb.Query{
{
Param: &pb.Query_Version{Version: "1.0.0"},
Package: &osvschema.Package{Name: "pkg-1", Ecosystem: "npm"},
},
},
},
}

_, err := s.QueryAffectedBatch(ctx, params)
if err == nil {
t.Fatalf("Expected error, got nil")
}
st, ok := status.FromError(err)
if !ok {
t.Fatalf("Expected gRPC status error, got %v", err)
}
if st.Code() != codes.Internal {
t.Errorf("Expected gRPC code Internal, got %v", st.Code())
}
if !strings.Contains(st.Message(), "internal server error") {
t.Errorf("Expected error message to contain 'internal server error', got %q", st.Message())
}
}
Loading
Loading