Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion spec-draft.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ It's important that there be a clearly specified type language for the type-leve
Special forms unfortunately require some special handling: the arguments list of a ``Callable`` will be packed in a tuple, and a ``...`` will become ``SpecialFormEllipsis``.


* ``GetArgs[T, Base]`` - returns a tuple containing all of the type arguments of ``T`` when interpreted as ``Base``, or ``Never`` if it cannot be. (TODO: UNIMPLEMENTED)
* ``GetArgs[T, Base]`` - returns a tuple containing all of the type arguments of ``T`` when interpreted as ``Base``, or ``Never`` if it cannot be.
* ``FromUnion[T]`` - returns a tuple containing all of the union elements, or a 1-ary tuple containing T if it is not a union.


Expand Down
42 changes: 30 additions & 12 deletions tests/test_type_dir.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
import textwrap
from typing import Never, Literal, Union, TypeVar
import typing
from typing import Literal, Never, TypeVar, Union

from typemap.type_eval import eval_typing
from typemap.typing import (
NewProtocol,
Member,
Attrs,
FromUnion,
GetArg,
GetName,
GetQuals,
GetType,
Is,
Iter,
Attrs,
Member,
Members,
FromUnion,
NewProtocol,
Uppercase,
Is,
)

from . import format_helper


type OrGotcha[K] = K | Literal["gotcha!"]

type StrForInt[X] = (str | OrGotcha[X]) if X is int else (X | OrGotcha[X])
Expand All @@ -39,7 +40,9 @@ class Base[T]:
t: dict[str, StrForInt[T]]
kkk: K

def foo(self, a: T | None, b: int = 0) -> dict[str, T]:
fin: typing.Final[int]

def foo(self, a: T | None, *, b: int = 0) -> dict[str, T]:
pass

def base[Z](self, a: T | Z | None, b: K) -> dict[str, T | Z]:
Expand Down Expand Up @@ -75,14 +78,20 @@ class Final(Mine, Ordinary, Wrapper[float], AnotherBase[float], Last[int]):


type AllOptional[T] = NewProtocol[
*[Member[GetName[p], GetType[p] | None] for p in Iter[Attrs[T]]]
*[
Member[GetName[p], GetType[p] | None, GetQuals[p]]
for p in Iter[Attrs[T]]
]
]

type OptionalFinal = AllOptional[Final]


type Capitalize[T] = NewProtocol[
*[Member[Uppercase[GetName[p]], GetType[p]] for p in Iter[Attrs[T]]]
*[
Member[Uppercase[GetName[p]], GetType[p], GetQuals[p]]
for p in Iter[Attrs[T]]
]
]

type Prims[T] = NewProtocol[
Expand All @@ -102,6 +111,7 @@ class Final(Mine, Ordinary, Wrapper[float], AnotherBase[float], Last[int]):
if not Is[t, Literal]
]
],
GetQuals[p],
]
for p in Iter[Attrs[T]]
]
Expand Down Expand Up @@ -137,6 +147,7 @@ class Final(Mine, Ordinary, Wrapper[float], AnotherBase[float], Last[int]):
if not Is[IsLiteral[t], Literal[True]]
]
],
GetQuals[p],
]
for p in Iter[Attrs[T]]
]
Expand All @@ -152,9 +163,10 @@ class Final:
iii: str | int | typing.Literal['gotcha!']
t: dict[str, str | int | typing.Literal['gotcha!']]
kkk: ~K
fin: typing.Final[int]
x: tests.test_type_dir.Wrapper[int | None]
ordinary: str
def foo(self, a: int | None, b: int = 0) -> dict[str, int]: ...
def foo(self, a: int | None, *, b: int = 0) -> dict[str, int]: ...
def base[Z](self, a: int | Z | None, b: ~K) -> dict[str, int | Z]: ...
def cbase(cls, a: int | None, b: ~K) -> dict[str, int]: ...
def sbase[Z](cls, a: int | Literal['gotcha!'] | Z | None, b: ~K) -> dict[str, int | Z]: ...
Expand All @@ -172,6 +184,7 @@ class AllOptional[tests.test_type_dir.Final]:
iii: str | int | typing.Literal['gotcha!'] | None
t: dict[str, str | int | typing.Literal['gotcha!']] | None
kkk: ~K | None
fin: typing.Final[int | None]
x: tests.test_type_dir.Wrapper[int | None] | None
ordinary: str | None
""")
Expand All @@ -186,6 +199,7 @@ class Capitalize[tests.test_type_dir.Final]:
III: str | int | typing.Literal['gotcha!']
T: dict[str, str | int | typing.Literal['gotcha!']]
KKK: ~K
FIN: typing.Final[int]
X: tests.test_type_dir.Wrapper[int | None]
ORDINARY: str
""")
Expand All @@ -197,6 +211,7 @@ def test_type_dir_4():
assert format_helper.format_class(d) == textwrap.dedent("""\
class Prims[tests.test_type_dir.Final]:
last: int | typing.Literal[True]
fin: typing.Final[int]
ordinary: str
""")

Expand All @@ -211,6 +226,7 @@ class NoLiterals1[tests.test_type_dir.Final]:
iii: str | int
t: dict[str, str | int | typing.Literal['gotcha!']]
kkk: ~K
fin: typing.Final[int]
x: tests.test_type_dir.Wrapper[int | None]
ordinary: str
""")
Expand All @@ -225,6 +241,7 @@ class NoLiterals2[tests.test_type_dir.Final]:
iii: str | int
t: dict[str, str | int | typing.Literal['gotcha!']]
kkk: ~K
fin: typing.Final[int]
x: tests.test_type_dir.Wrapper[int | None]
ordinary: str
""")
Expand All @@ -243,7 +260,8 @@ def test_type_dir_7():
typing.Callable[[\
typemap.typing.Param[typing.Literal['self'], typing.Any, typing.Never], \
typemap.typing.Param[typing.Literal['a'], int | None, typing.Never], \
typemap.typing.Param[typing.Literal['b'], int, typing.Literal['=']]], \
typemap.typing.Param[typing.Literal['b'], int, typing.Literal['keyword', \
'default']]], \
dict[str, int]], typing.Literal['ClassVar'], tests.test_type_dir.Final]"
)

Expand Down
11 changes: 11 additions & 0 deletions tests/test_type_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Attrs,
FromUnion,
GetArg,
GetArgs,
GetAttr,
GetName,
GetType,
Expand Down Expand Up @@ -178,6 +179,16 @@ def test_getarg_never():
assert d is Never


def test_eval_getargs():
t = dict[int, str]
args = eval_typing(GetArgs[t, dict])
assert args == tuple[int, str]

t = dict
args = eval_typing(GetArgs[t, dict])
assert args == tuple[Any, Any]


def test_eval_getarg_callable():
# oh hmmmmmmm -- yeah maybe callable could be fully bespoke if we
# disallowed putting Callable here...!
Expand Down
7 changes: 6 additions & 1 deletion typemap/type_eval/_apply_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,12 @@ def make_func(
annos: dict[str, Any],
) -> types.FunctionType:
new_func = types.FunctionType(
func.__code__, func.__globals__, "__call__", func.__defaults__, ()
func.__code__,
func.__globals__,
"__call__",
func.__defaults__,
(),
func.__kwdefaults__,
)

new_func.__module__ = func.__module__
Expand Down
63 changes: 52 additions & 11 deletions typemap/type_eval/_eval_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Capitalize,
FromUnion,
GetArg,
GetArgs,
GetAttr,
IsSubSimilar,
IsSubtype,
Expand Down Expand Up @@ -55,8 +56,24 @@ def get_annotated_type_hints(cls, **kwargs):
continue
for k in acls.__annotations__:
if k not in hints:
# XXX: TODO: Strip ClassVar/Final
hints[k] = ohints[k], (), acls
quals = set()
ty = ohints[k]

# Strip ClassVar/Final from ty and add them to quals
while True:
for form in [typing.ClassVar, typing.Final]:
if _typing_inspect.is_special_form(ty, form):
quals.add(form.__name__)
ty = (
typing.get_args(ty)[0]
if typing.get_args(ty)
else typing.Any
)
break
else:
break

hints[k] = ty, tuple(sorted(quals)), acls

# Stop early if we are done.
if len(hints) == len(ohints):
Expand Down Expand Up @@ -186,8 +203,10 @@ def _ann(x):
quals.append("*")
if p.kind == inspect.Parameter.VAR_KEYWORD:
quals.append("**")
if p.kind == inspect.Parameter.KEYWORD_ONLY:
quals.append("keyword")
if p.default is not empty:
quals.append("=")
quals.append("default")
params.append(
Param[
typing.Literal[p.name if has_name else None],
Expand Down Expand Up @@ -247,12 +266,10 @@ def _eval_GetAttr(lhs, prop, *, ctx):
return typing.get_type_hints(lhs)[name]


def _get_args(tp, base, ctx) -> typing.Any:
# XXX: check against base!!
def _get_raw_args(tp, base_head, ctx) -> typing.Any:
evaled = _eval_types(tp, ctx)

tp_head = _typing_inspect.get_head(tp)
base_head = _typing_inspect.get_head(base)
if not tp_head or not base_head:
return None

Expand All @@ -271,6 +288,14 @@ def _get_args(tp, base, ctx) -> typing.Any:
return None


def _get_args(tp, base, ctx) -> typing.Any:
base_head = _typing_inspect.get_head(base)
args = _get_raw_args(tp, base, ctx)
if args == ():
args = _get_defaults(base_head)
return args


def _fix_type(tp):
"""Fix up a type getting returned from GetArg

Expand Down Expand Up @@ -383,8 +408,6 @@ def _get_defaults(base_head):
def _eval_GetArg(tp, base, idx, *, ctx) -> typing.Any:
base_head = _typing_inspect.get_head(base)
args = _get_args(tp, base_head, ctx)
if args == ():
args = _get_defaults(base_head)
if args is None:
return typing.Never

Expand All @@ -394,6 +417,16 @@ def _eval_GetArg(tp, base, idx, *, ctx) -> typing.Any:
return typing.Never


@type_eval.register_evaluator(GetArgs)
@_lift_over_unions
def _eval_GetArgs(tp, base, *, ctx) -> typing.Any:
base_head = _typing_inspect.get_head(base)
args = _get_args(tp, base_head, ctx)
if args is None:
return typing.Never
return tuple[*args] # type: ignore[valid-type]


@type_eval.register_evaluator(Length)
@_lift_over_unions
def _eval_Length(tp, *, ctx) -> typing.Any:
Expand Down Expand Up @@ -428,15 +461,23 @@ def func(*args, ctx):
##################################################################


def _add_quals(typ, quals):
for qual in (typing.ClassVar, typing.Final):
if type_eval.issubsimilar(typing.Literal[qual.__name__], quals):
typ = qual[typ]
return typ


@type_eval.register_evaluator(NewProtocol)
def _eval_NewProtocol(*etyps: Member, ctx):
dct: dict[str, object] = {}
dct["__annotations__"] = {
# XXX: Should eval_typing on the etyps evaluate the arguments??
_from_literal(typing.get_args(prop)[0], ctx): _eval_types(
typing.get_args(prop)[1], ctx
_from_literal(name, ctx): _add_quals(
_eval_types(typ, ctx),
_eval_types(quals, ctx),
)
for prop in etyps
for name, typ, quals, _ in (typing.get_args(prop) for prop in etyps)
}

module_name = __name__
Expand Down
7 changes: 6 additions & 1 deletion typemap/type_eval/_eval_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,12 @@ def _eval_func(func: types.FunctionType | types.MethodType, ctx: EvalContext):
annos = {name: _eval_types(tp, ctx) for name, tp in annos.items()}

new_func = types.FunctionType(
root.__code__, root.__globals__, "__call__", root.__defaults__, ()
root.__code__,
root.__globals__,
"__call__",
root.__defaults__,
(),
root.__kwdefaults__,
)

new_func.__name__ = root.__name__
Expand Down
25 changes: 17 additions & 8 deletions typemap/type_eval/_typing_inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,37 @@


import typing

from typing import (
from types import GenericAlias, UnionType
from typing import ( # type: ignore [attr-defined] # noqa: PLC2701
Annotated,
Any,
ClassVar,
ForwardRef,
Literal,
TypeGuard,
TypeVar,
Union,
_GenericAlias,
_SpecialGenericAlias,
get_args,
get_origin,
)
from typing import _GenericAlias, _SpecialGenericAlias # type: ignore [attr-defined] # noqa: PLC2701

from typing_extensions import TypeAliasType, TypeVarTuple, Unpack
from types import GenericAlias, UnionType

from . import _eval_typing


def is_classvar(t: Any) -> bool:
return t is ClassVar or (is_generic_alias(t) and get_origin(t) is ClassVar) # type: ignore [comparison-overlap]
def is_special_form(t: Any, form: Any) -> bool:
"""Check if t is a special form or a generic alias of that form.

Args:
t: The type to check
form: The special form to check against (e.g., ClassVar, Final, Literal)

Returns:
True if t is the special form or a generic alias with that origin
"""
return t is form or (is_generic_alias(t) and get_origin(t) is form) # type: ignore [comparison-overlap]


def is_generic_alias(t: Any) -> TypeGuard[GenericAlias]:
Expand Down Expand Up @@ -142,12 +151,12 @@ def is_eval_proxy(t: Any) -> TypeGuard[type[_eval_typing._EvalProxy]]:

__all__ = (
"is_annotated",
"is_classvar",
"is_forward_ref",
"is_generic_alias",
"is_generic_type_alias",
"is_literal",
"is_optional_type",
"is_special_form",
"is_type_alias",
"is_union_type",
)
Loading