diff --git a/tests/test_type_eval.py b/tests/test_type_eval.py index 9eee8ca..759b525 100644 --- a/tests/test_type_eval.py +++ b/tests/test_type_eval.py @@ -19,7 +19,10 @@ from typemap.type_eval import eval_typing from typemap.typing import ( + AllOf, + AnyOf, Attrs, + Bool, FromUnion, GenericCallable, GetArg, @@ -31,6 +34,7 @@ IsSub, Iter, Length, + Matches, Member, Members, NewProtocol, @@ -39,6 +43,7 @@ StrConcat, StrSlice, Uppercase, + _BoolLiteral, ) from . import format_helper @@ -993,7 +998,107 @@ def test_uppercase_never(): def test_never_is(): d = eval_typing(IsSub[Never, Never]) - assert d is True + assert d == _BoolLiteral[True] + + +def test_eval_list_is_sub_01(): + d = eval_typing(list[IsSub[int, str]]) + assert d == list[_BoolLiteral[False]] + d = eval_typing(list[not IsSub[int, str]]) + assert d == list[_BoolLiteral[True]] + + +def test_matches_01(): + d = eval_typing(Matches[int, int]) + assert d == _BoolLiteral[True] + + d = eval_typing(Matches[int, str]) + assert d == _BoolLiteral[False] + + d = eval_typing(Matches[str, int]) + assert d == _BoolLiteral[False] + + +def test_matches_02(): + class A: + pass + + class B(A): + pass + + class C(B): + pass + + class D(A): + pass + + class X: + pass + + d = eval_typing(Matches[A, A]) + assert d == _BoolLiteral[True] + + d = eval_typing(Matches[A, B]) + assert d == _BoolLiteral[False] + d = eval_typing(Matches[B, A]) + assert d == _BoolLiteral[False] + d = eval_typing(Matches[B, C]) + assert d == _BoolLiteral[False] + d = eval_typing(Matches[C, B]) + assert d == _BoolLiteral[False] + d = eval_typing(Matches[C, D]) + assert d == _BoolLiteral[False] + d = eval_typing(Matches[D, C]) + assert d == _BoolLiteral[False] + + d = eval_typing(Matches[A, X]) + assert d == _BoolLiteral[False] + + +def test_matches_03(): + class A[T]: + pass + + class B[T](A[T]): + pass + + class C(B[int]): + pass + + class D(A[str]): + pass + + class X: + pass + + d = eval_typing(Matches[A[int], A[int]]) + assert d == _BoolLiteral[True] + d = eval_typing(Matches[A[int], A[str]]) + assert d == _BoolLiteral[True] + + d = eval_typing(Matches[A[int], B[int]]) + assert d == _BoolLiteral[False] + d = eval_typing(Matches[B[int], A[int]]) + assert d == _BoolLiteral[False] + d = eval_typing(Matches[A[int], B[str]]) + assert d == _BoolLiteral[False] + d = eval_typing(Matches[B[str], A[int]]) + assert d == _BoolLiteral[False] + d = eval_typing(Matches[B[int], C]) + assert d == _BoolLiteral[False] + d = eval_typing(Matches[C, B[int]]) + assert d == _BoolLiteral[False] + d = eval_typing(Matches[B[str], C]) + assert d == _BoolLiteral[False] + d = eval_typing(Matches[C, B[str]]) + assert d == _BoolLiteral[False] + d = eval_typing(Matches[C, D]) + assert d == _BoolLiteral[False] + d = eval_typing(Matches[D, C]) + assert d == _BoolLiteral[False] + + d = eval_typing(Matches[A[int], X]) + assert d == _BoolLiteral[False] def test_eval_iter_01(): @@ -1026,6 +1131,350 @@ def test_eval_iter_02(): assert d == tuple[int, str, int, str] +type NotLiteralGeneric[T] = not T +type AndLiteralGeneric[L, R] = L and R +type OrLiteralGeneric[L, R] = L or R +type LiteralGenericToLiteral[T] = Literal[True] if T else Literal[False] +type NotLiteralGenericToLiteral[T] = Literal[True] if not T else Literal[False] + + +def test_eval_bool_01(): + d = eval_typing(Bool[Literal[True]]) + assert d == _BoolLiteral[True] + + d = eval_typing(Bool[Literal[False]]) + assert d == _BoolLiteral[False] + + d = eval_typing(Bool[Literal[1]]) + assert d == _BoolLiteral[False] + + d = eval_typing(Bool[Literal[0]]) + assert d == _BoolLiteral[False] + + d = eval_typing(Bool[Literal["true"]]) + assert d == _BoolLiteral[False] + + d = eval_typing(Bool[Literal["false"]]) + assert d == _BoolLiteral[False] + + d = eval_typing(Bool[_BoolLiteral[True]]) + assert d == _BoolLiteral[True] + + d = eval_typing(Bool[_BoolLiteral[False]]) + assert d == _BoolLiteral[False] + + d = eval_typing(Bool[Never]) + assert d == _BoolLiteral[False] + + d = eval_typing(Bool[int]) + assert d == _BoolLiteral[False] + + class C: + pass + + d = eval_typing(Bool[C]) + assert d == _BoolLiteral[False] + + d = eval_typing(Bool[True]) + assert d == _BoolLiteral[True] + + d = eval_typing(Bool[False]) + assert d == _BoolLiteral[False] + + +def test_eval_bool_02(): + d = eval_typing(Bool[Literal[True] | Literal[False]]) + assert d == _BoolLiteral[True] + d = eval_typing(Bool[Literal[False] | Literal[True]]) + assert d == _BoolLiteral[True] + d = eval_typing(Bool[Literal[True] | Never]) + assert d == _BoolLiteral[True] + d = eval_typing(Bool[Never | Literal[True]]) + assert d == _BoolLiteral[True] + d = eval_typing(Bool[Literal[False] | Never]) + assert d == _BoolLiteral[False] + d = eval_typing(Bool[Never | Literal[False]]) + assert d == _BoolLiteral[False] + d = eval_typing(Bool[Literal[True] | int]) + assert d == _BoolLiteral[True] + d = eval_typing(Bool[int | Literal[True]]) + assert d == _BoolLiteral[True] + d = eval_typing(Bool[Literal[False] | int]) + assert d == _BoolLiteral[False] + d = eval_typing(Bool[int | Literal[False]]) + assert d == _BoolLiteral[False] + d = eval_typing(Bool[int | str]) + assert d == _BoolLiteral[False] + + +def test_eval_bool_03(): + d = eval_typing(NotLiteralGeneric[Bool[Literal[True]]]) + assert d == _BoolLiteral[False] + + d = eval_typing(NotLiteralGeneric[Bool[Literal[False]]]) + assert d == _BoolLiteral[True] + + +type NestedBool0[T] = Bool[T] +type NestedBool1[T] = NestedBool0[Bool[T]] +type NestedBool2[T] = NestedBool1[Bool[T]] +type NestedBool3[T] = NestedBool2[Bool[T]] +type NestedBool4[T] = NestedBool3[Bool[T]] +type NestedBool5[T] = NestedBool4[Bool[T]] + + +def test_eval_bool_04(): + d = eval_typing(NestedBool5[Literal[True]]) + assert d == _BoolLiteral[True] + + d = eval_typing(NestedBool5[Literal[False]]) + assert d == _BoolLiteral[False] + + +type IsIntBool[T] = Bool[IsSub[T, int]] +type IsIntLiteral[T] = Literal[True] if Bool[IsIntBool[T]] else Literal[False] + + +def test_eval_bool_05(): + d = eval_typing(IsIntLiteral[int]) + assert d == Literal[True] + + d = eval_typing(IsIntLiteral[str]) + assert d == Literal[False] + + +def test_eval_all_01(): + d = eval_typing(AllOf[()]) + assert d == _BoolLiteral[True] + + d = eval_typing(AllOf[_BoolLiteral[True]]) + assert d == _BoolLiteral[True] + d = eval_typing(AllOf[_BoolLiteral[False]]) + assert d == _BoolLiteral[False] + + d = eval_typing(AllOf[_BoolLiteral[True], _BoolLiteral[True]]) + assert d == _BoolLiteral[True] + d = eval_typing(AllOf[_BoolLiteral[True], _BoolLiteral[False]]) + assert d == _BoolLiteral[False] + d = eval_typing(AllOf[_BoolLiteral[False], _BoolLiteral[True]]) + assert d == _BoolLiteral[False] + d = eval_typing(AllOf[_BoolLiteral[False], _BoolLiteral[False]]) + assert d == _BoolLiteral[False] + + d = eval_typing(AllOf[Literal[True] | Literal[False]]) + assert d == _BoolLiteral[True] + d = eval_typing(AllOf[int | Never]) + assert d == _BoolLiteral[False] + d = eval_typing( + AllOf[Literal[0] | Literal[True], Literal[2] | Literal[True]] + ) + assert d == _BoolLiteral[True] + d = eval_typing(AllOf[Literal[0] | Literal[1], Literal[2] | Literal[True]]) + assert d == _BoolLiteral[False] + d = eval_typing(AllOf[Literal[0] | Literal[1], Literal[2] | Literal[3]]) + assert d == _BoolLiteral[False] + + +def test_eval_all_02(): + d = eval_typing(AllOf[()]) + assert d == _BoolLiteral[True] + + d = eval_typing(AllOf[Literal[True]]) + assert d == _BoolLiteral[True] + d = eval_typing(AllOf[Literal[False]]) + assert d == _BoolLiteral[False] + + d = eval_typing(AllOf[Literal[True], Literal[True]]) + assert d == _BoolLiteral[True] + d = eval_typing(AllOf[Literal[True], Literal[False]]) + assert d == _BoolLiteral[False] + d = eval_typing(AllOf[Literal[False], Literal[True]]) + assert d == _BoolLiteral[False] + + +type ContainsAllInt[Ts] = AllOf[*[IsSub[t, int] for t in Iter[Ts]]] +type ContainsAllIntToLiteral[Ts] = ( + Literal[True] if Bool[ContainsAllInt[Ts]] else Literal[False] +) + + +def test_eval_all_03(): + d = eval_typing(ContainsAllInt[tuple[()]]) + assert d == _BoolLiteral[True] + + d = eval_typing(ContainsAllInt[tuple[int]]) + assert d == _BoolLiteral[True] + d = eval_typing(ContainsAllInt[tuple[str]]) + assert d == _BoolLiteral[False] + + d = eval_typing(ContainsAllInt[tuple[int, int]]) + assert d == _BoolLiteral[True] + d = eval_typing(ContainsAllInt[tuple[int, str]]) + assert d == _BoolLiteral[False] + d = eval_typing(ContainsAllInt[tuple[str, str]]) + assert d == _BoolLiteral[False] + + +def test_eval_all_04(): + d = eval_typing(ContainsAllIntToLiteral[tuple[()]]) + assert d == Literal[True] + + d = eval_typing(ContainsAllIntToLiteral[tuple[int]]) + assert d == Literal[True] + d = eval_typing(ContainsAllIntToLiteral[tuple[str]]) + assert d == Literal[False] + + +def test_eval_any_01(): + d = eval_typing(AnyOf[()]) + assert d == _BoolLiteral[False] + + d = eval_typing(AnyOf[_BoolLiteral[True]]) + assert d == _BoolLiteral[True] + d = eval_typing(AnyOf[_BoolLiteral[False]]) + assert d == _BoolLiteral[False] + + d = eval_typing(AnyOf[_BoolLiteral[True], _BoolLiteral[True]]) + assert d == _BoolLiteral[True] + d = eval_typing(AnyOf[_BoolLiteral[True], _BoolLiteral[False]]) + assert d == _BoolLiteral[True] + d = eval_typing(AnyOf[_BoolLiteral[False], _BoolLiteral[True]]) + assert d == _BoolLiteral[True] + d = eval_typing(AnyOf[_BoolLiteral[False], _BoolLiteral[False]]) + assert d == _BoolLiteral[False] + + d = eval_typing(AnyOf[Literal[True] | Literal[False]]) + assert d == _BoolLiteral[True] + d = eval_typing(AnyOf[int | Never]) + assert d == _BoolLiteral[False] + d = eval_typing( + AnyOf[Literal[0] | Literal[True], Literal[2] | Literal[True]] + ) + assert d == _BoolLiteral[True] + d = eval_typing(AnyOf[Literal[0] | Literal[1], Literal[2] | Literal[True]]) + assert d == _BoolLiteral[True] + d = eval_typing(AnyOf[Literal[0] | Literal[1], Literal[2] | Literal[3]]) + assert d == _BoolLiteral[False] + + +def test_eval_any_02(): + d = eval_typing(AnyOf[()]) + assert d == _BoolLiteral[False] + + d = eval_typing(AnyOf[Literal[True]]) + assert d == _BoolLiteral[True] + d = eval_typing(AnyOf[Literal[False]]) + assert d == _BoolLiteral[False] + + d = eval_typing(AnyOf[Literal[True], Literal[True]]) + assert d == _BoolLiteral[True] + d = eval_typing(AnyOf[Literal[True], Literal[False]]) + assert d == _BoolLiteral[True] + d = eval_typing(AnyOf[Literal[False], Literal[True]]) + assert d == _BoolLiteral[True] + d = eval_typing(AnyOf[Literal[False], Literal[False]]) + assert d == _BoolLiteral[False] + + +type ContainsAnyInt[Ts] = AnyOf[*[IsSub[t, int] for t in Iter[Ts]]] +type ContainsAnyIntToLiteral[Ts] = ( + Literal[True] if Bool[ContainsAnyInt[Ts]] else Literal[False] +) + + +def test_eval_any_03(): + d = eval_typing(ContainsAnyInt[tuple[()]]) + assert d == _BoolLiteral[False] + + d = eval_typing(ContainsAnyInt[tuple[int]]) + assert d == _BoolLiteral[True] + d = eval_typing(ContainsAnyInt[tuple[str]]) + assert d == _BoolLiteral[False] + + d = eval_typing(ContainsAnyInt[tuple[int, int]]) + assert d == _BoolLiteral[True] + d = eval_typing(ContainsAnyInt[tuple[int, str]]) + assert d == _BoolLiteral[True] + d = eval_typing(ContainsAnyInt[tuple[str, str]]) + assert d == _BoolLiteral[False] + + +def test_eval_any_04(): + d = eval_typing(ContainsAnyIntToLiteral[tuple[()]]) + assert d == Literal[False] + + d = eval_typing(ContainsAnyIntToLiteral[tuple[int]]) + assert d == Literal[True] + d = eval_typing(ContainsAnyIntToLiteral[tuple[str]]) + assert d == Literal[False] + + +def test_eval_literal_generic_01(): + d = eval_typing(_BoolLiteral[True]) + assert d == _BoolLiteral[True] + d = eval_typing(_BoolLiteral[False]) + assert d == _BoolLiteral[False] + d = eval_typing(_BoolLiteral[1]) + assert d == _BoolLiteral[True] + d = eval_typing(_BoolLiteral[0]) + assert d == _BoolLiteral[False] + d = eval_typing(_BoolLiteral[_BoolLiteral[True]]) + assert d == _BoolLiteral[True] + d = eval_typing(_BoolLiteral[_BoolLiteral[False]]) + assert d == _BoolLiteral[False] + + +def test_eval_literal_generic_02(): + d = eval_typing(not _BoolLiteral[True]) + assert d == _BoolLiteral[False] + + d = eval_typing(NotLiteralGeneric[_BoolLiteral[True]]) + assert d == _BoolLiteral[False] + d = eval_typing(NotLiteralGeneric[_BoolLiteral[False]]) + assert d == _BoolLiteral[True] + + +def test_eval_literal_generic_03(): + d = eval_typing(AndLiteralGeneric[_BoolLiteral[True], _BoolLiteral[True]]) + assert d == _BoolLiteral[True] + d = eval_typing(AndLiteralGeneric[_BoolLiteral[True], _BoolLiteral[False]]) + assert d == _BoolLiteral[False] + d = eval_typing(AndLiteralGeneric[_BoolLiteral[False], _BoolLiteral[True]]) + assert d == _BoolLiteral[False] + d = eval_typing(AndLiteralGeneric[_BoolLiteral[False], _BoolLiteral[False]]) + assert d == _BoolLiteral[False] + + +def test_eval_literal_generic_04(): + d = eval_typing(OrLiteralGeneric[_BoolLiteral[True], _BoolLiteral[True]]) + assert d == _BoolLiteral[True] + d = eval_typing(OrLiteralGeneric[_BoolLiteral[True], _BoolLiteral[False]]) + assert d == _BoolLiteral[True] + d = eval_typing(OrLiteralGeneric[_BoolLiteral[False], _BoolLiteral[True]]) + assert d == _BoolLiteral[True] + d = eval_typing(OrLiteralGeneric[_BoolLiteral[False], _BoolLiteral[False]]) + assert d == _BoolLiteral[False] + + +def test_eval_literal_generic_05(): + d = eval_typing(LiteralGenericToLiteral[_BoolLiteral[True]]) + assert d == Literal[True] + d = eval_typing(LiteralGenericToLiteral[_BoolLiteral[False]]) + assert d == Literal[False] + + +def test_eval_literal_generic_06(): + d = eval_typing(NotLiteralGenericToLiteral[_BoolLiteral[True]]) + assert d == Literal[False] + d = eval_typing(NotLiteralGenericToLiteral[_BoolLiteral[False]]) + assert d == Literal[True] + + +def test_eval_literal_generic_error_01(): + with pytest.raises(TypeError, match="Expected literal type, got 'int'"): + eval_typing(_BoolLiteral[int]) + + def test_eval_length_01(): d = eval_typing(Length[tuple[int, str]]) assert d == Literal[2] @@ -1052,7 +1501,7 @@ def test_eval_literal_idempotent_01(): def test_is_literal_true_vs_one(): - assert eval_typing(IsSub[Literal[True], Literal[1]]) is False + assert eval_typing(IsSub[Literal[True], Literal[1]]) == _BoolLiteral[False] def test_callable_to_signature_01(): @@ -1220,7 +1669,7 @@ class AnnoTest: def test_type_eval_annotated_02(): res = eval_typing(IsSub[GetAttr[AnnoTest, Literal["a"]], int]) - assert res is True + assert res == _BoolLiteral[True] def test_type_eval_annotated_03(): diff --git a/typemap/type_eval/_eval_operators.py b/typemap/type_eval/_eval_operators.py index 49a60ab..e9c9301 100644 --- a/typemap/type_eval/_eval_operators.py +++ b/typemap/type_eval/_eval_operators.py @@ -16,7 +16,10 @@ _get_class_type_hint_namespaces, ) from typemap.typing import ( + AllOf, + AnyOf, Attrs, + Bool, Capitalize, DropAnnotations, FromUnion, @@ -31,6 +34,7 @@ Iter, Length, Lowercase, + Matches, Member, Members, NewProtocol, @@ -40,6 +44,7 @@ StrSlice, Uncapitalize, Uppercase, + _BoolLiteral, ) ################################################################## @@ -233,13 +238,67 @@ def _eval_Iter(tp, *, ctx): @type_eval.register_evaluator(IsSubtype) @_lift_evaluated def _eval_IsSubtype(lhs, rhs, *, ctx): - return type_eval.issubtype(lhs, rhs) + return _BoolLiteral[type_eval.issubtype(lhs, rhs)] @type_eval.register_evaluator(IsSubSimilar) @_lift_evaluated def _eval_IsSubSimilar(lhs, rhs, *, ctx): - return type_eval.issubsimilar(lhs, rhs) + return _BoolLiteral[type_eval.issubsimilar(lhs, rhs)] + + +@type_eval.register_evaluator(Matches) +@_lift_evaluated +def _eval_Matches(lhs, rhs, *, ctx): + return _BoolLiteral[ + type_eval.issubsimilar(lhs, rhs) and type_eval.issubsimilar(rhs, lhs) + ] + + +def _eval_bool_tp(tp, ctx): + if _typing_inspect.is_generic_alias(tp) and tp.__origin__ is _BoolLiteral: + return _BoolLiteral[bool(tp.__args__[0])] + else: + return _BoolLiteral[ + any( + type_eval.issubsimilar(arg, typing.Literal[True]) + and not type_eval.issubsimilar(arg, typing.Never) + for arg in _union_elems(tp, ctx) + ) + ] + + +@type_eval.register_evaluator(Bool) +@_lift_evaluated +def _eval_Bool(tp, *, ctx): + return _eval_bool_tp(tp, ctx) + + +@type_eval.register_evaluator(AllOf) +@_lift_evaluated +def _eval_AllOf(*tp, ctx): + return _BoolLiteral[all(_eval_bool_tp(tp, ctx) for tp in tp)] + + +@type_eval.register_evaluator(AnyOf) +@_lift_evaluated +def _eval_AnyOf(*tp, ctx): + return _BoolLiteral[any(_eval_bool_tp(tp, ctx) for tp in tp)] + + +@type_eval.register_evaluator(_BoolLiteral) +@_lift_evaluated +def _eval_BoolLiteral(tp, *, ctx): + from typemap.typing import _BoolLiteralGenericAlias + + if isinstance(tp, type): + raise TypeError(f"Expected literal type, got '{tp.__name__}'") + + # If already wrapped, just return it + if isinstance(tp, _BoolLiteralGenericAlias): + return tp + + return _BoolLiteral[tp] ################################################################## diff --git a/typemap/type_eval/_eval_typing.py b/typemap/type_eval/_eval_typing.py index 2db3fd7..29efd8b 100644 --- a/typemap/type_eval/_eval_typing.py +++ b/typemap/type_eval/_eval_typing.py @@ -247,6 +247,15 @@ def _eval_types(obj: typing.Any, ctx: EvalContext): ctx.resolved |= {x: x for x in child_ctx.known_recursive_types.keys()} ctx.known_recursive_types |= child_ctx.known_recursive_types + if isinstance(evaled, bool): + # Wrap a boolean result with _BoolLiteral + # This is because `not` calls `__bool__` first so a boolean + # expression like `not _BoolLiteral[True]` will result `False`, + # not `_BoolLiteral[False]` as we want. + import typemap.typing as nt + + evaled = nt._BoolLiteral[evaled] + # Don't cache iterators as they are stateful and can only be consumed once. # This is important for Iter results that may be used multiple times. if not isinstance(evaled, collections.abc.Iterator): @@ -337,6 +346,13 @@ def _eval_type_var(obj: typing.TypeVar, ctx: EvalContext): # do there, and doing it puts weird stuff in the caches. @_eval_types_impl.register def _eval_literal(obj: typing_LiteralGenericAlias, ctx: EvalContext): + from typemap.typing import _BoolLiteralGenericAlias + + if isinstance(obj, _BoolLiteralGenericAlias): + # If this is _BoolLiteralGenericAlias, defer to the registered evaluator + if func := _eval_funcs.get(obj.__origin__): + return func(*typing.get_args(obj), ctx=ctx) + return obj diff --git a/typemap/typing.py b/typemap/typing.py index 021483d..51e3955 100644 --- a/typemap/typing.py +++ b/typemap/typing.py @@ -1,7 +1,7 @@ import contextvars import typing from typing import Literal -from typing import _GenericAlias # type: ignore +from typing import _GenericAlias, _LiteralGenericAlias # type: ignore _SpecialForm: typing.Any = typing._SpecialForm @@ -197,23 +197,55 @@ def Iter(self, tp): return _IterGenericAlias(self, (tp,)) -class _IsGenericAlias(_GenericAlias, _root=True): # type: ignore[call-arg] +class _BoolGenericAlias(_GenericAlias, _root=True): # type: ignore[call-arg] def __bool__(self): evaluator = special_form_evaluator.get() if evaluator: - return evaluator(self) + result = evaluator(self) + # Unwrap _LiteralGeneric + return bool(result) else: return False @_SpecialForm def IsSubtype(self, tps): - return _IsGenericAlias(self, tps) + return _BoolGenericAlias(self, tps) @_SpecialForm def IsSubSimilar(self, tps): - return _IsGenericAlias(self, tps) + return _BoolGenericAlias(self, tps) + + +@_SpecialForm +def Matches(self, tps): + return _BoolGenericAlias(self, tps) IsSub = IsSubSimilar + + +@_SpecialForm +def Bool(self, tp): + return _BoolGenericAlias(self, tp) + + +@_SpecialForm +def AllOf(self, tp): + return _BoolGenericAlias(self, tp) + + +@_SpecialForm +def AnyOf(self, tp): + return _BoolGenericAlias(self, tp) + + +class _BoolLiteralGenericAlias(_LiteralGenericAlias, _root=True): # type: ignore[call-arg] + def __bool__(self): + return typing.get_args(self)[0] + + +@_SpecialForm +def _BoolLiteral(self, tp): + return _BoolLiteralGenericAlias(self, tp)