diff --git a/tests/test_qblike.py b/tests/test_qblike.py index 8bee327..8870589 100644 --- a/tests/test_qblike.py +++ b/tests/test_qblike.py @@ -1,5 +1,7 @@ import textwrap +from typing import Literal + from typemap.type_eval import eval_call, eval_typing from typemap.typing import ( NewProtocol, @@ -106,10 +108,15 @@ def test_qblike_3(): class select[...]: x: tests.test_qblike.Property[int] w: tests.test_qblike.Property[list[str]] - z: tests.test_qblike.Link[PropsOnly[tests.test_qblike.Tgt]] + z: tests.test_qblike.Link[PropsOnly[typemap.typing.GetArg[\ +tests.test_qblike.Link[tests.test_qblike.Tgt], tests.test_qblike.Link, 0]]] """) + # z: tests.test_qblike.Link[PropsOnly[tests.test_qblike.Tgt]] - tgt = eval_typing(GetAttr[ret, "z"].__args__[0]) + res = eval_typing(GetAttr[ret, Literal["z"]]) + tgt = res.__args__[0] + # XXX: this should probably be pre-evaluated already? + tgt = eval_typing(tgt) fmt = format_helper.format_class(tgt) assert fmt == textwrap.dedent("""\ diff --git a/tests/test_type_dir.py b/tests/test_type_dir.py index e629f2d..1f01146 100644 --- a/tests/test_type_dir.py +++ b/tests/test_type_dir.py @@ -232,7 +232,9 @@ class NoLiterals2[tests.test_type_dir.Final]: def test_type_dir_7(): d = eval_typing(Members[Final]) - foo = next(iter(m for m in Iter[d] if m.__args__[0].__args__[0] == "foo")) + foo = next( + iter(m for m in d.__args__ if m.__args__[0].__args__[0] == "foo") + ) # XXX: drop self? assert ( str(foo) diff --git a/tests/test_type_eval.py b/tests/test_type_eval.py index 91be373..a7dcea4 100644 --- a/tests/test_type_eval.py +++ b/tests/test_type_eval.py @@ -69,15 +69,17 @@ def test_eval_types_1(): def test_eval_types_2(): evaled = eval_typing(MapRecursive[Recursive]) - # Validate that recursion worked properly and "Recursive" was only walked once - assert evaled.__annotations__["a"].__args__[0] is evaled + # FIXME, or think about: this doesn't work, we currently evaluate it to an + # *unexpanded* type alias. + # # Validate that recursion worked properly and "Recursive" was only walked once + # assert evaled.__annotations__["a"].__args__[0] is evaled assert format_helper.format_class(evaled) == textwrap.dedent("""\ class MapRecursive[tests.test_type_eval.Recursive]: n: int | typing.Literal['gotcha!'] m: str | typing.Literal['gotcha!'] t: typing.Literal[False] | typing.Literal['gotcha!'] - a: tests.test_type_eval.MapRecursive[tests.test_type_eval.Recursive] | typing.Literal['gotcha!'] + a: MapRecursive[tests.test_type_eval.Recursive] | typing.Literal['gotcha!'] fff: int | typing.Literal['gotcha!'] control: float """) @@ -159,3 +161,12 @@ def test_type_strings_5(): def test_type_strings_6(): d = eval_typing(StrSlice[Literal["abcd"], Literal[1], Literal[None]]) assert d == Literal["bcd"] + + +def test_type_asdf(): + from typemap.typing import FromUnion + + d = eval_typing(FromUnion[int | bool]) + arg = FromUnion[int | str] + d = eval_typing(arg) + assert d == tuple[int, str] or d == tuple[str, int] diff --git a/typemap/type_eval/__init__.py b/typemap/type_eval/__init__.py index 8973014..b28eabf 100644 --- a/typemap/type_eval/__init__.py +++ b/typemap/type_eval/__init__.py @@ -1,11 +1,22 @@ +from ._eval_typing import ( + eval_typing, + _get_current_context, + register_evaluator, + _EvalProxy, +) + +# XXX: this needs to go second due to nasty circularity -- try to fix that!! from ._eval_call import eval_call -from ._eval_typing import eval_typing, _get_current_context, _EvalProxy from ._subtype import issubtype from ._subsim import issubsimilar +# This one is imported for registering handlers +from . import _eval_operators # noqa + __all__ = ( "eval_typing", + "register_evaluator", "eval_call", "issubtype", "issubsimilar", diff --git a/typemap/type_eval/_eval_call.py b/typemap/type_eval/_eval_call.py index 415fa26..248030d 100644 --- a/typemap/type_eval/_eval_call.py +++ b/typemap/type_eval/_eval_call.py @@ -27,7 +27,9 @@ def _eval_call( params = func.__type_params__ for p in params: if hasattr(p, "__bound__") and p.__bound__ is next.CallSpec: - vars[p.__name__] = next._CallSpecWrapper(args, kwargs, func) + vars[p.__name__] = next._CallSpecWrapper( + args, tuple(kwargs.items()), func + ) else: vars[p.__name__] = p diff --git a/typemap/type_eval/_eval_operators.py b/typemap/type_eval/_eval_operators.py new file mode 100644 index 0000000..0d80289 --- /dev/null +++ b/typemap/type_eval/_eval_operators.py @@ -0,0 +1,343 @@ +import functools +import inspect +import itertools +import types +import typing + +from typemap import type_eval +from typemap.type_eval import _typing_inspect +from typemap.type_eval._eval_typing import _eval_types + +from typemap.typing import ( + Attrs, + CallSpecKwargs, + _CallSpecWrapper, + Iter, + IsSubtype, + IsSubSimilar, + Member, + Members, + NewProtocol, + Param, + FromUnion, + GetArg, + GetAttr, + Capitalize, + Uncapitalize, + Uppercase, + Lowercase, + StrConcat, + StrSlice, +) + + +################################################################## + + +def _from_literal(val, ctx): + val = _eval_types(val, ctx) + if _typing_inspect.is_literal(val): + val = val.__args__[0] + return val + + +def get_annotated_type_hints(cls, **kwargs): + """Get the type hints for a cls annotated with definition site. + + This traverses the mro and finds the definition site for each annotation. + """ + ohints = typing.get_type_hints(cls, **kwargs) + hints = {} + for acls in cls.__mro__: + if not hasattr(acls, "__annotations__"): + continue + for k in acls.__annotations__: + if k not in hints: + hints[k] = ohints[k], acls + + # Stop early if we are done. + if len(hints) == len(ohints): + break + return hints + + +def _union_elems(tp, ctx): + tp = _eval_types(tp, ctx) + if isinstance(tp, types.UnionType): + return tuple(y for x in tp.__args__ for y in _union_elems(x, ctx)) + elif _typing_inspect.is_literal(tp) and len(tp.__args__) > 1: + return tuple(typing.Literal[x] for x in tp.__args__) + else: + return (tp,) + + +def _lift_over_unions(func): + @functools.wraps(func) + def wrapper(*args, ctx): + args2 = [_union_elems(x, ctx) for x in args] + # XXX: Never + parts = [func(*x, ctx=ctx) for x in itertools.product(*args2)] + return typing.Union[*parts] + + return wrapper + + +################################################################## + + +@type_eval.register_evaluator(Iter) +def _eval_Iter(tp, *, ctx): + tp = _eval_types(tp, ctx) + if ( + _typing_inspect.is_generic_alias(tp) + and tp.__origin__ is tuple + and (not tp.__args__ or tp.__args__[-1] is not Ellipsis) + ): + return iter(tp.__args__) + else: + # XXX: Or should we return []? + raise TypeError( + f"Invalid type argument to Iter: {tp} is not a fixed-length tuple" + ) + + +# N.B: These handle unions on their own + + +@type_eval.register_evaluator(IsSubtype) +def _eval_IsSubtype(lhs, rhs, *, ctx): + return type_eval.issubtype( + _eval_types(lhs, ctx), + _eval_types(rhs, ctx), + ) + + +@type_eval.register_evaluator(IsSubSimilar) +def _eval_IsSubSimilar(lhs, rhs, *, ctx): + return type_eval.issubsimilar( + _eval_types(lhs, ctx), + _eval_types(rhs, ctx), + ) + + +################################################################## + + +@type_eval.register_evaluator(CallSpecKwargs) +def _eval_CallSpecKwargs(spec: _CallSpecWrapper, *, ctx): + ff = types.FunctionType( + spec._func.__code__, + spec._func.__globals__, + spec._func.__name__, + None, + (), + ) + + # We can't call `inspect.signature` on `spec` directly -- + # signature() will attempt to resolve annotations and fail. + # So we run it on a copy of the function that doesn't have + # annotations set. + sig = inspect.signature(ff) + bound = sig.bind(*spec._args, **dict(spec._kwargs)) + + # TODO: Get the real type instead of Never + return tuple[ # type: ignore[misc] + *[ + Member[ + typing.Literal[name], # type: ignore[valid-type] + typing.Never, + ] + for name in bound.kwargs + ] + ] + + +################################################################## + + +def _function_type(func, *, is_method): + root = inspect.unwrap(func) + sig = inspect.signature(root) + # XXX: __type_params__!!! + + empty = inspect.Parameter.empty + + def _ann(x): + return typing.Any if x is empty else x + + params = [] + for _i, p in enumerate(sig.parameters.values()): + # XXX: what should we do about self? + # should we track classmethod/staticmethod somehow? + # mypy stores all this stuff in the SymbolNodes (FuncDef, etc), + # even though it kind of really is a type/descriptor thing + # if i == 0 and is_method: + # continue + has_name = p.kind in ( + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + ) + quals = [] + if p.kind == inspect.Parameter.VAR_POSITIONAL: + quals.append("*") + if p.kind == inspect.Parameter.VAR_KEYWORD: + quals.append("**") + if p.default is not empty: + quals.append("=") + params.append( + Param[ + typing.Literal[p.name if has_name else None], + _ann(p.annotation), + typing.Literal[*quals] if quals else typing.Never, + ] + ) + + return typing.Callable[params, _ann(sig.return_annotation)] + + +@type_eval.register_evaluator(Attrs) +def _eval_Attrs(tp, *, ctx): + hints = get_annotated_type_hints(tp, include_extras=True) + + return tuple[ + *[ + Member[typing.Literal[n], t, typing.Never, d] + for n, (t, d) in hints.items() + ] + ] + + +@type_eval.register_evaluator(Members) +@_lift_over_unions +def _eval_Members(tp, *, ctx): + hints = get_annotated_type_hints(tp, include_extras=True) + + attrs = [ + Member[typing.Literal[n], t, typing.Never, d] + for n, (t, d) in hints.items() + ] + + for name, attr in tp.__dict__.items(): + if isinstance(attr, (types.FunctionType, types.MethodType)): + if attr is typing._no_init_or_replace_init: + continue + + # XXX: populate the source field + attrs.append( + Member[ + typing.Literal[name], + _function_type(attr, is_method=True), + typing.Literal["ClassVar"], + ] + ) + + return tuple[*attrs] + + +################################################################## + + +@type_eval.register_evaluator(FromUnion) +def _eval_FromUnion(tp, *, ctx): + return tuple[*_union_elems(tp, ctx)] + + +################################################################## + + +@type_eval.register_evaluator(GetAttr) +@_lift_over_unions +def _eval_GetAttr(lhs, prop, *, ctx): + # TODO: the prop missing, etc! + # XXX: extras? + name = _from_literal(prop, ctx) + return typing.get_type_hints(lhs)[name] + + +def _get_args(tp, base, ctx) -> typing.Any: + # XXX: check against base!! + evaled = _eval_types(tp, ctx) + + tp_head = _typing_inspect.get_head(tp) + base_head = _typing_inspect.get_head(base) + # XXX: not sure this is what we want! + # at the very least we want unions I think + if not tp_head or not base_head: + return None + + if tp_head is base_head: + return typing.get_args(evaled) + + # Scan the fully-annotated MRO to find the base + elif gen_mro := getattr(evaled, "__generalized_mro__", None): + for box in gen_mro: + if box.cls is base_head: + return tuple(box.args.values()) + return None + + else: + # or error?? + return None + + +@type_eval.register_evaluator(GetArg) +@_lift_over_unions +def _eval_GetArg(tp, base, idx, *, ctx) -> typing.Any: + args = _get_args(tp, base, ctx) + if args is None: + return typing.Never + + try: + return args[_from_literal(idx, ctx)] + except IndexError: + return typing.Never + + +def _string_literal_op(typ, op): + @_lift_over_unions + def func(*args, ctx): + return typing.Literal[op(*[_from_literal(x, ctx) for x in args])] + + type_eval.register_evaluator(typ)(func) + + +_string_literal_op(Uppercase, op=str.upper) +_string_literal_op(Lowercase, op=str.lower) +_string_literal_op(Capitalize, op=str.capitalize) +_string_literal_op(Uncapitalize, op=lambda s: s[0:1].lower() + s[1:]) +_string_literal_op(StrConcat, op=lambda s, t: s + t) +_string_literal_op(StrSlice, op=lambda s, start, end: s[start:end]) + + +################################################################## + + +@type_eval.register_evaluator(NewProtocol) +def _eval_NewProtocol(*etyps: Member, ctx): + dct: dict[str, object] = {} + dct["__annotations__"] = { + # XXX: Should eval_typing on the etyps evaluate the arguments?? + _from_literal(typing.get_args(prop)[0], ctx): _eval_types( + typing.get_args(prop)[1], ctx + ) + for prop in etyps + } + + module_name = __name__ + name = "NewProtocol" + + # If the type evaluation context + ctx = type_eval._get_current_context() + if ctx.current_alias: + if isinstance(ctx.current_alias, types.GenericAlias): + name = str(ctx.current_alias) + else: + name = f"{ctx.current_alias.__name__}[...]" + module_name = ctx.current_alias.__module__ + + dct["__module__"] = module_name + + mcls: type = type(typing.cast(type, typing.Protocol)) + cls = mcls(name, (typing.Protocol,), dct) + cls = _eval_types(cls, ctx) + return cls diff --git a/typemap/type_eval/_eval_typing.py b/typemap/type_eval/_eval_typing.py index e68ebaa..69114c1 100644 --- a/typemap/type_eval/_eval_typing.py +++ b/typemap/type_eval/_eval_typing.py @@ -9,6 +9,8 @@ import types import typing +from typing import _GenericAlias # type: ignore [attr-defined] # noqa: PLC2701 + if typing.TYPE_CHECKING: from typing import Any @@ -19,6 +21,20 @@ __all__ = ("eval_typing",) +_eval_funcs: dict[type, typing.Callable[..., Any]] = {} + + +def register_evaluator[T: typing.Callable[..., Any]]( + typ: type, +) -> typing.Callable[[T], T]: + def func(f: T) -> T: + assert typ not in _eval_funcs + _eval_funcs[typ] = f + return f + + return func + + # Base type for the proxy classes we generate to hold __annotations__ class _EvalProxy: # Make sure __origin__ doesn't show up at runtime... @@ -31,6 +47,9 @@ class EvalContext: seen: dict[Any, Any] # The typing.Any is really a types.FunctionType, but mypy gets # confused and wants to treat it as a MethodType. + current_alias_stack: set[types.GenericAlias | typing.Any] = ( + dataclasses.field(default_factory=set) + ) current_alias: types.GenericAlias | typing.Any | None = None @@ -42,6 +61,8 @@ class EvalContext: @contextlib.contextmanager def _ensure_context() -> typing.Iterator[EvalContext]: + import typemap.typing as nt + ctx = _current_context.get() ctx_set = False if ctx is None: @@ -50,12 +71,16 @@ def _ensure_context() -> typing.Iterator[EvalContext]: ) _current_context.set(ctx) ctx_set = True + evaluator_token = nt.special_form_evaluator.set( + lambda t: _eval_types(t, ctx) + ) try: yield ctx finally: if ctx_set: _current_context.set(None) + nt.special_form_evaluator.reset(evaluator_token) def _get_current_context() -> EvalContext: @@ -73,6 +98,10 @@ def eval_typing(obj: typing.Any): def _eval_types(obj: typing.Any, ctx: EvalContext): + # Don't recurse into any pending alias expansion + if obj in ctx.current_alias_stack: + return obj + # strings match if obj in ctx.seen: return ctx.seen[obj] ctx.seen[obj] = evaled = _eval_types_impl(obj, ctx) @@ -148,12 +177,14 @@ def _eval_type_alias(obj: typing.TypeAliasType, ctx: EvalContext): @_eval_types_impl.register -def _eval_generic(obj: types.GenericAlias, ctx: EvalContext): +def _eval_types_generic(obj: types.GenericAlias, ctx: EvalContext): + new_args = tuple(_eval_types(arg, ctx) for arg in obj.__args__) + + new_obj = obj.__origin__[new_args] # type: ignore[index] if isinstance(obj.__origin__, type): # This is a GenericAlias over a Python class, e.g. `dict[str, int]` # Let's reconstruct it by evaluating all arguments - new_args = tuple(_eval_types(arg, ctx) for arg in obj.__args__) - return obj.__origin__[new_args] # type: ignore[index] + return new_obj func = obj.evaluate_value @@ -161,7 +192,8 @@ def _eval_generic(obj: types.GenericAlias, ctx: EvalContext): mod = sys.modules[obj.__module__] old_obj = ctx.current_alias - ctx.current_alias = obj + ctx.current_alias = new_obj # alias is the new_obj, so names look better + ctx.current_alias_stack.add(new_obj) try: ff = types.FunctionType(func.__code__, mod.__dict__, None, None, args) @@ -174,10 +206,25 @@ def _eval_generic(obj: types.GenericAlias, ctx: EvalContext): raise finally: ctx.current_alias = old_obj + ctx.current_alias_stack.remove(new_obj) return evaled +@_eval_types_impl.register +def _eval_typing_generic(obj: _GenericAlias, ctx: EvalContext): + # generic *classes* are typing._GenericAlias while generic type + # aliases are # types.GenericAlias? Why in the world. + if func := _eval_funcs.get(obj.__origin__): + new_args = tuple(_eval_types(arg, ctx) for arg in obj.__args__) + ret = func(*new_args, ctx=ctx) + # return _eval_types(ret, ctx) # ??? + return ret + + # TODO: Actually evaluate in this case! + return obj + + @_eval_types_impl.register def _eval_union(obj: typing.Union, ctx: EvalContext): # type: ignore args: typing.Sequence[typing.Any] = obj.__args__ diff --git a/typemap/typing.py b/typemap/typing.py index d6ee566..5b08b87 100644 --- a/typemap/typing.py +++ b/typemap/typing.py @@ -1,23 +1,14 @@ from dataclasses import dataclass -import functools -import inspect -import itertools +import contextvars import types import typing - -from typemap import type_eval -from typemap.type_eval import _typing_inspect +from typing import _GenericAlias # type: ignore _SpecialForm: typing.Any = typing._SpecialForm -class _NoCacheSpecialForm(_SpecialForm, _root=True): # type: ignore[call-arg] - def __getitem__(self, parameters): - return self._getitem(self, parameters) - - @dataclass(frozen=True) class CallSpec: pass @@ -26,7 +17,7 @@ class CallSpec: @dataclass(frozen=True) class _CallSpecWrapper: _args: tuple[typing.Any] - _kwargs: dict[str, typing.Any] + _kwargs: tuple[tuple[str, typing.Any], ...] # TODO: Support MethodType! _func: types.FunctionType # | types.MethodType @@ -39,45 +30,13 @@ def kwargs(self) -> None: pass -@_SpecialForm -def CallSpecKwargs(self, spec: _CallSpecWrapper): - ff = types.FunctionType( - spec._func.__code__, - spec._func.__globals__, - spec._func.__name__, - None, - (), - ) - - # We can't call `inspect.signature` on `spec` directly -- - # signature() will attempt to resolve annotations and fail. - # So we run it on a copy of the function that doesn't have - # annotations set. - sig = inspect.signature(ff) - bound = sig.bind(*spec._args, **spec._kwargs) - - # TODO: Get the real type instead of Never - return tuple[ # type: ignore[misc] - *[ - Member[ - typing.Literal[name], # type: ignore[valid-type] - typing.Never, - ] - for name in bound.kwargs - ] - ] +class CallSpecKwargs[Spec]: + pass ################################################################## -def _from_literal(val): - val = type_eval.eval_typing(val) - if _typing_inspect.is_literal(val): - val = val.__args__[0] - return val - - class Member[N: str, T, Q: str = typing.Never, D = typing.Never]: pass @@ -88,308 +47,97 @@ class Member[N: str, T, Q: str = typing.Never, D = typing.Never]: type GetDefiner[T: Member] = GetArg[T, Member, 3] # type: ignore[valid-type] -################################################################## - - -def get_annotated_type_hints(cls, **kwargs): - """Get the type hints for a cls annotated with definition site. - - This traverses the mro and finds the definition site for each annotation. - """ - ohints = typing.get_type_hints(cls, **kwargs) - hints = {} - for acls in cls.__mro__: - if not hasattr(acls, "__annotations__"): - continue - for k in acls.__annotations__: - if k not in hints: - hints[k] = ohints[k], acls - - # Stop early if we are done. - if len(hints) == len(ohints): - break - return hints - - -def _split_args(func): - @functools.wraps(func) - def wrapper(self, arg): - if isinstance(arg, tuple): - return func(self, *arg) - else: - return func(self, arg) - - return wrapper +class Attrs[T]: + pass -def _union_elems(tp): - tp = type_eval.eval_typing(tp) - if isinstance(tp, types.UnionType): - return tuple(y for x in tp.__args__ for y in _union_elems(x)) - elif _typing_inspect.is_literal(tp) and len(tp.__args__) > 1: - return tuple(typing.Literal[x] for x in tp.__args__) - else: - return (tp,) +class Param[N: str | None, T, Q: str = typing.Never]: + pass -def _lift_over_unions(func): - @functools.wraps(func) - @_split_args - def wrapper(self, *args): - args2 = [_union_elems(x) for x in args] - # XXX: Never - parts = [func(self, *x) for x in itertools.product(*args2)] - return typing.Union[*parts] +class Members[T]: + pass - return wrapper +class FromUnion[T]: + pass -@_SpecialForm -@_lift_over_unions -def Attrs(self, tp): - o = type_eval.eval_typing(tp) - hints = get_annotated_type_hints(o, include_extras=True) - return tuple[ - *[ - Member[typing.Literal[n], t, typing.Never, d] - for n, (t, d) in hints.items() - ] - ] +class GetAttr[Lhs, Prop]: + pass -class Param[N: str | None, T, Q: str = typing.Never]: +class GetArg[Tp, Base, Idx: int]: pass -def _function_type(func, *, is_method): - root = inspect.unwrap(func) - sig = inspect.signature(root) - # XXX: __type_params__!!! - - empty = inspect.Parameter.empty - - def _ann(x): - return typing.Any if x is empty else x - - params = [] - for _i, p in enumerate(sig.parameters.values()): - # XXX: what should we do about self? - # should we track classmethod/staticmethod somehow? - # mypy stores all this stuff in the SymbolNodes (FuncDef, etc), - # even though it kind of really is a type/descriptor thing - # if i == 0 and is_method: - # continue - has_name = p.kind in ( - inspect.Parameter.POSITIONAL_OR_KEYWORD, - inspect.Parameter.KEYWORD_ONLY, - ) - quals = [] - if p.kind == inspect.Parameter.VAR_POSITIONAL: - quals.append("*") - if p.kind == inspect.Parameter.VAR_KEYWORD: - quals.append("**") - if p.default is not empty: - quals.append("=") - params.append( - Param[ - typing.Literal[p.name if has_name else None], - _ann(p.annotation), - typing.Literal[*quals] if quals else typing.Never, - ] - ) - - return typing.Callable[params, _ann(sig.return_annotation)] +class Uppercase[S: str]: + pass -@_SpecialForm -@_lift_over_unions -def Members(self, tp): - o = type_eval.eval_typing(tp) - hints = get_annotated_type_hints(o, include_extras=True) +class Lowercase[S: str]: + pass - attrs = [ - Member[typing.Literal[n], t, typing.Never, d] - for n, (t, d) in hints.items() - ] - for name, attr in o.__dict__.items(): - if isinstance(attr, (types.FunctionType, types.MethodType)): - if attr is typing._no_init_or_replace_init: - continue +class Capitalize[S: str]: + pass - # XXX: populate the source field - attrs.append( - Member[ - typing.Literal[name], - _function_type(attr, is_method=True), - typing.Literal["ClassVar"], - ] - ) - return tuple[*attrs] +class Uncapitalize[S: str]: + pass -################################################################## +class StrConcat[S: str, T: str]: + pass -@_SpecialForm -def Iter(self, tp): - tp = type_eval.eval_typing(tp) - if ( - _typing_inspect.is_generic_alias(tp) - and tp.__origin__ is tuple - and (not tp.__args__ or tp.__args__[-1] is not Ellipsis) - ): - return tp.__args__ - else: - # XXX: Or should we return []? - raise TypeError( - f"Invalid type argument to Iter: {tp} is not a fixed-length tuple" - ) +class StrSlice[S: str, Start: int | None, End: int | None]: + pass -@_SpecialForm -def FromUnion(self, tp): - return tuple[*_union_elems(tp)] +class NewProtocol[*T]: + pass ################################################################## +# TODO: type better +special_form_evaluator: contextvars.ContextVar[ + typing.Callable[[typing.Any], typing.Any] | None +] = contextvars.ContextVar("special_form_evaluator", default=None) -@_SpecialForm -@_lift_over_unions -def GetAttr(self, lhs, prop): - # TODO: the prop missing, etc! - # XXX: extras? - name = _from_literal(type_eval.eval_typing(prop)) - return typing.get_type_hints(type_eval.eval_typing(lhs))[name] - - -def _get_args(tp, base) -> typing.Any: - # XXX: check against base!! - evaled = type_eval.eval_typing(tp) - - tp_head = _typing_inspect.get_head(tp) - base_head = _typing_inspect.get_head(base) - # XXX: not sure this is what we want! - # at the very least we want unions I think - if not tp_head or not base_head: - return None - if tp_head is base_head: - return typing.get_args(evaled) - - # Scan the fully-annotated MRO to find the base - elif gen_mro := getattr(evaled, "__generalized_mro__", None): - for box in gen_mro: - if box.cls is base_head: - return tuple(box.args.values()) - return None - - else: - # or error?? - return None +class _IterGenericAlias(_GenericAlias, _root=True): # type: ignore[call-arg] + def __iter__(self): + evaluator = special_form_evaluator.get() + if evaluator: + return evaluator(self) + else: + return iter(typing.TypeVarTuple("_IterDummy")) @_SpecialForm -@_lift_over_unions -def GetArg(self, tp, base, idx) -> typing.Any: - args = _get_args(tp, base) - if args is None: - return typing.Never - - try: - return args[_from_literal(idx)] - except IndexError: - return typing.Never - +def Iter(self, tp): + return _IterGenericAlias(self, (tp,)) -################################################################## -# N.B: These handle unions on their own +class _IsGenericAlias(_GenericAlias, _root=True): # type: ignore[call-arg] + def __bool__(self): + evaluator = special_form_evaluator.get() + if evaluator: + return evaluator(self) + else: + return False @_SpecialForm -@_split_args -def IsSubtype(self, lhs, rhs): - return type_eval.issubtype( - type_eval.eval_typing(lhs), - type_eval.eval_typing(rhs), - ) +def IsSubtype(self, tps): + return _IsGenericAlias(self, tps) @_SpecialForm -@_split_args -def IsSubSimilar(self, lhs, rhs): - return type_eval.issubsimilar( - type_eval.eval_typing(lhs), - type_eval.eval_typing(rhs), - ) +def IsSubSimilar(self, tps): + return _IsGenericAlias(self, tps) Is = IsSubSimilar - - -################################################################## - - -def _string_literal_op(op): - @_SpecialForm - @_lift_over_unions - def func(self, *args): - return typing.Literal[op(*[_from_literal(x) for x in args])] - - return func - - -Uppercase = _string_literal_op(op=str.upper) -Lowercase = _string_literal_op(op=str.lower) -Capitalize = _string_literal_op(op=str.capitalize) -Uncapitalize = _string_literal_op(op=lambda s: s[0:1].lower() + s[1:]) -StrConcat = _string_literal_op(op=lambda s, t: s + t) -StrSlice = _string_literal_op(op=lambda s, start, end: s[start:end]) - - -################################################################## - - -# XXX: We definitely can't use the normal _SpecialForm cache here -# directly, since we depend on the context's current_alias. -# Maybe we can add that to the cache, though. -# (Or maybe we need to never use the cache??) -@_NoCacheSpecialForm -def NewProtocol(self, val: Member | tuple[Member, ...]): - if not isinstance(val, tuple): - val = (val,) - - etyps = [type_eval.eval_typing(t) for t in val] - - dct: dict[str, object] = {} - dct["__annotations__"] = { - # XXX: Should eval_typing on the etyps evaluate the arguments?? - _from_literal(type_eval.eval_typing(typing.get_args(prop)[0])): - # XXX: We maybe (probably?) want to eval_typing the RHS, but - # we have infinite recursion issues in test_eval_types_2... - # type_eval.eval_typing(typing.get_args(prop)[1]) - typing.get_args(prop)[1] - for prop in etyps - } - - module_name = __name__ - name = "NewProtocol" - - # If the type evaluation context - ctx = type_eval._get_current_context() - if ctx.current_alias: - if isinstance(ctx.current_alias, types.GenericAlias): - name = str(ctx.current_alias) - else: - name = f"{ctx.current_alias.__name__}[...]" - module_name = ctx.current_alias.__module__ - - dct["__module__"] = module_name - - mcls: type = type(typing.cast(type, typing.Protocol)) - cls = mcls(name, (typing.Protocol,), dct) - return cls