Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 210 additions & 0 deletions tests/test_type_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2017,6 +2017,199 @@ def g(self) -> int: ...
)


def test_update_class_members_05():
# Generic base class with UpdateClass
class A[T]:
a: T

def __init_subclass__[U](
cls: type[U],
) -> AttrsAsSets[U]:
super().__init_subclass__()

class B(A[int]):
b: str

attrs = eval_typing(Attrs[B])
assert attrs.__args__ == (
Member[Literal["a"], set[int], Never, Never, B],
Member[Literal["b"], set[str], Never, Never, B],
)


def test_update_class_members_06():
# Generic derived class with UpdateClass
class A:
a: int

def __init_subclass__[T](
cls: type[T],
) -> AttrsAsSets[T]:
super().__init_subclass__()

class B[T](A):
b: T

attrs = eval_typing(Attrs[B[int]])
assert attrs.__args__ == (
Member[Literal["a"], set[int], Never, Never, B[int]],
Member[Literal["b"], set[int], Never, Never, B[int]],
)

attrs = eval_typing(Attrs[B[str]])
assert attrs.__args__ == (
Member[Literal["a"], set[int], Never, Never, B[str]],
Member[Literal["b"], set[str], Never, Never, B[str]],
)


def test_update_class_members_07():
# Generic derived and base class with UpdateClass
# derived from generic base
class A[T]:
a: T

def __init_subclass__[U](
cls: type[U],
) -> AttrsAsSets[U]:
super().__init_subclass__()

class B[T](A[tuple[T]]):
b: T

class C[T, U](A[U]):
c: T

attrs = eval_typing(Attrs[B[int]])
assert attrs.__args__ == (
Member[Literal["a"], set[tuple[int]], Never, Never, B[int]],
Member[Literal["b"], set[int], Never, Never, B[int]],
)

attrs = eval_typing(Attrs[C[int, str]])
assert attrs.__args__ == (
Member[Literal["a"], set[str], Never, Never, C[int, str]],
Member[Literal["c"], set[int], Never, Never, C[int, str]],
)


def test_update_class_members_08():
# Generic derived and base class with UpdateClass
# derived from specialized base
class A[T]:
a: T

def __init_subclass__[U](
cls: type[U],
) -> AttrsAsSets[U]:
super().__init_subclass__()

class B[T](A[int]):
b: T

attrs = eval_typing(Attrs[B[int]])
assert attrs.__args__ == (
Member[Literal["a"], set[int], Never, Never, B[int]],
Member[Literal["b"], set[int], Never, Never, B[int]],
)

attrs = eval_typing(Attrs[B[str]])
assert attrs.__args__ == (
Member[Literal["a"], set[int], Never, Never, B[str]],
Member[Literal["b"], set[str], Never, Never, B[str]],
)


def test_update_class_members_09():
# Generic classes which use their type params in UpdateClass
class A[V]:
def __init_subclass__[T](
cls: type[T],
) -> UpdateClass[Member[Literal["a"], V], *Attrs[T]]:
super().__init_subclass__()

class B[V](A[int]):
def __init_subclass__[T](
cls: type[T],
) -> UpdateClass[Member[Literal["b"], V], *Attrs[T]]:
super().__init_subclass__()

class C[V](B[str]):
def __init_subclass__[T](
cls: type[T],
) -> UpdateClass[Member[Literal["c"], V], *Attrs[T]]:
super().__init_subclass__()

class D[V](C[float]):
def __init_subclass__[T](
cls: type[T],
) -> UpdateClass[Member[Literal["d"], V]]:
super().__init_subclass__()

attrs = eval_typing(Attrs[A[int]])
assert attrs == tuple[()]

attrs = eval_typing(Attrs[B[str]])
assert attrs == tuple[Member[Literal["a"], int, Never, Never, B[str]]]

attrs = eval_typing(Attrs[C[float]])
assert (
attrs
== tuple[
Member[Literal["a"], int, Never, Never, C[float]],
Member[Literal["b"], str, Never, Never, C[float]],
]
)

attrs = eval_typing(Attrs[D[bool]])
assert (
attrs
== tuple[
Member[Literal["a"], int, Never, Never, D[bool]],
Member[Literal["b"], str, Never, Never, D[bool]],
Member[Literal["c"], float, Never, Never, D[bool]],
]
)


@pytest.mark.xfail(reason="Super sketchy....")
def test_update_class_members_10():
# Generic classes which use other classes in their hierarchy
# within UpdateClass.
class A:
a: int

def __init_subclass__[T](
cls: type[T],
) -> UpdateClass[
Member[Literal["b"], B[str]],
Member[Literal["c"], C[float]],
Member[Literal["d"], D[bool]],
]:
super().__init_subclass__()

class B[T](A):
pass

class C[T](B[T]):
pass

class D[T](C[T]):
pass

attrs = eval_typing(Attrs[C[int]])
# A's __init_subclass__ adds Member["x", B[str]]; we get "a" from A and "x" from UpdateClass.
assert (
attrs
== tuple[
Member[Literal["a"], int, Never, Never, A],
Member[Literal["b"], B[str], Never, Never, C[int]],
Member[Literal["c"], C[float], Never, Never, C[int]],
Member[Literal["d"], D[bool], Never, Never, C[int]],
]
)


def test_update_class_inheritance_01():
# current class init subclass is not applied
class A:
Expand Down Expand Up @@ -2134,6 +2327,23 @@ class C(B[float]):
assert eval_typing(GetArg[C, A, Literal[1]]) is float


@pytest.mark.xfail(reason="TODO")
def test_update_class_empty_01():
class A:
a: int

def __init_subclass__[T](
cls: type[T],
) -> UpdateClass[()]:
super().__init_subclass__()

class B(A):
b: int

attrs = eval_typing(Attrs[B])
assert attrs == tuple[()]


##############

type XTest[X] = Annotated[X, 'blah']
Expand Down
4 changes: 4 additions & 0 deletions typemap/type_eval/_apply_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,10 @@ def get_annotations(
if mod := sys.modules.get(obj.__module__):
globs.update(vars(mod))

# Make a copy in case we need to eval the annotations. We don't want to
# modify the original.
rr = dict(rr)

if isinstance(rr, dict) and any(isinstance(v, str) for v in rr.values()):
args = dict(args)
# Copy in any __type_params__ that aren't provided for, so that if
Expand Down
62 changes: 49 additions & 13 deletions typemap/type_eval/_eval_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typemap.type_eval._eval_typing import (
_child_context,
_eval_types,
EvalContext,
)
from typemap.typing import (
Attrs,
Expand Down Expand Up @@ -82,7 +83,7 @@ def _make_init_type(v):
return typing.Literal[(v,)]


def cached_box(cls, *, ctx):
def cached_box(cls, *, ctx: EvalContext):
if str(cls).startswith('typemap.typing'):
return _apply_generic.box(cls)
if cls in ctx.box_cache:
Expand Down Expand Up @@ -186,34 +187,49 @@ def get_annotated_method_hints(cls, *, ctx):


def _eval_init_subclass(
box: _apply_generic.Boxed, ctx: typing.Any
box: _apply_generic.Boxed, ctx: EvalContext
) -> _apply_generic.Boxed:
"""Get type after all __init_subclass__ with UpdateClass are evaluated."""
for abox in box.mro[1:]: # Skip the type itself
with _child_context() as ctx:
if ms := _get_update_class_members(
box.cls, abox.alias_type(), ctx=ctx
):
if ms := _get_update_class_members(box, abox, ctx=ctx):
nbox = _apply_generic.box(
_create_updated_class(box.cls, ms, ctx=ctx)
_create_updated_class(box, ms, ctx=ctx)
)
# We want to preserve the original cls for Members output
box = dataclasses.replace(nbox, orig_cls=box.canonical_cls)
ctx.box_cache[box.cls] = box
box = dataclasses.replace(
nbox, orig_cls=box.canonical_cls, args=box.args
)
ctx.box_cache[box.alias_type()] = box
return box


def _get_update_class_members(
cls: type, base: type, ctx: typing.Any
box: _apply_generic.Boxed,
boxed_base: _apply_generic.Boxed,
ctx: EvalContext,
) -> list[Member] | None:
init_subclass = base.__dict__.get("__init_subclass__")
cls = box.cls

# Get __init_subclass__ from the base class's origin if base is generic.
base_origin = boxed_base.cls
init_subclass = base_origin.__dict__.get("__init_subclass__")
if not init_subclass:
return None
init_subclass = inspect.unwrap(init_subclass)

args = {}
# Get any type params from the base class if it is generic
if (base_args := boxed_base.args.values()) and (
origin_params := getattr(base_origin, '__type_params__', None)
):
Comment on lines +223 to +225
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use str_args?

args = dict(
zip((p.__name__ for p in origin_params), base_args, strict=True)
)

# Get type params from function
if type_params := getattr(init_subclass, '__type_params__', None):
args[str(type_params[0])] = cls
args[type_params[0].__name__] = box.alias_type()

init_subclass_annos = _apply_generic.get_annotations(init_subclass, args)

Expand Down Expand Up @@ -256,7 +272,10 @@ def _get_update_class_members(
return None


def _create_updated_class(t: type, ms: list[Member], ctx) -> type:
def _create_updated_class(
box: _apply_generic.Boxed, ms: list[Member], ctx: EvalContext
) -> type:
t = box.cls
dct: dict[str, object] = {}

# Copy the module
Expand All @@ -280,8 +299,25 @@ def _create_updated_class(t: type, ms: list[Member], ctx) -> type:
_unpack_init(dct, member_name, init)

# Create the updated class

# If typing.Generic is a base, we need to use it with the type params
# applied. Additionally, use types.newclass to properly resolve the mro.
bases = tuple(
b.alias_type()
if b.cls is not typing.Generic
else typing.Generic[t.__type_params__] # type: ignore[index]
for b in box.bases
)

kwds = {}
mcls = type(t)
cls = mcls(t.__name__, t.__bases__, dct)
if mcls is not type:
kwds["metaclass"] = mcls

cls = types.new_class(t.__name__, bases, kwds, lambda ns: ns.update(dct))
# Explicitly set __type_params__. This normally doesn't work, but we are
# creating fake classes for the purpose of type evaluation.
cls.__type_params__ = t.__type_params__

return cls

Expand Down