diff --git a/client.go b/client.go index 3f29e545..044a237d 100644 --- a/client.go +++ b/client.go @@ -23,6 +23,7 @@ import ( "github.com/riverqueue/river/internal/middlewarelookup" "github.com/riverqueue/river/internal/notifier" "github.com/riverqueue/river/internal/notifylimiter" + "github.com/riverqueue/river/internal/pluginconfig" "github.com/riverqueue/river/internal/rivercommon" "github.com/riverqueue/river/internal/rivermiddleware" "github.com/riverqueue/river/internal/workunit" @@ -219,6 +220,9 @@ type Config struct { // work hook runs and the insertion hooks on either side of it are skipped. // // Jobs may have their own specific hooks by implementing JobArgsWithHooks. + // + // If a type in Hooks also implements rivertype.Middleware, it will be + // installed as middleware too. Hooks []rivertype.Hook // Logger is the structured logger to use for logging purposes. If none is @@ -252,8 +256,25 @@ type Config struct { // middlewares will run one after another, and the work middleware between // them will not run. When a job is worked, the work middleware runs and the // insertion middlewares on either side of it are skipped. + // + // If a type in Middleware also implements rivertype.Hook, it will be + // installed as a hook too. Middleware []rivertype.Middleware + // Plugins contains extensions installed globally as both hooks and + // middleware. + // + // A plugin must implement both rivertype.Hook and rivertype.Middleware. It + // may embed PluginDefaults, or embed both HookDefaults and + // MiddlewareDefaults directly, then implement any operation-specific hook or + // middleware interfaces it needs. + // + // Hooks and Middleware are still supported. If a type in Hooks also + // implements middleware, or a type in Middleware also implements hooks, River + // will install it on both sides automatically. The Plugins list exists as a + // generic place for new extensions to be registered. + Plugins []rivertype.Plugin + // PeriodicJobs are a set of periodic jobs to run at the specified intervals // in the client. PeriodicJobs []*PeriodicJob @@ -476,6 +497,7 @@ func (c *Config) WithDefaults() *Config { MaxAttempts: cmp.Or(c.MaxAttempts, MaxAttemptsDefault), Middleware: c.Middleware, PeriodicJobs: c.PeriodicJobs, + Plugins: c.Plugins, PollOnly: c.PollOnly, Queues: c.Queues, ReindexerIndexNames: reindexerIndexNames, @@ -774,7 +796,11 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client } } - for _, hook := range config.Hooks { + configuredMiddleware := middlewareFromConfig(config) + effectiveHooks := pluginconfig.Hooks(config.Hooks, configuredMiddleware, config.Plugins) + effectiveMiddleware := pluginconfig.Middleware(config.Hooks, configuredMiddleware, config.Plugins) + + for _, hook := range effectiveHooks { if withBaseService, ok := hook.(baseservice.WithBaseService); ok { baseservice.Init(archetype, withBaseService) } @@ -788,7 +814,7 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client config: config, driver: driver, hookLookupByJob: hooklookup.NewJobHookLookup(), - hookLookupGlobal: hooklookup.NewHookLookup(config.Hooks), + hookLookupGlobal: hooklookup.NewHookLookup(effectiveHooks), producersByQueueName: make(map[string]*producer), testSignals: clientTestSignals{}, workCancel: func(cause error) {}, // replaced on start, but here in case StopAndCancel is called before start up @@ -806,31 +832,12 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client client.baseService.Name = "Client" // Have to correct the name because base service isn't embedded like it usually is client.insertNotifyLimiter = notifylimiter.NewLimiter(archetype, config.FetchCooldown) - // Validation ensures that config.JobInsertMiddleware/WorkerMiddleware or - // the more abstract config.Middleware for middleware are set, but not both, - // so in practice we never append all three of these to each other. + // effectiveMiddleware contains configured middleware, hook/middleware + // hybrids, and plugins. Default middleware stays first so user middleware + // wraps inside River's internal defaults like before. { middleware := rivermiddleware.DefaultMiddleware() - middleware = append(middleware, config.Middleware...) - for _, jobInsertMiddleware := range config.JobInsertMiddleware { - middleware = append(middleware, jobInsertMiddleware) - } - outerLoop: - for _, workerMiddleware := range config.WorkerMiddleware { - // Don't add the middleware if it also implements JobInsertMiddleware - // and the instance has been added to config.JobInsertMiddleware. This - // is a hedge to make sure we don't accidentally double add middleware - // as we've converted over to the unified config.Middleware setting. - if workerMiddlewareAsJobInsertMiddleware, ok := workerMiddleware.(rivertype.JobInsertMiddleware); ok { - for _, jobInsertMiddleware := range config.JobInsertMiddleware { - if workerMiddlewareAsJobInsertMiddleware == jobInsertMiddleware { - continue outerLoop - } - } - } - - middleware = append(middleware, workerMiddleware) - } + middleware = append(middleware, effectiveMiddleware...) for _, middleware := range middleware { if withBaseService, ok := middleware.(baseservice.WithBaseService); ok { @@ -1040,6 +1047,35 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client return client, nil } +func middlewareFromConfig(config *Config) []rivertype.Middleware { + middleware := make([]rivertype.Middleware, 0, + len(config.Middleware)+len(config.JobInsertMiddleware)+len(config.WorkerMiddleware)) + middleware = append(middleware, config.Middleware...) + + for _, jobInsertMiddleware := range config.JobInsertMiddleware { + middleware = append(middleware, jobInsertMiddleware) + } + +outerLoop: + for _, workerMiddleware := range config.WorkerMiddleware { + // Don't add the middleware if it also implements JobInsertMiddleware + // and the instance has been added to config.JobInsertMiddleware. This + // is a hedge to make sure we don't accidentally double add middleware + // as we've converted over to the unified config.Middleware setting. + if workerMiddlewareAsJobInsertMiddleware, ok := workerMiddleware.(rivertype.JobInsertMiddleware); ok { + for _, jobInsertMiddleware := range config.JobInsertMiddleware { + if workerMiddlewareAsJobInsertMiddleware == jobInsertMiddleware { + continue outerLoop + } + } + } + + middleware = append(middleware, workerMiddleware) + } + + return middleware +} + // Start starts the client's job fetching and working loops. Once this is called, // the client will run in a background goroutine until stopped. All jobs are // run with a context inheriting from the provided context, but with a timeout diff --git a/client_test.go b/client_test.go index 9352dcbb..94304c09 100644 --- a/client_test.go +++ b/client_test.go @@ -70,6 +70,77 @@ type noOpWorker struct { func (w *noOpWorker) Work(ctx context.Context, job *Job[noOpArgs]) error { return nil } +var ( + _ rivertype.HookInsertBegin = &hookMiddlewareEmbeddedDefaultsPlugin{} + _ rivertype.JobInsertMiddleware = &hookMiddlewareEmbeddedDefaultsPlugin{} + _ rivertype.Plugin = &hookMiddlewareEmbeddedDefaultsPlugin{} + + _ rivertype.HookInsertBegin = &hookMiddlewarePlugin{} + _ rivertype.JobInsertMiddleware = &hookMiddlewarePlugin{} + _ rivertype.Plugin = &hookMiddlewarePlugin{} + + _ rivertype.HookInsertBegin = hookMiddlewareValuePlugin{} + _ rivertype.JobInsertMiddleware = hookMiddlewareValuePlugin{} +) + +type hookMiddlewareEmbeddedDefaultsPlugin struct { + HookDefaults + MiddlewareDefaults + + insertBeginCount int + insertManyCount int +} + +func (p *hookMiddlewareEmbeddedDefaultsPlugin) InsertBegin(ctx context.Context, params *rivertype.JobInsertParams) error { + p.insertBeginCount++ + return nil +} + +func (p *hookMiddlewareEmbeddedDefaultsPlugin) InsertMany(ctx context.Context, manyParams []*rivertype.JobInsertParams, doInner func(context.Context) ([]*rivertype.JobInsertResult, error)) ([]*rivertype.JobInsertResult, error) { + p.insertManyCount++ + return doInner(ctx) +} + +type hookMiddlewarePlugin struct { + PluginDefaults + + insertBeginCount int + insertManyCount int +} + +func (p *hookMiddlewarePlugin) InsertBegin(ctx context.Context, params *rivertype.JobInsertParams) error { + p.insertBeginCount++ + return nil +} + +func (p *hookMiddlewarePlugin) InsertMany(ctx context.Context, manyParams []*rivertype.JobInsertParams, doInner func(context.Context) ([]*rivertype.JobInsertResult, error)) ([]*rivertype.JobInsertResult, error) { + p.insertManyCount++ + return doInner(ctx) +} + +type hookMiddlewareValuePlugin struct { + counts *hookMiddlewareValuePluginCounts +} + +func (p hookMiddlewareValuePlugin) InsertBegin(ctx context.Context, params *rivertype.JobInsertParams) error { + p.counts.insertBeginCount++ + return nil +} + +func (p hookMiddlewareValuePlugin) InsertMany(ctx context.Context, manyParams []*rivertype.JobInsertParams, doInner func(context.Context) ([]*rivertype.JobInsertResult, error)) ([]*rivertype.JobInsertResult, error) { + p.counts.insertManyCount++ + return doInner(ctx) +} + +func (p hookMiddlewareValuePlugin) IsHook() bool { return true } + +func (p hookMiddlewareValuePlugin) IsMiddleware() bool { return true } + +type hookMiddlewareValuePluginCounts struct { + insertBeginCount int + insertManyCount int +} + type periodicJobArgs struct{} func (periodicJobArgs) Kind() string { return "periodic_job" } @@ -8195,6 +8266,121 @@ func Test_NewClient_Overrides(t *testing.T) { require.Len(t, client.config.WorkerMiddleware, 1) } +func Test_NewClient_PluginsAndHybrids(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + type testBundle struct { + config *Config + dbPool *pgxpool.Pool + } + + setup := func(t *testing.T) *testBundle { + t.Helper() + + dbPool := riversharedtest.DBPool(ctx, t) + driver := riverpgxv5.New(dbPool) + schema := riverdbtest.TestSchema(ctx, t, driver, nil) + + return &testBundle{ + config: newTestConfig(t, schema), + dbPool: dbPool, + } + } + + insertAndRequireCounts := func(t *testing.T, bundle *testBundle, plugin *hookMiddlewarePlugin, expectedCount int) { + t.Helper() + + client := newTestClient(t, bundle.dbPool, bundle.config) + + _, err := client.Insert(ctx, noOpArgs{}, nil) + require.NoError(t, err) + + require.Equal(t, expectedCount, plugin.insertBeginCount) + require.Equal(t, expectedCount, plugin.insertManyCount) + } + + t.Run("DuplicatesAcrossHooksMiddlewareAndPluginsRunMultipleTimes", func(t *testing.T) { + t.Parallel() + + bundle := setup(t) + plugin := &hookMiddlewarePlugin{} + bundle.config.Hooks = []rivertype.Hook{plugin} + bundle.config.Middleware = []rivertype.Middleware{plugin} + bundle.config.Plugins = []rivertype.Plugin{plugin} + + insertAndRequireCounts(t, bundle, plugin, 3) + }) + + t.Run("HookAlsoRegisteredAsMiddleware", func(t *testing.T) { + t.Parallel() + + bundle := setup(t) + plugin := &hookMiddlewarePlugin{} + bundle.config.Hooks = []rivertype.Hook{plugin} + + insertAndRequireCounts(t, bundle, plugin, 1) + }) + + t.Run("MiddlewareAlsoRegisteredAsHook", func(t *testing.T) { + t.Parallel() + + bundle := setup(t) + plugin := &hookMiddlewarePlugin{} + bundle.config.Middleware = []rivertype.Middleware{plugin} + + insertAndRequireCounts(t, bundle, plugin, 1) + }) + + t.Run("PluginRegisteredAsHookAndMiddleware", func(t *testing.T) { + t.Parallel() + + bundle := setup(t) + plugin := &hookMiddlewarePlugin{} + bundle.config.Plugins = []rivertype.Plugin{plugin} + + insertAndRequireCounts(t, bundle, plugin, 1) + }) + + t.Run("PluginRegisteredWithEmbeddedDefaults", func(t *testing.T) { + t.Parallel() + + bundle := setup(t) + plugin := &hookMiddlewareEmbeddedDefaultsPlugin{} + bundle.config.Plugins = []rivertype.Plugin{plugin} + + client := newTestClient(t, bundle.dbPool, bundle.config) + + _, err := client.Insert(ctx, noOpArgs{}, nil) + require.NoError(t, err) + + require.Equal(t, 1, plugin.insertBeginCount) + require.Equal(t, 1, plugin.insertManyCount) + }) + + t.Run("SeparateEqualValueInstancesRunSeparately", func(t *testing.T) { + t.Parallel() + + bundle := setup(t) + counts := &hookMiddlewareValuePluginCounts{} + bundle.config.Hooks = []rivertype.Hook{ + hookMiddlewareValuePlugin{counts: counts}, + } + bundle.config.Middleware = []rivertype.Middleware{ + hookMiddlewareValuePlugin{counts: counts}, + } + + client := newTestClient(t, bundle.dbPool, bundle.config) + + _, err := client.Insert(ctx, noOpArgs{}, nil) + require.NoError(t, err) + + require.Equal(t, 2, counts.insertBeginCount) + require.Equal(t, 2, counts.insertManyCount) + }) +} + func Test_NewClient_ReindexerIndexNamesExplicitEmptyOverride(t *testing.T) { t.Parallel() diff --git a/internal/pluginconfig/plugin_config.go b/internal/pluginconfig/plugin_config.go new file mode 100644 index 00000000..1377970e --- /dev/null +++ b/internal/pluginconfig/plugin_config.go @@ -0,0 +1,51 @@ +package pluginconfig + +import "github.com/riverqueue/river/rivertype" + +// Hooks returns the effective hook list from configured hooks, middleware, and +// plugins. Explicit hooks are preserved first, followed by middleware that also +// implement hooks, then plugins. +func Hooks(hooks []rivertype.Hook, middleware []rivertype.Middleware, plugins []rivertype.Plugin) []rivertype.Hook { + effectiveHooks := make([]rivertype.Hook, 0, len(hooks)+len(middleware)+len(plugins)) + + effectiveHooks = append(effectiveHooks, hooks...) + + for _, middlewareItem := range middleware { + hook, ok := middlewareItem.(rivertype.Hook) + if !ok { + continue + } + + effectiveHooks = append(effectiveHooks, hook) + } + + for _, plugin := range plugins { + effectiveHooks = append(effectiveHooks, plugin) + } + + return effectiveHooks +} + +// Middleware returns the effective middleware list from configured hooks, +// middleware, and plugins. Explicit middleware are preserved first, followed by +// hooks that also implement middleware, then plugins. +func Middleware(hooks []rivertype.Hook, middleware []rivertype.Middleware, plugins []rivertype.Plugin) []rivertype.Middleware { + effectiveMiddleware := make([]rivertype.Middleware, 0, len(hooks)+len(middleware)+len(plugins)) + + effectiveMiddleware = append(effectiveMiddleware, middleware...) + + for _, hook := range hooks { + middlewareItem, ok := hook.(rivertype.Middleware) + if !ok { + continue + } + + effectiveMiddleware = append(effectiveMiddleware, middlewareItem) + } + + for _, plugin := range plugins { + effectiveMiddleware = append(effectiveMiddleware, plugin) + } + + return effectiveMiddleware +} diff --git a/plugin_defaults.go b/plugin_defaults.go new file mode 100644 index 00000000..7c9596b7 --- /dev/null +++ b/plugin_defaults.go @@ -0,0 +1,10 @@ +package river + +// PluginDefaults should be embedded on plugin implementations. It helps +// identify a struct as both hooks and middleware, and guarantees forward +// compatibility in case additions are necessary to the rivertype.Hook or +// rivertype.Middleware interfaces. +type PluginDefaults struct { + HookDefaults + MiddlewareDefaults +} diff --git a/plugin_defaults_test.go b/plugin_defaults_test.go new file mode 100644 index 00000000..e995398f --- /dev/null +++ b/plugin_defaults_test.go @@ -0,0 +1,9 @@ +package river + +import "github.com/riverqueue/river/rivertype" + +var ( + _ rivertype.Hook = &PluginDefaults{} + _ rivertype.Middleware = &PluginDefaults{} + _ rivertype.Plugin = &PluginDefaults{} +) diff --git a/rivertest/worker.go b/rivertest/worker.go index 8da01ae0..91751a50 100644 --- a/rivertest/worker.go +++ b/rivertest/worker.go @@ -13,6 +13,7 @@ import ( "github.com/riverqueue/river/internal/jobexecutor" "github.com/riverqueue/river/internal/maintenance" "github.com/riverqueue/river/internal/middlewarelookup" + "github.com/riverqueue/river/internal/pluginconfig" "github.com/riverqueue/river/internal/rivermiddleware" "github.com/riverqueue/river/riverdriver" "github.com/riverqueue/river/rivershared/baseservice" @@ -147,12 +148,15 @@ func (w *Worker[T, TTx]) workJob(ctx context.Context, tb testing.TB, tx TTx, job } completer := jobcompleter.NewInlineCompleter(archetype, w.config.Schema, exec, w.client.Pilot(), subscribeCh) - for _, hook := range w.config.Hooks { + effectiveHooks := pluginconfig.Hooks(w.config.Hooks, w.config.Middleware, w.config.Plugins) + effectiveMiddleware := pluginconfig.Middleware(w.config.Hooks, w.config.Middleware, w.config.Plugins) + + for _, hook := range effectiveHooks { if withBaseService, ok := hook.(baseservice.WithBaseService); ok { baseservice.Init(archetype, withBaseService) } } - for _, middleware := range w.config.Middleware { + for _, middleware := range effectiveMiddleware { if withBaseService, ok := middleware.(baseservice.WithBaseService); ok { baseservice.Init(archetype, withBaseService) } @@ -205,10 +209,10 @@ func (w *Worker[T, TTx]) workJob(ctx context.Context, tb testing.TB, tx TTx, job return nil }, }, - HookLookupGlobal: hooklookup.NewHookLookup(w.config.Hooks), + HookLookupGlobal: hooklookup.NewHookLookup(effectiveHooks), HookLookupByJob: hooklookup.NewJobHookLookup(), JobRow: job, - MiddlewareLookupGlobal: middlewarelookup.NewMiddlewareLookup(append(rivermiddleware.DefaultMiddleware(), w.config.Middleware...)), + MiddlewareLookupGlobal: middlewarelookup.NewMiddlewareLookup(append(rivermiddleware.DefaultMiddleware(), effectiveMiddleware...)), ProducerCallbacks: struct { JobDone func(jobRow *rivertype.JobRow) Stuck func() diff --git a/rivertype/river_type.go b/rivertype/river_type.go index 13d07f63..f350e877 100644 --- a/rivertype/river_type.go +++ b/rivertype/river_type.go @@ -447,6 +447,17 @@ type Middleware interface { IsMiddleware() bool } +// Plugin is a generic extension installed globally as both a hook and +// middleware. +// +// Plugin structs should embed river.PluginDefaults, or embed both +// river.HookDefaults and river.MiddlewareDefaults directly, then implement any +// operation-specific hook or middleware interfaces they need. +type Plugin interface { + Hook + Middleware +} + // JobInsertMiddleware provides an interface for middleware that integrations // can use to encapsulate common logic around job insertion. //