Skip to content

Commit 7a90bd1

Browse files
authored
Use lambda in GenericCallable. (#77)
`GenericCallable` should now be used like `GenericCallable[tuple[T, U], lambda T, U: Callable[[T, U], tuple[T, U]]]` The `__class_getitem__` method is used to generate a `_GenericCallableGenericAlias`.
1 parent 19013fb commit 7a90bd1

6 files changed

Lines changed: 153 additions & 62 deletions

File tree

tests/test_type_dir.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def base[Z](self: Self, a: int | Z | None, b: ~K) -> dict[str, int | Z]: ...
210210
@classmethod
211211
def cbase(cls: type[typing.Self], a: int | None, b: ~K) -> dict[str, int]: ...
212212
@staticmethod
213-
def sbase[Z](a: int | Literal['gotcha!'] | Z | None, b: ~K) -> dict[str, int | Z]: ...
213+
def sbase[Z](a: OrGotcha[int] | Z | None, b: ~K) -> dict[str, int | Z]: ...
214214
""")
215215

216216

@@ -412,10 +412,15 @@ def test_type_members_func_3():
412412

413413
assert (
414414
str(typ)
415-
# == "\
416-
# staticmethod[tuple[typemap.typing.Param[typing.Literal['a'], int | typing.Literal['gotcha!'] | Z | None, typing.Never], typemap.typing.Param[typing.Literal['b'], ~K, typing.Never]], dict[str, int | Z]]"
417-
== "\
418-
typemap.typing.GenericCallable[tuple[Z], staticmethod[tuple[typemap.typing.Param[typing.Literal['a'], int | typing.Literal['gotcha!'] | Z | None, typing.Never], typemap.typing.Param[typing.Literal['b'], ~K, typing.Never]], dict[str, int | Z]]]"
415+
== "typemap.typing.GenericCallable[tuple[Z], typemap.type_eval._eval_operators._create_generic_callable_lambda.<locals>.<lambda>]"
416+
)
417+
418+
evaled = eval_typing(
419+
typing.get_args(typ)[1](*typing.get_args(typing.get_args(typ)[0]))
420+
)
421+
assert (
422+
str(evaled)
423+
== "staticmethod[tuple[typemap.typing.Param[typing.Literal['a'], int | typing.Literal['gotcha!'] | Z | None, typing.Never], typemap.typing.Param[typing.Literal['b'], ~K, typing.Never]], dict[str, int | Z]]"
419424
)
420425

421426

tests/test_type_eval.py

Lines changed: 67 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
Tuple,
1414
TypeVar,
1515
Union,
16+
get_args,
1617
)
1718

1819
import pytest
@@ -27,9 +28,11 @@
2728
GenericCallable,
2829
GetArg,
2930
GetArgs,
31+
GetDefiner,
3032
GetMember,
3133
GetMemberType,
3234
GetName,
35+
GetQuals,
3336
GetSpecialAttr,
3437
GetType,
3538
GetAnnotations,
@@ -394,6 +397,56 @@ def test_getmember_01():
394397
assert d == Never
395398

396399

400+
def test_getmember_02():
401+
type OnlyIntToSet[T] = set[T] if IsSub[T, int] else T
402+
403+
class C:
404+
def f[T](self, x: T) -> OnlyIntToSet[T]: ...
405+
406+
m = eval_typing(GetMember[C, Literal["f"]])
407+
assert eval_typing(GetName[m]) == Literal["f"]
408+
assert eval_typing(GetQuals[m]) == Literal["ClassVar"]
409+
assert eval_typing(GetDefiner[m]) == C
410+
411+
t = eval_typing(GetType[m])
412+
Vs = get_args(get_args(t)[0])
413+
L = get_args(t)[1]
414+
f = L(*Vs)
415+
assert (
416+
f
417+
== Callable[
418+
[Param[Literal["self"], C], Param[Literal["x"], Vs[0]]],
419+
OnlyIntToSet[Vs[0]],
420+
]
421+
)
422+
423+
424+
def test_getmember_03():
425+
type OnlyIntToSet[T] = set[T] if IsSub[T, int] else T
426+
427+
class C:
428+
def f[T](self, x: T) -> OnlyIntToSet[T]: ...
429+
430+
type P = IndirectProtocol[C]
431+
432+
m = eval_typing(GetMember[P, Literal["f"]])
433+
assert eval_typing(GetName[m]) == Literal["f"]
434+
assert eval_typing(GetQuals[m]) == Literal["ClassVar"]
435+
assert eval_typing(GetDefiner[m]) != C # eval typing generates a new class
436+
437+
t = eval_typing(GetType[m])
438+
Vs = get_args(get_args(t)[0])
439+
L = get_args(t)[1]
440+
f = L(*Vs)
441+
assert (
442+
f
443+
== Callable[
444+
[Param[Literal["self"], Self], Param[Literal["x"], Vs[0]]],
445+
OnlyIntToSet[Vs[0]],
446+
]
447+
)
448+
449+
397450
def test_getarg_never():
398451
d = eval_typing(GetArg[Never, object, Literal[0]])
399452
assert d is Never
@@ -480,11 +533,7 @@ def test_eval_getarg_callable_02():
480533
t = eval_typing(GetArg[gc, GenericCallable, Literal[0]])
481534
assert t == tuple[T]
482535
gc_f = eval_typing(GetArg[gc, GenericCallable, Literal[1]])
483-
assert gc_f == f
484-
t = eval_typing(GetArg[gc_f, Callable, Literal[0]])
485-
assert t == tuple[Param[Literal[None], T, Never]]
486-
t = eval_typing(GetArg[gc_f, Callable, Literal[1]])
487-
assert t is T
536+
assert gc_f == Never
488537

489538
# Params wrapped
490539
f = Callable[
@@ -502,7 +551,7 @@ def test_eval_getarg_callable_02():
502551
t = eval_typing(GetArg[gc, GenericCallable, Literal[0]])
503552
assert t == tuple[T]
504553
gc_f = eval_typing(GetArg[gc, GenericCallable, Literal[1]])
505-
assert gc_f == f
554+
assert gc_f == Never
506555

507556

508557
type IndirectProtocol[T] = NewProtocol[*[m for m in Iter[Members[T]]],]
@@ -650,18 +699,7 @@ def f[T](self, x: T, /, y: T, *, z: T) -> T: ...
650699
GetArg[GetArg[gc, GenericCallable, Literal[0]], tuple, Literal[0]]
651700
)
652701
f = eval_typing(GetArg[gc, GenericCallable, Literal[1]])
653-
t = eval_typing(GetArg[f, Callable, Literal[0]])
654-
assert (
655-
t
656-
== tuple[
657-
Param[Literal["self"], C, Literal["positional"]],
658-
Param[Literal["x"], _T, Literal["positional"]],
659-
Param[Literal["y"], _T],
660-
Param[Literal["z"], _T, Literal["keyword"]],
661-
]
662-
)
663-
t = eval_typing(GetArg[f, Callable, Literal[1]])
664-
assert t is _T
702+
assert f is Never
665703

666704

667705
def test_eval_getarg_callable_08():
@@ -675,19 +713,7 @@ def f[T](cls, x: T, /, y: T, *, z: T) -> T: ...
675713
GetArg[GetArg[gc, GenericCallable, Literal[0]], tuple, Literal[0]]
676714
)
677715
f = eval_typing(GetArg[gc, GenericCallable, Literal[1]])
678-
t = eval_typing(GetArg[f, classmethod, Literal[0]])
679-
assert t is C
680-
t = eval_typing(GetArg[f, classmethod, Literal[1]])
681-
assert (
682-
t
683-
== tuple[
684-
Param[Literal["x"], _T, Literal["positional"]],
685-
Param[Literal["y"], _T],
686-
Param[Literal["z"], _T, Literal["keyword"]],
687-
]
688-
)
689-
t = eval_typing(GetArg[f, classmethod, Literal[2]])
690-
assert t is _T
716+
assert f is Never
691717

692718

693719
def test_eval_getarg_callable_09():
@@ -701,17 +727,7 @@ def f[T](x: T, /, y: T, *, z: T) -> T: ...
701727
GetArg[GetArg[gc, GenericCallable, Literal[0]], tuple, Literal[0]]
702728
)
703729
f = eval_typing(GetArg[gc, GenericCallable, Literal[1]])
704-
t = eval_typing(GetArg[f, staticmethod, Literal[0]])
705-
assert (
706-
t
707-
== tuple[
708-
Param[Literal["x"], _T, Literal["positional"]],
709-
Param[Literal["y"], _T],
710-
Param[Literal["z"], _T, Literal["keyword"]],
711-
]
712-
)
713-
t = eval_typing(GetArg[f, staticmethod, Literal[1]])
714-
assert t is _T
730+
assert f is Never
715731

716732

717733
def test_eval_getarg_tuple():
@@ -989,6 +1005,15 @@ class Container2[T]: ...
9891005
assert eval_typing(GetArg[t, Container, Literal[1]]) == Never
9901006

9911007

1008+
def test_eval_getargs_generic_callable_01():
1009+
T = TypeVar("T")
1010+
t = GenericCallable[
1011+
tuple[T], lambda T: Callable[[Param[Literal["x"], T]], int]
1012+
]
1013+
args = eval_typing(GetArgs[t, GenericCallable])
1014+
assert args == tuple[tuple[T]]
1015+
1016+
9921017
class OuterType:
9931018
class InnerType:
9941019
pass

typemap/type_eval/_apply_generic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def make_func(
170170
func.__globals__,
171171
"__call__",
172172
func.__defaults__,
173-
(),
173+
func.__closure__,
174174
func.__kwdefaults__,
175175
)
176176

typemap/type_eval/_eval_call.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,11 @@ def eval_call_with_types(
162162
_typing_inspect.is_generic_alias(resolved_callable)
163163
and resolved_callable.__origin__ is GenericCallable
164164
):
165-
_, resolved_callable = typing.get_args(resolved_callable)
165+
typevars_tuple, callable_lambda = typing.get_args(resolved_callable)
166+
type_vars = typing.get_args(typevars_tuple)
167+
resolved_callable = callable_lambda(*type_vars)
168+
# Evaluate the result to expand type aliases
169+
resolved_callable = _eval_typing.eval_typing(resolved_callable)
166170

167171
sig = _callable_type_to_signature(resolved_callable)
168172
bound = sig.bind(*arg_types, **kwarg_types)

typemap/type_eval/_eval_operators.py

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,7 @@ def _is_pos_only(param):
464464
)
465465

466466

467-
def _callable_type_to_method(name, typ):
467+
def _callable_type_to_method(name, typ, ctx):
468468
"""Turn a callable type into a method.
469469
470470
I'm not totally sure if this is worth doing! The main accomplishment
@@ -475,8 +475,10 @@ def _callable_type_to_method(name, typ):
475475

476476
head = typing.get_origin(typ)
477477
if head is GenericCallable:
478-
ttparams, typ = typing.get_args(typ)
478+
# Call the lambda with type variables to substitute the type variables
479+
ttparams, ttfunc = typing.get_args(typ)
479480
type_params = typing.get_args(ttparams)
481+
typ = ttfunc(*type_params)
480482
head = typing.get_origin(typ)
481483

482484
if head is classmethod:
@@ -578,10 +580,40 @@ def _ann(x):
578580
else:
579581
f = typing.Callable[params, ret]
580582
if root.__type_params__:
581-
f = GenericCallable[tuple[*root.__type_params__], f]
583+
# Must store a lambda that performs type variable substitution
584+
type_params = root.__type_params__
585+
callable_lambda = _create_generic_callable_lambda(f, type_params)
586+
f = GenericCallable[tuple[*type_params], callable_lambda]
582587
return f
583588

584589

590+
def _create_generic_callable_lambda(
591+
f: typing.Callable | classmethod | staticmethod,
592+
type_params: tuple[typing.TypeVar, ...],
593+
):
594+
if typing.get_origin(f) in (staticmethod, classmethod):
595+
return lambda *vs: _apply_generic.substitute(
596+
f, dict(zip(type_params, vs, strict=True))
597+
)
598+
599+
else:
600+
# Callable params are stored as a list
601+
params, ret = typing.get_args(f)
602+
603+
return lambda *vs: typing.Callable[
604+
[
605+
_apply_generic.substitute(
606+
p,
607+
dict(zip(type_params, vs, strict=True)),
608+
)
609+
for p in params
610+
],
611+
_apply_generic.substitute(
612+
ret, dict(zip(type_params, vs, strict=True))
613+
),
614+
]
615+
616+
585617
def _resolved_function_signature(func, receiver_type=None):
586618
"""Get the signature of a function with type hints resolved.
587619
@@ -888,7 +920,13 @@ def _eval_GetArg(tp, base, idx, *, ctx) -> typing.Any:
888920
return typing.Never
889921

890922
try:
891-
return _fix_type(args[_eval_literal(idx, ctx)])
923+
idx_val = _eval_literal(idx, ctx)
924+
925+
if base_head is GenericCallable and idx_val >= 1:
926+
# Disallow access to callable lambda
927+
return typing.Never
928+
929+
return _fix_type(args[idx_val])
892930
except IndexError:
893931
return typing.Never
894932

@@ -900,6 +938,11 @@ def _eval_GetArgs(tp, base, *, ctx) -> typing.Any:
900938
args = _get_args(tp, base_head, ctx)
901939
if args is None:
902940
return typing.Never
941+
942+
if base_head is GenericCallable:
943+
# Disallow access to callable lambda
944+
return tuple[args[0]] # type: ignore[valid-type]
945+
903946
return tuple[*args] # type: ignore[valid-type]
904947

905948

@@ -1075,7 +1118,7 @@ def _eval_NewProtocol(*etyps: Member, ctx):
10751118
if type_eval.issubtype(
10761119
typing.Literal["ClassVar"], tquals
10771120
) and _is_method_like(typ):
1078-
dct[name] = _callable_type_to_method(name, typ)
1121+
dct[name] = _callable_type_to_method(name, typ, ctx)
10791122
else:
10801123
annos[name] = _add_quals(typ, tquals)
10811124
_unpack_init(dct, name, init)

typemap/typing.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55
import types
66

77
from typing import Literal, Unpack
8-
from typing import _GenericAlias, _LiteralGenericAlias, _UnpackGenericAlias
8+
from typing import (
9+
_GenericAlias,
10+
_LiteralGenericAlias,
11+
_UnpackGenericAlias,
12+
)
913

1014
_SpecialForm: typing.Any = typing._SpecialForm
1115

@@ -67,16 +71,26 @@ class SpecialFormEllipsis:
6771
###
6872

6973

70-
# We really need to be able to represent generic function types but it
71-
# is a problem for all kinds of reasons...
72-
# Can we bang it into Callable??
73-
class GenericCallable[
74-
TVs: tuple[typing.TypeVar, ...],
75-
C: typing.Callable | staticmethod | classmethod,
76-
]:
74+
class _GenericCallableGenericAlias(_GenericAlias, _root=True):
7775
pass
7876

7977

78+
class GenericCallable:
79+
def __class_getitem__(cls, params):
80+
message = (
81+
"GenericCallable must be used as "
82+
"GenericCallable[tuple[TypeVar, ...], lambda <vs>: callable]."
83+
)
84+
if not isinstance(params, tuple) or len(params) != 2:
85+
raise TypeError(message)
86+
87+
typevars, func = params
88+
if not callable(func):
89+
raise TypeError(message)
90+
91+
return _GenericCallableGenericAlias(cls, (typevars, func))
92+
93+
8094
###
8195

8296

0 commit comments

Comments
 (0)