Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions tests/test_type_dir.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def base[Z](self: Self, a: int | Z | None, b: ~K) -> dict[str, int | Z]: ...
@classmethod
def cbase(cls: type[typing.Self], a: int | None, b: ~K) -> dict[str, int]: ...
@staticmethod
def sbase[Z](a: int | Literal['gotcha!'] | Z | None, b: ~K) -> dict[str, int | Z]: ...
def sbase[Z](a: OrGotcha[int] | Z | None, b: ~K) -> dict[str, int | Z]: ...
""")


Expand Down Expand Up @@ -410,10 +410,15 @@ def test_type_members_func_3():

assert (
str(typ)
# == "\
# staticmethod[tuple[typemap.typing.Param[typing.Literal['a'], int | typing.Literal['gotcha!'] | Z | None, typing.Never], typemap.typing.Param[typing.Literal['b'], ~K, typing.Never]], dict[str, int | Z]]"
== "\
typemap.typing.GenericCallable[tuple[Z], staticmethod[tuple[typemap.typing.Param[typing.Literal['a'], int | typing.Literal['gotcha!'] | Z | None, typing.Never], typemap.typing.Param[typing.Literal['b'], ~K, typing.Never]], dict[str, int | Z]]]"
== "typemap.typing.GenericCallable[tuple[Z], typemap.type_eval._eval_operators._create_generic_callable_lambda.<locals>.<lambda>]"
)

evaled = eval_typing(
typing.get_args(typ)[1](*typing.get_args(typing.get_args(typ)[0]))
)
assert (
str(evaled)
== "staticmethod[tuple[typemap.typing.Param[typing.Literal['a'], int | typing.Literal['gotcha!'] | Z | None, typing.Never], typemap.typing.Param[typing.Literal['b'], ~K, typing.Never]], dict[str, int | Z]]"
)


Expand Down
109 changes: 67 additions & 42 deletions tests/test_type_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Tuple,
TypeVar,
Union,
get_args,
)

import pytest
Expand All @@ -25,9 +26,11 @@
GenericCallable,
GetArg,
GetArgs,
GetDefiner,
GetMember,
GetMemberType,
GetName,
GetQuals,
GetSpecialAttr,
GetType,
GetAnnotations,
Expand Down Expand Up @@ -394,6 +397,56 @@ def test_getmember_01():
assert d == Never


def test_getmember_02():
type OnlyIntToSet[T] = set[T] if IsSub[T, int] else T

class C:
def f[T](self, x: T) -> OnlyIntToSet[T]: ...

m = eval_typing(GetMember[C, Literal["f"]])
assert eval_typing(GetName[m]) == Literal["f"]
assert eval_typing(GetQuals[m]) == Literal["ClassVar"]
assert eval_typing(GetDefiner[m]) == C

t = eval_typing(GetType[m])
Vs = get_args(get_args(t)[0])
L = get_args(t)[1]
f = L(*Vs)
assert (
f
== Callable[
[Param[Literal["self"], C], Param[Literal["x"], Vs[0]]],
OnlyIntToSet[Vs[0]],
]
)


def test_getmember_03():
type OnlyIntToSet[T] = set[T] if IsSub[T, int] else T

class C:
def f[T](self, x: T) -> OnlyIntToSet[T]: ...

type P = IndirectProtocol[C]

m = eval_typing(GetMember[P, Literal["f"]])
assert eval_typing(GetName[m]) == Literal["f"]
assert eval_typing(GetQuals[m]) == Literal["ClassVar"]
assert eval_typing(GetDefiner[m]) != C # eval typing generates a new class

t = eval_typing(GetType[m])
Vs = get_args(get_args(t)[0])
L = get_args(t)[1]
f = L(*Vs)
assert (
f
== Callable[
[Param[Literal["self"], Self], Param[Literal["x"], Vs[0]]],
OnlyIntToSet[Vs[0]],
]
)


def test_getarg_never():
d = eval_typing(GetArg[Never, object, Literal[0]])
assert d is Never
Expand Down Expand Up @@ -480,11 +533,7 @@ def test_eval_getarg_callable_02():
t = eval_typing(GetArg[gc, GenericCallable, Literal[0]])
assert t == tuple[T]
gc_f = eval_typing(GetArg[gc, GenericCallable, Literal[1]])
assert gc_f == f
t = eval_typing(GetArg[gc_f, Callable, Literal[0]])
assert t == tuple[Param[Literal[None], T, Never]]
t = eval_typing(GetArg[gc_f, Callable, Literal[1]])
assert t is T
assert gc_f == Never

# Params wrapped
f = Callable[
Expand All @@ -502,7 +551,7 @@ def test_eval_getarg_callable_02():
t = eval_typing(GetArg[gc, GenericCallable, Literal[0]])
assert t == tuple[T]
gc_f = eval_typing(GetArg[gc, GenericCallable, Literal[1]])
assert gc_f == f
assert gc_f == Never


type IndirectProtocol[T] = NewProtocol[*[m for m in Iter[Members[T]]],]
Expand Down Expand Up @@ -650,18 +699,7 @@ def f[T](self, x: T, /, y: T, *, z: T) -> T: ...
GetArg[GetArg[gc, GenericCallable, Literal[0]], tuple, Literal[0]]
)
f = eval_typing(GetArg[gc, GenericCallable, Literal[1]])
t = eval_typing(GetArg[f, Callable, Literal[0]])
assert (
t
== tuple[
Param[Literal["self"], C, Literal["positional"]],
Param[Literal["x"], _T, Literal["positional"]],
Param[Literal["y"], _T],
Param[Literal["z"], _T, Literal["keyword"]],
]
)
t = eval_typing(GetArg[f, Callable, Literal[1]])
assert t is _T
assert f is Never


def test_eval_getarg_callable_08():
Expand All @@ -675,19 +713,7 @@ def f[T](cls, x: T, /, y: T, *, z: T) -> T: ...
GetArg[GetArg[gc, GenericCallable, Literal[0]], tuple, Literal[0]]
)
f = eval_typing(GetArg[gc, GenericCallable, Literal[1]])
t = eval_typing(GetArg[f, classmethod, Literal[0]])
assert t is C
t = eval_typing(GetArg[f, classmethod, Literal[1]])
assert (
t
== tuple[
Param[Literal["x"], _T, Literal["positional"]],
Param[Literal["y"], _T],
Param[Literal["z"], _T, Literal["keyword"]],
]
)
t = eval_typing(GetArg[f, classmethod, Literal[2]])
assert t is _T
assert f is Never


def test_eval_getarg_callable_09():
Expand All @@ -701,17 +727,7 @@ def f[T](x: T, /, y: T, *, z: T) -> T: ...
GetArg[GetArg[gc, GenericCallable, Literal[0]], tuple, Literal[0]]
)
f = eval_typing(GetArg[gc, GenericCallable, Literal[1]])
t = eval_typing(GetArg[f, staticmethod, Literal[0]])
assert (
t
== tuple[
Param[Literal["x"], _T, Literal["positional"]],
Param[Literal["y"], _T],
Param[Literal["z"], _T, Literal["keyword"]],
]
)
t = eval_typing(GetArg[f, staticmethod, Literal[1]])
assert t is _T
assert f is Never


def test_eval_getarg_tuple():
Expand Down Expand Up @@ -989,6 +1005,15 @@ class Container2[T]: ...
assert eval_typing(GetArg[t, Container, Literal[1]]) == Never


def test_eval_getargs_generic_callable_01():
T = TypeVar("T")
t = GenericCallable[
tuple[T], lambda T: Callable[[Param[Literal["x"], T]], int]
]
args = eval_typing(GetArgs[t, GenericCallable])
assert args == tuple[tuple[T]]


class OuterType:
class InnerType:
pass
Expand Down
2 changes: 1 addition & 1 deletion typemap/type_eval/_apply_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def make_func(
func.__globals__,
"__call__",
func.__defaults__,
(),
func.__closure__,
func.__kwdefaults__,
)

Expand Down
6 changes: 5 additions & 1 deletion typemap/type_eval/_eval_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,11 @@ def eval_call_with_types(
_typing_inspect.is_generic_alias(resolved_callable)
and resolved_callable.__origin__ is GenericCallable
):
_, resolved_callable = typing.get_args(resolved_callable)
typevars_tuple, callable_lambda = typing.get_args(resolved_callable)
type_vars = typing.get_args(typevars_tuple)
resolved_callable = callable_lambda(*type_vars)
# Evaluate the result to expand type aliases
resolved_callable = _eval_typing.eval_typing(resolved_callable)

sig = _callable_type_to_signature(resolved_callable)
bound = sig.bind(*arg_types, **kwarg_types)
Expand Down
53 changes: 48 additions & 5 deletions typemap/type_eval/_eval_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ def _is_pos_only(param):
)


def _callable_type_to_method(name, typ):
def _callable_type_to_method(name, typ, ctx):
"""Turn a callable type into a method.

I'm not totally sure if this is worth doing! The main accomplishment
Expand All @@ -482,8 +482,10 @@ def _callable_type_to_method(name, typ):

head = typing.get_origin(typ)
if head is GenericCallable:
ttparams, typ = typing.get_args(typ)
# Call the lambda with type variables to substitute the type variables
ttparams, ttfunc = typing.get_args(typ)
type_params = typing.get_args(ttparams)
typ = ttfunc(*type_params)
head = typing.get_origin(typ)

if head is classmethod:
Expand Down Expand Up @@ -585,10 +587,40 @@ def _ann(x):
else:
f = typing.Callable[params, ret]
if root.__type_params__:
f = GenericCallable[tuple[*root.__type_params__], f]
# Must store a lambda that performs type variable substitution
type_params = root.__type_params__
callable_lambda = _create_generic_callable_lambda(f, type_params)
f = GenericCallable[tuple[*type_params], callable_lambda]
Comment on lines 589 to +593
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for this to work right in the cases where the annotations do nontrivial computation on the params, it needs to go first, and then the whole rest of everything needs to be in the lambda?

return f


def _create_generic_callable_lambda(
f: typing.Callable | classmethod | staticmethod,
type_params: tuple[typing.TypeVar, ...],
):
if typing.get_origin(f) in (staticmethod, classmethod):
return lambda *vs: _apply_generic.substitute(
f, dict(zip(type_params, vs, strict=True))
)

else:
# Callable params are stored as a list
params, ret = typing.get_args(f)

return lambda *vs: typing.Callable[
[
_apply_generic.substitute(
p,
dict(zip(type_params, vs, strict=True)),
)
for p in params
],
_apply_generic.substitute(
ret, dict(zip(type_params, vs, strict=True))
),
]


def _resolved_function_signature(func, receiver_type=None):
"""Get the signature of a function with type hints resolved.

Expand Down Expand Up @@ -895,7 +927,13 @@ def _eval_GetArg(tp, base, idx, *, ctx) -> typing.Any:
return typing.Never

try:
return _fix_type(args[_eval_literal(idx, ctx)])
idx_val = _eval_literal(idx, ctx)

if base_head is GenericCallable and idx_val >= 1:
# Disallow access to callable lambda
return typing.Never

return _fix_type(args[idx_val])
except IndexError:
return typing.Never

Expand All @@ -907,6 +945,11 @@ def _eval_GetArgs(tp, base, *, ctx) -> typing.Any:
args = _get_args(tp, base_head, ctx)
if args is None:
return typing.Never

if base_head is GenericCallable:
# Disallow access to callable lambda
return tuple[args[0]] # type: ignore[valid-type]

return tuple[*args] # type: ignore[valid-type]


Expand Down Expand Up @@ -1082,7 +1125,7 @@ def _eval_NewProtocol(*etyps: Member, ctx):
if type_eval.issubsimilar(
typing.Literal["ClassVar"], tquals
) and _is_method_like(typ):
dct[name] = _callable_type_to_method(name, typ)
dct[name] = _callable_type_to_method(name, typ, ctx)
else:
annos[name] = _add_quals(typ, tquals)
_unpack_init(dct, name, init)
Expand Down
30 changes: 22 additions & 8 deletions typemap/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
import types

from typing import Literal, Unpack
from typing import _GenericAlias, _LiteralGenericAlias, _UnpackGenericAlias
from typing import (
_GenericAlias,
_LiteralGenericAlias,
_UnpackGenericAlias,
)

_SpecialForm: typing.Any = typing._SpecialForm

Expand Down Expand Up @@ -67,16 +71,26 @@ class SpecialFormEllipsis:
###


# We really need to be able to represent generic function types but it
# is a problem for all kinds of reasons...
# Can we bang it into Callable??
class GenericCallable[
TVs: tuple[typing.TypeVar, ...],
C: typing.Callable | staticmethod | classmethod,
]:
class _GenericCallableGenericAlias(_GenericAlias, _root=True):
pass


class GenericCallable:
def __class_getitem__(cls, params):
message = (
"GenericCallable must be used as "
"GenericCallable[tuple[TypeVar, ...], lambda <vs>: callable]."
)
if not isinstance(params, tuple) or len(params) != 2:
raise TypeError(message)

typevars, func = params
if not callable(func):
raise TypeError(message)

return _GenericCallableGenericAlias(cls, (typevars, func))


###


Expand Down