Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tests/test_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ def test_call_1():
ret = eval_call(func, a=1, b=2, c="aaa")
fmt = format_helper.format_class(ret)

# XXX: We want better than this for a name
assert fmt == textwrap.dedent("""\
class Protocol:
class NewProtocol:
a: int
b: int
c: int
Expand Down
3 changes: 1 addition & 2 deletions tests/test_type_dir.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,10 @@ def sbase[Z](cls, a: int | Literal['gotcha!'] | Z | None, b: ~K) -> dict[str, in
def test_type_dir_2():
d = eval_typing(OptionalFinal)

# XXX: the class should probably be named something like "AllOptional__T"
# XXX: `DirProperties` skips methods, true to its name. Perhaps we just need
# `Dir` that would iterate over everything
assert format_helper.format_class(d) == textwrap.dedent("""\
class Protocol:
class AllOptional[tests.test_type_dir.Final]:
last: int | typing.Literal[True] | None
iii: str | int | typing.Literal['gotcha!'] | None
t: dict[str, str | int | typing.Literal['gotcha!']] | None
Expand Down
4 changes: 2 additions & 2 deletions tests/test_type_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ def test_eval_types_2():
assert evaled.__annotations__["a"].__args__[0] is evaled

assert format_helper.format_class(evaled) == textwrap.dedent("""\
class Protocol:
class MapRecursive[tests.test_type_eval.Recursive]:
n: int | typing.Literal['gotcha!']
m: str | typing.Literal['gotcha!']
t: typing.Literal[False] | typing.Literal['gotcha!']
a: abc.Protocol | typing.Literal['gotcha!']
a: tests.test_type_eval.MapRecursive[tests.test_type_eval.Recursive] | typing.Literal['gotcha!']
fff: int | typing.Literal['gotcha!']
control: float
""")
4 changes: 2 additions & 2 deletions typemap/type_eval/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from ._eval_call import eval_call
from ._eval_typing import eval_typing
from ._eval_typing import eval_typing, _get_current_context


__all__ = ("eval_typing", "eval_call")
__all__ = ("eval_typing", "eval_call", "_get_current_context")
1
5 changes: 5 additions & 0 deletions typemap/type_eval/_eval_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@


def eval_call(func: types.FunctionType, /, *args: Any, **kwargs: Any) -> Any:
with _eval_typing._ensure_context():
return _eval_call(func, *args, **kwargs)


def _eval_call(func: types.FunctionType, /, *args: Any, **kwargs: Any) -> Any:
vars = {}

params = func.__type_params__
Expand Down
35 changes: 29 additions & 6 deletions typemap/type_eval/_eval_typing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import annotationlib

import contextlib
import contextvars
import dataclasses
import functools
Expand All @@ -21,6 +22,7 @@
@dataclasses.dataclass
class EvalContext:
seen: dict[Any, Any]
current_alias: types.GenericAlias | None = None


# `eval_types()` calls can be nested, context must be preserved
Expand All @@ -29,7 +31,8 @@ class EvalContext:
)


def eval_typing(obj: typing.Any):
@contextlib.contextmanager
def _ensure_context() -> typing.Iterator[EvalContext]:
ctx = _current_context.get()
ctx_set = False
if ctx is None:
Expand All @@ -40,12 +43,26 @@ def eval_typing(obj: typing.Any):
ctx_set = True

try:
return _eval_types(obj, ctx)
yield ctx
finally:
if ctx_set:
_current_context.set(None)


def _get_current_context() -> EvalContext:
ctx = _current_context.get()
if not ctx:
raise RuntimeError(
"type_eval._get_current_context() called outside of eval_types()"
)
return ctx


def eval_typing(obj: typing.Any):
with _ensure_context() as ctx:
return _eval_types(obj, ctx)


def _eval_types(obj: typing.Any, ctx: EvalContext):
if obj in ctx.seen:
return ctx.seen[obj]
Expand Down Expand Up @@ -131,15 +148,21 @@ def _eval_generic(obj: types.GenericAlias, ctx: EvalContext):

args = tuple(types.CellType(_eval_types(arg, ctx)) for arg in obj.__args__)
mod = sys.modules[obj.__module__]
ff = types.FunctionType(func.__code__, mod.__dict__, None, None, args)
unpacked = ff(annotationlib.Format.VALUE)

ctx.seen[obj] = unpacked
old_obj = ctx.current_alias
ctx.current_alias = obj

try:
ff = types.FunctionType(func.__code__, mod.__dict__, None, None, args)
unpacked = ff(annotationlib.Format.VALUE)

ctx.seen[obj] = unpacked
evaled = _eval_types(unpacked, ctx)
except Exception:
ctx.seen.pop(obj)
ctx.seen.pop(obj, None)
raise
finally:
ctx.current_alias = old_obj

return evaled

Expand Down
15 changes: 13 additions & 2 deletions typemap/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,20 @@ def __getitem__(cls, val: list[Property]):
dct = {}
dct["__annotations__"] = {prop.name: prop.type for prop in val}

module_name = __name__
name = "NewProtocol"

# If the type evaluation context
ctx = type_eval._get_current_context()
if ctx.current_alias:
name = str(ctx.current_alias)
module_name = ctx.current_alias.__module__

dct["__module__"] = module_name

mcls = type(typing.Protocol)
# TODO: Replace the "Protocol" name with the type alias name
return mcls("Protocol", (typing.Protocol,), dct)
cls = mcls(name, (typing.Protocol,), dct)
return cls


class NewProtocol(metaclass=NewProtocolMeta):
Expand Down