Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions pkg/workflows/ring/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ import (
"context"
"errors"

"github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types"

"github.com/smartcontractkit/chainlink-common/pkg/logger"
"github.com/smartcontractkit/chainlink-common/pkg/services"
"github.com/smartcontractkit/chainlink-common/pkg/types/core"
"github.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pb"
"github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types"
"github.com/smartcontractkit/chainlink-common/pkg/workflows/shardorchestrator"
)

const (
Expand All @@ -20,15 +22,16 @@ const (
var _ core.OCR3ReportingPluginFactory = &Factory{}

type Factory struct {
store *Store
arbiterScaler pb.ArbiterScalerClient
config *ConsensusConfig
lggr logger.Logger
ringStore *Store
shardOrchestratorStore *shardorchestrator.Store
arbiterScaler pb.ArbiterScalerClient
config *ConsensusConfig
lggr logger.Logger

services.StateMachine
}

func NewFactory(s *Store, arbiterScaler pb.ArbiterScalerClient, lggr logger.Logger, cfg *ConsensusConfig) (*Factory, error) {
func NewFactory(s *Store, shardOrchestratorStore *shardorchestrator.Store, arbiterScaler pb.ArbiterScalerClient, lggr logger.Logger, cfg *ConsensusConfig) (*Factory, error) {
if arbiterScaler == nil {
return nil, errors.New("arbiterScaler is required")
}
Expand All @@ -38,15 +41,16 @@ func NewFactory(s *Store, arbiterScaler pb.ArbiterScalerClient, lggr logger.Logg
}
}
return &Factory{
store: s,
arbiterScaler: arbiterScaler,
config: cfg,
lggr: logger.Named(lggr, "RingPluginFactory"),
ringStore: s,
shardOrchestratorStore: shardOrchestratorStore,
arbiterScaler: arbiterScaler,
config: cfg,
lggr: logger.Named(lggr, "RingPluginFactory"),
}, nil
}

func (o *Factory) NewReportingPlugin(_ context.Context, config ocr3types.ReportingPluginConfig) (ocr3types.ReportingPlugin[[]byte], ocr3types.ReportingPluginInfo, error) {
plugin, err := NewPlugin(o.store, o.arbiterScaler, config, o.lggr, o.config)
plugin, err := NewPlugin(o.ringStore, o.arbiterScaler, config, o.lggr, o.config)
pluginInfo := ocr3types.ReportingPluginInfo{
Name: "RingPlugin",
Limits: ocr3types.ReportingPluginLimits{
Expand Down
13 changes: 7 additions & 6 deletions pkg/workflows/ring/factory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ import (
"context"
"testing"

"github.com/smartcontractkit/chainlink-common/pkg/logger"
"github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types"
"github.com/stretchr/testify/require"

"github.com/smartcontractkit/chainlink-common/pkg/logger"
)

func TestFactory_NewFactory(t *testing.T) {
Expand All @@ -15,20 +16,20 @@ func TestFactory_NewFactory(t *testing.T) {
arbiter := &mockArbiter{}

t.Run("with_nil_config", func(t *testing.T) {
f, err := NewFactory(store, arbiter, lggr, nil)
f, err := NewFactory(store, nil, arbiter, lggr, nil)
require.NoError(t, err)
require.NotNil(t, f)
})

t.Run("with_custom_config", func(t *testing.T) {
cfg := &ConsensusConfig{BatchSize: 50}
f, err := NewFactory(store, arbiter, lggr, cfg)
f, err := NewFactory(store, nil, arbiter, lggr, cfg)
require.NoError(t, err)
require.NotNil(t, f)
})

t.Run("nil_arbiter_returns_error", func(t *testing.T) {
_, err := NewFactory(store, nil, lggr, nil)
_, err := NewFactory(store, nil, nil, lggr, nil)
require.Error(t, err)
require.Contains(t, err.Error(), "arbiterScaler is required")
})
Expand All @@ -37,7 +38,7 @@ func TestFactory_NewFactory(t *testing.T) {
func TestFactory_NewReportingPlugin(t *testing.T) {
lggr := logger.Test(t)
store := NewStore()
f, err := NewFactory(store, &mockArbiter{}, lggr, nil)
f, err := NewFactory(store, nil, &mockArbiter{}, lggr, nil)
require.NoError(t, err)

config := ocr3types.ReportingPluginConfig{N: 4, F: 1}
Expand All @@ -52,7 +53,7 @@ func TestFactory_NewReportingPlugin(t *testing.T) {
func TestFactory_Lifecycle(t *testing.T) {
lggr := logger.Test(t)
store := NewStore()
f, err := NewFactory(store, &mockArbiter{}, lggr, nil)
f, err := NewFactory(store, nil, &mockArbiter{}, lggr, nil)
require.NoError(t, err)

err = f.Start(context.Background())
Expand Down
7 changes: 4 additions & 3 deletions pkg/workflows/ring/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@ import (
"google.golang.org/protobuf/types/known/emptypb"
"google.golang.org/protobuf/types/known/timestamppb"

"github.com/smartcontractkit/chainlink-common/pkg/logger"
"github.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pb"
"github.com/smartcontractkit/libocr/offchainreporting2/types"
"github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types"

"github.com/smartcontractkit/chainlink-common/pkg/logger"
"github.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pb"
)

type mockArbiter struct {
Expand Down Expand Up @@ -443,7 +444,7 @@ func TestPlugin_NoHealthyShardsFallbackToShardZero(t *testing.T) {
})
require.NoError(t, err)

transmitter := NewTransmitter(lggr, store, arbiter, "test-account")
transmitter := NewTransmitter(lggr, store, nil, arbiter, "test-account")

ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
Expand Down
67 changes: 57 additions & 10 deletions pkg/workflows/ring/transmitter.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,33 @@ import (

"google.golang.org/protobuf/proto"

"github.com/smartcontractkit/chainlink-common/pkg/logger"
"github.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pb"
"github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types"
"github.com/smartcontractkit/libocr/offchainreporting2plus/types"

"github.com/smartcontractkit/chainlink-common/pkg/logger"
"github.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pb"
"github.com/smartcontractkit/chainlink-common/pkg/workflows/shardorchestrator"
)

var _ ocr3types.ContractTransmitter[[]byte] = (*Transmitter)(nil)

// Transmitter handles transmission of shard orchestration outcomes
type Transmitter struct {
lggr logger.Logger
store *Store
arbiterScaler pb.ArbiterScalerClient
fromAccount types.Account
lggr logger.Logger
ringStore *Store
shardOrchestratorStore *shardorchestrator.Store
arbiterScaler pb.ArbiterScalerClient
fromAccount types.Account
}

func NewTransmitter(lggr logger.Logger, store *Store, arbiterScaler pb.ArbiterScalerClient, fromAccount types.Account) *Transmitter {
return &Transmitter{lggr: lggr, store: store, arbiterScaler: arbiterScaler, fromAccount: fromAccount}
func NewTransmitter(lggr logger.Logger, ringStore *Store, shardOrchestratorStore *shardorchestrator.Store, arbiterScaler pb.ArbiterScalerClient, fromAccount types.Account) *Transmitter {
return &Transmitter{
lggr: lggr,
ringStore: ringStore,
shardOrchestratorStore: shardOrchestratorStore,
arbiterScaler: arbiterScaler,
fromAccount: fromAccount,
}
}

func (t *Transmitter) Transmit(ctx context.Context, _ types.ConfigDigest, _ uint64, r ocr3types.ReportWithInfo[[]byte], _ []types.AttributedOnchainSignature) error {
Expand All @@ -37,10 +46,48 @@ func (t *Transmitter) Transmit(ctx context.Context, _ types.ConfigDigest, _ uint
return err
}

t.store.SetRoutingState(outcome.State)
// Update Ring Store
t.ringStore.SetRoutingState(outcome.State)

// Determine transition state for ShardOrchestrator
transitionState := "steady"
var oldShardCount, newShardCount uint32

if outcome.State != nil {
switch s := outcome.State.State.(type) {
case *pb.RoutingState_RoutableShards:
newShardCount = s.RoutableShards
oldShardCount = newShardCount
case *pb.RoutingState_Transition:
transitionState = "transitioning"
oldShardCount = s.Transition.LastStableCount
newShardCount = s.Transition.WantShards
}
}

// Update ShardOrchestrator store if available
if t.shardOrchestratorStore != nil {
mappings := make([]*shardorchestrator.WorkflowMappingState, 0, len(outcome.Routes))
for workflowID, route := range outcome.Routes {
mappings = append(mappings, &shardorchestrator.WorkflowMappingState{
WorkflowID: workflowID,
OldShardID: oldShardCount,
NewShardID: route.Shard,
TransitionState: transitionState,
})
}

if err := t.shardOrchestratorStore.BatchUpdateWorkflowMappings(ctx, mappings); err != nil {
t.lggr.Errorw("failed to update ShardOrchestrator store", "err", err, "workflowCount", len(mappings))
// Don't fail the entire transmission if ShardOrchestrator update fails
} else {
t.lggr.Debugw("Updated ShardOrchestrator store", "workflowCount", len(mappings), "state", transitionState)
}
}

// Update Ring Store workflow mappings
for workflowID, route := range outcome.Routes {
t.store.SetShardForWorkflow(workflowID, route.Shard)
t.ringStore.SetShardForWorkflow(workflowID, route.Shard)
t.lggr.Debugw("Updated workflow shard mapping", "workflowID", workflowID, "shard", route.Shard)
}

Expand Down
21 changes: 11 additions & 10 deletions pkg/workflows/ring/transmitter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ import (
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/emptypb"

"github.com/smartcontractkit/chainlink-common/pkg/logger"
"github.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pb"
"github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types"
"github.com/smartcontractkit/libocr/offchainreporting2plus/types"

"github.com/smartcontractkit/chainlink-common/pkg/logger"
"github.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pb"
)

type mockArbiterScaler struct {
Expand All @@ -37,14 +38,14 @@ func (m *mockArbiterScaler) ConsensusWantShards(ctx context.Context, req *pb.Con
func TestTransmitter_NewTransmitter(t *testing.T) {
lggr := logger.Test(t)
store := NewStore()
tx := NewTransmitter(lggr, store, nil, "test-account")
tx := NewTransmitter(lggr, store, nil, nil, "test-account")
require.NotNil(t, tx)
}

func TestTransmitter_FromAccount(t *testing.T) {
lggr := logger.Test(t)
store := NewStore()
tx := NewTransmitter(lggr, store, nil, "my-account")
tx := NewTransmitter(lggr, store, nil, nil, "my-account")

account, err := tx.FromAccount(context.Background())
require.NoError(t, err)
Expand All @@ -55,7 +56,7 @@ func TestTransmitter_Transmit(t *testing.T) {
lggr := logger.Test(t)
store := NewStore()
mock := &mockArbiterScaler{}
tx := NewTransmitter(lggr, store, mock, "test-account")
tx := NewTransmitter(lggr, store, nil, mock, "test-account")

outcome := &pb.Outcome{
State: &pb.RoutingState{
Expand Down Expand Up @@ -88,7 +89,7 @@ func TestTransmitter_Transmit(t *testing.T) {
func TestTransmitter_Transmit_NilArbiter(t *testing.T) {
lggr := logger.Test(t)
store := NewStore()
tx := NewTransmitter(lggr, store, nil, "test-account")
tx := NewTransmitter(lggr, store, nil, nil, "test-account")

outcome := &pb.Outcome{
State: &pb.RoutingState{
Expand All @@ -107,7 +108,7 @@ func TestTransmitter_Transmit_TransitionState(t *testing.T) {
lggr := logger.Test(t)
store := NewStore()
mock := &mockArbiterScaler{}
tx := NewTransmitter(lggr, store, mock, "test-account")
tx := NewTransmitter(lggr, store, nil, mock, "test-account")

outcome := &pb.Outcome{
State: &pb.RoutingState{
Expand All @@ -127,7 +128,7 @@ func TestTransmitter_Transmit_TransitionState(t *testing.T) {
func TestTransmitter_Transmit_InvalidReport(t *testing.T) {
lggr := logger.Test(t)
store := NewStore()
tx := NewTransmitter(lggr, store, nil, "test-account")
tx := NewTransmitter(lggr, store, nil, nil, "test-account")

// Send invalid protobuf data
report := ocr3types.ReportWithInfo[[]byte]{Report: []byte("invalid protobuf")}
Expand All @@ -139,7 +140,7 @@ func TestTransmitter_Transmit_ArbiterError(t *testing.T) {
lggr := logger.Test(t)
store := NewStore()
mock := &mockArbiterScaler{err: context.DeadlineExceeded}
tx := NewTransmitter(lggr, store, mock, "test-account")
tx := NewTransmitter(lggr, store, nil, mock, "test-account")

outcome := &pb.Outcome{
State: &pb.RoutingState{
Expand All @@ -156,7 +157,7 @@ func TestTransmitter_Transmit_ArbiterError(t *testing.T) {
func TestTransmitter_Transmit_NilState(t *testing.T) {
lggr := logger.Test(t)
store := NewStore()
tx := NewTransmitter(lggr, store, nil, "test-account")
tx := NewTransmitter(lggr, store, nil, nil, "test-account")

outcome := &pb.Outcome{
State: nil,
Expand Down
Loading
Loading