diff --git a/docs/source/effectful.rst b/docs/source/effectful.rst index 6d33542a..5ef25be8 100644 --- a/docs/source/effectful.rst +++ b/docs/source/effectful.rst @@ -98,3 +98,7 @@ Internals .. automodule:: effectful.internals.runtime :members: :undoc-members: + +.. automodule:: effectful.internals.unification + :members: + :undoc-members: diff --git a/docs/source/semi_ring.py b/docs/source/semi_ring.py index f4425a84..c60dabf0 100644 --- a/docs/source/semi_ring.py +++ b/docs/source/semi_ring.py @@ -62,17 +62,17 @@ def Let[S, T, A, B]( @defop -def Record[T](**kwargs: T) -> dict[str, T]: +def Record[T](**kwargs: T) -> collections.abc.Mapping[str, T]: raise NotImplementedError @defop -def Field[T](record: dict[str, T], key: str) -> T: +def Field[T](record: collections.abc.Mapping[str, T], key: str) -> T: raise NotImplementedError @defop -def Dict[K, V](*contents: Union[K, V]) -> SemiRingDict[K, V]: +def Dict[K, V](*contents: tuple[K, V]) -> SemiRingDict[K, V]: raise NotImplementedError @@ -92,20 +92,14 @@ def add[T](x: T, y: T) -> T: ops.Field = Field -def eager_dict[K, V](*contents: Tuple[K, V]) -> SemiRingDict[K, V]: - if not any(isinstance(v, Term) for v in contents): - if len(contents) % 2 != 0: - raise ValueError("Dict requires an even number of arguments") - - kv = [] - for i in range(0, len(contents), 2): - kv.append((contents[i], contents[i + 1])) - return SemiRingDict(kv) +def eager_dict[K, V](*contents: tuple[K, V]) -> SemiRingDict[K, V]: + if not any(isinstance(v, Term) for kv in contents for v in kv): + return SemiRingDict(list(contents)) else: return fwd() -def eager_record[T](**kwargs: T) -> dict[str, T]: +def eager_record[T](**kwargs: T) -> collections.abc.Mapping[str, T]: if not any(isinstance(v, Term) for v in kwargs.values()): return dict(**kwargs) else: @@ -215,7 +209,7 @@ def vertical_fusion[S, T](e1: T, x: Operation[[], T], e2: S) -> S: ) term: SemiRingDict[int, int] = Let( - Sum(x(), k, v, Dict(k(), v() + 1)), y, Sum(y(), k, v, Dict(k(), v() + 1)) + Sum(x(), k, v, Dict((k(), v() + 1))), y, Sum(y(), k, v, Dict((k(), v() + 1))) ) print("Without optimization:", term) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py new file mode 100644 index 00000000..ff9eebb2 --- /dev/null +++ b/effectful/internals/unification.py @@ -0,0 +1,900 @@ +"""Type unification and inference utilities for Python's generic type system. + +This module implements a unification algorithm for type inference over a subset of +Python's generic types. Unification is a fundamental operation in type systems that +finds substitutions for type variables to make two types equivalent. + +The module provides four main operations: + +1. **unify(typ, subtyp, subs={})**: The core unification algorithm that attempts to + find a substitution mapping for type variables that makes a pattern type equal to + a concrete type. It handles TypeVars, generic types (List[T], Dict[K,V]), unions, + callables, and function signatures with inspect.Signature/BoundArguments. + +2. **substitute(typ, subs)**: Applies a substitution mapping to a type expression, + replacing all TypeVars with their mapped concrete types. This is used to + instantiate generic types after unification. + +3. **freetypevars(typ)**: Extracts all free (unbound) type variables from a type + expression. Useful for analyzing generic types and ensuring all TypeVars are + properly bound. + +4. **nested_type(value)**: Infers the type of a runtime value, handling nested + collections by recursively determining element types. For example, [1, 2, 3] + becomes list[int], and {"key": [1, 2]} becomes dict[str, list[int]]. + +The unification algorithm uses a single-dispatch pattern to handle different type +combinations: +- TypeVar unification binds variables to concrete types +- Generic type unification matches origins and recursively unifies type arguments +- Structural unification handles sequences and mappings by element +- Union types attempt unification with any matching branch +- Function signatures unify parameter types with bound arguments + +Example usage: + >>> from effectful.internals.unification import unify, substitute, freetypevars + >>> import typing + >>> T = typing.TypeVar('T') + >>> K = typing.TypeVar('K') + >>> V = typing.TypeVar('V') + + >>> # Find substitution that makes list[T] equal to list[int] + >>> subs = unify(list[T], list[int]) + >>> subs + {~T: } + + >>> # Apply substitution to instantiate a generic type + >>> substitute(dict[K, list[V]], {K: str, V: int}) + dict[str, list[int]] + + >>> # Find all type variables in a type expression + >>> freetypevars(dict[str, list[V]]) + {~V} + +This module is primarily used internally by effectful for type inference in its +effect system, allowing it to track and propagate type information through +effect handlers and operations. +""" + +import abc +import builtins +import collections +import collections.abc +import functools +import inspect +import numbers +import operator +import types +import typing + +try: + from typing import _collect_type_parameters as _freetypevars # type: ignore +except ImportError: + from typing import _collect_parameters as _freetypevars # type: ignore + +import effectful.ops.types + +if typing.TYPE_CHECKING: + TypeConstant = type | abc.ABCMeta | types.EllipsisType | None + GenericAlias = types.GenericAlias + UnionType = types.UnionType +else: + TypeConstant = ( + type | abc.ABCMeta | types.EllipsisType | type(None) | type(typing.Any) + ) + GenericAlias = types.GenericAlias | typing._GenericAlias + UnionType = types.UnionType | typing._UnionGenericAlias + +TypeVariable = typing.TypeVar | typing.TypeVarTuple | typing.ParamSpec +TypeApplication = GenericAlias | UnionType +TypeExpression = TypeVariable | TypeConstant | TypeApplication +TypeExpressions = TypeExpression | collections.abc.Sequence[TypeExpression] + +Substitutions = collections.abc.Mapping[TypeVariable, TypeExpressions] + + +@typing.overload +def unify( + typ: inspect.Signature, + subtyp: inspect.BoundArguments, + subs: Substitutions = {}, +) -> Substitutions: ... + + +@typing.overload +def unify( + typ: TypeExpressions, + subtyp: TypeExpressions, + subs: Substitutions = {}, +) -> Substitutions: ... + + +def unify(typ, subtyp, subs: Substitutions = {}) -> Substitutions: + """ + Unify a pattern type with a concrete type, returning a substitution map. + + This function attempts to find a substitution of type variables that makes + the pattern type (typ) equal to the concrete type (subtyp). It updates + and returns the substitution mapping, or raises TypeError if unification + is not possible. + + The function handles: + - TypeVar unification (binding type variables to concrete types) + - Generic type unification (matching origins and recursively unifying args) + - Structural unification of sequences and mappings + - Exact type matching for non-generic types + + Args: + typ: The pattern type that may contain TypeVars to be unified + subtyp: The concrete type to unify with the pattern + subs: Existing substitution mappings to be extended (not modified) + + Returns: + A new substitution mapping that includes all previous substitutions + plus any new TypeVar bindings discovered during unification. + + Raises: + TypeError: If unification is not possible (incompatible types or + conflicting TypeVar bindings) + + Examples: + >>> import typing + >>> T = typing.TypeVar('T') + >>> K = typing.TypeVar('K') + >>> V = typing.TypeVar('V') + + >>> # Simple TypeVar unification + >>> unify(T, int, {}) + {~T: } + + >>> # Generic type unification + >>> unify(list[T], list[int], {}) + {~T: } + + >>> # Exact type matching + >>> unify(int, int, {}) + {} + + >>> # Failed unification - incompatible types + >>> unify(list[T], dict[str, int], {}) # doctest: +ELLIPSIS + Traceback (most recent call last): + ... + TypeError: Cannot unify ... + + >>> # Failed unification - conflicting TypeVar binding + >>> unify(T, str, {T: int}) # doctest: +ELLIPSIS + Traceback (most recent call last): + ... + TypeError: Cannot unify ... + """ + if isinstance(typ, inspect.Signature): + return _unify_signature(typ, subtyp, subs) + + if typ != canonicalize(typ) or subtyp != canonicalize(subtyp): + return unify(canonicalize(typ), canonicalize(subtyp), subs) + + if typ is subtyp or typ == subtyp: + return subs + elif isinstance(typ, TypeVariable) or isinstance(subtyp, TypeVariable): + return _unify_typevar(typ, subtyp, subs) + elif isinstance(typ, collections.abc.Sequence) or isinstance( + subtyp, collections.abc.Sequence + ): + return _unify_sequence(typ, subtyp, subs) + elif isinstance(typ, UnionType) or isinstance(subtyp, UnionType): + return _unify_union(typ, subtyp, subs) + elif isinstance(typ, GenericAlias) or isinstance(subtyp, GenericAlias): + return _unify_generic(typ, subtyp, subs) + elif isinstance(typ, type) and isinstance(subtyp, type) and issubclass(subtyp, typ): + return subs + elif typ in (typing.Any, ...) or subtyp in (typing.Any, ...): + return subs + else: + raise TypeError(f"Cannot unify type {typ} with {subtyp} given {subs}. ") + + +@typing.overload +def _unify_typevar( + typ: TypeVariable, subtyp: TypeExpression, subs: Substitutions +) -> Substitutions: ... + + +@typing.overload +def _unify_typevar( + typ: TypeExpression, subtyp: TypeVariable, subs: Substitutions +) -> Substitutions: ... + + +def _unify_typevar(typ, subtyp, subs: Substitutions) -> Substitutions: + if isinstance(typ, TypeVariable) and isinstance(subtyp, TypeVariable): + return subs if typ == subtyp else {typ: subtyp, **subs} + elif isinstance(typ, TypeVariable) and not isinstance(subtyp, TypeVariable): + return unify(subs.get(typ, subtyp), subtyp, {typ: subtyp, **subs}) + elif ( + not isinstance(typ, TypeVariable) + and isinstance(subtyp, TypeVariable) + and getattr(subtyp, "__bound__", None) is None + ): + return unify(typ, subs.get(subtyp, typ), {subtyp: typ, **subs}) + else: + raise TypeError(f"Cannot unify type variable {typ} with {subtyp} given {subs}.") + + +@typing.overload +def _unify_sequence( + typ: collections.abc.Sequence, subtyp: TypeExpressions, subs: Substitutions +) -> Substitutions: ... + + +@typing.overload +def _unify_sequence( + typ: TypeExpressions, subtyp: collections.abc.Sequence, subs: Substitutions +) -> Substitutions: ... + + +def _unify_sequence(typ, subtyp, subs: Substitutions) -> Substitutions: + if isinstance(typ, types.EllipsisType) or isinstance(subtyp, types.EllipsisType): + return subs + if len(typ) != len(subtyp): + raise TypeError(f"Cannot unify sequence {typ} with {subtyp} given {subs}. ") + for p_item, c_item in zip(typ, subtyp): + subs = unify(p_item, c_item, subs) + return subs + + +@typing.overload +def _unify_union( + typ: UnionType, subtyp: TypeExpression, subs: Substitutions +) -> Substitutions: ... + + +@typing.overload +def _unify_union( + typ: TypeExpression, subtyp: UnionType, subs: Substitutions +) -> Substitutions: ... + + +def _unify_union(typ, subtyp, subs: Substitutions) -> Substitutions: + if typ == subtyp: + return subs + elif isinstance(subtyp, UnionType): + # If subtyp is a union, try to unify with each argument + for arg in typing.get_args(subtyp): + subs = unify(typ, arg, subs) + return subs + elif isinstance(typ, UnionType): + unifiers: list[Substitutions] = [] + for arg in typing.get_args(typ): + try: + unifiers.append(unify(arg, subtyp, subs)) + except TypeError: # noqa + continue + if len(unifiers) > 0 and all(u == unifiers[0] for u in unifiers): + return unifiers[0] + raise TypeError(f"Cannot unify {typ} with {subtyp} given {subs}") + + +@typing.overload +def _unify_generic( + typ: GenericAlias, subtyp: type, subs: Substitutions +) -> Substitutions: ... + + +@typing.overload +def _unify_generic( + typ: type, subtyp: GenericAlias, subs: Substitutions +) -> Substitutions: ... + + +@typing.overload +def _unify_generic( + typ: GenericAlias, subtyp: GenericAlias, subs: Substitutions +) -> Substitutions: ... + + +def _unify_generic(typ, subtyp, subs: Substitutions) -> Substitutions: + if ( + isinstance(typ, GenericAlias) + and isinstance(subtyp, GenericAlias) + and issubclass(typing.get_origin(subtyp), typing.get_origin(typ)) + ): + if typing.get_origin(subtyp) is tuple and typing.get_origin(typ) is not tuple: + for arg in typing.get_args(subtyp): + subs = unify(typ, tuple[arg, ...], subs) # type: ignore + return subs + elif typing.get_origin(subtyp) is collections.abc.Mapping and not issubclass( + typing.get_origin(typ), collections.abc.Mapping + ): + return unify(typing.get_args(typ)[0], typing.get_args(subtyp)[0], subs) + elif typing.get_origin(subtyp) is collections.abc.Generator and not issubclass( + typing.get_origin(typ), collections.abc.Generator + ): + return unify(typing.get_args(typ)[0], typing.get_args(subtyp)[0], subs) + elif typing.get_origin(typ) == typing.get_origin(subtyp): + return unify(typing.get_args(typ), typing.get_args(subtyp), subs) + elif types.get_original_bases(typing.get_origin(subtyp)): + for base in types.get_original_bases(typing.get_origin(subtyp)): + if isinstance(base, type | GenericAlias) and issubclass( + typing.get_origin(base) or base, # type: ignore + typing.get_origin(typ), + ): + return unify(typ, base[typing.get_args(subtyp)], subs) # type: ignore + elif isinstance(typ, type) and isinstance(subtyp, GenericAlias): + return unify(typ, typing.get_origin(subtyp), subs) + elif ( + isinstance(typ, GenericAlias) + and isinstance(subtyp, type) + and issubclass(subtyp, typing.get_origin(typ)) + ): + return subs # implicit expansion to subtyp[Any] + raise TypeError(f"Cannot unify generic type {typ} with {subtyp} given {subs}.") + + +def _unify_signature( + typ: inspect.Signature, subtyp: inspect.BoundArguments, subs: Substitutions +) -> Substitutions: + if typ != subtyp.signature: + raise TypeError(f"Cannot unify {typ} with {subtyp} given {subs}. ") + + for name, param in typ.parameters.items(): + if param.annotation is inspect.Parameter.empty: + continue + + if name not in subtyp.arguments: + assert param.kind in { + inspect.Parameter.VAR_POSITIONAL, + inspect.Parameter.VAR_KEYWORD, + } + continue + + ptyp, psubtyp = param.annotation, subtyp.arguments[name] + if param.kind is inspect.Parameter.VAR_POSITIONAL and isinstance( + psubtyp, collections.abc.Sequence + ): + for psubtyp_item in _freshen(psubtyp): + subs = unify(ptyp, psubtyp_item, subs) + elif param.kind is inspect.Parameter.VAR_KEYWORD and isinstance( + psubtyp, collections.abc.Mapping + ): + for psubtyp_item in _freshen(tuple(psubtyp.values())): + subs = unify(ptyp, psubtyp_item, subs) + elif param.kind not in { + inspect.Parameter.VAR_KEYWORD, + inspect.Parameter.VAR_POSITIONAL, + } or isinstance(psubtyp, typing.ParamSpecArgs | typing.ParamSpecKwargs): + subs = unify(ptyp, _freshen(psubtyp), subs) + else: + raise TypeError(f"Cannot unify {param} with {psubtyp} given {subs}") + return subs + + +def _freshen(tp: typing.Any): + """ + Return a freshened version of the given type expression. + + This function replaces all TypeVars in the type expression with new TypeVars + that have unique names, ensuring that the resulting type has no free TypeVars. + It is useful for creating fresh type variables in generic programming contexts. + + Args: + tp: The type expression to freshen. Can be a plain type, TypeVar, + generic alias, or union type. + + Returns: + A new type expression with all TypeVars replaced by fresh TypeVars. + + Examples: + >>> import typing + >>> T = typing.TypeVar('T') + >>> isinstance(_freshen(T), typing.TypeVar) + True + >>> _freshen(T) == T + False + """ + assert all(canonicalize(fv) is fv for fv in freetypevars(tp)) + subs: Substitutions = { + fv: typing.TypeVar(fv.__name__, bound=fv.__bound__) + if isinstance(fv, typing.TypeVar) + else typing.ParamSpec(fv.__name__) + for fv in freetypevars(tp) + if isinstance(fv, typing.TypeVar | typing.ParamSpec) + } + return substitute(tp, subs) + + +@functools.singledispatch +def canonicalize(typ) -> TypeExpressions: + """ + Normalize generic types + """ + raise TypeError(f"Cannot canonicalize type {typ}.") + + +@canonicalize.register +def _(typ: type | abc.ABCMeta): + if issubclass(typ, effectful.ops.types.Term): + return effectful.ops.types.Term + elif issubclass(typ, effectful.ops.types.Operation): + return effectful.ops.types.Operation + elif typ is dict: + return collections.abc.MutableMapping + elif typ is list: + return collections.abc.MutableSequence + elif typ is set: + return collections.abc.MutableSet + elif typ is frozenset: + return collections.abc.Set + elif typ is range: + return collections.abc.Sequence[int] + elif typ is types.GeneratorType: + return collections.abc.Generator + elif typ in {types.FunctionType, types.BuiltinFunctionType, types.LambdaType}: + return collections.abc.Callable[..., typing.Any] + elif isinstance(typ, abc.ABCMeta) and ( + typ in collections.abc.__dict__.values() or typ in numbers.__dict__.values() + ): + return typ + elif isinstance(typ, type) and ( + typ in builtins.__dict__.values() or typ in types.__dict__.values() + ): + return typ + elif types.get_original_bases(typ): + for base in types.get_original_bases(typ): + cbase = canonicalize(base) + if cbase != object: + return cbase + return typ + else: + raise TypeError(f"Cannot canonicalize type {typ}.") + + +@canonicalize.register +def _(typ: types.EllipsisType | None): + return typ + + +@canonicalize.register +def _(typ: typing.TypeVar): + if ( + typ.__constraints__ + or typ.__covariant__ + or typ.__contravariant__ + or getattr(typ, "__default__", None) is not getattr(typing, "NoDefault", None) + ): + raise TypeError(f"Cannot canonicalize typevar {typ} with nonempty attributes") + return typ + + +@canonicalize.register +def _(typ: typing.ParamSpec): + if ( + typ.__bound__ + or typ.__covariant__ + or typ.__contravariant__ + or getattr(typ, "__default__", None) is not getattr(typing, "NoDefault", None) + ): + raise TypeError(f"Cannot canonicalize typevar {typ} with nonempty attributes") + return typ + + +@canonicalize.register +def _(typ: typing.TypeVarTuple): + if getattr(typ, "__default__", None) is not getattr(typing, "NoDefault", None): + raise TypeError(f"Cannot canonicalize typevar {typ} with nonempty attributes") + return typ + + +@canonicalize.register +def _(typ: UnionType): + ctyp = canonicalize(typing.get_args(typ)[0]) + for arg in typing.get_args(typ)[1:]: + ctyp = ctyp | canonicalize(arg) # type: ignore + return ctyp + + +@canonicalize.register +def _(typ: GenericAlias): + origin, args = typing.get_origin(typ), typing.get_args(typ) + if origin is tuple and len(args) == 2 and args[-1] is Ellipsis: # Variadic tuple + return collections.abc.Sequence[canonicalize(args[0])] # type: ignore + elif isinstance(origin, typing._SpecialForm): + if len(args) == 1: + return canonicalize(args[0]) + else: + raise TypeError(f"Cannot canonicalize type {typ}") + else: + return canonicalize(origin)[tuple(canonicalize(a) for a in args)] # type: ignore + + +@canonicalize.register +def _(typ: list | tuple): + return type(typ)(canonicalize(item) for item in typ) + + +@canonicalize.register +def _(typ: effectful.ops.types._InterpretationMeta): + return typ + + +@canonicalize.register +def _(typ: typing._AnnotatedAlias): # type: ignore + return canonicalize(typing.get_args(typ)[0]) + + +@canonicalize.register +def _(typ: typing._SpecialGenericAlias): # type: ignore + assert not typing.get_args(typ), "Should not have type arguments" + return canonicalize(typing.get_origin(typ)) + + +@canonicalize.register +def _(typ: typing._LiteralGenericAlias): # type: ignore + return canonicalize(nested_type(typing.get_args(typ)[0])) + + +@canonicalize.register +def _(typ: typing.NewType): + return canonicalize(typ.__supertype__) + + +@canonicalize.register +def _(typ: typing.TypeAliasType): + return canonicalize(typ.__value__) + + +@canonicalize.register +def _(typ: typing._ConcatenateGenericAlias): # type: ignore + return Ellipsis + + +@canonicalize.register +def _(typ: typing._AnyMeta): # type: ignore + return typing.Any + + +@canonicalize.register +def _(typ: typing.ParamSpecArgs | typing.ParamSpecKwargs): + return typing.Any + + +@canonicalize.register +def _(typ: typing._SpecialForm): + return typing.Any + + +@canonicalize.register +def _(typ: typing._ProtocolMeta): + return typing.Any + + +@canonicalize.register +def _(typ: typing._UnpackGenericAlias): # type: ignore + raise TypeError(f"Cannot canonicalize type {typ}") + + +@canonicalize.register +def _(typ: typing.ForwardRef): + if typ.__forward_value__ is not None: + return canonicalize(typ.__forward_value__) + else: + raise TypeError(f"Cannot canonicalize lazy ForwardRef {typ}.") + + +@functools.singledispatch +def nested_type(value) -> TypeExpression: + """ + Infer the type of a value, handling nested collections with generic parameters. + + This function is a singledispatch generic function that determines the type + of a given value. For collections (mappings, sequences, sets), it recursively + infers the types of contained elements to produce a properly parameterized + generic type. For example, a list [1, 2, 3] becomes Sequence[int]. + + The function handles: + - Basic types and type annotations (passed through unchanged) + - Collections with recursive type inference for elements + - Special cases like str/bytes (treated as types, not sequences) + - Tuples (preserving exact element types) + - Empty collections (returning the collection's type without parameters) + + This is primarily used by canonicalize() to handle cases where values + are provided instead of type annotations. + + Args: + value: Any value whose type needs to be inferred. Can be a type, + a value instance, or a collection containing other values. + + Returns: + The inferred type, potentially with generic parameters for collections. + + Raises: + TypeError: If the value is a TypeVar (TypeVars shouldn't appear in values) + or if the value is a Term from effectful.ops.types. + + Examples: + >>> import collections.abc + >>> import typing + >>> from effectful.internals.unification import nested_type + + # Basic types are returned as their type + >>> nested_type(42) + + >>> nested_type("hello") + + >>> nested_type(3.14) + + >>> nested_type(True) + + + # Type objects pass through unchanged + >>> nested_type(int) + + >>> nested_type(str) + + >>> nested_type(list) + + + # Empty collections return their base type + >>> nested_type([]) + + >>> nested_type({}) + + >>> nested_type(set()) + + + # Sequences become Sequence[element_type] + >>> nested_type([1, 2, 3]) + collections.abc.MutableSequence[int] + >>> nested_type(["a", "b", "c"]) + collections.abc.MutableSequence[str] + + # Tuples preserve exact structure + >>> nested_type((1, "hello", 3.14)) + tuple[int, str, float] + >>> nested_type(()) + + >>> nested_type((1,)) + tuple[int] + + # Sets become Set[element_type] + >>> nested_type({1, 2, 3}) + collections.abc.MutableSet[int] + >>> nested_type({"a", "b"}) + collections.abc.MutableSet[str] + + # Mappings become Mapping[key_type, value_type] + >>> nested_type({"key": "value"}) + collections.abc.MutableMapping[str, str] + >>> nested_type({1: "one", 2: "two"}) + collections.abc.MutableMapping[int, str] + + # Strings and bytes are NOT treated as sequences + >>> nested_type("hello") + + >>> nested_type(b"bytes") + + + # Annotated functions return types derived from their annotations + >>> def annotated_func(x: int) -> str: + ... return str(x) + >>> nested_type(annotated_func) + collections.abc.Callable[[int], str] + + # Unannotated functions/callables return their type + >>> def f(): pass + >>> nested_type(f) + + >>> nested_type(lambda x: x) + + + # Generic aliases and union types pass through + >>> nested_type(list[int]) + list[int] + >>> nested_type(int | str) + int | str + """ + return type(value) + + +@nested_type.register +def _(value: TypeExpression): + return value + + +@nested_type.register +def _(value: effectful.ops.types.Term): + raise TypeError(f"Terms should not appear in nested_type, but got {value}") + + +@nested_type.register +def _(value: effectful.ops.types.Operation): + typ = nested_type.dispatch(collections.abc.Callable)(value) + (arg_types, return_type) = typing.get_args(typ) + return effectful.ops.types.Operation[arg_types, return_type] # type: ignore + + +@nested_type.register +def _(value: collections.abc.Callable): + if typing.get_overloads(value): + return type(value) + + try: + sig = inspect.signature(value) + except ValueError: + return type(value) + + if sig.return_annotation is inspect.Signature.empty: + return type(value) + elif any( + p.annotation is inspect.Parameter.empty + or p.kind + in { + inspect.Parameter.VAR_POSITIONAL, + inspect.Parameter.VAR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + } + for p in sig.parameters.values() + ): + return collections.abc.Callable[..., sig.return_annotation] + else: + return collections.abc.Callable[ + [p.annotation for p in sig.parameters.values()], sig.return_annotation + ] + + +@nested_type.register +def _(value: collections.abc.Mapping): + if value and isinstance(value, effectful.ops.types.Interpretation): + return effectful.ops.types.Interpretation + + if len(value) == 0: + return type(value) + elif len(value) == 1: + ktyp = nested_type(next(iter(value.keys()))) + vtyp = nested_type(next(iter(value.values()))) + return canonicalize(type(value))[ktyp, vtyp] # type: ignore + else: + ktyp = functools.reduce(operator.or_, map(nested_type, value.keys())) + vtyp = functools.reduce(operator.or_, map(nested_type, value.values())) + if isinstance(ktyp, UnionType) or isinstance(vtyp, UnionType): + return type(value) + else: + return canonicalize(type(value))[ktyp, vtyp] # type: ignore + + +@nested_type.register +def _(value: collections.abc.Collection): + if len(value) == 0: + return type(value) + elif len(value) == 1: + vtyp = nested_type(next(iter(value))) + return canonicalize(type(value))[vtyp] # type: ignore + else: + valtyp = functools.reduce(operator.or_, map(nested_type, value)) + if isinstance(valtyp, UnionType): + return type(value) + else: + return canonicalize(type(value))[valtyp] # type: ignore + + +@nested_type.register +def _(value: tuple): + return ( + nested_type.dispatch(collections.abc.Sequence)(value) + if type(value) != tuple or len(value) == 0 + else tuple[tuple(nested_type(item) for item in value)] # type: ignore + ) + + +@nested_type.register +def _(value: str | bytes | range | None): + return type(value) + + +def freetypevars(typ) -> collections.abc.Set[TypeVariable]: + """ + Return a set of free type variables in the given type expression. + + This function recursively traverses a type expression to find all TypeVar + instances that appear within it. It handles both simple types and generic + type aliases with nested type arguments. TypeVars are considered "free" + when they are not bound to a specific concrete type. + + Args: + typ: The type expression to analyze. Can be a plain type (e.g., int), + a TypeVar, or a generic type alias (e.g., List[T], Dict[K, V]). + + Returns: + A set containing all TypeVar instances found in the type expression. + Returns an empty set if no TypeVars are present. + + Examples: + >>> T = typing.TypeVar('T') + >>> K = typing.TypeVar('K') + >>> V = typing.TypeVar('V') + + >>> # TypeVar returns itself + >>> freetypevars(T) + {~T} + + >>> # Generic type with one TypeVar + >>> freetypevars(list[T]) + {~T} + + >>> # Generic type with multiple TypeVars + >>> freetypevars(dict[K, V]) == {K, V} + True + + >>> # Nested generic types + >>> freetypevars(list[dict[K, V]]) == {K, V} + True + + >>> # Concrete types have no free TypeVars + >>> freetypevars(int) + set() + + >>> # Generic types with concrete arguments have no free TypeVars + >>> freetypevars(list[int]) + set() + + >>> # Mixed concrete and TypeVar arguments + >>> freetypevars(dict[str, T]) + {~T} + """ + return set(_freetypevars((typ,))) + + +def substitute(typ, subs: Substitutions) -> TypeExpressions: + """ + Substitute type variables in a type expression with concrete types. + + This function recursively traverses a type expression and replaces any TypeVar + instances found with their corresponding concrete types from the substitution + mapping. If a TypeVar is not present in the substitution mapping, it remains + unchanged. The function handles nested generic types by recursively substituting + in their type arguments. + + Args: + typ: The type expression to perform substitution on. Can be a plain type, + a TypeVar, or a generic type alias (e.g., List[T], Dict[K, V]). + subs: A mapping from TypeVar instances to concrete types that should + replace them. + + Returns: + A new type expression with all mapped TypeVars replaced by their + corresponding concrete types. + + Examples: + >>> T = typing.TypeVar('T') + >>> K = typing.TypeVar('K') + >>> V = typing.TypeVar('V') + + >>> # Simple TypeVar substitution + >>> substitute(T, {T: int}) + + + >>> # Generic type substitution + >>> substitute(list[T], {T: str}) + list[str] + + >>> # Nested generic substitution + >>> substitute(dict[K, list[V]], {K: str, V: int}) + dict[str, list[int]] + + >>> # TypeVar not in mapping remains unchanged + >>> substitute(T, {K: int}) + ~T + + >>> # Non-generic types pass through unchanged + >>> substitute(int, {T: str}) + + """ + if isinstance(typ, typing.TypeVar | typing.ParamSpec | typing.TypeVarTuple): + return substitute(subs[typ], subs) if typ in subs else typ + elif isinstance(typ, list | tuple): + return type(typ)(substitute(item, subs) for item in typ) + elif any(fv in subs for fv in freetypevars(typ)): + args = tuple(subs.get(fv, fv) for fv in _freetypevars((typ,))) + return substitute(typ[args], subs) + else: + return typ diff --git a/effectful/ops/semantics.py b/effectful/ops/semantics.py index 5879f0a4..9fea27d6 100644 --- a/effectful/ops/semantics.py +++ b/effectful/ops/semantics.py @@ -1,5 +1,7 @@ import contextlib import functools +import types +import typing from collections.abc import Callable from typing import Any @@ -285,7 +287,24 @@ def typeof[T](term: Expr[T]) -> type[T]: from effectful.internals.runtime import interpreter with interpreter({apply: lambda _, op, *a, **k: op.__type_rule__(*a, **k)}): - return evaluate(term) if isinstance(term, Term) else type(term) # type: ignore + if isinstance(term, Term): + # If term is a Term, we evaluate it to get its type + tp = evaluate(term) + if isinstance(tp, typing.TypeVar): + tp = ( + tp.__bound__ + if tp.__bound__ + else tp.__constraints__[0] + if tp.__constraints__ + else object + ) + if isinstance(tp, types.UnionType): + raise TypeError( + f"Cannot determine type of {term} because it is a union type: {tp}" + ) + return typing.get_origin(tp) or tp # type: ignore + else: + return type(term) def fvsof[S](term: Expr[S]) -> set[Operation]: diff --git a/effectful/ops/syntax.py b/effectful/ops/syntax.py index 99d6edaf..8a39eb9e 100644 --- a/effectful/ops/syntax.py +++ b/effectful/ops/syntax.py @@ -588,48 +588,28 @@ def __fvs_rule__( return tuple(result_sig.args), dict(result_sig.kwargs) def __type_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> type[V]: - def unwrap_annotation(typ): - """Unwrap Annotated types.""" - return ( - typing.get_args(typ)[0] if typing.get_origin(typ) is Annotated else typ - ) - - def drop_params(typ): - """Strip parameters from polymorphic types.""" - origin = typing.get_origin(typ) - return typ if origin is None else origin - - sig = self.__signature__ - bound_sig = sig.bind(*args, **kwargs) - bound_sig.apply_defaults() - - anno = sig.return_annotation - anno = unwrap_annotation(anno) - - if anno is None: - return typing.cast(type[V], type(None)) - - if anno is inspect.Signature.empty: - return typing.cast(type[V], object) + from effectful.internals.unification import ( + freetypevars, + nested_type, + substitute, + unify, + ) - if isinstance(anno, typing.TypeVar): - # rudimentary but sound special-case type inference sufficient for syntax ops: - # if the return type annotation is a TypeVar, - # look for a parameter with the same annotation and return its type, - # otherwise give up and return Any/object - for name, param in bound_sig.signature.parameters.items(): - param_typ = unwrap_annotation(param.annotation) - if param_typ is anno and param.kind not in ( - inspect.Parameter.VAR_POSITIONAL, - inspect.Parameter.VAR_KEYWORD, - ): - arg = bound_sig.arguments[name] - tp: type[V] = type(arg) if not isinstance(arg, type) else arg - return drop_params(tp) + return_anno = self.__signature__.return_annotation + if typing.get_origin(return_anno) is typing.Annotated: + return_anno = typing.get_args(return_anno)[0] + if return_anno is inspect.Parameter.empty: return typing.cast(type[V], object) + elif return_anno is None: + return type(None) # type: ignore + elif not freetypevars(return_anno): + return return_anno - return drop_params(anno) + type_args = tuple(nested_type(a) for a in args) + type_kwargs = {k: nested_type(v) for k, v in kwargs.items()} + bound_sig = self.__signature__.bind(*type_args, **type_kwargs) + return substitute(return_anno, unify(self.__signature__, bound_sig)) # type: ignore def __repr__(self): return f"_BaseOperation({self._default}, name={self.__name__}, freshening={self._freshening})" @@ -660,6 +640,9 @@ def func(*args, **kwargs): @defop.register(type) +@defop.register(typing.cast(type, types.GenericAlias)) +@defop.register(typing.cast(type, typing._GenericAlias)) # type: ignore +@defop.register(typing.cast(type, types.UnionType)) def _[T](t: type[T], *, name: str | None = None) -> Operation[[], T]: def func() -> t: # type: ignore raise NotImplementedError @@ -994,9 +977,6 @@ def _(op, *args, **kwargs): base_term = __dispatch(typing.cast(type[T], object))(op, *args_, **kwargs_) tp = typeof(base_term) - if tp is typing.Union: - raise ValueError("Terms that return Union types are not supported.") - assert isinstance(tp, type) typed_term = __dispatch(tp)(op, *args_, **kwargs_) return typed_term diff --git a/tests/test_handlers_numbers.py b/tests/test_handlers_numbers.py index 1086cb1e..31d117dd 100644 --- a/tests/test_handlers_numbers.py +++ b/tests/test_handlers_numbers.py @@ -1,6 +1,8 @@ import collections +import collections.abc import logging import os +import typing import pytest @@ -11,6 +13,8 @@ logger = logging.getLogger(__name__) +T = typing.TypeVar("T") + def test_lambda_calculus_1(): x, y = defop(int), defop(int) @@ -40,7 +44,11 @@ def test_lambda_calculus_2(): def test_lambda_calculus_3(): - x, y, f = defop(int), defop(int), defop(collections.abc.Callable) + x, y, f = ( + defop(int), + defop(int), + defop(collections.abc.Callable[[int], collections.abc.Callable[[int], int]]), + ) with handler(eager_mixed): f2 = Lam(x, Lam(y, (x() + y()))) @@ -51,8 +59,8 @@ def test_lambda_calculus_3(): def test_lambda_calculus_4(): x, f, g = ( defop(int), - defop(collections.abc.Callable), - defop(collections.abc.Callable), + defop(collections.abc.Callable[[T], T]), + defop(collections.abc.Callable[[T], T]), ) with handler(eager_mixed): @@ -179,7 +187,7 @@ def f2(x: int, y: int) -> int: return x + y @trace - def app2(f: collections.abc.Callable, x: int, y: int) -> int: + def app2(f: collections.abc.Callable[[int, int], int], x: int, y: int) -> int: return f(x, y) assert syntactic_eq(app2(f2, 1, 2), 3) diff --git a/tests/test_internals_unification.py b/tests/test_internals_unification.py new file mode 100644 index 00000000..db236829 --- /dev/null +++ b/tests/test_internals_unification.py @@ -0,0 +1,1646 @@ +import collections.abc +import inspect +import typing + +import pytest + +from effectful.internals.unification import ( + canonicalize, + freetypevars, + nested_type, + substitute, + unify, +) + +if typing.TYPE_CHECKING: + T = typing.Any + K = typing.Any + V = typing.Any + U = typing.Any + W = typing.Any +else: + T = typing.TypeVar("T") + K = typing.TypeVar("K") + V = typing.TypeVar("V") + U = typing.TypeVar("U") + W = typing.TypeVar("W") + + +@pytest.mark.parametrize( + "typ,fvs", + [ + # Basic cases + (T, {T}), + (int, set()), + (str, set()), + # Single TypeVar in generic + (list[T], {T}), + (set[T], {T}), + (tuple[T], {T}), + # Multiple TypeVars + (dict[K, V], {K, V}), + (tuple[K, V], {K, V}), + (dict[T, T], {T}), # Same TypeVar used twice + # Nested generics with TypeVars + (list[dict[K, V]], {K, V}), + (dict[K, list[V]], {K, V}), + (list[tuple[T, U]], {T, U}), + (tuple[list[T], dict[K, V]], {T, K, V}), + # Concrete types in generics + (list[int], set()), + (dict[str, int], set()), + (tuple[int, str, float], set()), + # Mixed concrete and TypeVars + (dict[str, T], {T}), + (dict[K, int], {K}), + (tuple[T, int, V], {T, V}), + (list[tuple[int, T]], {T}), + # Deeply nested + (list[dict[K, list[tuple[V, T]]]], {K, V, T}), + (dict[tuple[K, V], list[dict[U, T]]], {K, V, U, T}), + # Union types (if supported) + (list[T] | dict[K, V], {T, K, V}), + (T | int, {T}), + # Callable types + (collections.abc.Callable[[T], V], {T, V}), + (collections.abc.Callable[[int, T], T], {T}), + (collections.abc.Callable[[], T], {T}), + (collections.abc.Callable[[T, U], V], {T, U, V}), + (collections.abc.Callable[[int], int], set()), + (collections.abc.Callable[[T], list[T]], {T}), + (collections.abc.Callable[[dict[K, V]], tuple[K, V]], {K, V}), + # Nested Callable + (collections.abc.Callable[[T], collections.abc.Callable[[U], V]], {T, U, V}), + (list[collections.abc.Callable[[T], V]], {T, V}), + (dict[K, collections.abc.Callable[[T], V]], {K, T, V}), + # ParamSpec and TypeVarTuple (if needed later) + # (collections.abc.Callable[typing.ParamSpec("P"), T], {T}), # Would need to handle ParamSpec + ], +) +def test_freetypevars(typ: type, fvs: set[typing.TypeVar]): + assert freetypevars(typ) == fvs + + +def test_canonicalize_1(): + assert canonicalize(int) == int + assert canonicalize(list[int]) == collections.abc.MutableSequence[int] + assert canonicalize(dict[str, int]) == collections.abc.MutableMapping[str, int] + assert ( + canonicalize(dict[str, set[int]]) + == collections.abc.MutableMapping[str, collections.abc.MutableSet[int]] + ) + assert canonicalize(tuple[int, ...]) == collections.abc.Sequence[int] + assert canonicalize(tuple[int, str]) == tuple[int, str] + + class CustomDict[T](dict[T, T]): + pass + + class ConcreteCustomDict(CustomDict[int]): + pass + + class CustomDictSet[T](CustomDict[frozenset[T]]): + pass + + assert canonicalize(CustomDict[int]) == canonicalize(dict[int, int]) + assert canonicalize(ConcreteCustomDict) == canonicalize(CustomDict[int]) + assert canonicalize(CustomDictSet[T]) == canonicalize( + dict[frozenset[T], frozenset[T]] + ) + assert canonicalize(CustomDictSet[int]) == canonicalize( + dict[frozenset[int], frozenset[int]] + ) + + +@pytest.mark.parametrize( + "typ,subs,expected", + [ + # Basic substitution + (T, {T: int}, int), + (T, {T: str}, str), + (T, {T: list[int]}, list[int]), + # TypeVar not in mapping + (T, {K: int}, T), + (T, {}, T), + # Non-TypeVar types + (int, {T: str}, int), + (str, {}, str), + (list[int], {T: str}, list[int]), + # Single TypeVar in generic + (list[T], {T: int}, list[int]), + (set[T], {T: str}, set[str]), + (tuple[T], {T: float}, tuple[float]), + # Multiple TypeVars + (dict[K, V], {K: str, V: int}, dict[str, int]), + (tuple[K, V], {K: int, V: str}, tuple[int, str]), + (dict[K, V], {K: str}, dict[str, V]), # Partial substitution + # Same TypeVar used multiple times + (dict[T, T], {T: int}, dict[int, int]), + (tuple[T, T, T], {T: str}, tuple[str, str, str]), + # Nested generics + (list[dict[K, V]], {K: str, V: int}, list[dict[str, int]]), + (dict[K, list[V]], {K: int, V: str}, dict[int, list[str]]), + (list[tuple[T, U]], {T: int, U: str}, list[tuple[int, str]]), + # Mixed concrete and TypeVars + (dict[str, T], {T: int}, dict[str, int]), + (tuple[int, T, str], {T: float}, tuple[int, float, str]), + (list[tuple[int, T]], {T: str}, list[tuple[int, str]]), + # Deeply nested substitution + (list[dict[K, list[V]]], {K: str, V: int}, list[dict[str, list[int]]]), + ( + dict[tuple[K, V], list[T]], + {K: int, V: str, T: float}, + dict[tuple[int, str], list[float]], + ), + # Substituting with generic types + (T, {T: list[int]}, list[int]), + (list[T], {T: dict[str, int]}, list[dict[str, int]]), + ( + dict[K, V], + {K: list[int], V: dict[str, float]}, + dict[list[int], dict[str, float]], + ), + # Empty substitution + (list[T], {}, list[T]), + (dict[K, V], {}, dict[K, V]), + # Union types (if supported) + (T | int, {T: str}, str | int), + ( + list[T] | dict[K, V], + {T: int, K: str, V: float}, + list[int] | dict[str, float], + ), + # Irrelevant substitutions (TypeVars not in type) + (list[T], {K: int, V: str}, list[T]), + (int, {T: str, K: int}, int), + # Callable types + ( + collections.abc.Callable[[T], V], + {T: int, V: str}, + collections.abc.Callable[[int], str], + ), + ( + collections.abc.Callable[[int, T], T], + {T: str}, + collections.abc.Callable[[int, str], str], + ), + ( + collections.abc.Callable[[], T], + {T: float}, + collections.abc.Callable[[], float], + ), + ( + collections.abc.Callable[[T, U], V], + {T: int, U: str, V: bool}, + collections.abc.Callable[[int, str], bool], + ), + ( + collections.abc.Callable[[int], int], + {T: str}, + collections.abc.Callable[[int], int], + ), + ( + collections.abc.Callable[[T], list[T]], + {T: int}, + collections.abc.Callable[[int], list[int]], + ), + ( + collections.abc.Callable[[dict[K, V]], tuple[K, V]], + {K: str, V: int}, + collections.abc.Callable[[dict[str, int]], tuple[str, int]], + ), + # Nested Callable + ( + collections.abc.Callable[[T], collections.abc.Callable[[U], V]], + {T: int, U: str, V: bool}, + collections.abc.Callable[[int], collections.abc.Callable[[str], bool]], + ), + ( + list[collections.abc.Callable[[T], V]], + {T: int, V: str}, + list[collections.abc.Callable[[int], str]], + ), + ( + dict[K, collections.abc.Callable[[T], V]], + {K: str, T: int, V: float}, + dict[str, collections.abc.Callable[[int], float]], + ), + # Partial substitution with Callable + ( + collections.abc.Callable[[T, U], V], + {T: int}, + collections.abc.Callable[[int, U], V], + ), + ( + collections.abc.Callable[[T], dict[K, V]], + {T: int, K: str}, + collections.abc.Callable[[int], dict[str, V]], + ), + ], +) +def test_substitute( + typ: type, subs: typing.Mapping[typing.TypeVar, type], expected: type +): + assert substitute(typ, subs) == expected # type: ignore + + +@pytest.mark.parametrize( + "typ,subtyp,initial_subs,expected_subs", + [ + # Basic TypeVar unification + (T, int, {}, {T: int}), + (T, str, {}, {T: str}), + (T, list[int], {}, {T: list[int]}), + # With existing substitutions + (V, bool, {T: int}, {T: int, V: bool}), + (K, str, {T: int, V: bool}, {T: int, V: bool, K: str}), + # Generic type unification + (list[T], list[int], {}, {T: int}), + (dict[K, V], dict[str, int], {}, {K: str, V: int}), + (tuple[T, U], tuple[int, str], {}, {T: int, U: str}), + (set[T], set[float], {}, {T: float}), + # Same TypeVar used multiple times + (dict[T, T], dict[int, int], {}, {T: int}), + (tuple[T, T, T], tuple[str, str, str], {}, {T: str}), + # Nested generic unification + (list[dict[K, V]], list[dict[str, int]], {}, {K: str, V: int}), + (dict[K, list[V]], dict[int, list[str]], {}, {K: int, V: str}), + (list[tuple[T, U]], list[tuple[bool, float]], {}, {T: bool, U: float}), + # Deeply nested + (list[dict[K, list[V]]], list[dict[str, list[int]]], {}, {K: str, V: int}), + ( + dict[tuple[K, V], list[T]], + dict[tuple[int, str], list[bool]], + {}, + {K: int, V: str, T: bool}, + ), + # Mixed concrete and TypeVars + (dict[str, T], dict[str, int], {}, {T: int}), + (tuple[int, T, str], tuple[int, float, str], {}, {T: float}), + (list[tuple[int, T]], list[tuple[int, str]], {}, {T: str}), + # Exact type matching (no TypeVars) + (int, int, {}, {}), + (str, str, {}, {}), + (list[int], list[int], {}, {}), + (dict[str, int], dict[str, int], {}, {}), + # Callable type unification + ( + collections.abc.Callable[[T], V], + collections.abc.Callable[[int], str], + {}, + {T: int, V: str}, + ), + ( + collections.abc.Callable[[T, U], V], + collections.abc.Callable[[int, str], bool], + {}, + {T: int, U: str, V: bool}, + ), + ( + collections.abc.Callable[[], T], + collections.abc.Callable[[], float], + {}, + {T: float}, + ), + ( + collections.abc.Callable[[T], list[T]], + collections.abc.Callable[[int], list[int]], + {}, + {T: int}, + ), + # Nested Callable + ( + collections.abc.Callable[[T], collections.abc.Callable[[U], V]], + collections.abc.Callable[[int], collections.abc.Callable[[str], bool]], + {}, + {T: int, U: str, V: bool}, + ), + # Complex combinations + ( + dict[K, collections.abc.Callable[[T], V]], + dict[str, collections.abc.Callable[[int], bool]], + {}, + {K: str, T: int, V: bool}, + ), + ], +) +def test_unify_success( + typ: type, + subtyp: type, + initial_subs: typing.Mapping, + expected_subs: typing.Mapping, +): + assert unify(typ, subtyp, initial_subs) == { + k: canonicalize(v) for k, v in expected_subs.items() + } + + +@pytest.mark.parametrize( + "typ,subtyp", + [ + # Incompatible types + (list[T], dict[str, int]), + (int, str), + (list[int], list[str]), + # Mismatched generic types + (list[T], set[int]), + (dict[K, V], list[int]), + # Same TypeVar with different values + (dict[T, T], dict[int, str]), + (tuple[T, T], tuple[int, str]), + # Mismatched arities + (tuple[T, U], tuple[int, str, bool]), + ( + collections.abc.Callable[[T], V], + collections.abc.Callable[[int, str], bool], + ), + # Sequence length mismatch + ((T, V), (int,)), + ([T, V], [int, str, bool]), + ], +) +def test_unify_failure( + typ: type, + subtyp: type, +): + with pytest.raises(TypeError): + unify(typ, subtyp, {}) + + +def test_unify_union_1(): + assert unify(int | str, int | str) == {} + assert unify(int | str, str) == {} + assert unify(int | str, int) == {} + + assert unify(T, int | str) == {T: int | str} + + +def test_unify_tuple_variadic(): + assert unify(tuple[T, ...], tuple[int, ...]) == {T: int} + assert unify(tuple[T, ...], tuple[int]) == {T: int} + assert unify(tuple[T, ...], tuple[int, int]) == {T: int} + assert unify(collections.abc.Sequence[T], tuple[int, ...]) == {T: int} + + +def test_unify_tuple_non_variadic(): + assert unify(tuple[T], tuple[int | str]) == {T: int | str} + assert unify(tuple[T, V], tuple[int, str]) == {T: int, V: str} + assert unify(tuple[T, T], tuple[int, int]) == {T: int} + assert unify(tuple[T, T, T], tuple[str, str, str]) == {T: str} + assert unify(collections.abc.Sequence[T], tuple[int, int]) == {T: int} + + +def test_unify_both_abstract(): + assert unify(tuple[T, ...], tuple[V, V]) == {T: V} + assert unify(tuple[T, V], tuple[int, U]) == {T: int, V: U} + assert unify(list[T], list[tuple[V, V]]) == {T: tuple[V, V]} + assert unify( + collections.abc.Callable[[T], T], + collections.abc.Callable[[tuple[int, V]], tuple[int, V]], + ) == {T: tuple[int, V]} + assert unify( + (tuple[T, int], tuple[int, int]), + (tuple[int, V], tuple[U, V]), + ) == {T: int, U: int, V: int} + assert unify( + (list[T], T), + (list[list[V]], list[V]), + ) == {T: canonicalize(list[V])} + + +# Test functions with various type patterns +def identity[T](x: T) -> T: + return x + + +def make_pair[T, V](x: T, y: V) -> tuple[T, V]: + return (x, y) + + +def wrap_in_list[T](x: T) -> list[T]: + return [x] + + +def get_first[T](items: list[T]) -> T: + return items[0] + + +def getitem_mapping[K, V](mapping: collections.abc.Mapping[K, V], key: K) -> V: + return mapping[key] + + +def dict_values[K, V](d: dict[K, V]) -> list[V]: + return list(d.values()) + + +def process_callable[T, V](func: collections.abc.Callable[[T], V], arg: T) -> V: + return func(arg) + + +def chain_callables[T, U, V]( + f: collections.abc.Callable[[T], U], g: collections.abc.Callable[[U], V] +) -> collections.abc.Callable[[T], V]: + def result(x: T) -> V: + return g(f(x)) + + return result + + +def constant_func() -> int: + return 42 + + +def multi_generic[T, K, V](a: T, b: list[T], c: dict[K, V]) -> tuple[T, K, V]: + return (a, next(iter(c.keys())), next(iter(c.values()))) + + +def same_type_twice[T](x: T, y: T) -> T: + return x if len(str(x)) > len(str(y)) else y + + +def nested_generic[T](x: T) -> dict[str, list[T]]: + return {"items": [x]} + + +def variadic_args_func[T](*args: T) -> T: # Variadic args not supported + return args[0] + + +def variadic_kwargs_func[T](**kwargs: T) -> T: # Variadic kwargs not supported + return next(iter(kwargs.values())) + + +@pytest.mark.parametrize( + "func,args,kwargs,expected_return_type", + [ + # Simple generic functions + (identity, (int,), {}, int), + (identity, (str,), {}, str), + (identity, (list[int],), {}, list[int]), + # Multiple TypeVars + (make_pair, (int, str), {}, tuple[int, str]), + (make_pair, (bool, list[float]), {}, tuple[bool, list[float]]), + # Generic collections + (wrap_in_list, (int,), {}, list[int]), + (wrap_in_list, (dict[str, bool],), {}, list[dict[str, bool]]), + (get_first, (list[str],), {}, str), + (get_first, (list[tuple[int, float]],), {}, tuple[int, float]), + (getitem_mapping, (collections.abc.Mapping[str, int], str), {}, int), + ( + getitem_mapping, + (collections.abc.Mapping[bool, list[str]], bool), + {}, + list[str], + ), + # Dict operations + (dict_values, (dict[str, int],), {}, list[int]), + (dict_values, (dict[bool, list[str]],), {}, list[list[str]]), + # Callable types + (process_callable, (collections.abc.Callable[[int], str], int), {}, str), + ( + process_callable, + (collections.abc.Callable[[list[int]], bool], list[int]), + {}, + bool, + ), + # Complex callable return + ( + chain_callables, + ( + collections.abc.Callable[[int], str], + collections.abc.Callable[[str], bool], + ), + {}, + collections.abc.Callable[[int], bool], + ), + # No generics + (constant_func, (), {}, int), + # Mixed generics + (multi_generic, (int, list[int], dict[str, bool]), {}, tuple[int, str, bool]), + ( + multi_generic, + (float, list[float], dict[bool, list[str]]), + {}, + tuple[float, bool, list[str]], + ), + # Same TypeVar used multiple times + (same_type_twice, (int, int), {}, int), + (same_type_twice, (str, str), {}, str), + # Nested generics + (nested_generic, (int,), {}, dict[str, list[int]]), + ( + nested_generic, + (collections.abc.Callable[[str], bool],), + {}, + dict[str, list[collections.abc.Callable[[str], bool]]], + ), + # Keyword arguments + (make_pair, (), {"x": int, "y": str}, tuple[int, str]), + ( + multi_generic, + (), + {"a": bool, "b": list[bool], "c": dict[int, str]}, + tuple[bool, int, str], + ), + # variadic args and kwargs + (variadic_args_func, (int,), {}, int), + (variadic_args_func, (int, int), {}, int), + (variadic_kwargs_func, (), {"x": int}, int), + (variadic_kwargs_func, (), {"x": int, "y": int}, int), + ], +) +def test_infer_return_type_success( + func: collections.abc.Callable, + args: tuple, + kwargs: dict, + expected_return_type: type, +): + sig = inspect.signature(func) + bound = sig.bind(*args, **kwargs) + result = substitute(sig.return_annotation, unify(sig, bound)) + assert canonicalize(result) == canonicalize(expected_return_type) + + +# Error cases +def unbound_typevar_func[T](x: T) -> tuple[T, V]: # V not in parameters + return (x, "error") + + +def no_return_annotation[T](x: T): # No return annotation + return x + + +def no_param_annotation[T](x) -> T: # type: ignore + return x + + +@pytest.mark.parametrize( + "func,args,kwargs", + [ + # Type mismatch - trying to unify incompatible types + (same_type_twice, (int, str), {}), + ], +) +def test_infer_return_type_failure( + func: collections.abc.Callable, + args: tuple, + kwargs: dict, +): + sig = inspect.signature(func) + bound = sig.bind(*args, **kwargs) + with pytest.raises(TypeError): + unify(sig, bound) + + +@pytest.mark.parametrize( + "value,expected", + [ + # Basic value types + (42, int), + (0, int), + (-5, int), + ("hello", str), + ("", str), + (3.14, float), + (0.0, float), + (True, bool), + (False, bool), + (None, type(None)), + (b"bytes", bytes), + (b"", bytes), + # Type objects pass through + (int, int), + (str, str), + (float, float), + (bool, bool), + (list, list), + (dict, dict), + (set, set), + (tuple, tuple), + (type(None), type(None)), + (type(...), type(...)), + # Generic aliases pass through + (list[int], list[int]), + (dict[str, int], dict[str, int]), + (set[bool], set[bool]), + (tuple[int, str], tuple[int, str]), + (int | str, int | str), + (list[T], list[T]), + (dict[K, V], dict[K, V]), + # Union types pass through + (int | str, int | str), + # Empty collections + ([], list), + ({}, dict), + (set(), set), + ((), tuple), + # Lists/sequences with single type + ([1, 2, 3], list[int]), + ([1], list[int]), + (["a", "b", "c"], list[str]), + ([True, False], list[bool]), + ([1.1, 2.2], list[float]), + # Sets with elements + ({1, 2, 3}, set[int]), + ({1}, set[int]), + ({"a", "b"}, set[str]), + ({True, False}, set[bool]), + # Dicts/mappings + ({"key": "value"}, dict[str, str]), + ({1: "one", 2: "two"}, dict[int, str]), + ({"a": 1, "b": 2}, dict[str, int]), + ({True: 1.0, False: 2.0}, dict[bool, float]), + # Tuples preserve exact structure + ((1, "hello", 3.14), tuple[int, str, float]), + ((1,), tuple[int]), + ((1, 2), tuple[int, int]), + (("a", "b", "c"), tuple[str, str, str]), + ((True, 1, "x", 3.14), tuple[bool, int, str, float]), + # Nested collections + ([[1, 2], [3, 4]], list[list[int]]), + ([{1, 2}, {3, 4}], list[set[int]]), + ([{"a": 1}, {"b": 2}], list[dict[str, int]]), + ({"key": [1, 2, 3]}, dict[str, list[int]]), + ({"a": {1, 2}, "b": {3, 4}}, dict[str, set[int]]), + ({1: {"x": True}, 2: {"y": False}}, dict[int, dict[str, bool]]), + # Tuples in collections + ([(1, "a"), (2, "b")], list[tuple[int, str]]), + ({(1, 2), (3, 4)}, set[tuple[int, int]]), + ({1: (True, "x"), 2: (False, "y")}, dict[int, tuple[bool, str]]), + # Functions/callables + (lambda x: x, type(lambda x: x)), + (print, type(print)), + (len, type(len)), + # Complex nested structures + ([[[1]]], list[list[list[int]]]), + ({"a": {"b": {"c": 1}}}, dict[str, dict[str, dict[str, int]]]), + # Special string/bytes handling (NOT treated as sequences) + ("hello", str), + (b"world", bytes), + # Other built-in types + (range(5), type(range(5))), + (slice(1, 10), type(slice(1, 10))), + ], +) +def test_nested_type(value, expected): + result = nested_type(value) + assert canonicalize(result) == canonicalize(expected) + + +def test_nested_type_term_error(): + """Test that Terms raise TypeError in nested_type""" + # We can't import Term here without creating a circular dependency, + # so we'll create a mock object that would trigger the isinstance check + from unittest.mock import Mock + + from effectful.ops.types import Term + + mock_term = Mock(spec=Term) + with pytest.raises(TypeError, match="Terms should not appear in nested_type"): + nested_type(mock_term) + + +def sequence_getitem[T](seq: collections.abc.Sequence[T], index: int) -> T: + return seq[index] + + +def mapping_getitem[K, V](mapping: collections.abc.Mapping[K, V], key: K) -> V: + return mapping[key] + + +def sequence_mapping_getitem[K, V]( + seq: collections.abc.Sequence[collections.abc.Mapping[K, V]], index: int, key: K +) -> V: + return mapping_getitem(sequence_getitem(seq, index), key) + + +def mapping_sequence_getitem[K, T]( + mapping: collections.abc.Mapping[K, collections.abc.Sequence[T]], key: K, index: int +) -> T: + return sequence_getitem(mapping_getitem(mapping, key), index) + + +def sequence_from_pair[T](a: T, b: T) -> collections.abc.Sequence[T]: + return [a, b] + + +def mapping_from_pair[K, V](a: K, b: V) -> collections.abc.Mapping[K, V]: + return {a: b} + + +def sequence_of_mappings[K, V]( + key1: K, val1: V, key2: K, val2: V +) -> collections.abc.Sequence[collections.abc.Mapping[K, V]]: + """Creates a sequence containing two mappings.""" + return sequence_from_pair( + mapping_from_pair(key1, val1), mapping_from_pair(key2, val2) + ) + + +def mapping_of_sequences[K, T]( + key1: K, val1: T, val2: T, key2: K, val3: T, val4: T +) -> collections.abc.Mapping[K, collections.abc.Sequence[T]]: + """Creates a mapping where each key maps to a sequence of two values.""" + return mapping_from_pair(key1, sequence_from_pair(val1, val2)) + + +def nested_sequence_mapping[K, T]( + k1: K, v1: T, v2: T, k2: K, v3: T, v4: T +) -> collections.abc.Sequence[collections.abc.Mapping[K, collections.abc.Sequence[T]]]: + """Creates a sequence of mappings, where each mapping contains sequences.""" + return sequence_from_pair( + mapping_from_pair(k1, sequence_from_pair(v1, v2)), + mapping_from_pair(k2, sequence_from_pair(v3, v4)), + ) + + +def get_from_constructed_sequence[T](a: T, b: T, index: int) -> T: + """Constructs a sequence from two elements and gets one by index.""" + return sequence_getitem(sequence_from_pair(a, b), index) + + +def get_from_constructed_mapping[K, V](key: K, value: V, lookup_key: K) -> V: + """Constructs a mapping from a key-value pair and looks up the value.""" + return mapping_getitem(mapping_from_pair(key, value), lookup_key) + + +def double_nested_get[K, T]( + k1: K, + v1: T, + v2: T, + k2: K, + v3: T, + v4: T, + outer_index: int, + inner_key: K, + inner_index: int, +) -> T: + """Creates nested structure and retrieves deeply nested value.""" + nested = nested_sequence_mapping(k1, v1, v2, k2, v3, v4) + mapping = sequence_getitem(nested, outer_index) + sequence = mapping_getitem(mapping, inner_key) + return sequence_getitem(sequence, inner_index) + + +def construct_and_extend_sequence[T]( + a: T, b: T, c: T, d: T +) -> collections.abc.Sequence[collections.abc.Sequence[T]]: + """Constructs two sequences and combines them into a sequence of sequences.""" + seq1 = sequence_from_pair(a, b) + seq2 = sequence_from_pair(c, d) + return sequence_from_pair(seq1, seq2) + + +def transform_mapping_values[K, T]( + key1: K, val1: T, key2: K, val2: T +) -> collections.abc.Mapping[K, collections.abc.Sequence[T]]: + """Creates a mapping where each value is wrapped in a sequence.""" + # Create mappings where each value becomes a single-element sequence + # Note: In a real implementation, we'd need a sequence_from_single function + # For now, using sequence_from_pair with the same value twice as a workaround + return mapping_from_pair(key1, sequence_from_pair(val1, val1)) + + +def call_func[T, V]( + func: collections.abc.Callable[[T], V], + arg: T, +) -> V: + """Calls a function with a single argument.""" + return func(arg) + + +def call_binary_func[T, U, V]( + func: collections.abc.Callable[[T, U], V], + arg1: T, + arg2: U, +) -> V: + """Calls a binary function with two arguments.""" + return func(arg1, arg2) + + +def map_sequence[T, U]( + f: collections.abc.Callable[[T], U], + seq: collections.abc.Sequence[T], +) -> collections.abc.Sequence[U]: + """Applies a function to each element in a sequence.""" + return [call_func(f, x) for x in seq] + + +def compose_mappings[T, U, V]( + f: collections.abc.Callable[[T], U], + g: collections.abc.Callable[[U], V], +) -> collections.abc.Callable[[T], V]: + """Composes two functions that operate on mappings.""" + + def composed(x: T) -> V: + return call_func(g, call_func(f, x)) + + return composed + + +def compose_binary[T, U, V]( + f: collections.abc.Callable[[T], U], + g: collections.abc.Callable[[U, U], V], +) -> collections.abc.Callable[[T], V]: + """Composes a unary function with a binary function.""" + + def composed(x: T) -> V: + return call_binary_func(g, call_func(f, x), call_func(f, x)) + + return composed + + +def apply_to_sequence_element[T, U]( + f: collections.abc.Callable[[T], U], + seq: collections.abc.Sequence[T], + index: int, +) -> U: + """Gets an element from a sequence and applies a function to it.""" + element = sequence_getitem(seq, index) + return call_func(f, element) + + +def map_and_get[T, U]( + f: collections.abc.Callable[[T], U], + seq: collections.abc.Sequence[T], + index: int, +) -> U: + """Maps a function over a sequence and gets element at index.""" + mapped_seq = map_sequence(f, seq) + return sequence_getitem(mapped_seq, index) + + +def compose_and_apply[T, U, V]( + f: collections.abc.Callable[[T], U], + g: collections.abc.Callable[[U], V], + value: T, +) -> V: + """Composes two functions and applies the result to a value.""" + composed = compose_mappings(f, g) + return call_func(composed, value) + + +def double_compose_apply[T, U, V, W]( + f: collections.abc.Callable[[T], U], + g: collections.abc.Callable[[U], V], + h: collections.abc.Callable[[V], W], + value: T, +) -> W: + """Composes three functions and applies to a value.""" + fg = compose_mappings(f, g) + fgh = compose_mappings(fg, h) + return call_func(fgh, value) + + +def binary_on_sequence_elements[T, U]( + f: collections.abc.Callable[[T, T], U], + seq: collections.abc.Sequence[T], + index1: int, + index2: int, +) -> U: + """Gets two elements from a sequence and applies a binary function.""" + elem1 = sequence_getitem(seq, index1) + elem2 = sequence_getitem(seq, index2) + return call_binary_func(f, elem1, elem2) + + +def map_sequence_and_apply_binary[T, U, V]( + f: collections.abc.Callable[[T], U], + g: collections.abc.Callable[[U, U], V], + seq: collections.abc.Sequence[T], + index1: int, + index2: int, +) -> V: + """Maps a function over sequence, then applies binary function to two elements.""" + mapped = map_sequence(f, seq) + elem1 = sequence_getitem(mapped, index1) + elem2 = sequence_getitem(mapped, index2) + return call_binary_func(g, elem1, elem2) + + +def construct_apply_and_get[T, U]( + f: collections.abc.Callable[[T], U], + a: T, + b: T, + index: int, +) -> U: + """Constructs a sequence, applies function to elements, and gets one.""" + seq = sequence_from_pair(a, b) + return apply_to_sequence_element(f, seq, index) + + +def sequence_function_composition[T]( + funcs: collections.abc.Sequence[collections.abc.Callable[[T], T]], + value: T, +) -> T: + """Applies a sequence of functions in order to a value.""" + result = value + for func in funcs: + result = call_func(func, result) + return result + + +def map_with_constructed_function[T, U, V]( + f: collections.abc.Callable[[T], U], + g: collections.abc.Callable[[U], V], + seq: collections.abc.Sequence[T], +) -> collections.abc.Sequence[V]: + """Composes two functions and maps the result over a sequence.""" + composed = compose_mappings(f, g) + return map_sequence(composed, seq) + + +def cross_apply_binary[T, U, V]( + f: collections.abc.Callable[[T, U], V], + seq1: collections.abc.Sequence[T], + seq2: collections.abc.Sequence[U], + index1: int, + index2: int, +) -> V: + """Gets elements from two sequences and applies a binary function.""" + elem1 = sequence_getitem(seq1, index1) + elem2 = sequence_getitem(seq2, index2) + return call_binary_func(f, elem1, elem2) + + +def nested_function_application[T, U, V]( + outer_f: collections.abc.Callable[[T], collections.abc.Callable[[U], V]], + inner_arg: U, + outer_arg: T, +) -> V: + """Applies a function that returns a function, then applies the result.""" + inner_f = call_func(outer_f, outer_arg) + return call_func(inner_f, inner_arg) + + +@pytest.mark.parametrize( + "seq,index,key", + [ + # Original test case: list of dicts with string keys and int values + ([{"a": 1}, {"b": 2}, {"c": 3}], 1, "b"), + # Different value types + ([{"x": "hello"}, {"y": "world"}, {"z": "test"}], 2, "z"), + ([{"name": 3.14}, {"value": 2.71}, {"constant": 1.41}], 0, "name"), + ([{"flag": True}, {"enabled": False}, {"active": True}], 1, "enabled"), + # Mixed value types in same dict (should still work) + ([{"a": [1, 2, 3]}, {"b": [4, 5, 6]}, {"c": [7, 8, 9]}], 0, "a"), + ([{"data": {"nested": "value"}}, {"info": {"deep": "data"}}], 1, "info"), + # Different key types + ([{1: "one"}, {2: "two"}, {3: "three"}], 2, 3), + ([{True: "yes"}, {False: "no"}], 0, True), + # Nested collections as values + ([{"items": [1, 2, 3]}, {"values": [4, 5, 6]}], 0, "items"), + ([{"matrix": [[1, 2], [3, 4]]}, {"grid": [[5, 6], [7, 8]]}], 1, "grid"), + ([{"sets": {1, 2, 3}}, {"groups": {4, 5, 6}}], 0, "sets"), + # Complex nested structures + ( + [ + {"users": [{"id": 1, "name": "Alice"}]}, + {"users": [{"id": 2, "name": "Bob"}]}, + ], + 1, + "users", + ), + ( + [ + {"config": {"db": {"host": "localhost", "port": 5432}}}, + {"config": {"cache": {"ttl": 300}}}, + ], + 0, + "config", + ), + # Edge cases with single element sequences + ([{"only": "one"}], 0, "only"), + # Tuples as values + ([{"point": (1, 2)}, {"coord": (3, 4)}, {"pos": (5, 6)}], 2, "pos"), + ([{"rgb": (255, 0, 0)}, {"hsv": (0, 100, 100)}], 0, "rgb"), + ], +) +def test_infer_composition_1(seq, index, key): + sig1 = inspect.signature(sequence_getitem) + sig2 = inspect.signature(mapping_getitem) + + sig12 = inspect.signature(sequence_mapping_getitem) + + inferred_type1 = substitute( + sig1.return_annotation, + unify(sig1, sig1.bind(nested_type(seq), nested_type(index))), + ) + + inferred_type2 = substitute( + sig2.return_annotation, + unify(sig2, sig2.bind(nested_type(inferred_type1), nested_type(key))), + ) + + inferred_type12 = substitute( + sig12.return_annotation, + unify( + sig12, + sig12.bind(nested_type(seq), nested_type(index), nested_type(key)), + ), + ) + + # check that the composed inference matches the direct inference + assert isinstance(unify(inferred_type2, inferred_type12), collections.abc.Mapping) + + # check that the result of nested_type on the value of the composition unifies with the inferred type + assert isinstance( + unify(nested_type(sequence_mapping_getitem(seq, index, key)), inferred_type12), + collections.abc.Mapping, + ) + + +@pytest.mark.parametrize( + "mapping,key,index", + [ + # Dict of lists with string keys + ( + { + "fruits": ["apple", "banana", "cherry"], + "colors": ["red", "green", "blue"], + }, + "fruits", + 1, + ), + ({"numbers": [1, 2, 3, 4, 5], "primes": [2, 3, 5, 7, 11]}, "primes", 3), + # Different value types in sequences + ({"floats": [1.1, 2.2, 3.3], "constants": [3.14, 2.71, 1.41]}, "constants", 0), + ( + {"flags": [True, False, True, False], "states": [False, True, False]}, + "flags", + 2, + ), + # Nested structures + ( + {"matrix": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], "identity": [[1, 0], [0, 1]]}, + "matrix", + 1, + ), + ( + {"teams": [{"name": "A", "score": 10}, {"name": "B", "score": 20}]}, + "teams", + 0, + ), + # Different key types + ( + { + 1: ["one", "uno", "un"], + 2: ["two", "dos", "deux"], + 3: ["three", "tres", "trois"], + }, + 2, + 1, + ), + ({True: ["yes", "true", "1"], False: ["no", "false", "0"]}, False, 2), + # Lists of different collection types + ( + {"data": [{"a": 1}, {"b": 2}, {"c": 3}], "info": [{"x": 10}, {"y": 20}]}, + "data", + 2, + ), + # Edge cases + ({"single": ["only"]}, "single", 0), + # Complex nested case + ( + { + "users": [ + {"id": 1, "tags": ["admin", "user"]}, + {"id": 2, "tags": ["user", "guest"]}, + {"id": 3, "tags": ["guest"]}, + ] + }, + "users", + 1, + ), + # More diverse cases + ( + {"names": ["Alice", "Bob", "Charlie", "David"], "ages": [25, 30, 35, 40]}, + "names", + 3, + ), + ( + {"options": [[1, 2], [3, 4], [5, 6]], "choices": [[7], [8], [9]]}, + "options", + 2, + ), + # Deeply nested lists + ( + {"deep": [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], "shallow": [[9, 10]]}, + "deep", + 0, + ), + ], +) +def test_infer_composition_2(mapping, key, index): + sig1 = inspect.signature(mapping_getitem) + sig2 = inspect.signature(sequence_getitem) + + sig12 = inspect.signature(mapping_sequence_getitem) + + # First infer type of mapping_getitem(mapping, key) -> should be a sequence + inferred_type1 = substitute( + sig1.return_annotation, + unify(sig1, sig1.bind(nested_type(mapping), nested_type(key))), + ) + + # Then infer type of sequence_getitem(result_from_step1, index) -> should be element type + inferred_type2 = substitute( + sig2.return_annotation, + unify(sig2, sig2.bind(nested_type(inferred_type1), nested_type(index))), + ) + + # Directly infer type of mapping_sequence_getitem(mapping, key, index) + inferred_type12 = substitute( + sig12.return_annotation, + unify( + sig12, + sig12.bind(nested_type(mapping), nested_type(key), nested_type(index)), + ), + ) + + # The composed inference should match the direct inference + assert isinstance(unify(inferred_type2, inferred_type12), collections.abc.Mapping) + + # check that the result of nested_type on the value of the composition unifies with the inferred type + assert isinstance( + unify( + nested_type(mapping_sequence_getitem(mapping, key, index)), inferred_type12 + ), + collections.abc.Mapping, + ) + + +@pytest.mark.parametrize( + "a,b,index", + [ + # Basic types + (1, 2, 0), + (1, 2, 1), + ("hello", "world", 0), + (3.14, 2.71, 1), + (True, False, 0), + # Complex types + ([1, 2], [3, 4], 1), + ({"a": 1}, {"b": 2}, 0), + ({1, 2}, {3, 4}, 1), + # Mixed but same types + ([1, 2, 3], [4, 5], 0), + ({"x": "a", "y": "b"}, {"z": "c"}, 1), + ], +) +def test_get_from_constructed_sequence(a, b, index): + """Test type inference through sequence construction and retrieval.""" + sig_construct = inspect.signature(sequence_from_pair) + sig_getitem = inspect.signature(sequence_getitem) + sig_composed = inspect.signature(get_from_constructed_sequence) + + # Infer type of sequence_from_pair(a, b) -> Sequence[T] + construct_subs = unify( + sig_construct, sig_construct.bind(nested_type(a), nested_type(b)) + ) + inferred_sequence_type = substitute(sig_construct.return_annotation, construct_subs) + + # Infer type of sequence_getitem(sequence, index) -> T + getitem_subs = unify( + sig_getitem, sig_getitem.bind(inferred_sequence_type, nested_type(index)) + ) + inferred_element_type = substitute(sig_getitem.return_annotation, getitem_subs) + + # Directly infer type of get_from_constructed_sequence(a, b, index) + direct_subs = unify( + sig_composed, + sig_composed.bind(nested_type(a), nested_type(b), nested_type(index)), + ) + direct_type = substitute(sig_composed.return_annotation, direct_subs) + + # The composed inference should match the direct inference + assert isinstance( + unify(inferred_element_type, direct_type), collections.abc.Mapping + ) + + # check that the result of nested_type on the value of the composition unifies with the inferred type + assert isinstance( + unify(nested_type(get_from_constructed_sequence(a, b, index)), direct_type), + collections.abc.Mapping, + ) + + +@pytest.mark.parametrize( + "key,value,lookup_key", + [ + # Basic types + ("name", "Alice", "name"), + (1, "one", 1), + (True, "yes", True), + (3.14, "pi", 3.14), + # Complex value types + ("data", [1, 2, 3], "data"), + ("config", {"host": "localhost", "port": 8080}, "config"), + ("items", {1, 2, 3}, "items"), + # Different key types + (42, {"value": "answer"}, 42), + ("key", (1, 2, 3), "key"), + ], +) +def test_get_from_constructed_mapping(key, value, lookup_key): + """Test type inference through mapping construction and retrieval.""" + sig_construct = inspect.signature(mapping_from_pair) + sig_getitem = inspect.signature(mapping_getitem) + sig_composed = inspect.signature(get_from_constructed_mapping) + + # Infer type of mapping_from_pair(key, value) -> Mapping[K, V] + construct_subs = unify( + sig_construct, sig_construct.bind(nested_type(key), nested_type(value)) + ) + inferred_mapping_type = substitute(sig_construct.return_annotation, construct_subs) + + # Infer type of mapping_getitem(mapping, lookup_key) -> V + getitem_subs = unify( + sig_getitem, sig_getitem.bind(inferred_mapping_type, nested_type(lookup_key)) + ) + inferred_value_type = substitute(sig_getitem.return_annotation, getitem_subs) + + # Directly infer type of get_from_constructed_mapping(key, value, lookup_key) + direct_subs = unify( + sig_composed, + sig_composed.bind( + nested_type(key), nested_type(value), nested_type(lookup_key) + ), + ) + direct_type = substitute(sig_composed.return_annotation, direct_subs) + + # The composed inference should match the direct inference + assert isinstance(unify(inferred_value_type, direct_type), collections.abc.Mapping) + + # check that the result of nested_type on the value of the composition unifies with the inferred type + assert isinstance( + unify( + nested_type(get_from_constructed_mapping(key, value, lookup_key)), + direct_type, + ), + collections.abc.Mapping, + ) + + +@pytest.mark.parametrize( + "key1,val1,key2,val2,index", + [ + # Basic case + ("a", 1, "b", 2, 0), + ("x", "hello", "y", "world", 1), + # Different types + (1, "one", 2, "two", 0), + (True, 1.0, False, 0.0, 1), + # Complex values + ("list1", [1, 2], "list2", [3, 4], 0), + ("dict1", {"a": 1}, "dict2", {"b": 2}, 1), + ], +) +def test_sequence_of_mappings(key1, val1, key2, val2, index): + """Test type inference for creating a sequence of mappings.""" + sig_map = inspect.signature(mapping_from_pair) + sig_seq = inspect.signature(sequence_from_pair) + sig_composed = inspect.signature(sequence_of_mappings) + + # Step 1: Infer types of the two mappings + map1_subs = unify(sig_map, sig_map.bind(nested_type(key1), nested_type(val1))) + map1_type = substitute(sig_map.return_annotation, map1_subs) + + # Step 2: Infer type of sequence containing these mappings + # We need to unify the two mapping types first + unified_map_type = map1_type # Assuming they're compatible + + seq_subs = unify(sig_seq, sig_seq.bind(unified_map_type, unified_map_type)) + seq_type = substitute(sig_seq.return_annotation, seq_subs) + + # Direct inference + direct_subs = unify( + sig_composed, + sig_composed.bind( + nested_type(key1), nested_type(val1), nested_type(key2), nested_type(val2) + ), + ) + direct_type = substitute(sig_composed.return_annotation, direct_subs) + + # The types should match + assert isinstance(unify(seq_type, direct_type), collections.abc.Mapping) + + # Note: nested_type(sequence_of_mappings(...)) returns concrete types (list[dict[K,V]]) + # while our function signature uses abstract types (Sequence[Mapping[K,V]]) + # This is expected behavior - concrete implementations vs abstract interfaces + + +@pytest.mark.parametrize( + "k1,v1,v2,k2,v3,v4,outer_idx,inner_key,inner_idx", + [ + # Basic test case + ("first", 1, 2, "second", 3, 4, 0, "first", 1), + ("a", "x", "y", "b", "z", "w", 1, "b", 0), + # Different types + (1, 10.0, 20.0, 2, 30.0, 40.0, 0, 1, 1), + ("data", [1], [2], "info", [3], [4], 1, "info", 0), + ], +) +def test_double_nested_get(k1, v1, v2, k2, v3, v4, outer_idx, inner_key, inner_idx): + """Test type inference through deeply nested structure construction and retrieval.""" + # Get signatures for all functions involved + sig_nested = inspect.signature(nested_sequence_mapping) + sig_seq_get = inspect.signature(sequence_getitem) + sig_map_get = inspect.signature(mapping_getitem) + sig_composed = inspect.signature(double_nested_get) + + # Step 1: Infer type of nested_sequence_mapping construction + nested_subs = unify( + sig_nested, + sig_nested.bind( + nested_type(k1), + nested_type(v1), + nested_type(v2), + nested_type(k2), + nested_type(v3), + nested_type(v4), + ), + ) + nested_seq_type = substitute(sig_nested.return_annotation, nested_subs) + # This should be Sequence[Mapping[K, Sequence[T]]] + + # Step 2: Get element from outer sequence + outer_get_subs = unify( + sig_seq_get, sig_seq_get.bind(nested_seq_type, nested_type(outer_idx)) + ) + mapping_type = substitute(sig_seq_get.return_annotation, outer_get_subs) + # This should be Mapping[K, Sequence[T]] + + # Step 3: Get sequence from mapping + inner_map_subs = unify( + sig_map_get, sig_map_get.bind(mapping_type, nested_type(inner_key)) + ) + sequence_type = substitute(sig_map_get.return_annotation, inner_map_subs) + # This should be Sequence[T] + + # Step 4: Get element from inner sequence + final_get_subs = unify( + sig_seq_get, sig_seq_get.bind(sequence_type, nested_type(inner_idx)) + ) + composed_type = substitute(sig_seq_get.return_annotation, final_get_subs) + # This should be T + + # Direct inference on the composed function + direct_subs = unify( + sig_composed, + sig_composed.bind( + nested_type(k1), + nested_type(v1), + nested_type(v2), + nested_type(k2), + nested_type(v3), + nested_type(v4), + nested_type(outer_idx), + nested_type(inner_key), + nested_type(inner_idx), + ), + ) + direct_type = substitute(sig_composed.return_annotation, direct_subs) + + # The composed inference should match the direct inference + assert isinstance(unify(composed_type, direct_type), collections.abc.Mapping) + + # check that the result of nested_type on the value of the composition unifies with the inferred type + assert isinstance( + unify( + nested_type( + double_nested_get( + k1, v1, v2, k2, v3, v4, outer_idx, inner_key, inner_idx + ) + ), + direct_type, + ), + collections.abc.Mapping, + ) + + +@pytest.mark.parametrize( + "f,seq,index", + [ + # Basic function applications + (lambda x: x * 2, [1, 2, 3], 0), + (lambda x: x * 2, [1, 2, 3], 2), + (lambda x: x.upper(), ["hello", "world"], 1), + (lambda x: len(x), ["a", "bb", "ccc"], 2), + (lambda x: x + 1.0, [1.0, 2.0, 3.0], 1), + ], +) +def test_apply_to_sequence_element(f, seq, index): + """Test type inference through sequence access and function application.""" + sig_getitem = inspect.signature(sequence_getitem) + sig_call = inspect.signature(call_func) + sig_composed = inspect.signature(apply_to_sequence_element) + + # Step 1: Infer type of sequence_getitem(seq, index) -> T + getitem_subs = unify( + sig_getitem, sig_getitem.bind(nested_type(seq), nested_type(index)) + ) + element_type = substitute(sig_getitem.return_annotation, getitem_subs) + + # Step 2: Infer type of call_func(f, element) -> U + call_subs = unify(sig_call, sig_call.bind(nested_type(f), element_type)) + composed_type = substitute(sig_call.return_annotation, call_subs) + + # Direct inference + direct_subs = unify( + sig_composed, + sig_composed.bind(nested_type(f), nested_type(seq), nested_type(index)), + ) + direct_type = substitute(sig_composed.return_annotation, direct_subs) + + # The composed inference should match the direct inference + assert isinstance(unify(composed_type, direct_type), collections.abc.Mapping) + + # check that the result of nested_type on the value of the composition unifies with the inferred type + assert isinstance( + unify(nested_type(apply_to_sequence_element(f, seq, index)), direct_type), + collections.abc.Mapping, + ) + + +@pytest.mark.parametrize( + "f,seq,index", + [ + # Basic transformations + (lambda x: x * 2, [1, 2, 3], 1), + (lambda x: x.upper(), ["hello", "world"], 0), + (lambda x: len(x), ["a", "bb", "ccc"], 2), + (lambda x: x + 1, [10, 20, 30], 0), + ], +) +def test_map_and_get(f, seq, index): + """Test type inference through mapping and element retrieval.""" + sig_map = inspect.signature(map_sequence) + sig_getitem = inspect.signature(sequence_getitem) + sig_composed = inspect.signature(map_and_get) + + # Step 1: Infer type of map_sequence(f, seq) -> Sequence[U] + map_subs = unify(sig_map, sig_map.bind(nested_type(f), nested_type(seq))) + mapped_type = substitute(sig_map.return_annotation, map_subs) + + # Step 2: Infer type of sequence_getitem(mapped_seq, index) -> U + getitem_subs = unify(sig_getitem, sig_getitem.bind(mapped_type, nested_type(index))) + composed_type = substitute(sig_getitem.return_annotation, getitem_subs) + + # Direct inference + direct_subs = unify( + sig_composed, + sig_composed.bind(nested_type(f), nested_type(seq), nested_type(index)), + ) + direct_type = substitute(sig_composed.return_annotation, direct_subs) + + # The composed inference should match the direct inference + assert isinstance(unify(composed_type, direct_type), collections.abc.Mapping) + + # check that the result of nested_type on the value of the composition unifies with the inferred type + assert isinstance( + unify(nested_type(map_and_get(f, seq, index)), direct_type), + collections.abc.Mapping, + ) + + +@pytest.mark.parametrize( + "f,g,value", + [ + # Basic function compositions + (lambda x: x * 2, lambda x: x + 1, 5), + (lambda x: str(x), lambda x: x.upper(), 42), + (lambda x: len(x), lambda x: x * 2, "hello"), + (lambda x: [x], lambda x: x[0], 1), + ], +) +def test_compose_and_apply(f, g, value): + """Test type inference through function composition and application.""" + sig_compose = inspect.signature(compose_mappings) + sig_call = inspect.signature(call_func) + sig_composed = inspect.signature(compose_and_apply) + + # Step 1: Infer type of compose_mappings(f, g) -> Callable[[T], V] + compose_subs = unify(sig_compose, sig_compose.bind(nested_type(f), nested_type(g))) + composed_func_type = substitute(sig_compose.return_annotation, compose_subs) + + # Step 2: Infer type of call_func(composed, value) -> V + call_subs = unify(sig_call, sig_call.bind(composed_func_type, nested_type(value))) + result_type = substitute(sig_call.return_annotation, call_subs) + + # Direct inference + direct_subs = unify( + sig_composed, + sig_composed.bind(nested_type(f), nested_type(g), nested_type(value)), + ) + direct_type = substitute(sig_composed.return_annotation, direct_subs) + + # The composed inference should match the direct inference + assert isinstance(unify(result_type, direct_type), collections.abc.Mapping) + + # check that the result of nested_type on the value of the composition unifies with the inferred type + assert isinstance( + unify(nested_type(compose_and_apply(f, g, value)), direct_type), + collections.abc.Mapping, + ) + + +@pytest.mark.parametrize( + "f,a,b,index", + [ + # Basic constructions and applications + (lambda x: x * 2, 1, 2, 0), + (lambda x: x * 2, 1, 2, 1), + (lambda x: x.upper(), "hello", "world", 0), + (lambda x: len(x), "a", "bb", 1), + ], +) +def test_construct_apply_and_get(f, a, b, index): + """Test type inference through construction, application, and retrieval.""" + sig_construct = inspect.signature(sequence_from_pair) + sig_apply = inspect.signature(apply_to_sequence_element) + sig_composed = inspect.signature(construct_apply_and_get) + + # Step 1: Infer type of sequence_from_pair(a, b) -> Sequence[T] + construct_subs = unify( + sig_construct, sig_construct.bind(nested_type(a), nested_type(b)) + ) + seq_type = substitute(sig_construct.return_annotation, construct_subs) + + # Step 2: Infer type of apply_to_sequence_element(f, seq, index) -> U + apply_subs = unify( + sig_apply, sig_apply.bind(nested_type(f), seq_type, nested_type(index)) + ) + composed_type = substitute(sig_apply.return_annotation, apply_subs) + + # Direct inference + direct_subs = unify( + sig_composed, + sig_composed.bind( + nested_type(f), nested_type(a), nested_type(b), nested_type(index) + ), + ) + direct_type = substitute(sig_composed.return_annotation, direct_subs) + + # The composed inference should match the direct inference + assert isinstance(unify(composed_type, direct_type), collections.abc.Mapping) + + # check that the result of nested_type on the value of the composition unifies with the inferred type + assert isinstance( + unify(nested_type(construct_apply_and_get(f, a, b, index)), direct_type), + collections.abc.Mapping, + ) + + +@pytest.mark.parametrize( + "f,seq,index1,index2", + [ + # Basic binary operations + (lambda x, y: x + y, [1, 2, 3], 0, 1), + (lambda x, y: x + y, [1, 2, 3], 1, 2), + (lambda x, y: x + y, ["hello", "world", "test"], 0, 2), + (lambda x, y: x * y, [2, 3, 4], 0, 2), + ], +) +def test_binary_on_sequence_elements(f, seq, index1, index2): + """Test type inference through sequence access and binary function application.""" + sig_getitem = inspect.signature(sequence_getitem) + sig_call_binary = inspect.signature(call_binary_func) + sig_composed = inspect.signature(binary_on_sequence_elements) + + # Step 1: Infer types of sequence_getitem calls + getitem1_subs = unify( + sig_getitem, sig_getitem.bind(nested_type(seq), nested_type(index1)) + ) + elem1_type = substitute(sig_getitem.return_annotation, getitem1_subs) + + getitem2_subs = unify( + sig_getitem, sig_getitem.bind(nested_type(seq), nested_type(index2)) + ) + elem2_type = substitute(sig_getitem.return_annotation, getitem2_subs) + + # Step 2: Infer type of call_binary_func(f, elem1, elem2) -> V + call_subs = unify( + sig_call_binary, sig_call_binary.bind(nested_type(f), elem1_type, elem2_type) + ) + composed_type = substitute(sig_call_binary.return_annotation, call_subs) + + # Direct inference + direct_subs = unify( + sig_composed, + sig_composed.bind( + nested_type(f), nested_type(seq), nested_type(index1), nested_type(index2) + ), + ) + direct_type = substitute(sig_composed.return_annotation, direct_subs) + + # The composed inference should match the direct inference + assert isinstance(unify(composed_type, direct_type), collections.abc.Mapping) + + # check that the result of nested_type on the value of the composition unifies with the inferred type + assert isinstance( + unify( + nested_type(binary_on_sequence_elements(f, seq, index1, index2)), + direct_type, + ), + collections.abc.Mapping, + ) diff --git a/tests/test_semi_ring.py b/tests/test_semi_ring.py index 4f85b92a..0418b8c8 100644 --- a/tests/test_semi_ring.py +++ b/tests/test_semi_ring.py @@ -1,10 +1,14 @@ import random +import typing -from docs.source.semi_ring import Dict, Field, Let, Sum, eager, ops, opt +from docs.source.semi_ring import Dict, Field, Let, SemiRingDict, Sum, eager, ops, opt from effectful.ops.semantics import handler from effectful.ops.syntax import defop, trace from effectful.ops.types import Term +S = typing.TypeVar("S") +T = typing.TypeVar("T") + @trace def add1(v: int) -> int: @@ -12,32 +16,32 @@ def add1(v: int) -> int: def test_simple_sum(): - x = defop(str, name="x") - y = defop(object, name="y") + x = defop(SemiRingDict[str, T], name="x") + y = defop(SemiRingDict[str, T], name="y") k = defop(str, name="k") v = defop(int, name="v") with handler(eager): - e = Sum(Dict("a", 1, "b", 2), k, v, Dict("v", v())) + e = Sum(Dict(("a", 1), ("b", 2)), k, v, Dict(("v", v()))) assert e["v"] == 3 with handler(eager): - e = Let(Dict("a", 1, "b", 2), x, Field(x(), "b")) + e = Let(Dict(("a", 1), ("b", 2)), x, Field(x(), "b")) assert e == 2 with handler(eager): - e = Sum(Dict("a", 1, "b", 2), k, v, Dict(k(), add1(add1(v())))) + e = Sum(Dict(("a", 1), ("b", 2)), k, v, Dict((k(), add1(add1(v()))))) assert e["a"] == 3 assert e["b"] == 4 with handler(eager), handler(opt): e = Let( - Dict("a", 1, "b", 2), + Dict(("a", 1), ("b", 2)), x, Let( - Sum(x(), k, v, Dict(k(), add1(v()))), + Sum(x(), k, v, Dict((k(), add1(v())))), y, - Sum(y(), k, v, Dict(k(), add1(v()))), + Sum(y(), k, v, Dict((k(), add1(v())))), ), ) assert e["a"] == 3 @@ -45,19 +49,19 @@ def test_simple_sum(): def fusion_test(d): - x = defop(object, name="x") - y = defop(object, name="y") - k = defop(object, name="k") - v = defop(object, name="v") + x = defop(SemiRingDict[S, T], name="x") + y = defop(SemiRingDict[S, T], name="y") + k = defop(str, name="k") + v = defop(int, name="v") return ( Let( d, x, Let( - Sum(x(), k, v, Dict(k(), add1(v()))), + Sum(x(), k, v, Dict((k(), add1(v())))), y, - Sum(y(), k, v, Dict(k(), add1(v()))), + Sum(y(), k, v, Dict((k(), add1(v())))), ), ), (x, y, k, v), @@ -65,15 +69,11 @@ def fusion_test(d): def make_dict(n): - kv = [] - for i in range(n): - kv.append(i) - kv.append(random.randint(1, 10)) - return Dict(*kv) + return Dict(*[(i, random.randint(1, 10)) for i in range(n)]) def test_fusion_term(): - dvar = defop(object, name="dvar") + dvar = defop(SemiRingDict[str, T], name="dvar") with handler(eager), handler(opt): result, (x, _, k, v) = fusion_test(dvar())