diff --git a/tests/test_type_eval.py b/tests/test_type_eval.py index fc966c4..53703f2 100644 --- a/tests/test_type_eval.py +++ b/tests/test_type_eval.py @@ -17,18 +17,28 @@ from typemap.type_eval import eval_typing from typemap.typing import ( + And, Attrs, + Equals, FromUnion, GetArg, GetArgs, GetAttr, GetName, GetType, + GreaterThan, + GreaterThanOrEqual, + If, Is, Iter, Length, + LessThan, + LessThanOrEqual, Member, NewProtocol, + Not, + NotEquals, + Or, SpecialFormEllipsis, StrConcat, StrSlice, @@ -211,6 +221,24 @@ def test_type_strings_6(): assert d == Literal["bcd"] +type ReplacePrefix[S: str, P: str, N: str] = ( + StrConcat[N, StrSlice[S, 2, Literal[None]]] + if Is[StrSlice[S, 0, Length[P]], P] + else S +) + + +def test_type_strings_7(): + d = eval_typing( + ReplacePrefix[Literal["x_a"], Literal["x_"], Literal["hi_"]] + ) + assert d == Literal["hi_a"] + d = eval_typing( + ReplacePrefix[Literal["x_a"], Literal["y_"], Literal["hi_"]] + ) + assert d == Literal["x_a"] + + def _is_generic_permutation(t1, t2): return t1.__origin__ == t2.__origin__ and collections.Counter( t1.__args__ @@ -667,6 +695,247 @@ def test_consistency_01(): assert t == Literal[False] +def test_eval_equals(): + d = eval_typing(Equals[Literal[1], Literal[1]]) + assert d == Literal[True] + d = eval_typing(Equals[Literal[1], Literal[2]]) + assert d == Literal[False] + + d = eval_typing(Equals[1, 1]) + assert d == Literal[True] + d = eval_typing(Equals[1, 2]) + assert d == Literal[False] + + +def test_eval_not_equals(): + d = eval_typing(NotEquals[Literal[1], Literal[1]]) + assert d == Literal[False] + d = eval_typing(NotEquals[Literal[1], Literal[2]]) + assert d == Literal[True] + + d = eval_typing(NotEquals[1, 1]) + assert d == Literal[False] + d = eval_typing(NotEquals[1, 2]) + assert d == Literal[True] + + +def test_eval_greater_than(): + d = eval_typing(GreaterThan[Literal[1], Literal[1]]) + assert d == Literal[False] + d = eval_typing(GreaterThan[Literal[1], Literal[2]]) + assert d == Literal[False] + d = eval_typing(GreaterThan[Literal[2], Literal[1]]) + assert d == Literal[True] + + d = eval_typing(GreaterThan[1, 1]) + assert d == Literal[False] + d = eval_typing(GreaterThan[1, 2]) + assert d == Literal[False] + d = eval_typing(GreaterThan[2, 1]) + assert d == Literal[True] + + +def test_eval_greater_than_or_equal(): + d = eval_typing(GreaterThanOrEqual[Literal[1], Literal[1]]) + assert d == Literal[True] + d = eval_typing(GreaterThanOrEqual[Literal[1], Literal[2]]) + assert d == Literal[False] + d = eval_typing(GreaterThanOrEqual[Literal[2], Literal[1]]) + assert d == Literal[True] + + d = eval_typing(GreaterThanOrEqual[1, 1]) + assert d == Literal[True] + d = eval_typing(GreaterThanOrEqual[1, 2]) + assert d == Literal[False] + d = eval_typing(GreaterThanOrEqual[2, 1]) + assert d == Literal[True] + + +def test_eval_less_than(): + d = eval_typing(LessThan[Literal[1], Literal[1]]) + assert d == Literal[False] + d = eval_typing(LessThan[Literal[1], Literal[2]]) + assert d == Literal[True] + d = eval_typing(LessThan[Literal[2], Literal[1]]) + assert d == Literal[False] + + d = eval_typing(LessThan[1, 1]) + assert d == Literal[False] + d = eval_typing(LessThan[1, 2]) + assert d == Literal[True] + d = eval_typing(LessThan[2, 1]) + assert d == Literal[False] + + +def test_eval_less_than_or_equal(): + d = eval_typing(LessThanOrEqual[Literal[1], Literal[1]]) + assert d == Literal[True] + d = eval_typing(LessThanOrEqual[Literal[1], Literal[2]]) + assert d == Literal[True] + d = eval_typing(LessThanOrEqual[Literal[2], Literal[1]]) + assert d == Literal[False] + + d = eval_typing(LessThanOrEqual[1, 1]) + assert d == Literal[True] + d = eval_typing(LessThanOrEqual[1, 2]) + assert d == Literal[True] + d = eval_typing(LessThanOrEqual[2, 1]) + assert d == Literal[False] + + +def test_eval_not(): + d = eval_typing(Not[Literal[True]]) + assert d == Literal[False] + d = eval_typing(Not[Literal[False]]) + assert d == Literal[True] + + d = eval_typing(Not[True]) + assert d == Literal[False] + d = eval_typing(Not[False]) + assert d == Literal[True] + + +def test_eval_and(): + d = eval_typing(And[Literal[True], Literal[True]]) + assert d == Literal[True] + d = eval_typing(And[Literal[True], Literal[False]]) + assert d == Literal[False] + d = eval_typing(And[Literal[False], Literal[True]]) + assert d == Literal[False] + d = eval_typing(And[Literal[False], Literal[False]]) + assert d == Literal[False] + + d = eval_typing(And[True, True]) + assert d == Literal[True] + d = eval_typing(And[True, False]) + assert d == Literal[False] + d = eval_typing(And[False, True]) + assert d == Literal[False] + d = eval_typing(And[False, False]) + assert d == Literal[False] + + +def test_eval_or(): + d = eval_typing(Or[Literal[True], Literal[True]]) + assert d == Literal[True] + d = eval_typing(Or[Literal[True], Literal[False]]) + assert d == Literal[True] + d = eval_typing(Or[Literal[False], Literal[True]]) + assert d == Literal[True] + d = eval_typing(Or[Literal[False], Literal[False]]) + assert d == Literal[False] + + d = eval_typing(Or[True, True]) + assert d == Literal[True] + d = eval_typing(Or[True, False]) + assert d == Literal[True] + d = eval_typing(Or[False, True]) + assert d == Literal[True] + d = eval_typing(Or[False, False]) + assert d == Literal[False] + + +type ShorterTuple[L, R] = If[LessThan[Length[L], Length[R]], L, R] + + +def test_eval_if(): + d = eval_typing(If[Literal[True], int, str]) + assert d is int + d = eval_typing(If[Literal[False], int, str]) + assert d is str + + d = eval_typing(If[True, int, str]) + assert d is int + d = eval_typing(If[False, int, str]) + assert d is str + + d = eval_typing(ShorterTuple[tuple[int], tuple[str, str]]) + assert d == tuple[int] + d = eval_typing(ShorterTuple[tuple[int, int], tuple[str]]) + assert d == tuple[str] + + d = eval_typing(If[True, Literal[True], Literal[False]]) + assert d == Literal[True] + d = eval_typing(If[False, Literal[True], Literal[False]]) + assert d == Literal[False] + + +class Prefixed: + x_a: int + x_b: str + x_c: float + y_a: int + y_b: str + y_c: float + + +type FilterPrefix[T, P: str] = NewProtocol[ + *[x for x in Iter[Attrs[T]] if Is[StrSlice[GetName[x], 0, Length[P]], P]] +] +type FilterPrefix2[T, Pint: str, Pstr: str] = NewProtocol[ + *[ + x + for x in Iter[Attrs[T]] + if ( + Is[StrSlice[GetName[x], 0, Length[Pint]], Pint] + if Is[GetType[x], int] + else ( + Is[StrSlice[GetName[x], 0, Length[Pstr]], Pstr] + and Is[GetType[x], str] + ) + ) + ] +] +type FilterPrefix3[T, P: str, N: str] = NewProtocol[ + *[ + ( + x + if not Is[StrSlice[GetName[x], 0, Length[P]], P] + else Member[ + StrConcat[N, StrSlice[GetName[x], Length[P], Literal[None]]], + GetType[x], + ] + ) + for x in Iter[Attrs[T]] + ] +] + + +def test_filter_prefix_1(): + d = eval_typing(FilterPrefix[Prefixed, Literal["x_"]]) + fmt = format_helper.format_class(d) + assert fmt == textwrap.dedent("""\ + class FilterPrefix[tests.test_type_eval.Prefixed, typing.Literal['x_']]: + x_a: int + x_b: str + x_c: float + """) + + +def test_filter_prefix_2(): + d = eval_typing(FilterPrefix2[Prefixed, Literal["x_"], Literal["y_"]]) + fmt = format_helper.format_class(d) + assert fmt == textwrap.dedent("""\ + class FilterPrefix2[tests.test_type_eval.Prefixed, typing.Literal['x_'], typing.Literal['y_']]: + x_a: int + y_b: str + """) + + +def test_filter_prefix_3(): + d = eval_typing(FilterPrefix3[Prefixed, Literal["y_"], Literal["z_"]]) + fmt = format_helper.format_class(d) + assert fmt == textwrap.dedent("""\ + class FilterPrefix3[tests.test_type_eval.Prefixed, typing.Literal['y_'], typing.Literal['z_']]: + x_a: int + x_b: str + x_c: float + z_a: int + z_b: str + z_c: float + """) + + def test_uppercase_never(): d = eval_typing(Uppercase[Never]) assert d is Never diff --git a/typemap/type_eval/_eval_operators.py b/typemap/type_eval/_eval_operators.py index 639959e..a3c802f 100644 --- a/typemap/type_eval/_eval_operators.py +++ b/typemap/type_eval/_eval_operators.py @@ -13,20 +13,29 @@ from typemap.type_eval import _typing_inspect from typemap.type_eval._eval_typing import _eval_types from typemap.typing import ( + And, Attrs, Capitalize, + Equals, FromUnion, GetArg, GetArgs, GetAttr, + GreaterThan, + GreaterThanOrEqual, IsSubSimilar, IsSubtype, Iter, Length, + LessThan, + LessThanOrEqual, Lowercase, Member, Members, NewProtocol, + Not, + NotEquals, + Or, Param, SpecialFormEllipsis, StrConcat, @@ -474,9 +483,54 @@ def _eval_Length(tp, *, ctx) -> typing.Any: return typing.Literal[len(tp.__args__)] else: return typing.Literal[None] + elif ( + _typing_inspect.is_generic_alias(tp) + and tp.__origin__ is typing.Literal + and isinstance(tp.__args__[0], str) + ): + return typing.Literal[len(tp.__args__[0])] else: # XXX: Or should we return Never? - raise TypeError(f"Invalid type argument to Length: {tp} is not a tuple") + raise TypeError( + f"Invalid type argument to Length: {tp} is not a tuple or string" + ) + + +################################################################## + + +def _literal_unary_op(typ, op): + @_lift_over_unions + def func(val, *, ctx): + return typing.Literal[op(_from_literal(val, ctx))] + + type_eval.register_evaluator(typ)(func) + + +def _literal_binary_op(typ, op): + @_lift_over_unions + def func(lhs, rhs, *, ctx): + return typing.Literal[ + op(_from_literal(lhs, ctx), _from_literal(rhs, ctx)) + ] + + type_eval.register_evaluator(typ)(func) + + +_literal_binary_op(Equals, op=lambda lhs, rhs: lhs == rhs) +_literal_binary_op(NotEquals, op=lambda lhs, rhs: lhs != rhs) +_literal_binary_op(GreaterThan, op=lambda lhs, rhs: lhs > rhs) +_literal_binary_op(LessThan, op=lambda lhs, rhs: lhs < rhs) +_literal_binary_op(GreaterThanOrEqual, op=lambda lhs, rhs: lhs >= rhs) +_literal_binary_op(LessThanOrEqual, op=lambda lhs, rhs: lhs <= rhs) + + +_literal_unary_op(Not, op=lambda val: not val) +_literal_binary_op(And, op=lambda lhs, rhs: lhs and rhs) +_literal_binary_op(Or, op=lambda lhs, rhs: lhs or rhs) + + +################################################################## def _string_literal_op(typ, op): diff --git a/typemap/type_eval/_eval_typing.py b/typemap/type_eval/_eval_typing.py index 7f76660..387aba8 100644 --- a/typemap/type_eval/_eval_typing.py +++ b/typemap/type_eval/_eval_typing.py @@ -19,6 +19,7 @@ from typing import Any from . import _apply_generic +from typemap.typing import If __all__ = ("eval_typing",) @@ -354,6 +355,14 @@ def _eval_applied_class(obj: typing_GenericAlias, ctx: EvalContext): """Eval a typing._GenericAlias -- an applied user-defined class""" # generic *classes* are typing._GenericAlias while generic type # aliases are types.GenericAlias? Why in the world. + + if obj.__origin__ == If: + cond = _eval_types(obj.__args__[0], ctx) + if cond is True or cond == typing.Literal[True]: + return _eval_types(obj.__args__[1], ctx) + else: + return _eval_types(obj.__args__[2], ctx) + new_args = tuple(_eval_types(arg, ctx) for arg in typing.get_args(obj)) if func := _eval_funcs.get(obj.__origin__): diff --git a/typemap/typing.py b/typemap/typing.py index 1413cc7..3f2cd35 100644 --- a/typemap/typing.py +++ b/typemap/typing.py @@ -66,7 +66,7 @@ class GetArgs[Tp, Base]: pass -class Length[S: tuple]: +class Length[S: tuple | str]: pass @@ -124,7 +124,16 @@ class _IsGenericAlias(_GenericAlias, _root=True): # type: ignore[call-arg] def __bool__(self): evaluator = special_form_evaluator.get() if evaluator: - return evaluator(self) + result = evaluator(self) + if result is True or result is False: + return result + elif ( + isinstance(result, typing._GenericAlias) + and getattr(result, "__origin__", None) is typing.Literal + ): + return result.__args__[0] + else: + raise RuntimeError(f"Expected boolean, got {result}") else: return False @@ -139,4 +148,54 @@ def IsSubSimilar(self, tps): return _IsGenericAlias(self, tps) +@_SpecialForm +def Equals(self, params): + return _IsGenericAlias(self, params) + + +@_SpecialForm +def NotEquals(self, params): + return _IsGenericAlias(self, params) + + +@_SpecialForm +def GreaterThan(self, params): + return _IsGenericAlias(self, params) + + +@_SpecialForm +def LessThan(self, params): + return _IsGenericAlias(self, params) + + +@_SpecialForm +def GreaterThanOrEqual(self, params): + return _IsGenericAlias(self, params) + + +@_SpecialForm +def LessThanOrEqual(self, params): + return _IsGenericAlias(self, params) + + +@_SpecialForm +def Not(self, params): + return _IsGenericAlias(self, params) + + +@_SpecialForm +def And(self, params): + return _IsGenericAlias(self, params) + + +@_SpecialForm +def Or(self, params): + return _IsGenericAlias(self, params) + + +@_SpecialForm +def If(self, params): + return _IsGenericAlias(self, params) + + Is = IsSubSimilar