diff --git a/tests/test_type_eval.py b/tests/test_type_eval.py index b351759..77829f4 100644 --- a/tests/test_type_eval.py +++ b/tests/test_type_eval.py @@ -544,8 +544,27 @@ class A(Generic[T]): assert eval_typing(GetArg[t, A, 1]) == Never -@pytest.mark.xfail(reason="Should this work?") +TestTypeVar = TypeVar("TestTypeVar") + + def test_eval_getarg_custom_05(): + # TypeVar declared outside of scope of class + class ATree(Generic[TestTypeVar]): + val: list[ATree[TestTypeVar]] + + t = ATree[int] + assert eval_typing(GetArg[t, ATree, 0]) is int + assert eval_typing(GetArg[t, ATree, -1]) is int + assert eval_typing(GetArg[t, ATree, 1]) == Never + + t = ATree + assert eval_typing(GetArg[t, ATree, 0]) is Any + assert eval_typing(GetArg[t, ATree, -1]) is Any + assert eval_typing(GetArg[t, ATree, 1]) == Never + + +def test_eval_getarg_custom_06(): + # TypeVar declared inside scope of class A = TypeVar("A") class ATree(Generic[A]): @@ -562,8 +581,8 @@ class ATree(Generic[A]): assert eval_typing(GetArg[t, ATree, 1]) == Never -@pytest.mark.xfail(reason="Should this work?") -def test_eval_getarg_custom_06(): +def test_eval_getarg_custom_07(): + # Doubly recursive generic types A = TypeVar("A") B = TypeVar("B") @@ -587,6 +606,30 @@ class ABTree(Generic[A, B]): assert eval_typing(GetArg[t, ABTree, 2]) == Never +def test_eval_getarg_custom_08(): + # Generic class with generic methods + T = TypeVar("T") + + class Container(Generic[T]): + data: list[T] + + def get[T](self, index: int, default: T) -> int | T: ... + def map[U](self, func: Callable[[int], U]) -> list[U]: ... + def convert[T](self, func: Callable[[int], T]) -> Container2[T]: ... + + class Container2[T]: ... + + t = Container[int] + assert eval_typing(GetArg[t, Container, 0]) is int + assert eval_typing(GetArg[t, Container, -1]) is int + assert eval_typing(GetArg[t, Container, 1]) == Never + + t = Container + assert eval_typing(GetArg[t, Container, 0]) is Any + assert eval_typing(GetArg[t, Container, -1]) is Any + assert eval_typing(GetArg[t, Container, 1]) == Never + + def test_uppercase_never(): d = eval_typing(Uppercase[Never]) assert d is Never diff --git a/typemap/type_eval/_apply_generic.py b/typemap/type_eval/_apply_generic.py index 4fd60fb..fd3ba28 100644 --- a/typemap/type_eval/_apply_generic.py +++ b/typemap/type_eval/_apply_generic.py @@ -22,7 +22,12 @@ def __post_init__(self): object.__setattr__( self, "str_args", - {str(k): v for k, v in self.args.items()}, + { + # Use __name__ when available instead of str() + # str(TypeVar('A')) returns '~A' + (k.__name__ if hasattr(k, '__name__') else str(k)): v + for k, v in self.args.items() + }, ) def __repr__(self): @@ -159,6 +164,21 @@ def make_func( return new_func +def _get_closure_types(af: types.FunctionType) -> dict[str, type]: + # Generate a fallback mapping of closure classes. + # This is needed for locally defined generic types which reference + # themselves in their type annotations. + if not af.__closure__: + return {} + return { + name: variable.cell_contents + for name, variable in zip( + af.__code__.co_freevars, af.__closure__, strict=True + ) + if isinstance(variable.cell_contents, type) + } + + def apply(cls: type[Any]) -> dict[str, Any]: mro_boxed = compute_mro(cls) @@ -166,15 +186,23 @@ def apply(cls: type[Any]) -> dict[str, Any]: dct: dict[str, Any] = {} for boxed in reversed(mro_boxed): - if af := getattr(boxed.cls, "__annotate__", None): + if af := typing.cast( + types.FunctionType, getattr(boxed.cls, "__annotate__", None) + ): # Class has annotations, let's resolve generic arguments + closure_types = _get_closure_types(af) + + def get_class_annotate_variable(name: str) -> typing.Any: + if name == "__classdict__": + return boxed.cls.__dict__ + elif name in boxed.str_args: + return boxed.str_args[name] + else: + return closure_types[name] + args = tuple( - types.CellType( - boxed.cls.__dict__ - if name == "__classdict__" - else boxed.str_args[name] - ) + types.CellType(get_class_annotate_variable(name)) for name in af.__code__.co_freevars ) @@ -205,7 +233,9 @@ def apply(cls: type[Any]) -> dict[str, Any]: stuff = inspect.unwrap(orig) if isinstance(stuff, types.FunctionType): - if af := getattr(stuff, "__annotate__", None): + if af := typing.cast( + types.FunctionType, getattr(stuff, "__annotate__", None) + ): params = dict( zip( map(str, stuff.__type_params__), @@ -214,14 +244,20 @@ def apply(cls: type[Any]) -> dict[str, Any]: ) ) + closure_types = _get_closure_types(af) + + def get_inner_annotate_variable(name: str) -> typing.Any: + if name == "__classdict__": + return boxed.cls.__dict__ + elif name in params: + return params[name] + elif name in boxed.str_args: + return boxed.str_args[name] + else: + return closure_types[name] + args = tuple( - types.CellType( - boxed.cls.__dict__ - if name == "__classdict__" - else params[name] - if name in params - else boxed.str_args[name] - ) + types.CellType(get_inner_annotate_variable(name)) for name in af.__code__.co_freevars )