diff --git a/effectful/internals/product_n.py b/effectful/internals/product_n.py index 4b8bd2a8..87a9c6a4 100644 --- a/effectful/internals/product_n.py +++ b/effectful/internals/product_n.py @@ -69,7 +69,7 @@ def map_structure(func, expr): else: return type(expr)(map_structure(func, tuple(expr.items()))) elif isinstance(expr, collections.abc.Sequence): - if isinstance(expr, str | bytes): + if isinstance(expr, str | bytes | range): return expr elif ( isinstance(expr, tuple) diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index 58a10ba3..575838c2 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -4,24 +4,26 @@ import numbers import typing from collections import Counter, defaultdict -from collections.abc import Callable, Generator, Iterable, Iterator, Mapping, Sequence +from collections.abc import Callable, Generator, Iterable, Iterator, Mapping from dataclasses import dataclass from graphlib import TopologicalSorter from typing import Annotated, Any from effectful.internals.disjoint_set import DisjointSet +from effectful.internals.runtime import interpreter from effectful.ops.semantics import coproduct, evaluate, fvsof, fwd, handler from effectful.ops.syntax import ( ObjectInterpretation, Scoped, _NumberTerm, defdata, + deffn, implements, iter_, syntactic_eq, syntactic_hash, ) -from effectful.ops.types import Interpretation, NotHandled, Operation, Term +from effectful.ops.types import Expr, Interpretation, NotHandled, Operation, Term # Note: The streams value type should be something like Iterable[T], but some of # our target stream types (e.g. jax.Array) are not subtypes of Iterable @@ -80,9 +82,13 @@ def plus[S: Body[T]](self, *args: S) -> S: def _plus[S](self, *args: S) -> S: return typing.cast(S, functools.reduce(self.kernel, args, self.identity)) - @_plus.register(Sequence) + @_plus.register(tuple) def _(self, *args): - return type(args[0])(self.plus(*vs) for vs in zip(*args, strict=True)) + return tuple(self.plus(*vs) for vs in zip(*args, strict=True)) + + @_plus.register(Generator) + def _(self, *args): + return (self.plus(*vs) for vs in zip(*args, strict=True)) @_plus.register(Mapping) def _(self, *args): @@ -161,8 +167,8 @@ def _(self, body: Mapping, streams): return {k: self.reduce(v, streams) for (k, v) in body.items()} @reduce.register # type: ignore[attr-defined] - def _(self, body: Sequence, streams): - return type(body)(self.reduce(x, streams) for x in body) # type:ignore[call-arg] + def _(self, body: tuple, streams): + return tuple(self.reduce(x, streams) for x in body) @reduce.register # type: ignore[attr-defined] def _(self, body: Generator, streams): @@ -252,12 +258,26 @@ def _arg_max[T]( return b if b[0] > a[0] else a # type: ignore +@Operation.define +def product[T]( + a: Iterable[tuple[T, ...] | T], b: Iterable[tuple[T, ...] | T] +) -> Iterable[tuple[T, ...]]: + if isinstance(a, Term) or isinstance(b, Term): + raise NotHandled + + def to_tuple(x): + return x if isinstance(x, tuple) else (x,) + + return [to_tuple(x) + to_tuple(y) for (x, y) in itertools.product(a, b)] + + Min = Semilattice(kernel=min, identity=float("inf")) Max = Semilattice(kernel=max, identity=float("-inf")) ArgMin = Monoid(kernel=_arg_min, identity=(float("inf"), None)) ArgMax = Monoid(kernel=_arg_max, identity=(float("-inf"), None)) Sum = CommutativeMonoid(kernel=_NumberTerm.__add__, identity=0) Product = CommutativeMonoidWithZero(kernel=_NumberTerm.__mul__, identity=1, zero=0) +CartesianProduct = Monoid(kernel=product, identity=[()]) @dataclass @@ -545,11 +565,136 @@ def reduce(self, monoid, body, streams): return fwd() +def inner_stream( + streams: dict[Operation, Expr], +) -> Iterable[tuple[dict[Operation, Expr], Operation, Expr]]: + """Returns the streams that can be ordered innermost in the loop nest as + well as the remaining streams in the nest. + + """ + stream_vars = set(streams.keys()) + + no_dependents = set() + succ = defaultdict(set) + for k, v in streams.items(): + for pred in fvsof(v) & stream_vars: + succ[pred].add(k) + else: + no_dependents.add(k) + + topo = TopologicalSorter(succ) + topo.prepare() + return ( + ({k: v for (k, v) in streams.items() if k != op}, op, streams[op]) + for op in set(topo.get_ready()) | no_dependents + ) + + +def match_reduce(term: Term) -> tuple | None: + reduce_args = None + + def set_reduce_args(*args, **kwargs): + nonlocal reduce_args + reduce_args = args + + with interpreter({Monoid.reduce: set_reduce_args}): + term.op(*term.args, **term.kwargs) + return reduce_args + + +class ReduceDistributeCartesianProduct(ObjectInterpretation): + """Eliminates a reduce over a cartesian product. + ∑_x₁ ∑_x₂ ... ∑_xₙ ∏_i f(xᵢ) = ∏_i ∑_xᵢ f(xᵢ) + This transform is also called inversion in the lifting + literature (e.g. [1]). + + More specifically, this transform implements the identity + reduce(⨁, reduce(⨂, body2, {vv: v()}), {v: reduce(×, body1, S1)} ∪ S2) + = reduce(⨁, reduce(⨂, reduce(⨁, body2, {vv: v()}), S1), S2) + where × is the cartesian product and ⨂ distributes over ⨁. + + Note: This could be generalized to grouped inversion [2]. + + [1] Braz, Rd, Eyal Amir, and Dan Roth. "Lifted first-order + probabilistic inference." IJCAI. 2005. + [2] Taghipour, Nima, et al. "Completeness results for lifted + variable elimination." AISTATS. 2013. + """ + + @implements(CommutativeMonoid.reduce) + def reduce(self, sum_monoid: Monoid, sum_body, sum_streams): + if not (isinstance(sum_body, Term)): + return fwd() + + # body is a product or multiplication of products + if distributes_over(sum_body.op, sum_monoid.plus): + prod_reduces = sum_body.args + else: + prod_reduces = [sum_body] + + products: list[tuple[Monoid, Callable, Operation, Term]] = [] + for prod_reduce in prod_reduces: + prod_args = match_reduce(prod_reduce) + if prod_args is None: + return fwd() + (prod_monoid, prod_body, prod_streams) = prod_args + if not ( + distributes_over(prod_monoid.plus, sum_monoid.plus) + and (len(products) == 0 or products[-1][0] == prod_monoid) + ): + return fwd() + + if len(prod_streams) > 1 or len(prod_streams) == 0: + return fwd() + (prod_op, prod_stream) = next(iter(prod_streams.items())) + products.append( + (prod_monoid, deffn(prod_body, prod_op), prod_op, prod_stream) + ) + + assert len(products) > 0 + + for outer_sum_streams, cprod_op, cprod_term in inner_stream(sum_streams): + if not ( + isinstance(cprod_term, Term) + and cprod_term.op == CartesianProduct.reduce + ): + continue + (cprod_body, cprod_streams) = cprod_term.args + + if not all( + prod_stream.op == cprod_op for (_, _, _, prod_stream) in products + ): + continue + + prod_op = Operation.define(products[0][2]) + inner_sum = sum_monoid.reduce( + Product.plus( + *(prod_body(prod_op()) for (_, prod_body, _, _) in products) + ), + {prod_op: cprod_body}, + ) + prod = prod_monoid.reduce(inner_sum, cprod_streams) + outer_sum = ( + sum_monoid.reduce(prod, outer_sum_streams) + if outer_sum_streams + else prod + ) + return outer_sum + + return fwd() + + NormalizeReduceIntp = functools.reduce( coproduct, typing.cast( list[Interpretation], - [ReduceNoStreams(), ReduceFusion(), ReduceSplit(), ReduceFactorization()], + [ + ReduceNoStreams(), + ReduceFusion(), + ReduceSplit(), + ReduceFactorization(), + ReduceDistributeCartesianProduct(), + ], ), ) diff --git a/effectful/ops/semantics.py b/effectful/ops/semantics.py index f7678fd2..8fd62bcd 100644 --- a/effectful/ops/semantics.py +++ b/effectful/ops/semantics.py @@ -209,6 +209,7 @@ def evaluate[T]( @evaluate.register(object) @evaluate.register(str) @evaluate.register(bytes) +@evaluate.register(range) def _evaluate_object[T](expr: T, **kwargs) -> T: if dataclasses.is_dataclass(expr) and not isinstance(expr, type): return typing.cast( diff --git a/effectful/ops/types.py b/effectful/ops/types.py index 40c1f4af..d24be974 100644 --- a/effectful/ops/types.py +++ b/effectful/ops/types.py @@ -488,7 +488,10 @@ def _instance_op(instance, *args, **kwargs): else: return default_result - instance_op = self.define(types.MethodType(_instance_op, instance)) + name = ("" if owner is None else f"{owner.__name__}_") + self.__name__ + instance_op = self.define( + types.MethodType(_instance_op, instance), name=name + ) instance.__dict__[self._name_on_instance] = instance_op return instance_op elif instance is not None: diff --git a/tests/_monoid_helpers.py b/tests/_monoid_helpers.py index 4532ae72..772f91ec 100644 --- a/tests/_monoid_helpers.py +++ b/tests/_monoid_helpers.py @@ -14,14 +14,14 @@ def _value_strategy_for(annotation: Any) -> st.SearchStrategy[Any]: if annotation is float: return st.floats(allow_nan=False) if get_origin(annotation) is list and get_args(annotation) == (int,): - return st.lists(st.integers()) + return st.lists(st.integers(), max_size=3) raise NotImplementedError( f"No value strategy for return annotation {annotation!r}; " "supported: int, list[int]" ) -_UNARY_INT_FNS: list[Callable[[int], int]] = [ +_UNARY_NUM_FNS: list[Callable[[int], int]] = [ lambda x: x, lambda x: x + 1, lambda x: x - 1, @@ -30,7 +30,7 @@ def _value_strategy_for(annotation: Any) -> st.SearchStrategy[Any]: lambda x: 3 * x + 1, ] -_BINARY_INT_FNS: list[Callable[[int, int], int]] = [ +_BINARY_NUM_FNS: list[Callable[[int, int], int]] = [ lambda x, y: x + y, lambda x, y: x - y, lambda x, y: x * y, @@ -58,10 +58,10 @@ def _strategy_for_op(op: Operation) -> st.SearchStrategy[Callable[..., Any]]: if not params: return _value_strategy_for(ret).map(deffn) - if ret is int and param_types == (int,): - return st.sampled_from(_UNARY_INT_FNS) - if ret is int and param_types == (int, int): - return st.sampled_from(_BINARY_INT_FNS) + if ret in (int, float) and param_types == (int,): + return st.sampled_from(_UNARY_NUM_FNS) + if ret in (int, float) and param_types == (int, int): + return st.sampled_from(_BINARY_NUM_FNS) if get_origin(ret) is list and get_args(ret) == (int,) and param_types == (int,): return st.sampled_from(_UNARY_LIST_FNS) raise NotImplementedError( diff --git a/tests/test_handlers_llm_provider.py b/tests/test_handlers_llm_provider.py index b56fd7bb..9a298390 100644 --- a/tests/test_handlers_llm_provider.py +++ b/tests/test_handlers_llm_provider.py @@ -240,7 +240,7 @@ def test_agent_tool_names_are_valid_integration(): agent = _ToolNameAgent() template = agent.ask tools = template.tools - expected_helper_tool_name = f"self__{agent.helper.__name__}" + expected_helper_tool_name = "self__helper" assert tools assert expected_helper_tool_name in tools assert all(re.fullmatch(r"[a-zA-Z0-9_-]+", name) for name in tools) diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py index a22928cc..8dbd32b6 100644 --- a/tests/test_ops_monoid.py +++ b/tests/test_ops_monoid.py @@ -6,7 +6,15 @@ from hypothesis import strategies as st from effectful.internals.runtime import interpreter -from effectful.ops.monoid import Max, Min, NormalizeIntp, Product, Semilattice, Sum +from effectful.ops.monoid import ( + CartesianProduct, + Max, + Min, + NormalizeIntp, + Product, + Semilattice, + Sum, +) from effectful.ops.semantics import apply, evaluate, fvsof, handler from effectful.ops.syntax import _BaseTerm, defdata, syntactic_eq from effectful.ops.types import NotHandled, Operation @@ -70,14 +78,11 @@ def syntactic_eq_alpha(x, y) -> bool: def _canonicalize(expr): counter = itertools.count() - def _passthrough(op, *args, **kwargs): - return defdata(op, *args, **kwargs) - def _substitute(arg, renaming): """Apply a bound-variable renaming using ``evaluate`` for traversal.""" if not renaming: return arg - with interpreter({apply: _passthrough, **renaming}): + with interpreter({apply: _BaseTerm, **renaming}): return evaluate(arg) def _bound_var_order(args, kwargs, bound_set): @@ -121,7 +126,7 @@ def _apply_canonical(op, *args, **kwargs): *bindings.args, *bindings.kwargs.values() ) if not all_bound: - return defdata(op, *args, **kwargs) + return _BaseTerm(op, *args, **kwargs) order = _bound_var_order(args, kwargs, all_bound) canonical = {var: _canonical_op(next(counter)) for var in order} @@ -252,8 +257,8 @@ def test_plus_assoc_left(monoid): def test_plus_sequence(monoid): a, b, c, d = define_vars("a", "b", "c", "d", typ=type(monoid.identity)) _check_pair( - lhs=monoid.plus([a(), b()], [c(), d()]), - rhs=[monoid.plus(a(), c()), monoid.plus(b(), d())], + lhs=monoid.plus((a(), b()), (c(), d())), + rhs=(monoid.plus(a(), c()), monoid.plus(b(), d())), free_vars=[a, b, c, d], ) @@ -368,8 +373,8 @@ def f(_x: int) -> int: g = Operation.define(f, name="g") - lhs = monoid.reduce([f(x()), g(x())], {x: X()}) - rhs = [monoid.reduce(f(x()), {x: X()}), monoid.reduce(g(x()), {x: X()})] + lhs = monoid.reduce((f(x()), g(x())), {x: X()}) + rhs = (monoid.reduce(f(x()), {x: X()}), monoid.reduce(g(x()), {x: X()})) _check_pair(lhs=lhs, rhs=rhs, free_vars=[X, f, g]) @@ -385,11 +390,11 @@ def f(_x: int) -> int: g = Operation.define(f, name="g") - lhs = monoid.reduce([f(x()), g(y())], {x: X(), y: Y()}) - rhs = [ + lhs = monoid.reduce((f(x()), g(y())), {x: X(), y: Y()}) + rhs = ( monoid.reduce(f(x()), {x: X(), y: Y()}), monoid.reduce(g(y()), {x: X(), y: Y()}), - ] + ) _check_pair(lhs=lhs, rhs=rhs, free_vars=[X, Y, f, g]) @@ -496,8 +501,6 @@ def g(_x: int) -> list[int]: Sum.reduce(Product.plus(b(), f(b(), c())), {b: g(a()), c: C()}), ) assert fvsof(bogus_rhs) != fvsof(lhs) - # Structural-only negative check: the normalizer correctly refused to apply - # the bogus factorization. assert not syntactic_eq_alpha(lhs, bogus_rhs) @@ -516,3 +519,79 @@ def f(_x: int, _y: int) -> int: Sum.reduce(Product.plus(b(), f(b(), c())), {b: B(), c: C()}), ) _check_pair(lhs=lhs, rhs=rhs, free_vars=[A, B, C, f]) + + +def test_reduce_lifted_1(): + a, i = define_vars("a", "i") + A, N, A_domain = define_vars("A", "N", "A_domain", typ=list[int]) + + @Operation.define + def f(_: int) -> float: + raise NotHandled + + term1 = Sum.reduce( + Product.reduce(f(a()), {a: A()}), + {A: CartesianProduct.reduce(A_domain(), {i: N()})}, + ) + term2 = Product.reduce(Sum.reduce(f(a()), {a: A_domain()}), {i: N()}) + _check_pair(lhs=term1, rhs=term2, free_vars=[N, A_domain, f]) + + +def test_reduce_cartesian_1(): + a, i = define_vars("a", "i") + A = define_vars("A", typ=list[int]) + + term1 = Sum.reduce( + Product.reduce(a(), {a: []}), + {A: CartesianProduct.reduce([], {i: []})}, + ) + term2 = Product.reduce(Sum.reduce(a(), {a: []}), {i: []}) + assert term1 == term2 + + +def test_reduce_cartesian_2(): + a, i = define_vars("a", "i") + A = define_vars("A", typ=list[int]) + + term1 = Sum.reduce( + Product.reduce(a(), {a: A()}), + {A: CartesianProduct.reduce([(0,)], {i: [0]})}, + ) + term2 = Product.reduce(Sum.reduce(a(), {a: [0]}), {i: [0]}) + assert term1 == term2 + + +def test_reduce_lifted_2(): + """The worked example on page 396 of 'Lifted Variable Elimination: + Decoupling the Operators from the Constraint Language'. + + """ + a, i, s, t = define_vars("a", "i", "s", "t") + A, N, T = define_vars("A", "N", "T", typ=list[int]) + + @Operation.define + def A_domain(_i: int) -> list[int]: + raise NotHandled + + @Operation.define + def f1(_a: int, _s: int) -> float: + raise NotHandled + + @Operation.define + def f2(_t: int, _a: int) -> float: + raise NotHandled + + term1 = Sum.reduce( + Product.reduce(Product.plus(f1(a(), s()), f2(t(), a())), {a: A()}), + {A: CartesianProduct.reduce(A_domain(i()), {i: N()}), t: T()}, + ) + + term2 = Sum.reduce( + Product.reduce( + Sum.reduce(Product.plus(f1(a(), s()), f2(t(), a())), {a: A_domain(i())}), + {i: N()}, + ), + {t: T()}, + ) + + _check_pair(lhs=term1, rhs=term2, free_vars=[a, i, s, t, A, N, T, A_domain, f1, f2]) diff --git a/tests/test_ops_syntax.py b/tests/test_ops_syntax.py index 185b6132..1f5c4776 100644 --- a/tests/test_ops_syntax.py +++ b/tests/test_ops_syntax.py @@ -489,7 +489,6 @@ def _(self, x: bool) -> bool: ) assert isinstance(term_float, Term) - assert term_float.op.__name__ == "my_singledispatch" assert term_float.args == (1.5,) assert term_float.kwargs == {}