Skip to content

Commit cfc421e

Browse files
committed
Build GenericCallables when we can't evaluate a method's params
The lambda we stick into the GenericCallable will actually go evaluate the type annotations, then produce a Callable.
1 parent 122c640 commit cfc421e

File tree

5 files changed

+119
-20
lines changed

5 files changed

+119
-20
lines changed

tests/test_type_dir.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import typing
33
from typing import Literal, Never, TypeVar, TypedDict, Union, ReadOnly
44

5-
from typemap.type_eval import eval_typing
5+
from typemap.type_eval import eval_typing, _ensure_context
66
from typemap_extensions import (
77
Attrs,
88
FromUnion,
@@ -323,6 +323,25 @@ class Last[bool]:
323323
""")
324324

325325

326+
def test_type_dir_10():
327+
class Lurr:
328+
def foo[T](x: T) -> int if IsAssignable[T, str] else list[int]: ...
329+
330+
d = eval_typing(Lurr)
331+
332+
assert format_helper.format_class(d) == textwrap.dedent("""\
333+
class Lurr:
334+
foo: typing.ClassVar[typemap.typing.GenericCallable[tuple[T], <...>]]
335+
""")
336+
337+
member = _get_member(eval_typing(Members[Lurr]), "foo")
338+
339+
fn = member.__args__[1].__args__[1]
340+
with _ensure_context():
341+
assert fn(str).__args__[1] is int
342+
assert fn(bool).__args__[1] == list[int]
343+
344+
326345
def test_type_dir_get_arg_1():
327346
d = eval_typing(BaseArg[Final])
328347
assert d is int
@@ -405,10 +424,7 @@ def test_type_members_func_3():
405424
assert name == typing.Literal["sbase"]
406425
assert quals == typing.Literal["ClassVar"]
407426

408-
assert (
409-
str(typ)
410-
== "typemap.typing.GenericCallable[tuple[Z], typemap.type_eval._eval_operators._create_generic_callable_lambda.<locals>.<lambda>]"
411-
)
427+
assert str(typ) == "typemap.typing.GenericCallable[tuple[Z], <...>]"
412428

413429
evaled = eval_typing(
414430
typing.get_args(typ)[1](*typing.get_args(typing.get_args(typ)[0]))

typemap/type_eval/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from ._eval_typing import (
22
eval_typing,
33
_get_current_context,
4+
_ensure_context,
45
register_evaluator,
56
StuckException,
67
_EvalProxy,
@@ -28,4 +29,5 @@
2829
"StuckException",
2930
"_EvalProxy",
3031
"_get_current_context",
32+
"_ensure_context",
3133
)

typemap/type_eval/_apply_generic.py

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from . import _eval_typing
1212
from . import _typing_inspect
1313

14+
1415
if typing.TYPE_CHECKING:
1516
from typing import Any, Mapping
1617

@@ -255,7 +256,34 @@ def get_annotations(
255256
return rr
256257

257258

259+
def _resolved_function_signature(func, args):
260+
"""Get the signature of a function with type hints resolved to arg values"""
261+
262+
import typemap.typing as nt
263+
264+
token = nt.special_form_evaluator.set(None)
265+
try:
266+
sig = inspect.signature(func)
267+
finally:
268+
nt.special_form_evaluator.reset(token)
269+
270+
if hints := get_annotations(func, args):
271+
params = []
272+
for name, param in sig.parameters.items():
273+
annotation = hints.get(name, param.annotation)
274+
params.append(param.replace(annotation=annotation))
275+
276+
return_annotation = hints.get("return", sig.return_annotation)
277+
sig = sig.replace(
278+
parameters=params, return_annotation=return_annotation
279+
)
280+
281+
return sig
282+
283+
258284
def get_local_defns(boxed: Boxed) -> tuple[dict[str, Any], dict[str, Any]]:
285+
from typemap.typing import GenericCallable
286+
259287
annos: dict[str, Any] = {}
260288
dct: dict[str, Any] = {}
261289

@@ -274,23 +302,52 @@ def get_local_defns(boxed: Boxed) -> tuple[dict[str, Any], dict[str, Any]]:
274302
# TODO: This annos_ok thing is a hack because processing
275303
# __annotations__ on methods broke stuff and I didn't want
276304
# to chase it down yet.
305+
stuck = False
277306
try:
278307
rr = get_annotations(stuff, boxed.str_args, annos_ok=False)
279308
except _eval_typing.StuckException:
280-
# TODO: Either generate a GenericCallable or a
281-
# function with our own __annotate__ for this case
282-
# where we can't even fetch the signature without
283-
# trouble.
309+
stuck = True
284310
rr = None
285311

286312
if rr is not None:
287313
local_fn = make_func(orig, rr)
288-
elif getattr(stuff, "__annotations__", None):
314+
elif not stuck and getattr(stuff, "__annotations__", None):
289315
# XXX: This is totally wrong; we still need to do
290316
# substitute in class vars
291317
local_fn = stuff
292318

293-
if local_fn is not None:
319+
# If we got stuck, we build a GenericCallable that
320+
# computes the type once it has been given type
321+
# variables!
322+
if stuck and stuff.__type_params__:
323+
type_params = stuff.__type_params__
324+
str_args = boxed.str_args
325+
326+
def _make_lambda(fn, o, sa, tp):
327+
from ._eval_operators import _function_type_from_sig
328+
329+
def lam(*vs):
330+
args = dict(sa)
331+
args.update(
332+
zip(
333+
(str(p) for p in tp),
334+
vs,
335+
strict=True,
336+
)
337+
)
338+
sig = _resolved_function_signature(fn, args)
339+
return _function_type_from_sig(
340+
sig, o, receiver_type=None
341+
)
342+
343+
return lam
344+
345+
gc = GenericCallable[ # type: ignore[valid-type,misc]
346+
tuple[*type_params], # type: ignore[valid-type]
347+
_make_lambda(stuff, orig, str_args, type_params),
348+
]
349+
annos[name] = typing.ClassVar[gc]
350+
elif local_fn is not None:
294351
if orig.__class__ is classmethod:
295352
local_fn = classmethod(local_fn)
296353
elif orig.__class__ is staticmethod:

typemap/type_eval/_eval_operators.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def cached_box(cls, *, ctx):
9797
return box
9898

9999

100-
def get_annotated_type_hints(cls, *, ctx, **kwargs):
100+
def get_annotated_type_hints(cls, *, ctx, attrs_only=False, **kwargs):
101101
"""Get the type hints/quals for a cls annotated with definition site.
102102
103103
This traverses the mro and finds the definition site for each annotation.
@@ -127,6 +127,10 @@ def get_annotated_type_hints(cls, *, ctx, **kwargs):
127127
else:
128128
break
129129

130+
# Skip method-like ClassVars when only attributes are wanted
131+
if attrs_only and "ClassVar" in quals and _is_method_like(ty):
132+
continue
133+
130134
if k in abox.cls.__dict__:
131135
# Wrap in tuple when creating Literal in case it *is* a tuple
132136
init = _make_init_type(abox.cls.__dict__[k])
@@ -650,11 +654,7 @@ def _callable_type_to_method(name, typ, ctx):
650654
return head(func)
651655

652656

653-
def _function_type(func, *, receiver_type):
654-
root = inspect.unwrap(func)
655-
sig = inspect.signature(root)
656-
# XXX: __type_params__!!!
657-
657+
def _function_type_from_sig(sig, func, *, receiver_type):
658658
empty = inspect.Parameter.empty
659659

660660
def _ann(x):
@@ -702,6 +702,15 @@ def _ann(x):
702702
f = classmethod[specified_receiver, tuple[*params[1:]], ret]
703703
else:
704704
f = typing.Callable[params, ret]
705+
706+
return f
707+
708+
709+
def _function_type(func, *, receiver_type):
710+
root = inspect.unwrap(func)
711+
sig = inspect.signature(root)
712+
f = _function_type_from_sig(sig, func, receiver_type=receiver_type)
713+
705714
if root.__type_params__:
706715
# Must store a lambda that performs type variable substitution
707716
type_params = root.__type_params__
@@ -757,7 +766,9 @@ def _hints_to_members(hints, ctx):
757766
@type_eval.register_evaluator(Attrs)
758767
@_lift_over_unions
759768
def _eval_Attrs(tp, *, ctx):
760-
hints = get_annotated_type_hints(tp, include_extras=True, ctx=ctx)
769+
hints = get_annotated_type_hints(
770+
tp, include_extras=True, attrs_only=True, ctx=ctx
771+
)
761772
return _hints_to_members(hints, ctx)
762773

763774

@@ -1185,7 +1196,10 @@ def _eval_NewProtocol(*etyps: Member, ctx):
11851196
if type_eval.issubtype(
11861197
typing.Literal["ClassVar"], tquals
11871198
) and _is_method_like(typ):
1188-
dct[name] = _callable_type_to_method(name, typ, ctx)
1199+
try:
1200+
dct[name] = _callable_type_to_method(name, typ, ctx)
1201+
except type_eval.StuckException:
1202+
annos[name] = _add_quals(typ, tquals)
11891203
else:
11901204
annos[name] = _add_quals(typ, tquals)
11911205
_unpack_init(dct, name, init)

typemap/typing.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,17 @@ class SpecialFormEllipsis:
7272

7373

7474
class _GenericCallableGenericAlias(_GenericAlias, _root=True):
75-
pass
75+
def __repr__(self):
76+
from typing import _type_repr
77+
78+
name = _type_repr(self.__origin__)
79+
if self.__args__:
80+
rargs = [_type_repr(self.__args__[0]), "<...>"]
81+
args = ", ".join(rargs)
82+
else:
83+
# To ensure the repr is eval-able.
84+
args = "()"
85+
return f'{name}[{args}]'
7686

7787

7888
class GenericCallable:

0 commit comments

Comments
 (0)