diff --git a/tests/test_type_eval.py b/tests/test_type_eval.py index 5d954a8..7627f77 100644 --- a/tests/test_type_eval.py +++ b/tests/test_type_eval.py @@ -31,8 +31,9 @@ GetSpecialAttr, GetType, GetAnnotations, - IsSubtype, IsSub, + IsSubSimilar, + IsSubtype, Iter, Length, Matches, @@ -45,6 +46,7 @@ StrConcat, Uppercase, _BoolLiteral, + _Lambda, ) from . import format_helper @@ -1428,6 +1430,149 @@ def test_eval_bool_literal_error_01(): eval_typing(_BoolLiteral[int]) +def test_eval_lambda_01(): + type OnlyIntToSet[T] = set[T] if IsSub[T, int] else T + + a = lambda: int + b = lambda T: T + c = lambda T: list[T] + d = lambda T: OnlyIntToSet[T] + + t = eval_typing(_Lambda[a]) + assert t == _Lambda[a] + assert eval_typing(t()) is int + + t = eval_typing(_Lambda[b]) + assert t == _Lambda[b] + assert eval_typing(t(int)) is int + assert eval_typing(t(str)) is str + + t = eval_typing(_Lambda[c]) + assert t == _Lambda[c] + assert eval_typing(t(int)) == list[int] + assert eval_typing(t(str)) == list[str] + + t = eval_typing(_Lambda[d]) + assert t == _Lambda[d] + assert eval_typing(t(int)) == set[int] + assert eval_typing(t(str)) is str + + +LambdaInt1 = _Lambda[lambda: int] +LambdaInt2 = _Lambda[lambda: int] +LambdaStr = _Lambda[lambda: str] + + +def test_eval_lambda_02(): + # nested lambdas + a = _Lambda[lambda: _Lambda[lambda: int]] + assert a == _Lambda[lambda: _Lambda[lambda: int]] + assert eval_typing(a()) == _Lambda[lambda: int] + + assert a != _Lambda[lambda: _Lambda[lambda: str]] + + # lambda closure + b = _Lambda[lambda: int] + c = _Lambda[lambda: b] + d = _Lambda[lambda: int] + e = _Lambda[lambda: str] + assert c == _Lambda[lambda: d] + assert eval_typing(c()) == _Lambda[lambda: int] + + assert c != _Lambda[lambda: e] + + # lambda global + f = _Lambda[lambda: LambdaInt1] + assert f == _Lambda[lambda: LambdaInt2] + assert eval_typing(f()) == _Lambda[lambda: int] + + assert f != _Lambda[lambda: LambdaStr] + + +def test_eval_lambda_03(): + # different lambdas with same bytecode are treated as the same + + assert eval_typing(_Lambda[lambda: int]) == eval_typing( + _Lambda[lambda: int] + ) + assert eval_typing(_Lambda[lambda: list[int]]) == eval_typing( + _Lambda[lambda: list[int]] + ) + + a1 = lambda: int + a2 = lambda: int + + assert _Lambda[a1] == _Lambda[a2] + assert eval_typing(_Lambda[a1]) == eval_typing(_Lambda[a2]) + + l1 = Literal[1] + l2 = Literal[1] + + assert _Lambda[lambda: l1] == _Lambda[lambda: l2] + assert eval_typing(_Lambda[lambda: l1]) == eval_typing(_Lambda[lambda: l2]) + + +def test_eval_lambda_04(): + # different lambdas with different bytecode are treated as different + + assert eval_typing(_Lambda[lambda: int]) != eval_typing( + _Lambda[lambda: str] + ) + + def _f1(): + X = str + return lambda: X + + f1 = _f1() + + def _f2(): + X = int + return lambda: X + + f2 = _f2() + + assert _Lambda[f1] != _Lambda[f2] + assert eval_typing(_Lambda[f1]) != eval_typing(_Lambda[f2]) + + +def test_eval_lambda_05(): + # comparison operators + a1 = lambda: int + a2 = lambda: int + + t = eval_typing(IsSubtype[_Lambda[a1], _Lambda[a2]]) + assert t == _BoolLiteral[True] + t = eval_typing(IsSubSimilar[_Lambda[a1], _Lambda[a2]]) + assert t == _BoolLiteral[True] + t = eval_typing(Matches[_Lambda[a1], _Lambda[a2]]) + assert t == _BoolLiteral[True] + + t = eval_typing(IsSubtype[_Lambda[lambda T: T], _Lambda[lambda U: U]]) + assert t == _BoolLiteral[True] + t = eval_typing(IsSubSimilar[_Lambda[lambda T: T], _Lambda[lambda U: U]]) + assert t == _BoolLiteral[True] + t = eval_typing(Matches[_Lambda[lambda T: T], _Lambda[lambda U: U]]) + assert t == _BoolLiteral[True] + + +def test_eval_lambda_06(): + # lambda captures non-hashable + + # list is specially converted to tuple + a = [int, str, float] + b = [int, str, float] + assert ( + _Lambda[lambda: Callable[a, int]] == _Lambda[lambda: Callable[b, int]] + ) + assert eval_typing(_Lambda[lambda: a]) == eval_typing(_Lambda[lambda: a]) + + # other non-hashables are only compared by id + c = {1, 2, 3} + d = {1, 2, 3} + assert _Lambda[lambda: c] != _Lambda[lambda: d] + assert eval_typing(_Lambda[lambda: c]) != eval_typing(_Lambda[lambda: d]) + + def test_eval_length_01(): d = eval_typing(Length[tuple[int, str]]) assert d == Literal[2] diff --git a/typemap/type_eval/_apply_generic.py b/typemap/type_eval/_apply_generic.py index 48a7c11..49b8224 100644 --- a/typemap/type_eval/_apply_generic.py +++ b/typemap/type_eval/_apply_generic.py @@ -170,7 +170,7 @@ def make_func( func.__globals__, "__call__", func.__defaults__, - (), + func.__closure__, func.__kwdefaults__, ) diff --git a/typemap/type_eval/_subsim.py b/typemap/type_eval/_subsim.py index e87d5b7..60ab636 100644 --- a/typemap/type_eval/_subsim.py +++ b/typemap/type_eval/_subsim.py @@ -40,6 +40,10 @@ def issubsimilar(lhs: typing.Any, rhs: typing.Any) -> bool: ): return issubclass(lhs, rhs) + # lambda <:? lambda + elif _typing_inspect.is_lambda(lhs) or _typing_inspect.is_lambda(rhs): + return lhs == rhs + # literal <:? literal elif _typing_inspect.is_literal(lhs) and _typing_inspect.is_literal(rhs): # We need to check both value and type, since True == 1 but diff --git a/typemap/type_eval/_subtype.py b/typemap/type_eval/_subtype.py index a80313c..c591d84 100644 --- a/typemap/type_eval/_subtype.py +++ b/typemap/type_eval/_subtype.py @@ -42,6 +42,10 @@ def issubtype(lhs: typing.Any, rhs: typing.Any) -> bool: ): return issubclass(lhs, rhs) + # lambda <:? lambda + elif _typing_inspect.is_lambda(lhs) or _typing_inspect.is_lambda(rhs): + return lhs == rhs + # literal <:? literal elif bool( _typing_inspect.is_literal(lhs) and _typing_inspect.is_literal(rhs) diff --git a/typemap/type_eval/_typing_inspect.py b/typemap/type_eval/_typing_inspect.py index 1904962..f377153 100644 --- a/typemap/type_eval/_typing_inspect.py +++ b/typemap/type_eval/_typing_inspect.py @@ -135,6 +135,12 @@ def is_literal(t: Any) -> bool: return is_generic_alias(t) and get_origin(t) is Literal # type: ignore [comparison-overlap] +def is_lambda(t: Any) -> bool: + from typemap.typing import _Lambda + + return is_generic_alias(t) and get_origin(t) is _Lambda + + def get_head(t: Any) -> type | None: if is_generic_alias(t): return get_head(get_origin(t)) diff --git a/typemap/typing.py b/typemap/typing.py index 670550d..54eee1e 100644 --- a/typemap/typing.py +++ b/typemap/typing.py @@ -315,3 +315,88 @@ def _BoolLiteral(self, tp): return tp return _BoolLiteralGenericAlias(Literal, tp) + + +class _LambdaGenericAlias(_GenericAlias, _root=True): # type: ignore[call-arg] + def __eq__(self, other: typing.Any) -> bool: + return ( + isinstance(other, _LambdaGenericAlias) and self.key() == other.key() + ) + + def __hash__(self) -> int: + return hash(self.key()) + + def __call__(self, *args: typing.Any, **kwargs: typing.Any) -> typing.Any: + return self._func()(*args, **kwargs) + + def _func(self) -> typing.Callable: + return typing.get_args(self)[0] + + def key( + self, + ) -> tuple[ + tuple[bytes, tuple[typing.Any, ...], tuple[typing.Any, ...]], + tuple[typing.Any, ...], + ]: + return self._key(self._func()) + + @staticmethod + def _key( + func: typing.Callable, + ) -> tuple[ + tuple[bytes, tuple[typing.Any, ...], tuple[typing.Any, ...]], + tuple[typing.Any, ...], + ]: + import builtins + + def _encode_code( + code: types.CodeType, + ) -> tuple[bytes, tuple[typing.Any, ...], tuple[typing.Any, ...]]: + bytecode = code.co_code + consts = tuple( + _encode_code(c) if isinstance(c, types.CodeType) else c + for c in code.co_consts + ) + + globals = tuple( + func.__globals__.get(name, None) + or getattr(builtins, name, None) + for name in code.co_names + ) + + return (bytecode, consts, globals) + + if func.__closure__ is None: + closures: tuple[typing.Any, ...] = () + else: + closures = tuple( + # list is specially converted to tuple + tuple(cell.cell_contents) + if isinstance(cell.cell_contents, list) + else cell.cell_contents + if bool( + isinstance( + cell.cell_contents, + ( + type, + typing.TypeVar, + typing.ParamSpec, + typing.TypeVarTuple, + typing.TypeAliasType, + typing._SpecialForm, + ), + ) + or typing.get_origin(cell.cell_contents) is not None + ) + else cell.cell_contents.key() + if isinstance(cell.cell_contents, _LambdaGenericAlias) + else id(cell.cell_contents) + for cell in func.__closure__ + ) + + return (_encode_code(func.__code__), closures) + + +@_SpecialForm +def _Lambda(self, tp): + return _LambdaGenericAlias(self, tp)