diff --git a/go.mod b/go.mod index 697c98d..3ce017f 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 github.com/kelseyhightower/envconfig v1.4.0 github.com/prometheus/client_golang v1.23.2 + github.com/sony/gobreaker/v2 v2.4.0 go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 go.opentelemetry.io/otel v1.42.0 go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.42.0 diff --git a/go.sum b/go.sum index 94ad4a6..7a7acc0 100644 --- a/go.sum +++ b/go.sum @@ -57,6 +57,8 @@ github.com/prometheus/common v0.67.5 h1:pIgK94WWlQt1WLwAC5j2ynLaBRDiinoAb86HZHTU github.com/prometheus/common v0.67.5/go.mod h1:SjE/0MzDEEAyrdr5Gqc6G+sXI67maCxzaT3A2+HqjUw= github.com/prometheus/procfs v0.20.1 h1:XwbrGOIplXW/AU3YhIhLODXMJYyC1isLFfYCsTEycfc= github.com/prometheus/procfs v0.20.1/go.mod h1:o9EMBZGRyvDrSPH1RqdxhojkuXstoe4UlK79eF5TGGo= +github.com/sony/gobreaker/v2 v2.4.0 h1:g2KJRW1Ubty3+ZOcSEUN7K+REQJdN6yo6XvaML+jptg= +github.com/sony/gobreaker/v2 v2.4.0/go.mod h1:pTyFJgcZ3h2tdQVLZZruK2C0eoFL1fb/G83wK1ZQl+s= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= diff --git a/pkg/interceptors/circuitbreaker/circuitbreaker.go b/pkg/interceptors/circuitbreaker/circuitbreaker.go new file mode 100644 index 0000000..1c6b366 --- /dev/null +++ b/pkg/interceptors/circuitbreaker/circuitbreaker.go @@ -0,0 +1,123 @@ +/* +Copyright 2026 Chainguard, Inc. +SPDX-License-Identifier: Apache-2.0 +*/ + +// Package circuitbreaker provides a gRPC client interceptor that wraps calls +// with a circuit breaker. When a downstream service returns too many errors, +// the circuit opens and subsequent calls fail fast with codes.Unavailable +// instead of adding load to the failing service. +package circuitbreaker + +import ( + "context" + "errors" + "time" + + "github.com/chainguard-dev/clog" + "github.com/sony/gobreaker/v2" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// DefaultSettings returns circuit breaker settings tuned for gRPC +// service-to-service calls on Cloud Run. +// +// - Opens after 5 consecutive failures +// - Half-open after 15s (sends a probe request) +// - Allows 10 probe requests in half-open state +// - Resets failure count every 30s if no trip +func DefaultSettings(name string) gobreaker.Settings { + return gobreaker.Settings{ + Name: name, + MaxRequests: 10, + Interval: 30 * time.Second, + Timeout: 15 * time.Second, + ReadyToTrip: func(counts gobreaker.Counts) bool { + return counts.ConsecutiveFailures >= 5 + }, + OnStateChange: func(name string, from, to gobreaker.State) { + // OnStateChange is called outside any request context, so use + // context.Background() to get the process-level logger. + clog.InfoContextf(context.Background(), "circuit breaker %s: %s -> %s", name, from, to) + }, + IsSuccessful: func(err error) bool { + if err == nil { + return true + } + // Treat client-side errors as successes (the server didn't fail). + // Canceled is included because cancellation is typically initiated + // by the client (context timeout or user abort), not a server failure. + code := status.Code(err) + switch code { + case codes.InvalidArgument, codes.NotFound, codes.AlreadyExists, + codes.PermissionDenied, codes.Unauthenticated, codes.FailedPrecondition, + codes.OutOfRange, codes.Unimplemented, codes.Canceled: + return true + default: + return false + } + }, + } +} + +// UnaryClientInterceptor returns a gRPC unary client interceptor that +// wraps each call with the provided circuit breaker. When the circuit is +// open, calls fail immediately with codes.Unavailable. +func UnaryClientInterceptor(cb *gobreaker.CircuitBreaker[any]) grpc.UnaryClientInterceptor { + return func( + ctx context.Context, + method string, + req, reply any, + cc *grpc.ClientConn, + invoker grpc.UnaryInvoker, + opts ...grpc.CallOption, + ) error { + _, err := cb.Execute(func() (any, error) { + err := invoker(ctx, method, req, reply, cc, opts...) + return nil, err + }) + if err != nil { + // Map gobreaker's sentinel errors to gRPC status codes. + if errors.Is(err, gobreaker.ErrOpenState) || errors.Is(err, gobreaker.ErrTooManyRequests) { + return status.Errorf(codes.Unavailable, + "circuit breaker %s is open: %v", cb.Name(), err) + } + } + return err + } +} + +// StreamClientInterceptor returns a gRPC stream client interceptor that +// checks the circuit breaker before establishing a stream. When the circuit +// is open, the stream fails immediately with codes.Unavailable. +// +// Note: only stream establishment is tracked by the circuit breaker. +// Errors on Send/Recv after the stream is established are not tracked, +// so a downstream that accepts connections but fails on every message +// will not trip the breaker. +func StreamClientInterceptor(cb *gobreaker.CircuitBreaker[any]) grpc.StreamClientInterceptor { + return func( + ctx context.Context, + desc *grpc.StreamDesc, + cc *grpc.ClientConn, + method string, + streamer grpc.Streamer, + opts ...grpc.CallOption, + ) (grpc.ClientStream, error) { + result, err := cb.Execute(func() (any, error) { + stream, err := streamer(ctx, desc, cc, method, opts...) + return stream, err + }) + if err != nil { + if errors.Is(err, gobreaker.ErrOpenState) || errors.Is(err, gobreaker.ErrTooManyRequests) { + return nil, status.Errorf(codes.Unavailable, + "circuit breaker %s is open: %v", cb.Name(), err) + } + return nil, err + } + stream, _ := result.(grpc.ClientStream) + return stream, nil + } +} diff --git a/pkg/interceptors/circuitbreaker/circuitbreaker_test.go b/pkg/interceptors/circuitbreaker/circuitbreaker_test.go new file mode 100644 index 0000000..2ce064d --- /dev/null +++ b/pkg/interceptors/circuitbreaker/circuitbreaker_test.go @@ -0,0 +1,252 @@ +/* +Copyright 2026 Chainguard, Inc. +SPDX-License-Identifier: Apache-2.0 +*/ + +package circuitbreaker + +import ( + "context" + "net" + "sync/atomic" + "testing" + "time" + + "github.com/sony/gobreaker/v2" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" + healthpb "google.golang.org/grpc/health/grpc_health_v1" + "google.golang.org/grpc/status" +) + +// flakyServer fails with the given code for the first N calls, then succeeds. +type flakyServer struct { + healthpb.UnimplementedHealthServer + failCode codes.Code + failCount int32 + calls atomic.Int32 +} + +func (s *flakyServer) Check(_ context.Context, _ *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error) { + n := s.calls.Add(1) + if n <= s.failCount { + return nil, status.Error(s.failCode, "error") + } + return &healthpb.HealthCheckResponse{Status: healthpb.HealthCheckResponse_SERVING}, nil +} + +func startServer(t *testing.T, srv healthpb.HealthServer) (string, func()) { + t.Helper() + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal(err) + } + s := grpc.NewServer() + healthpb.RegisterHealthServer(s, srv) + go s.Serve(lis) + return lis.Addr().String(), s.Stop +} + +func dial(t *testing.T, addr string, cb *gobreaker.CircuitBreaker[any]) healthpb.HealthClient { + t.Helper() + conn, err := grpc.NewClient(addr, + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithUnaryInterceptor(UnaryClientInterceptor(cb)), + ) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { conn.Close() }) + return healthpb.NewHealthClient(conn) +} + +func TestCircuitBreaker_TripsAfterConsecutiveFailures(t *testing.T) { + // Server always fails with Internal. + addr, stop := startServer(t, &flakyServer{failCode: codes.Internal, failCount: 100}) + defer stop() + + settings := DefaultSettings("test") + settings.ReadyToTrip = func(counts gobreaker.Counts) bool { + return counts.ConsecutiveFailures > 2 // Trip after 2 for faster test + } + cb := gobreaker.NewCircuitBreaker[any](settings) + client := dial(t, addr, cb) + + // First 3 calls hit the server and get Internal errors. + for i := range 3 { + _, err := client.Check(context.Background(), &healthpb.HealthCheckRequest{}) + if err == nil { + t.Fatalf("call %d: expected error, got nil", i) + } + if got := status.Code(err); got != codes.Internal { + t.Fatalf("call %d: expected Internal, got %v", i, got) + } + } + + // Circuit should now be open. Next call should fail fast with Unavailable. + _, err := client.Check(context.Background(), &healthpb.HealthCheckRequest{}) + if err == nil { + t.Fatal("expected circuit open error, got nil") + } + if got := status.Code(err); got != codes.Unavailable { + t.Errorf("expected Unavailable (circuit open), got %v", got) + } +} + +func TestCircuitBreaker_ClientErrorsDoNotTrip(t *testing.T) { + // Server always fails with NotFound (a client-side error). + addr, stop := startServer(t, &flakyServer{failCode: codes.NotFound, failCount: 100}) + defer stop() + + settings := DefaultSettings("test") + settings.ReadyToTrip = func(counts gobreaker.Counts) bool { + return counts.ConsecutiveFailures > 2 + } + cb := gobreaker.NewCircuitBreaker[any](settings) + client := dial(t, addr, cb) + + // NotFound is classified as successful (client error), so the circuit + // should NOT trip even after many calls. + for i := range 10 { + _, err := client.Check(context.Background(), &healthpb.HealthCheckRequest{}) + if err == nil { + t.Fatalf("call %d: expected NotFound error, got nil", i) + } + if got := status.Code(err); got != codes.NotFound { + t.Fatalf("call %d: expected NotFound, got %v", i, got) + } + } + + // Circuit should still be closed. + if cb.State() != gobreaker.StateClosed { + t.Errorf("expected circuit closed, got %v", cb.State()) + } +} + +func TestCircuitBreaker_RecoversThroughHalfOpen(t *testing.T) { + // Server fails first 3 calls, then succeeds. + srv := &flakyServer{failCode: codes.Unavailable, failCount: 3} + addr, stop := startServer(t, srv) + defer stop() + + settings := DefaultSettings("test") + settings.ReadyToTrip = func(counts gobreaker.Counts) bool { + return counts.ConsecutiveFailures > 2 + } + settings.Timeout = 1 * time.Millisecond // Transition to half-open almost immediately + settings.MaxRequests = 1 + cb := gobreaker.NewCircuitBreaker[any](settings) + client := dial(t, addr, cb) + + // Trip the breaker with 3 failures. + for range 3 { + client.Check(context.Background(), &healthpb.HealthCheckRequest{}) + } + if cb.State() != gobreaker.StateOpen { + t.Fatalf("expected open, got %v", cb.State()) + } + + // Wait for the timeout to elapse so the breaker transitions to half-open. + time.Sleep(5 * time.Millisecond) + + // The next call is a probe — server now succeeds (failCount=3, we're past it). + resp, err := client.Check(context.Background(), &healthpb.HealthCheckRequest{}) + if err != nil { + t.Fatalf("half-open probe should succeed: %v", err) + } + if resp.Status != healthpb.HealthCheckResponse_SERVING { + t.Errorf("got %v, want SERVING", resp.Status) + } + + // Circuit should now be closed again. + if cb.State() != gobreaker.StateClosed { + t.Errorf("expected closed after recovery, got %v", cb.State()) + } +} + +func TestStreamClientInterceptor_TripsAndFailsFast(t *testing.T) { + // Server always fails with Internal. + addr, stop := startServer(t, &flakyServer{failCode: codes.Internal, failCount: 100}) + defer stop() + + settings := DefaultSettings("stream-test") + settings.ReadyToTrip = func(counts gobreaker.Counts) bool { + return counts.ConsecutiveFailures >= 2 + } + cb := gobreaker.NewCircuitBreaker[any](settings) + + // Dial with both unary and stream interceptors. + conn, err := grpc.NewClient(addr, + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithUnaryInterceptor(UnaryClientInterceptor(cb)), + grpc.WithStreamInterceptor(StreamClientInterceptor(cb)), + ) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { conn.Close() }) + + client := healthpb.NewHealthClient(conn) + + // Trip the breaker via unary calls (2 failures). + for range 2 { + _, err := client.Check(context.Background(), &healthpb.HealthCheckRequest{}) + if err == nil { + t.Fatal("expected error") + } + } + if cb.State() != gobreaker.StateOpen { + t.Fatalf("expected open, got %v", cb.State()) + } + + // Stream call should fail fast with Unavailable (circuit is open). + stream, err := client.Watch(context.Background(), &healthpb.HealthCheckRequest{}) + if err == nil && stream != nil { + // Some gRPC versions defer the error to Recv. + _, err = stream.Recv() + } + if err == nil { + t.Fatal("expected circuit open error on stream, got nil") + } + if got := status.Code(err); got != codes.Unavailable { + t.Errorf("expected Unavailable (circuit open) on stream, got %v", got) + } +} + +func TestDefaultSettings_IsSuccessful(t *testing.T) { + settings := DefaultSettings("test") + + tests := []struct { + code codes.Code + want bool + }{ + {codes.OK, true}, + {codes.InvalidArgument, true}, + {codes.NotFound, true}, + {codes.AlreadyExists, true}, + {codes.PermissionDenied, true}, + {codes.Unauthenticated, true}, + {codes.FailedPrecondition, true}, + {codes.OutOfRange, true}, + {codes.Unimplemented, true}, + {codes.Canceled, true}, + {codes.Internal, false}, + {codes.Unavailable, false}, + {codes.DeadlineExceeded, false}, + {codes.ResourceExhausted, false}, + {codes.Unknown, false}, + {codes.Aborted, false}, + {codes.DataLoss, false}, + } + + for _, tt := range tests { + var err error + if tt.code != codes.OK { + err = status.Error(tt.code, "test") + } + if got := settings.IsSuccessful(err); got != tt.want { + t.Errorf("IsSuccessful(%v) = %v, want %v", tt.code, got, tt.want) + } + } +}