Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 19 additions & 29 deletions tests/test_type_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1707,14 +1707,11 @@ def g(self) -> int: ... # omitted

# Attrs
attrs = eval_typing(Attrs[B])
assert (
attrs
== tuple[
Member[Literal["a1"], int, Never, Never, A],
Member[Literal["a2"], str, Never, Never, B],
Member[Literal["b1"], str, Never, Never, B],
Member[Literal["b2"], str, Never, Never, B],
]
assert attrs.__args__ == (
Member[Literal["a1"], int, Never, Never, A],
Member[Literal["a2"], str, Never, Never, B],
Member[Literal["b1"], str, Never, Never, B],
Member[Literal["b2"], str, Never, Never, B],
)

# Members
Expand Down Expand Up @@ -2006,14 +2003,6 @@ def __init_subclass__[T](
)


type AttrsAsList[T] = UpdateClass[
*[Member[GetName[m], list[GetType[m]]] for m in Iter[Attrs[T]]]
]
type AttrsAsTuple[T] = UpdateClass[
*[Member[GetName[m], tuple[GetType[m]]] for m in Iter[Attrs[T]]]
]


def test_update_class_inheritance_02():
# __init_subclass__ calls follow normal MRO
class A:
Expand All @@ -2025,34 +2014,35 @@ def __init_subclass__[T](
super().__init_subclass__()

class B(A):
b: int
b: bytes

def __init_subclass__[T](
cls: type[T],
) -> AttrsAsList[T]:
) -> UpdateClass[
*[Member[GetName[m], list[GetType[m]]] for m in Iter[Attrs[T]]]
]:
super().__init_subclass__()

class C:
c: int
c: float

def __init_subclass__[T](
cls: type[T],
) -> AttrsAsTuple[T]:
) -> UpdateClass[
*[Member[GetName[m], tuple[GetType[m]]] for m in Iter[Attrs[T]]]
]:
super().__init_subclass__()

class D(B, C):
d: int
d: bool

attrs = eval_typing(Attrs[D])
# MRO = D, B, A, C, object
assert (
attrs
== tuple[
Member[Literal["c"], tuple[set[list[int]]], Never, Never, D],
Member[Literal["a"], tuple[set[list[int]]], Never, Never, D],
Member[Literal["b"], tuple[set[list[int]]], Never, Never, D],
Member[Literal["d"], tuple[set[list[int]]], Never, Never, D],
]
assert attrs.__args__ == (
Member[Literal["c"], tuple[set[list[float]]], Never, Never, D],
Member[Literal["a"], tuple[set[list[int]]], Never, Never, D],
Member[Literal["b"], tuple[set[list[bytes]]], Never, Never, D],
Member[Literal["d"], tuple[set[list[bool]]], Never, Never, D],
)


Expand Down
2 changes: 2 additions & 0 deletions typemap/type_eval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
eval_typing,
_get_current_context,
register_evaluator,
StuckException,
_EvalProxy,
)
from ._apply_generic import flatten_class
Expand All @@ -24,6 +25,7 @@
"flatten_class",
"issubtype",
"TypeMapError",
"StuckException",
"_EvalProxy",
"_get_current_context",
)
38 changes: 29 additions & 9 deletions typemap/type_eval/_apply_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,17 @@
from . import _typing_inspect

if typing.TYPE_CHECKING:
from typing import Any
from typing import Any, Mapping


@dataclasses.dataclass(frozen=True)
class Boxed:
cls: type[Any]
bases: list[Boxed]
args: dict[Any, Any]
orig_cls: type[Any] | None = (
None # Original class, before __init_subclass__ applied
)

str_args: dict[str, Any] = dataclasses.field(init=False)
mro: tuple[Boxed, ...] = dataclasses.field(init=False)
Expand All @@ -38,14 +41,22 @@ def __post_init__(self):
object.__setattr__(
self,
"mro",
_compute_mro(self),
tuple(_compute_mro(self)),
)

@property
def canonical_cls(self):
"""The class for the original boxing.

(Possibly a new one was created after __init_subclass__ applied.
"""
return self.orig_cls or self.cls

def alias_type(self):
if self.args:
return self.cls[*self.args.values()]
return self.canonical_cls[*self.args.values()]
else:
return self.cls
return self.canonical_cls

def __repr__(self):
return f"Boxed<{self.cls} {self.args}>"
Expand Down Expand Up @@ -194,7 +205,7 @@ def make_func(

def get_annotations(
obj: object,
args: dict[str, object],
args: Mapping[str, object],
key: str = '__annotate__',
cls: type | None = None,
annos_ok: bool = True,
Expand Down Expand Up @@ -222,7 +233,7 @@ def get_annotations(
globs.update(vars(mod))

if isinstance(rr, dict) and any(isinstance(v, str) for v in rr.values()):
args = args.copy()
args = dict(args)
# Copy in any __type_params__ that aren't provided for, so that if
# we have to eval, we have them.
if params := getattr(obj, "__type_params__", None):
Expand Down Expand Up @@ -273,9 +284,18 @@ def get_local_defns(boxed: Boxed) -> tuple[dict[str, Any], dict[str, Any]]:
# TODO: This annos_ok thing is a hack because processing
# __annotations__ on methods broke stuff and I didn't want
# to chase it down yet.
if (
rr := get_annotations(stuff, boxed.str_args, cls=boxed.cls, annos_ok=False)
) is not None:
try:
rr = get_annotations(
stuff, boxed.str_args, cls=boxed.cls, annos_ok=False
)
except _eval_typing.StuckException:
# TODO: Either generate a GenericCallable or a
# function with our own __annotate__ for this case
# where we can't even fetch the signature without
# trouble.
rr = None

if rr is not None:
local_fn = make_func(orig, rr)
elif getattr(stuff, "__annotations__", None):
# XXX: This is totally wrong; we still need to do
Expand Down
79 changes: 46 additions & 33 deletions typemap/type_eval/_eval_operators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import collections
import collections.abc
import contextlib
import dataclasses
import functools
import inspect
import itertools
Expand Down Expand Up @@ -80,24 +81,34 @@ def _make_init_type(v):
return typing.Literal[(v,)]


def cached_box(cls, *, ctx):
if str(cls).startswith('typemap.typing'):
return _apply_generic.box(cls)
if cls in ctx.box_cache:
return ctx.box_cache[cls]
ctx.box_cache[cls] = box = _apply_generic.box(cls)
assert box.mro
# if not all(b.mro for b in box.mro):
# breakpoint()
# assert all(b.mro for b in box.mro)

if new_box := _eval_init_subclass(box, ctx):
ctx.box_cache[cls] = box = new_box
return box


def get_annotated_type_hints(cls, *, ctx, **kwargs):
"""Get the type hints/quals for a cls annotated with definition site.

This traverses the mro and finds the definition site for each annotation.
"""

# TODO: Cache the box (slash don't need it??)
box = _apply_generic.box(cls)
box = cached_box(cls, ctx=ctx)

hints = {}
for abox in reversed(box.mro):
acls = abox.alias_type()

if abox is box and (updated_cls := _eval_init_subclass(box, ctx)):
# For the class itself, apply all UpdateClass from
# ancesstors' __init_subclass__ to get the final type.
abox = _apply_generic.box(updated_cls)

annos, _ = _apply_generic.get_local_defns(abox)
for k, ty in annos.items():
quals = set()
Expand Down Expand Up @@ -128,18 +139,12 @@ def get_annotated_type_hints(cls, *, ctx, **kwargs):


def get_annotated_method_hints(cls, *, ctx):
# TODO: Cache the box (slash don't need it??)
box = _apply_generic.box(cls)
box = cached_box(cls, ctx=ctx)

hints = {}
for abox in reversed(box.mro):
acls = abox.alias_type()

if abox is box and (updated_cls := _eval_init_subclass(box, ctx)):
# For the class itself, apply all UpdateClass from
# ancesstors' __init_subclass__ to get the final type.
abox = _apply_generic.box(updated_cls)

_, dct = _apply_generic.get_local_defns(abox)
for name, attr in dct.items():
if isinstance(
Expand All @@ -166,25 +171,38 @@ def get_annotated_method_hints(cls, *, ctx):

def _eval_init_subclass(
box: _apply_generic.Boxed, ctx: typing.Any
) -> type | None:
) -> _apply_generic.Boxed:
"""Get type after all __init_subclass__ with UpdateClass are evaluated."""
for abox in reversed(box.mro[1:]): # Skip the type itself
if ms := _get_update_class_members(box.cls, abox.alias_type(), ctx=ctx):
return _create_updated_class(box.cls, ms, ctx=ctx)

return None
for abox in box.mro[1:]: # Skip the type itself
with _child_context() as ctx:
if ms := _get_update_class_members(
box.cls, abox.alias_type(), ctx=ctx
):
nbox = _apply_generic.box(
_create_updated_class(box.cls, ms, ctx=ctx)
)
# We want to preserve the original cls for Members output
box = dataclasses.replace(nbox, orig_cls=box.canonical_cls)
ctx.box_cache[box.cls] = box
return box


def _get_update_class_members(
cls: type, base: type, ctx: typing.Any
) -> list[Member] | None:
if (
(init_subclass := base.__dict__.get("__init_subclass__"))
# XXX: We're using get_type_hints now to evaluate hints but
# we should have our own generic infrastructure instead.
# (I'm working on it -sully)
and (init_subclass_annos := typing.get_type_hints(init_subclass))
and (ret_annotation := init_subclass_annos.get("return"))
init_subclass = base.__dict__.get("__init_subclass__")
if not init_subclass:
return None
init_subclass = inspect.unwrap(init_subclass)

args = {}
if type_params := getattr(init_subclass, '__type_params__', None):
args[str(type_params[0])] = cls

init_subclass_annos = _apply_generic.get_annotations(init_subclass, args)

if init_subclass_annos and (
ret_annotation := init_subclass_annos.get("return")
):
# Substitute the cls type var with the current class
# This may not happen if cls is not generic!
Expand All @@ -210,12 +228,7 @@ def _get_update_class_members(
)

# Evaluate the return annotation
# Do it in a child context, so the evaluations are isolated. For
# example, if the return annotation uses Attrs[MyClass], we want
# Attrs[MyClass] to be evaluated with the updated class, not the
# original.
with _child_context() as ctx:
evaled_ret = _eval_types(ret_annotation, ctx=ctx)
evaled_ret = _eval_types(ret_annotation, ctx=ctx)

# If the result is an UpdateClass, return the members
if (
Expand Down
20 changes: 19 additions & 1 deletion typemap/type_eval/_eval_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@
__all__ = ("eval_typing",)


class StuckException(Exception):
"""Raised when a type operator receives a type variable argument."""

pass


_eval_funcs: dict[type, typing.Callable[..., Any]] = {}


Expand Down Expand Up @@ -101,6 +107,10 @@ class EvalContext:
typing.TypeAliasType | types.GenericAlias, typing.Any
] = dataclasses.field(default_factory=dict)

box_cache: dict[typing.Any, _apply_generic.Boxed] = dataclasses.field(
default_factory=dict
)

# The typing.Any is really a types.FunctionType, but mypy gets
# confused and wants to treat it as a MethodType.
current_generic_alias: types.GenericAlias | typing.Any | None = None
Expand All @@ -123,7 +133,7 @@ def _ensure_context() -> typing.Iterator[EvalContext]:
_current_context.set(ctx)
ctx_set = True
evaluator_token = nt.special_form_evaluator.set(
lambda t: _eval_types(t, ctx)
lambda t: _eval_types(t, _current_context.get()) # type: ignore[arg-type]
)

try:
Expand Down Expand Up @@ -168,6 +178,7 @@ def _child_context() -> typing.Iterator[EvalContext]:
recursive_type_alias=ctx.recursive_type_alias,
known_recursive_types=ctx.known_recursive_types.copy(),
current_generic_alias=ctx.current_generic_alias,
box_cache=ctx.box_cache, # Not copied!
)
_current_context.set(child_ctx)
yield child_ctx
Expand Down Expand Up @@ -394,6 +405,13 @@ def _eval_applied_class(obj: typing_GenericAlias, ctx: EvalContext):
new_args = _eval_args(typing.get_args(obj), ctx)

if func := _eval_funcs.get(obj.__origin__):
_tvars = (
typing.TypeVar,
typing.ParamSpec,
typing.TypeVarTuple,
)
if any(isinstance(a, _tvars) for a in new_args):
raise StuckException(obj)
ret = func(*new_args, ctx=ctx)
# return _eval_types(ret, ctx) # ???
return ret
Expand Down