Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 146 additions & 1 deletion tests/test_type_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@
GetSpecialAttr,
GetType,
GetAnnotations,
IsSubtype,
IsSub,
IsSubSimilar,
IsSubtype,
Iter,
Length,
Matches,
Expand All @@ -45,6 +46,7 @@
StrConcat,
Uppercase,
_BoolLiteral,
_Lambda,
)

from . import format_helper
Expand Down Expand Up @@ -1428,6 +1430,149 @@ def test_eval_bool_literal_error_01():
eval_typing(_BoolLiteral[int])


def test_eval_lambda_01():
type OnlyIntToSet[T] = set[T] if IsSub[T, int] else T

a = lambda: int
b = lambda T: T
c = lambda T: list[T]
d = lambda T: OnlyIntToSet[T]

t = eval_typing(_Lambda[a])
assert t == _Lambda[a]
assert eval_typing(t()) is int

t = eval_typing(_Lambda[b])
assert t == _Lambda[b]
assert eval_typing(t(int)) is int
assert eval_typing(t(str)) is str

t = eval_typing(_Lambda[c])
assert t == _Lambda[c]
assert eval_typing(t(int)) == list[int]
assert eval_typing(t(str)) == list[str]

t = eval_typing(_Lambda[d])
assert t == _Lambda[d]
assert eval_typing(t(int)) == set[int]
assert eval_typing(t(str)) is str


LambdaInt1 = _Lambda[lambda: int]
LambdaInt2 = _Lambda[lambda: int]
LambdaStr = _Lambda[lambda: str]


def test_eval_lambda_02():
# nested lambdas
a = _Lambda[lambda: _Lambda[lambda: int]]
assert a == _Lambda[lambda: _Lambda[lambda: int]]
assert eval_typing(a()) == _Lambda[lambda: int]

assert a != _Lambda[lambda: _Lambda[lambda: str]]

# lambda closure
b = _Lambda[lambda: int]
c = _Lambda[lambda: b]
d = _Lambda[lambda: int]
e = _Lambda[lambda: str]
assert c == _Lambda[lambda: d]
assert eval_typing(c()) == _Lambda[lambda: int]

assert c != _Lambda[lambda: e]

# lambda global
f = _Lambda[lambda: LambdaInt1]
assert f == _Lambda[lambda: LambdaInt2]
assert eval_typing(f()) == _Lambda[lambda: int]

assert f != _Lambda[lambda: LambdaStr]


def test_eval_lambda_03():
# different lambdas with same bytecode are treated as the same

assert eval_typing(_Lambda[lambda: int]) == eval_typing(
_Lambda[lambda: int]
)
assert eval_typing(_Lambda[lambda: list[int]]) == eval_typing(
_Lambda[lambda: list[int]]
)

a1 = lambda: int
a2 = lambda: int

assert _Lambda[a1] == _Lambda[a2]
assert eval_typing(_Lambda[a1]) == eval_typing(_Lambda[a2])

l1 = Literal[1]
l2 = Literal[1]

assert _Lambda[lambda: l1] == _Lambda[lambda: l2]
assert eval_typing(_Lambda[lambda: l1]) == eval_typing(_Lambda[lambda: l2])


def test_eval_lambda_04():
# different lambdas with different bytecode are treated as different

assert eval_typing(_Lambda[lambda: int]) != eval_typing(
_Lambda[lambda: str]
)

def _f1():
X = str
return lambda: X

f1 = _f1()

def _f2():
X = int
return lambda: X

f2 = _f2()

assert _Lambda[f1] != _Lambda[f2]
assert eval_typing(_Lambda[f1]) != eval_typing(_Lambda[f2])


def test_eval_lambda_05():
# comparison operators
a1 = lambda: int
a2 = lambda: int

t = eval_typing(IsSubtype[_Lambda[a1], _Lambda[a2]])
assert t == _BoolLiteral[True]
t = eval_typing(IsSubSimilar[_Lambda[a1], _Lambda[a2]])
assert t == _BoolLiteral[True]
t = eval_typing(Matches[_Lambda[a1], _Lambda[a2]])
assert t == _BoolLiteral[True]

t = eval_typing(IsSubtype[_Lambda[lambda T: T], _Lambda[lambda U: U]])
assert t == _BoolLiteral[True]
t = eval_typing(IsSubSimilar[_Lambda[lambda T: T], _Lambda[lambda U: U]])
assert t == _BoolLiteral[True]
t = eval_typing(Matches[_Lambda[lambda T: T], _Lambda[lambda U: U]])
assert t == _BoolLiteral[True]


def test_eval_lambda_06():
# lambda captures non-hashable

# list is specially converted to tuple
a = [int, str, float]
b = [int, str, float]
assert (
_Lambda[lambda: Callable[a, int]] == _Lambda[lambda: Callable[b, int]]
)
assert eval_typing(_Lambda[lambda: a]) == eval_typing(_Lambda[lambda: a])

# other non-hashables are only compared by id
c = {1, 2, 3}
d = {1, 2, 3}
assert _Lambda[lambda: c] != _Lambda[lambda: d]
assert eval_typing(_Lambda[lambda: c]) != eval_typing(_Lambda[lambda: d])


def test_eval_length_01():
d = eval_typing(Length[tuple[int, str]])
assert d == Literal[2]
Expand Down
2 changes: 1 addition & 1 deletion typemap/type_eval/_apply_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def make_func(
func.__globals__,
"__call__",
func.__defaults__,
(),
func.__closure__,
func.__kwdefaults__,
)

Expand Down
4 changes: 4 additions & 0 deletions typemap/type_eval/_subsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ def issubsimilar(lhs: typing.Any, rhs: typing.Any) -> bool:
):
return issubclass(lhs, rhs)

# lambda <:? lambda
elif _typing_inspect.is_lambda(lhs) or _typing_inspect.is_lambda(rhs):
return lhs == rhs

# literal <:? literal
elif _typing_inspect.is_literal(lhs) and _typing_inspect.is_literal(rhs):
# We need to check both value and type, since True == 1 but
Expand Down
4 changes: 4 additions & 0 deletions typemap/type_eval/_subtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ def issubtype(lhs: typing.Any, rhs: typing.Any) -> bool:
):
return issubclass(lhs, rhs)

# lambda <:? lambda
elif _typing_inspect.is_lambda(lhs) or _typing_inspect.is_lambda(rhs):
return lhs == rhs

# literal <:? literal
elif bool(
_typing_inspect.is_literal(lhs) and _typing_inspect.is_literal(rhs)
Expand Down
6 changes: 6 additions & 0 deletions typemap/type_eval/_typing_inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,12 @@ def is_literal(t: Any) -> bool:
return is_generic_alias(t) and get_origin(t) is Literal # type: ignore [comparison-overlap]


def is_lambda(t: Any) -> bool:
from typemap.typing import _Lambda

return is_generic_alias(t) and get_origin(t) is _Lambda


def get_head(t: Any) -> type | None:
if is_generic_alias(t):
return get_head(get_origin(t))
Expand Down
85 changes: 85 additions & 0 deletions typemap/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,3 +315,88 @@ def _BoolLiteral(self, tp):
return tp

return _BoolLiteralGenericAlias(Literal, tp)


class _LambdaGenericAlias(_GenericAlias, _root=True): # type: ignore[call-arg]
def __eq__(self, other: typing.Any) -> bool:
return (
isinstance(other, _LambdaGenericAlias) and self.key() == other.key()
)

def __hash__(self) -> int:
return hash(self.key())

def __call__(self, *args: typing.Any, **kwargs: typing.Any) -> typing.Any:
return self._func()(*args, **kwargs)

def _func(self) -> typing.Callable:
return typing.get_args(self)[0]

def key(
self,
) -> tuple[
tuple[bytes, tuple[typing.Any, ...], tuple[typing.Any, ...]],
tuple[typing.Any, ...],
]:
return self._key(self._func())

@staticmethod
def _key(
func: typing.Callable,
) -> tuple[
tuple[bytes, tuple[typing.Any, ...], tuple[typing.Any, ...]],
tuple[typing.Any, ...],
]:
import builtins

def _encode_code(
code: types.CodeType,
) -> tuple[bytes, tuple[typing.Any, ...], tuple[typing.Any, ...]]:
bytecode = code.co_code
consts = tuple(
_encode_code(c) if isinstance(c, types.CodeType) else c
for c in code.co_consts
)

globals = tuple(
func.__globals__.get(name, None)
or getattr(builtins, name, None)
for name in code.co_names
)

return (bytecode, consts, globals)

if func.__closure__ is None:
closures: tuple[typing.Any, ...] = ()
else:
closures = tuple(
# list is specially converted to tuple
tuple(cell.cell_contents)
if isinstance(cell.cell_contents, list)
else cell.cell_contents
if bool(
isinstance(
cell.cell_contents,
(
type,
typing.TypeVar,
typing.ParamSpec,
typing.TypeVarTuple,
typing.TypeAliasType,
typing._SpecialForm,
),
)
or typing.get_origin(cell.cell_contents) is not None
)
else cell.cell_contents.key()
if isinstance(cell.cell_contents, _LambdaGenericAlias)
else id(cell.cell_contents)
for cell in func.__closure__
)

return (_encode_code(func.__code__), closures)


@_SpecialForm
def _Lambda(self, tp):
return _LambdaGenericAlias(self, tp)