From e3418eeb672fba5029b89e3345dc3473cb8e09a7 Mon Sep 17 00:00:00 2001 From: dnwpark Date: Tue, 6 Jan 2026 16:32:57 -0800 Subject: [PATCH 1/7] Add tests --- tests/test_type_eval.py | 105 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 102 insertions(+), 3 deletions(-) diff --git a/tests/test_type_eval.py b/tests/test_type_eval.py index 258806c..5693dd9 100644 --- a/tests/test_type_eval.py +++ b/tests/test_type_eval.py @@ -1,6 +1,19 @@ +import pytest + +import collections import textwrap import unittest -from typing import Any, Callable, Generic, List, Literal, Never, Tuple, TypeVar +from typing import ( + Any, + Callable, + Generic, + List, + Literal, + Never, + Tuple, + TypeVar, + Union, +) from typemap.type_eval import eval_typing from typemap.typing import ( @@ -101,6 +114,13 @@ class F[bool]: """) +type IntTree = int | list[IntTree] +type GenericTree[T] = T | list[GenericTree[T]] +type XNode[X, Y] = X | list[YNode[X, Y]] +type YNode[X, Y] = Y | list[XNode[X, Y]] +type XYTree[X, Y] = XNode[X, Y] | YNode[X, Y] + + class TA: x: int y: list[float] @@ -167,11 +187,35 @@ def test_type_strings_6(): assert d == Literal["bcd"] -def test_type_asdf(): +def _is_generic_permutation(t1, t2): + return t1.__origin__ == t2.__origin__ and collections.Counter( + t1.__args__ + ) == collections.Counter(t2.__args__) + + +def test_type_from_union_01(): d = eval_typing(FromUnion[int | bool]) arg = FromUnion[int | str] d = eval_typing(arg) - assert d == tuple[int, str] or d == tuple[str, int] + assert _is_generic_permutation(d, tuple[int, str]) + + +def test_type_from_union_02(): + d = eval_typing(FromUnion[IntTree]) + assert _is_generic_permutation(d, tuple[int, list[IntTree]]) + + +def test_type_from_union_03(): + d = eval_typing(FromUnion[GenericTree[str]]) + assert _is_generic_permutation(d, tuple[str, list[GenericTree[str]]]) + + +def test_type_from_union_04(): + d = eval_typing(FromUnion[XYTree[int, str]]) + assert _is_generic_permutation( + d, + tuple[XNode[int, str], YNode[int, str]], + ) def test_getarg_never(): @@ -330,6 +374,18 @@ def test_eval_getarg_list(): assert arg == Never +@pytest.mark.xfail(reason="Should this work?") +def test_eval_getarg_union_01(): + arg = eval_typing(GetArg[int | str, Union, 0]) + assert arg is int + + +@pytest.mark.xfail(reason="Should this work?") +def test_eval_getarg_union_02(): + arg = eval_typing(GetArg[GenericTree[int], GenericTree, 0]) + assert arg is int + + def test_eval_getarg_custom_01(): class A[T]: pass @@ -394,6 +450,49 @@ class A(Generic[T]): assert eval_typing(GetArg[t, A, 1]) == Never +@pytest.mark.xfail(reason="Should this work?") +def test_eval_getarg_custom_05(): + A = TypeVar("A") + + class ATree(Generic[A]): + val: A | list[ATree[A]] + + t = ATree[int] + assert eval_typing(GetArg[t, ATree, 0]) is int + assert eval_typing(GetArg[t, ATree, -1]) is int + assert eval_typing(GetArg[t, ATree, 1]) == Never + + t = ATree + assert eval_typing(GetArg[t, ATree, 0]) is Any + assert eval_typing(GetArg[t, ATree, -1]) is Any + assert eval_typing(GetArg[t, ATree, 1]) == Never + + +@pytest.mark.xfail(reason="Should this work?") +def test_eval_getarg_custom_06(): + A = TypeVar("A") + B = TypeVar("B") + + class ANode(Generic[A, B]): + val: A | list[BNode[A, B]] + + class BNode(Generic[A, B]): + val: B | list[ANode[A, B]] + + class ABTree(Generic[A, B]): + root: ANode[A, B] | BNode[A, B] + + t = ABTree[int, str] + assert eval_typing(GetArg[t, ABTree, 0]) is int + assert eval_typing(GetArg[t, ABTree, 1]) is str + assert eval_typing(GetArg[t, ABTree, 2]) == Never + + t = ABTree + assert eval_typing(GetArg[t, ABTree, 0]) is Any + assert eval_typing(GetArg[t, ABTree, 1]) is Any + assert eval_typing(GetArg[t, ABTree, 2]) == Never + + def test_uppercase_never(): d = eval_typing(Uppercase[Never]) assert d is Never From 46ee2d7746ba278244c7e758ceefb09d3f27a316 Mon Sep 17 00:00:00 2001 From: dnwpark Date: Tue, 6 Jan 2026 17:05:18 -0800 Subject: [PATCH 2/7] Add child contexts. --- typemap/type_eval/_eval_typing.py | 32 ++++++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/typemap/type_eval/_eval_typing.py b/typemap/type_eval/_eval_typing.py index fa4d78e..d448953 100644 --- a/typemap/type_eval/_eval_typing.py +++ b/typemap/type_eval/_eval_typing.py @@ -44,7 +44,7 @@ class _EvalProxy: @dataclasses.dataclass class EvalContext: - seen: dict[Any, Any] + seen: dict[Any, Any] = dataclasses.field(default_factory=dict) # The typing.Any is really a types.FunctionType, but mypy gets # confused and wants to treat it as a MethodType. current_alias_stack: set[types.GenericAlias | typing.Any] = ( @@ -66,9 +66,7 @@ def _ensure_context() -> typing.Iterator[EvalContext]: ctx = _current_context.get() ctx_set = False if ctx is None: - ctx = EvalContext( - seen=dict(), - ) + ctx = EvalContext() _current_context.set(ctx) ctx_set = True evaluator_token = nt.special_form_evaluator.set( @@ -92,6 +90,26 @@ def _get_current_context() -> EvalContext: return ctx +@contextlib.contextmanager +def _child_context() -> typing.Iterator[EvalContext]: + ctx = _current_context.get() + if ctx is None: + raise RuntimeError( + "type_eval._create_child_context() called outside of eval_types()" + ) + + try: + child_ctx = EvalContext( + seen=ctx.seen.copy(), + current_alias_stack=ctx.current_alias_stack.copy(), + current_alias=ctx.current_alias, + ) + _current_context.set(child_ctx) + yield child_ctx + finally: + _current_context.set(ctx) + + def eval_typing(obj: typing.Any): with _ensure_context() as ctx: return _eval_types(obj, ctx) @@ -104,7 +122,11 @@ def _eval_types(obj: typing.Any, ctx: EvalContext): # strings match if obj in ctx.seen: return ctx.seen[obj] - ctx.seen[obj] = evaled = _eval_types_impl(obj, ctx) + + with _child_context() as child_ctx: + evaled = _eval_types_impl(obj, child_ctx) + + ctx.seen[obj] = evaled return evaled From 16e5c9be04a498dfe389a9f648a56fa476a1c269 Mon Sep 17 00:00:00 2001 From: dnwpark Date: Tue, 6 Jan 2026 20:29:38 -0800 Subject: [PATCH 3/7] Prevent recursive types from interfering with each other's evaluation. --- typemap/type_eval/_eval_operators.py | 5 ++- typemap/type_eval/_eval_typing.py | 61 ++++++++++++++++++++++++++-- 2 files changed, 62 insertions(+), 4 deletions(-) diff --git a/typemap/type_eval/_eval_operators.py b/typemap/type_eval/_eval_operators.py index b5c05d3..4675c5a 100644 --- a/typemap/type_eval/_eval_operators.py +++ b/typemap/type_eval/_eval_operators.py @@ -251,7 +251,10 @@ def _eval_Members(tp, *, ctx): @type_eval.register_evaluator(FromUnion) def _eval_FromUnion(tp, *, ctx): - return tuple[*_union_elems(tp, ctx)] + if tp in ctx.known_recursive_types: + return tuple[*_union_elems(ctx.known_recursive_types[tp], ctx)] + else: + return tuple[*_union_elems(tp, ctx)] ################################################################## diff --git a/typemap/type_eval/_eval_typing.py b/typemap/type_eval/_eval_typing.py index d448953..a617ab9 100644 --- a/typemap/type_eval/_eval_typing.py +++ b/typemap/type_eval/_eval_typing.py @@ -44,6 +44,9 @@ class _EvalProxy: @dataclasses.dataclass class EvalContext: + # Fully resolved types + resolved: dict[Any, Any] = dataclasses.field(default_factory=dict) + # Types that have been seen, but may not be fully resolved seen: dict[Any, Any] = dataclasses.field(default_factory=dict) # The typing.Any is really a types.FunctionType, but mypy gets # confused and wants to treat it as a MethodType. @@ -52,6 +55,14 @@ class EvalContext: ) current_alias: types.GenericAlias | typing.Any | None = None + unwind_stack: set[typing.TypeAliasType | types.GenericAlias] = ( + dataclasses.field(default_factory=set) + ) + unwinding_until: typing.TypeAliasType | types.GenericAlias | None = None + known_recursive_types: dict[ + typing.TypeAliasType | types.GenericAlias, typing.Any + ] = dataclasses.field(default_factory=dict) + # `eval_types()` calls can be nested, context must be preserved _current_context: contextvars.ContextVar[EvalContext | None] = ( @@ -100,9 +111,22 @@ def _child_context() -> typing.Iterator[EvalContext]: try: child_ctx = EvalContext( + resolved={ + # Drop resolved recursive types. + # This is to allow other recursive types to expand them out + # independently. For example, if we have a recursive types + # A = B|C and B = A|D, we want B to expand even if we already + # know A. + k: v + for k, v in ctx.resolved.items() + if k not in ctx.known_recursive_types + }, seen=ctx.seen.copy(), current_alias_stack=ctx.current_alias_stack.copy(), current_alias=ctx.current_alias, + unwind_stack=ctx.unwind_stack.copy(), + unwinding_until=ctx.unwinding_until, + known_recursive_types=ctx.known_recursive_types.copy(), ) _current_context.set(child_ctx) yield child_ctx @@ -112,21 +136,52 @@ def _child_context() -> typing.Iterator[EvalContext]: def eval_typing(obj: typing.Any): with _ensure_context() as ctx: - return _eval_types(obj, ctx) + result = _eval_types(obj, ctx) + if result in ctx.known_recursive_types: + result = ctx.known_recursive_types[result] + return result def _eval_types(obj: typing.Any, ctx: EvalContext): + # Found a recursive type, we need to unwind it + if obj in ctx.unwind_stack: + ctx.unwinding_until = obj + return obj + # Don't recurse into any pending alias expansion if obj in ctx.current_alias_stack: return obj - # strings match + + # Already resolved or seen, return the result + if obj in ctx.resolved: + return ctx.resolved[obj] if obj in ctx.seen: return ctx.seen[obj] with _child_context() as child_ctx: + child_ctx.unwind_stack.add(obj) evaled = _eval_types_impl(obj, child_ctx) - ctx.seen[obj] = evaled + # If we have identified a recursive type, discard evaluation results. + # This prevents external evaluations from being polluted by partial + # evaluations. + keep_intermediate = True + if child_ctx.unwinding_until: + if child_ctx.unwinding_until == obj: + # Finished unwinding. + ctx.known_recursive_types[obj] = evaled + evaled = obj + keep_intermediate = False + + else: + ctx.unwinding_until = child_ctx.unwinding_until + + if keep_intermediate: + ctx.resolved |= child_ctx.resolved + ctx.seen |= child_ctx.seen + ctx.known_recursive_types |= child_ctx.known_recursive_types + + ctx.resolved[obj] = evaled return evaled From 73bc09733ed83ca631bd5272301cf5cadd712b67 Mon Sep 17 00:00:00 2001 From: dnwpark Date: Wed, 7 Jan 2026 10:49:56 -0800 Subject: [PATCH 4/7] Improve consistency --- tests/test_type_eval.py | 81 +++++++++++++++++++++++++++++++ typemap/type_eval/_eval_typing.py | 41 ++++++++++------ 2 files changed, 106 insertions(+), 16 deletions(-) diff --git a/tests/test_type_eval.py b/tests/test_type_eval.py index 5693dd9..8d25c4d 100644 --- a/tests/test_type_eval.py +++ b/tests/test_type_eval.py @@ -119,6 +119,7 @@ class F[bool]: type XNode[X, Y] = X | list[YNode[X, Y]] type YNode[X, Y] = Y | list[XNode[X, Y]] type XYTree[X, Y] = XNode[X, Y] | YNode[X, Y] +type NestedTree = str | list[NestedTree] | list[IntTree] class TA: @@ -204,11 +205,25 @@ def test_type_from_union_02(): d = eval_typing(FromUnion[IntTree]) assert _is_generic_permutation(d, tuple[int, list[IntTree]]) + d = eval_typing(GetArg[d, tuple, 1]) + assert d == list[IntTree] + d = eval_typing(GetArg[d, list, 0]) + assert d == int | list[IntTree] + d = eval_typing(FromUnion[d]) + assert _is_generic_permutation(d, tuple[int, list[IntTree]]) + def test_type_from_union_03(): d = eval_typing(FromUnion[GenericTree[str]]) assert _is_generic_permutation(d, tuple[str, list[GenericTree[str]]]) + d = eval_typing(GetArg[d, tuple, 1]) + assert d == list[GenericTree[str]] + d = eval_typing(GetArg[d, list, 0]) + assert d == str | list[GenericTree[str]] + d = eval_typing(FromUnion[d]) + assert _is_generic_permutation(d, tuple[str, list[GenericTree[str]]]) + def test_type_from_union_04(): d = eval_typing(FromUnion[XYTree[int, str]]) @@ -217,6 +232,72 @@ def test_type_from_union_04(): tuple[XNode[int, str], YNode[int, str]], ) + x = eval_typing(GetArg[d, tuple, 0]) + assert x == int | list[str | list[XNode[int, str]]] + + x = eval_typing(FromUnion[x]) + assert _is_generic_permutation( + x, tuple[int, list[str | list[XNode[int, str]]]] + ) + x = eval_typing(GetArg[x, tuple, 1]) + assert x == list[str | list[XNode[int, str]]] + x = eval_typing(GetArg[x, list, 0]) + assert x == str | list[XNode[int, str]] + x = eval_typing(FromUnion[x]) + assert _is_generic_permutation(x, tuple[str, list[XNode[int, str]]]) + x = eval_typing(GetArg[x, tuple, 1]) + assert x == list[XNode[int, str]] + x = eval_typing(GetArg[x, list, 0]) + assert x == int | list[str | list[XNode[int, str]]] + + y = eval_typing(GetArg[d, tuple, 1]) + assert y == str | list[int | list[YNode[int, str]]] + + y = eval_typing(FromUnion[y]) + assert _is_generic_permutation( + y, tuple[str, list[int | list[YNode[int, str]]]] + ) + y = eval_typing(GetArg[y, tuple, 1]) + assert y == list[int | list[YNode[int, str]]] + y = eval_typing(GetArg[y, list, 0]) + assert y == int | list[YNode[int, str]] + y = eval_typing(FromUnion[y]) + assert _is_generic_permutation(y, tuple[int, list[YNode[int, str]]]) + y = eval_typing(GetArg[y, tuple, 1]) + assert y == list[YNode[int, str]] + y = eval_typing(GetArg[y, list, 0]) + assert y == str | list[int | list[YNode[int, str]]] + + +def test_type_from_union_05(): + d = eval_typing(FromUnion[NestedTree]) + assert _is_generic_permutation( + d, + tuple[str, list[NestedTree], list[IntTree]], + ) + + n = eval_typing(GetArg[d, tuple, 1]) + assert n == list[NestedTree] + n = eval_typing(GetArg[n, list, 0]) + assert n == str | list[NestedTree] | list[IntTree] + n = eval_typing(FromUnion[n]) + assert _is_generic_permutation( + n, tuple[str, list[NestedTree], list[IntTree]] + ) + + n = eval_typing(FromUnion[GetArg[GetArg[n, tuple, 1], list, 0]]) + assert _is_generic_permutation( + n, tuple[str, list[NestedTree], list[IntTree]] + ) + + i = eval_typing(GetArg[d, tuple, 2]) + assert i == list[IntTree] + i = eval_typing(GetArg[i, list, 0]) + assert i == int | list[IntTree] + + n = eval_typing(FromUnion[GetArg[GetArg[d, tuple, 2], list, 0]]) + assert _is_generic_permutation(n, tuple[int, list[IntTree]]) + def test_getarg_never(): d = eval_typing(GetArg[Never, object, 0]) diff --git a/typemap/type_eval/_eval_typing.py b/typemap/type_eval/_eval_typing.py index a617ab9..5442227 100644 --- a/typemap/type_eval/_eval_typing.py +++ b/typemap/type_eval/_eval_typing.py @@ -158,27 +158,36 @@ def _eval_types(obj: typing.Any, ctx: EvalContext): if obj in ctx.seen: return ctx.seen[obj] - with _child_context() as child_ctx: - child_ctx.unwind_stack.add(obj) - evaled = _eval_types_impl(obj, child_ctx) + if isinstance(obj, typing.TypeAliasType) or ( + isinstance(obj, types.GenericAlias) + and isinstance(obj.__origin__, typing.TypeAliasType) + ): + with _child_context() as child_ctx: + child_ctx.unwind_stack.add(obj) + evaled = _eval_types_impl(obj, child_ctx) + else: + evaled = _eval_types_impl(obj, ctx) + child_ctx = None # If we have identified a recursive type, discard evaluation results. # This prevents external evaluations from being polluted by partial # evaluations. keep_intermediate = True - if child_ctx.unwinding_until: - if child_ctx.unwinding_until == obj: - # Finished unwinding. - ctx.known_recursive_types[obj] = evaled - evaled = obj - keep_intermediate = False - - else: - ctx.unwinding_until = child_ctx.unwinding_until - - if keep_intermediate: - ctx.resolved |= child_ctx.resolved - ctx.seen |= child_ctx.seen + if child_ctx: + if child_ctx.unwinding_until: + if child_ctx.unwinding_until == obj: + # Finished unwinding. + ctx.known_recursive_types[obj] = evaled + evaled = obj + keep_intermediate = False + + else: + ctx.unwinding_until = child_ctx.unwinding_until + + if keep_intermediate: + ctx.resolved |= child_ctx.resolved + ctx.seen |= child_ctx.seen + ctx.known_recursive_types |= child_ctx.known_recursive_types ctx.resolved[obj] = evaled From 905507f820eece9bee5ca8ef338859b12fd674b4 Mon Sep 17 00:00:00 2001 From: dnwpark Date: Wed, 7 Jan 2026 11:25:29 -0800 Subject: [PATCH 5/7] Documentation --- typemap/type_eval/_eval_typing.py | 43 ++++++++++++++++++++++++++++--- 1 file changed, 39 insertions(+), 4 deletions(-) diff --git a/typemap/type_eval/_eval_typing.py b/typemap/type_eval/_eval_typing.py index 5442227..ce4e462 100644 --- a/typemap/type_eval/_eval_typing.py +++ b/typemap/type_eval/_eval_typing.py @@ -55,6 +55,38 @@ class EvalContext: ) current_alias: types.GenericAlias | typing.Any | None = None + # We want to resolve recursive aliases correctly, but not have haphazardly + # expanded results which vary based on order of evaluation, nesting, etc. + # To produce consistent results, we leave recursive aliases unexpanded, + # unless they are the final result. + # + # For example, given A = int|list[A], + # A expands to int|list[A] + # list[A] remains as list[A] + # + # IMPLEMENTATION + # + # To achieve this behavior, we resolve recursive aliases in a way that + # prevents them from interacting with each other's evaluations. + # + # Once a recursive alias is fully resolved, we discard all intermediate + # evaluations and only keep the final result. We then mark the resolve value + # for the alias as itself, ensure that external evaluations don't expand it. + # We keep the actual expanded value in `known_recursive_types` for future + # reference. + # + # We identify recursive aliases by tracking any aliases we see in + # `unwind_stack`. If an alias is seen again, we know it is a recursive alias + # and note it in `unwinding_until`. When we finally unwind to the previous + # time we saw the alias, we know it is fully resolved. + # + # Intermediate evaluations are discarded because evaluating recursive + # generic classes use the `seen` dictionary as a cache. Sharing this cache + # would cause inconsistent expansion results. + # + # For example, given A = C|list[B] and B = D|list[A], A|B could expand to + # C|D|list[D|list[A]]|list[A] which is technically correct, but not + # consistent in the way we want. unwind_stack: set[typing.TypeAliasType | types.GenericAlias] = ( dataclasses.field(default_factory=set) ) @@ -112,8 +144,8 @@ def _child_context() -> typing.Iterator[EvalContext]: try: child_ctx = EvalContext( resolved={ - # Drop resolved recursive types. - # This is to allow other recursive types to expand them out + # Drop resolved recursive aliases. + # This is to allow other recursive aliases to expand them out # independently. For example, if we have a recursive types # A = B|C and B = A|D, we want B to expand even if we already # know A. @@ -143,7 +175,7 @@ def eval_typing(obj: typing.Any): def _eval_types(obj: typing.Any, ctx: EvalContext): - # Found a recursive type, we need to unwind it + # Found a recursive alias, we need to unwind it if obj in ctx.unwind_stack: ctx.unwinding_until = obj return obj @@ -169,7 +201,7 @@ def _eval_types(obj: typing.Any, ctx: EvalContext): evaled = _eval_types_impl(obj, ctx) child_ctx = None - # If we have identified a recursive type, discard evaluation results. + # If we have identified a recursive alias, discard evaluation results. # This prevents external evaluations from being polluted by partial # evaluations. keep_intermediate = True @@ -188,6 +220,9 @@ def _eval_types(obj: typing.Any, ctx: EvalContext): ctx.resolved |= child_ctx.resolved ctx.seen |= child_ctx.seen + # In case a child context evaluated a nested recursive alias, we can + # keep those results as they are already "consistent". + ctx.resolved |= {x: x for x in child_ctx.known_recursive_types.keys()} ctx.known_recursive_types |= child_ctx.known_recursive_types ctx.resolved[obj] = evaled From 965e3b57e57133db144e71ba94c6f09f9aca870f Mon Sep 17 00:00:00 2001 From: dnwpark Date: Wed, 7 Jan 2026 17:19:20 -0800 Subject: [PATCH 6/7] Combine eval context stacks. --- typemap/type_eval/_eval_call.py | 6 +- typemap/type_eval/_eval_operators.py | 10 +-- typemap/type_eval/_eval_typing.py | 96 +++++++++++++++------------- 3 files changed, 58 insertions(+), 54 deletions(-) diff --git a/typemap/type_eval/_eval_call.py b/typemap/type_eval/_eval_call.py index e00de08..a1291a6 100644 --- a/typemap/type_eval/_eval_call.py +++ b/typemap/type_eval/_eval_call.py @@ -112,10 +112,10 @@ def _eval_call_with_type_vars( af.__code__, af.__globals__, af.__name__, None, af_args ) - old_obj = ctx.current_alias - ctx.current_alias = func + old_obj = ctx.current_generic_alias + ctx.current_generic_alias = func try: rr = ff(annotationlib.Format.VALUE) return _eval_typing.eval_typing(rr["return"]) finally: - ctx.current_alias = old_obj + ctx.current_generic_alias = old_obj diff --git a/typemap/type_eval/_eval_operators.py b/typemap/type_eval/_eval_operators.py index 4675c5a..a0007d8 100644 --- a/typemap/type_eval/_eval_operators.py +++ b/typemap/type_eval/_eval_operators.py @@ -488,12 +488,12 @@ def _eval_NewProtocol(*etyps: Member, ctx): # If the type evaluation context ctx = type_eval._get_current_context() - if ctx.current_alias: - if isinstance(ctx.current_alias, types.GenericAlias): - name = str(ctx.current_alias) + if ctx.current_generic_alias: + if isinstance(ctx.current_generic_alias, types.GenericAlias): + name = str(ctx.current_generic_alias) else: - name = f"{ctx.current_alias.__name__}[...]" - module_name = ctx.current_alias.__module__ + name = f"{ctx.current_generic_alias.__name__}[...]" + module_name = ctx.current_generic_alias.__module__ dct["__module__"] = module_name diff --git a/typemap/type_eval/_eval_typing.py b/typemap/type_eval/_eval_typing.py index ce4e462..9d6dc4b 100644 --- a/typemap/type_eval/_eval_typing.py +++ b/typemap/type_eval/_eval_typing.py @@ -48,16 +48,12 @@ class EvalContext: resolved: dict[Any, Any] = dataclasses.field(default_factory=dict) # Types that have been seen, but may not be fully resolved seen: dict[Any, Any] = dataclasses.field(default_factory=dict) - # The typing.Any is really a types.FunctionType, but mypy gets - # confused and wants to treat it as a MethodType. - current_alias_stack: set[types.GenericAlias | typing.Any] = ( - dataclasses.field(default_factory=set) - ) - current_alias: types.GenericAlias | typing.Any | None = None - # We want to resolve recursive aliases correctly, but not have haphazardly - # expanded results which vary based on order of evaluation, nesting, etc. - # To produce consistent results, we leave recursive aliases unexpanded, + # We want to resolve recursive type aliases correctly, but not have + # haphazardly expanded results which vary based on order of evaluation, + # nesting, etc. + # + # To produce consistent results, we leave recursive type aliases unexpanded, # unless they are the final result. # # For example, given A = int|list[A], @@ -66,7 +62,7 @@ class EvalContext: # # IMPLEMENTATION # - # To achieve this behavior, we resolve recursive aliases in a way that + # To achieve this behavior, we resolve recursive type aliases in a way that # prevents them from interacting with each other's evaluations. # # Once a recursive alias is fully resolved, we discard all intermediate @@ -75,10 +71,10 @@ class EvalContext: # We keep the actual expanded value in `known_recursive_types` for future # reference. # - # We identify recursive aliases by tracking any aliases we see in - # `unwind_stack`. If an alias is seen again, we know it is a recursive alias - # and note it in `unwinding_until`. When we finally unwind to the previous - # time we saw the alias, we know it is fully resolved. + # We identify recursive type aliases by tracking any aliases we see in + # `alias_stack`. If an alias is seen again, we know it is a recursive alias + # and note it in `recursive_type_alias`. When we finally unwind to the + # previous time we saw the alias, we know it is fully resolved. # # Intermediate evaluations are discarded because evaluating recursive # generic classes use the `seen` dictionary as a cache. Sharing this cache @@ -87,14 +83,24 @@ class EvalContext: # For example, given A = C|list[B] and B = D|list[A], A|B could expand to # C|D|list[D|list[A]]|list[A] which is technically correct, but not # consistent in the way we want. - unwind_stack: set[typing.TypeAliasType | types.GenericAlias] = ( + # + # The alias stack is also used to evaluate generic classes. The current + # generic alias is tracked in `current_generic_alias`. + # See `_eval_types_generic`. + alias_stack: set[typing.TypeAliasType | types.GenericAlias] = ( dataclasses.field(default_factory=set) ) - unwinding_until: typing.TypeAliasType | types.GenericAlias | None = None + recursive_type_alias: typing.TypeAliasType | types.GenericAlias | None = ( + None + ) known_recursive_types: dict[ typing.TypeAliasType | types.GenericAlias, typing.Any ] = dataclasses.field(default_factory=dict) + # The typing.Any is really a types.FunctionType, but mypy gets + # confused and wants to treat it as a MethodType. + current_generic_alias: types.GenericAlias | typing.Any | None = None + # `eval_types()` calls can be nested, context must be preserved _current_context: contextvars.ContextVar[EvalContext | None] = ( @@ -154,11 +160,10 @@ def _child_context() -> typing.Iterator[EvalContext]: if k not in ctx.known_recursive_types }, seen=ctx.seen.copy(), - current_alias_stack=ctx.current_alias_stack.copy(), - current_alias=ctx.current_alias, - unwind_stack=ctx.unwind_stack.copy(), - unwinding_until=ctx.unwinding_until, + alias_stack=ctx.alias_stack.copy(), + recursive_type_alias=ctx.recursive_type_alias, known_recursive_types=ctx.known_recursive_types.copy(), + current_generic_alias=ctx.current_generic_alias, ) _current_context.set(child_ctx) yield child_ctx @@ -174,14 +179,18 @@ def eval_typing(obj: typing.Any): return result +def _is_type_alias_type(obj: typing.Any) -> bool: + return isinstance(obj, typing.TypeAliasType) or ( + isinstance(obj, types.GenericAlias) + and isinstance(obj.__origin__, typing.TypeAliasType) + ) + + def _eval_types(obj: typing.Any, ctx: EvalContext): # Found a recursive alias, we need to unwind it - if obj in ctx.unwind_stack: - ctx.unwinding_until = obj - return obj - - # Don't recurse into any pending alias expansion - if obj in ctx.current_alias_stack: + if obj in ctx.alias_stack: + if _is_type_alias_type(obj): + ctx.recursive_type_alias = obj return obj # Already resolved or seen, return the result @@ -190,12 +199,9 @@ def _eval_types(obj: typing.Any, ctx: EvalContext): if obj in ctx.seen: return ctx.seen[obj] - if isinstance(obj, typing.TypeAliasType) or ( - isinstance(obj, types.GenericAlias) - and isinstance(obj.__origin__, typing.TypeAliasType) - ): + if _is_type_alias_type(obj): with _child_context() as child_ctx: - child_ctx.unwind_stack.add(obj) + child_ctx.alias_stack.add(obj) evaled = _eval_types_impl(obj, child_ctx) else: evaled = _eval_types_impl(obj, ctx) @@ -206,15 +212,15 @@ def _eval_types(obj: typing.Any, ctx: EvalContext): # evaluations. keep_intermediate = True if child_ctx: - if child_ctx.unwinding_until: - if child_ctx.unwinding_until == obj: + if child_ctx.recursive_type_alias: + if child_ctx.recursive_type_alias == obj: # Finished unwinding. ctx.known_recursive_types[obj] = evaled evaled = obj keep_intermediate = False else: - ctx.unwinding_until = child_ctx.unwinding_until + ctx.recursive_type_alias = child_ctx.recursive_type_alias if keep_intermediate: ctx.resolved |= child_ctx.resolved @@ -317,22 +323,20 @@ def _eval_types_generic(obj: types.GenericAlias, ctx: EvalContext): args = tuple(types.CellType(_eval_types(arg, ctx)) for arg in obj.__args__) mod = sys.modules[obj.__module__] - old_obj = ctx.current_alias - ctx.current_alias = new_obj # alias is the new_obj, so names look better - ctx.current_alias_stack.add(new_obj) + with _child_context() as child_ctx: + child_ctx.current_generic_alias = new_obj + if not _is_type_alias_type(new_obj): + # Type alias types are already added in _eval_types + child_ctx.alias_stack.add(new_obj) - try: ff = types.FunctionType(func.__code__, mod.__dict__, None, None, args) unpacked = ff(annotationlib.Format.VALUE) - ctx.seen[obj] = unpacked - evaled = _eval_types(unpacked, ctx) - except Exception: - ctx.seen.pop(obj, None) - raise - finally: - ctx.current_alias = old_obj - ctx.current_alias_stack.remove(new_obj) + child_ctx.seen[obj] = unpacked + evaled = _eval_types(unpacked, child_ctx) + + ctx.seen[obj] = unpacked + ctx.recursive_type_alias = child_ctx.recursive_type_alias return evaled From 6f01d4bbc246b3e1b096eb44967206a2c698f4ba Mon Sep 17 00:00:00 2001 From: dnwpark Date: Wed, 7 Jan 2026 17:22:34 -0800 Subject: [PATCH 7/7] Add test for T = list[T] --- tests/test_type_eval.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/tests/test_type_eval.py b/tests/test_type_eval.py index 8d25c4d..b351759 100644 --- a/tests/test_type_eval.py +++ b/tests/test_type_eval.py @@ -114,6 +114,7 @@ class F[bool]: """) +type UnlabeledTree = list[UnlabeledTree] type IntTree = int | list[IntTree] type GenericTree[T] = T | list[GenericTree[T]] type XNode[X, Y] = X | list[YNode[X, Y]] @@ -202,6 +203,18 @@ def test_type_from_union_01(): def test_type_from_union_02(): + d = eval_typing(FromUnion[UnlabeledTree]) + assert _is_generic_permutation(d, tuple[list[UnlabeledTree]]) + + d = eval_typing(GetArg[d, tuple, 0]) + assert d == list[UnlabeledTree] + d = eval_typing(GetArg[d, list, 0]) + assert d == list[UnlabeledTree] + d = eval_typing(FromUnion[d]) + assert _is_generic_permutation(d, tuple[list[UnlabeledTree]]) + + +def test_type_from_union_03(): d = eval_typing(FromUnion[IntTree]) assert _is_generic_permutation(d, tuple[int, list[IntTree]]) @@ -213,7 +226,7 @@ def test_type_from_union_02(): assert _is_generic_permutation(d, tuple[int, list[IntTree]]) -def test_type_from_union_03(): +def test_type_from_union_04(): d = eval_typing(FromUnion[GenericTree[str]]) assert _is_generic_permutation(d, tuple[str, list[GenericTree[str]]]) @@ -225,7 +238,7 @@ def test_type_from_union_03(): assert _is_generic_permutation(d, tuple[str, list[GenericTree[str]]]) -def test_type_from_union_04(): +def test_type_from_union_05(): d = eval_typing(FromUnion[XYTree[int, str]]) assert _is_generic_permutation( d, @@ -269,7 +282,7 @@ def test_type_from_union_04(): assert y == str | list[int | list[YNode[int, str]]] -def test_type_from_union_05(): +def test_type_from_union_06(): d = eval_typing(FromUnion[NestedTree]) assert _is_generic_permutation( d,