Skip to content

Commit d194a31

Browse files
committed
Add Bool special form.
1 parent 3ecea77 commit d194a31

3 files changed

Lines changed: 26 additions & 5 deletions

File tree

typemap/type_eval/_eval_operators.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from typemap.type_eval._eval_typing import _eval_types
1515
from typemap.typing import (
1616
Attrs,
17+
Bool,
1718
Capitalize,
1819
DropAnnotations,
1920
FromUnion,
@@ -223,6 +224,11 @@ def _eval_Iter(tp, *, ctx):
223224
)
224225

225226

227+
@type_eval.register_evaluator(Bool)
228+
def _eval_Bool(tp, *, ctx):
229+
return _eval_types(tp, ctx)
230+
231+
226232
# N.B: These handle unions on their own
227233

228234

typemap/type_eval/_special_form.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import typing
33
from typing import _GenericAlias # type: ignore
44

5+
from . import _typing_inspect
56

67
_SpecialForm: typing.Any = typing._SpecialForm
78

@@ -21,10 +22,19 @@ def __iter__(self):
2122
return iter(typing.TypeVarTuple("_IterDummy"))
2223

2324

24-
class _IsGenericAlias(_GenericAlias, _root=True): # type: ignore[call-arg]
25+
class _BoolGenericAlias(_GenericAlias, _root=True): # type: ignore[call-arg]
26+
"""A special form for boolean expressions.
27+
Converts Literal[bool] to bool for use in boolean contexts.
28+
"""
2529
def __bool__(self):
2630
evaluator = _special_form_evaluator.get()
2731
if evaluator:
28-
return evaluator(self)
32+
result = evaluator(self)
33+
if isinstance(result, bool):
34+
return result
35+
elif _typing_inspect.is_literal(result):
36+
return typing.get_args(result)[0]
37+
38+
raise RuntimeError(f"Expected bool or Literal[bool], got {result}")
2939
else:
3040
return False

typemap/typing.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from .type_eval._special_form import (
55
_IterGenericAlias,
6-
_IsGenericAlias,
6+
_BoolGenericAlias,
77
_SpecialForm,
88
)
99

@@ -183,14 +183,19 @@ def Iter(self, tp):
183183
return _IterGenericAlias(self, (tp,))
184184

185185

186+
@_SpecialForm
187+
def Bool(self, tps):
188+
return _BoolGenericAlias(self, tps)
189+
190+
186191
@_SpecialForm
187192
def IsSubtype(self, tps):
188-
return _IsGenericAlias(self, tps)
193+
return _BoolGenericAlias(self, tps)
189194

190195

191196
@_SpecialForm
192197
def IsSubSimilar(self, tps):
193-
return _IsGenericAlias(self, tps)
198+
return _BoolGenericAlias(self, tps)
194199

195200

196201
Sub = IsSubSimilar

0 commit comments

Comments
 (0)