From 544e90b1a1929d66b0aa3e5a1474c98a3f4fe69f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20BIDON?= Date: Fri, 19 Jun 2026 23:22:02 +0200 Subject: [PATCH] fix(assertions): guard reflection walkers against cyclic inputs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three hand-rolled reflection traversals recursed unconditionally and would overflow the goroutine stack (a fatal, unrecoverable runtime error) when given cyclic-but-legal Go values: - copyExportedFields (reached via EqualExportedValues): a self-referential struct such as `type N struct{ Next *N }` with `n.Next = n` recursed forever. - isEmptyValue (Empty/NotEmpty/Zero/NotZero): a cyclic pointer chain (`type P *P` with `p = &p`) recursed forever. - unwrapAll (error-chain formatting): an error with a cyclic Unwrap chain recursed forever. Each walker now carries a set of values currently being visited on the recursion path and breaks back-edges, while leaving non-cyclic (including shared/diamond) inputs unchanged. unwrapAll only tracks comparable errors, since using an incomparable error as a map key would panic (mirroring how the standard library's errors.Is guards its own comparisons). copyExportedFields is split into per-kind helpers to keep cognitive complexity within the linter threshold. Adds recursion_cycles_test.go with regression tests for each case. The signal is simply that the tests run to completion: before the fix each case crashes the test binary with a stack overflow. Note: the public ErrorIs/NotErrorIs path on a cyclic error chain still hangs inside the standard library's errors.Is (which walks Unwrap without cycle detection) before our formatter runs. Guarding that would mean reimplementing errors.Is and diverging from stdlib semantics, so it is left as-is; unwrapAll is hardened on its own as defense-in-depth. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Frédéric BIDON --- internal/assertions/equal.go | 136 +++++++++------- internal/assertions/equal_unary.go | 18 ++- internal/assertions/error.go | 34 +++- internal/assertions/recursion_cycles_test.go | 155 +++++++++++++++++++ 4 files changed, 279 insertions(+), 64 deletions(-) create mode 100644 internal/assertions/recursion_cycles_test.go diff --git a/internal/assertions/equal.go b/internal/assertions/equal.go index 223df40d7..2765cb543 100644 --- a/internal/assertions/equal.go +++ b/internal/assertions/equal.go @@ -350,77 +350,103 @@ func formatUnequalValues(expected, actual any) (e string, a string) { // copyExportedFields iterates downward through nested data structures and creates a copy // that only contains the exported struct fields. func copyExportedFields(expected any) any { + return copyExportedFieldsRec(expected, make(map[uintptr]struct{})) +} + +// copyExportedFieldsRec carries a set of pointers currently being visited on the +// recursion path, so that cyclic pointer references break the recursion instead +// of overflowing the goroutine stack. +func copyExportedFieldsRec(expected any, visited map[uintptr]struct{}) any { if isNil(expected) { return expected } expectedType := reflect.TypeOf(expected) - expectedKind := expectedType.Kind() expectedValue := reflect.ValueOf(expected) - switch expectedKind { + switch expectedType.Kind() { case reflect.Struct: - result := reflect.New(expectedType).Elem() - for i := range expectedType.NumField() { - field := expectedType.Field(i) - isExported := field.IsExported() - if isExported { - fieldValue := expectedValue.Field(i) - if isNil(fieldValue) || isNil(fieldValue.Interface()) { - continue - } - newValue := copyExportedFields(fieldValue.Interface()) - result.Field(i).Set(reflect.ValueOf(newValue)) - } - } - return result.Interface() - + return copyExportedStruct(expectedType, expectedValue, visited) case reflect.Pointer: - result := reflect.New(expectedType.Elem()) - unexportedRemoved := copyExportedFields(expectedValue.Elem().Interface()) - if unexportedRemoved != nil { - result.Elem().Set(reflect.ValueOf(unexportedRemoved)) - } - return result.Interface() - + return copyExportedPointer(expected, expectedType, expectedValue, visited) case reflect.Array, reflect.Slice: - var result reflect.Value - if expectedKind == reflect.Array { - result = reflect.New(reflect.ArrayOf(expectedValue.Len(), expectedType.Elem())).Elem() - } else { - result = reflect.MakeSlice(expectedType, expectedValue.Len(), expectedValue.Len()) + return copyExportedSequence(expectedType, expectedValue, visited) + case reflect.Map: + return copyExportedMap(expectedType, expectedValue, visited) + default: + return expected + } +} + +func copyExportedStruct(expectedType reflect.Type, expectedValue reflect.Value, visited map[uintptr]struct{}) any { + result := reflect.New(expectedType).Elem() + for i := range expectedType.NumField() { + if !expectedType.Field(i).IsExported() { + continue } - for i := range expectedValue.Len() { - index := expectedValue.Index(i) - if !index.CanInterface() { - // this should not be possible with current reflect, since values are retrieved from an array or slice, not a struct - panic(fmt.Errorf("internal error: can't resolve Interface() for value %v", index)) - } - unexportedRemoved := copyExportedFields(index.Interface()) - if unexportedRemoved != nil { - result.Index(i).Set(reflect.ValueOf(unexportedRemoved)) - } + fieldValue := expectedValue.Field(i) + if isNil(fieldValue) || isNil(fieldValue.Interface()) { + continue } - return result.Interface() + newValue := copyExportedFieldsRec(fieldValue.Interface(), visited) + result.Field(i).Set(reflect.ValueOf(newValue)) + } + return result.Interface() +} - case reflect.Map: - result := reflect.MakeMap(expectedType) - for _, k := range expectedValue.MapKeys() { - index := expectedValue.MapIndex(k) - if !index.CanInterface() { - // this should not be possible with current reflect, since values are retrieved from a map, not a struct - panic(fmt.Errorf("internal error: can't resolve Interface() for value %v", index)) - } - unexportedRemoved := copyExportedFields(index.Interface()) - if unexportedRemoved != nil { - result.SetMapIndex(k, reflect.ValueOf(unexportedRemoved)) - } +func copyExportedPointer(expected any, expectedType reflect.Type, expectedValue reflect.Value, visited map[uintptr]struct{}) any { + // Guard against cyclic pointer references: if this pointer is already + // on the current recursion path, return it as-is to break the cycle. + ptr := expectedValue.Pointer() + if _, ok := visited[ptr]; ok { + return expected + } + visited[ptr] = struct{}{} + defer delete(visited, ptr) + + result := reflect.New(expectedType.Elem()) + unexportedRemoved := copyExportedFieldsRec(expectedValue.Elem().Interface(), visited) + if unexportedRemoved != nil { + result.Elem().Set(reflect.ValueOf(unexportedRemoved)) + } + return result.Interface() +} + +func copyExportedSequence(expectedType reflect.Type, expectedValue reflect.Value, visited map[uintptr]struct{}) any { + var result reflect.Value + if expectedType.Kind() == reflect.Array { + result = reflect.New(reflect.ArrayOf(expectedValue.Len(), expectedType.Elem())).Elem() + } else { + result = reflect.MakeSlice(expectedType, expectedValue.Len(), expectedValue.Len()) + } + for i := range expectedValue.Len() { + index := expectedValue.Index(i) + if !index.CanInterface() { + // this should not be possible with current reflect, since values are retrieved from an array or slice, not a struct + panic(fmt.Errorf("internal error: can't resolve Interface() for value %v", index)) + } + unexportedRemoved := copyExportedFieldsRec(index.Interface(), visited) + if unexportedRemoved != nil { + result.Index(i).Set(reflect.ValueOf(unexportedRemoved)) } - return result.Interface() + } + return result.Interface() +} - default: - return expected +func copyExportedMap(expectedType reflect.Type, expectedValue reflect.Value, visited map[uintptr]struct{}) any { + result := reflect.MakeMap(expectedType) + for _, k := range expectedValue.MapKeys() { + index := expectedValue.MapIndex(k) + if !index.CanInterface() { + // this should not be possible with current reflect, since values are retrieved from a map, not a struct + panic(fmt.Errorf("internal error: can't resolve Interface() for value %v", index)) + } + unexportedRemoved := copyExportedFieldsRec(index.Interface(), visited) + if unexportedRemoved != nil { + result.SetMapIndex(k, reflect.ValueOf(unexportedRemoved)) + } } + return result.Interface() } func isFunction(arg any) bool { diff --git a/internal/assertions/equal_unary.go b/internal/assertions/equal_unary.go index 5b7b638eb..e18326a13 100644 --- a/internal/assertions/equal_unary.go +++ b/internal/assertions/equal_unary.go @@ -141,6 +141,13 @@ func isEmpty(object any) bool { // isEmptyValue gets whether the specified reflect.Value is considered empty or not. func isEmptyValue(objValue reflect.Value) bool { + return isEmptyValueRec(objValue, nil) +} + +// isEmptyValueRec carries the set of pointers already followed on the current +// recursion path, so that a cyclic pointer chain (e.g. type P *P with p = &p) +// breaks the recursion instead of overflowing the goroutine stack. +func isEmptyValueRec(objValue reflect.Value, visited map[uintptr]struct{}) bool { if objValue.IsZero() { return true } @@ -152,7 +159,16 @@ func isEmptyValue(objValue reflect.Value) bool { return objValue.Len() == 0 // non-nil pointers are empty if the value they point to is empty case reflect.Pointer: - return isEmptyValue(objValue.Elem()) + ptr := objValue.Pointer() + if _, ok := visited[ptr]; ok { + // a cyclic, non-nil pointer chain is not empty + return false + } + if visited == nil { + visited = make(map[uintptr]struct{}) + } + visited[ptr] = struct{}{} + return isEmptyValueRec(objValue.Elem(), visited) default: return false } diff --git a/internal/assertions/error.go b/internal/assertions/error.go index fedd48a69..e912afb02 100644 --- a/internal/assertions/error.go +++ b/internal/assertions/error.go @@ -271,21 +271,39 @@ func NotErrorAs(t T, err error, target any, msgAndArgs ...any) bool { ), msgAndArgs...) } -func unwrapAll(err error) (errs []error) { +func unwrapAll(err error) []error { + return appendUnwrapped(nil, err, make(map[error]struct{})) +} + +// appendUnwrapped flattens an error chain, guarding against cyclic chains that +// would otherwise recurse until the goroutine stack overflows. +// +// Only comparable errors are tracked: using an incomparable error as a map key +// would panic, so those simply skip cycle detection (mirroring how the standard +// library's errors.Is guards its own comparisons). +func appendUnwrapped(errs []error, err error, visited map[error]struct{}) []error { + if err != nil && reflect.TypeOf(err).Comparable() { + if _, ok := visited[err]; ok { + return errs // cyclic error chain: stop here + } + visited[err] = struct{}{} + defer delete(visited, err) + } + errs = append(errs, err) switch x := err.(type) { //nolint:errorlint // false positive: this type switch is checking for interfaces case interface{ Unwrap() error }: - err = x.Unwrap() - if err == nil { - return + next := x.Unwrap() + if next == nil { + return errs } - errs = append(errs, unwrapAll(err)...) + errs = appendUnwrapped(errs, next, visited) case interface{ Unwrap() []error }: - for _, err := range x.Unwrap() { - errs = append(errs, unwrapAll(err)...) + for _, next := range x.Unwrap() { + errs = appendUnwrapped(errs, next, visited) } } - return + return errs } func buildErrorChainString(err error, withType bool) string { diff --git a/internal/assertions/recursion_cycles_test.go b/internal/assertions/recursion_cycles_test.go new file mode 100644 index 000000000..f01bdca68 --- /dev/null +++ b/internal/assertions/recursion_cycles_test.go @@ -0,0 +1,155 @@ +// SPDX-FileCopyrightText: Copyright 2025 go-swagger maintainers +// SPDX-License-Identifier: Apache-2.0 + +package assertions + +import ( + "testing" +) + +// These tests guard the reflection-based assertion helpers against cyclic +// inputs. Such inputs are legal Go values but would, without an explicit guard, +// drive the helpers into unbounded recursion and overflow the goroutine stack. +// +// A goroutine stack overflow is a fatal runtime error that cannot be recovered, +// so the regression signal is simply that these tests run to completion: before +// the fix, each case crashes the test binary. + +type cyclicNode struct { + Name string + Next *cyclicNode +} + +// TestCopyExportedFieldsCycle covers copyExportedFields (reached via +// EqualExportedValues) with self-referential and mutually-referential pointers. +func TestCopyExportedFieldsCycle(t *testing.T) { + t.Parallel() + + t.Run("self-referential struct does not overflow the stack", func(t *testing.T) { + t.Parallel() + + a := &cyclicNode{Name: "node"} + a.Next = a + b := &cyclicNode{Name: "node"} + b.Next = b + + mock := new(mockT) + // Two structurally identical self-cycles compare as equal. + if !EqualExportedValues(mock, a, b) { + t.Errorf("expected structurally identical self-cycles to be equal") + } + }) + + t.Run("mutually-referential structs do not overflow the stack", func(t *testing.T) { + t.Parallel() + + a1 := &cyclicNode{Name: "a"} + a2 := &cyclicNode{Name: "b"} + a1.Next = a2 + a2.Next = a1 + + b1 := &cyclicNode{Name: "a"} + b2 := &cyclicNode{Name: "b"} + b1.Next = b2 + b2.Next = b1 + + mock := new(mockT) + if !EqualExportedValues(mock, a1, b1) { + t.Errorf("expected structurally identical mutual cycles to be equal") + } + }) + + t.Run("copyExportedFields returns on a direct cycle", func(t *testing.T) { + t.Parallel() + + a := &cyclicNode{Name: "node"} + a.Next = a + + // Direct call: the only property under test is that it returns. + _ = copyExportedFields(a) + }) +} + +// selfPtr is a pointer type that can point to a value of its own type, allowing +// the construction of a pointer chain that cycles back on itself. +type selfPtr *selfPtr + +// TestIsEmptyValueCycle covers isEmptyValue (reached via Empty/NotEmpty/Zero) +// with a cyclic pointer chain. +func TestIsEmptyValueCycle(t *testing.T) { + t.Parallel() + + var p selfPtr + p = &p // p points to itself + + mock := new(mockT) + + // A non-nil, cyclic pointer chain is not empty; the assertions must agree + // and, above all, must not recurse forever. + empty := Empty(mock, p) + notEmpty := NotEmpty(new(mockT), p) + + if empty { + t.Errorf("expected a non-nil cyclic pointer to be reported as not empty") + } + if empty == notEmpty { + t.Errorf("Empty and NotEmpty must be complementary, got Empty=%t NotEmpty=%t", empty, notEmpty) + } +} + +// cyclicError is an error whose Unwrap chain can be made to cycle. +type cyclicError struct { + msg string + next error +} + +func (e *cyclicError) Error() string { return e.msg } +func (e *cyclicError) Unwrap() error { return e.next } + +// multiCyclicError exercises the Unwrap() []error branch of the chain walker. +type multiCyclicError struct { + msg string + errs []error +} + +func (e *multiCyclicError) Error() string { return e.msg } +func (e *multiCyclicError) Unwrap() []error { return e.errs } + +// TestUnwrapAllCycle covers unwrapAll (reached when formatting error chains in +// ErrorIs/NotErrorIs failure messages) with cyclic Unwrap chains. +// +// Note: we deliberately test unwrapAll directly rather than through the public +// ErrorIs/NotErrorIs assertions. Those call the standard library's errors.Is +// first, which itself walks the Unwrap chain without cycle detection and so +// hangs on a cyclic chain before our formatter is ever reached. Guarding that +// would mean reimplementing errors.Is and diverging from standard-library +// semantics, which is out of scope here. unwrapAll is hardened on its own so +// that our code never contributes an unbounded recursion of its own. +func TestUnwrapAllCycle(t *testing.T) { + t.Parallel() + + t.Run("single Unwrap cycle does not overflow the stack", func(t *testing.T) { + t.Parallel() + + e := &cyclicError{msg: "boom"} + e.next = e // self-cycle + + errs := unwrapAll(e) + if len(errs) == 0 { + t.Fatalf("expected at least the head error") + } + }) + + t.Run("multi Unwrap cycle does not overflow the stack", func(t *testing.T) { + t.Parallel() + + leaf := &cyclicError{msg: "leaf"} + root := &multiCyclicError{msg: "root"} + root.errs = []error{leaf, root} // root references itself + + errs := unwrapAll(root) + if len(errs) == 0 { + t.Fatalf("expected at least the head error") + } + }) +}