diff --git a/pkg/workflows/ring/factory.go b/pkg/workflows/ring/factory.go index 8b85d02c8..5e8710f9d 100644 --- a/pkg/workflows/ring/factory.go +++ b/pkg/workflows/ring/factory.go @@ -4,11 +4,13 @@ import ( "context" "errors" + "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" + "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/services" "github.com/smartcontractkit/chainlink-common/pkg/types/core" "github.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pb" - "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/shardorchestrator" ) const ( @@ -20,15 +22,16 @@ const ( var _ core.OCR3ReportingPluginFactory = &Factory{} type Factory struct { - store *Store - arbiterScaler pb.ArbiterScalerClient - config *ConsensusConfig - lggr logger.Logger + ringStore *Store + shardOrchestratorStore *shardorchestrator.Store + arbiterScaler pb.ArbiterScalerClient + config *ConsensusConfig + lggr logger.Logger services.StateMachine } -func NewFactory(s *Store, arbiterScaler pb.ArbiterScalerClient, lggr logger.Logger, cfg *ConsensusConfig) (*Factory, error) { +func NewFactory(s *Store, shardOrchestratorStore *shardorchestrator.Store, arbiterScaler pb.ArbiterScalerClient, lggr logger.Logger, cfg *ConsensusConfig) (*Factory, error) { if arbiterScaler == nil { return nil, errors.New("arbiterScaler is required") } @@ -38,15 +41,16 @@ func NewFactory(s *Store, arbiterScaler pb.ArbiterScalerClient, lggr logger.Logg } } return &Factory{ - store: s, - arbiterScaler: arbiterScaler, - config: cfg, - lggr: logger.Named(lggr, "RingPluginFactory"), + ringStore: s, + shardOrchestratorStore: shardOrchestratorStore, + arbiterScaler: arbiterScaler, + config: cfg, + lggr: logger.Named(lggr, "RingPluginFactory"), }, nil } func (o *Factory) NewReportingPlugin(_ context.Context, config ocr3types.ReportingPluginConfig) (ocr3types.ReportingPlugin[[]byte], ocr3types.ReportingPluginInfo, error) { - plugin, err := NewPlugin(o.store, o.arbiterScaler, config, o.lggr, o.config) + plugin, err := NewPlugin(o.ringStore, o.arbiterScaler, config, o.lggr, o.config) pluginInfo := ocr3types.ReportingPluginInfo{ Name: "RingPlugin", Limits: ocr3types.ReportingPluginLimits{ diff --git a/pkg/workflows/ring/factory_test.go b/pkg/workflows/ring/factory_test.go index 557f872b5..f2a1993f4 100644 --- a/pkg/workflows/ring/factory_test.go +++ b/pkg/workflows/ring/factory_test.go @@ -4,15 +4,18 @@ import ( "context" "testing" - "github.com/smartcontractkit/chainlink-common/pkg/logger" - "github.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pb" "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pb" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/shardorchestrator" ) func TestFactory_NewFactory(t *testing.T) { lggr := logger.Test(t) store := NewStore() + shardOrchestratorStore := shardorchestrator.NewStore(lggr) arbiter := &mockArbiter{} tests := []struct { @@ -45,7 +48,7 @@ func TestFactory_NewFactory(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - f, err := NewFactory(store, tt.arbiter, lggr, tt.config) + f, err := NewFactory(store, shardOrchestratorStore, tt.arbiter, lggr, tt.config) if tt.wantErr { require.Error(t, err) require.Contains(t, err.Error(), tt.errSubstr) @@ -60,7 +63,7 @@ func TestFactory_NewFactory(t *testing.T) { func TestFactory_NewReportingPlugin(t *testing.T) { lggr := logger.Test(t) store := NewStore() - f, err := NewFactory(store, &mockArbiter{}, lggr, nil) + f, err := NewFactory(store, nil, &mockArbiter{}, lggr, nil) require.NoError(t, err) config := ocr3types.ReportingPluginConfig{N: 4, F: 1} @@ -75,7 +78,7 @@ func TestFactory_NewReportingPlugin(t *testing.T) { func TestFactory_Lifecycle(t *testing.T) { lggr := logger.Test(t) store := NewStore() - f, err := NewFactory(store, &mockArbiter{}, lggr, nil) + f, err := NewFactory(store, nil, &mockArbiter{}, lggr, nil) require.NoError(t, err) err = f.Start(context.Background()) diff --git a/pkg/workflows/ring/pb/generate.go b/pkg/workflows/ring/pb/generate.go index 850f3eeb4..bff63fddd 100644 --- a/pkg/workflows/ring/pb/generate.go +++ b/pkg/workflows/ring/pb/generate.go @@ -1,6 +1,5 @@ //go:generate protoc --go_out=. --go_opt=paths=source_relative shared.proto //go:generate protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative arbiter.proto -//go:generate protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative shard_orchestrator.proto //go:generate protoc --go_out=. --go_opt=paths=source_relative consensus.proto package pb diff --git a/pkg/workflows/ring/plugin_test.go b/pkg/workflows/ring/plugin_test.go index 99f70d1b9..094e32c93 100644 --- a/pkg/workflows/ring/plugin_test.go +++ b/pkg/workflows/ring/plugin_test.go @@ -12,10 +12,12 @@ import ( "google.golang.org/protobuf/types/known/emptypb" "google.golang.org/protobuf/types/known/timestamppb" - "github.com/smartcontractkit/chainlink-common/pkg/logger" - "github.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pb" "github.com/smartcontractkit/libocr/offchainreporting2/types" "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pb" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/shardorchestrator" ) type mockArbiter struct { @@ -443,7 +445,7 @@ func TestPlugin_NoHealthyShardsFallbackToShardZero(t *testing.T) { }) require.NoError(t, err) - transmitter := NewTransmitter(lggr, store, arbiter, "test-account") + transmitter := NewTransmitter(lggr, store, nil, arbiter, "test-account") ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) defer cancel() @@ -591,3 +593,158 @@ func TestPlugin_ObservationQuorum(t *testing.T) { require.True(t, quorum) }) } + +func TestPlugin_ShardOrchestratorIntegration(t *testing.T) { + lggr := logger.Test(t) + + // Create both stores + ringStore := NewStore() + orchestratorStore := shardorchestrator.NewStore(lggr) + + // Initialize ring store with healthy shards + ringStore.SetAllShardHealth(map[uint32]bool{0: true, 1: true, 2: true}) + + config := ocr3types.ReportingPluginConfig{ + N: 4, F: 1, + } + + arbiter := &mockArbiter{} + plugin, err := NewPlugin(ringStore, arbiter, config, lggr, &ConsensusConfig{ + BatchSize: 100, + TimeToSync: 1 * time.Second, + }) + require.NoError(t, err) + + // Create transmitter with both stores + transmitter := NewTransmitter(lggr, ringStore, orchestratorStore, arbiter, "test-account") + + ctx := context.Background() + now := time.Now() + + t.Run("initial_workflow_assignments", func(t *testing.T) { + // Create observations with workflows + workflows := []string{"wf-A", "wf-B", "wf-C"} + aos := makeObservationsWithWantShards(t, []map[uint32]*pb.ShardStatus{ + {0: {IsHealthy: true}, 1: {IsHealthy: true}, 2: {IsHealthy: true}}, + {0: {IsHealthy: true}, 1: {IsHealthy: true}, 2: {IsHealthy: true}}, + {0: {IsHealthy: true}, 1: {IsHealthy: true}, 2: {IsHealthy: true}}, + }, workflows, now, 3) + + outcomeCtx := ocr3types.OutcomeContext{ + SeqNr: 1, + PreviousOutcome: nil, + } + + // Generate outcome + outcome, err := plugin.Outcome(ctx, outcomeCtx, nil, aos) + require.NoError(t, err) + + // Generate report and transmit + reports, err := plugin.Reports(ctx, 1, outcome) + require.NoError(t, err) + require.Len(t, reports, 1) + + err = transmitter.Transmit(ctx, types.ConfigDigest{}, 1, reports[0].ReportWithInfo, nil) + require.NoError(t, err) + + // Verify ring store was updated + for _, wf := range workflows { + shard, err := ringStore.GetShardForWorkflow(ctx, wf) + require.NoError(t, err) + require.LessOrEqual(t, shard, uint32(2), "workflow should be assigned to valid shard") + t.Logf("Ring store: %s → shard %d", wf, shard) + } + + // Verify orchestrator store was updated with correct state + for _, wf := range workflows { + mapping, err := orchestratorStore.GetWorkflowMapping(ctx, wf) + require.NoError(t, err) + require.Equal(t, wf, mapping.WorkflowID) + require.LessOrEqual(t, mapping.NewShardID, uint32(2)) + require.Equal(t, uint32(0), mapping.OldShardID, "initial assignment should have oldShardID=0") + require.Equal(t, shardorchestrator.StateSteady, mapping.TransitionState, "initial assignment should be steady") + t.Logf("Orchestrator store: %s → shard %d (state: %s)", wf, mapping.NewShardID, mapping.TransitionState.String()) + } + + // Verify version tracking + version := orchestratorStore.GetMappingVersion() + require.Equal(t, uint64(1), version, "version should increment after first update") + }) + + t.Run("workflow_transition_detected", func(t *testing.T) { + // First, establish a baseline with workflows distributed across 3 shards + // Use wantShards=3 to ensure workflows actually get assigned to shard 2 + baselineAos := makeObservationsWithWantShards(t, []map[uint32]*pb.ShardStatus{ + {0: {IsHealthy: true}, 1: {IsHealthy: true}, 2: {IsHealthy: true}}, + {0: {IsHealthy: true}, 1: {IsHealthy: true}, 2: {IsHealthy: true}}, + {0: {IsHealthy: true}, 1: {IsHealthy: true}, 2: {IsHealthy: true}}, + }, []string{"wf-A", "wf-B", "wf-C", "wf-D", "wf-E"}, now, 3) + + baselineOutcome, err := plugin.Outcome(ctx, ocr3types.OutcomeContext{SeqNr: 2}, nil, baselineAos) + require.NoError(t, err) + + baselineReports, err := plugin.Reports(ctx, 2, baselineOutcome) + require.NoError(t, err) + + err = transmitter.Transmit(ctx, types.ConfigDigest{}, 2, baselineReports[0].ReportWithInfo, nil) + require.NoError(t, err) + + // Parse baseline to see which workflows were on shard 2 + baselineProto := &pb.Outcome{} + err = proto.Unmarshal(baselineOutcome, baselineProto) + require.NoError(t, err) + + workflowsOnShard2 := []string{} + for wfID, route := range baselineProto.Routes { + if route.Shard == 2 { + workflowsOnShard2 = append(workflowsOnShard2, wfID) + } + t.Logf("Baseline: %s on shard %d", wfID, route.Shard) + } + require.NotEmpty(t, workflowsOnShard2, "at least one workflow should be on shard 2 for this test") + + // Now scale down to 2 shards - workflows on shard 2 MUST move + transitionAos := makeObservationsWithWantShards(t, []map[uint32]*pb.ShardStatus{ + {0: {IsHealthy: true}, 1: {IsHealthy: true}}, + {0: {IsHealthy: true}, 1: {IsHealthy: true}}, + {0: {IsHealthy: true}, 1: {IsHealthy: true}}, + }, []string{"wf-A", "wf-B", "wf-C", "wf-D", "wf-E"}, now, 2) + + outcomeCtx := ocr3types.OutcomeContext{ + SeqNr: 3, + PreviousOutcome: baselineOutcome, + } + + outcome, err := plugin.Outcome(ctx, outcomeCtx, nil, transitionAos) + require.NoError(t, err) + + reports, err := plugin.Reports(ctx, 3, outcome) + require.NoError(t, err) + + err = transmitter.Transmit(ctx, types.ConfigDigest{}, 3, reports[0].ReportWithInfo, nil) + require.NoError(t, err) + + // Verify orchestrator store shows transition state for workflows that moved from shard 2 + outcomeProto := &pb.Outcome{} + err = proto.Unmarshal(outcome, outcomeProto) + require.NoError(t, err) + + // Workflows that were on shard 2 must have moved and should show TransitionState + for _, wfID := range workflowsOnShard2 { + mapping, err := orchestratorStore.GetWorkflowMapping(ctx, wfID) + require.NoError(t, err) + + newRoute := outcomeProto.Routes[wfID] + require.NotEqual(t, uint32(2), newRoute.Shard, "workflow should have moved from shard 2") + require.Equal(t, shardorchestrator.StateTransitioning, mapping.TransitionState, + "workflow %s moved from shard 2 to shard %d, should be transitioning", wfID, newRoute.Shard) + require.Equal(t, uint32(2), mapping.OldShardID, "should track old shard") + require.Equal(t, newRoute.Shard, mapping.NewShardID, "should track new shard") + t.Logf("Workflow %s transitioned: shard 2 → %d", wfID, newRoute.Shard) + } + + // Verify version incremented + version := orchestratorStore.GetMappingVersion() + require.Equal(t, uint64(3), version, "version should increment after update") + }) +} diff --git a/pkg/workflows/ring/state.go b/pkg/workflows/ring/state.go index 62c26a5b1..f1751bf80 100644 --- a/pkg/workflows/ring/state.go +++ b/pkg/workflows/ring/state.go @@ -7,8 +7,25 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" "github.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pb" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/shardorchestrator" ) +// TransitionStateFromBool converts a proto bool (in_transition) to TransitionState +func TransitionStateFromBool(inTransition bool) shardorchestrator.TransitionState { + if inTransition { + return shardorchestrator.StateTransitioning + } + return shardorchestrator.StateSteady +} + +// TransitionStateFromRoutingState returns the TransitionState based on RoutingState +func TransitionStateFromRoutingState(state *pb.RoutingState) shardorchestrator.TransitionState { + if IsInSteadyState(state) { + return shardorchestrator.StateSteady + } + return shardorchestrator.StateTransitioning +} + func IsInSteadyState(state *pb.RoutingState) bool { if state == nil { return false diff --git a/pkg/workflows/ring/transmitter.go b/pkg/workflows/ring/transmitter.go index 524be65be..5e5f72708 100644 --- a/pkg/workflows/ring/transmitter.go +++ b/pkg/workflows/ring/transmitter.go @@ -5,24 +5,33 @@ import ( "google.golang.org/protobuf/proto" - "github.com/smartcontractkit/chainlink-common/pkg/logger" - "github.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pb" "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pb" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/shardorchestrator" ) var _ ocr3types.ContractTransmitter[[]byte] = (*Transmitter)(nil) // Transmitter handles transmission of shard orchestration outcomes type Transmitter struct { - lggr logger.Logger - store *Store - arbiterScaler pb.ArbiterScalerClient - fromAccount types.Account + lggr logger.Logger + ringStore *Store + shardOrchestratorStore *shardorchestrator.Store + arbiterScaler pb.ArbiterScalerClient + fromAccount types.Account } -func NewTransmitter(lggr logger.Logger, store *Store, arbiterScaler pb.ArbiterScalerClient, fromAccount types.Account) *Transmitter { - return &Transmitter{lggr: lggr, store: store, arbiterScaler: arbiterScaler, fromAccount: fromAccount} +func NewTransmitter(lggr logger.Logger, ringStore *Store, shardOrchestratorStore *shardorchestrator.Store, arbiterScaler pb.ArbiterScalerClient, fromAccount types.Account) *Transmitter { + return &Transmitter{ + lggr: lggr, + ringStore: ringStore, + shardOrchestratorStore: shardOrchestratorStore, + arbiterScaler: arbiterScaler, + fromAccount: fromAccount, + } } func (t *Transmitter) Transmit(ctx context.Context, _ types.ConfigDigest, _ uint64, r ocr3types.ReportWithInfo[[]byte], _ []types.AttributedOnchainSignature) error { @@ -37,10 +46,63 @@ func (t *Transmitter) Transmit(ctx context.Context, _ types.ConfigDigest, _ uint return err } - t.store.SetRoutingState(outcome.State) + // Update Ring Store + t.ringStore.SetRoutingState(outcome.State) + + // Determine if system is in transition state + systemInTransition := false + if outcome.State != nil { + if _, ok := outcome.State.State.(*pb.RoutingState_Transition); ok { + systemInTransition = true + } + } + + // Update ShardOrchestrator store if available + if t.shardOrchestratorStore != nil { + mappings := make([]*shardorchestrator.WorkflowMappingState, 0, len(outcome.Routes)) + for workflowID, route := range outcome.Routes { + // Get the current shard assignment for this workflow to detect changes + var oldShardID uint32 + var transitionState shardorchestrator.TransitionState + + existingMapping, err := t.shardOrchestratorStore.GetWorkflowMapping(ctx, workflowID) + if err != nil { + // New workflow - no previous assignment + oldShardID = 0 + transitionState = shardorchestrator.StateSteady + } else if existingMapping.NewShardID != route.Shard { + // Workflow is moving to a different shard + oldShardID = existingMapping.NewShardID + transitionState = shardorchestrator.StateTransitioning + } else { + // Same shard - but might be in system transition + oldShardID = existingMapping.NewShardID + if systemInTransition { + transitionState = shardorchestrator.StateTransitioning + } else { + transitionState = shardorchestrator.StateSteady + } + } + + mappings = append(mappings, &shardorchestrator.WorkflowMappingState{ + WorkflowID: workflowID, + OldShardID: oldShardID, + NewShardID: route.Shard, + TransitionState: transitionState, + }) + } + + if err := t.shardOrchestratorStore.BatchUpdateWorkflowMappings(ctx, mappings); err != nil { + t.lggr.Errorw("failed to update ShardOrchestrator store", "err", err, "workflowCount", len(mappings)) + // Don't fail the entire transmission if ShardOrchestrator update fails + } else { + t.lggr.Debugw("Updated ShardOrchestrator store", "workflowCount", len(mappings)) + } + } + // Update Ring Store workflow mappings for workflowID, route := range outcome.Routes { - t.store.SetShardForWorkflow(workflowID, route.Shard) + t.ringStore.SetShardForWorkflow(workflowID, route.Shard) t.lggr.Debugw("Updated workflow shard mapping", "workflowID", workflowID, "shard", route.Shard) } diff --git a/pkg/workflows/ring/transmitter_test.go b/pkg/workflows/ring/transmitter_test.go index 9fde00015..087f4aaf5 100644 --- a/pkg/workflows/ring/transmitter_test.go +++ b/pkg/workflows/ring/transmitter_test.go @@ -9,10 +9,11 @@ import ( "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/emptypb" - "github.com/smartcontractkit/chainlink-common/pkg/logger" - "github.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pb" "github.com/smartcontractkit/libocr/offchainreporting2plus/ocr3types" "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pb" ) type mockArbiterScaler struct { @@ -37,14 +38,14 @@ func (m *mockArbiterScaler) ConsensusWantShards(ctx context.Context, req *pb.Con func TestTransmitter_NewTransmitter(t *testing.T) { lggr := logger.Test(t) store := NewStore() - tx := NewTransmitter(lggr, store, nil, "test-account") + tx := NewTransmitter(lggr, store, nil, nil, "test-account") require.NotNil(t, tx) } func TestTransmitter_FromAccount(t *testing.T) { lggr := logger.Test(t) store := NewStore() - tx := NewTransmitter(lggr, store, nil, "my-account") + tx := NewTransmitter(lggr, store, nil, nil, "my-account") account, err := tx.FromAccount(context.Background()) require.NoError(t, err) @@ -55,7 +56,7 @@ func TestTransmitter_Transmit(t *testing.T) { lggr := logger.Test(t) store := NewStore() mock := &mockArbiterScaler{} - tx := NewTransmitter(lggr, store, mock, "test-account") + tx := NewTransmitter(lggr, store, nil, mock, "test-account") outcome := &pb.Outcome{ State: &pb.RoutingState{ @@ -88,7 +89,7 @@ func TestTransmitter_Transmit(t *testing.T) { func TestTransmitter_Transmit_NilArbiter(t *testing.T) { lggr := logger.Test(t) store := NewStore() - tx := NewTransmitter(lggr, store, nil, "test-account") + tx := NewTransmitter(lggr, store, nil, nil, "test-account") outcome := &pb.Outcome{ State: &pb.RoutingState{ @@ -107,7 +108,7 @@ func TestTransmitter_Transmit_TransitionState(t *testing.T) { lggr := logger.Test(t) store := NewStore() mock := &mockArbiterScaler{} - tx := NewTransmitter(lggr, store, mock, "test-account") + tx := NewTransmitter(lggr, store, nil, mock, "test-account") outcome := &pb.Outcome{ State: &pb.RoutingState{ @@ -127,7 +128,7 @@ func TestTransmitter_Transmit_TransitionState(t *testing.T) { func TestTransmitter_Transmit_InvalidReport(t *testing.T) { lggr := logger.Test(t) store := NewStore() - tx := NewTransmitter(lggr, store, nil, "test-account") + tx := NewTransmitter(lggr, store, nil, nil, "test-account") // Send invalid protobuf data report := ocr3types.ReportWithInfo[[]byte]{Report: []byte("invalid protobuf")} @@ -139,7 +140,7 @@ func TestTransmitter_Transmit_ArbiterError(t *testing.T) { lggr := logger.Test(t) store := NewStore() mock := &mockArbiterScaler{err: context.DeadlineExceeded} - tx := NewTransmitter(lggr, store, mock, "test-account") + tx := NewTransmitter(lggr, store, nil, mock, "test-account") outcome := &pb.Outcome{ State: &pb.RoutingState{ @@ -156,7 +157,7 @@ func TestTransmitter_Transmit_ArbiterError(t *testing.T) { func TestTransmitter_Transmit_NilState(t *testing.T) { lggr := logger.Test(t) store := NewStore() - tx := NewTransmitter(lggr, store, nil, "test-account") + tx := NewTransmitter(lggr, store, nil, nil, "test-account") outcome := &pb.Outcome{ State: nil, diff --git a/pkg/workflows/shardorchestrator/client.go b/pkg/workflows/shardorchestrator/client.go new file mode 100644 index 000000000..c7c9ed0f8 --- /dev/null +++ b/pkg/workflows/shardorchestrator/client.go @@ -0,0 +1,78 @@ +package shardorchestrator + +import ( + "context" + "fmt" + + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/shardorchestrator/pb" +) + +// Client wraps gRPC client for communicating with shard 0's orchestrator service +type Client struct { + conn *grpc.ClientConn + client pb.ShardOrchestratorServiceClient + logger logger.Logger +} + +// NewClient creates a new gRPC client to communicate with the shard orchestrator on shard 0 +func NewClient(ctx context.Context, address string, lggr logger.Logger) (*Client, error) { + conn, err := grpc.NewClient(address, + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + return nil, fmt.Errorf("failed to create shard orchestrator client for %s: %w", address, err) + } + + return &Client{ + conn: conn, + client: pb.NewShardOrchestratorServiceClient(conn), + logger: logger.Named(lggr, "ShardOrchestratorClient"), + }, nil +} + +// GetWorkflowShardMapping queries shard 0 for workflow-to-shard mappings +func (c *Client) GetWorkflowShardMapping(ctx context.Context, workflowIDs []string) (*pb.GetWorkflowShardMappingResponse, error) { + c.logger.Debugw("Calling GetWorkflowShardMapping", "workflowCount", len(workflowIDs)) + + req := &pb.GetWorkflowShardMappingRequest{ + WorkflowIds: workflowIDs, + } + + resp, err := c.client.GetWorkflowShardMapping(ctx, req) + if err != nil { + return nil, fmt.Errorf("gRPC GetWorkflowShardMapping failed: %w", err) + } + + c.logger.Debugw("GetWorkflowShardMapping response received", + "mappingCount", len(resp.Mappings), + "version", resp.MappingVersion) + + return resp, nil +} + +// ReportWorkflowTriggerRegistration reports workflow trigger registration to shard 0 +func (c *Client) ReportWorkflowTriggerRegistration(ctx context.Context, req *pb.ReportWorkflowTriggerRegistrationRequest) (*pb.ReportWorkflowTriggerRegistrationResponse, error) { + c.logger.Debugw("Calling ReportWorkflowTriggerRegistration", + "shardID", req.SourceShardId, + "workflowCount", len(req.RegisteredWorkflows)) + + resp, err := c.client.ReportWorkflowTriggerRegistration(ctx, req) + if err != nil { + return nil, fmt.Errorf("gRPC ReportWorkflowTriggerRegistration failed: %w", err) + } + + c.logger.Debugw("ReportWorkflowTriggerRegistration response received", + "success", resp.Success) + + return resp, nil +} + +// Close closes the gRPC connection +func (c *Client) Close() error { + c.logger.Info("Closing ShardOrchestrator gRPC client") + return c.conn.Close() +} diff --git a/pkg/workflows/shardorchestrator/client_test.go b/pkg/workflows/shardorchestrator/client_test.go new file mode 100644 index 000000000..a090c531f --- /dev/null +++ b/pkg/workflows/shardorchestrator/client_test.go @@ -0,0 +1,212 @@ +package shardorchestrator + +import ( + "context" + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/test/bufconn" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/shardorchestrator/pb" +) + +const bufSize = 1024 * 1024 + +// mockShardOrchestratorServer implements the gRPC server for testing +type mockShardOrchestratorServer struct { + pb.UnimplementedShardOrchestratorServiceServer + mappings map[string]uint32 + registrationCalled bool +} + +func (m *mockShardOrchestratorServer) GetWorkflowShardMapping(ctx context.Context, req *pb.GetWorkflowShardMappingRequest) (*pb.GetWorkflowShardMappingResponse, error) { + mappings := make(map[string]uint32) + mappingStates := make(map[string]*pb.WorkflowMappingState) + + for _, wfID := range req.WorkflowIds { + if shardID, ok := m.mappings[wfID]; ok { + mappings[wfID] = shardID + mappingStates[wfID] = &pb.WorkflowMappingState{ + OldShardId: 0, + NewShardId: shardID, + InTransition: false, + } + } + } + + return &pb.GetWorkflowShardMappingResponse{ + Mappings: mappings, + MappingStates: mappingStates, + MappingVersion: 1, + }, nil +} + +func (m *mockShardOrchestratorServer) ReportWorkflowTriggerRegistration(ctx context.Context, req *pb.ReportWorkflowTriggerRegistrationRequest) (*pb.ReportWorkflowTriggerRegistrationResponse, error) { + m.registrationCalled = true + return &pb.ReportWorkflowTriggerRegistrationResponse{ + Success: true, + }, nil +} + +// setupTestServer creates a test gRPC server using bufconn +func setupTestServer(t *testing.T, mock *mockShardOrchestratorServer) (*grpc.Server, *bufconn.Listener) { + lis := bufconn.Listen(bufSize) + s := grpc.NewServer() + pb.RegisterShardOrchestratorServiceServer(s, mock) + + go func() { + if err := s.Serve(lis); err != nil { + t.Logf("Server exited with error: %v", err) + } + }() + + return s, lis +} + +// createTestClient creates a client connected to the test server +func createTestClient(t *testing.T, lis *bufconn.Listener) *Client { + conn, err := grpc.NewClient("passthrough://bufnet", + grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) { + return lis.Dial() + }), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + require.NoError(t, err) + + lggr := logger.Test(t) + return &Client{ + conn: conn, + client: pb.NewShardOrchestratorServiceClient(conn), + logger: logger.Named(lggr, "TestClient"), + } +} + +func TestClient_GetWorkflowShardMapping(t *testing.T) { + ctx := context.Background() + + mock := &mockShardOrchestratorServer{ + mappings: map[string]uint32{ + "workflow-1": 0, + "workflow-2": 1, + "workflow-3": 2, + }, + } + + grpcServer, lis := setupTestServer(t, mock) + defer grpcServer.Stop() + + client := createTestClient(t, lis) + defer client.Close() + + t.Run("successful mapping query", func(t *testing.T) { + workflowIDs := []string{"workflow-1", "workflow-2", "workflow-3"} + resp, err := client.GetWorkflowShardMapping(ctx, workflowIDs) + require.NoError(t, err) + require.NotNil(t, resp) + + assert.Len(t, resp.Mappings, 3) + assert.Equal(t, uint32(0), resp.Mappings["workflow-1"]) + assert.Equal(t, uint32(1), resp.Mappings["workflow-2"]) + assert.Equal(t, uint32(2), resp.Mappings["workflow-3"]) + + assert.Len(t, resp.MappingStates, 3) + assert.Equal(t, uint64(1), resp.MappingVersion) + }) + + t.Run("partial workflow query", func(t *testing.T) { + workflowIDs := []string{"workflow-1", "workflow-unknown"} + resp, err := client.GetWorkflowShardMapping(ctx, workflowIDs) + require.NoError(t, err) + require.NotNil(t, resp) + + // Should only return mappings for known workflows + assert.Len(t, resp.Mappings, 1) + assert.Equal(t, uint32(0), resp.Mappings["workflow-1"]) + _, exists := resp.Mappings["workflow-unknown"] + assert.False(t, exists) + }) + + t.Run("empty workflow list", func(t *testing.T) { + resp, err := client.GetWorkflowShardMapping(ctx, []string{}) + require.NoError(t, err) + require.NotNil(t, resp) + + assert.Empty(t, resp.Mappings) + }) +} + +func TestClient_ReportWorkflowTriggerRegistration(t *testing.T) { + ctx := context.Background() + + mock := &mockShardOrchestratorServer{ + mappings: map[string]uint32{}, + } + + grpcServer, lis := setupTestServer(t, mock) + defer grpcServer.Stop() + + client := createTestClient(t, lis) + defer client.Close() + + t.Run("successful registration report", func(t *testing.T) { + req := &pb.ReportWorkflowTriggerRegistrationRequest{ + SourceShardId: 1, + RegisteredWorkflows: map[string]uint32{ + "workflow-1": 1, + "workflow-2": 1, + }, + TotalActiveWorkflows: 2, + } + + resp, err := client.ReportWorkflowTriggerRegistration(ctx, req) + require.NoError(t, err) + require.NotNil(t, resp) + + assert.True(t, resp.Success) + assert.True(t, mock.registrationCalled) + }) +} + +func TestClient_Close(t *testing.T) { + mock := &mockShardOrchestratorServer{ + mappings: map[string]uint32{}, + } + + grpcServer, lis := setupTestServer(t, mock) + defer grpcServer.Stop() + + client := createTestClient(t, lis) + + err := client.Close() + assert.NoError(t, err) + + // Verify connection is closed by attempting to use it + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + _, err = client.GetWorkflowShardMapping(ctx, []string{"test"}) + assert.Error(t, err, "should fail after client is closed") +} + +func TestNewClient(t *testing.T) { + ctx := context.Background() + lggr := logger.Test(t) + + t.Run("creates client successfully", func(t *testing.T) { + // Note: This creates a client but doesn't connect immediately with grpc.NewClient + client, err := NewClient(ctx, "localhost:50051", lggr) + require.NoError(t, err) + require.NotNil(t, client) + defer client.Close() + + assert.NotNil(t, client.conn) + assert.NotNil(t, client.client) + assert.NotNil(t, client.logger) + }) +} diff --git a/pkg/workflows/shardorchestrator/pb/generate.go b/pkg/workflows/shardorchestrator/pb/generate.go new file mode 100644 index 000000000..be4026238 --- /dev/null +++ b/pkg/workflows/shardorchestrator/pb/generate.go @@ -0,0 +1,3 @@ +package pb + +//go:generate protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative shard_orchestrator.proto diff --git a/pkg/workflows/ring/pb/shard_orchestrator.pb.go b/pkg/workflows/shardorchestrator/pb/shard_orchestrator.pb.go similarity index 68% rename from pkg/workflows/ring/pb/shard_orchestrator.pb.go rename to pkg/workflows/shardorchestrator/pb/shard_orchestrator.pb.go index 7a3f8491e..ba90ab901 100644 --- a/pkg/workflows/ring/pb/shard_orchestrator.pb.go +++ b/pkg/workflows/shardorchestrator/pb/shard_orchestrator.pb.go @@ -9,7 +9,6 @@ package pb import ( protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" - timestamppb "google.golang.org/protobuf/types/known/timestamppb" reflect "reflect" sync "sync" unsafe "unsafe" @@ -71,7 +70,6 @@ type WorkflowMappingState struct { OldShardId uint32 `protobuf:"varint,1,opt,name=old_shard_id,json=oldShardId,proto3" json:"old_shard_id,omitempty"` NewShardId uint32 `protobuf:"varint,2,opt,name=new_shard_id,json=newShardId,proto3" json:"new_shard_id,omitempty"` InTransition bool `protobuf:"varint,3,opt,name=in_transition,json=inTransition,proto3" json:"in_transition,omitempty"` - LastUpdated *timestamppb.Timestamp `protobuf:"bytes,4,opt,name=last_updated,json=lastUpdated,proto3" json:"last_updated,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -127,19 +125,11 @@ func (x *WorkflowMappingState) GetInTransition() bool { return false } -func (x *WorkflowMappingState) GetLastUpdated() *timestamppb.Timestamp { - if x != nil { - return x.LastUpdated - } - return nil -} - type GetWorkflowShardMappingResponse struct { state protoimpl.MessageState `protogen:"open.v1"` Mappings map[string]uint32 `protobuf:"bytes,1,rep,name=mappings,proto3" json:"mappings,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"varint,2,opt,name=value"` MappingStates map[string]*WorkflowMappingState `protobuf:"bytes,2,rep,name=mapping_states,json=mappingStates,proto3" json:"mapping_states,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` - Timestamp *timestamppb.Timestamp `protobuf:"bytes,3,opt,name=timestamp,proto3" json:"timestamp,omitempty"` - MappingVersion uint64 `protobuf:"varint,4,opt,name=mapping_version,json=mappingVersion,proto3" json:"mapping_version,omitempty"` + MappingVersion uint64 `protobuf:"varint,3,opt,name=mapping_version,json=mappingVersion,proto3" json:"mapping_version,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -188,13 +178,6 @@ func (x *GetWorkflowShardMappingResponse) GetMappingStates() map[string]*Workflo return nil } -func (x *GetWorkflowShardMappingResponse) GetTimestamp() *timestamppb.Timestamp { - if x != nil { - return x.Timestamp - } - return nil -} - func (x *GetWorkflowShardMappingResponse) GetMappingVersion() uint64 { if x != nil { return x.MappingVersion @@ -206,8 +189,7 @@ type ReportWorkflowTriggerRegistrationRequest struct { state protoimpl.MessageState `protogen:"open.v1"` SourceShardId uint32 `protobuf:"varint,1,opt,name=source_shard_id,json=sourceShardId,proto3" json:"source_shard_id,omitempty"` RegisteredWorkflows map[string]uint32 `protobuf:"bytes,2,rep,name=registered_workflows,json=registeredWorkflows,proto3" json:"registered_workflows,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"varint,2,opt,name=value"` - ReportTimestamp *timestamppb.Timestamp `protobuf:"bytes,3,opt,name=report_timestamp,json=reportTimestamp,proto3" json:"report_timestamp,omitempty"` - TotalActiveWorkflows uint32 `protobuf:"varint,4,opt,name=total_active_workflows,json=totalActiveWorkflows,proto3" json:"total_active_workflows,omitempty"` + TotalActiveWorkflows uint32 `protobuf:"varint,3,opt,name=total_active_workflows,json=totalActiveWorkflows,proto3" json:"total_active_workflows,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -256,13 +238,6 @@ func (x *ReportWorkflowTriggerRegistrationRequest) GetRegisteredWorkflows() map[ return nil } -func (x *ReportWorkflowTriggerRegistrationRequest) GetReportTimestamp() *timestamppb.Timestamp { - if x != nil { - return x.ReportTimestamp - } - return nil -} - func (x *ReportWorkflowTriggerRegistrationRequest) GetTotalActiveWorkflows() uint32 { if x != nil { return x.TotalActiveWorkflows @@ -318,40 +293,37 @@ var File_shard_orchestrator_proto protoreflect.FileDescriptor const file_shard_orchestrator_proto_rawDesc = "" + "\n" + - "\x18shard_orchestrator.proto\x12\x04ring\x1a\x1fgoogle/protobuf/timestamp.proto\"C\n" + + "\x18shard_orchestrator.proto\x12\x11shardorchestrator\"C\n" + "\x1eGetWorkflowShardMappingRequest\x12!\n" + - "\fworkflow_ids\x18\x01 \x03(\tR\vworkflowIds\"\xbe\x01\n" + + "\fworkflow_ids\x18\x01 \x03(\tR\vworkflowIds\"\x7f\n" + "\x14WorkflowMappingState\x12 \n" + "\fold_shard_id\x18\x01 \x01(\rR\n" + "oldShardId\x12 \n" + "\fnew_shard_id\x18\x02 \x01(\rR\n" + "newShardId\x12#\n" + - "\rin_transition\x18\x03 \x01(\bR\finTransition\x12=\n" + - "\flast_updated\x18\x04 \x01(\v2\x1a.google.protobuf.TimestampR\vlastUpdated\"\xd1\x03\n" + - "\x1fGetWorkflowShardMappingResponse\x12O\n" + - "\bmappings\x18\x01 \x03(\v23.ring.GetWorkflowShardMappingResponse.MappingsEntryR\bmappings\x12_\n" + - "\x0emapping_states\x18\x02 \x03(\v28.ring.GetWorkflowShardMappingResponse.MappingStatesEntryR\rmappingStates\x128\n" + - "\ttimestamp\x18\x03 \x01(\v2\x1a.google.protobuf.TimestampR\ttimestamp\x12'\n" + - "\x0fmapping_version\x18\x04 \x01(\x04R\x0emappingVersion\x1a;\n" + + "\rin_transition\x18\x03 \x01(\bR\finTransition\"\xbe\x03\n" + + "\x1fGetWorkflowShardMappingResponse\x12\\\n" + + "\bmappings\x18\x01 \x03(\v2@.shardorchestrator.GetWorkflowShardMappingResponse.MappingsEntryR\bmappings\x12l\n" + + "\x0emapping_states\x18\x02 \x03(\v2E.shardorchestrator.GetWorkflowShardMappingResponse.MappingStatesEntryR\rmappingStates\x12'\n" + + "\x0fmapping_version\x18\x03 \x01(\x04R\x0emappingVersion\x1a;\n" + "\rMappingsEntry\x12\x10\n" + "\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n" + - "\x05value\x18\x02 \x01(\rR\x05value:\x028\x01\x1a\\\n" + + "\x05value\x18\x02 \x01(\rR\x05value:\x028\x01\x1ai\n" + "\x12MappingStatesEntry\x12\x10\n" + - "\x03key\x18\x01 \x01(\tR\x03key\x120\n" + - "\x05value\x18\x02 \x01(\v2\x1a.ring.WorkflowMappingStateR\x05value:\x028\x01\"\x93\x03\n" + + "\x03key\x18\x01 \x01(\tR\x03key\x12=\n" + + "\x05value\x18\x02 \x01(\v2'.shardorchestrator.WorkflowMappingStateR\x05value:\x028\x01\"\xda\x02\n" + "(ReportWorkflowTriggerRegistrationRequest\x12&\n" + - "\x0fsource_shard_id\x18\x01 \x01(\rR\rsourceShardId\x12z\n" + - "\x14registered_workflows\x18\x02 \x03(\v2G.ring.ReportWorkflowTriggerRegistrationRequest.RegisteredWorkflowsEntryR\x13registeredWorkflows\x12E\n" + - "\x10report_timestamp\x18\x03 \x01(\v2\x1a.google.protobuf.TimestampR\x0freportTimestamp\x124\n" + - "\x16total_active_workflows\x18\x04 \x01(\rR\x14totalActiveWorkflows\x1aF\n" + + "\x0fsource_shard_id\x18\x01 \x01(\rR\rsourceShardId\x12\x87\x01\n" + + "\x14registered_workflows\x18\x02 \x03(\v2T.shardorchestrator.ReportWorkflowTriggerRegistrationRequest.RegisteredWorkflowsEntryR\x13registeredWorkflows\x124\n" + + "\x16total_active_workflows\x18\x03 \x01(\rR\x14totalActiveWorkflows\x1aF\n" + "\x18RegisteredWorkflowsEntry\x12\x10\n" + "\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n" + "\x05value\x18\x02 \x01(\rR\x05value:\x028\x01\"E\n" + ")ReportWorkflowTriggerRegistrationResponse\x12\x18\n" + - "\asuccess\x18\x01 \x01(\bR\asuccess2\x89\x02\n" + - "\x18ShardOrchestratorService\x12f\n" + - "\x17GetWorkflowShardMapping\x12$.ring.GetWorkflowShardMappingRequest\x1a%.ring.GetWorkflowShardMappingResponse\x12\x84\x01\n" + - "!ReportWorkflowTriggerRegistration\x12..ring.ReportWorkflowTriggerRegistrationRequest\x1a/.ring.ReportWorkflowTriggerRegistrationResponseBDZBgithub.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pbb\x06proto3" + "\asuccess\x18\x01 \x01(\bR\asuccess2\xbe\x02\n" + + "\x18ShardOrchestratorService\x12\x80\x01\n" + + "\x17GetWorkflowShardMapping\x121.shardorchestrator.GetWorkflowShardMappingRequest\x1a2.shardorchestrator.GetWorkflowShardMappingResponse\x12\x9e\x01\n" + + "!ReportWorkflowTriggerRegistration\x12;.shardorchestrator.ReportWorkflowTriggerRegistrationRequest\x1a<.shardorchestrator.ReportWorkflowTriggerRegistrationResponseBQZOgithub.com/smartcontractkit/chainlink-common/pkg/workflows/shardorchestrator/pbb\x06proto3" var ( file_shard_orchestrator_proto_rawDescOnce sync.Once @@ -367,33 +339,29 @@ func file_shard_orchestrator_proto_rawDescGZIP() []byte { var file_shard_orchestrator_proto_msgTypes = make([]protoimpl.MessageInfo, 8) var file_shard_orchestrator_proto_goTypes = []any{ - (*GetWorkflowShardMappingRequest)(nil), // 0: ring.GetWorkflowShardMappingRequest - (*WorkflowMappingState)(nil), // 1: ring.WorkflowMappingState - (*GetWorkflowShardMappingResponse)(nil), // 2: ring.GetWorkflowShardMappingResponse - (*ReportWorkflowTriggerRegistrationRequest)(nil), // 3: ring.ReportWorkflowTriggerRegistrationRequest - (*ReportWorkflowTriggerRegistrationResponse)(nil), // 4: ring.ReportWorkflowTriggerRegistrationResponse - nil, // 5: ring.GetWorkflowShardMappingResponse.MappingsEntry - nil, // 6: ring.GetWorkflowShardMappingResponse.MappingStatesEntry - nil, // 7: ring.ReportWorkflowTriggerRegistrationRequest.RegisteredWorkflowsEntry - (*timestamppb.Timestamp)(nil), // 8: google.protobuf.Timestamp + (*GetWorkflowShardMappingRequest)(nil), // 0: shardorchestrator.GetWorkflowShardMappingRequest + (*WorkflowMappingState)(nil), // 1: shardorchestrator.WorkflowMappingState + (*GetWorkflowShardMappingResponse)(nil), // 2: shardorchestrator.GetWorkflowShardMappingResponse + (*ReportWorkflowTriggerRegistrationRequest)(nil), // 3: shardorchestrator.ReportWorkflowTriggerRegistrationRequest + (*ReportWorkflowTriggerRegistrationResponse)(nil), // 4: shardorchestrator.ReportWorkflowTriggerRegistrationResponse + nil, // 5: shardorchestrator.GetWorkflowShardMappingResponse.MappingsEntry + nil, // 6: shardorchestrator.GetWorkflowShardMappingResponse.MappingStatesEntry + nil, // 7: shardorchestrator.ReportWorkflowTriggerRegistrationRequest.RegisteredWorkflowsEntry } var file_shard_orchestrator_proto_depIdxs = []int32{ - 8, // 0: ring.WorkflowMappingState.last_updated:type_name -> google.protobuf.Timestamp - 5, // 1: ring.GetWorkflowShardMappingResponse.mappings:type_name -> ring.GetWorkflowShardMappingResponse.MappingsEntry - 6, // 2: ring.GetWorkflowShardMappingResponse.mapping_states:type_name -> ring.GetWorkflowShardMappingResponse.MappingStatesEntry - 8, // 3: ring.GetWorkflowShardMappingResponse.timestamp:type_name -> google.protobuf.Timestamp - 7, // 4: ring.ReportWorkflowTriggerRegistrationRequest.registered_workflows:type_name -> ring.ReportWorkflowTriggerRegistrationRequest.RegisteredWorkflowsEntry - 8, // 5: ring.ReportWorkflowTriggerRegistrationRequest.report_timestamp:type_name -> google.protobuf.Timestamp - 1, // 6: ring.GetWorkflowShardMappingResponse.MappingStatesEntry.value:type_name -> ring.WorkflowMappingState - 0, // 7: ring.ShardOrchestratorService.GetWorkflowShardMapping:input_type -> ring.GetWorkflowShardMappingRequest - 3, // 8: ring.ShardOrchestratorService.ReportWorkflowTriggerRegistration:input_type -> ring.ReportWorkflowTriggerRegistrationRequest - 2, // 9: ring.ShardOrchestratorService.GetWorkflowShardMapping:output_type -> ring.GetWorkflowShardMappingResponse - 4, // 10: ring.ShardOrchestratorService.ReportWorkflowTriggerRegistration:output_type -> ring.ReportWorkflowTriggerRegistrationResponse - 9, // [9:11] is the sub-list for method output_type - 7, // [7:9] is the sub-list for method input_type - 7, // [7:7] is the sub-list for extension type_name - 7, // [7:7] is the sub-list for extension extendee - 0, // [0:7] is the sub-list for field type_name + 5, // 0: shardorchestrator.GetWorkflowShardMappingResponse.mappings:type_name -> shardorchestrator.GetWorkflowShardMappingResponse.MappingsEntry + 6, // 1: shardorchestrator.GetWorkflowShardMappingResponse.mapping_states:type_name -> shardorchestrator.GetWorkflowShardMappingResponse.MappingStatesEntry + 7, // 2: shardorchestrator.ReportWorkflowTriggerRegistrationRequest.registered_workflows:type_name -> shardorchestrator.ReportWorkflowTriggerRegistrationRequest.RegisteredWorkflowsEntry + 1, // 3: shardorchestrator.GetWorkflowShardMappingResponse.MappingStatesEntry.value:type_name -> shardorchestrator.WorkflowMappingState + 0, // 4: shardorchestrator.ShardOrchestratorService.GetWorkflowShardMapping:input_type -> shardorchestrator.GetWorkflowShardMappingRequest + 3, // 5: shardorchestrator.ShardOrchestratorService.ReportWorkflowTriggerRegistration:input_type -> shardorchestrator.ReportWorkflowTriggerRegistrationRequest + 2, // 6: shardorchestrator.ShardOrchestratorService.GetWorkflowShardMapping:output_type -> shardorchestrator.GetWorkflowShardMappingResponse + 4, // 7: shardorchestrator.ShardOrchestratorService.ReportWorkflowTriggerRegistration:output_type -> shardorchestrator.ReportWorkflowTriggerRegistrationResponse + 6, // [6:8] is the sub-list for method output_type + 4, // [4:6] is the sub-list for method input_type + 4, // [4:4] is the sub-list for extension type_name + 4, // [4:4] is the sub-list for extension extendee + 0, // [0:4] is the sub-list for field type_name } func init() { file_shard_orchestrator_proto_init() } diff --git a/pkg/workflows/ring/pb/shard_orchestrator.proto b/pkg/workflows/shardorchestrator/pb/shard_orchestrator.proto similarity index 78% rename from pkg/workflows/ring/pb/shard_orchestrator.proto rename to pkg/workflows/shardorchestrator/pb/shard_orchestrator.proto index c7e3c1668..1d9fe6a6c 100644 --- a/pkg/workflows/ring/pb/shard_orchestrator.proto +++ b/pkg/workflows/shardorchestrator/pb/shard_orchestrator.proto @@ -1,10 +1,8 @@ syntax = "proto3"; -package ring; +package shardorchestrator; -import "google/protobuf/timestamp.proto"; - -option go_package = "github.com/smartcontractkit/chainlink-common/pkg/workflows/ring/pb"; +option go_package = "github.com/smartcontractkit/chainlink-common/pkg/workflows/shardorchestrator/pb"; message GetWorkflowShardMappingRequest { repeated string workflow_ids = 1; @@ -14,21 +12,18 @@ message WorkflowMappingState { uint32 old_shard_id = 1; uint32 new_shard_id = 2; bool in_transition = 3; - google.protobuf.Timestamp last_updated = 4; } message GetWorkflowShardMappingResponse { map mappings = 1; map mapping_states = 2; - google.protobuf.Timestamp timestamp = 3; - uint64 mapping_version = 4; + uint64 mapping_version = 3; } message ReportWorkflowTriggerRegistrationRequest { uint32 source_shard_id = 1; map registered_workflows = 2; - google.protobuf.Timestamp report_timestamp = 3; - uint32 total_active_workflows = 4; + uint32 total_active_workflows = 3; } message ReportWorkflowTriggerRegistrationResponse { diff --git a/pkg/workflows/ring/pb/shard_orchestrator_grpc.pb.go b/pkg/workflows/shardorchestrator/pb/shard_orchestrator_grpc.pb.go similarity index 96% rename from pkg/workflows/ring/pb/shard_orchestrator_grpc.pb.go rename to pkg/workflows/shardorchestrator/pb/shard_orchestrator_grpc.pb.go index d2ab234c3..d099fd0ac 100644 --- a/pkg/workflows/ring/pb/shard_orchestrator_grpc.pb.go +++ b/pkg/workflows/shardorchestrator/pb/shard_orchestrator_grpc.pb.go @@ -19,8 +19,8 @@ import ( const _ = grpc.SupportPackageIsVersion9 const ( - ShardOrchestratorService_GetWorkflowShardMapping_FullMethodName = "/ring.ShardOrchestratorService/GetWorkflowShardMapping" - ShardOrchestratorService_ReportWorkflowTriggerRegistration_FullMethodName = "/ring.ShardOrchestratorService/ReportWorkflowTriggerRegistration" + ShardOrchestratorService_GetWorkflowShardMapping_FullMethodName = "/shardorchestrator.ShardOrchestratorService/GetWorkflowShardMapping" + ShardOrchestratorService_ReportWorkflowTriggerRegistration_FullMethodName = "/shardorchestrator.ShardOrchestratorService/ReportWorkflowTriggerRegistration" ) // ShardOrchestratorServiceClient is the client API for ShardOrchestratorService service. @@ -143,7 +143,7 @@ func _ShardOrchestratorService_ReportWorkflowTriggerRegistration_Handler(srv int // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) var ShardOrchestratorService_ServiceDesc = grpc.ServiceDesc{ - ServiceName: "ring.ShardOrchestratorService", + ServiceName: "shardorchestrator.ShardOrchestratorService", HandlerType: (*ShardOrchestratorServiceServer)(nil), Methods: []grpc.MethodDesc{ { diff --git a/pkg/workflows/shardorchestrator/service.go b/pkg/workflows/shardorchestrator/service.go new file mode 100644 index 000000000..4d0b8cb62 --- /dev/null +++ b/pkg/workflows/shardorchestrator/service.go @@ -0,0 +1,108 @@ +package shardorchestrator + +import ( + "context" + "fmt" + + "google.golang.org/grpc" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/shardorchestrator/pb" +) + +// Server implements the gRPC ShardOrchestratorService +// This runs on shard zero and serves requests from other shards +type Server struct { + pb.UnimplementedShardOrchestratorServiceServer + store *Store + logger logger.Logger +} + +func NewServer(store *Store, lggr logger.Logger) *Server { + return &Server{ + store: store, + logger: logger.Named(lggr, "ShardOrchestratorServer"), + } +} + +// RegisterWithGRPCServer registers this service with a gRPC server +func (s *Server) RegisterWithGRPCServer(grpcServer *grpc.Server) { + pb.RegisterShardOrchestratorServiceServer(grpcServer, s) + s.logger.Info("Registered ShardOrchestrator gRPC service") +} + +// GetWorkflowShardMapping handles batch requests for workflow-to-shard mappings +// This is called by other shards to determine where to route workflow executions +func (s *Server) GetWorkflowShardMapping(ctx context.Context, req *pb.GetWorkflowShardMappingRequest) (*pb.GetWorkflowShardMappingResponse, error) { + s.logger.Debugw("GetWorkflowShardMapping called", "workflowCount", len(req.WorkflowIds)) + + if len(req.WorkflowIds) == 0 { + return nil, fmt.Errorf("workflow_ids is required and must not be empty") + } + + // Retrieve batch from store + mappings, version, err := s.store.GetWorkflowMappingsBatch(ctx, req.WorkflowIds) + if err != nil { + s.logger.Errorw("Failed to get workflow mappings", "error", err) + return nil, fmt.Errorf("failed to get workflow mappings: %w", err) + } + + // Build simple mappings map (workflow_id -> shard_id) + simpleMappings := make(map[string]uint32, len(mappings)) + // Build detailed mapping states + mappingStates := make(map[string]*pb.WorkflowMappingState, len(mappings)) + + for workflowID, mapping := range mappings { + // Simple mapping: just the current shard + simpleMappings[workflowID] = mapping.NewShardID + + // Detailed state: includes transition information + mappingStates[workflowID] = &pb.WorkflowMappingState{ + OldShardId: mapping.OldShardID, + NewShardId: mapping.NewShardID, + InTransition: mapping.TransitionState.InTransition(), + } + } + + return &pb.GetWorkflowShardMappingResponse{ + Mappings: simpleMappings, + MappingStates: mappingStates, + MappingVersion: version, + }, nil +} + +// ReportWorkflowTriggerRegistration handles shard registration reports +// Shards call this to inform shard zero about which workflows they have loaded +func (s *Server) ReportWorkflowTriggerRegistration(ctx context.Context, req *pb.ReportWorkflowTriggerRegistrationRequest) (*pb.ReportWorkflowTriggerRegistrationResponse, error) { + s.logger.Debugw("ReportWorkflowTriggerRegistration called", + "shardID", req.SourceShardId, + "workflowCount", len(req.RegisteredWorkflows), + "totalActive", req.TotalActiveWorkflows, + ) + + // Extract workflow IDs from the map + workflowIDs := make([]string, 0, len(req.RegisteredWorkflows)) + for workflowID := range req.RegisteredWorkflows { + workflowIDs = append(workflowIDs, workflowID) + } + + err := s.store.ReportShardRegistration(ctx, req.SourceShardId, workflowIDs) + if err != nil { + s.logger.Errorw("Failed to update shard registrations", + "shardID", req.SourceShardId, + "error", err, + ) + return &pb.ReportWorkflowTriggerRegistrationResponse{ + Success: false, + }, nil + } + + s.logger.Infow("Successfully registered workflows", + "shardID", req.SourceShardId, + "workflowCount", len(workflowIDs), + ) + + return &pb.ReportWorkflowTriggerRegistrationResponse{ + Success: true, + }, nil +} diff --git a/pkg/workflows/shardorchestrator/service_test.go b/pkg/workflows/shardorchestrator/service_test.go new file mode 100644 index 000000000..bbe8766cc --- /dev/null +++ b/pkg/workflows/shardorchestrator/service_test.go @@ -0,0 +1,118 @@ +package shardorchestrator_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/shardorchestrator" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/shardorchestrator/pb" +) + +func TestServer_GetWorkflowShardMapping(t *testing.T) { + ctx := context.Background() + lggr := logger.Test(t) + + t.Run("returns_mappings_for_multiple_workflows", func(t *testing.T) { + store := shardorchestrator.NewStore(lggr) + server := shardorchestrator.NewServer(store, lggr) + + // Set up some workflow mappings + mappings := []*shardorchestrator.WorkflowMappingState{ + { + WorkflowID: "wf-alpha", + OldShardID: 0, + NewShardID: 1, + TransitionState: shardorchestrator.StateSteady, + }, + { + WorkflowID: "wf-beta", + OldShardID: 0, + NewShardID: 2, + TransitionState: shardorchestrator.StateSteady, + }, + { + WorkflowID: "wf-gamma", + OldShardID: 1, + NewShardID: 0, + TransitionState: shardorchestrator.StateTransitioning, + }, + } + err := store.BatchUpdateWorkflowMappings(ctx, mappings) + require.NoError(t, err) + + // Request all three workflows + req := &pb.GetWorkflowShardMappingRequest{ + WorkflowIds: []string{"wf-alpha", "wf-beta", "wf-gamma"}, + } + + resp, err := server.GetWorkflowShardMapping(ctx, req) + require.NoError(t, err) + require.NotNil(t, resp) + + // Verify simple mappings + require.Len(t, resp.Mappings, 3) + require.Equal(t, uint32(1), resp.Mappings["wf-alpha"]) + require.Equal(t, uint32(2), resp.Mappings["wf-beta"]) + require.Equal(t, uint32(0), resp.Mappings["wf-gamma"]) + + // Verify detailed mapping states + require.Len(t, resp.MappingStates, 3) + + // wf-alpha: steady state + alphaState := resp.MappingStates["wf-alpha"] + require.Equal(t, uint32(0), alphaState.OldShardId) + require.Equal(t, uint32(1), alphaState.NewShardId) + require.False(t, alphaState.InTransition, "steady state should not be in transition") + + // wf-gamma: transitioning state + gammaState := resp.MappingStates["wf-gamma"] + require.Equal(t, uint32(1), gammaState.OldShardId) + require.Equal(t, uint32(0), gammaState.NewShardId) + require.True(t, gammaState.InTransition, "transitioning state should be in transition") + + // Verify version + require.Equal(t, uint64(1), resp.MappingVersion) + }) + + t.Run("rejects_empty_workflow_ids", func(t *testing.T) { + store := shardorchestrator.NewStore(lggr) + server := shardorchestrator.NewServer(store, lggr) + + req := &pb.GetWorkflowShardMappingRequest{ + WorkflowIds: []string{}, + } + + resp, err := server.GetWorkflowShardMapping(ctx, req) + require.Error(t, err) + require.Nil(t, resp) + require.Contains(t, err.Error(), "required") + }) + + t.Run("handles_partial_results_for_nonexistent_workflows", func(t *testing.T) { + store := shardorchestrator.NewStore(lggr) + server := shardorchestrator.NewServer(store, lggr) + + // Add one workflow + err := store.BatchUpdateWorkflowMappings(ctx, []*shardorchestrator.WorkflowMappingState{ + {WorkflowID: "exists", NewShardID: 1, TransitionState: shardorchestrator.StateSteady}, + }) + require.NoError(t, err) + + // Request one that exists and one that doesn't - batch query silently skips missing workflows + req := &pb.GetWorkflowShardMappingRequest{ + WorkflowIds: []string{"exists", "does-not-exist"}, + } + + resp, err := server.GetWorkflowShardMapping(ctx, req) + require.NoError(t, err) + require.NotNil(t, resp) + + // Only the existing workflow is returned + require.Len(t, resp.Mappings, 1) + require.Equal(t, uint32(1), resp.Mappings["exists"]) + require.NotContains(t, resp.Mappings, "does-not-exist") + }) +} diff --git a/pkg/workflows/shardorchestrator/store.go b/pkg/workflows/shardorchestrator/store.go new file mode 100644 index 000000000..4da391bab --- /dev/null +++ b/pkg/workflows/shardorchestrator/store.go @@ -0,0 +1,235 @@ +package shardorchestrator + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" +) + +// TransitionState represents the state of a workflow's shard assignment +type TransitionState uint8 + +const ( + StateSteady TransitionState = iota + StateTransitioning +) + +// String returns the string representation of the TransitionState +func (s TransitionState) String() string { + switch s { + case StateSteady: + return "steady" + case StateTransitioning: + return "transitioning" + default: + return "unknown" + } +} + +// InTransition returns true if the state is transitioning +func (s TransitionState) InTransition() bool { + return s == StateTransitioning +} + +// WorkflowMappingState represents the state of a workflow assignment +type WorkflowMappingState struct { + WorkflowID string + OldShardID uint32 + NewShardID uint32 + TransitionState TransitionState + UpdatedAt time.Time +} + +// Store manages workflow-to-shard mappings that will be exposed via gRPC +// RingOCR plugin updates this store, and the gRPC service reads from it +type Store struct { + // workflowMappings tracks the current shard assignment for each workflow + workflowMappings map[string]*WorkflowMappingState // workflow_id -> mapping state + + // shardRegistrations tracks what workflows each shard has registered + // This is populated by ReportWorkflowTriggerRegistration calls from shards + shardRegistrations map[uint32]map[string]bool // shard_id -> set of workflow_ids + + // mappingVersion increments on any change to workflowMappings + // Used by clients for cache invalidation + mappingVersion uint64 + + // lastUpdateTime tracks when mappings were last modified + lastUpdateTime time.Time + + mu sync.RWMutex + logger logger.Logger +} + +func NewStore(lggr logger.Logger) *Store { + return &Store{ + workflowMappings: make(map[string]*WorkflowMappingState), + shardRegistrations: make(map[uint32]map[string]bool), + mappingVersion: 0, + lastUpdateTime: time.Now(), + logger: logger.Named(lggr, "ShardOrchestratorStore"), + } +} + +// UpdateWorkflowMapping is called by RingOCR to update workflow assignments +// This is the primary data source for shard orchestration +func (s *Store) UpdateWorkflowMapping(ctx context.Context, workflowID string, oldShardID, newShardID uint32, state TransitionState) error { + s.mu.Lock() + defer s.mu.Unlock() + + now := time.Now() + s.workflowMappings[workflowID] = &WorkflowMappingState{ + WorkflowID: workflowID, + OldShardID: oldShardID, + NewShardID: newShardID, + TransitionState: state, + UpdatedAt: now, + } + + s.mappingVersion++ + s.lastUpdateTime = now + + s.logger.Debugw("Updated workflow mapping", + "workflowID", workflowID, + "oldShardID", oldShardID, + "newShardID", newShardID, + "state", state.String(), + "version", s.mappingVersion, + ) + + return nil +} + +// BatchUpdateWorkflowMappings allows RingOCR to update multiple mappings atomically +func (s *Store) BatchUpdateWorkflowMappings(ctx context.Context, mappings []*WorkflowMappingState) error { + s.mu.Lock() + defer s.mu.Unlock() + + now := time.Now() + for _, mapping := range mappings { + s.workflowMappings[mapping.WorkflowID] = &WorkflowMappingState{ + WorkflowID: mapping.WorkflowID, + OldShardID: mapping.OldShardID, + NewShardID: mapping.NewShardID, + TransitionState: mapping.TransitionState, + UpdatedAt: now, + } + } + + s.mappingVersion++ + s.lastUpdateTime = now + + s.logger.Debugw("Batch updated workflow mappings", "count", len(mappings), "version", s.mappingVersion) + return nil +} + +// GetWorkflowMapping retrieves the shard assignment for a specific workflow +// This is called by the gRPC service to respond to GetWorkflowShardMapping requests +func (s *Store) GetWorkflowMapping(ctx context.Context, workflowID string) (*WorkflowMappingState, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + mapping, ok := s.workflowMappings[workflowID] + if !ok { + return nil, fmt.Errorf("workflow %s not found in shard mappings", workflowID) + } + + // Return a copy to avoid external mutations + return &WorkflowMappingState{ + WorkflowID: mapping.WorkflowID, + OldShardID: mapping.OldShardID, + NewShardID: mapping.NewShardID, + TransitionState: mapping.TransitionState, + UpdatedAt: mapping.UpdatedAt, + }, nil +} + +// GetAllWorkflowMappings returns all current workflow-to-shard assignments +func (s *Store) GetAllWorkflowMappings(ctx context.Context) ([]*WorkflowMappingState, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + mappings := make([]*WorkflowMappingState, 0, len(s.workflowMappings)) + for _, mapping := range s.workflowMappings { + mappings = append(mappings, &WorkflowMappingState{ + WorkflowID: mapping.WorkflowID, + OldShardID: mapping.OldShardID, + NewShardID: mapping.NewShardID, + TransitionState: mapping.TransitionState, + UpdatedAt: mapping.UpdatedAt, + }) + } + + return mappings, nil +} + +// ReportShardRegistration is called when a shard reports its registered workflows +// This helps track which workflows each shard has successfully loaded +func (s *Store) ReportShardRegistration(ctx context.Context, shardID uint32, workflowIDs []string) error { + s.mu.Lock() + defer s.mu.Unlock() + + // Clear and update + s.shardRegistrations[shardID] = make(map[string]bool) + for _, wfID := range workflowIDs { + s.shardRegistrations[shardID][wfID] = true + } + + s.logger.Debugw("Updated shard registrations", + "shardID", shardID, + "workflowCount", len(workflowIDs), + ) + + return nil +} + +// GetShardRegistrations returns the workflows registered on a specific shard +func (s *Store) GetShardRegistrations(ctx context.Context, shardID uint32) ([]string, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + workflows, ok := s.shardRegistrations[shardID] + if !ok { + return []string{}, nil + } + + result := make([]string, 0, len(workflows)) + for wfID := range workflows { + result = append(result, wfID) + } + + return result, nil +} + +// GetWorkflowMappingsBatch retrieves mappings for multiple workflows +func (s *Store) GetWorkflowMappingsBatch(ctx context.Context, workflowIDs []string) (map[string]*WorkflowMappingState, uint64, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + result := make(map[string]*WorkflowMappingState, len(workflowIDs)) + + for _, workflowID := range workflowIDs { + if mapping, ok := s.workflowMappings[workflowID]; ok { + // Return a copy to avoid external mutations + result[workflowID] = &WorkflowMappingState{ + WorkflowID: mapping.WorkflowID, + OldShardID: mapping.OldShardID, + NewShardID: mapping.NewShardID, + TransitionState: mapping.TransitionState, + UpdatedAt: mapping.UpdatedAt, + } + } + } + + return result, s.mappingVersion, nil +} + +// GetMappingVersion returns the current version of the mapping set +func (s *Store) GetMappingVersion() uint64 { + s.mu.RLock() + defer s.mu.RUnlock() + return s.mappingVersion +} diff --git a/pkg/workflows/shardorchestrator/store_test.go b/pkg/workflows/shardorchestrator/store_test.go new file mode 100644 index 000000000..6ad2af5a0 --- /dev/null +++ b/pkg/workflows/shardorchestrator/store_test.go @@ -0,0 +1,198 @@ +package shardorchestrator_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/workflows/shardorchestrator" +) + +func TestStore_BatchUpdateAndQuery(t *testing.T) { + ctx := context.Background() + lggr := logger.Test(t) + store := shardorchestrator.NewStore(lggr) + + // Create and insert multiple workflow mappings + mappings := []*shardorchestrator.WorkflowMappingState{ + { + WorkflowID: "workflow-1", + OldShardID: 0, + NewShardID: 1, + TransitionState: shardorchestrator.StateSteady, + }, + { + WorkflowID: "workflow-2", + OldShardID: 0, + NewShardID: 2, + TransitionState: shardorchestrator.StateSteady, + }, + { + WorkflowID: "workflow-3", + OldShardID: 0, + NewShardID: 1, + TransitionState: shardorchestrator.StateSteady, + }, + } + + err := store.BatchUpdateWorkflowMappings(ctx, mappings) + require.NoError(t, err) + + // Query individual workflow + mapping1, err := store.GetWorkflowMapping(ctx, "workflow-1") + require.NoError(t, err) + assert.Equal(t, uint32(1), mapping1.NewShardID) + assert.Equal(t, shardorchestrator.StateSteady, mapping1.TransitionState) + + // Query all workflows + allMappings, err := store.GetAllWorkflowMappings(ctx) + require.NoError(t, err) + assert.Len(t, allMappings, 3) + + // Query batch + batchMappings, version, err := store.GetWorkflowMappingsBatch(ctx, []string{"workflow-1", "workflow-2"}) + require.NoError(t, err) + assert.Len(t, batchMappings, 2) + assert.Equal(t, uint64(1), version) // First update +} + +func TestStore_WorkflowTransition(t *testing.T) { + ctx := context.Background() + lggr := logger.Test(t) + store := shardorchestrator.NewStore(lggr) + + // Initial assignment + err := store.UpdateWorkflowMapping(ctx, "workflow-123", 0, 1, shardorchestrator.StateSteady) + require.NoError(t, err) + + mapping, err := store.GetWorkflowMapping(ctx, "workflow-123") + require.NoError(t, err) + assert.Equal(t, uint32(1), mapping.NewShardID) + assert.Equal(t, shardorchestrator.StateSteady, mapping.TransitionState) + + // Move to different shard (transitioning) + err = store.UpdateWorkflowMapping(ctx, "workflow-123", 1, 3, shardorchestrator.StateTransitioning) + require.NoError(t, err) + + mapping, err = store.GetWorkflowMapping(ctx, "workflow-123") + require.NoError(t, err) + assert.Equal(t, uint32(1), mapping.OldShardID) + assert.Equal(t, uint32(3), mapping.NewShardID) + assert.Equal(t, shardorchestrator.StateTransitioning, mapping.TransitionState) + + // Complete transition + err = store.UpdateWorkflowMapping(ctx, "workflow-123", 1, 3, shardorchestrator.StateSteady) + require.NoError(t, err) + + mapping, err = store.GetWorkflowMapping(ctx, "workflow-123") + require.NoError(t, err) + assert.Equal(t, uint32(3), mapping.NewShardID) + assert.Equal(t, shardorchestrator.StateSteady, mapping.TransitionState) +} + +func TestStore_VersionTracking(t *testing.T) { + ctx := context.Background() + lggr := logger.Test(t) + store := shardorchestrator.NewStore(lggr) + + // Initial version should be 0 + assert.Equal(t, uint64(0), store.GetMappingVersion()) + + // First update increments version + err := store.UpdateWorkflowMapping(ctx, "wf-1", 0, 1, shardorchestrator.StateSteady) + require.NoError(t, err) + assert.Equal(t, uint64(1), store.GetMappingVersion()) + + // Batch update increments version + err = store.BatchUpdateWorkflowMappings(ctx, []*shardorchestrator.WorkflowMappingState{ + {WorkflowID: "wf-2", NewShardID: 2, TransitionState: shardorchestrator.StateSteady}, + }) + require.NoError(t, err) + assert.Equal(t, uint64(2), store.GetMappingVersion()) + + // Version is included in batch query response + _, version, err := store.GetWorkflowMappingsBatch(ctx, []string{"wf-1", "wf-2"}) + require.NoError(t, err) + assert.Equal(t, uint64(2), version) +} + +func TestStore_ShardRegistrations(t *testing.T) { + ctx := context.Background() + lggr := logger.Test(t) + store := shardorchestrator.NewStore(lggr) + + // Shard 1 reports its workflows + err := store.ReportShardRegistration(ctx, 1, []string{"workflow-1", "workflow-3"}) + require.NoError(t, err) + + // Shard 2 reports its workflows + err = store.ReportShardRegistration(ctx, 2, []string{"workflow-2"}) + require.NoError(t, err) + + // Query shard registrations + shard1Workflows, err := store.GetShardRegistrations(ctx, 1) + require.NoError(t, err) + assert.Len(t, shard1Workflows, 2) + assert.Contains(t, shard1Workflows, "workflow-1") + assert.Contains(t, shard1Workflows, "workflow-3") + + shard2Workflows, err := store.GetShardRegistrations(ctx, 2) + require.NoError(t, err) + assert.Len(t, shard2Workflows, 1) + assert.Contains(t, shard2Workflows, "workflow-2") + + // Query non-existent shard returns empty + shard3Workflows, err := store.GetShardRegistrations(ctx, 3) + require.NoError(t, err) + assert.Empty(t, shard3Workflows) + + // Re-reporting replaces previous registrations + err = store.ReportShardRegistration(ctx, 1, []string{"workflow-1"}) + require.NoError(t, err) + + shard1Workflows, err = store.GetShardRegistrations(ctx, 1) + require.NoError(t, err) + assert.Len(t, shard1Workflows, 1) + assert.Contains(t, shard1Workflows, "workflow-1") + assert.NotContains(t, shard1Workflows, "workflow-3") // Removed +} + +func TestStore_NotFoundError(t *testing.T) { + ctx := context.Background() + lggr := logger.Test(t) + store := shardorchestrator.NewStore(lggr) + + // Query non-existent workflow + _, err := store.GetWorkflowMapping(ctx, "non-existent") + require.Error(t, err) + assert.Contains(t, err.Error(), "not found") +} + +func TestStore_BatchQueryPartialResults(t *testing.T) { + ctx := context.Background() + lggr := logger.Test(t) + store := shardorchestrator.NewStore(lggr) + + // Insert only some workflows + err := store.UpdateWorkflowMapping(ctx, "exists-1", 0, 1, shardorchestrator.StateSteady) + require.NoError(t, err) + err = store.UpdateWorkflowMapping(ctx, "exists-2", 0, 2, shardorchestrator.StateSteady) + require.NoError(t, err) + + // Query mix of existing and non-existing workflows + results, _, err := store.GetWorkflowMappingsBatch(ctx, []string{ + "exists-1", + "non-existent", + "exists-2", + }) + require.NoError(t, err) + + // Should only return existing ones + assert.Len(t, results, 2) + assert.Contains(t, results, "exists-1") + assert.Contains(t, results, "exists-2") + assert.NotContains(t, results, "non-existent") +}