From 70203247aa42b6587c86946f6a2dc8fb930e7367 Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Mon, 10 Nov 2025 12:04:12 -0800 Subject: [PATCH 01/12] WIP: start moving to pure evaluator semantics --- tests/test_type_eval.py | 9 +++++++ typemap/type_eval/__init__.py | 10 +++++++- typemap/type_eval/_eval_typing.py | 32 ++++++++++++++++++++++- typemap/typing.py | 42 +++++++++++++++++++++---------- 4 files changed, 78 insertions(+), 15 deletions(-) diff --git a/tests/test_type_eval.py b/tests/test_type_eval.py index 91be373..a0a9b1d 100644 --- a/tests/test_type_eval.py +++ b/tests/test_type_eval.py @@ -159,3 +159,12 @@ def test_type_strings_5(): def test_type_strings_6(): d = eval_typing(StrSlice[Literal["abcd"], Literal[1], Literal[None]]) assert d == Literal["bcd"] + + +def test_type_asdf(): + from typemap.typing import FromUnion + + d = eval_typing(FromUnion[int | bool]) + arg = FromUnion[int | str] + d = eval_typing(arg) + assert d == tuple[int, str] or d == tuple[str, int] diff --git a/typemap/type_eval/__init__.py b/typemap/type_eval/__init__.py index 8973014..5642705 100644 --- a/typemap/type_eval/__init__.py +++ b/typemap/type_eval/__init__.py @@ -1,11 +1,19 @@ +from ._eval_typing import ( + eval_typing, + _get_current_context, + register_evaluator, + _EvalProxy, +) + +# XXX: this needs to go second due to nasty circularity -- try to fix that!! from ._eval_call import eval_call -from ._eval_typing import eval_typing, _get_current_context, _EvalProxy from ._subtype import issubtype from ._subsim import issubsimilar __all__ = ( "eval_typing", + "register_evaluator", "eval_call", "issubtype", "issubsimilar", diff --git a/typemap/type_eval/_eval_typing.py b/typemap/type_eval/_eval_typing.py index e68ebaa..1cff7c3 100644 --- a/typemap/type_eval/_eval_typing.py +++ b/typemap/type_eval/_eval_typing.py @@ -9,6 +9,8 @@ import types import typing +from typing import _GenericAlias # type: ignore [attr-defined] # noqa: PLC2701 + if typing.TYPE_CHECKING: from typing import Any @@ -19,6 +21,19 @@ __all__ = ("eval_typing",) +_eval_funcs: dict[type, typing.Callable[..., Any]] = {} + + +def register_evaluator[T: typing.Callable[..., Any]]( + typ: type, +) -> typing.Callable[[T], T]: + def func(f: T) -> T: + _eval_funcs[typ] = f + return f + + return func + + # Base type for the proxy classes we generate to hold __annotations__ class _EvalProxy: # Make sure __origin__ doesn't show up at runtime... @@ -148,11 +163,12 @@ def _eval_type_alias(obj: typing.TypeAliasType, ctx: EvalContext): @_eval_types_impl.register -def _eval_generic(obj: types.GenericAlias, ctx: EvalContext): +def _eval_types_generic(obj: types.GenericAlias, ctx: EvalContext): if isinstance(obj.__origin__, type): # This is a GenericAlias over a Python class, e.g. `dict[str, int]` # Let's reconstruct it by evaluating all arguments new_args = tuple(_eval_types(arg, ctx) for arg in obj.__args__) + return obj.__origin__[new_args] # type: ignore[index] func = obj.evaluate_value @@ -178,6 +194,20 @@ def _eval_generic(obj: types.GenericAlias, ctx: EvalContext): return evaled +@_eval_types_impl.register +def _eval_typing_generic(obj: _GenericAlias, ctx: EvalContext): + # generic *classes* are typing._GenericAlias while generic type + # aliases are # types.GenericAlias? Why in the world. + if func := _eval_funcs.get(obj.__origin__): + new_args = tuple(_eval_types(arg, ctx) for arg in obj.__args__) + ret = func(*new_args) + # return _eval_types(ret, ctx) # ??? + return ret + + # TODO: Actually evaluate in this case! + return obj + + @_eval_types_impl.register def _eval_union(obj: typing.Union, ctx: EvalContext): # type: ignore args: typing.Sequence[typing.Any] = obj.__args__ diff --git a/typemap/typing.py b/typemap/typing.py index d6ee566..fc1ae5a 100644 --- a/typemap/typing.py +++ b/typemap/typing.py @@ -144,11 +144,13 @@ def wrapper(self, *args): return wrapper -@_SpecialForm -@_lift_over_unions -def Attrs(self, tp): - o = type_eval.eval_typing(tp) - hints = get_annotated_type_hints(o, include_extras=True) +class Attrs[T]: + pass + + +@type_eval.register_evaluator(Attrs) +def _eval_attrs(tp): + hints = get_annotated_type_hints(tp, include_extras=True) return tuple[ *[ @@ -202,18 +204,21 @@ def _ann(x): return typing.Callable[params, _ann(sig.return_annotation)] -@_SpecialForm -@_lift_over_unions -def Members(self, tp): - o = type_eval.eval_typing(tp) - hints = get_annotated_type_hints(o, include_extras=True) +class Members[T]: + pass + + +@type_eval.register_evaluator(Members) +# @_lift_over_unions +def _eval_members(tp): + hints = get_annotated_type_hints(tp, include_extras=True) attrs = [ Member[typing.Literal[n], t, typing.Never, d] for n, (t, d) in hints.items() ] - for name, attr in o.__dict__.items(): + for name, attr in tp.__dict__.items(): if isinstance(attr, (types.FunctionType, types.MethodType)): if attr is typing._no_init_or_replace_init: continue @@ -233,6 +238,8 @@ def Members(self, tp): ################################################################## +# NB - Iter needs to be interpreted, I think! +# XXX: Can we figure a way around this? @_SpecialForm def Iter(self, tp): tp = type_eval.eval_typing(tp) @@ -249,8 +256,12 @@ def Iter(self, tp): ) -@_SpecialForm -def FromUnion(self, tp): +class FromUnion[T]: + pass + + +@type_eval.register_evaluator(FromUnion) +def _eval_from_union(tp): return tuple[*_union_elems(tp)] @@ -310,6 +321,11 @@ def GetArg(self, tp, base, idx) -> typing.Any: # N.B: These handle unions on their own +# NB - Is needs to be interpreted, I think! +# XXX: Can we figure a way around this? +# By registering a handler?? + + @_SpecialForm @_split_args def IsSubtype(self, lhs, rhs): From 27e8ccc8c2682e80318bfadf4ca023d73f897315 Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Mon, 10 Nov 2025 12:28:41 -0800 Subject: [PATCH 02/12] more working --- tests/test_qblike.py | 11 +++++++-- typemap/type_eval/_eval_typing.py | 9 ++++---- typemap/typing.py | 37 +++++++++++++++++++++++-------- 3 files changed, 42 insertions(+), 15 deletions(-) diff --git a/tests/test_qblike.py b/tests/test_qblike.py index 8bee327..8870589 100644 --- a/tests/test_qblike.py +++ b/tests/test_qblike.py @@ -1,5 +1,7 @@ import textwrap +from typing import Literal + from typemap.type_eval import eval_call, eval_typing from typemap.typing import ( NewProtocol, @@ -106,10 +108,15 @@ def test_qblike_3(): class select[...]: x: tests.test_qblike.Property[int] w: tests.test_qblike.Property[list[str]] - z: tests.test_qblike.Link[PropsOnly[tests.test_qblike.Tgt]] + z: tests.test_qblike.Link[PropsOnly[typemap.typing.GetArg[\ +tests.test_qblike.Link[tests.test_qblike.Tgt], tests.test_qblike.Link, 0]]] """) + # z: tests.test_qblike.Link[PropsOnly[tests.test_qblike.Tgt]] - tgt = eval_typing(GetAttr[ret, "z"].__args__[0]) + res = eval_typing(GetAttr[ret, Literal["z"]]) + tgt = res.__args__[0] + # XXX: this should probably be pre-evaluated already? + tgt = eval_typing(tgt) fmt = format_helper.format_class(tgt) assert fmt == textwrap.dedent("""\ diff --git a/typemap/type_eval/_eval_typing.py b/typemap/type_eval/_eval_typing.py index 1cff7c3..fc41be6 100644 --- a/typemap/type_eval/_eval_typing.py +++ b/typemap/type_eval/_eval_typing.py @@ -164,12 +164,13 @@ def _eval_type_alias(obj: typing.TypeAliasType, ctx: EvalContext): @_eval_types_impl.register def _eval_types_generic(obj: types.GenericAlias, ctx: EvalContext): + new_args = tuple(_eval_types(arg, ctx) for arg in obj.__args__) + + new_obj = obj.__origin__[new_args] # type: ignore[index] if isinstance(obj.__origin__, type): # This is a GenericAlias over a Python class, e.g. `dict[str, int]` # Let's reconstruct it by evaluating all arguments - new_args = tuple(_eval_types(arg, ctx) for arg in obj.__args__) - - return obj.__origin__[new_args] # type: ignore[index] + return new_obj func = obj.evaluate_value @@ -177,7 +178,7 @@ def _eval_types_generic(obj: types.GenericAlias, ctx: EvalContext): mod = sys.modules[obj.__module__] old_obj = ctx.current_alias - ctx.current_alias = obj + ctx.current_alias = new_obj # alias is the new_obj, so names look better try: ff = types.FunctionType(func.__code__, mod.__dict__, None, None, args) diff --git a/typemap/typing.py b/typemap/typing.py index fc1ae5a..84032f1 100644 --- a/typemap/typing.py +++ b/typemap/typing.py @@ -144,6 +144,17 @@ def wrapper(self, *args): return wrapper +def _lift_over_unions_new(func): + @functools.wraps(func) + def wrapper(*args): + args2 = [_union_elems(x) for x in args] + # XXX: Never + parts = [func(*x) for x in itertools.product(*args2)] + return typing.Union[*parts] + + return wrapper + + class Attrs[T]: pass @@ -209,7 +220,7 @@ class Members[T]: @type_eval.register_evaluator(Members) -# @_lift_over_unions +@_lift_over_unions_new def _eval_members(tp): hints = get_annotated_type_hints(tp, include_extras=True) @@ -268,13 +279,17 @@ def _eval_from_union(tp): ################################################################## -@_SpecialForm -@_lift_over_unions -def GetAttr(self, lhs, prop): +class GetAttr[Lhs, Prop]: + pass + + +@type_eval.register_evaluator(GetAttr) +@_lift_over_unions_new +def _eval_GetAttr(lhs, prop): # TODO: the prop missing, etc! # XXX: extras? - name = _from_literal(type_eval.eval_typing(prop)) - return typing.get_type_hints(type_eval.eval_typing(lhs))[name] + name = _from_literal(prop) + return typing.get_type_hints(lhs)[name] def _get_args(tp, base) -> typing.Any: @@ -303,9 +318,13 @@ def _get_args(tp, base) -> typing.Any: return None -@_SpecialForm -@_lift_over_unions -def GetArg(self, tp, base, idx) -> typing.Any: +class GetArg[Tp, Base, Idx: int]: + pass + + +@type_eval.register_evaluator(GetArg) +@_lift_over_unions_new +def _eval_GetArg(tp, base, idx) -> typing.Any: args = _get_args(tp, base) if args is None: return typing.Never From a09b487419dc7e0e576bd42bdcd87b1e6b1e6197 Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Mon, 10 Nov 2025 12:33:35 -0800 Subject: [PATCH 03/12] all the string operations --- typemap/typing.py | 61 ++++++++++++++++++++++++++++------------------- 1 file changed, 36 insertions(+), 25 deletions(-) diff --git a/typemap/typing.py b/typemap/typing.py index 84032f1..801ae9d 100644 --- a/typemap/typing.py +++ b/typemap/typing.py @@ -133,18 +133,6 @@ def _union_elems(tp): def _lift_over_unions(func): - @functools.wraps(func) - @_split_args - def wrapper(self, *args): - args2 = [_union_elems(x) for x in args] - # XXX: Never - parts = [func(self, *x) for x in itertools.product(*args2)] - return typing.Union[*parts] - - return wrapper - - -def _lift_over_unions_new(func): @functools.wraps(func) def wrapper(*args): args2 = [_union_elems(x) for x in args] @@ -220,7 +208,7 @@ class Members[T]: @type_eval.register_evaluator(Members) -@_lift_over_unions_new +@_lift_over_unions def _eval_members(tp): hints = get_annotated_type_hints(tp, include_extras=True) @@ -284,7 +272,7 @@ class GetAttr[Lhs, Prop]: @type_eval.register_evaluator(GetAttr) -@_lift_over_unions_new +@_lift_over_unions def _eval_GetAttr(lhs, prop): # TODO: the prop missing, etc! # XXX: extras? @@ -323,7 +311,7 @@ class GetArg[Tp, Base, Idx: int]: @type_eval.register_evaluator(GetArg) -@_lift_over_unions_new +@_lift_over_unions def _eval_GetArg(tp, base, idx) -> typing.Any: args = _get_args(tp, base) if args is None: @@ -369,21 +357,44 @@ def IsSubSimilar(self, lhs, rhs): ################################################################## -def _string_literal_op(op): - @_SpecialForm +class Uppercase[S: str]: + pass + + +class Lowercase[S: str]: + pass + + +class Capitalize[S: str]: + pass + + +class Uncapitalize[S: str]: + pass + + +class StrConcat[S: str, T: str]: + pass + + +class StrSlice[S: str, Start: int | None, End: int | None]: + pass + + +def _string_literal_op(typ, op): @_lift_over_unions - def func(self, *args): + def func(*args): return typing.Literal[op(*[_from_literal(x) for x in args])] - return func + type_eval.register_evaluator(typ)(func) -Uppercase = _string_literal_op(op=str.upper) -Lowercase = _string_literal_op(op=str.lower) -Capitalize = _string_literal_op(op=str.capitalize) -Uncapitalize = _string_literal_op(op=lambda s: s[0:1].lower() + s[1:]) -StrConcat = _string_literal_op(op=lambda s, t: s + t) -StrSlice = _string_literal_op(op=lambda s, start, end: s[start:end]) +_string_literal_op(Uppercase, op=str.upper) +_string_literal_op(Lowercase, op=str.lower) +_string_literal_op(Capitalize, op=str.capitalize) +_string_literal_op(Uncapitalize, op=lambda s: s[0:1].lower() + s[1:]) +_string_literal_op(StrConcat, op=lambda s, t: s + t) +_string_literal_op(StrSlice, op=lambda s, start, end: s[start:end]) ################################################################## From c2c75d41bed9b37900f20a35a03d323727c28771 Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Mon, 10 Nov 2025 13:57:32 -0800 Subject: [PATCH 04/12] NewProtocol --- tests/test_type_eval.py | 8 +++++--- typemap/type_eval/_eval_typing.py | 9 +++++++++ typemap/typing.py | 18 +++++++++--------- 3 files changed, 23 insertions(+), 12 deletions(-) diff --git a/tests/test_type_eval.py b/tests/test_type_eval.py index a0a9b1d..a7dcea4 100644 --- a/tests/test_type_eval.py +++ b/tests/test_type_eval.py @@ -69,15 +69,17 @@ def test_eval_types_1(): def test_eval_types_2(): evaled = eval_typing(MapRecursive[Recursive]) - # Validate that recursion worked properly and "Recursive" was only walked once - assert evaled.__annotations__["a"].__args__[0] is evaled + # FIXME, or think about: this doesn't work, we currently evaluate it to an + # *unexpanded* type alias. + # # Validate that recursion worked properly and "Recursive" was only walked once + # assert evaled.__annotations__["a"].__args__[0] is evaled assert format_helper.format_class(evaled) == textwrap.dedent("""\ class MapRecursive[tests.test_type_eval.Recursive]: n: int | typing.Literal['gotcha!'] m: str | typing.Literal['gotcha!'] t: typing.Literal[False] | typing.Literal['gotcha!'] - a: tests.test_type_eval.MapRecursive[tests.test_type_eval.Recursive] | typing.Literal['gotcha!'] + a: MapRecursive[tests.test_type_eval.Recursive] | typing.Literal['gotcha!'] fff: int | typing.Literal['gotcha!'] control: float """) diff --git a/typemap/type_eval/_eval_typing.py b/typemap/type_eval/_eval_typing.py index fc41be6..96fe88a 100644 --- a/typemap/type_eval/_eval_typing.py +++ b/typemap/type_eval/_eval_typing.py @@ -46,6 +46,9 @@ class EvalContext: seen: dict[Any, Any] # The typing.Any is really a types.FunctionType, but mypy gets # confused and wants to treat it as a MethodType. + current_alias_stack: set[types.GenericAlias | typing.Any] = ( + dataclasses.field(default_factory=set) + ) current_alias: types.GenericAlias | typing.Any | None = None @@ -88,6 +91,10 @@ def eval_typing(obj: typing.Any): def _eval_types(obj: typing.Any, ctx: EvalContext): + # Don't recurse into any pending alias expansion + if obj in ctx.current_alias_stack: + return obj + # strings match if obj in ctx.seen: return ctx.seen[obj] ctx.seen[obj] = evaled = _eval_types_impl(obj, ctx) @@ -179,6 +186,7 @@ def _eval_types_generic(obj: types.GenericAlias, ctx: EvalContext): old_obj = ctx.current_alias ctx.current_alias = new_obj # alias is the new_obj, so names look better + ctx.current_alias_stack.add(new_obj) try: ff = types.FunctionType(func.__code__, mod.__dict__, None, None, args) @@ -191,6 +199,7 @@ def _eval_types_generic(obj: types.GenericAlias, ctx: EvalContext): raise finally: ctx.current_alias = old_obj + ctx.current_alias_stack.remove(new_obj) return evaled diff --git a/typemap/typing.py b/typemap/typing.py index 801ae9d..e6b9d8d 100644 --- a/typemap/typing.py +++ b/typemap/typing.py @@ -400,25 +400,24 @@ def func(*args): ################################################################## +class NewProtocol[*T]: + pass + + # XXX: We definitely can't use the normal _SpecialForm cache here # directly, since we depend on the context's current_alias. # Maybe we can add that to the cache, though. # (Or maybe we need to never use the cache??) -@_NoCacheSpecialForm -def NewProtocol(self, val: Member | tuple[Member, ...]): - if not isinstance(val, tuple): - val = (val,) - - etyps = [type_eval.eval_typing(t) for t in val] - +@type_eval.register_evaluator(NewProtocol) +def _eval_NewProtocol(*etyps: Member): dct: dict[str, object] = {} dct["__annotations__"] = { # XXX: Should eval_typing on the etyps evaluate the arguments?? _from_literal(type_eval.eval_typing(typing.get_args(prop)[0])): # XXX: We maybe (probably?) want to eval_typing the RHS, but # we have infinite recursion issues in test_eval_types_2... - # type_eval.eval_typing(typing.get_args(prop)[1]) - typing.get_args(prop)[1] + type_eval.eval_typing(typing.get_args(prop)[1]) + # typing.get_args(prop)[1] for prop in etyps } @@ -438,4 +437,5 @@ def NewProtocol(self, val: Member | tuple[Member, ...]): mcls: type = type(typing.cast(type, typing.Protocol)) cls = mcls(name, (typing.Protocol,), dct) + cls = type_eval.eval_typing(cls) return cls From db449dc4e28885ebd670bf78a43dcc49219fe551 Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Mon, 10 Nov 2025 13:58:29 -0800 Subject: [PATCH 05/12] tweaks --- typemap/typing.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/typemap/typing.py b/typemap/typing.py index e6b9d8d..470682d 100644 --- a/typemap/typing.py +++ b/typemap/typing.py @@ -148,7 +148,7 @@ class Attrs[T]: @type_eval.register_evaluator(Attrs) -def _eval_attrs(tp): +def _eval_Attrs(tp): hints = get_annotated_type_hints(tp, include_extras=True) return tuple[ @@ -209,7 +209,7 @@ class Members[T]: @type_eval.register_evaluator(Members) @_lift_over_unions -def _eval_members(tp): +def _eval_Members(tp): hints = get_annotated_type_hints(tp, include_extras=True) attrs = [ @@ -260,7 +260,7 @@ class FromUnion[T]: @type_eval.register_evaluator(FromUnion) -def _eval_from_union(tp): +def _eval_FromUnion(tp): return tuple[*_union_elems(tp)] @@ -404,20 +404,14 @@ class NewProtocol[*T]: pass -# XXX: We definitely can't use the normal _SpecialForm cache here -# directly, since we depend on the context's current_alias. -# Maybe we can add that to the cache, though. -# (Or maybe we need to never use the cache??) @type_eval.register_evaluator(NewProtocol) def _eval_NewProtocol(*etyps: Member): dct: dict[str, object] = {} dct["__annotations__"] = { # XXX: Should eval_typing on the etyps evaluate the arguments?? - _from_literal(type_eval.eval_typing(typing.get_args(prop)[0])): - # XXX: We maybe (probably?) want to eval_typing the RHS, but - # we have infinite recursion issues in test_eval_types_2... - type_eval.eval_typing(typing.get_args(prop)[1]) - # typing.get_args(prop)[1] + _from_literal( + type_eval.eval_typing(typing.get_args(prop)[0]) + ): type_eval.eval_typing(typing.get_args(prop)[1]) for prop in etyps } From 0a85ab933853f5638a8dbd78fbbcb0ba14ade89f Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Tue, 11 Nov 2025 08:49:07 -0800 Subject: [PATCH 06/12] Start moving everything to a new file --- typemap/type_eval/__init__.py | 3 + typemap/type_eval/_eval_operators.py | 281 ++++++++++++++++++++++++ typemap/type_eval/_eval_typing.py | 1 + typemap/typing.py | 306 +++------------------------ 4 files changed, 317 insertions(+), 274 deletions(-) create mode 100644 typemap/type_eval/_eval_operators.py diff --git a/typemap/type_eval/__init__.py b/typemap/type_eval/__init__.py index 5642705..b28eabf 100644 --- a/typemap/type_eval/__init__.py +++ b/typemap/type_eval/__init__.py @@ -10,6 +10,9 @@ from ._subtype import issubtype from ._subsim import issubsimilar +# This one is imported for registering handlers +from . import _eval_operators # noqa + __all__ = ( "eval_typing", diff --git a/typemap/type_eval/_eval_operators.py b/typemap/type_eval/_eval_operators.py new file mode 100644 index 0000000..c060c9b --- /dev/null +++ b/typemap/type_eval/_eval_operators.py @@ -0,0 +1,281 @@ +import functools +import inspect +import itertools +import types +import typing + +from typemap import type_eval +from typemap.type_eval import _typing_inspect + +from typemap.typing import ( + Attrs, + Member, + Members, + NewProtocol, + Param, + FromUnion, + GetArg, + GetAttr, + Capitalize, + Uncapitalize, + Uppercase, + Lowercase, + StrConcat, + StrSlice, +) + + +# _SpecialForm: typing.Any = typing._SpecialForm + + +def _from_literal(val): + val = type_eval.eval_typing(val) + if _typing_inspect.is_literal(val): + val = val.__args__[0] + return val + + +################################################################## + + +def get_annotated_type_hints(cls, **kwargs): + """Get the type hints for a cls annotated with definition site. + + This traverses the mro and finds the definition site for each annotation. + """ + ohints = typing.get_type_hints(cls, **kwargs) + hints = {} + for acls in cls.__mro__: + if not hasattr(acls, "__annotations__"): + continue + for k in acls.__annotations__: + if k not in hints: + hints[k] = ohints[k], acls + + # Stop early if we are done. + if len(hints) == len(ohints): + break + return hints + + +def _split_args(func): + @functools.wraps(func) + def wrapper(self, arg): + if isinstance(arg, tuple): + return func(self, *arg) + else: + return func(self, arg) + + return wrapper + + +def _union_elems(tp): + tp = type_eval.eval_typing(tp) + if isinstance(tp, types.UnionType): + return tuple(y for x in tp.__args__ for y in _union_elems(x)) + elif _typing_inspect.is_literal(tp) and len(tp.__args__) > 1: + return tuple(typing.Literal[x] for x in tp.__args__) + else: + return (tp,) + + +def _lift_over_unions(func): + @functools.wraps(func) + def wrapper(*args): + args2 = [_union_elems(x) for x in args] + # XXX: Never + parts = [func(*x) for x in itertools.product(*args2)] + return typing.Union[*parts] + + return wrapper + + +@type_eval.register_evaluator(Attrs) +def _eval_Attrs(tp): + hints = get_annotated_type_hints(tp, include_extras=True) + + return tuple[ + *[ + Member[typing.Literal[n], t, typing.Never, d] + for n, (t, d) in hints.items() + ] + ] + + +################################################################## + + +def _function_type(func, *, is_method): + root = inspect.unwrap(func) + sig = inspect.signature(root) + # XXX: __type_params__!!! + + empty = inspect.Parameter.empty + + def _ann(x): + return typing.Any if x is empty else x + + params = [] + for _i, p in enumerate(sig.parameters.values()): + # XXX: what should we do about self? + # should we track classmethod/staticmethod somehow? + # mypy stores all this stuff in the SymbolNodes (FuncDef, etc), + # even though it kind of really is a type/descriptor thing + # if i == 0 and is_method: + # continue + has_name = p.kind in ( + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + ) + quals = [] + if p.kind == inspect.Parameter.VAR_POSITIONAL: + quals.append("*") + if p.kind == inspect.Parameter.VAR_KEYWORD: + quals.append("**") + if p.default is not empty: + quals.append("=") + params.append( + Param[ + typing.Literal[p.name if has_name else None], + _ann(p.annotation), + typing.Literal[*quals] if quals else typing.Never, + ] + ) + + return typing.Callable[params, _ann(sig.return_annotation)] + + +@type_eval.register_evaluator(Members) +@_lift_over_unions +def _eval_Members(tp): + hints = get_annotated_type_hints(tp, include_extras=True) + + attrs = [ + Member[typing.Literal[n], t, typing.Never, d] + for n, (t, d) in hints.items() + ] + + for name, attr in tp.__dict__.items(): + if isinstance(attr, (types.FunctionType, types.MethodType)): + if attr is typing._no_init_or_replace_init: + continue + + # XXX: populate the source field + attrs.append( + Member[ + typing.Literal[name], + _function_type(attr, is_method=True), + typing.Literal["ClassVar"], + ] + ) + + return tuple[*attrs] + + +################################################################## + + +@type_eval.register_evaluator(FromUnion) +def _eval_FromUnion(tp): + return tuple[*_union_elems(tp)] + + +################################################################## + + +@type_eval.register_evaluator(GetAttr) +@_lift_over_unions +def _eval_GetAttr(lhs, prop): + # TODO: the prop missing, etc! + # XXX: extras? + name = _from_literal(prop) + return typing.get_type_hints(lhs)[name] + + +def _get_args(tp, base) -> typing.Any: + # XXX: check against base!! + evaled = type_eval.eval_typing(tp) + + tp_head = _typing_inspect.get_head(tp) + base_head = _typing_inspect.get_head(base) + # XXX: not sure this is what we want! + # at the very least we want unions I think + if not tp_head or not base_head: + return None + + if tp_head is base_head: + return typing.get_args(evaled) + + # Scan the fully-annotated MRO to find the base + elif gen_mro := getattr(evaled, "__generalized_mro__", None): + for box in gen_mro: + if box.cls is base_head: + return tuple(box.args.values()) + return None + + else: + # or error?? + return None + + +@type_eval.register_evaluator(GetArg) +@_lift_over_unions +def _eval_GetArg(tp, base, idx) -> typing.Any: + args = _get_args(tp, base) + if args is None: + return typing.Never + + try: + return args[_from_literal(idx)] + except IndexError: + return typing.Never + + +def _string_literal_op(typ, op): + @_lift_over_unions + def func(*args): + return typing.Literal[op(*[_from_literal(x) for x in args])] + + type_eval.register_evaluator(typ)(func) + + +_string_literal_op(Uppercase, op=str.upper) +_string_literal_op(Lowercase, op=str.lower) +_string_literal_op(Capitalize, op=str.capitalize) +_string_literal_op(Uncapitalize, op=lambda s: s[0:1].lower() + s[1:]) +_string_literal_op(StrConcat, op=lambda s, t: s + t) +_string_literal_op(StrSlice, op=lambda s, start, end: s[start:end]) + + +################################################################## + + +@type_eval.register_evaluator(NewProtocol) +def _eval_NewProtocol(*etyps: Member): + dct: dict[str, object] = {} + dct["__annotations__"] = { + # XXX: Should eval_typing on the etyps evaluate the arguments?? + _from_literal( + type_eval.eval_typing(typing.get_args(prop)[0]) + ): type_eval.eval_typing(typing.get_args(prop)[1]) + for prop in etyps + } + + module_name = __name__ + name = "NewProtocol" + + # If the type evaluation context + ctx = type_eval._get_current_context() + if ctx.current_alias: + if isinstance(ctx.current_alias, types.GenericAlias): + name = str(ctx.current_alias) + else: + name = f"{ctx.current_alias.__name__}[...]" + module_name = ctx.current_alias.__module__ + + dct["__module__"] = module_name + + mcls: type = type(typing.cast(type, typing.Protocol)) + cls = mcls(name, (typing.Protocol,), dct) + cls = type_eval.eval_typing(cls) + return cls diff --git a/typemap/type_eval/_eval_typing.py b/typemap/type_eval/_eval_typing.py index 96fe88a..9ad7eb8 100644 --- a/typemap/type_eval/_eval_typing.py +++ b/typemap/type_eval/_eval_typing.py @@ -28,6 +28,7 @@ def register_evaluator[T: typing.Callable[..., Any]]( typ: type, ) -> typing.Callable[[T], T]: def func(f: T) -> T: + assert typ not in _eval_funcs _eval_funcs[typ] = f return f diff --git a/typemap/typing.py b/typemap/typing.py index 470682d..031a657 100644 --- a/typemap/typing.py +++ b/typemap/typing.py @@ -2,7 +2,6 @@ import functools import inspect -import itertools import types import typing @@ -71,13 +70,6 @@ def CallSpecKwargs(self, spec: _CallSpecWrapper): ################################################################## -def _from_literal(val): - val = type_eval.eval_typing(val) - if _typing_inspect.is_literal(val): - val = val.__args__[0] - return val - - class Member[N: str, T, Q: str = typing.Never, D = typing.Never]: pass @@ -88,153 +80,70 @@ class Member[N: str, T, Q: str = typing.Never, D = typing.Never]: type GetDefiner[T: Member] = GetArg[T, Member, 3] # type: ignore[valid-type] -################################################################## - - -def get_annotated_type_hints(cls, **kwargs): - """Get the type hints for a cls annotated with definition site. - - This traverses the mro and finds the definition site for each annotation. - """ - ohints = typing.get_type_hints(cls, **kwargs) - hints = {} - for acls in cls.__mro__: - if not hasattr(acls, "__annotations__"): - continue - for k in acls.__annotations__: - if k not in hints: - hints[k] = ohints[k], acls +class Attrs[T]: + pass - # Stop early if we are done. - if len(hints) == len(ohints): - break - return hints +class Param[N: str | None, T, Q: str = typing.Never]: + pass -def _split_args(func): - @functools.wraps(func) - def wrapper(self, arg): - if isinstance(arg, tuple): - return func(self, *arg) - else: - return func(self, arg) - - return wrapper +class Members[T]: + pass -def _union_elems(tp): - tp = type_eval.eval_typing(tp) - if isinstance(tp, types.UnionType): - return tuple(y for x in tp.__args__ for y in _union_elems(x)) - elif _typing_inspect.is_literal(tp) and len(tp.__args__) > 1: - return tuple(typing.Literal[x] for x in tp.__args__) - else: - return (tp,) +class FromUnion[T]: + pass -def _lift_over_unions(func): - @functools.wraps(func) - def wrapper(*args): - args2 = [_union_elems(x) for x in args] - # XXX: Never - parts = [func(*x) for x in itertools.product(*args2)] - return typing.Union[*parts] - return wrapper +class GetAttr[Lhs, Prop]: + pass -class Attrs[T]: +class GetArg[Tp, Base, Idx: int]: pass -@type_eval.register_evaluator(Attrs) -def _eval_Attrs(tp): - hints = get_annotated_type_hints(tp, include_extras=True) - - return tuple[ - *[ - Member[typing.Literal[n], t, typing.Never, d] - for n, (t, d) in hints.items() - ] - ] +class Uppercase[S: str]: + pass -class Param[N: str | None, T, Q: str = typing.Never]: +class Lowercase[S: str]: pass -def _function_type(func, *, is_method): - root = inspect.unwrap(func) - sig = inspect.signature(root) - # XXX: __type_params__!!! +class Capitalize[S: str]: + pass - empty = inspect.Parameter.empty - def _ann(x): - return typing.Any if x is empty else x +class Uncapitalize[S: str]: + pass - params = [] - for _i, p in enumerate(sig.parameters.values()): - # XXX: what should we do about self? - # should we track classmethod/staticmethod somehow? - # mypy stores all this stuff in the SymbolNodes (FuncDef, etc), - # even though it kind of really is a type/descriptor thing - # if i == 0 and is_method: - # continue - has_name = p.kind in ( - inspect.Parameter.POSITIONAL_OR_KEYWORD, - inspect.Parameter.KEYWORD_ONLY, - ) - quals = [] - if p.kind == inspect.Parameter.VAR_POSITIONAL: - quals.append("*") - if p.kind == inspect.Parameter.VAR_KEYWORD: - quals.append("**") - if p.default is not empty: - quals.append("=") - params.append( - Param[ - typing.Literal[p.name if has_name else None], - _ann(p.annotation), - typing.Literal[*quals] if quals else typing.Never, - ] - ) - return typing.Callable[params, _ann(sig.return_annotation)] +class StrConcat[S: str, T: str]: + pass -class Members[T]: +class StrSlice[S: str, Start: int | None, End: int | None]: pass -@type_eval.register_evaluator(Members) -@_lift_over_unions -def _eval_Members(tp): - hints = get_annotated_type_hints(tp, include_extras=True) - - attrs = [ - Member[typing.Literal[n], t, typing.Never, d] - for n, (t, d) in hints.items() - ] +class NewProtocol[*T]: + pass - for name, attr in tp.__dict__.items(): - if isinstance(attr, (types.FunctionType, types.MethodType)): - if attr is typing._no_init_or_replace_init: - continue - # XXX: populate the source field - attrs.append( - Member[ - typing.Literal[name], - _function_type(attr, is_method=True), - typing.Literal["ClassVar"], - ] - ) +################################################################## - return tuple[*attrs] +def _split_args(func): + @functools.wraps(func) + def wrapper(self, arg): + if isinstance(arg, tuple): + return func(self, *arg) + else: + return func(self, arg) -################################################################## + return wrapper # NB - Iter needs to be interpreted, I think! @@ -255,76 +164,6 @@ def Iter(self, tp): ) -class FromUnion[T]: - pass - - -@type_eval.register_evaluator(FromUnion) -def _eval_FromUnion(tp): - return tuple[*_union_elems(tp)] - - -################################################################## - - -class GetAttr[Lhs, Prop]: - pass - - -@type_eval.register_evaluator(GetAttr) -@_lift_over_unions -def _eval_GetAttr(lhs, prop): - # TODO: the prop missing, etc! - # XXX: extras? - name = _from_literal(prop) - return typing.get_type_hints(lhs)[name] - - -def _get_args(tp, base) -> typing.Any: - # XXX: check against base!! - evaled = type_eval.eval_typing(tp) - - tp_head = _typing_inspect.get_head(tp) - base_head = _typing_inspect.get_head(base) - # XXX: not sure this is what we want! - # at the very least we want unions I think - if not tp_head or not base_head: - return None - - if tp_head is base_head: - return typing.get_args(evaled) - - # Scan the fully-annotated MRO to find the base - elif gen_mro := getattr(evaled, "__generalized_mro__", None): - for box in gen_mro: - if box.cls is base_head: - return tuple(box.args.values()) - return None - - else: - # or error?? - return None - - -class GetArg[Tp, Base, Idx: int]: - pass - - -@type_eval.register_evaluator(GetArg) -@_lift_over_unions -def _eval_GetArg(tp, base, idx) -> typing.Any: - args = _get_args(tp, base) - if args is None: - return typing.Never - - try: - return args[_from_literal(idx)] - except IndexError: - return typing.Never - - -################################################################## - # N.B: These handle unions on their own @@ -352,84 +191,3 @@ def IsSubSimilar(self, lhs, rhs): Is = IsSubSimilar - - -################################################################## - - -class Uppercase[S: str]: - pass - - -class Lowercase[S: str]: - pass - - -class Capitalize[S: str]: - pass - - -class Uncapitalize[S: str]: - pass - - -class StrConcat[S: str, T: str]: - pass - - -class StrSlice[S: str, Start: int | None, End: int | None]: - pass - - -def _string_literal_op(typ, op): - @_lift_over_unions - def func(*args): - return typing.Literal[op(*[_from_literal(x) for x in args])] - - type_eval.register_evaluator(typ)(func) - - -_string_literal_op(Uppercase, op=str.upper) -_string_literal_op(Lowercase, op=str.lower) -_string_literal_op(Capitalize, op=str.capitalize) -_string_literal_op(Uncapitalize, op=lambda s: s[0:1].lower() + s[1:]) -_string_literal_op(StrConcat, op=lambda s, t: s + t) -_string_literal_op(StrSlice, op=lambda s, start, end: s[start:end]) - - -################################################################## - - -class NewProtocol[*T]: - pass - - -@type_eval.register_evaluator(NewProtocol) -def _eval_NewProtocol(*etyps: Member): - dct: dict[str, object] = {} - dct["__annotations__"] = { - # XXX: Should eval_typing on the etyps evaluate the arguments?? - _from_literal( - type_eval.eval_typing(typing.get_args(prop)[0]) - ): type_eval.eval_typing(typing.get_args(prop)[1]) - for prop in etyps - } - - module_name = __name__ - name = "NewProtocol" - - # If the type evaluation context - ctx = type_eval._get_current_context() - if ctx.current_alias: - if isinstance(ctx.current_alias, types.GenericAlias): - name = str(ctx.current_alias) - else: - name = f"{ctx.current_alias.__name__}[...]" - module_name = ctx.current_alias.__module__ - - dct["__module__"] = module_name - - mcls: type = type(typing.cast(type, typing.Protocol)) - cls = mcls(name, (typing.Protocol,), dct) - cls = type_eval.eval_typing(cls) - return cls From be3ac64d793f00894b1490667a235eaf02166b09 Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Wed, 12 Nov 2025 10:25:20 -0800 Subject: [PATCH 07/12] drop _NoCacheSpecialForm --- typemap/typing.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/typemap/typing.py b/typemap/typing.py index 031a657..9a93321 100644 --- a/typemap/typing.py +++ b/typemap/typing.py @@ -12,11 +12,6 @@ _SpecialForm: typing.Any = typing._SpecialForm -class _NoCacheSpecialForm(_SpecialForm, _root=True): # type: ignore[call-arg] - def __getitem__(self, parameters): - return self._getitem(self, parameters) - - @dataclass(frozen=True) class CallSpec: pass From 106d77cdc2c03fae990bc3f04532584de273c5f8 Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Wed, 12 Nov 2025 14:26:22 -0800 Subject: [PATCH 08/12] more CallSpecWrapper also --- typemap/type_eval/_eval_call.py | 4 ++- typemap/type_eval/_eval_operators.py | 39 +++++++++++++++++++++++++--- typemap/typing.py | 32 +++-------------------- 3 files changed, 41 insertions(+), 34 deletions(-) diff --git a/typemap/type_eval/_eval_call.py b/typemap/type_eval/_eval_call.py index 415fa26..248030d 100644 --- a/typemap/type_eval/_eval_call.py +++ b/typemap/type_eval/_eval_call.py @@ -27,7 +27,9 @@ def _eval_call( params = func.__type_params__ for p in params: if hasattr(p, "__bound__") and p.__bound__ is next.CallSpec: - vars[p.__name__] = next._CallSpecWrapper(args, kwargs, func) + vars[p.__name__] = next._CallSpecWrapper( + args, tuple(kwargs.items()), func + ) else: vars[p.__name__] = p diff --git a/typemap/type_eval/_eval_operators.py b/typemap/type_eval/_eval_operators.py index c060c9b..715aed2 100644 --- a/typemap/type_eval/_eval_operators.py +++ b/typemap/type_eval/_eval_operators.py @@ -9,6 +9,8 @@ from typemap.typing import ( Attrs, + CallSpecKwargs, + _CallSpecWrapper, Member, Members, NewProtocol, @@ -25,7 +27,7 @@ ) -# _SpecialForm: typing.Any = typing._SpecialForm +################################################################## def _from_literal(val): @@ -35,9 +37,6 @@ def _from_literal(val): return val -################################################################## - - def get_annotated_type_hints(cls, **kwargs): """Get the type hints for a cls annotated with definition site. @@ -105,6 +104,38 @@ def _eval_Attrs(tp): ################################################################## +@type_eval.register_evaluator(CallSpecKwargs) +def eval_CallSpecKwargs(spec: _CallSpecWrapper): + ff = types.FunctionType( + spec._func.__code__, + spec._func.__globals__, + spec._func.__name__, + None, + (), + ) + + # We can't call `inspect.signature` on `spec` directly -- + # signature() will attempt to resolve annotations and fail. + # So we run it on a copy of the function that doesn't have + # annotations set. + sig = inspect.signature(ff) + bound = sig.bind(*spec._args, **dict(spec._kwargs)) + + # TODO: Get the real type instead of Never + return tuple[ # type: ignore[misc] + *[ + Member[ + typing.Literal[name], # type: ignore[valid-type] + typing.Never, + ] + for name in bound.kwargs + ] + ] + + +################################################################## + + def _function_type(func, *, is_method): root = inspect.unwrap(func) sig = inspect.signature(root) diff --git a/typemap/typing.py b/typemap/typing.py index 9a93321..4239d52 100644 --- a/typemap/typing.py +++ b/typemap/typing.py @@ -1,7 +1,6 @@ from dataclasses import dataclass import functools -import inspect import types import typing @@ -20,7 +19,7 @@ class CallSpec: @dataclass(frozen=True) class _CallSpecWrapper: _args: tuple[typing.Any] - _kwargs: dict[str, typing.Any] + _kwargs: tuple[tuple[str, typing.Any], ...] # TODO: Support MethodType! _func: types.FunctionType # | types.MethodType @@ -33,33 +32,8 @@ def kwargs(self) -> None: pass -@_SpecialForm -def CallSpecKwargs(self, spec: _CallSpecWrapper): - ff = types.FunctionType( - spec._func.__code__, - spec._func.__globals__, - spec._func.__name__, - None, - (), - ) - - # We can't call `inspect.signature` on `spec` directly -- - # signature() will attempt to resolve annotations and fail. - # So we run it on a copy of the function that doesn't have - # annotations set. - sig = inspect.signature(ff) - bound = sig.bind(*spec._args, **spec._kwargs) - - # TODO: Get the real type instead of Never - return tuple[ # type: ignore[misc] - *[ - Member[ - typing.Literal[name], # type: ignore[valid-type] - typing.Never, - ] - for name in bound.kwargs - ] - ] +class CallSpecKwargs[Spec]: + pass ################################################################## From 46dd7a2a402190e1cd229fce736098f1be60f5f0 Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Wed, 12 Nov 2025 15:06:51 -0800 Subject: [PATCH 09/12] Sink a function down a bit --- typemap/typing.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/typemap/typing.py b/typemap/typing.py index 4239d52..a7faa0f 100644 --- a/typemap/typing.py +++ b/typemap/typing.py @@ -104,17 +104,6 @@ class NewProtocol[*T]: ################################################################## -def _split_args(func): - @functools.wraps(func) - def wrapper(self, arg): - if isinstance(arg, tuple): - return func(self, *arg) - else: - return func(self, arg) - - return wrapper - - # NB - Iter needs to be interpreted, I think! # XXX: Can we figure a way around this? @_SpecialForm @@ -136,6 +125,17 @@ def Iter(self, tp): # N.B: These handle unions on their own +def _split_args(func): + @functools.wraps(func) + def wrapper(self, arg): + if isinstance(arg, tuple): + return func(self, *arg) + else: + return func(self, arg) + + return wrapper + + # NB - Is needs to be interpreted, I think! # XXX: Can we figure a way around this? # By registering a handler?? From 0ab0cf96721af274d88ebed38c9a34527e0e3610 Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Wed, 12 Nov 2025 15:25:47 -0800 Subject: [PATCH 10/12] Make Iter and Is get evaluated by a contextvar specified hook --- tests/test_type_dir.py | 4 +- typemap/type_eval/_eval_operators.py | 72 ++++++++++++++++++-------- typemap/type_eval/_eval_typing.py | 7 +++ typemap/typing.py | 75 +++++++++++----------------- 4 files changed, 89 insertions(+), 69 deletions(-) diff --git a/tests/test_type_dir.py b/tests/test_type_dir.py index e629f2d..1f01146 100644 --- a/tests/test_type_dir.py +++ b/tests/test_type_dir.py @@ -232,7 +232,9 @@ class NoLiterals2[tests.test_type_dir.Final]: def test_type_dir_7(): d = eval_typing(Members[Final]) - foo = next(iter(m for m in Iter[d] if m.__args__[0].__args__[0] == "foo")) + foo = next( + iter(m for m in d.__args__ if m.__args__[0].__args__[0] == "foo") + ) # XXX: drop self? assert ( str(foo) diff --git a/typemap/type_eval/_eval_operators.py b/typemap/type_eval/_eval_operators.py index 715aed2..35b9f76 100644 --- a/typemap/type_eval/_eval_operators.py +++ b/typemap/type_eval/_eval_operators.py @@ -11,6 +11,9 @@ Attrs, CallSpecKwargs, _CallSpecWrapper, + Iter, + IsSubtype, + IsSubSimilar, Member, Members, NewProtocol, @@ -57,17 +60,6 @@ def get_annotated_type_hints(cls, **kwargs): return hints -def _split_args(func): - @functools.wraps(func) - def wrapper(self, arg): - if isinstance(arg, tuple): - return func(self, *arg) - else: - return func(self, arg) - - return wrapper - - def _union_elems(tp): tp = type_eval.eval_typing(tp) if isinstance(tp, types.UnionType): @@ -89,23 +81,49 @@ def wrapper(*args): return wrapper -@type_eval.register_evaluator(Attrs) -def _eval_Attrs(tp): - hints = get_annotated_type_hints(tp, include_extras=True) +################################################################## - return tuple[ - *[ - Member[typing.Literal[n], t, typing.Never, d] - for n, (t, d) in hints.items() - ] - ] + +@type_eval.register_evaluator(Iter) +def _eval_Iter(tp): + tp = type_eval.eval_typing(tp) + if ( + _typing_inspect.is_generic_alias(tp) + and tp.__origin__ is tuple + and (not tp.__args__ or tp.__args__[-1] is not Ellipsis) + ): + return iter(tp.__args__) + else: + # XXX: Or should we return []? + raise TypeError( + f"Invalid type argument to Iter: {tp} is not a fixed-length tuple" + ) + + +# N.B: These handle unions on their own + + +@type_eval.register_evaluator(IsSubtype) +def _eval_IsSubtype(lhs, rhs): + return type_eval.issubtype( + type_eval.eval_typing(lhs), + type_eval.eval_typing(rhs), + ) + + +@type_eval.register_evaluator(IsSubSimilar) +def _eval_IsSubSimilar(lhs, rhs): + return type_eval.issubsimilar( + type_eval.eval_typing(lhs), + type_eval.eval_typing(rhs), + ) ################################################################## @type_eval.register_evaluator(CallSpecKwargs) -def eval_CallSpecKwargs(spec: _CallSpecWrapper): +def _eval_CallSpecKwargs(spec: _CallSpecWrapper): ff = types.FunctionType( spec._func.__code__, spec._func.__globals__, @@ -176,6 +194,18 @@ def _ann(x): return typing.Callable[params, _ann(sig.return_annotation)] +@type_eval.register_evaluator(Attrs) +def _eval_Attrs(tp): + hints = get_annotated_type_hints(tp, include_extras=True) + + return tuple[ + *[ + Member[typing.Literal[n], t, typing.Never, d] + for n, (t, d) in hints.items() + ] + ] + + @type_eval.register_evaluator(Members) @_lift_over_unions def _eval_Members(tp): diff --git a/typemap/type_eval/_eval_typing.py b/typemap/type_eval/_eval_typing.py index 9ad7eb8..a797f3c 100644 --- a/typemap/type_eval/_eval_typing.py +++ b/typemap/type_eval/_eval_typing.py @@ -61,6 +61,8 @@ class EvalContext: @contextlib.contextmanager def _ensure_context() -> typing.Iterator[EvalContext]: + import typemap.typing as nt + ctx = _current_context.get() ctx_set = False if ctx is None: @@ -69,12 +71,17 @@ def _ensure_context() -> typing.Iterator[EvalContext]: ) _current_context.set(ctx) ctx_set = True + old_evaluator = nt.special_form_evaluator.get() + if old_evaluator is not eval_typing: + nt.special_form_evaluator.set(eval_typing) try: yield ctx finally: if ctx_set: _current_context.set(None) + if old_evaluator is not eval_typing: + nt.special_form_evaluator.set(old_evaluator) def _get_current_context() -> EvalContext: diff --git a/typemap/typing.py b/typemap/typing.py index a7faa0f..3058d18 100644 --- a/typemap/typing.py +++ b/typemap/typing.py @@ -1,11 +1,9 @@ from dataclasses import dataclass -import functools +import contextvars import types import typing - -from typemap import type_eval -from typemap.type_eval import _typing_inspect +from typing import _GenericAlias # type: ignore _SpecialForm: typing.Any = typing._SpecialForm @@ -103,60 +101,43 @@ class NewProtocol[*T]: ################################################################## +# TODO: type better +special_form_evaluator: contextvars.ContextVar[ + typing.Callable[[typing.Any], typing.Any] | None +] = contextvars.ContextVar("special_form_evaluator", default=None) -# NB - Iter needs to be interpreted, I think! -# XXX: Can we figure a way around this? -@_SpecialForm -def Iter(self, tp): - tp = type_eval.eval_typing(tp) - if ( - _typing_inspect.is_generic_alias(tp) - and tp.__origin__ is tuple - and (not tp.__args__ or tp.__args__[-1] is not Ellipsis) - ): - return tp.__args__ - else: - # XXX: Or should we return []? - raise TypeError( - f"Invalid type argument to Iter: {tp} is not a fixed-length tuple" - ) - - -# N.B: These handle unions on their own - - -def _split_args(func): - @functools.wraps(func) - def wrapper(self, arg): - if isinstance(arg, tuple): - return func(self, *arg) + +class _IterGenericAlias(_GenericAlias, _root=True): # type: ignore[call-arg] + def __iter__(self): + evaluator = special_form_evaluator.get() + if evaluator: + return evaluator(self) else: - return func(self, arg) + return super().__iter__() - return wrapper +@_SpecialForm +def Iter(self, tp): + return _IterGenericAlias(self, (tp,)) -# NB - Is needs to be interpreted, I think! -# XXX: Can we figure a way around this? -# By registering a handler?? + +class _IsGenericAlias(_GenericAlias, _root=True): # type: ignore[call-arg] + def __bool__(self): + evaluator = special_form_evaluator.get() + if evaluator: + return evaluator(self) + else: + raise TypeError(f"No evaluator provided for {self}") @_SpecialForm -@_split_args -def IsSubtype(self, lhs, rhs): - return type_eval.issubtype( - type_eval.eval_typing(lhs), - type_eval.eval_typing(rhs), - ) +def IsSubtype(self, tps): + return _IsGenericAlias(self, tps) @_SpecialForm -@_split_args -def IsSubSimilar(self, lhs, rhs): - return type_eval.issubsimilar( - type_eval.eval_typing(lhs), - type_eval.eval_typing(rhs), - ) +def IsSubSimilar(self, tps): + return _IsGenericAlias(self, tps) Is = IsSubSimilar From 15fa6af1bee2d3013a312b33c423ffe39a4aad6e Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Thu, 13 Nov 2025 11:38:26 -0800 Subject: [PATCH 11/12] Pass ctx into _eval_operators --- typemap/type_eval/_eval_operators.py | 71 ++++++++++++++-------------- typemap/type_eval/_eval_typing.py | 11 ++--- typemap/typing.py | 2 +- 3 files changed, 42 insertions(+), 42 deletions(-) diff --git a/typemap/type_eval/_eval_operators.py b/typemap/type_eval/_eval_operators.py index 35b9f76..0d80289 100644 --- a/typemap/type_eval/_eval_operators.py +++ b/typemap/type_eval/_eval_operators.py @@ -6,6 +6,7 @@ from typemap import type_eval from typemap.type_eval import _typing_inspect +from typemap.type_eval._eval_typing import _eval_types from typemap.typing import ( Attrs, @@ -33,8 +34,8 @@ ################################################################## -def _from_literal(val): - val = type_eval.eval_typing(val) +def _from_literal(val, ctx): + val = _eval_types(val, ctx) if _typing_inspect.is_literal(val): val = val.__args__[0] return val @@ -60,10 +61,10 @@ def get_annotated_type_hints(cls, **kwargs): return hints -def _union_elems(tp): - tp = type_eval.eval_typing(tp) +def _union_elems(tp, ctx): + tp = _eval_types(tp, ctx) if isinstance(tp, types.UnionType): - return tuple(y for x in tp.__args__ for y in _union_elems(x)) + return tuple(y for x in tp.__args__ for y in _union_elems(x, ctx)) elif _typing_inspect.is_literal(tp) and len(tp.__args__) > 1: return tuple(typing.Literal[x] for x in tp.__args__) else: @@ -72,10 +73,10 @@ def _union_elems(tp): def _lift_over_unions(func): @functools.wraps(func) - def wrapper(*args): - args2 = [_union_elems(x) for x in args] + def wrapper(*args, ctx): + args2 = [_union_elems(x, ctx) for x in args] # XXX: Never - parts = [func(*x) for x in itertools.product(*args2)] + parts = [func(*x, ctx=ctx) for x in itertools.product(*args2)] return typing.Union[*parts] return wrapper @@ -85,8 +86,8 @@ def wrapper(*args): @type_eval.register_evaluator(Iter) -def _eval_Iter(tp): - tp = type_eval.eval_typing(tp) +def _eval_Iter(tp, *, ctx): + tp = _eval_types(tp, ctx) if ( _typing_inspect.is_generic_alias(tp) and tp.__origin__ is tuple @@ -104,18 +105,18 @@ def _eval_Iter(tp): @type_eval.register_evaluator(IsSubtype) -def _eval_IsSubtype(lhs, rhs): +def _eval_IsSubtype(lhs, rhs, *, ctx): return type_eval.issubtype( - type_eval.eval_typing(lhs), - type_eval.eval_typing(rhs), + _eval_types(lhs, ctx), + _eval_types(rhs, ctx), ) @type_eval.register_evaluator(IsSubSimilar) -def _eval_IsSubSimilar(lhs, rhs): +def _eval_IsSubSimilar(lhs, rhs, *, ctx): return type_eval.issubsimilar( - type_eval.eval_typing(lhs), - type_eval.eval_typing(rhs), + _eval_types(lhs, ctx), + _eval_types(rhs, ctx), ) @@ -123,7 +124,7 @@ def _eval_IsSubSimilar(lhs, rhs): @type_eval.register_evaluator(CallSpecKwargs) -def _eval_CallSpecKwargs(spec: _CallSpecWrapper): +def _eval_CallSpecKwargs(spec: _CallSpecWrapper, *, ctx): ff = types.FunctionType( spec._func.__code__, spec._func.__globals__, @@ -195,7 +196,7 @@ def _ann(x): @type_eval.register_evaluator(Attrs) -def _eval_Attrs(tp): +def _eval_Attrs(tp, *, ctx): hints = get_annotated_type_hints(tp, include_extras=True) return tuple[ @@ -208,7 +209,7 @@ def _eval_Attrs(tp): @type_eval.register_evaluator(Members) @_lift_over_unions -def _eval_Members(tp): +def _eval_Members(tp, *, ctx): hints = get_annotated_type_hints(tp, include_extras=True) attrs = [ @@ -237,8 +238,8 @@ def _eval_Members(tp): @type_eval.register_evaluator(FromUnion) -def _eval_FromUnion(tp): - return tuple[*_union_elems(tp)] +def _eval_FromUnion(tp, *, ctx): + return tuple[*_union_elems(tp, ctx)] ################################################################## @@ -246,16 +247,16 @@ def _eval_FromUnion(tp): @type_eval.register_evaluator(GetAttr) @_lift_over_unions -def _eval_GetAttr(lhs, prop): +def _eval_GetAttr(lhs, prop, *, ctx): # TODO: the prop missing, etc! # XXX: extras? - name = _from_literal(prop) + name = _from_literal(prop, ctx) return typing.get_type_hints(lhs)[name] -def _get_args(tp, base) -> typing.Any: +def _get_args(tp, base, ctx) -> typing.Any: # XXX: check against base!! - evaled = type_eval.eval_typing(tp) + evaled = _eval_types(tp, ctx) tp_head = _typing_inspect.get_head(tp) base_head = _typing_inspect.get_head(base) @@ -281,21 +282,21 @@ def _get_args(tp, base) -> typing.Any: @type_eval.register_evaluator(GetArg) @_lift_over_unions -def _eval_GetArg(tp, base, idx) -> typing.Any: - args = _get_args(tp, base) +def _eval_GetArg(tp, base, idx, *, ctx) -> typing.Any: + args = _get_args(tp, base, ctx) if args is None: return typing.Never try: - return args[_from_literal(idx)] + return args[_from_literal(idx, ctx)] except IndexError: return typing.Never def _string_literal_op(typ, op): @_lift_over_unions - def func(*args): - return typing.Literal[op(*[_from_literal(x) for x in args])] + def func(*args, ctx): + return typing.Literal[op(*[_from_literal(x, ctx) for x in args])] type_eval.register_evaluator(typ)(func) @@ -312,13 +313,13 @@ def func(*args): @type_eval.register_evaluator(NewProtocol) -def _eval_NewProtocol(*etyps: Member): +def _eval_NewProtocol(*etyps: Member, ctx): dct: dict[str, object] = {} dct["__annotations__"] = { # XXX: Should eval_typing on the etyps evaluate the arguments?? - _from_literal( - type_eval.eval_typing(typing.get_args(prop)[0]) - ): type_eval.eval_typing(typing.get_args(prop)[1]) + _from_literal(typing.get_args(prop)[0], ctx): _eval_types( + typing.get_args(prop)[1], ctx + ) for prop in etyps } @@ -338,5 +339,5 @@ def _eval_NewProtocol(*etyps: Member): mcls: type = type(typing.cast(type, typing.Protocol)) cls = mcls(name, (typing.Protocol,), dct) - cls = type_eval.eval_typing(cls) + cls = _eval_types(cls, ctx) return cls diff --git a/typemap/type_eval/_eval_typing.py b/typemap/type_eval/_eval_typing.py index a797f3c..69114c1 100644 --- a/typemap/type_eval/_eval_typing.py +++ b/typemap/type_eval/_eval_typing.py @@ -71,17 +71,16 @@ def _ensure_context() -> typing.Iterator[EvalContext]: ) _current_context.set(ctx) ctx_set = True - old_evaluator = nt.special_form_evaluator.get() - if old_evaluator is not eval_typing: - nt.special_form_evaluator.set(eval_typing) + evaluator_token = nt.special_form_evaluator.set( + lambda t: _eval_types(t, ctx) + ) try: yield ctx finally: if ctx_set: _current_context.set(None) - if old_evaluator is not eval_typing: - nt.special_form_evaluator.set(old_evaluator) + nt.special_form_evaluator.reset(evaluator_token) def _get_current_context() -> EvalContext: @@ -218,7 +217,7 @@ def _eval_typing_generic(obj: _GenericAlias, ctx: EvalContext): # aliases are # types.GenericAlias? Why in the world. if func := _eval_funcs.get(obj.__origin__): new_args = tuple(_eval_types(arg, ctx) for arg in obj.__args__) - ret = func(*new_args) + ret = func(*new_args, ctx=ctx) # return _eval_types(ret, ctx) # ??? return ret diff --git a/typemap/typing.py b/typemap/typing.py index 3058d18..0f03a51 100644 --- a/typemap/typing.py +++ b/typemap/typing.py @@ -113,7 +113,7 @@ def __iter__(self): if evaluator: return evaluator(self) else: - return super().__iter__() + raise TypeError(f"No evaluator provided for {self}") @_SpecialForm From 3ee413711dee50311f6c045c021f03c47b295fb9 Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Thu, 13 Nov 2025 15:46:17 -0800 Subject: [PATCH 12/12] Try to not fail in _IterGenericAlias/_IsGenericAlias Is will just return False, while _IterGenericAlias will return an Unpack of some typevar. This is kind of unprincipled though still --- typemap/typing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/typemap/typing.py b/typemap/typing.py index 0f03a51..5b08b87 100644 --- a/typemap/typing.py +++ b/typemap/typing.py @@ -113,7 +113,7 @@ def __iter__(self): if evaluator: return evaluator(self) else: - raise TypeError(f"No evaluator provided for {self}") + return iter(typing.TypeVarTuple("_IterDummy")) @_SpecialForm @@ -127,7 +127,7 @@ def __bool__(self): if evaluator: return evaluator(self) else: - raise TypeError(f"No evaluator provided for {self}") + return False @_SpecialForm