diff --git a/tests/test_type_eval.py b/tests/test_type_eval.py index 3f80152..79a3dd3 100644 --- a/tests/test_type_eval.py +++ b/tests/test_type_eval.py @@ -1707,14 +1707,11 @@ def g(self) -> int: ... # omitted # Attrs attrs = eval_typing(Attrs[B]) - assert ( - attrs - == tuple[ - Member[Literal["a1"], int, Never, Never, A], - Member[Literal["a2"], str, Never, Never, B], - Member[Literal["b1"], str, Never, Never, B], - Member[Literal["b2"], str, Never, Never, B], - ] + assert attrs.__args__ == ( + Member[Literal["a1"], int, Never, Never, A], + Member[Literal["a2"], str, Never, Never, B], + Member[Literal["b1"], str, Never, Never, B], + Member[Literal["b2"], str, Never, Never, B], ) # Members @@ -2006,14 +2003,6 @@ def __init_subclass__[T]( ) -type AttrsAsList[T] = UpdateClass[ - *[Member[GetName[m], list[GetType[m]]] for m in Iter[Attrs[T]]] -] -type AttrsAsTuple[T] = UpdateClass[ - *[Member[GetName[m], tuple[GetType[m]]] for m in Iter[Attrs[T]]] -] - - def test_update_class_inheritance_02(): # __init_subclass__ calls follow normal MRO class A: @@ -2025,34 +2014,35 @@ def __init_subclass__[T]( super().__init_subclass__() class B(A): - b: int + b: bytes def __init_subclass__[T]( cls: type[T], - ) -> AttrsAsList[T]: + ) -> UpdateClass[ + *[Member[GetName[m], list[GetType[m]]] for m in Iter[Attrs[T]]] + ]: super().__init_subclass__() class C: - c: int + c: float def __init_subclass__[T]( cls: type[T], - ) -> AttrsAsTuple[T]: + ) -> UpdateClass[ + *[Member[GetName[m], tuple[GetType[m]]] for m in Iter[Attrs[T]]] + ]: super().__init_subclass__() class D(B, C): - d: int + d: bool attrs = eval_typing(Attrs[D]) # MRO = D, B, A, C, object - assert ( - attrs - == tuple[ - Member[Literal["c"], tuple[set[list[int]]], Never, Never, D], - Member[Literal["a"], tuple[set[list[int]]], Never, Never, D], - Member[Literal["b"], tuple[set[list[int]]], Never, Never, D], - Member[Literal["d"], tuple[set[list[int]]], Never, Never, D], - ] + assert attrs.__args__ == ( + Member[Literal["c"], tuple[set[list[float]]], Never, Never, D], + Member[Literal["a"], tuple[set[list[int]]], Never, Never, D], + Member[Literal["b"], tuple[set[list[bytes]]], Never, Never, D], + Member[Literal["d"], tuple[set[list[bool]]], Never, Never, D], ) diff --git a/typemap/type_eval/__init__.py b/typemap/type_eval/__init__.py index 7e11eb1..030362b 100644 --- a/typemap/type_eval/__init__.py +++ b/typemap/type_eval/__init__.py @@ -2,6 +2,7 @@ eval_typing, _get_current_context, register_evaluator, + StuckException, _EvalProxy, ) from ._apply_generic import flatten_class @@ -24,6 +25,7 @@ "flatten_class", "issubtype", "TypeMapError", + "StuckException", "_EvalProxy", "_get_current_context", ) diff --git a/typemap/type_eval/_apply_generic.py b/typemap/type_eval/_apply_generic.py index 44adf55..ea8b4d3 100644 --- a/typemap/type_eval/_apply_generic.py +++ b/typemap/type_eval/_apply_generic.py @@ -12,7 +12,7 @@ from . import _typing_inspect if typing.TYPE_CHECKING: - from typing import Any + from typing import Any, Mapping @dataclasses.dataclass(frozen=True) @@ -20,6 +20,9 @@ class Boxed: cls: type[Any] bases: list[Boxed] args: dict[Any, Any] + orig_cls: type[Any] | None = ( + None # Original class, before __init_subclass__ applied + ) str_args: dict[str, Any] = dataclasses.field(init=False) mro: tuple[Boxed, ...] = dataclasses.field(init=False) @@ -38,14 +41,22 @@ def __post_init__(self): object.__setattr__( self, "mro", - _compute_mro(self), + tuple(_compute_mro(self)), ) + @property + def canonical_cls(self): + """The class for the original boxing. + + (Possibly a new one was created after __init_subclass__ applied. + """ + return self.orig_cls or self.cls + def alias_type(self): if self.args: - return self.cls[*self.args.values()] + return self.canonical_cls[*self.args.values()] else: - return self.cls + return self.canonical_cls def __repr__(self): return f"Boxed<{self.cls} {self.args}>" @@ -194,7 +205,7 @@ def make_func( def get_annotations( obj: object, - args: dict[str, object], + args: Mapping[str, object], key: str = '__annotate__', cls: type | None = None, annos_ok: bool = True, @@ -222,7 +233,7 @@ def get_annotations( globs.update(vars(mod)) if isinstance(rr, dict) and any(isinstance(v, str) for v in rr.values()): - args = args.copy() + args = dict(args) # Copy in any __type_params__ that aren't provided for, so that if # we have to eval, we have them. if params := getattr(obj, "__type_params__", None): @@ -273,9 +284,18 @@ def get_local_defns(boxed: Boxed) -> tuple[dict[str, Any], dict[str, Any]]: # TODO: This annos_ok thing is a hack because processing # __annotations__ on methods broke stuff and I didn't want # to chase it down yet. - if ( - rr := get_annotations(stuff, boxed.str_args, cls=boxed.cls, annos_ok=False) - ) is not None: + try: + rr = get_annotations( + stuff, boxed.str_args, cls=boxed.cls, annos_ok=False + ) + except _eval_typing.StuckException: + # TODO: Either generate a GenericCallable or a + # function with our own __annotate__ for this case + # where we can't even fetch the signature without + # trouble. + rr = None + + if rr is not None: local_fn = make_func(orig, rr) elif getattr(stuff, "__annotations__", None): # XXX: This is totally wrong; we still need to do diff --git a/typemap/type_eval/_eval_operators.py b/typemap/type_eval/_eval_operators.py index 51cce0a..6c9e611 100644 --- a/typemap/type_eval/_eval_operators.py +++ b/typemap/type_eval/_eval_operators.py @@ -1,6 +1,7 @@ import collections import collections.abc import contextlib +import dataclasses import functools import inspect import itertools @@ -80,24 +81,34 @@ def _make_init_type(v): return typing.Literal[(v,)] +def cached_box(cls, *, ctx): + if str(cls).startswith('typemap.typing'): + return _apply_generic.box(cls) + if cls in ctx.box_cache: + return ctx.box_cache[cls] + ctx.box_cache[cls] = box = _apply_generic.box(cls) + assert box.mro + # if not all(b.mro for b in box.mro): + # breakpoint() + # assert all(b.mro for b in box.mro) + + if new_box := _eval_init_subclass(box, ctx): + ctx.box_cache[cls] = box = new_box + return box + + def get_annotated_type_hints(cls, *, ctx, **kwargs): """Get the type hints/quals for a cls annotated with definition site. This traverses the mro and finds the definition site for each annotation. """ - # TODO: Cache the box (slash don't need it??) - box = _apply_generic.box(cls) + box = cached_box(cls, ctx=ctx) hints = {} for abox in reversed(box.mro): acls = abox.alias_type() - if abox is box and (updated_cls := _eval_init_subclass(box, ctx)): - # For the class itself, apply all UpdateClass from - # ancesstors' __init_subclass__ to get the final type. - abox = _apply_generic.box(updated_cls) - annos, _ = _apply_generic.get_local_defns(abox) for k, ty in annos.items(): quals = set() @@ -128,18 +139,12 @@ def get_annotated_type_hints(cls, *, ctx, **kwargs): def get_annotated_method_hints(cls, *, ctx): - # TODO: Cache the box (slash don't need it??) - box = _apply_generic.box(cls) + box = cached_box(cls, ctx=ctx) hints = {} for abox in reversed(box.mro): acls = abox.alias_type() - if abox is box and (updated_cls := _eval_init_subclass(box, ctx)): - # For the class itself, apply all UpdateClass from - # ancesstors' __init_subclass__ to get the final type. - abox = _apply_generic.box(updated_cls) - _, dct = _apply_generic.get_local_defns(abox) for name, attr in dct.items(): if isinstance( @@ -166,25 +171,38 @@ def get_annotated_method_hints(cls, *, ctx): def _eval_init_subclass( box: _apply_generic.Boxed, ctx: typing.Any -) -> type | None: +) -> _apply_generic.Boxed: """Get type after all __init_subclass__ with UpdateClass are evaluated.""" - for abox in reversed(box.mro[1:]): # Skip the type itself - if ms := _get_update_class_members(box.cls, abox.alias_type(), ctx=ctx): - return _create_updated_class(box.cls, ms, ctx=ctx) - - return None + for abox in box.mro[1:]: # Skip the type itself + with _child_context() as ctx: + if ms := _get_update_class_members( + box.cls, abox.alias_type(), ctx=ctx + ): + nbox = _apply_generic.box( + _create_updated_class(box.cls, ms, ctx=ctx) + ) + # We want to preserve the original cls for Members output + box = dataclasses.replace(nbox, orig_cls=box.canonical_cls) + ctx.box_cache[box.cls] = box + return box def _get_update_class_members( cls: type, base: type, ctx: typing.Any ) -> list[Member] | None: - if ( - (init_subclass := base.__dict__.get("__init_subclass__")) - # XXX: We're using get_type_hints now to evaluate hints but - # we should have our own generic infrastructure instead. - # (I'm working on it -sully) - and (init_subclass_annos := typing.get_type_hints(init_subclass)) - and (ret_annotation := init_subclass_annos.get("return")) + init_subclass = base.__dict__.get("__init_subclass__") + if not init_subclass: + return None + init_subclass = inspect.unwrap(init_subclass) + + args = {} + if type_params := getattr(init_subclass, '__type_params__', None): + args[str(type_params[0])] = cls + + init_subclass_annos = _apply_generic.get_annotations(init_subclass, args) + + if init_subclass_annos and ( + ret_annotation := init_subclass_annos.get("return") ): # Substitute the cls type var with the current class # This may not happen if cls is not generic! @@ -210,12 +228,7 @@ def _get_update_class_members( ) # Evaluate the return annotation - # Do it in a child context, so the evaluations are isolated. For - # example, if the return annotation uses Attrs[MyClass], we want - # Attrs[MyClass] to be evaluated with the updated class, not the - # original. - with _child_context() as ctx: - evaled_ret = _eval_types(ret_annotation, ctx=ctx) + evaled_ret = _eval_types(ret_annotation, ctx=ctx) # If the result is an UpdateClass, return the members if ( diff --git a/typemap/type_eval/_eval_typing.py b/typemap/type_eval/_eval_typing.py index e340dff..87989ce 100644 --- a/typemap/type_eval/_eval_typing.py +++ b/typemap/type_eval/_eval_typing.py @@ -25,6 +25,12 @@ __all__ = ("eval_typing",) +class StuckException(Exception): + """Raised when a type operator receives a type variable argument.""" + + pass + + _eval_funcs: dict[type, typing.Callable[..., Any]] = {} @@ -101,6 +107,10 @@ class EvalContext: typing.TypeAliasType | types.GenericAlias, typing.Any ] = dataclasses.field(default_factory=dict) + box_cache: dict[typing.Any, _apply_generic.Boxed] = dataclasses.field( + default_factory=dict + ) + # The typing.Any is really a types.FunctionType, but mypy gets # confused and wants to treat it as a MethodType. current_generic_alias: types.GenericAlias | typing.Any | None = None @@ -123,7 +133,7 @@ def _ensure_context() -> typing.Iterator[EvalContext]: _current_context.set(ctx) ctx_set = True evaluator_token = nt.special_form_evaluator.set( - lambda t: _eval_types(t, ctx) + lambda t: _eval_types(t, _current_context.get()) # type: ignore[arg-type] ) try: @@ -168,6 +178,7 @@ def _child_context() -> typing.Iterator[EvalContext]: recursive_type_alias=ctx.recursive_type_alias, known_recursive_types=ctx.known_recursive_types.copy(), current_generic_alias=ctx.current_generic_alias, + box_cache=ctx.box_cache, # Not copied! ) _current_context.set(child_ctx) yield child_ctx @@ -394,6 +405,13 @@ def _eval_applied_class(obj: typing_GenericAlias, ctx: EvalContext): new_args = _eval_args(typing.get_args(obj), ctx) if func := _eval_funcs.get(obj.__origin__): + _tvars = ( + typing.TypeVar, + typing.ParamSpec, + typing.TypeVarTuple, + ) + if any(isinstance(a, _tvars) for a in new_args): + raise StuckException(obj) ret = func(*new_args, ctx=ctx) # return _eval_types(ret, ctx) # ??? return ret