Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 82 additions & 2 deletions tests/test_call.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import textwrap

from typing import Unpack
from typing import Generic, Literal, Self, TypeVar, Unpack

from typemap.type_eval import eval_call
from typemap.typing import (
Attrs,
BaseTypedDict,
GetName,
NewProtocol,
Member,
GetName,
Iter,
)

Expand Down Expand Up @@ -72,3 +72,83 @@ class Wrapped[typing.Literal[1]]:
value: typing.Literal[1]
def __init__(self: Self, value: Literal[1]) -> None: ...
""")


def test_call_bound_method_01():
# non-generic class, non-generic method
class C:
def invoke(self: Self, x: int) -> int:
return x

c = C()
ret = eval_call(c.invoke, 1)
assert ret is int


def test_call_bound_method_02():
# non-generic class, generic method
class C:
def invoke[X](self: Self, x: X) -> X:
return x

c = C()
ret = eval_call(c.invoke, 1)
assert ret is Literal[1]


def test_call_bound_method_03():
# generic class, non-generic method, with type var
X = TypeVar("X")

class C(Generic[X]):
def invoke(self: Self, x: X) -> X:
return x

c = C[int]()
ret = eval_call(c.invoke, 1)
assert ret is Literal[1]


def test_call_bound_method_04():
# generic class, non-generic method, PEP695 syntax
class C[X]:
def invoke(self: Self, x: X) -> X:
return x

c = C[int]()
ret = eval_call(c.invoke, 1)
assert ret is Literal[1]


def test_call_bound_method_05():
# generic class, generic method, with type var
X = TypeVar("X")

class C(Generic[X]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the generic class syntax work? (class C[X]:)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I just added tests to confirm this.

def invoke[Y](self: Self, x: Y) -> Y:
return x

c = C[int]()
ret = eval_call(c.invoke, "!!!")
assert ret is Literal["!!!"]


def test_call_bound_method_06():
# generic class, generic method, PEP695 syntax
class C[X]:
def invoke[Y](self: Self, x: Y) -> Y:
return x

c = C[int]()
ret = eval_call(c.invoke, "!!!")
assert ret is Literal["!!!"]


def test_call_local_type_01():
class C: ...

def invoke() -> C:
return C()

ret = eval_call(invoke)
assert ret is C
72 changes: 61 additions & 11 deletions typemap/type_eval/_eval_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,56 @@ def _type(t):
return type(t)


def eval_call(func: types.FunctionType, /, *args: Any, **kwargs: Any) -> RtType:
def eval_call(
func: types.FunctionType | types.MethodType, /, *args: Any, **kwargs: Any
) -> RtType:
bound_self: Any | None = None
if isinstance(func, types.MethodType):
bound_self = func.__self__
func = func.__func__ # type: ignore[assignment]

arg_types = tuple(_type(t) for t in args)
kwarg_types = {k: _type(t) for k, t in kwargs.items()}
return eval_call_with_types(func, arg_types, kwarg_types)
return eval_call_with_types(func, arg_types, kwarg_types, bound_self)


def _get_bound_type_args(
func: types.FunctionType,
func: types.FunctionType | types.MethodType,
arg_types: tuple[RtType, ...],
kwarg_types: dict[str, RtType],
bound_self: Any | None = None,
) -> dict[str, RtType]:
sig = inspect.signature(func)
bound = sig.bind(*arg_types, **kwarg_types)
bound = (
sig.bind(bound_self, *arg_types, **kwarg_types)
if bound_self
else sig.bind(*arg_types, **kwarg_types)
)

vars: dict[str, RtType] = {}

# Extract type parameters for bound methods
if bound_self and hasattr(bound_self, '__orig_class__'):
# Bound to a generic class
orig_class = bound_self.__orig_class__
origin = orig_class.__origin__
type_args = orig_class.__args__

for type_param, arg in zip(
origin.__type_params__,
type_args,
strict=False,
):
vars[type_param.__name__] = arg

if hasattr(origin, '__dict__'):
vars['__classdict__'] = dict(origin.__dict__)
elif bound_self:
# Bound to a non-generic class
bound_class = type(bound_self)
if hasattr(bound_class, '__dict__'):
vars['__classdict__'] = dict(bound_class.__dict__)

# TODO: duplication, error cases
for param in sig.parameters.values():
if (
Expand Down Expand Up @@ -77,13 +112,16 @@ def _get_bound_type_args(


def eval_call_with_types(
func: types.FunctionType,
func: types.FunctionType | types.MethodType,
arg_types: tuple[RtType, ...],
kwarg_types: dict[str, RtType],
bound_self: Any | None = None,
) -> RtType:
vars: dict[str, Any] = {}
params = func.__type_params__
vars = _get_bound_type_args(func, arg_types, kwarg_types)
params = (
func.__type_params__ if isinstance(func, types.FunctionType) else ()
)
vars = _get_bound_type_args(func, arg_types, kwarg_types, bound_self)
for p in params:
if p.__name__ not in vars:
vars[p.__name__] = p
Expand All @@ -92,26 +130,38 @@ def eval_call_with_types(


def eval_call_with_type_vars(
func: types.FunctionType, vars: dict[str, RtType]
func: types.FunctionType | types.MethodType,
vars: dict[str, RtType],
) -> RtType:
with _eval_typing._ensure_context() as ctx:
return _eval_call_with_type_vars(func, vars, ctx)


def _eval_call_with_type_vars(
func: types.FunctionType,
func: types.FunctionType | types.MethodType,
vars: dict[str, RtType],
ctx: _eval_typing.EvalContext,
) -> RtType:
try:
af = func.__annotate__
af = (
func.__annotate__
if isinstance(func, types.FunctionType)
else func.__call__.__annotate__
)
except AttributeError:
raise ValueError("func has no __annotate__ attribute")
if not af:
raise ValueError("func has no __annotate__ attribute")

closure_vars_by_name = dict(
zip(func.__code__.co_freevars, func.__closure__ or (), strict=True)
)

af_args = tuple(
types.CellType(vars[name]) for name in af.__code__.co_freevars
types.CellType(vars[name])
if name in vars
else closure_vars_by_name[name]
for name in af.__code__.co_freevars
)

ff = types.FunctionType(
Expand Down