Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ coverage.out
coverage-report.txt
report.xml
junit.xml
test.log
/tmp/*
/examples/tmp/*
/bin/serve/docker/prometheus/data
62 changes: 46 additions & 16 deletions circuitbreaker/circuitbreaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,24 @@
}
}

type Logger interface {
Debug(...interface{})
}

type Config struct {
FailureThreshold int
OpenTimeout time.Duration
HalfOpenMaxCalls int
OnStateChange func(from, to State)
Logger Logger
}

type Breaker interface {
Allow() bool
GetState() State
RecordStart() bool
RecordResult(success bool)
Reset()
}

type CircuitBreaker struct {
Expand All @@ -46,21 +59,31 @@
timer *time.Timer
}

func NewDefault() (*CircuitBreaker, error) {
return New(Config{
FailureThreshold: 5,
OpenTimeout: 5 * time.Second,
HalfOpenMaxCalls: 5,
})
}

func New(config Config) (*CircuitBreaker, error) {
if err := validateConfig(config); err != nil {
if err := validateConfig(&config); err != nil {
return nil, err
}

logger.Debug("creating new circuit breaker", logger.Fields{
breaker := &CircuitBreaker{
config: config,
state: StateClosed,
}

breaker.config.Logger.Debug("creating new circuit breaker", logger.Fields{
"failure_threshold": config.FailureThreshold,
"open_timeout": config.OpenTimeout.String(),
"half_open_max_calls": config.HalfOpenMaxCalls,
})

return &CircuitBreaker{
config: config,
state: StateClosed,
}, nil
return breaker, nil
}

func (cb *CircuitBreaker) Allow() bool {
Expand Down Expand Up @@ -115,7 +138,7 @@
}

cb.attempts++
logger.Debug("attempt started", logger.Fields{
cb.config.Logger.Debug("attempt started", logger.Fields{
"state": cb.state.String(),
"attempts": cb.attempts,
})
Expand All @@ -127,7 +150,7 @@
cb.mutex.Lock()
defer cb.mutex.Unlock()

logger.Debug("recording attempt result", logger.Fields{
cb.config.Logger.Debug("recording attempt result", logger.Fields{
"success": success,
"state": cb.state.String(),
"attempts": cb.attempts,
Expand All @@ -143,7 +166,7 @@
switch cb.state {
case StateHalfOpen:
cb.successes++
logger.Debug("recorded success in half-open state", logger.Fields{
cb.config.Logger.Debug("recorded success in half-open state", logger.Fields{
"attempts": cb.attempts,
"successes": cb.successes,
"max_calls": cb.config.HalfOpenMaxCalls,
Expand All @@ -153,7 +176,7 @@
}
case StateClosed:
cb.failures = 0
logger.Debug("recorded success in closed state", logger.Fields{
cb.config.Logger.Debug("recorded success in closed state", logger.Fields{

Check warning on line 179 in circuitbreaker/circuitbreaker.go

View check run for this annotation

Codecov / codecov/patch

circuitbreaker/circuitbreaker.go#L179

Added line #L179 was not covered by tests
"failures": cb.failures,
})
}
Expand All @@ -163,7 +186,7 @@
cb.mutex.Lock()
defer cb.mutex.Unlock()

logger.Debug("resetting circuit breaker", logger.Fields{
cb.config.Logger.Debug("resetting circuit breaker", logger.Fields{

Check warning on line 189 in circuitbreaker/circuitbreaker.go

View check run for this annotation

Codecov / codecov/patch

circuitbreaker/circuitbreaker.go#L189

Added line #L189 was not covered by tests
"from_state": cb.state.String(),
})

Expand All @@ -180,7 +203,7 @@
func (cb *CircuitBreaker) recordFailure() {
cb.failures++

logger.Debug("recorded failure", logger.Fields{
cb.config.Logger.Debug("recorded failure", logger.Fields{
"state": cb.state.String(),
"failures": cb.failures,
"threshold": cb.config.FailureThreshold,
Expand All @@ -197,7 +220,7 @@
}

func (cb *CircuitBreaker) openCircuit() {
logger.Debug("opening circuit", logger.Fields{
cb.config.Logger.Debug("opening circuit", logger.Fields{
"from_state": cb.state.String(),
"open_timeout": cb.config.OpenTimeout.String(),
})
Expand All @@ -212,7 +235,7 @@
cb.mutex.Lock()
defer cb.mutex.Unlock()

logger.Debug("open timeout elapsed", logger.Fields{
cb.config.Logger.Debug("open timeout elapsed", logger.Fields{
"current_state": cb.state.String(),
})

Expand All @@ -232,7 +255,7 @@
cb.attempts = 0
cb.successes = 0

logger.Debug("state transition", logger.Fields{
cb.config.Logger.Debug("state transition", logger.Fields{
"from_state": oldState.String(),
"to_state": newState.String(),
"attempts": cb.attempts,
Expand All @@ -244,15 +267,22 @@
}
}

func validateConfig(config Config) error {
func validateConfig(config *Config) error {
if config.FailureThreshold <= 0 {
return fmt.Errorf("failure threshold must be greater than 0")
}

if config.OpenTimeout <= 0 {
return fmt.Errorf("open timeout must be greater than 0")
}

if config.HalfOpenMaxCalls <= 0 {
return fmt.Errorf("half-open max calls must be greater than 0")
}

if config.Logger == nil {
config.Logger = logger.New()
}

return nil
}
16 changes: 11 additions & 5 deletions messaging/natsjscm/natsjscm.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,14 @@ type Connector interface {
Disconnect() error
IsConnected() bool
GetConnection() *nats.Conn
GetJetStream() jetstream.JetStream
EnsureStream(ctx context.Context, config jetstream.StreamConfig) (jetstream.JetStream, error)
GetJetStream() JetStream
EnsureStream(ctx context.Context, config jetstream.StreamConfig) (JetStream, error)
}

type JetStream interface {
Stream(ctx context.Context, name string) (jetstream.Stream, error)
CreateStream(ctx context.Context, config jetstream.StreamConfig) (jetstream.Stream, error)
PublishMsg(context.Context, *nats.Msg, ...jetstream.PublishOpt) (*jetstream.PubAck, error)
}

// ConnectionConfig holds the configuration for NATS connection
Expand All @@ -45,7 +51,7 @@ type ConnectionManager struct {
}

// NewConnectionManager creates a new connection manager
func NewConnectionManager(config ConnectionConfig) (*ConnectionManager, error) {
func NewConnectionManager(config ConnectionConfig) (Connector, error) {
if config.URL == "" {
return nil, fmt.Errorf("NATS URL is required")
}
Expand Down Expand Up @@ -179,7 +185,7 @@ func (cm *ConnectionManager) GetConnection() *nats.Conn {
}

// GetJetStream returns the JetStream context
func (cm *ConnectionManager) GetJetStream() jetstream.JetStream {
func (cm *ConnectionManager) GetJetStream() JetStream {
cm.mu.RLock()
defer cm.mu.RUnlock()
return cm.js
Expand All @@ -189,7 +195,7 @@ func (cm *ConnectionManager) GetJetStream() jetstream.JetStream {
func (cm *ConnectionManager) EnsureStream(
ctx context.Context,
config jetstream.StreamConfig,
) (jetstream.JetStream, error) {
) (JetStream, error) {
cm.mu.RLock()
js := cm.js
cm.mu.RUnlock()
Expand Down
5 changes: 3 additions & 2 deletions messaging/natsjsdlq/natsjsdlq_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/nats-io/nats.go"
"github.com/nats-io/nats.go/jetstream"
"github.com/simiancreative/simiango/messaging/natsjscm"
"github.com/simiancreative/simiango/messaging/natsjsdlq"
"github.com/stretchr/testify/mock"
"github.com/tj/assert"
Expand All @@ -28,15 +29,15 @@ func (m *MockConnectionManager) GetConnection() *nats.Conn {
return args.Get(0).(*nats.Conn)
}

func (m *MockConnectionManager) GetJetStream() jetstream.JetStream {
func (m *MockConnectionManager) GetJetStream() natsjscm.JetStream {
args := m.Called()
return args.Get(0).(jetstream.JetStream)
}

func (m *MockConnectionManager) EnsureStream(
ctx context.Context,
config jetstream.StreamConfig,
) (jetstream.JetStream, error) {
) (natsjscm.JetStream, error) {
args := m.Called(ctx, config)
if args.Get(0) == nil {
return nil, args.Error(1)
Expand Down
46 changes: 27 additions & 19 deletions messaging/natsjspub/natsjspub.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@
) (*jetstream.PubAck, error)
}

type PublishMulti interface {
JsonPublisher
Publisher
}

// Config holds publisher configuration
type Config struct {
// Stream name to publish to
Expand All @@ -37,9 +42,6 @@
// Subject to publish on
Subject string

// CircuitBreaker configuration (optional)
CircuitBreaker *circuitbreaker.Config

// Publish timeout (default 5s)
Timeout time.Duration

Expand All @@ -50,24 +52,30 @@
// Dependencies for the publisher
type Dependencies struct {
// ConnectionManager for NATS
ConnectionManager *natsjscm.ConnectionManager
Connector natsjscm.Connector
Breaker circuitbreaker.Breaker
Logger Logger
}

// Publisher is a JetStream publisher with circuit breaker capabilities
type PublishManager struct {
config Config
cm *natsjscm.ConnectionManager
cb *circuitbreaker.CircuitBreaker
cm natsjscm.Connector
cb circuitbreaker.Breaker
log Logger
}

// NewPublisher creates a new JetStream publisher
func NewPublisher(deps Dependencies, config Config) (*PublishManager, error) {
func NewPublisher(deps Dependencies, config Config) (PublishMulti, error) {
// Validation
if deps.ConnectionManager == nil {
if deps.Connector == nil {
return nil, fmt.Errorf("connection manager is required")
}

if deps.Breaker == nil {
deps.Breaker, _ = circuitbreaker.NewDefault()
}

if config.StreamName == "" {
return nil, fmt.Errorf("stream name is required")
}
Expand All @@ -80,19 +88,15 @@
config.Timeout = 5 * time.Second
}

pub := &PublishManager{
config: config,
cm: deps.ConnectionManager,
log: logger.New(),
if deps.Logger == nil {
deps.Logger = logger.New()
}

// Initialize circuit breaker if configured
if config.CircuitBreaker != nil {
cb, err := circuitbreaker.New(*config.CircuitBreaker)
if err != nil {
return nil, fmt.Errorf("failed to create circuit breaker: %w", err)
}
pub.cb = cb
pub := &PublishManager{
config: config,
cm: deps.Connector,
cb: deps.Breaker,
log: deps.Logger,
}

// Ensure the stream exists
Expand All @@ -106,8 +110,11 @@

// ensureStream makes sure the configured stream exists
func (p *PublishManager) ensureStream(ctx context.Context) error {
p.log.Debugf("ensuring stream is connected")

// Get JetStream connection
if !p.cm.IsConnected() {
p.log.Debugf("connecting to NATS")

Check warning on line 117 in messaging/natsjspub/natsjspub.go

View check run for this annotation

Codecov / codecov/patch

messaging/natsjspub/natsjspub.go#L117

Added line #L117 was not covered by tests
if err := p.cm.Connect(); err != nil {
return fmt.Errorf("failed to connect to NATS: %w", err)
}
Expand All @@ -127,6 +134,7 @@
}

// Ensure stream exists
p.log.Debugf("ensuring stream %s exists", p.config.StreamName)
_, err := p.cm.EnsureStream(ctx, streamConfig)
return err
}
Expand Down
35 changes: 35 additions & 0 deletions messaging/natsjspub/natsjspub_mcb_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package natsjspub_test

import (
"github.com/sanity-io/litter"
"github.com/simiancreative/simiango/circuitbreaker"
"github.com/stretchr/testify/mock"
)

type MockCircuitBreaker struct {
mock.Mock
}

func (m *MockCircuitBreaker) Allow() bool {
args := m.Called()
litter.Dump(args)
return args.Bool(0)
}

func (m *MockCircuitBreaker) GetState() circuitbreaker.State {
args := m.Called()
return args.Get(0).(circuitbreaker.State)
}

func (m *MockCircuitBreaker) RecordStart() bool {
args := m.Called()
return args.Bool(0)
}

func (m *MockCircuitBreaker) RecordResult(success bool) {
m.Called(success)
}

func (m *MockCircuitBreaker) Reset() {
m.Called()
}
Loading