diff --git a/tests/test_call.py b/tests/test_call.py index e983aa7..8f7ad99 100644 --- a/tests/test_call.py +++ b/tests/test_call.py @@ -21,7 +21,7 @@ def func[*T, K: BaseTypedDict]( ) -> NewProtocol[*[Member[GetName[c], int] for c in Iter[Attrs[K]]]]: ... -def test_call_1(): +def test_eval_call_01(): ret = eval_call(func, a=1, b=2, c="aaa") fmt = format_helper.format_class(ret) @@ -40,7 +40,7 @@ def func_trivial[*T, K: BaseTypedDict]( return kwargs -def test_call_2(): +def test_eval_call_02(): ret = eval_call(func_trivial, a=1, b=2, c="aaa") fmt = format_helper.format_class(ret) @@ -63,7 +63,7 @@ def wrapped[T](value: T) -> Wrapped[T]: return Wrapped[T](value) -def test_call_3(): +def test_eval_call_03(): ret = eval_call(wrapped, 1) fmt = format_helper.format_class(ret) diff --git a/tests/test_eval_call_with_types.py b/tests/test_eval_call_with_types.py new file mode 100644 index 0000000..e014084 --- /dev/null +++ b/tests/test_eval_call_with_types.py @@ -0,0 +1,899 @@ +import pytest +import textwrap + +from types import GenericAlias +from typing import Callable, Generic, Literal, Self, TypeVar, Unpack +from typing_extensions import TypeAliasType + +from typemap.type_eval import eval_call_with_types +from typemap.typing import ( + Attrs, + BaseTypedDict, + NewProtocol, + Member, + GetAttr, + Iter, + Param, +) + +from typing import _ProtocolMeta + +from . import format_helper + + +class Wrapper[T]: + value: T + + +class WrappedInt(Wrapper[int]): + pass + + +class WrappedStr(Wrapper[str]): + pass + + +class WrappedIntStr(WrappedInt, WrappedStr): + pass + + +class Pair[T, U]: + first: T + second: U + + +def func_positional(x: int) -> int: ... +def func_named(*, x: int) -> int: ... +def func_generic_to_value[T](x: T) -> T: ... +def func_generic_to_wrapped[T](x: T) -> Wrapper[T]: ... +def func_generic_from_wrapped[T](x: Wrapper[T]) -> T: ... +def func_generic_partial[T](x: Pair[int, T]) -> T: ... + + +def func_unpack_tuple[*T]( + *args: Unpack[T], +) -> T: ... +def func_unpack_dict[K: BaseTypedDict]( + **kwargs: Unpack[K], +) -> K: ... + + +def test_eval_call_with_types_module_function_01(): + ret = eval_call_with_types(func_positional, int) + assert ret is int + + +def test_eval_call_with_types_module_function_02(): + ret = eval_call_with_types(func_named, x=int) + assert ret is int + + +def test_eval_call_with_types_module_function_03(): + ret = eval_call_with_types(func_generic_to_value, int) + assert ret is int + + +def test_eval_call_with_types_module_function_04(): + ret = eval_call_with_types(func_generic_to_wrapped, int) + assert ret is Wrapper[int] + + +def test_eval_call_with_types_module_function_05(): + ret = eval_call_with_types(func_generic_from_wrapped, Wrapper[int]) + assert ret is int + ret = eval_call_with_types(func_generic_from_wrapped, WrappedInt) + assert ret is int + ret = eval_call_with_types(func_generic_from_wrapped, WrappedStr) + assert ret is str + ret = eval_call_with_types(func_generic_from_wrapped, WrappedIntStr) + assert ret is int + + +def test_eval_call_with_types_module_function_06(): + ret = eval_call_with_types(func_generic_partial, Pair[int, int]) + assert ret is int + ret = eval_call_with_types(func_generic_partial, Pair[int, str]) + assert ret is str + + +def test_eval_call_with_types_module_function_07(): + ret = eval_call_with_types(func_unpack_tuple, int, float, str) + assert ret == tuple[int, float, str] + + +def test_eval_call_with_types_module_function_08(): + ret = eval_call_with_types(func_unpack_dict, a=int, b=float, c=str) + fmt = format_helper.format_class(ret) + + assert fmt == textwrap.dedent("""\ + class **kwargs: + a: int + b: float + c: str + """) + + +def test_eval_call_with_types_local_function_01(): + def func(x: int) -> int: ... + + res = eval_call_with_types(func, int) + assert res is int + + +def test_eval_call_with_types_local_function_02(): + def func(*, x: int) -> int: ... + + res = eval_call_with_types(func, x=int) + assert res is int + + +def test_eval_call_with_types_local_function_03(): + def func[T](x: T) -> T: ... + + res = eval_call_with_types(func, int) + assert res is int + + class C: ... + + res = eval_call_with_types(func, C) + assert res is C + + +def test_eval_call_with_types_local_function_04(): + class C[T]: + pass + + def f[T](x: T) -> C[T]: ... + + ret = eval_call_with_types(f, int) + assert ret == C[int] + + +def test_eval_call_with_types_local_function_05(): + T = TypeVar("T") + + class C(Generic[T]): ... + + class D(C[int]): ... + + class E(C[str]): ... + + class F(D, E): ... + + def func[U](x: C[U]) -> U: ... + + res = eval_call_with_types(func, C[int]) + assert res is int + res = eval_call_with_types(func, D) + assert res is int + res = eval_call_with_types(func, E) + assert res is str + res = eval_call_with_types(func, F) + assert res is int + + +def test_eval_call_with_types_local_function_06(): + class C[T, U]: ... + + def func[V](x: C[int, V]) -> V: ... + + res = eval_call_with_types(func, C[int, str]) + assert res is str + + +class ModuleClass: + def member_func(self, x: int) -> str: ... + @classmethod + def class_func(self, x: int) -> str: ... + @staticmethod + def static_func(x: int) -> str: ... + + def generic_member_func[T](self, x: T) -> T: ... + @classmethod + def generic_class_func[T](cls, x: T) -> T: ... + @staticmethod + def generic_static_func[T](x: T) -> T: ... + + +def test_eval_call_with_types_module_class_01(): + ret = eval_call_with_types( + GetAttr[ModuleClass, Literal["member_func"]], ModuleClass, int + ) + assert ret is str + + +def test_eval_call_with_types_module_class_02(): + ret = eval_call_with_types( + GetAttr[ModuleClass, Literal["class_func"]], type(ModuleClass), int + ) + assert ret is str + + +def test_eval_call_with_types_module_class_03(): + ret = eval_call_with_types( + GetAttr[ModuleClass, Literal["static_func"]], int + ) + assert ret is str + + +def test_eval_call_with_types_module_class_04(): + ret = eval_call_with_types( + GetAttr[ModuleClass, Literal["generic_member_func"]], ModuleClass, int + ) + assert ret is int + + +def test_eval_call_with_types_module_class_05(): + ret = eval_call_with_types( + GetAttr[ModuleClass, Literal["generic_class_func"]], + type(ModuleClass), + int, + ) + assert ret is int + + +def test_eval_call_with_types_module_class_06(): + ret = eval_call_with_types( + GetAttr[ModuleClass, Literal["generic_static_func"]], int + ) + assert ret is int + + +class ModuleGeneric[T]: + def member_func(self, x: int) -> str: ... + @classmethod + def class_func(self, x: int) -> str: ... + @staticmethod + def static_func(x: int) -> str: ... + + def specialized_member_func[T](self, x: T) -> T: ... + @classmethod + def specialized_class_func[T](cls, x: T) -> T: ... + @staticmethod + def specialized_static_func[T](x: T) -> T: ... + + def generic_method[U](self, x: T, y: U) -> U: ... + @classmethod + def generic_class_method[U](cls, x: T, y: U) -> U: ... + @staticmethod + def generic_static_method[U](x: T, y: U) -> U: ... + + +def test_eval_call_with_types_module_generic_class_01(): + ret = eval_call_with_types( + GetAttr[ModuleGeneric[float], Literal["member_func"]], + ModuleGeneric[float], + int, + ) + assert ret is str + + +def test_eval_call_with_types_module_generic_class_02(): + ret = eval_call_with_types( + GetAttr[ModuleGeneric[float], Literal["class_func"]], + ModuleGeneric[float], + int, + ) + assert ret is str + + +def test_eval_call_with_types_module_generic_class_03(): + ret = eval_call_with_types( + GetAttr[ModuleGeneric[float], Literal["static_func"]], int + ) + assert ret is str + + +def test_eval_call_with_types_module_generic_class_04(): + ret = eval_call_with_types( + GetAttr[ModuleGeneric[float], Literal["specialized_member_func"]], + ModuleGeneric[float], + int, + ) + assert ret is int + + +def test_eval_call_with_types_module_generic_class_05(): + ret = eval_call_with_types( + GetAttr[ModuleGeneric[float], Literal["specialized_class_func"]], + ModuleGeneric[float], + int, + ) + assert ret is int + + +def test_eval_call_with_types_module_generic_class_06(): + ret = eval_call_with_types( + GetAttr[ModuleGeneric[float], Literal["specialized_static_func"]], int + ) + assert ret is int + + +def test_eval_call_with_types_module_generic_class_07(): + ret = eval_call_with_types( + GetAttr[ModuleGeneric[float], Literal["generic_method"]], + ModuleGeneric[float], + float, + int, + ) + assert ret is int + + +def test_eval_call_with_types_module_generic_class_08(): + ret = eval_call_with_types( + GetAttr[ModuleGeneric[float], Literal["generic_class_method"]], + ModuleGeneric[float], + float, + int, + ) + assert ret is int + + +def test_eval_call_with_types_module_generic_class_09(): + ret = eval_call_with_types( + GetAttr[ModuleGeneric[float], Literal["generic_static_method"]], + float, + int, + ) + assert ret is int + + +def test_eval_call_with_types_local_class_01(): + class C: + def member_func(self, x: int) -> str: ... + + ret = eval_call_with_types(GetAttr[C, Literal["member_func"]], C, int) + assert ret is str + + +def test_eval_call_with_types_local_class_02(): + class C: + @classmethod + def class_func(cls, x: int) -> str: ... + + ret = eval_call_with_types(GetAttr[C, Literal["class_func"]], type(C), int) + assert ret is str + + +def test_eval_call_with_types_local_class_03(): + class C: + @staticmethod + def static_func(x: int) -> str: ... + + ret = eval_call_with_types(GetAttr[C, Literal["static_func"]], int) + assert ret is str + + +def test_eval_call_with_types_local_class_04(): + class C: + def generic_member_func[T](self, x: T) -> T: ... + + ret = eval_call_with_types( + GetAttr[C, Literal["generic_member_func"]], C, int + ) + assert ret is int + + +def test_eval_call_with_types_local_class_05(): + class C: + @classmethod + def generic_class_func[T](cls, x: T) -> T: ... + + ret = eval_call_with_types( + GetAttr[C, Literal["generic_class_func"]], type(C), int + ) + assert ret is int + + +def test_eval_call_with_types_local_class_06(): + class C: + @staticmethod + def generic_static_func[T](x: T) -> T: ... + + ret = eval_call_with_types(GetAttr[C, Literal["generic_static_func"]], int) + assert ret is int + + +def test_eval_call_with_types_local_generic_class_01(): + class C[T]: + def member_func(self, x: int) -> str: ... + + ret = eval_call_with_types( + GetAttr[C[float], Literal["member_func"]], C[float], int + ) + assert ret is str + + +def test_eval_call_with_types_local_generic_class_02(): + class C[T]: + @classmethod + def class_func(cls, x: int) -> str: ... + + ret = eval_call_with_types( + GetAttr[C[float], Literal["class_func"]], C[float], int + ) + assert ret is str + + +def test_eval_call_with_types_local_generic_class_03(): + class C[T]: + @staticmethod + def static_func(x: int) -> str: ... + + ret = eval_call_with_types(GetAttr[C[float], Literal["static_func"]], int) + assert ret is str + + +def test_eval_call_with_types_local_generic_class_04(): + class C[T]: + def specialized_member_func[T](self, x: T) -> T: ... + + ret = eval_call_with_types( + GetAttr[C[float], Literal["specialized_member_func"]], C[float], int + ) + assert ret is int + + +def test_eval_call_with_types_local_generic_class_05(): + class C[T]: + @classmethod + def specialized_class_func[T](cls, x: T) -> T: ... + + ret = eval_call_with_types( + GetAttr[C[float], Literal["specialized_class_func"]], C[float], int + ) + assert ret is int + + +def test_eval_call_with_types_local_generic_class_06(): + class C[T]: + @staticmethod + def specialized_static_func[T](x: T) -> T: ... + + ret = eval_call_with_types( + GetAttr[C[float], Literal["specialized_static_func"]], int + ) + assert ret is int + + +def test_eval_call_with_types_local_generic_class_07(): + class C[T]: + def generic_method[U](self, x: T, y: U) -> U: ... + + ret = eval_call_with_types( + GetAttr[C[float], Literal["generic_method"]], C[float], float, int + ) + assert ret is int + + +def test_eval_call_with_types_local_generic_class_08(): + class C[T]: + @classmethod + def generic_class_method[U](cls, x: T, y: U) -> U: ... + + ret = eval_call_with_types( + GetAttr[C[float], Literal["generic_class_method"]], C[float], float, int + ) + assert ret is int + + +def test_eval_call_with_types_local_generic_class_09(): + class C[T]: + @staticmethod + def generic_static_method[U](x: float, y: U) -> U: ... + + ret = eval_call_with_types( + GetAttr[C[float], Literal["generic_static_method"]], float, int + ) + assert ret is int + + +class Foo: + a: int + + +class Bar: + a: str + + +U = TypeVar("U") +type WithCopy[T] = NewProtocol[ + *[c for c in Iter[Attrs[T]]], + Member[ + Literal["copy"], + Callable[[Param[Literal["self"], Self]], WithCopy[T]], + Literal["ClassVar"], + ], +] +type WithEq[T] = NewProtocol[ + *[c for c in Iter[Attrs[T]]], + Member[ + Literal["__eq__"], + Callable[ + [Param[Literal["self"], Self], Param[Literal["other"], WithEq[T]]], + bool, + ], + Literal["ClassVar"], + ], +] +type WithContains[T] = NewProtocol[ + *[c for c in Iter[Attrs[T]]], + Member[ + Literal["__contains__"], + Callable[ + [Param[Literal["self"], Self], Param[Literal["item"], U]], bool + ], + Literal["ClassVar"], + ], +] +type WithAdd[T] = NewProtocol[ + *[c for c in Iter[Attrs[T]]], + Member[ + Literal["__add__"], + Callable[ + [Param[Literal["self"], Self], Param[Literal["other"], U]], + WithAdd[U], + ], + Literal["ClassVar"], + ], +] +type WithMax[T] = NewProtocol[ + *[c for c in Iter[Attrs[T]]], + Member[ + Literal["from"], + Callable[ + [Param[Literal["self"], Self], Param[Literal["other"], WithMax[U]]], + U, + ], + Literal["ClassVar"], + ], +] + + +def with_copy[T](value: T) -> WithCopy[T]: ... +def with_eq[T](value: T) -> WithEq[T]: ... +def with_contains[T](value: T) -> WithContains[T]: ... +def with_add[T](value: T) -> WithAdd[T]: ... +def with_max[T](value: T) -> WithMax[T]: ... + + +def test_eval_call_with_types_protocol_01(): + # Member function of a protocol + # Returns same protocol + + cls = eval_call_with_types(with_copy, Foo) + assert type(cls) is _ProtocolMeta + + fmt = format_helper.format_class(cls) + assert fmt == textwrap.dedent("""\ + class WithCopy[tests.test_eval_call_with_types.Foo]: + a: int + def copy(self: Self) -> WithCopy[tests.test_eval_call_with_types.Foo]: ... + """) + + ret = eval_call_with_types(GetAttr[cls, Literal["copy"]], WithCopy[Foo]) + assert ret == WithCopy[Foo] + + # Note: ret here is a generic TypeAliasType + assert isinstance(ret, GenericAlias) + assert isinstance(ret.__origin__, TypeAliasType) + + # Still renders the same as the original protocol + fmt = format_helper.format_class(ret) + assert fmt == textwrap.dedent("""\ + class WithCopy[tests.test_eval_call_with_types.Foo]: + a: int + def copy(self: Self) -> WithCopy[tests.test_eval_call_with_types.Foo]: ... + """) + + # Make sure we can keep calling the member function + ret2 = eval_call_with_types(GetAttr[ret, Literal["copy"]], WithCopy[Foo]) + assert ret2 == ret + + +def test_eval_call_with_types_protocol_02(): + # Member function of a protocol + # Param is the same protocol + # Returns bool + + cls = eval_call_with_types(with_eq, Foo) + fmt = format_helper.format_class(cls) + assert fmt == textwrap.dedent("""\ + class WithEq[tests.test_eval_call_with_types.Foo]: + a: int + def __eq__(self: Self, other: WithEq[tests.test_eval_call_with_types.Foo]) -> bool: ... + """) + + ret = eval_call_with_types( + GetAttr[cls, Literal["__eq__"]], WithEq[Foo], WithEq[Foo] + ) + assert ret is bool + + with pytest.raises(ValueError, match="Argument type mismatch for other"): + eval_call_with_types(GetAttr[cls, Literal["__eq__"]], WithEq[Foo], int) + with pytest.raises(ValueError, match="Argument type mismatch for other"): + eval_call_with_types(GetAttr[cls, Literal["__eq__"]], WithEq[Foo], Foo) + with pytest.raises(ValueError, match="Argument type mismatch for other"): + eval_call_with_types( + GetAttr[cls, Literal["__eq__"]], WithEq[Foo], WithEq[Bar] + ) + with pytest.raises(ValueError, match="Argument type mismatch for other"): + eval_call_with_types( + GetAttr[cls, Literal["__eq__"]], WithEq[Foo], WithAdd[Foo] + ) + + +def test_eval_call_with_types_protocol_03(): + # Member function of a protocol + # Param is a different type + # Returns bool + + cls = eval_call_with_types(with_contains, Foo) + fmt = format_helper.format_class(cls) + assert fmt == textwrap.dedent("""\ + class WithContains[tests.test_eval_call_with_types.Foo]: + a: int + def __contains__(self: Self, item: ~U) -> bool: ... + """) + + ret = eval_call_with_types( + GetAttr[cls, Literal["__contains__"]], + WithContains[Foo], + int, + ) + assert ret is bool + ret = eval_call_with_types( + GetAttr[cls, Literal["__contains__"]], + WithContains[Foo], + str, + ) + assert ret is bool + ret = eval_call_with_types( + GetAttr[cls, Literal["__contains__"]], + WithContains[Foo], + float, + ) + assert ret is bool + + +def test_eval_call_with_types_protocol_04(): + # Member function of a protocol + # Param is a different type + # Returns a protocol based on the param type + + cls = eval_call_with_types(with_add, Foo) + fmt = format_helper.format_class(cls) + assert fmt == textwrap.dedent("""\ + class WithAdd[tests.test_eval_call_with_types.Foo]: + a: int + def __add__(self: Self, other: ~U) -> WithAdd[~U]: ... + """) + ret = eval_call_with_types( + GetAttr[cls, Literal["__add__"]], WithAdd[Foo], Bar + ) + assert ret == WithAdd[Bar] + + # Note: ret here is a generic TypeAliasType + assert isinstance(ret, GenericAlias) + assert isinstance(ret.__origin__, TypeAliasType) + + fmt = format_helper.format_class(ret) + assert fmt == textwrap.dedent("""\ + class WithAdd[tests.test_eval_call_with_types.Bar]: + a: str + def __add__(self: Self, other: ~U) -> WithAdd[~U]: ... + """) + + # Make sure we can keep calling the member function + ret2 = eval_call_with_types( + GetAttr[ret, Literal["__add__"]], WithAdd[Bar], Foo + ) + assert ret2 == WithAdd[Foo] + + +def test_eval_call_with_types_protocol_05(): + cls = eval_call_with_types(with_max, Foo) + fmt = format_helper.format_class(cls) + assert fmt == textwrap.dedent("""\ + class WithMax[tests.test_eval_call_with_types.Foo]: + a: int + def from(self: Self, other: WithMax[~U]) -> ~U: ... + """) + ret = eval_call_with_types( + GetAttr[cls, Literal["from"]], WithMax[Foo], WithMax[Bar] + ) + assert ret is Bar + + with pytest.raises(ValueError, match="Argument type mismatch for other"): + eval_call_with_types(GetAttr[cls, Literal["from"]], WithMax[Foo], int) + with pytest.raises(ValueError, match="Argument type mismatch for other"): + eval_call_with_types( + GetAttr[cls, Literal["from"]], WithMax[Foo], WithEq[Foo] + ) + + +def test_eval_call_with_types_callable_01(): + res = eval_call_with_types(Callable[[], int]) + assert res is int + + +def test_eval_call_with_types_callable_02(): + res = eval_call_with_types(Callable[[Param[Literal["x"], int]], int], int) + assert res is int + + +def test_eval_call_with_types_callable_03(): + res = eval_call_with_types( + Callable[[Param[Literal["x"], int, Literal["keyword"]]], int], x=int + ) + assert res is int + + +def test_eval_call_with_types_callable_04(): + class C: ... + + res = eval_call_with_types(Callable[[Param[Literal["self"], Self]], int], C) + assert res is int + + +def test_eval_call_with_types_callable_05(): + class C: ... + + res = eval_call_with_types(Callable[[Param[Literal["self"], Self]], C], C) + assert res is C + + +def test_eval_call_with_types_callable_06(): + class C: ... + + res = eval_call_with_types( + Callable[[Param[Literal["self"], Self], Param[Literal["x"], int]], int], + C, + int, + ) + assert res is int + + +def test_eval_call_with_types_callable_07(): + class C: ... + + res = eval_call_with_types( + Callable[ + [ + Param[Literal["self"], Self], + Param[Literal["x"], int, Literal["keyword"]], + ], + int, + ], + C, + x=int, + ) + assert res is int + + +def test_eval_call_with_types_callable_08(): + T = TypeVar("T") + res = eval_call_with_types(Callable[[Param[Literal["x"], T]], str], int) + assert res is str + + +def test_eval_call_with_types_callable_09(): + T = TypeVar("T") + res = eval_call_with_types(Callable[[Param[Literal["x"], T]], T], int) + assert res is int + + +def test_eval_call_with_types_callable_10(): + T = TypeVar("T") + + class C(Generic[T]): ... + + res = eval_call_with_types(Callable[[Param[Literal["x"], C[T]]], T], C[int]) + assert res is int + + +def test_eval_call_with_types_callable_11(): + T = TypeVar("T") + + class C(Generic[T]): ... + + class D(C[int]): ... + + class E(D): ... + + res = eval_call_with_types(Callable[[Param[Literal["x"], C[T]]], T], D) + assert res is int + res = eval_call_with_types(Callable[[Param[Literal["x"], C[T]]], T], E) + assert res is int + + +def test_eval_call_with_types_callable_12(): + T = TypeVar("T") + + class C[U]: ... + + ret = eval_call_with_types(Callable[[Param[Literal["x"], T]], C[T]], int) + assert ret == C[int] + + +def test_eval_call_with_types_callable_13(): + T = TypeVar("T") + U = TypeVar("U") + + class C(Generic[T, U]): ... + + ret = eval_call_with_types( + Callable[[Param[Literal["x"], C[int, T]]], T], C[int, str] + ) + assert ret is str + + +def test_eval_call_with_types_bind_error_01(): + T = TypeVar("T") + + with pytest.raises( + ValueError, match="Type variable T is already bound to int, but got str" + ): + eval_call_with_types( + Callable[[Param[Literal["x"], T], Param[Literal["y"], T]], T], + int, + str, + ) + + +def test_eval_call_with_types_bind_error_02(): + def func[T](x: T, y: T) -> T: ... + + with pytest.raises( + ValueError, match="Type variable T is already bound to int, but got str" + ): + eval_call_with_types(func, int, str) + + +def test_eval_call_with_types_bind_error_03(): + T = TypeVar("T") + + class C(Generic[T]): ... + + with pytest.raises( + ValueError, match="Type variable T is already bound to int, but got str" + ): + eval_call_with_types( + Callable[[Param[Literal["x"], C[T]], Param[Literal["y"], C[T]]], T], + C[int], + C[str], + ) + + +def test_eval_call_with_types_bind_error_04(): + class C[T]: ... + + def func[T](x: C[T], y: C[T]) -> T: ... + + with pytest.raises( + ValueError, match="Type variable T is already bound to int, but got str" + ): + eval_call_with_types(func, C[int], C[str]) + + +def test_eval_call_with_types_bind_error_05(): + class C[T]: ... + + class D[T]: ... + + def func[T](x: C[T]) -> T: ... + + with pytest.raises(ValueError, match="Argument type mismatch for x"): + eval_call_with_types(func, D[int]) + + +def test_eval_call_with_types_bind_error_06(): + T = TypeVar("T") + U = TypeVar("U") + + class C(Generic[T, U]): ... + + with pytest.raises(ValueError, match="Argument type mismatch for x"): + eval_call_with_types( + Callable[[Param[Literal["x"], C[int, T]]], T], C[float, str] + ) diff --git a/tests/test_type_eval.py b/tests/test_type_eval.py index ce9d1e6..0140e27 100644 --- a/tests/test_type_eval.py +++ b/tests/test_type_eval.py @@ -17,11 +17,10 @@ import pytest -from typemap.type_eval import eval_call_with_types, eval_typing +from typemap.type_eval import eval_typing from typemap.typing import ( Attrs, FromUnion, - GenericCallable, GetArg, GetArgs, GetAttr, @@ -1026,321 +1025,3 @@ def test_type_eval_annotated_03(): def test_type_eval_annotated_04(): res = eval_typing(GetAnnotations[GetAttr[AnnoTest, Literal["b"]]]) assert res == Literal["blah"] - - -def test_type_call_callable_01(): - res = eval_call_with_types(Callable[[], int]) - assert res is int - - -def test_type_call_callable_02(): - res = eval_call_with_types(Callable[[Param[Literal["x"], int]], int], int) - assert res is int - - -def test_type_call_callable_03(): - res = eval_call_with_types( - Callable[[Param[Literal["x"], int, Literal["keyword"]]], int], x=int - ) - assert res is int - - -def test_type_call_callable_04(): - class C: ... - - res = eval_call_with_types(Callable[[Param[Literal["self"], Self]], int], C) - assert res is int - - -def test_type_call_callable_05(): - class C: ... - - res = eval_call_with_types(Callable[[Param[Literal["self"], Self]], C], C) - assert res is C - - -def test_type_call_callable_06(): - class C: ... - - res = eval_call_with_types( - Callable[[Param[Literal["self"], Self], Param[Literal["x"], int]], int], - C, - int, - ) - assert res is int - - -def test_type_call_callable_07(): - class C: ... - - res = eval_call_with_types( - Callable[ - [ - Param[Literal["self"], Self], - Param[Literal["x"], int, Literal["keyword"]], - ], - int, - ], - C, - x=int, - ) - assert res is int - - -def test_type_call_callable_08(): - T = TypeVar("T") - res = eval_call_with_types(Callable[[Param[Literal["x"], T]], str], int) - assert res is str - - -def test_type_call_callable_09(): - T = TypeVar("T") - res = eval_call_with_types(Callable[[Param[Literal["x"], T]], T], int) - assert res is int - - -def test_type_call_callable_10(): - T = TypeVar("T") - - class C(Generic[T]): ... - - res = eval_call_with_types(Callable[[Param[Literal["x"], C[T]]], T], C[int]) - assert res is int - - -def test_type_call_callable_11(): - T = TypeVar("T") - - class C(Generic[T]): ... - - class D(C[int]): ... - - class E(D): ... - - res = eval_call_with_types(Callable[[Param[Literal["x"], C[T]]], T], D) - assert res is int - res = eval_call_with_types(Callable[[Param[Literal["x"], C[T]]], T], E) - assert res is int - - -def test_type_call_local_function_01(): - def func(x: int) -> int: ... - - res = eval_call_with_types(func, int) - assert res is int - - -def test_type_call_local_function_02(): - def func(*, x: int) -> int: ... - - res = eval_call_with_types(func, x=int) - assert res is int - - -def test_type_call_local_function_03(): - def func[T](x: T) -> T: ... - - res = eval_call_with_types(func, int) - assert res is int - - -def test_type_call_local_function_04(): - class C: ... - - def func(x: C) -> C: ... - - res = eval_call_with_types(func, C) - assert res is C - - -def test_type_call_local_function_05(): - class C: ... - - def func[T](x: T) -> T: ... - - res = eval_call_with_types(func, C) - assert res is C - - -def test_type_call_local_function_06(): - T = TypeVar("T") - - class C(Generic[T]): ... - - def func[U](x: C[U]) -> C[U]: ... - - res = eval_call_with_types(func, C[int]) - assert res == C[int] - - -def test_type_call_local_function_07(): - T = TypeVar("T") - - class C(Generic[T]): ... - - class D(C[int]): ... - - class E(D): ... - - def func[U](x: C[U]) -> U: ... - - res = eval_call_with_types(func, D) - assert res is int - res = eval_call_with_types(func, E) - assert res is int - - -def test_type_call_local_function_08(): - class C[T]: ... - - class D(C[int]): ... - - class E(C[str]): ... - - class F(D, E): ... - - def func[U](x: C[U]) -> U: ... - - res = eval_call_with_types(func, F) - assert res is int - - -def test_type_call_local_function_09(): - class C[T, U]: ... - - def func[V](x: C[int, V]) -> V: ... - - res = eval_call_with_types(func, C[int, str]) - assert res is str - - -def test_type_call_bind_error_01(): - T = TypeVar("T") - - with pytest.raises( - ValueError, match="Type variable T is already bound to int, but got str" - ): - eval_call_with_types( - Callable[[Param[Literal["x"], T], Param[Literal["y"], T]], T], - int, - str, - ) - - -def test_type_call_bind_error_02(): - def func[T](x: T, y: T) -> T: ... - - with pytest.raises( - ValueError, match="Type variable T is already bound to int, but got str" - ): - eval_call_with_types(func, int, str) - - -def test_type_call_bind_error_03(): - T = TypeVar("T") - - class C(Generic[T]): ... - - with pytest.raises( - ValueError, match="Type variable T is already bound to int, but got str" - ): - eval_call_with_types( - Callable[[Param[Literal["x"], C[T]], Param[Literal["y"], C[T]]], T], - C[int], - C[str], - ) - - -def test_type_call_bind_error_04(): - class C[T]: ... - - def func[T](x: C[T], y: C[T]) -> T: ... - - with pytest.raises( - ValueError, match="Type variable T is already bound to int, but got str" - ): - eval_call_with_types(func, C[int], C[str]) - - -def test_type_call_bind_error_05(): - class C[T]: ... - - class D[T]: ... - - def func[T](x: C[T]) -> T: ... - - with pytest.raises(ValueError, match="Argument type mismatch for x"): - eval_call_with_types(func, D[int]) - - -type GetCallableMember[T, N: str] = GetArg[ - tuple[ - *[ - GetType[m] - for m in Iter[Members[T]] - if ( - IsSub[GetType[m], Callable] - or IsSub[GetType[m], GenericCallable] - ) - and IsSub[GetName[m], N] - ] - ], - tuple, - 0, -] - - -def test_type_call_member_01(): - class C: - def invoke(self, x: int) -> int: ... - - res = eval_call_with_types(GetCallableMember[C, Literal["invoke"]], C, int) - assert res is int - - -def test_type_call_member_02(): - class C: - def invoke[T](self, x: T) -> T: ... - - res = eval_call_with_types(GetCallableMember[C, Literal["invoke"]], C, int) - assert res is int - - -def test_type_call_member_03(): - class C[T]: - def invoke(self, x: str) -> str: ... - - res = eval_call_with_types( - GetCallableMember[C[int], Literal["invoke"]], C[int], str - ) - assert res is str - - -def test_type_call_member_04(): - class C[T]: - def invoke(self, x: T) -> T: ... - - res = eval_call_with_types( - GetCallableMember[C[int], Literal["invoke"]], C[int], int - ) - assert res is int - - -def test_type_call_member_05(): - class C[T]: - def invoke(self) -> C[T]: ... - - res = eval_call_with_types( - GetCallableMember[C[int], Literal["invoke"]], C[int] - ) - assert res == C[int] - - -def test_type_call_member_06(): - class C[T]: - def invoke[U](self, x: U) -> C[U]: ... - - res = eval_call_with_types( - GetCallableMember[C[int], Literal["invoke"]], C[int], str - ) - assert res == C[str] diff --git a/typemap/type_eval/_apply_generic.py b/typemap/type_eval/_apply_generic.py index 2f2eb4d..f876970 100644 --- a/typemap/type_eval/_apply_generic.py +++ b/typemap/type_eval/_apply_generic.py @@ -5,18 +5,16 @@ import typing from typing import _GenericAlias as typing_GenericAlias # type: ignore [attr-defined] # noqa: PLC2701 +from typing import Any from . import _eval_typing from . import _typing_inspect -if typing.TYPE_CHECKING: - from typing import Any - @dataclasses.dataclass(frozen=True) class Boxed: - cls: type[Any] + cls: type[Any] | typing.TypeVar bases: list[Boxed] args: dict[Any, Any] @@ -74,7 +72,14 @@ def substitute(ty, args): return ty -def box(cls: type[Any]) -> Boxed: +def box(cls: type[Any] | typing.TypeVar) -> Boxed: + if isinstance(cls, typing.TypeVar): + return Boxed(cls, [], {}) + + if _typing_inspect.is_generic_type_alias(cls): + evaled = _eval_typing.eval_typing(cls) + return box(evaled) + # TODO: We want a cache for this!! def _box(cls: type[Any], args: dict[Any, Any]) -> Boxed: boxed_bases: list[Boxed] = [] @@ -107,13 +112,15 @@ def _box(cls: type[Any], args: dict[Any, Any]) -> Boxed: return Boxed(cls, boxed_bases, args) if isinstance(cls, (typing._GenericAlias, types.GenericAlias)): # type: ignore[attr-defined] - if params := getattr(cls.__origin__, "__parameters__", None): - args = dict( - zip(cls.__origin__.__parameters__, cls.__args__, strict=True) - ) + origin = typing.cast(type[Any], cls.__origin__) + + if params := getattr(origin, "__parameters__", None) or getattr( + origin, "__type_params__", None + ): + args = dict(zip(params, cls.__args__, strict=True)) else: args = {} - cls = cls.__origin__ + cls = origin else: if params := getattr(cls, "__parameters__", None): args = {p: _typing_inspect.param_default(p) for p in params} @@ -205,6 +212,9 @@ def get_local_defns(boxed: Boxed) -> tuple[dict[str, Any], dict[str, Any]]: annos: dict[str, Any] = {} dct: dict[str, Any] = {} + if isinstance(boxed.cls, typing.TypeVar): + return annos, dct + if af := typing.cast( types.FunctionType, getattr(boxed.cls, "__annotate__", None) ): @@ -305,8 +315,16 @@ def flatten_class_new_proto(cls: type) -> type: args_str = ", ".join(_type_repr(a) for a in args) args_str = f'[{args_str}]' if args_str else '' - nt.__name__ = f'{cls.__name__}{args_str}' - nt.__qualname__ = f'{cls.__qualname__}{args_str}' + origin = typing.get_origin(cls) + if isinstance(origin, typing.TypeAliasType): + cls_name = origin.__name__ + cls_qualname = origin.__name__ + else: + cls_name = cls.__name__ + cls_qualname = cls.__qualname__ + + nt.__name__ = f'{cls_name}{args_str}' + nt.__qualname__ = f'{cls_qualname}{args_str}' del nt.__subclasshook__ return nt diff --git a/typemap/type_eval/_eval_call.py b/typemap/type_eval/_eval_call.py index 8246b39..e045260 100644 --- a/typemap/type_eval/_eval_call.py +++ b/typemap/type_eval/_eval_call.py @@ -125,10 +125,21 @@ def _update_bound_typevar( elif _typing_inspect.is_generic_alias(tv): tv_args = tv.__args__ - with _eval_typing._ensure_context() as ctx: - param_args = _eval_operators._get_args( - param_value, tv.__origin__, ctx - ) + param_args: Any | None = None + if ( + isinstance(tv.__origin__, typing.TypeAliasType) + and _typing_inspect.is_generic_type_alias(param_value) + and param_value.__origin__ is tv.__origin__ + ): + # Type aliases should match their arguments 1 to 1 + # For example, binding C[A] to C[B] should bind A to B + param_args = param_value.__args__ + + else: + with _eval_typing._ensure_context() as ctx: + param_args = _eval_operators._get_args( + param_value, tv.__origin__, ctx + ) if param_args is None: raise ValueError(f"Argument type mismatch for {param_name}") @@ -136,6 +147,9 @@ def _update_bound_typevar( for p_arg, c_arg in zip(tv_args, param_args, strict=True): _update_bound_typevar(param_name, p_arg, c_arg, vars) + elif tv != param_value: + raise ValueError(f"Argument type mismatch for {param_name}") + def eval_call_with_types( func: types.FunctionType | typing.Callable, diff --git a/typemap/type_eval/_eval_operators.py b/typemap/type_eval/_eval_operators.py index 3489846..032fb4c 100644 --- a/typemap/type_eval/_eval_operators.py +++ b/typemap/type_eval/_eval_operators.py @@ -282,10 +282,46 @@ def _callable_type_to_signature(callable_type: object) -> inspect.Signature: or Never if no qualifiers """ args = typing.get_args(callable_type) - if len(args) != 2: - raise TypeError(f"Expected Callable[[...], ret], got {callable_type}") + if ( + isinstance(callable_type, types.GenericAlias) + and callable_type.__origin__ is classmethod + ): + if len(args) != 3: + raise TypeError( + f"Expected classmethod[cls, [...], ret], got {callable_type}" + ) + + receiver, param_types, return_type = typing.get_args(callable_type) + param_types = [ + Param[ + typing.Literal["cls"], + receiver, # type: ignore[valid-type] + typing.Literal["positional"], + ], + *param_types.__args__, + ] + + elif ( + isinstance(callable_type, types.GenericAlias) + and callable_type.__origin__ is staticmethod + ): + if len(args) != 2: + raise TypeError( + f"Expected staticmethod[...], ret], got {callable_type}" + ) + + param_types, return_type = typing.get_args(callable_type) + param_types = [ + *param_types.__args__, + ] + + else: + if len(args) != 2: + raise TypeError( + f"Expected Callable[[...], ret], got {callable_type}" + ) - param_types, return_type = args + param_types, return_type = args # Handle the case where param_types is a list of Param types if not isinstance(param_types, (list, tuple)):