diff --git a/go.mod b/go.mod index 9fc6d0b69..4bc804c2e 100644 --- a/go.mod +++ b/go.mod @@ -34,7 +34,7 @@ require ( github.com/jackc/pgconn v1.14.3 github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438 github.com/jackc/pgx-zap v0.0.0-20221202020421-94b1cb2f889f - github.com/jackc/pgx/v5 v5.6.0 + github.com/jackc/pgx/v5 v5.8.0 github.com/joho/godotenv v1.5.1 github.com/kenshaw/inflector v0.2.0 github.com/kenshaw/snaker v0.2.0 @@ -47,7 +47,7 @@ require ( github.com/prometheus/client_golang v1.19.0 github.com/smartystreets/assertions v1.13.1 github.com/smartystreets/gunit v1.4.5 - github.com/stretchr/testify v1.9.0 + github.com/stretchr/testify v1.11.1 github.com/swaggest/jsonschema-go v0.3.72 github.com/swaggest/openapi-go v0.2.53 github.com/zitadel/oidc/v2 v2.12.0 @@ -84,7 +84,7 @@ require ( github.com/jackc/pgio v1.0.0 // indirect github.com/jackc/pgproto3/v2 v2.3.3 // indirect github.com/jackc/pgtype v1.14.3 // indirect - github.com/jackc/puddle/v2 v2.2.1 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/klauspost/cpuid/v2 v2.2.7 // indirect github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect github.com/muhlemmer/gu v0.3.1 // indirect @@ -129,7 +129,7 @@ require ( github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/invopop/yaml v0.3.1 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect - github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/leodido/go-urn v1.4.0 // indirect diff --git a/go.sum b/go.sum index 65500831b..dde663728 100644 --- a/go.sum +++ b/go.sum @@ -221,8 +221,8 @@ github.com/jackc/pgproto3/v2 v2.3.3 h1:1HLSx5H+tXR9pW3in3zaztoEwQYRC9SQaYUHjTSUO github.com/jackc/pgproto3/v2 v2.3.3/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA= -github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= @@ -239,14 +239,14 @@ github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c/go.mod h1:1QD0+tgS github.com/jackc/pgx/v4 v4.18.2/go.mod h1:Ey4Oru5tH5sB6tV7hDmfWFahwF15Eb7DNXlRKx2CkVw= github.com/jackc/pgx/v4 v4.18.3 h1:dE2/TrEsGX3RBprb3qryqSV9Y60iZN1C6i8IrmW9/BA= github.com/jackc/pgx/v4 v4.18.3/go.mod h1:Ey4Oru5tH5sB6tV7hDmfWFahwF15Eb7DNXlRKx2CkVw= -github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY= -github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw= +github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo= +github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw= github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.3.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= -github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= -github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jeremija/gosubmit v0.2.7 h1:At0OhGCFGPXyjPYAsCchoBUhE099pcBXmsb4iZqROIc= github.com/jeremija/gosubmit v0.2.7/go.mod h1:Ui+HS073lCFREXBbdfrJzMB57OI/bdxTiLtrDHHhFPI= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= @@ -406,8 +406,9 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/swaggest/assertjson v1.9.0 h1:dKu0BfJkIxv/xe//mkCrK5yZbs79jL7OVf9Ija7o2xQ= github.com/swaggest/assertjson v1.9.0/go.mod h1:b+ZKX2VRiUjxfUIal0HDN85W0nHPAYUbYH5WkkSsFsU= github.com/swaggest/jsonschema-go v0.3.72 h1:IHaGlR1bdBUBPfhe4tfacN2TGAPKENEGiNyNzvnVHv4= diff --git a/vendor/github.com/jackc/pgservicefile/pgservicefile.go b/vendor/github.com/jackc/pgservicefile/pgservicefile.go index 797bbab9e..c62caa7fe 100644 --- a/vendor/github.com/jackc/pgservicefile/pgservicefile.go +++ b/vendor/github.com/jackc/pgservicefile/pgservicefile.go @@ -57,7 +57,7 @@ func ParseServicefile(r io.Reader) (*Servicefile, error) { } else if strings.HasPrefix(line, "[") && strings.HasSuffix(line, "]") { service = &Service{Name: line[1 : len(line)-1], Settings: make(map[string]string)} servicefile.Services = append(servicefile.Services, service) - } else { + } else if service != nil { parts := strings.SplitN(line, "=", 2) if len(parts) != 2 { return nil, fmt.Errorf("unable to parse line %d", lineNum) @@ -67,6 +67,8 @@ func ParseServicefile(r io.Reader) (*Servicefile, error) { value := strings.TrimSpace(parts[1]) service.Settings[key] = value + } else { + return nil, fmt.Errorf("line %d is not in a section", lineNum) } } diff --git a/vendor/github.com/jackc/pgx/v5/.golangci.yml b/vendor/github.com/jackc/pgx/v5/.golangci.yml new file mode 100644 index 000000000..ca74c703a --- /dev/null +++ b/vendor/github.com/jackc/pgx/v5/.golangci.yml @@ -0,0 +1,21 @@ +# See for configurations: https://golangci-lint.run/usage/configuration/ +version: 2 + +# See: https://golangci-lint.run/usage/formatters/ +formatters: + default: none + enable: + - gofmt # https://pkg.go.dev/cmd/gofmt + - gofumpt # https://github.com/mvdan/gofumpt + + settings: + gofmt: + simplify: true # Simplify code: gofmt with `-s` option. + + gofumpt: + # Module path which contains the source code being formatted. + # Default: "" + module-path: github.com/jackc/pgx/v5 # Should match with module in go.mod + # Choose whether to use the extra rules. + # Default: false + extra-rules: true diff --git a/vendor/github.com/jackc/pgx/v5/CHANGELOG.md b/vendor/github.com/jackc/pgx/v5/CHANGELOG.md index 61b4695fd..b958a8161 100644 --- a/vendor/github.com/jackc/pgx/v5/CHANGELOG.md +++ b/vendor/github.com/jackc/pgx/v5/CHANGELOG.md @@ -1,3 +1,100 @@ +# 5.8.0 (December 26, 2025) + +* Require Go 1.24+ +* Remove golang.org/x/crypto dependency +* Add OptionShouldPing to control ResetSession ping behavior (ilyam8) +* Fix: Avoid overflow when MaxConns is set to MaxInt32 +* Fix: Close batch pipeline after a query error (Anthonin Bonnefoy) +* Faster shutdown of pgxpool.Pool background goroutines (Blake Gentry) +* Add pgxpool ping timeout (Amirsalar Safaei) +* Fix: Rows.FieldDescriptions for empty query +* Scan unknown types into *any as string or []byte based on format code +* Optimize pgtype.Numeric (Philip Dubé) +* Add AfterNetConnect hook to pgconn.Config +* Fix: Handle for preparing statements that fail during the Describe phase +* Fix overflow in numeric scanning (Ilia Demianenko) +* Fix: json/jsonb sql.Scanner source type is []byte +* Migrate from math/rand to math/rand/v2 (Mathias Bogaert) +* Optimize internal iobufpool (Mathias Bogaert) +* Optimize stmtcache invalidation (Mathias Bogaert) +* Fix: missing error case in interval parsing (Maxime Soulé) +* Fix: invalidate statement/description cache in Exec (James Hartig) +* ColumnTypeLength method return the type length for varbit type (DengChan) +* Array and Composite codecs handle typed nils + +# 5.7.6 (September 8, 2025) + +* Use ParseConfigError in pgx.ParseConfig and pgxpool.ParseConfig (Yurasov Ilia) +* Add PrepareConn hook to pgxpool (Jonathan Hall) +* Reduce allocations in QueryContext (Dominique Lefevre) +* Add MarshalJSON and UnmarshalJSON for pgtype.Uint32 (Panos Koutsovasilis) +* Configure ping behavior on pgxpool with ShouldPing (Christian Kiely) +* zeronull int types implement Int64Valuer and Int64Scanner (Li Zeghong) +* Fix panic when receiving terminate connection message during CopyFrom (Michal Drausowski) +* Fix statement cache not being invalidated on error during batch (Muhammadali Nazarov) + +# 5.7.5 (May 17, 2025) + +* Support sslnegotiation connection option (divyam234) +* Update golang.org/x/crypto to v0.37.0. This placates security scanners that were unable to see that pgx did not use the behavior affected by https://pkg.go.dev/vuln/GO-2025-3487. +* TraceLog now logs Acquire and Release at the debug level (dave sinclair) +* Add support for PGTZ environment variable +* Add support for PGOPTIONS environment variable +* Unpin memory used by Rows quicker +* Remove PlanScan memoization. This resolves a rare issue where scanning could be broken for one type by first scanning another. The problem was in the memoization system and benchmarking revealed that memoization was not providing any meaningful benefit. + +# 5.7.4 (March 24, 2025) + +* Fix / revert change to scanning JSON `null` (Felix Röhrich) + +# 5.7.3 (March 21, 2025) + +* Expose EmptyAcquireWaitTime in pgxpool.Stat (vamshiaruru32) +* Improve SQL sanitizer performance (ninedraft) +* Fix Scan confusion with json(b), sql.Scanner, and automatic dereferencing (moukoublen, felix-roehrich) +* Fix Values() for xml type always returning nil instead of []byte +* Add ability to send Flush message in pipeline mode (zenkovev) +* Fix pgtype.Timestamp's JSON behavior to match PostgreSQL (pconstantinou) +* Better error messages when scanning structs (logicbomb) +* Fix handling of error on batch write (bonnefoa) +* Match libpq's connection fallback behavior more closely (felix-roehrich) +* Add MinIdleConns to pgxpool (djahandarie) + +# 5.7.2 (December 21, 2024) + +* Fix prepared statement already exists on batch prepare failure +* Add commit query to tx options (Lucas Hild) +* Fix pgtype.Timestamp json unmarshal (Shean de Montigny-Desautels) +* Add message body size limits in frontend and backend (zene) +* Add xid8 type +* Ensure planning encodes and scans cannot infinitely recurse +* Implement pgtype.UUID.String() (Konstantin Grachev) +* Switch from ExecParams to Exec in ValidateConnectTargetSessionAttrs functions (Alexander Rumyantsev) +* Update golang.org/x/crypto +* Fix json(b) columns prefer sql.Scanner interface like database/sql (Ludovico Russo) + +# 5.7.1 (September 10, 2024) + +* Fix data race in tracelog.TraceLog +* Update puddle to v2.2.2. This removes the import of nanotime via linkname. +* Update golang.org/x/crypto and golang.org/x/text + +# 5.7.0 (September 7, 2024) + +* Add support for sslrootcert=system (Yann Soubeyrand) +* Add LoadTypes to load multiple types in a single SQL query (Nick Farrell) +* Add XMLCodec supports encoding + scanning XML column type like json (nickcruess-soda) +* Add MultiTrace (Stepan Rabotkin) +* Add TraceLogConfig with customizable TimeKey (stringintech) +* pgx.ErrNoRows wraps sql.ErrNoRows to aid in database/sql compatibility with native pgx functions (merlin) +* Support scanning binary formatted uint32 into string / TextScanner (jennifersp) +* Fix interval encoding to allow 0s and avoid extra spaces (Carlos Pérez-Aradros Herce) +* Update pgservicefile - fixes panic when parsing invalid file +* Better error message when reading past end of batch +* Don't print url when url.Parse returns an error (Kevin Biju) +* Fix snake case name normalization collision in RowToStructByName with db tag (nolandseigler) +* Fix: Scan and encode types with underlying types of arrays + # 5.6.0 (May 25, 2024) * Add StrictNamedArgs (Tomas Zahradnicek) diff --git a/vendor/github.com/jackc/pgx/v5/README.md b/vendor/github.com/jackc/pgx/v5/README.md index 0cf2c2916..3c6f8e6a7 100644 --- a/vendor/github.com/jackc/pgx/v5/README.md +++ b/vendor/github.com/jackc/pgx/v5/README.md @@ -84,7 +84,7 @@ It is also possible to use the `database/sql` interface and convert a connection ## Testing -See CONTRIBUTING.md for setup instructions. +See [CONTRIBUTING.md](./CONTRIBUTING.md) for setup instructions. ## Architecture @@ -92,7 +92,7 @@ See the presentation at Golang Estonia, [PGX Top to Bottom](https://www.youtube. ## Supported Go and PostgreSQL Versions -pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.21 and higher and PostgreSQL 12 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/). +pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.24 and higher and PostgreSQL 13 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/). ## Version Policy @@ -126,7 +126,8 @@ pgerrcode contains constants for the PostgreSQL error codes. ## Adapters for 3rd Party Tracers -* [https://github.com/jackhopner/pgx-xray-tracer](https://github.com/jackhopner/pgx-xray-tracer) +* [github.com/jackhopner/pgx-xray-tracer](https://github.com/jackhopner/pgx-xray-tracer) +* [github.com/exaring/otelpgx](https://github.com/exaring/otelpgx) ## Adapters for 3rd Party Loggers @@ -156,7 +157,7 @@ Library for scanning data from a database into Go structs and more. A carefully designed SQL client for making using SQL easier, more productive, and less error-prone on Golang. -### [https://github.com/otan/gopgkrb5](https://github.com/otan/gopgkrb5) +### [github.com/otan/gopgkrb5](https://github.com/otan/gopgkrb5) Adds GSSAPI / Kerberos authentication support. @@ -169,6 +170,22 @@ Explicit data mapping and scanning library for Go structs and slices. Type safe and flexible package for scanning database data into Go types. Supports, structs, maps, slices and custom mapping functions. -### [https://github.com/z0ne-dev/mgx](https://github.com/z0ne-dev/mgx) +### [github.com/z0ne-dev/mgx](https://github.com/z0ne-dev/mgx) Code first migration library for native pgx (no database/sql abstraction). + +### [github.com/amirsalarsafaei/sqlc-pgx-monitoring](https://github.com/amirsalarsafaei/sqlc-pgx-monitoring) + +A database monitoring/metrics library for pgx and sqlc. Trace, log and monitor your sqlc query performance using OpenTelemetry. + +### [https://github.com/nikolayk812/pgx-outbox](https://github.com/nikolayk812/pgx-outbox) + +Simple Golang implementation for transactional outbox pattern for PostgreSQL using jackc/pgx driver. + +### [https://github.com/Arlandaren/pgxWrappy](https://github.com/Arlandaren/pgxWrappy) + +Simplifies working with the pgx library, providing convenient scanning of nested structures. + +### [https://github.com/KoNekoD/pgx-colon-query-rewriter](https://github.com/KoNekoD/pgx-colon-query-rewriter) + +Implementation of the pgx query rewriter to use ':' instead of '@' in named query parameters. diff --git a/vendor/github.com/jackc/pgx/v5/Rakefile b/vendor/github.com/jackc/pgx/v5/Rakefile index d957573e9..3e3aa5030 100644 --- a/vendor/github.com/jackc/pgx/v5/Rakefile +++ b/vendor/github.com/jackc/pgx/v5/Rakefile @@ -2,7 +2,7 @@ require "erb" rule '.go' => '.go.erb' do |task| erb = ERB.new(File.read(task.source)) - File.write(task.name, "// Do not edit. Generated from #{task.source}\n" + erb.result(binding)) + File.write(task.name, "// Code generated from #{task.source}. DO NOT EDIT.\n\n" + erb.result(binding)) sh "goimports", "-w", task.name end diff --git a/vendor/github.com/jackc/pgx/v5/batch.go b/vendor/github.com/jackc/pgx/v5/batch.go index 3540f57f5..d5e7dc8ec 100644 --- a/vendor/github.com/jackc/pgx/v5/batch.go +++ b/vendor/github.com/jackc/pgx/v5/batch.go @@ -43,6 +43,10 @@ func (qq *QueuedQuery) QueryRow(fn func(row Row) error) { } // Exec sets fn to be called when the response to qq is received. +// +// Note: for simple batch insert uses where it is not required to handle +// each potential error individually, it's sufficient to not set any callbacks, +// and just handle the return value of BatchResults.Close. func (qq *QueuedQuery) Exec(fn func(ct pgconn.CommandTag) error) { qq.Fn = func(br BatchResults) error { ct, err := br.Exec() @@ -60,9 +64,13 @@ type Batch struct { QueuedQueries []*QueuedQuery } -// Queue queues a query to batch b. query can be an SQL query or the name of a prepared statement. -// The only pgx option argument that is supported is QueryRewriter. Queries are executed using the -// connection's DefaultQueryExecMode. +// Queue queues a query to batch b. query can be an SQL query or the name of a prepared statement. The only pgx option +// argument that is supported is QueryRewriter. Queries are executed using the connection's DefaultQueryExecMode. +// +// While query can contain multiple statements if the connection's DefaultQueryExecMode is QueryModeSimple, this should +// be avoided. QueuedQuery.Fn must not be set as it will only be called for the first query. That is, QueuedQuery.Query, +// QueuedQuery.QueryRow, and QueuedQuery.Exec must not be called. In addition, any error messages or tracing that +// include the current query may reference the wrong query. func (b *Batch) Queue(query string, arguments ...any) *QueuedQuery { qq := &QueuedQuery{ SQL: query, @@ -79,7 +87,7 @@ func (b *Batch) Len() int { type BatchResults interface { // Exec reads the results from the next query in the batch as if the query has been sent with Conn.Exec. Prefer - // calling Exec on the QueuedQuery. + // calling Exec on the QueuedQuery, or just calling Close. Exec() (pgconn.CommandTag, error) // Query reads the results from the next query in the batch as if the query has been sent with Conn.Query. Prefer @@ -94,6 +102,9 @@ type BatchResults interface { // QueuedQuery.Query, QueuedQuery.QueryRow, or QueuedQuery.Exec will be called. If a callback function returns an // error or the batch encounters an error subsequent callback functions will not be called. // + // For simple batch inserts inside a transaction or similar queries, it's sufficient to not set any callbacks, + // and just handle the return value of Close. + // // Close must be called before the underlying connection can be used again. Any error that occurred during a batch // operation may have made it impossible to resyncronize the connection with the server. In this case the underlying // connection will have been closed. @@ -128,7 +139,7 @@ func (br *batchResults) Exec() (pgconn.CommandTag, error) { if !br.mrr.NextResult() { err := br.mrr.Close() if err == nil { - err = errors.New("no result") + err = errors.New("no more results in batch") } if br.conn.batchTracer != nil { br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{ @@ -180,7 +191,7 @@ func (br *batchResults) Query() (Rows, error) { if !br.mrr.NextResult() { rows.err = br.mrr.Close() if rows.err == nil { - rows.err = errors.New("no result") + rows.err = errors.New("no more results in batch") } rows.closed = true @@ -203,7 +214,6 @@ func (br *batchResults) Query() (Rows, error) { func (br *batchResults) QueryRow() Row { rows, _ := br.Query() return (*connRow)(rows.(*baseRows)) - } // Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to @@ -216,6 +226,8 @@ func (br *batchResults) Close() error { } br.endTraced = true } + + invalidateCachesOnBatchResultsError(br.conn, br.b, br.err) }() if br.err != nil { @@ -287,7 +299,10 @@ func (br *pipelineBatchResults) Exec() (pgconn.CommandTag, error) { return pgconn.CommandTag{}, br.err } - query, arguments, _ := br.nextQueryAndArgs() + query, arguments, err := br.nextQueryAndArgs() + if err != nil { + return pgconn.CommandTag{}, err + } results, err := br.pipeline.GetResults() if err != nil { @@ -330,9 +345,9 @@ func (br *pipelineBatchResults) Query() (Rows, error) { return &baseRows{err: br.err, closed: true}, br.err } - query, arguments, ok := br.nextQueryAndArgs() - if !ok { - query = "batch query" + query, arguments, err := br.nextQueryAndArgs() + if err != nil { + return &baseRows{err: err, closed: true}, err } rows := br.conn.getRows(br.ctx, query, arguments) @@ -371,7 +386,6 @@ func (br *pipelineBatchResults) Query() (Rows, error) { func (br *pipelineBatchResults) QueryRow() Row { rows, _ := br.Query() return (*connRow)(rows.(*baseRows)) - } // Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to @@ -384,11 +398,12 @@ func (br *pipelineBatchResults) Close() error { } br.endTraced = true } + + invalidateCachesOnBatchResultsError(br.conn, br.b, br.err) }() if br.err == nil && br.lastRows != nil && br.lastRows.err != nil { br.err = br.lastRows.err - return br.err } if br.closed { @@ -421,13 +436,72 @@ func (br *pipelineBatchResults) earlyError() error { return br.err } -func (br *pipelineBatchResults) nextQueryAndArgs() (query string, args []any, ok bool) { - if br.b != nil && br.qqIdx < len(br.b.QueuedQueries) { - bi := br.b.QueuedQueries[br.qqIdx] - query = bi.SQL - args = bi.Arguments - ok = true - br.qqIdx++ +func (br *pipelineBatchResults) nextQueryAndArgs() (query string, args []any, err error) { + if br.b == nil { + return "", nil, errors.New("no reference to batch") + } + + if br.qqIdx >= len(br.b.QueuedQueries) { + return "", nil, errors.New("no more results in batch") + } + + bi := br.b.QueuedQueries[br.qqIdx] + br.qqIdx++ + return bi.SQL, bi.Arguments, nil +} + +type emptyBatchResults struct { + conn *Conn + closed bool +} + +// Exec reads the results from the next query in the batch as if the query has been sent with Exec. +func (br *emptyBatchResults) Exec() (pgconn.CommandTag, error) { + if br.closed { + return pgconn.CommandTag{}, fmt.Errorf("batch already closed") + } + return pgconn.CommandTag{}, errors.New("no more results in batch") +} + +// Query reads the results from the next query in the batch as if the query has been sent with Query. +func (br *emptyBatchResults) Query() (Rows, error) { + if br.closed { + alreadyClosedErr := fmt.Errorf("batch already closed") + return &baseRows{err: alreadyClosedErr, closed: true}, alreadyClosedErr + } + + rows := br.conn.getRows(context.Background(), "", nil) + rows.err = errors.New("no more results in batch") + rows.closed = true + return rows, rows.err +} + +// QueryRow reads the results from the next query in the batch as if the query has been sent with QueryRow. +func (br *emptyBatchResults) QueryRow() Row { + rows, _ := br.Query() + return (*connRow)(rows.(*baseRows)) +} + +// Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to +// resyncronize the connection with the server. In this case the underlying connection will have been closed. +func (br *emptyBatchResults) Close() error { + br.closed = true + return nil +} + +// invalidates statement and description caches on batch results error +func invalidateCachesOnBatchResultsError(conn *Conn, b *Batch, err error) { + if err != nil && conn != nil && b != nil { + if sc := conn.statementCache; sc != nil { + for _, bi := range b.QueuedQueries { + sc.Invalidate(bi.SQL) + } + } + + if sc := conn.descriptionCache; sc != nil { + for _, bi := range b.QueuedQueries { + sc.Invalidate(bi.SQL) + } + } } - return } diff --git a/vendor/github.com/jackc/pgx/v5/conn.go b/vendor/github.com/jackc/pgx/v5/conn.go index 311721459..4a6d7fa97 100644 --- a/vendor/github.com/jackc/pgx/v5/conn.go +++ b/vendor/github.com/jackc/pgx/v5/conn.go @@ -3,6 +3,7 @@ package pgx import ( "context" "crypto/sha256" + "database/sql" "encoding/hex" "errors" "fmt" @@ -64,11 +65,12 @@ func (cc *ConnConfig) ConnString() string { return cc.connString } // Conn is a PostgreSQL connection handle. It is not safe for concurrent usage. Use a connection pool to manage access // to multiple database connections from multiple goroutines. type Conn struct { - pgConn *pgconn.PgConn - config *ConnConfig // config used when establishing this connection - preparedStatements map[string]*pgconn.StatementDescription - statementCache stmtcache.Cache - descriptionCache stmtcache.Cache + pgConn *pgconn.PgConn + config *ConnConfig // config used when establishing this connection + preparedStatements map[string]*pgconn.StatementDescription + failedDescribeStatement string + statementCache stmtcache.Cache + descriptionCache stmtcache.Cache queryTracer QueryTracer batchTracer BatchTracer @@ -102,13 +104,31 @@ func (ident Identifier) Sanitize() string { var ( // ErrNoRows occurs when rows are expected but none are returned. - ErrNoRows = errors.New("no rows in result set") + ErrNoRows = newProxyErr(sql.ErrNoRows, "no rows in result set") // ErrTooManyRows occurs when more rows than expected are returned. ErrTooManyRows = errors.New("too many rows in result set") ) -var errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache") -var errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache") +func newProxyErr(background error, msg string) error { + return &proxyError{ + msg: msg, + background: background, + } +} + +type proxyError struct { + msg string + background error +} + +func (err *proxyError) Error() string { return err.msg } + +func (err *proxyError) Unwrap() error { return err.background } + +var ( + errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache") + errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache") +) // Connect establishes a connection with a PostgreSQL server with a connection string. See // pgconn.Connect for details. @@ -153,7 +173,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con delete(config.RuntimeParams, "statement_cache_capacity") n, err := strconv.ParseInt(s, 10, 32) if err != nil { - return nil, fmt.Errorf("cannot parse statement_cache_capacity: %w", err) + return nil, pgconn.NewParseConfigError(connString, "cannot parse statement_cache_capacity", err) } statementCacheCapacity = int(n) } @@ -163,7 +183,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con delete(config.RuntimeParams, "description_cache_capacity") n, err := strconv.ParseInt(s, 10, 32) if err != nil { - return nil, fmt.Errorf("cannot parse description_cache_capacity: %w", err) + return nil, pgconn.NewParseConfigError(connString, "cannot parse description_cache_capacity", err) } descriptionCacheCapacity = int(n) } @@ -183,7 +203,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con case "simple_protocol": defaultQueryExecMode = QueryExecModeSimpleProtocol default: - return nil, fmt.Errorf("invalid default_query_exec_mode: %s", s) + return nil, pgconn.NewParseConfigError(connString, "invalid default_query_exec_mode", err) } } @@ -295,6 +315,14 @@ func (c *Conn) Close(ctx context.Context) error { // Prepare is idempotent; i.e. it is safe to call Prepare multiple times with the same name and sql arguments. This // allows a code path to Prepare and Query/Exec without concern for if the statement has already been prepared. func (c *Conn) Prepare(ctx context.Context, name, sql string) (sd *pgconn.StatementDescription, err error) { + if c.failedDescribeStatement != "" { + err = c.Deallocate(ctx, c.failedDescribeStatement) + if err != nil { + return nil, fmt.Errorf("failed to deallocate previously failed statement %q: %w", c.failedDescribeStatement, err) + } + c.failedDescribeStatement = "" + } + if c.prepareTracer != nil { ctx = c.prepareTracer.TracePrepareStart(ctx, c, TracePrepareStartData{Name: name, SQL: sql}) } @@ -327,6 +355,10 @@ func (c *Conn) Prepare(ctx context.Context, name, sql string) (sd *pgconn.Statem sd, err = c.pgConn.Prepare(ctx, psName, sql, nil) if err != nil { + var pErr *pgconn.PrepareError + if errors.As(err, &pErr) { + c.failedDescribeStatement = psKey + } return nil, err } @@ -401,7 +433,7 @@ func (c *Conn) IsClosed() bool { return c.pgConn.IsClosed() } -func (c *Conn) die(err error) { +func (c *Conn) die() { if c.IsClosed() { return } @@ -483,6 +515,18 @@ optionLoop: mode = QueryExecModeSimpleProtocol } + defer func() { + if err != nil { + if sc := c.statementCache; sc != nil { + sc.Invalidate(sql) + } + + if sc := c.descriptionCache; sc != nil { + sc.Invalidate(sql) + } + } + }() + if sd, ok := c.preparedStatements[sql]; ok { return c.execPrepared(ctx, sd, arguments) } @@ -569,14 +613,6 @@ func (c *Conn) execPrepared(ctx context.Context, sd *pgconn.StatementDescription return result.CommandTag, result.Err } -type unknownArgumentTypeQueryExecModeExecError struct { - arg any -} - -func (e *unknownArgumentTypeQueryExecModeExecError) Error() string { - return fmt.Sprintf("cannot use unregistered type %T as query argument in QueryExecModeExec", e.arg) -} - func (c *Conn) execSQLParams(ctx context.Context, sql string, args []any) (pgconn.CommandTag, error) { err := c.eqb.Build(c.typeMap, nil, args) if err != nil { @@ -631,21 +667,33 @@ const ( // registered with pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are // unregistered or ambiguous. e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know // the PostgreSQL type can use a map[string]string directly as an argument. This mode cannot. + // + // On rare occasions user defined types may behave differently when encoded in the text format instead of the binary + // format. For example, this could happen if a "type RomanNumeral int32" implements fmt.Stringer to format integers as + // Roman numerals (e.g. 7 is VII). The binary format would properly encode the integer 7 as the binary value for 7. + // But the text format would encode the integer 7 as the string "VII". As QueryExecModeExec uses the text format, it + // is possible that changing query mode from another mode to QueryExecModeExec could change the behavior of the query. + // This should not occur with types pgx supports directly and can be avoided by registering the types with + // pgtype.Map.RegisterDefaultPgType and implementing the appropriate type interfaces. In the cas of RomanNumeral, it + // should implement pgtype.Int64Valuer. QueryExecModeExec - // Use the simple protocol. Assume the PostgreSQL query parameter types based on the Go type of the arguments. - // Queries are executed in a single round trip. Type mappings can be registered with + // Use the simple protocol. Assume the PostgreSQL query parameter types based on the Go type of the arguments. This is + // especially significant for []byte values. []byte values are encoded as PostgreSQL bytea. string must be used + // instead for text type values including json and jsonb. Type mappings can be registered with // pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are unregistered or ambiguous. - // e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know the PostgreSQL type can use - // a map[string]string directly as an argument. This mode cannot. + // e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know the PostgreSQL type can use a + // map[string]string directly as an argument. This mode cannot. Queries are executed in a single round trip. // - // QueryExecModeSimpleProtocol should have the user application visible behavior as QueryExecModeExec with minor - // exceptions such as behavior when multiple result returning queries are erroneously sent in a single string. + // QueryExecModeSimpleProtocol should have the user application visible behavior as QueryExecModeExec. This includes + // the warning regarding differences in text format and binary format encoding with user defined types. There may be + // other minor exceptions such as behavior when multiple result returning queries are erroneously sent in a single + // string. // // QueryExecModeSimpleProtocol uses client side parameter interpolation. All values are quoted and escaped. Prefer - // QueryExecModeExec over QueryExecModeSimpleProtocol whenever possible. In general QueryExecModeSimpleProtocol - // should only be used if connecting to a proxy server, connection pool server, or non-PostgreSQL server that does - // not support the extended protocol. + // QueryExecModeExec over QueryExecModeSimpleProtocol whenever possible. In general QueryExecModeSimpleProtocol should + // only be used if connecting to a proxy server, connection pool server, or non-PostgreSQL server that does not + // support the extended protocol. QueryExecModeSimpleProtocol ) @@ -843,7 +891,6 @@ func (c *Conn) getStatementDescription( mode QueryExecMode, sql string, ) (sd *pgconn.StatementDescription, err error) { - switch mode { case QueryExecModeCacheStatement: if c.statementCache == nil { @@ -886,7 +933,14 @@ func (c *Conn) QueryRow(ctx context.Context, sql string, args ...any) Row { // SendBatch sends all queued queries to the server at once. All queries are run in an implicit transaction unless // explicit transaction control statements are executed. The returned BatchResults must be closed before the connection // is used again. +// +// Depending on the QueryExecMode, all queries may be prepared before any are executed. This means that creating a table +// and using it in a subsequent query in the same batch can fail. func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) { + if len(b.QueuedQueries) == 0 { + return &emptyBatchResults{conn: c} + } + if c.batchTracer != nil { ctx = c.batchTracer.TraceBatchStart(ctx, c, TraceBatchStartData{Batch: b}) defer func() { @@ -1108,47 +1162,64 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d // Prepare any needed queries if len(distinctNewQueries) > 0 { - for _, sd := range distinctNewQueries { - pipeline.SendPrepare(sd.Name, sd.SQL, nil) - } + err := func() (err error) { + for _, sd := range distinctNewQueries { + pipeline.SendPrepare(sd.Name, sd.SQL, nil) + } - err := pipeline.Sync() - if err != nil { - return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} - } + // Store all statements we are preparing into the cache. It's fine if it overflows because HandleInvalidated will + // clean them up later. + if sdCache != nil { + for _, sd := range distinctNewQueries { + sdCache.Put(sd) + } + } + + // If something goes wrong preparing the statements, we need to invalidate the cache entries we just added. + defer func() { + if err != nil && sdCache != nil { + for _, sd := range distinctNewQueries { + sdCache.Invalidate(sd.SQL) + } + } + }() + + err = pipeline.Sync() + if err != nil { + return err + } + + for _, sd := range distinctNewQueries { + results, err := pipeline.GetResults() + if err != nil { + return err + } + + resultSD, ok := results.(*pgconn.StatementDescription) + if !ok { + return fmt.Errorf("expected statement description, got %T", results) + } + + // Fill in the previously empty / pending statement descriptions. + sd.ParamOIDs = resultSD.ParamOIDs + sd.Fields = resultSD.Fields + } - for _, sd := range distinctNewQueries { results, err := pipeline.GetResults() if err != nil { - return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} + return err } - resultSD, ok := results.(*pgconn.StatementDescription) + _, ok := results.(*pgconn.PipelineSync) if !ok { - return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected statement description, got %T", results), closed: true} + return fmt.Errorf("expected sync, got %T", results) } - // Fill in the previously empty / pending statement descriptions. - sd.ParamOIDs = resultSD.ParamOIDs - sd.Fields = resultSD.Fields - } - - results, err := pipeline.GetResults() + return nil + }() if err != nil { return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} } - - _, ok := results.(*pgconn.PipelineSync) - if !ok { - return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected sync, got %T", results), closed: true} - } - } - - // Put all statements into the cache. It's fine if it overflows because HandleInvalidated will clean them up later. - if sdCache != nil { - for _, sd := range distinctNewQueries { - sdCache.Put(sd) - } } // Queue the queries. diff --git a/vendor/github.com/jackc/pgx/v5/derived_types.go b/vendor/github.com/jackc/pgx/v5/derived_types.go new file mode 100644 index 000000000..72c0a2423 --- /dev/null +++ b/vendor/github.com/jackc/pgx/v5/derived_types.go @@ -0,0 +1,256 @@ +package pgx + +import ( + "context" + "fmt" + "regexp" + "strconv" + "strings" + + "github.com/jackc/pgx/v5/pgtype" +) + +/* +buildLoadDerivedTypesSQL generates the correct query for retrieving type information. + + pgVersion: the major version of the PostgreSQL server + typeNames: the names of the types to load. If nil, load all types. +*/ +func buildLoadDerivedTypesSQL(pgVersion int64, typeNames []string) string { + supportsMultirange := (pgVersion >= 14) + var typeNamesClause string + + if typeNames == nil { + // This should not occur; this will not return any types + typeNamesClause = "= ''" + } else { + typeNamesClause = "= ANY($1)" + } + parts := make([]string, 0, 10) + + // Each of the type names provided might be found in pg_class or pg_type. + // Additionally, it may or may not include a schema portion. + parts = append(parts, ` +WITH RECURSIVE +-- find the OIDs in pg_class which match one of the provided type names +selected_classes(oid,reltype) AS ( + -- this query uses the namespace search path, so will match type names without a schema prefix + SELECT pg_class.oid, pg_class.reltype + FROM pg_catalog.pg_class + LEFT JOIN pg_catalog.pg_namespace n ON n.oid = pg_class.relnamespace + WHERE pg_catalog.pg_table_is_visible(pg_class.oid) + AND relname `, typeNamesClause, ` +UNION ALL + -- this query will only match type names which include the schema prefix + SELECT pg_class.oid, pg_class.reltype + FROM pg_class + INNER JOIN pg_namespace ON (pg_class.relnamespace = pg_namespace.oid) + WHERE nspname || '.' || relname `, typeNamesClause, ` +), +selected_types(oid) AS ( + -- collect the OIDs from pg_types which correspond to the selected classes + SELECT reltype AS oid + FROM selected_classes +UNION ALL + -- as well as any other type names which match our criteria + SELECT pg_type.oid + FROM pg_type + LEFT OUTER JOIN pg_namespace ON (pg_type.typnamespace = pg_namespace.oid) + WHERE typname `, typeNamesClause, ` + OR nspname || '.' || typname `, typeNamesClause, ` +), +-- this builds a parent/child mapping of objects, allowing us to know +-- all the child (ie: dependent) types that a parent (type) requires +-- As can be seen, there are 3 ways this can occur (the last of which +-- is due to being a composite class, where the composite fields are children) +pc(parent, child) AS ( + SELECT parent.oid, parent.typelem + FROM pg_type parent + WHERE parent.typtype = 'b' AND parent.typelem != 0 +UNION ALL + SELECT parent.oid, parent.typbasetype + FROM pg_type parent + WHERE parent.typtypmod = -1 AND parent.typbasetype != 0 +UNION ALL + SELECT pg_type.oid, atttypid + FROM pg_attribute + INNER JOIN pg_class ON (pg_class.oid = pg_attribute.attrelid) + INNER JOIN pg_type ON (pg_type.oid = pg_class.reltype) + WHERE NOT attisdropped + AND attnum > 0 +), +-- Now construct a recursive query which includes a 'depth' element. +-- This is used to ensure that the "youngest" children are registered before +-- their parents. +relationships(parent, child, depth) AS ( + SELECT DISTINCT 0::OID, selected_types.oid, 0 + FROM selected_types +UNION ALL + SELECT pg_type.oid AS parent, pg_attribute.atttypid AS child, 1 + FROM selected_classes c + inner join pg_type ON (c.reltype = pg_type.oid) + inner join pg_attribute on (c.oid = pg_attribute.attrelid) +UNION ALL + SELECT pc.parent, pc.child, relationships.depth + 1 + FROM pc + INNER JOIN relationships ON (pc.parent = relationships.child) +), +-- composite fields need to be encapsulated as a couple of arrays to provide the required information for registration +composite AS ( + SELECT pg_type.oid, ARRAY_AGG(attname ORDER BY attnum) AS attnames, ARRAY_AGG(atttypid ORDER BY ATTNUM) AS atttypids + FROM pg_attribute + INNER JOIN pg_class ON (pg_class.oid = pg_attribute.attrelid) + INNER JOIN pg_type ON (pg_type.oid = pg_class.reltype) + WHERE NOT attisdropped + AND attnum > 0 + GROUP BY pg_type.oid +) +-- Bring together this information, showing all the information which might possibly be required +-- to complete the registration, applying filters to only show the items which relate to the selected +-- types/classes. +SELECT typname, + pg_namespace.nspname, + typtype, + typbasetype, + typelem, + pg_type.oid,`) + if supportsMultirange { + parts = append(parts, ` + COALESCE(multirange.rngtypid, 0) AS rngtypid,`) + } else { + parts = append(parts, ` + 0 AS rngtypid,`) + } + parts = append(parts, ` + COALESCE(pg_range.rngsubtype, 0) AS rngsubtype, + attnames, atttypids + FROM relationships + INNER JOIN pg_type ON (pg_type.oid = relationships.child) + LEFT OUTER JOIN pg_range ON (pg_type.oid = pg_range.rngtypid)`) + if supportsMultirange { + parts = append(parts, ` + LEFT OUTER JOIN pg_range multirange ON (pg_type.oid = multirange.rngmultitypid)`) + } + + parts = append(parts, ` + LEFT OUTER JOIN composite USING (oid) + LEFT OUTER JOIN pg_namespace ON (pg_type.typnamespace = pg_namespace.oid) + WHERE NOT (typtype = 'b' AND typelem = 0)`) + parts = append(parts, ` + GROUP BY typname, pg_namespace.nspname, typtype, typbasetype, typelem, pg_type.oid, pg_range.rngsubtype,`) + if supportsMultirange { + parts = append(parts, ` + multirange.rngtypid,`) + } + parts = append(parts, ` + attnames, atttypids + ORDER BY MAX(depth) desc, typname;`) + return strings.Join(parts, "") +} + +type derivedTypeInfo struct { + Oid, Typbasetype, Typelem, Rngsubtype, Rngtypid uint32 + TypeName, Typtype, NspName string + Attnames []string + Atttypids []uint32 +} + +// LoadTypes performs a single (complex) query, returning all the required +// information to register the named types, as well as any other types directly +// or indirectly required to complete the registration. +// The result of this call can be passed into RegisterTypes to complete the process. +func (c *Conn) LoadTypes(ctx context.Context, typeNames []string) ([]*pgtype.Type, error) { + m := c.TypeMap() + if len(typeNames) == 0 { + return nil, fmt.Errorf("No type names were supplied.") + } + + // Disregard server version errors. This will result in + // the SQL not support recent structures such as multirange + serverVersion, _ := serverVersion(c) + sql := buildLoadDerivedTypesSQL(serverVersion, typeNames) + rows, err := c.Query(ctx, sql, QueryExecModeSimpleProtocol, typeNames) + if err != nil { + return nil, fmt.Errorf("While generating load types query: %w", err) + } + defer rows.Close() + result := make([]*pgtype.Type, 0, 100) + for rows.Next() { + ti := derivedTypeInfo{} + err = rows.Scan(&ti.TypeName, &ti.NspName, &ti.Typtype, &ti.Typbasetype, &ti.Typelem, &ti.Oid, &ti.Rngtypid, &ti.Rngsubtype, &ti.Attnames, &ti.Atttypids) + if err != nil { + return nil, fmt.Errorf("While scanning type information: %w", err) + } + var type_ *pgtype.Type + switch ti.Typtype { + case "b": // array + dt, ok := m.TypeForOID(ti.Typelem) + if !ok { + return nil, fmt.Errorf("Array element OID %v not registered while loading pgtype %q", ti.Typelem, ti.TypeName) + } + type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.ArrayCodec{ElementType: dt}} + case "c": // composite + var fields []pgtype.CompositeCodecField + for i, fieldName := range ti.Attnames { + dt, ok := m.TypeForOID(ti.Atttypids[i]) + if !ok { + return nil, fmt.Errorf("Unknown field for composite type %q: field %q (OID %v) is not already registered.", ti.TypeName, fieldName, ti.Atttypids[i]) + } + fields = append(fields, pgtype.CompositeCodecField{Name: fieldName, Type: dt}) + } + + type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.CompositeCodec{Fields: fields}} + case "d": // domain + dt, ok := m.TypeForOID(ti.Typbasetype) + if !ok { + return nil, fmt.Errorf("Domain base type OID %v was not already registered, needed for %q", ti.Typbasetype, ti.TypeName) + } + + type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: dt.Codec} + case "e": // enum + type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.EnumCodec{}} + case "r": // range + dt, ok := m.TypeForOID(ti.Rngsubtype) + if !ok { + return nil, fmt.Errorf("Range element OID %v was not already registered, needed for %q", ti.Rngsubtype, ti.TypeName) + } + + type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.RangeCodec{ElementType: dt}} + case "m": // multirange + dt, ok := m.TypeForOID(ti.Rngtypid) + if !ok { + return nil, fmt.Errorf("Multirange element OID %v was not already registered, needed for %q", ti.Rngtypid, ti.TypeName) + } + + type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.MultirangeCodec{ElementType: dt}} + default: + return nil, fmt.Errorf("Unknown typtype %q was found while registering %q", ti.Typtype, ti.TypeName) + } + + // the type_ is imposible to be null + m.RegisterType(type_) + if ti.NspName != "" { + nspType := &pgtype.Type{Name: ti.NspName + "." + type_.Name, OID: type_.OID, Codec: type_.Codec} + m.RegisterType(nspType) + result = append(result, nspType) + } + result = append(result, type_) + } + return result, nil +} + +// serverVersion returns the postgresql server version. +func serverVersion(c *Conn) (int64, error) { + serverVersionStr := c.PgConn().ParameterStatus("server_version") + serverVersionStr = regexp.MustCompile(`^[0-9]+`).FindString(serverVersionStr) + // if not PostgreSQL do nothing + if serverVersionStr == "" { + return 0, fmt.Errorf("Cannot identify server version in %q", serverVersionStr) + } + + version, err := strconv.ParseInt(serverVersionStr, 10, 64) + if err != nil { + return 0, fmt.Errorf("postgres version parsing failed: %w", err) + } + return version, nil +} diff --git a/vendor/github.com/jackc/pgx/v5/doc.go b/vendor/github.com/jackc/pgx/v5/doc.go index bc0391dde..5d2ae3889 100644 --- a/vendor/github.com/jackc/pgx/v5/doc.go +++ b/vendor/github.com/jackc/pgx/v5/doc.go @@ -175,7 +175,7 @@ notification is received or the context is canceled. Tracing and Logging -pgx supports tracing by setting ConnConfig.Tracer. +pgx supports tracing by setting ConnConfig.Tracer. To combine several tracers you can use the multitracer.Tracer. In addition, the tracelog package provides the TraceLog type which lets a traditional logger act as a Tracer. @@ -183,7 +183,7 @@ For debug tracing of the actual PostgreSQL wire protocol messages see github.com Lower Level PostgreSQL Functionality -github.com/jackc/pgx/v5/pgconn contains a lower level PostgreSQL driver roughly at the level of libpq. pgx.Conn in +github.com/jackc/pgx/v5/pgconn contains a lower level PostgreSQL driver roughly at the level of libpq. pgx.Conn is implemented on top of pgconn. The Conn.PgConn() method can be used to access this lower layer. PgBouncer diff --git a/vendor/github.com/jackc/pgx/v5/internal/iobufpool/iobufpool.go b/vendor/github.com/jackc/pgx/v5/internal/iobufpool/iobufpool.go index 89e0c2273..abc41f657 100644 --- a/vendor/github.com/jackc/pgx/v5/internal/iobufpool/iobufpool.go +++ b/vendor/github.com/jackc/pgx/v5/internal/iobufpool/iobufpool.go @@ -4,7 +4,10 @@ // an allocation is purposely not documented. https://github.com/golang/go/issues/16323 package iobufpool -import "sync" +import ( + "math/bits" + "sync" +) const minPoolExpOf2 = 8 @@ -37,15 +40,14 @@ func Get(size int) *[]byte { } func getPoolIdx(size int) int { - size-- - size >>= minPoolExpOf2 - i := 0 - for size > 0 { - size >>= 1 - i++ + if size < 2 { + return 0 } - - return i + idx := bits.Len(uint(size-1)) - minPoolExpOf2 + if idx < 0 { + return 0 + } + return idx } // Put returns buf to the pool. @@ -59,12 +61,18 @@ func Put(buf *[]byte) { } func putPoolIdx(size int) int { - minPoolSize := 1 << minPoolExpOf2 - for i := range pools { - if size == minPoolSize<= len(pools) { + return -1 } - return -1 + return idx } diff --git a/vendor/github.com/jackc/pgx/v5/internal/sanitize/benchmmark.sh b/vendor/github.com/jackc/pgx/v5/internal/sanitize/benchmmark.sh new file mode 100644 index 000000000..ec0f7b03a --- /dev/null +++ b/vendor/github.com/jackc/pgx/v5/internal/sanitize/benchmmark.sh @@ -0,0 +1,60 @@ +#!/usr/bin/env bash + +current_branch=$(git rev-parse --abbrev-ref HEAD) +if [ "$current_branch" == "HEAD" ]; then + current_branch=$(git rev-parse HEAD) +fi + +restore_branch() { + echo "Restoring original branch/commit: $current_branch" + git checkout "$current_branch" +} +trap restore_branch EXIT + +# Check if there are uncommitted changes +if ! git diff --quiet || ! git diff --cached --quiet; then + echo "There are uncommitted changes. Please commit or stash them before running this script." + exit 1 +fi + +# Ensure that at least one commit argument is passed +if [ "$#" -lt 1 ]; then + echo "Usage: $0 ... " + exit 1 +fi + +commits=("$@") +benchmarks_dir=benchmarks + +if ! mkdir -p "${benchmarks_dir}"; then + echo "Unable to create dir for benchmarks data" + exit 1 +fi + +# Benchmark results +bench_files=() + +# Run benchmark for each listed commit +for i in "${!commits[@]}"; do + commit="${commits[i]}" + git checkout "$commit" || { + echo "Failed to checkout $commit" + exit 1 + } + + # Sanitized commmit message + commit_message=$(git log -1 --pretty=format:"%s" | tr -c '[:alnum:]-_' '_') + + # Benchmark data will go there + bench_file="${benchmarks_dir}/${i}_${commit_message}.bench" + + if ! go test -bench=. -count=10 >"$bench_file"; then + echo "Benchmarking failed for commit $commit" + exit 1 + fi + + bench_files+=("$bench_file") +done + +# go install golang.org/x/perf/cmd/benchstat[@latest] +benchstat "${bench_files[@]}" diff --git a/vendor/github.com/jackc/pgx/v5/internal/sanitize/sanitize.go b/vendor/github.com/jackc/pgx/v5/internal/sanitize/sanitize.go index df58c4484..b516817cb 100644 --- a/vendor/github.com/jackc/pgx/v5/internal/sanitize/sanitize.go +++ b/vendor/github.com/jackc/pgx/v5/internal/sanitize/sanitize.go @@ -4,8 +4,10 @@ import ( "bytes" "encoding/hex" "fmt" + "slices" "strconv" "strings" + "sync" "time" "unicode/utf8" ) @@ -24,18 +26,33 @@ type Query struct { // https://github.com/jackc/pgx/issues/1380 const replacementcharacterwidth = 3 +const maxBufSize = 16384 // 16 Ki + +var bufPool = &pool[*bytes.Buffer]{ + new: func() *bytes.Buffer { + return &bytes.Buffer{} + }, + reset: func(b *bytes.Buffer) bool { + n := b.Len() + b.Reset() + return n < maxBufSize + }, +} + +var null = []byte("null") + func (q *Query) Sanitize(args ...any) (string, error) { argUse := make([]bool, len(args)) - buf := &bytes.Buffer{} + buf := bufPool.get() + defer bufPool.put(buf) for _, part := range q.Parts { - var str string switch part := part.(type) { case string: - str = part + buf.WriteString(part) case int: argIdx := part - 1 - + var p []byte if argIdx < 0 { return "", fmt.Errorf("first sql argument must be > 0") } @@ -43,34 +60,41 @@ func (q *Query) Sanitize(args ...any) (string, error) { if argIdx >= len(args) { return "", fmt.Errorf("insufficient arguments") } + + // Prevent SQL injection via Line Comment Creation + // https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p + buf.WriteByte(' ') + arg := args[argIdx] switch arg := arg.(type) { case nil: - str = "null" + p = null case int64: - str = strconv.FormatInt(arg, 10) + p = strconv.AppendInt(buf.AvailableBuffer(), arg, 10) case float64: - str = strconv.FormatFloat(arg, 'f', -1, 64) + p = strconv.AppendFloat(buf.AvailableBuffer(), arg, 'f', -1, 64) case bool: - str = strconv.FormatBool(arg) + p = strconv.AppendBool(buf.AvailableBuffer(), arg) case []byte: - str = QuoteBytes(arg) + p = QuoteBytes(buf.AvailableBuffer(), arg) case string: - str = QuoteString(arg) + p = QuoteString(buf.AvailableBuffer(), arg) case time.Time: - str = arg.Truncate(time.Microsecond).Format("'2006-01-02 15:04:05.999999999Z07:00:00'") + p = arg.Truncate(time.Microsecond). + AppendFormat(buf.AvailableBuffer(), "'2006-01-02 15:04:05.999999999Z07:00:00'") default: return "", fmt.Errorf("invalid arg type: %T", arg) } argUse[argIdx] = true + buf.Write(p) + // Prevent SQL injection via Line Comment Creation // https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p - str = " " + str + " " + buf.WriteByte(' ') default: return "", fmt.Errorf("invalid Part type: %T", part) } - buf.WriteString(str) } for i, used := range argUse { @@ -82,26 +106,99 @@ func (q *Query) Sanitize(args ...any) (string, error) { } func NewQuery(sql string) (*Query, error) { - l := &sqlLexer{ - src: sql, - stateFn: rawState, + query := &Query{} + query.init(sql) + + return query, nil +} + +var sqlLexerPool = &pool[*sqlLexer]{ + new: func() *sqlLexer { + return &sqlLexer{} + }, + reset: func(sl *sqlLexer) bool { + *sl = sqlLexer{} + return true + }, +} + +func (q *Query) init(sql string) { + parts := q.Parts[:0] + if parts == nil { + // dirty, but fast heuristic to preallocate for ~90% usecases + n := strings.Count(sql, "$") + strings.Count(sql, "--") + 1 + parts = make([]Part, 0, n) } + l := sqlLexerPool.get() + defer sqlLexerPool.put(l) + + l.src = sql + l.stateFn = rawState + l.parts = parts + for l.stateFn != nil { l.stateFn = l.stateFn(l) } - query := &Query{Parts: l.parts} - - return query, nil + q.Parts = l.parts } -func QuoteString(str string) string { - return "'" + strings.ReplaceAll(str, "'", "''") + "'" +func QuoteString(dst []byte, str string) []byte { + const quote = '\'' + + // Preallocate space for the worst case scenario + dst = slices.Grow(dst, len(str)*2+2) + + // Add opening quote + dst = append(dst, quote) + + // Iterate through the string without allocating + for i := 0; i < len(str); i++ { + if str[i] == quote { + dst = append(dst, quote, quote) + } else { + dst = append(dst, str[i]) + } + } + + // Add closing quote + dst = append(dst, quote) + + return dst } -func QuoteBytes(buf []byte) string { - return `'\x` + hex.EncodeToString(buf) + "'" +func QuoteBytes(dst, buf []byte) []byte { + if len(buf) == 0 { + return append(dst, `'\x'`...) + } + + // Calculate required length + requiredLen := 3 + hex.EncodedLen(len(buf)) + 1 + + // Ensure dst has enough capacity + if cap(dst)-len(dst) < requiredLen { + newDst := make([]byte, len(dst), len(dst)+requiredLen) + copy(newDst, dst) + dst = newDst + } + + // Record original length and extend slice + origLen := len(dst) + dst = dst[:origLen+requiredLen] + + // Add prefix + dst[origLen] = '\'' + dst[origLen+1] = '\\' + dst[origLen+2] = 'x' + + // Encode bytes directly into dst + hex.Encode(dst[origLen+3:len(dst)-1], buf) + + // Add suffix + dst[len(dst)-1] = '\'' + + return dst } type sqlLexer struct { @@ -319,13 +416,45 @@ func multilineCommentState(l *sqlLexer) stateFn { } } +var queryPool = &pool[*Query]{ + new: func() *Query { + return &Query{} + }, + reset: func(q *Query) bool { + n := len(q.Parts) + q.Parts = q.Parts[:0] + return n < 64 // drop too large queries + }, +} + // SanitizeSQL replaces placeholder values with args. It quotes and escapes args // as necessary. This function is only safe when standard_conforming_strings is // on. func SanitizeSQL(sql string, args ...any) (string, error) { - query, err := NewQuery(sql) - if err != nil { - return "", err - } + query := queryPool.get() + query.init(sql) + defer queryPool.put(query) + return query.Sanitize(args...) } + +type pool[E any] struct { + p sync.Pool + new func() E + reset func(E) bool +} + +func (pool *pool[E]) get() E { + v, ok := pool.p.Get().(E) + if !ok { + v = pool.new() + } + + return v +} + +func (p *pool[E]) put(v E) { + if p.reset(v) { + p.p.Put(v) + } +} diff --git a/vendor/github.com/jackc/pgx/v5/internal/stmtcache/lru_cache.go b/vendor/github.com/jackc/pgx/v5/internal/stmtcache/lru_cache.go index dec83f47b..52b479ad7 100644 --- a/vendor/github.com/jackc/pgx/v5/internal/stmtcache/lru_cache.go +++ b/vendor/github.com/jackc/pgx/v5/internal/stmtcache/lru_cache.go @@ -12,14 +12,16 @@ type LRUCache struct { m map[string]*list.Element l *list.List invalidStmts []*pgconn.StatementDescription + invalidSet map[string]struct{} } // NewLRUCache creates a new LRUCache. cap is the maximum size of the cache. func NewLRUCache(cap int) *LRUCache { return &LRUCache{ - cap: cap, - m: make(map[string]*list.Element), - l: list.New(), + cap: cap, + m: make(map[string]*list.Element), + l: list.New(), + invalidSet: make(map[string]struct{}), } } @@ -31,7 +33,6 @@ func (c *LRUCache) Get(key string) *pgconn.StatementDescription { } return nil - } // Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache or @@ -46,10 +47,8 @@ func (c *LRUCache) Put(sd *pgconn.StatementDescription) { } // The statement may have been invalidated but not yet handled. Do not readd it to the cache. - for _, invalidSD := range c.invalidStmts { - if invalidSD.SQL == sd.SQL { - return - } + if _, invalidated := c.invalidSet[sd.SQL]; invalidated { + return } if c.l.Len() == c.cap { @@ -64,7 +63,9 @@ func (c *LRUCache) Put(sd *pgconn.StatementDescription) { func (c *LRUCache) Invalidate(sql string) { if el, ok := c.m[sql]; ok { delete(c.m, sql) - c.invalidStmts = append(c.invalidStmts, el.Value.(*pgconn.StatementDescription)) + sd := el.Value.(*pgconn.StatementDescription) + c.invalidStmts = append(c.invalidStmts, sd) + c.invalidSet[sql] = struct{}{} c.l.Remove(el) } } @@ -73,7 +74,9 @@ func (c *LRUCache) Invalidate(sql string) { func (c *LRUCache) InvalidateAll() { el := c.l.Front() for el != nil { - c.invalidStmts = append(c.invalidStmts, el.Value.(*pgconn.StatementDescription)) + sd := el.Value.(*pgconn.StatementDescription) + c.invalidStmts = append(c.invalidStmts, sd) + c.invalidSet[sd.SQL] = struct{}{} el = el.Next() } @@ -91,6 +94,7 @@ func (c *LRUCache) GetInvalidated() []*pgconn.StatementDescription { // never seen by the call to GetInvalidated. func (c *LRUCache) RemoveInvalidated() { c.invalidStmts = nil + c.invalidSet = make(map[string]struct{}) } // Len returns the number of cached prepared statement descriptions. @@ -107,6 +111,7 @@ func (c *LRUCache) invalidateOldest() { oldest := c.l.Back() sd := oldest.Value.(*pgconn.StatementDescription) c.invalidStmts = append(c.invalidStmts, sd) + c.invalidSet[sd.SQL] = struct{}{} delete(c.m, sd.SQL) c.l.Remove(oldest) } diff --git a/vendor/github.com/jackc/pgx/v5/internal/stmtcache/unlimited_cache.go b/vendor/github.com/jackc/pgx/v5/internal/stmtcache/unlimited_cache.go deleted file mode 100644 index 696413291..000000000 --- a/vendor/github.com/jackc/pgx/v5/internal/stmtcache/unlimited_cache.go +++ /dev/null @@ -1,77 +0,0 @@ -package stmtcache - -import ( - "math" - - "github.com/jackc/pgx/v5/pgconn" -) - -// UnlimitedCache implements Cache with no capacity limit. -type UnlimitedCache struct { - m map[string]*pgconn.StatementDescription - invalidStmts []*pgconn.StatementDescription -} - -// NewUnlimitedCache creates a new UnlimitedCache. -func NewUnlimitedCache() *UnlimitedCache { - return &UnlimitedCache{ - m: make(map[string]*pgconn.StatementDescription), - } -} - -// Get returns the statement description for sql. Returns nil if not found. -func (c *UnlimitedCache) Get(sql string) *pgconn.StatementDescription { - return c.m[sql] -} - -// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache. -func (c *UnlimitedCache) Put(sd *pgconn.StatementDescription) { - if sd.SQL == "" { - panic("cannot store statement description with empty SQL") - } - - if _, present := c.m[sd.SQL]; present { - return - } - - c.m[sd.SQL] = sd -} - -// Invalidate invalidates statement description identified by sql. Does nothing if not found. -func (c *UnlimitedCache) Invalidate(sql string) { - if sd, ok := c.m[sql]; ok { - delete(c.m, sql) - c.invalidStmts = append(c.invalidStmts, sd) - } -} - -// InvalidateAll invalidates all statement descriptions. -func (c *UnlimitedCache) InvalidateAll() { - for _, sd := range c.m { - c.invalidStmts = append(c.invalidStmts, sd) - } - - c.m = make(map[string]*pgconn.StatementDescription) -} - -// GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated. -func (c *UnlimitedCache) GetInvalidated() []*pgconn.StatementDescription { - return c.invalidStmts -} - -// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a -// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were -// never seen by the call to GetInvalidated. -func (c *UnlimitedCache) RemoveInvalidated() { - c.invalidStmts = nil -} - -// Len returns the number of cached prepared statement descriptions. -func (c *UnlimitedCache) Len() int { - return len(c.m) -} - -// Cap returns the maximum number of cached prepared statement descriptions. -func (c *UnlimitedCache) Cap() int { - return math.MaxInt -} diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/auth_scram.go b/vendor/github.com/jackc/pgx/v5/pgconn/auth_scram.go index 064983615..9979087a0 100644 --- a/vendor/github.com/jackc/pgx/v5/pgconn/auth_scram.go +++ b/vendor/github.com/jackc/pgx/v5/pgconn/auth_scram.go @@ -15,15 +15,16 @@ package pgconn import ( "bytes" "crypto/hmac" + "crypto/pbkdf2" "crypto/rand" "crypto/sha256" "encoding/base64" "errors" "fmt" + "slices" "strconv" "github.com/jackc/pgx/v5/pgproto3" - "golang.org/x/crypto/pbkdf2" "golang.org/x/text/secure/precis" ) @@ -107,7 +108,7 @@ func (c *PgConn) rxSASLFinal() (*pgproto3.AuthenticationSASLFinal, error) { type scramClient struct { serverAuthMechanisms []string - password []byte + password string clientNonce []byte clientFirstMessageBare []byte @@ -127,23 +128,17 @@ func newScramClient(serverAuthMechanisms []string, password string) (*scramClien } // Ensure server supports SCRAM-SHA-256 - hasScramSHA256 := false - for _, mech := range sc.serverAuthMechanisms { - if mech == "SCRAM-SHA-256" { - hasScramSHA256 = true - break - } - } + hasScramSHA256 := slices.Contains(sc.serverAuthMechanisms, "SCRAM-SHA-256") if !hasScramSHA256 { return nil, errors.New("server does not support SCRAM-SHA-256") } // precis.OpaqueString is equivalent to SASLprep for password. var err error - sc.password, err = precis.OpaqueString.Bytes([]byte(password)) + sc.password, err = precis.OpaqueString.String(password) if err != nil { // PostgreSQL allows passwords invalid according to SCRAM / SASLprep. - sc.password = []byte(password) + sc.password = password } buf := make([]byte, clientNonceLen) @@ -158,8 +153,8 @@ func newScramClient(serverAuthMechanisms []string, password string) (*scramClien } func (sc *scramClient) clientFirstMessage() []byte { - sc.clientFirstMessageBare = []byte(fmt.Sprintf("n=,r=%s", sc.clientNonce)) - return []byte(fmt.Sprintf("n,,%s", sc.clientFirstMessageBare)) + sc.clientFirstMessageBare = fmt.Appendf(nil, "n=,r=%s", sc.clientNonce) + return fmt.Appendf(nil, "n,,%s", sc.clientFirstMessageBare) } func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error { @@ -218,9 +213,13 @@ func (sc *scramClient) recvServerFirstMessage(serverFirstMessage []byte) error { } func (sc *scramClient) clientFinalMessage() string { - clientFinalMessageWithoutProof := []byte(fmt.Sprintf("c=biws,r=%s", sc.clientAndServerNonce)) + clientFinalMessageWithoutProof := fmt.Appendf(nil, "c=biws,r=%s", sc.clientAndServerNonce) - sc.saltedPassword = pbkdf2.Key([]byte(sc.password), sc.salt, sc.iterations, 32, sha256.New) + var err error + sc.saltedPassword, err = pbkdf2.Key(sha256.New, sc.password, sc.salt, sc.iterations, 32) + if err != nil { + panic(err) // This should never happen. + } sc.authMessage = bytes.Join([][]byte{sc.clientFirstMessageBare, sc.serverFirstMessage, clientFinalMessageWithoutProof}, []byte(",")) clientProof := computeClientProof(sc.saltedPassword, sc.authMessage) @@ -254,7 +253,7 @@ func computeClientProof(saltedPassword, authMessage []byte) []byte { clientSignature := computeHMAC(storedKey[:], authMessage) clientProof := make([]byte, len(clientSignature)) - for i := 0; i < len(clientSignature); i++ { + for i := range clientSignature { clientProof[i] = clientKey[i] ^ clientSignature[i] } @@ -263,7 +262,7 @@ func computeClientProof(saltedPassword, authMessage []byte) []byte { return buf } -func computeServerSignature(saltedPassword []byte, authMessage []byte) []byte { +func computeServerSignature(saltedPassword, authMessage []byte) []byte { serverKey := computeHMAC(saltedPassword, []byte("Server Key")) serverSignature := computeHMAC(serverKey, authMessage) buf := make([]byte, base64.StdEncoding.EncodedLen(len(serverSignature))) diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/config.go b/vendor/github.com/jackc/pgx/v5/pgconn/config.go index 598917f55..d5914aad9 100644 --- a/vendor/github.com/jackc/pgx/v5/pgconn/config.go +++ b/vendor/github.com/jackc/pgx/v5/pgconn/config.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "io" + "maps" "math" "net" "net/url" @@ -23,9 +24,11 @@ import ( "github.com/jackc/pgx/v5/pgproto3" ) -type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error -type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error -type GetSSLPasswordFunc func(ctx context.Context) string +type ( + AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error + ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error + GetSSLPasswordFunc func(ctx context.Context) string +) // Config is the settings used to establish a connection to a PostgreSQL server. It must be created by [ParseConfig]. A // manually initialized Config will cause ConnectConfig to panic. @@ -51,6 +54,15 @@ type Config struct { KerberosSpn string Fallbacks []*FallbackConfig + SSLNegotiation string // sslnegotiation=postgres or sslnegotiation=direct + + // AfterNetConnect is called after the network connection, including TLS if applicable, is established but before any + // PostgreSQL protocol communication. It takes the established net.Conn and returns a net.Conn that will be used in + // its place. It can be used to wrap the net.Conn (e.g. for logging, diagnostics, or testing). Its functionality has + // some overlap with DialFunc. However, DialFunc takes place before TLS is established and cannot be used to control + // the final net.Conn used for PostgreSQL protocol communication while AfterNetConnect can. + AfterNetConnect func(ctx context.Context, config *Config, conn net.Conn) (net.Conn, error) + // ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server. // It can be used to validate that the server is acceptable. If this returns an error the connection is closed and the next // fallback config is tried. This allows implementing high availability behavior such as libpq does with target_session_attrs. @@ -92,9 +104,7 @@ func (c *Config) Copy() *Config { } if newConf.RuntimeParams != nil { newConf.RuntimeParams = make(map[string]string, len(c.RuntimeParams)) - for k, v := range c.RuntimeParams { - newConf.RuntimeParams[k] = v - } + maps.Copy(newConf.RuntimeParams, c.RuntimeParams) } if newConf.Fallbacks != nil { newConf.Fallbacks = make([]*FallbackConfig, len(c.Fallbacks)) @@ -177,7 +187,7 @@ func NetworkAddress(host string, port uint16) (network, address string) { // // ParseConfig supports specifying multiple hosts in similar manner to libpq. Host and port may include comma separated // values that will be tried in order. This can be used as part of a high availability system. See -// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information. +// https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information. // // # Example URL // postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb @@ -198,13 +208,15 @@ func NetworkAddress(host string, port uint16) (network, address string) { // PGSSLKEY // PGSSLROOTCERT // PGSSLPASSWORD +// PGOPTIONS // PGAPPNAME // PGCONNECT_TIMEOUT // PGTARGETSESSIONATTRS +// PGTZ // -// See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of environment variables. +// See http://www.postgresql.org/docs/current/static/libpq-envars.html for details on the meaning of environment variables. // -// See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key word names. They are +// See https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key word names. They are // usually but not always the environment variable name downcased and without the "PG" prefix. // // Important Security Notes: @@ -212,7 +224,7 @@ func NetworkAddress(host string, port uint16) (network, address string) { // ParseConfig tries to match libpq behavior with regard to PGSSLMODE. This includes defaulting to "prefer" behavior if // not set. // -// See http://www.postgresql.org/docs/11/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on what level of +// See http://www.postgresql.org/docs/current/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on what level of // security each sslmode provides. // // The sslmode "prefer" (the default), sslmode "allow", and multiple hosts are implemented via the Fallbacks field of @@ -318,6 +330,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con "sslkey": {}, "sslcert": {}, "sslrootcert": {}, + "sslnegotiation": {}, "sslpassword": {}, "sslsni": {}, "krbspn": {}, @@ -386,6 +399,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con config.Port = fallbacks[0].Port config.TLSConfig = fallbacks[0].TLSConfig config.Fallbacks = fallbacks[1:] + config.SSLNegotiation = settings["sslnegotiation"] passfile, err := pgpassfile.ReadPassfile(settings["passfile"]) if err == nil { @@ -423,9 +437,7 @@ func mergeSettings(settingSets ...map[string]string) map[string]string { settings := make(map[string]string) for _, s2 := range settingSets { - for k, v := range s2 { - settings[k] = v - } + maps.Copy(settings, s2) } return settings @@ -449,9 +461,12 @@ func parseEnvSettings() map[string]string { "PGSSLSNI": "sslsni", "PGSSLROOTCERT": "sslrootcert", "PGSSLPASSWORD": "sslpassword", + "PGSSLNEGOTIATION": "sslnegotiation", "PGTARGETSESSIONATTRS": "target_session_attrs", "PGSERVICE": "service", "PGSERVICEFILE": "servicefile", + "PGTZ": "timezone", + "PGOPTIONS": "options", } for envname, realname := range nameMap { @@ -467,14 +482,17 @@ func parseEnvSettings() map[string]string { func parseURLSettings(connString string) (map[string]string, error) { settings := make(map[string]string) - url, err := url.Parse(connString) + parsedURL, err := url.Parse(connString) if err != nil { + if urlErr := new(url.Error); errors.As(err, &urlErr) { + return nil, urlErr.Err + } return nil, err } - if url.User != nil { - settings["user"] = url.User.Username() - if password, present := url.User.Password(); present { + if parsedURL.User != nil { + settings["user"] = parsedURL.User.Username() + if password, present := parsedURL.User.Password(); present { settings["password"] = password } } @@ -482,7 +500,7 @@ func parseURLSettings(connString string) (map[string]string, error) { // Handle multiple host:port's in url.Host by splitting them into host,host,host and port,port,port. var hosts []string var ports []string - for _, host := range strings.Split(url.Host, ",") { + for host := range strings.SplitSeq(parsedURL.Host, ",") { if host == "" { continue } @@ -508,7 +526,7 @@ func parseURLSettings(connString string) (map[string]string, error) { settings["port"] = strings.Join(ports, ",") } - database := strings.TrimLeft(url.Path, "/") + database := strings.TrimLeft(parsedURL.Path, "/") if database != "" { settings["database"] = database } @@ -517,7 +535,7 @@ func parseURLSettings(connString string) (map[string]string, error) { "dbname": "database", } - for k, v := range url.Query() { + for k, v := range parsedURL.Query() { if k2, present := nameMap[k]; present { k = k2 } @@ -643,6 +661,7 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P sslkey := settings["sslkey"] sslpassword := settings["sslpassword"] sslsni := settings["sslsni"] + sslnegotiation := settings["sslnegotiation"] // Match libpq default behavior if sslmode == "" { @@ -654,6 +673,43 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P tlsConfig := &tls.Config{} + if sslnegotiation == "direct" { + tlsConfig.NextProtos = []string{"postgresql"} + if sslmode == "prefer" { + sslmode = "require" + } + } + + if sslrootcert != "" { + var caCertPool *x509.CertPool + + if sslrootcert == "system" { + var err error + + caCertPool, err = x509.SystemCertPool() + if err != nil { + return nil, fmt.Errorf("unable to load system certificate pool: %w", err) + } + + sslmode = "verify-full" + } else { + caCertPool = x509.NewCertPool() + + caPath := sslrootcert + caCert, err := os.ReadFile(caPath) + if err != nil { + return nil, fmt.Errorf("unable to read CA file: %w", err) + } + + if !caCertPool.AppendCertsFromPEM(caCert) { + return nil, errors.New("unable to add CA to cert pool") + } + } + + tlsConfig.RootCAs = caCertPool + tlsConfig.ClientCAs = caCertPool + } + switch sslmode { case "disable": return []*tls.Config{nil}, nil @@ -663,7 +719,7 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P // According to PostgreSQL documentation, if a root CA file exists, // the behavior of sslmode=require should be the same as that of verify-ca // - // See https://www.postgresql.org/docs/12/libpq-ssl.html + // See https://www.postgresql.org/docs/current/libpq-ssl.html if sslrootcert != "" { goto nextCase } @@ -711,23 +767,6 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P return nil, errors.New("sslmode is invalid") } - if sslrootcert != "" { - caCertPool := x509.NewCertPool() - - caPath := sslrootcert - caCert, err := os.ReadFile(caPath) - if err != nil { - return nil, fmt.Errorf("unable to read CA file: %w", err) - } - - if !caCertPool.AppendCertsFromPEM(caCert) { - return nil, errors.New("unable to add CA to cert pool") - } - - tlsConfig.RootCAs = caCertPool - tlsConfig.ClientCAs = caCertPool - } - if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") { return nil, errors.New(`both "sslcert" and "sslkey" are required`) } @@ -751,8 +790,8 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P if sslpassword != "" { decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword)) } - //if sslpassword not provided or has decryption error when use it - //try to find sslpassword with callback function + // if sslpassword not provided or has decryption error when use it + // try to find sslpassword with callback function if sslpassword == "" || decryptedError != nil { if parseConfigOptions.GetSSLPassword != nil { sslpassword = parseConfigOptions.GetSSLPassword(context.Background()) @@ -845,12 +884,12 @@ func makeConnectTimeoutDialFunc(timeout time.Duration) DialFunc { // ValidateConnectTargetSessionAttrsReadWrite is a ValidateConnectFunc that implements libpq compatible // target_session_attrs=read-write. func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error { - result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read() - if result.Err != nil { - return result.Err + result, err := pgConn.Exec(ctx, "show transaction_read_only").ReadAll() + if err != nil { + return err } - if string(result.Rows[0][0]) == "on" { + if string(result[0].Rows[0][0]) == "on" { return errors.New("read only connection") } @@ -860,12 +899,12 @@ func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgC // ValidateConnectTargetSessionAttrsReadOnly is a ValidateConnectFunc that implements libpq compatible // target_session_attrs=read-only. func ValidateConnectTargetSessionAttrsReadOnly(ctx context.Context, pgConn *PgConn) error { - result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read() - if result.Err != nil { - return result.Err + result, err := pgConn.Exec(ctx, "show transaction_read_only").ReadAll() + if err != nil { + return err } - if string(result.Rows[0][0]) != "on" { + if string(result[0].Rows[0][0]) != "on" { return errors.New("connection is not read only") } @@ -875,12 +914,12 @@ func ValidateConnectTargetSessionAttrsReadOnly(ctx context.Context, pgConn *PgCo // ValidateConnectTargetSessionAttrsStandby is a ValidateConnectFunc that implements libpq compatible // target_session_attrs=standby. func ValidateConnectTargetSessionAttrsStandby(ctx context.Context, pgConn *PgConn) error { - result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read() - if result.Err != nil { - return result.Err + result, err := pgConn.Exec(ctx, "select pg_is_in_recovery()").ReadAll() + if err != nil { + return err } - if string(result.Rows[0][0]) != "t" { + if string(result[0].Rows[0][0]) != "t" { return errors.New("server is not in hot standby mode") } @@ -890,12 +929,12 @@ func ValidateConnectTargetSessionAttrsStandby(ctx context.Context, pgConn *PgCon // ValidateConnectTargetSessionAttrsPrimary is a ValidateConnectFunc that implements libpq compatible // target_session_attrs=primary. func ValidateConnectTargetSessionAttrsPrimary(ctx context.Context, pgConn *PgConn) error { - result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read() - if result.Err != nil { - return result.Err + result, err := pgConn.Exec(ctx, "select pg_is_in_recovery()").ReadAll() + if err != nil { + return err } - if string(result.Rows[0][0]) == "t" { + if string(result[0].Rows[0][0]) == "t" { return errors.New("server is in standby mode") } @@ -905,12 +944,12 @@ func ValidateConnectTargetSessionAttrsPrimary(ctx context.Context, pgConn *PgCon // ValidateConnectTargetSessionAttrsPreferStandby is a ValidateConnectFunc that implements libpq compatible // target_session_attrs=prefer-standby. func ValidateConnectTargetSessionAttrsPreferStandby(ctx context.Context, pgConn *PgConn) error { - result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read() - if result.Err != nil { - return result.Err + result, err := pgConn.Exec(ctx, "select pg_is_in_recovery()").ReadAll() + if err != nil { + return err } - if string(result.Rows[0][0]) != "t" { + if string(result[0].Rows[0][0]) != "t" { return &NotPreferredError{err: errors.New("server is not in hot standby mode")} } diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/errors.go b/vendor/github.com/jackc/pgx/v5/pgconn/errors.go index ec4a6d47c..bc1e31e31 100644 --- a/vendor/github.com/jackc/pgx/v5/pgconn/errors.go +++ b/vendor/github.com/jackc/pgx/v5/pgconn/errors.go @@ -27,7 +27,7 @@ func Timeout(err error) bool { } // PgError represents an error reported by the PostgreSQL server. See -// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for +// http://www.postgresql.org/docs/current/static/protocol-error-fields.html for // detailed field description. type PgError struct { Severity string @@ -112,6 +112,14 @@ type ParseConfigError struct { err error } +func NewParseConfigError(conn, msg string, err error) error { + return &ParseConfigError{ + ConnString: conn, + msg: msg, + err: err, + } +} + func (e *ParseConfigError) Error() string { // Now that ParseConfigError is public and ConnString is available to the developer, perhaps it would be better only // return a static string. That would ensure that the error message cannot leak a password. The ConnString field would @@ -246,3 +254,20 @@ func (e *NotPreferredError) SafeToRetry() bool { func (e *NotPreferredError) Unwrap() error { return e.err } + +type PrepareError struct { + err error + + ParseComplete bool // Indicates whether the error occurred after a ParseComplete message was received. +} + +func (e *PrepareError) Error() string { + if e.ParseComplete { + return fmt.Sprintf("prepare failed after ParseComplete: %s", e.err.Error()) + } + return e.err.Error() +} + +func (e *PrepareError) Unwrap() error { + return e.err +} diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/krb5.go b/vendor/github.com/jackc/pgx/v5/pgconn/krb5.go index 3c1af3477..efb0d61b8 100644 --- a/vendor/github.com/jackc/pgx/v5/pgconn/krb5.go +++ b/vendor/github.com/jackc/pgx/v5/pgconn/krb5.go @@ -28,7 +28,7 @@ func RegisterGSSProvider(newGSSArg NewGSSFunc) { // GSS provides GSSAPI authentication (e.g., Kerberos). type GSS interface { - GetInitToken(host string, service string) ([]byte, error) + GetInitToken(host, service string) ([]byte, error) GetInitTokenFromSPN(spn string) ([]byte, error) Continue(inToken []byte) (done bool, outToken []byte, err error) } diff --git a/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go b/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go index 7efb522a4..081e20578 100644 --- a/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go +++ b/vendor/github.com/jackc/pgx/v5/pgconn/pgconn.go @@ -1,6 +1,7 @@ package pgconn import ( + "container/list" "context" "crypto/md5" "crypto/tls" @@ -9,6 +10,7 @@ import ( "errors" "fmt" "io" + "maps" "math" "net" "strconv" @@ -134,7 +136,7 @@ func ConnectWithOptions(ctx context.Context, connString string, parseConfigOptio // // If config.Fallbacks are present they will sequentially be tried in case of error establishing network connection. An // authentication error will terminate the chain of attempts (like libpq: -// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS) and be returned as the error. +// https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS) and be returned as the error. func ConnectConfig(ctx context.Context, config *Config) (*PgConn, error) { // Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from // zero values. @@ -267,12 +269,15 @@ func connectPreferred(ctx context.Context, config *Config, connectOneConfigs []* var pgErr *PgError if errors.As(err, &pgErr) { - const ERRCODE_INVALID_PASSWORD = "28P01" // wrong password - const ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION = "28000" // wrong password or bad pg_hba.conf settings - const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist - const ERRCODE_INSUFFICIENT_PRIVILEGE = "42501" // missing connect privilege + // pgx will try next host even if libpq does not in certain cases (see #2246) + // consider change for the next major version + + const ERRCODE_INVALID_PASSWORD = "28P01" + const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist + const ERRCODE_INSUFFICIENT_PRIVILEGE = "42501" // missing connect privilege + + // auth failed due to invalid password, db does not exist or user has no permission if pgErr.Code == ERRCODE_INVALID_PASSWORD || - pgErr.Code == ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION && c.tlsConfig != nil || pgErr.Code == ERRCODE_INVALID_CATALOG_NAME || pgErr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE { return nil, allErrors @@ -321,7 +326,15 @@ func connectOne(ctx context.Context, config *Config, connectConfig *connectOneCo if connectConfig.tlsConfig != nil { pgConn.contextWatcher = ctxwatch.NewContextWatcher(&DeadlineContextWatcherHandler{Conn: pgConn.conn}) pgConn.contextWatcher.Watch(ctx) - tlsConn, err := startTLS(pgConn.conn, connectConfig.tlsConfig) + var ( + tlsConn net.Conn + err error + ) + if config.SSLNegotiation == "direct" { + tlsConn = tls.Client(pgConn.conn, connectConfig.tlsConfig) + } else { + tlsConn, err = startTLS(pgConn.conn, connectConfig.tlsConfig) + } pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS. if err != nil { pgConn.conn.Close() @@ -331,6 +344,14 @@ func connectOne(ctx context.Context, config *Config, connectConfig *connectOneCo pgConn.conn = tlsConn } + if config.AfterNetConnect != nil { + pgConn.conn, err = config.AfterNetConnect(ctx, config, pgConn.conn) + if err != nil { + pgConn.conn.Close() + return nil, newPerDialConnectError("AfterNetConnect failed", err) + } + } + pgConn.contextWatcher = ctxwatch.NewContextWatcher(config.BuildContextWatcherHandler(pgConn)) pgConn.contextWatcher.Watch(ctx) defer pgConn.contextWatcher.Unwatch() @@ -354,9 +375,7 @@ func connectOne(ctx context.Context, config *Config, connectConfig *connectOneCo } // Copy default run-time params - for k, v := range config.RuntimeParams { - startupMsg.Parameters[k] = v - } + maps.Copy(startupMsg.Parameters, config.RuntimeParams) startupMsg.Parameters["user"] = config.User if config.Database != "" { @@ -846,6 +865,10 @@ type StatementDescription struct { // // Prepare does not send a PREPARE statement to the server. It uses the PostgreSQL Parse and Describe protocol messages // directly. +// +// In extremely rare cases, Prepare may fail after the Parse is successful, but before the Describe is complete. In this +// case, the returned error will be an error where errors.As with a *PrepareError succeeds and the *PrepareError has +// ParseComplete set to true. func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*StatementDescription, error) { if err := pgConn.lock(); err != nil { return nil, err @@ -873,7 +896,8 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [ psd := &StatementDescription{Name: name, SQL: sql} - var parseErr error + var ParseComplete bool + var pgErr *PgError readloop: for { @@ -884,20 +908,22 @@ readloop: } switch msg := msg.(type) { + case *pgproto3.ParseComplete: + ParseComplete = true case *pgproto3.ParameterDescription: psd.ParamOIDs = make([]uint32, len(msg.ParameterOIDs)) copy(psd.ParamOIDs, msg.ParameterOIDs) case *pgproto3.RowDescription: psd.Fields = pgConn.convertRowDescription(nil, msg) case *pgproto3.ErrorResponse: - parseErr = ErrorResponseToPgError(msg) + pgErr = ErrorResponseToPgError(msg) case *pgproto3.ReadyForQuery: break readloop } } - if parseErr != nil { - return nil, parseErr + if pgErr != nil { + return nil, &PrepareError{err: pgErr, ParseComplete: ParseComplete} } return psd, nil } @@ -979,7 +1005,8 @@ func noticeResponseToNotice(msg *pgproto3.NoticeResponse) *Notice { // CancelRequest sends a cancel request to the PostgreSQL server. It returns an error if unable to deliver the cancel // request, but lack of an error does not ensure that the query was canceled. As specified in the documentation, there -// is no way to be sure a query was canceled. See https://www.postgresql.org/docs/11/protocol-flow.html#id-1.10.5.7.9 +// is no way to be sure a query was canceled. +// See https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-CANCELING-REQUESTS func (pgConn *PgConn) CancelRequest(ctx context.Context) error { // Open a cancellation request to the same server. The address is taken from the net.Conn directly instead of reusing // the connection config. This is important in high availability configurations where fallback connections may be @@ -1128,7 +1155,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { // binary format. If resultFormats is nil all results will be in text format. // // ResultReader must be closed before PgConn can be used again. -func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) *ResultReader { +func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats, resultFormats []int16) *ResultReader { result := pgConn.execExtendedPrefix(ctx, paramValues) if result.closed { return result @@ -1154,7 +1181,7 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] // binary format. If resultFormats is nil all results will be in text format. // // ResultReader must be closed before PgConn can be used again. -func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) *ResultReader { +func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramValues [][]byte, paramFormats, resultFormats []int16) *ResultReader { result := pgConn.execExtendedPrefix(ctx, paramValues) if result.closed { return result @@ -1361,7 +1388,14 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co close(pgConn.cleanupDone) return CommandTag{}, normalizeTimeoutError(ctx, err) } - msg, _ := pgConn.receiveMessage() + // peekMessage never returns err in the bufferingReceive mode - it only forwards the bufferingReceive variables. + // Therefore, the only case for receiveMessage to return err is during handling of the ErrorResponse message type + // and using pgOnError handler to determine the connection is no longer valid (and thus closing the conn). + msg, serverError := pgConn.receiveMessage() + if serverError != nil { + close(abortCopyChan) + return CommandTag{}, serverError + } switch msg := msg.(type) { case *pgproto3.ErrorResponse: @@ -1408,9 +1442,8 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co // MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch. type MultiResultReader struct { - pgConn *PgConn - ctx context.Context - pipeline *Pipeline + pgConn *PgConn + ctx context.Context rr *ResultReader @@ -1443,12 +1476,8 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) switch msg := msg.(type) { case *pgproto3.ReadyForQuery: mrr.closed = true - if mrr.pipeline != nil { - mrr.pipeline.expectedReadyForQueryCount-- - } else { - mrr.pgConn.contextWatcher.Unwatch() - mrr.pgConn.unlock() - } + mrr.pgConn.contextWatcher.Unwatch() + mrr.pgConn.unlock() case *pgproto3.ErrorResponse: mrr.err = ErrorResponseToPgError(msg) } @@ -1484,7 +1513,12 @@ func (mrr *MultiResultReader) NextResult() bool { mrr.rr = &mrr.pgConn.resultReader return true case *pgproto3.EmptyQueryResponse: - return false + mrr.pgConn.resultReader = ResultReader{ + commandConcluded: true, + closed: true, + } + mrr.rr = &mrr.pgConn.resultReader + return true } } @@ -1672,7 +1706,11 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error case *pgproto3.EmptyQueryResponse: rr.concludeCommand(CommandTag{}, nil) case *pgproto3.ErrorResponse: - rr.concludeCommand(CommandTag{}, ErrorResponseToPgError(msg)) + pgErr := ErrorResponseToPgError(msg) + if rr.pipeline != nil { + rr.pipeline.state.HandleError(pgErr) + } + rr.concludeCommand(CommandTag{}, pgErr) } return msg, nil @@ -1701,7 +1739,7 @@ type Batch struct { } // ExecParams appends an ExecParams command to the batch. See PgConn.ExecParams for parameter descriptions. -func (batch *Batch) ExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) { +func (batch *Batch) ExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats, resultFormats []int16) { if batch.err != nil { return } @@ -1714,7 +1752,7 @@ func (batch *Batch) ExecParams(sql string, paramValues [][]byte, paramOIDs []uin } // ExecPrepared appends an ExecPrepared e command to the batch. See PgConn.ExecPrepared for parameter descriptions. -func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) { +func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFormats, resultFormats []int16) { if batch.err != nil { return } @@ -1773,9 +1811,10 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR batch.buf, batch.err = (&pgproto3.Sync{}).Encode(batch.buf) if batch.err != nil { + pgConn.contextWatcher.Unwatch() + multiResult.err = normalizeTimeoutError(multiResult.ctx, batch.err) multiResult.closed = true - multiResult.err = batch.err - pgConn.unlock() + pgConn.asyncClose() return multiResult } @@ -1783,9 +1822,10 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR defer pgConn.exitPotentialWriteReadDeadlock() _, err := pgConn.conn.Write(batch.buf) if err != nil { + pgConn.contextWatcher.Unwatch() + multiResult.err = normalizeTimeoutError(multiResult.ctx, err) multiResult.closed = true - multiResult.err = err - pgConn.unlock() + pgConn.asyncClose() return multiResult } @@ -1886,7 +1926,7 @@ func (pgConn *PgConn) flushWithPotentialWriteReadDeadlock() error { // // This should not be confused with the PostgreSQL protocol Sync message. func (pgConn *PgConn) SyncConn(ctx context.Context) error { - for i := 0; i < 10; i++ { + for range 10 { if pgConn.bgReader.Status() == bgreader.StatusStopped && pgConn.frontend.ReadBufferLen() == 0 { return nil } @@ -1999,9 +2039,7 @@ type Pipeline struct { conn *PgConn ctx context.Context - expectedReadyForQueryCount int - pendingSync bool - + state pipelineState err error closed bool } @@ -2012,6 +2050,122 @@ type PipelineSync struct{} // CloseComplete is returned by GetResults when a CloseComplete message is received. type CloseComplete struct{} +type pipelineRequestType int + +const ( + pipelineNil pipelineRequestType = iota + pipelinePrepare + pipelineQueryParams + pipelineQueryPrepared + pipelineDeallocate + pipelineSyncRequest + pipelineFlushRequest +) + +type pipelineRequestEvent struct { + RequestType pipelineRequestType + WasSentToServer bool + BeforeFlushOrSync bool +} + +type pipelineState struct { + requestEventQueue list.List + lastRequestType pipelineRequestType + pgErr *PgError + expectedReadyForQueryCount int +} + +func (s *pipelineState) Init() { + s.requestEventQueue.Init() + s.lastRequestType = pipelineNil +} + +func (s *pipelineState) RegisterSendingToServer() { + for elem := s.requestEventQueue.Back(); elem != nil; elem = elem.Prev() { + val := elem.Value.(pipelineRequestEvent) + if val.WasSentToServer { + return + } + val.WasSentToServer = true + elem.Value = val + } +} + +func (s *pipelineState) registerFlushingBufferOnServer() { + for elem := s.requestEventQueue.Back(); elem != nil; elem = elem.Prev() { + val := elem.Value.(pipelineRequestEvent) + if val.BeforeFlushOrSync { + return + } + val.BeforeFlushOrSync = true + elem.Value = val + } +} + +func (s *pipelineState) PushBackRequestType(req pipelineRequestType) { + if req == pipelineNil { + return + } + + if req != pipelineFlushRequest { + s.requestEventQueue.PushBack(pipelineRequestEvent{RequestType: req}) + } + if req == pipelineFlushRequest || req == pipelineSyncRequest { + s.registerFlushingBufferOnServer() + } + s.lastRequestType = req + + if req == pipelineSyncRequest { + s.expectedReadyForQueryCount++ + } +} + +func (s *pipelineState) ExtractFrontRequestType() pipelineRequestType { + for { + elem := s.requestEventQueue.Front() + if elem == nil { + return pipelineNil + } + val := elem.Value.(pipelineRequestEvent) + if !(val.WasSentToServer && val.BeforeFlushOrSync) { + return pipelineNil + } + + s.requestEventQueue.Remove(elem) + if val.RequestType == pipelineSyncRequest { + s.pgErr = nil + } + if s.pgErr == nil { + return val.RequestType + } + } +} + +func (s *pipelineState) HandleError(err *PgError) { + s.pgErr = err +} + +func (s *pipelineState) HandleReadyForQuery() { + s.expectedReadyForQueryCount-- +} + +func (s *pipelineState) PendingSync() bool { + var notPendingSync bool + + if elem := s.requestEventQueue.Back(); elem != nil { + val := elem.Value.(pipelineRequestEvent) + notPendingSync = (val.RequestType == pipelineSyncRequest) && val.WasSentToServer + } else { + notPendingSync = (s.lastRequestType == pipelineSyncRequest) || (s.lastRequestType == pipelineNil) + } + + return !notPendingSync +} + +func (s *pipelineState) ExpectedReadyForQuery() int { + return s.expectedReadyForQueryCount +} + // StartPipeline switches the connection to pipeline mode and returns a *Pipeline. In pipeline mode requests can be sent // to the server without waiting for a response. Close must be called on the returned *Pipeline to return the connection // to normal mode. While in pipeline mode, no methods that communicate with the server may be called except @@ -2020,16 +2174,21 @@ type CloseComplete struct{} // Prefer ExecBatch when only sending one group of queries at once. func (pgConn *PgConn) StartPipeline(ctx context.Context) *Pipeline { if err := pgConn.lock(); err != nil { - return &Pipeline{ + pipeline := &Pipeline{ closed: true, err: err, } + pipeline.state.Init() + + return pipeline } pgConn.pipeline = Pipeline{ conn: pgConn, ctx: ctx, } + pgConn.pipeline.state.Init() + pipeline := &pgConn.pipeline if ctx != context.Background() { @@ -2052,10 +2211,10 @@ func (p *Pipeline) SendPrepare(name, sql string, paramOIDs []uint32) { if p.closed { return } - p.pendingSync = true p.conn.frontend.SendParse(&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}) p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'S', Name: name}) + p.state.PushBackRequestType(pipelinePrepare) } // SendDeallocate deallocates a prepared statement. @@ -2063,34 +2222,65 @@ func (p *Pipeline) SendDeallocate(name string) { if p.closed { return } - p.pendingSync = true p.conn.frontend.SendClose(&pgproto3.Close{ObjectType: 'S', Name: name}) + p.state.PushBackRequestType(pipelineDeallocate) } // SendQueryParams is the pipeline version of *PgConn.QueryParams. -func (p *Pipeline) SendQueryParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) { +func (p *Pipeline) SendQueryParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats, resultFormats []int16) { if p.closed { return } - p.pendingSync = true p.conn.frontend.SendParse(&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}) p.conn.frontend.SendBind(&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'}) p.conn.frontend.SendExecute(&pgproto3.Execute{}) + p.state.PushBackRequestType(pipelineQueryParams) } // SendQueryPrepared is the pipeline version of *PgConn.QueryPrepared. -func (p *Pipeline) SendQueryPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) { +func (p *Pipeline) SendQueryPrepared(stmtName string, paramValues [][]byte, paramFormats, resultFormats []int16) { if p.closed { return } - p.pendingSync = true p.conn.frontend.SendBind(&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'}) p.conn.frontend.SendExecute(&pgproto3.Execute{}) + p.state.PushBackRequestType(pipelineQueryPrepared) +} + +// SendFlushRequest sends a request for the server to flush its output buffer. +// +// The server flushes its output buffer automatically as a result of Sync being called, +// or on any request when not in pipeline mode; this function is useful to cause the server +// to flush its output buffer in pipeline mode without establishing a synchronization point. +// Note that the request is not itself flushed to the server automatically; use Flush if +// necessary. This copies the behavior of libpq PQsendFlushRequest. +func (p *Pipeline) SendFlushRequest() { + if p.closed { + return + } + + p.conn.frontend.Send(&pgproto3.Flush{}) + p.state.PushBackRequestType(pipelineFlushRequest) +} + +// SendPipelineSync marks a synchronization point in a pipeline by sending a sync message +// without flushing the send buffer. This serves as the delimiter of an implicit +// transaction and an error recovery point. +// +// Note that the request is not itself flushed to the server automatically; use Flush if +// necessary. This copies the behavior of libpq PQsendPipelineSync. +func (p *Pipeline) SendPipelineSync() { + if p.closed { + return + } + + p.conn.frontend.SendSync(&pgproto3.Sync{}) + p.state.PushBackRequestType(pipelineSyncRequest) } // Flush flushes the queued requests without establishing a synchronization point. @@ -2115,28 +2305,14 @@ func (p *Pipeline) Flush() error { return err } + p.state.RegisterSendingToServer() return nil } // Sync establishes a synchronization point and flushes the queued requests. func (p *Pipeline) Sync() error { - if p.closed { - if p.err != nil { - return p.err - } - return errors.New("pipeline closed") - } - - p.conn.frontend.SendSync(&pgproto3.Sync{}) - err := p.Flush() - if err != nil { - return err - } - - p.pendingSync = false - p.expectedReadyForQueryCount++ - - return nil + p.SendPipelineSync() + return p.Flush() } // GetResults gets the next results. If results are present, results may be a *ResultReader, *StatementDescription, or @@ -2150,7 +2326,7 @@ func (p *Pipeline) GetResults() (results any, err error) { return nil, errors.New("pipeline closed") } - if p.expectedReadyForQueryCount == 0 { + if p.state.ExtractFrontRequestType() == pipelineNil { return nil, nil } @@ -2195,13 +2371,13 @@ func (p *Pipeline) getResults() (results any, err error) { case *pgproto3.CloseComplete: return &CloseComplete{}, nil case *pgproto3.ReadyForQuery: - p.expectedReadyForQueryCount-- + p.state.HandleReadyForQuery() return &PipelineSync{}, nil case *pgproto3.ErrorResponse: pgErr := ErrorResponseToPgError(msg) + p.state.HandleError(pgErr) return nil, pgErr } - } } @@ -2231,6 +2407,7 @@ func (p *Pipeline) getResultsPrepare() (*StatementDescription, error) { // These should never happen here. But don't take chances that could lead to a deadlock. case *pgproto3.ErrorResponse: pgErr := ErrorResponseToPgError(msg) + p.state.HandleError(pgErr) return nil, pgErr case *pgproto3.CommandComplete: p.conn.asyncClose() @@ -2250,7 +2427,7 @@ func (p *Pipeline) Close() error { p.closed = true - if p.pendingSync { + if p.state.PendingSync() { p.conn.asyncClose() p.err = errors.New("pipeline has unsynced requests") p.conn.contextWatcher.Unwatch() @@ -2259,7 +2436,7 @@ func (p *Pipeline) Close() error { return p.err } - for p.expectedReadyForQueryCount > 0 { + for p.state.ExpectedReadyForQuery() > 0 { _, err := p.getResults() if err != nil { p.err = err diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_cleartext_password.go b/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_cleartext_password.go index ac2962e9e..415e1a24a 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_cleartext_password.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_cleartext_password.go @@ -9,8 +9,7 @@ import ( ) // AuthenticationCleartextPassword is a message sent from the backend indicating that a clear-text password is required. -type AuthenticationCleartextPassword struct { -} +type AuthenticationCleartextPassword struct{} // Backend identifies this message as sendable by the PostgreSQL backend. func (*AuthenticationCleartextPassword) Backend() {} diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_ok.go b/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_ok.go index ec11d39f1..98c0b2d66 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_ok.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/authentication_ok.go @@ -9,8 +9,7 @@ import ( ) // AuthenticationOk is a message sent from the backend indicating that authentication was successful. -type AuthenticationOk struct { -} +type AuthenticationOk struct{} // Backend identifies this message as sendable by the PostgreSQL backend. func (*AuthenticationOk) Backend() {} diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/backend.go b/vendor/github.com/jackc/pgx/v5/pgproto3/backend.go index d146c3384..d9d0f370c 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/backend.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/backend.go @@ -46,8 +46,8 @@ type Backend struct { } const ( - minStartupPacketLen = 4 // minStartupPacketLen is a single 32-bit int version or code. - maxStartupPacketLen = 10000 // maxStartupPacketLen is MAX_STARTUP_PACKET_LENGTH from PG source. + minStartupPacketLen = 4 // minStartupPacketLen is a single 32-bit int version or code. + maxStartupPacketLen = 10_000 // maxStartupPacketLen is MAX_STARTUP_PACKET_LENGTH from PG source. ) // NewBackend creates a new Backend. @@ -175,7 +175,13 @@ func (b *Backend) Receive() (FrontendMessage, error) { } b.msgType = header[0] - b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4 + + msgLength := int(binary.BigEndian.Uint32(header[1:])) + if msgLength < 4 { + return nil, fmt.Errorf("invalid message length: %d", msgLength) + } + + b.bodyLen = msgLength - 4 if b.maxBodyLen > 0 && b.bodyLen > b.maxBodyLen { return nil, &ExceededMaxBodyLenErr{b.maxBodyLen, b.bodyLen} } @@ -282,9 +288,10 @@ func (b *Backend) SetAuthType(authType uint32) error { return nil } -// SetMaxBodyLen sets the maximum length of a message body in octets. If a message body exceeds this length, Receive will return -// an error. This is useful for protecting against malicious clients that send large messages with the intent of -// causing memory exhaustion. +// SetMaxBodyLen sets the maximum length of a message body in octets. +// If a message body exceeds this length, Receive will return an error. +// This is useful for protecting against malicious clients that send +// large messages with the intent of causing memory exhaustion. // The default value is 0. // If maxBodyLen is 0, then no maximum is enforced. func (b *Backend) SetMaxBodyLen(maxBodyLen int) { diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/bind.go b/vendor/github.com/jackc/pgx/v5/pgproto3/bind.go index ad6ac48bf..13700c39c 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/bind.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/bind.go @@ -54,7 +54,7 @@ func (dst *Bind) Decode(src []byte) error { if len(src[rp:]) < len(dst.ParameterFormatCodes)*2 { return &invalidMessageFormatErr{messageType: "Bind"} } - for i := 0; i < parameterFormatCodeCount; i++ { + for i := range parameterFormatCodeCount { dst.ParameterFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:])) rp += 2 } @@ -69,7 +69,7 @@ func (dst *Bind) Decode(src []byte) error { if parameterCount > 0 { dst.Parameters = make([][]byte, parameterCount) - for i := 0; i < parameterCount; i++ { + for i := range parameterCount { if len(src[rp:]) < 4 { return &invalidMessageFormatErr{messageType: "Bind"} } @@ -101,7 +101,7 @@ func (dst *Bind) Decode(src []byte) error { if len(src[rp:]) < len(dst.ResultFormatCodes)*2 { return &invalidMessageFormatErr{messageType: "Bind"} } - for i := 0; i < resultFormatCodeCount; i++ { + for i := range resultFormatCodeCount { dst.ResultFormatCodes[i] = int16(binary.BigEndian.Uint16(src[rp:])) rp += 2 } diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/copy_both_response.go b/vendor/github.com/jackc/pgx/v5/pgproto3/copy_both_response.go index 99e1afea4..e2a402f9a 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/copy_both_response.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/copy_both_response.go @@ -35,7 +35,7 @@ func (dst *CopyBothResponse) Decode(src []byte) error { } columnFormatCodes := make([]uint16, columnCount) - for i := 0; i < columnCount; i++ { + for i := range columnCount { columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2)) } diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/copy_done.go b/vendor/github.com/jackc/pgx/v5/pgproto3/copy_done.go index 040814dbd..c3421a9b5 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/copy_done.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/copy_done.go @@ -4,8 +4,7 @@ import ( "encoding/json" ) -type CopyDone struct { -} +type CopyDone struct{} // Backend identifies this message as sendable by the PostgreSQL backend. func (*CopyDone) Backend() {} diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/copy_in_response.go b/vendor/github.com/jackc/pgx/v5/pgproto3/copy_in_response.go index 06cf99ced..0633935b9 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/copy_in_response.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/copy_in_response.go @@ -35,7 +35,7 @@ func (dst *CopyInResponse) Decode(src []byte) error { } columnFormatCodes := make([]uint16, columnCount) - for i := 0; i < columnCount; i++ { + for i := range columnCount { columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2)) } diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/copy_out_response.go b/vendor/github.com/jackc/pgx/v5/pgproto3/copy_out_response.go index 549e916c1..006864ac8 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/copy_out_response.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/copy_out_response.go @@ -34,7 +34,7 @@ func (dst *CopyOutResponse) Decode(src []byte) error { } columnFormatCodes := make([]uint16, columnCount) - for i := 0; i < columnCount; i++ { + for i := range columnCount { columnFormatCodes[i] = binary.BigEndian.Uint16(buf.Next(2)) } diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/data_row.go b/vendor/github.com/jackc/pgx/v5/pgproto3/data_row.go index fdfb0f7f6..54418d58c 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/data_row.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/data_row.go @@ -31,16 +31,13 @@ func (dst *DataRow) Decode(src []byte) error { // large reallocate. This is too avoid one row with many columns from // permanently allocating memory. if cap(dst.Values) < fieldCount || cap(dst.Values)-fieldCount > 32 { - newCap := 32 - if newCap < fieldCount { - newCap = fieldCount - } + newCap := max(32, fieldCount) dst.Values = make([][]byte, fieldCount, newCap) } else { dst.Values = dst.Values[:fieldCount] } - for i := 0; i < fieldCount; i++ { + for i := range fieldCount { if len(src[rp:]) < 4 { return &invalidMessageFormatErr{messageType: "DataRow"} } diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/frontend.go b/vendor/github.com/jackc/pgx/v5/pgproto3/frontend.go index b41abbe10..056e547cd 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/frontend.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/frontend.go @@ -54,6 +54,7 @@ type Frontend struct { portalSuspended PortalSuspended bodyLen int + maxBodyLen int // maxBodyLen is the maximum length of a message body in octets. If a message body exceeds this length, Receive will return an error. msgType byte partialMsg bool authType uint32 @@ -317,6 +318,9 @@ func (f *Frontend) Receive() (BackendMessage, error) { } f.bodyLen = msgLength - 4 + if f.maxBodyLen > 0 && f.bodyLen > f.maxBodyLen { + return nil, &ExceededMaxBodyLenErr{f.maxBodyLen, f.bodyLen} + } f.partialMsg = true } @@ -452,3 +456,13 @@ func (f *Frontend) GetAuthType() uint32 { func (f *Frontend) ReadBufferLen() int { return f.cr.wp - f.cr.rp } + +// SetMaxBodyLen sets the maximum length of a message body in octets. +// If a message body exceeds this length, Receive will return an error. +// This is useful for protecting against a corrupted server that sends +// messages with incorrect length, which can cause memory exhaustion. +// The default value is 0. +// If maxBodyLen is 0, then no maximum is enforced. +func (f *Frontend) SetMaxBodyLen(maxBodyLen int) { + f.maxBodyLen = maxBodyLen +} diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/function_call.go b/vendor/github.com/jackc/pgx/v5/pgproto3/function_call.go index 7d83579ff..affb713f5 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/function_call.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/function_call.go @@ -33,7 +33,7 @@ func (dst *FunctionCall) Decode(src []byte) error { nArgumentCodes := int(binary.BigEndian.Uint16(src[rp:])) rp += 2 argumentCodes := make([]uint16, nArgumentCodes) - for i := 0; i < nArgumentCodes; i++ { + for i := range nArgumentCodes { // The argument format codes. Each must presently be zero (text) or one (binary). ac := binary.BigEndian.Uint16(src[rp:]) if ac != 0 && ac != 1 { @@ -48,7 +48,7 @@ func (dst *FunctionCall) Decode(src []byte) error { nArguments := int(binary.BigEndian.Uint16(src[rp:])) rp += 2 arguments := make([][]byte, nArguments) - for i := 0; i < nArguments; i++ { + for i := range nArguments { // The length of the argument value, in bytes (this count does not include itself). Can be zero. // As a special case, -1 indicates a NULL argument value. No value bytes follow in the NULL case. argumentLength := int(binary.BigEndian.Uint32(src[rp:])) diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/gss_enc_request.go b/vendor/github.com/jackc/pgx/v5/pgproto3/gss_enc_request.go index 70cb20cd5..122d1341c 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/gss_enc_request.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/gss_enc_request.go @@ -10,8 +10,7 @@ import ( const gssEncReqNumber = 80877104 -type GSSEncRequest struct { -} +type GSSEncRequest struct{} // Frontend identifies this message as sendable by a PostgreSQL frontend. func (*GSSEncRequest) Frontend() {} diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/parameter_description.go b/vendor/github.com/jackc/pgx/v5/pgproto3/parameter_description.go index 1ef27b75f..58eb26ef0 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/parameter_description.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/parameter_description.go @@ -33,7 +33,7 @@ func (dst *ParameterDescription) Decode(src []byte) error { *dst = ParameterDescription{ParameterOIDs: make([]uint32, parameterCount)} - for i := 0; i < parameterCount; i++ { + for i := range parameterCount { dst.ParameterOIDs[i] = binary.BigEndian.Uint32(buf.Next(4)) } diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/parse.go b/vendor/github.com/jackc/pgx/v5/pgproto3/parse.go index 6ba3486cf..8fb8de5d4 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/parse.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/parse.go @@ -43,7 +43,7 @@ func (dst *Parse) Decode(src []byte) error { } parameterOIDCount := int(binary.BigEndian.Uint16(buf.Next(2))) - for i := 0; i < parameterOIDCount; i++ { + for range parameterOIDCount { if buf.Len() < 4 { return &invalidMessageFormatErr{messageType: "Parse"} } diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/password_message.go b/vendor/github.com/jackc/pgx/v5/pgproto3/password_message.go index d820d3275..67b78515d 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/password_message.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/password_message.go @@ -12,7 +12,7 @@ type PasswordMessage struct { // Frontend identifies this message as sendable by a PostgreSQL frontend. func (*PasswordMessage) Frontend() {} -// Frontend identifies this message as an authentication response. +// InitialResponse identifies this message as an authentication response. func (*PasswordMessage) InitialResponse() {} // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/row_description.go b/vendor/github.com/jackc/pgx/v5/pgproto3/row_description.go index dc2a4ddf2..b46f510dc 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/row_description.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/row_description.go @@ -56,7 +56,6 @@ func (*RowDescription) Backend() {} // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message // type identifier and 4 byte message length. func (dst *RowDescription) Decode(src []byte) error { - if len(src) < 2 { return &invalidMessageFormatErr{messageType: "RowDescription"} } @@ -65,7 +64,7 @@ func (dst *RowDescription) Decode(src []byte) error { dst.Fields = dst.Fields[0:0] - for i := 0; i < fieldCount; i++ { + for range fieldCount { var fd FieldDescription idx := bytes.IndexByte(src[rp:], 0) diff --git a/vendor/github.com/jackc/pgx/v5/pgproto3/ssl_request.go b/vendor/github.com/jackc/pgx/v5/pgproto3/ssl_request.go index b0fc28476..bdfc7c427 100644 --- a/vendor/github.com/jackc/pgx/v5/pgproto3/ssl_request.go +++ b/vendor/github.com/jackc/pgx/v5/pgproto3/ssl_request.go @@ -10,8 +10,7 @@ import ( const sslRequestNumber = 80877103 -type SSLRequest struct { -} +type SSLRequest struct{} // Frontend identifies this message as sendable by a PostgreSQL frontend. func (*SSLRequest) Frontend() {} diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/array.go b/vendor/github.com/jackc/pgx/v5/pgtype/array.go index 06b824ad0..872a08891 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/array.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/array.go @@ -374,8 +374,8 @@ func quoteArrayElementIfNeeded(src string) string { return src } -// Array represents a PostgreSQL array for T. It implements the ArrayGetter and ArraySetter interfaces. It preserves -// PostgreSQL dimensions and custom lower bounds. Use FlatArray if these are not needed. +// Array represents a PostgreSQL array for T. It implements the [ArrayGetter] and [ArraySetter] interfaces. It preserves +// PostgreSQL dimensions and custom lower bounds. Use [FlatArray] if these are not needed. type Array[T any] struct { Elements []T Dims []ArrayDimension @@ -419,8 +419,8 @@ func (a Array[T]) ScanIndexType() any { return new(T) } -// FlatArray implements the ArrayGetter and ArraySetter interfaces for any slice of T. It ignores PostgreSQL dimensions -// and custom lower bounds. Use Array to preserve these. +// FlatArray implements the [ArrayGetter] and [ArraySetter] interfaces for any slice of T. It ignores PostgreSQL dimensions +// and custom lower bounds. Use [Array] to preserve these. type FlatArray[T any] []T func (a FlatArray[T]) Dimensions() []ArrayDimension { diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/array_codec.go b/vendor/github.com/jackc/pgx/v5/pgtype/array_codec.go index bf5f6989a..f6b36f439 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/array_codec.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/array_codec.go @@ -118,7 +118,7 @@ func (p *encodePlanArrayCodecText) Encode(value any, buf []byte) (newBuf []byte, var encodePlan EncodePlan var lastElemType reflect.Type inElemBuf := make([]byte, 0, 32) - for i := 0; i < elementCount; i++ { + for i := range elementCount { if i > 0 { buf = append(buf, ',') } @@ -131,7 +131,7 @@ func (p *encodePlanArrayCodecText) Encode(value any, buf []byte) (newBuf []byte, elem := array.Index(i) var elemBuf []byte - if elem != nil { + if isNil, _ := isNilDriverValuer(elem); !isNil { elemType := reflect.TypeOf(elem) if lastElemType != elemType { lastElemType = elemType @@ -189,13 +189,13 @@ func (p *encodePlanArrayCodecBinary) Encode(value any, buf []byte) (newBuf []byt var encodePlan EncodePlan var lastElemType reflect.Type - for i := 0; i < elementCount; i++ { + for i := range elementCount { sp := len(buf) buf = pgio.AppendInt32(buf, -1) elem := array.Index(i) var elemBuf []byte - if elem != nil { + if isNil, _ := isNilDriverValuer(elem); !isNil { elemType := reflect.TypeOf(elem) if lastElemType != elemType { lastElemType = elemType @@ -270,7 +270,7 @@ func (c *ArrayCodec) decodeBinary(m *Map, arrayOID uint32, src []byte, array Arr elementScanPlan = m.PlanScan(c.ElementType.OID, BinaryFormatCode, array.ScanIndex(0)) } - for i := 0; i < elementCount; i++ { + for i := range elementCount { elem := array.ScanIndex(i) elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) rp += 4 @@ -388,7 +388,7 @@ func isRagged(slice reflect.Value) bool { sliceLen := slice.Len() innerLen := 0 - for i := 0; i < sliceLen; i++ { + for i := range sliceLen { if i == 0 { innerLen = slice.Index(i).Len() } else { diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/bits.go b/vendor/github.com/jackc/pgx/v5/pgtype/bits.go index e7a1d016d..2a48e3549 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/bits.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/bits.go @@ -23,16 +23,18 @@ type Bits struct { Valid bool } +// ScanBits implements the [BitsScanner] interface. func (b *Bits) ScanBits(v Bits) error { *b = v return nil } +// BitsValue implements the [BitsValuer] interface. func (b Bits) BitsValue() (Bits, error) { return b, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (dst *Bits) Scan(src any) error { if src == nil { *dst = Bits{} @@ -47,7 +49,7 @@ func (dst *Bits) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (src Bits) Value() (driver.Value, error) { if !src.Valid { return nil, nil @@ -127,7 +129,6 @@ func (encodePlanBitsCodecText) Encode(value any, buf []byte) (newBuf []byte, err } func (BitsCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/bool.go b/vendor/github.com/jackc/pgx/v5/pgtype/bool.go index 71caffa74..955f01fe8 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/bool.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/bool.go @@ -22,16 +22,18 @@ type Bool struct { Valid bool } +// ScanBool implements the [BoolScanner] interface. func (b *Bool) ScanBool(v Bool) error { *b = v return nil } +// BoolValue implements the [BoolValuer] interface. func (b Bool) BoolValue() (Bool, error) { return b, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (dst *Bool) Scan(src any) error { if src == nil { *dst = Bool{} @@ -61,7 +63,7 @@ func (dst *Bool) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (src Bool) Value() (driver.Value, error) { if !src.Valid { return nil, nil @@ -70,6 +72,7 @@ func (src Bool) Value() (driver.Value, error) { return src.Bool, nil } +// MarshalJSON implements the [encoding/json.Marshaler] interface. func (src Bool) MarshalJSON() ([]byte, error) { if !src.Valid { return []byte("null"), nil @@ -82,6 +85,7 @@ func (src Bool) MarshalJSON() ([]byte, error) { } } +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. func (dst *Bool) UnmarshalJSON(b []byte) error { var v *bool err := json.Unmarshal(b, &v) @@ -200,7 +204,6 @@ func (encodePlanBoolCodecTextBool) Encode(value any, buf []byte) (newBuf []byte, } func (BoolCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { @@ -328,7 +331,7 @@ func (scanPlanTextAnyToBoolScanner) Scan(src []byte, dst any) error { return s.ScanBool(Bool{Bool: v, Valid: true}) } -// https://www.postgresql.org/docs/11/datatype-boolean.html +// https://www.postgresql.org/docs/current/datatype-boolean.html func planTextToBool(src []byte) (bool, error) { s := string(bytes.ToLower(bytes.TrimSpace(src))) diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/box.go b/vendor/github.com/jackc/pgx/v5/pgtype/box.go index 887d268bc..d243f58e3 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/box.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/box.go @@ -24,16 +24,18 @@ type Box struct { Valid bool } +// ScanBox implements the [BoxScanner] interface. func (b *Box) ScanBox(v Box) error { *b = v return nil } +// BoxValue implements the [BoxValuer] interface. func (b Box) BoxValue() (Box, error) { return b, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (dst *Box) Scan(src any) error { if src == nil { *dst = Box{} @@ -48,7 +50,7 @@ func (dst *Box) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (src Box) Value() (driver.Value, error) { if !src.Valid { return nil, nil @@ -127,7 +129,6 @@ func (encodePlanBoxCodecText) Encode(value any, buf []byte) (newBuf []byte, err } func (BoxCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/builtin_wrappers.go b/vendor/github.com/jackc/pgx/v5/pgtype/builtin_wrappers.go index b39d3fa10..126e0be2a 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/builtin_wrappers.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/builtin_wrappers.go @@ -527,6 +527,7 @@ func (w *netIPNetWrapper) ScanNetipPrefix(v netip.Prefix) error { return nil } + func (w netIPNetWrapper) NetipPrefixValue() (netip.Prefix, error) { ip, ok := netip.AddrFromSlice(w.IP) if !ok { @@ -881,7 +882,6 @@ func (a *anyMultiDimSliceArray) SetDimensions(dimensions []ArrayDimension) error return nil } - } func (a *anyMultiDimSliceArray) makeMultidimensionalSlice(sliceType reflect.Type, dimensions []ArrayDimension, flatSlice reflect.Value, flatSliceIdx int) reflect.Value { @@ -892,7 +892,7 @@ func (a *anyMultiDimSliceArray) makeMultidimensionalSlice(sliceType reflect.Type sliceLen := int(dimensions[0].Length) slice := reflect.MakeSlice(sliceType, sliceLen, sliceLen) - for i := 0; i < sliceLen; i++ { + for i := range sliceLen { subSlice := a.makeMultidimensionalSlice(sliceType.Elem(), dimensions[1:], flatSlice, flatSliceIdx+(i*int(dimensions[1].Length))) slice.Index(i).Set(subSlice) } diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/bytea.go b/vendor/github.com/jackc/pgx/v5/pgtype/bytea.go index a247705e9..6c4f0c5ea 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/bytea.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/bytea.go @@ -148,7 +148,6 @@ func (encodePlanBytesCodecTextBytesValuer) Encode(value any, buf []byte) (newBuf } func (ByteaCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/circle.go b/vendor/github.com/jackc/pgx/v5/pgtype/circle.go index e8f118cc9..fb9b4c11d 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/circle.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/circle.go @@ -25,16 +25,18 @@ type Circle struct { Valid bool } +// ScanCircle implements the [CircleScanner] interface. func (c *Circle) ScanCircle(v Circle) error { *c = v return nil } +// CircleValue implements the [CircleValuer] interface. func (c Circle) CircleValue() (Circle, error) { return c, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (dst *Circle) Scan(src any) error { if src == nil { *dst = Circle{} @@ -49,7 +51,7 @@ func (dst *Circle) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (src Circle) Value() (driver.Value, error) { if !src.Valid { return nil, nil diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/composite.go b/vendor/github.com/jackc/pgx/v5/pgtype/composite.go index fb372325b..59081a823 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/composite.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/composite.go @@ -276,7 +276,6 @@ func (c *CompositeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byt default: return nil, fmt.Errorf("unknown format code %d", format) } - } type CompositeBinaryScanner struct { @@ -477,7 +476,7 @@ func (b *CompositeBinaryBuilder) AppendValue(oid uint32, field any) { return } - if field == nil { + if isNil, _ := isNilDriverValuer(field); isNil { b.buf = pgio.AppendUint32(b.buf, oid) b.buf = pgio.AppendInt32(b.buf, -1) b.fieldCount++ @@ -534,7 +533,7 @@ func (b *CompositeTextBuilder) AppendValue(oid uint32, field any) { return } - if field == nil { + if isNil, _ := isNilDriverValuer(field); isNil { b.buf = append(b.buf, ',') return } diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/date.go b/vendor/github.com/jackc/pgx/v5/pgtype/date.go index 784b16deb..447056860 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/date.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/date.go @@ -26,11 +26,13 @@ type Date struct { Valid bool } +// ScanDate implements the [DateScanner] interface. func (d *Date) ScanDate(v Date) error { *d = v return nil } +// DateValue implements the [DateValuer] interface. func (d Date) DateValue() (Date, error) { return d, nil } @@ -40,7 +42,7 @@ const ( infinityDayOffset = 2147483647 ) -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (dst *Date) Scan(src any) error { if src == nil { *dst = Date{} @@ -58,7 +60,7 @@ func (dst *Date) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (src Date) Value() (driver.Value, error) { if !src.Valid { return nil, nil @@ -70,6 +72,7 @@ func (src Date) Value() (driver.Value, error) { return src.Time, nil } +// MarshalJSON implements the [encoding/json.Marshaler] interface. func (src Date) MarshalJSON() ([]byte, error) { if !src.Valid { return []byte("null"), nil @@ -89,6 +92,7 @@ func (src Date) MarshalJSON() ([]byte, error) { return json.Marshal(s) } +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. func (dst *Date) UnmarshalJSON(b []byte) error { var s *string err := json.Unmarshal(b, &s) @@ -223,7 +227,6 @@ func (encodePlanDateCodecText) Encode(value any, buf []byte) (newBuf []byte, err } func (DateCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/doc.go b/vendor/github.com/jackc/pgx/v5/pgtype/doc.go index d56c1dc70..83dfc5de5 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/doc.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/doc.go @@ -53,6 +53,9 @@ similar fashion to database/sql. The second is to use a pointer to a pointer. return err } +When using nullable pgtype types as parameters for queries, one has to remember to explicitly set their Valid field to +true, otherwise the parameter's value will be NULL. + JSON Support pgtype automatically marshals and unmarshals data from json and jsonb PostgreSQL types. @@ -156,11 +159,16 @@ example_child_records_test.go for an example. Overview of Scanning Implementation -The first step is to use the OID to lookup the correct Codec. If the OID is unavailable, Map will try to find the OID -from previous calls of Map.RegisterDefaultPgType. The Map will call the Codec's PlanScan method to get a plan for -scanning into the Go value. A Codec will support scanning into one or more Go types. Oftentime these Go types are -interfaces rather than explicit types. For example, PointCodec can use any Go type that implements the PointScanner and -PointValuer interfaces. +The first step is to use the OID to lookup the correct Codec. The Map will call the Codec's PlanScan method to get a +plan for scanning into the Go value. A Codec will support scanning into one or more Go types. Oftentime these Go types +are interfaces rather than explicit types. For example, PointCodec can use any Go type that implements the PointScanner +and PointValuer interfaces. + +If a Go value is not supported directly by a Codec then Map will try see if it is a sql.Scanner. If is then that +interface will be used to scan the value. Most sql.Scanners require the input to be in the text format (e.g. UUIDs and +numeric). However, pgx will typically have received the value in the binary format. In this case the binary value will be +parsed, reencoded as text, and then passed to the sql.Scanner. This may incur additional overhead for query results with +a large number of affected values. If a Go value is not supported directly by a Codec then Map will try wrapping it with additional logic and try again. For example, Int8Codec does not support scanning into a renamed type (e.g. type myInt64 int64). But Map will detect that diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/float4.go b/vendor/github.com/jackc/pgx/v5/pgtype/float4.go index 8646d9d22..241a25add 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/float4.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/float4.go @@ -16,26 +16,29 @@ type Float4 struct { Valid bool } -// ScanFloat64 implements the Float64Scanner interface. +// ScanFloat64 implements the [Float64Scanner] interface. func (f *Float4) ScanFloat64(n Float8) error { *f = Float4{Float32: float32(n.Float64), Valid: n.Valid} return nil } +// Float64Value implements the [Float64Valuer] interface. func (f Float4) Float64Value() (Float8, error) { return Float8{Float64: float64(f.Float32), Valid: f.Valid}, nil } +// ScanInt64 implements the [Int64Scanner] interface. func (f *Float4) ScanInt64(n Int8) error { *f = Float4{Float32: float32(n.Int64), Valid: n.Valid} return nil } +// Int64Value implements the [Int64Valuer] interface. func (f Float4) Int64Value() (Int8, error) { return Int8{Int64: int64(f.Float32), Valid: f.Valid}, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (f *Float4) Scan(src any) error { if src == nil { *f = Float4{} @@ -58,7 +61,7 @@ func (f *Float4) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (f Float4) Value() (driver.Value, error) { if !f.Valid { return nil, nil @@ -66,6 +69,7 @@ func (f Float4) Value() (driver.Value, error) { return float64(f.Float32), nil } +// MarshalJSON implements the [encoding/json.Marshaler] interface. func (f Float4) MarshalJSON() ([]byte, error) { if !f.Valid { return []byte("null"), nil @@ -73,6 +77,7 @@ func (f Float4) MarshalJSON() ([]byte, error) { return json.Marshal(f.Float32) } +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. func (f *Float4) UnmarshalJSON(b []byte) error { var n *float32 err := json.Unmarshal(b, &n) @@ -170,7 +175,6 @@ func (encodePlanFloat4CodecBinaryInt64Valuer) Encode(value any, buf []byte) (new } func (Float4Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/float8.go b/vendor/github.com/jackc/pgx/v5/pgtype/float8.go index 9c923c9a3..54d6781ec 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/float8.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/float8.go @@ -24,26 +24,29 @@ type Float8 struct { Valid bool } -// ScanFloat64 implements the Float64Scanner interface. +// ScanFloat64 implements the [Float64Scanner] interface. func (f *Float8) ScanFloat64(n Float8) error { *f = n return nil } +// Float64Value implements the [Float64Valuer] interface. func (f Float8) Float64Value() (Float8, error) { return f, nil } +// ScanInt64 implements the [Int64Scanner] interface. func (f *Float8) ScanInt64(n Int8) error { *f = Float8{Float64: float64(n.Int64), Valid: n.Valid} return nil } +// Int64Value implements the [Int64Valuer] interface. func (f Float8) Int64Value() (Int8, error) { return Int8{Int64: int64(f.Float64), Valid: f.Valid}, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (f *Float8) Scan(src any) error { if src == nil { *f = Float8{} @@ -66,7 +69,7 @@ func (f *Float8) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (f Float8) Value() (driver.Value, error) { if !f.Valid { return nil, nil @@ -74,6 +77,7 @@ func (f Float8) Value() (driver.Value, error) { return f.Float64, nil } +// MarshalJSON implements the [encoding/json.Marshaler] interface. func (f Float8) MarshalJSON() ([]byte, error) { if !f.Valid { return []byte("null"), nil @@ -81,6 +85,7 @@ func (f Float8) MarshalJSON() ([]byte, error) { return json.Marshal(f.Float64) } +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. func (f *Float8) UnmarshalJSON(b []byte) error { var n *float64 err := json.Unmarshal(b, &n) @@ -208,7 +213,6 @@ func (encodePlanTextInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, e } func (Float8Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/hstore.go b/vendor/github.com/jackc/pgx/v5/pgtype/hstore.go index 2f34f4c9e..61a42de88 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/hstore.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/hstore.go @@ -22,16 +22,18 @@ type HstoreValuer interface { // associated with its keys. type Hstore map[string]*string +// ScanHstore implements the [HstoreScanner] interface. func (h *Hstore) ScanHstore(v Hstore) error { *h = v return nil } +// HstoreValue implements the [HstoreValuer] interface. func (h Hstore) HstoreValue() (Hstore, error) { return h, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (h *Hstore) Scan(src any) error { if src == nil { *h = nil @@ -46,7 +48,7 @@ func (h *Hstore) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (h Hstore) Value() (driver.Value, error) { if h == nil { return nil, nil @@ -162,7 +164,6 @@ func (encodePlanHstoreCodecText) Encode(value any, buf []byte) (newBuf []byte, e } func (HstoreCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { @@ -201,7 +202,7 @@ func (scanPlanBinaryHstoreToHstoreScanner) Scan(src []byte, dst any) error { // one allocation for all *string, rather than one per string, just like text parsing valueStrings := make([]string, pairCount) - for i := 0; i < pairCount; i++ { + for i := range pairCount { if len(src[rp:]) < uint32Len { return fmt.Errorf("hstore incomplete %v", src) } @@ -298,7 +299,7 @@ func (p *hstoreParser) consume() (b byte, end bool) { return b, false } -func unexpectedByteErr(actualB byte, expectedB byte) error { +func unexpectedByteErr(actualB, expectedB byte) error { return fmt.Errorf("expected '%c' ('%#v'); found '%c' ('%#v')", expectedB, expectedB, actualB, actualB) } @@ -316,7 +317,7 @@ func (p *hstoreParser) consumeExpectedByte(expectedB byte) error { // consumeExpected2 consumes two expected bytes or returns an error. // This was a bit faster than using a string argument (better inlining? Not sure). -func (p *hstoreParser) consumeExpected2(one byte, two byte) error { +func (p *hstoreParser) consumeExpected2(one, two byte) error { if p.pos+2 > len(p.str) { return errors.New("unexpected end of string") } diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/inet.go b/vendor/github.com/jackc/pgx/v5/pgtype/inet.go index 6ca10ea07..b92edb239 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/inet.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/inet.go @@ -24,7 +24,7 @@ type NetipPrefixValuer interface { NetipPrefixValue() (netip.Prefix, error) } -// InetCodec handles both inet and cidr PostgreSQL types. The preferred Go types are netip.Prefix and netip.Addr. If +// InetCodec handles both inet and cidr PostgreSQL types. The preferred Go types are [netip.Prefix] and [netip.Addr]. If // IsValid() is false then they are treated as SQL NULL. type InetCodec struct{} @@ -107,7 +107,6 @@ func (encodePlanInetCodecText) Encode(value any, buf []byte) (newBuf []byte, err } func (InetCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/int.go b/vendor/github.com/jackc/pgx/v5/pgtype/int.go index 90a20a26a..d1b8eb612 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/int.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/int.go @@ -1,4 +1,5 @@ -// Do not edit. Generated from pgtype/int.go.erb +// Code generated from pgtype/int.go.erb. DO NOT EDIT. + package pgtype import ( @@ -25,7 +26,7 @@ type Int2 struct { Valid bool } -// ScanInt64 implements the Int64Scanner interface. +// ScanInt64 implements the [Int64Scanner] interface. func (dst *Int2) ScanInt64(n Int8) error { if !n.Valid { *dst = Int2{} @@ -43,11 +44,12 @@ func (dst *Int2) ScanInt64(n Int8) error { return nil } +// Int64Value implements the [Int64Valuer] interface. func (n Int2) Int64Value() (Int8, error) { return Int8{Int64: int64(n.Int16), Valid: n.Valid}, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (dst *Int2) Scan(src any) error { if src == nil { *dst = Int2{} @@ -86,7 +88,7 @@ func (dst *Int2) Scan(src any) error { return nil } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (src Int2) Value() (driver.Value, error) { if !src.Valid { return nil, nil @@ -94,6 +96,7 @@ func (src Int2) Value() (driver.Value, error) { return int64(src.Int16), nil } +// MarshalJSON implements the [encoding/json.Marshaler] interface. func (src Int2) MarshalJSON() ([]byte, error) { if !src.Valid { return []byte("null"), nil @@ -101,6 +104,7 @@ func (src Int2) MarshalJSON() ([]byte, error) { return []byte(strconv.FormatInt(int64(src.Int16), 10)), nil } +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. func (dst *Int2) UnmarshalJSON(b []byte) error { var n *int16 err := json.Unmarshal(b, &n) @@ -585,7 +589,7 @@ type Int4 struct { Valid bool } -// ScanInt64 implements the Int64Scanner interface. +// ScanInt64 implements the [Int64Scanner] interface. func (dst *Int4) ScanInt64(n Int8) error { if !n.Valid { *dst = Int4{} @@ -603,11 +607,12 @@ func (dst *Int4) ScanInt64(n Int8) error { return nil } +// Int64Value implements the [Int64Valuer] interface. func (n Int4) Int64Value() (Int8, error) { return Int8{Int64: int64(n.Int32), Valid: n.Valid}, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (dst *Int4) Scan(src any) error { if src == nil { *dst = Int4{} @@ -646,7 +651,7 @@ func (dst *Int4) Scan(src any) error { return nil } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (src Int4) Value() (driver.Value, error) { if !src.Valid { return nil, nil @@ -654,6 +659,7 @@ func (src Int4) Value() (driver.Value, error) { return int64(src.Int32), nil } +// MarshalJSON implements the [encoding/json.Marshaler] interface. func (src Int4) MarshalJSON() ([]byte, error) { if !src.Valid { return []byte("null"), nil @@ -661,6 +667,7 @@ func (src Int4) MarshalJSON() ([]byte, error) { return []byte(strconv.FormatInt(int64(src.Int32), 10)), nil } +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. func (dst *Int4) UnmarshalJSON(b []byte) error { var n *int32 err := json.Unmarshal(b, &n) @@ -1156,7 +1163,7 @@ type Int8 struct { Valid bool } -// ScanInt64 implements the Int64Scanner interface. +// ScanInt64 implements the [Int64Scanner] interface. func (dst *Int8) ScanInt64(n Int8) error { if !n.Valid { *dst = Int8{} @@ -1174,11 +1181,12 @@ func (dst *Int8) ScanInt64(n Int8) error { return nil } +// Int64Value implements the [Int64Valuer] interface. func (n Int8) Int64Value() (Int8, error) { return Int8{Int64: int64(n.Int64), Valid: n.Valid}, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (dst *Int8) Scan(src any) error { if src == nil { *dst = Int8{} @@ -1217,7 +1225,7 @@ func (dst *Int8) Scan(src any) error { return nil } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (src Int8) Value() (driver.Value, error) { if !src.Valid { return nil, nil @@ -1225,6 +1233,7 @@ func (src Int8) Value() (driver.Value, error) { return int64(src.Int64), nil } +// MarshalJSON implements the [encoding/json.Marshaler] interface. func (src Int8) MarshalJSON() ([]byte, error) { if !src.Valid { return []byte("null"), nil @@ -1232,6 +1241,7 @@ func (src Int8) MarshalJSON() ([]byte, error) { return []byte(strconv.FormatInt(int64(src.Int64), 10)), nil } +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. func (dst *Int8) UnmarshalJSON(b []byte) error { var n *int64 err := json.Unmarshal(b, &n) diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/int.go.erb b/vendor/github.com/jackc/pgx/v5/pgtype/int.go.erb index e0c8b7a3f..c2d40f60b 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/int.go.erb +++ b/vendor/github.com/jackc/pgx/v5/pgtype/int.go.erb @@ -27,7 +27,7 @@ type Int<%= pg_byte_size %> struct { Valid bool } -// ScanInt64 implements the Int64Scanner interface. +// ScanInt64 implements the [Int64Scanner] interface. func (dst *Int<%= pg_byte_size %>) ScanInt64(n Int8) error { if !n.Valid { *dst = Int<%= pg_byte_size %>{} @@ -45,11 +45,12 @@ func (dst *Int<%= pg_byte_size %>) ScanInt64(n Int8) error { return nil } +// Int64Value implements the [Int64Valuer] interface. func (n Int<%= pg_byte_size %>) Int64Value() (Int8, error) { return Int8{Int64: int64(n.Int<%= pg_bit_size %>), Valid: n.Valid}, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (dst *Int<%= pg_byte_size %>) Scan(src any) error { if src == nil { *dst = Int<%= pg_byte_size %>{} @@ -88,7 +89,7 @@ func (dst *Int<%= pg_byte_size %>) Scan(src any) error { return nil } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (src Int<%= pg_byte_size %>) Value() (driver.Value, error) { if !src.Valid { return nil, nil @@ -96,6 +97,7 @@ func (src Int<%= pg_byte_size %>) Value() (driver.Value, error) { return int64(src.Int<%= pg_bit_size %>), nil } +// MarshalJSON implements the [encoding/json.Marshaler] interface. func (src Int<%= pg_byte_size %>) MarshalJSON() ([]byte, error) { if !src.Valid { return []byte("null"), nil @@ -103,6 +105,7 @@ func (src Int<%= pg_byte_size %>) MarshalJSON() ([]byte, error) { return []byte(strconv.FormatInt(int64(src.Int<%= pg_bit_size %>), 10)), nil } +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. func (dst *Int<%= pg_byte_size %>) UnmarshalJSON(b []byte) error { var n *int<%= pg_bit_size %> err := json.Unmarshal(b, &n) diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/integration_benchmark_test.go.erb b/vendor/github.com/jackc/pgx/v5/pgtype/integration_benchmark_test.go.erb index 0175700a4..6f4011534 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/integration_benchmark_test.go.erb +++ b/vendor/github.com/jackc/pgx/v5/pgtype/integration_benchmark_test.go.erb @@ -25,7 +25,7 @@ func BenchmarkQuery<%= format_name %>FormatDecode_PG_<%= pg_type %>_to_Go_<%= go rows, _ := conn.Query( ctx, `select <% columns.times do |col_idx| %><% if col_idx != 0 %>, <% end %>n::<%= pg_type %> + <%= col_idx%><% end %> from generate_series(1, <%= rows %>) n`, - []any{pgx.QueryResultFormats{<%= format_code %>}}, + pgx.QueryResultFormats{<%= format_code %>}, ) _, err := pgx.ForEachRow(rows, []any{<% columns.times do |col_idx| %><% if col_idx != 0 %>, <% end %>&v[<%= col_idx%>]<% end %>}, func() error { return nil }) if err != nil { @@ -49,7 +49,7 @@ func BenchmarkQuery<%= format_name %>FormatDecode_PG_Int4Array_With_Go_Int4Array rows, _ := conn.Query( ctx, `select array_agg(n) from generate_series(1, <%= array_size %>) n`, - []any{pgx.QueryResultFormats{<%= format_code %>}}, + pgx.QueryResultFormats{<%= format_code %>}, ) _, err := pgx.ForEachRow(rows, []any{&v}, func() error { return nil }) if err != nil { diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/interval.go b/vendor/github.com/jackc/pgx/v5/pgtype/interval.go index 06703d4dc..b1bc78527 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/interval.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/interval.go @@ -11,7 +11,7 @@ import ( ) const ( - microsecondsPerSecond = 1000000 + microsecondsPerSecond = 1_000_000 microsecondsPerMinute = 60 * microsecondsPerSecond microsecondsPerHour = 60 * microsecondsPerMinute microsecondsPerDay = 24 * microsecondsPerHour @@ -33,16 +33,18 @@ type Interval struct { Valid bool } +// ScanInterval implements the [IntervalScanner] interface. func (interval *Interval) ScanInterval(v Interval) error { *interval = v return nil } +// IntervalValue implements the [IntervalValuer] interface. func (interval Interval) IntervalValue() (Interval, error) { return interval, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (interval *Interval) Scan(src any) error { if src == nil { *interval = Interval{} @@ -57,7 +59,7 @@ func (interval *Interval) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (interval Interval) Value() (driver.Value, error) { if !interval.Valid { return nil, nil @@ -132,36 +134,31 @@ func (encodePlanIntervalCodecText) Encode(value any, buf []byte) (newBuf []byte, if interval.Days != 0 { buf = append(buf, strconv.FormatInt(int64(interval.Days), 10)...) - buf = append(buf, " day"...) + buf = append(buf, " day "...) } - if interval.Microseconds != 0 { - buf = append(buf, " "...) - - absMicroseconds := interval.Microseconds - if absMicroseconds < 0 { - absMicroseconds = -absMicroseconds - buf = append(buf, '-') - } + absMicroseconds := interval.Microseconds + if absMicroseconds < 0 { + absMicroseconds = -absMicroseconds + buf = append(buf, '-') + } - hours := absMicroseconds / microsecondsPerHour - minutes := (absMicroseconds % microsecondsPerHour) / microsecondsPerMinute - seconds := (absMicroseconds % microsecondsPerMinute) / microsecondsPerSecond + hours := absMicroseconds / microsecondsPerHour + minutes := (absMicroseconds % microsecondsPerHour) / microsecondsPerMinute + seconds := (absMicroseconds % microsecondsPerMinute) / microsecondsPerSecond - timeStr := fmt.Sprintf("%02d:%02d:%02d", hours, minutes, seconds) - buf = append(buf, timeStr...) + timeStr := fmt.Sprintf("%02d:%02d:%02d", hours, minutes, seconds) + buf = append(buf, timeStr...) - microseconds := absMicroseconds % microsecondsPerSecond - if microseconds != 0 { - buf = append(buf, fmt.Sprintf(".%06d", microseconds)...) - } + microseconds := absMicroseconds % microsecondsPerSecond + if microseconds != 0 { + buf = append(buf, fmt.Sprintf(".%06d", microseconds)...) } return buf, nil } func (IntervalCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { @@ -226,6 +223,8 @@ func (scanPlanTextAnyToIntervalScanner) Scan(src []byte, dst any) error { months += int32(scalar) case "day", "days": days = int32(scalar) + default: + return fmt.Errorf("bad interval format: %q", parts[i+1]) } } diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/json.go b/vendor/github.com/jackc/pgx/v5/pgtype/json.go index e71dcb9bf..bf70735ee 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/json.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/json.go @@ -37,7 +37,7 @@ func (c *JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) Enco // // https://github.com/jackc/pgx/issues/1430 // - // Check for driver.Valuer must come before json.Marshaler so that it is guaranteed to beused + // Check for driver.Valuer must come before json.Marshaler so that it is guaranteed to be used // when both are implemented https://github.com/jackc/pgx/issues/1805 case driver.Valuer: return &encodePlanDriverValuer{m: m, oid: oid, formatCode: format} @@ -71,6 +71,27 @@ func (c *JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) Enco } } +// JSON needs its on scan plan for pointers to handle 'null'::json(b). +// Consider making pointerPointerScanPlan more flexible in the future. +type jsonPointerScanPlan struct { + next ScanPlan +} + +func (p jsonPointerScanPlan) Scan(src []byte, dst any) error { + el := reflect.ValueOf(dst).Elem() + if src == nil || string(src) == "null" { + el.SetZero() + return nil + } + + el.Set(reflect.New(el.Type().Elem())) + if p.next != nil { + return p.next.Scan(src, el.Interface()) + } + + return nil +} + type encodePlanJSONCodecEitherFormatString struct{} func (encodePlanJSONCodecEitherFormatString) Encode(value any, buf []byte) (newBuf []byte, err error) { @@ -117,41 +138,35 @@ func (e *encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) ( return buf, nil } -func (c *JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { +func (c *JSONCodec) PlanScan(m *Map, oid uint32, formatCode int16, target any) ScanPlan { + return c.planScan(m, oid, formatCode, target, 0) +} + +// JSON cannot fallback to pointerPointerScanPlan because of 'null'::json(b), +// so we need to duplicate the logic here. +func (c *JSONCodec) planScan(m *Map, oid uint32, formatCode int16, target any, depth int) ScanPlan { + if depth > 8 { + return &scanPlanFail{m: m, oid: oid, formatCode: formatCode} + } + switch target.(type) { case *string: - return scanPlanAnyToString{} - - case **string: - // This is to fix **string scanning. It seems wrong to special case **string, but it's not clear what a better - // solution would be. - // - // https://github.com/jackc/pgx/issues/1470 -- **string - // https://github.com/jackc/pgx/issues/1691 -- ** anything else - - if wrapperPlan, nextDst, ok := TryPointerPointerScanPlan(target); ok { - if nextPlan := m.planScan(oid, format, nextDst); nextPlan != nil { - if _, failed := nextPlan.(*scanPlanFail); !failed { - wrapperPlan.SetNext(nextPlan) - return wrapperPlan - } - } - } - + return &scanPlanAnyToString{} case *[]byte: - return scanPlanJSONToByteSlice{} + return &scanPlanJSONToByteSlice{} case BytesScanner: - return scanPlanBinaryBytesToBytesScanner{} - - // Cannot rely on sql.Scanner being handled later because scanPlanJSONToJSONUnmarshal will take precedence. - // - // https://github.com/jackc/pgx/issues/1418 + return &scanPlanBinaryBytesToBytesScanner{} case sql.Scanner: - return &scanPlanSQLScanner{formatCode: format} + return &scanPlanCodecSQLScanner{c: c, m: m, oid: oid, formatCode: formatCode} } - return &scanPlanJSONToJSONUnmarshal{ - unmarshal: c.Unmarshal, + rv := reflect.ValueOf(target) + if rv.Kind() == reflect.Pointer && rv.Elem().Kind() == reflect.Pointer { + var plan jsonPointerScanPlan + plan.next = c.planScan(m, oid, formatCode, rv.Elem().Interface(), depth+1) + return plan + } else { + return &scanPlanJSONToJSONUnmarshal{unmarshal: c.Unmarshal} } } @@ -177,13 +192,6 @@ func (scanPlanJSONToByteSlice) Scan(src []byte, dst any) error { return nil } -type scanPlanJSONToBytesScanner struct{} - -func (scanPlanJSONToBytesScanner) Scan(src []byte, dst any) error { - scanner := (dst).(BytesScanner) - return scanner.ScanBytes(src) -} - type scanPlanJSONToJSONUnmarshal struct { unmarshal func(data []byte, v any) error } @@ -203,7 +211,12 @@ func (s *scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error { return fmt.Errorf("cannot scan NULL into %T", dst) } - elem := reflect.ValueOf(dst).Elem() + v := reflect.ValueOf(dst) + if v.Kind() != reflect.Pointer || v.IsNil() { + return fmt.Errorf("cannot scan into non-pointer or nil destinations %T", dst) + } + + elem := v.Elem() elem.Set(reflect.Zero(elem.Type())) return s.unmarshal(src, dst) diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/line.go b/vendor/github.com/jackc/pgx/v5/pgtype/line.go index 4ae8003e8..10efc8ce7 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/line.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/line.go @@ -24,11 +24,13 @@ type Line struct { Valid bool } +// ScanLine implements the [LineScanner] interface. func (line *Line) ScanLine(v Line) error { *line = v return nil } +// LineValue implements the [LineValuer] interface. func (line Line) LineValue() (Line, error) { return line, nil } @@ -37,7 +39,7 @@ func (line *Line) Set(src any) error { return fmt.Errorf("cannot convert %v to Line", src) } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (line *Line) Scan(src any) error { if src == nil { *line = Line{} @@ -52,7 +54,7 @@ func (line *Line) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (line Line) Value() (driver.Value, error) { if !line.Valid { return nil, nil @@ -129,7 +131,6 @@ func (encodePlanLineCodecText) Encode(value any, buf []byte) (newBuf []byte, err } func (LineCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/lseg.go b/vendor/github.com/jackc/pgx/v5/pgtype/lseg.go index 05a86e1c8..ed0d40d2a 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/lseg.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/lseg.go @@ -24,16 +24,18 @@ type Lseg struct { Valid bool } +// ScanLseg implements the [LsegScanner] interface. func (lseg *Lseg) ScanLseg(v Lseg) error { *lseg = v return nil } +// LsegValue implements the [LsegValuer] interface. func (lseg Lseg) LsegValue() (Lseg, error) { return lseg, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (lseg *Lseg) Scan(src any) error { if src == nil { *lseg = Lseg{} @@ -48,7 +50,7 @@ func (lseg *Lseg) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (lseg Lseg) Value() (driver.Value, error) { if !lseg.Valid { return nil, nil @@ -127,7 +129,6 @@ func (encodePlanLsegCodecText) Encode(value any, buf []byte) (newBuf []byte, err } func (LsegCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/multirange.go b/vendor/github.com/jackc/pgx/v5/pgtype/multirange.go index e57637880..549917e4c 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/multirange.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/multirange.go @@ -98,7 +98,7 @@ func (p *encodePlanMultirangeCodecText) Encode(value any, buf []byte) (newBuf [] var encodePlan EncodePlan var lastElemType reflect.Type inElemBuf := make([]byte, 0, 32) - for i := 0; i < elementCount; i++ { + for i := range elementCount { if i > 0 { buf = append(buf, ',') } @@ -151,7 +151,7 @@ func (p *encodePlanMultirangeCodecBinary) Encode(value any, buf []byte) (newBuf var encodePlan EncodePlan var lastElemType reflect.Type - for i := 0; i < elementCount; i++ { + for i := range elementCount { sp := len(buf) buf = pgio.AppendInt32(buf, -1) @@ -224,7 +224,7 @@ func (c *MultirangeCodec) decodeBinary(m *Map, multirangeOID uint32, src []byte, elementScanPlan = m.PlanScan(c.ElementType.OID, BinaryFormatCode, multirange.ScanIndex(0)) } - for i := 0; i < elementCount; i++ { + for i := range elementCount { elem := multirange.ScanIndex(i) elemLen := int(int32(binary.BigEndian.Uint32(src[rp:]))) rp += 4 @@ -374,7 +374,6 @@ parseValueLoop: } return elements, nil - } func parseRange(buf *bytes.Buffer) (string, error) { @@ -403,8 +402,8 @@ func parseRange(buf *bytes.Buffer) (string, error) { // Multirange is a generic multirange type. // -// T should implement RangeValuer and *T should implement RangeScanner. However, there does not appear to be a way to -// enforce the RangeScanner constraint. +// T should implement [RangeValuer] and *T should implement [RangeScanner]. However, there does not appear to be a way to +// enforce the [RangeScanner] constraint. type Multirange[T RangeValuer] []T func (r Multirange[T]) IsNull() bool { diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/numeric.go b/vendor/github.com/jackc/pgx/v5/pgtype/numeric.go index 4dbec7861..b295c2ada 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/numeric.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/numeric.go @@ -14,7 +14,7 @@ import ( ) // PostgreSQL internal numeric storage uses 16-bit "digits" with base of 10,000 -const nbase = 10000 +const nbase = 10_000 const ( pgNumericNaN = 0x00000000c0000000 @@ -27,16 +27,19 @@ const ( pgNumericNegInfSign = 0xf000 ) -var big0 *big.Int = big.NewInt(0) -var big1 *big.Int = big.NewInt(1) -var big10 *big.Int = big.NewInt(10) -var big100 *big.Int = big.NewInt(100) -var big1000 *big.Int = big.NewInt(1000) +var ( + big1 *big.Int = big.NewInt(1) + big10 *big.Int = big.NewInt(10) + big100 *big.Int = big.NewInt(100) + big1000 *big.Int = big.NewInt(1000) +) -var bigNBase *big.Int = big.NewInt(nbase) -var bigNBaseX2 *big.Int = big.NewInt(nbase * nbase) -var bigNBaseX3 *big.Int = big.NewInt(nbase * nbase * nbase) -var bigNBaseX4 *big.Int = big.NewInt(nbase * nbase * nbase * nbase) +var ( + bigNBase *big.Int = big.NewInt(nbase) + bigNBaseX2 *big.Int = big.NewInt(nbase * nbase) + bigNBaseX3 *big.Int = big.NewInt(nbase * nbase * nbase) + bigNBaseX4 *big.Int = big.NewInt(nbase * nbase * nbase * nbase) +) type NumericScanner interface { ScanNumeric(v Numeric) error @@ -54,15 +57,18 @@ type Numeric struct { Valid bool } +// ScanNumeric implements the [NumericScanner] interface. func (n *Numeric) ScanNumeric(v Numeric) error { *n = v return nil } +// NumericValue implements the [NumericValuer] interface. func (n Numeric) NumericValue() (Numeric, error) { return n, nil } +// Float64Value implements the [Float64Valuer] interface. func (n Numeric) Float64Value() (Float8, error) { if !n.Valid { return Float8{}, nil @@ -92,6 +98,7 @@ func (n Numeric) Float64Value() (Float8, error) { return Float8{Float64: f, Valid: true}, nil } +// ScanInt64 implements the [Int64Scanner] interface. func (n *Numeric) ScanInt64(v Int8) error { if !v.Valid { *n = Numeric{} @@ -102,6 +109,7 @@ func (n *Numeric) ScanInt64(v Int8) error { return nil } +// Int64Value implements the [Int64Valuer] interface. func (n Numeric) Int64Value() (Int8, error) { if !n.Valid { return Int8{}, nil @@ -157,7 +165,7 @@ func (n *Numeric) toBigInt() (*big.Int, error) { div.Exp(big10, big.NewInt(int64(-n.Exp)), nil) remainder := &big.Int{} num.DivMod(num, div, remainder) - if remainder.Cmp(big0) != 0 { + if remainder.Sign() != 0 { return nil, fmt.Errorf("cannot convert %v to integer", n) } return num, nil @@ -185,14 +193,11 @@ func parseNumericString(str string) (n *big.Int, exp int32, err error) { } func nbaseDigitsToInt64(src []byte) (accum int64, bytesRead, digitsRead int) { - digits := len(src) / 2 - if digits > 4 { - digits = 4 - } + digits := min(len(src)/2, 4) rp := 0 - for i := 0; i < digits; i++ { + for i := range digits { if i > 0 { accum *= nbase } @@ -203,7 +208,7 @@ func nbaseDigitsToInt64(src []byte) (accum int64, bytesRead, digitsRead int) { return accum, rp, digits } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (n *Numeric) Scan(src any) error { if src == nil { *n = Numeric{} @@ -218,7 +223,7 @@ func (n *Numeric) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (n Numeric) Value() (driver.Value, error) { if !n.Valid { return nil, nil @@ -231,6 +236,7 @@ func (n Numeric) Value() (driver.Value, error) { return string(buf), err } +// MarshalJSON implements the [encoding/json.Marshaler] interface. func (n Numeric) MarshalJSON() ([]byte, error) { if !n.Valid { return []byte("null"), nil @@ -243,6 +249,7 @@ func (n Numeric) MarshalJSON() ([]byte, error) { return n.numberTextBytes(), nil } +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. func (n *Numeric) UnmarshalJSON(src []byte) error { if bytes.Equal(src, []byte(`null`)) { *n = Numeric{} @@ -269,14 +276,14 @@ func (n Numeric) numberTextBytes() []byte { exp := int(n.Exp) if exp > 0 { buf.WriteString(intStr) - for i := 0; i < exp; i++ { + for range exp { buf.WriteByte('0') } } else if exp < 0 { if len(intStr) <= -exp { buf.WriteString("0.") leadingZeros := -exp - len(intStr) - for i := 0; i < leadingZeros; i++ { + for range leadingZeros { buf.WriteByte('0') } buf.WriteString(intStr) @@ -398,7 +405,7 @@ func encodeNumericBinary(n Numeric, buf []byte) (newBuf []byte, err error) { } var sign int16 - if n.Int.Cmp(big0) < 0 { + if n.Int.Sign() < 0 { sign = 16384 } @@ -436,12 +443,12 @@ func encodeNumericBinary(n Numeric, buf []byte) (newBuf []byte, err error) { var wholeDigits, fracDigits []int16 - for wholePart.Cmp(big0) != 0 { + for wholePart.Sign() != 0 { wholePart.DivMod(wholePart, bigNBase, remainder) wholeDigits = append(wholeDigits, int16(remainder.Int64())) } - if fracPart.Cmp(big0) != 0 { + if fracPart.Sign() != 0 { for fracPart.Cmp(big1) != 0 { fracPart.DivMod(fracPart, bigNBase, remainder) fracDigits = append(fracDigits, int16(remainder.Int64())) @@ -553,7 +560,6 @@ func encodeNumericText(n Numeric, buf []byte) (newBuf []byte, err error) { } func (NumericCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { @@ -648,18 +654,19 @@ func (scanPlanBinaryNumericToNumericScanner) Scan(src []byte, dst any) error { exp := (int32(weight) - int32(ndigits) + 1) * 4 if dscale > 0 { - fracNBaseDigits := int16(int32(ndigits) - int32(weight) - 1) + fracNBaseDigits := int(ndigits) - int(weight) - 1 fracDecimalDigits := fracNBaseDigits * 4 + dscaleInt := int(dscale) - if dscale > fracDecimalDigits { - multCount := int(dscale - fracDecimalDigits) - for i := 0; i < multCount; i++ { + if dscaleInt > fracDecimalDigits { + multCount := dscaleInt - fracDecimalDigits + for range multCount { accum.Mul(accum, big10) exp-- } - } else if dscale < fracDecimalDigits { - divCount := int(fracDecimalDigits - dscale) - for i := 0; i < divCount; i++ { + } else if dscaleInt < fracDecimalDigits { + divCount := fracDecimalDigits - dscaleInt + for range divCount { accum.Div(accum, big10) exp++ } @@ -671,7 +678,7 @@ func (scanPlanBinaryNumericToNumericScanner) Scan(src []byte, dst any) error { if exp >= 0 { for { reduced.DivMod(accum, big10, remainder) - if remainder.Cmp(big0) != 0 { + if remainder.Sign() != 0 { break } accum.Set(reduced) diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/path.go b/vendor/github.com/jackc/pgx/v5/pgtype/path.go index 73e0ec52f..685996a89 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/path.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/path.go @@ -25,16 +25,18 @@ type Path struct { Valid bool } +// ScanPath implements the [PathScanner] interface. func (path *Path) ScanPath(v Path) error { *path = v return nil } +// PathValue implements the [PathValuer] interface. func (path Path) PathValue() (Path, error) { return path, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (path *Path) Scan(src any) error { if src == nil { *path = Path{} @@ -49,7 +51,7 @@ func (path *Path) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (path Path) Value() (driver.Value, error) { if !path.Valid { return nil, nil @@ -154,7 +156,6 @@ func (encodePlanPathCodecText) Encode(value any, buf []byte) (newBuf []byte, err } func (PathCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { @@ -194,7 +195,7 @@ func (scanPlanBinaryPathToPathScanner) Scan(src []byte, dst any) error { } points := make([]Vec2, pointCount) - for i := 0; i < len(points); i++ { + for i := range points { x := binary.BigEndian.Uint64(src[rp:]) rp += 8 y := binary.BigEndian.Uint64(src[rp:]) diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/pgtype.go b/vendor/github.com/jackc/pgx/v5/pgtype/pgtype.go index 408295683..ae217dc31 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/pgtype.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/pgtype.go @@ -26,7 +26,10 @@ const ( XIDOID = 28 CIDOID = 29 JSONOID = 114 + XMLOID = 142 + XMLArrayOID = 143 JSONArrayOID = 199 + XID8ArrayOID = 271 PointOID = 600 LsegOID = 601 PathOID = 602 @@ -115,6 +118,7 @@ const ( TstzmultirangeOID = 4534 DatemultirangeOID = 4535 Int8multirangeOID = 4536 + XID8OID = 5069 Int4multirangeArrayOID = 6150 NummultirangeArrayOID = 6151 TsmultirangeArrayOID = 6152 @@ -198,7 +202,6 @@ type Map struct { reflectTypeToType map[reflect.Type]*Type - memoizedScanPlans map[uint32]map[reflect.Type][2]ScanPlan memoizedEncodePlans map[uint32]map[reflect.Type][2]EncodePlan // TryWrapEncodePlanFuncs is a slice of functions that will wrap a value that cannot be encoded by the Codec. Every @@ -214,6 +217,15 @@ type Map struct { TryWrapScanPlanFuncs []TryWrapScanPlanFunc } +// Copy returns a new Map containing the same registered types. +func (m *Map) Copy() *Map { + newMap := NewMap() + for _, type_ := range m.oidToType { + newMap.RegisterType(type_) + } + return newMap +} + func NewMap() *Map { defaultMapInitOnce.Do(initDefaultMap) @@ -223,7 +235,6 @@ func NewMap() *Map { reflectTypeToName: make(map[reflect.Type]string), oidToFormatCode: make(map[uint32]int16), - memoizedScanPlans: make(map[uint32]map[reflect.Type][2]ScanPlan), memoizedEncodePlans: make(map[uint32]map[reflect.Type][2]EncodePlan), TryWrapEncodePlanFuncs: []TryWrapEncodePlanFunc{ @@ -248,6 +259,13 @@ func NewMap() *Map { } } +// RegisterTypes registers multiple data types in the sequence they are provided. +func (m *Map) RegisterTypes(types []*Type) { + for _, t := range types { + m.RegisterType(t) + } +} + // RegisterType registers a data type with the Map. t must not be mutated after it is registered. func (m *Map) RegisterType(t *Type) { m.oidToType[t.OID] = t @@ -256,9 +274,6 @@ func (m *Map) RegisterType(t *Type) { // Invalidated by type registration m.reflectTypeToType = nil - for k := range m.memoizedScanPlans { - delete(m.memoizedScanPlans, k) - } for k := range m.memoizedEncodePlans { delete(m.memoizedEncodePlans, k) } @@ -272,9 +287,6 @@ func (m *Map) RegisterDefaultPgType(value any, name string) { // Invalidated by type registration m.reflectTypeToType = nil - for k := range m.memoizedScanPlans { - delete(m.memoizedScanPlans, k) - } for k := range m.memoizedEncodePlans { delete(m.memoizedEncodePlans, k) } @@ -377,6 +389,7 @@ type scanPlanSQLScanner struct { func (plan *scanPlanSQLScanner) Scan(src []byte, dst any) error { scanner := dst.(sql.Scanner) + if src == nil { // This is necessary because interface value []byte:nil does not equal nil:nil for the binary format path and the // text format path would be converted to empty string. @@ -431,14 +444,14 @@ func (plan *scanPlanFail) Scan(src []byte, dst any) error { // As a horrible hack try all types to find anything that can scan into dst. for oid := range plan.m.oidToType { // using planScan instead of Scan or PlanScan to avoid polluting the planned scan cache. - plan := plan.m.planScan(oid, plan.formatCode, dst) + plan := plan.m.planScan(oid, plan.formatCode, dst, 0) if _, ok := plan.(*scanPlanFail); !ok { return plan.Scan(src, dst) } } for oid := range defaultMap.oidToType { if _, ok := plan.m.oidToType[oid]; !ok { - plan := plan.m.planScan(oid, plan.formatCode, dst) + plan := plan.m.planScan(oid, plan.formatCode, dst, 0) if _, ok := plan.(*scanPlanFail); !ok { return plan.Scan(src, dst) } @@ -555,17 +568,24 @@ func TryFindUnderlyingTypeScanPlan(dst any) (plan WrappedScanPlanNextSetter, nex elemValue = dstValue.Elem() } nextDstType := elemKindToPointerTypes[elemValue.Kind()] - if nextDstType == nil && elemValue.Kind() == reflect.Slice { - if elemValue.Type().Elem().Kind() == reflect.Uint8 { - var v *[]byte - nextDstType = reflect.TypeOf(v) + if nextDstType == nil { + if elemValue.Kind() == reflect.Slice { + if elemValue.Type().Elem().Kind() == reflect.Uint8 { + var v *[]byte + nextDstType = reflect.TypeOf(v) + } + } + + // Get underlying type of any array. + // https://github.com/jackc/pgx/issues/2107 + if elemValue.Kind() == reflect.Array { + nextDstType = reflect.PointerTo(reflect.ArrayOf(elemValue.Len(), elemValue.Type().Elem())) } } if nextDstType != nil && dstValue.Type() != nextDstType && dstValue.CanConvert(nextDstType) { return &underlyingTypeScanPlan{dstType: dstValue.Type(), nextDstType: nextDstType}, dstValue.Convert(nextDstType).Interface(), true } - } return nil, nil, false @@ -1039,24 +1059,14 @@ func (plan *wrapPtrArrayReflectScanPlan) Scan(src []byte, target any) error { // PlanScan prepares a plan to scan a value into target. func (m *Map) PlanScan(oid uint32, formatCode int16, target any) ScanPlan { - oidMemo := m.memoizedScanPlans[oid] - if oidMemo == nil { - oidMemo = make(map[reflect.Type][2]ScanPlan) - m.memoizedScanPlans[oid] = oidMemo - } - targetReflectType := reflect.TypeOf(target) - typeMemo := oidMemo[targetReflectType] - plan := typeMemo[formatCode] - if plan == nil { - plan = m.planScan(oid, formatCode, target) - typeMemo[formatCode] = plan - oidMemo[targetReflectType] = typeMemo - } - - return plan + return m.planScan(oid, formatCode, target, 0) } -func (m *Map) planScan(oid uint32, formatCode int16, target any) ScanPlan { +func (m *Map) planScan(oid uint32, formatCode int16, target any, depth int) ScanPlan { + if depth > 8 { + return &scanPlanFail{m: m, oid: oid, formatCode: formatCode} + } + if target == nil { return &scanPlanFail{m: m, oid: oid, formatCode: formatCode} } @@ -1116,7 +1126,7 @@ func (m *Map) planScan(oid uint32, formatCode int16, target any) ScanPlan { for _, f := range m.TryWrapScanPlanFuncs { if wrapperPlan, nextDst, ok := f(target); ok { - if nextPlan := m.planScan(oid, formatCode, nextDst); nextPlan != nil { + if nextPlan := m.planScan(oid, formatCode, nextDst, depth+1); nextPlan != nil { if _, failed := nextPlan.(*scanPlanFail); !failed { wrapperPlan.SetNext(nextPlan) return wrapperPlan @@ -1125,10 +1135,18 @@ func (m *Map) planScan(oid uint32, formatCode int16, target any) ScanPlan { } } - if dt != nil { - if _, ok := target.(*any); ok { - return &pointerEmptyInterfaceScanPlan{codec: dt.Codec, m: m, oid: oid, formatCode: formatCode} + if _, ok := target.(*any); ok { + var codec Codec + if dt != nil { + codec = dt.Codec + } else { + if formatCode == TextFormatCode { + codec = TextCodec{} + } else { + codec = ByteaCodec{} + } } + return &pointerEmptyInterfaceScanPlan{codec: codec, m: m, oid: oid, formatCode: formatCode} } return &scanPlanFail{m: m, oid: oid, formatCode: formatCode} @@ -1173,9 +1191,18 @@ func codecDecodeToTextFormat(codec Codec, m *Map, oid uint32, format int16, src } } -// PlanEncode returns an Encode plan for encoding value into PostgreSQL format for oid and format. If no plan can be +// PlanEncode returns an EncodePlan for encoding value into PostgreSQL format for oid and format. If no plan can be // found then nil is returned. func (m *Map) PlanEncode(oid uint32, format int16, value any) EncodePlan { + return m.planEncodeDepth(oid, format, value, 0) +} + +func (m *Map) planEncodeDepth(oid uint32, format int16, value any, depth int) EncodePlan { + // Guard against infinite recursion. + if depth > 8 { + return nil + } + oidMemo := m.memoizedEncodePlans[oid] if oidMemo == nil { oidMemo = make(map[reflect.Type][2]EncodePlan) @@ -1185,7 +1212,7 @@ func (m *Map) PlanEncode(oid uint32, format int16, value any) EncodePlan { typeMemo := oidMemo[targetReflectType] plan := typeMemo[format] if plan == nil { - plan = m.planEncode(oid, format, value) + plan = m.planEncode(oid, format, value, depth) typeMemo[format] = plan oidMemo[targetReflectType] = typeMemo } @@ -1193,7 +1220,7 @@ func (m *Map) PlanEncode(oid uint32, format int16, value any) EncodePlan { return plan } -func (m *Map) planEncode(oid uint32, format int16, value any) EncodePlan { +func (m *Map) planEncode(oid uint32, format int16, value any, depth int) EncodePlan { if format == TextFormatCode { switch value.(type) { case string: @@ -1224,7 +1251,7 @@ func (m *Map) planEncode(oid uint32, format int16, value any) EncodePlan { for _, f := range m.TryWrapEncodePlanFuncs { if wrapperPlan, nextValue, ok := f(value); ok { - if nextPlan := m.PlanEncode(oid, format, nextValue); nextPlan != nil { + if nextPlan := m.planEncodeDepth(oid, format, nextValue, depth+1); nextPlan != nil { wrapperPlan.SetNext(nextPlan) return wrapperPlan } @@ -1405,6 +1432,15 @@ func TryWrapFindUnderlyingTypeEncodePlan(value any) (plan WrappedEncodePlanNextS return &underlyingTypeEncodePlan{nextValueType: byteSliceType}, refValue.Convert(byteSliceType).Interface(), true } + // Get underlying type of any array. + // https://github.com/jackc/pgx/issues/2107 + if refValue.Kind() == reflect.Array { + underlyingArrayType := reflect.ArrayOf(refValue.Len(), refValue.Type().Elem()) + if refValue.Type() != underlyingArrayType { + return &underlyingTypeEncodePlan{nextValueType: underlyingArrayType}, refValue.Convert(underlyingArrayType).Interface(), true + } + } + return nil, nil, false } @@ -1971,37 +2007,18 @@ func (w *sqlScannerWrapper) Scan(src any) error { case []byte: bufSrc = src default: - bufSrc = []byte(fmt.Sprint(bufSrc)) + bufSrc = fmt.Append(nil, bufSrc) } } return w.m.Scan(t.OID, TextFormatCode, bufSrc, w.v) } -// canBeNil returns true if value can be nil. -func canBeNil(value any) bool { - refVal := reflect.ValueOf(value) - kind := refVal.Kind() - switch kind { - case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.UnsafePointer, reflect.Interface, reflect.Slice: - return true - default: - return false - } -} - -// valuerReflectType is a reflect.Type for driver.Valuer. It has confusing syntax because reflect.TypeOf returns nil -// when it's argument is a nil interface value. So we use a pointer to the interface and call Elem to get the actual -// type. Yuck. -// -// This can be simplified in Go 1.22 with reflect.TypeFor. -// -// var valuerReflectType = reflect.TypeFor[driver.Valuer]() -var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() +var valuerReflectType = reflect.TypeFor[driver.Valuer]() // isNilDriverValuer returns true if value is any type of nil unless it implements driver.Valuer. *T is not considered to implement // driver.Valuer if it is only implemented by T. -func isNilDriverValuer(value any) (isNil bool, callNilDriverValuer bool) { +func isNilDriverValuer(value any) (isNil, callNilDriverValuer bool) { if value == nil { return true, false } diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/pgtype_default.go b/vendor/github.com/jackc/pgx/v5/pgtype/pgtype_default.go index 9525f37c9..5648d89bf 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/pgtype_default.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/pgtype_default.go @@ -2,6 +2,7 @@ package pgtype import ( "encoding/json" + "encoding/xml" "net" "net/netip" "reflect" @@ -22,7 +23,6 @@ func initDefaultMap() { reflectTypeToName: make(map[reflect.Type]string), oidToFormatCode: make(map[uint32]int16), - memoizedScanPlans: make(map[uint32]map[reflect.Type][2]ScanPlan), memoizedEncodePlans: make(map[uint32]map[reflect.Type][2]EncodePlan), TryWrapEncodePlanFuncs: []TryWrapEncodePlanFunc{ @@ -89,6 +89,26 @@ func initDefaultMap() { defaultMap.RegisterType(&Type{Name: "varbit", OID: VarbitOID, Codec: BitsCodec{}}) defaultMap.RegisterType(&Type{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}}) defaultMap.RegisterType(&Type{Name: "xid", OID: XIDOID, Codec: Uint32Codec{}}) + defaultMap.RegisterType(&Type{Name: "xid8", OID: XID8OID, Codec: Uint64Codec{}}) + defaultMap.RegisterType(&Type{Name: "xml", OID: XMLOID, Codec: &XMLCodec{ + Marshal: xml.Marshal, + // xml.Unmarshal does not support unmarshalling into *any. However, XMLCodec.DecodeValue calls Unmarshal with a + // *any. Wrap xml.Marshal with a function that copies the data into a new byte slice in this case. Not implementing + // directly in XMLCodec.DecodeValue to allow for the unlikely possibility that someone uses an alternative XML + // unmarshaler that does support unmarshalling into *any. + // + // https://github.com/jackc/pgx/issues/2227 + // https://github.com/jackc/pgx/pull/2228 + Unmarshal: func(data []byte, v any) error { + if v, ok := v.(*any); ok { + dstBuf := make([]byte, len(data)) + copy(dstBuf, data) + *v = dstBuf + return nil + } + return xml.Unmarshal(data, v) + }, + }}) // Range types defaultMap.RegisterType(&Type{Name: "daterange", OID: DaterangeOID, Codec: &RangeCodec{ElementType: defaultMap.oidToType[DateOID]}}) @@ -153,6 +173,8 @@ func initDefaultMap() { defaultMap.RegisterType(&Type{Name: "_varbit", OID: VarbitArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[VarbitOID]}}) defaultMap.RegisterType(&Type{Name: "_varchar", OID: VarcharArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[VarcharOID]}}) defaultMap.RegisterType(&Type{Name: "_xid", OID: XIDArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[XIDOID]}}) + defaultMap.RegisterType(&Type{Name: "_xid8", OID: XID8ArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[XID8OID]}}) + defaultMap.RegisterType(&Type{Name: "_xml", OID: XMLArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[XMLOID]}}) // Integer types that directly map to a PostgreSQL type registerDefaultPgTypeVariants[int16](defaultMap, "int2") diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/point.go b/vendor/github.com/jackc/pgx/v5/pgtype/point.go index 09b19bb53..b701513dc 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/point.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/point.go @@ -30,11 +30,13 @@ type Point struct { Valid bool } +// ScanPoint implements the [PointScanner] interface. func (p *Point) ScanPoint(v Point) error { *p = v return nil } +// PointValue implements the [PointValuer] interface. func (p Point) PointValue() (Point, error) { return p, nil } @@ -68,7 +70,7 @@ func parsePoint(src []byte) (*Point, error) { return &Point{P: Vec2{x, y}, Valid: true}, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (dst *Point) Scan(src any) error { if src == nil { *dst = Point{} @@ -83,7 +85,7 @@ func (dst *Point) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (src Point) Value() (driver.Value, error) { if !src.Valid { return nil, nil @@ -96,6 +98,7 @@ func (src Point) Value() (driver.Value, error) { return string(buf), err } +// MarshalJSON implements the [encoding/json.Marshaler] interface. func (src Point) MarshalJSON() ([]byte, error) { if !src.Valid { return []byte("null"), nil @@ -108,6 +111,7 @@ func (src Point) MarshalJSON() ([]byte, error) { return buff.Bytes(), nil } +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. func (dst *Point) UnmarshalJSON(point []byte) error { p, err := parsePoint(point) if err != nil { @@ -178,7 +182,6 @@ func (encodePlanPointCodecText) Encode(value any, buf []byte) (newBuf []byte, er } func (PointCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/polygon.go b/vendor/github.com/jackc/pgx/v5/pgtype/polygon.go index 04b0ba6b0..e18c9da63 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/polygon.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/polygon.go @@ -24,16 +24,18 @@ type Polygon struct { Valid bool } +// ScanPolygon implements the [PolygonScanner] interface. func (p *Polygon) ScanPolygon(v Polygon) error { *p = v return nil } +// PolygonValue implements the [PolygonValuer] interface. func (p Polygon) PolygonValue() (Polygon, error) { return p, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (p *Polygon) Scan(src any) error { if src == nil { *p = Polygon{} @@ -48,7 +50,7 @@ func (p *Polygon) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (p Polygon) Value() (driver.Value, error) { if !p.Valid { return nil, nil @@ -139,7 +141,6 @@ func (encodePlanPolygonCodecText) Encode(value any, buf []byte) (newBuf []byte, } func (PolygonCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { @@ -177,7 +178,7 @@ func (scanPlanBinaryPolygonToPolygonScanner) Scan(src []byte, dst any) error { } points := make([]Vec2, pointCount) - for i := 0; i < len(points); i++ { + for i := range points { x := binary.BigEndian.Uint64(src[rp:]) rp += 8 y := binary.BigEndian.Uint64(src[rp:]) diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/range.go b/vendor/github.com/jackc/pgx/v5/pgtype/range.go index 16427cccd..62d699905 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/range.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/range.go @@ -191,11 +191,13 @@ type untypedBinaryRange struct { // 18 = [ = 10010 // 24 = = 11000 -const emptyMask = 1 -const lowerInclusiveMask = 2 -const upperInclusiveMask = 4 -const lowerUnboundedMask = 8 -const upperUnboundedMask = 16 +const ( + emptyMask = 1 + lowerInclusiveMask = 2 + upperInclusiveMask = 4 + lowerUnboundedMask = 8 + upperUnboundedMask = 16 +) func parseUntypedBinaryRange(src []byte) (*untypedBinaryRange, error) { ubr := &untypedBinaryRange{} @@ -273,7 +275,6 @@ func parseUntypedBinaryRange(src []byte) (*untypedBinaryRange, error) { } return ubr, nil - } // Range is a generic range type. diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/record_codec.go b/vendor/github.com/jackc/pgx/v5/pgtype/record_codec.go index b3b166045..90b9bd4bb 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/record_codec.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/record_codec.go @@ -121,5 +121,4 @@ func (RecordCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (an default: return nil, fmt.Errorf("unknown format code %d", format) } - } diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/text.go b/vendor/github.com/jackc/pgx/v5/pgtype/text.go index 021ee331b..e08b12549 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/text.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/text.go @@ -19,16 +19,18 @@ type Text struct { Valid bool } +// ScanText implements the [TextScanner] interface. func (t *Text) ScanText(v Text) error { *t = v return nil } +// TextValue implements the [TextValuer] interface. func (t Text) TextValue() (Text, error) { return t, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (dst *Text) Scan(src any) error { if src == nil { *dst = Text{} @@ -47,7 +49,7 @@ func (dst *Text) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (src Text) Value() (driver.Value, error) { if !src.Valid { return nil, nil @@ -55,6 +57,7 @@ func (src Text) Value() (driver.Value, error) { return src.String, nil } +// MarshalJSON implements the [encoding/json.Marshaler] interface. func (src Text) MarshalJSON() ([]byte, error) { if !src.Valid { return []byte("null"), nil @@ -63,6 +66,7 @@ func (src Text) MarshalJSON() ([]byte, error) { return json.Marshal(src.String) } +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. func (dst *Text) UnmarshalJSON(b []byte) error { var s *string err := json.Unmarshal(b, &s) @@ -146,7 +150,6 @@ func (encodePlanTextCodecTextValuer) Encode(value any, buf []byte) (newBuf []byt } func (TextCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case TextFormatCode, BinaryFormatCode: switch target.(type) { diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/tid.go b/vendor/github.com/jackc/pgx/v5/pgtype/tid.go index 9bc2c2a14..05c9e6d98 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/tid.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/tid.go @@ -35,16 +35,18 @@ type TID struct { Valid bool } +// ScanTID implements the [TIDScanner] interface. func (b *TID) ScanTID(v TID) error { *b = v return nil } +// TIDValue implements the [TIDValuer] interface. func (b TID) TIDValue() (TID, error) { return b, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (dst *TID) Scan(src any) error { if src == nil { *dst = TID{} @@ -59,7 +61,7 @@ func (dst *TID) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (src TID) Value() (driver.Value, error) { if !src.Valid { return nil, nil @@ -131,7 +133,6 @@ func (encodePlanTIDCodecText) Encode(value any, buf []byte) (newBuf []byte, err } func (TIDCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/time.go b/vendor/github.com/jackc/pgx/v5/pgtype/time.go index 61a3abdfd..4b8f69083 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/time.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/time.go @@ -19,24 +19,28 @@ type TimeValuer interface { // Time represents the PostgreSQL time type. The PostgreSQL time is a time of day without time zone. // -// Time is represented as the number of microseconds since midnight in the same way that PostgreSQL does. Other time -// and date types in pgtype can use time.Time as the underlying representation. However, pgtype.Time type cannot due -// to needing to handle 24:00:00. time.Time converts that to 00:00:00 on the following day. +// Time is represented as the number of microseconds since midnight in the same way that PostgreSQL does. Other time and +// date types in pgtype can use time.Time as the underlying representation. However, pgtype.Time type cannot due to +// needing to handle 24:00:00. time.Time converts that to 00:00:00 on the following day. +// +// The time with time zone type is not supported. Use of time with time zone is discouraged by the PostgreSQL documentation. type Time struct { Microseconds int64 // Number of microseconds since midnight Valid bool } +// ScanTime implements the [TimeScanner] interface. func (t *Time) ScanTime(v Time) error { *t = v return nil } +// TimeValue implements the [TimeValuer] interface. func (t Time) TimeValue() (Time, error) { return t, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (t *Time) Scan(src any) error { if src == nil { *t = Time{} @@ -56,7 +60,7 @@ func (t *Time) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (t Time) Value() (driver.Value, error) { if !t.Valid { return nil, nil @@ -135,7 +139,6 @@ func (encodePlanTimeCodecText) Encode(value any, buf []byte) (newBuf []byte, err } func (TimeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/timestamp.go b/vendor/github.com/jackc/pgx/v5/pgtype/timestamp.go index 677a2c6ea..e94ce6bcd 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/timestamp.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/timestamp.go @@ -11,7 +11,10 @@ import ( "github.com/jackc/pgx/v5/internal/pgio" ) -const pgTimestampFormat = "2006-01-02 15:04:05.999999999" +const ( + pgTimestampFormat = "2006-01-02 15:04:05.999999999" + jsonISO8601 = "2006-01-02T15:04:05.999999999" +) type TimestampScanner interface { ScanTimestamp(v Timestamp) error @@ -28,16 +31,18 @@ type Timestamp struct { Valid bool } +// ScanTimestamp implements the [TimestampScanner] interface. func (ts *Timestamp) ScanTimestamp(v Timestamp) error { *ts = v return nil } +// TimestampValue implements the [TimestampValuer] interface. func (ts Timestamp) TimestampValue() (Timestamp, error) { return ts, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (ts *Timestamp) Scan(src any) error { if src == nil { *ts = Timestamp{} @@ -55,7 +60,7 @@ func (ts *Timestamp) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (ts Timestamp) Value() (driver.Value, error) { if !ts.Valid { return nil, nil @@ -67,6 +72,7 @@ func (ts Timestamp) Value() (driver.Value, error) { return ts.Time, nil } +// MarshalJSON implements the [encoding/json.Marshaler] interface. func (ts Timestamp) MarshalJSON() ([]byte, error) { if !ts.Valid { return []byte("null"), nil @@ -76,7 +82,7 @@ func (ts Timestamp) MarshalJSON() ([]byte, error) { switch ts.InfinityModifier { case Finite: - s = ts.Time.Format(time.RFC3339Nano) + s = ts.Time.Format(jsonISO8601) case Infinity: s = "infinity" case NegativeInfinity: @@ -86,6 +92,7 @@ func (ts Timestamp) MarshalJSON() ([]byte, error) { return json.Marshal(s) } +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. func (ts *Timestamp) UnmarshalJSON(b []byte) error { var s *string err := json.Unmarshal(b, &s) @@ -104,15 +111,23 @@ func (ts *Timestamp) UnmarshalJSON(b []byte) error { case "-infinity": *ts = Timestamp{Valid: true, InfinityModifier: -Infinity} default: - // PostgreSQL uses ISO 8601 for to_json function and casting from a string to timestamptz - tim, err := time.Parse(time.RFC3339Nano, *s) - if err != nil { - return err + // Parse time with or without timezonr + tss := *s + // PostgreSQL uses ISO 8601 without timezone for to_json function and casting from a string to timestampt + tim, err := time.Parse(time.RFC3339Nano, tss) + if err == nil { + *ts = Timestamp{Time: tim, Valid: true} + return nil } - - *ts = Timestamp{Time: tim, Valid: true} + tim, err = time.ParseInLocation(jsonISO8601, tss, time.UTC) + if err == nil { + *ts = Timestamp{Time: tim, Valid: true} + return nil + } + ts.Valid = false + return fmt.Errorf("cannot unmarshal %s to timestamp with layout %s or %s (%w)", + *s, time.RFC3339Nano, jsonISO8601, err) } - return nil } @@ -161,7 +176,7 @@ func (encodePlanTimestampCodecBinary) Encode(value any, buf []byte) (newBuf []by switch ts.InfinityModifier { case Finite: t := discardTimeZone(ts.Time) - microsecSinceUnixEpoch := t.Unix()*1000000 + int64(t.Nanosecond())/1000 + microsecSinceUnixEpoch := t.Unix()*1_000_000 + int64(t.Nanosecond())/1000 microsecSinceY2K = microsecSinceUnixEpoch - microsecFromUnixEpochToY2K case Infinity: microsecSinceY2K = infinityMicrosecondOffset @@ -225,7 +240,6 @@ func discardTimeZone(t time.Time) time.Time { } func (c *TimestampCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { @@ -265,8 +279,8 @@ func (plan *scanPlanBinaryTimestampToTimestampScanner) Scan(src []byte, dst any) ts = Timestamp{Valid: true, InfinityModifier: -Infinity} default: tim := time.Unix( - microsecFromUnixEpochToY2K/1000000+microsecSinceY2K/1000000, - (microsecFromUnixEpochToY2K%1000000*1000)+(microsecSinceY2K%1000000*1000), + microsecFromUnixEpochToY2K/1_000_000+microsecSinceY2K/1_000_000, + (microsecFromUnixEpochToY2K%1_000_000*1_000)+(microsecSinceY2K%1_000_000*1000), ).UTC() if plan.location != nil { tim = time.Date(tim.Year(), tim.Month(), tim.Day(), tim.Hour(), tim.Minute(), tim.Second(), tim.Nanosecond(), plan.location) diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/timestamptz.go b/vendor/github.com/jackc/pgx/v5/pgtype/timestamptz.go index 7efbcffd2..4d055bfab 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/timestamptz.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/timestamptz.go @@ -11,10 +11,12 @@ import ( "github.com/jackc/pgx/v5/internal/pgio" ) -const pgTimestamptzHourFormat = "2006-01-02 15:04:05.999999999Z07" -const pgTimestamptzMinuteFormat = "2006-01-02 15:04:05.999999999Z07:00" -const pgTimestamptzSecondFormat = "2006-01-02 15:04:05.999999999Z07:00:00" -const microsecFromUnixEpochToY2K = 946684800 * 1000000 +const ( + pgTimestamptzHourFormat = "2006-01-02 15:04:05.999999999Z07" + pgTimestamptzMinuteFormat = "2006-01-02 15:04:05.999999999Z07:00" + pgTimestamptzSecondFormat = "2006-01-02 15:04:05.999999999Z07:00:00" + microsecFromUnixEpochToY2K = 946_684_800 * 1_000_000 +) const ( negativeInfinityMicrosecondOffset = -9223372036854775808 @@ -36,16 +38,18 @@ type Timestamptz struct { Valid bool } +// ScanTimestamptz implements the [TimestamptzScanner] interface. func (tstz *Timestamptz) ScanTimestamptz(v Timestamptz) error { *tstz = v return nil } +// TimestamptzValue implements the [TimestamptzValuer] interface. func (tstz Timestamptz) TimestamptzValue() (Timestamptz, error) { return tstz, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (tstz *Timestamptz) Scan(src any) error { if src == nil { *tstz = Timestamptz{} @@ -63,7 +67,7 @@ func (tstz *Timestamptz) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (tstz Timestamptz) Value() (driver.Value, error) { if !tstz.Valid { return nil, nil @@ -75,6 +79,7 @@ func (tstz Timestamptz) Value() (driver.Value, error) { return tstz.Time, nil } +// MarshalJSON implements the [encoding/json.Marshaler] interface. func (tstz Timestamptz) MarshalJSON() ([]byte, error) { if !tstz.Valid { return []byte("null"), nil @@ -94,6 +99,7 @@ func (tstz Timestamptz) MarshalJSON() ([]byte, error) { return json.Marshal(s) } +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. func (tstz *Timestamptz) UnmarshalJSON(b []byte) error { var s *string err := json.Unmarshal(b, &s) @@ -225,7 +231,6 @@ func (encodePlanTimestamptzCodecText) Encode(value any, buf []byte) (newBuf []by } func (c *TimestamptzCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { @@ -265,8 +270,8 @@ func (plan *scanPlanBinaryTimestamptzToTimestamptzScanner) Scan(src []byte, dst tstz = Timestamptz{Valid: true, InfinityModifier: -Infinity} default: tim := time.Unix( - microsecFromUnixEpochToY2K/1000000+microsecSinceY2K/1000000, - (microsecFromUnixEpochToY2K%1000000*1000)+(microsecSinceY2K%1000000*1000), + microsecFromUnixEpochToY2K/1_000_000+microsecSinceY2K/1_000_000, + (microsecFromUnixEpochToY2K%1_000_000*1_000)+(microsecSinceY2K%1_000_000*1_000), ) if plan.location != nil { tim = tim.In(plan.location) diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/uint32.go b/vendor/github.com/jackc/pgx/v5/pgtype/uint32.go index 098c516c1..e6d4b1cf6 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/uint32.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/uint32.go @@ -3,6 +3,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "encoding/json" "fmt" "math" "strconv" @@ -24,16 +25,18 @@ type Uint32 struct { Valid bool } +// ScanUint32 implements the [Uint32Scanner] interface. func (n *Uint32) ScanUint32(v Uint32) error { *n = v return nil } +// Uint32Value implements the [Uint32Valuer] interface. func (n Uint32) Uint32Value() (Uint32, error) { return n, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (dst *Uint32) Scan(src any) error { if src == nil { *dst = Uint32{} @@ -67,7 +70,7 @@ func (dst *Uint32) Scan(src any) error { return nil } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (src Uint32) Value() (driver.Value, error) { if !src.Valid { return nil, nil @@ -75,6 +78,31 @@ func (src Uint32) Value() (driver.Value, error) { return int64(src.Uint32), nil } +// MarshalJSON implements the [encoding/json.Marshaler] interface. +func (src Uint32) MarshalJSON() ([]byte, error) { + if !src.Valid { + return []byte("null"), nil + } + return json.Marshal(src.Uint32) +} + +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. +func (dst *Uint32) UnmarshalJSON(b []byte) error { + var n *uint32 + err := json.Unmarshal(b, &n) + if err != nil { + return err + } + + if n == nil { + *dst = Uint32{} + } else { + *dst = Uint32{Uint32: *n, Valid: true} + } + + return nil +} + type Uint32Codec struct{} func (Uint32Codec) FormatSupported(format int16) bool { @@ -197,7 +225,6 @@ func (encodePlanUint32CodecTextInt64Valuer) Encode(value any, buf []byte) (newBu } func (Uint32Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { @@ -205,6 +232,8 @@ func (Uint32Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPl return scanPlanBinaryUint32ToUint32{} case Uint32Scanner: return scanPlanBinaryUint32ToUint32Scanner{} + case TextScanner: + return scanPlanBinaryUint32ToTextScanner{} } case TextFormatCode: switch target.(type) { @@ -282,6 +311,26 @@ func (scanPlanBinaryUint32ToUint32Scanner) Scan(src []byte, dst any) error { return s.ScanUint32(Uint32{Uint32: n, Valid: true}) } +type scanPlanBinaryUint32ToTextScanner struct{} + +func (scanPlanBinaryUint32ToTextScanner) Scan(src []byte, dst any) error { + s, ok := (dst).(TextScanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanText(Text{}) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for uint32: %v", len(src)) + } + + n := uint64(binary.BigEndian.Uint32(src)) + return s.ScanText(Text{String: strconv.FormatUint(n, 10), Valid: true}) +} + type scanPlanTextAnyToUint32Scanner struct{} func (scanPlanTextAnyToUint32Scanner) Scan(src []byte, dst any) error { diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/uint64.go b/vendor/github.com/jackc/pgx/v5/pgtype/uint64.go new file mode 100644 index 000000000..68fd16613 --- /dev/null +++ b/vendor/github.com/jackc/pgx/v5/pgtype/uint64.go @@ -0,0 +1,323 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "strconv" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type Uint64Scanner interface { + ScanUint64(v Uint64) error +} + +type Uint64Valuer interface { + Uint64Value() (Uint64, error) +} + +// Uint64 is the core type that is used to represent PostgreSQL types such as XID8. +type Uint64 struct { + Uint64 uint64 + Valid bool +} + +// ScanUint64 implements the [Uint64Scanner] interface. +func (n *Uint64) ScanUint64(v Uint64) error { + *n = v + return nil +} + +// Uint64Value implements the [Uint64Valuer] interface. +func (n Uint64) Uint64Value() (Uint64, error) { + return n, nil +} + +// Scan implements the [database/sql.Scanner] interface. +func (dst *Uint64) Scan(src any) error { + if src == nil { + *dst = Uint64{} + return nil + } + + var n uint64 + + switch src := src.(type) { + case int64: + if src < 0 { + return fmt.Errorf("%d is less than the minimum value for Uint64", src) + } + n = uint64(src) + case string: + un, err := strconv.ParseUint(src, 10, 64) + if err != nil { + return err + } + n = un + default: + return fmt.Errorf("cannot scan %T", src) + } + + *dst = Uint64{Uint64: n, Valid: true} + + return nil +} + +// Value implements the [database/sql/driver.Valuer] interface. +func (src Uint64) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + + // If the value is greater than the maximum value for int64, return it as a string instead of losing data or returning + // an error. + if src.Uint64 > math.MaxInt64 { + return strconv.FormatUint(src.Uint64, 10), nil + } + + return int64(src.Uint64), nil +} + +type Uint64Codec struct{} + +func (Uint64Codec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (Uint64Codec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (Uint64Codec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + switch format { + case BinaryFormatCode: + switch value.(type) { + case uint64: + return encodePlanUint64CodecBinaryUint64{} + case Uint64Valuer: + return encodePlanUint64CodecBinaryUint64Valuer{} + case Int64Valuer: + return encodePlanUint64CodecBinaryInt64Valuer{} + } + case TextFormatCode: + switch value.(type) { + case uint64: + return encodePlanUint64CodecTextUint64{} + case Int64Valuer: + return encodePlanUint64CodecTextInt64Valuer{} + } + } + + return nil +} + +type encodePlanUint64CodecBinaryUint64 struct{} + +func (encodePlanUint64CodecBinaryUint64) Encode(value any, buf []byte) (newBuf []byte, err error) { + v := value.(uint64) + return pgio.AppendUint64(buf, v), nil +} + +type encodePlanUint64CodecBinaryUint64Valuer struct{} + +func (encodePlanUint64CodecBinaryUint64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + v, err := value.(Uint64Valuer).Uint64Value() + if err != nil { + return nil, err + } + + if !v.Valid { + return nil, nil + } + + return pgio.AppendUint64(buf, v.Uint64), nil +} + +type encodePlanUint64CodecBinaryInt64Valuer struct{} + +func (encodePlanUint64CodecBinaryInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + v, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !v.Valid { + return nil, nil + } + + if v.Int64 < 0 { + return nil, fmt.Errorf("%d is less than minimum value for uint64", v.Int64) + } + + return pgio.AppendUint64(buf, uint64(v.Int64)), nil +} + +type encodePlanUint64CodecTextUint64 struct{} + +func (encodePlanUint64CodecTextUint64) Encode(value any, buf []byte) (newBuf []byte, err error) { + v := value.(uint64) + return append(buf, strconv.FormatUint(uint64(v), 10)...), nil +} + +type encodePlanUint64CodecTextUint64Valuer struct{} + +func (encodePlanUint64CodecTextUint64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + v, err := value.(Uint64Valuer).Uint64Value() + if err != nil { + return nil, err + } + + if !v.Valid { + return nil, nil + } + + return append(buf, strconv.FormatUint(v.Uint64, 10)...), nil +} + +type encodePlanUint64CodecTextInt64Valuer struct{} + +func (encodePlanUint64CodecTextInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + v, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !v.Valid { + return nil, nil + } + + if v.Int64 < 0 { + return nil, fmt.Errorf("%d is less than minimum value for uint64", v.Int64) + } + + return append(buf, strconv.FormatInt(v.Int64, 10)...), nil +} + +func (Uint64Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case BinaryFormatCode: + switch target.(type) { + case *uint64: + return scanPlanBinaryUint64ToUint64{} + case Uint64Scanner: + return scanPlanBinaryUint64ToUint64Scanner{} + case TextScanner: + return scanPlanBinaryUint64ToTextScanner{} + } + case TextFormatCode: + switch target.(type) { + case *uint64: + return scanPlanTextAnyToUint64{} + case Uint64Scanner: + return scanPlanTextAnyToUint64Scanner{} + } + } + + return nil +} + +func (c Uint64Codec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + var n uint64 + err := codecScan(c, m, oid, format, src, &n) + if err != nil { + return nil, err + } + return int64(n), nil +} + +func (c Uint64Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var n uint64 + err := codecScan(c, m, oid, format, src, &n) + if err != nil { + return nil, err + } + return n, nil +} + +type scanPlanBinaryUint64ToUint64 struct{} + +func (scanPlanBinaryUint64ToUint64) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for uint64: %v", len(src)) + } + + p := (dst).(*uint64) + *p = binary.BigEndian.Uint64(src) + + return nil +} + +type scanPlanBinaryUint64ToUint64Scanner struct{} + +func (scanPlanBinaryUint64ToUint64Scanner) Scan(src []byte, dst any) error { + s, ok := (dst).(Uint64Scanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanUint64(Uint64{}) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for uint64: %v", len(src)) + } + + n := binary.BigEndian.Uint64(src) + + return s.ScanUint64(Uint64{Uint64: n, Valid: true}) +} + +type scanPlanBinaryUint64ToTextScanner struct{} + +func (scanPlanBinaryUint64ToTextScanner) Scan(src []byte, dst any) error { + s, ok := (dst).(TextScanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanText(Text{}) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for uint64: %v", len(src)) + } + + n := uint64(binary.BigEndian.Uint64(src)) + return s.ScanText(Text{String: strconv.FormatUint(n, 10), Valid: true}) +} + +type scanPlanTextAnyToUint64Scanner struct{} + +func (scanPlanTextAnyToUint64Scanner) Scan(src []byte, dst any) error { + s, ok := (dst).(Uint64Scanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanUint64(Uint64{}) + } + + n, err := strconv.ParseUint(string(src), 10, 64) + if err != nil { + return err + } + + return s.ScanUint64(Uint64{Uint64: n, Valid: true}) +} diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/uuid.go b/vendor/github.com/jackc/pgx/v5/pgtype/uuid.go index d57c0f2fa..83d0c4127 100644 --- a/vendor/github.com/jackc/pgx/v5/pgtype/uuid.go +++ b/vendor/github.com/jackc/pgx/v5/pgtype/uuid.go @@ -20,11 +20,13 @@ type UUID struct { Valid bool } +// ScanUUID implements the [UUIDScanner] interface. func (b *UUID) ScanUUID(v UUID) error { *b = v return nil } +// UUIDValue implements the [UUIDValuer] interface. func (b UUID) UUIDValue() (UUID, error) { return b, nil } @@ -67,7 +69,7 @@ func encodeUUID(src [16]byte) string { return string(buf[:]) } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (dst *UUID) Scan(src any) error { if src == nil { *dst = UUID{} @@ -87,7 +89,7 @@ func (dst *UUID) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (src UUID) Value() (driver.Value, error) { if !src.Valid { return nil, nil @@ -96,6 +98,15 @@ func (src UUID) Value() (driver.Value, error) { return encodeUUID(src.Bytes), nil } +func (src UUID) String() string { + if !src.Valid { + return "" + } + + return encodeUUID(src.Bytes) +} + +// MarshalJSON implements the [encoding/json.Marshaler] interface. func (src UUID) MarshalJSON() ([]byte, error) { if !src.Valid { return []byte("null"), nil @@ -108,6 +119,7 @@ func (src UUID) MarshalJSON() ([]byte, error) { return buff.Bytes(), nil } +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. func (dst *UUID) UnmarshalJSON(src []byte) error { if bytes.Equal(src, []byte("null")) { *dst = UUID{} diff --git a/vendor/github.com/jackc/pgx/v5/pgtype/xml.go b/vendor/github.com/jackc/pgx/v5/pgtype/xml.go new file mode 100644 index 000000000..79e3698a4 --- /dev/null +++ b/vendor/github.com/jackc/pgx/v5/pgtype/xml.go @@ -0,0 +1,198 @@ +package pgtype + +import ( + "database/sql" + "database/sql/driver" + "encoding/xml" + "fmt" + "reflect" +) + +type XMLCodec struct { + Marshal func(v any) ([]byte, error) + Unmarshal func(data []byte, v any) error +} + +func (*XMLCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (*XMLCodec) PreferredFormat() int16 { + return TextFormatCode +} + +func (c *XMLCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + switch value.(type) { + case string: + return encodePlanXMLCodecEitherFormatString{} + case []byte: + return encodePlanXMLCodecEitherFormatByteSlice{} + + // Cannot rely on driver.Valuer being handled later because anything can be marshalled. + // + // https://github.com/jackc/pgx/issues/1430 + // + // Check for driver.Valuer must come before xml.Marshaler so that it is guaranteed to be used + // when both are implemented https://github.com/jackc/pgx/issues/1805 + case driver.Valuer: + return &encodePlanDriverValuer{m: m, oid: oid, formatCode: format} + + // Must come before trying wrap encode plans because a pointer to a struct may be unwrapped to a struct that can be + // marshalled. + // + // https://github.com/jackc/pgx/issues/1681 + case xml.Marshaler: + return &encodePlanXMLCodecEitherFormatMarshal{ + marshal: c.Marshal, + } + } + + // Because anything can be marshalled the normal wrapping in Map.PlanScan doesn't get a chance to run. So try the + // appropriate wrappers here. + for _, f := range []TryWrapEncodePlanFunc{ + TryWrapDerefPointerEncodePlan, + TryWrapFindUnderlyingTypeEncodePlan, + } { + if wrapperPlan, nextValue, ok := f(value); ok { + if nextPlan := c.PlanEncode(m, oid, format, nextValue); nextPlan != nil { + wrapperPlan.SetNext(nextPlan) + return wrapperPlan + } + } + } + + return &encodePlanXMLCodecEitherFormatMarshal{ + marshal: c.Marshal, + } +} + +type encodePlanXMLCodecEitherFormatString struct{} + +func (encodePlanXMLCodecEitherFormatString) Encode(value any, buf []byte) (newBuf []byte, err error) { + xmlString := value.(string) + buf = append(buf, xmlString...) + return buf, nil +} + +type encodePlanXMLCodecEitherFormatByteSlice struct{} + +func (encodePlanXMLCodecEitherFormatByteSlice) Encode(value any, buf []byte) (newBuf []byte, err error) { + xmlBytes := value.([]byte) + if xmlBytes == nil { + return nil, nil + } + + buf = append(buf, xmlBytes...) + return buf, nil +} + +type encodePlanXMLCodecEitherFormatMarshal struct { + marshal func(v any) ([]byte, error) +} + +func (e *encodePlanXMLCodecEitherFormatMarshal) Encode(value any, buf []byte) (newBuf []byte, err error) { + xmlBytes, err := e.marshal(value) + if err != nil { + return nil, err + } + + buf = append(buf, xmlBytes...) + return buf, nil +} + +func (c *XMLCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch target.(type) { + case *string: + return scanPlanAnyToString{} + + case **string: + // This is to fix **string scanning. It seems wrong to special case **string, but it's not clear what a better + // solution would be. + // + // https://github.com/jackc/pgx/issues/1470 -- **string + // https://github.com/jackc/pgx/issues/1691 -- ** anything else + + if wrapperPlan, nextDst, ok := TryPointerPointerScanPlan(target); ok { + if nextPlan := m.planScan(oid, format, nextDst, 0); nextPlan != nil { + if _, failed := nextPlan.(*scanPlanFail); !failed { + wrapperPlan.SetNext(nextPlan) + return wrapperPlan + } + } + } + + case *[]byte: + return scanPlanXMLToByteSlice{} + case BytesScanner: + return scanPlanBinaryBytesToBytesScanner{} + + // Cannot rely on sql.Scanner being handled later because scanPlanXMLToXMLUnmarshal will take precedence. + // + // https://github.com/jackc/pgx/issues/1418 + case sql.Scanner: + return &scanPlanSQLScanner{formatCode: format} + } + + return &scanPlanXMLToXMLUnmarshal{ + unmarshal: c.Unmarshal, + } +} + +type scanPlanXMLToByteSlice struct{} + +func (scanPlanXMLToByteSlice) Scan(src []byte, dst any) error { + dstBuf := dst.(*[]byte) + if src == nil { + *dstBuf = nil + return nil + } + + *dstBuf = make([]byte, len(src)) + copy(*dstBuf, src) + return nil +} + +type scanPlanXMLToXMLUnmarshal struct { + unmarshal func(data []byte, v any) error +} + +func (s *scanPlanXMLToXMLUnmarshal) Scan(src []byte, dst any) error { + if src == nil { + dstValue := reflect.ValueOf(dst) + if dstValue.Kind() == reflect.Ptr { + el := dstValue.Elem() + switch el.Kind() { + case reflect.Ptr, reflect.Slice, reflect.Map, reflect.Interface, reflect.Struct: + el.Set(reflect.Zero(el.Type())) + return nil + } + } + + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + elem := reflect.ValueOf(dst).Elem() + elem.Set(reflect.Zero(elem.Type())) + + return s.unmarshal(src, dst) +} + +func (c *XMLCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + dstBuf := make([]byte, len(src)) + copy(dstBuf, src) + return dstBuf, nil +} + +func (c *XMLCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var dst any + err := c.Unmarshal(src, &dst) + return dst, err +} diff --git a/vendor/github.com/jackc/pgx/v5/pgxpool/pool.go b/vendor/github.com/jackc/pgx/v5/pgxpool/pool.go index fdcba7241..df97bdbd7 100644 --- a/vendor/github.com/jackc/pgx/v5/pgxpool/pool.go +++ b/vendor/github.com/jackc/pgx/v5/pgxpool/pool.go @@ -2,8 +2,8 @@ package pgxpool import ( "context" - "fmt" - "math/rand" + "errors" + "math/rand/v2" "runtime" "strconv" "sync" @@ -15,11 +15,14 @@ import ( "github.com/jackc/puddle/v2" ) -var defaultMaxConns = int32(4) -var defaultMinConns = int32(0) -var defaultMaxConnLifetime = time.Hour -var defaultMaxConnIdleTime = time.Minute * 30 -var defaultHealthCheckPeriod = time.Minute +var ( + defaultMaxConns = int32(4) + defaultMinConns = int32(0) + defaultMinIdleConns = int32(0) + defaultMaxConnLifetime = time.Hour + defaultMaxConnIdleTime = time.Minute * 30 + defaultHealthCheckPeriod = time.Minute +) type connResource struct { conn *pgx.Conn @@ -83,15 +86,21 @@ type Pool struct { config *Config beforeConnect func(context.Context, *pgx.ConnConfig) error afterConnect func(context.Context, *pgx.Conn) error - beforeAcquire func(context.Context, *pgx.Conn) bool + prepareConn func(context.Context, *pgx.Conn) (bool, error) afterRelease func(*pgx.Conn) bool beforeClose func(*pgx.Conn) + shouldPing func(context.Context, ShouldPingParams) bool minConns int32 + minIdleConns int32 maxConns int32 maxConnLifetime time.Duration maxConnLifetimeJitter time.Duration maxConnIdleTime time.Duration healthCheckPeriod time.Duration + pingTimeout time.Duration + + healthCheckMu sync.Mutex + healthCheckTimer *time.Timer healthCheckChan chan struct{} @@ -102,6 +111,12 @@ type Pool struct { closeChan chan struct{} } +// ShouldPingParams are the parameters passed to ShouldPing. +type ShouldPingParams struct { + Conn *pgx.Conn + IdleDuration time.Duration +} + // Config is the configuration struct for creating a pool. It must be created by [ParseConfig] and then it can be // modified. type Config struct { @@ -117,8 +132,23 @@ type Config struct { // BeforeAcquire is called before a connection is acquired from the pool. It must return true to allow the // acquisition or false to indicate that the connection should be destroyed and a different connection should be // acquired. + // + // Deprecated: Use PrepareConn instead. If both PrepareConn and BeforeAcquire are set, PrepareConn will take + // precedence, ignoring BeforeAcquire. BeforeAcquire func(context.Context, *pgx.Conn) bool + // PrepareConn is called before a connection is acquired from the pool. If this function returns true, the connection + // is considered valid, otherwise the connection is destroyed. If the function returns a non-nil error, the instigating + // query will fail with the returned error. + // + // Specifically, this means that: + // + // - If it returns true and a nil error, the query proceeds as normal. + // - If it returns true and an error, the connection will be returned to the pool, and the instigating query will fail with the returned error. + // - If it returns false, and an error, the connection will be destroyed, and the query will fail with the returned error. + // - If it returns false and a nil error, the connection will be destroyed, and the instigating query will be retried on a new connection. + PrepareConn func(context.Context, *pgx.Conn) (bool, error) + // AfterRelease is called after a connection is released, but before it is returned to the pool. It must return true to // return the connection to the pool or false to destroy the connection. AfterRelease func(*pgx.Conn) bool @@ -126,6 +156,10 @@ type Config struct { // BeforeClose is called right before a connection is closed and removed from the pool. BeforeClose func(*pgx.Conn) + // ShouldPing is called after a connection is acquired from the pool. If it returns true, the connection is pinged to check for liveness. + // If this func is not set, the default behavior is to ping connections that have been idle for at least 1 second. + ShouldPing func(context.Context, ShouldPingParams) bool + // MaxConnLifetime is the duration since creation after which a connection will be automatically closed. MaxConnLifetime time.Duration @@ -136,6 +170,10 @@ type Config struct { // MaxConnIdleTime is the duration after which an idle connection will be automatically closed by the health check. MaxConnIdleTime time.Duration + // PingTimeout is the maximum amount of time to wait for a connection to pong before considering it as unhealthy and + // destroying it. If zero, the default is no timeout. + PingTimeout time.Duration + // MaxConns is the maximum size of the pool. The default is the greater of 4 or runtime.NumCPU(). MaxConns int32 @@ -144,6 +182,13 @@ type Config struct { // to create new connections. MinConns int32 + // MinIdleConns is the minimum number of idle connections in the pool. You can increase this to ensure that + // there are always idle connections available. This can help reduce tail latencies during request processing, + // as you can avoid the latency of establishing a new connection while handling requests. It is superior + // to MinConns for this purpose. + // Similar to MinConns, the pool might temporarily dip below MinIdleConns after connection closes. + MinIdleConns int32 + // HealthCheckPeriod is the duration between checks of the health of idle connections. HealthCheckPeriod time.Duration @@ -181,18 +226,27 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) { panic("config must be created by ParseConfig") } + prepareConn := config.PrepareConn + if prepareConn == nil && config.BeforeAcquire != nil { + prepareConn = func(ctx context.Context, conn *pgx.Conn) (bool, error) { + return config.BeforeAcquire(ctx, conn), nil + } + } + p := &Pool{ config: config, beforeConnect: config.BeforeConnect, afterConnect: config.AfterConnect, - beforeAcquire: config.BeforeAcquire, + prepareConn: prepareConn, afterRelease: config.AfterRelease, beforeClose: config.BeforeClose, minConns: config.MinConns, + minIdleConns: config.MinIdleConns, maxConns: config.MaxConns, maxConnLifetime: config.MaxConnLifetime, maxConnLifetimeJitter: config.MaxConnLifetimeJitter, maxConnIdleTime: config.MaxConnIdleTime, + pingTimeout: config.PingTimeout, healthCheckPeriod: config.HealthCheckPeriod, healthCheckChan: make(chan struct{}, 1), closeChan: make(chan struct{}), @@ -206,6 +260,14 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) { p.releaseTracer = t } + if config.ShouldPing != nil { + p.shouldPing = config.ShouldPing + } else { + p.shouldPing = func(ctx context.Context, params ShouldPingParams) bool { + return params.IdleDuration > time.Second + } + } + var err error p.p, err = puddle.NewPool( &puddle.Config[*connResource]{ @@ -271,7 +333,8 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) { } go func() { - p.createIdleResources(ctx, int(p.minConns)) + targetIdleResources := max(int(p.minConns), int(p.minIdleConns)) + p.createIdleResources(ctx, targetIdleResources) p.backgroundHealthCheck() }() @@ -281,20 +344,20 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) { // ParseConfig builds a Config from connString. It parses connString with the same behavior as [pgx.ParseConfig] with the // addition of the following variables: // -// - pool_max_conns: integer greater than 0 -// - pool_min_conns: integer 0 or greater -// - pool_max_conn_lifetime: duration string -// - pool_max_conn_idle_time: duration string -// - pool_health_check_period: duration string -// - pool_max_conn_lifetime_jitter: duration string +// - pool_max_conns: integer greater than 0 (default 4) +// - pool_min_conns: integer 0 or greater (default 0) +// - pool_max_conn_lifetime: duration string (default 1 hour) +// - pool_max_conn_idle_time: duration string (default 30 minutes) +// - pool_health_check_period: duration string (default 1 minute) +// - pool_max_conn_lifetime_jitter: duration string (default 0) // // See Config for definitions of these arguments. // // # Example Keyword/Value -// user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca pool_max_conns=10 +// user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca pool_max_conns=10 pool_max_conn_lifetime=1h30m // // # Example URL -// postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca&pool_max_conns=10 +// postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca&pool_max_conns=10&pool_max_conn_lifetime=1h30m func ParseConfig(connString string) (*Config, error) { connConfig, err := pgx.ParseConfig(connString) if err != nil { @@ -310,10 +373,10 @@ func ParseConfig(connString string) (*Config, error) { delete(connConfig.Config.RuntimeParams, "pool_max_conns") n, err := strconv.ParseInt(s, 10, 32) if err != nil { - return nil, fmt.Errorf("cannot parse pool_max_conns: %w", err) + return nil, pgconn.NewParseConfigError(connString, "cannot parse pool_max_conns", err) } if n < 1 { - return nil, fmt.Errorf("pool_max_conns too small: %d", n) + return nil, pgconn.NewParseConfigError(connString, "pool_max_conns too small", err) } config.MaxConns = int32(n) } else { @@ -327,18 +390,29 @@ func ParseConfig(connString string) (*Config, error) { delete(connConfig.Config.RuntimeParams, "pool_min_conns") n, err := strconv.ParseInt(s, 10, 32) if err != nil { - return nil, fmt.Errorf("cannot parse pool_min_conns: %w", err) + return nil, pgconn.NewParseConfigError(connString, "cannot parse pool_min_conns", err) } config.MinConns = int32(n) } else { config.MinConns = defaultMinConns } + if s, ok := config.ConnConfig.Config.RuntimeParams["pool_min_idle_conns"]; ok { + delete(connConfig.Config.RuntimeParams, "pool_min_idle_conns") + n, err := strconv.ParseInt(s, 10, 32) + if err != nil { + return nil, pgconn.NewParseConfigError(connString, "cannot parse pool_min_idle_conns", err) + } + config.MinIdleConns = int32(n) + } else { + config.MinIdleConns = defaultMinIdleConns + } + if s, ok := config.ConnConfig.Config.RuntimeParams["pool_max_conn_lifetime"]; ok { delete(connConfig.Config.RuntimeParams, "pool_max_conn_lifetime") d, err := time.ParseDuration(s) if err != nil { - return nil, fmt.Errorf("invalid pool_max_conn_lifetime: %w", err) + return nil, pgconn.NewParseConfigError(connString, "cannot parse pool_max_conn_lifetime", err) } config.MaxConnLifetime = d } else { @@ -349,7 +423,7 @@ func ParseConfig(connString string) (*Config, error) { delete(connConfig.Config.RuntimeParams, "pool_max_conn_idle_time") d, err := time.ParseDuration(s) if err != nil { - return nil, fmt.Errorf("invalid pool_max_conn_idle_time: %w", err) + return nil, pgconn.NewParseConfigError(connString, "cannot parse pool_max_conn_idle_time", err) } config.MaxConnIdleTime = d } else { @@ -360,7 +434,7 @@ func ParseConfig(connString string) (*Config, error) { delete(connConfig.Config.RuntimeParams, "pool_health_check_period") d, err := time.ParseDuration(s) if err != nil { - return nil, fmt.Errorf("invalid pool_health_check_period: %w", err) + return nil, pgconn.NewParseConfigError(connString, "cannot parse pool_health_check_period", err) } config.HealthCheckPeriod = d } else { @@ -371,7 +445,7 @@ func ParseConfig(connString string) (*Config, error) { delete(connConfig.Config.RuntimeParams, "pool_max_conn_lifetime_jitter") d, err := time.ParseDuration(s) if err != nil { - return nil, fmt.Errorf("invalid pool_max_conn_lifetime_jitter: %w", err) + return nil, pgconn.NewParseConfigError(connString, "cannot parse pool_max_conn_lifetime_jitter", err) } config.MaxConnLifetimeJitter = d } @@ -393,15 +467,25 @@ func (p *Pool) isExpired(res *puddle.Resource[*connResource]) bool { } func (p *Pool) triggerHealthCheck() { - go func() { + const healthCheckDelay = 500 * time.Millisecond + + p.healthCheckMu.Lock() + defer p.healthCheckMu.Unlock() + + if p.healthCheckTimer == nil { // Destroy is asynchronous so we give it time to actually remove itself from // the pool otherwise we might try to check the pool size too soon - time.Sleep(500 * time.Millisecond) - select { - case p.healthCheckChan <- struct{}{}: - default: - } - }() + p.healthCheckTimer = time.AfterFunc(healthCheckDelay, func() { + select { + case <-p.closeChan: + case p.healthCheckChan <- struct{}{}: + default: + } + }) + return + } + + p.healthCheckTimer.Reset(healthCheckDelay) } func (p *Pool) backgroundHealthCheck() { @@ -472,7 +556,9 @@ func (p *Pool) checkMinConns() error { // TotalConns can include ones that are being destroyed but we should have // sleep(500ms) around all of the destroys to help prevent that from throwing // off this check - toCreate := p.minConns - p.Stat().TotalConns() + + // Create the number of connections needed to get to both minConns and minIdleConns + toCreate := max(p.minConns-p.Stat().TotalConns(), p.minIdleConns-p.Stat().IdleConns()) if toCreate > 0 { return p.createIdleResources(context.Background(), int(toCreate)) } @@ -485,7 +571,7 @@ func (p *Pool) createIdleResources(parentCtx context.Context, targetResources in errs := make(chan error, targetResources) - for i := 0; i < targetResources; i++ { + for range targetResources { go func() { err := p.p.CreateResource(ctx) // Ignore ErrNotAvailable since it means that the pool has become full since we started creating resource. @@ -497,7 +583,7 @@ func (p *Pool) createIdleResources(parentCtx context.Context, targetResources in } var firstError error - for i := 0; i < targetResources; i++ { + for range targetResources { err := <-errs if err != nil && firstError == nil { cancel() @@ -521,7 +607,10 @@ func (p *Pool) Acquire(ctx context.Context) (c *Conn, err error) { }() } - for { + // Try to acquire from the connection pool up to maxConns + 1 times, so that + // any that fatal errors would empty the pool and still at least try 1 fresh + // connection. + for range int(p.maxConns) + 1 { res, err := p.p.Acquire(ctx) if err != nil { return nil, err @@ -529,20 +618,41 @@ func (p *Pool) Acquire(ctx context.Context) (c *Conn, err error) { cr := res.Value() - if res.IdleDuration() > time.Second { - err := cr.conn.Ping(ctx) + shouldPingParams := ShouldPingParams{Conn: cr.conn, IdleDuration: res.IdleDuration()} + if p.shouldPing(ctx, shouldPingParams) { + pingCtx := ctx + if p.pingTimeout > 0 { + var cancel context.CancelFunc + pingCtx, cancel = context.WithTimeout(ctx, p.pingTimeout) + defer cancel() + } + + err := cr.conn.Ping(pingCtx) if err != nil { res.Destroy() continue } } - if p.beforeAcquire == nil || p.beforeAcquire(ctx, cr.conn) { - return cr.getConn(p, res), nil + if p.prepareConn != nil { + ok, err := p.prepareConn(ctx, cr.conn) + if !ok { + res.Destroy() + } + if err != nil { + if ok { + res.Release() + } + return nil, err + } + if !ok { + continue + } } - res.Destroy() + return cr.getConn(p, res), nil } + return nil, errors.New("pgxpool: detected infinite loop acquiring connection; likely bug in PrepareConn or BeforeAcquire hook") } // AcquireFunc acquires a *Conn and calls f with that *Conn. ctx will only affect the Acquire. It has no effect on the @@ -565,11 +675,14 @@ func (p *Pool) AcquireAllIdle(ctx context.Context) []*Conn { conns := make([]*Conn, 0, len(resources)) for _, res := range resources { cr := res.Value() - if p.beforeAcquire == nil || p.beforeAcquire(ctx, cr.conn) { - conns = append(conns, cr.getConn(p, res)) - } else { - res.Destroy() + if p.prepareConn != nil { + ok, err := p.prepareConn(ctx, cr.conn) + if !ok || err != nil { + res.Destroy() + continue + } } + conns = append(conns, cr.getConn(p, res)) } return conns diff --git a/vendor/github.com/jackc/pgx/v5/pgxpool/stat.go b/vendor/github.com/jackc/pgx/v5/pgxpool/stat.go index cfa0c4c56..e02b6ac39 100644 --- a/vendor/github.com/jackc/pgx/v5/pgxpool/stat.go +++ b/vendor/github.com/jackc/pgx/v5/pgxpool/stat.go @@ -82,3 +82,10 @@ func (s *Stat) MaxLifetimeDestroyCount() int64 { func (s *Stat) MaxIdleDestroyCount() int64 { return s.idleDestroyCount } + +// EmptyAcquireWaitTime returns the cumulative time waited for successful acquires +// from the pool for a resource to be released or constructed because the pool was +// empty. +func (s *Stat) EmptyAcquireWaitTime() time.Duration { + return s.s.EmptyAcquireWaitTime() +} diff --git a/vendor/github.com/jackc/pgx/v5/pgxpool/tx.go b/vendor/github.com/jackc/pgx/v5/pgxpool/tx.go index 74df8593a..b49e7f4d9 100644 --- a/vendor/github.com/jackc/pgx/v5/pgxpool/tx.go +++ b/vendor/github.com/jackc/pgx/v5/pgxpool/tx.go @@ -18,9 +18,10 @@ func (tx *Tx) Begin(ctx context.Context) (pgx.Tx, error) { return tx.t.Begin(ctx) } -// Commit commits the transaction and returns the associated connection back to the Pool. Commit will return ErrTxClosed -// if the Tx is already closed, but is otherwise safe to call multiple times. If the commit fails with a rollback status -// (e.g. the transaction was already in a broken state) then ErrTxCommitRollback will be returned. +// Commit commits the transaction and returns the associated connection back to the Pool. Commit will return an error +// where errors.Is(ErrTxClosed) is true if the Tx is already closed, but is otherwise safe to call multiple times. If +// the commit fails with a rollback status (e.g. the transaction was already in a broken state) then ErrTxCommitRollback +// will be returned. func (tx *Tx) Commit(ctx context.Context) error { err := tx.t.Commit(ctx) if tx.c != nil { @@ -30,9 +31,9 @@ func (tx *Tx) Commit(ctx context.Context) error { return err } -// Rollback rolls back the transaction and returns the associated connection back to the Pool. Rollback will return ErrTxClosed -// if the Tx is already closed, but is otherwise safe to call multiple times. Hence, defer tx.Rollback() is safe even if -// tx.Commit() will be called first in a non-error condition. +// Rollback rolls back the transaction and returns the associated connection back to the Pool. Rollback will return +// where an error where errors.Is(ErrTxClosed) is true if the Tx is already closed, but is otherwise safe to call +// multiple times. Hence, defer tx.Rollback() is safe even if tx.Commit() will be called first in a non-error condition. func (tx *Tx) Rollback(ctx context.Context) error { err := tx.t.Rollback(ctx) if tx.c != nil { diff --git a/vendor/github.com/jackc/pgx/v5/rows.go b/vendor/github.com/jackc/pgx/v5/rows.go index d4f7a9016..ac02ba9a0 100644 --- a/vendor/github.com/jackc/pgx/v5/rows.go +++ b/vendor/github.com/jackc/pgx/v5/rows.go @@ -29,9 +29,9 @@ type Rows interface { // to call Close after rows is already closed. Close() - // Err returns any error that occurred while reading. Err must only be called after the Rows is closed (either by - // calling Close or by Next returning false). If it is called early it may return nil even if there was an error - // executing the query. + // Err returns any error that occurred while executing a query or reading its results. Err must be called after the + // Rows is closed (either by calling Close or by Next returning false) to check if the query was successful. If it is + // called before the Rows is closed it may return nil even if the query failed on the server. Err() error // CommandTag returns the command tag from this query. It is only available after Rows is closed. @@ -41,22 +41,19 @@ type Rows interface { // when there was an error executing the query. FieldDescriptions() []pgconn.FieldDescription - // Next prepares the next row for reading. It returns true if there is another - // row and false if no more rows are available or a fatal error has occurred. - // It automatically closes rows when all rows are read. + // Next prepares the next row for reading. It returns true if there is another row and false if no more rows are + // available or a fatal error has occurred. It automatically closes rows upon returning false (whether due to all rows + // having been read or due to an error). // - // Callers should check rows.Err() after rows.Next() returns false to detect - // whether result-set reading ended prematurely due to an error. See - // Conn.Query for details. + // Callers should check rows.Err() after rows.Next() returns false to detect whether result-set reading ended + // prematurely due to an error. See Conn.Query for details. // - // For simpler error handling, consider using the higher-level pgx v5 - // CollectRows() and ForEachRow() helpers instead. + // For simpler error handling, consider using the higher-level pgx v5 CollectRows() and ForEachRow() helpers instead. Next() bool - // Scan reads the values from the current row into dest values positionally. - // dest can include pointers to core types, values implementing the Scanner - // interface, and nil. nil will skip the value entirely. It is an error to - // call Scan without first calling Next() and checking that it returned true. + // Scan reads the values from the current row into dest values positionally. dest can include pointers to core types, + // values implementing the Scanner interface, and nil. nil will skip the value entirely. It is an error to call Scan + // without first calling Next() and checking that it returned true. Rows is automatically closed upon error. Scan(dest ...any) error // Values returns the decoded row values. As with Scan(), it is an error to @@ -188,6 +185,17 @@ func (rows *baseRows) Close() { } else if rows.queryTracer != nil { rows.queryTracer.TraceQueryEnd(rows.ctx, rows.conn, TraceQueryEndData{rows.commandTag, rows.err}) } + + // Zero references to other memory allocations. This allows them to be GC'd even when the Rows still referenced. In + // particular, when using pgxpool GC could be delayed as pgxpool.poolRows are allocated in large slices. + // + // https://github.com/jackc/pgx/pull/2269 + rows.values = nil + rows.scanPlans = nil + rows.scanTypes = nil + rows.ctx = nil + rows.sql = "" + rows.args = nil } func (rows *baseRows) CommandTag() pgconn.CommandTag { @@ -272,7 +280,7 @@ func (rows *baseRows) Scan(dest ...any) error { err := rows.scanPlans[i].Scan(values[i], dst) if err != nil { - err = ScanArgError{ColumnIndex: i, Err: err} + err = ScanArgError{ColumnIndex: i, FieldName: fieldDescriptions[i].Name, Err: err} rows.fatal(err) return err } @@ -334,11 +342,16 @@ func (rows *baseRows) Conn() *Conn { type ScanArgError struct { ColumnIndex int + FieldName string Err error } func (e ScanArgError) Error() string { - return fmt.Sprintf("can't scan into dest[%d]: %v", e.ColumnIndex, e.Err) + if e.FieldName == "?column?" { // Don't include the fieldname if it's unknown + return fmt.Sprintf("can't scan into dest[%d]: %v", e.ColumnIndex, e.Err) + } + + return fmt.Sprintf("can't scan into dest[%d] (col: %s): %v", e.ColumnIndex, e.FieldName, e.Err) } func (e ScanArgError) Unwrap() error { @@ -366,7 +379,7 @@ func ScanRow(typeMap *pgtype.Map, fieldDescriptions []pgconn.FieldDescription, v err := typeMap.Scan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, values[i], d) if err != nil { - return ScanArgError{ColumnIndex: i, Err: err} + return ScanArgError{ColumnIndex: i, FieldName: fieldDescriptions[i].Name, Err: err} } } @@ -468,6 +481,8 @@ func CollectOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) { return value, err } + // The defer rows.Close() won't have executed yet. If the query returned more than one row, rows would still be open. + // rows.Close() must be called before rows.Err() so we explicitly call it here. rows.Close() return value, rows.Err() } @@ -545,7 +560,7 @@ func (rs *mapRowScanner) ScanRow(rows Rows) error { return nil } -// RowToStructByPos returns a T scanned from row. T must be a struct. T must have the same number a public fields as row +// RowToStructByPos returns a T scanned from row. T must be a struct. T must have the same number of public fields as row // has fields. The row and T fields will be matched by position. If the "db" struct tag is "-" then the field will be // ignored. func RowToStructByPos[T any](row CollectableRow) (T, error) { @@ -797,7 +812,7 @@ func computeNamedStructFields( if !dbTagPresent { colName = sf.Name } - fpos := fieldPosByName(fldDescs, colName) + fpos := fieldPosByName(fldDescs, colName, !dbTagPresent) if fpos == -1 { if missingField == "" { missingField = colName @@ -816,16 +831,21 @@ func computeNamedStructFields( const structTagKey = "db" -func fieldPosByName(fldDescs []pgconn.FieldDescription, field string) (i int) { +func fieldPosByName(fldDescs []pgconn.FieldDescription, field string, normalize bool) (i int) { i = -1 - for i, desc := range fldDescs { - // Snake case support. + if normalize { field = strings.ReplaceAll(field, "_", "") - descName := strings.ReplaceAll(desc.Name, "_", "") - - if strings.EqualFold(descName, field) { - return i + } + for i, desc := range fldDescs { + if normalize { + if strings.EqualFold(strings.ReplaceAll(desc.Name, "_", ""), field) { + return i + } + } else { + if desc.Name == field { + return i + } } } return diff --git a/vendor/github.com/jackc/pgx/v5/stdlib/sql.go b/vendor/github.com/jackc/pgx/v5/stdlib/sql.go index 29cd3fbbf..99a18c055 100644 --- a/vendor/github.com/jackc/pgx/v5/stdlib/sql.go +++ b/vendor/github.com/jackc/pgx/v5/stdlib/sql.go @@ -73,8 +73,9 @@ import ( "fmt" "io" "math" - "math/rand" + "math/rand/v2" "reflect" + "slices" "strconv" "strings" "sync" @@ -98,7 +99,7 @@ func init() { // if pgx driver was already registered by different pgx major version then we // skip registration under the default name. - if !contains(sql.Drivers(), "pgx") { + if !slices.Contains(sql.Drivers(), "pgx") { sql.Register("pgx", pgxDriver) } sql.Register("pgx/v5", pgxDriver) @@ -120,20 +121,25 @@ func init() { } } -// TODO replace by slices.Contains when experimental package will be merged to stdlib -// https://pkg.go.dev/golang.org/x/exp/slices#Contains -func contains(list []string, y string) bool { - for _, x := range list { - if x == y { - return true - } - } - return false -} - // OptionOpenDB options for configuring the driver when opening a new db pool. type OptionOpenDB func(*connector) +// ShouldPingParams are passed to OptionShouldPing to decide whether to ping before reusing a connection. +type ShouldPingParams struct { + // Conn is the underlying pgx connection. + Conn *pgx.Conn + // IdleDuration is how long it has been since ResetSession last ran. + IdleDuration time.Duration +} + +// OptionShouldPing controls whether stdlib should issue a liveness ping before reusing a connection. +// If the function returns true, stdlib will ping. +// If it returns false, stdlib will skip the ping. +// If not provided, default is ping only when IdleDuration > 1s. +func OptionShouldPing(f func(context.Context, ShouldPingParams) bool) OptionOpenDB { + return func(dc *connector) { dc.ShouldPing = f } +} + // OptionBeforeConnect provides a callback for before connect. It is passed a shallow copy of the ConnConfig that will // be used to connect, so only its immediate members should be modified. Used only if db is opened with *pgx.ConnConfig. func OptionBeforeConnect(bc func(context.Context, *pgx.ConnConfig) error) OptionOpenDB { @@ -226,7 +232,8 @@ func OpenDB(config pgx.ConnConfig, opts ...OptionOpenDB) *sql.DB { // OpenDBFromPool creates a new *sql.DB from the given *pgxpool.Pool. Note that this method automatically sets the // maximum number of idle connections in *sql.DB to zero, since they must be managed from the *pgxpool.Pool. This is -// required to avoid acquiring all the connections from the pgxpool and starving any direct users of the pgxpool. +// required to avoid acquiring all the connections from the pgxpool and starving any direct users of the pgxpool. Note +// that closing the returned *sql.DB will not close the *pgxpool.Pool. func OpenDBFromPool(pool *pgxpool.Pool, opts ...OptionOpenDB) *sql.DB { c := GetPoolConnector(pool, opts...) db := sql.OpenDB(c) @@ -240,6 +247,7 @@ type connector struct { BeforeConnect func(context.Context, *pgx.ConnConfig) error // function to call before creation of every new connection AfterConnect func(context.Context, *pgx.Conn) error // function to call after creation of every new connection ResetSession func(context.Context, *pgx.Conn) error // function is called before a connection is reused + ShouldPing func(context.Context, ShouldPingParams) bool // function to decide if stdlib should ping before reusing a connection driver *Driver } @@ -291,6 +299,7 @@ func (c connector) Connect(ctx context.Context) (driver.Conn, error) { driver: c.driver, connConfig: connConfig, resetSessionFunc: c.ResetSession, + shouldPing: c.ShouldPing, psRefCounts: make(map[*pgconn.StatementDescription]int), }, nil } @@ -398,7 +407,8 @@ type Conn struct { close func(context.Context) error driver *Driver connConfig pgx.ConnConfig - resetSessionFunc func(context.Context, *pgx.Conn) error // Function is called before a connection is reused + resetSessionFunc func(context.Context, *pgx.Conn) error // Function is called before a connection is reused + shouldPing func(context.Context, ShouldPingParams) bool // Function to decide if stdlib should ping before reusing a connection lastResetSessionTime time.Time // psRefCounts contains reference counts for prepared statements. Prepare uses the underlying pgx logic to generate @@ -480,7 +490,8 @@ func (c *Conn) ExecContext(ctx context.Context, query string, argsV []driver.Nam return nil, driver.ErrBadConn } - args := namedValueToInterface(argsV) + args := make([]any, len(argsV)) + convertNamedArguments(args, argsV) commandTag, err := c.conn.Exec(ctx, query, args...) // if we got a network error before we had a chance to send the query, retry @@ -497,8 +508,9 @@ func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.Na return nil, driver.ErrBadConn } - args := []any{databaseSQLResultFormats} - args = append(args, namedValueToInterface(argsV)...) + args := make([]any, 1+len(argsV)) + args[0] = databaseSQLResultFormats + convertNamedArguments(args[1:], argsV) rows, err := c.conn.Query(ctx, query, args...) if err != nil { @@ -544,11 +556,23 @@ func (c *Conn) ResetSession(ctx context.Context) error { } now := time.Now() - if now.Sub(c.lastResetSessionTime) > time.Second { + idle := now.Sub(c.lastResetSessionTime) + + doPing := idle > time.Second // default behavior: ping only if idle > 1s + + if c.shouldPing != nil { + doPing = c.shouldPing(ctx, ShouldPingParams{ + Conn: c.conn, + IdleDuration: idle, + }) + } + + if doPing { if err := c.conn.PgConn().Ping(ctx); err != nil { return driver.ErrBadConn } } + c.lastResetSessionTime = now return c.resetSessionFunc(ctx, c.conn) @@ -640,6 +664,8 @@ func (r *Rows) ColumnTypeLength(index int) (int64, bool) { return math.MaxInt64, true case pgtype.VarcharOID, pgtype.BPCharArrayOID: return int64(fd.TypeModifier - varHeaderSize), true + case pgtype.VarbitOID: + return int64(fd.TypeModifier), true default: return 0, false } @@ -805,6 +831,16 @@ func (r *Rows) Next(dest []driver.Value) error { } return d.Value() } + case pgtype.XMLOID: + var d []byte + scanPlan := m.PlanScan(dataTypeOID, format, &d) + r.valueFuncs[i] = func(src []byte) (driver.Value, error) { + err := scanPlan.Scan(src, &d) + if err != nil { + return nil, err + } + return d, nil + } default: var d string scanPlan := m.PlanScan(dataTypeOID, format, &d) @@ -847,28 +883,14 @@ func (r *Rows) Next(dest []driver.Value) error { return nil } -func valueToInterface(argsV []driver.Value) []any { - args := make([]any, 0, len(argsV)) - for _, v := range argsV { - if v != nil { - args = append(args, v.(any)) - } else { - args = append(args, nil) - } - } - return args -} - -func namedValueToInterface(argsV []driver.NamedValue) []any { - args := make([]any, 0, len(argsV)) - for _, v := range argsV { +func convertNamedArguments(args []any, argsV []driver.NamedValue) { + for i, v := range argsV { if v.Value != nil { - args = append(args, v.Value.(any)) + args[i] = v.Value.(any) } else { - args = append(args, nil) + args[i] = nil } } - return args } type wrapTx struct { diff --git a/vendor/github.com/jackc/pgx/v5/tracelog/tracelog.go b/vendor/github.com/jackc/pgx/v5/tracelog/tracelog.go index 78b15fbe4..a68fc6a6a 100644 --- a/vendor/github.com/jackc/pgx/v5/tracelog/tracelog.go +++ b/vendor/github.com/jackc/pgx/v5/tracelog/tracelog.go @@ -6,10 +6,12 @@ import ( "encoding/hex" "errors" "fmt" + "sync" "time" "unicode/utf8" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" ) // LogLevel represents the pgx logging level. See LogLevel* constants for @@ -53,10 +55,10 @@ type Logger interface { } // LoggerFunc is a wrapper around a function to satisfy the pgx.Logger interface -type LoggerFunc func(ctx context.Context, level LogLevel, msg string, data map[string]interface{}) +type LoggerFunc func(ctx context.Context, level LogLevel, msg string, data map[string]any) // Log delegates the logging request to the wrapped function -func (f LoggerFunc) Log(ctx context.Context, level LogLevel, msg string, data map[string]interface{}) { +func (f LoggerFunc) Log(ctx context.Context, level LogLevel, msg string, data map[string]any) { f(ctx, level, msg, data) } @@ -102,7 +104,7 @@ func logQueryArgs(args []any) []any { } case string: if len(v) > 64 { - var l int = 0 + l := 0 for w := 0; l < 64; l += w { _, w = utf8.DecodeRuneInString(v[l:]) } @@ -117,11 +119,38 @@ func logQueryArgs(args []any) []any { return logArgs } -// TraceLog implements pgx.QueryTracer, pgx.BatchTracer, pgx.ConnectTracer, and pgx.CopyFromTracer. All fields are -// required. +// TraceLogConfig holds the configuration for key names +type TraceLogConfig struct { + TimeKey string +} + +// DefaultTraceLogConfig returns the default configuration for TraceLog +func DefaultTraceLogConfig() *TraceLogConfig { + return &TraceLogConfig{ + TimeKey: "time", + } +} + +// TraceLog implements pgx.QueryTracer, pgx.BatchTracer, pgx.ConnectTracer, pgx.CopyFromTracer, pgxpool.AcquireTracer, +// and pgxpool.ReleaseTracer. Logger and LogLevel are required. Config will be automatically initialized on the +// first use if nil. type TraceLog struct { Logger Logger LogLevel LogLevel + + Config *TraceLogConfig + ensureConfigOnce sync.Once +} + +// ensureConfig initializes the Config field with default values if it is nil. +func (tl *TraceLog) ensureConfig() { + tl.ensureConfigOnce.Do( + func() { + if tl.Config == nil { + tl.Config = DefaultTraceLogConfig() + } + }, + ) } type ctxKey int @@ -133,6 +162,7 @@ const ( tracelogCopyFromCtxKey tracelogConnectCtxKey tracelogPrepareCtxKey + tracelogAcquireCtxKey ) type traceQueryData struct { @@ -150,6 +180,7 @@ func (tl *TraceLog) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pg } func (tl *TraceLog) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) { + tl.ensureConfig() queryData := ctx.Value(tracelogQueryCtxKey).(*traceQueryData) endTime := time.Now() @@ -157,13 +188,13 @@ func (tl *TraceLog) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx. if data.Err != nil { if tl.shouldLog(LogLevelError) { - tl.log(ctx, conn, LogLevelError, "Query", map[string]any{"sql": queryData.sql, "args": logQueryArgs(queryData.args), "err": data.Err, "time": interval}) + tl.log(ctx, conn, LogLevelError, "Query", map[string]any{"sql": queryData.sql, "args": logQueryArgs(queryData.args), "err": data.Err, tl.Config.TimeKey: interval}) } return } if tl.shouldLog(LogLevelInfo) { - tl.log(ctx, conn, LogLevelInfo, "Query", map[string]any{"sql": queryData.sql, "args": logQueryArgs(queryData.args), "time": interval, "commandTag": data.CommandTag.String()}) + tl.log(ctx, conn, LogLevelInfo, "Query", map[string]any{"sql": queryData.sql, "args": logQueryArgs(queryData.args), tl.Config.TimeKey: interval, "commandTag": data.CommandTag.String()}) } } @@ -191,6 +222,7 @@ func (tl *TraceLog) TraceBatchQuery(ctx context.Context, conn *pgx.Conn, data pg } func (tl *TraceLog) TraceBatchEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) { + tl.ensureConfig() queryData := ctx.Value(tracelogBatchCtxKey).(*traceBatchData) endTime := time.Now() @@ -198,13 +230,13 @@ func (tl *TraceLog) TraceBatchEnd(ctx context.Context, conn *pgx.Conn, data pgx. if data.Err != nil { if tl.shouldLog(LogLevelError) { - tl.log(ctx, conn, LogLevelError, "BatchClose", map[string]any{"err": data.Err, "time": interval}) + tl.log(ctx, conn, LogLevelError, "BatchClose", map[string]any{"err": data.Err, tl.Config.TimeKey: interval}) } return } if tl.shouldLog(LogLevelInfo) { - tl.log(ctx, conn, LogLevelInfo, "BatchClose", map[string]any{"time": interval}) + tl.log(ctx, conn, LogLevelInfo, "BatchClose", map[string]any{tl.Config.TimeKey: interval}) } } @@ -223,6 +255,7 @@ func (tl *TraceLog) TraceCopyFromStart(ctx context.Context, conn *pgx.Conn, data } func (tl *TraceLog) TraceCopyFromEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) { + tl.ensureConfig() copyFromData := ctx.Value(tracelogCopyFromCtxKey).(*traceCopyFromData) endTime := time.Now() @@ -230,13 +263,13 @@ func (tl *TraceLog) TraceCopyFromEnd(ctx context.Context, conn *pgx.Conn, data p if data.Err != nil { if tl.shouldLog(LogLevelError) { - tl.log(ctx, conn, LogLevelError, "CopyFrom", map[string]any{"tableName": copyFromData.TableName, "columnNames": copyFromData.ColumnNames, "err": data.Err, "time": interval}) + tl.log(ctx, conn, LogLevelError, "CopyFrom", map[string]any{"tableName": copyFromData.TableName, "columnNames": copyFromData.ColumnNames, "err": data.Err, tl.Config.TimeKey: interval}) } return } if tl.shouldLog(LogLevelInfo) { - tl.log(ctx, conn, LogLevelInfo, "CopyFrom", map[string]any{"tableName": copyFromData.TableName, "columnNames": copyFromData.ColumnNames, "err": data.Err, "time": interval, "rowCount": data.CommandTag.RowsAffected()}) + tl.log(ctx, conn, LogLevelInfo, "CopyFrom", map[string]any{"tableName": copyFromData.TableName, "columnNames": copyFromData.ColumnNames, "err": data.Err, tl.Config.TimeKey: interval, "rowCount": data.CommandTag.RowsAffected()}) } } @@ -253,6 +286,7 @@ func (tl *TraceLog) TraceConnectStart(ctx context.Context, data pgx.TraceConnect } func (tl *TraceLog) TraceConnectEnd(ctx context.Context, data pgx.TraceConnectEndData) { + tl.ensureConfig() connectData := ctx.Value(tracelogConnectCtxKey).(*traceConnectData) endTime := time.Now() @@ -261,11 +295,11 @@ func (tl *TraceLog) TraceConnectEnd(ctx context.Context, data pgx.TraceConnectEn if data.Err != nil { if tl.shouldLog(LogLevelError) { tl.Logger.Log(ctx, LogLevelError, "Connect", map[string]any{ - "host": connectData.connConfig.Host, - "port": connectData.connConfig.Port, - "database": connectData.connConfig.Database, - "time": interval, - "err": data.Err, + "host": connectData.connConfig.Host, + "port": connectData.connConfig.Port, + "database": connectData.connConfig.Database, + tl.Config.TimeKey: interval, + "err": data.Err, }) } return @@ -274,10 +308,10 @@ func (tl *TraceLog) TraceConnectEnd(ctx context.Context, data pgx.TraceConnectEn if data.Conn != nil { if tl.shouldLog(LogLevelInfo) { tl.log(ctx, data.Conn, LogLevelInfo, "Connect", map[string]any{ - "host": connectData.connConfig.Host, - "port": connectData.connConfig.Port, - "database": connectData.connConfig.Database, - "time": interval, + "host": connectData.connConfig.Host, + "port": connectData.connConfig.Port, + "database": connectData.connConfig.Database, + tl.Config.TimeKey: interval, }) } } @@ -298,6 +332,7 @@ func (tl *TraceLog) TracePrepareStart(ctx context.Context, _ *pgx.Conn, data pgx } func (tl *TraceLog) TracePrepareEnd(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) { + tl.ensureConfig() prepareData := ctx.Value(tracelogPrepareCtxKey).(*tracePrepareData) endTime := time.Now() @@ -305,13 +340,51 @@ func (tl *TraceLog) TracePrepareEnd(ctx context.Context, conn *pgx.Conn, data pg if data.Err != nil { if tl.shouldLog(LogLevelError) { - tl.log(ctx, conn, LogLevelError, "Prepare", map[string]any{"name": prepareData.name, "sql": prepareData.sql, "err": data.Err, "time": interval}) + tl.log(ctx, conn, LogLevelError, "Prepare", map[string]any{"name": prepareData.name, "sql": prepareData.sql, "err": data.Err, tl.Config.TimeKey: interval}) } return } if tl.shouldLog(LogLevelInfo) { - tl.log(ctx, conn, LogLevelInfo, "Prepare", map[string]any{"name": prepareData.name, "sql": prepareData.sql, "time": interval, "alreadyPrepared": data.AlreadyPrepared}) + tl.log(ctx, conn, LogLevelInfo, "Prepare", map[string]any{"name": prepareData.name, "sql": prepareData.sql, tl.Config.TimeKey: interval, "alreadyPrepared": data.AlreadyPrepared}) + } +} + +type traceAcquireData struct { + startTime time.Time +} + +func (tl *TraceLog) TraceAcquireStart(ctx context.Context, _ *pgxpool.Pool, _ pgxpool.TraceAcquireStartData) context.Context { + return context.WithValue(ctx, tracelogAcquireCtxKey, &traceAcquireData{ + startTime: time.Now(), + }) +} + +func (tl *TraceLog) TraceAcquireEnd(ctx context.Context, _ *pgxpool.Pool, data pgxpool.TraceAcquireEndData) { + tl.ensureConfig() + acquireData := ctx.Value(tracelogAcquireCtxKey).(*traceAcquireData) + + endTime := time.Now() + interval := endTime.Sub(acquireData.startTime) + + if data.Err != nil { + if tl.shouldLog(LogLevelError) { + tl.Logger.Log(ctx, LogLevelError, "Acquire", map[string]any{"err": data.Err, tl.Config.TimeKey: interval}) + } + return + } + + if data.Conn != nil { + if tl.shouldLog(LogLevelDebug) { + tl.log(ctx, data.Conn, LogLevelDebug, "Acquire", map[string]any{tl.Config.TimeKey: interval}) + } + } +} + +func (tl *TraceLog) TraceRelease(_ *pgxpool.Pool, data pgxpool.TraceReleaseData) { + if tl.shouldLog(LogLevelDebug) { + // there is no context on the TraceRelease callback + tl.log(context.Background(), data.Conn, LogLevelDebug, "Release", map[string]any{}) } } diff --git a/vendor/github.com/jackc/pgx/v5/tx.go b/vendor/github.com/jackc/pgx/v5/tx.go index 8feeb5123..571e5e00f 100644 --- a/vendor/github.com/jackc/pgx/v5/tx.go +++ b/vendor/github.com/jackc/pgx/v5/tx.go @@ -3,7 +3,6 @@ package pgx import ( "context" "errors" - "fmt" "strconv" "strings" @@ -48,6 +47,8 @@ type TxOptions struct { // BeginQuery is the SQL query that will be executed to begin the transaction. This allows using non-standard syntax // such as BEGIN PRIORITY HIGH with CockroachDB. If set this will override the other settings. BeginQuery string + // CommitQuery is the SQL query that will be executed to commit the transaction. + CommitQuery string } var emptyTxOptions TxOptions @@ -101,11 +102,14 @@ func (c *Conn) BeginTx(ctx context.Context, txOptions TxOptions) (Tx, error) { if err != nil { // begin should never fail unless there is an underlying connection issue or // a context timeout. In either case, the connection is possibly broken. - c.die(errors.New("failed to begin transaction")) + c.die() return nil, err } - return &dbTx{conn: c}, nil + return &dbTx{ + conn: c, + commitQuery: txOptions.CommitQuery, + }, nil } // Tx represents a database transaction. @@ -154,6 +158,7 @@ type dbTx struct { conn *Conn savepointNum int64 closed bool + commitQuery string } // Begin starts a pseudo nested transaction implemented with a savepoint. @@ -177,7 +182,12 @@ func (tx *dbTx) Commit(ctx context.Context) error { return ErrTxClosed } - commandTag, err := tx.conn.Exec(ctx, "commit") + commandSQL := "commit" + if tx.commitQuery != "" { + commandSQL = tx.commitQuery + } + + commandTag, err := tx.conn.Exec(ctx, commandSQL) tx.closed = true if err != nil { if tx.conn.PgConn().TxStatus() != 'I' { @@ -205,7 +215,7 @@ func (tx *dbTx) Rollback(ctx context.Context) error { tx.closed = true if err != nil { // A rollback failure leaves the connection in an undefined state - tx.conn.die(fmt.Errorf("rollback failed: %w", err)) + tx.conn.die() return err } diff --git a/vendor/github.com/jackc/puddle/v2/CHANGELOG.md b/vendor/github.com/jackc/puddle/v2/CHANGELOG.md index a15991c58..d0d202c74 100644 --- a/vendor/github.com/jackc/puddle/v2/CHANGELOG.md +++ b/vendor/github.com/jackc/puddle/v2/CHANGELOG.md @@ -1,3 +1,8 @@ +# 2.2.2 (September 10, 2024) + +* Add empty acquire time to stats (Maxim Ivanov) +* Stop importing nanotime from runtime via linkname (maypok86) + # 2.2.1 (July 15, 2023) * Fix: CreateResource cannot overflow pool. This changes documented behavior of CreateResource. Previously, diff --git a/vendor/github.com/jackc/puddle/v2/README.md b/vendor/github.com/jackc/puddle/v2/README.md index 0ad07ec43..fa82a9d46 100644 --- a/vendor/github.com/jackc/puddle/v2/README.md +++ b/vendor/github.com/jackc/puddle/v2/README.md @@ -1,4 +1,4 @@ -[![](https://godoc.org/github.com/jackc/puddle?status.svg)](https://godoc.org/github.com/jackc/puddle) +[![Go Reference](https://pkg.go.dev/badge/github.com/jackc/puddle/v2.svg)](https://pkg.go.dev/github.com/jackc/puddle/v2) ![Build Status](https://github.com/jackc/puddle/actions/workflows/ci.yml/badge.svg) # Puddle diff --git a/vendor/github.com/jackc/puddle/v2/nanotime.go b/vendor/github.com/jackc/puddle/v2/nanotime.go new file mode 100644 index 000000000..8a5351a0d --- /dev/null +++ b/vendor/github.com/jackc/puddle/v2/nanotime.go @@ -0,0 +1,16 @@ +package puddle + +import "time" + +// nanotime returns the time in nanoseconds since process start. +// +// This approach, described at +// https://github.com/golang/go/issues/61765#issuecomment-1672090302, +// is fast, monotonic, and portable, and avoids the previous +// dependence on runtime.nanotime using the (unsafe) linkname hack. +// In particular, time.Since does less work than time.Now. +func nanotime() int64 { + return time.Since(globalStart).Nanoseconds() +} + +var globalStart = time.Now() diff --git a/vendor/github.com/jackc/puddle/v2/nanotime_time.go b/vendor/github.com/jackc/puddle/v2/nanotime_time.go deleted file mode 100644 index f8e759386..000000000 --- a/vendor/github.com/jackc/puddle/v2/nanotime_time.go +++ /dev/null @@ -1,13 +0,0 @@ -//go:build purego || appengine || js - -// This file contains the safe implementation of nanotime using time.Now(). - -package puddle - -import ( - "time" -) - -func nanotime() int64 { - return time.Now().UnixNano() -} diff --git a/vendor/github.com/jackc/puddle/v2/nanotime_unsafe.go b/vendor/github.com/jackc/puddle/v2/nanotime_unsafe.go deleted file mode 100644 index fc3b8a258..000000000 --- a/vendor/github.com/jackc/puddle/v2/nanotime_unsafe.go +++ /dev/null @@ -1,12 +0,0 @@ -//go:build !purego && !appengine && !js - -// This file contains the implementation of nanotime using runtime.nanotime. - -package puddle - -import "unsafe" - -var _ = unsafe.Sizeof(0) - -//go:linkname nanotime runtime.nanotime -func nanotime() int64 diff --git a/vendor/github.com/jackc/puddle/v2/pool.go b/vendor/github.com/jackc/puddle/v2/pool.go index c8edc0fb6..c411d2f6e 100644 --- a/vendor/github.com/jackc/puddle/v2/pool.go +++ b/vendor/github.com/jackc/puddle/v2/pool.go @@ -139,6 +139,7 @@ type Pool[T any] struct { acquireCount int64 acquireDuration time.Duration emptyAcquireCount int64 + emptyAcquireWaitTime time.Duration canceledAcquireCount atomic.Int64 resetCount int @@ -154,7 +155,7 @@ type Config[T any] struct { MaxSize int32 } -// NewPool creates a new pool. Panics if maxSize is less than 1. +// NewPool creates a new pool. Returns an error iff MaxSize is less than 1. func NewPool[T any](config *Config[T]) (*Pool[T], error) { if config.MaxSize < 1 { return nil, errors.New("MaxSize must be >= 1") @@ -202,6 +203,7 @@ type Stat struct { acquireCount int64 acquireDuration time.Duration emptyAcquireCount int64 + emptyAcquireWaitTime time.Duration canceledAcquireCount int64 } @@ -251,6 +253,13 @@ func (s *Stat) EmptyAcquireCount() int64 { return s.emptyAcquireCount } +// EmptyAcquireWaitTime returns the cumulative time waited for successful acquires +// from the pool for a resource to be released or constructed because the pool was +// empty. +func (s *Stat) EmptyAcquireWaitTime() time.Duration { + return s.emptyAcquireWaitTime +} + // CanceledAcquireCount returns the cumulative count of acquires from the pool // that were canceled by a context. func (s *Stat) CanceledAcquireCount() int64 { @@ -266,6 +275,7 @@ func (p *Pool[T]) Stat() *Stat { maxResources: p.maxSize, acquireCount: p.acquireCount, emptyAcquireCount: p.emptyAcquireCount, + emptyAcquireWaitTime: p.emptyAcquireWaitTime, canceledAcquireCount: p.canceledAcquireCount.Load(), acquireDuration: p.acquireDuration, } @@ -363,11 +373,13 @@ func (p *Pool[T]) acquire(ctx context.Context) (*Resource[T], error) { // If a resource is available in the pool. if res := p.tryAcquireIdleResource(); res != nil { + waitTime := time.Duration(nanotime() - startNano) if waitedForLock { p.emptyAcquireCount += 1 + p.emptyAcquireWaitTime += waitTime } p.acquireCount += 1 - p.acquireDuration += time.Duration(nanotime() - startNano) + p.acquireDuration += waitTime p.mux.Unlock() return res, nil } @@ -391,7 +403,9 @@ func (p *Pool[T]) acquire(ctx context.Context) (*Resource[T], error) { p.emptyAcquireCount += 1 p.acquireCount += 1 - p.acquireDuration += time.Duration(nanotime() - startNano) + waitTime := time.Duration(nanotime() - startNano) + p.acquireDuration += waitTime + p.emptyAcquireWaitTime += waitTime return res, nil } diff --git a/vendor/github.com/stretchr/testify/assert/assertion_compare.go b/vendor/github.com/stretchr/testify/assert/assertion_compare.go index 4d4b4aad6..ffb24e8e3 100644 --- a/vendor/github.com/stretchr/testify/assert/assertion_compare.go +++ b/vendor/github.com/stretchr/testify/assert/assertion_compare.go @@ -7,10 +7,13 @@ import ( "time" ) -type CompareType int +// Deprecated: CompareType has only ever been for internal use and has accidentally been published since v1.6.0. Do not use it. +type CompareType = compareResult + +type compareResult int const ( - compareLess CompareType = iota - 1 + compareLess compareResult = iota - 1 compareEqual compareGreater ) @@ -39,7 +42,7 @@ var ( bytesType = reflect.TypeOf([]byte{}) ) -func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) { +func compare(obj1, obj2 interface{}, kind reflect.Kind) (compareResult, bool) { obj1Value := reflect.ValueOf(obj1) obj2Value := reflect.ValueOf(obj2) @@ -325,7 +328,13 @@ func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) { timeObj2 = obj2Value.Convert(timeType).Interface().(time.Time) } - return compare(timeObj1.UnixNano(), timeObj2.UnixNano(), reflect.Int64) + if timeObj1.Before(timeObj2) { + return compareLess, true + } + if timeObj1.Equal(timeObj2) { + return compareEqual, true + } + return compareGreater, true } case reflect.Slice: { @@ -345,7 +354,7 @@ func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) { bytesObj2 = obj2Value.Convert(bytesType).Interface().([]byte) } - return CompareType(bytes.Compare(bytesObj1, bytesObj2)), true + return compareResult(bytes.Compare(bytesObj1, bytesObj2)), true } case reflect.Uintptr: { @@ -381,7 +390,8 @@ func Greater(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface if h, ok := t.(tHelper); ok { h.Helper() } - return compareTwoValues(t, e1, e2, []CompareType{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs...) + failMessage := fmt.Sprintf("\"%v\" is not greater than \"%v\"", e1, e2) + return compareTwoValues(t, e1, e2, []compareResult{compareGreater}, failMessage, msgAndArgs...) } // GreaterOrEqual asserts that the first element is greater than or equal to the second @@ -394,7 +404,8 @@ func GreaterOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...in if h, ok := t.(tHelper); ok { h.Helper() } - return compareTwoValues(t, e1, e2, []CompareType{compareGreater, compareEqual}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs...) + failMessage := fmt.Sprintf("\"%v\" is not greater than or equal to \"%v\"", e1, e2) + return compareTwoValues(t, e1, e2, []compareResult{compareGreater, compareEqual}, failMessage, msgAndArgs...) } // Less asserts that the first element is less than the second @@ -406,7 +417,8 @@ func Less(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) if h, ok := t.(tHelper); ok { h.Helper() } - return compareTwoValues(t, e1, e2, []CompareType{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs...) + failMessage := fmt.Sprintf("\"%v\" is not less than \"%v\"", e1, e2) + return compareTwoValues(t, e1, e2, []compareResult{compareLess}, failMessage, msgAndArgs...) } // LessOrEqual asserts that the first element is less than or equal to the second @@ -419,7 +431,8 @@ func LessOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...inter if h, ok := t.(tHelper); ok { h.Helper() } - return compareTwoValues(t, e1, e2, []CompareType{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs...) + failMessage := fmt.Sprintf("\"%v\" is not less than or equal to \"%v\"", e1, e2) + return compareTwoValues(t, e1, e2, []compareResult{compareLess, compareEqual}, failMessage, msgAndArgs...) } // Positive asserts that the specified element is positive @@ -431,7 +444,8 @@ func Positive(t TestingT, e interface{}, msgAndArgs ...interface{}) bool { h.Helper() } zero := reflect.Zero(reflect.TypeOf(e)) - return compareTwoValues(t, e, zero.Interface(), []CompareType{compareGreater}, "\"%v\" is not positive", msgAndArgs...) + failMessage := fmt.Sprintf("\"%v\" is not positive", e) + return compareTwoValues(t, e, zero.Interface(), []compareResult{compareGreater}, failMessage, msgAndArgs...) } // Negative asserts that the specified element is negative @@ -443,10 +457,11 @@ func Negative(t TestingT, e interface{}, msgAndArgs ...interface{}) bool { h.Helper() } zero := reflect.Zero(reflect.TypeOf(e)) - return compareTwoValues(t, e, zero.Interface(), []CompareType{compareLess}, "\"%v\" is not negative", msgAndArgs...) + failMessage := fmt.Sprintf("\"%v\" is not negative", e) + return compareTwoValues(t, e, zero.Interface(), []compareResult{compareLess}, failMessage, msgAndArgs...) } -func compareTwoValues(t TestingT, e1 interface{}, e2 interface{}, allowedComparesResults []CompareType, failMessage string, msgAndArgs ...interface{}) bool { +func compareTwoValues(t TestingT, e1 interface{}, e2 interface{}, allowedComparesResults []compareResult, failMessage string, msgAndArgs ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() } @@ -459,17 +474,17 @@ func compareTwoValues(t TestingT, e1 interface{}, e2 interface{}, allowedCompare compareResult, isComparable := compare(e1, e2, e1Kind) if !isComparable { - return Fail(t, fmt.Sprintf("Can not compare type \"%s\"", reflect.TypeOf(e1)), msgAndArgs...) + return Fail(t, fmt.Sprintf(`Can not compare type "%T"`, e1), msgAndArgs...) } if !containsValue(allowedComparesResults, compareResult) { - return Fail(t, fmt.Sprintf(failMessage, e1, e2), msgAndArgs...) + return Fail(t, failMessage, msgAndArgs...) } return true } -func containsValue(values []CompareType, value CompareType) bool { +func containsValue(values []compareResult, value compareResult) bool { for _, v := range values { if v == value { return true diff --git a/vendor/github.com/stretchr/testify/assert/assertion_format.go b/vendor/github.com/stretchr/testify/assert/assertion_format.go index 3ddab109a..c592f6ad5 100644 --- a/vendor/github.com/stretchr/testify/assert/assertion_format.go +++ b/vendor/github.com/stretchr/testify/assert/assertion_format.go @@ -50,10 +50,19 @@ func ElementsMatchf(t TestingT, listA interface{}, listB interface{}, msg string return ElementsMatch(t, listA, listB, append([]interface{}{msg}, args...)...) } -// Emptyf asserts that the specified object is empty. I.e. nil, "", false, 0 or either -// a slice or a channel with len == 0. +// Emptyf asserts that the given value is "empty". +// +// [Zero values] are "empty". +// +// Arrays are "empty" if every element is the zero value of the type (stricter than "empty"). +// +// Slices, maps and channels with zero length are "empty". +// +// Pointer values are "empty" if the pointer is nil or if the pointed value is "empty". // // assert.Emptyf(t, obj, "error message %s", "formatted") +// +// [Zero values]: https://go.dev/ref/spec#The_zero_value func Emptyf(t TestingT, object interface{}, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -104,8 +113,8 @@ func EqualExportedValuesf(t TestingT, expected interface{}, actual interface{}, return EqualExportedValues(t, expected, actual, append([]interface{}{msg}, args...)...) } -// EqualValuesf asserts that two objects are equal or convertible to the same types -// and equal. +// EqualValuesf asserts that two objects are equal or convertible to the larger +// type and equal. // // assert.EqualValuesf(t, uint32(123), int32(123), "error message %s", "formatted") func EqualValuesf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool { @@ -117,10 +126,8 @@ func EqualValuesf(t TestingT, expected interface{}, actual interface{}, msg stri // Errorf asserts that a function returned an error (i.e. not `nil`). // -// actualObj, err := SomeFunction() -// if assert.Errorf(t, err, "error message %s", "formatted") { -// assert.Equal(t, expectedErrorf, err) -// } +// actualObj, err := SomeFunction() +// assert.Errorf(t, err, "error message %s", "formatted") func Errorf(t TestingT, err error, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -186,7 +193,7 @@ func Eventuallyf(t TestingT, condition func() bool, waitFor time.Duration, tick // assert.EventuallyWithTf(t, func(c *assert.CollectT, "error message %s", "formatted") { // // add assertions as needed; any assertion failure will fail the current tick // assert.True(c, externalValue, "expected 'externalValue' to be true") -// }, 1*time.Second, 10*time.Second, "external state has not changed to 'true'; still false") +// }, 10*time.Second, 1*time.Second, "external state has not changed to 'true'; still false") func EventuallyWithTf(t TestingT, condition func(collect *CollectT), waitFor time.Duration, tick time.Duration, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -438,7 +445,19 @@ func IsNonIncreasingf(t TestingT, object interface{}, msg string, args ...interf return IsNonIncreasing(t, object, append([]interface{}{msg}, args...)...) } +// IsNotTypef asserts that the specified objects are not of the same type. +// +// assert.IsNotTypef(t, &NotMyStruct{}, &MyStruct{}, "error message %s", "formatted") +func IsNotTypef(t TestingT, theType interface{}, object interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return IsNotType(t, theType, object, append([]interface{}{msg}, args...)...) +} + // IsTypef asserts that the specified objects are of the same type. +// +// assert.IsTypef(t, &MyStruct{}, &MyStruct{}, "error message %s", "formatted") func IsTypef(t TestingT, expectedType interface{}, object interface{}, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -568,8 +587,24 @@ func NotContainsf(t TestingT, s interface{}, contains interface{}, msg string, a return NotContains(t, s, contains, append([]interface{}{msg}, args...)...) } -// NotEmptyf asserts that the specified object is NOT empty. I.e. not nil, "", false, 0 or either -// a slice or a channel with len == 0. +// NotElementsMatchf asserts that the specified listA(array, slice...) is NOT equal to specified +// listB(array, slice...) ignoring the order of the elements. If there are duplicate elements, +// the number of appearances of each of them in both lists should not match. +// This is an inverse of ElementsMatch. +// +// assert.NotElementsMatchf(t, [1, 1, 2, 3], [1, 1, 2, 3], "error message %s", "formatted") -> false +// +// assert.NotElementsMatchf(t, [1, 1, 2, 3], [1, 2, 3], "error message %s", "formatted") -> true +// +// assert.NotElementsMatchf(t, [1, 2, 3], [1, 2, 4], "error message %s", "formatted") -> true +func NotElementsMatchf(t TestingT, listA interface{}, listB interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return NotElementsMatch(t, listA, listB, append([]interface{}{msg}, args...)...) +} + +// NotEmptyf asserts that the specified object is NOT [Empty]. // // if assert.NotEmptyf(t, obj, "error message %s", "formatted") { // assert.Equal(t, "two", obj[1]) @@ -604,7 +639,16 @@ func NotEqualValuesf(t TestingT, expected interface{}, actual interface{}, msg s return NotEqualValues(t, expected, actual, append([]interface{}{msg}, args...)...) } -// NotErrorIsf asserts that at none of the errors in err's chain matches target. +// NotErrorAsf asserts that none of the errors in err's chain matches target, +// but if so, sets target to that error value. +func NotErrorAsf(t TestingT, err error, target interface{}, msg string, args ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + return NotErrorAs(t, err, target, append([]interface{}{msg}, args...)...) +} + +// NotErrorIsf asserts that none of the errors in err's chain matches target. // This is a wrapper for errors.Is. func NotErrorIsf(t TestingT, err error, target error, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { @@ -667,12 +711,15 @@ func NotSamef(t TestingT, expected interface{}, actual interface{}, msg string, return NotSame(t, expected, actual, append([]interface{}{msg}, args...)...) } -// NotSubsetf asserts that the specified list(array, slice...) or map does NOT -// contain all elements given in the specified subset list(array, slice...) or -// map. +// NotSubsetf asserts that the list (array, slice, or map) does NOT contain all +// elements given in the subset (array, slice, or map). +// Map elements are key-value pairs unless compared with an array or slice where +// only the map key is evaluated. // // assert.NotSubsetf(t, [1, 3, 4], [1, 2], "error message %s", "formatted") // assert.NotSubsetf(t, {"x": 1, "y": 2}, {"z": 3}, "error message %s", "formatted") +// assert.NotSubsetf(t, [1, 3, 4], {1: "one", 2: "two"}, "error message %s", "formatted") +// assert.NotSubsetf(t, {"x": 1, "y": 2}, ["z"], "error message %s", "formatted") func NotSubsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() @@ -756,11 +803,15 @@ func Samef(t TestingT, expected interface{}, actual interface{}, msg string, arg return Same(t, expected, actual, append([]interface{}{msg}, args...)...) } -// Subsetf asserts that the specified list(array, slice...) or map contains all -// elements given in the specified subset list(array, slice...) or map. +// Subsetf asserts that the list (array, slice, or map) contains all elements +// given in the subset (array, slice, or map). +// Map elements are key-value pairs unless compared with an array or slice where +// only the map key is evaluated. // // assert.Subsetf(t, [1, 2, 3], [1, 2], "error message %s", "formatted") // assert.Subsetf(t, {"x": 1, "y": 2}, {"x": 1}, "error message %s", "formatted") +// assert.Subsetf(t, [1, 2, 3], {1: "one", 2: "two"}, "error message %s", "formatted") +// assert.Subsetf(t, {"x": 1, "y": 2}, ["x"], "error message %s", "formatted") func Subsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() diff --git a/vendor/github.com/stretchr/testify/assert/assertion_forward.go b/vendor/github.com/stretchr/testify/assert/assertion_forward.go index a84e09bd4..58db92845 100644 --- a/vendor/github.com/stretchr/testify/assert/assertion_forward.go +++ b/vendor/github.com/stretchr/testify/assert/assertion_forward.go @@ -92,10 +92,19 @@ func (a *Assertions) ElementsMatchf(listA interface{}, listB interface{}, msg st return ElementsMatchf(a.t, listA, listB, msg, args...) } -// Empty asserts that the specified object is empty. I.e. nil, "", false, 0 or either -// a slice or a channel with len == 0. +// Empty asserts that the given value is "empty". +// +// [Zero values] are "empty". +// +// Arrays are "empty" if every element is the zero value of the type (stricter than "empty"). +// +// Slices, maps and channels with zero length are "empty". +// +// Pointer values are "empty" if the pointer is nil or if the pointed value is "empty". // // a.Empty(obj) +// +// [Zero values]: https://go.dev/ref/spec#The_zero_value func (a *Assertions) Empty(object interface{}, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -103,10 +112,19 @@ func (a *Assertions) Empty(object interface{}, msgAndArgs ...interface{}) bool { return Empty(a.t, object, msgAndArgs...) } -// Emptyf asserts that the specified object is empty. I.e. nil, "", false, 0 or either -// a slice or a channel with len == 0. +// Emptyf asserts that the given value is "empty". +// +// [Zero values] are "empty". +// +// Arrays are "empty" if every element is the zero value of the type (stricter than "empty"). +// +// Slices, maps and channels with zero length are "empty". +// +// Pointer values are "empty" if the pointer is nil or if the pointed value is "empty". // // a.Emptyf(obj, "error message %s", "formatted") +// +// [Zero values]: https://go.dev/ref/spec#The_zero_value func (a *Assertions) Emptyf(object interface{}, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -186,8 +204,8 @@ func (a *Assertions) EqualExportedValuesf(expected interface{}, actual interface return EqualExportedValuesf(a.t, expected, actual, msg, args...) } -// EqualValues asserts that two objects are equal or convertible to the same types -// and equal. +// EqualValues asserts that two objects are equal or convertible to the larger +// type and equal. // // a.EqualValues(uint32(123), int32(123)) func (a *Assertions) EqualValues(expected interface{}, actual interface{}, msgAndArgs ...interface{}) bool { @@ -197,8 +215,8 @@ func (a *Assertions) EqualValues(expected interface{}, actual interface{}, msgAn return EqualValues(a.t, expected, actual, msgAndArgs...) } -// EqualValuesf asserts that two objects are equal or convertible to the same types -// and equal. +// EqualValuesf asserts that two objects are equal or convertible to the larger +// type and equal. // // a.EqualValuesf(uint32(123), int32(123), "error message %s", "formatted") func (a *Assertions) EqualValuesf(expected interface{}, actual interface{}, msg string, args ...interface{}) bool { @@ -224,10 +242,8 @@ func (a *Assertions) Equalf(expected interface{}, actual interface{}, msg string // Error asserts that a function returned an error (i.e. not `nil`). // -// actualObj, err := SomeFunction() -// if a.Error(err) { -// assert.Equal(t, expectedError, err) -// } +// actualObj, err := SomeFunction() +// a.Error(err) func (a *Assertions) Error(err error, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -297,10 +313,8 @@ func (a *Assertions) ErrorIsf(err error, target error, msg string, args ...inter // Errorf asserts that a function returned an error (i.e. not `nil`). // -// actualObj, err := SomeFunction() -// if a.Errorf(err, "error message %s", "formatted") { -// assert.Equal(t, expectedErrorf, err) -// } +// actualObj, err := SomeFunction() +// a.Errorf(err, "error message %s", "formatted") func (a *Assertions) Errorf(err error, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -336,7 +350,7 @@ func (a *Assertions) Eventually(condition func() bool, waitFor time.Duration, ti // a.EventuallyWithT(func(c *assert.CollectT) { // // add assertions as needed; any assertion failure will fail the current tick // assert.True(c, externalValue, "expected 'externalValue' to be true") -// }, 1*time.Second, 10*time.Second, "external state has not changed to 'true'; still false") +// }, 10*time.Second, 1*time.Second, "external state has not changed to 'true'; still false") func (a *Assertions) EventuallyWithT(condition func(collect *CollectT), waitFor time.Duration, tick time.Duration, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -361,7 +375,7 @@ func (a *Assertions) EventuallyWithT(condition func(collect *CollectT), waitFor // a.EventuallyWithTf(func(c *assert.CollectT, "error message %s", "formatted") { // // add assertions as needed; any assertion failure will fail the current tick // assert.True(c, externalValue, "expected 'externalValue' to be true") -// }, 1*time.Second, 10*time.Second, "external state has not changed to 'true'; still false") +// }, 10*time.Second, 1*time.Second, "external state has not changed to 'true'; still false") func (a *Assertions) EventuallyWithTf(condition func(collect *CollectT), waitFor time.Duration, tick time.Duration, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -868,7 +882,29 @@ func (a *Assertions) IsNonIncreasingf(object interface{}, msg string, args ...in return IsNonIncreasingf(a.t, object, msg, args...) } +// IsNotType asserts that the specified objects are not of the same type. +// +// a.IsNotType(&NotMyStruct{}, &MyStruct{}) +func (a *Assertions) IsNotType(theType interface{}, object interface{}, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return IsNotType(a.t, theType, object, msgAndArgs...) +} + +// IsNotTypef asserts that the specified objects are not of the same type. +// +// a.IsNotTypef(&NotMyStruct{}, &MyStruct{}, "error message %s", "formatted") +func (a *Assertions) IsNotTypef(theType interface{}, object interface{}, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return IsNotTypef(a.t, theType, object, msg, args...) +} + // IsType asserts that the specified objects are of the same type. +// +// a.IsType(&MyStruct{}, &MyStruct{}) func (a *Assertions) IsType(expectedType interface{}, object interface{}, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -877,6 +913,8 @@ func (a *Assertions) IsType(expectedType interface{}, object interface{}, msgAnd } // IsTypef asserts that the specified objects are of the same type. +// +// a.IsTypef(&MyStruct{}, &MyStruct{}, "error message %s", "formatted") func (a *Assertions) IsTypef(expectedType interface{}, object interface{}, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -1128,8 +1166,41 @@ func (a *Assertions) NotContainsf(s interface{}, contains interface{}, msg strin return NotContainsf(a.t, s, contains, msg, args...) } -// NotEmpty asserts that the specified object is NOT empty. I.e. not nil, "", false, 0 or either -// a slice or a channel with len == 0. +// NotElementsMatch asserts that the specified listA(array, slice...) is NOT equal to specified +// listB(array, slice...) ignoring the order of the elements. If there are duplicate elements, +// the number of appearances of each of them in both lists should not match. +// This is an inverse of ElementsMatch. +// +// a.NotElementsMatch([1, 1, 2, 3], [1, 1, 2, 3]) -> false +// +// a.NotElementsMatch([1, 1, 2, 3], [1, 2, 3]) -> true +// +// a.NotElementsMatch([1, 2, 3], [1, 2, 4]) -> true +func (a *Assertions) NotElementsMatch(listA interface{}, listB interface{}, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return NotElementsMatch(a.t, listA, listB, msgAndArgs...) +} + +// NotElementsMatchf asserts that the specified listA(array, slice...) is NOT equal to specified +// listB(array, slice...) ignoring the order of the elements. If there are duplicate elements, +// the number of appearances of each of them in both lists should not match. +// This is an inverse of ElementsMatch. +// +// a.NotElementsMatchf([1, 1, 2, 3], [1, 1, 2, 3], "error message %s", "formatted") -> false +// +// a.NotElementsMatchf([1, 1, 2, 3], [1, 2, 3], "error message %s", "formatted") -> true +// +// a.NotElementsMatchf([1, 2, 3], [1, 2, 4], "error message %s", "formatted") -> true +func (a *Assertions) NotElementsMatchf(listA interface{}, listB interface{}, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return NotElementsMatchf(a.t, listA, listB, msg, args...) +} + +// NotEmpty asserts that the specified object is NOT [Empty]. // // if a.NotEmpty(obj) { // assert.Equal(t, "two", obj[1]) @@ -1141,8 +1212,7 @@ func (a *Assertions) NotEmpty(object interface{}, msgAndArgs ...interface{}) boo return NotEmpty(a.t, object, msgAndArgs...) } -// NotEmptyf asserts that the specified object is NOT empty. I.e. not nil, "", false, 0 or either -// a slice or a channel with len == 0. +// NotEmptyf asserts that the specified object is NOT [Empty]. // // if a.NotEmptyf(obj, "error message %s", "formatted") { // assert.Equal(t, "two", obj[1]) @@ -1200,7 +1270,25 @@ func (a *Assertions) NotEqualf(expected interface{}, actual interface{}, msg str return NotEqualf(a.t, expected, actual, msg, args...) } -// NotErrorIs asserts that at none of the errors in err's chain matches target. +// NotErrorAs asserts that none of the errors in err's chain matches target, +// but if so, sets target to that error value. +func (a *Assertions) NotErrorAs(err error, target interface{}, msgAndArgs ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return NotErrorAs(a.t, err, target, msgAndArgs...) +} + +// NotErrorAsf asserts that none of the errors in err's chain matches target, +// but if so, sets target to that error value. +func (a *Assertions) NotErrorAsf(err error, target interface{}, msg string, args ...interface{}) bool { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + return NotErrorAsf(a.t, err, target, msg, args...) +} + +// NotErrorIs asserts that none of the errors in err's chain matches target. // This is a wrapper for errors.Is. func (a *Assertions) NotErrorIs(err error, target error, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { @@ -1209,7 +1297,7 @@ func (a *Assertions) NotErrorIs(err error, target error, msgAndArgs ...interface return NotErrorIs(a.t, err, target, msgAndArgs...) } -// NotErrorIsf asserts that at none of the errors in err's chain matches target. +// NotErrorIsf asserts that none of the errors in err's chain matches target. // This is a wrapper for errors.Is. func (a *Assertions) NotErrorIsf(err error, target error, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { @@ -1326,12 +1414,15 @@ func (a *Assertions) NotSamef(expected interface{}, actual interface{}, msg stri return NotSamef(a.t, expected, actual, msg, args...) } -// NotSubset asserts that the specified list(array, slice...) or map does NOT -// contain all elements given in the specified subset list(array, slice...) or -// map. +// NotSubset asserts that the list (array, slice, or map) does NOT contain all +// elements given in the subset (array, slice, or map). +// Map elements are key-value pairs unless compared with an array or slice where +// only the map key is evaluated. // // a.NotSubset([1, 3, 4], [1, 2]) // a.NotSubset({"x": 1, "y": 2}, {"z": 3}) +// a.NotSubset([1, 3, 4], {1: "one", 2: "two"}) +// a.NotSubset({"x": 1, "y": 2}, ["z"]) func (a *Assertions) NotSubset(list interface{}, subset interface{}, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -1339,12 +1430,15 @@ func (a *Assertions) NotSubset(list interface{}, subset interface{}, msgAndArgs return NotSubset(a.t, list, subset, msgAndArgs...) } -// NotSubsetf asserts that the specified list(array, slice...) or map does NOT -// contain all elements given in the specified subset list(array, slice...) or -// map. +// NotSubsetf asserts that the list (array, slice, or map) does NOT contain all +// elements given in the subset (array, slice, or map). +// Map elements are key-value pairs unless compared with an array or slice where +// only the map key is evaluated. // // a.NotSubsetf([1, 3, 4], [1, 2], "error message %s", "formatted") // a.NotSubsetf({"x": 1, "y": 2}, {"z": 3}, "error message %s", "formatted") +// a.NotSubsetf([1, 3, 4], {1: "one", 2: "two"}, "error message %s", "formatted") +// a.NotSubsetf({"x": 1, "y": 2}, ["z"], "error message %s", "formatted") func (a *Assertions) NotSubsetf(list interface{}, subset interface{}, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -1504,11 +1598,15 @@ func (a *Assertions) Samef(expected interface{}, actual interface{}, msg string, return Samef(a.t, expected, actual, msg, args...) } -// Subset asserts that the specified list(array, slice...) or map contains all -// elements given in the specified subset list(array, slice...) or map. +// Subset asserts that the list (array, slice, or map) contains all elements +// given in the subset (array, slice, or map). +// Map elements are key-value pairs unless compared with an array or slice where +// only the map key is evaluated. // // a.Subset([1, 2, 3], [1, 2]) // a.Subset({"x": 1, "y": 2}, {"x": 1}) +// a.Subset([1, 2, 3], {1: "one", 2: "two"}) +// a.Subset({"x": 1, "y": 2}, ["x"]) func (a *Assertions) Subset(list interface{}, subset interface{}, msgAndArgs ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -1516,11 +1614,15 @@ func (a *Assertions) Subset(list interface{}, subset interface{}, msgAndArgs ... return Subset(a.t, list, subset, msgAndArgs...) } -// Subsetf asserts that the specified list(array, slice...) or map contains all -// elements given in the specified subset list(array, slice...) or map. +// Subsetf asserts that the list (array, slice, or map) contains all elements +// given in the subset (array, slice, or map). +// Map elements are key-value pairs unless compared with an array or slice where +// only the map key is evaluated. // // a.Subsetf([1, 2, 3], [1, 2], "error message %s", "formatted") // a.Subsetf({"x": 1, "y": 2}, {"x": 1}, "error message %s", "formatted") +// a.Subsetf([1, 2, 3], {1: "one", 2: "two"}, "error message %s", "formatted") +// a.Subsetf({"x": 1, "y": 2}, ["x"], "error message %s", "formatted") func (a *Assertions) Subsetf(list interface{}, subset interface{}, msg string, args ...interface{}) bool { if h, ok := a.t.(tHelper); ok { h.Helper() diff --git a/vendor/github.com/stretchr/testify/assert/assertion_order.go b/vendor/github.com/stretchr/testify/assert/assertion_order.go index 00df62a05..2fdf80fdd 100644 --- a/vendor/github.com/stretchr/testify/assert/assertion_order.go +++ b/vendor/github.com/stretchr/testify/assert/assertion_order.go @@ -6,7 +6,7 @@ import ( ) // isOrdered checks that collection contains orderable elements. -func isOrdered(t TestingT, object interface{}, allowedComparesResults []CompareType, failMessage string, msgAndArgs ...interface{}) bool { +func isOrdered(t TestingT, object interface{}, allowedComparesResults []compareResult, failMessage string, msgAndArgs ...interface{}) bool { objKind := reflect.TypeOf(object).Kind() if objKind != reflect.Slice && objKind != reflect.Array { return false @@ -33,7 +33,7 @@ func isOrdered(t TestingT, object interface{}, allowedComparesResults []CompareT compareResult, isComparable := compare(prevValueInterface, valueInterface, firstValueKind) if !isComparable { - return Fail(t, fmt.Sprintf("Can not compare type \"%s\" and \"%s\"", reflect.TypeOf(value), reflect.TypeOf(prevValue)), msgAndArgs...) + return Fail(t, fmt.Sprintf(`Can not compare type "%T" and "%T"`, value, prevValue), msgAndArgs...) } if !containsValue(allowedComparesResults, compareResult) { @@ -50,7 +50,7 @@ func isOrdered(t TestingT, object interface{}, allowedComparesResults []CompareT // assert.IsIncreasing(t, []float{1, 2}) // assert.IsIncreasing(t, []string{"a", "b"}) func IsIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { - return isOrdered(t, object, []CompareType{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs...) + return isOrdered(t, object, []compareResult{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs...) } // IsNonIncreasing asserts that the collection is not increasing @@ -59,7 +59,7 @@ func IsIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) boo // assert.IsNonIncreasing(t, []float{2, 1}) // assert.IsNonIncreasing(t, []string{"b", "a"}) func IsNonIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { - return isOrdered(t, object, []CompareType{compareEqual, compareGreater}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs...) + return isOrdered(t, object, []compareResult{compareEqual, compareGreater}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs...) } // IsDecreasing asserts that the collection is decreasing @@ -68,7 +68,7 @@ func IsNonIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) // assert.IsDecreasing(t, []float{2, 1}) // assert.IsDecreasing(t, []string{"b", "a"}) func IsDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { - return isOrdered(t, object, []CompareType{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs...) + return isOrdered(t, object, []compareResult{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs...) } // IsNonDecreasing asserts that the collection is not decreasing @@ -77,5 +77,5 @@ func IsDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) boo // assert.IsNonDecreasing(t, []float{1, 2}) // assert.IsNonDecreasing(t, []string{"a", "b"}) func IsNonDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { - return isOrdered(t, object, []CompareType{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs...) + return isOrdered(t, object, []compareResult{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs...) } diff --git a/vendor/github.com/stretchr/testify/assert/assertions.go b/vendor/github.com/stretchr/testify/assert/assertions.go index 0b7570f21..de8de0cb6 100644 --- a/vendor/github.com/stretchr/testify/assert/assertions.go +++ b/vendor/github.com/stretchr/testify/assert/assertions.go @@ -19,7 +19,9 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/pmezard/go-difflib/difflib" - "gopkg.in/yaml.v3" + + // Wrapper around gopkg.in/yaml.v3 + "github.com/stretchr/testify/assert/yaml" ) //go:generate sh -c "cd ../_codegen && go build && cd - && ../_codegen/_codegen -output-package=assert -template=assertion_format.go.tmpl" @@ -45,6 +47,10 @@ type BoolAssertionFunc func(TestingT, bool, ...interface{}) bool // for table driven tests. type ErrorAssertionFunc func(TestingT, error, ...interface{}) bool +// PanicAssertionFunc is a common function prototype when validating a panic value. Can be useful +// for table driven tests. +type PanicAssertionFunc = func(t TestingT, f PanicTestFunc, msgAndArgs ...interface{}) bool + // Comparison is a custom function that returns true on success and false on failure type Comparison func() (success bool) @@ -204,59 +210,77 @@ the problem actually occurred in calling code.*/ // of each stack frame leading from the current test to the assert call that // failed. func CallerInfo() []string { - var pc uintptr - var ok bool var file string var line int var name string + const stackFrameBufferSize = 10 + pcs := make([]uintptr, stackFrameBufferSize) + callers := []string{} - for i := 0; ; i++ { - pc, file, line, ok = runtime.Caller(i) - if !ok { - // The breaks below failed to terminate the loop, and we ran off the - // end of the call stack. - break - } + offset := 1 - // This is a huge edge case, but it will panic if this is the case, see #180 - if file == "" { - break - } + for { + n := runtime.Callers(offset, pcs) - f := runtime.FuncForPC(pc) - if f == nil { - break - } - name = f.Name() - - // testing.tRunner is the standard library function that calls - // tests. Subtests are called directly by tRunner, without going through - // the Test/Benchmark/Example function that contains the t.Run calls, so - // with subtests we should break when we hit tRunner, without adding it - // to the list of callers. - if name == "testing.tRunner" { + if n == 0 { break } - parts := strings.Split(file, "/") - if len(parts) > 1 { - filename := parts[len(parts)-1] - dir := parts[len(parts)-2] - if (dir != "assert" && dir != "mock" && dir != "require") || filename == "mock_test.go" { - callers = append(callers, fmt.Sprintf("%s:%d", file, line)) + frames := runtime.CallersFrames(pcs[:n]) + + for { + frame, more := frames.Next() + pc = frame.PC + file = frame.File + line = frame.Line + + // This is a huge edge case, but it will panic if this is the case, see #180 + if file == "" { + break } - } - // Drop the package - segments := strings.Split(name, ".") - name = segments[len(segments)-1] - if isTest(name, "Test") || - isTest(name, "Benchmark") || - isTest(name, "Example") { - break + f := runtime.FuncForPC(pc) + if f == nil { + break + } + name = f.Name() + + // testing.tRunner is the standard library function that calls + // tests. Subtests are called directly by tRunner, without going through + // the Test/Benchmark/Example function that contains the t.Run calls, so + // with subtests we should break when we hit tRunner, without adding it + // to the list of callers. + if name == "testing.tRunner" { + break + } + + parts := strings.Split(file, "/") + if len(parts) > 1 { + filename := parts[len(parts)-1] + dir := parts[len(parts)-2] + if (dir != "assert" && dir != "mock" && dir != "require") || filename == "mock_test.go" { + callers = append(callers, fmt.Sprintf("%s:%d", file, line)) + } + } + + // Drop the package + dotPos := strings.LastIndexByte(name, '.') + name = name[dotPos+1:] + if isTest(name, "Test") || + isTest(name, "Benchmark") || + isTest(name, "Example") { + break + } + + if !more { + break + } } + + // Next batch + offset += cap(pcs) } return callers @@ -431,17 +455,34 @@ func NotImplements(t TestingT, interfaceObject interface{}, object interface{}, return true } +func isType(expectedType, object interface{}) bool { + return ObjectsAreEqual(reflect.TypeOf(object), reflect.TypeOf(expectedType)) +} + // IsType asserts that the specified objects are of the same type. -func IsType(t TestingT, expectedType interface{}, object interface{}, msgAndArgs ...interface{}) bool { +// +// assert.IsType(t, &MyStruct{}, &MyStruct{}) +func IsType(t TestingT, expectedType, object interface{}, msgAndArgs ...interface{}) bool { + if isType(expectedType, object) { + return true + } if h, ok := t.(tHelper); ok { h.Helper() } + return Fail(t, fmt.Sprintf("Object expected to be of type %T, but was %T", expectedType, object), msgAndArgs...) +} - if !ObjectsAreEqual(reflect.TypeOf(object), reflect.TypeOf(expectedType)) { - return Fail(t, fmt.Sprintf("Object expected to be of type %v, but was %v", reflect.TypeOf(expectedType), reflect.TypeOf(object)), msgAndArgs...) +// IsNotType asserts that the specified objects are not of the same type. +// +// assert.IsNotType(t, &NotMyStruct{}, &MyStruct{}) +func IsNotType(t TestingT, theType, object interface{}, msgAndArgs ...interface{}) bool { + if !isType(theType, object) { + return true } - - return true + if h, ok := t.(tHelper); ok { + h.Helper() + } + return Fail(t, fmt.Sprintf("Object type expected to be different than %T", theType), msgAndArgs...) } // Equal asserts that two objects are equal. @@ -469,7 +510,6 @@ func Equal(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) } return true - } // validateEqualArgs checks whether provided arguments can be safely used in the @@ -496,10 +536,17 @@ func Same(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) b h.Helper() } - if !samePointers(expected, actual) { + same, ok := samePointers(expected, actual) + if !ok { + return Fail(t, "Both arguments must be pointers", msgAndArgs...) + } + + if !same { + // both are pointers but not the same type & pointing to the same address return Fail(t, fmt.Sprintf("Not same: \n"+ - "expected: %p %#v\n"+ - "actual : %p %#v", expected, expected, actual, actual), msgAndArgs...) + "expected: %p %#[1]v\n"+ + "actual : %p %#[2]v", + expected, actual), msgAndArgs...) } return true @@ -516,29 +563,37 @@ func NotSame(t TestingT, expected, actual interface{}, msgAndArgs ...interface{} h.Helper() } - if samePointers(expected, actual) { + same, ok := samePointers(expected, actual) + if !ok { + // fails when the arguments are not pointers + return !(Fail(t, "Both arguments must be pointers", msgAndArgs...)) + } + + if same { return Fail(t, fmt.Sprintf( - "Expected and actual point to the same object: %p %#v", - expected, expected), msgAndArgs...) + "Expected and actual point to the same object: %p %#[1]v", + expected), msgAndArgs...) } return true } -// samePointers compares two generic interface objects and returns whether -// they point to the same object -func samePointers(first, second interface{}) bool { +// samePointers checks if two generic interface objects are pointers of the same +// type pointing to the same object. It returns two values: same indicating if +// they are the same type and point to the same object, and ok indicating that +// both inputs are pointers. +func samePointers(first, second interface{}) (same bool, ok bool) { firstPtr, secondPtr := reflect.ValueOf(first), reflect.ValueOf(second) if firstPtr.Kind() != reflect.Ptr || secondPtr.Kind() != reflect.Ptr { - return false + return false, false // not both are pointers } firstType, secondType := reflect.TypeOf(first), reflect.TypeOf(second) if firstType != secondType { - return false + return false, true // both are pointers, but of different types } // compare pointer addresses - return first == second + return first == second, true } // formatUnequalValues takes two values of arbitrary types and returns string @@ -572,8 +627,8 @@ func truncatingFormat(data interface{}) string { return value } -// EqualValues asserts that two objects are equal or convertible to the same types -// and equal. +// EqualValues asserts that two objects are equal or convertible to the larger +// type and equal. // // assert.EqualValues(t, uint32(123), int32(123)) func EqualValues(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool { @@ -590,7 +645,6 @@ func EqualValues(t TestingT, expected, actual interface{}, msgAndArgs ...interfa } return true - } // EqualExportedValues asserts that the types of two objects are equal and their public @@ -615,21 +669,6 @@ func EqualExportedValues(t TestingT, expected, actual interface{}, msgAndArgs .. return Fail(t, fmt.Sprintf("Types expected to match exactly\n\t%v != %v", aType, bType), msgAndArgs...) } - if aType.Kind() == reflect.Ptr { - aType = aType.Elem() - } - if bType.Kind() == reflect.Ptr { - bType = bType.Elem() - } - - if aType.Kind() != reflect.Struct { - return Fail(t, fmt.Sprintf("Types expected to both be struct or pointer to struct \n\t%v != %v", aType.Kind(), reflect.Struct), msgAndArgs...) - } - - if bType.Kind() != reflect.Struct { - return Fail(t, fmt.Sprintf("Types expected to both be struct or pointer to struct \n\t%v != %v", bType.Kind(), reflect.Struct), msgAndArgs...) - } - expected = copyExportedFields(expected) actual = copyExportedFields(actual) @@ -660,7 +699,6 @@ func Exactly(t TestingT, expected, actual interface{}, msgAndArgs ...interface{} } return Equal(t, expected, actual, msgAndArgs...) - } // NotNil asserts that the specified object is not nil. @@ -710,37 +748,45 @@ func Nil(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { // isEmpty gets whether the specified object is considered empty or not. func isEmpty(object interface{}) bool { - // get nil case out of the way if object == nil { return true } - objValue := reflect.ValueOf(object) + return isEmptyValue(reflect.ValueOf(object)) +} +// isEmptyValue gets whether the specified reflect.Value is considered empty or not. +func isEmptyValue(objValue reflect.Value) bool { + if objValue.IsZero() { + return true + } + // Special cases of non-zero values that we consider empty switch objValue.Kind() { // collection types are empty when they have no element + // Note: array types are empty when they match their zero-initialized state. case reflect.Chan, reflect.Map, reflect.Slice: return objValue.Len() == 0 - // pointers are empty if nil or if the value they point to is empty + // non-nil pointers are empty if the value they point to is empty case reflect.Ptr: - if objValue.IsNil() { - return true - } - deref := objValue.Elem().Interface() - return isEmpty(deref) - // for all other types, compare against the zero value - // array types are empty when they match their zero-initialized state - default: - zero := reflect.Zero(objValue.Type()) - return reflect.DeepEqual(object, zero.Interface()) + return isEmptyValue(objValue.Elem()) } + return false } -// Empty asserts that the specified object is empty. I.e. nil, "", false, 0 or either -// a slice or a channel with len == 0. +// Empty asserts that the given value is "empty". +// +// [Zero values] are "empty". +// +// Arrays are "empty" if every element is the zero value of the type (stricter than "empty"). +// +// Slices, maps and channels with zero length are "empty". +// +// Pointer values are "empty" if the pointer is nil or if the pointed value is "empty". // // assert.Empty(t, obj) +// +// [Zero values]: https://go.dev/ref/spec#The_zero_value func Empty(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { pass := isEmpty(object) if !pass { @@ -751,11 +797,9 @@ func Empty(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { } return pass - } -// NotEmpty asserts that the specified object is NOT empty. I.e. not nil, "", false, 0 or either -// a slice or a channel with len == 0. +// NotEmpty asserts that the specified object is NOT [Empty]. // // if assert.NotEmpty(t, obj) { // assert.Equal(t, "two", obj[1]) @@ -770,7 +814,6 @@ func NotEmpty(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { } return pass - } // getLen tries to get the length of an object. @@ -814,7 +857,6 @@ func True(t TestingT, value bool, msgAndArgs ...interface{}) bool { } return true - } // False asserts that the specified value is false. @@ -829,7 +871,6 @@ func False(t TestingT, value bool, msgAndArgs ...interface{}) bool { } return true - } // NotEqual asserts that the specified values are NOT equal. @@ -852,7 +893,6 @@ func NotEqual(t TestingT, expected, actual interface{}, msgAndArgs ...interface{ } return true - } // NotEqualValues asserts that two objects are not equal even when converted to the same type @@ -875,7 +915,6 @@ func NotEqualValues(t TestingT, expected, actual interface{}, msgAndArgs ...inte // return (true, false) if element was not found. // return (true, true) if element was found. func containsElement(list interface{}, element interface{}) (ok, found bool) { - listValue := reflect.ValueOf(list) listType := reflect.TypeOf(list) if listType == nil { @@ -910,7 +949,6 @@ func containsElement(list interface{}, element interface{}) (ok, found bool) { } } return true, false - } // Contains asserts that the specified string, list(array, slice...) or map contains the @@ -933,7 +971,6 @@ func Contains(t TestingT, s, contains interface{}, msgAndArgs ...interface{}) bo } return true - } // NotContains asserts that the specified string, list(array, slice...) or map does NOT contain the @@ -956,14 +993,17 @@ func NotContains(t TestingT, s, contains interface{}, msgAndArgs ...interface{}) } return true - } -// Subset asserts that the specified list(array, slice...) or map contains all -// elements given in the specified subset list(array, slice...) or map. +// Subset asserts that the list (array, slice, or map) contains all elements +// given in the subset (array, slice, or map). +// Map elements are key-value pairs unless compared with an array or slice where +// only the map key is evaluated. // // assert.Subset(t, [1, 2, 3], [1, 2]) // assert.Subset(t, {"x": 1, "y": 2}, {"x": 1}) +// assert.Subset(t, [1, 2, 3], {1: "one", 2: "two"}) +// assert.Subset(t, {"x": 1, "y": 2}, ["x"]) func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok bool) { if h, ok := t.(tHelper); ok { h.Helper() @@ -978,7 +1018,7 @@ func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok } subsetKind := reflect.TypeOf(subset).Kind() - if subsetKind != reflect.Array && subsetKind != reflect.Slice && listKind != reflect.Map { + if subsetKind != reflect.Array && subsetKind != reflect.Slice && subsetKind != reflect.Map { return Fail(t, fmt.Sprintf("%q has an unsupported type %s", subset, subsetKind), msgAndArgs...) } @@ -1002,6 +1042,13 @@ func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok } subsetList := reflect.ValueOf(subset) + if subsetKind == reflect.Map { + keys := make([]interface{}, subsetList.Len()) + for idx, key := range subsetList.MapKeys() { + keys[idx] = key.Interface() + } + subsetList = reflect.ValueOf(keys) + } for i := 0; i < subsetList.Len(); i++ { element := subsetList.Index(i).Interface() ok, found := containsElement(list, element) @@ -1016,12 +1063,15 @@ func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok return true } -// NotSubset asserts that the specified list(array, slice...) or map does NOT -// contain all elements given in the specified subset list(array, slice...) or -// map. +// NotSubset asserts that the list (array, slice, or map) does NOT contain all +// elements given in the subset (array, slice, or map). +// Map elements are key-value pairs unless compared with an array or slice where +// only the map key is evaluated. // // assert.NotSubset(t, [1, 3, 4], [1, 2]) // assert.NotSubset(t, {"x": 1, "y": 2}, {"z": 3}) +// assert.NotSubset(t, [1, 3, 4], {1: "one", 2: "two"}) +// assert.NotSubset(t, {"x": 1, "y": 2}, ["z"]) func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok bool) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1036,7 +1086,7 @@ func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) } subsetKind := reflect.TypeOf(subset).Kind() - if subsetKind != reflect.Array && subsetKind != reflect.Slice && listKind != reflect.Map { + if subsetKind != reflect.Array && subsetKind != reflect.Slice && subsetKind != reflect.Map { return Fail(t, fmt.Sprintf("%q has an unsupported type %s", subset, subsetKind), msgAndArgs...) } @@ -1060,11 +1110,18 @@ func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) } subsetList := reflect.ValueOf(subset) + if subsetKind == reflect.Map { + keys := make([]interface{}, subsetList.Len()) + for idx, key := range subsetList.MapKeys() { + keys[idx] = key.Interface() + } + subsetList = reflect.ValueOf(keys) + } for i := 0; i < subsetList.Len(); i++ { element := subsetList.Index(i).Interface() ok, found := containsElement(list, element) if !ok { - return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", list), msgAndArgs...) + return Fail(t, fmt.Sprintf("%q could not be applied builtin len()", list), msgAndArgs...) } if !found { return true @@ -1170,6 +1227,39 @@ func formatListDiff(listA, listB interface{}, extraA, extraB []interface{}) stri return msg.String() } +// NotElementsMatch asserts that the specified listA(array, slice...) is NOT equal to specified +// listB(array, slice...) ignoring the order of the elements. If there are duplicate elements, +// the number of appearances of each of them in both lists should not match. +// This is an inverse of ElementsMatch. +// +// assert.NotElementsMatch(t, [1, 1, 2, 3], [1, 1, 2, 3]) -> false +// +// assert.NotElementsMatch(t, [1, 1, 2, 3], [1, 2, 3]) -> true +// +// assert.NotElementsMatch(t, [1, 2, 3], [1, 2, 4]) -> true +func NotElementsMatch(t TestingT, listA, listB interface{}, msgAndArgs ...interface{}) (ok bool) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if isEmpty(listA) && isEmpty(listB) { + return Fail(t, "listA and listB contain the same elements", msgAndArgs) + } + + if !isList(t, listA, msgAndArgs...) { + return Fail(t, "listA is not a list type", msgAndArgs...) + } + if !isList(t, listB, msgAndArgs...) { + return Fail(t, "listB is not a list type", msgAndArgs...) + } + + extraA, extraB := diffLists(listA, listB) + if len(extraA) == 0 && len(extraB) == 0 { + return Fail(t, "listA and listB contain the same elements", msgAndArgs) + } + + return true +} + // Condition uses a Comparison to assert a complex condition. func Condition(t TestingT, comp Comparison, msgAndArgs ...interface{}) bool { if h, ok := t.(tHelper); ok { @@ -1488,6 +1578,9 @@ func InEpsilon(t TestingT, expected, actual interface{}, epsilon float64, msgAnd if err != nil { return Fail(t, err.Error(), msgAndArgs...) } + if math.IsNaN(actualEpsilon) { + return Fail(t, "relative error is NaN", msgAndArgs...) + } if actualEpsilon > epsilon { return Fail(t, fmt.Sprintf("Relative error is too high: %#v (expected)\n"+ " < %#v (actual)", epsilon, actualEpsilon), msgAndArgs...) @@ -1550,10 +1643,8 @@ func NoError(t TestingT, err error, msgAndArgs ...interface{}) bool { // Error asserts that a function returned an error (i.e. not `nil`). // -// actualObj, err := SomeFunction() -// if assert.Error(t, err) { -// assert.Equal(t, expectedError, err) -// } +// actualObj, err := SomeFunction() +// assert.Error(t, err) func Error(t TestingT, err error, msgAndArgs ...interface{}) bool { if err == nil { if h, ok := t.(tHelper); ok { @@ -1611,7 +1702,6 @@ func ErrorContains(t TestingT, theError error, contains string, msgAndArgs ...in // matchRegexp return true if a specified regexp matches a string. func matchRegexp(rx interface{}, str interface{}) bool { - var r *regexp.Regexp if rr, ok := rx.(*regexp.Regexp); ok { r = rr @@ -1619,8 +1709,14 @@ func matchRegexp(rx interface{}, str interface{}) bool { r = regexp.MustCompile(fmt.Sprint(rx)) } - return (r.FindStringIndex(fmt.Sprint(str)) != nil) - + switch v := str.(type) { + case []byte: + return r.Match(v) + case string: + return r.MatchString(v) + default: + return r.MatchString(fmt.Sprint(v)) + } } // Regexp asserts that a specified regexp matches a string. @@ -1656,7 +1752,6 @@ func NotRegexp(t TestingT, rx interface{}, str interface{}, msgAndArgs ...interf } return !match - } // Zero asserts that i is the zero value for its type. @@ -1767,6 +1862,11 @@ func JSONEq(t TestingT, expected string, actual string, msgAndArgs ...interface{ return Fail(t, fmt.Sprintf("Expected value ('%s') is not valid json.\nJSON parsing error: '%s'", expected, err.Error()), msgAndArgs...) } + // Shortcut if same bytes + if actual == expected { + return true + } + if err := json.Unmarshal([]byte(actual), &actualJSONAsInterface); err != nil { return Fail(t, fmt.Sprintf("Input ('%s') needs to be valid json.\nJSON parsing error: '%s'", actual, err.Error()), msgAndArgs...) } @@ -1785,6 +1885,11 @@ func YAMLEq(t TestingT, expected string, actual string, msgAndArgs ...interface{ return Fail(t, fmt.Sprintf("Expected value ('%s') is not valid yaml.\nYAML parsing error: '%s'", expected, err.Error()), msgAndArgs...) } + // Shortcut if same bytes + if actual == expected { + return true + } + if err := yaml.Unmarshal([]byte(actual), &actualYAMLAsInterface); err != nil { return Fail(t, fmt.Sprintf("Input ('%s') needs to be valid yaml.\nYAML error: '%s'", actual, err.Error()), msgAndArgs...) } @@ -1872,7 +1977,7 @@ var spewConfigStringerEnabled = spew.ConfigState{ MaxDepth: 10, } -type tHelper interface { +type tHelper = interface { Helper() } @@ -1886,6 +1991,7 @@ func Eventually(t TestingT, condition func() bool, waitFor time.Duration, tick t } ch := make(chan bool, 1) + checkCond := func() { ch <- condition() } timer := time.NewTimer(waitFor) defer timer.Stop() @@ -1893,35 +1999,47 @@ func Eventually(t TestingT, condition func() bool, waitFor time.Duration, tick t ticker := time.NewTicker(tick) defer ticker.Stop() - for tick := ticker.C; ; { + var tickC <-chan time.Time + + // Check the condition once first on the initial call. + go checkCond() + + for { select { case <-timer.C: return Fail(t, "Condition never satisfied", msgAndArgs...) - case <-tick: - tick = nil - go func() { ch <- condition() }() + case <-tickC: + tickC = nil + go checkCond() case v := <-ch: if v { return true } - tick = ticker.C + tickC = ticker.C } } } // CollectT implements the TestingT interface and collects all errors. type CollectT struct { + // A slice of errors. Non-nil slice denotes a failure. + // If it's non-nil but len(c.errors) == 0, this is also a failure + // obtained by direct c.FailNow() call. errors []error } +// Helper is like [testing.T.Helper] but does nothing. +func (CollectT) Helper() {} + // Errorf collects the error. func (c *CollectT) Errorf(format string, args ...interface{}) { c.errors = append(c.errors, fmt.Errorf(format, args...)) } -// FailNow panics. -func (*CollectT) FailNow() { - panic("Assertion failed") +// FailNow stops execution by calling runtime.Goexit. +func (c *CollectT) FailNow() { + c.fail() + runtime.Goexit() } // Deprecated: That was a method for internal usage that should not have been published. Now just panics. @@ -1934,6 +2052,16 @@ func (*CollectT) Copy(TestingT) { panic("Copy() is deprecated") } +func (c *CollectT) fail() { + if !c.failed() { + c.errors = []error{} // Make it non-nil to mark a failure. + } +} + +func (c *CollectT) failed() bool { + return c.errors != nil +} + // EventuallyWithT asserts that given condition will be met in waitFor time, // periodically checking target function each tick. In contrast to Eventually, // it supplies a CollectT to the condition function, so that the condition @@ -1951,14 +2079,22 @@ func (*CollectT) Copy(TestingT) { // assert.EventuallyWithT(t, func(c *assert.CollectT) { // // add assertions as needed; any assertion failure will fail the current tick // assert.True(c, externalValue, "expected 'externalValue' to be true") -// }, 1*time.Second, 10*time.Second, "external state has not changed to 'true'; still false") +// }, 10*time.Second, 1*time.Second, "external state has not changed to 'true'; still false") func EventuallyWithT(t TestingT, condition func(collect *CollectT), waitFor time.Duration, tick time.Duration, msgAndArgs ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() } var lastFinishedTickErrs []error - ch := make(chan []error, 1) + ch := make(chan *CollectT, 1) + + checkCond := func() { + collect := new(CollectT) + defer func() { + ch <- collect + }() + condition(collect) + } timer := time.NewTimer(waitFor) defer timer.Stop() @@ -1966,29 +2102,28 @@ func EventuallyWithT(t TestingT, condition func(collect *CollectT), waitFor time ticker := time.NewTicker(tick) defer ticker.Stop() - for tick := ticker.C; ; { + var tickC <-chan time.Time + + // Check the condition once first on the initial call. + go checkCond() + + for { select { case <-timer.C: for _, err := range lastFinishedTickErrs { t.Errorf("%v", err) } return Fail(t, "Condition never satisfied", msgAndArgs...) - case <-tick: - tick = nil - go func() { - collect := new(CollectT) - defer func() { - ch <- collect.errors - }() - condition(collect) - }() - case errs := <-ch: - if len(errs) == 0 { + case <-tickC: + tickC = nil + go checkCond() + case collect := <-ch: + if !collect.failed() { return true } // Keep the errors from the last ended condition, so that they can be copied to t if timeout is reached. - lastFinishedTickErrs = errs - tick = ticker.C + lastFinishedTickErrs = collect.errors + tickC = ticker.C } } } @@ -2003,6 +2138,7 @@ func Never(t TestingT, condition func() bool, waitFor time.Duration, tick time.D } ch := make(chan bool, 1) + checkCond := func() { ch <- condition() } timer := time.NewTimer(waitFor) defer timer.Stop() @@ -2010,18 +2146,23 @@ func Never(t TestingT, condition func() bool, waitFor time.Duration, tick time.D ticker := time.NewTicker(tick) defer ticker.Stop() - for tick := ticker.C; ; { + var tickC <-chan time.Time + + // Check the condition once first on the initial call. + go checkCond() + + for { select { case <-timer.C: return true - case <-tick: - tick = nil - go func() { ch <- condition() }() + case <-tickC: + tickC = nil + go checkCond() case v := <-ch: if v { return Fail(t, "Condition satisfied", msgAndArgs...) } - tick = ticker.C + tickC = ticker.C } } } @@ -2039,9 +2180,12 @@ func ErrorIs(t TestingT, err, target error, msgAndArgs ...interface{}) bool { var expectedText string if target != nil { expectedText = target.Error() + if err == nil { + return Fail(t, fmt.Sprintf("Expected error with %q in chain but got nil.", expectedText), msgAndArgs...) + } } - chain := buildErrorChainString(err) + chain := buildErrorChainString(err, false) return Fail(t, fmt.Sprintf("Target error should be in err chain:\n"+ "expected: %q\n"+ @@ -2049,7 +2193,7 @@ func ErrorIs(t TestingT, err, target error, msgAndArgs ...interface{}) bool { ), msgAndArgs...) } -// NotErrorIs asserts that at none of the errors in err's chain matches target. +// NotErrorIs asserts that none of the errors in err's chain matches target. // This is a wrapper for errors.Is. func NotErrorIs(t TestingT, err, target error, msgAndArgs ...interface{}) bool { if h, ok := t.(tHelper); ok { @@ -2064,7 +2208,7 @@ func NotErrorIs(t TestingT, err, target error, msgAndArgs ...interface{}) bool { expectedText = target.Error() } - chain := buildErrorChainString(err) + chain := buildErrorChainString(err, false) return Fail(t, fmt.Sprintf("Target error should not be in err chain:\n"+ "found: %q\n"+ @@ -2082,24 +2226,70 @@ func ErrorAs(t TestingT, err error, target interface{}, msgAndArgs ...interface{ return true } - chain := buildErrorChainString(err) + expectedType := reflect.TypeOf(target).Elem().String() + if err == nil { + return Fail(t, fmt.Sprintf("An error is expected but got nil.\n"+ + "expected: %s", expectedType), msgAndArgs...) + } + + chain := buildErrorChainString(err, true) return Fail(t, fmt.Sprintf("Should be in error chain:\n"+ - "expected: %q\n"+ - "in chain: %s", target, chain, + "expected: %s\n"+ + "in chain: %s", expectedType, chain, ), msgAndArgs...) } -func buildErrorChainString(err error) string { +// NotErrorAs asserts that none of the errors in err's chain matches target, +// but if so, sets target to that error value. +func NotErrorAs(t TestingT, err error, target interface{}, msgAndArgs ...interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if !errors.As(err, target) { + return true + } + + chain := buildErrorChainString(err, true) + + return Fail(t, fmt.Sprintf("Target error should not be in err chain:\n"+ + "found: %s\n"+ + "in chain: %s", reflect.TypeOf(target).Elem().String(), chain, + ), msgAndArgs...) +} + +func unwrapAll(err error) (errs []error) { + errs = append(errs, err) + switch x := err.(type) { + case interface{ Unwrap() error }: + err = x.Unwrap() + if err == nil { + return + } + errs = append(errs, unwrapAll(err)...) + case interface{ Unwrap() []error }: + for _, err := range x.Unwrap() { + errs = append(errs, unwrapAll(err)...) + } + } + return +} + +func buildErrorChainString(err error, withType bool) string { if err == nil { return "" } - e := errors.Unwrap(err) - chain := fmt.Sprintf("%q", err.Error()) - for e != nil { - chain += fmt.Sprintf("\n\t%q", e.Error()) - e = errors.Unwrap(e) + var chain string + errs := unwrapAll(err) + for i := range errs { + if i != 0 { + chain += "\n\t" + } + chain += fmt.Sprintf("%q", errs[i].Error()) + if withType { + chain += fmt.Sprintf(" (%T)", errs[i]) + } } return chain } diff --git a/vendor/github.com/stretchr/testify/assert/doc.go b/vendor/github.com/stretchr/testify/assert/doc.go index 4953981d3..a0b953aa5 100644 --- a/vendor/github.com/stretchr/testify/assert/doc.go +++ b/vendor/github.com/stretchr/testify/assert/doc.go @@ -1,5 +1,9 @@ // Package assert provides a set of comprehensive testing tools for use with the normal Go testing system. // +// # Note +// +// All functions in this package return a bool value indicating whether the assertion has passed. +// // # Example Usage // // The following is a complete example using assert in a standard test function: diff --git a/vendor/github.com/stretchr/testify/assert/http_assertions.go b/vendor/github.com/stretchr/testify/assert/http_assertions.go index 861ed4b7c..5a6bb75f2 100644 --- a/vendor/github.com/stretchr/testify/assert/http_assertions.go +++ b/vendor/github.com/stretchr/testify/assert/http_assertions.go @@ -138,7 +138,7 @@ func HTTPBodyContains(t TestingT, handler http.HandlerFunc, method, url string, contains := strings.Contains(body, fmt.Sprint(str)) if !contains { - Fail(t, fmt.Sprintf("Expected response body for \"%s\" to contain \"%s\" but found \"%s\"", url+"?"+values.Encode(), str, body), msgAndArgs...) + Fail(t, fmt.Sprintf("Expected response body for %q to contain %q but found %q", url+"?"+values.Encode(), str, body), msgAndArgs...) } return contains @@ -158,7 +158,7 @@ func HTTPBodyNotContains(t TestingT, handler http.HandlerFunc, method, url strin contains := strings.Contains(body, fmt.Sprint(str)) if contains { - Fail(t, fmt.Sprintf("Expected response body for \"%s\" to NOT contain \"%s\" but found \"%s\"", url+"?"+values.Encode(), str, body), msgAndArgs...) + Fail(t, fmt.Sprintf("Expected response body for %q to NOT contain %q but found %q", url+"?"+values.Encode(), str, body), msgAndArgs...) } return !contains diff --git a/vendor/github.com/stretchr/testify/assert/yaml/yaml_custom.go b/vendor/github.com/stretchr/testify/assert/yaml/yaml_custom.go new file mode 100644 index 000000000..5a74c4f4d --- /dev/null +++ b/vendor/github.com/stretchr/testify/assert/yaml/yaml_custom.go @@ -0,0 +1,24 @@ +//go:build testify_yaml_custom && !testify_yaml_fail && !testify_yaml_default + +// Package yaml is an implementation of YAML functions that calls a pluggable implementation. +// +// This implementation is selected with the testify_yaml_custom build tag. +// +// go test -tags testify_yaml_custom +// +// This implementation can be used at build time to replace the default implementation +// to avoid linking with [gopkg.in/yaml.v3]. +// +// In your test package: +// +// import assertYaml "github.com/stretchr/testify/assert/yaml" +// +// func init() { +// assertYaml.Unmarshal = func (in []byte, out interface{}) error { +// // ... +// return nil +// } +// } +package yaml + +var Unmarshal func(in []byte, out interface{}) error diff --git a/vendor/github.com/stretchr/testify/assert/yaml/yaml_default.go b/vendor/github.com/stretchr/testify/assert/yaml/yaml_default.go new file mode 100644 index 000000000..0bae80e34 --- /dev/null +++ b/vendor/github.com/stretchr/testify/assert/yaml/yaml_default.go @@ -0,0 +1,36 @@ +//go:build !testify_yaml_fail && !testify_yaml_custom + +// Package yaml is just an indirection to handle YAML deserialization. +// +// This package is just an indirection that allows the builder to override the +// indirection with an alternative implementation of this package that uses +// another implementation of YAML deserialization. This allows to not either not +// use YAML deserialization at all, or to use another implementation than +// [gopkg.in/yaml.v3] (for example for license compatibility reasons, see [PR #1120]). +// +// Alternative implementations are selected using build tags: +// +// - testify_yaml_fail: [Unmarshal] always fails with an error +// - testify_yaml_custom: [Unmarshal] is a variable. Caller must initialize it +// before calling any of [github.com/stretchr/testify/assert.YAMLEq] or +// [github.com/stretchr/testify/assert.YAMLEqf]. +// +// Usage: +// +// go test -tags testify_yaml_fail +// +// You can check with "go list" which implementation is linked: +// +// go list -f '{{.Imports}}' github.com/stretchr/testify/assert/yaml +// go list -tags testify_yaml_fail -f '{{.Imports}}' github.com/stretchr/testify/assert/yaml +// go list -tags testify_yaml_custom -f '{{.Imports}}' github.com/stretchr/testify/assert/yaml +// +// [PR #1120]: https://github.com/stretchr/testify/pull/1120 +package yaml + +import goyaml "gopkg.in/yaml.v3" + +// Unmarshal is just a wrapper of [gopkg.in/yaml.v3.Unmarshal]. +func Unmarshal(in []byte, out interface{}) error { + return goyaml.Unmarshal(in, out) +} diff --git a/vendor/github.com/stretchr/testify/assert/yaml/yaml_fail.go b/vendor/github.com/stretchr/testify/assert/yaml/yaml_fail.go new file mode 100644 index 000000000..8041803fd --- /dev/null +++ b/vendor/github.com/stretchr/testify/assert/yaml/yaml_fail.go @@ -0,0 +1,17 @@ +//go:build testify_yaml_fail && !testify_yaml_custom && !testify_yaml_default + +// Package yaml is an implementation of YAML functions that always fail. +// +// This implementation can be used at build time to replace the default implementation +// to avoid linking with [gopkg.in/yaml.v3]: +// +// go test -tags testify_yaml_fail +package yaml + +import "errors" + +var errNotImplemented = errors.New("YAML functions are not available (see https://pkg.go.dev/github.com/stretchr/testify/assert/yaml)") + +func Unmarshal([]byte, interface{}) error { + return errNotImplemented +} diff --git a/vendor/github.com/stretchr/testify/require/doc.go b/vendor/github.com/stretchr/testify/require/doc.go index 968434724..c8e3f94a8 100644 --- a/vendor/github.com/stretchr/testify/require/doc.go +++ b/vendor/github.com/stretchr/testify/require/doc.go @@ -23,6 +23,8 @@ // // The `require` package have same global functions as in the `assert` package, // but instead of returning a boolean result they call `t.FailNow()`. +// A consequence of this is that it must be called from the goroutine running +// the test function, not from other goroutines created during the test. // // Every assertion function also takes an optional string message as the final argument, // allowing custom error messages to be appended to the message the assertion method outputs. diff --git a/vendor/github.com/stretchr/testify/require/require.go b/vendor/github.com/stretchr/testify/require/require.go index 506a82f80..2d02f9bce 100644 --- a/vendor/github.com/stretchr/testify/require/require.go +++ b/vendor/github.com/stretchr/testify/require/require.go @@ -34,9 +34,9 @@ func Conditionf(t TestingT, comp assert.Comparison, msg string, args ...interfac // Contains asserts that the specified string, list(array, slice...) or map contains the // specified substring or element. // -// assert.Contains(t, "Hello World", "World") -// assert.Contains(t, ["Hello", "World"], "World") -// assert.Contains(t, {"Hello": "World"}, "Hello") +// require.Contains(t, "Hello World", "World") +// require.Contains(t, ["Hello", "World"], "World") +// require.Contains(t, {"Hello": "World"}, "Hello") func Contains(t TestingT, s interface{}, contains interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -50,9 +50,9 @@ func Contains(t TestingT, s interface{}, contains interface{}, msgAndArgs ...int // Containsf asserts that the specified string, list(array, slice...) or map contains the // specified substring or element. // -// assert.Containsf(t, "Hello World", "World", "error message %s", "formatted") -// assert.Containsf(t, ["Hello", "World"], "World", "error message %s", "formatted") -// assert.Containsf(t, {"Hello": "World"}, "Hello", "error message %s", "formatted") +// require.Containsf(t, "Hello World", "World", "error message %s", "formatted") +// require.Containsf(t, ["Hello", "World"], "World", "error message %s", "formatted") +// require.Containsf(t, {"Hello": "World"}, "Hello", "error message %s", "formatted") func Containsf(t TestingT, s interface{}, contains interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -91,7 +91,7 @@ func DirExistsf(t TestingT, path string, msg string, args ...interface{}) { // listB(array, slice...) ignoring the order of the elements. If there are duplicate elements, // the number of appearances of each of them in both lists should match. // -// assert.ElementsMatch(t, [1, 3, 2, 3], [1, 3, 3, 2]) +// require.ElementsMatch(t, [1, 3, 2, 3], [1, 3, 3, 2]) func ElementsMatch(t TestingT, listA interface{}, listB interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -106,7 +106,7 @@ func ElementsMatch(t TestingT, listA interface{}, listB interface{}, msgAndArgs // listB(array, slice...) ignoring the order of the elements. If there are duplicate elements, // the number of appearances of each of them in both lists should match. // -// assert.ElementsMatchf(t, [1, 3, 2, 3], [1, 3, 3, 2], "error message %s", "formatted") +// require.ElementsMatchf(t, [1, 3, 2, 3], [1, 3, 3, 2], "error message %s", "formatted") func ElementsMatchf(t TestingT, listA interface{}, listB interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -117,10 +117,19 @@ func ElementsMatchf(t TestingT, listA interface{}, listB interface{}, msg string t.FailNow() } -// Empty asserts that the specified object is empty. I.e. nil, "", false, 0 or either -// a slice or a channel with len == 0. +// Empty asserts that the given value is "empty". // -// assert.Empty(t, obj) +// [Zero values] are "empty". +// +// Arrays are "empty" if every element is the zero value of the type (stricter than "empty"). +// +// Slices, maps and channels with zero length are "empty". +// +// Pointer values are "empty" if the pointer is nil or if the pointed value is "empty". +// +// require.Empty(t, obj) +// +// [Zero values]: https://go.dev/ref/spec#The_zero_value func Empty(t TestingT, object interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -131,10 +140,19 @@ func Empty(t TestingT, object interface{}, msgAndArgs ...interface{}) { t.FailNow() } -// Emptyf asserts that the specified object is empty. I.e. nil, "", false, 0 or either -// a slice or a channel with len == 0. +// Emptyf asserts that the given value is "empty". +// +// [Zero values] are "empty". +// +// Arrays are "empty" if every element is the zero value of the type (stricter than "empty"). +// +// Slices, maps and channels with zero length are "empty". +// +// Pointer values are "empty" if the pointer is nil or if the pointed value is "empty". +// +// require.Emptyf(t, obj, "error message %s", "formatted") // -// assert.Emptyf(t, obj, "error message %s", "formatted") +// [Zero values]: https://go.dev/ref/spec#The_zero_value func Emptyf(t TestingT, object interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -147,7 +165,7 @@ func Emptyf(t TestingT, object interface{}, msg string, args ...interface{}) { // Equal asserts that two objects are equal. // -// assert.Equal(t, 123, 123) +// require.Equal(t, 123, 123) // // Pointer variable equality is determined based on the equality of the // referenced values (as opposed to the memory addresses). Function equality @@ -166,7 +184,7 @@ func Equal(t TestingT, expected interface{}, actual interface{}, msgAndArgs ...i // and that it is equal to the provided error. // // actualObj, err := SomeFunction() -// assert.EqualError(t, err, expectedErrorString) +// require.EqualError(t, err, expectedErrorString) func EqualError(t TestingT, theError error, errString string, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -181,7 +199,7 @@ func EqualError(t TestingT, theError error, errString string, msgAndArgs ...inte // and that it is equal to the provided error. // // actualObj, err := SomeFunction() -// assert.EqualErrorf(t, err, expectedErrorString, "error message %s", "formatted") +// require.EqualErrorf(t, err, expectedErrorString, "error message %s", "formatted") func EqualErrorf(t TestingT, theError error, errString string, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -200,8 +218,8 @@ func EqualErrorf(t TestingT, theError error, errString string, msg string, args // Exported int // notExported int // } -// assert.EqualExportedValues(t, S{1, 2}, S{1, 3}) => true -// assert.EqualExportedValues(t, S{1, 2}, S{2, 3}) => false +// require.EqualExportedValues(t, S{1, 2}, S{1, 3}) => true +// require.EqualExportedValues(t, S{1, 2}, S{2, 3}) => false func EqualExportedValues(t TestingT, expected interface{}, actual interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -220,8 +238,8 @@ func EqualExportedValues(t TestingT, expected interface{}, actual interface{}, m // Exported int // notExported int // } -// assert.EqualExportedValuesf(t, S{1, 2}, S{1, 3}, "error message %s", "formatted") => true -// assert.EqualExportedValuesf(t, S{1, 2}, S{2, 3}, "error message %s", "formatted") => false +// require.EqualExportedValuesf(t, S{1, 2}, S{1, 3}, "error message %s", "formatted") => true +// require.EqualExportedValuesf(t, S{1, 2}, S{2, 3}, "error message %s", "formatted") => false func EqualExportedValuesf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -232,10 +250,10 @@ func EqualExportedValuesf(t TestingT, expected interface{}, actual interface{}, t.FailNow() } -// EqualValues asserts that two objects are equal or convertible to the same types -// and equal. +// EqualValues asserts that two objects are equal or convertible to the larger +// type and equal. // -// assert.EqualValues(t, uint32(123), int32(123)) +// require.EqualValues(t, uint32(123), int32(123)) func EqualValues(t TestingT, expected interface{}, actual interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -246,10 +264,10 @@ func EqualValues(t TestingT, expected interface{}, actual interface{}, msgAndArg t.FailNow() } -// EqualValuesf asserts that two objects are equal or convertible to the same types -// and equal. +// EqualValuesf asserts that two objects are equal or convertible to the larger +// type and equal. // -// assert.EqualValuesf(t, uint32(123), int32(123), "error message %s", "formatted") +// require.EqualValuesf(t, uint32(123), int32(123), "error message %s", "formatted") func EqualValuesf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -262,7 +280,7 @@ func EqualValuesf(t TestingT, expected interface{}, actual interface{}, msg stri // Equalf asserts that two objects are equal. // -// assert.Equalf(t, 123, 123, "error message %s", "formatted") +// require.Equalf(t, 123, 123, "error message %s", "formatted") // // Pointer variable equality is determined based on the equality of the // referenced values (as opposed to the memory addresses). Function equality @@ -279,10 +297,8 @@ func Equalf(t TestingT, expected interface{}, actual interface{}, msg string, ar // Error asserts that a function returned an error (i.e. not `nil`). // -// actualObj, err := SomeFunction() -// if assert.Error(t, err) { -// assert.Equal(t, expectedError, err) -// } +// actualObj, err := SomeFunction() +// require.Error(t, err) func Error(t TestingT, err error, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -321,7 +337,7 @@ func ErrorAsf(t TestingT, err error, target interface{}, msg string, args ...int // and that the error contains the specified substring. // // actualObj, err := SomeFunction() -// assert.ErrorContains(t, err, expectedErrorSubString) +// require.ErrorContains(t, err, expectedErrorSubString) func ErrorContains(t TestingT, theError error, contains string, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -336,7 +352,7 @@ func ErrorContains(t TestingT, theError error, contains string, msgAndArgs ...in // and that the error contains the specified substring. // // actualObj, err := SomeFunction() -// assert.ErrorContainsf(t, err, expectedErrorSubString, "error message %s", "formatted") +// require.ErrorContainsf(t, err, expectedErrorSubString, "error message %s", "formatted") func ErrorContainsf(t TestingT, theError error, contains string, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -373,10 +389,8 @@ func ErrorIsf(t TestingT, err error, target error, msg string, args ...interface // Errorf asserts that a function returned an error (i.e. not `nil`). // -// actualObj, err := SomeFunction() -// if assert.Errorf(t, err, "error message %s", "formatted") { -// assert.Equal(t, expectedErrorf, err) -// } +// actualObj, err := SomeFunction() +// require.Errorf(t, err, "error message %s", "formatted") func Errorf(t TestingT, err error, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -390,7 +404,7 @@ func Errorf(t TestingT, err error, msg string, args ...interface{}) { // Eventually asserts that given condition will be met in waitFor time, // periodically checking target function each tick. // -// assert.Eventually(t, func() bool { return true; }, time.Second, 10*time.Millisecond) +// require.Eventually(t, func() bool { return true; }, time.Second, 10*time.Millisecond) func Eventually(t TestingT, condition func() bool, waitFor time.Duration, tick time.Duration, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -415,10 +429,10 @@ func Eventually(t TestingT, condition func() bool, waitFor time.Duration, tick t // time.Sleep(8*time.Second) // externalValue = true // }() -// assert.EventuallyWithT(t, func(c *assert.CollectT) { +// require.EventuallyWithT(t, func(c *require.CollectT) { // // add assertions as needed; any assertion failure will fail the current tick -// assert.True(c, externalValue, "expected 'externalValue' to be true") -// }, 1*time.Second, 10*time.Second, "external state has not changed to 'true'; still false") +// require.True(c, externalValue, "expected 'externalValue' to be true") +// }, 10*time.Second, 1*time.Second, "external state has not changed to 'true'; still false") func EventuallyWithT(t TestingT, condition func(collect *assert.CollectT), waitFor time.Duration, tick time.Duration, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -443,10 +457,10 @@ func EventuallyWithT(t TestingT, condition func(collect *assert.CollectT), waitF // time.Sleep(8*time.Second) // externalValue = true // }() -// assert.EventuallyWithTf(t, func(c *assert.CollectT, "error message %s", "formatted") { +// require.EventuallyWithTf(t, func(c *require.CollectT, "error message %s", "formatted") { // // add assertions as needed; any assertion failure will fail the current tick -// assert.True(c, externalValue, "expected 'externalValue' to be true") -// }, 1*time.Second, 10*time.Second, "external state has not changed to 'true'; still false") +// require.True(c, externalValue, "expected 'externalValue' to be true") +// }, 10*time.Second, 1*time.Second, "external state has not changed to 'true'; still false") func EventuallyWithTf(t TestingT, condition func(collect *assert.CollectT), waitFor time.Duration, tick time.Duration, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -460,7 +474,7 @@ func EventuallyWithTf(t TestingT, condition func(collect *assert.CollectT), wait // Eventuallyf asserts that given condition will be met in waitFor time, // periodically checking target function each tick. // -// assert.Eventuallyf(t, func() bool { return true; }, time.Second, 10*time.Millisecond, "error message %s", "formatted") +// require.Eventuallyf(t, func() bool { return true; }, time.Second, 10*time.Millisecond, "error message %s", "formatted") func Eventuallyf(t TestingT, condition func() bool, waitFor time.Duration, tick time.Duration, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -473,7 +487,7 @@ func Eventuallyf(t TestingT, condition func() bool, waitFor time.Duration, tick // Exactly asserts that two objects are equal in value and type. // -// assert.Exactly(t, int32(123), int64(123)) +// require.Exactly(t, int32(123), int64(123)) func Exactly(t TestingT, expected interface{}, actual interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -486,7 +500,7 @@ func Exactly(t TestingT, expected interface{}, actual interface{}, msgAndArgs .. // Exactlyf asserts that two objects are equal in value and type. // -// assert.Exactlyf(t, int32(123), int64(123), "error message %s", "formatted") +// require.Exactlyf(t, int32(123), int64(123), "error message %s", "formatted") func Exactlyf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -543,7 +557,7 @@ func Failf(t TestingT, failureMessage string, msg string, args ...interface{}) { // False asserts that the specified value is false. // -// assert.False(t, myBool) +// require.False(t, myBool) func False(t TestingT, value bool, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -556,7 +570,7 @@ func False(t TestingT, value bool, msgAndArgs ...interface{}) { // Falsef asserts that the specified value is false. // -// assert.Falsef(t, myBool, "error message %s", "formatted") +// require.Falsef(t, myBool, "error message %s", "formatted") func Falsef(t TestingT, value bool, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -593,9 +607,9 @@ func FileExistsf(t TestingT, path string, msg string, args ...interface{}) { // Greater asserts that the first element is greater than the second // -// assert.Greater(t, 2, 1) -// assert.Greater(t, float64(2), float64(1)) -// assert.Greater(t, "b", "a") +// require.Greater(t, 2, 1) +// require.Greater(t, float64(2), float64(1)) +// require.Greater(t, "b", "a") func Greater(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -608,10 +622,10 @@ func Greater(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface // GreaterOrEqual asserts that the first element is greater than or equal to the second // -// assert.GreaterOrEqual(t, 2, 1) -// assert.GreaterOrEqual(t, 2, 2) -// assert.GreaterOrEqual(t, "b", "a") -// assert.GreaterOrEqual(t, "b", "b") +// require.GreaterOrEqual(t, 2, 1) +// require.GreaterOrEqual(t, 2, 2) +// require.GreaterOrEqual(t, "b", "a") +// require.GreaterOrEqual(t, "b", "b") func GreaterOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -624,10 +638,10 @@ func GreaterOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...in // GreaterOrEqualf asserts that the first element is greater than or equal to the second // -// assert.GreaterOrEqualf(t, 2, 1, "error message %s", "formatted") -// assert.GreaterOrEqualf(t, 2, 2, "error message %s", "formatted") -// assert.GreaterOrEqualf(t, "b", "a", "error message %s", "formatted") -// assert.GreaterOrEqualf(t, "b", "b", "error message %s", "formatted") +// require.GreaterOrEqualf(t, 2, 1, "error message %s", "formatted") +// require.GreaterOrEqualf(t, 2, 2, "error message %s", "formatted") +// require.GreaterOrEqualf(t, "b", "a", "error message %s", "formatted") +// require.GreaterOrEqualf(t, "b", "b", "error message %s", "formatted") func GreaterOrEqualf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -640,9 +654,9 @@ func GreaterOrEqualf(t TestingT, e1 interface{}, e2 interface{}, msg string, arg // Greaterf asserts that the first element is greater than the second // -// assert.Greaterf(t, 2, 1, "error message %s", "formatted") -// assert.Greaterf(t, float64(2), float64(1), "error message %s", "formatted") -// assert.Greaterf(t, "b", "a", "error message %s", "formatted") +// require.Greaterf(t, 2, 1, "error message %s", "formatted") +// require.Greaterf(t, float64(2), float64(1), "error message %s", "formatted") +// require.Greaterf(t, "b", "a", "error message %s", "formatted") func Greaterf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -656,7 +670,7 @@ func Greaterf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...in // HTTPBodyContains asserts that a specified handler returns a // body that contains a string. // -// assert.HTTPBodyContains(t, myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky") +// require.HTTPBodyContains(t, myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky") // // Returns whether the assertion was successful (true) or not (false). func HTTPBodyContains(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, str interface{}, msgAndArgs ...interface{}) { @@ -672,7 +686,7 @@ func HTTPBodyContains(t TestingT, handler http.HandlerFunc, method string, url s // HTTPBodyContainsf asserts that a specified handler returns a // body that contains a string. // -// assert.HTTPBodyContainsf(t, myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky", "error message %s", "formatted") +// require.HTTPBodyContainsf(t, myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky", "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func HTTPBodyContainsf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, str interface{}, msg string, args ...interface{}) { @@ -688,7 +702,7 @@ func HTTPBodyContainsf(t TestingT, handler http.HandlerFunc, method string, url // HTTPBodyNotContains asserts that a specified handler returns a // body that does not contain a string. // -// assert.HTTPBodyNotContains(t, myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky") +// require.HTTPBodyNotContains(t, myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky") // // Returns whether the assertion was successful (true) or not (false). func HTTPBodyNotContains(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, str interface{}, msgAndArgs ...interface{}) { @@ -704,7 +718,7 @@ func HTTPBodyNotContains(t TestingT, handler http.HandlerFunc, method string, ur // HTTPBodyNotContainsf asserts that a specified handler returns a // body that does not contain a string. // -// assert.HTTPBodyNotContainsf(t, myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky", "error message %s", "formatted") +// require.HTTPBodyNotContainsf(t, myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky", "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func HTTPBodyNotContainsf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, str interface{}, msg string, args ...interface{}) { @@ -719,7 +733,7 @@ func HTTPBodyNotContainsf(t TestingT, handler http.HandlerFunc, method string, u // HTTPError asserts that a specified handler returns an error status code. // -// assert.HTTPError(t, myHandler, "POST", "/a/b/c", url.Values{"a": []string{"b", "c"}} +// require.HTTPError(t, myHandler, "POST", "/a/b/c", url.Values{"a": []string{"b", "c"}} // // Returns whether the assertion was successful (true) or not (false). func HTTPError(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, msgAndArgs ...interface{}) { @@ -734,7 +748,7 @@ func HTTPError(t TestingT, handler http.HandlerFunc, method string, url string, // HTTPErrorf asserts that a specified handler returns an error status code. // -// assert.HTTPErrorf(t, myHandler, "POST", "/a/b/c", url.Values{"a": []string{"b", "c"}} +// require.HTTPErrorf(t, myHandler, "POST", "/a/b/c", url.Values{"a": []string{"b", "c"}} // // Returns whether the assertion was successful (true) or not (false). func HTTPErrorf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, msg string, args ...interface{}) { @@ -749,7 +763,7 @@ func HTTPErrorf(t TestingT, handler http.HandlerFunc, method string, url string, // HTTPRedirect asserts that a specified handler returns a redirect status code. // -// assert.HTTPRedirect(t, myHandler, "GET", "/a/b/c", url.Values{"a": []string{"b", "c"}} +// require.HTTPRedirect(t, myHandler, "GET", "/a/b/c", url.Values{"a": []string{"b", "c"}} // // Returns whether the assertion was successful (true) or not (false). func HTTPRedirect(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, msgAndArgs ...interface{}) { @@ -764,7 +778,7 @@ func HTTPRedirect(t TestingT, handler http.HandlerFunc, method string, url strin // HTTPRedirectf asserts that a specified handler returns a redirect status code. // -// assert.HTTPRedirectf(t, myHandler, "GET", "/a/b/c", url.Values{"a": []string{"b", "c"}} +// require.HTTPRedirectf(t, myHandler, "GET", "/a/b/c", url.Values{"a": []string{"b", "c"}} // // Returns whether the assertion was successful (true) or not (false). func HTTPRedirectf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, msg string, args ...interface{}) { @@ -779,7 +793,7 @@ func HTTPRedirectf(t TestingT, handler http.HandlerFunc, method string, url stri // HTTPStatusCode asserts that a specified handler returns a specified status code. // -// assert.HTTPStatusCode(t, myHandler, "GET", "/notImplemented", nil, 501) +// require.HTTPStatusCode(t, myHandler, "GET", "/notImplemented", nil, 501) // // Returns whether the assertion was successful (true) or not (false). func HTTPStatusCode(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, statuscode int, msgAndArgs ...interface{}) { @@ -794,7 +808,7 @@ func HTTPStatusCode(t TestingT, handler http.HandlerFunc, method string, url str // HTTPStatusCodef asserts that a specified handler returns a specified status code. // -// assert.HTTPStatusCodef(t, myHandler, "GET", "/notImplemented", nil, 501, "error message %s", "formatted") +// require.HTTPStatusCodef(t, myHandler, "GET", "/notImplemented", nil, 501, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func HTTPStatusCodef(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, statuscode int, msg string, args ...interface{}) { @@ -809,7 +823,7 @@ func HTTPStatusCodef(t TestingT, handler http.HandlerFunc, method string, url st // HTTPSuccess asserts that a specified handler returns a success status code. // -// assert.HTTPSuccess(t, myHandler, "POST", "http://www.google.com", nil) +// require.HTTPSuccess(t, myHandler, "POST", "http://www.google.com", nil) // // Returns whether the assertion was successful (true) or not (false). func HTTPSuccess(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, msgAndArgs ...interface{}) { @@ -824,7 +838,7 @@ func HTTPSuccess(t TestingT, handler http.HandlerFunc, method string, url string // HTTPSuccessf asserts that a specified handler returns a success status code. // -// assert.HTTPSuccessf(t, myHandler, "POST", "http://www.google.com", nil, "error message %s", "formatted") +// require.HTTPSuccessf(t, myHandler, "POST", "http://www.google.com", nil, "error message %s", "formatted") // // Returns whether the assertion was successful (true) or not (false). func HTTPSuccessf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, msg string, args ...interface{}) { @@ -839,7 +853,7 @@ func HTTPSuccessf(t TestingT, handler http.HandlerFunc, method string, url strin // Implements asserts that an object is implemented by the specified interface. // -// assert.Implements(t, (*MyInterface)(nil), new(MyObject)) +// require.Implements(t, (*MyInterface)(nil), new(MyObject)) func Implements(t TestingT, interfaceObject interface{}, object interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -852,7 +866,7 @@ func Implements(t TestingT, interfaceObject interface{}, object interface{}, msg // Implementsf asserts that an object is implemented by the specified interface. // -// assert.Implementsf(t, (*MyInterface)(nil), new(MyObject), "error message %s", "formatted") +// require.Implementsf(t, (*MyInterface)(nil), new(MyObject), "error message %s", "formatted") func Implementsf(t TestingT, interfaceObject interface{}, object interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -865,7 +879,7 @@ func Implementsf(t TestingT, interfaceObject interface{}, object interface{}, ms // InDelta asserts that the two numerals are within delta of each other. // -// assert.InDelta(t, math.Pi, 22/7.0, 0.01) +// require.InDelta(t, math.Pi, 22/7.0, 0.01) func InDelta(t TestingT, expected interface{}, actual interface{}, delta float64, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -922,7 +936,7 @@ func InDeltaSlicef(t TestingT, expected interface{}, actual interface{}, delta f // InDeltaf asserts that the two numerals are within delta of each other. // -// assert.InDeltaf(t, math.Pi, 22/7.0, 0.01, "error message %s", "formatted") +// require.InDeltaf(t, math.Pi, 22/7.0, 0.01, "error message %s", "formatted") func InDeltaf(t TestingT, expected interface{}, actual interface{}, delta float64, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -979,9 +993,9 @@ func InEpsilonf(t TestingT, expected interface{}, actual interface{}, epsilon fl // IsDecreasing asserts that the collection is decreasing // -// assert.IsDecreasing(t, []int{2, 1, 0}) -// assert.IsDecreasing(t, []float{2, 1}) -// assert.IsDecreasing(t, []string{"b", "a"}) +// require.IsDecreasing(t, []int{2, 1, 0}) +// require.IsDecreasing(t, []float{2, 1}) +// require.IsDecreasing(t, []string{"b", "a"}) func IsDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -994,9 +1008,9 @@ func IsDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) { // IsDecreasingf asserts that the collection is decreasing // -// assert.IsDecreasingf(t, []int{2, 1, 0}, "error message %s", "formatted") -// assert.IsDecreasingf(t, []float{2, 1}, "error message %s", "formatted") -// assert.IsDecreasingf(t, []string{"b", "a"}, "error message %s", "formatted") +// require.IsDecreasingf(t, []int{2, 1, 0}, "error message %s", "formatted") +// require.IsDecreasingf(t, []float{2, 1}, "error message %s", "formatted") +// require.IsDecreasingf(t, []string{"b", "a"}, "error message %s", "formatted") func IsDecreasingf(t TestingT, object interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1009,9 +1023,9 @@ func IsDecreasingf(t TestingT, object interface{}, msg string, args ...interface // IsIncreasing asserts that the collection is increasing // -// assert.IsIncreasing(t, []int{1, 2, 3}) -// assert.IsIncreasing(t, []float{1, 2}) -// assert.IsIncreasing(t, []string{"a", "b"}) +// require.IsIncreasing(t, []int{1, 2, 3}) +// require.IsIncreasing(t, []float{1, 2}) +// require.IsIncreasing(t, []string{"a", "b"}) func IsIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1024,9 +1038,9 @@ func IsIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) { // IsIncreasingf asserts that the collection is increasing // -// assert.IsIncreasingf(t, []int{1, 2, 3}, "error message %s", "formatted") -// assert.IsIncreasingf(t, []float{1, 2}, "error message %s", "formatted") -// assert.IsIncreasingf(t, []string{"a", "b"}, "error message %s", "formatted") +// require.IsIncreasingf(t, []int{1, 2, 3}, "error message %s", "formatted") +// require.IsIncreasingf(t, []float{1, 2}, "error message %s", "formatted") +// require.IsIncreasingf(t, []string{"a", "b"}, "error message %s", "formatted") func IsIncreasingf(t TestingT, object interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1039,9 +1053,9 @@ func IsIncreasingf(t TestingT, object interface{}, msg string, args ...interface // IsNonDecreasing asserts that the collection is not decreasing // -// assert.IsNonDecreasing(t, []int{1, 1, 2}) -// assert.IsNonDecreasing(t, []float{1, 2}) -// assert.IsNonDecreasing(t, []string{"a", "b"}) +// require.IsNonDecreasing(t, []int{1, 1, 2}) +// require.IsNonDecreasing(t, []float{1, 2}) +// require.IsNonDecreasing(t, []string{"a", "b"}) func IsNonDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1054,9 +1068,9 @@ func IsNonDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) // IsNonDecreasingf asserts that the collection is not decreasing // -// assert.IsNonDecreasingf(t, []int{1, 1, 2}, "error message %s", "formatted") -// assert.IsNonDecreasingf(t, []float{1, 2}, "error message %s", "formatted") -// assert.IsNonDecreasingf(t, []string{"a", "b"}, "error message %s", "formatted") +// require.IsNonDecreasingf(t, []int{1, 1, 2}, "error message %s", "formatted") +// require.IsNonDecreasingf(t, []float{1, 2}, "error message %s", "formatted") +// require.IsNonDecreasingf(t, []string{"a", "b"}, "error message %s", "formatted") func IsNonDecreasingf(t TestingT, object interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1069,9 +1083,9 @@ func IsNonDecreasingf(t TestingT, object interface{}, msg string, args ...interf // IsNonIncreasing asserts that the collection is not increasing // -// assert.IsNonIncreasing(t, []int{2, 1, 1}) -// assert.IsNonIncreasing(t, []float{2, 1}) -// assert.IsNonIncreasing(t, []string{"b", "a"}) +// require.IsNonIncreasing(t, []int{2, 1, 1}) +// require.IsNonIncreasing(t, []float{2, 1}) +// require.IsNonIncreasing(t, []string{"b", "a"}) func IsNonIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1084,9 +1098,9 @@ func IsNonIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) // IsNonIncreasingf asserts that the collection is not increasing // -// assert.IsNonIncreasingf(t, []int{2, 1, 1}, "error message %s", "formatted") -// assert.IsNonIncreasingf(t, []float{2, 1}, "error message %s", "formatted") -// assert.IsNonIncreasingf(t, []string{"b", "a"}, "error message %s", "formatted") +// require.IsNonIncreasingf(t, []int{2, 1, 1}, "error message %s", "formatted") +// require.IsNonIncreasingf(t, []float{2, 1}, "error message %s", "formatted") +// require.IsNonIncreasingf(t, []string{"b", "a"}, "error message %s", "formatted") func IsNonIncreasingf(t TestingT, object interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1097,7 +1111,35 @@ func IsNonIncreasingf(t TestingT, object interface{}, msg string, args ...interf t.FailNow() } +// IsNotType asserts that the specified objects are not of the same type. +// +// require.IsNotType(t, &NotMyStruct{}, &MyStruct{}) +func IsNotType(t TestingT, theType interface{}, object interface{}, msgAndArgs ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.IsNotType(t, theType, object, msgAndArgs...) { + return + } + t.FailNow() +} + +// IsNotTypef asserts that the specified objects are not of the same type. +// +// require.IsNotTypef(t, &NotMyStruct{}, &MyStruct{}, "error message %s", "formatted") +func IsNotTypef(t TestingT, theType interface{}, object interface{}, msg string, args ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.IsNotTypef(t, theType, object, msg, args...) { + return + } + t.FailNow() +} + // IsType asserts that the specified objects are of the same type. +// +// require.IsType(t, &MyStruct{}, &MyStruct{}) func IsType(t TestingT, expectedType interface{}, object interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1109,6 +1151,8 @@ func IsType(t TestingT, expectedType interface{}, object interface{}, msgAndArgs } // IsTypef asserts that the specified objects are of the same type. +// +// require.IsTypef(t, &MyStruct{}, &MyStruct{}, "error message %s", "formatted") func IsTypef(t TestingT, expectedType interface{}, object interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1121,7 +1165,7 @@ func IsTypef(t TestingT, expectedType interface{}, object interface{}, msg strin // JSONEq asserts that two JSON strings are equivalent. // -// assert.JSONEq(t, `{"hello": "world", "foo": "bar"}`, `{"foo": "bar", "hello": "world"}`) +// require.JSONEq(t, `{"hello": "world", "foo": "bar"}`, `{"foo": "bar", "hello": "world"}`) func JSONEq(t TestingT, expected string, actual string, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1134,7 +1178,7 @@ func JSONEq(t TestingT, expected string, actual string, msgAndArgs ...interface{ // JSONEqf asserts that two JSON strings are equivalent. // -// assert.JSONEqf(t, `{"hello": "world", "foo": "bar"}`, `{"foo": "bar", "hello": "world"}`, "error message %s", "formatted") +// require.JSONEqf(t, `{"hello": "world", "foo": "bar"}`, `{"foo": "bar", "hello": "world"}`, "error message %s", "formatted") func JSONEqf(t TestingT, expected string, actual string, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1148,7 +1192,7 @@ func JSONEqf(t TestingT, expected string, actual string, msg string, args ...int // Len asserts that the specified object has specific length. // Len also fails if the object has a type that len() not accept. // -// assert.Len(t, mySlice, 3) +// require.Len(t, mySlice, 3) func Len(t TestingT, object interface{}, length int, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1162,7 +1206,7 @@ func Len(t TestingT, object interface{}, length int, msgAndArgs ...interface{}) // Lenf asserts that the specified object has specific length. // Lenf also fails if the object has a type that len() not accept. // -// assert.Lenf(t, mySlice, 3, "error message %s", "formatted") +// require.Lenf(t, mySlice, 3, "error message %s", "formatted") func Lenf(t TestingT, object interface{}, length int, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1175,9 +1219,9 @@ func Lenf(t TestingT, object interface{}, length int, msg string, args ...interf // Less asserts that the first element is less than the second // -// assert.Less(t, 1, 2) -// assert.Less(t, float64(1), float64(2)) -// assert.Less(t, "a", "b") +// require.Less(t, 1, 2) +// require.Less(t, float64(1), float64(2)) +// require.Less(t, "a", "b") func Less(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1190,10 +1234,10 @@ func Less(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) // LessOrEqual asserts that the first element is less than or equal to the second // -// assert.LessOrEqual(t, 1, 2) -// assert.LessOrEqual(t, 2, 2) -// assert.LessOrEqual(t, "a", "b") -// assert.LessOrEqual(t, "b", "b") +// require.LessOrEqual(t, 1, 2) +// require.LessOrEqual(t, 2, 2) +// require.LessOrEqual(t, "a", "b") +// require.LessOrEqual(t, "b", "b") func LessOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1206,10 +1250,10 @@ func LessOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...inter // LessOrEqualf asserts that the first element is less than or equal to the second // -// assert.LessOrEqualf(t, 1, 2, "error message %s", "formatted") -// assert.LessOrEqualf(t, 2, 2, "error message %s", "formatted") -// assert.LessOrEqualf(t, "a", "b", "error message %s", "formatted") -// assert.LessOrEqualf(t, "b", "b", "error message %s", "formatted") +// require.LessOrEqualf(t, 1, 2, "error message %s", "formatted") +// require.LessOrEqualf(t, 2, 2, "error message %s", "formatted") +// require.LessOrEqualf(t, "a", "b", "error message %s", "formatted") +// require.LessOrEqualf(t, "b", "b", "error message %s", "formatted") func LessOrEqualf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1222,9 +1266,9 @@ func LessOrEqualf(t TestingT, e1 interface{}, e2 interface{}, msg string, args . // Lessf asserts that the first element is less than the second // -// assert.Lessf(t, 1, 2, "error message %s", "formatted") -// assert.Lessf(t, float64(1), float64(2), "error message %s", "formatted") -// assert.Lessf(t, "a", "b", "error message %s", "formatted") +// require.Lessf(t, 1, 2, "error message %s", "formatted") +// require.Lessf(t, float64(1), float64(2), "error message %s", "formatted") +// require.Lessf(t, "a", "b", "error message %s", "formatted") func Lessf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1237,8 +1281,8 @@ func Lessf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...inter // Negative asserts that the specified element is negative // -// assert.Negative(t, -1) -// assert.Negative(t, -1.23) +// require.Negative(t, -1) +// require.Negative(t, -1.23) func Negative(t TestingT, e interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1251,8 +1295,8 @@ func Negative(t TestingT, e interface{}, msgAndArgs ...interface{}) { // Negativef asserts that the specified element is negative // -// assert.Negativef(t, -1, "error message %s", "formatted") -// assert.Negativef(t, -1.23, "error message %s", "formatted") +// require.Negativef(t, -1, "error message %s", "formatted") +// require.Negativef(t, -1.23, "error message %s", "formatted") func Negativef(t TestingT, e interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1266,7 +1310,7 @@ func Negativef(t TestingT, e interface{}, msg string, args ...interface{}) { // Never asserts that the given condition doesn't satisfy in waitFor time, // periodically checking the target function each tick. // -// assert.Never(t, func() bool { return false; }, time.Second, 10*time.Millisecond) +// require.Never(t, func() bool { return false; }, time.Second, 10*time.Millisecond) func Never(t TestingT, condition func() bool, waitFor time.Duration, tick time.Duration, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1280,7 +1324,7 @@ func Never(t TestingT, condition func() bool, waitFor time.Duration, tick time.D // Neverf asserts that the given condition doesn't satisfy in waitFor time, // periodically checking the target function each tick. // -// assert.Neverf(t, func() bool { return false; }, time.Second, 10*time.Millisecond, "error message %s", "formatted") +// require.Neverf(t, func() bool { return false; }, time.Second, 10*time.Millisecond, "error message %s", "formatted") func Neverf(t TestingT, condition func() bool, waitFor time.Duration, tick time.Duration, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1293,7 +1337,7 @@ func Neverf(t TestingT, condition func() bool, waitFor time.Duration, tick time. // Nil asserts that the specified object is nil. // -// assert.Nil(t, err) +// require.Nil(t, err) func Nil(t TestingT, object interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1306,7 +1350,7 @@ func Nil(t TestingT, object interface{}, msgAndArgs ...interface{}) { // Nilf asserts that the specified object is nil. // -// assert.Nilf(t, err, "error message %s", "formatted") +// require.Nilf(t, err, "error message %s", "formatted") func Nilf(t TestingT, object interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1344,8 +1388,8 @@ func NoDirExistsf(t TestingT, path string, msg string, args ...interface{}) { // NoError asserts that a function returned no error (i.e. `nil`). // // actualObj, err := SomeFunction() -// if assert.NoError(t, err) { -// assert.Equal(t, expectedObj, actualObj) +// if require.NoError(t, err) { +// require.Equal(t, expectedObj, actualObj) // } func NoError(t TestingT, err error, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { @@ -1360,8 +1404,8 @@ func NoError(t TestingT, err error, msgAndArgs ...interface{}) { // NoErrorf asserts that a function returned no error (i.e. `nil`). // // actualObj, err := SomeFunction() -// if assert.NoErrorf(t, err, "error message %s", "formatted") { -// assert.Equal(t, expectedObj, actualObj) +// if require.NoErrorf(t, err, "error message %s", "formatted") { +// require.Equal(t, expectedObj, actualObj) // } func NoErrorf(t TestingT, err error, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { @@ -1400,9 +1444,9 @@ func NoFileExistsf(t TestingT, path string, msg string, args ...interface{}) { // NotContains asserts that the specified string, list(array, slice...) or map does NOT contain the // specified substring or element. // -// assert.NotContains(t, "Hello World", "Earth") -// assert.NotContains(t, ["Hello", "World"], "Earth") -// assert.NotContains(t, {"Hello": "World"}, "Earth") +// require.NotContains(t, "Hello World", "Earth") +// require.NotContains(t, ["Hello", "World"], "Earth") +// require.NotContains(t, {"Hello": "World"}, "Earth") func NotContains(t TestingT, s interface{}, contains interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1416,9 +1460,9 @@ func NotContains(t TestingT, s interface{}, contains interface{}, msgAndArgs ... // NotContainsf asserts that the specified string, list(array, slice...) or map does NOT contain the // specified substring or element. // -// assert.NotContainsf(t, "Hello World", "Earth", "error message %s", "formatted") -// assert.NotContainsf(t, ["Hello", "World"], "Earth", "error message %s", "formatted") -// assert.NotContainsf(t, {"Hello": "World"}, "Earth", "error message %s", "formatted") +// require.NotContainsf(t, "Hello World", "Earth", "error message %s", "formatted") +// require.NotContainsf(t, ["Hello", "World"], "Earth", "error message %s", "formatted") +// require.NotContainsf(t, {"Hello": "World"}, "Earth", "error message %s", "formatted") func NotContainsf(t TestingT, s interface{}, contains interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1429,11 +1473,50 @@ func NotContainsf(t TestingT, s interface{}, contains interface{}, msg string, a t.FailNow() } -// NotEmpty asserts that the specified object is NOT empty. I.e. not nil, "", false, 0 or either -// a slice or a channel with len == 0. +// NotElementsMatch asserts that the specified listA(array, slice...) is NOT equal to specified +// listB(array, slice...) ignoring the order of the elements. If there are duplicate elements, +// the number of appearances of each of them in both lists should not match. +// This is an inverse of ElementsMatch. +// +// require.NotElementsMatch(t, [1, 1, 2, 3], [1, 1, 2, 3]) -> false +// +// require.NotElementsMatch(t, [1, 1, 2, 3], [1, 2, 3]) -> true +// +// require.NotElementsMatch(t, [1, 2, 3], [1, 2, 4]) -> true +func NotElementsMatch(t TestingT, listA interface{}, listB interface{}, msgAndArgs ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.NotElementsMatch(t, listA, listB, msgAndArgs...) { + return + } + t.FailNow() +} + +// NotElementsMatchf asserts that the specified listA(array, slice...) is NOT equal to specified +// listB(array, slice...) ignoring the order of the elements. If there are duplicate elements, +// the number of appearances of each of them in both lists should not match. +// This is an inverse of ElementsMatch. +// +// require.NotElementsMatchf(t, [1, 1, 2, 3], [1, 1, 2, 3], "error message %s", "formatted") -> false +// +// require.NotElementsMatchf(t, [1, 1, 2, 3], [1, 2, 3], "error message %s", "formatted") -> true +// +// require.NotElementsMatchf(t, [1, 2, 3], [1, 2, 4], "error message %s", "formatted") -> true +func NotElementsMatchf(t TestingT, listA interface{}, listB interface{}, msg string, args ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.NotElementsMatchf(t, listA, listB, msg, args...) { + return + } + t.FailNow() +} + +// NotEmpty asserts that the specified object is NOT [Empty]. // -// if assert.NotEmpty(t, obj) { -// assert.Equal(t, "two", obj[1]) +// if require.NotEmpty(t, obj) { +// require.Equal(t, "two", obj[1]) // } func NotEmpty(t TestingT, object interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { @@ -1445,11 +1528,10 @@ func NotEmpty(t TestingT, object interface{}, msgAndArgs ...interface{}) { t.FailNow() } -// NotEmptyf asserts that the specified object is NOT empty. I.e. not nil, "", false, 0 or either -// a slice or a channel with len == 0. +// NotEmptyf asserts that the specified object is NOT [Empty]. // -// if assert.NotEmptyf(t, obj, "error message %s", "formatted") { -// assert.Equal(t, "two", obj[1]) +// if require.NotEmptyf(t, obj, "error message %s", "formatted") { +// require.Equal(t, "two", obj[1]) // } func NotEmptyf(t TestingT, object interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { @@ -1463,7 +1545,7 @@ func NotEmptyf(t TestingT, object interface{}, msg string, args ...interface{}) // NotEqual asserts that the specified values are NOT equal. // -// assert.NotEqual(t, obj1, obj2) +// require.NotEqual(t, obj1, obj2) // // Pointer variable equality is determined based on the equality of the // referenced values (as opposed to the memory addresses). @@ -1479,7 +1561,7 @@ func NotEqual(t TestingT, expected interface{}, actual interface{}, msgAndArgs . // NotEqualValues asserts that two objects are not equal even when converted to the same type // -// assert.NotEqualValues(t, obj1, obj2) +// require.NotEqualValues(t, obj1, obj2) func NotEqualValues(t TestingT, expected interface{}, actual interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1492,7 +1574,7 @@ func NotEqualValues(t TestingT, expected interface{}, actual interface{}, msgAnd // NotEqualValuesf asserts that two objects are not equal even when converted to the same type // -// assert.NotEqualValuesf(t, obj1, obj2, "error message %s", "formatted") +// require.NotEqualValuesf(t, obj1, obj2, "error message %s", "formatted") func NotEqualValuesf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1505,7 +1587,7 @@ func NotEqualValuesf(t TestingT, expected interface{}, actual interface{}, msg s // NotEqualf asserts that the specified values are NOT equal. // -// assert.NotEqualf(t, obj1, obj2, "error message %s", "formatted") +// require.NotEqualf(t, obj1, obj2, "error message %s", "formatted") // // Pointer variable equality is determined based on the equality of the // referenced values (as opposed to the memory addresses). @@ -1519,7 +1601,31 @@ func NotEqualf(t TestingT, expected interface{}, actual interface{}, msg string, t.FailNow() } -// NotErrorIs asserts that at none of the errors in err's chain matches target. +// NotErrorAs asserts that none of the errors in err's chain matches target, +// but if so, sets target to that error value. +func NotErrorAs(t TestingT, err error, target interface{}, msgAndArgs ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.NotErrorAs(t, err, target, msgAndArgs...) { + return + } + t.FailNow() +} + +// NotErrorAsf asserts that none of the errors in err's chain matches target, +// but if so, sets target to that error value. +func NotErrorAsf(t TestingT, err error, target interface{}, msg string, args ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.NotErrorAsf(t, err, target, msg, args...) { + return + } + t.FailNow() +} + +// NotErrorIs asserts that none of the errors in err's chain matches target. // This is a wrapper for errors.Is. func NotErrorIs(t TestingT, err error, target error, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { @@ -1531,7 +1637,7 @@ func NotErrorIs(t TestingT, err error, target error, msgAndArgs ...interface{}) t.FailNow() } -// NotErrorIsf asserts that at none of the errors in err's chain matches target. +// NotErrorIsf asserts that none of the errors in err's chain matches target. // This is a wrapper for errors.Is. func NotErrorIsf(t TestingT, err error, target error, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { @@ -1545,7 +1651,7 @@ func NotErrorIsf(t TestingT, err error, target error, msg string, args ...interf // NotImplements asserts that an object does not implement the specified interface. // -// assert.NotImplements(t, (*MyInterface)(nil), new(MyObject)) +// require.NotImplements(t, (*MyInterface)(nil), new(MyObject)) func NotImplements(t TestingT, interfaceObject interface{}, object interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1558,7 +1664,7 @@ func NotImplements(t TestingT, interfaceObject interface{}, object interface{}, // NotImplementsf asserts that an object does not implement the specified interface. // -// assert.NotImplementsf(t, (*MyInterface)(nil), new(MyObject), "error message %s", "formatted") +// require.NotImplementsf(t, (*MyInterface)(nil), new(MyObject), "error message %s", "formatted") func NotImplementsf(t TestingT, interfaceObject interface{}, object interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1571,7 +1677,7 @@ func NotImplementsf(t TestingT, interfaceObject interface{}, object interface{}, // NotNil asserts that the specified object is not nil. // -// assert.NotNil(t, err) +// require.NotNil(t, err) func NotNil(t TestingT, object interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1584,7 +1690,7 @@ func NotNil(t TestingT, object interface{}, msgAndArgs ...interface{}) { // NotNilf asserts that the specified object is not nil. // -// assert.NotNilf(t, err, "error message %s", "formatted") +// require.NotNilf(t, err, "error message %s", "formatted") func NotNilf(t TestingT, object interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1597,7 +1703,7 @@ func NotNilf(t TestingT, object interface{}, msg string, args ...interface{}) { // NotPanics asserts that the code inside the specified PanicTestFunc does NOT panic. // -// assert.NotPanics(t, func(){ RemainCalm() }) +// require.NotPanics(t, func(){ RemainCalm() }) func NotPanics(t TestingT, f assert.PanicTestFunc, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1610,7 +1716,7 @@ func NotPanics(t TestingT, f assert.PanicTestFunc, msgAndArgs ...interface{}) { // NotPanicsf asserts that the code inside the specified PanicTestFunc does NOT panic. // -// assert.NotPanicsf(t, func(){ RemainCalm() }, "error message %s", "formatted") +// require.NotPanicsf(t, func(){ RemainCalm() }, "error message %s", "formatted") func NotPanicsf(t TestingT, f assert.PanicTestFunc, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1623,8 +1729,8 @@ func NotPanicsf(t TestingT, f assert.PanicTestFunc, msg string, args ...interfac // NotRegexp asserts that a specified regexp does not match a string. // -// assert.NotRegexp(t, regexp.MustCompile("starts"), "it's starting") -// assert.NotRegexp(t, "^start", "it's not starting") +// require.NotRegexp(t, regexp.MustCompile("starts"), "it's starting") +// require.NotRegexp(t, "^start", "it's not starting") func NotRegexp(t TestingT, rx interface{}, str interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1637,8 +1743,8 @@ func NotRegexp(t TestingT, rx interface{}, str interface{}, msgAndArgs ...interf // NotRegexpf asserts that a specified regexp does not match a string. // -// assert.NotRegexpf(t, regexp.MustCompile("starts"), "it's starting", "error message %s", "formatted") -// assert.NotRegexpf(t, "^start", "it's not starting", "error message %s", "formatted") +// require.NotRegexpf(t, regexp.MustCompile("starts"), "it's starting", "error message %s", "formatted") +// require.NotRegexpf(t, "^start", "it's not starting", "error message %s", "formatted") func NotRegexpf(t TestingT, rx interface{}, str interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1651,7 +1757,7 @@ func NotRegexpf(t TestingT, rx interface{}, str interface{}, msg string, args .. // NotSame asserts that two pointers do not reference the same object. // -// assert.NotSame(t, ptr1, ptr2) +// require.NotSame(t, ptr1, ptr2) // // Both arguments must be pointer variables. Pointer variable sameness is // determined based on the equality of both type and value. @@ -1667,7 +1773,7 @@ func NotSame(t TestingT, expected interface{}, actual interface{}, msgAndArgs .. // NotSamef asserts that two pointers do not reference the same object. // -// assert.NotSamef(t, ptr1, ptr2, "error message %s", "formatted") +// require.NotSamef(t, ptr1, ptr2, "error message %s", "formatted") // // Both arguments must be pointer variables. Pointer variable sameness is // determined based on the equality of both type and value. @@ -1681,12 +1787,15 @@ func NotSamef(t TestingT, expected interface{}, actual interface{}, msg string, t.FailNow() } -// NotSubset asserts that the specified list(array, slice...) or map does NOT -// contain all elements given in the specified subset list(array, slice...) or -// map. +// NotSubset asserts that the list (array, slice, or map) does NOT contain all +// elements given in the subset (array, slice, or map). +// Map elements are key-value pairs unless compared with an array or slice where +// only the map key is evaluated. // -// assert.NotSubset(t, [1, 3, 4], [1, 2]) -// assert.NotSubset(t, {"x": 1, "y": 2}, {"z": 3}) +// require.NotSubset(t, [1, 3, 4], [1, 2]) +// require.NotSubset(t, {"x": 1, "y": 2}, {"z": 3}) +// require.NotSubset(t, [1, 3, 4], {1: "one", 2: "two"}) +// require.NotSubset(t, {"x": 1, "y": 2}, ["z"]) func NotSubset(t TestingT, list interface{}, subset interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1697,12 +1806,15 @@ func NotSubset(t TestingT, list interface{}, subset interface{}, msgAndArgs ...i t.FailNow() } -// NotSubsetf asserts that the specified list(array, slice...) or map does NOT -// contain all elements given in the specified subset list(array, slice...) or -// map. +// NotSubsetf asserts that the list (array, slice, or map) does NOT contain all +// elements given in the subset (array, slice, or map). +// Map elements are key-value pairs unless compared with an array or slice where +// only the map key is evaluated. // -// assert.NotSubsetf(t, [1, 3, 4], [1, 2], "error message %s", "formatted") -// assert.NotSubsetf(t, {"x": 1, "y": 2}, {"z": 3}, "error message %s", "formatted") +// require.NotSubsetf(t, [1, 3, 4], [1, 2], "error message %s", "formatted") +// require.NotSubsetf(t, {"x": 1, "y": 2}, {"z": 3}, "error message %s", "formatted") +// require.NotSubsetf(t, [1, 3, 4], {1: "one", 2: "two"}, "error message %s", "formatted") +// require.NotSubsetf(t, {"x": 1, "y": 2}, ["z"], "error message %s", "formatted") func NotSubsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1737,7 +1849,7 @@ func NotZerof(t TestingT, i interface{}, msg string, args ...interface{}) { // Panics asserts that the code inside the specified PanicTestFunc panics. // -// assert.Panics(t, func(){ GoCrazy() }) +// require.Panics(t, func(){ GoCrazy() }) func Panics(t TestingT, f assert.PanicTestFunc, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1752,7 +1864,7 @@ func Panics(t TestingT, f assert.PanicTestFunc, msgAndArgs ...interface{}) { // panics, and that the recovered panic value is an error that satisfies the // EqualError comparison. // -// assert.PanicsWithError(t, "crazy error", func(){ GoCrazy() }) +// require.PanicsWithError(t, "crazy error", func(){ GoCrazy() }) func PanicsWithError(t TestingT, errString string, f assert.PanicTestFunc, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1767,7 +1879,7 @@ func PanicsWithError(t TestingT, errString string, f assert.PanicTestFunc, msgAn // panics, and that the recovered panic value is an error that satisfies the // EqualError comparison. // -// assert.PanicsWithErrorf(t, "crazy error", func(){ GoCrazy() }, "error message %s", "formatted") +// require.PanicsWithErrorf(t, "crazy error", func(){ GoCrazy() }, "error message %s", "formatted") func PanicsWithErrorf(t TestingT, errString string, f assert.PanicTestFunc, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1781,7 +1893,7 @@ func PanicsWithErrorf(t TestingT, errString string, f assert.PanicTestFunc, msg // PanicsWithValue asserts that the code inside the specified PanicTestFunc panics, and that // the recovered panic value equals the expected panic value. // -// assert.PanicsWithValue(t, "crazy error", func(){ GoCrazy() }) +// require.PanicsWithValue(t, "crazy error", func(){ GoCrazy() }) func PanicsWithValue(t TestingT, expected interface{}, f assert.PanicTestFunc, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1795,7 +1907,7 @@ func PanicsWithValue(t TestingT, expected interface{}, f assert.PanicTestFunc, m // PanicsWithValuef asserts that the code inside the specified PanicTestFunc panics, and that // the recovered panic value equals the expected panic value. // -// assert.PanicsWithValuef(t, "crazy error", func(){ GoCrazy() }, "error message %s", "formatted") +// require.PanicsWithValuef(t, "crazy error", func(){ GoCrazy() }, "error message %s", "formatted") func PanicsWithValuef(t TestingT, expected interface{}, f assert.PanicTestFunc, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1808,7 +1920,7 @@ func PanicsWithValuef(t TestingT, expected interface{}, f assert.PanicTestFunc, // Panicsf asserts that the code inside the specified PanicTestFunc panics. // -// assert.Panicsf(t, func(){ GoCrazy() }, "error message %s", "formatted") +// require.Panicsf(t, func(){ GoCrazy() }, "error message %s", "formatted") func Panicsf(t TestingT, f assert.PanicTestFunc, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1821,8 +1933,8 @@ func Panicsf(t TestingT, f assert.PanicTestFunc, msg string, args ...interface{} // Positive asserts that the specified element is positive // -// assert.Positive(t, 1) -// assert.Positive(t, 1.23) +// require.Positive(t, 1) +// require.Positive(t, 1.23) func Positive(t TestingT, e interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1835,8 +1947,8 @@ func Positive(t TestingT, e interface{}, msgAndArgs ...interface{}) { // Positivef asserts that the specified element is positive // -// assert.Positivef(t, 1, "error message %s", "formatted") -// assert.Positivef(t, 1.23, "error message %s", "formatted") +// require.Positivef(t, 1, "error message %s", "formatted") +// require.Positivef(t, 1.23, "error message %s", "formatted") func Positivef(t TestingT, e interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1849,8 +1961,8 @@ func Positivef(t TestingT, e interface{}, msg string, args ...interface{}) { // Regexp asserts that a specified regexp matches a string. // -// assert.Regexp(t, regexp.MustCompile("start"), "it's starting") -// assert.Regexp(t, "start...$", "it's not starting") +// require.Regexp(t, regexp.MustCompile("start"), "it's starting") +// require.Regexp(t, "start...$", "it's not starting") func Regexp(t TestingT, rx interface{}, str interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1863,8 +1975,8 @@ func Regexp(t TestingT, rx interface{}, str interface{}, msgAndArgs ...interface // Regexpf asserts that a specified regexp matches a string. // -// assert.Regexpf(t, regexp.MustCompile("start"), "it's starting", "error message %s", "formatted") -// assert.Regexpf(t, "start...$", "it's not starting", "error message %s", "formatted") +// require.Regexpf(t, regexp.MustCompile("start"), "it's starting", "error message %s", "formatted") +// require.Regexpf(t, "start...$", "it's not starting", "error message %s", "formatted") func Regexpf(t TestingT, rx interface{}, str interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1877,7 +1989,7 @@ func Regexpf(t TestingT, rx interface{}, str interface{}, msg string, args ...in // Same asserts that two pointers reference the same object. // -// assert.Same(t, ptr1, ptr2) +// require.Same(t, ptr1, ptr2) // // Both arguments must be pointer variables. Pointer variable sameness is // determined based on the equality of both type and value. @@ -1893,7 +2005,7 @@ func Same(t TestingT, expected interface{}, actual interface{}, msgAndArgs ...in // Samef asserts that two pointers reference the same object. // -// assert.Samef(t, ptr1, ptr2, "error message %s", "formatted") +// require.Samef(t, ptr1, ptr2, "error message %s", "formatted") // // Both arguments must be pointer variables. Pointer variable sameness is // determined based on the equality of both type and value. @@ -1907,11 +2019,15 @@ func Samef(t TestingT, expected interface{}, actual interface{}, msg string, arg t.FailNow() } -// Subset asserts that the specified list(array, slice...) or map contains all -// elements given in the specified subset list(array, slice...) or map. +// Subset asserts that the list (array, slice, or map) contains all elements +// given in the subset (array, slice, or map). +// Map elements are key-value pairs unless compared with an array or slice where +// only the map key is evaluated. // -// assert.Subset(t, [1, 2, 3], [1, 2]) -// assert.Subset(t, {"x": 1, "y": 2}, {"x": 1}) +// require.Subset(t, [1, 2, 3], [1, 2]) +// require.Subset(t, {"x": 1, "y": 2}, {"x": 1}) +// require.Subset(t, [1, 2, 3], {1: "one", 2: "two"}) +// require.Subset(t, {"x": 1, "y": 2}, ["x"]) func Subset(t TestingT, list interface{}, subset interface{}, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1922,11 +2038,15 @@ func Subset(t TestingT, list interface{}, subset interface{}, msgAndArgs ...inte t.FailNow() } -// Subsetf asserts that the specified list(array, slice...) or map contains all -// elements given in the specified subset list(array, slice...) or map. +// Subsetf asserts that the list (array, slice, or map) contains all elements +// given in the subset (array, slice, or map). +// Map elements are key-value pairs unless compared with an array or slice where +// only the map key is evaluated. // -// assert.Subsetf(t, [1, 2, 3], [1, 2], "error message %s", "formatted") -// assert.Subsetf(t, {"x": 1, "y": 2}, {"x": 1}, "error message %s", "formatted") +// require.Subsetf(t, [1, 2, 3], [1, 2], "error message %s", "formatted") +// require.Subsetf(t, {"x": 1, "y": 2}, {"x": 1}, "error message %s", "formatted") +// require.Subsetf(t, [1, 2, 3], {1: "one", 2: "two"}, "error message %s", "formatted") +// require.Subsetf(t, {"x": 1, "y": 2}, ["x"], "error message %s", "formatted") func Subsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1939,7 +2059,7 @@ func Subsetf(t TestingT, list interface{}, subset interface{}, msg string, args // True asserts that the specified value is true. // -// assert.True(t, myBool) +// require.True(t, myBool) func True(t TestingT, value bool, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1952,7 +2072,7 @@ func True(t TestingT, value bool, msgAndArgs ...interface{}) { // Truef asserts that the specified value is true. // -// assert.Truef(t, myBool, "error message %s", "formatted") +// require.Truef(t, myBool, "error message %s", "formatted") func Truef(t TestingT, value bool, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1965,7 +2085,7 @@ func Truef(t TestingT, value bool, msg string, args ...interface{}) { // WithinDuration asserts that the two times are within duration delta of each other. // -// assert.WithinDuration(t, time.Now(), time.Now(), 10*time.Second) +// require.WithinDuration(t, time.Now(), time.Now(), 10*time.Second) func WithinDuration(t TestingT, expected time.Time, actual time.Time, delta time.Duration, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1978,7 +2098,7 @@ func WithinDuration(t TestingT, expected time.Time, actual time.Time, delta time // WithinDurationf asserts that the two times are within duration delta of each other. // -// assert.WithinDurationf(t, time.Now(), time.Now(), 10*time.Second, "error message %s", "formatted") +// require.WithinDurationf(t, time.Now(), time.Now(), 10*time.Second, "error message %s", "formatted") func WithinDurationf(t TestingT, expected time.Time, actual time.Time, delta time.Duration, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -1991,7 +2111,7 @@ func WithinDurationf(t TestingT, expected time.Time, actual time.Time, delta tim // WithinRange asserts that a time is within a time range (inclusive). // -// assert.WithinRange(t, time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second)) +// require.WithinRange(t, time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second)) func WithinRange(t TestingT, actual time.Time, start time.Time, end time.Time, msgAndArgs ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() @@ -2004,7 +2124,7 @@ func WithinRange(t TestingT, actual time.Time, start time.Time, end time.Time, m // WithinRangef asserts that a time is within a time range (inclusive). // -// assert.WithinRangef(t, time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second), "error message %s", "formatted") +// require.WithinRangef(t, time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second), "error message %s", "formatted") func WithinRangef(t TestingT, actual time.Time, start time.Time, end time.Time, msg string, args ...interface{}) { if h, ok := t.(tHelper); ok { h.Helper() diff --git a/vendor/github.com/stretchr/testify/require/require.go.tmpl b/vendor/github.com/stretchr/testify/require/require.go.tmpl index 55e42ddeb..8b3283685 100644 --- a/vendor/github.com/stretchr/testify/require/require.go.tmpl +++ b/vendor/github.com/stretchr/testify/require/require.go.tmpl @@ -1,4 +1,4 @@ -{{.Comment}} +{{ replace .Comment "assert." "require."}} func {{.DocInfo.Name}}(t TestingT, {{.Params}}) { if h, ok := t.(tHelper); ok { h.Helper() } if assert.{{.DocInfo.Name}}(t, {{.ForwardedParams}}) { return } diff --git a/vendor/github.com/stretchr/testify/require/require_forward.go b/vendor/github.com/stretchr/testify/require/require_forward.go index eee8310a5..e6f7e9446 100644 --- a/vendor/github.com/stretchr/testify/require/require_forward.go +++ b/vendor/github.com/stretchr/testify/require/require_forward.go @@ -93,10 +93,19 @@ func (a *Assertions) ElementsMatchf(listA interface{}, listB interface{}, msg st ElementsMatchf(a.t, listA, listB, msg, args...) } -// Empty asserts that the specified object is empty. I.e. nil, "", false, 0 or either -// a slice or a channel with len == 0. +// Empty asserts that the given value is "empty". +// +// [Zero values] are "empty". +// +// Arrays are "empty" if every element is the zero value of the type (stricter than "empty"). +// +// Slices, maps and channels with zero length are "empty". +// +// Pointer values are "empty" if the pointer is nil or if the pointed value is "empty". // // a.Empty(obj) +// +// [Zero values]: https://go.dev/ref/spec#The_zero_value func (a *Assertions) Empty(object interface{}, msgAndArgs ...interface{}) { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -104,10 +113,19 @@ func (a *Assertions) Empty(object interface{}, msgAndArgs ...interface{}) { Empty(a.t, object, msgAndArgs...) } -// Emptyf asserts that the specified object is empty. I.e. nil, "", false, 0 or either -// a slice or a channel with len == 0. +// Emptyf asserts that the given value is "empty". +// +// [Zero values] are "empty". +// +// Arrays are "empty" if every element is the zero value of the type (stricter than "empty"). +// +// Slices, maps and channels with zero length are "empty". +// +// Pointer values are "empty" if the pointer is nil or if the pointed value is "empty". // // a.Emptyf(obj, "error message %s", "formatted") +// +// [Zero values]: https://go.dev/ref/spec#The_zero_value func (a *Assertions) Emptyf(object interface{}, msg string, args ...interface{}) { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -187,8 +205,8 @@ func (a *Assertions) EqualExportedValuesf(expected interface{}, actual interface EqualExportedValuesf(a.t, expected, actual, msg, args...) } -// EqualValues asserts that two objects are equal or convertible to the same types -// and equal. +// EqualValues asserts that two objects are equal or convertible to the larger +// type and equal. // // a.EqualValues(uint32(123), int32(123)) func (a *Assertions) EqualValues(expected interface{}, actual interface{}, msgAndArgs ...interface{}) { @@ -198,8 +216,8 @@ func (a *Assertions) EqualValues(expected interface{}, actual interface{}, msgAn EqualValues(a.t, expected, actual, msgAndArgs...) } -// EqualValuesf asserts that two objects are equal or convertible to the same types -// and equal. +// EqualValuesf asserts that two objects are equal or convertible to the larger +// type and equal. // // a.EqualValuesf(uint32(123), int32(123), "error message %s", "formatted") func (a *Assertions) EqualValuesf(expected interface{}, actual interface{}, msg string, args ...interface{}) { @@ -225,10 +243,8 @@ func (a *Assertions) Equalf(expected interface{}, actual interface{}, msg string // Error asserts that a function returned an error (i.e. not `nil`). // -// actualObj, err := SomeFunction() -// if a.Error(err) { -// assert.Equal(t, expectedError, err) -// } +// actualObj, err := SomeFunction() +// a.Error(err) func (a *Assertions) Error(err error, msgAndArgs ...interface{}) { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -298,10 +314,8 @@ func (a *Assertions) ErrorIsf(err error, target error, msg string, args ...inter // Errorf asserts that a function returned an error (i.e. not `nil`). // -// actualObj, err := SomeFunction() -// if a.Errorf(err, "error message %s", "formatted") { -// assert.Equal(t, expectedErrorf, err) -// } +// actualObj, err := SomeFunction() +// a.Errorf(err, "error message %s", "formatted") func (a *Assertions) Errorf(err error, msg string, args ...interface{}) { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -337,7 +351,7 @@ func (a *Assertions) Eventually(condition func() bool, waitFor time.Duration, ti // a.EventuallyWithT(func(c *assert.CollectT) { // // add assertions as needed; any assertion failure will fail the current tick // assert.True(c, externalValue, "expected 'externalValue' to be true") -// }, 1*time.Second, 10*time.Second, "external state has not changed to 'true'; still false") +// }, 10*time.Second, 1*time.Second, "external state has not changed to 'true'; still false") func (a *Assertions) EventuallyWithT(condition func(collect *assert.CollectT), waitFor time.Duration, tick time.Duration, msgAndArgs ...interface{}) { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -362,7 +376,7 @@ func (a *Assertions) EventuallyWithT(condition func(collect *assert.CollectT), w // a.EventuallyWithTf(func(c *assert.CollectT, "error message %s", "formatted") { // // add assertions as needed; any assertion failure will fail the current tick // assert.True(c, externalValue, "expected 'externalValue' to be true") -// }, 1*time.Second, 10*time.Second, "external state has not changed to 'true'; still false") +// }, 10*time.Second, 1*time.Second, "external state has not changed to 'true'; still false") func (a *Assertions) EventuallyWithTf(condition func(collect *assert.CollectT), waitFor time.Duration, tick time.Duration, msg string, args ...interface{}) { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -869,7 +883,29 @@ func (a *Assertions) IsNonIncreasingf(object interface{}, msg string, args ...in IsNonIncreasingf(a.t, object, msg, args...) } +// IsNotType asserts that the specified objects are not of the same type. +// +// a.IsNotType(&NotMyStruct{}, &MyStruct{}) +func (a *Assertions) IsNotType(theType interface{}, object interface{}, msgAndArgs ...interface{}) { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + IsNotType(a.t, theType, object, msgAndArgs...) +} + +// IsNotTypef asserts that the specified objects are not of the same type. +// +// a.IsNotTypef(&NotMyStruct{}, &MyStruct{}, "error message %s", "formatted") +func (a *Assertions) IsNotTypef(theType interface{}, object interface{}, msg string, args ...interface{}) { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + IsNotTypef(a.t, theType, object, msg, args...) +} + // IsType asserts that the specified objects are of the same type. +// +// a.IsType(&MyStruct{}, &MyStruct{}) func (a *Assertions) IsType(expectedType interface{}, object interface{}, msgAndArgs ...interface{}) { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -878,6 +914,8 @@ func (a *Assertions) IsType(expectedType interface{}, object interface{}, msgAnd } // IsTypef asserts that the specified objects are of the same type. +// +// a.IsTypef(&MyStruct{}, &MyStruct{}, "error message %s", "formatted") func (a *Assertions) IsTypef(expectedType interface{}, object interface{}, msg string, args ...interface{}) { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -1129,8 +1167,41 @@ func (a *Assertions) NotContainsf(s interface{}, contains interface{}, msg strin NotContainsf(a.t, s, contains, msg, args...) } -// NotEmpty asserts that the specified object is NOT empty. I.e. not nil, "", false, 0 or either -// a slice or a channel with len == 0. +// NotElementsMatch asserts that the specified listA(array, slice...) is NOT equal to specified +// listB(array, slice...) ignoring the order of the elements. If there are duplicate elements, +// the number of appearances of each of them in both lists should not match. +// This is an inverse of ElementsMatch. +// +// a.NotElementsMatch([1, 1, 2, 3], [1, 1, 2, 3]) -> false +// +// a.NotElementsMatch([1, 1, 2, 3], [1, 2, 3]) -> true +// +// a.NotElementsMatch([1, 2, 3], [1, 2, 4]) -> true +func (a *Assertions) NotElementsMatch(listA interface{}, listB interface{}, msgAndArgs ...interface{}) { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + NotElementsMatch(a.t, listA, listB, msgAndArgs...) +} + +// NotElementsMatchf asserts that the specified listA(array, slice...) is NOT equal to specified +// listB(array, slice...) ignoring the order of the elements. If there are duplicate elements, +// the number of appearances of each of them in both lists should not match. +// This is an inverse of ElementsMatch. +// +// a.NotElementsMatchf([1, 1, 2, 3], [1, 1, 2, 3], "error message %s", "formatted") -> false +// +// a.NotElementsMatchf([1, 1, 2, 3], [1, 2, 3], "error message %s", "formatted") -> true +// +// a.NotElementsMatchf([1, 2, 3], [1, 2, 4], "error message %s", "formatted") -> true +func (a *Assertions) NotElementsMatchf(listA interface{}, listB interface{}, msg string, args ...interface{}) { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + NotElementsMatchf(a.t, listA, listB, msg, args...) +} + +// NotEmpty asserts that the specified object is NOT [Empty]. // // if a.NotEmpty(obj) { // assert.Equal(t, "two", obj[1]) @@ -1142,8 +1213,7 @@ func (a *Assertions) NotEmpty(object interface{}, msgAndArgs ...interface{}) { NotEmpty(a.t, object, msgAndArgs...) } -// NotEmptyf asserts that the specified object is NOT empty. I.e. not nil, "", false, 0 or either -// a slice or a channel with len == 0. +// NotEmptyf asserts that the specified object is NOT [Empty]. // // if a.NotEmptyf(obj, "error message %s", "formatted") { // assert.Equal(t, "two", obj[1]) @@ -1201,7 +1271,25 @@ func (a *Assertions) NotEqualf(expected interface{}, actual interface{}, msg str NotEqualf(a.t, expected, actual, msg, args...) } -// NotErrorIs asserts that at none of the errors in err's chain matches target. +// NotErrorAs asserts that none of the errors in err's chain matches target, +// but if so, sets target to that error value. +func (a *Assertions) NotErrorAs(err error, target interface{}, msgAndArgs ...interface{}) { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + NotErrorAs(a.t, err, target, msgAndArgs...) +} + +// NotErrorAsf asserts that none of the errors in err's chain matches target, +// but if so, sets target to that error value. +func (a *Assertions) NotErrorAsf(err error, target interface{}, msg string, args ...interface{}) { + if h, ok := a.t.(tHelper); ok { + h.Helper() + } + NotErrorAsf(a.t, err, target, msg, args...) +} + +// NotErrorIs asserts that none of the errors in err's chain matches target. // This is a wrapper for errors.Is. func (a *Assertions) NotErrorIs(err error, target error, msgAndArgs ...interface{}) { if h, ok := a.t.(tHelper); ok { @@ -1210,7 +1298,7 @@ func (a *Assertions) NotErrorIs(err error, target error, msgAndArgs ...interface NotErrorIs(a.t, err, target, msgAndArgs...) } -// NotErrorIsf asserts that at none of the errors in err's chain matches target. +// NotErrorIsf asserts that none of the errors in err's chain matches target. // This is a wrapper for errors.Is. func (a *Assertions) NotErrorIsf(err error, target error, msg string, args ...interface{}) { if h, ok := a.t.(tHelper); ok { @@ -1327,12 +1415,15 @@ func (a *Assertions) NotSamef(expected interface{}, actual interface{}, msg stri NotSamef(a.t, expected, actual, msg, args...) } -// NotSubset asserts that the specified list(array, slice...) or map does NOT -// contain all elements given in the specified subset list(array, slice...) or -// map. +// NotSubset asserts that the list (array, slice, or map) does NOT contain all +// elements given in the subset (array, slice, or map). +// Map elements are key-value pairs unless compared with an array or slice where +// only the map key is evaluated. // // a.NotSubset([1, 3, 4], [1, 2]) // a.NotSubset({"x": 1, "y": 2}, {"z": 3}) +// a.NotSubset([1, 3, 4], {1: "one", 2: "two"}) +// a.NotSubset({"x": 1, "y": 2}, ["z"]) func (a *Assertions) NotSubset(list interface{}, subset interface{}, msgAndArgs ...interface{}) { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -1340,12 +1431,15 @@ func (a *Assertions) NotSubset(list interface{}, subset interface{}, msgAndArgs NotSubset(a.t, list, subset, msgAndArgs...) } -// NotSubsetf asserts that the specified list(array, slice...) or map does NOT -// contain all elements given in the specified subset list(array, slice...) or -// map. +// NotSubsetf asserts that the list (array, slice, or map) does NOT contain all +// elements given in the subset (array, slice, or map). +// Map elements are key-value pairs unless compared with an array or slice where +// only the map key is evaluated. // // a.NotSubsetf([1, 3, 4], [1, 2], "error message %s", "formatted") // a.NotSubsetf({"x": 1, "y": 2}, {"z": 3}, "error message %s", "formatted") +// a.NotSubsetf([1, 3, 4], {1: "one", 2: "two"}, "error message %s", "formatted") +// a.NotSubsetf({"x": 1, "y": 2}, ["z"], "error message %s", "formatted") func (a *Assertions) NotSubsetf(list interface{}, subset interface{}, msg string, args ...interface{}) { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -1505,11 +1599,15 @@ func (a *Assertions) Samef(expected interface{}, actual interface{}, msg string, Samef(a.t, expected, actual, msg, args...) } -// Subset asserts that the specified list(array, slice...) or map contains all -// elements given in the specified subset list(array, slice...) or map. +// Subset asserts that the list (array, slice, or map) contains all elements +// given in the subset (array, slice, or map). +// Map elements are key-value pairs unless compared with an array or slice where +// only the map key is evaluated. // // a.Subset([1, 2, 3], [1, 2]) // a.Subset({"x": 1, "y": 2}, {"x": 1}) +// a.Subset([1, 2, 3], {1: "one", 2: "two"}) +// a.Subset({"x": 1, "y": 2}, ["x"]) func (a *Assertions) Subset(list interface{}, subset interface{}, msgAndArgs ...interface{}) { if h, ok := a.t.(tHelper); ok { h.Helper() @@ -1517,11 +1615,15 @@ func (a *Assertions) Subset(list interface{}, subset interface{}, msgAndArgs ... Subset(a.t, list, subset, msgAndArgs...) } -// Subsetf asserts that the specified list(array, slice...) or map contains all -// elements given in the specified subset list(array, slice...) or map. +// Subsetf asserts that the list (array, slice, or map) contains all elements +// given in the subset (array, slice, or map). +// Map elements are key-value pairs unless compared with an array or slice where +// only the map key is evaluated. // // a.Subsetf([1, 2, 3], [1, 2], "error message %s", "formatted") // a.Subsetf({"x": 1, "y": 2}, {"x": 1}, "error message %s", "formatted") +// a.Subsetf([1, 2, 3], {1: "one", 2: "two"}, "error message %s", "formatted") +// a.Subsetf({"x": 1, "y": 2}, ["x"], "error message %s", "formatted") func (a *Assertions) Subsetf(list interface{}, subset interface{}, msg string, args ...interface{}) { if h, ok := a.t.(tHelper); ok { h.Helper() diff --git a/vendor/github.com/stretchr/testify/require/requirements.go b/vendor/github.com/stretchr/testify/require/requirements.go index 91772dfeb..6b7ce929e 100644 --- a/vendor/github.com/stretchr/testify/require/requirements.go +++ b/vendor/github.com/stretchr/testify/require/requirements.go @@ -6,7 +6,7 @@ type TestingT interface { FailNow() } -type tHelper interface { +type tHelper = interface { Helper() } diff --git a/vendor/modules.txt b/vendor/modules.txt index e06c6a74b..3072915fa 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -263,7 +263,7 @@ github.com/jackc/pgpassfile # github.com/jackc/pgproto3/v2 v2.3.3 ## explicit; go 1.12 github.com/jackc/pgproto3/v2 -# github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 +# github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 ## explicit; go 1.14 github.com/jackc/pgservicefile # github.com/jackc/pgtype v1.14.3 @@ -272,8 +272,8 @@ github.com/jackc/pgtype # github.com/jackc/pgx-zap v0.0.0-20221202020421-94b1cb2f889f ## explicit; go 1.17 github.com/jackc/pgx-zap -# github.com/jackc/pgx/v5 v5.6.0 -## explicit; go 1.20 +# github.com/jackc/pgx/v5 v5.8.0 +## explicit; go 1.24.0 github.com/jackc/pgx/v5 github.com/jackc/pgx/v5/internal/iobufpool github.com/jackc/pgx/v5/internal/pgio @@ -287,7 +287,7 @@ github.com/jackc/pgx/v5/pgtype github.com/jackc/pgx/v5/pgxpool github.com/jackc/pgx/v5/stdlib github.com/jackc/pgx/v5/tracelog -# github.com/jackc/puddle/v2 v2.2.1 +# github.com/jackc/puddle/v2 v2.2.2 ## explicit; go 1.19 github.com/jackc/puddle/v2 github.com/jackc/puddle/v2/internal/genstack @@ -421,9 +421,10 @@ github.com/spf13/cobra # github.com/spf13/pflag v1.0.5 ## explicit; go 1.12 github.com/spf13/pflag -# github.com/stretchr/testify v1.9.0 +# github.com/stretchr/testify v1.11.1 ## explicit; go 1.17 github.com/stretchr/testify/assert +github.com/stretchr/testify/assert/yaml github.com/stretchr/testify/require # github.com/swaggest/jsonschema-go v0.3.72 ## explicit; go 1.18