Skip to content

Commit 2ce42f4

Browse files
committed
Short circuit If.
1 parent 9d45292 commit 2ce42f4

3 files changed

Lines changed: 14 additions & 14 deletions

File tree

tests/test_type_eval.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -850,6 +850,11 @@ def test_eval_if():
850850
d = eval_typing(If[Literal[False], int, str])
851851
assert d is str
852852

853+
d = eval_typing(If[True, int, str])
854+
assert d is int
855+
d = eval_typing(If[False, int, str])
856+
assert d is str
857+
853858
d = eval_typing(ShorterTuple[tuple[int], tuple[str, str]])
854859
assert d == tuple[int]
855860
d = eval_typing(ShorterTuple[tuple[int, int], tuple[str]])

typemap/type_eval/_eval_operators.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
GetAttr,
2525
GreaterThan,
2626
GreaterThanOrEqual,
27-
If,
2827
IsSubSimilar,
2928
IsSubtype,
3029
Iter,
@@ -531,19 +530,6 @@ def func(lhs, rhs, *, ctx):
531530
##################################################################
532531

533532

534-
@type_eval.register_evaluator(If)
535-
def _eval_If(cond, then_branch, else_branch, *, ctx):
536-
cond_val = _from_literal(cond, ctx)
537-
return (
538-
_eval_types(then_branch, ctx)
539-
if cond_val
540-
else _eval_types(else_branch, ctx)
541-
)
542-
543-
544-
##################################################################
545-
546-
547533
def _string_literal_op(typ, op):
548534
@_lift_over_unions
549535
def func(*args, ctx):

typemap/type_eval/_eval_typing.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from typing import Any
1818

1919
from . import _apply_generic
20+
from typemap.typing import If
2021

2122

2223
__all__ = ("eval_typing",)
@@ -341,6 +342,14 @@ def _eval_applied_class(obj: typing_GenericAlias, ctx: EvalContext):
341342
"""Eval a typing._GenericAlias -- an applied user-defined class"""
342343
# generic *classes* are typing._GenericAlias while generic type
343344
# aliases are types.GenericAlias? Why in the world.
345+
346+
if obj.__origin__ == If:
347+
cond = _eval_types(obj.__args__[0], ctx)
348+
if cond is True or cond == typing.Literal[True]:
349+
return _eval_types(obj.__args__[1], ctx)
350+
else:
351+
return _eval_types(obj.__args__[2], ctx)
352+
344353
if func := _eval_funcs.get(obj.__origin__):
345354
new_args = tuple(_eval_types(arg, ctx) for arg in obj.__args__)
346355
ret = func(*new_args, ctx=ctx)

0 commit comments

Comments
 (0)