diff --git a/src/compyre/api.py b/src/compyre/api.py index 0184745..369ba19 100644 --- a/src/compyre/api.py +++ b/src/compyre/api.py @@ -118,6 +118,7 @@ def compare( ) pairs: Deque[Pair] = deque([Pair(index=(), actual=actual, expected=expected)]) + seen: set[tuple[int, int]] = {(id(actual), id(expected))} errors: list[CompareError] = [] while pairs: pair = pairs.popleft() @@ -133,7 +134,10 @@ def compare( errors.append(CompareError(pair=pair, exception=unpack_result)) else: for p in reversed(unpack_result): - pairs.appendleft(p) + key = (id(p.actual), id(p.expected)) + if key not in seen: + seen.add(key) + pairs.appendleft(p) continue equal_result: EqualFnResult = None diff --git a/tests/test_api.py b/tests/test_api.py index 02e1f89..9e574e3 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,10 +1,11 @@ +import dataclasses import inspect from copy import deepcopy from typing import Annotated, Any import pytest -from compyre import alias, api, builtin +from compyre import alias, api, builtin, utils class TestParametrizeFns: @@ -228,6 +229,100 @@ def test_alias(self): class TestCompare: + def test_circular_reference_two_nodes(self): + @dataclasses.dataclass + class Node: + value: int + peer: Any = None + + a = Node(value=1) + b = Node(value=1) + a.peer = b + b.peer = a + + c = Node(value=1) + d = Node(value=1) + c.peer = d + d.peer = c + + def unpack_node(p, /): + if not utils.both_isinstance(p, Node): + return None + + return [ + api.Pair( + index=(*p.index, "value"), + actual=p.actual.value, + expected=p.expected.value, + ), + api.Pair( + index=(*p.index, "peer"), + actual=p.actual.peer, + expected=p.expected.peer, + ), + ] + + def equal_fn(p, /): + return p.actual == p.expected + + errors = api.compare( + [a, b], + [c, d], + unpack_fns=[ + builtin.unpack_fns.collections_sequence, + unpack_node, + ], + equal_fns=[equal_fn], + ) + assert not errors + + def test_circular_reference_mismatch(self): + @dataclasses.dataclass + class Node: + value: int + peer: Any = None + + a = Node(value=1) + b = Node(value=2) + a.peer = b + b.peer = a + + c = Node(value=1) + d = Node(value=3) + c.peer = d + d.peer = c + + def unpack_node(p, /): + if not (isinstance(p.actual, Node) and isinstance(p.expected, Node)): + return None + return [ + api.Pair( + index=(*p.index, "value"), + actual=p.actual.value, + expected=p.expected.value, + ), + api.Pair( + index=(*p.index, "peer"), + actual=p.actual.peer, + expected=p.expected.peer, + ), + ] + + def equal_fn(p, /): + return p.actual == p.expected + + errors = api.compare( + [a, b], + [c, d], + unpack_fns=[ + builtin.unpack_fns.collections_sequence, + unpack_node, + ], + equal_fns=[equal_fn], + ) + assert len(errors) == 1 + assert errors[0].pair.index == (1, "value") + def test_unpack_fn_exception(self): exc = Exception()