Skip to content

Commit d9884a0

Browse files
committed
Implement.
1 parent 35afd43 commit d9884a0

2 files changed

Lines changed: 39 additions & 1 deletion

File tree

typemap/type_eval/_apply_generic.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,14 @@ def _resolved_function_signature(func, args):
281281
return sig
282282

283283

284-
def get_local_defns(boxed: Boxed) -> tuple[dict[str, Any], dict[str, Any]]:
284+
def get_local_defns(
285+
boxed: Boxed,
286+
) -> tuple[
287+
dict[str, Any],
288+
dict[
289+
str, types.FunctionType | classmethod | staticmethod | WrappedOverloaded
290+
],
291+
]:
285292
from typemap.typing import GenericCallable
286293

287294
annos: dict[str, Any] = {}
@@ -315,6 +322,8 @@ def get_local_defns(boxed: Boxed) -> tuple[dict[str, Any], dict[str, Any]]:
315322
# XXX: This is totally wrong; we still need to do
316323
# substitute in class vars
317324
local_fn = stuff
325+
elif overloaded := _is_overloaded_function(stuff):
326+
local_fn = overloaded
318327

319328
# If we got stuck, we build a GenericCallable that
320329
# computes the type once it has been given type
@@ -358,6 +367,23 @@ def lam(*vs):
358367
return annos, dct
359368

360369

370+
@dataclasses.dataclass(frozen=True)
371+
class WrappedOverloaded:
372+
functions: tuple[types.FunctionType, ...]
373+
374+
375+
def _is_overloaded_function(func):
376+
module_overload_registry = typing._overload_registry[func.__module__]
377+
if not module_overload_registry:
378+
return None
379+
380+
func_overload_registry = module_overload_registry[func.__qualname__]
381+
if not func_overload_registry:
382+
return
383+
384+
return WrappedOverloaded(tuple(func_overload_registry.values()))
385+
386+
361387
def flatten_class_new_proto(cls: type) -> type:
362388
# This is a hacky version of flatten_class that works by using
363389
# NewProtocol on Members!

typemap/type_eval/_eval_operators.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
Member,
4040
Members,
4141
NewProtocol,
42+
Overloaded,
4243
Param,
4344
RaiseError,
4445
Slice,
@@ -169,6 +170,17 @@ def get_annotated_method_hints(cls, *, ctx):
169170
object,
170171
acls,
171172
)
173+
elif isinstance(attr, _apply_generic.WrappedOverloaded):
174+
overloads = [
175+
_function_type(_eval_types(of, ctx), receiver_type=acls)
176+
for of in attr.functions
177+
]
178+
hints[name] = (
179+
Overloaded[*overloads],
180+
("ClassVar",),
181+
object,
182+
acls,
183+
)
172184

173185
return hints
174186

0 commit comments

Comments
 (0)