diff --git a/tests/test_type_eval.py b/tests/test_type_eval.py index 258806c..b351759 100644 --- a/tests/test_type_eval.py +++ b/tests/test_type_eval.py @@ -1,6 +1,19 @@ +import pytest + +import collections import textwrap import unittest -from typing import Any, Callable, Generic, List, Literal, Never, Tuple, TypeVar +from typing import ( + Any, + Callable, + Generic, + List, + Literal, + Never, + Tuple, + TypeVar, + Union, +) from typemap.type_eval import eval_typing from typemap.typing import ( @@ -101,6 +114,15 @@ class F[bool]: """) +type UnlabeledTree = list[UnlabeledTree] +type IntTree = int | list[IntTree] +type GenericTree[T] = T | list[GenericTree[T]] +type XNode[X, Y] = X | list[YNode[X, Y]] +type YNode[X, Y] = Y | list[XNode[X, Y]] +type XYTree[X, Y] = XNode[X, Y] | YNode[X, Y] +type NestedTree = str | list[NestedTree] | list[IntTree] + + class TA: x: int y: list[float] @@ -167,11 +189,127 @@ def test_type_strings_6(): assert d == Literal["bcd"] -def test_type_asdf(): +def _is_generic_permutation(t1, t2): + return t1.__origin__ == t2.__origin__ and collections.Counter( + t1.__args__ + ) == collections.Counter(t2.__args__) + + +def test_type_from_union_01(): d = eval_typing(FromUnion[int | bool]) arg = FromUnion[int | str] d = eval_typing(arg) - assert d == tuple[int, str] or d == tuple[str, int] + assert _is_generic_permutation(d, tuple[int, str]) + + +def test_type_from_union_02(): + d = eval_typing(FromUnion[UnlabeledTree]) + assert _is_generic_permutation(d, tuple[list[UnlabeledTree]]) + + d = eval_typing(GetArg[d, tuple, 0]) + assert d == list[UnlabeledTree] + d = eval_typing(GetArg[d, list, 0]) + assert d == list[UnlabeledTree] + d = eval_typing(FromUnion[d]) + assert _is_generic_permutation(d, tuple[list[UnlabeledTree]]) + + +def test_type_from_union_03(): + d = eval_typing(FromUnion[IntTree]) + assert _is_generic_permutation(d, tuple[int, list[IntTree]]) + + d = eval_typing(GetArg[d, tuple, 1]) + assert d == list[IntTree] + d = eval_typing(GetArg[d, list, 0]) + assert d == int | list[IntTree] + d = eval_typing(FromUnion[d]) + assert _is_generic_permutation(d, tuple[int, list[IntTree]]) + + +def test_type_from_union_04(): + d = eval_typing(FromUnion[GenericTree[str]]) + assert _is_generic_permutation(d, tuple[str, list[GenericTree[str]]]) + + d = eval_typing(GetArg[d, tuple, 1]) + assert d == list[GenericTree[str]] + d = eval_typing(GetArg[d, list, 0]) + assert d == str | list[GenericTree[str]] + d = eval_typing(FromUnion[d]) + assert _is_generic_permutation(d, tuple[str, list[GenericTree[str]]]) + + +def test_type_from_union_05(): + d = eval_typing(FromUnion[XYTree[int, str]]) + assert _is_generic_permutation( + d, + tuple[XNode[int, str], YNode[int, str]], + ) + + x = eval_typing(GetArg[d, tuple, 0]) + assert x == int | list[str | list[XNode[int, str]]] + + x = eval_typing(FromUnion[x]) + assert _is_generic_permutation( + x, tuple[int, list[str | list[XNode[int, str]]]] + ) + x = eval_typing(GetArg[x, tuple, 1]) + assert x == list[str | list[XNode[int, str]]] + x = eval_typing(GetArg[x, list, 0]) + assert x == str | list[XNode[int, str]] + x = eval_typing(FromUnion[x]) + assert _is_generic_permutation(x, tuple[str, list[XNode[int, str]]]) + x = eval_typing(GetArg[x, tuple, 1]) + assert x == list[XNode[int, str]] + x = eval_typing(GetArg[x, list, 0]) + assert x == int | list[str | list[XNode[int, str]]] + + y = eval_typing(GetArg[d, tuple, 1]) + assert y == str | list[int | list[YNode[int, str]]] + + y = eval_typing(FromUnion[y]) + assert _is_generic_permutation( + y, tuple[str, list[int | list[YNode[int, str]]]] + ) + y = eval_typing(GetArg[y, tuple, 1]) + assert y == list[int | list[YNode[int, str]]] + y = eval_typing(GetArg[y, list, 0]) + assert y == int | list[YNode[int, str]] + y = eval_typing(FromUnion[y]) + assert _is_generic_permutation(y, tuple[int, list[YNode[int, str]]]) + y = eval_typing(GetArg[y, tuple, 1]) + assert y == list[YNode[int, str]] + y = eval_typing(GetArg[y, list, 0]) + assert y == str | list[int | list[YNode[int, str]]] + + +def test_type_from_union_06(): + d = eval_typing(FromUnion[NestedTree]) + assert _is_generic_permutation( + d, + tuple[str, list[NestedTree], list[IntTree]], + ) + + n = eval_typing(GetArg[d, tuple, 1]) + assert n == list[NestedTree] + n = eval_typing(GetArg[n, list, 0]) + assert n == str | list[NestedTree] | list[IntTree] + n = eval_typing(FromUnion[n]) + assert _is_generic_permutation( + n, tuple[str, list[NestedTree], list[IntTree]] + ) + + n = eval_typing(FromUnion[GetArg[GetArg[n, tuple, 1], list, 0]]) + assert _is_generic_permutation( + n, tuple[str, list[NestedTree], list[IntTree]] + ) + + i = eval_typing(GetArg[d, tuple, 2]) + assert i == list[IntTree] + i = eval_typing(GetArg[i, list, 0]) + assert i == int | list[IntTree] + + n = eval_typing(FromUnion[GetArg[GetArg[d, tuple, 2], list, 0]]) + assert _is_generic_permutation(n, tuple[int, list[IntTree]]) def test_getarg_never(): @@ -330,6 +468,18 @@ def test_eval_getarg_list(): assert arg == Never +@pytest.mark.xfail(reason="Should this work?") +def test_eval_getarg_union_01(): + arg = eval_typing(GetArg[int | str, Union, 0]) + assert arg is int + + +@pytest.mark.xfail(reason="Should this work?") +def test_eval_getarg_union_02(): + arg = eval_typing(GetArg[GenericTree[int], GenericTree, 0]) + assert arg is int + + def test_eval_getarg_custom_01(): class A[T]: pass @@ -394,6 +544,49 @@ class A(Generic[T]): assert eval_typing(GetArg[t, A, 1]) == Never +@pytest.mark.xfail(reason="Should this work?") +def test_eval_getarg_custom_05(): + A = TypeVar("A") + + class ATree(Generic[A]): + val: A | list[ATree[A]] + + t = ATree[int] + assert eval_typing(GetArg[t, ATree, 0]) is int + assert eval_typing(GetArg[t, ATree, -1]) is int + assert eval_typing(GetArg[t, ATree, 1]) == Never + + t = ATree + assert eval_typing(GetArg[t, ATree, 0]) is Any + assert eval_typing(GetArg[t, ATree, -1]) is Any + assert eval_typing(GetArg[t, ATree, 1]) == Never + + +@pytest.mark.xfail(reason="Should this work?") +def test_eval_getarg_custom_06(): + A = TypeVar("A") + B = TypeVar("B") + + class ANode(Generic[A, B]): + val: A | list[BNode[A, B]] + + class BNode(Generic[A, B]): + val: B | list[ANode[A, B]] + + class ABTree(Generic[A, B]): + root: ANode[A, B] | BNode[A, B] + + t = ABTree[int, str] + assert eval_typing(GetArg[t, ABTree, 0]) is int + assert eval_typing(GetArg[t, ABTree, 1]) is str + assert eval_typing(GetArg[t, ABTree, 2]) == Never + + t = ABTree + assert eval_typing(GetArg[t, ABTree, 0]) is Any + assert eval_typing(GetArg[t, ABTree, 1]) is Any + assert eval_typing(GetArg[t, ABTree, 2]) == Never + + def test_uppercase_never(): d = eval_typing(Uppercase[Never]) assert d is Never diff --git a/typemap/type_eval/_eval_call.py b/typemap/type_eval/_eval_call.py index e00de08..a1291a6 100644 --- a/typemap/type_eval/_eval_call.py +++ b/typemap/type_eval/_eval_call.py @@ -112,10 +112,10 @@ def _eval_call_with_type_vars( af.__code__, af.__globals__, af.__name__, None, af_args ) - old_obj = ctx.current_alias - ctx.current_alias = func + old_obj = ctx.current_generic_alias + ctx.current_generic_alias = func try: rr = ff(annotationlib.Format.VALUE) return _eval_typing.eval_typing(rr["return"]) finally: - ctx.current_alias = old_obj + ctx.current_generic_alias = old_obj diff --git a/typemap/type_eval/_eval_operators.py b/typemap/type_eval/_eval_operators.py index b5c05d3..a0007d8 100644 --- a/typemap/type_eval/_eval_operators.py +++ b/typemap/type_eval/_eval_operators.py @@ -251,7 +251,10 @@ def _eval_Members(tp, *, ctx): @type_eval.register_evaluator(FromUnion) def _eval_FromUnion(tp, *, ctx): - return tuple[*_union_elems(tp, ctx)] + if tp in ctx.known_recursive_types: + return tuple[*_union_elems(ctx.known_recursive_types[tp], ctx)] + else: + return tuple[*_union_elems(tp, ctx)] ################################################################## @@ -485,12 +488,12 @@ def _eval_NewProtocol(*etyps: Member, ctx): # If the type evaluation context ctx = type_eval._get_current_context() - if ctx.current_alias: - if isinstance(ctx.current_alias, types.GenericAlias): - name = str(ctx.current_alias) + if ctx.current_generic_alias: + if isinstance(ctx.current_generic_alias, types.GenericAlias): + name = str(ctx.current_generic_alias) else: - name = f"{ctx.current_alias.__name__}[...]" - module_name = ctx.current_alias.__module__ + name = f"{ctx.current_generic_alias.__name__}[...]" + module_name = ctx.current_generic_alias.__module__ dct["__module__"] = module_name diff --git a/typemap/type_eval/_eval_typing.py b/typemap/type_eval/_eval_typing.py index fa4d78e..9d6dc4b 100644 --- a/typemap/type_eval/_eval_typing.py +++ b/typemap/type_eval/_eval_typing.py @@ -44,13 +44,62 @@ class _EvalProxy: @dataclasses.dataclass class EvalContext: - seen: dict[Any, Any] - # The typing.Any is really a types.FunctionType, but mypy gets - # confused and wants to treat it as a MethodType. - current_alias_stack: set[types.GenericAlias | typing.Any] = ( + # Fully resolved types + resolved: dict[Any, Any] = dataclasses.field(default_factory=dict) + # Types that have been seen, but may not be fully resolved + seen: dict[Any, Any] = dataclasses.field(default_factory=dict) + + # We want to resolve recursive type aliases correctly, but not have + # haphazardly expanded results which vary based on order of evaluation, + # nesting, etc. + # + # To produce consistent results, we leave recursive type aliases unexpanded, + # unless they are the final result. + # + # For example, given A = int|list[A], + # A expands to int|list[A] + # list[A] remains as list[A] + # + # IMPLEMENTATION + # + # To achieve this behavior, we resolve recursive type aliases in a way that + # prevents them from interacting with each other's evaluations. + # + # Once a recursive alias is fully resolved, we discard all intermediate + # evaluations and only keep the final result. We then mark the resolve value + # for the alias as itself, ensure that external evaluations don't expand it. + # We keep the actual expanded value in `known_recursive_types` for future + # reference. + # + # We identify recursive type aliases by tracking any aliases we see in + # `alias_stack`. If an alias is seen again, we know it is a recursive alias + # and note it in `recursive_type_alias`. When we finally unwind to the + # previous time we saw the alias, we know it is fully resolved. + # + # Intermediate evaluations are discarded because evaluating recursive + # generic classes use the `seen` dictionary as a cache. Sharing this cache + # would cause inconsistent expansion results. + # + # For example, given A = C|list[B] and B = D|list[A], A|B could expand to + # C|D|list[D|list[A]]|list[A] which is technically correct, but not + # consistent in the way we want. + # + # The alias stack is also used to evaluate generic classes. The current + # generic alias is tracked in `current_generic_alias`. + # See `_eval_types_generic`. + alias_stack: set[typing.TypeAliasType | types.GenericAlias] = ( dataclasses.field(default_factory=set) ) - current_alias: types.GenericAlias | typing.Any | None = None + recursive_type_alias: typing.TypeAliasType | types.GenericAlias | None = ( + None + ) + known_recursive_types: dict[ + typing.TypeAliasType | types.GenericAlias, typing.Any + ] = dataclasses.field(default_factory=dict) + + # The typing.Any is really a types.FunctionType, but mypy gets + # confused and wants to treat it as a MethodType. + current_generic_alias: types.GenericAlias | typing.Any | None = None # `eval_types()` calls can be nested, context must be preserved @@ -66,9 +115,7 @@ def _ensure_context() -> typing.Iterator[EvalContext]: ctx = _current_context.get() ctx_set = False if ctx is None: - ctx = EvalContext( - seen=dict(), - ) + ctx = EvalContext() _current_context.set(ctx) ctx_set = True evaluator_token = nt.special_form_evaluator.set( @@ -92,19 +139,99 @@ def _get_current_context() -> EvalContext: return ctx +@contextlib.contextmanager +def _child_context() -> typing.Iterator[EvalContext]: + ctx = _current_context.get() + if ctx is None: + raise RuntimeError( + "type_eval._create_child_context() called outside of eval_types()" + ) + + try: + child_ctx = EvalContext( + resolved={ + # Drop resolved recursive aliases. + # This is to allow other recursive aliases to expand them out + # independently. For example, if we have a recursive types + # A = B|C and B = A|D, we want B to expand even if we already + # know A. + k: v + for k, v in ctx.resolved.items() + if k not in ctx.known_recursive_types + }, + seen=ctx.seen.copy(), + alias_stack=ctx.alias_stack.copy(), + recursive_type_alias=ctx.recursive_type_alias, + known_recursive_types=ctx.known_recursive_types.copy(), + current_generic_alias=ctx.current_generic_alias, + ) + _current_context.set(child_ctx) + yield child_ctx + finally: + _current_context.set(ctx) + + def eval_typing(obj: typing.Any): with _ensure_context() as ctx: - return _eval_types(obj, ctx) + result = _eval_types(obj, ctx) + if result in ctx.known_recursive_types: + result = ctx.known_recursive_types[result] + return result + + +def _is_type_alias_type(obj: typing.Any) -> bool: + return isinstance(obj, typing.TypeAliasType) or ( + isinstance(obj, types.GenericAlias) + and isinstance(obj.__origin__, typing.TypeAliasType) + ) def _eval_types(obj: typing.Any, ctx: EvalContext): - # Don't recurse into any pending alias expansion - if obj in ctx.current_alias_stack: + # Found a recursive alias, we need to unwind it + if obj in ctx.alias_stack: + if _is_type_alias_type(obj): + ctx.recursive_type_alias = obj return obj - # strings match + + # Already resolved or seen, return the result + if obj in ctx.resolved: + return ctx.resolved[obj] if obj in ctx.seen: return ctx.seen[obj] - ctx.seen[obj] = evaled = _eval_types_impl(obj, ctx) + + if _is_type_alias_type(obj): + with _child_context() as child_ctx: + child_ctx.alias_stack.add(obj) + evaled = _eval_types_impl(obj, child_ctx) + else: + evaled = _eval_types_impl(obj, ctx) + child_ctx = None + + # If we have identified a recursive alias, discard evaluation results. + # This prevents external evaluations from being polluted by partial + # evaluations. + keep_intermediate = True + if child_ctx: + if child_ctx.recursive_type_alias: + if child_ctx.recursive_type_alias == obj: + # Finished unwinding. + ctx.known_recursive_types[obj] = evaled + evaled = obj + keep_intermediate = False + + else: + ctx.recursive_type_alias = child_ctx.recursive_type_alias + + if keep_intermediate: + ctx.resolved |= child_ctx.resolved + ctx.seen |= child_ctx.seen + + # In case a child context evaluated a nested recursive alias, we can + # keep those results as they are already "consistent". + ctx.resolved |= {x: x for x in child_ctx.known_recursive_types.keys()} + ctx.known_recursive_types |= child_ctx.known_recursive_types + + ctx.resolved[obj] = evaled return evaled @@ -196,22 +323,20 @@ def _eval_types_generic(obj: types.GenericAlias, ctx: EvalContext): args = tuple(types.CellType(_eval_types(arg, ctx)) for arg in obj.__args__) mod = sys.modules[obj.__module__] - old_obj = ctx.current_alias - ctx.current_alias = new_obj # alias is the new_obj, so names look better - ctx.current_alias_stack.add(new_obj) + with _child_context() as child_ctx: + child_ctx.current_generic_alias = new_obj + if not _is_type_alias_type(new_obj): + # Type alias types are already added in _eval_types + child_ctx.alias_stack.add(new_obj) - try: ff = types.FunctionType(func.__code__, mod.__dict__, None, None, args) unpacked = ff(annotationlib.Format.VALUE) - ctx.seen[obj] = unpacked - evaled = _eval_types(unpacked, ctx) - except Exception: - ctx.seen.pop(obj, None) - raise - finally: - ctx.current_alias = old_obj - ctx.current_alias_stack.remove(new_obj) + child_ctx.seen[obj] = unpacked + evaled = _eval_types(unpacked, child_ctx) + + ctx.seen[obj] = unpacked + ctx.recursive_type_alias = child_ctx.recursive_type_alias return evaled