diff --git a/internal/assertions/equal.go b/internal/assertions/equal.go index 223df40d7..2765cb543 100644 --- a/internal/assertions/equal.go +++ b/internal/assertions/equal.go @@ -350,77 +350,103 @@ func formatUnequalValues(expected, actual any) (e string, a string) { // copyExportedFields iterates downward through nested data structures and creates a copy // that only contains the exported struct fields. func copyExportedFields(expected any) any { + return copyExportedFieldsRec(expected, make(map[uintptr]struct{})) +} + +// copyExportedFieldsRec carries a set of pointers currently being visited on the +// recursion path, so that cyclic pointer references break the recursion instead +// of overflowing the goroutine stack. +func copyExportedFieldsRec(expected any, visited map[uintptr]struct{}) any { if isNil(expected) { return expected } expectedType := reflect.TypeOf(expected) - expectedKind := expectedType.Kind() expectedValue := reflect.ValueOf(expected) - switch expectedKind { + switch expectedType.Kind() { case reflect.Struct: - result := reflect.New(expectedType).Elem() - for i := range expectedType.NumField() { - field := expectedType.Field(i) - isExported := field.IsExported() - if isExported { - fieldValue := expectedValue.Field(i) - if isNil(fieldValue) || isNil(fieldValue.Interface()) { - continue - } - newValue := copyExportedFields(fieldValue.Interface()) - result.Field(i).Set(reflect.ValueOf(newValue)) - } - } - return result.Interface() - + return copyExportedStruct(expectedType, expectedValue, visited) case reflect.Pointer: - result := reflect.New(expectedType.Elem()) - unexportedRemoved := copyExportedFields(expectedValue.Elem().Interface()) - if unexportedRemoved != nil { - result.Elem().Set(reflect.ValueOf(unexportedRemoved)) - } - return result.Interface() - + return copyExportedPointer(expected, expectedType, expectedValue, visited) case reflect.Array, reflect.Slice: - var result reflect.Value - if expectedKind == reflect.Array { - result = reflect.New(reflect.ArrayOf(expectedValue.Len(), expectedType.Elem())).Elem() - } else { - result = reflect.MakeSlice(expectedType, expectedValue.Len(), expectedValue.Len()) + return copyExportedSequence(expectedType, expectedValue, visited) + case reflect.Map: + return copyExportedMap(expectedType, expectedValue, visited) + default: + return expected + } +} + +func copyExportedStruct(expectedType reflect.Type, expectedValue reflect.Value, visited map[uintptr]struct{}) any { + result := reflect.New(expectedType).Elem() + for i := range expectedType.NumField() { + if !expectedType.Field(i).IsExported() { + continue } - for i := range expectedValue.Len() { - index := expectedValue.Index(i) - if !index.CanInterface() { - // this should not be possible with current reflect, since values are retrieved from an array or slice, not a struct - panic(fmt.Errorf("internal error: can't resolve Interface() for value %v", index)) - } - unexportedRemoved := copyExportedFields(index.Interface()) - if unexportedRemoved != nil { - result.Index(i).Set(reflect.ValueOf(unexportedRemoved)) - } + fieldValue := expectedValue.Field(i) + if isNil(fieldValue) || isNil(fieldValue.Interface()) { + continue } - return result.Interface() + newValue := copyExportedFieldsRec(fieldValue.Interface(), visited) + result.Field(i).Set(reflect.ValueOf(newValue)) + } + return result.Interface() +} - case reflect.Map: - result := reflect.MakeMap(expectedType) - for _, k := range expectedValue.MapKeys() { - index := expectedValue.MapIndex(k) - if !index.CanInterface() { - // this should not be possible with current reflect, since values are retrieved from a map, not a struct - panic(fmt.Errorf("internal error: can't resolve Interface() for value %v", index)) - } - unexportedRemoved := copyExportedFields(index.Interface()) - if unexportedRemoved != nil { - result.SetMapIndex(k, reflect.ValueOf(unexportedRemoved)) - } +func copyExportedPointer(expected any, expectedType reflect.Type, expectedValue reflect.Value, visited map[uintptr]struct{}) any { + // Guard against cyclic pointer references: if this pointer is already + // on the current recursion path, return it as-is to break the cycle. + ptr := expectedValue.Pointer() + if _, ok := visited[ptr]; ok { + return expected + } + visited[ptr] = struct{}{} + defer delete(visited, ptr) + + result := reflect.New(expectedType.Elem()) + unexportedRemoved := copyExportedFieldsRec(expectedValue.Elem().Interface(), visited) + if unexportedRemoved != nil { + result.Elem().Set(reflect.ValueOf(unexportedRemoved)) + } + return result.Interface() +} + +func copyExportedSequence(expectedType reflect.Type, expectedValue reflect.Value, visited map[uintptr]struct{}) any { + var result reflect.Value + if expectedType.Kind() == reflect.Array { + result = reflect.New(reflect.ArrayOf(expectedValue.Len(), expectedType.Elem())).Elem() + } else { + result = reflect.MakeSlice(expectedType, expectedValue.Len(), expectedValue.Len()) + } + for i := range expectedValue.Len() { + index := expectedValue.Index(i) + if !index.CanInterface() { + // this should not be possible with current reflect, since values are retrieved from an array or slice, not a struct + panic(fmt.Errorf("internal error: can't resolve Interface() for value %v", index)) + } + unexportedRemoved := copyExportedFieldsRec(index.Interface(), visited) + if unexportedRemoved != nil { + result.Index(i).Set(reflect.ValueOf(unexportedRemoved)) } - return result.Interface() + } + return result.Interface() +} - default: - return expected +func copyExportedMap(expectedType reflect.Type, expectedValue reflect.Value, visited map[uintptr]struct{}) any { + result := reflect.MakeMap(expectedType) + for _, k := range expectedValue.MapKeys() { + index := expectedValue.MapIndex(k) + if !index.CanInterface() { + // this should not be possible with current reflect, since values are retrieved from a map, not a struct + panic(fmt.Errorf("internal error: can't resolve Interface() for value %v", index)) + } + unexportedRemoved := copyExportedFieldsRec(index.Interface(), visited) + if unexportedRemoved != nil { + result.SetMapIndex(k, reflect.ValueOf(unexportedRemoved)) + } } + return result.Interface() } func isFunction(arg any) bool { diff --git a/internal/assertions/equal_unary.go b/internal/assertions/equal_unary.go index 5b7b638eb..e18326a13 100644 --- a/internal/assertions/equal_unary.go +++ b/internal/assertions/equal_unary.go @@ -141,6 +141,13 @@ func isEmpty(object any) bool { // isEmptyValue gets whether the specified reflect.Value is considered empty or not. func isEmptyValue(objValue reflect.Value) bool { + return isEmptyValueRec(objValue, nil) +} + +// isEmptyValueRec carries the set of pointers already followed on the current +// recursion path, so that a cyclic pointer chain (e.g. type P *P with p = &p) +// breaks the recursion instead of overflowing the goroutine stack. +func isEmptyValueRec(objValue reflect.Value, visited map[uintptr]struct{}) bool { if objValue.IsZero() { return true } @@ -152,7 +159,16 @@ func isEmptyValue(objValue reflect.Value) bool { return objValue.Len() == 0 // non-nil pointers are empty if the value they point to is empty case reflect.Pointer: - return isEmptyValue(objValue.Elem()) + ptr := objValue.Pointer() + if _, ok := visited[ptr]; ok { + // a cyclic, non-nil pointer chain is not empty + return false + } + if visited == nil { + visited = make(map[uintptr]struct{}) + } + visited[ptr] = struct{}{} + return isEmptyValueRec(objValue.Elem(), visited) default: return false } diff --git a/internal/assertions/error.go b/internal/assertions/error.go index fedd48a69..e912afb02 100644 --- a/internal/assertions/error.go +++ b/internal/assertions/error.go @@ -271,21 +271,39 @@ func NotErrorAs(t T, err error, target any, msgAndArgs ...any) bool { ), msgAndArgs...) } -func unwrapAll(err error) (errs []error) { +func unwrapAll(err error) []error { + return appendUnwrapped(nil, err, make(map[error]struct{})) +} + +// appendUnwrapped flattens an error chain, guarding against cyclic chains that +// would otherwise recurse until the goroutine stack overflows. +// +// Only comparable errors are tracked: using an incomparable error as a map key +// would panic, so those simply skip cycle detection (mirroring how the standard +// library's errors.Is guards its own comparisons). +func appendUnwrapped(errs []error, err error, visited map[error]struct{}) []error { + if err != nil && reflect.TypeOf(err).Comparable() { + if _, ok := visited[err]; ok { + return errs // cyclic error chain: stop here + } + visited[err] = struct{}{} + defer delete(visited, err) + } + errs = append(errs, err) switch x := err.(type) { //nolint:errorlint // false positive: this type switch is checking for interfaces case interface{ Unwrap() error }: - err = x.Unwrap() - if err == nil { - return + next := x.Unwrap() + if next == nil { + return errs } - errs = append(errs, unwrapAll(err)...) + errs = appendUnwrapped(errs, next, visited) case interface{ Unwrap() []error }: - for _, err := range x.Unwrap() { - errs = append(errs, unwrapAll(err)...) + for _, next := range x.Unwrap() { + errs = appendUnwrapped(errs, next, visited) } } - return + return errs } func buildErrorChainString(err error, withType bool) string { diff --git a/internal/assertions/recursion_cycles_test.go b/internal/assertions/recursion_cycles_test.go new file mode 100644 index 000000000..f01bdca68 --- /dev/null +++ b/internal/assertions/recursion_cycles_test.go @@ -0,0 +1,155 @@ +// SPDX-FileCopyrightText: Copyright 2025 go-swagger maintainers +// SPDX-License-Identifier: Apache-2.0 + +package assertions + +import ( + "testing" +) + +// These tests guard the reflection-based assertion helpers against cyclic +// inputs. Such inputs are legal Go values but would, without an explicit guard, +// drive the helpers into unbounded recursion and overflow the goroutine stack. +// +// A goroutine stack overflow is a fatal runtime error that cannot be recovered, +// so the regression signal is simply that these tests run to completion: before +// the fix, each case crashes the test binary. + +type cyclicNode struct { + Name string + Next *cyclicNode +} + +// TestCopyExportedFieldsCycle covers copyExportedFields (reached via +// EqualExportedValues) with self-referential and mutually-referential pointers. +func TestCopyExportedFieldsCycle(t *testing.T) { + t.Parallel() + + t.Run("self-referential struct does not overflow the stack", func(t *testing.T) { + t.Parallel() + + a := &cyclicNode{Name: "node"} + a.Next = a + b := &cyclicNode{Name: "node"} + b.Next = b + + mock := new(mockT) + // Two structurally identical self-cycles compare as equal. + if !EqualExportedValues(mock, a, b) { + t.Errorf("expected structurally identical self-cycles to be equal") + } + }) + + t.Run("mutually-referential structs do not overflow the stack", func(t *testing.T) { + t.Parallel() + + a1 := &cyclicNode{Name: "a"} + a2 := &cyclicNode{Name: "b"} + a1.Next = a2 + a2.Next = a1 + + b1 := &cyclicNode{Name: "a"} + b2 := &cyclicNode{Name: "b"} + b1.Next = b2 + b2.Next = b1 + + mock := new(mockT) + if !EqualExportedValues(mock, a1, b1) { + t.Errorf("expected structurally identical mutual cycles to be equal") + } + }) + + t.Run("copyExportedFields returns on a direct cycle", func(t *testing.T) { + t.Parallel() + + a := &cyclicNode{Name: "node"} + a.Next = a + + // Direct call: the only property under test is that it returns. + _ = copyExportedFields(a) + }) +} + +// selfPtr is a pointer type that can point to a value of its own type, allowing +// the construction of a pointer chain that cycles back on itself. +type selfPtr *selfPtr + +// TestIsEmptyValueCycle covers isEmptyValue (reached via Empty/NotEmpty/Zero) +// with a cyclic pointer chain. +func TestIsEmptyValueCycle(t *testing.T) { + t.Parallel() + + var p selfPtr + p = &p // p points to itself + + mock := new(mockT) + + // A non-nil, cyclic pointer chain is not empty; the assertions must agree + // and, above all, must not recurse forever. + empty := Empty(mock, p) + notEmpty := NotEmpty(new(mockT), p) + + if empty { + t.Errorf("expected a non-nil cyclic pointer to be reported as not empty") + } + if empty == notEmpty { + t.Errorf("Empty and NotEmpty must be complementary, got Empty=%t NotEmpty=%t", empty, notEmpty) + } +} + +// cyclicError is an error whose Unwrap chain can be made to cycle. +type cyclicError struct { + msg string + next error +} + +func (e *cyclicError) Error() string { return e.msg } +func (e *cyclicError) Unwrap() error { return e.next } + +// multiCyclicError exercises the Unwrap() []error branch of the chain walker. +type multiCyclicError struct { + msg string + errs []error +} + +func (e *multiCyclicError) Error() string { return e.msg } +func (e *multiCyclicError) Unwrap() []error { return e.errs } + +// TestUnwrapAllCycle covers unwrapAll (reached when formatting error chains in +// ErrorIs/NotErrorIs failure messages) with cyclic Unwrap chains. +// +// Note: we deliberately test unwrapAll directly rather than through the public +// ErrorIs/NotErrorIs assertions. Those call the standard library's errors.Is +// first, which itself walks the Unwrap chain without cycle detection and so +// hangs on a cyclic chain before our formatter is ever reached. Guarding that +// would mean reimplementing errors.Is and diverging from standard-library +// semantics, which is out of scope here. unwrapAll is hardened on its own so +// that our code never contributes an unbounded recursion of its own. +func TestUnwrapAllCycle(t *testing.T) { + t.Parallel() + + t.Run("single Unwrap cycle does not overflow the stack", func(t *testing.T) { + t.Parallel() + + e := &cyclicError{msg: "boom"} + e.next = e // self-cycle + + errs := unwrapAll(e) + if len(errs) == 0 { + t.Fatalf("expected at least the head error") + } + }) + + t.Run("multi Unwrap cycle does not overflow the stack", func(t *testing.T) { + t.Parallel() + + leaf := &cyclicError{msg: "leaf"} + root := &multiCyclicError{msg: "root"} + root.errs = []error{leaf, root} // root references itself + + errs := unwrapAll(root) + if len(errs) == 0 { + t.Fatalf("expected at least the head error") + } + }) +}