Skip to content

Commit a94d74f

Browse files
committed
use special forms
1 parent 8029c5b commit a94d74f

2 files changed

Lines changed: 112 additions & 41 deletions

File tree

tests/test_type_eval.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -854,6 +854,58 @@ def test_eval_if():
854854
d = eval_typing(ShorterTuple[tuple[int, int], tuple[str]])
855855
assert d == tuple[str]
856856

857+
d = eval_typing(If[True, Literal[True], Literal[False]])
858+
assert d == Literal[True]
859+
d = eval_typing(If[False, Literal[True], Literal[False]])
860+
assert d == Literal[False]
861+
862+
863+
class Prefixed:
864+
x_a: int
865+
x_b: str
866+
y_a: int
867+
y_b: str
868+
869+
870+
type FilterPrefix[T, P: str] = NewProtocol[
871+
*[
872+
x
873+
for x in Iter[Attrs[T]]
874+
if Equals[StrSlice[GetName[x], 0, Length[P]], P]
875+
]
876+
]
877+
type FilterPrefix2[T, Pint: str, Pstr: str] = NewProtocol[
878+
*[
879+
x
880+
for x in Iter[Attrs[T]]
881+
if If[
882+
Is[GetType[x], int],
883+
Equals[StrSlice[GetName[x], 0, Length[Pint]], Pint],
884+
Equals[StrSlice[GetName[x], 0, Length[Pstr]], Pstr],
885+
]
886+
]
887+
]
888+
889+
890+
def test_filter_prefix_1():
891+
d = eval_typing(FilterPrefix[Prefixed, Literal["x_"]])
892+
fmt = format_helper.format_class(d)
893+
assert fmt == textwrap.dedent("""\
894+
class FilterPrefix[tests.test_type_eval.Prefixed, typing.Literal['x_']]:
895+
x_a: int
896+
x_b: str
897+
""")
898+
899+
900+
def test_filter_prefix_2():
901+
d = eval_typing(FilterPrefix2[Prefixed, Literal["x_"], Literal["y_"]])
902+
fmt = format_helper.format_class(d)
903+
assert fmt == textwrap.dedent("""\
904+
class FilterPrefix2[tests.test_type_eval.Prefixed, typing.Literal['x_'], typing.Literal['y_']]:
905+
x_a: int
906+
y_b: str
907+
""")
908+
857909

858910
def test_uppercase_never():
859911
d = eval_typing(Uppercase[Never])

typemap/typing.py

Lines changed: 60 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -70,46 +70,6 @@ class Length[S: tuple | str]:
7070
pass
7171

7272

73-
class Equals[L, R]:
74-
pass
75-
76-
77-
class NotEquals[L, R]:
78-
pass
79-
80-
81-
class GreaterThan[L, R]:
82-
pass
83-
84-
85-
class LessThan[L, R]:
86-
pass
87-
88-
89-
class GreaterThanOrEqual[L, R]:
90-
pass
91-
92-
93-
class LessThanOrEqual[L, R]:
94-
pass
95-
96-
97-
class Not[T]:
98-
pass
99-
100-
101-
class And[L, R]:
102-
pass
103-
104-
105-
class Or[L, R]:
106-
pass
107-
108-
109-
class If[Cond, Then, Else]:
110-
pass
111-
112-
11373
class Uppercase[S: str]:
11474
pass
11575

@@ -164,7 +124,16 @@ class _IsGenericAlias(_GenericAlias, _root=True): # type: ignore[call-arg]
164124
def __bool__(self):
165125
evaluator = special_form_evaluator.get()
166126
if evaluator:
167-
return evaluator(self)
127+
result = evaluator(self)
128+
if result is True or result is False:
129+
return result
130+
elif (
131+
isinstance(result, typing._GenericAlias)
132+
and getattr(result, "__origin__", None) is typing.Literal
133+
):
134+
return result.__args__[0]
135+
else:
136+
raise RuntimeError(f"Expected boolean, got {result}")
168137
else:
169138
return False
170139

@@ -179,4 +148,54 @@ def IsSubSimilar(self, tps):
179148
return _IsGenericAlias(self, tps)
180149

181150

151+
@_SpecialForm
152+
def Equals(self, params):
153+
return _IsGenericAlias(self, params)
154+
155+
156+
@_SpecialForm
157+
def NotEquals(self, params):
158+
return _IsGenericAlias(self, params)
159+
160+
161+
@_SpecialForm
162+
def GreaterThan(self, params):
163+
return _IsGenericAlias(self, params)
164+
165+
166+
@_SpecialForm
167+
def LessThan(self, params):
168+
return _IsGenericAlias(self, params)
169+
170+
171+
@_SpecialForm
172+
def GreaterThanOrEqual(self, params):
173+
return _IsGenericAlias(self, params)
174+
175+
176+
@_SpecialForm
177+
def LessThanOrEqual(self, params):
178+
return _IsGenericAlias(self, params)
179+
180+
181+
@_SpecialForm
182+
def Not(self, params):
183+
return _IsGenericAlias(self, params)
184+
185+
186+
@_SpecialForm
187+
def And(self, params):
188+
return _IsGenericAlias(self, params)
189+
190+
191+
@_SpecialForm
192+
def Or(self, params):
193+
return _IsGenericAlias(self, params)
194+
195+
196+
@_SpecialForm
197+
def If(self, params):
198+
return _IsGenericAlias(self, params)
199+
200+
182201
Is = IsSubSimilar

0 commit comments

Comments
 (0)