From 422ef32877443ef47f9dcc3cee6747f9574749c9 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 11 Mar 2026 13:46:29 -0400 Subject: [PATCH 1/9] Support typing.Self --- effectful/internals/unification.py | 63 +++++++++ effectful/ops/types.py | 3 +- tests/test_internals_unification.py | 199 ++++++++++++++++++++++++++++ tests/test_ops_semantics.py | 84 ++++++++++++ 4 files changed, 348 insertions(+), 1 deletion(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index 71d6583f2..0e7f27fa6 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -94,6 +94,60 @@ Substitutions = collections.abc.Mapping[TypeVariable, TypeExpressions] +def _has_typing_self(typ) -> bool: + """Check if typing.Self appears anywhere in a type expression.""" + if typ is typing.Self: + return True + if isinstance(typ, inspect.Signature): + if _has_typing_self(typ.return_annotation): + return True + for p in typ.parameters.values(): + if _has_typing_self(p.annotation): + return True + return False + if isinstance(typ, list): + return any(_has_typing_self(item) for item in typ) + for arg in typing.get_args(typ): + if _has_typing_self(arg): + return True + return False + + +def _replace_self(typ, self_tv: typing.TypeVar): + """Replace all occurrences of typing.Self with the given TypeVar.""" + if typ is typing.Self: + return self_tv + elif isinstance(typ, inspect.Signature): + new_params = [] + for i, p in enumerate(typ.parameters.values()): + if _has_typing_self(p.annotation): + new_params.append( + p.replace(annotation=_replace_self(p.annotation, self_tv)) + ) + elif i == 0 and p.annotation is inspect.Parameter.empty: + new_params.append(p.replace(annotation=self_tv)) + else: + new_params.append(p) + new_ret = ( + _replace_self(typ.return_annotation, self_tv) + if _has_typing_self(typ.return_annotation) + else typ.return_annotation + ) + return typ.replace(parameters=new_params, return_annotation=new_ret) + elif isinstance(typ, list): + return [_replace_self(item, self_tv) for item in typ] + args = typing.get_args(typ) + if not args: + return typ + new_args = tuple(_replace_self(a, self_tv) for a in args) + if new_args == args: + return typ + origin = typing.get_origin(typ) + if origin is not None: + return origin[new_args] + return typ + + @dataclass class Box[T]: """Boxed types. Prevents confusion between types computed by __type_rule__ @@ -347,6 +401,12 @@ def _unify_signature( if typ != subtyp.signature: raise TypeError(f"Cannot unify {typ} with {subtyp} given {subs}. ") + if _has_typing_self(typ): + self_tv = typing.TypeVar("Self") + typ = _replace_self(typ, self_tv) + subtyp = typ.bind(*subtyp.args, **subtyp.kwargs) + return {**unify(typ, subtyp, subs), typing.Self: self_tv} # type: ignore + for name, param in typ.parameters.items(): if param.annotation is inspect.Parameter.empty: continue @@ -913,6 +973,9 @@ def substitute(typ, subs: Substitutions) -> TypeExpressions: >>> substitute(int, {T: str}) """ + if typing.Self in subs and _has_typing_self(typ): + return substitute(_replace_self(typ, subs[typing.Self]), subs) + if isinstance(typ, typing.TypeVar | typing.ParamSpec | typing.TypeVarTuple): return substitute(subs[typ], subs) if typ in subs else typ elif isinstance(typ, list | tuple): diff --git a/effectful/ops/types.py b/effectful/ops/types.py index 975c66dce..52d891f4a 100644 --- a/effectful/ops/types.py +++ b/effectful/ops/types.py @@ -380,6 +380,7 @@ def __type_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> type[V]: """ from effectful.internals.unification import ( + _has_typing_self, freetypevars, nested_type, substitute, @@ -394,7 +395,7 @@ def __type_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> type[V]: return typing.cast(type[V], object) elif return_anno is None: return type(None) # type: ignore - elif not freetypevars(return_anno): + elif not freetypevars(return_anno) and not _has_typing_self(return_anno): return return_anno type_args = tuple(nested_type(a).value for a in args) diff --git a/tests/test_internals_unification.py b/tests/test_internals_unification.py index 18b158d85..61a1fcabc 100644 --- a/tests/test_internals_unification.py +++ b/tests/test_internals_unification.py @@ -7,6 +7,7 @@ from effectful.internals.unification import ( Box, + _has_typing_self, canonicalize, freetypevars, nested_type, @@ -137,6 +138,8 @@ class GenericClass[T]: (int, {T: str}, int), (str, {}, str), (list[int], {T: str}, list[int]), + # typing.Self with no binding passes through unchanged + (typing.Self, {}, typing.Self), # Single TypeVar in generic (list[T], {T: int}, list[int]), (set[T], {T: str}, set[str]), @@ -488,6 +491,44 @@ def variadic_kwargs_func[T](**kwargs: T) -> T: # Variadic kwargs not supported return next(iter(kwargs.values())) +class _Foo: + def return_self(self) -> typing.Self: + return self + + def return_list_self(self) -> list[typing.Self]: + return [self] + + def return_self_or_none(self) -> typing.Self | None: + return self + + def annotated_self(self: typing.Self) -> typing.Self: + return self + + def takes_other(self, other: typing.Self) -> typing.Self: + return self + + def mixed_with_typevar[T](self, x: T) -> tuple[typing.Self, T]: + return (self, x) + + def return_dict_self(self) -> dict[str, typing.Self]: + return {"me": self} + + def return_callable_self(self) -> collections.abc.Callable[[typing.Self], int]: + return id + + def return_type_self(self) -> type[typing.Self]: + return type(self) + + @classmethod + def from_config(cls, config: int) -> typing.Self: + return cls() + + +class _Bar: + def return_self(self) -> typing.Self: + return self + + @pytest.mark.parametrize( "func,args,kwargs,expected_return_type", [ @@ -565,6 +606,32 @@ def variadic_kwargs_func[T](**kwargs: T) -> T: # Variadic kwargs not supported (variadic_args_func, (int, int), {}, int), (variadic_kwargs_func, (), {"x": int}, int), (variadic_kwargs_func, (), {"x": int, "y": int}, int), + # typing.Self return types (methods) + (_Foo.return_self, (int,), {}, int), + (_Foo.return_self, (str,), {}, str), + (_Foo.return_self, (_Foo,), {}, _Foo), + (_Foo.return_self, (_Bar,), {}, _Bar), + (_Bar.return_self, (_Bar,), {}, _Bar), + (_Foo.annotated_self, (_Foo,), {}, _Foo), + (_Foo.return_list_self, (int,), {}, list[int]), + (_Foo.return_list_self, (_Foo,), {}, list[_Foo]), + (_Foo.return_self_or_none, (int,), {}, int | None), + (_Foo.return_self_or_none, (_Foo,), {}, _Foo | None), + # Self as a non-self parameter + (_Foo.takes_other, (int, int), {}, int), + (_Foo.takes_other, (_Foo, _Foo), {}, _Foo), + # Self mixed with other TypeVars + (_Foo.mixed_with_typevar, (int, str), {}, tuple[int, str]), + (_Foo.mixed_with_typevar, (_Foo, list[int]), {}, tuple[_Foo, list[int]]), + # Self in dict[str, Self] + (_Foo.return_dict_self, (int,), {}, dict[str, int]), + (_Foo.return_dict_self, (_Foo,), {}, dict[str, _Foo]), + # Self inside Callable[[Self], int] + (_Foo.return_callable_self, (int,), {}, collections.abc.Callable[[int], int]), + (_Foo.return_callable_self, (_Foo,), {}, collections.abc.Callable[[_Foo], int]), + # type[Self] + (_Foo.return_type_self, (int,), {}, type[int]), + (_Foo.return_type_self, (_Foo,), {}, type[_Foo]), ], ) def test_infer_return_type_success( @@ -1698,3 +1765,135 @@ def test_binary_on_sequence_elements(f, seq, index1, index2): ), collections.abc.Mapping, ) + + +# ============================================================ +# typing.Self resolution tests +# ============================================================ + + +# --- _has_typing_self --- + + +@pytest.mark.parametrize( + "typ,expected", + [ + (typing.Self, True), + (list[typing.Self], True), # type: ignore[misc] + (typing.Self | None, True), + (dict[str, typing.Self], True), # type: ignore[misc] + (collections.abc.Callable[[typing.Self], int], True), + (type[typing.Self], True), + (int, False), + (list[int], False), + (T, False), + (list[T], False), + ], +) +def test_has_typing_self(typ, expected): + assert _has_typing_self(typ) == expected + + +# --- chaining two signatures with Self from different "classes" --- + + +def test_chained_self_signatures(): + """Two unify calls sharing subs must not conflate Self.""" + sig_a = inspect.signature(_Foo.return_self) + sig_b = inspect.signature(_Bar.return_self) + + subs = unify(sig_a, sig_a.bind(_Foo)) + assert canonicalize(substitute(sig_a.return_annotation, subs)) == _Foo + + # Chaining: second unify with shared subs must not break + subs2 = unify(sig_b, sig_b.bind(_Bar), subs) + assert canonicalize(substitute(sig_b.return_annotation, subs2)) == _Bar + + +# --- classmethod with Self: cls is stripped, Self stays unresolved --- + + +def test_classmethod_self_not_resolved(): + """Classmethod Self stays unresolved when cls is stripped. + + inspect.signature strips `cls`, so `from_config(config: int) -> Self` + has no unannotated first parameter. The Self TypeVar is created but + nothing binds it, so it remains free in the substitution result. + """ + sig = inspect.signature(_Foo.from_config) + subs = unify(sig, sig.bind(int)) + result = substitute(sig.return_annotation, subs) + # Self was replaced with a TypeVar, but nothing bound it. + assert isinstance(result, typing.TypeVar) + assert result.__name__ == "Self" + + +# --- composition tests with Self --- + + +@pytest.mark.parametrize("obj_type", [_Foo, _Bar, int, str]) +def test_infer_self_composition_1(obj_type): + """Compose return_list_self -> get_first, verify matches return_self. + + Step 1: return_list_self(obj) -> list[Self] (Self method) + Step 2: get_first(list[T]) -> T (generic function) + Direct: return_self(obj) -> Self + """ + sig1 = inspect.signature(_Foo.return_list_self) + sig2 = inspect.signature(get_first) + sig_direct = inspect.signature(_Foo.return_self) + + # Step 1: infer list[Self] with Self bound to obj_type + inferred_type1 = substitute( + sig1.return_annotation, + unify(sig1, sig1.bind(obj_type)), + ) + + # Step 2: get_first(list[obj_type]) -> obj_type + inferred_type2 = substitute( + sig2.return_annotation, + unify(sig2, sig2.bind(nested_type(Box(inferred_type1)).value)), + ) + + # Direct: return_self(obj_type) -> obj_type + inferred_direct = substitute( + sig_direct.return_annotation, + unify(sig_direct, sig_direct.bind(obj_type)), + ) + + # The composed inference should match the direct inference + assert isinstance(unify(inferred_type2, inferred_direct), collections.abc.Mapping) + + +@pytest.mark.parametrize("obj_type", [_Foo, _Bar, int, str]) +def test_infer_self_composition_2(obj_type): + """Compose identity -> return_list_self, verify matches wrap_in_list. + + Step 1: identity(x: T) -> T (generic function) + Step 2: return_list_self(self) -> list[Self] (Self method) + Direct: wrap_in_list(x: T) -> list[T] + """ + sig1 = inspect.signature(identity) + sig2 = inspect.signature(_Foo.return_list_self) + sig_direct = inspect.signature(wrap_in_list) + + # Step 1: identity(obj_type) -> obj_type + inferred_type1 = substitute( + sig1.return_annotation, + unify(sig1, sig1.bind(obj_type)), + ) + + # Step 2: return_list_self(obj_type) -> list[obj_type] + inferred_type2 = substitute( + sig2.return_annotation, + unify(sig2, sig2.bind(nested_type(Box(inferred_type1)).value)), + ) + + # Direct: wrap_in_list(obj_type) -> list[obj_type] + inferred_direct = substitute( + sig_direct.return_annotation, + unify(sig_direct, sig_direct.bind(obj_type)), + ) + + # The composed inference should match the direct inference + assert isinstance(unify(inferred_type2, inferred_direct), collections.abc.Mapping) diff --git a/tests/test_ops_semantics.py b/tests/test_ops_semantics.py index 79b806633..82a900a41 100644 --- a/tests/test_ops_semantics.py +++ b/tests/test_ops_semantics.py @@ -2,6 +2,7 @@ import functools import itertools import logging +import typing from collections.abc import Callable, Mapping from typing import Annotated, Any, Literal, Union @@ -863,3 +864,86 @@ def get_mixed() -> Literal[1, "a"]: with pytest.raises(TypeError, match="Union types are not supported"): typeof(get_mixed()) + + +# --- Module-level classes for typing.Self tests --- +# Must be at module level so @defop can resolve annotations via get_type_hints. + + +class _SelfA: + @defop + def ret_self(self, x: int) -> typing.Self: + raise NotHandled + + @defop + def ret_list_self(self, x: int) -> list[typing.Self]: + raise NotHandled + + @defop + def annotated_self(self: typing.Self, x: int) -> typing.Self: + raise NotHandled + + @defop + def ret_self_or_none(self, x: int) -> typing.Self | None: + raise NotHandled + + +class _SelfB: + @defop + def ret_self(self, x: int) -> typing.Self: + raise NotHandled + + +def test_typeof_self_basic(): + """typeof resolves typing.Self to the type of the first argument.""" + obj = _SelfA() + assert typeof(_SelfA.ret_self(obj, 42)) is _SelfA + + +def test_typeof_self_list(): + """typeof resolves list[Self] to list (origin type).""" + obj = _SelfA() + assert typeof(_SelfA.ret_list_self(obj, 42)) is list + + +def test_typeof_self_annotated_param(): + """Self as both the self-parameter annotation and return type.""" + obj = _SelfA() + assert typeof(_SelfA.annotated_self(obj, 42)) is _SelfA + + +def test_typeof_self_two_classes(): + """Self resolves independently per class.""" + a, b = _SelfA(), _SelfB() + assert typeof(_SelfA.ret_self(a, 42)) is _SelfA + assert typeof(_SelfB.ret_self(b, 42)) is _SelfB + + +def test_typeof_self_nested_polymorphic(): + """Self composes with a polymorphic identity operation.""" + + @defop + def identity[T](x: T) -> T: + raise NotHandled + + obj = _SelfA() + assert typeof(identity(_SelfA.ret_self(obj, 42))) is _SelfA + + +@pytest.mark.xfail(reason="Union types are not yet supported") +def test_typeof_self_union(): + """Self | None is a union return type — unsupported.""" + obj = _SelfA() + typeof(_SelfA.ret_self_or_none(obj, 42)) + + +class _SelfClassmethod: + @defop + @classmethod + def cls_ret(cls) -> typing.Self: + raise NotHandled + + +def test_typeof_self_classmethod(): + """Classmethod with Self — cls is stripped so Self is unresolved but does not crash.""" + typeof(_SelfClassmethod.cls_ret()) From 7bd874c495d6670338c72045abac002d9a59d032 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 5 May 2026 17:34:48 -0400 Subject: [PATCH 2/9] checkpoint --- effectful/internals/unification.py | 79 ++++++++++++++++++++++++++++++ effectful/ops/types.py | 26 ++-------- 2 files changed, 82 insertions(+), 23 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index 48f748718..361aa4c35 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -613,6 +613,85 @@ def _unify_generic(typ, subtyp, subs: Substitutions) -> Substitutions: raise TypeError(f"Cannot unify generic type {typ} with {subtyp} given {subs}.") +def infer_return_type(bound_sig: inspect.BoundArguments) -> TypeExpressions: + """Infer the return type of a function from its bound arguments.""" + typ = _replace_self( + _freshen(_sig_to_type(bound_sig.signature)), typing.TypeVar("Self") + ) + subtyp = _freshen(_bound_sig_to_type(bound_sig)) + return substitute(bound_sig.signature.return_annotation, unify(typ, subtyp)) + + +def _sig_to_type(sig: inspect.Signature) -> TypeExpression: + """Convert an inspect.Signature to a type expression.""" + if sig.return_annotation is inspect.Parameter.empty: + return _sig_to_type(sig.replace(return_annotation=typing.Any)) + elif sig.return_annotation is None: + return _sig_to_type(sig.replace(return_annotation=type(None))) + elif typing.get_origin(sig.return_annotation) is typing.Annotated: + return _sig_to_type( + sig.replace(return_annotation=typing.get_args(sig.return_annotation)[0]) + ) + + annotations: dict[str, TypeExpressions] = { + "return": typing.NotRequired[sig.return_annotation] + } + for name, param in sig.parameters.items(): + if param.kind == inspect.Parameter.VAR_POSITIONAL: + annotations[name] = tuple[param.annotation, ...] + elif param.kind == inspect.Parameter.VAR_KEYWORD: + annotations[name] = dict[str, param.annotation] + elif param.kind not in { + inspect.Parameter.VAR_KEYWORD, + inspect.Parameter.VAR_POSITIONAL, + }: + annotations[name] = param.annotation + + return typing.TypedDict(f"{sig}_Type", annotations) + + +def _bound_sig_to_type(bound_sig: inspect.BoundArguments) -> TypeExpression: + """Convert an inspect.BoundArguments to a type expression for unification.""" + sig: inspect.Signature = bound_sig.signature + typed_arguments = sig.bind( + *[nested_type(a).value for a in bound_sig.args], + **{k: nested_type(v).value for k, v in bound_sig.kwargs.items()}, + ).arguments + + annotations: dict[str, TypeExpressions] = {} + for name, param in sig.parameters.items(): + if param.annotation is inspect.Parameter.empty: + continue + + if name not in typed_arguments: + assert param.kind in { + inspect.Parameter.VAR_POSITIONAL, + inspect.Parameter.VAR_KEYWORD, + } + continue + + psubtyp = typed_arguments[name] + if param.kind == inspect.Parameter.VAR_POSITIONAL and isinstance( + psubtyp, collections.abc.Sequence + ): + annotations[name] = tuple[psubtyp] + elif param.kind == inspect.Parameter.VAR_KEYWORD and isinstance( + psubtyp, collections.abc.Mapping + ): + annotations[name] = typing.TypedDict(f"{name}BoundKwargs", psubtyp) + elif param.kind not in { + inspect.Parameter.VAR_KEYWORD, + inspect.Parameter.VAR_POSITIONAL, + } or isinstance(psubtyp, typing.ParamSpecArgs | typing.ParamSpecKwargs): + annotations[name] = psubtyp + else: + raise TypeError( + f"Cannot unify parameter {param} with argument {psubtyp} in signature unification." + ) + + return typing.TypedDict("BoundSigType", annotations) + + def _unify_signature( typ: inspect.Signature, subtyp: inspect.BoundArguments, subs: Substitutions ) -> Substitutions: diff --git a/effectful/ops/types.py b/effectful/ops/types.py index 1591e6624..91d0694d2 100644 --- a/effectful/ops/types.py +++ b/effectful/ops/types.py @@ -379,30 +379,10 @@ def __type_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> type[V]: allows for terms that compute on type-valued arguments. """ - from effectful.internals.unification import ( - _has_typing_self, - freetypevars, - nested_type, - substitute, - unify, - ) + from effectful.internals.unification import infer_return_type - return_anno = self.__signature__.return_annotation - if typing.get_origin(return_anno) is typing.Annotated: - return_anno = typing.get_args(return_anno)[0] - - if return_anno is inspect.Parameter.empty: - return typing.cast(type[V], object) - elif return_anno is None: - return type(None) # type: ignore - elif not freetypevars(return_anno) and not _has_typing_self(return_anno): - return return_anno - - type_args = tuple(nested_type(a).value for a in args) - type_kwargs = {k: nested_type(v).value for k, v in kwargs.items()} - bound_sig = self.__signature__.bind(*type_args, **type_kwargs) - subst_type = substitute(return_anno, unify(self.__signature__, bound_sig)) - return typing.cast(type[V], subst_type) + bound_sig = self.__signature__.bind(*args, **kwargs) + return typing.cast(type[V], infer_return_type(bound_sig)) @typing.final def __fvs_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> inspect.BoundArguments: From 1a8593840b39ba967e1e75f2505c68364358de44 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 5 May 2026 19:36:14 -0400 Subject: [PATCH 3/9] checkpoint, tests pass --- effectful/internals/unification.py | 316 +++++++------- tests/test_internals_unification.py | 634 +++++++++++++++------------- tests/test_ops_semantics.py | 8 +- 3 files changed, 501 insertions(+), 457 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index 361aa4c35..9fa46cae8 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -9,7 +9,7 @@ 1. **unify(typ, subtyp, subs={})**: The core unification algorithm that attempts to find a substitution mapping for type variables that makes a pattern type equal to a concrete type. It handles TypeVars, generic types (List[T], Dict[K,V]), unions, - callables, and function signatures with inspect.Signature/BoundArguments. + TypedDicts, and structural collections. 2. **substitute(typ, subs)**: Applies a substitution mapping to a type expression, replacing all TypeVars with their mapped concrete types. This is used to @@ -23,13 +23,19 @@ collections by recursively determining element types. For example, [1, 2, 3] becomes list[int], and {"key": [1, 2]} becomes dict[str, list[int]]. +Function signatures are not first-class to ``unify``. They are converted to +TypedDict patterns via :func:`_sig_to_type` (with ``typing.Self`` eliminated by +:class:`SelfTypeReplacer` at the boundary), and the resulting TypedDicts are +unified through the regular path. :func:`infer_return_type` is the entry point +that ties this together for callable invocations. + The unification algorithm uses a single-dispatch pattern to handle different type combinations: - TypeVar unification binds variables to concrete types - Generic type unification matches origins and recursively unifies type arguments - Structural unification handles sequences and mappings by element - Union types attempt unification with any matching branch -- Function signatures unify parameter types with bound arguments +- TypedDict unification matches fields, with NotRequired / ReadOnly semantics Example usage: >>> from effectful.internals.unification import unify, substitute, freetypevars @@ -60,13 +66,13 @@ import builtins import collections import collections.abc +import dataclasses import functools import inspect import numbers import operator import types import typing -from dataclasses import dataclass try: from typing import _collect_type_parameters as _freetypevars # type: ignore @@ -94,61 +100,7 @@ Substitutions = collections.abc.Mapping[TypeVariable, TypeExpressions] -def _has_typing_self(typ) -> bool: - """Check if typing.Self appears anywhere in a type expression.""" - if typ is typing.Self: - return True - if isinstance(typ, inspect.Signature): - if _has_typing_self(typ.return_annotation): - return True - for p in typ.parameters.values(): - if _has_typing_self(p.annotation): - return True - return False - if isinstance(typ, list): - return any(_has_typing_self(item) for item in typ) - for arg in typing.get_args(typ): - if _has_typing_self(arg): - return True - return False - - -def _replace_self(typ, self_tv: typing.TypeVar): - """Replace all occurrences of typing.Self with the given TypeVar.""" - if typ is typing.Self: - return self_tv - elif isinstance(typ, inspect.Signature): - new_params = [] - for i, p in enumerate(typ.parameters.values()): - if _has_typing_self(p.annotation): - new_params.append( - p.replace(annotation=_replace_self(p.annotation, self_tv)) - ) - elif i == 0 and p.annotation is inspect.Parameter.empty: - new_params.append(p.replace(annotation=self_tv)) - else: - new_params.append(p) - new_ret = ( - _replace_self(typ.return_annotation, self_tv) - if _has_typing_self(typ.return_annotation) - else typ.return_annotation - ) - return typ.replace(parameters=new_params, return_annotation=new_ret) - elif isinstance(typ, list): - return [_replace_self(item, self_tv) for item in typ] - args = typing.get_args(typ) - if not args: - return typ - new_args = tuple(_replace_self(a, self_tv) for a in args) - if new_args == args: - return typ - origin = typing.get_origin(typ) - if origin is not None: - return origin[new_args] - return typ - - -@dataclass +@dataclasses.dataclass class Box[T]: """Boxed types. Prevents confusion between types computed by __type_rule__ and values. @@ -238,24 +190,14 @@ def _(self, typ: typing.ForwardRef): else: return typ - -@typing.overload -def unify( - typ: inspect.Signature, - subtyp: inspect.BoundArguments, - subs: Substitutions = {}, -) -> Substitutions: ... + @evaluate.register # type: ignore[arg-type] + def _(self, typ: typing._SpecialForm) -> TypeExpressions: # type: ignore[type-arg] + return typing.cast(TypeExpressions, typ) -@typing.overload def unify( - typ: TypeExpressions, - subtyp: TypeExpressions, - subs: Substitutions = {}, -) -> Substitutions: ... - - -def unify(typ, subtyp, subs: Substitutions = {}) -> Substitutions: + typ: TypeExpressions, subtyp: TypeExpressions, subs: Substitutions = {} +) -> Substitutions: """ Unify a pattern type with a concrete type, returning a substitution map. @@ -313,9 +255,6 @@ def unify(typ, subtyp, subs: Substitutions = {}) -> Substitutions: ... TypeError: Cannot unify ... """ - if isinstance(typ, inspect.Signature): - return _unify_signature(typ, subtyp, subs) - if typ != canonicalize(typ) or subtyp != canonicalize(subtyp): return unify(canonicalize(typ), canonicalize(subtyp), subs) @@ -613,19 +552,66 @@ def _unify_generic(typ, subtyp, subs: Substitutions) -> Substitutions: raise TypeError(f"Cannot unify generic type {typ} with {subtyp} given {subs}.") +@dataclasses.dataclass +class SelfTypeReplacer(TypeEvaluator): + """Replace ``typing.Self`` with a TypeVar throughout a type expression. + + Used during signature-to-TypedDict conversion in :func:`_sig_to_type` so + that ``typing.Self`` is eliminated at the boundary and never appears in + downstream unification machinery (``unify``, ``substitute``, ``_freshen``, + ``freetypevars``). + """ + + tv: typing.TypeVar + + def evaluate(self, typ) -> TypeExpressions: + if typ is typing.Self: + return self.tv + return super().evaluate(typ) + + def infer_return_type(bound_sig: inspect.BoundArguments) -> TypeExpressions: - """Infer the return type of a function from its bound arguments.""" - typ = _replace_self( - _freshen(_sig_to_type(bound_sig.signature)), typing.TypeVar("Self") - ) - subtyp = _freshen(_bound_sig_to_type(bound_sig)) - return substitute(bound_sig.signature.return_annotation, unify(typ, subtyp)) + """Infer the return type of a function from its bound arguments. + + The signature is converted to a TypedDict-shaped pattern via + :func:`_sig_to_type`; the bound arguments are converted via + :func:`_bound_sig_to_type`; the two are unified using the regular TypedDict + unification path; and the freshened return type from the pattern is + substituted with the resulting bindings. + + ``typing.Self`` is eliminated by :func:`_sig_to_type` (via + :class:`SelfTypeReplacer`) — by the time we reach unification, no + ``typing.Self`` references remain anywhere. + """ + pattern = _sig_to_type(bound_sig.signature) + bound = _freshen(_bound_sig_to_type(bound_sig)) + return_type = _get_typeddict_hints(pattern)["return"] + return substitute(return_type, unify(pattern, bound)) def _sig_to_type(sig: inspect.Signature) -> TypeExpression: - """Convert an inspect.Signature to a type expression.""" + """Convert an :class:`inspect.Signature` to a TypedDict pattern for unification. + + .. note:: + + The precise encoding of signatures as TypedDicts is an implementation + detail of the unification machinery and is subject to change. Callers + outside this module should treat the returned TypedDict as opaque and + interact with it only through :func:`unify` / :func:`infer_return_type`. + In particular: the choice of field names (e.g. ``"return"``), wrapping + order (``NotRequired`` vs ``ReadOnly``), variadic encoding, and which + parameters are omitted may all change without notice. + + The current encoding maps each annotated parameter to a TypedDict field and + the return annotation to a ``"return"`` field. ``typing.Self`` is replaced + by a fresh ``TypeVar`` via :class:`SelfTypeReplacer` so that no Self leaks + into the unification path. Fields the bound side may omit (default-valued + params, ``*args``, ``**kwargs``, ``"return"``) are wrapped in + ``NotRequired``; all fields are also wrapped in ``ReadOnly`` to give the + TypedDict unification path covariant semantics for parameter types. + """ if sig.return_annotation is inspect.Parameter.empty: - return _sig_to_type(sig.replace(return_annotation=typing.Any)) + return _sig_to_type(sig.replace(return_annotation=object)) elif sig.return_annotation is None: return _sig_to_type(sig.replace(return_annotation=type(None))) elif typing.get_origin(sig.return_annotation) is typing.Annotated: @@ -633,25 +619,49 @@ def _sig_to_type(sig: inspect.Signature) -> TypeExpression: sig.replace(return_annotation=typing.get_args(sig.return_annotation)[0]) ) + replacer = SelfTypeReplacer(typing.TypeVar("Self")) # type: ignore[misc] + annotations: dict[str, TypeExpressions] = { - "return": typing.NotRequired[sig.return_annotation] + "return": typing.NotRequired[ + typing.ReadOnly[replacer.evaluate(sig.return_annotation)] + ] # type: ignore[assignment] } for name, param in sig.parameters.items(): + if param.annotation is inspect.Parameter.empty: + continue + ann = replacer.evaluate(param.annotation) if param.kind == inspect.Parameter.VAR_POSITIONAL: - annotations[name] = tuple[param.annotation, ...] + field: TypeExpressions = typing.NotRequired[ + typing.ReadOnly[tuple[ann, ...]] + ] # type: ignore[assignment] elif param.kind == inspect.Parameter.VAR_KEYWORD: - annotations[name] = dict[str, param.annotation] - elif param.kind not in { - inspect.Parameter.VAR_KEYWORD, - inspect.Parameter.VAR_POSITIONAL, - }: - annotations[name] = param.annotation + field = typing.NotRequired[typing.ReadOnly[dict[str, ann]]] # type: ignore[assignment] + elif param.default is not inspect.Parameter.empty: + field = typing.NotRequired[typing.ReadOnly[ann]] # type: ignore[assignment] + else: + field = typing.ReadOnly[ann] # type: ignore[assignment] + annotations[name] = field return typing.TypedDict(f"{sig}_Type", annotations) def _bound_sig_to_type(bound_sig: inspect.BoundArguments) -> TypeExpression: - """Convert an inspect.BoundArguments to a type expression for unification.""" + """Convert an :class:`inspect.BoundArguments` to a TypedDict subtype. + + .. note:: + + Like :func:`_sig_to_type`, the precise encoding here is an + implementation detail of the unification machinery and is subject to + change. The two functions are designed in tandem to produce TypedDicts + that unify against each other; treat the output as opaque. + + The current encoding maps each bound argument to a field whose value is + its runtime type (from :func:`nested_type`). Variadic positional/keyword + arguments are encoded as a fixed-arity tuple / nested TypedDict + respectively. Parameters that were not bound (defaults, unfilled varargs) + are simply omitted; this aligns with the ``NotRequired`` fields produced + by :func:`_sig_to_type`. + """ sig: inspect.Signature = bound_sig.signature typed_arguments = sig.bind( *[nested_type(a).value for a in bound_sig.args], @@ -664,10 +674,14 @@ def _bound_sig_to_type(bound_sig: inspect.BoundArguments) -> TypeExpression: continue if name not in typed_arguments: - assert param.kind in { - inspect.Parameter.VAR_POSITIONAL, - inspect.Parameter.VAR_KEYWORD, - } + assert ( + param.kind + in { + inspect.Parameter.VAR_POSITIONAL, + inspect.Parameter.VAR_KEYWORD, + } + or param.default is not inspect.Parameter.empty + ) continue psubtyp = typed_arguments[name] @@ -692,64 +706,16 @@ def _bound_sig_to_type(bound_sig: inspect.BoundArguments) -> TypeExpression: return typing.TypedDict("BoundSigType", annotations) -def _unify_signature( - typ: inspect.Signature, subtyp: inspect.BoundArguments, subs: Substitutions -) -> Substitutions: - if typ != subtyp.signature: - raise TypeError(f"Cannot unify {typ} with {subtyp} given {subs}. ") - - if _has_typing_self(typ): - self_tv = typing.TypeVar("Self") - typ = _replace_self(typ, self_tv) - subtyp = typ.bind(*subtyp.args, **subtyp.kwargs) - return {**unify(typ, subtyp, subs), typing.Self: self_tv} # type: ignore - - for name, param in typ.parameters.items(): - if param.annotation is inspect.Parameter.empty: - continue - - if name not in subtyp.arguments: - assert param.kind in { - inspect.Parameter.VAR_POSITIONAL, - inspect.Parameter.VAR_KEYWORD, - } - continue - - ptyp, psubtyp = param.annotation, subtyp.arguments[name] - if param.kind is inspect.Parameter.VAR_POSITIONAL and isinstance( - psubtyp, collections.abc.Sequence - ): - for psubtyp_item in _freshen(psubtyp): - subs = unify(ptyp, psubtyp_item, subs) - elif param.kind is inspect.Parameter.VAR_KEYWORD and isinstance( - psubtyp, collections.abc.Mapping - ): - for psubtyp_item in _freshen(tuple(psubtyp.values())): - subs = unify(ptyp, psubtyp_item, subs) - elif param.kind not in { - inspect.Parameter.VAR_KEYWORD, - inspect.Parameter.VAR_POSITIONAL, - } or isinstance(psubtyp, typing.ParamSpecArgs | typing.ParamSpecKwargs): - subs = unify(ptyp, _freshen(psubtyp), subs) - else: - raise TypeError(f"Cannot unify {param} with {psubtyp} given {subs}") - return subs - - def _freshen(tp: typing.Any): """ Return a freshened version of the given type expression. - This function replaces all TypeVars in the type expression with new TypeVars - that have unique names, ensuring that the resulting type has no free TypeVars. - It is useful for creating fresh type variables in generic programming contexts. + Replaces all free :class:`typing.TypeVar` / :class:`typing.ParamSpec` + occurrences with new TypeVars/ParamSpecs of the same name, isolating type + variables across independent unification calls so they can't accidentally + collide. - Args: - tp: The type expression to freshen. Can be a plain type, TypeVar, - generic alias, or union type. - - Returns: - A new type expression with all TypeVars replaced by fresh TypeVars. + Traverses ``TypeExpression`` and ``TypedDict`` shapes. Examples: >>> import typing @@ -759,12 +725,13 @@ def _freshen(tp: typing.Any): >>> _freshen(T) == T False """ - assert all(canonicalize(fv) is fv for fv in freetypevars(tp)) + fvs = freetypevars(tp) + assert all(canonicalize(fv) is fv for fv in fvs) subs: Substitutions = { fv: typing.TypeVar(fv.__name__, bound=fv.__bound__) if isinstance(fv, typing.TypeVar) else typing.ParamSpec(fv.__name__) - for fv in freetypevars(tp) + for fv in fvs if isinstance(fv, typing.TypeVar | typing.ParamSpec) } return substitute(tp, subs) @@ -799,15 +766,19 @@ def _(typ: type | abc.ABCMeta): # Idempotency: if all field types are already canonical, return same object if all(canonicalize(h) == h for h in hints.values()): return typ - # Otherwise, create a fresh TypedDict with canonicalized field types + # Otherwise, create a fresh TypedDict with canonicalized field types, + # preserving NotRequired-ness and ReadOnly-ness via __optional_keys__ + # and __readonly_keys__. optional_keys: frozenset[str] = getattr(typ, "__optional_keys__", frozenset()) + readonly_keys: frozenset[str] = getattr(typ, "__readonly_keys__", frozenset()) canon_fields: dict[str, type] = {} for field, ftype in hints.items(): ct = canonicalize(ftype) + if field in readonly_keys: + ct = typing.ReadOnly[ct] # type: ignore[assignment] if field in optional_keys: - canon_fields[field] = typing.NotRequired[ct] # type: ignore[assignment] - else: - canon_fields[field] = ct # type: ignore[assignment] + ct = typing.NotRequired[ct] # type: ignore[assignment] + canon_fields[field] = ct # type: ignore[assignment] return typing.TypedDict(typ.__name__, canon_fields) # type: ignore[operator] elif typ is types.GeneratorType: return collections.abc.Generator @@ -1245,6 +1216,11 @@ def freetypevars(typ) -> collections.abc.Set[TypeVariable]: >>> freetypevars(dict[str, T]) {~T} """ + if _is_typeddict_type(typ): + result: set[TypeVariable] = set() + for ftype in _get_typeddict_hints(typ).values(): + result |= freetypevars(ftype) + return result return set(_freetypevars((typ,))) @@ -1293,15 +1269,37 @@ def substitute(typ, subs: Substitutions) -> TypeExpressions: >>> substitute(int, {T: str}) """ - if typing.Self in subs and _has_typing_self(typ): - return substitute(_replace_self(typ, subs[typing.Self]), subs) - if isinstance(typ, typing.TypeVar | typing.ParamSpec | typing.TypeVarTuple): return substitute(subs[typ], subs) if typ in subs else typ elif isinstance(typ, list | tuple): return type(typ)(substitute(item, subs) for item in typ) + elif _is_typeddict_type(typ): + return _substitute_typeddict(typ, subs) elif any(fv in subs for fv in freetypevars(typ)): args = tuple(subs.get(fv, fv) for fv in _freetypevars((typ,))) return substitute(typ[args], subs) else: return typ + + +def _substitute_typeddict(typ, subs: Substitutions) -> TypeExpressions: + """Substitute inside a TypedDict's fields, preserving NotRequired-ness.""" + typ_origin = typing.get_origin(typ) or typ + optional_keys: frozenset[str] = getattr( + typ_origin, "__optional_keys__", frozenset() + ) + hints = _get_typeddict_hints(typ) + new_fields: dict[str, TypeExpressions] = {} + changed = False + for field, ftype in hints.items(): + new_ftype = substitute(ftype, subs) + if new_ftype is not ftype: + changed = True + if field in optional_keys: + new_fields[field] = typing.NotRequired[new_ftype] # type: ignore[assignment] + else: + new_fields[field] = new_ftype + if not changed: + return typ + name = getattr(typ_origin, "__name__", "_Subbed") + return typing.TypedDict(name, new_fields) # type: ignore[operator,misc] diff --git a/tests/test_internals_unification.py b/tests/test_internals_unification.py index 0f470cdd8..3f44afe28 100644 --- a/tests/test_internals_unification.py +++ b/tests/test_internals_unification.py @@ -13,9 +13,9 @@ TypeEvaluator, TypeExpressions, TypeVariable, - _has_typing_self, canonicalize, freetypevars, + infer_return_type, nested_type, substitute, unify, @@ -592,32 +592,54 @@ def variadic_kwargs_func[T](**kwargs: T) -> T: # Variadic kwargs not supported return next(iter(kwargs.values())) +def variadic_args_const_return(*args: object) -> int: + return len(args) + + +def default_param_func[T](x: T, y: str = "ok") -> T: + return x + + +def returns_none(x: int) -> None: + return None + + +def returns_annotated[T](x: T) -> typing.Annotated[T, "meta"]: + return x + + +def no_return_annotation_func(x: int): # type: ignore[no-untyped-def] + return x + + class _Foo: - def return_self(self) -> typing.Self: + def return_self(self: typing.Self) -> typing.Self: return self - def return_list_self(self) -> list[typing.Self]: + def return_list_self(self: typing.Self) -> list[typing.Self]: return [self] - def return_self_or_none(self) -> typing.Self | None: + def return_self_or_none(self: typing.Self) -> typing.Self | None: return self def annotated_self(self: typing.Self) -> typing.Self: return self - def takes_other(self, other: typing.Self) -> typing.Self: + def takes_other(self: typing.Self, other: typing.Self) -> typing.Self: return self - def mixed_with_typevar[T](self, x: T) -> tuple[typing.Self, T]: + def mixed_with_typevar[T](self: typing.Self, x: T) -> tuple[typing.Self, T]: return (self, x) - def return_dict_self(self) -> dict[str, typing.Self]: + def return_dict_self(self: typing.Self) -> dict[str, typing.Self]: return {"me": self} - def return_callable_self(self) -> collections.abc.Callable[[typing.Self], int]: + def return_callable_self( + self: typing.Self, + ) -> collections.abc.Callable[[typing.Self], int]: return id - def return_type_self(self) -> type[typing.Self]: + def return_type_self(self: typing.Self) -> type[typing.Self]: return type(self) @classmethod @@ -626,7 +648,7 @@ def from_config(cls, config: int) -> typing.Self: class _Bar: - def return_self(self) -> typing.Self: + def return_self(self: typing.Self) -> typing.Self: return self @@ -707,43 +729,93 @@ def return_self(self) -> typing.Self: (variadic_args_func, (int, int), {}, int), (variadic_kwargs_func, (), {"x": int}, int), (variadic_kwargs_func, (), {"x": int, "y": int}, int), - # typing.Self return types (methods) - (_Foo.return_self, (int,), {}, int), - (_Foo.return_self, (str,), {}, str), - (_Foo.return_self, (_Foo,), {}, _Foo), - (_Foo.return_self, (_Bar,), {}, _Bar), - (_Bar.return_self, (_Bar,), {}, _Bar), - (_Foo.annotated_self, (_Foo,), {}, _Foo), - (_Foo.return_list_self, (int,), {}, list[int]), - (_Foo.return_list_self, (_Foo,), {}, list[_Foo]), - (_Foo.return_self_or_none, (int,), {}, int | None), - (_Foo.return_self_or_none, (_Foo,), {}, _Foo | None), + # zero-arg variadic positional/keyword (return type doesn't depend on T) + (variadic_args_const_return, (), {}, int), + # default-valued param: omitted vs provided + (default_param_func, (int,), {}, int), + (default_param_func, (str,), {"y": str}, str), + # special return-annotation forms + (returns_none, (int,), {}, type(None)), + (returns_annotated, (int,), {}, int), + (no_return_annotation_func, (int,), {}, object), + ], +) +def test_infer_return_type_success( + func: collections.abc.Callable, + args: tuple, + kwargs: dict, + expected_return_type: type, +): + # Args here are type expressions, not values; box them so nested_type + # treats them as already-typed (per Box's intended purpose). + sig = inspect.signature(func) + bound = sig.bind( + *[Box(a) for a in args], + **{k: Box(v) for k, v in kwargs.items()}, + ) + result = infer_return_type(bound) + assert canonicalize(result) == canonicalize(expected_return_type) + + +@pytest.mark.parametrize( + "func,args,kwargs,expected_return_type", + [ + # typing.Self return types (methods). Args are wrapped in Box so + # nested_type treats them as already-typed values. + (_Foo.return_self, (Box(int),), {}, int), + (_Foo.return_self, (Box(str),), {}, str), + (_Foo.return_self, (Box(_Foo),), {}, _Foo), + (_Foo.return_self, (Box(_Bar),), {}, _Bar), + (_Bar.return_self, (Box(_Bar),), {}, _Bar), + (_Foo.annotated_self, (Box(_Foo),), {}, _Foo), + (_Foo.return_list_self, (Box(int),), {}, list[int]), + (_Foo.return_list_self, (Box(_Foo),), {}, list[_Foo]), + (_Foo.return_self_or_none, (Box(int),), {}, int | None), + (_Foo.return_self_or_none, (Box(_Foo),), {}, _Foo | None), # Self as a non-self parameter - (_Foo.takes_other, (int, int), {}, int), - (_Foo.takes_other, (_Foo, _Foo), {}, _Foo), + (_Foo.takes_other, (Box(int), Box(int)), {}, int), + (_Foo.takes_other, (Box(_Foo), Box(_Foo)), {}, _Foo), # Self mixed with other TypeVars - (_Foo.mixed_with_typevar, (int, str), {}, tuple[int, str]), - (_Foo.mixed_with_typevar, (_Foo, list[int]), {}, tuple[_Foo, list[int]]), + (_Foo.mixed_with_typevar, (Box(int), Box(str)), {}, tuple[int, str]), + ( + _Foo.mixed_with_typevar, + (Box(_Foo), Box(list[int])), + {}, + tuple[_Foo, list[int]], + ), # Self in dict[str, Self] - (_Foo.return_dict_self, (int,), {}, dict[str, int]), - (_Foo.return_dict_self, (_Foo,), {}, dict[str, _Foo]), + (_Foo.return_dict_self, (Box(int),), {}, dict[str, int]), + (_Foo.return_dict_self, (Box(_Foo),), {}, dict[str, _Foo]), # Self inside Callable[[Self], int] - (_Foo.return_callable_self, (int,), {}, collections.abc.Callable[[int], int]), - (_Foo.return_callable_self, (_Foo,), {}, collections.abc.Callable[[_Foo], int]), + ( + _Foo.return_callable_self, + (Box(int),), + {}, + collections.abc.Callable[[int], int], + ), + ( + _Foo.return_callable_self, + (Box(_Foo),), + {}, + collections.abc.Callable[[_Foo], int], + ), # type[Self] - (_Foo.return_type_self, (int,), {}, type[int]), - (_Foo.return_type_self, (_Foo,), {}, type[_Foo]), + (_Foo.return_type_self, (Box(int),), {}, type[int]), + (_Foo.return_type_self, (Box(_Foo),), {}, type[_Foo]), ], ) -def test_infer_return_type_success( +def test_infer_return_type_self_success( func: collections.abc.Callable, args: tuple, kwargs: dict, expected_return_type: type, ): + """Self resolution goes through infer_return_type (the Self TypeVar is + internal to the conversion in _sig_to_type and never appears in the result + subs that the user could pass to substitute).""" sig = inspect.signature(func) bound = sig.bind(*args, **kwargs) - result = substitute(sig.return_annotation, unify(sig, bound)) + result = infer_return_type(bound) assert canonicalize(result) == canonicalize(expected_return_type) @@ -773,9 +845,12 @@ def test_infer_return_type_failure( kwargs: dict, ): sig = inspect.signature(func) - bound = sig.bind(*args, **kwargs) + bound = sig.bind( + *[Box(a) for a in args], + **{k: Box(v) for k, v in kwargs.items()}, + ) with pytest.raises(TypeError): - unify(sig, bound) + infer_return_type(bound) @pytest.mark.parametrize( @@ -1256,28 +1331,11 @@ def test_infer_composition_1(seq, index, key): sig12 = inspect.signature(sequence_mapping_getitem) - inferred_type1 = substitute( - sig1.return_annotation, - unify(sig1, sig1.bind(nested_type(seq).value, nested_type(index).value)), - ) + inferred_type1 = infer_return_type(sig1.bind(seq, index)) - inferred_type2 = substitute( - sig2.return_annotation, - unify( - sig2, - sig2.bind(nested_type(Box(inferred_type1)).value, nested_type(key).value), - ), - ) + inferred_type2 = infer_return_type(sig2.bind(Box(inferred_type1), key)) - inferred_type12 = substitute( - sig12.return_annotation, - unify( - sig12, - sig12.bind( - nested_type(seq).value, nested_type(index).value, nested_type(key).value - ), - ), - ) + inferred_type12 = infer_return_type(sig12.bind(seq, index, key)) # check that the composed inference matches the direct inference assert isinstance(unify(inferred_type2, inferred_type12), collections.abc.Mapping) @@ -1388,32 +1446,13 @@ def test_infer_composition_2(mapping, key, index): sig12 = inspect.signature(mapping_sequence_getitem) # First infer type of mapping_getitem(mapping, key) -> should be a sequence - inferred_type1 = substitute( - sig1.return_annotation, - unify(sig1, sig1.bind(nested_type(mapping).value, nested_type(key).value)), - ) + inferred_type1 = infer_return_type(sig1.bind(mapping, key)) # Then infer type of sequence_getitem(result_from_step1, index) -> should be element type - inferred_type2 = substitute( - sig2.return_annotation, - unify( - sig2, - sig2.bind(nested_type(Box(inferred_type1)).value, nested_type(index).value), - ), - ) + inferred_type2 = infer_return_type(sig2.bind(Box(inferred_type1), index)) # Directly infer type of mapping_sequence_getitem(mapping, key, index) - inferred_type12 = substitute( - sig12.return_annotation, - unify( - sig12, - sig12.bind( - nested_type(mapping).value, - nested_type(key).value, - nested_type(index).value, - ), - ), - ) + inferred_type12 = infer_return_type(sig12.bind(mapping, key, index)) # The composed inference should match the direct inference assert isinstance(unify(inferred_type2, inferred_type12), collections.abc.Mapping) @@ -1453,25 +1492,15 @@ def test_get_from_constructed_sequence(a, b, index): sig_composed = inspect.signature(get_from_constructed_sequence) # Infer type of sequence_from_pair(a, b) -> Sequence[T] - construct_subs = unify( - sig_construct, sig_construct.bind(nested_type(a).value, nested_type(b).value) - ) - inferred_sequence_type = substitute(sig_construct.return_annotation, construct_subs) + inferred_sequence_type = infer_return_type(sig_construct.bind(a, b)) # Infer type of sequence_getitem(sequence, index) -> T - getitem_subs = unify( - sig_getitem, sig_getitem.bind(inferred_sequence_type, nested_type(index).value) + inferred_element_type = infer_return_type( + sig_getitem.bind(Box(inferred_sequence_type), index) ) - inferred_element_type = substitute(sig_getitem.return_annotation, getitem_subs) # Directly infer type of get_from_constructed_sequence(a, b, index) - direct_subs = unify( - sig_composed, - sig_composed.bind( - nested_type(a).value, nested_type(b).value, nested_type(index).value - ), - ) - direct_type = substitute(sig_composed.return_annotation, direct_subs) + direct_type = infer_return_type(sig_composed.bind(a, b, index)) # The composed inference should match the direct inference assert isinstance( @@ -1511,29 +1540,15 @@ def test_get_from_constructed_mapping(key, value, lookup_key): sig_composed = inspect.signature(get_from_constructed_mapping) # Infer type of mapping_from_pair(key, value) -> Mapping[K, V] - construct_subs = unify( - sig_construct, - sig_construct.bind(nested_type(key).value, nested_type(value).value), - ) - inferred_mapping_type = substitute(sig_construct.return_annotation, construct_subs) + inferred_mapping_type = infer_return_type(sig_construct.bind(key, value)) # Infer type of mapping_getitem(mapping, lookup_key) -> V - getitem_subs = unify( - sig_getitem, - sig_getitem.bind(inferred_mapping_type, nested_type(lookup_key).value), + inferred_value_type = infer_return_type( + sig_getitem.bind(Box(inferred_mapping_type), lookup_key) ) - inferred_value_type = substitute(sig_getitem.return_annotation, getitem_subs) # Directly infer type of get_from_constructed_mapping(key, value, lookup_key) - direct_subs = unify( - sig_composed, - sig_composed.bind( - nested_type(key).value, - nested_type(value).value, - nested_type(lookup_key).value, - ), - ) - direct_type = substitute(sig_composed.return_annotation, direct_subs) + direct_type = infer_return_type(sig_composed.bind(key, value, lookup_key)) # The composed inference should match the direct inference assert isinstance(unify(inferred_value_type, direct_type), collections.abc.Mapping) @@ -1569,29 +1584,25 @@ def test_sequence_of_mappings(key1, val1, key2, val2, index): sig_composed = inspect.signature(sequence_of_mappings) # Step 1: Infer types of the two mappings - map1_subs = unify( - sig_map, sig_map.bind(nested_type(key1).value, nested_type(val1).value) - ) - map1_type = substitute(sig_map.return_annotation, map1_subs) + map1_type = infer_return_type(sig_map.bind(key1, val1)) # Step 2: Infer type of sequence containing these mappings # We need to unify the two mapping types first unified_map_type = map1_type # Assuming they're compatible - seq_subs = unify(sig_seq, sig_seq.bind(unified_map_type, unified_map_type)) - seq_type = substitute(sig_seq.return_annotation, seq_subs) + seq_type = infer_return_type( + sig_seq.bind(Box(unified_map_type), Box(unified_map_type)) + ) # Direct inference - direct_subs = unify( - sig_composed, + direct_type = infer_return_type( sig_composed.bind( - nested_type(key1).value, - nested_type(val1).value, - nested_type(key2).value, - nested_type(val2).value, - ), + key1, + val1, + key2, + val2, + ) ) - direct_type = substitute(sig_composed.return_annotation, direct_subs) # The types should match assert isinstance(unify(seq_type, direct_type), collections.abc.Mapping) @@ -1621,57 +1632,25 @@ def test_double_nested_get(k1, v1, v2, k2, v3, v4, outer_idx, inner_key, inner_i sig_composed = inspect.signature(double_nested_get) # Step 1: Infer type of nested_sequence_mapping construction - nested_subs = unify( - sig_nested, - sig_nested.bind( - nested_type(k1).value, - nested_type(v1).value, - nested_type(v2).value, - nested_type(k2).value, - nested_type(v3).value, - nested_type(v4).value, - ), - ) - nested_seq_type = substitute(sig_nested.return_annotation, nested_subs) + nested_seq_type = infer_return_type(sig_nested.bind(k1, v1, v2, k2, v3, v4)) # This should be Sequence[Mapping[K, Sequence[T]]] # Step 2: Get element from outer sequence - outer_get_subs = unify( - sig_seq_get, sig_seq_get.bind(nested_seq_type, nested_type(outer_idx).value) - ) - mapping_type = substitute(sig_seq_get.return_annotation, outer_get_subs) + mapping_type = infer_return_type(sig_seq_get.bind(Box(nested_seq_type), outer_idx)) # This should be Mapping[K, Sequence[T]] # Step 3: Get sequence from mapping - inner_map_subs = unify( - sig_map_get, sig_map_get.bind(mapping_type, nested_type(inner_key).value) - ) - sequence_type = substitute(sig_map_get.return_annotation, inner_map_subs) + sequence_type = infer_return_type(sig_map_get.bind(Box(mapping_type), inner_key)) # This should be Sequence[T] # Step 4: Get element from inner sequence - final_get_subs = unify( - sig_seq_get, sig_seq_get.bind(sequence_type, nested_type(inner_idx).value) - ) - composed_type = substitute(sig_seq_get.return_annotation, final_get_subs) + composed_type = infer_return_type(sig_seq_get.bind(Box(sequence_type), inner_idx)) # This should be T # Direct inference on the composed function - direct_subs = unify( - sig_composed, - sig_composed.bind( - nested_type(k1).value, - nested_type(v1).value, - nested_type(v2).value, - nested_type(k2).value, - nested_type(v3).value, - nested_type(v4).value, - nested_type(outer_idx).value, - nested_type(inner_key).value, - nested_type(inner_idx).value, - ), + direct_type = infer_return_type( + sig_composed.bind(k1, v1, v2, k2, v3, v4, outer_idx, inner_key, inner_idx) ) - direct_type = substitute(sig_composed.return_annotation, direct_subs) # The composed inference should match the direct inference assert isinstance(unify(composed_type, direct_type), collections.abc.Mapping) @@ -1708,23 +1687,13 @@ def test_apply_to_sequence_element(f, seq, index): sig_composed = inspect.signature(apply_to_sequence_element) # Step 1: Infer type of sequence_getitem(seq, index) -> T - getitem_subs = unify( - sig_getitem, sig_getitem.bind(nested_type(seq).value, nested_type(index).value) - ) - element_type = substitute(sig_getitem.return_annotation, getitem_subs) + element_type = infer_return_type(sig_getitem.bind(seq, index)) # Step 2: Infer type of call_func(f, element) -> U - call_subs = unify(sig_call, sig_call.bind(nested_type(f).value, element_type)) - composed_type = substitute(sig_call.return_annotation, call_subs) + composed_type = infer_return_type(sig_call.bind(f, Box(element_type))) # Direct inference - direct_subs = unify( - sig_composed, - sig_composed.bind( - nested_type(f).value, nested_type(seq).value, nested_type(index).value - ), - ) - direct_type = substitute(sig_composed.return_annotation, direct_subs) + direct_type = infer_return_type(sig_composed.bind(f, seq, index)) # The composed inference should match the direct inference assert isinstance(unify(composed_type, direct_type), collections.abc.Mapping) @@ -1753,25 +1722,13 @@ def test_map_and_get(f, seq, index): sig_composed = inspect.signature(map_and_get) # Step 1: Infer type of map_sequence(f, seq) -> Sequence[U] - map_subs = unify( - sig_map, sig_map.bind(nested_type(f).value, nested_type(seq).value) - ) - mapped_type = substitute(sig_map.return_annotation, map_subs) + mapped_type = infer_return_type(sig_map.bind(f, seq)) # Step 2: Infer type of sequence_getitem(mapped_seq, index) -> U - getitem_subs = unify( - sig_getitem, sig_getitem.bind(mapped_type, nested_type(index).value) - ) - composed_type = substitute(sig_getitem.return_annotation, getitem_subs) + composed_type = infer_return_type(sig_getitem.bind(Box(mapped_type), index)) # Direct inference - direct_subs = unify( - sig_composed, - sig_composed.bind( - nested_type(f).value, nested_type(seq).value, nested_type(index).value - ), - ) - direct_type = substitute(sig_composed.return_annotation, direct_subs) + direct_type = infer_return_type(sig_composed.bind(f, seq, index)) # The composed inference should match the direct inference assert isinstance(unify(composed_type, direct_type), collections.abc.Mapping) @@ -1800,25 +1757,13 @@ def test_compose_and_apply(f, g, value): sig_composed = inspect.signature(compose_and_apply) # Step 1: Infer type of compose_mappings(f, g) -> Callable[[T], V] - compose_subs = unify( - sig_compose, sig_compose.bind(nested_type(f).value, nested_type(g).value) - ) - composed_func_type = substitute(sig_compose.return_annotation, compose_subs) + composed_func_type = infer_return_type(sig_compose.bind(f, g)) # Step 2: Infer type of call_func(composed, value) -> V - call_subs = unify( - sig_call, sig_call.bind(composed_func_type, nested_type(value).value) - ) - result_type = substitute(sig_call.return_annotation, call_subs) + result_type = infer_return_type(sig_call.bind(Box(composed_func_type), value)) # Direct inference - direct_subs = unify( - sig_composed, - sig_composed.bind( - nested_type(f).value, nested_type(g).value, nested_type(value).value - ), - ) - direct_type = substitute(sig_composed.return_annotation, direct_subs) + direct_type = infer_return_type(sig_composed.bind(f, g, value)) # The composed inference should match the direct inference assert isinstance(unify(result_type, direct_type), collections.abc.Mapping) @@ -1847,29 +1792,22 @@ def test_construct_apply_and_get(f, a, b, index): sig_composed = inspect.signature(construct_apply_and_get) # Step 1: Infer type of sequence_from_pair(a, b) -> Sequence[T] - construct_subs = unify( - sig_construct, sig_construct.bind(nested_type(a).value, nested_type(b).value) - ) - seq_type = substitute(sig_construct.return_annotation, construct_subs) + seq_type = infer_return_type(sig_construct.bind(a, b)) # Step 2: Infer type of apply_to_sequence_element(f, seq, index) -> U - apply_subs = unify( - sig_apply, - sig_apply.bind(nested_type(f).value, seq_type, nested_type(index).value), + composed_type = infer_return_type( + sig_apply.bind(f, Box(seq_type), index), ) - composed_type = substitute(sig_apply.return_annotation, apply_subs) # Direct inference - direct_subs = unify( - sig_composed, + direct_type = infer_return_type( sig_composed.bind( - nested_type(f).value, - nested_type(a).value, - nested_type(b).value, - nested_type(index).value, + f, + a, + b, + index, ), ) - direct_type = substitute(sig_composed.return_annotation, direct_subs) # The composed inference should match the direct inference assert isinstance(unify(composed_type, direct_type), collections.abc.Mapping) @@ -1898,34 +1836,24 @@ def test_binary_on_sequence_elements(f, seq, index1, index2): sig_composed = inspect.signature(binary_on_sequence_elements) # Step 1: Infer types of sequence_getitem calls - getitem1_subs = unify( - sig_getitem, sig_getitem.bind(nested_type(seq).value, nested_type(index1).value) - ) - elem1_type = substitute(sig_getitem.return_annotation, getitem1_subs) + elem1_type = infer_return_type(sig_getitem.bind(seq, index1)) - getitem2_subs = unify( - sig_getitem, sig_getitem.bind(nested_type(seq).value, nested_type(index2).value) - ) - elem2_type = substitute(sig_getitem.return_annotation, getitem2_subs) + elem2_type = infer_return_type(sig_getitem.bind(seq, index2)) # Step 2: Infer type of call_binary_func(f, elem1, elem2) -> V - call_subs = unify( - sig_call_binary, - sig_call_binary.bind(nested_type(f).value, elem1_type, elem2_type), + composed_type = infer_return_type( + sig_call_binary.bind(f, Box(elem1_type), Box(elem2_type)), ) - composed_type = substitute(sig_call_binary.return_annotation, call_subs) # Direct inference - direct_subs = unify( - sig_composed, + direct_type = infer_return_type( sig_composed.bind( - nested_type(f).value, - nested_type(seq).value, - nested_type(index1).value, - nested_type(index2).value, + f, + seq, + index1, + index2, ), ) - direct_type = substitute(sig_composed.return_annotation, direct_subs) # The composed inference should match the direct inference assert isinstance(unify(composed_type, direct_type), collections.abc.Mapping) @@ -1974,42 +1902,51 @@ class Info(typing.TypedDict): # ============================================================ -# --- _has_typing_self --- +# --- SelfTypeReplacer --- +@pytest.mark.parametrize("tv_name", ["Self", "X"]) @pytest.mark.parametrize( - "typ,expected", + "typ,expected_factory", [ - (typing.Self, True), - (list[typing.Self], True), # type: ignore[misc] - (typing.Self | None, True), - (dict[str, typing.Self], True), # type: ignore[misc] - (collections.abc.Callable[[typing.Self], int], True), - (type[typing.Self], True), - (int, False), - (list[int], False), - (T, False), - (list[T], False), + # bare Self → tv + (typing.Self, lambda tv: tv), + # Self nested in generics → tv nested in same generics + (list[typing.Self], lambda tv: list[tv]), # type: ignore[misc,valid-type] + (dict[str, typing.Self], lambda tv: dict[str, tv]), # type: ignore[misc,valid-type] + ( + collections.abc.Callable[[typing.Self], int], + lambda tv: collections.abc.Callable[[tv], int], + ), + (type[typing.Self], lambda tv: type[tv]), # type: ignore[misc,valid-type] + # Union containing Self + (typing.Self | None, lambda tv: tv | None), + # Types without Self pass through unchanged + (int, lambda _tv: int), + (list[int], lambda _tv: list[int]), + (T, lambda _tv: T), + (list[T], lambda _tv: list[T]), ], ) -def test_has_typing_self(typ, expected): - assert _has_typing_self(typ) == expected +def test_self_type_replacer(typ, expected_factory, tv_name): + """SelfTypeReplacer rewrites every typing.Self occurrence to the given TypeVar.""" + from effectful.internals.unification import SelfTypeReplacer + + tv = typing.TypeVar(tv_name) # type: ignore[misc] + replacer = SelfTypeReplacer(tv) + assert replacer.evaluate(typ) == expected_factory(tv) # --- chaining two signatures with Self from different "classes" --- def test_chained_self_signatures(): - """Two unify calls sharing subs must not conflate Self.""" + """Two infer_return_type calls don't conflate Self across classes.""" sig_a = inspect.signature(_Foo.return_self) sig_b = inspect.signature(_Bar.return_self) - subs = unify(sig_a, sig_a.bind(_Foo)) - assert canonicalize(substitute(sig_a.return_annotation, subs)) == _Foo - - # Chaining: second unify with shared subs must not break - subs2 = unify(sig_b, sig_b.bind(_Bar), subs) - assert canonicalize(substitute(sig_b.return_annotation, subs2)) == _Bar + assert canonicalize(infer_return_type(sig_a.bind(Box(_Foo)))) == _Foo + assert canonicalize(infer_return_type(sig_b.bind(Box(_Bar)))) == _Bar # --- classmethod with Self: cls is stripped, Self stays unresolved --- @@ -2018,14 +1955,12 @@ def test_chained_self_signatures(): def test_classmethod_self_not_resolved(): """Classmethod Self stays unresolved when cls is stripped. - inspect.signature strips `cls`, so `from_config(config: int) -> Self` - has no unannotated first parameter. The Self TypeVar is created but - nothing binds it, so it remains free in the substitution result. + ``inspect.signature`` strips ``cls``, so ``from_config(config: int) -> Self`` + has no parameter that binds Self. SelfTypeReplacer produces a fresh + TypeVar; nothing in unify binds it; the TypeVar leaks through the result. """ sig = inspect.signature(_Foo.from_config) - subs = unify(sig, sig.bind(int)) - result = substitute(sig.return_annotation, subs) - # Self was replaced with a TypeVar, but nothing bound it. + result = infer_return_type(sig.bind(Box(int))) assert isinstance(result, typing.TypeVar) assert result.__name__ == "Self" @@ -2046,22 +1981,13 @@ def test_infer_self_composition_1(obj_type): sig_direct = inspect.signature(_Foo.return_self) # Step 1: infer list[Self] with Self bound to obj_type - inferred_type1 = substitute( - sig1.return_annotation, - unify(sig1, sig1.bind(obj_type)), - ) + inferred_type1 = infer_return_type(sig1.bind(Box(obj_type))) # Step 2: get_first(list[obj_type]) -> obj_type - inferred_type2 = substitute( - sig2.return_annotation, - unify(sig2, sig2.bind(nested_type(Box(inferred_type1)).value)), - ) + inferred_type2 = infer_return_type(sig2.bind(Box(inferred_type1))) # Direct: return_self(obj_type) -> obj_type - inferred_direct = substitute( - sig_direct.return_annotation, - unify(sig_direct, sig_direct.bind(obj_type)), - ) + inferred_direct = infer_return_type(sig_direct.bind(Box(obj_type))) # The composed inference should match the direct inference assert isinstance(unify(inferred_type2, inferred_direct), collections.abc.Mapping) @@ -2079,23 +2005,143 @@ def test_infer_self_composition_2(obj_type): sig2 = inspect.signature(_Foo.return_list_self) sig_direct = inspect.signature(wrap_in_list) - # Step 1: identity(obj_type) -> obj_type - inferred_type1 = substitute( - sig1.return_annotation, - unify(sig1, sig1.bind(obj_type)), - ) + inferred_type1 = infer_return_type(sig1.bind(Box(obj_type))) + inferred_type2 = infer_return_type(sig2.bind(Box(inferred_type1))) + inferred_direct = infer_return_type(sig_direct.bind(Box(obj_type))) - # Step 2: return_list_self(obj_type) -> list[obj_type] - inferred_type2 = substitute( - sig2.return_annotation, - unify(sig2, sig2.bind(nested_type(Box(inferred_type1)).value)), - ) + assert isinstance(unify(inferred_type2, inferred_direct), collections.abc.Mapping) - # Direct: wrap_in_list(obj_type) -> list[obj_type] - inferred_direct = substitute( - sig_direct.return_annotation, - unify(sig_direct, sig_direct.bind(obj_type)), - ) - # The composed inference should match the direct inference - assert isinstance(unify(inferred_type2, inferred_direct), collections.abc.Mapping) +def test_infer_return_type_direct_simple(): + """infer_return_type produces the substituted return type. + + Inputs are Boxed types per the convention used by ``__type_rule__``. + """ + + def f(x: T) -> list[T]: ... # type: ignore[valid-type] + + sig = inspect.signature(f) + result = infer_return_type(sig.bind(Box(int))) + assert canonicalize(result) == canonicalize(list[int]) + + +def test_infer_return_type_direct_self(): + """infer_return_type resolves typing.Self via the new TypedDict path. + + typing.Self only resolves when ``self`` is annotated explicitly: there is + no implicit self-injection (it would be unsafe for class/static/free + functions where the first positional is not the receiver). + """ + + class C: + def m(self: typing.Self) -> typing.Self: ... + + sig = inspect.signature(C.m) + result = infer_return_type(sig.bind(C())) + assert canonicalize(result) == C + + +def test_infer_return_type_bare_self_unresolved(): + """A method with bare ``self`` does not bind Self — it stays a free TypeVar.""" + + class C: + def m(self) -> typing.Self: ... + + sig = inspect.signature(C.m) + result = infer_return_type(sig.bind(C())) + assert isinstance(result, typing.TypeVar) + assert result.__name__ == "Self" + + +def test_infer_return_type_direct_default_omitted(): + """infer_return_type with a default-valued param omitted still returns correctly.""" + + def f(x: T, y: str = "ok") -> T: ... # type: ignore[valid-type] + + sig = inspect.signature(f) + result = infer_return_type(sig.bind(Box(int))) + assert canonicalize(result) == int + + +# --- TypedDict-valued parameters in signatures --- + + +def test_infer_return_type_typeddict_param(): + """A function parameter typed as a TypedDict unifies with a runtime + str-keyed dict whose nested_type is a structurally-compatible TypedDict. + + Composes the sig→TypedDict conversion with the TypedDict-vs-TypedDict + unification path added in PR 651. + """ + + class User(typing.TypedDict): + name: str + age: int + + def greet(u: User) -> str: ... + + sig = inspect.signature(greet) + # nested_type({"name": "a", "age": 1}) returns a TypedDict, which is + # structurally compatible with User → unification should succeed. + result = infer_return_type(sig.bind({"name": "a", "age": 1})) + assert canonicalize(result) is str + + +def test_infer_return_type_generic_typeddict_param(): + """A parameterized TypedDict in a function parameter binds the TypeVar + from the runtime dict's value types.""" + + class Datum[T](typing.TypedDict): + name: str + value: T + + def extract[T](d: Datum[T]) -> T: ... + + sig = inspect.signature(extract) + result = infer_return_type(sig.bind({"name": "a", "value": 42})) + assert canonicalize(result) is int + + +# --- Failure cases --- + + +def test_infer_return_type_conflicting_self_bindings(): + """A method with two parameters both annotated typing.Self must reject + callers that pass differently-typed receivers for those slots. + + Without explicit Self binding, this would silently widen Self to a common + supertype; with the SelfTypeReplacer pattern, both Self occurrences alias + the same fresh TypeVar, so conflicting concrete types must raise. + """ + + class C: + def f(self: typing.Self, other: typing.Self) -> typing.Self: ... + + sig = inspect.signature(C.f) + # Same type for both: succeeds. + assert canonicalize(infer_return_type(sig.bind(Box(_Foo), Box(_Foo)))) == _Foo + # Mismatched concrete types: must raise. + with pytest.raises(TypeError): + infer_return_type(sig.bind(Box(_Foo), Box(_Bar))) + + +def test_infer_return_type_wrong_arg_type(): + """Passing an arg whose type is incompatible with the parameter + annotation must raise.""" + + def f(x: int) -> int: ... + + sig = inspect.signature(f) + with pytest.raises(TypeError): + infer_return_type(sig.bind(Box(str))) + + +def test_infer_return_type_missing_required_arg(): + """Omitting a required positional arg surfaces as a Signature.bind error + before we reach infer_return_type — document that contract here.""" + + def f(x: int, y: int) -> int: ... + + sig = inspect.signature(f) + with pytest.raises(TypeError): + sig.bind(Box(int)) # missing `y` diff --git a/tests/test_ops_semantics.py b/tests/test_ops_semantics.py index 008ac5446..f636f4373 100644 --- a/tests/test_ops_semantics.py +++ b/tests/test_ops_semantics.py @@ -917,11 +917,11 @@ def f(self): class _SelfA: @defop - def ret_self(self, x: int) -> typing.Self: + def ret_self(self: typing.Self, x: int) -> typing.Self: raise NotHandled @defop - def ret_list_self(self, x: int) -> list[typing.Self]: + def ret_list_self(self: typing.Self, x: int) -> list[typing.Self]: raise NotHandled @defop @@ -929,13 +929,13 @@ def annotated_self(self: typing.Self, x: int) -> typing.Self: raise NotHandled @defop - def ret_self_or_none(self, x: int) -> typing.Self | None: + def ret_self_or_none(self: typing.Self, x: int) -> typing.Self | None: raise NotHandled class _SelfB: @defop - def ret_self(self, x: int) -> typing.Self: + def ret_self(self: typing.Self, x: int) -> typing.Self: raise NotHandled From 7d45cc160267602439836f26cb3d67f586f142af Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 5 May 2026 20:00:42 -0400 Subject: [PATCH 4/9] docstring --- effectful/internals/unification.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index 9fa46cae8..97bdec65c 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -4,7 +4,10 @@ Python's generic types. Unification is a fundamental operation in type systems that finds substitutions for type variables to make two types equivalent. -The module provides four main operations: +The module provides five main operations: + +0. **infer_return_type(bound_sig)**: Given a bound signature (from a callable invocation), + infer the return type implied by the signature and the types of the bound arguments. 1. **unify(typ, subtyp, subs={})**: The core unification algorithm that attempts to find a substitution mapping for type variables that makes a pattern type equal to @@ -23,12 +26,6 @@ collections by recursively determining element types. For example, [1, 2, 3] becomes list[int], and {"key": [1, 2]} becomes dict[str, list[int]]. -Function signatures are not first-class to ``unify``. They are converted to -TypedDict patterns via :func:`_sig_to_type` (with ``typing.Self`` eliminated by -:class:`SelfTypeReplacer` at the boundary), and the resulting TypedDicts are -unified through the regular path. :func:`infer_return_type` is the entry point -that ties this together for callable invocations. - The unification algorithm uses a single-dispatch pattern to handle different type combinations: - TypeVar unification binds variables to concrete types From 34b7e71c146f14f9af68cbb56a02dc5f3f7bfb38 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 6 May 2026 11:12:08 -0400 Subject: [PATCH 5/9] fix static typing --- effectful/internals/unification.py | 246 +++++++++++++++++++--------- tests/test_internals_unification.py | 24 +-- 2 files changed, 177 insertions(+), 93 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index 97bdec65c..b4fca9d6d 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -79,10 +79,12 @@ import effectful.ops.types if typing.TYPE_CHECKING: - TypeConstant = type | abc.ABCMeta | types.EllipsisType | None + TypedDictType = type[dict[str, typing.Any]] + TypeConstant = TypedDictType | type | abc.ABCMeta | types.EllipsisType | None GenericAlias = types.GenericAlias UnionType = types.UnionType else: + TypedDictType = type TypeConstant = ( type | abc.ABCMeta | types.EllipsisType | type(None) | type(typing.Any) ) @@ -187,9 +189,28 @@ def _(self, typ: typing.ForwardRef): else: return typ - @evaluate.register # type: ignore[arg-type] - def _(self, typ: typing._SpecialForm) -> TypeExpressions: # type: ignore[type-arg] - return typing.cast(TypeExpressions, typ) + @evaluate.register + def _(self, typ: typing._SpecialForm) -> TypeExpressions: + return typ # type: ignore + + +def infer_return_type(bound_sig: inspect.BoundArguments) -> TypeExpressions: + """Infer the return type of a function from its bound arguments. + + The signature is converted to a TypedDict-shaped pattern via + :func:`_sig_to_type`; the bound arguments are converted via + :func:`_bound_sig_to_type`; the two are unified using the regular TypedDict + unification path; and the freshened return type from the pattern is + substituted with the resulting bindings. + + ``typing.Self`` is eliminated by :func:`_sig_to_type` (via + :class:`SelfTypeReplacer`) — by the time we reach unification, no + ``typing.Self`` references remain anywhere. + """ + pattern = _freshen(_sig_to_type(bound_sig.signature)) + bound = _freshen(_bound_sig_to_type(bound_sig)) + return_type = _get_encoded_return_annotation(pattern) + return substitute(return_type, unify(pattern, bound)) def unify( @@ -255,32 +276,20 @@ def unify( if typ != canonicalize(typ) or subtyp != canonicalize(subtyp): return unify(canonicalize(typ), canonicalize(subtyp), subs) - if _is_typeddict_type(typ) and _is_typeddict_type(subtyp): - return _unify_typeddict(typ, subtyp, subs) - - # unifying Mapping[K, V] and TypedDict - if _is_typeddict_type(subtyp) and isinstance(typ, GenericAlias): - origin = typing.get_origin(typ) - if ( - origin is not None - and isinstance(origin, type) - and issubclass(origin, collections.abc.Mapping) - and len(typing.get_args(typ)) == 2 - ): - return _unify_mapping_typeddict(typ, subtyp, subs) - if typ is subtyp or typ == subtyp: return subs elif isinstance(typ, TypeVariable) or isinstance(subtyp, TypeVariable): - return _unify_typevar(typ, subtyp, subs) + return _unify_typevar(typ, subtyp, subs) # type: ignore elif isinstance(typ, collections.abc.Sequence) or isinstance( subtyp, collections.abc.Sequence ): - return _unify_sequence(typ, subtyp, subs) + return _unify_sequence(typ, subtyp, subs) # type: ignore elif isinstance(typ, UnionType) or isinstance(subtyp, UnionType): - return _unify_union(typ, subtyp, subs) + return _unify_union(typ, subtyp, subs) # type: ignore elif isinstance(typ, GenericAlias) or isinstance(subtyp, GenericAlias): - return _unify_generic(typ, subtyp, subs) + return _unify_generic(typ, subtyp, subs) # type: ignore + elif _is_typeddict_type(typ) or _is_typeddict_type(subtyp): + return _unify_typeddict(typ, subtyp, subs) # type: ignore elif isinstance(typ, type) and isinstance(subtyp, type) and issubclass(subtyp, typ): return subs elif typ in (typing.Any, ...) or subtyp in (typing.Any, ...): @@ -301,7 +310,9 @@ def _unify_typevar( ) -> Substitutions: ... -def _unify_typevar(typ, subtyp, subs: Substitutions) -> Substitutions: +def _unify_typevar( + typ: TypeExpression, subtyp: TypeExpression, subs: Substitutions +) -> Substitutions: if isinstance(typ, TypeVariable) and isinstance(subtyp, TypeVariable): return subs if typ == subtyp else {typ: subtyp, **subs} elif isinstance(typ, TypeVariable) and not isinstance(subtyp, TypeVariable): @@ -318,19 +329,35 @@ def _unify_typevar(typ, subtyp, subs: Substitutions) -> Substitutions: @typing.overload def _unify_sequence( - typ: collections.abc.Sequence, subtyp: TypeExpressions, subs: Substitutions + typ: collections.abc.Sequence[TypeExpression], + subtyp: TypeExpression, + subs: Substitutions, +) -> Substitutions: ... + + +@typing.overload +def _unify_sequence( + typ: TypeExpression, + subtyp: collections.abc.Sequence[TypeExpression], + subs: Substitutions, ) -> Substitutions: ... @typing.overload def _unify_sequence( - typ: TypeExpressions, subtyp: collections.abc.Sequence, subs: Substitutions + typ: collections.abc.Sequence[TypeExpression], + subtyp: collections.abc.Sequence[TypeExpression], + subs: Substitutions, ) -> Substitutions: ... -def _unify_sequence(typ, subtyp, subs: Substitutions) -> Substitutions: +def _unify_sequence( + typ: TypeExpressions, subtyp: TypeExpressions, subs: Substitutions +) -> Substitutions: if isinstance(typ, types.EllipsisType) or isinstance(subtyp, types.EllipsisType): return subs + assert isinstance(typ, collections.abc.Sequence) + assert isinstance(subtyp, collections.abc.Sequence) if len(typ) != len(subtyp): raise TypeError(f"Cannot unify sequence {typ} with {subtyp} given {subs}. ") for p_item, c_item in zip(typ, subtyp): @@ -370,20 +397,25 @@ def _unify_union(typ, subtyp, subs: Substitutions) -> Substitutions: raise TypeError(f"Cannot unify {typ} with {subtyp} given {subs}") -def _is_typeddict_type(typ) -> bool: +def _is_typeddict_type( + typ: TypeExpressions, +) -> typing.TypeIs[TypedDictType]: """Check if typ is a TypedDict class or a parameterized TypedDict (e.g. Datum[T]).""" - if isinstance(typ, type) and typing.is_typeddict(typ): - return True - origin = typing.get_origin(typ) - return ( - origin is not None and isinstance(origin, type) and typing.is_typeddict(origin) - ) + if isinstance(typ, GenericAlias): + return typing.is_typeddict(typing.get_origin(typ)) + else: + return typing.is_typeddict(typ) -def _get_typeddict_hints(typ) -> dict[str, TypeExpressions]: +def _get_typeddict_hints( + typ: TypedDictType | GenericAlias, +) -> collections.abc.Mapping[str, TypeExpression]: """Get type hints for a TypedDict, substituting type params if parameterized.""" - origin = typing.get_origin(typ) - if origin is not None and typing.is_typeddict(origin): + if isinstance(typ, GenericAlias): + origin = typing.get_origin(typ) + assert typing.is_typeddict(origin), ( + f"Expected a parameterized TypedDict, got {typ}." + ) args = typing.get_args(typ) type_params = origin.__type_params__ hints = typing.get_type_hints(origin) @@ -393,7 +425,7 @@ def _get_typeddict_hints(typ) -> dict[str, TypeExpressions]: hints = typing.get_type_hints(typ) # For classes like Derived(Base[int]), resolve unsubstituted TypeVars # from parameterized bases. - base_param_subs: dict[TypeVariable, TypeExpressions] = { + base_param_subs: dict[TypeVariable, TypeExpression] = { tp: arg for base in types.get_original_bases(typ) if (base_origin := typing.get_origin(base)) is not None @@ -408,7 +440,11 @@ def _get_typeddict_hints(typ) -> dict[str, TypeExpressions]: return hints -def _unify_typeddict(typ, subtyp, subs: Substitutions) -> Substitutions: +def _unify_typeddict( + typ: TypedDictType | GenericAlias, + subtyp: TypedDictType | GenericAlias, + subs: Substitutions, +) -> Substitutions: """Unify two TypedDict types by matching fields structurally. Per the typing spec for TypedDict structural subtyping: @@ -469,13 +505,22 @@ def _unify_typeddict(typ, subtyp, subs: Substitutions) -> Substitutions: return subs -def _unify_mapping_typeddict(typ, subtyp, subs: Substitutions) -> Substitutions: +def _unify_mapping_typeddict( + typ: GenericAlias, subtyp: TypedDictType, subs: Substitutions +) -> Substitutions: """Unify Mapping[K, V] (or MutableMapping[K, V]) with a TypedDict. TypedDict keys are always str, so K must unify with str. V must unify with each field's value type (covariant for Mapping, invariant for MutableMapping). """ + if not ( + typing.get_origin(typ) is not None + and issubclass(typing.get_origin(typ), collections.abc.Mapping) + and len(typing.get_args(typ)) == 2 + ): + raise TypeError(f"Expected a Mapping type, got {typ}.") + origin = typing.get_origin(typ) key_type, value_type = typing.get_args(typ) subtyp_hints = _get_typeddict_hints(subtyp) @@ -495,13 +540,13 @@ def _unify_mapping_typeddict(typ, subtyp, subs: Substitutions) -> Substitutions: @typing.overload def _unify_generic( - typ: GenericAlias, subtyp: type, subs: Substitutions + typ: GenericAlias, subtyp: TypeConstant, subs: Substitutions ) -> Substitutions: ... @typing.overload def _unify_generic( - typ: type, subtyp: GenericAlias, subs: Substitutions + typ: TypeConstant, subtyp: GenericAlias, subs: Substitutions ) -> Substitutions: ... @@ -511,7 +556,11 @@ def _unify_generic( ) -> Substitutions: ... -def _unify_generic(typ, subtyp, subs: Substitutions) -> Substitutions: +def _unify_generic( + typ: TypeConstant | GenericAlias, + subtyp: TypeConstant | GenericAlias, + subs: Substitutions, +) -> Substitutions: if ( isinstance(typ, GenericAlias) and isinstance(subtyp, GenericAlias) @@ -540,6 +589,11 @@ def _unify_generic(typ, subtyp, subs: Substitutions) -> Substitutions: return unify(typ, base[typing.get_args(subtyp)], subs) # type: ignore elif isinstance(typ, type) and isinstance(subtyp, GenericAlias): return unify(typ, typing.get_origin(subtyp), subs) + elif isinstance(typ, GenericAlias) and _is_typeddict_type(subtyp): + if _is_typeddict_type(typ): + return _unify_typeddict(typ, subtyp, subs) + else: + return _unify_mapping_typeddict(typ, subtyp, subs) elif ( isinstance(typ, GenericAlias) and isinstance(subtyp, type) @@ -567,23 +621,9 @@ def evaluate(self, typ) -> TypeExpressions: return super().evaluate(typ) -def infer_return_type(bound_sig: inspect.BoundArguments) -> TypeExpressions: - """Infer the return type of a function from its bound arguments. - - The signature is converted to a TypedDict-shaped pattern via - :func:`_sig_to_type`; the bound arguments are converted via - :func:`_bound_sig_to_type`; the two are unified using the regular TypedDict - unification path; and the freshened return type from the pattern is - substituted with the resulting bindings. - - ``typing.Self`` is eliminated by :func:`_sig_to_type` (via - :class:`SelfTypeReplacer`) — by the time we reach unification, no - ``typing.Self`` references remain anywhere. - """ - pattern = _sig_to_type(bound_sig.signature) - bound = _freshen(_bound_sig_to_type(bound_sig)) - return_type = _get_typeddict_hints(pattern)["return"] - return substitute(return_type, unify(pattern, bound)) +def _get_encoded_return_annotation(sig_type_encoding) -> TypeExpression: + """Extract the return annotation from a signature encoded as a type.""" + return _get_typeddict_hints(sig_type_encoding)["return"] def _sig_to_type(sig: inspect.Signature) -> TypeExpression: @@ -616,12 +656,12 @@ def _sig_to_type(sig: inspect.Signature) -> TypeExpression: sig.replace(return_annotation=typing.get_args(sig.return_annotation)[0]) ) - replacer = SelfTypeReplacer(typing.TypeVar("Self")) # type: ignore[misc] + replacer = SelfTypeReplacer(typing.TypeVar("Self")) annotations: dict[str, TypeExpressions] = { - "return": typing.NotRequired[ + "return": typing.NotRequired[ # type: ignore[dict-item] typing.ReadOnly[replacer.evaluate(sig.return_annotation)] - ] # type: ignore[assignment] + ] } for name, param in sig.parameters.items(): if param.annotation is inspect.Parameter.empty: @@ -629,17 +669,17 @@ def _sig_to_type(sig: inspect.Signature) -> TypeExpression: ann = replacer.evaluate(param.annotation) if param.kind == inspect.Parameter.VAR_POSITIONAL: field: TypeExpressions = typing.NotRequired[ - typing.ReadOnly[tuple[ann, ...]] + typing.ReadOnly[tuple[ann, ...]] # type: ignore[valid-type] ] # type: ignore[assignment] elif param.kind == inspect.Parameter.VAR_KEYWORD: - field = typing.NotRequired[typing.ReadOnly[dict[str, ann]]] # type: ignore[assignment] + field = typing.NotRequired[typing.ReadOnly[dict[str, ann]]] # type: ignore[valid-type,assignment] elif param.default is not inspect.Parameter.empty: field = typing.NotRequired[typing.ReadOnly[ann]] # type: ignore[assignment] else: field = typing.ReadOnly[ann] # type: ignore[assignment] annotations[name] = field - return typing.TypedDict(f"{sig}_Type", annotations) + return typing.TypedDict(f"{sig}_Type", annotations) # type: ignore[operator] def _bound_sig_to_type(bound_sig: inspect.BoundArguments) -> TypeExpression: @@ -660,12 +700,17 @@ def _bound_sig_to_type(bound_sig: inspect.BoundArguments) -> TypeExpression: by :func:`_sig_to_type`. """ sig: inspect.Signature = bound_sig.signature - typed_arguments = sig.bind( + typed_arguments: collections.OrderedDict[ + str, + TypeExpression + | collections.abc.Sequence[TypeExpression] + | collections.abc.Mapping[str, TypeExpression], + ] = sig.bind( *[nested_type(a).value for a in bound_sig.args], **{k: nested_type(v).value for k, v in bound_sig.kwargs.items()}, ).arguments - annotations: dict[str, TypeExpressions] = {} + annotations: dict[str, TypeExpression] = {} for name, param in sig.parameters.items(): if param.annotation is inspect.Parameter.empty: continue @@ -685,25 +730,28 @@ def _bound_sig_to_type(bound_sig: inspect.BoundArguments) -> TypeExpression: if param.kind == inspect.Parameter.VAR_POSITIONAL and isinstance( psubtyp, collections.abc.Sequence ): - annotations[name] = tuple[psubtyp] + annotations[name] = tuple[psubtyp] # type: ignore[valid-type] elif param.kind == inspect.Parameter.VAR_KEYWORD and isinstance( psubtyp, collections.abc.Mapping ): - annotations[name] = typing.TypedDict(f"{name}BoundKwargs", psubtyp) + annotations[name] = typing.TypedDict(f"{name}BoundKwargs", psubtyp) # type: ignore[operator] elif param.kind not in { inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL, } or isinstance(psubtyp, typing.ParamSpecArgs | typing.ParamSpecKwargs): + assert not isinstance( + psubtyp, collections.abc.Sequence | collections.abc.Mapping + ) annotations[name] = psubtyp else: raise TypeError( f"Cannot unify parameter {param} with argument {psubtyp} in signature unification." ) - return typing.TypedDict("BoundSigType", annotations) + return typing.TypedDict("BoundSigType", annotations) # type: ignore[operator] -def _freshen(tp: typing.Any): +def _freshen[T: TypeExpressions](tp: T) -> T: """ Return a freshened version of the given type expression. @@ -724,10 +772,10 @@ def _freshen(tp: typing.Any): """ fvs = freetypevars(tp) assert all(canonicalize(fv) is fv for fv in fvs) - subs: Substitutions = { + subs: collections.abc.Mapping[TypeVariable, TypeVariable] = { fv: typing.TypeVar(fv.__name__, bound=fv.__bound__) if isinstance(fv, typing.TypeVar) - else typing.ParamSpec(fv.__name__) + else typing.ParamSpec(fv.__name__, bound=fv.__bound__) for fv in fvs if isinstance(fv, typing.TypeVar | typing.ParamSpec) } @@ -820,7 +868,7 @@ def _(typ: typing.TypeVar): @canonicalize.register def _(typ: typing.ParamSpec): if ( - typ.__bound__ + typ.__bound__ not in (None, type(None)) or typ.__covariant__ or typ.__contravariant__ or getattr(typ, "__default__", None) is not getattr(typing, "NoDefault", None) @@ -1163,7 +1211,7 @@ def _(value: str | bytes | range | None): return Box(type(value)) -def freetypevars(typ) -> collections.abc.Set[TypeVariable]: +def freetypevars(typ: TypeExpressions) -> collections.abc.Set[TypeVariable]: """ Return a set of free type variables in the given type expression. @@ -1221,6 +1269,38 @@ def freetypevars(typ) -> collections.abc.Set[TypeVariable]: return set(_freetypevars((typ,))) +@typing.overload +def substitute(typ: typing.TypeVar, subs: Substitutions) -> TypeExpression: ... + + +@typing.overload +def substitute( + typ: typing.ParamSpec, subs: Substitutions +) -> ( + typing.ParamSpec | types.EllipsisType | collections.abc.Sequence[TypeExpression] +): ... + + +@typing.overload +def substitute( + typ: typing.TypeVarTuple, subs: Substitutions +) -> ( + typing.TypeVarTuple | types.EllipsisType | collections.abc.Sequence[TypeExpression] +): ... + + +@typing.overload +def substitute[ + T: TypeConstant | TypeApplication | collections.abc.Sequence[TypeExpression] +](typ: T, subs: Substitutions) -> T: ... + + +@typing.overload +def substitute[T: TypeExpressions]( + typ: T, subs: collections.abc.Mapping[TypeVariable, TypeVariable] +) -> T: ... + + def substitute(typ, subs: Substitutions) -> TypeExpressions: """ Substitute type variables in a type expression with concrete types. @@ -1279,12 +1359,15 @@ def substitute(typ, subs: Substitutions) -> TypeExpressions: return typ -def _substitute_typeddict(typ, subs: Substitutions) -> TypeExpressions: - """Substitute inside a TypedDict's fields, preserving NotRequired-ness.""" +def _substitute_typeddict[T: TypedDictType](typ: T, subs: Substitutions) -> T: + """Substitute inside a TypedDict's fields, preserving NotRequired-ness and ReadOnly-ness.""" typ_origin = typing.get_origin(typ) or typ optional_keys: frozenset[str] = getattr( typ_origin, "__optional_keys__", frozenset() ) + readonly_keys: frozenset[str] = getattr( + typ_origin, "__readonly_keys__", frozenset() + ) hints = _get_typeddict_hints(typ) new_fields: dict[str, TypeExpressions] = {} changed = False @@ -1292,11 +1375,12 @@ def _substitute_typeddict(typ, subs: Substitutions) -> TypeExpressions: new_ftype = substitute(ftype, subs) if new_ftype is not ftype: changed = True + if field in readonly_keys: + new_ftype = typing.ReadOnly[new_ftype] # type: ignore[assignment] if field in optional_keys: - new_fields[field] = typing.NotRequired[new_ftype] # type: ignore[assignment] - else: - new_fields[field] = new_ftype + new_ftype = typing.NotRequired[new_ftype] # type: ignore[assignment] + new_fields[field] = new_ftype if not changed: return typ name = getattr(typ_origin, "__name__", "_Subbed") - return typing.TypedDict(name, new_fields) # type: ignore[operator,misc] + return typing.TypedDict(name, new_fields) # type: ignore[operator] diff --git a/tests/test_internals_unification.py b/tests/test_internals_unification.py index 3f44afe28..f00890fb7 100644 --- a/tests/test_internals_unification.py +++ b/tests/test_internals_unification.py @@ -608,7 +608,7 @@ def returns_annotated[T](x: T) -> typing.Annotated[T, "meta"]: return x -def no_return_annotation_func(x: int): # type: ignore[no-untyped-def] +def no_return_annotation_func(x: int): return x @@ -1918,7 +1918,7 @@ class Info(typing.TypedDict): collections.abc.Callable[[typing.Self], int], lambda tv: collections.abc.Callable[[tv], int], ), - (type[typing.Self], lambda tv: type[tv]), # type: ignore[misc,valid-type] + (type[typing.Self], lambda tv: type[tv]), # Union containing Self (typing.Self | None, lambda tv: tv | None), # Types without Self pass through unchanged @@ -1932,7 +1932,7 @@ def test_self_type_replacer(typ, expected_factory, tv_name): """SelfTypeReplacer rewrites every typing.Self occurrence to the given TypeVar.""" from effectful.internals.unification import SelfTypeReplacer - tv = typing.TypeVar(tv_name) # type: ignore[misc] + tv = typing.TypeVar(tv_name) replacer = SelfTypeReplacer(tv) assert replacer.evaluate(typ) == expected_factory(tv) @@ -2018,7 +2018,7 @@ def test_infer_return_type_direct_simple(): Inputs are Boxed types per the convention used by ``__type_rule__``. """ - def f(x: T) -> list[T]: ... # type: ignore[valid-type] + def f(x: T) -> list[T]: ... # type: ignore[empty-body] sig = inspect.signature(f) result = infer_return_type(sig.bind(Box(int))) @@ -2034,7 +2034,7 @@ def test_infer_return_type_direct_self(): """ class C: - def m(self: typing.Self) -> typing.Self: ... + def m(self: typing.Self) -> typing.Self: ... # type: ignore[empty-body] sig = inspect.signature(C.m) result = infer_return_type(sig.bind(C())) @@ -2045,7 +2045,7 @@ def test_infer_return_type_bare_self_unresolved(): """A method with bare ``self`` does not bind Self — it stays a free TypeVar.""" class C: - def m(self) -> typing.Self: ... + def m(self) -> typing.Self: ... # type: ignore[empty-body] sig = inspect.signature(C.m) result = infer_return_type(sig.bind(C())) @@ -2056,7 +2056,7 @@ def m(self) -> typing.Self: ... def test_infer_return_type_direct_default_omitted(): """infer_return_type with a default-valued param omitted still returns correctly.""" - def f(x: T, y: str = "ok") -> T: ... # type: ignore[valid-type] + def f(x: T, y: str = "ok") -> T: ... sig = inspect.signature(f) result = infer_return_type(sig.bind(Box(int))) @@ -2078,7 +2078,7 @@ class User(typing.TypedDict): name: str age: int - def greet(u: User) -> str: ... + def greet(u: User) -> str: ... # type: ignore[empty-body] sig = inspect.signature(greet) # nested_type({"name": "a", "age": 1}) returns a TypedDict, which is @@ -2095,7 +2095,7 @@ class Datum[T](typing.TypedDict): name: str value: T - def extract[T](d: Datum[T]) -> T: ... + def extract[T](d: Datum[T]) -> T: ... # type: ignore[empty-body] sig = inspect.signature(extract) result = infer_return_type(sig.bind({"name": "a", "value": 42})) @@ -2115,7 +2115,7 @@ def test_infer_return_type_conflicting_self_bindings(): """ class C: - def f(self: typing.Self, other: typing.Self) -> typing.Self: ... + def f(self: typing.Self, other: typing.Self) -> typing.Self: ... # type: ignore[empty-body] sig = inspect.signature(C.f) # Same type for both: succeeds. @@ -2129,7 +2129,7 @@ def test_infer_return_type_wrong_arg_type(): """Passing an arg whose type is incompatible with the parameter annotation must raise.""" - def f(x: int) -> int: ... + def f(x: int) -> int: ... # type: ignore[empty-body] sig = inspect.signature(f) with pytest.raises(TypeError): @@ -2140,7 +2140,7 @@ def test_infer_return_type_missing_required_arg(): """Omitting a required positional arg surfaces as a Signature.bind error before we reach infer_return_type — document that contract here.""" - def f(x: int, y: int) -> int: ... + def f(x: int, y: int) -> int: ... # type: ignore[empty-body] sig = inspect.signature(f) with pytest.raises(TypeError): From c7198b032dc8b61344eabbb85a8c4c8d638b4001 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 6 May 2026 11:28:44 -0400 Subject: [PATCH 6/9] fix --- effectful/internals/unification.py | 41 +++++++++++++++++++++--------- effectful/ops/types.py | 2 +- 2 files changed, 30 insertions(+), 13 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index b4fca9d6d..29123cad5 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -194,7 +194,9 @@ def _(self, typ: typing._SpecialForm) -> TypeExpressions: return typ # type: ignore -def infer_return_type(bound_sig: inspect.BoundArguments) -> TypeExpressions: +def infer_return_type( + bound_sig: inspect.BoundArguments, *, always_check: bool = True +) -> TypeExpressions: """Infer the return type of a function from its bound arguments. The signature is converted to a TypedDict-shaped pattern via @@ -208,9 +210,14 @@ def infer_return_type(bound_sig: inspect.BoundArguments) -> TypeExpressions: ``typing.Self`` references remain anywhere. """ pattern = _freshen(_sig_to_type(bound_sig.signature)) - bound = _freshen(_bound_sig_to_type(bound_sig)) - return_type = _get_encoded_return_annotation(pattern) - return substitute(return_type, unify(pattern, bound)) + return_anno = _get_encoded_return_annotation(pattern) + if not always_check and not freetypevars(return_anno): + # Fast path: if the return annotation is closed (has no free type variables), + # we can skip unification and just return it directly. + return return_anno + else: + bound = _freshen(_bound_sig_to_type(bound_sig)) + return substitute(return_anno, unify(pattern, bound)) def unify( @@ -626,6 +633,23 @@ def _get_encoded_return_annotation(sig_type_encoding) -> TypeExpression: return _get_typeddict_hints(sig_type_encoding)["return"] +def _fix_return_annotation(sig: inspect.Signature) -> inspect.Signature: + """Replace empty return annotations with object, and None with type(None). + + This ensures that all signatures have a return annotation that can be + processed by _sig_to_type without special-casing the empty annotation or + None. + """ + if sig.return_annotation is inspect.Parameter.empty: + return sig.replace(return_annotation=object) + elif sig.return_annotation is None: + return sig.replace(return_annotation=type(None)) + elif typing.get_origin(sig.return_annotation) is typing.Annotated: + return sig.replace(return_annotation=typing.get_args(sig.return_annotation)[0]) + else: + return sig + + def _sig_to_type(sig: inspect.Signature) -> TypeExpression: """Convert an :class:`inspect.Signature` to a TypedDict pattern for unification. @@ -647,14 +671,7 @@ def _sig_to_type(sig: inspect.Signature) -> TypeExpression: ``NotRequired``; all fields are also wrapped in ``ReadOnly`` to give the TypedDict unification path covariant semantics for parameter types. """ - if sig.return_annotation is inspect.Parameter.empty: - return _sig_to_type(sig.replace(return_annotation=object)) - elif sig.return_annotation is None: - return _sig_to_type(sig.replace(return_annotation=type(None))) - elif typing.get_origin(sig.return_annotation) is typing.Annotated: - return _sig_to_type( - sig.replace(return_annotation=typing.get_args(sig.return_annotation)[0]) - ) + sig = _fix_return_annotation(sig) replacer = SelfTypeReplacer(typing.TypeVar("Self")) diff --git a/effectful/ops/types.py b/effectful/ops/types.py index 91d0694d2..3bb381d2d 100644 --- a/effectful/ops/types.py +++ b/effectful/ops/types.py @@ -382,7 +382,7 @@ def __type_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> type[V]: from effectful.internals.unification import infer_return_type bound_sig = self.__signature__.bind(*args, **kwargs) - return typing.cast(type[V], infer_return_type(bound_sig)) + return typing.cast(type[V], infer_return_type(bound_sig, always_check=False)) @typing.final def __fvs_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> inspect.BoundArguments: From 0feb76acb6bb65b3091440d51bb04bacfa20b157 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 6 May 2026 11:48:35 -0400 Subject: [PATCH 7/9] typeguard --- effectful/internals/unification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index 29123cad5..e65cc202b 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -406,7 +406,7 @@ def _unify_union(typ, subtyp, subs: Substitutions) -> Substitutions: def _is_typeddict_type( typ: TypeExpressions, -) -> typing.TypeIs[TypedDictType]: +) -> typing.TypeGuard[TypedDictType]: """Check if typ is a TypedDict class or a parameterized TypedDict (e.g. Datum[T]).""" if isinstance(typ, GenericAlias): return typing.is_typeddict(typing.get_origin(typ)) From a2a678ef6b347d3baa441623a19969dd1726c15d Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 6 May 2026 12:03:53 -0400 Subject: [PATCH 8/9] gate readonly --- effectful/internals/unification.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index e65cc202b..fbe005d1e 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -68,6 +68,7 @@ import inspect import numbers import operator +import sys import types import typing @@ -675,25 +676,30 @@ def _sig_to_type(sig: inspect.Signature) -> TypeExpression: replacer = SelfTypeReplacer(typing.TypeVar("Self")) - annotations: dict[str, TypeExpressions] = { + annotations: dict[str, TypeExpression] = { "return": typing.NotRequired[ # type: ignore[dict-item] typing.ReadOnly[replacer.evaluate(sig.return_annotation)] + if sys.version_info >= (3, 13) + else replacer.evaluate(sig.return_annotation) ] } for name, param in sig.parameters.items(): if param.annotation is inspect.Parameter.empty: continue ann = replacer.evaluate(param.annotation) + assert not isinstance(ann, collections.abc.Sequence) if param.kind == inspect.Parameter.VAR_POSITIONAL: - field: TypeExpressions = typing.NotRequired[ - typing.ReadOnly[tuple[ann, ...]] # type: ignore[valid-type] + field: TypeExpression = typing.NotRequired[ + tuple[ann, ...] # type: ignore[valid-type] ] # type: ignore[assignment] elif param.kind == inspect.Parameter.VAR_KEYWORD: - field = typing.NotRequired[typing.ReadOnly[dict[str, ann]]] # type: ignore[valid-type,assignment] + field = typing.NotRequired[dict[str, ann]] # type: ignore[valid-type,assignment] elif param.default is not inspect.Parameter.empty: - field = typing.NotRequired[typing.ReadOnly[ann]] # type: ignore[assignment] + field = typing.NotRequired[ann] # type: ignore[assignment] else: - field = typing.ReadOnly[ann] # type: ignore[assignment] + field = ann + if sys.version_info >= (3, 13): + field = typing.ReadOnly[field] # type: ignore annotations[name] = field return typing.TypedDict(f"{sig}_Type", annotations) # type: ignore[operator] @@ -836,8 +842,8 @@ def _(typ: type | abc.ABCMeta): canon_fields: dict[str, type] = {} for field, ftype in hints.items(): ct = canonicalize(ftype) - if field in readonly_keys: - ct = typing.ReadOnly[ct] # type: ignore[assignment] + if field in readonly_keys and sys.version_info >= (3, 13): + ct = typing.ReadOnly[ct] # type: ignore if field in optional_keys: ct = typing.NotRequired[ct] # type: ignore[assignment] canon_fields[field] = ct # type: ignore[assignment] @@ -1392,8 +1398,8 @@ def _substitute_typeddict[T: TypedDictType](typ: T, subs: Substitutions) -> T: new_ftype = substitute(ftype, subs) if new_ftype is not ftype: changed = True - if field in readonly_keys: - new_ftype = typing.ReadOnly[new_ftype] # type: ignore[assignment] + if field in readonly_keys and sys.version_info >= (3, 13): + new_ftype = typing.ReadOnly[new_ftype] # type: ignore if field in optional_keys: new_ftype = typing.NotRequired[new_ftype] # type: ignore[assignment] new_fields[field] = new_ftype From c2043308b83c21459956578d681036fdfbf70efe Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 6 May 2026 13:09:16 -0400 Subject: [PATCH 9/9] py312 --- effectful/internals/unification.py | 57 ++++++++++--------- tests/test_internals_unification.py | 35 ++++++------ tests/test_internals_unification_typeddict.py | 35 +++++------- 3 files changed, 61 insertions(+), 66 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index fbe005d1e..a8073e766 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -68,10 +68,11 @@ import inspect import numbers import operator -import sys import types import typing +import typing_extensions + try: from typing import _collect_type_parameters as _freetypevars # type: ignore except ImportError: @@ -410,9 +411,9 @@ def _is_typeddict_type( ) -> typing.TypeGuard[TypedDictType]: """Check if typ is a TypedDict class or a parameterized TypedDict (e.g. Datum[T]).""" if isinstance(typ, GenericAlias): - return typing.is_typeddict(typing.get_origin(typ)) + return typing_extensions.is_typeddict(typing.get_origin(typ)) else: - return typing.is_typeddict(typ) + return typing_extensions.is_typeddict(typ) def _get_typeddict_hints( @@ -421,16 +422,16 @@ def _get_typeddict_hints( """Get type hints for a TypedDict, substituting type params if parameterized.""" if isinstance(typ, GenericAlias): origin = typing.get_origin(typ) - assert typing.is_typeddict(origin), ( + assert typing_extensions.is_typeddict(origin), ( f"Expected a parameterized TypedDict, got {typ}." ) args = typing.get_args(typ) type_params = origin.__type_params__ - hints = typing.get_type_hints(origin) - param_subs = dict(zip(type_params, args)) + hints = typing_extensions.get_type_hints(origin) + param_subs = typing.cast(Substitutions, dict(zip(type_params, args))) return {field: substitute(hint, param_subs) for field, hint in hints.items()} else: - hints = typing.get_type_hints(typ) + hints = typing_extensions.get_type_hints(typ) # For classes like Derived(Base[int]), resolve unsubstituted TypeVars # from parameterized bases. base_param_subs: dict[TypeVariable, TypeExpression] = { @@ -678,9 +679,7 @@ def _sig_to_type(sig: inspect.Signature) -> TypeExpression: annotations: dict[str, TypeExpression] = { "return": typing.NotRequired[ # type: ignore[dict-item] - typing.ReadOnly[replacer.evaluate(sig.return_annotation)] - if sys.version_info >= (3, 13) - else replacer.evaluate(sig.return_annotation) + typing_extensions.ReadOnly[replacer.evaluate(sig.return_annotation)] ] } for name, param in sig.parameters.items(): @@ -698,11 +697,9 @@ def _sig_to_type(sig: inspect.Signature) -> TypeExpression: field = typing.NotRequired[ann] # type: ignore[assignment] else: field = ann - if sys.version_info >= (3, 13): - field = typing.ReadOnly[field] # type: ignore - annotations[name] = field + annotations[name] = typing_extensions.ReadOnly[field] # type: ignore[assignment] - return typing.TypedDict(f"{sig}_Type", annotations) # type: ignore[operator] + return typing_extensions.TypedDict(f"{sig}_Type", annotations) # type: ignore[operator] def _bound_sig_to_type(bound_sig: inspect.BoundArguments) -> TypeExpression: @@ -757,7 +754,9 @@ def _bound_sig_to_type(bound_sig: inspect.BoundArguments) -> TypeExpression: elif param.kind == inspect.Parameter.VAR_KEYWORD and isinstance( psubtyp, collections.abc.Mapping ): - annotations[name] = typing.TypedDict(f"{name}BoundKwargs", psubtyp) # type: ignore[operator] + annotations[name] = typing_extensions.TypedDict( + f"{name}BoundKwargs", psubtyp + ) # type: ignore[operator] elif param.kind not in { inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL, @@ -771,7 +770,7 @@ def _bound_sig_to_type(bound_sig: inspect.BoundArguments) -> TypeExpression: f"Cannot unify parameter {param} with argument {psubtyp} in signature unification." ) - return typing.TypedDict("BoundSigType", annotations) # type: ignore[operator] + return typing_extensions.TypedDict("BoundSigType", annotations) # type: ignore[operator] def _freshen[T: TypeExpressions](tp: T) -> T: @@ -829,8 +828,8 @@ def _(typ: type | abc.ABCMeta): return collections.abc.Set elif typ is range: return collections.abc.Sequence[int] - elif typing.is_typeddict(typ): - hints = typing.get_type_hints(typ) + elif typing_extensions.is_typeddict(typ): + hints = typing_extensions.get_type_hints(typ) # Idempotency: if all field types are already canonical, return same object if all(canonicalize(h) == h for h in hints.values()): return typ @@ -842,12 +841,12 @@ def _(typ: type | abc.ABCMeta): canon_fields: dict[str, type] = {} for field, ftype in hints.items(): ct = canonicalize(ftype) - if field in readonly_keys and sys.version_info >= (3, 13): - ct = typing.ReadOnly[ct] # type: ignore + if field in readonly_keys: + ct = typing_extensions.ReadOnly[ct] # type: ignore[assignment] if field in optional_keys: ct = typing.NotRequired[ct] # type: ignore[assignment] canon_fields[field] = ct # type: ignore[assignment] - return typing.TypedDict(typ.__name__, canon_fields) # type: ignore[operator] + return typing_extensions.TypedDict(typ.__name__, canon_fields) # type: ignore[operator] elif typ is types.GeneratorType: return collections.abc.Generator elif typ in {types.FunctionType, types.BuiltinFunctionType, types.LambdaType}: @@ -1185,9 +1184,13 @@ def _(value: collections.abc.Mapping): elif len(value) == 1: ktyp = nested_type(next(iter(value.keys()))).value vtyp = nested_type(next(iter(value.values()))).value - if ktyp is str and isinstance(vtyp, type) and typing.is_typeddict(vtyp): + if ( + ktyp is str + and isinstance(vtyp, type) + and typing_extensions.is_typeddict(vtyp) + ): fields = {key: nested_type(vl).value for key, vl in value.items()} - return Box(typing.TypedDict("RuntimeTypeDict", fields)) # type: ignore + return Box(typing_extensions.TypedDict("RuntimeTypeDict", fields)) # type: ignore return Box(canonicalize(type(value))[ktyp, vtyp]) # type: ignore else: ktyp = functools.reduce( @@ -1196,7 +1199,7 @@ def _(value: collections.abc.Mapping): if ktyp is str: # str-keyed multi-entry dicts → always TypedDict fields = {key: nested_type(vl).value for key, vl in value.items()} - return Box(typing.TypedDict("RuntimeTypeDict", fields)) # type: ignore + return Box(typing_extensions.TypedDict("RuntimeTypeDict", fields)) # type: ignore vtyp = functools.reduce( operator.or_, [nested_type(x).value for x in value.values()] ) @@ -1398,12 +1401,12 @@ def _substitute_typeddict[T: TypedDictType](typ: T, subs: Substitutions) -> T: new_ftype = substitute(ftype, subs) if new_ftype is not ftype: changed = True - if field in readonly_keys and sys.version_info >= (3, 13): - new_ftype = typing.ReadOnly[new_ftype] # type: ignore + if field in readonly_keys: + new_ftype = typing_extensions.ReadOnly[new_ftype] # type: ignore[assignment] if field in optional_keys: new_ftype = typing.NotRequired[new_ftype] # type: ignore[assignment] new_fields[field] = new_ftype if not changed: return typ name = getattr(typ_origin, "__name__", "_Subbed") - return typing.TypedDict(name, new_fields) # type: ignore[operator] + return typing_extensions.TypedDict(name, new_fields) # type: ignore[operator] diff --git a/tests/test_internals_unification.py b/tests/test_internals_unification.py index f00890fb7..fe8849746 100644 --- a/tests/test_internals_unification.py +++ b/tests/test_internals_unification.py @@ -6,6 +6,7 @@ from typing import Literal import pytest +import typing_extensions from effectful.internals.unification import ( Box, @@ -198,9 +199,9 @@ class MyTD(typing.TypedDict): # Fields are already canonical types result = canonicalize(MyTD) - assert typing.is_typeddict(result) + assert typing_extensions.is_typeddict(result) - hints = typing.get_type_hints(result) + hints = typing_extensions.get_type_hints(result) assert hints["name"] is str assert hints["age"] is int @@ -210,8 +211,8 @@ class MyTD2(typing.TypedDict): # list[int] -> MutableSequence[int], so a new TypedDict is created result = canonicalize(MyTD2) - assert typing.is_typeddict(result) - hints = typing.get_type_hints(result) + assert typing_extensions.is_typeddict(result) + hints = typing_extensions.get_type_hints(result) assert hints["name"] is str assert hints["items"] == collections.abc.MutableSequence[int] @@ -935,8 +936,8 @@ def test_nested_type_typeddict_str_keys_mixed_values(): value = {"name": "Alice", "age": 30} result = nested_type(value).value # Should be a TypedDict, not dict - assert typing.is_typeddict(result) - hints = typing.get_type_hints(result) + assert typing_extensions.is_typeddict(result) + hints = typing_extensions.get_type_hints(result) assert hints == {"name": str, "age": int} @@ -944,8 +945,8 @@ def test_nested_type_typeddict_multiple_value_types(): """TypedDict with more than two distinct value types.""" value = {"label": "x", "count": 5, "flag": True} result = nested_type(value).value - assert typing.is_typeddict(result) - hints = typing.get_type_hints(result) + assert typing_extensions.is_typeddict(result) + hints = typing_extensions.get_type_hints(result) assert hints == {"label": str, "count": int, "flag": bool} @@ -953,8 +954,8 @@ def test_nested_type_typeddict_nested_values(): """TypedDict with nested collection values.""" value = {"items": [1, 2, 3], "name": "test"} result = nested_type(value).value - assert typing.is_typeddict(result) - hints = typing.get_type_hints(result) + assert typing_extensions.is_typeddict(result) + hints = typing_extensions.get_type_hints(result) assert canonicalize(hints["items"]) == canonicalize(list[int]) assert hints["name"] is str @@ -968,21 +969,21 @@ class UserTD(typing.TypedDict): value = UserTD(name="a", age=1) result = nested_type(value).value - assert typing.is_typeddict(result) - hints = typing.get_type_hints(result) + assert typing_extensions.is_typeddict(result) + hints = typing_extensions.get_type_hints(result) assert hints == {"name": str, "age": int} def test_nested_type_typeddict_homogeneous_str_keys(): """Multi-key str dicts produce TypedDict even with homogeneous value types.""" result = nested_type({"a": 1, "b": 2}).value - assert typing.is_typeddict(result) - hints = typing.get_type_hints(result) + assert typing_extensions.is_typeddict(result) + hints = typing_extensions.get_type_hints(result) assert hints == {"a": int, "b": int} result = nested_type({"a": {1, 2}, "b": {3, 4}}).value - assert typing.is_typeddict(result) - hints = typing.get_type_hints(result) + assert typing_extensions.is_typeddict(result) + hints = typing_extensions.get_type_hints(result) assert canonicalize(hints["a"]) == canonicalize(set[int]) assert canonicalize(hints["b"]) == canonicalize(set[int]) @@ -992,7 +993,7 @@ def test_nested_type_non_str_keys_mixed_values_stays_dict(): value = {1: "one", 2: True} result = nested_type(value).value # Should remain a plain dict type, not a TypedDict - assert not typing.is_typeddict(result) + assert not typing_extensions.is_typeddict(result) assert result is dict diff --git a/tests/test_internals_unification_typeddict.py b/tests/test_internals_unification_typeddict.py index 7fbc370b7..b07362015 100644 --- a/tests/test_internals_unification_typeddict.py +++ b/tests/test_internals_unification_typeddict.py @@ -1,19 +1,14 @@ -"""TypedDict unification tests using Required/NotRequired/ReadOnly annotations. +"""TypedDict unification tests using Required/NotRequired/typing_extensions.ReadOnly annotations. Separated from test_internals_unification.py because mypy 1.19 cannot serialize RequiredType instances to its cache, causing a crash. """ import collections.abc -import sys import typing import pytest - -if sys.version_info >= (3, 13): - from typing import ReadOnly -else: - from typing_extensions import ReadOnly +import typing_extensions from effectful.internals.unification import ( canonicalize, @@ -193,7 +188,7 @@ class TD(typing.TypedDict): items: typing.NotRequired[list[int]] result = canonicalize(TD) - assert typing.is_typeddict(result) + assert typing_extensions.is_typeddict(result) assert "items" in result.__optional_keys__ @@ -242,29 +237,25 @@ class Sub(typing.TypedDict): unify(Pattern, Sub) -@pytest.mark.skipif( - sys.version_info < (3, 13), reason="ReadOnly TypedDict requires 3.13+" -) def test_unify_typeddict_readonly_covariance(): - """ReadOnly field allows covariant subtyping.""" - Pattern = typing.TypedDict("Pattern", {"x": ReadOnly[int]}) # noqa: UP013 - Sub = typing.TypedDict("Sub", {"x": ReadOnly[bool]}) # noqa: UP013 + """typing_extensions.ReadOnly field allows covariant subtyping.""" + Pattern = typing_extensions.TypedDict( # noqa: UP013 + "Pattern", {"x": typing_extensions.ReadOnly[int]} + ) + Sub = typing_extensions.TypedDict("Sub", {"x": typing_extensions.ReadOnly[bool]}) # noqa: UP013 - # bool is subtype of int, ReadOnly allows covariance + # bool is subtype of int, typing_extensions.ReadOnly allows covariance subs = unify(Pattern, Sub) assert subs == {} -@pytest.mark.skipif( - sys.version_info < (3, 13), reason="ReadOnly TypedDict requires 3.13+" -) def test_unify_typeddict_readonly_notrequired_to_required(): - """ReadOnly NotRequired in typ, Required in subtyp → OK (promotion).""" - Pattern = typing.TypedDict( # noqa: UP013 + """typing_extensions.ReadOnly NotRequired in typ, Required in subtyp → OK (promotion).""" + Pattern = typing_extensions.TypedDict( # noqa: UP013 "Pattern", - {"x": ReadOnly[typing.NotRequired[str]]}, + {"x": typing_extensions.ReadOnly[typing.NotRequired[str]]}, ) - Sub = typing.TypedDict("Sub", {"x": str}) # noqa: UP013 + Sub = typing_extensions.TypedDict("Sub", {"x": str}) # noqa: UP013 subs = unify(Pattern, Sub) assert subs == {}