Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 81 additions & 55 deletions internal/assertions/equal.go
Original file line number Diff line number Diff line change
Expand Up @@ -350,77 +350,103 @@ func formatUnequalValues(expected, actual any) (e string, a string) {
// copyExportedFields iterates downward through nested data structures and creates a copy
// that only contains the exported struct fields.
func copyExportedFields(expected any) any {
return copyExportedFieldsRec(expected, make(map[uintptr]struct{}))
}

// copyExportedFieldsRec carries a set of pointers currently being visited on the
// recursion path, so that cyclic pointer references break the recursion instead
// of overflowing the goroutine stack.
func copyExportedFieldsRec(expected any, visited map[uintptr]struct{}) any {
if isNil(expected) {
return expected
}

expectedType := reflect.TypeOf(expected)
expectedKind := expectedType.Kind()
expectedValue := reflect.ValueOf(expected)

switch expectedKind {
switch expectedType.Kind() {
case reflect.Struct:
result := reflect.New(expectedType).Elem()
for i := range expectedType.NumField() {
field := expectedType.Field(i)
isExported := field.IsExported()
if isExported {
fieldValue := expectedValue.Field(i)
if isNil(fieldValue) || isNil(fieldValue.Interface()) {
continue
}
newValue := copyExportedFields(fieldValue.Interface())
result.Field(i).Set(reflect.ValueOf(newValue))
}
}
return result.Interface()

return copyExportedStruct(expectedType, expectedValue, visited)
case reflect.Pointer:
result := reflect.New(expectedType.Elem())
unexportedRemoved := copyExportedFields(expectedValue.Elem().Interface())
if unexportedRemoved != nil {
result.Elem().Set(reflect.ValueOf(unexportedRemoved))
}
return result.Interface()

return copyExportedPointer(expected, expectedType, expectedValue, visited)
case reflect.Array, reflect.Slice:
var result reflect.Value
if expectedKind == reflect.Array {
result = reflect.New(reflect.ArrayOf(expectedValue.Len(), expectedType.Elem())).Elem()
} else {
result = reflect.MakeSlice(expectedType, expectedValue.Len(), expectedValue.Len())
return copyExportedSequence(expectedType, expectedValue, visited)
case reflect.Map:
return copyExportedMap(expectedType, expectedValue, visited)
default:
return expected
}
}

func copyExportedStruct(expectedType reflect.Type, expectedValue reflect.Value, visited map[uintptr]struct{}) any {
result := reflect.New(expectedType).Elem()
for i := range expectedType.NumField() {
if !expectedType.Field(i).IsExported() {
continue
}
for i := range expectedValue.Len() {
index := expectedValue.Index(i)
if !index.CanInterface() {
// this should not be possible with current reflect, since values are retrieved from an array or slice, not a struct
panic(fmt.Errorf("internal error: can't resolve Interface() for value %v", index))
}
unexportedRemoved := copyExportedFields(index.Interface())
if unexportedRemoved != nil {
result.Index(i).Set(reflect.ValueOf(unexportedRemoved))
}
fieldValue := expectedValue.Field(i)
if isNil(fieldValue) || isNil(fieldValue.Interface()) {
continue
}
return result.Interface()
newValue := copyExportedFieldsRec(fieldValue.Interface(), visited)
result.Field(i).Set(reflect.ValueOf(newValue))
}
return result.Interface()
}

case reflect.Map:
result := reflect.MakeMap(expectedType)
for _, k := range expectedValue.MapKeys() {
index := expectedValue.MapIndex(k)
if !index.CanInterface() {
// this should not be possible with current reflect, since values are retrieved from a map, not a struct
panic(fmt.Errorf("internal error: can't resolve Interface() for value %v", index))
}
unexportedRemoved := copyExportedFields(index.Interface())
if unexportedRemoved != nil {
result.SetMapIndex(k, reflect.ValueOf(unexportedRemoved))
}
func copyExportedPointer(expected any, expectedType reflect.Type, expectedValue reflect.Value, visited map[uintptr]struct{}) any {
// Guard against cyclic pointer references: if this pointer is already
// on the current recursion path, return it as-is to break the cycle.
ptr := expectedValue.Pointer()
if _, ok := visited[ptr]; ok {
return expected
}
visited[ptr] = struct{}{}
defer delete(visited, ptr)

result := reflect.New(expectedType.Elem())
unexportedRemoved := copyExportedFieldsRec(expectedValue.Elem().Interface(), visited)
if unexportedRemoved != nil {
result.Elem().Set(reflect.ValueOf(unexportedRemoved))
}
return result.Interface()
}

func copyExportedSequence(expectedType reflect.Type, expectedValue reflect.Value, visited map[uintptr]struct{}) any {
var result reflect.Value
if expectedType.Kind() == reflect.Array {
result = reflect.New(reflect.ArrayOf(expectedValue.Len(), expectedType.Elem())).Elem()
} else {
result = reflect.MakeSlice(expectedType, expectedValue.Len(), expectedValue.Len())
}
for i := range expectedValue.Len() {
index := expectedValue.Index(i)
if !index.CanInterface() {
// this should not be possible with current reflect, since values are retrieved from an array or slice, not a struct
panic(fmt.Errorf("internal error: can't resolve Interface() for value %v", index))
}
unexportedRemoved := copyExportedFieldsRec(index.Interface(), visited)
if unexportedRemoved != nil {
result.Index(i).Set(reflect.ValueOf(unexportedRemoved))
}
return result.Interface()
}
return result.Interface()
}

default:
return expected
func copyExportedMap(expectedType reflect.Type, expectedValue reflect.Value, visited map[uintptr]struct{}) any {
result := reflect.MakeMap(expectedType)
for _, k := range expectedValue.MapKeys() {
index := expectedValue.MapIndex(k)
if !index.CanInterface() {
// this should not be possible with current reflect, since values are retrieved from a map, not a struct
panic(fmt.Errorf("internal error: can't resolve Interface() for value %v", index))
}
unexportedRemoved := copyExportedFieldsRec(index.Interface(), visited)
if unexportedRemoved != nil {
result.SetMapIndex(k, reflect.ValueOf(unexportedRemoved))
}
}
return result.Interface()
}

func isFunction(arg any) bool {
Expand Down
18 changes: 17 additions & 1 deletion internal/assertions/equal_unary.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,13 @@ func isEmpty(object any) bool {

// isEmptyValue gets whether the specified reflect.Value is considered empty or not.
func isEmptyValue(objValue reflect.Value) bool {
return isEmptyValueRec(objValue, nil)
}

// isEmptyValueRec carries the set of pointers already followed on the current
// recursion path, so that a cyclic pointer chain (e.g. type P *P with p = &p)
// breaks the recursion instead of overflowing the goroutine stack.
func isEmptyValueRec(objValue reflect.Value, visited map[uintptr]struct{}) bool {
if objValue.IsZero() {
return true
}
Expand All @@ -152,7 +159,16 @@ func isEmptyValue(objValue reflect.Value) bool {
return objValue.Len() == 0
// non-nil pointers are empty if the value they point to is empty
case reflect.Pointer:
return isEmptyValue(objValue.Elem())
ptr := objValue.Pointer()
if _, ok := visited[ptr]; ok {
// a cyclic, non-nil pointer chain is not empty
return false
}
if visited == nil {
visited = make(map[uintptr]struct{})
}
visited[ptr] = struct{}{}
return isEmptyValueRec(objValue.Elem(), visited)
default:
return false
}
Expand Down
34 changes: 26 additions & 8 deletions internal/assertions/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,21 +271,39 @@ func NotErrorAs(t T, err error, target any, msgAndArgs ...any) bool {
), msgAndArgs...)
}

func unwrapAll(err error) (errs []error) {
func unwrapAll(err error) []error {
return appendUnwrapped(nil, err, make(map[error]struct{}))
}

// appendUnwrapped flattens an error chain, guarding against cyclic chains that
// would otherwise recurse until the goroutine stack overflows.
//
// Only comparable errors are tracked: using an incomparable error as a map key
// would panic, so those simply skip cycle detection (mirroring how the standard
// library's errors.Is guards its own comparisons).
func appendUnwrapped(errs []error, err error, visited map[error]struct{}) []error {
if err != nil && reflect.TypeOf(err).Comparable() {
if _, ok := visited[err]; ok {
return errs // cyclic error chain: stop here
}
visited[err] = struct{}{}
defer delete(visited, err)
}

errs = append(errs, err)
switch x := err.(type) { //nolint:errorlint // false positive: this type switch is checking for interfaces
case interface{ Unwrap() error }:
err = x.Unwrap()
if err == nil {
return
next := x.Unwrap()
if next == nil {
return errs
}
errs = append(errs, unwrapAll(err)...)
errs = appendUnwrapped(errs, next, visited)
case interface{ Unwrap() []error }:
for _, err := range x.Unwrap() {
errs = append(errs, unwrapAll(err)...)
for _, next := range x.Unwrap() {
errs = appendUnwrapped(errs, next, visited)
}
}
return
return errs
}

func buildErrorChainString(err error, withType bool) string {
Expand Down
155 changes: 155 additions & 0 deletions internal/assertions/recursion_cycles_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
// SPDX-FileCopyrightText: Copyright 2025 go-swagger maintainers
// SPDX-License-Identifier: Apache-2.0

package assertions

import (
"testing"
)

// These tests guard the reflection-based assertion helpers against cyclic
// inputs. Such inputs are legal Go values but would, without an explicit guard,
// drive the helpers into unbounded recursion and overflow the goroutine stack.
//
// A goroutine stack overflow is a fatal runtime error that cannot be recovered,
// so the regression signal is simply that these tests run to completion: before
// the fix, each case crashes the test binary.

type cyclicNode struct {
Name string
Next *cyclicNode
}

// TestCopyExportedFieldsCycle covers copyExportedFields (reached via
// EqualExportedValues) with self-referential and mutually-referential pointers.
func TestCopyExportedFieldsCycle(t *testing.T) {
t.Parallel()

t.Run("self-referential struct does not overflow the stack", func(t *testing.T) {
t.Parallel()

a := &cyclicNode{Name: "node"}
a.Next = a
b := &cyclicNode{Name: "node"}
b.Next = b

mock := new(mockT)
// Two structurally identical self-cycles compare as equal.
if !EqualExportedValues(mock, a, b) {
t.Errorf("expected structurally identical self-cycles to be equal")
}
})

t.Run("mutually-referential structs do not overflow the stack", func(t *testing.T) {
t.Parallel()

a1 := &cyclicNode{Name: "a"}
a2 := &cyclicNode{Name: "b"}
a1.Next = a2
a2.Next = a1

b1 := &cyclicNode{Name: "a"}
b2 := &cyclicNode{Name: "b"}
b1.Next = b2
b2.Next = b1

mock := new(mockT)
if !EqualExportedValues(mock, a1, b1) {
t.Errorf("expected structurally identical mutual cycles to be equal")
}
})

t.Run("copyExportedFields returns on a direct cycle", func(t *testing.T) {
t.Parallel()

a := &cyclicNode{Name: "node"}
a.Next = a

// Direct call: the only property under test is that it returns.
_ = copyExportedFields(a)
})
}

// selfPtr is a pointer type that can point to a value of its own type, allowing
// the construction of a pointer chain that cycles back on itself.
type selfPtr *selfPtr

// TestIsEmptyValueCycle covers isEmptyValue (reached via Empty/NotEmpty/Zero)
// with a cyclic pointer chain.
func TestIsEmptyValueCycle(t *testing.T) {
t.Parallel()

var p selfPtr
p = &p // p points to itself

mock := new(mockT)

// A non-nil, cyclic pointer chain is not empty; the assertions must agree
// and, above all, must not recurse forever.
empty := Empty(mock, p)
notEmpty := NotEmpty(new(mockT), p)

if empty {
t.Errorf("expected a non-nil cyclic pointer to be reported as not empty")
}
if empty == notEmpty {
t.Errorf("Empty and NotEmpty must be complementary, got Empty=%t NotEmpty=%t", empty, notEmpty)
}
}

// cyclicError is an error whose Unwrap chain can be made to cycle.
type cyclicError struct {
msg string
next error
}

func (e *cyclicError) Error() string { return e.msg }
func (e *cyclicError) Unwrap() error { return e.next }

// multiCyclicError exercises the Unwrap() []error branch of the chain walker.
type multiCyclicError struct {
msg string
errs []error
}

func (e *multiCyclicError) Error() string { return e.msg }
func (e *multiCyclicError) Unwrap() []error { return e.errs }

// TestUnwrapAllCycle covers unwrapAll (reached when formatting error chains in
// ErrorIs/NotErrorIs failure messages) with cyclic Unwrap chains.
//
// Note: we deliberately test unwrapAll directly rather than through the public
// ErrorIs/NotErrorIs assertions. Those call the standard library's errors.Is
// first, which itself walks the Unwrap chain without cycle detection and so
// hangs on a cyclic chain before our formatter is ever reached. Guarding that
// would mean reimplementing errors.Is and diverging from standard-library
// semantics, which is out of scope here. unwrapAll is hardened on its own so
// that our code never contributes an unbounded recursion of its own.
func TestUnwrapAllCycle(t *testing.T) {
t.Parallel()

t.Run("single Unwrap cycle does not overflow the stack", func(t *testing.T) {
t.Parallel()

e := &cyclicError{msg: "boom"}
e.next = e // self-cycle

errs := unwrapAll(e)
if len(errs) == 0 {
t.Fatalf("expected at least the head error")
}
})

t.Run("multi Unwrap cycle does not overflow the stack", func(t *testing.T) {
t.Parallel()

leaf := &cyclicError{msg: "leaf"}
root := &multiCyclicError{msg: "root"}
root.errs = []error{leaf, root} // root references itself

errs := unwrapAll(root)
if len(errs) == 0 {
t.Fatalf("expected at least the head error")
}
})
}
Loading