diff --git a/tests/test_type_eval.py b/tests/test_type_eval.py index 9eee8ca..4a13b36 100644 --- a/tests/test_type_eval.py +++ b/tests/test_type_eval.py @@ -19,7 +19,10 @@ from typemap.type_eval import eval_typing from typemap.typing import ( + AllOf, + AnyOf, Attrs, + Bool, FromUnion, GenericCallable, GetArg, @@ -31,6 +34,7 @@ IsSub, Iter, Length, + Matches, Member, Members, NewProtocol, @@ -39,6 +43,7 @@ StrConcat, StrSlice, Uppercase, + _LiteralGeneric, ) from . import format_helper @@ -993,7 +998,100 @@ def test_uppercase_never(): def test_never_is(): d = eval_typing(IsSub[Never, Never]) - assert d is True + assert d == _LiteralGeneric[True] + + +def test_matches_01(): + d = eval_typing(Matches[int, int]) + assert d == _LiteralGeneric[True] + + d = eval_typing(Matches[int, str]) + assert d == _LiteralGeneric[False] + + d = eval_typing(Matches[str, int]) + assert d == _LiteralGeneric[False] + + +def test_matches_02(): + class A: + pass + + class B(A): + pass + + class C(B): + pass + + class D(A): + pass + + class X: + pass + + d = eval_typing(Matches[A, A]) + assert d == _LiteralGeneric[True] + + d = eval_typing(Matches[A, B]) + assert d == _LiteralGeneric[False] + d = eval_typing(Matches[B, A]) + assert d == _LiteralGeneric[False] + d = eval_typing(Matches[B, C]) + assert d == _LiteralGeneric[False] + d = eval_typing(Matches[C, B]) + assert d == _LiteralGeneric[False] + d = eval_typing(Matches[C, D]) + assert d == _LiteralGeneric[False] + d = eval_typing(Matches[D, C]) + assert d == _LiteralGeneric[False] + + d = eval_typing(Matches[A, X]) + assert d == _LiteralGeneric[False] + + +def test_matches_03(): + class A[T]: + pass + + class B[T](A[T]): + pass + + class C(B[int]): + pass + + class D(A[str]): + pass + + class X: + pass + + d = eval_typing(Matches[A[int], A[int]]) + assert d == _LiteralGeneric[True] + d = eval_typing(Matches[A[int], A[str]]) + assert d == _LiteralGeneric[True] + + d = eval_typing(Matches[A[int], B[int]]) + assert d == _LiteralGeneric[False] + d = eval_typing(Matches[B[int], A[int]]) + assert d == _LiteralGeneric[False] + d = eval_typing(Matches[A[int], B[str]]) + assert d == _LiteralGeneric[False] + d = eval_typing(Matches[B[str], A[int]]) + assert d == _LiteralGeneric[False] + d = eval_typing(Matches[B[int], C]) + assert d == _LiteralGeneric[False] + d = eval_typing(Matches[C, B[int]]) + assert d == _LiteralGeneric[False] + d = eval_typing(Matches[B[str], C]) + assert d == _LiteralGeneric[False] + d = eval_typing(Matches[C, B[str]]) + assert d == _LiteralGeneric[False] + d = eval_typing(Matches[C, D]) + assert d == _LiteralGeneric[False] + d = eval_typing(Matches[D, C]) + assert d == _LiteralGeneric[False] + + d = eval_typing(Matches[A[int], X]) + assert d == _LiteralGeneric[False] def test_eval_iter_01(): @@ -1026,6 +1124,321 @@ def test_eval_iter_02(): assert d == tuple[int, str, int, str] +type NotLiteralGeneric[T] = not T +type AndLiteralGeneric[L, R] = L and R +type OrLiteralGeneric[L, R] = L or R +type LiteralGenericToLiteral[T] = Literal[True] if T else Literal[False] +type NotLiteralGenericToLiteral[T] = Literal[True] if not T else Literal[False] + + +def test_eval_bool_01(): + d = eval_typing(Bool[Literal[True]]) + assert d == _LiteralGeneric[True] + + d = eval_typing(Bool[Literal[False]]) + assert d == _LiteralGeneric[False] + + d = eval_typing(Bool[Literal[1]]) + assert d == _LiteralGeneric[True] + + d = eval_typing(Bool[Literal[0]]) + assert d == _LiteralGeneric[False] + + d = eval_typing(Bool[Literal["true"]]) + assert d == _LiteralGeneric[True] + + d = eval_typing(Bool[Literal["false"]]) + assert d == _LiteralGeneric[True] + + +def test_eval_bool_02(): + d = eval_typing(Bool[_LiteralGeneric[True]]) + assert d == _LiteralGeneric[True] + + d = eval_typing(Bool[_LiteralGeneric[False]]) + assert d == _LiteralGeneric[False] + + +def test_eval_bool_03(): + d = eval_typing(NotLiteralGeneric[Bool[Literal[True]]]) + assert d == _LiteralGeneric[False] + + d = eval_typing(NotLiteralGeneric[Bool[Literal[False]]]) + assert d == _LiteralGeneric[True] + + +type NestedBool0[T] = Bool[T] +type NestedBool1[T] = NestedBool0[Bool[T]] +type NestedBool2[T] = NestedBool1[Bool[T]] +type NestedBool3[T] = NestedBool2[Bool[T]] +type NestedBool4[T] = NestedBool3[Bool[T]] +type NestedBool5[T] = NestedBool4[Bool[T]] + + +def test_eval_bool_04(): + d = eval_typing(NestedBool5[Literal[True]]) + assert d == _LiteralGeneric[True] + + d = eval_typing(NestedBool5[Literal[False]]) + assert d == _LiteralGeneric[False] + + +type IsIntBool[T] = Bool[IsSub[T, int]] +type IsIntLiteral[T] = Literal[True] if Bool[IsIntBool[T]] else Literal[False] + + +def test_eval_bool_05(): + d = eval_typing(IsIntLiteral[int]) + assert d == Literal[True] + + d = eval_typing(IsIntLiteral[str]) + assert d == Literal[False] + + +def test_eval_all_01(): + d = eval_typing(AllOf[()]) + assert d == _LiteralGeneric[True] + + d = eval_typing(AllOf[_LiteralGeneric[True]]) + assert d == _LiteralGeneric[True] + + d = eval_typing(AllOf[_LiteralGeneric[False]]) + assert d == _LiteralGeneric[False] + + d = eval_typing(AllOf[_LiteralGeneric[True], _LiteralGeneric[True]]) + assert d == _LiteralGeneric[True] + + d = eval_typing(AllOf[_LiteralGeneric[True], _LiteralGeneric[False]]) + assert d == _LiteralGeneric[False] + + d = eval_typing(AllOf[_LiteralGeneric[False], _LiteralGeneric[True]]) + assert d == _LiteralGeneric[False] + + d = eval_typing(AllOf[_LiteralGeneric[False], _LiteralGeneric[False]]) + assert d == _LiteralGeneric[False] + + +def test_eval_all_02(): + d = eval_typing(AllOf[Literal[True], Literal[True]]) + assert d == _LiteralGeneric[True] + + d = eval_typing(AllOf[Literal[True], Literal[False]]) + assert d == _LiteralGeneric[False] + + d = eval_typing(AllOf[Literal[False], Literal[True]]) + assert d == _LiteralGeneric[False] + + +type ContainsAllInt[Ts] = AllOf[*[IsSub[t, int] for t in Iter[Ts]]] +type ContainsAllIntToLiteral[Ts] = ( + Literal[True] if Bool[ContainsAllInt[Ts]] else Literal[False] +) + + +def test_eval_all_03(): + d = eval_typing(ContainsAllInt[tuple[()]]) + assert d == _LiteralGeneric[True] + + d = eval_typing(ContainsAllInt[tuple[int]]) + assert d == _LiteralGeneric[True] + + d = eval_typing(ContainsAllInt[tuple[str]]) + assert d == _LiteralGeneric[False] + + d = eval_typing(ContainsAllInt[tuple[int, int]]) + assert d == _LiteralGeneric[True] + + d = eval_typing(ContainsAllInt[tuple[int, str]]) + assert d == _LiteralGeneric[False] + + d = eval_typing(ContainsAllInt[tuple[str, str]]) + assert d == _LiteralGeneric[False] + + +def test_eval_all_04(): + d = eval_typing(ContainsAllIntToLiteral[tuple[()]]) + assert d == Literal[True] + + d = eval_typing(ContainsAllIntToLiteral[tuple[int]]) + assert d == Literal[True] + + d = eval_typing(ContainsAllIntToLiteral[tuple[str]]) + assert d == Literal[False] + + +def test_eval_any_01(): + d = eval_typing(AnyOf[()]) + assert d == _LiteralGeneric[False] + + d = eval_typing(AnyOf[_LiteralGeneric[True]]) + assert d == _LiteralGeneric[True] + + d = eval_typing(AnyOf[_LiteralGeneric[False]]) + assert d == _LiteralGeneric[False] + + d = eval_typing(AnyOf[_LiteralGeneric[True], _LiteralGeneric[True]]) + assert d == _LiteralGeneric[True] + + d = eval_typing(AnyOf[_LiteralGeneric[True], _LiteralGeneric[False]]) + assert d == _LiteralGeneric[True] + + d = eval_typing(AnyOf[_LiteralGeneric[False], _LiteralGeneric[True]]) + assert d == _LiteralGeneric[True] + + d = eval_typing(AnyOf[_LiteralGeneric[False], _LiteralGeneric[False]]) + assert d == _LiteralGeneric[False] + + +def test_eval_any_02(): + d = eval_typing(AnyOf[()]) + assert d == _LiteralGeneric[False] + + d = eval_typing(AnyOf[Literal[True]]) + assert d == _LiteralGeneric[True] + + d = eval_typing(AnyOf[Literal[False]]) + assert d == _LiteralGeneric[False] + + d = eval_typing(AnyOf[Literal[True], Literal[True]]) + assert d == _LiteralGeneric[True] + + d = eval_typing(AnyOf[Literal[True], Literal[False]]) + assert d == _LiteralGeneric[True] + + d = eval_typing(AnyOf[Literal[False], Literal[True]]) + assert d == _LiteralGeneric[True] + + d = eval_typing(AnyOf[Literal[False], Literal[False]]) + assert d == _LiteralGeneric[False] + + +type ContainsAnyInt[Ts] = AnyOf[*[IsSub[t, int] for t in Iter[Ts]]] +type ContainsAnyIntToLiteral[Ts] = ( + Literal[True] if Bool[ContainsAnyInt[Ts]] else Literal[False] +) + + +def test_eval_any_03(): + d = eval_typing(ContainsAnyInt[tuple[()]]) + assert d == _LiteralGeneric[False] + + d = eval_typing(ContainsAnyInt[tuple[int]]) + assert d == _LiteralGeneric[True] + + d = eval_typing(ContainsAnyInt[tuple[str]]) + assert d == _LiteralGeneric[False] + + d = eval_typing(ContainsAnyInt[tuple[int, int]]) + assert d == _LiteralGeneric[True] + + d = eval_typing(ContainsAnyInt[tuple[int, str]]) + assert d == _LiteralGeneric[True] + + d = eval_typing(ContainsAnyInt[tuple[str, str]]) + assert d == _LiteralGeneric[False] + + +def test_eval_any_04(): + d = eval_typing(ContainsAnyIntToLiteral[tuple[()]]) + assert d == Literal[False] + + d = eval_typing(ContainsAnyIntToLiteral[tuple[int]]) + assert d == Literal[True] + + d = eval_typing(ContainsAnyIntToLiteral[tuple[str]]) + assert d == Literal[False] + + +def test_eval_literal_generic_01(): + d = eval_typing(_LiteralGeneric[True]) + assert d == _LiteralGeneric[True] + + d = eval_typing(_LiteralGeneric[False]) + assert d == _LiteralGeneric[False] + + d = eval_typing(_LiteralGeneric[1]) + assert d == _LiteralGeneric[True] + + d = eval_typing(_LiteralGeneric[0]) + assert d == _LiteralGeneric[False] + + +def test_eval_literal_generic_02(): + d = eval_typing(not _LiteralGeneric[True]) + assert d == _LiteralGeneric[False] + + d = eval_typing(NotLiteralGeneric[_LiteralGeneric[True]]) + assert d == _LiteralGeneric[False] + + d = eval_typing(NotLiteralGeneric[_LiteralGeneric[False]]) + assert d == _LiteralGeneric[True] + + +def test_eval_literal_generic_03(): + d = eval_typing( + AndLiteralGeneric[_LiteralGeneric[True], _LiteralGeneric[True]] + ) + assert d == _LiteralGeneric[True] + + d = eval_typing( + AndLiteralGeneric[_LiteralGeneric[True], _LiteralGeneric[False]] + ) + assert d == _LiteralGeneric[False] + + d = eval_typing( + AndLiteralGeneric[_LiteralGeneric[False], _LiteralGeneric[True]] + ) + assert d == _LiteralGeneric[False] + + d = eval_typing( + AndLiteralGeneric[_LiteralGeneric[False], _LiteralGeneric[False]] + ) + assert d == _LiteralGeneric[False] + + +def test_eval_literal_generic_04(): + d = eval_typing( + OrLiteralGeneric[_LiteralGeneric[True], _LiteralGeneric[True]] + ) + assert d == _LiteralGeneric[True] + + d = eval_typing( + OrLiteralGeneric[_LiteralGeneric[True], _LiteralGeneric[False]] + ) + assert d == _LiteralGeneric[True] + + d = eval_typing( + OrLiteralGeneric[_LiteralGeneric[False], _LiteralGeneric[True]] + ) + assert d == _LiteralGeneric[True] + + d = eval_typing( + OrLiteralGeneric[_LiteralGeneric[False], _LiteralGeneric[False]] + ) + assert d == _LiteralGeneric[False] + + +def test_eval_literal_generic_05(): + d = eval_typing(LiteralGenericToLiteral[_LiteralGeneric[True]]) + assert d == Literal[True] + + d = eval_typing(LiteralGenericToLiteral[_LiteralGeneric[False]]) + assert d == Literal[False] + + +def test_eval_literal_generic_06(): + d = eval_typing(NotLiteralGenericToLiteral[_LiteralGeneric[True]]) + assert d == Literal[False] + + d = eval_typing(NotLiteralGenericToLiteral[_LiteralGeneric[False]]) + assert d == Literal[True] + + +def test_eval_literal_generic_error_01(): + with pytest.raises(TypeError, match="Expected literal type, got 'int'"): + eval_typing(_LiteralGeneric[int]) + + def test_eval_length_01(): d = eval_typing(Length[tuple[int, str]]) assert d == Literal[2] @@ -1052,7 +1465,9 @@ def test_eval_literal_idempotent_01(): def test_is_literal_true_vs_one(): - assert eval_typing(IsSub[Literal[True], Literal[1]]) is False + assert ( + eval_typing(IsSub[Literal[True], Literal[1]]) == _LiteralGeneric[False] + ) def test_callable_to_signature_01(): @@ -1220,7 +1635,7 @@ class AnnoTest: def test_type_eval_annotated_02(): res = eval_typing(IsSub[GetAttr[AnnoTest, Literal["a"]], int]) - assert res is True + assert res == _LiteralGeneric[True] def test_type_eval_annotated_03(): diff --git a/typemap/type_eval/_eval_operators.py b/typemap/type_eval/_eval_operators.py index 49a60ab..eff8c51 100644 --- a/typemap/type_eval/_eval_operators.py +++ b/typemap/type_eval/_eval_operators.py @@ -16,7 +16,10 @@ _get_class_type_hint_namespaces, ) from typemap.typing import ( + AllOf, + AnyOf, Attrs, + Bool, Capitalize, DropAnnotations, FromUnion, @@ -31,6 +34,7 @@ Iter, Length, Lowercase, + Matches, Member, Members, NewProtocol, @@ -40,7 +44,9 @@ StrSlice, Uncapitalize, Uppercase, + _LiteralGeneric, ) +from typemap.type_eval._wrapped_value import _BoolValue ################################################################## @@ -233,13 +239,47 @@ def _eval_Iter(tp, *, ctx): @type_eval.register_evaluator(IsSubtype) @_lift_evaluated def _eval_IsSubtype(lhs, rhs, *, ctx): - return type_eval.issubtype(lhs, rhs) + return _LiteralGeneric[type_eval.issubtype(lhs, rhs)] @type_eval.register_evaluator(IsSubSimilar) @_lift_evaluated def _eval_IsSubSimilar(lhs, rhs, *, ctx): - return type_eval.issubsimilar(lhs, rhs) + return _LiteralGeneric[type_eval.issubsimilar(lhs, rhs)] + + +@type_eval.register_evaluator(Matches) +@_lift_evaluated +def _eval_Matches(lhs, rhs, *, ctx): + return _LiteralGeneric[ + type_eval.issubsimilar(lhs, rhs) and type_eval.issubsimilar(rhs, lhs) + ] + + +@type_eval.register_evaluator(AllOf) +@_lift_evaluated +def _eval_AllOf(*tps, ctx): + return _LiteralGeneric[all(_eval_bool_tp(tp) for tp in tps)] + + +@type_eval.register_evaluator(AnyOf) +@_lift_evaluated +def _eval_AnyOf(*tps, ctx): + return _LiteralGeneric[any(_eval_bool_tp(tp) for tp in tps)] + + +@type_eval.register_evaluator(Bool) +@_lift_evaluated +def _eval_Bool(tp, *, ctx): + return _eval_bool_tp(tp) + + +def _eval_bool_tp(tp): + if typing.get_origin(tp) is typing.Literal: + return _LiteralGeneric[tp.__args__[0]] + elif isinstance(tp, _BoolValue._WrappedInstance): + return _LiteralGeneric[tp._value] + raise TypeError(f"Expected Literal type, got {tp}") ################################################################## diff --git a/typemap/type_eval/_eval_typing.py b/typemap/type_eval/_eval_typing.py index 2db3fd7..dae8291 100644 --- a/typemap/type_eval/_eval_typing.py +++ b/typemap/type_eval/_eval_typing.py @@ -21,7 +21,7 @@ if typing.TYPE_CHECKING: from typing import Any -from . import _apply_generic, _typing_inspect +from . import _apply_generic, _typing_inspect, _wrapped_value __all__ = ("eval_typing",) @@ -182,6 +182,16 @@ def eval_typing(obj: typing.Any): result = _eval_types(obj, ctx) if not isinstance(result, list) and result in ctx.known_recursive_types: result = ctx.known_recursive_types[result] + + if isinstance(result, bool): + # Wrap a boolean result with _LiteralGeneric + # This is because `not` calls `__bool__` first so a boolean + # expression like `not _LiteralGeneric[True]` will result `False`, + # not `_LiteralGeneric[False]` as we want. + from typemap.typing import _LiteralGeneric + + result = _LiteralGeneric[result] # type: ignore[valid-type] + return result @@ -400,7 +410,13 @@ def _eval_applied_type_alias(obj: types.GenericAlias, ctx: EvalContext): # Type alias types are already added in _eval_types child_ctx.alias_stack.add(new_obj) - ff = types.FunctionType(func.__code__, mod.__dict__, None, None, args) + ff = types.FunctionType( + func.__code__, + _GlobalsWrapper(mod.__dict__, child_ctx), + None, + None, + args, + ) unpacked = ff(annotationlib.Format.VALUE) child_ctx.seen[obj] = unpacked @@ -412,6 +428,50 @@ def _eval_applied_type_alias(obj: types.GenericAlias, ctx: EvalContext): return evaled +class _GlobalsWrapper(dict): + """Wraps module dict to make type aliases in type aliases evaluate + immediately. + + This allows us to ensure that generic aliases which resolve to + _LiteralGeneric are evaluated *before* they are used as booleans. + + For example, suppose we have: + + type BoolToLiteral[T] = Literal[True] if Bool[T] else Literal[False] + + Though `Bool` results in a `_LiteralGeneric`, it is not one itself. So when + used in `BoolToLiteral`, it will always evaluate to true. + """ + + def __init__(self, base_dict, ctx): + super().__init__(base_dict) + self._ctx = ctx + + def __getitem__(self, key): + value = super().__getitem__(key) + if isinstance(value, type) and issubclass( + value, _wrapped_value._BoolExpr + ): + return _GenericClassWrapper(value, self._ctx) + return value + + +class _GenericClassWrapper: + def __init__(self, generic_class, ctx): + self._generic_class = generic_class + self._ctx = ctx + + def __getitem__(self, item): + result = self._generic_class[item] + # Immediately evaluate the generic alias + if isinstance(result, (types.GenericAlias, typing._GenericAlias)): + return _eval_types(result, self._ctx) + return result + + def __getattr__(self, name): + return getattr(self._generic_class, name) + + @_eval_types_impl.register def _eval_applied_class(obj: typing_GenericAlias, ctx: EvalContext): """Eval a typing._GenericAlias -- an applied user-defined class""" diff --git a/typemap/type_eval/_wrapped_value.py b/typemap/type_eval/_wrapped_value.py new file mode 100644 index 0000000..0b6bbc8 --- /dev/null +++ b/typemap/type_eval/_wrapped_value.py @@ -0,0 +1,34 @@ +class _BoolExpr: + pass + + +class _BoolValue: + class _WrappedInstance: + def __init__(self, value: bool, type_name: str): + self._value = value + self._type_name = type_name + + def __bool__(self): + return self._value + + def __repr__(self): + return f"typemap.typing.{self._type_name}[{self._value!r}]" + + def __eq__(self, other): + return ( + isinstance(other, self.__class__) + and self._value == other._value + ) + + def __hash__(self): + return hash((self._type_name, self._value)) + + def __init_subclass__(cls): + cls.__true_instance = cls._WrappedInstance(True, cls.__name__) + cls.__false_instance = cls._WrappedInstance(False, cls.__name__) + + def __class_getitem__(cls, item): + if isinstance(item, type): + raise TypeError(f"Expected literal type, got '{item.__name__}'") + + return cls.__true_instance if bool(item) else cls.__false_instance diff --git a/typemap/typing.py b/typemap/typing.py index 021483d..eada354 100644 --- a/typemap/typing.py +++ b/typemap/typing.py @@ -3,6 +3,10 @@ from typing import Literal from typing import _GenericAlias # type: ignore + +from .type_eval._wrapped_value import _BoolExpr, _BoolValue + + _SpecialForm: typing.Any = typing._SpecialForm # Not type-level computation but related @@ -201,19 +205,37 @@ class _IsGenericAlias(_GenericAlias, _root=True): # type: ignore[call-arg] def __bool__(self): evaluator = special_form_evaluator.get() if evaluator: - return evaluator(self) + return bool(evaluator(self)) else: return False -@_SpecialForm -def IsSubtype(self, tps): - return _IsGenericAlias(self, tps) +class IsSubtype[Lhs, Rhs](_BoolExpr): + pass -@_SpecialForm -def IsSubSimilar(self, tps): - return _IsGenericAlias(self, tps) +class IsSubSimilar[Lhs, Rhs](_BoolExpr): + pass + + +class Matches[Lhs, Rhs](_BoolExpr): + pass IsSub = IsSubSimilar + + +class AllOf[*Ts](_BoolExpr): + pass + + +class AnyOf[*Ts](_BoolExpr): + pass + + +class Bool[T: typing.Literal[True, False]](_BoolExpr): + pass + + +class _LiteralGeneric[B: bool](_BoolValue): + pass