Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ require (
github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0
github.com/kelseyhightower/envconfig v1.4.0
github.com/prometheus/client_golang v1.23.2
github.com/sony/gobreaker/v2 v2.4.0
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0
go.opentelemetry.io/otel v1.42.0
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.42.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ github.com/prometheus/common v0.67.5 h1:pIgK94WWlQt1WLwAC5j2ynLaBRDiinoAb86HZHTU
github.com/prometheus/common v0.67.5/go.mod h1:SjE/0MzDEEAyrdr5Gqc6G+sXI67maCxzaT3A2+HqjUw=
github.com/prometheus/procfs v0.20.1 h1:XwbrGOIplXW/AU3YhIhLODXMJYyC1isLFfYCsTEycfc=
github.com/prometheus/procfs v0.20.1/go.mod h1:o9EMBZGRyvDrSPH1RqdxhojkuXstoe4UlK79eF5TGGo=
github.com/sony/gobreaker/v2 v2.4.0 h1:g2KJRW1Ubty3+ZOcSEUN7K+REQJdN6yo6XvaML+jptg=
github.com/sony/gobreaker/v2 v2.4.0/go.mod h1:pTyFJgcZ3h2tdQVLZZruK2C0eoFL1fb/G83wK1ZQl+s=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
Expand Down
123 changes: 123 additions & 0 deletions pkg/interceptors/circuitbreaker/circuitbreaker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
/*
Copyright 2026 Chainguard, Inc.
SPDX-License-Identifier: Apache-2.0
*/

// Package circuitbreaker provides a gRPC client interceptor that wraps calls
// with a circuit breaker. When a downstream service returns too many errors,
// the circuit opens and subsequent calls fail fast with codes.Unavailable
// instead of adding load to the failing service.
package circuitbreaker

import (
"context"
"errors"
"time"

"github.com/chainguard-dev/clog"
"github.com/sony/gobreaker/v2"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

// DefaultSettings returns circuit breaker settings tuned for gRPC
// service-to-service calls on Cloud Run.
//
// - Opens after 5 consecutive failures
// - Half-open after 15s (sends a probe request)
// - Allows 10 probe requests in half-open state
// - Resets failure count every 30s if no trip
func DefaultSettings(name string) gobreaker.Settings {
return gobreaker.Settings{
Name: name,
MaxRequests: 10,
Interval: 30 * time.Second,
Timeout: 15 * time.Second,
ReadyToTrip: func(counts gobreaker.Counts) bool {
return counts.ConsecutiveFailures >= 5
},
OnStateChange: func(name string, from, to gobreaker.State) {
// OnStateChange is called outside any request context, so use
// context.Background() to get the process-level logger.
clog.InfoContextf(context.Background(), "circuit breaker %s: %s -> %s", name, from, to)
},
IsSuccessful: func(err error) bool {
if err == nil {
return true
}
// Treat client-side errors as successes (the server didn't fail).
// Canceled is included because cancellation is typically initiated
// by the client (context timeout or user abort), not a server failure.
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖: Comment says "context timeout" but codes.Canceled is explicit context cancellation (e.g., ctx.Cancel()), not a timeout. Timeouts produce codes.DeadlineExceeded which is already handled separately above. Consider: "cancelled by the client (explicit cancellation), not a server failure."

code := status.Code(err)
switch code {
case codes.InvalidArgument, codes.NotFound, codes.AlreadyExists,
codes.PermissionDenied, codes.Unauthenticated, codes.FailedPrecondition,
codes.OutOfRange, codes.Unimplemented, codes.Canceled:
return true
default:
return false
}
},
}
}

// UnaryClientInterceptor returns a gRPC unary client interceptor that
// wraps each call with the provided circuit breaker. When the circuit is
// open, calls fail immediately with codes.Unavailable.
func UnaryClientInterceptor(cb *gobreaker.CircuitBreaker[any]) grpc.UnaryClientInterceptor {
return func(
ctx context.Context,
method string,
req, reply any,
cc *grpc.ClientConn,
invoker grpc.UnaryInvoker,
opts ...grpc.CallOption,
) error {
_, err := cb.Execute(func() (any, error) {
err := invoker(ctx, method, req, reply, cc, opts...)
return nil, err
})
if err != nil {
// Map gobreaker's sentinel errors to gRPC status codes.
if errors.Is(err, gobreaker.ErrOpenState) || errors.Is(err, gobreaker.ErrTooManyRequests) {
return status.Errorf(codes.Unavailable,
"circuit breaker %s is open: %v", cb.Name(), err)
}
}
return err
}
}

// StreamClientInterceptor returns a gRPC stream client interceptor that
// checks the circuit breaker before establishing a stream. When the circuit
// is open, the stream fails immediately with codes.Unavailable.
//
// Note: only stream establishment is tracked by the circuit breaker.
// Errors on Send/Recv after the stream is established are not tracked,
// so a downstream that accepts connections but fails on every message
// will not trip the breaker.
func StreamClientInterceptor(cb *gobreaker.CircuitBreaker[any]) grpc.StreamClientInterceptor {
return func(
ctx context.Context,
desc *grpc.StreamDesc,
cc *grpc.ClientConn,
method string,
streamer grpc.Streamer,
opts ...grpc.CallOption,
) (grpc.ClientStream, error) {
result, err := cb.Execute(func() (any, error) {
stream, err := streamer(ctx, desc, cc, method, opts...)
return stream, err
})
if err != nil {
if errors.Is(err, gobreaker.ErrOpenState) || errors.Is(err, gobreaker.ErrTooManyRequests) {
return nil, status.Errorf(codes.Unavailable,
"circuit breaker %s is open: %v", cb.Name(), err)
}
return nil, err
}
stream, _ := result.(grpc.ClientStream)
return stream, nil
}
}
252 changes: 252 additions & 0 deletions pkg/interceptors/circuitbreaker/circuitbreaker_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
/*
Copyright 2026 Chainguard, Inc.
SPDX-License-Identifier: Apache-2.0
*/

package circuitbreaker

import (
"context"
"net"
"sync/atomic"
"testing"
"time"

"github.com/sony/gobreaker/v2"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
healthpb "google.golang.org/grpc/health/grpc_health_v1"
"google.golang.org/grpc/status"
)

// flakyServer fails with the given code for the first N calls, then succeeds.
type flakyServer struct {
healthpb.UnimplementedHealthServer
failCode codes.Code
failCount int32
calls atomic.Int32
}

func (s *flakyServer) Check(_ context.Context, _ *healthpb.HealthCheckRequest) (*healthpb.HealthCheckResponse, error) {
n := s.calls.Add(1)
if n <= s.failCount {
return nil, status.Error(s.failCode, "error")
}
return &healthpb.HealthCheckResponse{Status: healthpb.HealthCheckResponse_SERVING}, nil
}

func startServer(t *testing.T, srv healthpb.HealthServer) (string, func()) {
t.Helper()
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatal(err)
}
s := grpc.NewServer()
healthpb.RegisterHealthServer(s, srv)
go s.Serve(lis)
return lis.Addr().String(), s.Stop
}

func dial(t *testing.T, addr string, cb *gobreaker.CircuitBreaker[any]) healthpb.HealthClient {
t.Helper()
conn, err := grpc.NewClient(addr,
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithUnaryInterceptor(UnaryClientInterceptor(cb)),
)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { conn.Close() })
return healthpb.NewHealthClient(conn)
}

func TestCircuitBreaker_TripsAfterConsecutiveFailures(t *testing.T) {
// Server always fails with Internal.
addr, stop := startServer(t, &flakyServer{failCode: codes.Internal, failCount: 100})
defer stop()

settings := DefaultSettings("test")
settings.ReadyToTrip = func(counts gobreaker.Counts) bool {
return counts.ConsecutiveFailures > 2 // Trip after 2 for faster test
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖: Comment says "Trip after 2" but the condition is > 2, which trips after 3 consecutive failures. Should be "Trip after 3" or change to > 1.

}
cb := gobreaker.NewCircuitBreaker[any](settings)
client := dial(t, addr, cb)

// First 3 calls hit the server and get Internal errors.
for i := range 3 {
_, err := client.Check(context.Background(), &healthpb.HealthCheckRequest{})
if err == nil {
t.Fatalf("call %d: expected error, got nil", i)
}
if got := status.Code(err); got != codes.Internal {
t.Fatalf("call %d: expected Internal, got %v", i, got)
}
}

// Circuit should now be open. Next call should fail fast with Unavailable.
_, err := client.Check(context.Background(), &healthpb.HealthCheckRequest{})
if err == nil {
t.Fatal("expected circuit open error, got nil")
}
if got := status.Code(err); got != codes.Unavailable {
t.Errorf("expected Unavailable (circuit open), got %v", got)
}
}

func TestCircuitBreaker_ClientErrorsDoNotTrip(t *testing.T) {
// Server always fails with NotFound (a client-side error).
addr, stop := startServer(t, &flakyServer{failCode: codes.NotFound, failCount: 100})
defer stop()

settings := DefaultSettings("test")
settings.ReadyToTrip = func(counts gobreaker.Counts) bool {
return counts.ConsecutiveFailures > 2
}
cb := gobreaker.NewCircuitBreaker[any](settings)
client := dial(t, addr, cb)

// NotFound is classified as successful (client error), so the circuit
// should NOT trip even after many calls.
for i := range 10 {
_, err := client.Check(context.Background(), &healthpb.HealthCheckRequest{})
if err == nil {
t.Fatalf("call %d: expected NotFound error, got nil", i)
}
if got := status.Code(err); got != codes.NotFound {
t.Fatalf("call %d: expected NotFound, got %v", i, got)
}
}

// Circuit should still be closed.
if cb.State() != gobreaker.StateClosed {
t.Errorf("expected circuit closed, got %v", cb.State())
}
}

func TestCircuitBreaker_RecoversThroughHalfOpen(t *testing.T) {
// Server fails first 3 calls, then succeeds.
srv := &flakyServer{failCode: codes.Unavailable, failCount: 3}
addr, stop := startServer(t, srv)
defer stop()

settings := DefaultSettings("test")
settings.ReadyToTrip = func(counts gobreaker.Counts) bool {
return counts.ConsecutiveFailures > 2
}
settings.Timeout = 1 * time.Millisecond // Transition to half-open almost immediately
settings.MaxRequests = 1
cb := gobreaker.NewCircuitBreaker[any](settings)
client := dial(t, addr, cb)

// Trip the breaker with 3 failures.
for range 3 {
client.Check(context.Background(), &healthpb.HealthCheckRequest{})
}
if cb.State() != gobreaker.StateOpen {
t.Fatalf("expected open, got %v", cb.State())
}

// Wait for the timeout to elapse so the breaker transitions to half-open.
time.Sleep(5 * time.Millisecond)

// The next call is a probe — server now succeeds (failCount=3, we're past it).
resp, err := client.Check(context.Background(), &healthpb.HealthCheckRequest{})
if err != nil {
t.Fatalf("half-open probe should succeed: %v", err)
}
if resp.Status != healthpb.HealthCheckResponse_SERVING {
t.Errorf("got %v, want SERVING", resp.Status)
}

// Circuit should now be closed again.
if cb.State() != gobreaker.StateClosed {
t.Errorf("expected closed after recovery, got %v", cb.State())
}
}

func TestStreamClientInterceptor_TripsAndFailsFast(t *testing.T) {
// Server always fails with Internal.
addr, stop := startServer(t, &flakyServer{failCode: codes.Internal, failCount: 100})
defer stop()

settings := DefaultSettings("stream-test")
settings.ReadyToTrip = func(counts gobreaker.Counts) bool {
return counts.ConsecutiveFailures >= 2
}
cb := gobreaker.NewCircuitBreaker[any](settings)

// Dial with both unary and stream interceptors.
conn, err := grpc.NewClient(addr,
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithUnaryInterceptor(UnaryClientInterceptor(cb)),
grpc.WithStreamInterceptor(StreamClientInterceptor(cb)),
)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { conn.Close() })

client := healthpb.NewHealthClient(conn)

// Trip the breaker via unary calls (2 failures).
for range 2 {
_, err := client.Check(context.Background(), &healthpb.HealthCheckRequest{})
if err == nil {
t.Fatal("expected error")
}
}
if cb.State() != gobreaker.StateOpen {
t.Fatalf("expected open, got %v", cb.State())
}

// Stream call should fail fast with Unavailable (circuit is open).
stream, err := client.Watch(context.Background(), &healthpb.HealthCheckRequest{})
if err == nil && stream != nil {
// Some gRPC versions defer the error to Recv.
_, err = stream.Recv()
}
if err == nil {
t.Fatal("expected circuit open error on stream, got nil")
}
if got := status.Code(err); got != codes.Unavailable {
t.Errorf("expected Unavailable (circuit open) on stream, got %v", got)
}
}

func TestDefaultSettings_IsSuccessful(t *testing.T) {
settings := DefaultSettings("test")

tests := []struct {
code codes.Code
want bool
}{
{codes.OK, true},
{codes.InvalidArgument, true},
{codes.NotFound, true},
{codes.AlreadyExists, true},
{codes.PermissionDenied, true},
{codes.Unauthenticated, true},
{codes.FailedPrecondition, true},
{codes.OutOfRange, true},
{codes.Unimplemented, true},
{codes.Canceled, true},
{codes.Internal, false},
{codes.Unavailable, false},
{codes.DeadlineExceeded, false},
{codes.ResourceExhausted, false},
{codes.Unknown, false},
{codes.Aborted, false},
{codes.DataLoss, false},
}

for _, tt := range tests {
var err error
if tt.code != codes.OK {
err = status.Error(tt.code, "test")
}
if got := settings.IsSuccessful(err); got != tt.want {
t.Errorf("IsSuccessful(%v) = %v, want %v", tt.code, got, tt.want)
}
}
}
Loading