Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions tests/test_type_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
StrConcat,
StrSlice,
Uppercase,
bool_special_form,
)

from . import format_helper
Expand Down Expand Up @@ -769,3 +770,36 @@ def test_callable_to_signature():
'(_arg0: int, /, b: int, c: int = ..., *args: int, '
'd: int, e: int = ..., **kwargs: int) -> int'
)


@bool_special_form
class IsNotInt[T]:
__expr__ = not Is[T, int]


@bool_special_form
class IsNotStr[T]:
__expr__ = not Is[T, str]


@bool_special_form
class IsNotIntOrStr[T]:
__expr__ = IsNotInt[T] and IsNotStr[T]


type SetOfNotInt[T] = set[T] if IsNotInt[T] else T
type SetOfNotIntOrStr[T] = set[T] if IsNotIntOrStr[T] else T


def test_eval_if_generic_01():
t = eval_typing(SetOfNotInt[int])
assert t is int
t = eval_typing(SetOfNotInt[str])
assert t == set[str]

t = eval_typing(SetOfNotIntOrStr[int])
assert t is int
t = eval_typing(SetOfNotIntOrStr[str])
assert t is str
t = eval_typing(SetOfNotIntOrStr[float])
assert t == set[float]
56 changes: 50 additions & 6 deletions typemap/type_eval/_eval_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
from typing import Any

from . import _apply_generic
from ._special_form import (
BoolSpecialMetadata,
_bool_special_form_registry,
_special_form_evaluator,
)


__all__ = ("eval_typing",)
Expand Down Expand Up @@ -113,24 +118,20 @@ class EvalContext:

@contextlib.contextmanager
def _ensure_context() -> typing.Iterator[EvalContext]:
import typemap.typing as nt

ctx = _current_context.get()
ctx_set = False
if ctx is None:
ctx = EvalContext()
_current_context.set(ctx)
ctx_set = True
evaluator_token = nt.special_form_evaluator.set(
lambda t: _eval_types(t, ctx)
)
evaluator_token = _special_form_evaluator.set(lambda t: _eval_types(t, ctx))

try:
yield ctx
finally:
if ctx_set:
_current_context.set(None)
nt.special_form_evaluator.reset(evaluator_token)
_special_form_evaluator.reset(evaluator_token)


def _get_current_context() -> EvalContext:
Expand Down Expand Up @@ -349,13 +350,56 @@ def _eval_applied_type_alias(obj: types.GenericAlias, ctx: EvalContext):
return evaled


def _eval_bool_special_form(
metadata: BoolSpecialMetadata,
new_args: tuple[typing.Any, ...],
ctx: EvalContext,
) -> bool:
import ast

original_cls = metadata.cls

try:
namespace = {}

# Add the class's module
if cls_module := sys.modules.get(original_cls.__module__):
namespace.update(cls_module.__dict__)

# Add type parameters with their substituted values
type_params = metadata.type_params
if type_params and new_args:
for param, arg in zip(type_params, new_args, strict=False):
namespace[param.__name__] = arg

expr = compile(
ast.Expression(body=metadata.expr_node), # type: ignore[arg-type]
'<bool_expr>',
'eval',
)
bool_expr = eval(expr, namespace)

# Evaluate the type expression
result = _eval_types(bool_expr, ctx)

return result

except Exception as e:
raise RuntimeError(
f"Failed to evaluate special form for {original_cls.__name__}: {e}"
) from e


@_eval_types_impl.register
def _eval_applied_class(obj: typing_GenericAlias, ctx: EvalContext):
"""Eval a typing._GenericAlias -- an applied user-defined class"""
# generic *classes* are typing._GenericAlias while generic type
# aliases are types.GenericAlias? Why in the world.
new_args = tuple(_eval_types(arg, ctx) for arg in typing.get_args(obj))

if metadata := _bool_special_form_registry.get(obj.__origin__):
return _eval_bool_special_form(metadata, new_args, ctx)

if func := _eval_funcs.get(obj.__origin__):
ret = func(*new_args, ctx=ctx)
# return _eval_types(ret, ctx) # ???
Expand Down
102 changes: 102 additions & 0 deletions typemap/type_eval/_special_form.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import ast
import contextvars
import dataclasses
import typing
from typing import _GenericAlias # type: ignore


_SpecialForm: typing.Any = typing._SpecialForm


# TODO: type better
_special_form_evaluator: contextvars.ContextVar[
typing.Callable[[typing.Any], typing.Any] | None
] = contextvars.ContextVar("special_form_evaluator", default=None)


class _IterGenericAlias(_GenericAlias, _root=True): # type: ignore[call-arg]
def __iter__(self):
evaluator = _special_form_evaluator.get()
if evaluator:
return evaluator(self)
else:
return iter(typing.TypeVarTuple("_IterDummy"))


class _BoolGenericAlias(_GenericAlias, _root=True): # type: ignore[call-arg]
def __bool__(self):
evaluator = _special_form_evaluator.get()
if evaluator:
return evaluator(self)
else:
return False


_IsGenericAlias = _BoolGenericAlias


_bool_special_form_registry: dict[typing.Any, BoolSpecialMetadata] = {}


@dataclasses.dataclass(frozen=True, kw_only=True)
class BoolSpecialMetadata:
cls: type
type_params: tuple[type]
expr_node: ast.AST


def _register_bool_special_form(cls):
import inspect
import textwrap

type_params = getattr(cls, '__type_params__', ())

if '__expr__' not in cls.__dict__:
raise TypeError(f"{cls.__name__} must have an '__expr__' field")

# Parse __expr__ to get the assigned expression
source = inspect.getsource(cls)
source = textwrap.dedent(source)
tree = ast.parse(source)

expr_node = None
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef):
for item in node.body:
if isinstance(item, ast.AnnAssign):
# __expr__: SomeType = expression
if (
isinstance(item.target, ast.Name)
and item.target.id == '__expr__'
):
expr_node = item.value
break
elif isinstance(item, ast.Assign):
# __expr__ = expression
for target in item.targets:
if (
isinstance(target, ast.Name)
and target.id == '__expr__'
):
expr_node = item.value
break
if expr_node:
break
if expr_node:
break

if expr_node is None:
raise TypeError(f"Could not find __expr__ assignment in {cls.__name__}")

def impl_func(self, params):
return _BoolGenericAlias(self, params)

sf = _SpecialForm(impl_func)

_bool_special_form_registry[sf] = BoolSpecialMetadata(
cls=cls,
type_params=type_params,
expr_node=expr_node,
)

return sf
36 changes: 10 additions & 26 deletions typemap/typing.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import contextvars
import typing
from typing import _GenericAlias # type: ignore

_SpecialForm: typing.Any = typing._SpecialForm
from .type_eval._special_form import (
_IterGenericAlias,
_IsGenericAlias,
_SpecialForm,
_register_bool_special_form,
)

# Not type-level computation but related

Expand Down Expand Up @@ -113,35 +116,12 @@ class NewProtocol[*T]:

##################################################################

# TODO: type better
special_form_evaluator: contextvars.ContextVar[
typing.Callable[[typing.Any], typing.Any] | None
] = contextvars.ContextVar("special_form_evaluator", default=None)


class _IterGenericAlias(_GenericAlias, _root=True): # type: ignore[call-arg]
def __iter__(self):
evaluator = special_form_evaluator.get()
if evaluator:
return evaluator(self)
else:
return iter(typing.TypeVarTuple("_IterDummy"))


@_SpecialForm
def Iter(self, tp):
return _IterGenericAlias(self, (tp,))


class _IsGenericAlias(_GenericAlias, _root=True): # type: ignore[call-arg]
def __bool__(self):
evaluator = special_form_evaluator.get()
if evaluator:
return evaluator(self)
else:
return False


@_SpecialForm
def IsSubtype(self, tps):
return _IsGenericAlias(self, tps)
Expand All @@ -153,3 +133,7 @@ def IsSubSimilar(self, tps):


Is = IsSubSimilar


def bool_special_form(cls):
return _register_bool_special_form(cls)