Skip to content

Commit 3f31f0c

Browse files
committed
Make __init_subclass__ work with nontrivial RHSes
Basically the idea is to evaluate it with the type variable substituted in. I had to rework the caching for boxes to prevent infinite recursion. It's still a TODO to properly report out GenericCallables for methods with nontrivial computation at the top level, but the basic idea is to *try* to get the annotation and raise a StuckException if we would need to use an operator on a variable while evaluating a Bool or Iter. Then if the exception got raised, we will produce a GenericCallable.
1 parent 9b0ddec commit 3f31f0c

File tree

5 files changed

+115
-72
lines changed

5 files changed

+115
-72
lines changed

tests/test_type_eval.py

Lines changed: 19 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1707,14 +1707,11 @@ def g(self) -> int: ... # omitted
17071707

17081708
# Attrs
17091709
attrs = eval_typing(Attrs[B])
1710-
assert (
1711-
attrs
1712-
== tuple[
1713-
Member[Literal["a1"], int, Never, Never, A],
1714-
Member[Literal["a2"], str, Never, Never, B],
1715-
Member[Literal["b1"], str, Never, Never, B],
1716-
Member[Literal["b2"], str, Never, Never, B],
1717-
]
1710+
assert attrs.__args__ == (
1711+
Member[Literal["a1"], int, Never, Never, A],
1712+
Member[Literal["a2"], str, Never, Never, B],
1713+
Member[Literal["b1"], str, Never, Never, B],
1714+
Member[Literal["b2"], str, Never, Never, B],
17181715
)
17191716

17201717
# Members
@@ -2006,14 +2003,6 @@ def __init_subclass__[T](
20062003
)
20072004

20082005

2009-
type AttrsAsList[T] = UpdateClass[
2010-
*[Member[GetName[m], list[GetType[m]]] for m in Iter[Attrs[T]]]
2011-
]
2012-
type AttrsAsTuple[T] = UpdateClass[
2013-
*[Member[GetName[m], tuple[GetType[m]]] for m in Iter[Attrs[T]]]
2014-
]
2015-
2016-
20172006
def test_update_class_inheritance_02():
20182007
# __init_subclass__ calls follow normal MRO
20192008
class A:
@@ -2025,34 +2014,35 @@ def __init_subclass__[T](
20252014
super().__init_subclass__()
20262015

20272016
class B(A):
2028-
b: int
2017+
b: bytes
20292018

20302019
def __init_subclass__[T](
20312020
cls: type[T],
2032-
) -> AttrsAsList[T]:
2021+
) -> UpdateClass[
2022+
*[Member[GetName[m], list[GetType[m]]] for m in Iter[Attrs[T]]]
2023+
]:
20332024
super().__init_subclass__()
20342025

20352026
class C:
2036-
c: int
2027+
c: float
20372028

20382029
def __init_subclass__[T](
20392030
cls: type[T],
2040-
) -> AttrsAsTuple[T]:
2031+
) -> UpdateClass[
2032+
*[Member[GetName[m], tuple[GetType[m]]] for m in Iter[Attrs[T]]]
2033+
]:
20412034
super().__init_subclass__()
20422035

20432036
class D(B, C):
2044-
d: int
2037+
d: bool
20452038

20462039
attrs = eval_typing(Attrs[D])
20472040
# MRO = D, B, A, C, object
2048-
assert (
2049-
attrs
2050-
== tuple[
2051-
Member[Literal["c"], tuple[set[list[int]]], Never, Never, D],
2052-
Member[Literal["a"], tuple[set[list[int]]], Never, Never, D],
2053-
Member[Literal["b"], tuple[set[list[int]]], Never, Never, D],
2054-
Member[Literal["d"], tuple[set[list[int]]], Never, Never, D],
2055-
]
2041+
assert attrs.__args__ == (
2042+
Member[Literal["c"], tuple[set[list[float]]], Never, Never, D],
2043+
Member[Literal["a"], tuple[set[list[int]]], Never, Never, D],
2044+
Member[Literal["b"], tuple[set[list[bytes]]], Never, Never, D],
2045+
Member[Literal["d"], tuple[set[list[bool]]], Never, Never, D],
20562046
)
20572047

20582048

typemap/type_eval/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
eval_typing,
33
_get_current_context,
44
register_evaluator,
5+
StuckException,
56
_EvalProxy,
67
)
78
from ._apply_generic import flatten_class
@@ -24,6 +25,7 @@
2425
"flatten_class",
2526
"issubtype",
2627
"TypeMapError",
28+
"StuckException",
2729
"_EvalProxy",
2830
"_get_current_context",
2931
)

typemap/type_eval/_apply_generic.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,17 @@
1212
from . import _typing_inspect
1313

1414
if typing.TYPE_CHECKING:
15-
from typing import Any
15+
from typing import Any, Mapping
1616

1717

1818
@dataclasses.dataclass(frozen=True)
1919
class Boxed:
2020
cls: type[Any]
2121
bases: list[Boxed]
2222
args: dict[Any, Any]
23+
orig_cls: type[Any] | None = (
24+
None # Original class, before __init_subclass__ applied
25+
)
2326

2427
str_args: dict[str, Any] = dataclasses.field(init=False)
2528
mro: tuple[Boxed, ...] = dataclasses.field(init=False)
@@ -38,14 +41,22 @@ def __post_init__(self):
3841
object.__setattr__(
3942
self,
4043
"mro",
41-
_compute_mro(self),
44+
tuple(_compute_mro(self)),
4245
)
4346

47+
@property
48+
def canonical_cls(self):
49+
"""The class for the original boxing.
50+
51+
(Possibly a new one was created after __init_subclass__ applied.
52+
"""
53+
return self.orig_cls or self.cls
54+
4455
def alias_type(self):
4556
if self.args:
46-
return self.cls[*self.args.values()]
57+
return self.canonical_cls[*self.args.values()]
4758
else:
48-
return self.cls
59+
return self.canonical_cls
4960

5061
def __repr__(self):
5162
return f"Boxed<{self.cls} {self.args}>"
@@ -194,7 +205,7 @@ def make_func(
194205

195206
def get_annotations(
196207
obj: object,
197-
args: dict[str, object],
208+
args: Mapping[str, object],
198209
key: str = '__annotate__',
199210
cls: type | None = None,
200211
annos_ok: bool = True,
@@ -222,7 +233,7 @@ def get_annotations(
222233
globs.update(vars(mod))
223234

224235
if isinstance(rr, dict) and any(isinstance(v, str) for v in rr.values()):
225-
args = args.copy()
236+
args = dict(args)
226237
# Copy in any __type_params__ that aren't provided for, so that if
227238
# we have to eval, we have them.
228239
if params := getattr(obj, "__type_params__", None):
@@ -273,9 +284,18 @@ def get_local_defns(boxed: Boxed) -> tuple[dict[str, Any], dict[str, Any]]:
273284
# TODO: This annos_ok thing is a hack because processing
274285
# __annotations__ on methods broke stuff and I didn't want
275286
# to chase it down yet.
276-
if (
277-
rr := get_annotations(stuff, boxed.str_args, cls=boxed.cls, annos_ok=False)
278-
) is not None:
287+
try:
288+
rr = get_annotations(
289+
stuff, boxed.str_args, cls=boxed.cls, annos_ok=False
290+
)
291+
except _eval_typing.StuckException:
292+
# TODO: Either generate a GenericCallable or a
293+
# function with our own __annotate__ for this case
294+
# where we can't even fetch the signature without
295+
# trouble.
296+
rr = None
297+
298+
if rr is not None:
279299
local_fn = make_func(orig, rr)
280300
elif getattr(stuff, "__annotations__", None):
281301
# XXX: This is totally wrong; we still need to do

typemap/type_eval/_eval_operators.py

Lines changed: 46 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import collections
22
import collections.abc
33
import contextlib
4+
import dataclasses
45
import functools
56
import inspect
67
import itertools
@@ -80,24 +81,34 @@ def _make_init_type(v):
8081
return typing.Literal[(v,)]
8182

8283

84+
def cached_box(cls, *, ctx):
85+
if str(cls).startswith('typemap.typing'):
86+
return _apply_generic.box(cls)
87+
if cls in ctx.box_cache:
88+
return ctx.box_cache[cls]
89+
ctx.box_cache[cls] = box = _apply_generic.box(cls)
90+
assert box.mro
91+
# if not all(b.mro for b in box.mro):
92+
# breakpoint()
93+
# assert all(b.mro for b in box.mro)
94+
95+
if new_box := _eval_init_subclass(box, ctx):
96+
ctx.box_cache[cls] = box = new_box
97+
return box
98+
99+
83100
def get_annotated_type_hints(cls, *, ctx, **kwargs):
84101
"""Get the type hints/quals for a cls annotated with definition site.
85102
86103
This traverses the mro and finds the definition site for each annotation.
87104
"""
88105

89-
# TODO: Cache the box (slash don't need it??)
90-
box = _apply_generic.box(cls)
106+
box = cached_box(cls, ctx=ctx)
91107

92108
hints = {}
93109
for abox in reversed(box.mro):
94110
acls = abox.alias_type()
95111

96-
if abox is box and (updated_cls := _eval_init_subclass(box, ctx)):
97-
# For the class itself, apply all UpdateClass from
98-
# ancesstors' __init_subclass__ to get the final type.
99-
abox = _apply_generic.box(updated_cls)
100-
101112
annos, _ = _apply_generic.get_local_defns(abox)
102113
for k, ty in annos.items():
103114
quals = set()
@@ -128,18 +139,12 @@ def get_annotated_type_hints(cls, *, ctx, **kwargs):
128139

129140

130141
def get_annotated_method_hints(cls, *, ctx):
131-
# TODO: Cache the box (slash don't need it??)
132-
box = _apply_generic.box(cls)
142+
box = cached_box(cls, ctx=ctx)
133143

134144
hints = {}
135145
for abox in reversed(box.mro):
136146
acls = abox.alias_type()
137147

138-
if abox is box and (updated_cls := _eval_init_subclass(box, ctx)):
139-
# For the class itself, apply all UpdateClass from
140-
# ancesstors' __init_subclass__ to get the final type.
141-
abox = _apply_generic.box(updated_cls)
142-
143148
_, dct = _apply_generic.get_local_defns(abox)
144149
for name, attr in dct.items():
145150
if isinstance(
@@ -166,25 +171,38 @@ def get_annotated_method_hints(cls, *, ctx):
166171

167172
def _eval_init_subclass(
168173
box: _apply_generic.Boxed, ctx: typing.Any
169-
) -> type | None:
174+
) -> _apply_generic.Boxed:
170175
"""Get type after all __init_subclass__ with UpdateClass are evaluated."""
171-
for abox in reversed(box.mro[1:]): # Skip the type itself
172-
if ms := _get_update_class_members(box.cls, abox.alias_type(), ctx=ctx):
173-
return _create_updated_class(box.cls, ms, ctx=ctx)
174-
175-
return None
176+
for abox in box.mro[1:]: # Skip the type itself
177+
with _child_context() as ctx:
178+
if ms := _get_update_class_members(
179+
box.cls, abox.alias_type(), ctx=ctx
180+
):
181+
nbox = _apply_generic.box(
182+
_create_updated_class(box.cls, ms, ctx=ctx)
183+
)
184+
# We want to preserve the original cls for Members output
185+
box = dataclasses.replace(nbox, orig_cls=box.canonical_cls)
186+
ctx.box_cache[box.cls] = box
187+
return box
176188

177189

178190
def _get_update_class_members(
179191
cls: type, base: type, ctx: typing.Any
180192
) -> list[Member] | None:
181-
if (
182-
(init_subclass := base.__dict__.get("__init_subclass__"))
183-
# XXX: We're using get_type_hints now to evaluate hints but
184-
# we should have our own generic infrastructure instead.
185-
# (I'm working on it -sully)
186-
and (init_subclass_annos := typing.get_type_hints(init_subclass))
187-
and (ret_annotation := init_subclass_annos.get("return"))
193+
init_subclass = base.__dict__.get("__init_subclass__")
194+
if not init_subclass:
195+
return None
196+
init_subclass = inspect.unwrap(init_subclass)
197+
198+
args = {}
199+
if type_params := getattr(init_subclass, '__type_params__', None):
200+
args[str(type_params[0])] = cls
201+
202+
init_subclass_annos = _apply_generic.get_annotations(init_subclass, args)
203+
204+
if init_subclass_annos and (
205+
ret_annotation := init_subclass_annos.get("return")
188206
):
189207
# Substitute the cls type var with the current class
190208
# This may not happen if cls is not generic!
@@ -210,12 +228,7 @@ def _get_update_class_members(
210228
)
211229

212230
# Evaluate the return annotation
213-
# Do it in a child context, so the evaluations are isolated. For
214-
# example, if the return annotation uses Attrs[MyClass], we want
215-
# Attrs[MyClass] to be evaluated with the updated class, not the
216-
# original.
217-
with _child_context() as ctx:
218-
evaled_ret = _eval_types(ret_annotation, ctx=ctx)
231+
evaled_ret = _eval_types(ret_annotation, ctx=ctx)
219232

220233
# If the result is an UpdateClass, return the members
221234
if (

typemap/type_eval/_eval_typing.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@
2525
__all__ = ("eval_typing",)
2626

2727

28+
class StuckException(Exception):
29+
"""Raised when a type operator receives a type variable argument."""
30+
31+
pass
32+
33+
2834
_eval_funcs: dict[type, typing.Callable[..., Any]] = {}
2935

3036

@@ -101,6 +107,10 @@ class EvalContext:
101107
typing.TypeAliasType | types.GenericAlias, typing.Any
102108
] = dataclasses.field(default_factory=dict)
103109

110+
box_cache: dict[typing.Any, _apply_generic.Boxed] = dataclasses.field(
111+
default_factory=dict
112+
)
113+
104114
# The typing.Any is really a types.FunctionType, but mypy gets
105115
# confused and wants to treat it as a MethodType.
106116
current_generic_alias: types.GenericAlias | typing.Any | None = None
@@ -123,7 +133,7 @@ def _ensure_context() -> typing.Iterator[EvalContext]:
123133
_current_context.set(ctx)
124134
ctx_set = True
125135
evaluator_token = nt.special_form_evaluator.set(
126-
lambda t: _eval_types(t, ctx)
136+
lambda t: _eval_types(t, _current_context.get()) # type: ignore[arg-type]
127137
)
128138

129139
try:
@@ -168,6 +178,7 @@ def _child_context() -> typing.Iterator[EvalContext]:
168178
recursive_type_alias=ctx.recursive_type_alias,
169179
known_recursive_types=ctx.known_recursive_types.copy(),
170180
current_generic_alias=ctx.current_generic_alias,
181+
box_cache=ctx.box_cache, # Not copied!
171182
)
172183
_current_context.set(child_ctx)
173184
yield child_ctx
@@ -394,6 +405,13 @@ def _eval_applied_class(obj: typing_GenericAlias, ctx: EvalContext):
394405
new_args = _eval_args(typing.get_args(obj), ctx)
395406

396407
if func := _eval_funcs.get(obj.__origin__):
408+
_tvars = (
409+
typing.TypeVar,
410+
typing.ParamSpec,
411+
typing.TypeVarTuple,
412+
)
413+
if any(isinstance(a, _tvars) for a in new_args):
414+
raise StuckException(obj)
397415
ret = func(*new_args, ctx=ctx)
398416
# return _eval_types(ret, ctx) # ???
399417
return ret

0 commit comments

Comments
 (0)