diff --git a/tests/test_qblike.py b/tests/test_qblike.py index 13b9ffb..870d721 100644 --- a/tests/test_qblike.py +++ b/tests/test_qblike.py @@ -2,7 +2,11 @@ from typing import Literal, Unpack -from typemap.type_eval import eval_call, eval_typing +from typemap.type_eval import ( + eval_call, + eval_call_with_types, + eval_typing, +) from typemap.typing import ( BaseTypedDict, NewProtocol, @@ -124,3 +128,21 @@ class select[...]: class PropsOnly[tests.test_qblike.Tgt]: name: tests.test_qblike.Property[str] """) + + +def test_qblike_4(): + t = eval_call_with_types( + select, + A, + x=bool, + w=bool, + z=bool, + ) + fmt = format_helper.format_class(t) + + assert fmt == textwrap.dedent("""\ + class select[...]: + x: tests.test_qblike.Property[int] + w: tests.test_qblike.Property[list[str]] + z: tests.test_qblike.Link[tests.test_qblike.PropsOnly[tests.test_qblike.Tgt]] + """) diff --git a/tests/test_type_eval.py b/tests/test_type_eval.py index 0140e27..ce9d1e6 100644 --- a/tests/test_type_eval.py +++ b/tests/test_type_eval.py @@ -17,10 +17,11 @@ import pytest -from typemap.type_eval import eval_typing +from typemap.type_eval import eval_call_with_types, eval_typing from typemap.typing import ( Attrs, FromUnion, + GenericCallable, GetArg, GetArgs, GetAttr, @@ -1025,3 +1026,321 @@ def test_type_eval_annotated_03(): def test_type_eval_annotated_04(): res = eval_typing(GetAnnotations[GetAttr[AnnoTest, Literal["b"]]]) assert res == Literal["blah"] + + +def test_type_call_callable_01(): + res = eval_call_with_types(Callable[[], int]) + assert res is int + + +def test_type_call_callable_02(): + res = eval_call_with_types(Callable[[Param[Literal["x"], int]], int], int) + assert res is int + + +def test_type_call_callable_03(): + res = eval_call_with_types( + Callable[[Param[Literal["x"], int, Literal["keyword"]]], int], x=int + ) + assert res is int + + +def test_type_call_callable_04(): + class C: ... + + res = eval_call_with_types(Callable[[Param[Literal["self"], Self]], int], C) + assert res is int + + +def test_type_call_callable_05(): + class C: ... + + res = eval_call_with_types(Callable[[Param[Literal["self"], Self]], C], C) + assert res is C + + +def test_type_call_callable_06(): + class C: ... + + res = eval_call_with_types( + Callable[[Param[Literal["self"], Self], Param[Literal["x"], int]], int], + C, + int, + ) + assert res is int + + +def test_type_call_callable_07(): + class C: ... + + res = eval_call_with_types( + Callable[ + [ + Param[Literal["self"], Self], + Param[Literal["x"], int, Literal["keyword"]], + ], + int, + ], + C, + x=int, + ) + assert res is int + + +def test_type_call_callable_08(): + T = TypeVar("T") + res = eval_call_with_types(Callable[[Param[Literal["x"], T]], str], int) + assert res is str + + +def test_type_call_callable_09(): + T = TypeVar("T") + res = eval_call_with_types(Callable[[Param[Literal["x"], T]], T], int) + assert res is int + + +def test_type_call_callable_10(): + T = TypeVar("T") + + class C(Generic[T]): ... + + res = eval_call_with_types(Callable[[Param[Literal["x"], C[T]]], T], C[int]) + assert res is int + + +def test_type_call_callable_11(): + T = TypeVar("T") + + class C(Generic[T]): ... + + class D(C[int]): ... + + class E(D): ... + + res = eval_call_with_types(Callable[[Param[Literal["x"], C[T]]], T], D) + assert res is int + res = eval_call_with_types(Callable[[Param[Literal["x"], C[T]]], T], E) + assert res is int + + +def test_type_call_local_function_01(): + def func(x: int) -> int: ... + + res = eval_call_with_types(func, int) + assert res is int + + +def test_type_call_local_function_02(): + def func(*, x: int) -> int: ... + + res = eval_call_with_types(func, x=int) + assert res is int + + +def test_type_call_local_function_03(): + def func[T](x: T) -> T: ... + + res = eval_call_with_types(func, int) + assert res is int + + +def test_type_call_local_function_04(): + class C: ... + + def func(x: C) -> C: ... + + res = eval_call_with_types(func, C) + assert res is C + + +def test_type_call_local_function_05(): + class C: ... + + def func[T](x: T) -> T: ... + + res = eval_call_with_types(func, C) + assert res is C + + +def test_type_call_local_function_06(): + T = TypeVar("T") + + class C(Generic[T]): ... + + def func[U](x: C[U]) -> C[U]: ... + + res = eval_call_with_types(func, C[int]) + assert res == C[int] + + +def test_type_call_local_function_07(): + T = TypeVar("T") + + class C(Generic[T]): ... + + class D(C[int]): ... + + class E(D): ... + + def func[U](x: C[U]) -> U: ... + + res = eval_call_with_types(func, D) + assert res is int + res = eval_call_with_types(func, E) + assert res is int + + +def test_type_call_local_function_08(): + class C[T]: ... + + class D(C[int]): ... + + class E(C[str]): ... + + class F(D, E): ... + + def func[U](x: C[U]) -> U: ... + + res = eval_call_with_types(func, F) + assert res is int + + +def test_type_call_local_function_09(): + class C[T, U]: ... + + def func[V](x: C[int, V]) -> V: ... + + res = eval_call_with_types(func, C[int, str]) + assert res is str + + +def test_type_call_bind_error_01(): + T = TypeVar("T") + + with pytest.raises( + ValueError, match="Type variable T is already bound to int, but got str" + ): + eval_call_with_types( + Callable[[Param[Literal["x"], T], Param[Literal["y"], T]], T], + int, + str, + ) + + +def test_type_call_bind_error_02(): + def func[T](x: T, y: T) -> T: ... + + with pytest.raises( + ValueError, match="Type variable T is already bound to int, but got str" + ): + eval_call_with_types(func, int, str) + + +def test_type_call_bind_error_03(): + T = TypeVar("T") + + class C(Generic[T]): ... + + with pytest.raises( + ValueError, match="Type variable T is already bound to int, but got str" + ): + eval_call_with_types( + Callable[[Param[Literal["x"], C[T]], Param[Literal["y"], C[T]]], T], + C[int], + C[str], + ) + + +def test_type_call_bind_error_04(): + class C[T]: ... + + def func[T](x: C[T], y: C[T]) -> T: ... + + with pytest.raises( + ValueError, match="Type variable T is already bound to int, but got str" + ): + eval_call_with_types(func, C[int], C[str]) + + +def test_type_call_bind_error_05(): + class C[T]: ... + + class D[T]: ... + + def func[T](x: C[T]) -> T: ... + + with pytest.raises(ValueError, match="Argument type mismatch for x"): + eval_call_with_types(func, D[int]) + + +type GetCallableMember[T, N: str] = GetArg[ + tuple[ + *[ + GetType[m] + for m in Iter[Members[T]] + if ( + IsSub[GetType[m], Callable] + or IsSub[GetType[m], GenericCallable] + ) + and IsSub[GetName[m], N] + ] + ], + tuple, + 0, +] + + +def test_type_call_member_01(): + class C: + def invoke(self, x: int) -> int: ... + + res = eval_call_with_types(GetCallableMember[C, Literal["invoke"]], C, int) + assert res is int + + +def test_type_call_member_02(): + class C: + def invoke[T](self, x: T) -> T: ... + + res = eval_call_with_types(GetCallableMember[C, Literal["invoke"]], C, int) + assert res is int + + +def test_type_call_member_03(): + class C[T]: + def invoke(self, x: str) -> str: ... + + res = eval_call_with_types( + GetCallableMember[C[int], Literal["invoke"]], C[int], str + ) + assert res is str + + +def test_type_call_member_04(): + class C[T]: + def invoke(self, x: T) -> T: ... + + res = eval_call_with_types( + GetCallableMember[C[int], Literal["invoke"]], C[int], int + ) + assert res is int + + +def test_type_call_member_05(): + class C[T]: + def invoke(self) -> C[T]: ... + + res = eval_call_with_types( + GetCallableMember[C[int], Literal["invoke"]], C[int] + ) + assert res == C[int] + + +def test_type_call_member_06(): + class C[T]: + def invoke[U](self, x: U) -> C[U]: ... + + res = eval_call_with_types( + GetCallableMember[C[int], Literal["invoke"]], C[int], str + ) + assert res == C[str] diff --git a/typemap/type_eval/__init__.py b/typemap/type_eval/__init__.py index b63c49d..3d2a217 100644 --- a/typemap/type_eval/__init__.py +++ b/typemap/type_eval/__init__.py @@ -7,7 +7,7 @@ from ._apply_generic import flatten_class # XXX: this needs to go second due to nasty circularity -- try to fix that!! -from ._eval_call import eval_call +from ._eval_call import eval_call, eval_call_with_types from ._subtype import issubtype from ._subsim import issubsimilar @@ -19,6 +19,7 @@ "eval_typing", "register_evaluator", "eval_call", + "eval_call_with_types", "flatten_class", "issubtype", "issubsimilar", diff --git a/typemap/type_eval/_eval_call.py b/typemap/type_eval/_eval_call.py index d71007a..8246b39 100644 --- a/typemap/type_eval/_eval_call.py +++ b/typemap/type_eval/_eval_call.py @@ -8,8 +8,11 @@ from typing import Any +from . import _eval_operators from . import _eval_typing from . import _typing_inspect +from ._eval_operators import _callable_type_to_signature +from ._apply_generic import substitute, _get_closure_types RtType = Any @@ -28,7 +31,7 @@ def _type(t): def eval_call(func: types.FunctionType, /, *args: Any, **kwargs: Any) -> RtType: arg_types = tuple(_type(t) for t in args) kwarg_types = {k: _type(t) for k, t in kwargs.items()} - return eval_call_with_types(func, arg_types, kwarg_types) + return eval_call_with_types(func, *arg_types, **kwarg_types) def _get_bound_type_args( @@ -39,7 +42,17 @@ def _get_bound_type_args( sig = inspect.signature(func) bound = sig.bind(*arg_types, **kwarg_types) - vars: dict[str, RtType] = {} + return { + tv.__name__: tp + for tv, tp in _get_bound_type_args_from_bound_args(sig, bound).items() + } + + +def _get_bound_type_args_from_bound_args( + sig: inspect.Signature, + bound: inspect.BoundArguments, +) -> dict[typing.TypeVar | typing.TypeVarTuple, RtType]: + vars: dict[typing.TypeVar | typing.TypeVarTuple, RtType] = {} # TODO: duplication, error cases for param in sig.parameters.values(): # Unpack[TypeVarType] for *args @@ -54,7 +67,7 @@ def _get_bound_type_args( and isinstance(tv, typing.TypeVarTuple) ): tps = bound.arguments.get(param.name, ()) - vars[tv.__name__] = tuple[tps] # type: ignore[valid-type] + vars[tv] = tuple[tps] # type: ignore[valid-type] # Unpack[T] for **kwargs elif ( param.kind == inspect.Parameter.VAR_KEYWORD @@ -69,7 +82,7 @@ def _get_bound_type_args( and typing_extensions.is_typeddict(tv.__bound__) ): tp = typing.TypedDict(f"**{param.name}", bound.kwargs) # type: ignore[misc, operator] - vars[tv.__name__] = tp + vars[tv] = tp # trivial type[T] bindings elif ( _typing_inspect.is_generic_alias(param.annotation) @@ -80,35 +93,85 @@ def _get_bound_type_args( and _typing_inspect.is_generic_alias(arg) and arg.__origin__ is type ): - vars[tv.__name__] = arg.__args__[0] + vars[tv] = arg.__args__[0] # trivial T bindings - elif ( - isinstance(param.annotation, typing.TypeVar) - and param.name in bound.arguments - ): + elif isinstance( + param.annotation, typing.TypeVar + ) or _typing_inspect.is_generic_alias(param.annotation): param_value = bound.arguments[param.name] - vars[param.annotation.__name__] = param_value + _update_bound_typevar( + param.name, param.annotation, param_value, vars + ) # TODO: simple bindings to other variables too return vars +def _update_bound_typevar( + param_name: str, + tv: Any, + param_value: Any, + vars: dict[typing.TypeVar | typing.TypeVarTuple, RtType], +) -> None: + if isinstance(tv, typing.TypeVar): + if tv not in vars: + vars[tv] = param_value + elif vars[tv] != param_value: + raise ValueError( + f"Type variable {tv.__name__} " + f"is already bound to {vars[tv].__name__}, " + f"but got {param_value.__name__}" + ) + elif _typing_inspect.is_generic_alias(tv): + tv_args = tv.__args__ + + with _eval_typing._ensure_context() as ctx: + param_args = _eval_operators._get_args( + param_value, tv.__origin__, ctx + ) + + if param_args is None: + raise ValueError(f"Argument type mismatch for {param_name}") + + for p_arg, c_arg in zip(tv_args, param_args, strict=True): + _update_bound_typevar(param_name, p_arg, c_arg, vars) + + def eval_call_with_types( - func: types.FunctionType, - arg_types: tuple[RtType, ...], - kwarg_types: dict[str, RtType], + func: types.FunctionType | typing.Callable, + *arg_types: tuple[RtType, ...], + **kwarg_types: dict[str, RtType], ) -> RtType: - vars: dict[str, Any] = {} - params = func.__type_params__ - vars = _get_bound_type_args(func, arg_types, kwarg_types) - for p in params: - if p.__name__ not in vars: - vars[p.__name__] = p + if isinstance(func, types.FunctionType): + vars: dict[str, Any] = _get_bound_type_args( + func, arg_types, kwarg_types + ) + for p in func.__type_params__: + if p.__name__ not in vars: + vars[p.__name__] = p - return eval_call_with_type_vars(func, vars) + return eval_func_with_type_vars(func, vars) + else: + from typemap.typing import GenericCallable + + resolved_callable = _eval_typing.eval_typing(func) + + if ( + _typing_inspect.is_generic_alias(resolved_callable) + and resolved_callable.__origin__ is GenericCallable + ): + _, resolved_callable = typing.get_args(resolved_callable) -def eval_call_with_type_vars( + sig = _callable_type_to_signature(resolved_callable) + bound = sig.bind(*arg_types, **kwarg_types) + bound_args = _get_bound_type_args_from_bound_args(sig, bound) + res = substitute(sig.return_annotation, bound_args) + + return res + + +def eval_func_with_type_vars( func: types.FunctionType, vars: dict[str, RtType] ) -> RtType: with _eval_typing._ensure_context() as ctx: @@ -121,12 +184,17 @@ def _eval_call_with_type_vars( ctx: _eval_typing.EvalContext, ) -> RtType: try: - af = func.__annotate__ + af = typing.cast(types.FunctionType, func.__annotate__) except AttributeError: raise ValueError("func has no __annotate__ attribute") if not af: raise ValueError("func has no __annotate__ attribute") + closure_types = _get_closure_types(af) + for name, value in closure_types.items(): + if name not in vars: + vars[name] = value + af_args = tuple( types.CellType(vars[name]) for name in af.__code__.co_freevars )