From c991820b8ba2327db327cb4e8328acb9ddf38d37 Mon Sep 17 00:00:00 2001 From: Brandur Date: Fri, 12 Jun 2026 07:37:44 -0500 Subject: [PATCH] Cross-compatible hook/middleware + "plugins" which are allowed to be hook and middleware This one's largely aimed at extending a few parts of `otelriver` to be able to emit some additional useful metrics like time that it takes to lock jobs, number of jobs locked per batch, or any other arbitrary metrics we want to emit down the road. I previously had something similar back in #1203, but here we extract an isolated piece of it. This change does two things: * If either a hook sent to `Config.Hooks` implements a middleware or a middleware sent to `Config.Middleware` implements a hook, activate its alternate side as well. * Establish a new `Config.Plugins` that acts as a more generalized place where a hook/middleware can go. We define a plugin as this type: type Plugin interface { Hook Middleware } The reason for the first point is better backward compatibility. Notably, if I add a hook to `otelriver.Middleware`, I want it to be able to still work even if the user doesn't explicitly movie it from `Config.Middleware` to `Config.Plugins`. --- client.go | 86 ++++++++---- client_test.go | 186 +++++++++++++++++++++++++ internal/pluginconfig/plugin_config.go | 51 +++++++ plugin_defaults.go | 10 ++ plugin_defaults_test.go | 9 ++ rivertest/worker.go | 12 +- rivertype/river_type.go | 11 ++ 7 files changed, 336 insertions(+), 29 deletions(-) create mode 100644 internal/pluginconfig/plugin_config.go create mode 100644 plugin_defaults.go create mode 100644 plugin_defaults_test.go diff --git a/client.go b/client.go index 3f29e545..044a237d 100644 --- a/client.go +++ b/client.go @@ -23,6 +23,7 @@ import ( "github.com/riverqueue/river/internal/middlewarelookup" "github.com/riverqueue/river/internal/notifier" "github.com/riverqueue/river/internal/notifylimiter" + "github.com/riverqueue/river/internal/pluginconfig" "github.com/riverqueue/river/internal/rivercommon" "github.com/riverqueue/river/internal/rivermiddleware" "github.com/riverqueue/river/internal/workunit" @@ -219,6 +220,9 @@ type Config struct { // work hook runs and the insertion hooks on either side of it are skipped. // // Jobs may have their own specific hooks by implementing JobArgsWithHooks. + // + // If a type in Hooks also implements rivertype.Middleware, it will be + // installed as middleware too. Hooks []rivertype.Hook // Logger is the structured logger to use for logging purposes. If none is @@ -252,8 +256,25 @@ type Config struct { // middlewares will run one after another, and the work middleware between // them will not run. When a job is worked, the work middleware runs and the // insertion middlewares on either side of it are skipped. + // + // If a type in Middleware also implements rivertype.Hook, it will be + // installed as a hook too. Middleware []rivertype.Middleware + // Plugins contains extensions installed globally as both hooks and + // middleware. + // + // A plugin must implement both rivertype.Hook and rivertype.Middleware. It + // may embed PluginDefaults, or embed both HookDefaults and + // MiddlewareDefaults directly, then implement any operation-specific hook or + // middleware interfaces it needs. + // + // Hooks and Middleware are still supported. If a type in Hooks also + // implements middleware, or a type in Middleware also implements hooks, River + // will install it on both sides automatically. The Plugins list exists as a + // generic place for new extensions to be registered. + Plugins []rivertype.Plugin + // PeriodicJobs are a set of periodic jobs to run at the specified intervals // in the client. PeriodicJobs []*PeriodicJob @@ -476,6 +497,7 @@ func (c *Config) WithDefaults() *Config { MaxAttempts: cmp.Or(c.MaxAttempts, MaxAttemptsDefault), Middleware: c.Middleware, PeriodicJobs: c.PeriodicJobs, + Plugins: c.Plugins, PollOnly: c.PollOnly, Queues: c.Queues, ReindexerIndexNames: reindexerIndexNames, @@ -774,7 +796,11 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client } } - for _, hook := range config.Hooks { + configuredMiddleware := middlewareFromConfig(config) + effectiveHooks := pluginconfig.Hooks(config.Hooks, configuredMiddleware, config.Plugins) + effectiveMiddleware := pluginconfig.Middleware(config.Hooks, configuredMiddleware, config.Plugins) + + for _, hook := range effectiveHooks { if withBaseService, ok := hook.(baseservice.WithBaseService); ok { baseservice.Init(archetype, withBaseService) } @@ -788,7 +814,7 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client config: config, driver: driver, hookLookupByJob: hooklookup.NewJobHookLookup(), - hookLookupGlobal: hooklookup.NewHookLookup(config.Hooks), + hookLookupGlobal: hooklookup.NewHookLookup(effectiveHooks), producersByQueueName: make(map[string]*producer), testSignals: clientTestSignals{}, workCancel: func(cause error) {}, // replaced on start, but here in case StopAndCancel is called before start up @@ -806,31 +832,12 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client client.baseService.Name = "Client" // Have to correct the name because base service isn't embedded like it usually is client.insertNotifyLimiter = notifylimiter.NewLimiter(archetype, config.FetchCooldown) - // Validation ensures that config.JobInsertMiddleware/WorkerMiddleware or - // the more abstract config.Middleware for middleware are set, but not both, - // so in practice we never append all three of these to each other. + // effectiveMiddleware contains configured middleware, hook/middleware + // hybrids, and plugins. Default middleware stays first so user middleware + // wraps inside River's internal defaults like before. { middleware := rivermiddleware.DefaultMiddleware() - middleware = append(middleware, config.Middleware...) - for _, jobInsertMiddleware := range config.JobInsertMiddleware { - middleware = append(middleware, jobInsertMiddleware) - } - outerLoop: - for _, workerMiddleware := range config.WorkerMiddleware { - // Don't add the middleware if it also implements JobInsertMiddleware - // and the instance has been added to config.JobInsertMiddleware. This - // is a hedge to make sure we don't accidentally double add middleware - // as we've converted over to the unified config.Middleware setting. - if workerMiddlewareAsJobInsertMiddleware, ok := workerMiddleware.(rivertype.JobInsertMiddleware); ok { - for _, jobInsertMiddleware := range config.JobInsertMiddleware { - if workerMiddlewareAsJobInsertMiddleware == jobInsertMiddleware { - continue outerLoop - } - } - } - - middleware = append(middleware, workerMiddleware) - } + middleware = append(middleware, effectiveMiddleware...) for _, middleware := range middleware { if withBaseService, ok := middleware.(baseservice.WithBaseService); ok { @@ -1040,6 +1047,35 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client return client, nil } +func middlewareFromConfig(config *Config) []rivertype.Middleware { + middleware := make([]rivertype.Middleware, 0, + len(config.Middleware)+len(config.JobInsertMiddleware)+len(config.WorkerMiddleware)) + middleware = append(middleware, config.Middleware...) + + for _, jobInsertMiddleware := range config.JobInsertMiddleware { + middleware = append(middleware, jobInsertMiddleware) + } + +outerLoop: + for _, workerMiddleware := range config.WorkerMiddleware { + // Don't add the middleware if it also implements JobInsertMiddleware + // and the instance has been added to config.JobInsertMiddleware. This + // is a hedge to make sure we don't accidentally double add middleware + // as we've converted over to the unified config.Middleware setting. + if workerMiddlewareAsJobInsertMiddleware, ok := workerMiddleware.(rivertype.JobInsertMiddleware); ok { + for _, jobInsertMiddleware := range config.JobInsertMiddleware { + if workerMiddlewareAsJobInsertMiddleware == jobInsertMiddleware { + continue outerLoop + } + } + } + + middleware = append(middleware, workerMiddleware) + } + + return middleware +} + // Start starts the client's job fetching and working loops. Once this is called, // the client will run in a background goroutine until stopped. All jobs are // run with a context inheriting from the provided context, but with a timeout diff --git a/client_test.go b/client_test.go index 9352dcbb..94304c09 100644 --- a/client_test.go +++ b/client_test.go @@ -70,6 +70,77 @@ type noOpWorker struct { func (w *noOpWorker) Work(ctx context.Context, job *Job[noOpArgs]) error { return nil } +var ( + _ rivertype.HookInsertBegin = &hookMiddlewareEmbeddedDefaultsPlugin{} + _ rivertype.JobInsertMiddleware = &hookMiddlewareEmbeddedDefaultsPlugin{} + _ rivertype.Plugin = &hookMiddlewareEmbeddedDefaultsPlugin{} + + _ rivertype.HookInsertBegin = &hookMiddlewarePlugin{} + _ rivertype.JobInsertMiddleware = &hookMiddlewarePlugin{} + _ rivertype.Plugin = &hookMiddlewarePlugin{} + + _ rivertype.HookInsertBegin = hookMiddlewareValuePlugin{} + _ rivertype.JobInsertMiddleware = hookMiddlewareValuePlugin{} +) + +type hookMiddlewareEmbeddedDefaultsPlugin struct { + HookDefaults + MiddlewareDefaults + + insertBeginCount int + insertManyCount int +} + +func (p *hookMiddlewareEmbeddedDefaultsPlugin) InsertBegin(ctx context.Context, params *rivertype.JobInsertParams) error { + p.insertBeginCount++ + return nil +} + +func (p *hookMiddlewareEmbeddedDefaultsPlugin) InsertMany(ctx context.Context, manyParams []*rivertype.JobInsertParams, doInner func(context.Context) ([]*rivertype.JobInsertResult, error)) ([]*rivertype.JobInsertResult, error) { + p.insertManyCount++ + return doInner(ctx) +} + +type hookMiddlewarePlugin struct { + PluginDefaults + + insertBeginCount int + insertManyCount int +} + +func (p *hookMiddlewarePlugin) InsertBegin(ctx context.Context, params *rivertype.JobInsertParams) error { + p.insertBeginCount++ + return nil +} + +func (p *hookMiddlewarePlugin) InsertMany(ctx context.Context, manyParams []*rivertype.JobInsertParams, doInner func(context.Context) ([]*rivertype.JobInsertResult, error)) ([]*rivertype.JobInsertResult, error) { + p.insertManyCount++ + return doInner(ctx) +} + +type hookMiddlewareValuePlugin struct { + counts *hookMiddlewareValuePluginCounts +} + +func (p hookMiddlewareValuePlugin) InsertBegin(ctx context.Context, params *rivertype.JobInsertParams) error { + p.counts.insertBeginCount++ + return nil +} + +func (p hookMiddlewareValuePlugin) InsertMany(ctx context.Context, manyParams []*rivertype.JobInsertParams, doInner func(context.Context) ([]*rivertype.JobInsertResult, error)) ([]*rivertype.JobInsertResult, error) { + p.counts.insertManyCount++ + return doInner(ctx) +} + +func (p hookMiddlewareValuePlugin) IsHook() bool { return true } + +func (p hookMiddlewareValuePlugin) IsMiddleware() bool { return true } + +type hookMiddlewareValuePluginCounts struct { + insertBeginCount int + insertManyCount int +} + type periodicJobArgs struct{} func (periodicJobArgs) Kind() string { return "periodic_job" } @@ -8195,6 +8266,121 @@ func Test_NewClient_Overrides(t *testing.T) { require.Len(t, client.config.WorkerMiddleware, 1) } +func Test_NewClient_PluginsAndHybrids(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + type testBundle struct { + config *Config + dbPool *pgxpool.Pool + } + + setup := func(t *testing.T) *testBundle { + t.Helper() + + dbPool := riversharedtest.DBPool(ctx, t) + driver := riverpgxv5.New(dbPool) + schema := riverdbtest.TestSchema(ctx, t, driver, nil) + + return &testBundle{ + config: newTestConfig(t, schema), + dbPool: dbPool, + } + } + + insertAndRequireCounts := func(t *testing.T, bundle *testBundle, plugin *hookMiddlewarePlugin, expectedCount int) { + t.Helper() + + client := newTestClient(t, bundle.dbPool, bundle.config) + + _, err := client.Insert(ctx, noOpArgs{}, nil) + require.NoError(t, err) + + require.Equal(t, expectedCount, plugin.insertBeginCount) + require.Equal(t, expectedCount, plugin.insertManyCount) + } + + t.Run("DuplicatesAcrossHooksMiddlewareAndPluginsRunMultipleTimes", func(t *testing.T) { + t.Parallel() + + bundle := setup(t) + plugin := &hookMiddlewarePlugin{} + bundle.config.Hooks = []rivertype.Hook{plugin} + bundle.config.Middleware = []rivertype.Middleware{plugin} + bundle.config.Plugins = []rivertype.Plugin{plugin} + + insertAndRequireCounts(t, bundle, plugin, 3) + }) + + t.Run("HookAlsoRegisteredAsMiddleware", func(t *testing.T) { + t.Parallel() + + bundle := setup(t) + plugin := &hookMiddlewarePlugin{} + bundle.config.Hooks = []rivertype.Hook{plugin} + + insertAndRequireCounts(t, bundle, plugin, 1) + }) + + t.Run("MiddlewareAlsoRegisteredAsHook", func(t *testing.T) { + t.Parallel() + + bundle := setup(t) + plugin := &hookMiddlewarePlugin{} + bundle.config.Middleware = []rivertype.Middleware{plugin} + + insertAndRequireCounts(t, bundle, plugin, 1) + }) + + t.Run("PluginRegisteredAsHookAndMiddleware", func(t *testing.T) { + t.Parallel() + + bundle := setup(t) + plugin := &hookMiddlewarePlugin{} + bundle.config.Plugins = []rivertype.Plugin{plugin} + + insertAndRequireCounts(t, bundle, plugin, 1) + }) + + t.Run("PluginRegisteredWithEmbeddedDefaults", func(t *testing.T) { + t.Parallel() + + bundle := setup(t) + plugin := &hookMiddlewareEmbeddedDefaultsPlugin{} + bundle.config.Plugins = []rivertype.Plugin{plugin} + + client := newTestClient(t, bundle.dbPool, bundle.config) + + _, err := client.Insert(ctx, noOpArgs{}, nil) + require.NoError(t, err) + + require.Equal(t, 1, plugin.insertBeginCount) + require.Equal(t, 1, plugin.insertManyCount) + }) + + t.Run("SeparateEqualValueInstancesRunSeparately", func(t *testing.T) { + t.Parallel() + + bundle := setup(t) + counts := &hookMiddlewareValuePluginCounts{} + bundle.config.Hooks = []rivertype.Hook{ + hookMiddlewareValuePlugin{counts: counts}, + } + bundle.config.Middleware = []rivertype.Middleware{ + hookMiddlewareValuePlugin{counts: counts}, + } + + client := newTestClient(t, bundle.dbPool, bundle.config) + + _, err := client.Insert(ctx, noOpArgs{}, nil) + require.NoError(t, err) + + require.Equal(t, 2, counts.insertBeginCount) + require.Equal(t, 2, counts.insertManyCount) + }) +} + func Test_NewClient_ReindexerIndexNamesExplicitEmptyOverride(t *testing.T) { t.Parallel() diff --git a/internal/pluginconfig/plugin_config.go b/internal/pluginconfig/plugin_config.go new file mode 100644 index 00000000..1377970e --- /dev/null +++ b/internal/pluginconfig/plugin_config.go @@ -0,0 +1,51 @@ +package pluginconfig + +import "github.com/riverqueue/river/rivertype" + +// Hooks returns the effective hook list from configured hooks, middleware, and +// plugins. Explicit hooks are preserved first, followed by middleware that also +// implement hooks, then plugins. +func Hooks(hooks []rivertype.Hook, middleware []rivertype.Middleware, plugins []rivertype.Plugin) []rivertype.Hook { + effectiveHooks := make([]rivertype.Hook, 0, len(hooks)+len(middleware)+len(plugins)) + + effectiveHooks = append(effectiveHooks, hooks...) + + for _, middlewareItem := range middleware { + hook, ok := middlewareItem.(rivertype.Hook) + if !ok { + continue + } + + effectiveHooks = append(effectiveHooks, hook) + } + + for _, plugin := range plugins { + effectiveHooks = append(effectiveHooks, plugin) + } + + return effectiveHooks +} + +// Middleware returns the effective middleware list from configured hooks, +// middleware, and plugins. Explicit middleware are preserved first, followed by +// hooks that also implement middleware, then plugins. +func Middleware(hooks []rivertype.Hook, middleware []rivertype.Middleware, plugins []rivertype.Plugin) []rivertype.Middleware { + effectiveMiddleware := make([]rivertype.Middleware, 0, len(hooks)+len(middleware)+len(plugins)) + + effectiveMiddleware = append(effectiveMiddleware, middleware...) + + for _, hook := range hooks { + middlewareItem, ok := hook.(rivertype.Middleware) + if !ok { + continue + } + + effectiveMiddleware = append(effectiveMiddleware, middlewareItem) + } + + for _, plugin := range plugins { + effectiveMiddleware = append(effectiveMiddleware, plugin) + } + + return effectiveMiddleware +} diff --git a/plugin_defaults.go b/plugin_defaults.go new file mode 100644 index 00000000..7c9596b7 --- /dev/null +++ b/plugin_defaults.go @@ -0,0 +1,10 @@ +package river + +// PluginDefaults should be embedded on plugin implementations. It helps +// identify a struct as both hooks and middleware, and guarantees forward +// compatibility in case additions are necessary to the rivertype.Hook or +// rivertype.Middleware interfaces. +type PluginDefaults struct { + HookDefaults + MiddlewareDefaults +} diff --git a/plugin_defaults_test.go b/plugin_defaults_test.go new file mode 100644 index 00000000..e995398f --- /dev/null +++ b/plugin_defaults_test.go @@ -0,0 +1,9 @@ +package river + +import "github.com/riverqueue/river/rivertype" + +var ( + _ rivertype.Hook = &PluginDefaults{} + _ rivertype.Middleware = &PluginDefaults{} + _ rivertype.Plugin = &PluginDefaults{} +) diff --git a/rivertest/worker.go b/rivertest/worker.go index 8da01ae0..91751a50 100644 --- a/rivertest/worker.go +++ b/rivertest/worker.go @@ -13,6 +13,7 @@ import ( "github.com/riverqueue/river/internal/jobexecutor" "github.com/riverqueue/river/internal/maintenance" "github.com/riverqueue/river/internal/middlewarelookup" + "github.com/riverqueue/river/internal/pluginconfig" "github.com/riverqueue/river/internal/rivermiddleware" "github.com/riverqueue/river/riverdriver" "github.com/riverqueue/river/rivershared/baseservice" @@ -147,12 +148,15 @@ func (w *Worker[T, TTx]) workJob(ctx context.Context, tb testing.TB, tx TTx, job } completer := jobcompleter.NewInlineCompleter(archetype, w.config.Schema, exec, w.client.Pilot(), subscribeCh) - for _, hook := range w.config.Hooks { + effectiveHooks := pluginconfig.Hooks(w.config.Hooks, w.config.Middleware, w.config.Plugins) + effectiveMiddleware := pluginconfig.Middleware(w.config.Hooks, w.config.Middleware, w.config.Plugins) + + for _, hook := range effectiveHooks { if withBaseService, ok := hook.(baseservice.WithBaseService); ok { baseservice.Init(archetype, withBaseService) } } - for _, middleware := range w.config.Middleware { + for _, middleware := range effectiveMiddleware { if withBaseService, ok := middleware.(baseservice.WithBaseService); ok { baseservice.Init(archetype, withBaseService) } @@ -205,10 +209,10 @@ func (w *Worker[T, TTx]) workJob(ctx context.Context, tb testing.TB, tx TTx, job return nil }, }, - HookLookupGlobal: hooklookup.NewHookLookup(w.config.Hooks), + HookLookupGlobal: hooklookup.NewHookLookup(effectiveHooks), HookLookupByJob: hooklookup.NewJobHookLookup(), JobRow: job, - MiddlewareLookupGlobal: middlewarelookup.NewMiddlewareLookup(append(rivermiddleware.DefaultMiddleware(), w.config.Middleware...)), + MiddlewareLookupGlobal: middlewarelookup.NewMiddlewareLookup(append(rivermiddleware.DefaultMiddleware(), effectiveMiddleware...)), ProducerCallbacks: struct { JobDone func(jobRow *rivertype.JobRow) Stuck func() diff --git a/rivertype/river_type.go b/rivertype/river_type.go index 13d07f63..f350e877 100644 --- a/rivertype/river_type.go +++ b/rivertype/river_type.go @@ -447,6 +447,17 @@ type Middleware interface { IsMiddleware() bool } +// Plugin is a generic extension installed globally as both a hook and +// middleware. +// +// Plugin structs should embed river.PluginDefaults, or embed both +// river.HookDefaults and river.MiddlewareDefaults directly, then implement any +// operation-specific hook or middleware interfaces they need. +type Plugin interface { + Hook + Middleware +} + // JobInsertMiddleware provides an interface for middleware that integrations // can use to encapsulate common logic around job insertion. //