From af6c08ba6895e89cdfcb4f1d2e968dbcee0da2ba Mon Sep 17 00:00:00 2001 From: dnwpark Date: Tue, 3 Feb 2026 11:03:16 -0800 Subject: [PATCH 1/2] Add _Lambda special form. --- tests/test_type_eval.py | 82 +++++++++++++++++++++++++++- typemap/type_eval/_apply_generic.py | 2 +- typemap/type_eval/_subsim.py | 4 ++ typemap/type_eval/_subtype.py | 4 ++ typemap/type_eval/_typing_inspect.py | 6 ++ typemap/typing.py | 22 ++++++++ 6 files changed, 118 insertions(+), 2 deletions(-) diff --git a/tests/test_type_eval.py b/tests/test_type_eval.py index 5d954a8..85a45d3 100644 --- a/tests/test_type_eval.py +++ b/tests/test_type_eval.py @@ -31,8 +31,9 @@ GetSpecialAttr, GetType, GetAnnotations, - IsSubtype, IsSub, + IsSubSimilar, + IsSubtype, Iter, Length, Matches, @@ -45,6 +46,7 @@ StrConcat, Uppercase, _BoolLiteral, + _Lambda, ) from . import format_helper @@ -1428,6 +1430,84 @@ def test_eval_bool_literal_error_01(): eval_typing(_BoolLiteral[int]) +def test_eval_lambda_01(): + type OnlyIntToSet[T] = set[T] if IsSub[T, int] else T + + a = lambda: int + b = lambda T: T + c = lambda T: list[T] + d = lambda T: OnlyIntToSet[T] + + t = eval_typing(_Lambda[a]) + assert t == _Lambda[a] + assert eval_typing(t()) is int + + t = eval_typing(_Lambda[b]) + assert t == _Lambda[b] + assert eval_typing(t(int)) is int + assert eval_typing(t(str)) is str + + t = eval_typing(_Lambda[c]) + assert t == _Lambda[c] + assert eval_typing(t(int)) == list[int] + assert eval_typing(t(str)) == list[str] + + t = eval_typing(_Lambda[d]) + assert t == _Lambda[d] + assert eval_typing(t(int)) == set[int] + assert eval_typing(t(str)) is str + + +def test_eval_lambda_02(): + # different lambdas with same bytecode are treated as the same + a1 = lambda: int + a2 = lambda: int + + t1 = eval_typing(_Lambda[a1]) + t2 = eval_typing(_Lambda[a2]) + assert t1 == t2 + + t = eval_typing(IsSubtype[_Lambda[a1], _Lambda[a2]]) + assert t == _BoolLiteral[True] + t = eval_typing(IsSubSimilar[_Lambda[a1], _Lambda[a2]]) + assert t == _BoolLiteral[True] + t = eval_typing(Matches[_Lambda[a1], _Lambda[a2]]) + assert t == _BoolLiteral[True] + + t = eval_typing(IsSubtype[_Lambda[lambda T: T], _Lambda[lambda U: U]]) + assert t == _BoolLiteral[True] + t = eval_typing(IsSubSimilar[_Lambda[lambda T: T], _Lambda[lambda U: U]]) + assert t == _BoolLiteral[True] + t = eval_typing(Matches[_Lambda[lambda T: T], _Lambda[lambda U: U]]) + assert t == _BoolLiteral[True] + + +def test_eval_lambda_03(): + # different lambdas with different bytecode are treated as different + a1 = lambda: int + a2 = lambda: str + + t1 = eval_typing(_Lambda[a1]) + t2 = eval_typing(_Lambda[a2]) + assert t1 != t2 + + t = eval_typing(IsSubtype[_Lambda[a1], _Lambda[a2]]) + assert t == _BoolLiteral[False] + t = eval_typing(IsSubSimilar[_Lambda[a1], _Lambda[a2]]) + assert t == _BoolLiteral[False] + t = eval_typing(Matches[_Lambda[a1], _Lambda[a2]]) + assert t == _BoolLiteral[False] + + t = eval_typing(IsSubtype[_Lambda[lambda T: T], _Lambda[lambda T: list[T]]]) + assert t == _BoolLiteral[False] + t = eval_typing( + IsSubSimilar[_Lambda[lambda T: T], _Lambda[lambda T: list[T]]] + ) + assert t == _BoolLiteral[False] + t = eval_typing(Matches[_Lambda[lambda T: T], _Lambda[lambda T: list[T]]]) + assert t == _BoolLiteral[False] + + def test_eval_length_01(): d = eval_typing(Length[tuple[int, str]]) assert d == Literal[2] diff --git a/typemap/type_eval/_apply_generic.py b/typemap/type_eval/_apply_generic.py index 48a7c11..49b8224 100644 --- a/typemap/type_eval/_apply_generic.py +++ b/typemap/type_eval/_apply_generic.py @@ -170,7 +170,7 @@ def make_func( func.__globals__, "__call__", func.__defaults__, - (), + func.__closure__, func.__kwdefaults__, ) diff --git a/typemap/type_eval/_subsim.py b/typemap/type_eval/_subsim.py index e87d5b7..60ab636 100644 --- a/typemap/type_eval/_subsim.py +++ b/typemap/type_eval/_subsim.py @@ -40,6 +40,10 @@ def issubsimilar(lhs: typing.Any, rhs: typing.Any) -> bool: ): return issubclass(lhs, rhs) + # lambda <:? lambda + elif _typing_inspect.is_lambda(lhs) or _typing_inspect.is_lambda(rhs): + return lhs == rhs + # literal <:? literal elif _typing_inspect.is_literal(lhs) and _typing_inspect.is_literal(rhs): # We need to check both value and type, since True == 1 but diff --git a/typemap/type_eval/_subtype.py b/typemap/type_eval/_subtype.py index a80313c..c591d84 100644 --- a/typemap/type_eval/_subtype.py +++ b/typemap/type_eval/_subtype.py @@ -42,6 +42,10 @@ def issubtype(lhs: typing.Any, rhs: typing.Any) -> bool: ): return issubclass(lhs, rhs) + # lambda <:? lambda + elif _typing_inspect.is_lambda(lhs) or _typing_inspect.is_lambda(rhs): + return lhs == rhs + # literal <:? literal elif bool( _typing_inspect.is_literal(lhs) and _typing_inspect.is_literal(rhs) diff --git a/typemap/type_eval/_typing_inspect.py b/typemap/type_eval/_typing_inspect.py index 1904962..f377153 100644 --- a/typemap/type_eval/_typing_inspect.py +++ b/typemap/type_eval/_typing_inspect.py @@ -135,6 +135,12 @@ def is_literal(t: Any) -> bool: return is_generic_alias(t) and get_origin(t) is Literal # type: ignore [comparison-overlap] +def is_lambda(t: Any) -> bool: + from typemap.typing import _Lambda + + return is_generic_alias(t) and get_origin(t) is _Lambda + + def get_head(t: Any) -> type | None: if is_generic_alias(t): return get_head(get_origin(t)) diff --git a/typemap/typing.py b/typemap/typing.py index 670550d..6658fdd 100644 --- a/typemap/typing.py +++ b/typemap/typing.py @@ -315,3 +315,25 @@ def _BoolLiteral(self, tp): return tp return _BoolLiteralGenericAlias(Literal, tp) + + +class _LambdaGenericAlias(_GenericAlias, _root=True): # type: ignore[call-arg] + def __eq__(self, other: typing.Any) -> bool: + return ( + isinstance(other, _LambdaGenericAlias) + and self._func().__code__.co_code == other._func().__code__.co_code + ) + + def __hash__(self) -> int: + return hash(self._func().__code__.co_code) + + def __call__(self, *args: typing.Any, **kwargs: typing.Any) -> typing.Any: + return self._func()(*args, **kwargs) + + def _func(self) -> typing.Callable: + return typing.get_args(self)[0] + + +@_SpecialForm +def _Lambda(self, tp): + return _LambdaGenericAlias(self, tp) From c3c25cf3f47f051c78de29ba28e4800e75d8c95b Mon Sep 17 00:00:00 2001 From: dnwpark Date: Tue, 3 Feb 2026 12:59:27 -0800 Subject: [PATCH 2/2] Better handling of closures. --- tests/test_type_eval.py | 115 +++++++++++++++++++++++++++++++--------- typemap/typing.py | 69 ++++++++++++++++++++++-- 2 files changed, 156 insertions(+), 28 deletions(-) diff --git a/tests/test_type_eval.py b/tests/test_type_eval.py index 85a45d3..7627f77 100644 --- a/tests/test_type_eval.py +++ b/tests/test_type_eval.py @@ -1458,14 +1458,87 @@ def test_eval_lambda_01(): assert eval_typing(t(str)) is str +LambdaInt1 = _Lambda[lambda: int] +LambdaInt2 = _Lambda[lambda: int] +LambdaStr = _Lambda[lambda: str] + + def test_eval_lambda_02(): + # nested lambdas + a = _Lambda[lambda: _Lambda[lambda: int]] + assert a == _Lambda[lambda: _Lambda[lambda: int]] + assert eval_typing(a()) == _Lambda[lambda: int] + + assert a != _Lambda[lambda: _Lambda[lambda: str]] + + # lambda closure + b = _Lambda[lambda: int] + c = _Lambda[lambda: b] + d = _Lambda[lambda: int] + e = _Lambda[lambda: str] + assert c == _Lambda[lambda: d] + assert eval_typing(c()) == _Lambda[lambda: int] + + assert c != _Lambda[lambda: e] + + # lambda global + f = _Lambda[lambda: LambdaInt1] + assert f == _Lambda[lambda: LambdaInt2] + assert eval_typing(f()) == _Lambda[lambda: int] + + assert f != _Lambda[lambda: LambdaStr] + + +def test_eval_lambda_03(): # different lambdas with same bytecode are treated as the same + + assert eval_typing(_Lambda[lambda: int]) == eval_typing( + _Lambda[lambda: int] + ) + assert eval_typing(_Lambda[lambda: list[int]]) == eval_typing( + _Lambda[lambda: list[int]] + ) + a1 = lambda: int a2 = lambda: int - t1 = eval_typing(_Lambda[a1]) - t2 = eval_typing(_Lambda[a2]) - assert t1 == t2 + assert _Lambda[a1] == _Lambda[a2] + assert eval_typing(_Lambda[a1]) == eval_typing(_Lambda[a2]) + + l1 = Literal[1] + l2 = Literal[1] + + assert _Lambda[lambda: l1] == _Lambda[lambda: l2] + assert eval_typing(_Lambda[lambda: l1]) == eval_typing(_Lambda[lambda: l2]) + + +def test_eval_lambda_04(): + # different lambdas with different bytecode are treated as different + + assert eval_typing(_Lambda[lambda: int]) != eval_typing( + _Lambda[lambda: str] + ) + + def _f1(): + X = str + return lambda: X + + f1 = _f1() + + def _f2(): + X = int + return lambda: X + + f2 = _f2() + + assert _Lambda[f1] != _Lambda[f2] + assert eval_typing(_Lambda[f1]) != eval_typing(_Lambda[f2]) + + +def test_eval_lambda_05(): + # comparison operators + a1 = lambda: int + a2 = lambda: int t = eval_typing(IsSubtype[_Lambda[a1], _Lambda[a2]]) assert t == _BoolLiteral[True] @@ -1482,30 +1555,22 @@ def test_eval_lambda_02(): assert t == _BoolLiteral[True] -def test_eval_lambda_03(): - # different lambdas with different bytecode are treated as different - a1 = lambda: int - a2 = lambda: str - - t1 = eval_typing(_Lambda[a1]) - t2 = eval_typing(_Lambda[a2]) - assert t1 != t2 - - t = eval_typing(IsSubtype[_Lambda[a1], _Lambda[a2]]) - assert t == _BoolLiteral[False] - t = eval_typing(IsSubSimilar[_Lambda[a1], _Lambda[a2]]) - assert t == _BoolLiteral[False] - t = eval_typing(Matches[_Lambda[a1], _Lambda[a2]]) - assert t == _BoolLiteral[False] +def test_eval_lambda_06(): + # lambda captures non-hashable - t = eval_typing(IsSubtype[_Lambda[lambda T: T], _Lambda[lambda T: list[T]]]) - assert t == _BoolLiteral[False] - t = eval_typing( - IsSubSimilar[_Lambda[lambda T: T], _Lambda[lambda T: list[T]]] + # list is specially converted to tuple + a = [int, str, float] + b = [int, str, float] + assert ( + _Lambda[lambda: Callable[a, int]] == _Lambda[lambda: Callable[b, int]] ) - assert t == _BoolLiteral[False] - t = eval_typing(Matches[_Lambda[lambda T: T], _Lambda[lambda T: list[T]]]) - assert t == _BoolLiteral[False] + assert eval_typing(_Lambda[lambda: a]) == eval_typing(_Lambda[lambda: a]) + + # other non-hashables are only compared by id + c = {1, 2, 3} + d = {1, 2, 3} + assert _Lambda[lambda: c] != _Lambda[lambda: d] + assert eval_typing(_Lambda[lambda: c]) != eval_typing(_Lambda[lambda: d]) def test_eval_length_01(): diff --git a/typemap/typing.py b/typemap/typing.py index 6658fdd..54eee1e 100644 --- a/typemap/typing.py +++ b/typemap/typing.py @@ -320,12 +320,11 @@ def _BoolLiteral(self, tp): class _LambdaGenericAlias(_GenericAlias, _root=True): # type: ignore[call-arg] def __eq__(self, other: typing.Any) -> bool: return ( - isinstance(other, _LambdaGenericAlias) - and self._func().__code__.co_code == other._func().__code__.co_code + isinstance(other, _LambdaGenericAlias) and self.key() == other.key() ) def __hash__(self) -> int: - return hash(self._func().__code__.co_code) + return hash(self.key()) def __call__(self, *args: typing.Any, **kwargs: typing.Any) -> typing.Any: return self._func()(*args, **kwargs) @@ -333,6 +332,70 @@ def __call__(self, *args: typing.Any, **kwargs: typing.Any) -> typing.Any: def _func(self) -> typing.Callable: return typing.get_args(self)[0] + def key( + self, + ) -> tuple[ + tuple[bytes, tuple[typing.Any, ...], tuple[typing.Any, ...]], + tuple[typing.Any, ...], + ]: + return self._key(self._func()) + + @staticmethod + def _key( + func: typing.Callable, + ) -> tuple[ + tuple[bytes, tuple[typing.Any, ...], tuple[typing.Any, ...]], + tuple[typing.Any, ...], + ]: + import builtins + + def _encode_code( + code: types.CodeType, + ) -> tuple[bytes, tuple[typing.Any, ...], tuple[typing.Any, ...]]: + bytecode = code.co_code + consts = tuple( + _encode_code(c) if isinstance(c, types.CodeType) else c + for c in code.co_consts + ) + + globals = tuple( + func.__globals__.get(name, None) + or getattr(builtins, name, None) + for name in code.co_names + ) + + return (bytecode, consts, globals) + + if func.__closure__ is None: + closures: tuple[typing.Any, ...] = () + else: + closures = tuple( + # list is specially converted to tuple + tuple(cell.cell_contents) + if isinstance(cell.cell_contents, list) + else cell.cell_contents + if bool( + isinstance( + cell.cell_contents, + ( + type, + typing.TypeVar, + typing.ParamSpec, + typing.TypeVarTuple, + typing.TypeAliasType, + typing._SpecialForm, + ), + ) + or typing.get_origin(cell.cell_contents) is not None + ) + else cell.cell_contents.key() + if isinstance(cell.cell_contents, _LambdaGenericAlias) + else id(cell.cell_contents) + for cell in func.__closure__ + ) + + return (_encode_code(func.__code__), closures) + @_SpecialForm def _Lambda(self, tp):