From 5c861601c83633007f27c5be6e154c4c0e7b501f Mon Sep 17 00:00:00 2001 From: dnwpark Date: Thu, 15 Jan 2026 14:25:09 -0800 Subject: [PATCH 1/6] Add tests. --- tests/test_type_eval.py | 312 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 312 insertions(+) diff --git a/tests/test_type_eval.py b/tests/test_type_eval.py index 0140e27..7f7f280 100644 --- a/tests/test_type_eval.py +++ b/tests/test_type_eval.py @@ -1025,3 +1025,315 @@ def test_type_eval_annotated_03(): def test_type_eval_annotated_04(): res = eval_typing(GetAnnotations[GetAttr[AnnoTest, Literal["b"]]]) assert res == Literal["blah"] + + +def test_type_call_callable_01(): + res = eval_type_call(Callable[[], int]) + assert res is int + + +def test_type_call_callable_02(): + res = eval_type_call(Callable[[Param[Literal["x"], int]], int], int) + assert res is int + + +def test_type_call_callable_03(): + res = eval_type_call( + Callable[[Param[Literal["x"], int, Literal["keyword"]]], int], x=int + ) + assert res is int + + +def test_type_call_callable_04(): + class C: ... + + res = eval_type_call(Callable[[Param[Literal["self"], Self]], int], C) + assert res is int + + +def test_type_call_callable_05(): + class C: ... + + res = eval_type_call(Callable[[Param[Literal["self"], Self]], C], C) + assert res is C + + +def test_type_call_callable_06(): + class C: ... + + res = eval_type_call( + Callable[[Param[Literal["self"], Self], Param[Literal["x"], int]], int], + C, + int, + ) + assert res is int + + +def test_type_call_callable_07(): + class C: ... + + res = eval_type_call( + Callable[ + [ + Param[Literal["self"], Self], + Param[Literal["x"], int, Literal["keyword"]], + ], + int, + ], + C, + x=int, + ) + assert res is int + + +def test_type_call_callable_08(): + T = TypeVar("T") + res = eval_type_call(Callable[[Param[Literal["x"], T]], str], int) + assert res is str + + +def test_type_call_callable_09(): + T = TypeVar("T") + res = eval_type_call(Callable[[Param[Literal["x"], T]], T], int) + assert res is int + + +def test_type_call_callable_10(): + T = TypeVar("T") + + class C(Generic[T]): ... + + res = eval_type_call(Callable[[Param[Literal["x"], C[T]]], T], C[int]) + assert res is int + + +def test_type_call_callable_11(): + T = TypeVar("T") + + class C(Generic[T]): ... + + class D(C[int]): ... + + class E(D): ... + + res = eval_type_call(Callable[[Param[Literal["x"], C[T]]], T], D) + assert res is int + res = eval_type_call(Callable[[Param[Literal["x"], C[T]]], T], E) + assert res is int + + +def test_type_call_local_function_01(): + def func(x: int) -> int: ... + + res = eval_type_call(func, int) + assert res is int + + +def test_type_call_local_function_02(): + def func(*, x: int) -> int: ... + + res = eval_type_call(func, x=int) + assert res is int + + +def test_type_call_local_function_03(): + def func[T](x: T) -> T: ... + + res = eval_type_call(func, int) + assert res is int + + +def test_type_call_local_function_04(): + class C: ... + + def func(x: C) -> C: ... + + res = eval_type_call(func, C) + assert res is C + + +def test_type_call_local_function_05(): + class C: ... + + def func[T](x: T) -> T: ... + + res = eval_type_call(func, C) + assert res is C + + +def test_type_call_local_function_06(): + T = TypeVar("T") + + class C(Generic[T]): ... + + def func[U](x: C[U]) -> C[U]: ... + + res = eval_type_call(func, C[int]) + assert res == C[int] + + +def test_type_call_local_function_07(): + T = TypeVar("T") + + class C(Generic[T]): ... + + class D(C[int]): ... + + class E(D): ... + + def func[U](x: C[U]) -> U: ... + + res = eval_type_call(func, D) + assert res is int + res = eval_type_call(func, E) + assert res is int + + +def test_type_call_local_function_08(): + class C[T]: ... + + class D(C[int]): ... + + class E(C[str]): ... + + class F(D, E): ... + + def func[U](x: C[U]) -> U: ... + + res = eval_type_call(func, F) + assert res is int + + +def test_type_call_local_function_09(): + class C[T, U]: ... + + def func[V](x: C[int, V]) -> V: ... + + res = eval_type_call(func, C[int, str]) + assert res is str + + +def test_type_call_bind_error_01(): + T = TypeVar("T") + + with pytest.raises( + ValueError, match="Type variable T is already bound to int, but got str" + ): + eval_type_call( + Callable[[Param[Literal["x"], T], Param[Literal["y"], T]], T], + int, + str, + ) + + +def test_type_call_bind_error_02(): + def func[T](x: T, y: T) -> T: ... + + with pytest.raises( + ValueError, match="Type variable T is already bound to int, but got str" + ): + eval_type_call(func, int, str) + + +def test_type_call_bind_error_03(): + T = TypeVar("T") + + class C(Generic[T]): ... + + with pytest.raises( + ValueError, match="Type variable T is already bound to int, but got str" + ): + eval_type_call( + Callable[[Param[Literal["x"], C[T]], Param[Literal["y"], C[T]]], T], + C[int], + C[str], + ) + + +def test_type_call_bind_error_04(): + class C[T]: ... + + def func[T](x: C[T], y: C[T]) -> T: ... + + with pytest.raises( + ValueError, match="Type variable T is already bound to int, but got str" + ): + eval_type_call(func, C[int], C[str]) + + +def test_type_call_bind_error_05(): + class C[T]: ... + + class D[T]: ... + + def func[T](x: C[T]) -> T: ... + + with pytest.raises(ValueError, match="Argument type mismatch for x"): + eval_type_call(func, D[int]) + + +type GetCallableMember[T, N: str] = GetArg[ + tuple[ + *[ + GetType[m] + for m in Iter[Members[T]] + if Sub[GetType[m], Callable] and Sub[GetName[m], N] + ] + ], + tuple, + 0, +] + + +def test_type_call_member_01(): + class C: + def invoke(self, x: int) -> int: ... + + res = eval_type_call(GetCallableMember[C, Literal["invoke"]], C, int) + assert res is int + + +def test_type_call_member_02(): + class C: + def invoke[T](self, x: T) -> T: ... + + res = eval_type_call(GetCallableMember[C, Literal["invoke"]], C, int) + assert res is int + + +def test_type_call_member_03(): + class C[T]: + def invoke(self, x: str) -> str: ... + + res = eval_type_call( + GetCallableMember[C[int], Literal["invoke"]], C[int], str + ) + assert res is str + + +def test_type_call_member_04(): + class C[T]: + def invoke(self, x: T) -> T: ... + + res = eval_type_call( + GetCallableMember[C[int], Literal["invoke"]], C[int], int + ) + assert res is int + + +def test_type_call_member_05(): + class C[T]: + def invoke(self) -> C[T]: ... + + res = eval_type_call(GetCallableMember[C[int], Literal["invoke"]], C[int]) + assert res == C[int] + + +def test_type_call_member_06(): + class C[T]: + def invoke[U](self, x: U) -> C[U]: ... + + res = eval_type_call( + GetCallableMember[C[int], Literal["invoke"]], C[int], str + ) + assert res == C[str] From b4cc42f7e0d58423ae64cacf635835b825023fda Mon Sep 17 00:00:00 2001 From: dnwpark Date: Thu, 15 Jan 2026 14:25:25 -0800 Subject: [PATCH 2/6] Implement eval_type_call. --- tests/test_type_eval.py | 2 +- typemap/type_eval/__init__.py | 3 ++- typemap/type_eval/_eval_call.py | 41 +++++++++++++++++++++++++++++++++ 3 files changed, 44 insertions(+), 2 deletions(-) diff --git a/tests/test_type_eval.py b/tests/test_type_eval.py index 7f7f280..7fb1a4a 100644 --- a/tests/test_type_eval.py +++ b/tests/test_type_eval.py @@ -17,7 +17,7 @@ import pytest -from typemap.type_eval import eval_typing +from typemap.type_eval import eval_typing, eval_type_call from typemap.typing import ( Attrs, FromUnion, diff --git a/typemap/type_eval/__init__.py b/typemap/type_eval/__init__.py index b63c49d..c8dc962 100644 --- a/typemap/type_eval/__init__.py +++ b/typemap/type_eval/__init__.py @@ -7,7 +7,7 @@ from ._apply_generic import flatten_class # XXX: this needs to go second due to nasty circularity -- try to fix that!! -from ._eval_call import eval_call +from ._eval_call import eval_call, eval_type_call from ._subtype import issubtype from ._subsim import issubsimilar @@ -19,6 +19,7 @@ "eval_typing", "register_evaluator", "eval_call", + "eval_type_call", "flatten_class", "issubtype", "issubsimilar", diff --git a/typemap/type_eval/_eval_call.py b/typemap/type_eval/_eval_call.py index d71007a..e3d3f79 100644 --- a/typemap/type_eval/_eval_call.py +++ b/typemap/type_eval/_eval_call.py @@ -39,6 +39,13 @@ def _get_bound_type_args( sig = inspect.signature(func) bound = sig.bind(*arg_types, **kwarg_types) + return _get_bound_type_args_from_bound_args(sig, bound) + + +def _get_bound_type_args_from_bound_args( + sig: inspect.Signature, + bound: inspect.BoundArguments, +) -> dict[str, RtType]: vars: dict[str, RtType] = {} # TODO: duplication, error cases for param in sig.parameters.values(): @@ -142,3 +149,37 @@ def _eval_call_with_type_vars( return _eval_typing.eval_typing(rr["return"]) finally: ctx.current_generic_alias = old_obj + + +def eval_type_call( + callable: typing.Any, + *arg_types: Any, + **kwarg_types: Any, +) -> RtType: + arg_types = tuple(_eval_typing.eval_typing(t) for t in arg_types) + kwarg_types = { + k: _eval_typing.eval_typing(t) for k, t in kwarg_types.items() + } + if isinstance(callable, types.FunctionType): + sig = inspect.signature(callable) + else: + resolved_callable = _eval_typing.eval_typing(callable) + sig = _callable_type_to_signature(resolved_callable) + bound = sig.bind(*arg_types, **kwarg_types) + vars = _get_bound_type_args_from_bound_args(sig, bound) + res = _substitute_type_vars(sig.return_annotation, vars) + + return res + + +def _substitute_type_vars( + obj: typing.Any, vars: dict[str, RtType] +) -> typing.Any: + """Recursively substitute type variables into a type expression.""" + if isinstance(obj, typing.TypeVar): + return vars[obj.__name__] + elif _typing_inspect.is_generic_alias(obj): + args = tuple(_substitute_type_vars(v, vars) for v in obj.__args__) + return obj.__origin__[args] # type: ignore[index] + else: + return obj From b4176ad157f610407fa6cd4d96d983885d7894e2 Mon Sep 17 00:00:00 2001 From: dnwpark Date: Thu, 15 Jan 2026 16:34:24 -0800 Subject: [PATCH 3/6] Support binding generics to typevars. --- tests/test_type_eval.py | 4 +++- typemap/type_eval/_eval_call.py | 42 +++++++++++++++++++++++++++++---- 2 files changed, 40 insertions(+), 6 deletions(-) diff --git a/tests/test_type_eval.py b/tests/test_type_eval.py index 7fb1a4a..e758d4a 100644 --- a/tests/test_type_eval.py +++ b/tests/test_type_eval.py @@ -21,6 +21,7 @@ from typemap.typing import ( Attrs, FromUnion, + GenericCallable, GetArg, GetArgs, GetAttr, @@ -1277,7 +1278,8 @@ def func[T](x: C[T]) -> T: ... *[ GetType[m] for m in Iter[Members[T]] - if Sub[GetType[m], Callable] and Sub[GetName[m], N] + if (Sub[GetType[m], Callable] or Sub[GetType[m], GenericCallable]) + and Sub[GetName[m], N] ] ], tuple, diff --git a/typemap/type_eval/_eval_call.py b/typemap/type_eval/_eval_call.py index e3d3f79..555827c 100644 --- a/typemap/type_eval/_eval_call.py +++ b/typemap/type_eval/_eval_call.py @@ -89,17 +89,39 @@ def _get_bound_type_args_from_bound_args( ): vars[tv.__name__] = arg.__args__[0] # trivial T bindings - elif ( - isinstance(param.annotation, typing.TypeVar) - and param.name in bound.arguments - ): + elif isinstance( + param.annotation, typing.TypeVar + ) or _typing_inspect.is_generic_alias(param.annotation): param_value = bound.arguments[param.name] - vars[param.annotation.__name__] = param_value + _update_bound_typevar(param.annotation, param_value, vars) # TODO: simple bindings to other variables too return vars +def _update_bound_typevar( + tv: Any, + param_value: Any, + vars: dict[str, RtType], +) -> None: + if isinstance(tv, typing.TypeVar): + if tv.__name__ not in vars: + vars[tv.__name__] = param_value + elif vars[tv.__name__] != param_value: + raise ValueError( + f"Type variable {tv.__name__} " + f"is already bound to {vars[tv.__name__].__name__}, " + f"but got {param_value.__name__}" + ) + elif bool( + _typing_inspect.is_generic_alias(tv) + and _typing_inspect.is_generic_alias(param_value) + and tv.__origin__ == param_value.__origin__ + ): + for p_arg, c_arg in zip(tv.__args__, param_value.__args__, strict=True): + _update_bound_typevar(p_arg, c_arg, vars) + + def eval_call_with_types( func: types.FunctionType, arg_types: tuple[RtType, ...], @@ -156,6 +178,8 @@ def eval_type_call( *arg_types: Any, **kwarg_types: Any, ) -> RtType: + from typemap.typing import GenericCallable + arg_types = tuple(_eval_typing.eval_typing(t) for t in arg_types) kwarg_types = { k: _eval_typing.eval_typing(t) for k, t in kwarg_types.items() @@ -164,7 +188,15 @@ def eval_type_call( sig = inspect.signature(callable) else: resolved_callable = _eval_typing.eval_typing(callable) + + if ( + _typing_inspect.is_generic_alias(resolved_callable) + and resolved_callable.__origin__ is GenericCallable + ): + _, resolved_callable = typing.get_args(resolved_callable) + sig = _callable_type_to_signature(resolved_callable) + bound = sig.bind(*arg_types, **kwarg_types) vars = _get_bound_type_args_from_bound_args(sig, bound) res = _substitute_type_vars(sig.return_annotation, vars) From 29e2525c317ec69b114f6f57b21a9f317eafc635 Mon Sep 17 00:00:00 2001 From: dnwpark Date: Fri, 16 Jan 2026 09:15:09 -0800 Subject: [PATCH 4/6] Use _get_args to get base class args. --- typemap/type_eval/_eval_call.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/typemap/type_eval/_eval_call.py b/typemap/type_eval/_eval_call.py index 555827c..163ec4d 100644 --- a/typemap/type_eval/_eval_call.py +++ b/typemap/type_eval/_eval_call.py @@ -8,6 +8,7 @@ from typing import Any +from . import _eval_operators from . import _eval_typing from . import _typing_inspect @@ -93,13 +94,16 @@ def _get_bound_type_args_from_bound_args( param.annotation, typing.TypeVar ) or _typing_inspect.is_generic_alias(param.annotation): param_value = bound.arguments[param.name] - _update_bound_typevar(param.annotation, param_value, vars) + _update_bound_typevar( + param.name, param.annotation, param_value, vars + ) # TODO: simple bindings to other variables too return vars def _update_bound_typevar( + param_name: str, tv: Any, param_value: Any, vars: dict[str, RtType], @@ -113,13 +117,19 @@ def _update_bound_typevar( f"is already bound to {vars[tv.__name__].__name__}, " f"but got {param_value.__name__}" ) - elif bool( - _typing_inspect.is_generic_alias(tv) - and _typing_inspect.is_generic_alias(param_value) - and tv.__origin__ == param_value.__origin__ - ): - for p_arg, c_arg in zip(tv.__args__, param_value.__args__, strict=True): - _update_bound_typevar(p_arg, c_arg, vars) + elif _typing_inspect.is_generic_alias(tv): + tv_args = tv.__args__ + + with _eval_typing._ensure_context() as ctx: + param_args = _eval_operators._get_args( + param_value, tv.__origin__, ctx + ) + + if param_args is None: + raise ValueError(f"Argument type mismatch for {param_name}") + + for p_arg, c_arg in zip(tv_args, param_args, strict=True): + _update_bound_typevar(param_name, p_arg, c_arg, vars) def eval_call_with_types( From b085d979f7180b27d4fc3b74b5bcf32178c7c8cb Mon Sep 17 00:00:00 2001 From: dnwpark Date: Mon, 19 Jan 2026 09:28:12 -0800 Subject: [PATCH 5/6] Remove duplicate substitute. --- tests/test_type_eval.py | 7 ++++-- typemap/type_eval/_eval_call.py | 42 +++++++++++++-------------------- 2 files changed, 22 insertions(+), 27 deletions(-) diff --git a/tests/test_type_eval.py b/tests/test_type_eval.py index e758d4a..435cb90 100644 --- a/tests/test_type_eval.py +++ b/tests/test_type_eval.py @@ -1278,8 +1278,11 @@ def func[T](x: C[T]) -> T: ... *[ GetType[m] for m in Iter[Members[T]] - if (Sub[GetType[m], Callable] or Sub[GetType[m], GenericCallable]) - and Sub[GetName[m], N] + if ( + IsSub[GetType[m], Callable] + or IsSub[GetType[m], GenericCallable] + ) + and IsSub[GetName[m], N] ] ], tuple, diff --git a/typemap/type_eval/_eval_call.py b/typemap/type_eval/_eval_call.py index 163ec4d..6f2b334 100644 --- a/typemap/type_eval/_eval_call.py +++ b/typemap/type_eval/_eval_call.py @@ -11,6 +11,8 @@ from . import _eval_operators from . import _eval_typing from . import _typing_inspect +from ._eval_operators import _callable_type_to_signature +from ._apply_generic import substitute RtType = Any @@ -40,14 +42,17 @@ def _get_bound_type_args( sig = inspect.signature(func) bound = sig.bind(*arg_types, **kwarg_types) - return _get_bound_type_args_from_bound_args(sig, bound) + return { + tv.__name__: tp + for tv, tp in _get_bound_type_args_from_bound_args(sig, bound).items() + } def _get_bound_type_args_from_bound_args( sig: inspect.Signature, bound: inspect.BoundArguments, -) -> dict[str, RtType]: - vars: dict[str, RtType] = {} +) -> dict[typing.TypeVar | typing.TypeVarTuple, RtType]: + vars: dict[typing.TypeVar | typing.TypeVarTuple, RtType] = {} # TODO: duplication, error cases for param in sig.parameters.values(): # Unpack[TypeVarType] for *args @@ -62,7 +67,7 @@ def _get_bound_type_args_from_bound_args( and isinstance(tv, typing.TypeVarTuple) ): tps = bound.arguments.get(param.name, ()) - vars[tv.__name__] = tuple[tps] # type: ignore[valid-type] + vars[tv] = tuple[tps] # type: ignore[valid-type] # Unpack[T] for **kwargs elif ( param.kind == inspect.Parameter.VAR_KEYWORD @@ -77,7 +82,7 @@ def _get_bound_type_args_from_bound_args( and typing_extensions.is_typeddict(tv.__bound__) ): tp = typing.TypedDict(f"**{param.name}", bound.kwargs) # type: ignore[misc, operator] - vars[tv.__name__] = tp + vars[tv] = tp # trivial type[T] bindings elif ( _typing_inspect.is_generic_alias(param.annotation) @@ -88,7 +93,7 @@ def _get_bound_type_args_from_bound_args( and _typing_inspect.is_generic_alias(arg) and arg.__origin__ is type ): - vars[tv.__name__] = arg.__args__[0] + vars[tv] = arg.__args__[0] # trivial T bindings elif isinstance( param.annotation, typing.TypeVar @@ -106,15 +111,15 @@ def _update_bound_typevar( param_name: str, tv: Any, param_value: Any, - vars: dict[str, RtType], + vars: dict[typing.TypeVar | typing.TypeVarTuple, RtType], ) -> None: if isinstance(tv, typing.TypeVar): - if tv.__name__ not in vars: - vars[tv.__name__] = param_value - elif vars[tv.__name__] != param_value: + if tv not in vars: + vars[tv] = param_value + elif vars[tv] != param_value: raise ValueError( f"Type variable {tv.__name__} " - f"is already bound to {vars[tv.__name__].__name__}, " + f"is already bound to {vars[tv].__name__}, " f"but got {param_value.__name__}" ) elif _typing_inspect.is_generic_alias(tv): @@ -209,19 +214,6 @@ def eval_type_call( bound = sig.bind(*arg_types, **kwarg_types) vars = _get_bound_type_args_from_bound_args(sig, bound) - res = _substitute_type_vars(sig.return_annotation, vars) + res = substitute(sig.return_annotation, vars) return res - - -def _substitute_type_vars( - obj: typing.Any, vars: dict[str, RtType] -) -> typing.Any: - """Recursively substitute type variables into a type expression.""" - if isinstance(obj, typing.TypeVar): - return vars[obj.__name__] - elif _typing_inspect.is_generic_alias(obj): - args = tuple(_substitute_type_vars(v, vars) for v in obj.__args__) - return obj.__origin__[args] # type: ignore[index] - else: - return obj From 15c30d3576933de73d60815643d8c6dd02d7c54a Mon Sep 17 00:00:00 2001 From: dnwpark Date: Mon, 19 Jan 2026 12:19:24 -0800 Subject: [PATCH 6/6] Move Callable functionality to existing eval_call_with_types. --- tests/test_qblike.py | 24 +++++++++- tests/test_type_eval.py | 70 +++++++++++++-------------- typemap/type_eval/__init__.py | 4 +- typemap/type_eval/_eval_call.py | 83 +++++++++++++++------------------ 4 files changed, 99 insertions(+), 82 deletions(-) diff --git a/tests/test_qblike.py b/tests/test_qblike.py index 13b9ffb..870d721 100644 --- a/tests/test_qblike.py +++ b/tests/test_qblike.py @@ -2,7 +2,11 @@ from typing import Literal, Unpack -from typemap.type_eval import eval_call, eval_typing +from typemap.type_eval import ( + eval_call, + eval_call_with_types, + eval_typing, +) from typemap.typing import ( BaseTypedDict, NewProtocol, @@ -124,3 +128,21 @@ class select[...]: class PropsOnly[tests.test_qblike.Tgt]: name: tests.test_qblike.Property[str] """) + + +def test_qblike_4(): + t = eval_call_with_types( + select, + A, + x=bool, + w=bool, + z=bool, + ) + fmt = format_helper.format_class(t) + + assert fmt == textwrap.dedent("""\ + class select[...]: + x: tests.test_qblike.Property[int] + w: tests.test_qblike.Property[list[str]] + z: tests.test_qblike.Link[tests.test_qblike.PropsOnly[tests.test_qblike.Tgt]] + """) diff --git a/tests/test_type_eval.py b/tests/test_type_eval.py index 435cb90..ce9d1e6 100644 --- a/tests/test_type_eval.py +++ b/tests/test_type_eval.py @@ -17,7 +17,7 @@ import pytest -from typemap.type_eval import eval_typing, eval_type_call +from typemap.type_eval import eval_call_with_types, eval_typing from typemap.typing import ( Attrs, FromUnion, @@ -1029,17 +1029,17 @@ def test_type_eval_annotated_04(): def test_type_call_callable_01(): - res = eval_type_call(Callable[[], int]) + res = eval_call_with_types(Callable[[], int]) assert res is int def test_type_call_callable_02(): - res = eval_type_call(Callable[[Param[Literal["x"], int]], int], int) + res = eval_call_with_types(Callable[[Param[Literal["x"], int]], int], int) assert res is int def test_type_call_callable_03(): - res = eval_type_call( + res = eval_call_with_types( Callable[[Param[Literal["x"], int, Literal["keyword"]]], int], x=int ) assert res is int @@ -1048,21 +1048,21 @@ def test_type_call_callable_03(): def test_type_call_callable_04(): class C: ... - res = eval_type_call(Callable[[Param[Literal["self"], Self]], int], C) + res = eval_call_with_types(Callable[[Param[Literal["self"], Self]], int], C) assert res is int def test_type_call_callable_05(): class C: ... - res = eval_type_call(Callable[[Param[Literal["self"], Self]], C], C) + res = eval_call_with_types(Callable[[Param[Literal["self"], Self]], C], C) assert res is C def test_type_call_callable_06(): class C: ... - res = eval_type_call( + res = eval_call_with_types( Callable[[Param[Literal["self"], Self], Param[Literal["x"], int]], int], C, int, @@ -1073,7 +1073,7 @@ class C: ... def test_type_call_callable_07(): class C: ... - res = eval_type_call( + res = eval_call_with_types( Callable[ [ Param[Literal["self"], Self], @@ -1089,13 +1089,13 @@ class C: ... def test_type_call_callable_08(): T = TypeVar("T") - res = eval_type_call(Callable[[Param[Literal["x"], T]], str], int) + res = eval_call_with_types(Callable[[Param[Literal["x"], T]], str], int) assert res is str def test_type_call_callable_09(): T = TypeVar("T") - res = eval_type_call(Callable[[Param[Literal["x"], T]], T], int) + res = eval_call_with_types(Callable[[Param[Literal["x"], T]], T], int) assert res is int @@ -1104,7 +1104,7 @@ def test_type_call_callable_10(): class C(Generic[T]): ... - res = eval_type_call(Callable[[Param[Literal["x"], C[T]]], T], C[int]) + res = eval_call_with_types(Callable[[Param[Literal["x"], C[T]]], T], C[int]) assert res is int @@ -1117,30 +1117,30 @@ class D(C[int]): ... class E(D): ... - res = eval_type_call(Callable[[Param[Literal["x"], C[T]]], T], D) + res = eval_call_with_types(Callable[[Param[Literal["x"], C[T]]], T], D) assert res is int - res = eval_type_call(Callable[[Param[Literal["x"], C[T]]], T], E) + res = eval_call_with_types(Callable[[Param[Literal["x"], C[T]]], T], E) assert res is int def test_type_call_local_function_01(): def func(x: int) -> int: ... - res = eval_type_call(func, int) + res = eval_call_with_types(func, int) assert res is int def test_type_call_local_function_02(): def func(*, x: int) -> int: ... - res = eval_type_call(func, x=int) + res = eval_call_with_types(func, x=int) assert res is int def test_type_call_local_function_03(): def func[T](x: T) -> T: ... - res = eval_type_call(func, int) + res = eval_call_with_types(func, int) assert res is int @@ -1149,7 +1149,7 @@ class C: ... def func(x: C) -> C: ... - res = eval_type_call(func, C) + res = eval_call_with_types(func, C) assert res is C @@ -1158,7 +1158,7 @@ class C: ... def func[T](x: T) -> T: ... - res = eval_type_call(func, C) + res = eval_call_with_types(func, C) assert res is C @@ -1169,7 +1169,7 @@ class C(Generic[T]): ... def func[U](x: C[U]) -> C[U]: ... - res = eval_type_call(func, C[int]) + res = eval_call_with_types(func, C[int]) assert res == C[int] @@ -1184,9 +1184,9 @@ class E(D): ... def func[U](x: C[U]) -> U: ... - res = eval_type_call(func, D) + res = eval_call_with_types(func, D) assert res is int - res = eval_type_call(func, E) + res = eval_call_with_types(func, E) assert res is int @@ -1201,7 +1201,7 @@ class F(D, E): ... def func[U](x: C[U]) -> U: ... - res = eval_type_call(func, F) + res = eval_call_with_types(func, F) assert res is int @@ -1210,7 +1210,7 @@ class C[T, U]: ... def func[V](x: C[int, V]) -> V: ... - res = eval_type_call(func, C[int, str]) + res = eval_call_with_types(func, C[int, str]) assert res is str @@ -1220,7 +1220,7 @@ def test_type_call_bind_error_01(): with pytest.raises( ValueError, match="Type variable T is already bound to int, but got str" ): - eval_type_call( + eval_call_with_types( Callable[[Param[Literal["x"], T], Param[Literal["y"], T]], T], int, str, @@ -1233,7 +1233,7 @@ def func[T](x: T, y: T) -> T: ... with pytest.raises( ValueError, match="Type variable T is already bound to int, but got str" ): - eval_type_call(func, int, str) + eval_call_with_types(func, int, str) def test_type_call_bind_error_03(): @@ -1244,7 +1244,7 @@ class C(Generic[T]): ... with pytest.raises( ValueError, match="Type variable T is already bound to int, but got str" ): - eval_type_call( + eval_call_with_types( Callable[[Param[Literal["x"], C[T]], Param[Literal["y"], C[T]]], T], C[int], C[str], @@ -1259,7 +1259,7 @@ def func[T](x: C[T], y: C[T]) -> T: ... with pytest.raises( ValueError, match="Type variable T is already bound to int, but got str" ): - eval_type_call(func, C[int], C[str]) + eval_call_with_types(func, C[int], C[str]) def test_type_call_bind_error_05(): @@ -1270,7 +1270,7 @@ class D[T]: ... def func[T](x: C[T]) -> T: ... with pytest.raises(ValueError, match="Argument type mismatch for x"): - eval_type_call(func, D[int]) + eval_call_with_types(func, D[int]) type GetCallableMember[T, N: str] = GetArg[ @@ -1294,7 +1294,7 @@ def test_type_call_member_01(): class C: def invoke(self, x: int) -> int: ... - res = eval_type_call(GetCallableMember[C, Literal["invoke"]], C, int) + res = eval_call_with_types(GetCallableMember[C, Literal["invoke"]], C, int) assert res is int @@ -1302,7 +1302,7 @@ def test_type_call_member_02(): class C: def invoke[T](self, x: T) -> T: ... - res = eval_type_call(GetCallableMember[C, Literal["invoke"]], C, int) + res = eval_call_with_types(GetCallableMember[C, Literal["invoke"]], C, int) assert res is int @@ -1310,7 +1310,7 @@ def test_type_call_member_03(): class C[T]: def invoke(self, x: str) -> str: ... - res = eval_type_call( + res = eval_call_with_types( GetCallableMember[C[int], Literal["invoke"]], C[int], str ) assert res is str @@ -1320,7 +1320,7 @@ def test_type_call_member_04(): class C[T]: def invoke(self, x: T) -> T: ... - res = eval_type_call( + res = eval_call_with_types( GetCallableMember[C[int], Literal["invoke"]], C[int], int ) assert res is int @@ -1330,7 +1330,9 @@ def test_type_call_member_05(): class C[T]: def invoke(self) -> C[T]: ... - res = eval_type_call(GetCallableMember[C[int], Literal["invoke"]], C[int]) + res = eval_call_with_types( + GetCallableMember[C[int], Literal["invoke"]], C[int] + ) assert res == C[int] @@ -1338,7 +1340,7 @@ def test_type_call_member_06(): class C[T]: def invoke[U](self, x: U) -> C[U]: ... - res = eval_type_call( + res = eval_call_with_types( GetCallableMember[C[int], Literal["invoke"]], C[int], str ) assert res == C[str] diff --git a/typemap/type_eval/__init__.py b/typemap/type_eval/__init__.py index c8dc962..3d2a217 100644 --- a/typemap/type_eval/__init__.py +++ b/typemap/type_eval/__init__.py @@ -7,7 +7,7 @@ from ._apply_generic import flatten_class # XXX: this needs to go second due to nasty circularity -- try to fix that!! -from ._eval_call import eval_call, eval_type_call +from ._eval_call import eval_call, eval_call_with_types from ._subtype import issubtype from ._subsim import issubsimilar @@ -19,7 +19,7 @@ "eval_typing", "register_evaluator", "eval_call", - "eval_type_call", + "eval_call_with_types", "flatten_class", "issubtype", "issubsimilar", diff --git a/typemap/type_eval/_eval_call.py b/typemap/type_eval/_eval_call.py index 6f2b334..8246b39 100644 --- a/typemap/type_eval/_eval_call.py +++ b/typemap/type_eval/_eval_call.py @@ -12,7 +12,7 @@ from . import _eval_typing from . import _typing_inspect from ._eval_operators import _callable_type_to_signature -from ._apply_generic import substitute +from ._apply_generic import substitute, _get_closure_types RtType = Any @@ -31,7 +31,7 @@ def _type(t): def eval_call(func: types.FunctionType, /, *args: Any, **kwargs: Any) -> RtType: arg_types = tuple(_type(t) for t in args) kwarg_types = {k: _type(t) for k, t in kwargs.items()} - return eval_call_with_types(func, arg_types, kwarg_types) + return eval_call_with_types(func, *arg_types, **kwarg_types) def _get_bound_type_args( @@ -138,21 +138,40 @@ def _update_bound_typevar( def eval_call_with_types( - func: types.FunctionType, - arg_types: tuple[RtType, ...], - kwarg_types: dict[str, RtType], + func: types.FunctionType | typing.Callable, + *arg_types: tuple[RtType, ...], + **kwarg_types: dict[str, RtType], ) -> RtType: - vars: dict[str, Any] = {} - params = func.__type_params__ - vars = _get_bound_type_args(func, arg_types, kwarg_types) - for p in params: - if p.__name__ not in vars: - vars[p.__name__] = p + if isinstance(func, types.FunctionType): + vars: dict[str, Any] = _get_bound_type_args( + func, arg_types, kwarg_types + ) + for p in func.__type_params__: + if p.__name__ not in vars: + vars[p.__name__] = p - return eval_call_with_type_vars(func, vars) + return eval_func_with_type_vars(func, vars) + + else: + from typemap.typing import GenericCallable + resolved_callable = _eval_typing.eval_typing(func) -def eval_call_with_type_vars( + if ( + _typing_inspect.is_generic_alias(resolved_callable) + and resolved_callable.__origin__ is GenericCallable + ): + _, resolved_callable = typing.get_args(resolved_callable) + + sig = _callable_type_to_signature(resolved_callable) + bound = sig.bind(*arg_types, **kwarg_types) + bound_args = _get_bound_type_args_from_bound_args(sig, bound) + res = substitute(sig.return_annotation, bound_args) + + return res + + +def eval_func_with_type_vars( func: types.FunctionType, vars: dict[str, RtType] ) -> RtType: with _eval_typing._ensure_context() as ctx: @@ -165,12 +184,17 @@ def _eval_call_with_type_vars( ctx: _eval_typing.EvalContext, ) -> RtType: try: - af = func.__annotate__ + af = typing.cast(types.FunctionType, func.__annotate__) except AttributeError: raise ValueError("func has no __annotate__ attribute") if not af: raise ValueError("func has no __annotate__ attribute") + closure_types = _get_closure_types(af) + for name, value in closure_types.items(): + if name not in vars: + vars[name] = value + af_args = tuple( types.CellType(vars[name]) for name in af.__code__.co_freevars ) @@ -186,34 +210,3 @@ def _eval_call_with_type_vars( return _eval_typing.eval_typing(rr["return"]) finally: ctx.current_generic_alias = old_obj - - -def eval_type_call( - callable: typing.Any, - *arg_types: Any, - **kwarg_types: Any, -) -> RtType: - from typemap.typing import GenericCallable - - arg_types = tuple(_eval_typing.eval_typing(t) for t in arg_types) - kwarg_types = { - k: _eval_typing.eval_typing(t) for k, t in kwarg_types.items() - } - if isinstance(callable, types.FunctionType): - sig = inspect.signature(callable) - else: - resolved_callable = _eval_typing.eval_typing(callable) - - if ( - _typing_inspect.is_generic_alias(resolved_callable) - and resolved_callable.__origin__ is GenericCallable - ): - _, resolved_callable = typing.get_args(resolved_callable) - - sig = _callable_type_to_signature(resolved_callable) - - bound = sig.bind(*arg_types, **kwarg_types) - vars = _get_bound_type_args_from_bound_args(sig, bound) - res = substitute(sig.return_annotation, vars) - - return res