Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions tests/test_type_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
TypeVar,
Union,
get_args,
overload,
)

import pytest
Expand Down Expand Up @@ -45,6 +46,7 @@
Member,
Members,
NewProtocol,
Overloaded,
Param,
Slice,
SpecialFormEllipsis,
Expand Down Expand Up @@ -466,6 +468,46 @@ def f[T](self, x: T) -> OnlyIntToSet[T]: ...
)


def test_getmember_04():
class C:
@overload
def f(self, x: int) -> set[int]: ...

@overload
def f[T](self, x: T) -> T: ...

def f(self, x): ...

m = eval_typing(GetMember[C, Literal["f"]])
mt = eval_typing(GetType[m])
assert mt.__origin__ is Overloaded
assert len(mt.__args__) == 2

# Non-generic overload
assert (
eval_typing(IsAssignable[GetArg[mt, Overloaded, Literal[0]], Callable])
== _BoolLiteral[True]
)
assert (
mt.__args__[0]
== Callable[
[Param[Literal["self"], C], Param[Literal["x"], int]], set[int]
]
)

# Generic overload
assert (
eval_typing(
IsAssignable[GetArg[mt, Overloaded, Literal[1]], GenericCallable]
)
== _BoolLiteral[True]
)
assert (
eval_typing(mt.__args__[1].__args__[1](int))
== Callable[[Param[Literal["self"], C], Param[Literal["x"], int]], int]
)


def test_getarg_never():
d = eval_typing(GetArg[Never, object, Literal[0]])
assert d is Never
Expand Down
96 changes: 15 additions & 81 deletions typemap/type_eval/_apply_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,14 @@ def _resolved_function_signature(func, args):
return sig


def get_local_defns(boxed: Boxed) -> tuple[dict[str, Any], dict[str, Any]]:
def get_local_defns(
boxed: Boxed,
) -> tuple[
dict[str, Any],
dict[
str, types.FunctionType | classmethod | staticmethod | WrappedOverloads
],
]:
from typemap.typing import GenericCallable

annos: dict[str, Any] = {}
Expand Down Expand Up @@ -327,6 +334,8 @@ def get_local_defns(boxed: Boxed) -> tuple[dict[str, Any], dict[str, Any]]:
# XXX: This is totally wrong; we still need to do
# substitute in class vars
local_fn = stuff
elif overloads := typing.get_overloads(stuff):
local_fn = WrappedOverloads(tuple(overloads))

# If we got stuck, we build a GenericCallable that
# computes the type once it has been given type
Expand Down Expand Up @@ -370,6 +379,11 @@ def lam(*vs):
return annos, dct


@dataclasses.dataclass(frozen=True)
class WrappedOverloads:
functions: tuple[typing.Callable[..., Any], ...]


def flatten_class_new_proto(cls: type) -> type:
# This is a hacky version of flatten_class that works by using
# NewProtocol on Members!
Expand Down Expand Up @@ -405,84 +419,4 @@ def _type_repr(t: Any) -> str:
return repr(t)


# TODO: Potentially most of this could be ripped out. The internals
# don't use this at all, it's only used by format_class.
def _flatten_class_explicit(
cls: type[Any], ctx: _eval_typing.EvalContext
) -> type[_eval_typing._EvalProxy]:
cls_boxed = box(cls)
mro_boxed = cls_boxed.mro

# TODO: I think we want to create the whole mro chain...
# before we evaluate the contents?

# FIXME: right now we flatten out all the attributes... but should we??
# XXX: Yeah, a lot of work is put into copying everything into every
# class and it is not worth it, at all.

new = {}

# Run through the mro and populate everything
for boxed in reversed(mro_boxed):
# We create it early so we can add it to seen, to handle recursion
# XXX: currently we are doing this even for types with no generics...
# that simplifies the flow... - probably keep it this way until
# we stop flattening attributes into every class
name = boxed.cls.__name__
cboxed: Any

args = tuple(boxed.args.values())
args_str = ", ".join(_type_repr(a) for a in args)
fullname = f"{name}[{args_str}]" if args_str else name
cboxed = type(
fullname,
(_eval_typing._EvalProxy,),
{
"__module__": boxed.cls.__module__,
"__name__": fullname,
"__origin__": boxed.cls,
"__local_args__": args,
},
)
new[boxed] = cboxed

annos: dict[str, Any] = {}
dct: dict[str, Any] = {}
sources: dict[str, Any] = {}

cboxed.__local_annotations__, cboxed.__local_defns__ = get_local_defns(
boxed
)
for base in reversed(boxed.mro):
cbase = new[base]
annos.update(cbase.__local_annotations__)
dct.update(cbase.__local_defns__) # uh.
for k in [*cbase.__local_annotations__, *cbase.__local_defns__]:
sources[k] = cbase

cboxed.__defn_names__ = set(dct)
cboxed.__annotations__ = annos
cboxed.__defn_sources__ = sources
cboxed.__generalized_mro__ = [new[b] for b in boxed.mro]

for k, v in dct.items():
setattr(cboxed, k, v)

# Run through the mro again and evaluate everything
for cboxed in new.values():
for k, v in cboxed.__annotations__.items():
cboxed.__annotations__[k] = _eval_typing._eval_types(v, ctx=ctx)

for k in cboxed.__defn_names__:
v = cboxed.__dict__[k]
setattr(cboxed, k, _eval_typing._eval_types(v, ctx=ctx))

return new[cls_boxed]


def flatten_class_explicit(obj: typing.Any):
with _eval_typing._ensure_context() as ctx:
return _flatten_class_explicit(obj, ctx)


Comment on lines -408 to -487
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this dead code as a standalone PR, please, since it might be worth recovering at some point.

flatten_class = flatten_class_new_proto
12 changes: 12 additions & 0 deletions typemap/type_eval/_eval_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
Member,
Members,
NewProtocol,
Overloaded,
Param,
RaiseError,
Slice,
Expand Down Expand Up @@ -169,6 +170,17 @@ def get_annotated_method_hints(cls, *, ctx):
object,
acls,
)
elif isinstance(attr, _apply_generic.WrappedOverloads):
overloads = [
_function_type(_eval_types(of, ctx), receiver_type=acls)
for of in attr.functions
]
hints[name] = (
Overloaded[*overloads],
("ClassVar",),
object,
acls,
)

return hints

Expand Down
4 changes: 4 additions & 0 deletions typemap/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ def __class_getitem__(cls, params):
return _GenericCallableGenericAlias(cls, (typevars, func))


class Overloaded[*Callables]:
pass


###


Expand Down