Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 35 additions & 7 deletions tests/test_type_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
GetSpecialAttr,
GetType,
GetAnnotations,
IsSubtype,
IsSub,
Iter,
Length,
Expand Down Expand Up @@ -1334,7 +1335,7 @@ def test_eval_bool_05():
assert d == Literal[False]


def test_eval_literal_generic_01():
def test_eval_bool_literal_01():
d = eval_typing(_BoolLiteral[True])
assert d == _BoolLiteral[True]
d = eval_typing(_BoolLiteral[False])
Expand All @@ -1349,7 +1350,7 @@ def test_eval_literal_generic_01():
assert d == _BoolLiteral[False]


def test_eval_literal_generic_02():
def test_eval_bool_literal_02():
d = eval_typing(not _BoolLiteral[True])
assert d == _BoolLiteral[False]

Expand All @@ -1359,7 +1360,7 @@ def test_eval_literal_generic_02():
assert d == _BoolLiteral[True]


def test_eval_literal_generic_03():
def test_eval_bool_literal_03():
d = eval_typing(AndLiteralGeneric[_BoolLiteral[True], _BoolLiteral[True]])
assert d == _BoolLiteral[True]
d = eval_typing(AndLiteralGeneric[_BoolLiteral[True], _BoolLiteral[False]])
Expand All @@ -1370,7 +1371,7 @@ def test_eval_literal_generic_03():
assert d == _BoolLiteral[False]


def test_eval_literal_generic_04():
def test_eval_bool_literal_04():
d = eval_typing(OrLiteralGeneric[_BoolLiteral[True], _BoolLiteral[True]])
assert d == _BoolLiteral[True]
d = eval_typing(OrLiteralGeneric[_BoolLiteral[True], _BoolLiteral[False]])
Expand All @@ -1381,21 +1382,48 @@ def test_eval_literal_generic_04():
assert d == _BoolLiteral[False]


def test_eval_literal_generic_05():
def test_eval_bool_literal_05():
d = eval_typing(LiteralGenericToLiteral[_BoolLiteral[True]])
assert d == Literal[True]
d = eval_typing(LiteralGenericToLiteral[_BoolLiteral[False]])
assert d == Literal[False]


def test_eval_literal_generic_06():
def test_eval_bool_literal_06():
d = eval_typing(NotLiteralGenericToLiteral[_BoolLiteral[True]])
assert d == Literal[False]
d = eval_typing(NotLiteralGenericToLiteral[_BoolLiteral[False]])
assert d == Literal[True]


def test_eval_literal_generic_error_01():
def test_eval_bool_literal_07():
d = eval_typing(IsSub[_BoolLiteral[True], Literal[True]])
assert d == _BoolLiteral[True]
d = eval_typing(IsSub[_BoolLiteral[False], Literal[False]])
assert d == _BoolLiteral[True]

d = eval_typing(IsSub[Literal[True], _BoolLiteral[True]])
assert d == _BoolLiteral[True]
d = eval_typing(IsSub[Literal[False], _BoolLiteral[False]])
assert d == _BoolLiteral[True]

d = eval_typing(IsSubtype[_BoolLiteral[True], Literal[True]])
assert d == _BoolLiteral[True]
d = eval_typing(IsSubtype[_BoolLiteral[False], Literal[False]])
assert d == _BoolLiteral[True]

d = eval_typing(IsSubtype[Literal[True], _BoolLiteral[True]])
assert d == _BoolLiteral[True]
d = eval_typing(IsSubtype[Literal[False], _BoolLiteral[False]])
assert d == _BoolLiteral[True]

d = eval_typing(Matches[_BoolLiteral[True], Literal[True]])
assert d == _BoolLiteral[True]
d = eval_typing(Matches[_BoolLiteral[False], Literal[False]])
assert d == _BoolLiteral[True]


def test_eval_bool_literal_error_01():
with pytest.raises(TypeError, match="Expected literal type, got 'int'"):
eval_typing(_BoolLiteral[int])

Expand Down
15 changes: 0 additions & 15 deletions typemap/type_eval/_eval_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,21 +275,6 @@ def _eval_Bool(tp, *, ctx):
return _eval_bool_tp(tp, ctx)


@type_eval.register_evaluator(_BoolLiteral)
@_lift_evaluated
def _eval_BoolLiteral(tp, *, ctx):
from typemap.typing import _BoolLiteralGenericAlias

if isinstance(tp, type):
raise TypeError(f"Expected literal type, got '{tp.__name__}'")

# If already wrapped, just return it
if isinstance(tp, _BoolLiteralGenericAlias):
return tp

return _BoolLiteral[tp]


##################################################################


Expand Down
9 changes: 8 additions & 1 deletion typemap/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,4 +307,11 @@ def __bool__(self):

@_SpecialForm
def _BoolLiteral(self, tp):
return _BoolLiteralGenericAlias(self, tp)
if isinstance(tp, type):
raise TypeError(f"Expected literal type, got '{tp.__name__}'")

# If already wrapped, just return it
if isinstance(tp, _BoolLiteralGenericAlias):
return tp

return _BoolLiteralGenericAlias(Literal, tp)