diff --git a/tests/test_call.py b/tests/test_call.py index 2acf3ff..a0e92e0 100644 --- a/tests/test_call.py +++ b/tests/test_call.py @@ -17,8 +17,9 @@ def test_call_1(): ret = eval_call(func, a=1, b=2, c="aaa") fmt = format_helper.format_class(ret) + # XXX: We want better than this for a name assert fmt == textwrap.dedent("""\ - class Protocol: + class NewProtocol: a: int b: int c: int diff --git a/tests/test_type_dir.py b/tests/test_type_dir.py index 9dc9c5f..98842a4 100644 --- a/tests/test_type_dir.py +++ b/tests/test_type_dir.py @@ -81,11 +81,10 @@ def sbase[Z](cls, a: int | Literal['gotcha!'] | Z | None, b: ~K) -> dict[str, in def test_type_dir_2(): d = eval_typing(OptionalFinal) - # XXX: the class should probably be named something like "AllOptional__T" # XXX: `DirProperties` skips methods, true to its name. Perhaps we just need # `Dir` that would iterate over everything assert format_helper.format_class(d) == textwrap.dedent("""\ - class Protocol: + class AllOptional[tests.test_type_dir.Final]: last: int | typing.Literal[True] | None iii: str | int | typing.Literal['gotcha!'] | None t: dict[str, str | int | typing.Literal['gotcha!']] | None diff --git a/tests/test_type_eval.py b/tests/test_type_eval.py index 8882cfc..26b4a45 100644 --- a/tests/test_type_eval.py +++ b/tests/test_type_eval.py @@ -52,11 +52,11 @@ def test_eval_types_2(): assert evaled.__annotations__["a"].__args__[0] is evaled assert format_helper.format_class(evaled) == textwrap.dedent("""\ - class Protocol: + class MapRecursive[tests.test_type_eval.Recursive]: n: int | typing.Literal['gotcha!'] m: str | typing.Literal['gotcha!'] t: typing.Literal[False] | typing.Literal['gotcha!'] - a: abc.Protocol | typing.Literal['gotcha!'] + a: tests.test_type_eval.MapRecursive[tests.test_type_eval.Recursive] | typing.Literal['gotcha!'] fff: int | typing.Literal['gotcha!'] control: float """) diff --git a/typemap/type_eval/__init__.py b/typemap/type_eval/__init__.py index a7738e4..80690a4 100644 --- a/typemap/type_eval/__init__.py +++ b/typemap/type_eval/__init__.py @@ -1,6 +1,6 @@ from ._eval_call import eval_call -from ._eval_typing import eval_typing +from ._eval_typing import eval_typing, _get_current_context -__all__ = ("eval_typing", "eval_call") +__all__ = ("eval_typing", "eval_call", "_get_current_context") 1 diff --git a/typemap/type_eval/_eval_call.py b/typemap/type_eval/_eval_call.py index 93f7e50..aff404a 100644 --- a/typemap/type_eval/_eval_call.py +++ b/typemap/type_eval/_eval_call.py @@ -11,6 +11,11 @@ def eval_call(func: types.FunctionType, /, *args: Any, **kwargs: Any) -> Any: + with _eval_typing._ensure_context(): + return _eval_call(func, *args, **kwargs) + + +def _eval_call(func: types.FunctionType, /, *args: Any, **kwargs: Any) -> Any: vars = {} params = func.__type_params__ diff --git a/typemap/type_eval/_eval_typing.py b/typemap/type_eval/_eval_typing.py index 720dae5..e1fbdaa 100644 --- a/typemap/type_eval/_eval_typing.py +++ b/typemap/type_eval/_eval_typing.py @@ -1,5 +1,6 @@ import annotationlib +import contextlib import contextvars import dataclasses import functools @@ -21,6 +22,7 @@ @dataclasses.dataclass class EvalContext: seen: dict[Any, Any] + current_alias: types.GenericAlias | None = None # `eval_types()` calls can be nested, context must be preserved @@ -29,7 +31,8 @@ class EvalContext: ) -def eval_typing(obj: typing.Any): +@contextlib.contextmanager +def _ensure_context() -> typing.Iterator[EvalContext]: ctx = _current_context.get() ctx_set = False if ctx is None: @@ -40,12 +43,26 @@ def eval_typing(obj: typing.Any): ctx_set = True try: - return _eval_types(obj, ctx) + yield ctx finally: if ctx_set: _current_context.set(None) +def _get_current_context() -> EvalContext: + ctx = _current_context.get() + if not ctx: + raise RuntimeError( + "type_eval._get_current_context() called outside of eval_types()" + ) + return ctx + + +def eval_typing(obj: typing.Any): + with _ensure_context() as ctx: + return _eval_types(obj, ctx) + + def _eval_types(obj: typing.Any, ctx: EvalContext): if obj in ctx.seen: return ctx.seen[obj] @@ -131,15 +148,21 @@ def _eval_generic(obj: types.GenericAlias, ctx: EvalContext): args = tuple(types.CellType(_eval_types(arg, ctx)) for arg in obj.__args__) mod = sys.modules[obj.__module__] - ff = types.FunctionType(func.__code__, mod.__dict__, None, None, args) - unpacked = ff(annotationlib.Format.VALUE) - ctx.seen[obj] = unpacked + old_obj = ctx.current_alias + ctx.current_alias = obj + try: + ff = types.FunctionType(func.__code__, mod.__dict__, None, None, args) + unpacked = ff(annotationlib.Format.VALUE) + + ctx.seen[obj] = unpacked evaled = _eval_types(unpacked, ctx) except Exception: - ctx.seen.pop(obj) + ctx.seen.pop(obj, None) raise + finally: + ctx.current_alias = old_obj return evaled diff --git a/typemap/typing.py b/typemap/typing.py index 104e63a..ad31ddf 100644 --- a/typemap/typing.py +++ b/typemap/typing.py @@ -88,9 +88,20 @@ def __getitem__(cls, val: list[Property]): dct = {} dct["__annotations__"] = {prop.name: prop.type for prop in val} + module_name = __name__ + name = "NewProtocol" + + # If the type evaluation context + ctx = type_eval._get_current_context() + if ctx.current_alias: + name = str(ctx.current_alias) + module_name = ctx.current_alias.__module__ + + dct["__module__"] = module_name + mcls = type(typing.Protocol) - # TODO: Replace the "Protocol" name with the type alias name - return mcls("Protocol", (typing.Protocol,), dct) + cls = mcls(name, (typing.Protocol,), dct) + return cls class NewProtocol(metaclass=NewProtocolMeta):