From 7c2275653cc9f98630f9d429e6000b093df6088e Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Wed, 6 May 2026 16:21:46 -0400 Subject: [PATCH 1/3] Add monoid module (#653) * add monoid module * clean up * fix doctest * fix * wip * remove incorrect rule * add disjoint set tests and fix bug * lint * drop jax monoid defs * drop incorrect comment * add assert * reduce nondeterminism and add assertions * fix inconsistent stream numbering and missing constant factors --- effectful/internals/disjoint_set.py | 99 +++++ effectful/ops/monoid.py | 556 +++++++++++++++++++++++++++ effectful/ops/syntax.py | 78 ++++ pyproject.toml | 1 + tests/_monoid_helpers.py | 85 ++++ tests/test_internals_disjoint_set.py | 124 ++++++ tests/test_ops_monoid.py | 518 +++++++++++++++++++++++++ 7 files changed, 1461 insertions(+) create mode 100644 effectful/internals/disjoint_set.py create mode 100644 effectful/ops/monoid.py create mode 100644 tests/_monoid_helpers.py create mode 100644 tests/test_internals_disjoint_set.py create mode 100644 tests/test_ops_monoid.py diff --git a/effectful/internals/disjoint_set.py b/effectful/internals/disjoint_set.py new file mode 100644 index 000000000..73b5c5c52 --- /dev/null +++ b/effectful/internals/disjoint_set.py @@ -0,0 +1,99 @@ +class DisjointSet: + """Disjoint Set Union (Union-Find) data structure. + + Maintains a collection of disjoint sets over the integers 0..n-1, + supporting near-constant-time union and find operations via + path compression and union by rank. + + The amortized time complexity per operation is O(α(n)), where α + is the inverse Ackermann function (effectively constant for any + practical n). + + Example: + >>> dsu = DisjointSet(5) + >>> dsu.union(0, 1) + True + >>> dsu.union(1, 2) + True + >>> dsu.find(0) == dsu.find(2) + True + >>> dsu.find(0) == dsu.find(3) + False + """ + + def __init__(self, n): + """Initialize n singleton sets: {0}, {1}, ..., {n-1}. + + Args: + n: The number of elements. Elements are labeled 0..n-1. + """ + self.parent = list(range(n)) + self.rank = [0] * n + + def _validate(self, x): + if x < 0 or x >= len(self.parent): + raise IndexError(f"Element {x} out of bounds") + + def find(self, x): + """Return the representative (root) of the set containing x. + + Two elements belong to the same set if and only if they have + the same representative. Applies path compression: every node + traversed is re-parented directly to its grandparent, flattening + the tree to speed up future queries. + + Args: + x: The element to look up. + + Returns: + The root element of x's set. + """ + self._validate(x) + while self.parent[x] != x: + self.parent[x] = self.parent[self.parent[x]] # path compression + x = self.parent[x] + return x + + def union(self, *elements): + """Merge the sets containing all given elements into one. + + Accepts any number of elements and unions them all together. + Uses union by rank: shallower trees are attached under the root + of the deeper one, keeping the combined tree shallow. + + Args: + *elements: Two or more elements to merge into a single set. + Calling with 0 or 1 elements is a no-op and returns False. + + Returns: + True if any merging occurred (i.e., at least two of the + elements were in different sets); False if all elements + were already in the same set or fewer than 2 were given. + """ + if len(elements) < 2: + return False + + merged = False + first = elements[0] + + for y in elements[1:]: + if self._union_pair(first, y): + merged = True + + return merged + + def _union_pair(self, x, y): + rx = self.find(x) + ry = self.find(y) + + if rx == ry: + return False + + if self.rank[rx] < self.rank[ry]: + rx, ry = ry, rx + + self.parent[ry] = rx + if self.rank[rx] == self.rank[ry]: + self.rank[rx] += 1 + + return True diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py new file mode 100644 index 000000000..58a10ba3d --- /dev/null +++ b/effectful/ops/monoid.py @@ -0,0 +1,556 @@ +import collections.abc +import functools +import itertools +import numbers +import typing +from collections import Counter, defaultdict +from collections.abc import Callable, Generator, Iterable, Iterator, Mapping, Sequence +from dataclasses import dataclass +from graphlib import TopologicalSorter +from typing import Annotated, Any + +from effectful.internals.disjoint_set import DisjointSet +from effectful.ops.semantics import coproduct, evaluate, fvsof, fwd, handler +from effectful.ops.syntax import ( + ObjectInterpretation, + Scoped, + _NumberTerm, + defdata, + implements, + iter_, + syntactic_eq, + syntactic_hash, +) +from effectful.ops.types import Interpretation, NotHandled, Operation, Term + +# Note: The streams value type should be something like Iterable[T], but some of +# our target stream types (e.g. jax.Array) are not subtypes of Iterable +type Streams[T] = Mapping[Operation[[], T], Any] + +type Body[T] = ( + Iterable[T] + | Callable[..., Body[T]] + | T + | Mapping[Any, Body[T]] + | Interpretation[T, Body[T]] +) + + +def order_streams[T](streams: Streams[T]) -> Iterable[tuple[Operation[[], T], Any]]: + """Determine an order to evaluate the streams based on their dependencies""" + stream_vars = set(streams.keys()) + dependencies = {k: fvsof(v) & stream_vars for k, v in streams.items()} + topo = TopologicalSorter(dependencies) + topo.prepare() + while topo.is_active(): + node_group = topo.get_ready() + for op in sorted(node_group): + yield (op, streams[op]) + topo.done(*node_group) + + +class Monoid[T]: + kernel: Operation[[T, T], T] + identity: T + + def __init__(self, kernel: Callable[[T, T], T], identity: T): + self.identity = identity + self.kernel = ( + kernel if isinstance(kernel, Operation) else Operation.define(kernel) + ) + + def __repr__(self): + return f"{type(self)}({self.kernel}, {self.identity})" + + @Operation.define + def plus[S: Body[T]](self, *args: S) -> S: + """Monoid addition with broadcasting over common collection types, + callables, and interpretations. + + """ + if not args: + return typing.cast(S, self.identity) + + if any(isinstance(x, Term) for x in args): + return typing.cast(S, defdata(self.plus, *args)) + + return self._plus(*args) + + @functools.singledispatchmethod + def _plus[S](self, *args: S) -> S: + return typing.cast(S, functools.reduce(self.kernel, args, self.identity)) + + @_plus.register(Sequence) + def _(self, *args): + return type(args[0])(self.plus(*vs) for vs in zip(*args, strict=True)) + + @_plus.register(Mapping) + def _(self, *args): + if isinstance(args[0], Interpretation): + keys = args[0].keys() + + for b in args[1:]: + if not isinstance(b, Interpretation): + raise TypeError(f"Expected interpretation but got {b}") + + b_keys = b.keys() + if not keys == b_keys: + raise ValueError( + f"Expected interpretation of {keys} but got {b_keys}" + ) + + result = {k: self.plus(*(handler(b)(b[k]) for b in args)) for k in keys} + return result + + for b in args[1:]: + if not isinstance(b, Mapping): + raise TypeError(f"Expected mapping but got {b}") + + all_values = collections.defaultdict(list) + for d in args: + for k, v in d.items(): + all_values[k].append(v) + result = {k: self.plus(*vs) for (k, vs) in all_values.items()} + return result + + @Operation.define + @functools.singledispatchmethod + def reduce[A, B, U: Body]( + self, + body: Annotated[U, Scoped[A | B]], + streams: Annotated[Streams, Scoped[A]], + ) -> Annotated[U, Scoped[B]]: + if callable(body): + return typing.cast(U, lambda *a, **k: self.reduce(body(*a, **k), streams)) + + def generator(loop_order) -> Iterator[Interpretation]: + if len(loop_order) == 0: + return + + stream_key = loop_order[0][0] + stream_values = evaluate(streams[stream_key]) + stream_values_iter = iter(stream_values) # type: ignore[arg-type] + + # If we try to iterate and get a term instead of a real + # iterator, give up + if isinstance(stream_values_iter, Term) and stream_values_iter.op is iter_: + raise NotHandled + + if len(loop_order) == 1: + for val in stream_values_iter: + yield {stream_key: functools.partial(lambda v: v, val)} + else: + for val in stream_values_iter: + intp: Interpretation = { + stream_key: functools.partial(lambda v: v, val) + } + with handler(intp): + for intp2 in generator(loop_order[1:]): + yield coproduct(intp, intp2) + + loop_order = list(order_streams(streams)) + try: + return self.plus( + *(handler(intp)(evaluate)(body) for intp in generator(loop_order)) + ) + except NotHandled: + return typing.cast(U, defdata(self.reduce, body, streams)) + + @reduce.register # type: ignore[attr-defined] + def _(self, body: Mapping, streams): + return {k: self.reduce(v, streams) for (k, v) in body.items()} + + @reduce.register # type: ignore[attr-defined] + def _(self, body: Sequence, streams): + return type(body)(self.reduce(x, streams) for x in body) # type:ignore[call-arg] + + @reduce.register # type: ignore[attr-defined] + def _(self, body: Generator, streams): + return (self.reduce(x, streams) for x in body) + + +class IdempotentMonoid[T](Monoid[T]): + @Operation.define + def plus[S: Body[T]](self, *args: S) -> S: + return super().plus(*args) + + @Operation.define + def reduce[A, B, U: Body]( + self, + body: Annotated[U, Scoped[A | B]], + streams: Annotated[Streams, Scoped[A]], + ) -> Annotated[U, Scoped[B]]: + return super().reduce(body, streams) + + +class CommutativeMonoid[T](Monoid[T]): + @Operation.define + def plus[S: Body[T]](self, *args: S) -> S: + return super().plus(*args) + + @Operation.define + def reduce[A, B, U: Body]( + self, + body: Annotated[U, Scoped[A | B]], + streams: Annotated[Streams, Scoped[A]], + ) -> Annotated[U, Scoped[B]]: + return super().reduce(body, streams) + + +class CommutativeMonoidWithZero[T](CommutativeMonoid[T]): + zero: T + + def __init__(self, kernel: Callable[[T, T], T], identity: T, zero: T): + super().__init__(kernel, identity) + self.zero = zero + + def __repr__(self): + return f"{type(self)}({self.kernel}, {self.identity}, {self.zero})" + + @Operation.define + def plus[S: Body[T]](self, *args: S) -> S: + return super().plus(*args) + + @Operation.define + def reduce[A, B, U: Body]( + self, + body: Annotated[U, Scoped[A | B]], + streams: Annotated[Streams, Scoped[A]], + ) -> Annotated[U, Scoped[B]]: + return super().reduce(body, streams) + + +class Semilattice[T](IdempotentMonoid[T], CommutativeMonoid[T]): + @Operation.define + def plus[S: Body[T]](self, *args: S) -> S: + return super().plus(*args) + + @Operation.define + def reduce[A, B, U: Body]( + self, + body: Annotated[U, Scoped[A | B]], + streams: Annotated[Streams, Scoped[A]], + ) -> Annotated[U, Scoped[B]]: + return super().reduce(body, streams) + + +@Operation.define +def _arg_min[T]( + a: tuple[numbers.Number, T | None], b: tuple[numbers.Number, T | None] +) -> tuple[numbers.Number, T | None]: + if isinstance(a[0], Term) or isinstance(b[0], Term): + raise NotHandled + return b if b[0] < a[0] else a # type: ignore + + +@Operation.define +def _arg_max[T]( + a: tuple[numbers.Number, T | None], b: tuple[numbers.Number, T | None] +) -> tuple[numbers.Number, T | None]: + if isinstance(a[0], Term) or isinstance(b[0], Term): + raise NotHandled + return b if b[0] > a[0] else a # type: ignore + + +Min = Semilattice(kernel=min, identity=float("inf")) +Max = Semilattice(kernel=max, identity=float("-inf")) +ArgMin = Monoid(kernel=_arg_min, identity=(float("inf"), None)) +ArgMax = Monoid(kernel=_arg_max, identity=(float("-inf"), None)) +Sum = CommutativeMonoid(kernel=_NumberTerm.__add__, identity=0) +Product = CommutativeMonoidWithZero(kernel=_NumberTerm.__mul__, identity=1, zero=0) + + +@dataclass +class _ExtensibleBinaryRelation[S, T]: + tuples: set[tuple[S, T]] + + def register(self, s: S, t: T) -> None: + self.tuples.add((s, t)) + + def __call__(self, s: S, t: T) -> bool: + return (s, t) in self.tuples + + +distributes_over = _ExtensibleBinaryRelation( + { + (Max.plus, Min.plus), + (Min.plus, Max.plus), + (Sum.plus, Min.plus), + (Sum.plus, Max.plus), + (Product.plus, Sum.plus), + } +) + + +class PlusEmpty(ObjectInterpretation): + """plus() = 0""" + + @implements(Monoid.plus) + def plus(self, monoid, *args): + if not args: + return monoid.identity + return fwd() + + +class PlusSingle(ObjectInterpretation): + """plus(x) = x""" + + @implements(Monoid.plus) + def plus(self, _, *args): + if len(args) == 1: + return args[0] + return fwd() + + +class PlusIdentity(ObjectInterpretation): + """x₁ + ... + 0 + ... + xₙ = x₁ + ... + xₙ""" + + @implements(Monoid.plus) + def plus(self, monoid, *args): + if any(x is monoid.identity for x in args): + return monoid.plus(*(x for x in args if x is not monoid.identity)) + return fwd() + + +class PlusAssoc(ObjectInterpretation): + """x + (y + z) = (x + y) + z = x + y + z""" + + @implements(Monoid.plus) + def plus(self, monoid, *args): + if any(isinstance(x, Term) and x.op is monoid.plus for x in args): + flat_args = itertools.chain.from_iterable( + t.args if isinstance(t, Term) and t.op is monoid.plus else (t,) + for t in args + ) + assert len(args) > 0 + return monoid.plus(*flat_args) + return fwd() + + +class PlusDistr(ObjectInterpretation): + """x + (y * z) = x * y + x * z""" + + @implements(Monoid.plus) + def plus(self, monoid, *args): + if any( + isinstance(x, Term) and distributes_over(monoid.plus, x.op) for x in args + ): + non_terms = [] + + # group terms by head operation + by_head_op = defaultdict(list) + for t in args: + if isinstance(t, Term): + by_head_op[t.op].append(t) + else: + non_terms.append(t) + + # distribute over each group + progress = False + final_sum = [] + for op, terms in by_head_op.items(): + if ( + len(terms) > 1 + and distributes_over(monoid.plus, op) + and not distributes_over(op, monoid.plus) + ): + progress = True + term_args = (t.args for t in terms) + dist_terms = ( + monoid.plus(*args) for args in itertools.product(*term_args) + ) + final_sum.append(op(*dist_terms)) + else: + final_sum += terms + if progress: + return monoid.plus(*non_terms, *final_sum) + return fwd() + + +class PlusZero(ObjectInterpretation): + """x₁ * ... * 0 * ... * xₙ = 0""" + + @implements(CommutativeMonoidWithZero.plus) + def plus(self, monoid, *args): + if any(x is monoid.zero for x in args): + return monoid.zero + return fwd() + + +class PlusConsecutiveDups(ObjectInterpretation): + """x ⊕ x ⊕ y = x ⊕ y""" + + @implements(IdempotentMonoid.plus) + def plus(self, monoid, *args): + dedup_args = ( + args[i] + for i in range(len(args)) + if i == 0 or not syntactic_eq(args[i - 1], args[i]) + ) + return fwd(monoid, *dedup_args) + + +class PlusDups(ObjectInterpretation): + """x ⊕ y ⊕ x = x ⊕ y""" + + @dataclass + class _HashableTerm: + term: Term + + def __eq__(self, other): + return syntactic_eq(self, other) + + def __hash__(self): + return syntactic_hash(self) + + @implements(Semilattice.plus) + def plus(self, monoid, *args): + # elim dups + args_count = Counter(self._HashableTerm(t) for t in args) + if len(args_count) < len(args): + dedup_args = [] + for t in args: + ht = self._HashableTerm(t) + if ht in args_count: + dedup_args.append(t) + del args_count[ht] + return fwd(monoid, *dedup_args) + return fwd() + + +NormalizePlusIntp = functools.reduce( + coproduct, + typing.cast( + list[Interpretation], + [ + PlusEmpty(), + PlusSingle(), + PlusIdentity(), + PlusAssoc(), + PlusDistr(), + PlusZero(), + PlusConsecutiveDups(), + PlusDups(), + ], + ), +) + + +class ReduceNoStreams(ObjectInterpretation): + """Implements the identity + reduce(R, ∅, body) = 0 + """ + + @implements(Monoid.reduce) + def reduce(self, monoid, _, streams): + if len(streams) == 0: + return monoid.identity + return fwd() + + +class ReduceFusion(ObjectInterpretation): + """Implements the identity + reduce(R, S1, reduce(R, S2, body)) = reduce(R, S1 ∪ S2, body) + """ + + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + if isinstance(body, Term) and body.op == monoid.reduce: + return monoid.reduce(body.args[0], streams | body.args[1]) + return fwd() + + +class ReduceSplit(ObjectInterpretation): + """Implements the identity + reduce(R, S, b1 + ... + bn) = reduce(R, S, b1) + ... + reduce(R, S, bn) + """ + + @implements(CommutativeMonoid.reduce) + def reduce(self, monoid, body, streams): + if isinstance(body, Term) and body.op == monoid.plus: + return monoid.plus(*(monoid.reduce(x, streams) for x in body.args)) + return fwd() + + +class ReduceFactorization(ObjectInterpretation): + """ + Implements factorization of independent terms. + For example, when having two independent distributions, + we can rewrite their marginalization as: + ∫p(x)⋅q(y)dxdy => ∫p(x)dx ⋅ ∫q(y)dy + + More specifically, in terms of reduces we are performing: + reduce(R, (S₁ × ... × Sₖ) , A₁ * ... * Aₖ) + => reduce(R, S₁, A₁) * ... * reduce(R, Sₖ, Aₖ) + where free(Aᵢ) ∩ free(Aⱼ) ∩ S = ∅ + and free(Aᵢ) ∩ S ⊆ Sᵢ + """ + + @implements(CommutativeMonoid.reduce) + def reduce(self, monoid, body, streams): + if isinstance(body, Term) and distributes_over(body.op, monoid.plus): + stream_vars = set(streams.keys()) + factors = [(arg, fvsof(arg)) for arg in body.args] + stream_ids = {v: i for (i, v) in enumerate(stream_vars)} + ds = DisjointSet(len(streams)) + + # streams are in the same partition as their dependencies + for stream_var, stream_id in stream_ids.items(): + stream_body = streams[stream_var] + deps = sorted([stream_ids[v] for v in fvsof(stream_body) & stream_vars]) + ds.union(stream_id, *deps) + + # factors are in the same partition as their dependencies + for factor, factor_fvs in factors: + factor_streams = sorted( + [stream_ids[v] for v in (factor_fvs & stream_vars)] + ) + ds.union(*factor_streams) + + placed_streams = set() + new_reduces = [] + for stream_key in streams: + if stream_key in placed_streams: + continue + + partition = ds.find(stream_ids[stream_key]) + partition_streams = { + k: v + for (k, v) in streams.items() + if ds.find(stream_ids[k]) == partition + } + partition_stream_keys = set(partition_streams.keys()) + + partition_factors = [ + t for t in factors if (t[1] & partition_stream_keys) + ] + + assert all( + (t[1] & stream_vars) <= partition_stream_keys + for t in partition_factors + ), "partition contains all streams required by factor" + + partition_term = body.op(*(t[0] for t in partition_factors)) + new_reduces.append((partition_term, partition_streams)) + placed_streams |= partition_stream_keys + + constant_factors = [t for (t, fvs) in factors if not (fvs & stream_vars)] + + if len(new_reduces) > 1: + result = body.op( + *constant_factors, *(monoid.reduce(*args) for args in new_reduces) + ) + return result + + return fwd() + + +NormalizeReduceIntp = functools.reduce( + coproduct, + typing.cast( + list[Interpretation], + [ReduceNoStreams(), ReduceFusion(), ReduceSplit(), ReduceFactorization()], + ), +) + +NormalizeIntp = coproduct(NormalizePlusIntp, NormalizeReduceIntp) diff --git a/effectful/ops/syntax.py b/effectful/ops/syntax.py index 764016752..8fb12598f 100644 --- a/effectful/ops/syntax.py +++ b/effectful/ops/syntax.py @@ -852,6 +852,84 @@ def _(x: object, other) -> bool: return x == other +@_CustomSingleDispatchCallable +def syntactic_hash(__dispatch: Callable[[type], Callable[[Any], int]], x) -> int: + """Structural hash compatible with :func:`syntactic_eq`. + + Guarantees that ``syntactic_eq(x, y)`` implies + ``syntactic_hash(x) == syntactic_hash(y)``. + + :param x: A term. + :returns: An integer hash. + """ + if dataclasses.is_dataclass(x) and not isinstance(x, type): + return hash( + ( + "dataclass", + type(x), + syntactic_hash( + { + field.name: getattr(x, field.name) + for field in dataclasses.fields(x) + } + ), + ) + ) + else: + return __dispatch(type(x))(x) + + +@syntactic_hash.register +def _(x: Term) -> int: + return hash( + ( + "term", + x.op, + len(x.args), + tuple(syntactic_hash(a) for a in x.args), + # sort kwargs so order doesn't affect the hash + tuple((k, syntactic_hash(x.kwargs[k])) for k in sorted(x.kwargs)), + ) + ) + + +@syntactic_hash.register +def _(x: collections.abc.Mapping) -> int: + # XOR over (key_hash, value_hash) pairs — order-independent, + # matching the set-based comparison in syntactic_eq's Mapping branch. + acc = 0 + for k in x: + acc ^= hash((hash(k), syntactic_hash(x[k]))) + return hash(("mapping", acc)) + + +@syntactic_hash.register +def _(x: collections.abc.Sequence) -> int: + if ( + isinstance(x, tuple) + and hasattr(x, "_fields") + and all(hasattr(x, f) for f in x._fields) + ): + return hash( + ( + "namedtuple", + type(x), + tuple(syntactic_hash(getattr(x, f)) for f in x._fields), + ) + ) + else: + # Use the abstract Sequence tag (not type(x)) because syntactic_eq + # treats any two Sequences of equal length and elementwise-equal + # contents as equal — e.g. [1,2] and (1,2) compare equal. + return hash(("sequence", len(x), tuple(syntactic_hash(a) for a in x))) + + +@syntactic_hash.register(object) +@syntactic_hash.register(str | bytes) +def _(x: object) -> int: + return hash(x) + + class ObjectInterpretation[T, V](collections.abc.Mapping): """A helper superclass for defining an ``Interpretation`` of many :class:`~effectful.ops.types.Operation` instances with shared state or behavior. diff --git a/pyproject.toml b/pyproject.toml index d565403f2..685aaf55f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,7 @@ test = [ "pytest-cov", "pytest-xdist", "pytest-benchmark", + "hypothesis", "mypy", "ruff", "nbval", diff --git a/tests/_monoid_helpers.py b/tests/_monoid_helpers.py new file mode 100644 index 000000000..4532ae72d --- /dev/null +++ b/tests/_monoid_helpers.py @@ -0,0 +1,85 @@ +from collections.abc import Callable, Mapping, Sequence +from typing import Any, get_args, get_origin + +from hypothesis import strategies as st + +from effectful.ops.syntax import deffn +from effectful.ops.types import Operation + + +def _value_strategy_for(annotation: Any) -> st.SearchStrategy[Any]: + """Strategy for the value an *0-arg* Operation should return.""" + if annotation is int: + return st.integers() + if annotation is float: + return st.floats(allow_nan=False) + if get_origin(annotation) is list and get_args(annotation) == (int,): + return st.lists(st.integers()) + raise NotImplementedError( + f"No value strategy for return annotation {annotation!r}; " + "supported: int, list[int]" + ) + + +_UNARY_INT_FNS: list[Callable[[int], int]] = [ + lambda x: x, + lambda x: x + 1, + lambda x: x - 1, + lambda x: -x, + lambda x: 2 * x, + lambda x: 3 * x + 1, +] + +_BINARY_INT_FNS: list[Callable[[int, int], int]] = [ + lambda x, y: x + y, + lambda x, y: x - y, + lambda x, y: x * y, + lambda x, y: x + 2 * y, + lambda x, y: 2 * x - y, +] + +_UNARY_LIST_FNS: list[Callable[[int], list[int]]] = [ + lambda _x: [], + lambda x: [x], + lambda x: [x, x + 1], + lambda x: [x, -x], + lambda x: [0, x, x + 1], +] + + +def _strategy_for_op(op: Operation) -> st.SearchStrategy[Callable[..., Any]]: + """Pick a strategy producing a callable suitable for binding `op` in an + interpretation. Inspects the operation's signature. + """ + sig = op.__signature__ + params = list(sig.parameters.values()) + ret = sig.return_annotation + param_types = tuple(p.annotation for p in params) + + if not params: + return _value_strategy_for(ret).map(deffn) + if ret is int and param_types == (int,): + return st.sampled_from(_UNARY_INT_FNS) + if ret is int and param_types == (int, int): + return st.sampled_from(_BINARY_INT_FNS) + if get_origin(ret) is list and get_args(ret) == (int,) and param_types == (int,): + return st.sampled_from(_UNARY_LIST_FNS) + raise NotImplementedError( + f"Function-typed free var must return int or list[int]; got {ret!r} for {op}" + ) + + +@st.composite +def random_interpretation( + draw: st.DrawFn, free_vars: Sequence[Operation] +) -> Mapping[Operation, Callable[..., Any]]: + """Draw an Interpretation binding every Operation in `case.free_vars` to + a randomly chosen value/callable. Keys are Operation identities. + """ + intp: dict[Operation, Callable[..., Any]] = {} + for op in free_vars: + intp[op] = draw(_strategy_for_op(op)) + return intp + + +__all__ = ["random_interpretation"] diff --git a/tests/test_internals_disjoint_set.py b/tests/test_internals_disjoint_set.py new file mode 100644 index 000000000..808b8d25d --- /dev/null +++ b/tests/test_internals_disjoint_set.py @@ -0,0 +1,124 @@ +import random + +import pytest + +from effectful.internals.disjoint_set import DisjointSet + + +@pytest.fixture +def dsu(): + return DisjointSet(10) + + +def test_initial_state(dsu): + for i in range(10): + assert dsu.find(i) == i + + +def test_simple_union(dsu): + assert dsu.union(1, 2) is True + assert dsu.find(1) == dsu.find(2) + + +def test_union_idempotent(dsu): + dsu.union(1, 2) + assert dsu.union(1, 2) is False + + +def test_union_chain(dsu): + dsu.union(1, 2) + dsu.union(2, 3) + assert dsu.find(1) == dsu.find(3) + + +def test_union_multiple_elements_all_connected(dsu): + dsu.union(1, 2, 3, 4, 5) + roots = {dsu.find(i) for i in [1, 2, 3, 4, 5]} + assert len(roots) == 1 + + +def test_union_multiple_elements_partial_overlap(dsu): + dsu.union(1, 2) + dsu.union(3, 4) + dsu.union(2, 3, 5) + + roots = {dsu.find(i) for i in [1, 2, 3, 4, 5]} + assert len(roots) == 1 + + +def test_union_multiple_elements_with_existing_connections(dsu): + dsu.union(1, 2) + dsu.union(2, 3) + dsu.union(3, 4, 5, 6) + + roots = {dsu.find(i) for i in [1, 2, 3, 4, 5, 6]} + assert len(roots) == 1 + + +def test_union_single_element(dsu): + assert dsu.union(1) is False + + +def test_union_no_elements(dsu): + assert dsu.union() is False + + +def test_union_self(dsu): + assert dsu.union(3, 3) is False + assert dsu.find(3) == 3 + + +def test_transitivity(dsu): + dsu.union(1, 2) + dsu.union(2, 3) + dsu.union(3, 4) + assert dsu.find(1) == dsu.find(4) + + +def test_disjoint_sets_remain_separate(dsu): + dsu.union(1, 2) + dsu.union(3, 4) + assert dsu.find(1) != dsu.find(3) + + +def test_randomized_unions(): + n = 50 + dsu = DisjointSet(n) + + groups = [{i} for i in range(n)] + + def find_group(x): + for g in groups: + if x in g: + return g + + for _ in range(100): + elems = random.sample(range(n), random.randint(2, 5)) + dsu.union(*elems) + + # merge ground-truth groups + merged = set() + for e in elems: + merged |= find_group(e) + + groups = [g for g in groups if g.isdisjoint(merged)] + groups.append(merged) + + # verify structure matches ground truth + for g in groups: + roots = {dsu.find(x) for x in g} + assert len(roots) == 1 + + +def test_path_compression_effect(): + dsu = DisjointSet(6) + dsu.union(0, 1) + dsu.union(1, 2) + dsu.union(2, 3) + dsu.union(3, 4) + + # Trigger compression + root_before = dsu.find(4) + root_after = dsu.find(4) + + assert root_before == root_after diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py new file mode 100644 index 000000000..a22928cca --- /dev/null +++ b/tests/test_ops_monoid.py @@ -0,0 +1,518 @@ +import functools +import itertools + +import pytest +from hypothesis import given, settings +from hypothesis import strategies as st + +from effectful.internals.runtime import interpreter +from effectful.ops.monoid import Max, Min, NormalizeIntp, Product, Semilattice, Sum +from effectful.ops.semantics import apply, evaluate, fvsof, handler +from effectful.ops.syntax import _BaseTerm, defdata, syntactic_eq +from effectful.ops.types import NotHandled, Operation +from tests._monoid_helpers import random_interpretation + +_INT = st.integers(min_value=-100, max_value=100) + +ALL_MONOIDS = [ + pytest.param(Sum, id="Sum"), + pytest.param(Product, id="Product"), + pytest.param(Min, id="Min"), + pytest.param(Max, id="Max"), +] + +COMMUTATIVE = [ + pytest.param(Sum, id="Sum"), + pytest.param(Product, id="Product"), + pytest.param(Min, id="Min"), + pytest.param(Max, id="Max"), +] + +IDEMPOTENT = [ + pytest.param(Min, id="Min"), + pytest.param(Max, id="Max"), +] + +WITH_ZERO = [ + pytest.param(Product, id="Product"), +] + + +def define_vars(*names, typ=int): + if len(names) == 1: + return Operation.define(typ, name=names[0]) + return tuple(Operation.define(typ, name=n) for n in names) + + +@functools.cache +def _canonical_op(idx: int) -> Operation: + """Globally cached canonical Operation, keyed by encounter index. + + Cached so that two independent canonicalize runs return the same + Operation object for the same index — letting ``syntactic_eq`` + compare canonical forms by Operation identity. + """ + return Operation.define(int, name=f"__cv_{idx}") + + +def syntactic_eq_alpha(x, y) -> bool: + """Alpha-equivalence-respecting variant of ``syntactic_eq``. + + Walks each expression bottom-up with :func:`evaluate` and renames + every bound variable to a deterministic canonical Operation. The + canonical names are assigned by a counter that increments in + ``evaluate``'s natural traversal order, so two alpha-equivalent + expressions canonicalize to syntactically identical results. + """ + return syntactic_eq(_canonicalize(x), _canonicalize(y)) + + +def _canonicalize(expr): + counter = itertools.count() + + def _passthrough(op, *args, **kwargs): + return defdata(op, *args, **kwargs) + + def _substitute(arg, renaming): + """Apply a bound-variable renaming using ``evaluate`` for traversal.""" + if not renaming: + return arg + with interpreter({apply: _passthrough, **renaming}): + return evaluate(arg) + + def _bound_var_order(args, kwargs, bound_set): + """Return bound variables in deterministic encounter order.""" + seen: list[Operation] = [] + seen_set: set[Operation] = set() + + def _capture(op, *a, **kw): + if op in bound_set and op not in seen_set: + seen.append(op) + seen_set.add(op) + return defdata(op, *a, **kw) + + # ``evaluate`` walks Terms, lists, tuples, mappings, dataclasses, + # etc. for free; the apply handler captures bound vars used as + # ``x()`` anywhere in the body. + with interpreter({apply: _capture}): + evaluate((args, kwargs)) + + # Binders bypass the apply handler. Pick them up with a small structural + # walk that visits dict keys too. + def _walk_bare(obj): + if isinstance(obj, Operation): + if obj in bound_set and obj not in seen_set: + seen.append(obj) + seen_set.add(obj) + elif isinstance(obj, dict): + for k, v in obj.items(): + _walk_bare(k) + _walk_bare(v) + elif isinstance(obj, list | set | frozenset | tuple): + for v in obj: + _walk_bare(v) + + _walk_bare((args, kwargs)) + return seen + + def _apply_canonical(op, *args, **kwargs): + bindings = op.__fvs_rule__(*args, **kwargs) + all_bound: set[Operation] = set().union( + *bindings.args, *bindings.kwargs.values() + ) + if not all_bound: + return defdata(op, *args, **kwargs) + + order = _bound_var_order(args, kwargs, all_bound) + canonical = {var: _canonical_op(next(counter)) for var in order} + assert all_bound <= set(order) + + new_args = tuple( + _substitute( + arg, {v: canonical[v] for v in bindings.args[i] if v in canonical} + ) + for i, arg in enumerate(args) + ) + new_kwargs = { + k: _substitute( + v, + {var: canonical[var] for var in bindings.kwargs[k] if var in canonical}, + ) + for k, v in kwargs.items() + } + + # avoid the renaming from defdata + return _BaseTerm(op, *new_args, **new_kwargs) + + with interpreter({apply: _apply_canonical}): + return evaluate(expr) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +@given(a=_INT, b=_INT, c=_INT) +@settings(max_examples=50, deadline=None) +def test_associativity(monoid, a, b, c): + left = monoid.plus(monoid.plus(a, b), c) + right = monoid.plus(a, monoid.plus(b, c)) + assert left == right + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +@given(a=_INT) +@settings(max_examples=50, deadline=None) +def test_identity(monoid, a): + assert monoid.plus(monoid.identity, a) == a + assert monoid.plus(a, monoid.identity) == a + + +@pytest.mark.parametrize("monoid", COMMUTATIVE) +@given(a=_INT, b=_INT) +@settings(max_examples=50, deadline=None) +def test_commutativity(monoid, a, b): + assert monoid.plus(a, b) == monoid.plus(b, a) + + +@pytest.mark.parametrize("monoid", IDEMPOTENT) +@given(a=_INT) +@settings(max_examples=50, deadline=None) +def test_idempotence(monoid, a): + assert monoid.plus(a, a) == a + + +@pytest.mark.parametrize("monoid", WITH_ZERO) +@given(a=_INT) +@settings(max_examples=50, deadline=None) +def test_zero_absorbs(monoid, a): + assert monoid.plus(monoid.zero, a) == monoid.zero + assert monoid.plus(a, monoid.zero) == monoid.zero + + +def _check_pair(lhs, rhs, *, free_vars=[], max_examples: int = 25) -> None: + """Run structural + semantic checks on a TermPair.""" + with handler(NormalizeIntp): + norm = evaluate(lhs) + + assert syntactic_eq_alpha(norm, rhs) + + @given(intp=random_interpretation(free_vars)) + @settings(max_examples=max_examples, deadline=None) + def _check_semantics(intp): + with handler(intp): + lhs_val = evaluate(lhs) + rhs_val = evaluate(rhs) + assert lhs_val == rhs_val + + _check_semantics() + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_empty(monoid): + _check_pair(lhs=monoid.plus(), rhs=monoid.identity) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_single(monoid): + x = define_vars("x", typ=type(monoid.identity)) + _check_pair(lhs=monoid.plus(x()), rhs=x(), free_vars=[x]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_identity_right(monoid): + x = define_vars("x", typ=type(monoid.identity)) + _check_pair(lhs=monoid.plus(x(), monoid.identity), rhs=x(), free_vars=[x]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_identity_left(monoid): + x = define_vars("x", typ=type(monoid.identity)) + _check_pair(lhs=monoid.plus(monoid.identity, x()), rhs=x(), free_vars=[x]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_assoc_right(monoid): + x, y, z = define_vars("x", "y", "z", typ=type(monoid.identity)) + _check_pair( + lhs=monoid.plus(x(), monoid.plus(y(), z())), + rhs=monoid.plus(x(), y(), z()), + free_vars=[x, y, z], + ) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_assoc_left(monoid): + x, y, z = define_vars("x", "y", "z", typ=type(monoid.identity)) + _check_pair( + lhs=monoid.plus(monoid.plus(x(), y()), z()), + rhs=monoid.plus(x(), y(), z()), + free_vars=[x, y, z], + ) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_sequence(monoid): + a, b, c, d = define_vars("a", "b", "c", "d", typ=type(monoid.identity)) + _check_pair( + lhs=monoid.plus([a(), b()], [c(), d()]), + rhs=[monoid.plus(a(), c()), monoid.plus(b(), d())], + free_vars=[a, b, c, d], + ) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_mapping(monoid): + a, b, c, d = define_vars("a", "b", "c", "d", typ=type(monoid.identity)) + _check_pair( + lhs=monoid.plus({"x": a(), "y": b()}, {"x": c(), "z": d()}), + rhs={"x": monoid.plus(a(), c()), "y": b(), "z": d()}, + free_vars=[a, b, c, d], + ) + + +def test_plus_distributes(): + a, b, c, d = define_vars("a", "b", "c", "d") + lhs = Product.plus(Sum.plus(a(), b()), Sum.plus(c(), d())) + rhs = Sum.plus( + Product.plus(a(), c()), + Product.plus(a(), d()), + Product.plus(b(), c()), + Product.plus(b(), d()), + ) + _check_pair(lhs=lhs, rhs=rhs, free_vars=[a, b, c, d]) + + +def test_plus_distributes_constant(): + a, b, c, d = define_vars("a", "b", "c", "d") + lhs = Product.plus(Sum.plus(a(), b()), Sum.plus(c(), d()), 5) + rhs = Product.plus( + 5, + Sum.plus( + Product.plus(a(), c()), + Product.plus(a(), d()), + Product.plus(b(), c()), + Product.plus(b(), d()), + ), + ) + _check_pair(lhs=lhs, rhs=rhs, free_vars=[a, b, c, d]) + + +def test_plus_distributes_multiple(): + a, b, c, d = define_vars("a", "b", "c", "d") + lhs = Sum.plus( + Min.plus(a(), b()), + Min.plus(c(), d()), + Max.plus(a(), b()), + Max.plus(c(), d()), + ) + rhs = Sum.plus( + Min.plus( + Sum.plus(a(), c()), + Sum.plus(a(), d()), + Sum.plus(b(), c()), + Sum.plus(b(), d()), + ), + Max.plus( + Sum.plus(a(), c()), + Sum.plus(a(), d()), + Sum.plus(b(), c()), + Sum.plus(b(), d()), + ), + ) + _check_pair(lhs=lhs, rhs=rhs, free_vars=[a, b, c, d]) + + +@pytest.mark.parametrize("monoid", IDEMPOTENT) +def test_plus_idempotent_consecutive(monoid): + """``a, a, b → a, b`` — only consecutive duplicates collapse.""" + a, b = define_vars("a", "b") + lhs = monoid.plus(a(), a(), b()) + return _check_pair(lhs=lhs, rhs=monoid.plus(a(), b()), free_vars=[a, b]) + + +@pytest.mark.parametrize("monoid", IDEMPOTENT) +def test_plus_idempotent_non_consecutive(monoid): + """``a, b, a`` — Semilattice (Min/Max) collapses via commutative + PlusDups; plain IdempotentMonoid leaves it as-is (consecutive-only).""" + a, b = define_vars("a", "b") + lhs = monoid.plus(a(), b(), a()) + if isinstance(monoid, Semilattice): + rhs = monoid.plus(a(), b()) + else: + rhs = monoid.plus(a(), b(), a()) + _check_pair(lhs=lhs, rhs=rhs, free_vars=[a, b]) + + +def test_plus_commutative_idempotent_long(): + """Long alternation collapses via commutative dedup (Min/Max only).""" + a, b = define_vars("a", "b") + lhs = Min.plus(a(), b(), a(), b(), b(), a(), a()) + _check_pair(lhs=lhs, rhs=Min.plus(a(), b()), free_vars=[a, b]) + + +@pytest.mark.parametrize("monoid", WITH_ZERO) +def test_plus_zero(monoid): + a = define_vars("a") + lhs_right = monoid.plus(a(), monoid.zero) + lhs_left = monoid.plus(monoid.zero, a()) + _check_pair(lhs=lhs_right, rhs=monoid.zero, free_vars=[a]) + _check_pair(lhs=lhs_left, rhs=monoid.zero, free_vars=[a]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_reduce_body_sequence(monoid): + x = Operation.define(int, name="x") + X = Operation.define(list[int], name="X") + + @Operation.define + def f(_x: int) -> int: + raise NotHandled + + g = Operation.define(f, name="g") + + lhs = monoid.reduce([f(x()), g(x())], {x: X()}) + rhs = [monoid.reduce(f(x()), {x: X()}), monoid.reduce(g(x()), {x: X()})] + + _check_pair(lhs=lhs, rhs=rhs, free_vars=[X, f, g]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_reduce_body_sequence_2(monoid): + x, y = define_vars("x", "y") + X, Y = define_vars("X", "Y", typ=list[int]) + + @Operation.define + def f(_x: int) -> int: + raise NotHandled + + g = Operation.define(f, name="g") + + lhs = monoid.reduce([f(x()), g(y())], {x: X(), y: Y()}) + rhs = [ + monoid.reduce(f(x()), {x: X(), y: Y()}), + monoid.reduce(g(y()), {x: X(), y: Y()}), + ] + + _check_pair(lhs=lhs, rhs=rhs, free_vars=[X, Y, f, g]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_reduce_body_mapping(monoid): + x = Operation.define(int, name="x") + X = Operation.define(list[int], name="X") + + @Operation.define + def f(_x: int) -> int: + raise NotHandled + + g = Operation.define(f, name="g") + + lhs = monoid.reduce({"a": f(x()), "b": g(x())}, {x: X()}) + rhs = { + "a": monoid.reduce(f(x()), {x: X()}), + "b": monoid.reduce(g(x()), {x: X()}), + } + _check_pair(lhs=lhs, rhs=rhs, free_vars=[X, f, g]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_reduce_no_streams(monoid): + a = define_vars("a") + lhs = monoid.reduce(a(), {}) + rhs = monoid.identity + + _check_pair(lhs=lhs, rhs=rhs, free_vars=[a]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_reduce_reduce(monoid): + a, b = define_vars("a", "b") + A, B = define_vars("A", "B", typ=list[int]) + + @Operation.define + def f(_x: int, _y: int) -> int: + raise NotHandled + + lhs = monoid.reduce(monoid.reduce(f(a(), b()), {a: A()}), {b: B()}) + rhs = monoid.reduce(f(a(), b()), {a: A(), b: B()}) + + _check_pair(lhs=lhs, rhs=rhs, free_vars=[A, B, f]) + + +@pytest.mark.parametrize("monoid", COMMUTATIVE) +def test_reduce_plus(monoid): + a, b = define_vars("a", "b") + A, B = define_vars("A", "B", typ=list[int]) + lhs = monoid.reduce(monoid.plus(a(), b()), {a: A(), b: B()}) + rhs = monoid.plus( + monoid.reduce(a(), {a: A(), b: B()}), + monoid.reduce(b(), {a: A(), b: B()}), + ) + _check_pair(lhs=lhs, rhs=rhs, free_vars=[A, B]) + + +def test_reduce_independent_1(): + a, b = define_vars("a", "b") + A, B = define_vars("A", "B", typ=list[int]) + lhs = Sum.reduce(Product.plus(a(), b()), {a: A(), b: B()}) + rhs = Product.plus(Sum.reduce(a(), {a: A()}), Sum.reduce(b(), {b: B()})) + _check_pair(lhs=lhs, rhs=rhs, free_vars=[A, B]) + + +def test_reduce_independent_2(): + a, b, c = define_vars("a", "b", "c") + A, B, C = define_vars("A", "B", "C", typ=list[int]) + + @Operation.define + def f(_x: int, _y: int) -> int: + raise NotHandled + + lhs = Sum.reduce(Product.plus(a(), b(), f(b(), c())), {a: A(), b: B(), c: C()}) + rhs = Product.plus( + Sum.reduce(a(), {a: A()}), + Sum.reduce(Product.plus(b(), f(b(), c())), {b: B(), c: C()}), + ) + _check_pair(lhs=lhs, rhs=rhs, free_vars=[A, B, C, f]) + + +def test_reduce_independent_3_negative(): + """Stream `b` depends on `a` (b: g(a())), so the proposed factorization + is unsound — the normalizer must NOT apply it.""" + a, b, c = define_vars("a", "b", "c") + A, C = define_vars("A", "C", typ=list[int]) + + @Operation.define + def f(_x: int, _y: int) -> int: + raise NotHandled + + @Operation.define + def g(_x: int) -> list[int]: + raise NotHandled + + with handler(NormalizeIntp): + lhs = Sum.reduce( + Product.plus(a(), b(), f(b(), c())), {a: A(), b: g(a()), c: C()} + ) + bogus_rhs = Product.plus( + Sum.reduce(a(), {a: A()}), + Sum.reduce(Product.plus(b(), f(b(), c())), {b: g(a()), c: C()}), + ) + assert fvsof(bogus_rhs) != fvsof(lhs) + # Structural-only negative check: the normalizer correctly refused to apply + # the bogus factorization. + assert not syntactic_eq_alpha(lhs, bogus_rhs) + + +def test_reduce_independent_4(): + a, b, c = define_vars("a", "b", "c") + A, B, C = define_vars("A", "B", "C", typ=list[int]) + + @Operation.define + def f(_x: int, _y: int) -> int: + raise NotHandled + + lhs = Sum.reduce(Product.plus(a(), b(), f(b(), c()), 7), {a: A(), b: B(), c: C()}) + rhs = Product.plus( + 7, + Sum.reduce(a(), {a: A()}), + Sum.reduce(Product.plus(b(), f(b(), c())), {b: B(), c: C()}), + ) + _check_pair(lhs=lhs, rhs=rhs, free_vars=[A, B, C, f]) From 1d38f0dbbf322d4f262f8c95e366f4318f731beb Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Thu, 7 May 2026 15:27:57 -0400 Subject: [PATCH 2/3] wip --- effectful/internals/product_n.py | 2 +- effectful/ops/monoid.py | 174 +++++++++++++++++++++++++++- effectful/ops/semantics.py | 1 + effectful/ops/types.py | 5 +- tests/_monoid_helpers.py | 12 +- tests/test_handlers_llm_provider.py | 2 +- tests/test_ops_monoid.py | 95 +++++++++++++-- tests/test_ops_syntax.py | 2 +- 8 files changed, 269 insertions(+), 24 deletions(-) diff --git a/effectful/internals/product_n.py b/effectful/internals/product_n.py index 4b8bd2a81..87a9c6a42 100644 --- a/effectful/internals/product_n.py +++ b/effectful/internals/product_n.py @@ -69,7 +69,7 @@ def map_structure(func, expr): else: return type(expr)(map_structure(func, tuple(expr.items()))) elif isinstance(expr, collections.abc.Sequence): - if isinstance(expr, str | bytes): + if isinstance(expr, str | bytes | range): return expr elif ( isinstance(expr, tuple) diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index 58a10ba3d..748eb9cf3 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -10,18 +10,20 @@ from typing import Annotated, Any from effectful.internals.disjoint_set import DisjointSet +from effectful.internals.runtime import interpreter from effectful.ops.semantics import coproduct, evaluate, fvsof, fwd, handler from effectful.ops.syntax import ( ObjectInterpretation, Scoped, _NumberTerm, defdata, + deffn, implements, iter_, syntactic_eq, syntactic_hash, ) -from effectful.ops.types import Interpretation, NotHandled, Operation, Term +from effectful.ops.types import Expr, Interpretation, NotHandled, Operation, Term # Note: The streams value type should be something like Iterable[T], but some of # our target stream types (e.g. jax.Array) are not subtypes of Iterable @@ -80,9 +82,13 @@ def plus[S: Body[T]](self, *args: S) -> S: def _plus[S](self, *args: S) -> S: return typing.cast(S, functools.reduce(self.kernel, args, self.identity)) - @_plus.register(Sequence) + @_plus.register(tuple) def _(self, *args): - return type(args[0])(self.plus(*vs) for vs in zip(*args, strict=True)) + return tuple(self.plus(*vs) for vs in zip(*args, strict=True)) + + @_plus.register(Generator) + def _(self, *args): + return (self.plus(*vs) for vs in zip(*args, strict=True)) @_plus.register(Mapping) def _(self, *args): @@ -161,8 +167,8 @@ def _(self, body: Mapping, streams): return {k: self.reduce(v, streams) for (k, v) in body.items()} @reduce.register # type: ignore[attr-defined] - def _(self, body: Sequence, streams): - return type(body)(self.reduce(x, streams) for x in body) # type:ignore[call-arg] + def _(self, body: tuple, streams): + return tuple(self.reduce(x, streams) for x in body) @reduce.register # type: ignore[attr-defined] def _(self, body: Generator, streams): @@ -252,12 +258,26 @@ def _arg_max[T]( return b if b[0] > a[0] else a # type: ignore +@Operation.define +def product[T]( + a: Iterable[tuple[T, ...] | T], b: Iterable[tuple[T, ...] | T] +) -> Iterable[tuple[T, ...]]: + if isinstance(a, Term) or isinstance(b, Term): + raise NotHandled + + def to_tuple(x): + return x if isinstance(x, tuple) else (x,) + + return [to_tuple(x) + to_tuple(y) for (x, y) in itertools.product(a, b)] + + Min = Semilattice(kernel=min, identity=float("inf")) Max = Semilattice(kernel=max, identity=float("-inf")) ArgMin = Monoid(kernel=_arg_min, identity=(float("inf"), None)) ArgMax = Monoid(kernel=_arg_max, identity=(float("-inf"), None)) Sum = CommutativeMonoid(kernel=_NumberTerm.__add__, identity=0) Product = CommutativeMonoidWithZero(kernel=_NumberTerm.__mul__, identity=1, zero=0) +CartesianProduct = Monoid(kernel=product, identity=[()]) @dataclass @@ -545,11 +565,153 @@ def reduce(self, monoid, body, streams): return fwd() +def outer_stream( + streams: dict[Operation, Expr], +) -> Iterable[tuple[Operation, Expr, dict[Operation, Expr]]]: + """Returns the streams that can be ordered outermost in the loop nest as + well as the remaining streams in the nest. + + """ + stream_vars = set(streams.keys()) + pred = {k: fvsof(v) & stream_vars for k, v in streams.items()} + topo = TopologicalSorter(pred) + topo.prepare() + return ( + (op, streams[op], {k: v for (k, v) in streams.items() if k != op}) + for op in topo.get_ready() + ) + + +def inner_stream( + streams: dict[Operation, Expr], +) -> Iterable[tuple[dict[Operation, Expr], Operation, Expr]]: + """Returns the streams that can be ordered innermost in the loop nest as + well as the remaining streams in the nest. + + """ + stream_vars = set(streams.keys()) + + no_dependents = set() + succ = defaultdict(set) + for k, v in streams.items(): + for pred in fvsof(v) & stream_vars: + succ[pred].add(k) + else: + no_dependents.add(k) + + topo = TopologicalSorter(succ) + topo.prepare() + return ( + ({k: v for (k, v) in streams.items() if k != op}, op, streams[op]) + for op in set(topo.get_ready()) | no_dependents + ) + + +def match_reduce(term: Term) -> tuple | None: + reduce_args = None + + def set_reduce_args(*args, **kwargs): + nonlocal reduce_args + reduce_args = args + + with interpreter({Monoid.reduce: set_reduce_args}): + term.op(*term.args, **term.kwargs) + return reduce_args + + +class ReduceDistributeCartesianProduct(ObjectInterpretation): + """Eliminates a reduce over a cartesian product. + ∑_x₁ ∑_x₂ ... ∑_xₙ ∏_i f(xᵢ) = ∏_i ∑_xᵢ f(xᵢ) + This transform is also called inversion in the lifting + literature (e.g. [1]). + + More specifically, this transform implements the identity + reduce(⨁, reduce(⨂, body2, {vv: v()}), {v: reduce(×, body1, S1)} ∪ S2) + = reduce(⨁, reduce(⨂, reduce(⨁, body2, {vv: v()}), S1), S2) + where × is the cartesian product and ⨂ distributes over ⨁. + + Note: This could be generalized to grouped inversion [2]. + + [1] Braz, Rd, Eyal Amir, and Dan Roth. "Lifted first-order + probabilistic inference." IJCAI. 2005. + [2] Taghipour, Nima, et al. "Completeness results for lifted + variable elimination." AISTATS. 2013. + """ + + @implements(CommutativeMonoid.reduce) + def reduce(self, sum_monoid: Monoid, sum_body, sum_streams): + if not (isinstance(sum_body, Term)): + return fwd() + + # body is a product or multiplication of products + if distributes_over(sum_body.op, sum_monoid.plus): + prod_reduces = sum_body.args + else: + prod_reduces = [sum_body] + + products: list[tuple[Monoid, Callable, Operation, Term]] = [] + for prod_reduce in prod_reduces: + prod_args = match_reduce(prod_reduce) + if prod_args is None: + return fwd() + (prod_monoid, prod_body, prod_streams) = prod_args + if not ( + distributes_over(prod_monoid.plus, sum_monoid.plus) + and (len(products) == 0 or products[-1][0] == prod_monoid) + ): + return fwd() + + if len(prod_streams) > 1 or len(prod_streams) == 0: + return fwd() + (prod_op, prod_stream) = next(iter(prod_streams.items())) + products.append( + (prod_monoid, deffn(prod_body, prod_op), prod_op, prod_stream) + ) + + assert len(products) > 0 + + for outer_sum_streams, cprod_op, cprod_term in inner_stream(sum_streams): + if not ( + isinstance(cprod_term, Term) + and cprod_term.op == CartesianProduct.reduce + ): + continue + (cprod_body, cprod_streams) = cprod_term.args + + if not all( + prod_stream.op == cprod_op for (_, _, _, prod_stream) in products + ): + continue + + prod_op = Operation.define(products[0][2]) + inner_sum = sum_monoid.reduce( + Product.plus( + *(prod_body(prod_op()) for (_, prod_body, _, _) in products) + ), + {prod_op: cprod_body}, + ) + prod = prod_monoid.reduce(inner_sum, cprod_streams) + outer_sum = ( + sum_monoid.reduce(prod, outer_sum_streams) + if outer_sum_streams + else prod + ) + return outer_sum + + return fwd() + + NormalizeReduceIntp = functools.reduce( coproduct, typing.cast( list[Interpretation], - [ReduceNoStreams(), ReduceFusion(), ReduceSplit(), ReduceFactorization()], + [ + ReduceNoStreams(), + ReduceFusion(), + ReduceSplit(), + ReduceFactorization(), + ReduceDistributeCartesianProduct(), + ], ), ) diff --git a/effectful/ops/semantics.py b/effectful/ops/semantics.py index f7678fd24..8fd62bcd5 100644 --- a/effectful/ops/semantics.py +++ b/effectful/ops/semantics.py @@ -209,6 +209,7 @@ def evaluate[T]( @evaluate.register(object) @evaluate.register(str) @evaluate.register(bytes) +@evaluate.register(range) def _evaluate_object[T](expr: T, **kwargs) -> T: if dataclasses.is_dataclass(expr) and not isinstance(expr, type): return typing.cast( diff --git a/effectful/ops/types.py b/effectful/ops/types.py index 40c1f4af5..d24be9745 100644 --- a/effectful/ops/types.py +++ b/effectful/ops/types.py @@ -488,7 +488,10 @@ def _instance_op(instance, *args, **kwargs): else: return default_result - instance_op = self.define(types.MethodType(_instance_op, instance)) + name = ("" if owner is None else f"{owner.__name__}_") + self.__name__ + instance_op = self.define( + types.MethodType(_instance_op, instance), name=name + ) instance.__dict__[self._name_on_instance] = instance_op return instance_op elif instance is not None: diff --git a/tests/_monoid_helpers.py b/tests/_monoid_helpers.py index 4532ae72d..f6397053b 100644 --- a/tests/_monoid_helpers.py +++ b/tests/_monoid_helpers.py @@ -21,7 +21,7 @@ def _value_strategy_for(annotation: Any) -> st.SearchStrategy[Any]: ) -_UNARY_INT_FNS: list[Callable[[int], int]] = [ +_UNARY_NUM_FNS: list[Callable[[int], int]] = [ lambda x: x, lambda x: x + 1, lambda x: x - 1, @@ -30,7 +30,7 @@ def _value_strategy_for(annotation: Any) -> st.SearchStrategy[Any]: lambda x: 3 * x + 1, ] -_BINARY_INT_FNS: list[Callable[[int, int], int]] = [ +_BINARY_NUM_FNS: list[Callable[[int, int], int]] = [ lambda x, y: x + y, lambda x, y: x - y, lambda x, y: x * y, @@ -58,10 +58,10 @@ def _strategy_for_op(op: Operation) -> st.SearchStrategy[Callable[..., Any]]: if not params: return _value_strategy_for(ret).map(deffn) - if ret is int and param_types == (int,): - return st.sampled_from(_UNARY_INT_FNS) - if ret is int and param_types == (int, int): - return st.sampled_from(_BINARY_INT_FNS) + if ret in (int, float) and param_types == (int,): + return st.sampled_from(_UNARY_NUM_FNS) + if ret in (int, float) and param_types == (int, int): + return st.sampled_from(_BINARY_NUM_FNS) if get_origin(ret) is list and get_args(ret) == (int,) and param_types == (int,): return st.sampled_from(_UNARY_LIST_FNS) raise NotImplementedError( diff --git a/tests/test_handlers_llm_provider.py b/tests/test_handlers_llm_provider.py index b56fd7bbd..9a2983901 100644 --- a/tests/test_handlers_llm_provider.py +++ b/tests/test_handlers_llm_provider.py @@ -240,7 +240,7 @@ def test_agent_tool_names_are_valid_integration(): agent = _ToolNameAgent() template = agent.ask tools = template.tools - expected_helper_tool_name = f"self__{agent.helper.__name__}" + expected_helper_tool_name = "self__helper" assert tools assert expected_helper_tool_name in tools assert all(re.fullmatch(r"[a-zA-Z0-9_-]+", name) for name in tools) diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py index a22928cca..e1c1400ff 100644 --- a/tests/test_ops_monoid.py +++ b/tests/test_ops_monoid.py @@ -6,7 +6,15 @@ from hypothesis import strategies as st from effectful.internals.runtime import interpreter -from effectful.ops.monoid import Max, Min, NormalizeIntp, Product, Semilattice, Sum +from effectful.ops.monoid import ( + CartesianProduct, + Max, + Min, + NormalizeIntp, + Product, + Semilattice, + Sum, +) from effectful.ops.semantics import apply, evaluate, fvsof, handler from effectful.ops.syntax import _BaseTerm, defdata, syntactic_eq from effectful.ops.types import NotHandled, Operation @@ -70,14 +78,11 @@ def syntactic_eq_alpha(x, y) -> bool: def _canonicalize(expr): counter = itertools.count() - def _passthrough(op, *args, **kwargs): - return defdata(op, *args, **kwargs) - def _substitute(arg, renaming): """Apply a bound-variable renaming using ``evaluate`` for traversal.""" if not renaming: return arg - with interpreter({apply: _passthrough, **renaming}): + with interpreter({apply: _BaseTerm, **renaming}): return evaluate(arg) def _bound_var_order(args, kwargs, bound_set): @@ -121,7 +126,7 @@ def _apply_canonical(op, *args, **kwargs): *bindings.args, *bindings.kwargs.values() ) if not all_bound: - return defdata(op, *args, **kwargs) + return _BaseTerm(op, *args, **kwargs) order = _bound_var_order(args, kwargs, all_bound) canonical = {var: _canonical_op(next(counter)) for var in order} @@ -496,8 +501,6 @@ def g(_x: int) -> list[int]: Sum.reduce(Product.plus(b(), f(b(), c())), {b: g(a()), c: C()}), ) assert fvsof(bogus_rhs) != fvsof(lhs) - # Structural-only negative check: the normalizer correctly refused to apply - # the bogus factorization. assert not syntactic_eq_alpha(lhs, bogus_rhs) @@ -516,3 +519,79 @@ def f(_x: int, _y: int) -> int: Sum.reduce(Product.plus(b(), f(b(), c())), {b: B(), c: C()}), ) _check_pair(lhs=lhs, rhs=rhs, free_vars=[A, B, C, f]) + + +def test_reduce_lifted_1(): + a, i = define_vars("a", "i") + A, N, A_domain = define_vars("A", "N", "A_domain", typ=list[int]) + + @Operation.define + def f(_: int) -> float: + raise NotHandled + + term1 = Sum.reduce( + Product.reduce(f(a()), {a: A()}), + {A: CartesianProduct.reduce(A_domain(), {i: N()})}, + ) + term2 = Product.reduce(Sum.reduce(f(a()), {a: A_domain()}), {i: N()}) + _check_pair(lhs=term1, rhs=term2, free_vars=[N, A_domain, f]) + + +def test_reduce_cartesian_1(): + a, i = define_vars("a", "i") + A = define_vars("A", typ=list[int]) + + term1 = Sum.reduce( + Product.reduce(a(), {a: []}), + {A: CartesianProduct.reduce([], {i: []})}, + ) + term2 = Product.reduce(Sum.reduce(a(), {a: []}), {i: []}) + assert term1 == term2 + + +def test_reduce_cartesian_2(): + a, i = define_vars("a", "i") + A = define_vars("A", typ=list[int]) + + term1 = Sum.reduce( + Product.reduce(a(), {a: A()}), + {A: CartesianProduct.reduce([(0,)], {i: [0]})}, + ) + term2 = Product.reduce(Sum.reduce(a(), {a: [0]}), {i: [0]}) + assert term1 == term2 + + +def test_reduce_lifted_2(): + """The worked example on page 396 of 'Lifted Variable Elimination: + Decoupling the Operators from the Constraint Language'. + + """ + a, i, s, t = define_vars("a", "i", "s", "t") + A, N, T = define_vars("A", "N", "T", typ=list[int]) + + @Operation.define + def A_domain(_i: int) -> list[int]: + raise NotHandled + + @Operation.define + def f1(_a: int, _s: int) -> float: + raise NotHandled + + @Operation.define + def f2(_t: int, _a: int) -> float: + raise NotHandled + + term1 = Sum.reduce( + Product.reduce(Product.plus(f1(a(), s()), f2(t(), a())), {a: A()}), + {A: CartesianProduct.reduce(A_domain(i()), {i: N()}), t: T()}, + ) + + term2 = Sum.reduce( + Product.reduce( + Sum.reduce(Product.plus(f1(a(), s()), f2(t(), a())), {a: A_domain(i())}), + {i: N()}, + ), + {t: T()}, + ) + + _check_pair(lhs=term1, rhs=term2, free_vars=[a, i, s, t, A, N, T, A_domain, f1, f2]) diff --git a/tests/test_ops_syntax.py b/tests/test_ops_syntax.py index 185b6132e..af8935eca 100644 --- a/tests/test_ops_syntax.py +++ b/tests/test_ops_syntax.py @@ -489,7 +489,7 @@ def _(self, x: bool) -> bool: ) assert isinstance(term_float, Term) - assert term_float.op.__name__ == "my_singledispatch" + assert term_float.op.__name__ == "MyClass_my_singledispatch" assert term_float.args == (1.5,) assert term_float.kwargs == {} From 92586c43e3276cc3a883764012c0b65b9b7377eb Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Thu, 7 May 2026 15:41:39 -0400 Subject: [PATCH 3/3] cleanup --- effectful/ops/monoid.py | 19 +------------------ tests/_monoid_helpers.py | 2 +- tests/test_ops_monoid.py | 14 +++++++------- tests/test_ops_syntax.py | 1 - 4 files changed, 9 insertions(+), 27 deletions(-) diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index 748eb9cf3..575838c29 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -4,7 +4,7 @@ import numbers import typing from collections import Counter, defaultdict -from collections.abc import Callable, Generator, Iterable, Iterator, Mapping, Sequence +from collections.abc import Callable, Generator, Iterable, Iterator, Mapping from dataclasses import dataclass from graphlib import TopologicalSorter from typing import Annotated, Any @@ -565,23 +565,6 @@ def reduce(self, monoid, body, streams): return fwd() -def outer_stream( - streams: dict[Operation, Expr], -) -> Iterable[tuple[Operation, Expr, dict[Operation, Expr]]]: - """Returns the streams that can be ordered outermost in the loop nest as - well as the remaining streams in the nest. - - """ - stream_vars = set(streams.keys()) - pred = {k: fvsof(v) & stream_vars for k, v in streams.items()} - topo = TopologicalSorter(pred) - topo.prepare() - return ( - (op, streams[op], {k: v for (k, v) in streams.items() if k != op}) - for op in topo.get_ready() - ) - - def inner_stream( streams: dict[Operation, Expr], ) -> Iterable[tuple[dict[Operation, Expr], Operation, Expr]]: diff --git a/tests/_monoid_helpers.py b/tests/_monoid_helpers.py index f6397053b..772f91ecf 100644 --- a/tests/_monoid_helpers.py +++ b/tests/_monoid_helpers.py @@ -14,7 +14,7 @@ def _value_strategy_for(annotation: Any) -> st.SearchStrategy[Any]: if annotation is float: return st.floats(allow_nan=False) if get_origin(annotation) is list and get_args(annotation) == (int,): - return st.lists(st.integers()) + return st.lists(st.integers(), max_size=3) raise NotImplementedError( f"No value strategy for return annotation {annotation!r}; " "supported: int, list[int]" diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py index e1c1400ff..8dbd32b6e 100644 --- a/tests/test_ops_monoid.py +++ b/tests/test_ops_monoid.py @@ -257,8 +257,8 @@ def test_plus_assoc_left(monoid): def test_plus_sequence(monoid): a, b, c, d = define_vars("a", "b", "c", "d", typ=type(monoid.identity)) _check_pair( - lhs=monoid.plus([a(), b()], [c(), d()]), - rhs=[monoid.plus(a(), c()), monoid.plus(b(), d())], + lhs=monoid.plus((a(), b()), (c(), d())), + rhs=(monoid.plus(a(), c()), monoid.plus(b(), d())), free_vars=[a, b, c, d], ) @@ -373,8 +373,8 @@ def f(_x: int) -> int: g = Operation.define(f, name="g") - lhs = monoid.reduce([f(x()), g(x())], {x: X()}) - rhs = [monoid.reduce(f(x()), {x: X()}), monoid.reduce(g(x()), {x: X()})] + lhs = monoid.reduce((f(x()), g(x())), {x: X()}) + rhs = (monoid.reduce(f(x()), {x: X()}), monoid.reduce(g(x()), {x: X()})) _check_pair(lhs=lhs, rhs=rhs, free_vars=[X, f, g]) @@ -390,11 +390,11 @@ def f(_x: int) -> int: g = Operation.define(f, name="g") - lhs = monoid.reduce([f(x()), g(y())], {x: X(), y: Y()}) - rhs = [ + lhs = monoid.reduce((f(x()), g(y())), {x: X(), y: Y()}) + rhs = ( monoid.reduce(f(x()), {x: X(), y: Y()}), monoid.reduce(g(y()), {x: X(), y: Y()}), - ] + ) _check_pair(lhs=lhs, rhs=rhs, free_vars=[X, Y, f, g]) diff --git a/tests/test_ops_syntax.py b/tests/test_ops_syntax.py index af8935eca..1f5c47763 100644 --- a/tests/test_ops_syntax.py +++ b/tests/test_ops_syntax.py @@ -489,7 +489,6 @@ def _(self, x: bool) -> bool: ) assert isinstance(term_float, Term) - assert term_float.op.__name__ == "MyClass_my_singledispatch" assert term_float.args == (1.5,) assert term_float.kwargs == {}