Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Added

- Added root River CLI flag `--statement-timeout` so Postgres session statement timeout can be set explicitly for commands like migrations. Explicit flag values take priority over database URL query params, and query params still take priority over built-in defaults. [PR #1142](https://github.com/riverqueue/river/pull/1142).

### Fixed

- `JobCountByQueueAndState` now returns consistent results across drivers, including requested queues with zero jobs, and deduplicates repeated queue names in input. This resolves an issue with the sqlite driver in River UI reported in [riverqueue/riverui#496](https://github.com/riverqueue/riverui#496). [PR #1140](https://github.com/riverqueue/river/pull/1140).
Expand Down
37 changes: 27 additions & 10 deletions cmd/river/rivercli/command.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package rivercli

import (
"cmp"
"context"
"database/sql"
"fmt"
Expand Down Expand Up @@ -46,11 +47,12 @@ type CommandOpts interface {

// RunCommandBundle is a bundle of utilities for RunCommand.
type RunCommandBundle struct {
DatabaseURL *string
DriverProcurer DriverProcurer
Logger *slog.Logger
OutStd io.Writer
Schema string
DatabaseURL *string
DriverProcurer DriverProcurer
Logger *slog.Logger
OutStd io.Writer
Schema string
StatementTimeout *time.Duration
}

// RunCommand bootstraps and runs a River CLI subcommand.
Expand Down Expand Up @@ -81,7 +83,7 @@ func RunCommand[TOpts CommandOpts](ctx context.Context, bundle *RunCommandBundle
if databaseURL != nil {
switch protocol {
case "postgres", "postgresql":
dbPool, err := openPgxV5DBPool(ctx, *databaseURL)
dbPool, err := openPgxV5DBPool(ctx, *databaseURL, bundle.StatementTimeout)
if err != nil {
return false, err
}
Expand Down Expand Up @@ -128,7 +130,7 @@ func RunCommand[TOpts CommandOpts](ctx context.Context, bundle *RunCommandBundle
return nil
}

func openPgxV5DBPool(ctx context.Context, databaseURL string) (*pgxpool.Pool, error) {
func openPgxV5DBPool(ctx context.Context, databaseURL string, statementTimeout *time.Duration) (*pgxpool.Pool, error) {
const (
defaultIdleInTransactionSessionTimeout = 11 * time.Second // should be greater than statement timeout because statements count towards idle-in-transaction
defaultStatementTimeout = 10 * time.Second
Expand All @@ -149,9 +151,24 @@ func openPgxV5DBPool(ctx context.Context, databaseURL string) (*pgxpool.Pool, er
runtimeParams[name] = val
}

setParamIfUnset(pgxConfig.ConnConfig.RuntimeParams, "application_name", "river CLI")
setParamIfUnset(pgxConfig.ConnConfig.RuntimeParams, "idle_in_transaction_session_timeout", strconv.Itoa(int(defaultIdleInTransactionSessionTimeout.Milliseconds())))
setParamIfUnset(pgxConfig.ConnConfig.RuntimeParams, "statement_timeout", strconv.Itoa(int(defaultStatementTimeout.Milliseconds())))
runtimeParams := pgxConfig.ConnConfig.RuntimeParams
if runtimeParams == nil {
runtimeParams = make(map[string]string)
pgxConfig.ConnConfig.RuntimeParams = runtimeParams
}

var statementTimeoutMilliseconds string
if statementTimeout != nil {
statementTimeoutMilliseconds = strconv.Itoa(int(statementTimeout.Milliseconds()))
}

setParamIfUnset(runtimeParams, "application_name", "river CLI")
setParamIfUnset(runtimeParams, "idle_in_transaction_session_timeout", strconv.Itoa(int(defaultIdleInTransactionSessionTimeout.Milliseconds())))
runtimeParams["statement_timeout"] = cmp.Or(
statementTimeoutMilliseconds,
runtimeParams["statement_timeout"],
strconv.Itoa(int(defaultStatementTimeout.Milliseconds())),
)

dbPool, err := pgxpool.NewWithConfig(ctx, pgxConfig)
if err != nil {
Expand Down
44 changes: 36 additions & 8 deletions cmd/river/rivercli/river_cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,11 @@ func (c *CLI) BaseCommandSet() *cobra.Command {
ctx := context.Background()

var globalOpts struct {
Debug bool
Verbose bool
Debug bool
StatementTimeout time.Duration
Verbose bool
}
var rootCmd *cobra.Command

makeLogger := func() *slog.Logger {
switch {
Expand All @@ -75,18 +77,35 @@ func (c *CLI) BaseCommandSet() *cobra.Command {
}
}

statementTimeoutFlagSet := func() bool {
return rootCmd.PersistentFlags().Changed("statement-timeout")
}

validateGlobalOpts := func() error {
if statementTimeoutFlagSet() && globalOpts.StatementTimeout <= time.Millisecond {
return errors.New("`--statement-timeout` must be greater than 1ms when set")
}

return nil
}

// Make a bundle for RunCommand. Takes a database URL pointer because not every command is required to take a database URL.
makeCommandBundle := func(databaseURL *string, schema string) *RunCommandBundle {
var statementTimeout *time.Duration
if statementTimeoutFlagSet() {
statementTimeout = &globalOpts.StatementTimeout
}

return &RunCommandBundle{
DatabaseURL: databaseURL,
DriverProcurer: c.driverProcurer,
Logger: makeLogger(),
OutStd: c.out,
Schema: schema,
DatabaseURL: databaseURL,
DriverProcurer: c.driverProcurer,
Logger: makeLogger(),
OutStd: c.out,
Schema: schema,
StatementTimeout: statementTimeout,
}
}

var rootCmd *cobra.Command
{
var rootOpts struct {
Version bool
Expand All @@ -103,7 +122,15 @@ also accept Postgres configuration through the standard set of libpq environment
variables like PGHOST, PGPORT, PGDATABASE, PGUSER, PGPASSWORD, and PGSSLMODE,
with a minimum of PGDATABASE required. --database-url will take precedence of
PG* vars if it's been specified.

Use --statement-timeout to explicitly set Postgres statement_timeout for
Postgres-backed commands. Precedence is: --statement-timeout, then a
statement_timeout query parameter in --database-url, then the built-in 10s
default.
`),
PersistentPreRunE: func(cmd *cobra.Command, args []string) error {
return validateGlobalOpts()
},
RunE: func(cmd *cobra.Command, args []string) error {
if rootOpts.Version {
return RunCommand(ctx, makeCommandBundle(nil, ""), &version{}, &versionOpts{Name: c.name})
Expand All @@ -116,6 +143,7 @@ PG* vars if it's been specified.
rootCmd.SetOut(c.out)

rootCmd.PersistentFlags().BoolVar(&globalOpts.Debug, "debug", false, "output maximum logging verbosity (debug level)")
rootCmd.PersistentFlags().DurationVar(&globalOpts.StatementTimeout, "statement-timeout", 0, "override Postgres statement_timeout for Postgres commands (Go duration >1ms, e.g. 10s, 1m); precedence: flag > --database-url statement_timeout > default 10s")
rootCmd.PersistentFlags().BoolVarP(&globalOpts.Verbose, "verbose", "v", false, "output additional logging verbosity (info level)")
rootCmd.MarkFlagsMutuallyExclusive("debug", "verbose")

Expand Down
139 changes: 139 additions & 0 deletions cmd/river/rivercli/river_cli_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"cmp"
"context"
"fmt"
"maps"
"net/url"
"runtime/debug"
"strings"
Expand Down Expand Up @@ -169,6 +170,34 @@ func TestBaseCommandSetIntegration(t *testing.T) {
require.EqualError(t, cmd.Execute(), "either PG* env vars or --database-url must be set")
})

t.Run("StatementTimeoutValidation", func(t *testing.T) {
t.Parallel()

t.Run("AllowsGreaterThanOneMillisecond", func(t *testing.T) {
t.Parallel()

cmd, _ := setup(t)
cmd.SetArgs([]string{"--statement-timeout", "2ms", "--version"})
require.NoError(t, cmd.Execute())
})

t.Run("RejectsOneMillisecond", func(t *testing.T) {
t.Parallel()

cmd, _ := setup(t)
cmd.SetArgs([]string{"--statement-timeout", "1ms", "--version"})
require.EqualError(t, cmd.Execute(), "`--statement-timeout` must be greater than 1ms when set")
})

t.Run("RejectsZero", func(t *testing.T) {
t.Parallel()

cmd, _ := setup(t)
cmd.SetArgs([]string{"--statement-timeout", "0", "--version"})
require.EqualError(t, cmd.Execute(), "`--statement-timeout` must be greater than 1ms when set")
})
})

t.Run("VersionFlag", func(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -263,6 +292,116 @@ func TestBaseCommandSetNonParallel(t *testing.T) {
})
}

func TestBaseCommandSetPostgresTimeoutPrecedence(t *testing.T) {
t.Parallel()

type testCase struct {
databaseURLStatementTimeout string
expectedStatementTimeoutMS string
name string
statementTimeoutFlag string
}

makeCommandAndParams := func(t *testing.T) (*cobra.Command, func() map[string]string) {
t.Helper()

var capturedRuntimeParams map[string]string

migratorStub := &MigratorStub{}
migratorStub.allVersionsStub = func() []rivermigrate.Migration { return []rivermigrate.Migration{testMigration01} }
migratorStub.getVersionStub = func(version int) (rivermigrate.Migration, error) {
if version == 1 {
return testMigration01, nil
}

return rivermigrate.Migration{}, fmt.Errorf("unknown version: %d", version)
}
migratorStub.existingVersionsStub = func(ctx context.Context) ([]rivermigrate.Migration, error) { return nil, nil }

cli := NewCLI(&Config{
DriverProcurer: &DriverProcurerStub{
getMigratorStub: func(config *rivermigrate.Config) (MigratorInterface, error) {
return migratorStub, nil
},
initPgxV5Stub: func(pool *pgxpool.Pool) {
capturedRuntimeParams = maps.Clone(pool.Config().ConnConfig.RuntimeParams)
},
},
Name: "River",
})

var out bytes.Buffer
cli.SetOut(&out)

return cli.BaseCommandSet(), func() map[string]string {
return capturedRuntimeParams
}
}

makeBaseDatabaseURL := func(t *testing.T) *url.URL {
t.Helper()

testDatabaseURL := riversharedtest.TestDatabaseURL()
parsedDatabaseURL, err := url.Parse(testDatabaseURL)
require.NoError(t, err)

return parsedDatabaseURL
}

testCases := []testCase{
{
name: "DefaultsAppliedWhenNothingSpecified",
expectedStatementTimeoutMS: "10000",
},
{
databaseURLStatementTimeout: "11234",
name: "DatabaseURLQueryParamsOverrideDefaults",
expectedStatementTimeoutMS: "11234",
},
{
databaseURLStatementTimeout: "12345",
name: "ExplicitFlagsOverrideDatabaseURLQueryParams",
statementTimeoutFlag: "1m3.123s",
expectedStatementTimeoutMS: "63123",
},
{
databaseURLStatementTimeout: "12345",
name: "ExplicitFlagsUseMillisecondValue",
statementTimeoutFlag: "2ms",
expectedStatementTimeoutMS: "2",
},
}

for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
t.Parallel()

cmd, getRuntimeParams := makeCommandAndParams(t)

databaseURL := makeBaseDatabaseURL(t)
if testCase.databaseURLStatementTimeout != "" {
queryValues := databaseURL.Query()
queryValues.Set("statement_timeout", testCase.databaseURLStatementTimeout)
databaseURL.RawQuery = queryValues.Encode()
}

args := []string{
"migrate-get", "--up", "--version", "1", "--database-url", databaseURL.String(),
}
if testCase.statementTimeoutFlag != "" {
args = append(args, "--statement-timeout", testCase.statementTimeoutFlag)
}
cmd.SetArgs(args)
require.NoError(t, cmd.Execute())

runtimeParams := getRuntimeParams()
require.NotNil(t, runtimeParams)

require.Equal(t, testCase.expectedStatementTimeoutMS, runtimeParams["statement_timeout"])
})
}
}

func TestBaseCommandSetDriverProcurerPgxV5(t *testing.T) {
t.Parallel()

Expand Down
5 changes: 5 additions & 0 deletions docs/development.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ To run programs locally outside of tests, create and raise a development databas
createdb river_dev
go run ./cmd/river migrate-up --database-url postgres:///river_dev --line main

If needed, override Postgres timeouts for long-running migrations with root CLI
flags:

go run ./cmd/river --statement-timeout 2m migrate-up --database-url postgres:///river_dev --line main

## Releasing a new version

1. Fetch changes to the repo and any new tags. Export `VERSION` by incrementing the last tag. Execute `update-mod-version` to add it the project's `go.mod` files:
Expand Down
Loading