diff --git a/tests/test_call.py b/tests/test_call.py index e983aa7..79ac5da 100644 --- a/tests/test_call.py +++ b/tests/test_call.py @@ -1,14 +1,14 @@ import textwrap -from typing import Unpack +from typing import Generic, Literal, Self, TypeVar, Unpack from typemap.type_eval import eval_call from typemap.typing import ( Attrs, BaseTypedDict, + GetName, NewProtocol, Member, - GetName, Iter, ) @@ -72,3 +72,83 @@ class Wrapped[typing.Literal[1]]: value: typing.Literal[1] def __init__(self: Self, value: Literal[1]) -> None: ... """) + + +def test_call_bound_method_01(): + # non-generic class, non-generic method + class C: + def invoke(self: Self, x: int) -> int: + return x + + c = C() + ret = eval_call(c.invoke, 1) + assert ret is int + + +def test_call_bound_method_02(): + # non-generic class, generic method + class C: + def invoke[X](self: Self, x: X) -> X: + return x + + c = C() + ret = eval_call(c.invoke, 1) + assert ret is Literal[1] + + +def test_call_bound_method_03(): + # generic class, non-generic method, with type var + X = TypeVar("X") + + class C(Generic[X]): + def invoke(self: Self, x: X) -> X: + return x + + c = C[int]() + ret = eval_call(c.invoke, 1) + assert ret is Literal[1] + + +def test_call_bound_method_04(): + # generic class, non-generic method, PEP695 syntax + class C[X]: + def invoke(self: Self, x: X) -> X: + return x + + c = C[int]() + ret = eval_call(c.invoke, 1) + assert ret is Literal[1] + + +def test_call_bound_method_05(): + # generic class, generic method, with type var + X = TypeVar("X") + + class C(Generic[X]): + def invoke[Y](self: Self, x: Y) -> Y: + return x + + c = C[int]() + ret = eval_call(c.invoke, "!!!") + assert ret is Literal["!!!"] + + +def test_call_bound_method_06(): + # generic class, generic method, PEP695 syntax + class C[X]: + def invoke[Y](self: Self, x: Y) -> Y: + return x + + c = C[int]() + ret = eval_call(c.invoke, "!!!") + assert ret is Literal["!!!"] + + +def test_call_local_type_01(): + class C: ... + + def invoke() -> C: + return C() + + ret = eval_call(invoke) + assert ret is C diff --git a/typemap/type_eval/_eval_call.py b/typemap/type_eval/_eval_call.py index f874a25..ea6987f 100644 --- a/typemap/type_eval/_eval_call.py +++ b/typemap/type_eval/_eval_call.py @@ -22,21 +22,56 @@ def _type(t): return type(t) -def eval_call(func: types.FunctionType, /, *args: Any, **kwargs: Any) -> RtType: +def eval_call( + func: types.FunctionType | types.MethodType, /, *args: Any, **kwargs: Any +) -> RtType: + bound_self: Any | None = None + if isinstance(func, types.MethodType): + bound_self = func.__self__ + func = func.__func__ # type: ignore[assignment] + arg_types = tuple(_type(t) for t in args) kwarg_types = {k: _type(t) for k, t in kwargs.items()} - return eval_call_with_types(func, arg_types, kwarg_types) + return eval_call_with_types(func, arg_types, kwarg_types, bound_self) def _get_bound_type_args( - func: types.FunctionType, + func: types.FunctionType | types.MethodType, arg_types: tuple[RtType, ...], kwarg_types: dict[str, RtType], + bound_self: Any | None = None, ) -> dict[str, RtType]: sig = inspect.signature(func) - bound = sig.bind(*arg_types, **kwarg_types) + bound = ( + sig.bind(bound_self, *arg_types, **kwarg_types) + if bound_self + else sig.bind(*arg_types, **kwarg_types) + ) vars: dict[str, RtType] = {} + + # Extract type parameters for bound methods + if bound_self and hasattr(bound_self, '__orig_class__'): + # Bound to a generic class + orig_class = bound_self.__orig_class__ + origin = orig_class.__origin__ + type_args = orig_class.__args__ + + for type_param, arg in zip( + origin.__type_params__, + type_args, + strict=False, + ): + vars[type_param.__name__] = arg + + if hasattr(origin, '__dict__'): + vars['__classdict__'] = dict(origin.__dict__) + elif bound_self: + # Bound to a non-generic class + bound_class = type(bound_self) + if hasattr(bound_class, '__dict__'): + vars['__classdict__'] = dict(bound_class.__dict__) + # TODO: duplication, error cases for param in sig.parameters.values(): if ( @@ -77,13 +112,16 @@ def _get_bound_type_args( def eval_call_with_types( - func: types.FunctionType, + func: types.FunctionType | types.MethodType, arg_types: tuple[RtType, ...], kwarg_types: dict[str, RtType], + bound_self: Any | None = None, ) -> RtType: vars: dict[str, Any] = {} - params = func.__type_params__ - vars = _get_bound_type_args(func, arg_types, kwarg_types) + params = ( + func.__type_params__ if isinstance(func, types.FunctionType) else () + ) + vars = _get_bound_type_args(func, arg_types, kwarg_types, bound_self) for p in params: if p.__name__ not in vars: vars[p.__name__] = p @@ -92,26 +130,38 @@ def eval_call_with_types( def eval_call_with_type_vars( - func: types.FunctionType, vars: dict[str, RtType] + func: types.FunctionType | types.MethodType, + vars: dict[str, RtType], ) -> RtType: with _eval_typing._ensure_context() as ctx: return _eval_call_with_type_vars(func, vars, ctx) def _eval_call_with_type_vars( - func: types.FunctionType, + func: types.FunctionType | types.MethodType, vars: dict[str, RtType], ctx: _eval_typing.EvalContext, ) -> RtType: try: - af = func.__annotate__ + af = ( + func.__annotate__ + if isinstance(func, types.FunctionType) + else func.__call__.__annotate__ + ) except AttributeError: raise ValueError("func has no __annotate__ attribute") if not af: raise ValueError("func has no __annotate__ attribute") + closure_vars_by_name = dict( + zip(func.__code__.co_freevars, func.__closure__ or (), strict=True) + ) + af_args = tuple( - types.CellType(vars[name]) for name in af.__code__.co_freevars + types.CellType(vars[name]) + if name in vars + else closure_vars_by_name[name] + for name in af.__code__.co_freevars ) ff = types.FunctionType(