diff --git a/tests/test_type_eval.py b/tests/test_type_eval.py index f00d9ce..5d954a8 100644 --- a/tests/test_type_eval.py +++ b/tests/test_type_eval.py @@ -31,6 +31,7 @@ GetSpecialAttr, GetType, GetAnnotations, + IsSubtype, IsSub, Iter, Length, @@ -1334,7 +1335,7 @@ def test_eval_bool_05(): assert d == Literal[False] -def test_eval_literal_generic_01(): +def test_eval_bool_literal_01(): d = eval_typing(_BoolLiteral[True]) assert d == _BoolLiteral[True] d = eval_typing(_BoolLiteral[False]) @@ -1349,7 +1350,7 @@ def test_eval_literal_generic_01(): assert d == _BoolLiteral[False] -def test_eval_literal_generic_02(): +def test_eval_bool_literal_02(): d = eval_typing(not _BoolLiteral[True]) assert d == _BoolLiteral[False] @@ -1359,7 +1360,7 @@ def test_eval_literal_generic_02(): assert d == _BoolLiteral[True] -def test_eval_literal_generic_03(): +def test_eval_bool_literal_03(): d = eval_typing(AndLiteralGeneric[_BoolLiteral[True], _BoolLiteral[True]]) assert d == _BoolLiteral[True] d = eval_typing(AndLiteralGeneric[_BoolLiteral[True], _BoolLiteral[False]]) @@ -1370,7 +1371,7 @@ def test_eval_literal_generic_03(): assert d == _BoolLiteral[False] -def test_eval_literal_generic_04(): +def test_eval_bool_literal_04(): d = eval_typing(OrLiteralGeneric[_BoolLiteral[True], _BoolLiteral[True]]) assert d == _BoolLiteral[True] d = eval_typing(OrLiteralGeneric[_BoolLiteral[True], _BoolLiteral[False]]) @@ -1381,21 +1382,48 @@ def test_eval_literal_generic_04(): assert d == _BoolLiteral[False] -def test_eval_literal_generic_05(): +def test_eval_bool_literal_05(): d = eval_typing(LiteralGenericToLiteral[_BoolLiteral[True]]) assert d == Literal[True] d = eval_typing(LiteralGenericToLiteral[_BoolLiteral[False]]) assert d == Literal[False] -def test_eval_literal_generic_06(): +def test_eval_bool_literal_06(): d = eval_typing(NotLiteralGenericToLiteral[_BoolLiteral[True]]) assert d == Literal[False] d = eval_typing(NotLiteralGenericToLiteral[_BoolLiteral[False]]) assert d == Literal[True] -def test_eval_literal_generic_error_01(): +def test_eval_bool_literal_07(): + d = eval_typing(IsSub[_BoolLiteral[True], Literal[True]]) + assert d == _BoolLiteral[True] + d = eval_typing(IsSub[_BoolLiteral[False], Literal[False]]) + assert d == _BoolLiteral[True] + + d = eval_typing(IsSub[Literal[True], _BoolLiteral[True]]) + assert d == _BoolLiteral[True] + d = eval_typing(IsSub[Literal[False], _BoolLiteral[False]]) + assert d == _BoolLiteral[True] + + d = eval_typing(IsSubtype[_BoolLiteral[True], Literal[True]]) + assert d == _BoolLiteral[True] + d = eval_typing(IsSubtype[_BoolLiteral[False], Literal[False]]) + assert d == _BoolLiteral[True] + + d = eval_typing(IsSubtype[Literal[True], _BoolLiteral[True]]) + assert d == _BoolLiteral[True] + d = eval_typing(IsSubtype[Literal[False], _BoolLiteral[False]]) + assert d == _BoolLiteral[True] + + d = eval_typing(Matches[_BoolLiteral[True], Literal[True]]) + assert d == _BoolLiteral[True] + d = eval_typing(Matches[_BoolLiteral[False], Literal[False]]) + assert d == _BoolLiteral[True] + + +def test_eval_bool_literal_error_01(): with pytest.raises(TypeError, match="Expected literal type, got 'int'"): eval_typing(_BoolLiteral[int]) diff --git a/typemap/type_eval/_eval_operators.py b/typemap/type_eval/_eval_operators.py index 4ebb4ca..0a89de6 100644 --- a/typemap/type_eval/_eval_operators.py +++ b/typemap/type_eval/_eval_operators.py @@ -275,21 +275,6 @@ def _eval_Bool(tp, *, ctx): return _eval_bool_tp(tp, ctx) -@type_eval.register_evaluator(_BoolLiteral) -@_lift_evaluated -def _eval_BoolLiteral(tp, *, ctx): - from typemap.typing import _BoolLiteralGenericAlias - - if isinstance(tp, type): - raise TypeError(f"Expected literal type, got '{tp.__name__}'") - - # If already wrapped, just return it - if isinstance(tp, _BoolLiteralGenericAlias): - return tp - - return _BoolLiteral[tp] - - ################################################################## diff --git a/typemap/typing.py b/typemap/typing.py index 4c7168a..670550d 100644 --- a/typemap/typing.py +++ b/typemap/typing.py @@ -307,4 +307,11 @@ def __bool__(self): @_SpecialForm def _BoolLiteral(self, tp): - return _BoolLiteralGenericAlias(self, tp) + if isinstance(tp, type): + raise TypeError(f"Expected literal type, got '{tp.__name__}'") + + # If already wrapped, just return it + if isinstance(tp, _BoolLiteralGenericAlias): + return tp + + return _BoolLiteralGenericAlias(Literal, tp)