Skip to content

Commit 04cae66

Browse files
committed
Properly handle unions.
1 parent bb961c5 commit 04cae66

2 files changed

Lines changed: 98 additions & 49 deletions

File tree

tests/test_type_eval.py

Lines changed: 84 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1146,25 +1146,66 @@ def test_eval_bool_01():
11461146
assert d == _BoolLiteral[False]
11471147

11481148
d = eval_typing(Bool[Literal[1]])
1149-
assert d == _BoolLiteral[True]
1149+
assert d == _BoolLiteral[False]
11501150

11511151
d = eval_typing(Bool[Literal[0]])
11521152
assert d == _BoolLiteral[False]
11531153

11541154
d = eval_typing(Bool[Literal["true"]])
1155-
assert d == _BoolLiteral[True]
1155+
assert d == _BoolLiteral[False]
11561156

11571157
d = eval_typing(Bool[Literal["false"]])
1158-
assert d == _BoolLiteral[True]
1159-
1158+
assert d == _BoolLiteral[False]
11601159

1161-
def test_eval_bool_02():
11621160
d = eval_typing(Bool[_BoolLiteral[True]])
11631161
assert d == _BoolLiteral[True]
11641162

11651163
d = eval_typing(Bool[_BoolLiteral[False]])
11661164
assert d == _BoolLiteral[False]
11671165

1166+
d = eval_typing(Bool[Never])
1167+
assert d == _BoolLiteral[False]
1168+
1169+
d = eval_typing(Bool[int])
1170+
assert d == _BoolLiteral[False]
1171+
1172+
class C:
1173+
pass
1174+
1175+
d = eval_typing(Bool[C])
1176+
assert d == _BoolLiteral[False]
1177+
1178+
d = eval_typing(Bool[True])
1179+
assert d == _BoolLiteral[True]
1180+
1181+
d = eval_typing(Bool[False])
1182+
assert d == _BoolLiteral[False]
1183+
1184+
1185+
def test_eval_bool_02():
1186+
d = eval_typing(Bool[Literal[True] | Literal[False]])
1187+
assert d == _BoolLiteral[True]
1188+
d = eval_typing(Bool[Literal[False] | Literal[True]])
1189+
assert d == _BoolLiteral[True]
1190+
d = eval_typing(Bool[Literal[True] | Never])
1191+
assert d == _BoolLiteral[True]
1192+
d = eval_typing(Bool[Never | Literal[True]])
1193+
assert d == _BoolLiteral[True]
1194+
d = eval_typing(Bool[Literal[False] | Never])
1195+
assert d == _BoolLiteral[False]
1196+
d = eval_typing(Bool[Never | Literal[False]])
1197+
assert d == _BoolLiteral[False]
1198+
d = eval_typing(Bool[Literal[True] | int])
1199+
assert d == _BoolLiteral[True]
1200+
d = eval_typing(Bool[int | Literal[True]])
1201+
assert d == _BoolLiteral[True]
1202+
d = eval_typing(Bool[Literal[False] | int])
1203+
assert d == _BoolLiteral[False]
1204+
d = eval_typing(Bool[int | Literal[False]])
1205+
assert d == _BoolLiteral[False]
1206+
d = eval_typing(Bool[int | str])
1207+
assert d == _BoolLiteral[False]
1208+
11681209

11691210
def test_eval_bool_03():
11701211
d = eval_typing(NotLiteralGeneric[Bool[Literal[True]]])
@@ -1208,30 +1249,45 @@ def test_eval_all_01():
12081249

12091250
d = eval_typing(AllOf[_BoolLiteral[True]])
12101251
assert d == _BoolLiteral[True]
1211-
12121252
d = eval_typing(AllOf[_BoolLiteral[False]])
12131253
assert d == _BoolLiteral[False]
12141254

12151255
d = eval_typing(AllOf[_BoolLiteral[True], _BoolLiteral[True]])
12161256
assert d == _BoolLiteral[True]
1217-
12181257
d = eval_typing(AllOf[_BoolLiteral[True], _BoolLiteral[False]])
12191258
assert d == _BoolLiteral[False]
1220-
12211259
d = eval_typing(AllOf[_BoolLiteral[False], _BoolLiteral[True]])
12221260
assert d == _BoolLiteral[False]
1223-
12241261
d = eval_typing(AllOf[_BoolLiteral[False], _BoolLiteral[False]])
12251262
assert d == _BoolLiteral[False]
12261263

1264+
d = eval_typing(AllOf[Literal[True] | Literal[False]])
1265+
assert d == _BoolLiteral[True]
1266+
d = eval_typing(AllOf[int | Never])
1267+
assert d == _BoolLiteral[False]
1268+
d = eval_typing(
1269+
AllOf[Literal[0] | Literal[True], Literal[2] | Literal[True]]
1270+
)
1271+
assert d == _BoolLiteral[True]
1272+
d = eval_typing(AllOf[Literal[0] | Literal[1], Literal[2] | Literal[True]])
1273+
assert d == _BoolLiteral[False]
1274+
d = eval_typing(AllOf[Literal[0] | Literal[1], Literal[2] | Literal[3]])
1275+
assert d == _BoolLiteral[False]
1276+
12271277

12281278
def test_eval_all_02():
1229-
d = eval_typing(AllOf[Literal[True], Literal[True]])
1279+
d = eval_typing(AllOf[()])
12301280
assert d == _BoolLiteral[True]
12311281

1232-
d = eval_typing(AllOf[Literal[True], Literal[False]])
1282+
d = eval_typing(AllOf[Literal[True]])
1283+
assert d == _BoolLiteral[True]
1284+
d = eval_typing(AllOf[Literal[False]])
12331285
assert d == _BoolLiteral[False]
12341286

1287+
d = eval_typing(AllOf[Literal[True], Literal[True]])
1288+
assert d == _BoolLiteral[True]
1289+
d = eval_typing(AllOf[Literal[True], Literal[False]])
1290+
assert d == _BoolLiteral[False]
12351291
d = eval_typing(AllOf[Literal[False], Literal[True]])
12361292
assert d == _BoolLiteral[False]
12371293

@@ -1248,16 +1304,13 @@ def test_eval_all_03():
12481304

12491305
d = eval_typing(ContainsAllInt[tuple[int]])
12501306
assert d == _BoolLiteral[True]
1251-
12521307
d = eval_typing(ContainsAllInt[tuple[str]])
12531308
assert d == _BoolLiteral[False]
12541309

12551310
d = eval_typing(ContainsAllInt[tuple[int, int]])
12561311
assert d == _BoolLiteral[True]
1257-
12581312
d = eval_typing(ContainsAllInt[tuple[int, str]])
12591313
assert d == _BoolLiteral[False]
1260-
12611314
d = eval_typing(ContainsAllInt[tuple[str, str]])
12621315
assert d == _BoolLiteral[False]
12631316

@@ -1268,7 +1321,6 @@ def test_eval_all_04():
12681321

12691322
d = eval_typing(ContainsAllIntToLiteral[tuple[int]])
12701323
assert d == Literal[True]
1271-
12721324
d = eval_typing(ContainsAllIntToLiteral[tuple[str]])
12731325
assert d == Literal[False]
12741326

@@ -1279,42 +1331,47 @@ def test_eval_any_01():
12791331

12801332
d = eval_typing(AnyOf[_BoolLiteral[True]])
12811333
assert d == _BoolLiteral[True]
1282-
12831334
d = eval_typing(AnyOf[_BoolLiteral[False]])
12841335
assert d == _BoolLiteral[False]
12851336

12861337
d = eval_typing(AnyOf[_BoolLiteral[True], _BoolLiteral[True]])
12871338
assert d == _BoolLiteral[True]
1288-
12891339
d = eval_typing(AnyOf[_BoolLiteral[True], _BoolLiteral[False]])
12901340
assert d == _BoolLiteral[True]
1291-
12921341
d = eval_typing(AnyOf[_BoolLiteral[False], _BoolLiteral[True]])
12931342
assert d == _BoolLiteral[True]
1294-
12951343
d = eval_typing(AnyOf[_BoolLiteral[False], _BoolLiteral[False]])
12961344
assert d == _BoolLiteral[False]
12971345

1346+
d = eval_typing(AnyOf[Literal[True] | Literal[False]])
1347+
assert d == _BoolLiteral[True]
1348+
d = eval_typing(AnyOf[int | Never])
1349+
assert d == _BoolLiteral[False]
1350+
d = eval_typing(
1351+
AnyOf[Literal[0] | Literal[True], Literal[2] | Literal[True]]
1352+
)
1353+
assert d == _BoolLiteral[True]
1354+
d = eval_typing(AnyOf[Literal[0] | Literal[1], Literal[2] | Literal[True]])
1355+
assert d == _BoolLiteral[True]
1356+
d = eval_typing(AnyOf[Literal[0] | Literal[1], Literal[2] | Literal[3]])
1357+
assert d == _BoolLiteral[False]
1358+
12981359

12991360
def test_eval_any_02():
13001361
d = eval_typing(AnyOf[()])
13011362
assert d == _BoolLiteral[False]
13021363

13031364
d = eval_typing(AnyOf[Literal[True]])
13041365
assert d == _BoolLiteral[True]
1305-
13061366
d = eval_typing(AnyOf[Literal[False]])
13071367
assert d == _BoolLiteral[False]
13081368

13091369
d = eval_typing(AnyOf[Literal[True], Literal[True]])
13101370
assert d == _BoolLiteral[True]
1311-
13121371
d = eval_typing(AnyOf[Literal[True], Literal[False]])
13131372
assert d == _BoolLiteral[True]
1314-
13151373
d = eval_typing(AnyOf[Literal[False], Literal[True]])
13161374
assert d == _BoolLiteral[True]
1317-
13181375
d = eval_typing(AnyOf[Literal[False], Literal[False]])
13191376
assert d == _BoolLiteral[False]
13201377

@@ -1331,16 +1388,13 @@ def test_eval_any_03():
13311388

13321389
d = eval_typing(ContainsAnyInt[tuple[int]])
13331390
assert d == _BoolLiteral[True]
1334-
13351391
d = eval_typing(ContainsAnyInt[tuple[str]])
13361392
assert d == _BoolLiteral[False]
13371393

13381394
d = eval_typing(ContainsAnyInt[tuple[int, int]])
13391395
assert d == _BoolLiteral[True]
1340-
13411396
d = eval_typing(ContainsAnyInt[tuple[int, str]])
13421397
assert d == _BoolLiteral[True]
1343-
13441398
d = eval_typing(ContainsAnyInt[tuple[str, str]])
13451399
assert d == _BoolLiteral[False]
13461400

@@ -1351,23 +1405,23 @@ def test_eval_any_04():
13511405

13521406
d = eval_typing(ContainsAnyIntToLiteral[tuple[int]])
13531407
assert d == Literal[True]
1354-
13551408
d = eval_typing(ContainsAnyIntToLiteral[tuple[str]])
13561409
assert d == Literal[False]
13571410

13581411

13591412
def test_eval_literal_generic_01():
13601413
d = eval_typing(_BoolLiteral[True])
13611414
assert d == _BoolLiteral[True]
1362-
13631415
d = eval_typing(_BoolLiteral[False])
13641416
assert d == _BoolLiteral[False]
1365-
13661417
d = eval_typing(_BoolLiteral[1])
13671418
assert d == _BoolLiteral[True]
1368-
13691419
d = eval_typing(_BoolLiteral[0])
13701420
assert d == _BoolLiteral[False]
1421+
d = eval_typing(_BoolLiteral[_BoolLiteral[True]])
1422+
assert d == _BoolLiteral[True]
1423+
d = eval_typing(_BoolLiteral[_BoolLiteral[False]])
1424+
assert d == _BoolLiteral[False]
13711425

13721426

13731427
def test_eval_literal_generic_02():
@@ -1376,51 +1430,42 @@ def test_eval_literal_generic_02():
13761430

13771431
d = eval_typing(NotLiteralGeneric[_BoolLiteral[True]])
13781432
assert d == _BoolLiteral[False]
1379-
13801433
d = eval_typing(NotLiteralGeneric[_BoolLiteral[False]])
13811434
assert d == _BoolLiteral[True]
13821435

13831436

13841437
def test_eval_literal_generic_03():
13851438
d = eval_typing(AndLiteralGeneric[_BoolLiteral[True], _BoolLiteral[True]])
13861439
assert d == _BoolLiteral[True]
1387-
13881440
d = eval_typing(AndLiteralGeneric[_BoolLiteral[True], _BoolLiteral[False]])
13891441
assert d == _BoolLiteral[False]
1390-
13911442
d = eval_typing(AndLiteralGeneric[_BoolLiteral[False], _BoolLiteral[True]])
13921443
assert d == _BoolLiteral[False]
1393-
13941444
d = eval_typing(AndLiteralGeneric[_BoolLiteral[False], _BoolLiteral[False]])
13951445
assert d == _BoolLiteral[False]
13961446

13971447

13981448
def test_eval_literal_generic_04():
13991449
d = eval_typing(OrLiteralGeneric[_BoolLiteral[True], _BoolLiteral[True]])
14001450
assert d == _BoolLiteral[True]
1401-
14021451
d = eval_typing(OrLiteralGeneric[_BoolLiteral[True], _BoolLiteral[False]])
14031452
assert d == _BoolLiteral[True]
1404-
14051453
d = eval_typing(OrLiteralGeneric[_BoolLiteral[False], _BoolLiteral[True]])
14061454
assert d == _BoolLiteral[True]
1407-
14081455
d = eval_typing(OrLiteralGeneric[_BoolLiteral[False], _BoolLiteral[False]])
14091456
assert d == _BoolLiteral[False]
14101457

14111458

14121459
def test_eval_literal_generic_05():
14131460
d = eval_typing(LiteralGenericToLiteral[_BoolLiteral[True]])
14141461
assert d == Literal[True]
1415-
14161462
d = eval_typing(LiteralGenericToLiteral[_BoolLiteral[False]])
14171463
assert d == Literal[False]
14181464

14191465

14201466
def test_eval_literal_generic_06():
14211467
d = eval_typing(NotLiteralGenericToLiteral[_BoolLiteral[True]])
14221468
assert d == Literal[False]
1423-
14241469
d = eval_typing(NotLiteralGenericToLiteral[_BoolLiteral[False]])
14251470
assert d == Literal[True]
14261471

typemap/type_eval/_eval_operators.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -255,31 +255,35 @@ def _eval_Matches(lhs, rhs, *, ctx):
255255
]
256256

257257

258-
def _eval_bool_tp(tp):
259-
if _typing_inspect.is_generic_alias(tp):
260-
if tp.__origin__ is typing.Literal:
261-
return _BoolLiteral[bool(tp.__args__[0])]
262-
elif tp.__origin__ is _BoolLiteral:
263-
return _BoolLiteral[bool(tp.__args__[0])]
264-
raise TypeError(f"Expected Literal type, got {tp}")
258+
def _eval_bool_tp(tp, ctx):
259+
if _typing_inspect.is_generic_alias(tp) and tp.__origin__ is _BoolLiteral:
260+
return _BoolLiteral[bool(tp.__args__[0])]
261+
else:
262+
return _BoolLiteral[
263+
any(
264+
type_eval.issubsimilar(arg, typing.Literal[True])
265+
and not type_eval.issubsimilar(arg, typing.Never)
266+
for arg in _union_elems(tp, ctx)
267+
)
268+
]
265269

266270

267271
@type_eval.register_evaluator(Bool)
268272
@_lift_evaluated
269273
def _eval_Bool(tp, *, ctx):
270-
return _eval_bool_tp(tp)
274+
return _eval_bool_tp(tp, ctx)
271275

272276

273277
@type_eval.register_evaluator(AllOf)
274278
@_lift_evaluated
275279
def _eval_AllOf(*tp, ctx):
276-
return _BoolLiteral[all(_eval_bool_tp(tp) for tp in tp)]
280+
return _BoolLiteral[all(_eval_bool_tp(tp, ctx) for tp in tp)]
277281

278282

279283
@type_eval.register_evaluator(AnyOf)
280284
@_lift_evaluated
281285
def _eval_AnyOf(*tp, ctx):
282-
return _BoolLiteral[any(_eval_bool_tp(tp) for tp in tp)]
286+
return _BoolLiteral[any(_eval_bool_tp(tp, ctx) for tp in tp)]
283287

284288

285289
@type_eval.register_evaluator(_BoolLiteral)

0 commit comments

Comments
 (0)