Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 50 additions & 2 deletions tests/test_type_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2210,6 +2210,55 @@ class D[T](C[T]):
)


def test_update_class_members_11():
class A:
a: int

def __init_subclass__[T](
cls: type[T],
) -> UpdateClass[*Members[T]]:
super().__init_subclass__()

def f(self) -> int: ...

class B(A):
b: str

def g(self) -> str: ...

attrs = eval_typing(Attrs[B])
assert (
attrs
== tuple[
Member[Literal["a"], int, Never, Never, B],
Member[Literal["b"], str, Never, Never, B],
]
)

members = eval_typing(MembersExceptInitSubclass[B])
assert (
members
== tuple[
Member[Literal["a"], int, Never, Never, B],
Member[Literal["b"], str, Never, Never, B],
Member[
Literal["f"],
Callable[[Param[Literal["self"], Self]], int],
Literal["ClassVar"],
object,
B,
],
Member[
Literal["g"],
Callable[[Param[Literal["self"], Self]], str],
Literal["ClassVar"],
object,
B,
],
]
)


def test_update_class_inheritance_01():
# current class init subclass is not applied
class A:
Expand Down Expand Up @@ -2327,7 +2376,6 @@ class C(B[float]):
assert eval_typing(GetArg[C, A, Literal[1]]) is float


@pytest.mark.xfail(reason="TODO")
def test_update_class_empty_01():
class A:
a: int
Expand All @@ -2341,7 +2389,7 @@ class B(A):
b: int

attrs = eval_typing(Attrs[B])
assert attrs == tuple[()]
assert attrs == tuple[Member[Literal["a"], int, Never, Never, A]]


##############
Expand Down
18 changes: 11 additions & 7 deletions typemap/type_eval/_eval_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typemap.type_eval import _apply_generic, _typing_inspect
from typemap.type_eval._eval_typing import (
_child_context,
_eval_args,
_eval_types,
EvalContext,
)
Expand Down Expand Up @@ -192,7 +193,8 @@ def _eval_init_subclass(
"""Get type after all __init_subclass__ with UpdateClass are evaluated."""
for abox in box.mro[1:]: # Skip the type itself
with _child_context() as ctx:
if ms := _get_update_class_members(box, abox, ctx=ctx):
ms = _get_update_class_members(box, abox, ctx=ctx)
if ms is not None:
nbox = _apply_generic.box(
_create_updated_class(box, ms, ctx=ctx)
)
Expand All @@ -208,7 +210,7 @@ def _get_update_class_members(
box: _apply_generic.Boxed,
boxed_base: _apply_generic.Boxed,
ctx: EvalContext,
) -> list[Member] | None:
) -> typing.Sequence[Member] | None:
cls = box.cls

# Get __init_subclass__ from the base class's origin if base is generic.
Expand Down Expand Up @@ -267,13 +269,13 @@ def _get_update_class_members(
_typing_inspect.is_generic_alias(evaled_ret)
and typing.get_origin(evaled_ret) is UpdateClass
):
return [m for m in typing.get_args(evaled_ret)]
return _eval_args(typing.get_args(evaled_ret), ctx)

return None


def _create_updated_class(
box: _apply_generic.Boxed, ms: list[Member], ctx: EvalContext
box: _apply_generic.Boxed, ms: typing.Sequence[Member], ctx: EvalContext
) -> type:
t = box.cls
dct: dict[str, object] = {}
Expand All @@ -289,9 +291,11 @@ def _create_updated_class(
typ = _eval_types(typ, ctx)
tquals = _eval_types(quals, ctx)

if type_eval.issubtype(
typing.Literal["ClassVar"], tquals
) and _is_method_like(typ):
if (
type_eval.issubtype(typing.Literal["ClassVar"], tquals)
and _is_method_like(typ)
and _typing_inspect.get_head(typ) is not GenericCallable
):
dct[member_name] = _callable_type_to_method(member_name, typ, ctx)
else:
# Update/add the annotation
Expand Down