diff --git a/effectful/handlers/jax/monoid.py b/effectful/handlers/jax/monoid.py new file mode 100644 index 000000000..4b55674f9 --- /dev/null +++ b/effectful/handlers/jax/monoid.py @@ -0,0 +1,79 @@ +import jax + +import effectful.handlers.jax.numpy as jnp +from effectful.handlers.jax import bind_dims, unbind_dims +from effectful.handlers.jax.scipy.special import logsumexp +from effectful.ops.monoid import ( + CommutativeMonoid, + CommutativeMonoidWithZero, + Monoid, + Semilattice, + Streams, + distributes_over, + outer_stream, +) +from effectful.ops.semantics import evaluate, handler, typeof +from effectful.ops.syntax import deffn +from effectful.ops.types import Operation + + +@Operation.define +def cartesian_prod(x, y): + if x.ndim == 1: + x = x[:, None] + if y.ndim == 1: + y = y[:, None] + x, y = jnp.repeat(x, y.shape[0], axis=0), jnp.tile(y, (x.shape[0], 1)) + return jnp.hstack([x, y]) + + +Sum = CommutativeMonoid(kernel=jnp.add, identity=jnp.asarray(0)) +Product = CommutativeMonoidWithZero( + kernel=jnp.multiply, identity=jnp.asarray(1), zero=jnp.asarray(0) +) +Min = Semilattice(kernel=jnp.minimum, identity=jnp.asarray(float("-inf"))) +Max = Semilattice(kernel=jnp.maximum, identity=jnp.asarray(float("inf"))) +LogSumExp = CommutativeMonoid(kernel=jnp.logaddexp, identity=jnp.asarray(float("-inf"))) +CartesianProd = Monoid(kernel=cartesian_prod, identity=jnp.array([])) + +distributes_over.register(Max.plus, Min.plus) +distributes_over.register(Min.plus, Max.plus) +distributes_over.register(Sum.plus, Min.plus) +distributes_over.register(Sum.plus, Max.plus) +distributes_over.register(Product.plus, Sum.plus) +distributes_over.register(Sum.plus, LogSumExp.plus) + +ARRAY_REDUCE = { + Sum.plus: jnp.sum, + Product.plus: jnp.prod, + Min.plus: jnp.min, + Max.plus: jnp.max, + LogSumExp.plus: logsumexp, +} + + +@Monoid.reduce.register(jax.Array) +def _reduce_array(self, body: jax.Array, streams: Streams): + reductor = ARRAY_REDUCE[self.plus] + index = Operation.define(jax.Array) + + if not streams: + return self.identity + + # find and reduce an array stream + for stream_key, stream_body, streams_tail in outer_stream(streams): + if typeof(stream_body) != jax.Array: + continue + + with handler({stream_key: deffn(unbind_dims(stream_body, index))}): + (eval_body, eval_streams_tail) = evaluate(body), evaluate(streams_tail) + assert isinstance(eval_streams_tail, dict) + + reduce_tail = ( + self.reduce(eval_body, eval_streams_tail) + if len(eval_streams_tail) > 0 + else eval_body + ) + return reductor(bind_dims(reduce_tail, index), axis=0) + + return self._reduce_object(body, streams) diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index 58a10ba3d..5efebd22a 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -4,24 +4,32 @@ import numbers import typing from collections import Counter, defaultdict -from collections.abc import Callable, Generator, Iterable, Iterator, Mapping, Sequence +from collections.abc import Callable, Generator, Iterable, Mapping, Sequence from dataclasses import dataclass from graphlib import TopologicalSorter from typing import Annotated, Any from effectful.internals.disjoint_set import DisjointSet -from effectful.ops.semantics import coproduct, evaluate, fvsof, fwd, handler +from effectful.ops.semantics import coproduct, evaluate, fvsof, fwd, handler, typeof from effectful.ops.syntax import ( ObjectInterpretation, Scoped, _NumberTerm, defdata, + deffn, implements, iter_, syntactic_eq, syntactic_hash, ) -from effectful.ops.types import Interpretation, NotHandled, Operation, Term +from effectful.ops.types import ( + Expr, + Interpretation, + NotHandled, + Operation, + Term, + _CustomSingleDispatchCallable, +) # Note: The streams value type should be something like Iterable[T], but some of # our target stream types (e.g. jax.Array) are not subtypes of Iterable @@ -36,17 +44,21 @@ ) -def order_streams[T](streams: Streams[T]) -> Iterable[tuple[Operation[[], T], Any]]: - """Determine an order to evaluate the streams based on their dependencies""" +def outer_stream( + streams: Streams, +) -> Iterable[tuple[Operation, Expr, dict[Operation, Expr]]]: + """Returns the streams that can be ordered outermost in the loop nest as + well as the remaining streams in the nest. + + """ stream_vars = set(streams.keys()) - dependencies = {k: fvsof(v) & stream_vars for k, v in streams.items()} - topo = TopologicalSorter(dependencies) + pred = {k: fvsof(v) & stream_vars for k, v in streams.items()} + topo = TopologicalSorter(pred) topo.prepare() - while topo.is_active(): - node_group = topo.get_ready() - for op in sorted(node_group): - yield (op, streams[op]) - topo.done(*node_group) + return ( + (op, streams[op], {k: v for (k, v) in streams.items() if k != op}) + for op in topo.get_ready() + ) class Monoid[T]: @@ -114,58 +126,65 @@ def _(self, *args): return result @Operation.define - @functools.singledispatchmethod + @_CustomSingleDispatchCallable # type: ignore[arg-type] def reduce[A, B, U: Body]( + dispatch, self, + /, body: Annotated[U, Scoped[A | B]], streams: Annotated[Streams, Scoped[A]], ) -> Annotated[U, Scoped[B]]: - if callable(body): - return typing.cast(U, lambda *a, **k: self.reduce(body(*a, **k), streams)) + return dispatch(typeof(body))(self, body, streams) # type: ignore[operator] - def generator(loop_order) -> Iterator[Interpretation]: - if len(loop_order) == 0: - return + @reduce.register(object) # type: ignore[attr-defined] + def _reduce_object(self, body: object, streams: Streams): + if not streams: + return self.identity - stream_key = loop_order[0][0] - stream_values = evaluate(streams[stream_key]) - stream_values_iter = iter(stream_values) # type: ignore[arg-type] + # find and reduce a ground stream + for stream_key, stream_body, streams_tail in outer_stream(streams): + if isinstance(stream_body, Term): + continue - # If we try to iterate and get a term instead of a real - # iterator, give up + stream_values_iter = iter(stream_body) + + # if we iterate and get a term instead of a real iterator, skip if isinstance(stream_values_iter, Term) and stream_values_iter.op is iter_: - raise NotHandled - - if len(loop_order) == 1: - for val in stream_values_iter: - yield {stream_key: functools.partial(lambda v: v, val)} - else: - for val in stream_values_iter: - intp: Interpretation = { - stream_key: functools.partial(lambda v: v, val) - } - with handler(intp): - for intp2 in generator(loop_order[1:]): - yield coproduct(intp, intp2) - - loop_order = list(order_streams(streams)) - try: - return self.plus( - *(handler(intp)(evaluate)(body) for intp in generator(loop_order)) - ) - except NotHandled: - return typing.cast(U, defdata(self.reduce, body, streams)) + continue - @reduce.register # type: ignore[attr-defined] - def _(self, body: Mapping, streams): + new_reduces = [] + for stream_val in stream_values_iter: + with handler({stream_key: deffn(stream_val)}): + eval_args = evaluate((body, streams_tail)) + assert isinstance(eval_args, tuple) + new_reduces.append(self.reduce(*eval_args)) + + return self.plus(*new_reduces) + + return defdata(self.reduce, body, streams) + + @reduce.register(Callable) # type: ignore[attr-defined] + def _reduce_callable(self, body: Callable, streams): + if isinstance(body, Term): + return defdata(self.reduce, body, streams) + return lambda *a, **k: self.reduce(body(*a, **k), streams) + + @reduce.register(Mapping) # type: ignore[attr-defined] + def _reduce_mapping(self, body: Mapping, streams): + if isinstance(body, Term): + return defdata(self.reduce, body, streams) return {k: self.reduce(v, streams) for (k, v) in body.items()} - @reduce.register # type: ignore[attr-defined] - def _(self, body: Sequence, streams): + @reduce.register(list | tuple) # type: ignore[attr-defined] + def _reduce_sequence(self, body: Sequence, streams): + if isinstance(body, Term): + return defdata(self.reduce, body, streams) return type(body)(self.reduce(x, streams) for x in body) # type:ignore[call-arg] - @reduce.register # type: ignore[attr-defined] - def _(self, body: Generator, streams): + @reduce.register(Generator) # type: ignore[attr-defined] + def _reduce_generator(self, body: Generator, streams): + if isinstance(body, Term): + return defdata(self.reduce, body, streams) return (self.reduce(x, streams) for x in body) diff --git a/tests/_monoid_helpers.py b/tests/_monoid_helpers.py index 4532ae72d..634aacea6 100644 --- a/tests/_monoid_helpers.py +++ b/tests/_monoid_helpers.py @@ -1,11 +1,41 @@ +import itertools from collections.abc import Callable, Mapping, Sequence from typing import Any, get_args, get_origin +import jax from hypothesis import strategies as st -from effectful.ops.syntax import deffn +import effectful.handlers.jax.numpy as _jnp +from effectful.internals.runtime import interpreter +from effectful.ops.semantics import apply, evaluate +from effectful.ops.syntax import _BaseTerm, defdata, deffn, syntactic_eq from effectful.ops.types import Operation +_JAX_ARRAY_SHAPE = (3,) + + +def _jax_array_value_strategy() -> st.SearchStrategy[jax.Array]: + return st.integers(min_value=0, max_value=2**31 - 1).map( + lambda seed: jax.random.uniform( + jax.random.PRNGKey(seed), _JAX_ARRAY_SHAPE, minval=0.5, maxval=1.5 + ) + ) + + +# Unary jax fns map a scalar to a 1-D array (analogous to ``_UNARY_LIST_FNS`` +# for ints). Uses the effectful-wrapped jnp so named-dim broadcasting works. +_UNARY_JAX_FNS: list[Callable[[jax.Array], jax.Array]] = [ + lambda a: _jnp.stack([a, a + 1.0]), + lambda a: _jnp.stack([a, -a]), + lambda a: _jnp.stack([a, a + 1.0, 2.0 * a]), +] + +_BINARY_JAX_FNS: list[Callable[[jax.Array, jax.Array], jax.Array]] = [ + lambda a, b: a + b, + lambda a, b: a - b, + lambda a, b: a * b, +] + def _value_strategy_for(annotation: Any) -> st.SearchStrategy[Any]: """Strategy for the value an *0-arg* Operation should return.""" @@ -15,9 +45,11 @@ def _value_strategy_for(annotation: Any) -> st.SearchStrategy[Any]: return st.floats(allow_nan=False) if get_origin(annotation) is list and get_args(annotation) == (int,): return st.lists(st.integers()) + if annotation is jax.Array: + return _jax_array_value_strategy() raise NotImplementedError( f"No value strategy for return annotation {annotation!r}; " - "supported: int, list[int]" + "supported: int, list[int], jax.Array" ) @@ -64,8 +96,12 @@ def _strategy_for_op(op: Operation) -> st.SearchStrategy[Callable[..., Any]]: return st.sampled_from(_BINARY_INT_FNS) if get_origin(ret) is list and get_args(ret) == (int,) and param_types == (int,): return st.sampled_from(_UNARY_LIST_FNS) + if ret is jax.Array and param_types == (jax.Array,): + return st.sampled_from(_UNARY_JAX_FNS) + if ret is jax.Array and param_types == (jax.Array, jax.Array): + return st.sampled_from(_BINARY_JAX_FNS) raise NotImplementedError( - f"Function-typed free var must return int or list[int]; got {ret!r} for {op}" + f"No callable strategy for free var with return {ret!r}, params {param_types!r}" ) @@ -82,4 +118,119 @@ def random_interpretation( return intp -__all__ = ["random_interpretation"] +def define_vars(*names, typ=int): + if len(names) == 1: + return Operation.define(typ, name=names[0]) + return tuple(Operation.define(typ, name=n) for n in names) + + +def syntactic_eq_alpha(x, y) -> bool: + """Alpha-equivalence-respecting variant of ``syntactic_eq``. + + Walks each expression bottom-up with :func:`evaluate` and renames + every bound variable to a deterministic canonical Operation. The + canonical names are assigned by a counter that increments in + ``evaluate``'s natural traversal order, so two alpha-equivalent + expressions canonicalize to syntactically identical results. + """ + + _op_cache: dict[int, Operation] = {} + + def _canonical_op(idx: int, op: Operation) -> Operation: + """Cached canonical Operation, keyed by encounter index. + + Cached so that two independent canonicalize runs return the same + Operation object for the same index — letting ``syntactic_eq`` + compare canonical forms by Operation identity. + """ + if idx in _op_cache: + return _op_cache[idx] + + op = Operation.define(op, name=f"__cv_{idx}") + _op_cache[idx] = op + return op + + cx = _canonicalize(x, _canonical_op) + cy = _canonicalize(y, _canonical_op) + return syntactic_eq(cx, cy) + + +def _canonicalize(expr, _canonical_op): + counter = itertools.count() + + def _substitute(arg, renaming): + """Apply a bound-variable renaming using ``evaluate`` for traversal.""" + if not renaming: + return arg + with interpreter({apply: _BaseTerm, **renaming}): + return evaluate(arg) + + def _bound_var_order(args, kwargs, bound_set): + """Return bound variables in deterministic encounter order.""" + seen: list[Operation] = [] + seen_set: set[Operation] = set() + + def _capture(op, *a, **kw): + if op in bound_set and op not in seen_set: + seen.append(op) + seen_set.add(op) + return defdata(op, *a, **kw) + + # ``evaluate`` walks Terms, lists, tuples, mappings, dataclasses, + # etc. for free; the apply handler captures bound vars used as + # ``x()`` anywhere in the body. + with interpreter({apply: _capture}): + evaluate((args, kwargs)) + + # Binders bypass the apply handler. Pick them up with a small structural + # walk that visits dict keys too. + def _walk_bare(obj): + if isinstance(obj, Operation): + if obj in bound_set and obj not in seen_set: + seen.append(obj) + seen_set.add(obj) + elif isinstance(obj, dict): + for k, v in obj.items(): + _walk_bare(k) + _walk_bare(v) + elif isinstance(obj, list | set | frozenset | tuple): + for v in obj: + _walk_bare(v) + + _walk_bare((args, kwargs)) + return seen + + def _apply_canonical(op, *args, **kwargs): + bindings = op.__fvs_rule__(*args, **kwargs) + all_bound: set[Operation] = set().union( + *bindings.args, *bindings.kwargs.values() + ) + if not all_bound: + return _BaseTerm(op, *args, **kwargs) + + order = _bound_var_order(args, kwargs, all_bound) + canonical = {var: _canonical_op(next(counter), var) for var in order} + assert all_bound <= set(order) + + new_args = tuple( + _substitute( + arg, {v: canonical[v] for v in bindings.args[i] if v in canonical} + ) + for i, arg in enumerate(args) + ) + new_kwargs = { + k: _substitute( + v, + {var: canonical[var] for var in bindings.kwargs[k] if var in canonical}, + ) + for k, v in kwargs.items() + } + + # avoid the renaming from defdata + return _BaseTerm(op, *new_args, **new_kwargs) + + with interpreter({apply: _apply_canonical}): + return evaluate(expr) + + +__all__ = ["random_interpretation", "define_vars", "syntactic_eq_alpha"] diff --git a/tests/test_handlers_jax_monoid.py b/tests/test_handlers_jax_monoid.py new file mode 100644 index 000000000..4efd0eb21 --- /dev/null +++ b/tests/test_handlers_jax_monoid.py @@ -0,0 +1,82 @@ +import jax +import pytest + +import effectful.handlers.jax.numpy as jnp +from effectful.handlers.jax import bind_dims, unbind_dims +from effectful.handlers.jax.monoid import LogSumExp, Max, Min, Product, Sum +from effectful.handlers.jax.scipy.special import logsumexp +from effectful.ops.types import NotHandled, Operation +from tests._monoid_helpers import define_vars, syntactic_eq_alpha + +MONOIDS = [ + pytest.param(Sum, jnp.sum, id="Sum"), + pytest.param(Product, jnp.prod, id="Product"), + pytest.param(Min, jnp.min, id="Min"), + pytest.param(Max, jnp.max, id="Max"), + pytest.param(LogSumExp, logsumexp, id="LogSumExp"), +] + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_reduce_array_1(monoid, reductor): + (x, X, k) = define_vars("x", "X", "k", typ=jax.Array) + + lhs = monoid.reduce(x(), {x: X()}) + rhs = reductor(bind_dims(unbind_dims(X(), k), k), axis=0) + + assert syntactic_eq_alpha(lhs, rhs) + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_reduce_array_2(monoid, reductor): + (x, y, X, Y, k1, k2) = define_vars("x", "y", "X", "Y", "k1", "k2", typ=jax.Array) + + @Operation.define + def f(_a: jax.Array, _b: jax.Array) -> jax.Array: + raise NotHandled + + lhs = monoid.reduce(f(x(), y()), {x: X(), y: Y()}) + rhs = reductor( + bind_dims( + reductor( + bind_dims(f(unbind_dims(X(), k1), unbind_dims(Y(), k2)), k2), + axis=0, + ), + k1, + ), + axis=0, + ) + + assert syntactic_eq_alpha(lhs, rhs) + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_reduce_array_3(monoid, reductor): + """Stream `y` is `g(x())` — depends on the bound element of X. The reducer + must inline ``g`` along the same named dim used to unbind `x`.""" + (x, y, X, k1, k2) = define_vars("x", "y", "X", "k1", "k2", typ=jax.Array) + + @Operation.define + def f(_a: jax.Array, _b: jax.Array) -> jax.Array: + raise NotHandled + + @Operation.define + def g(_a: jax.Array) -> jax.Array: + raise NotHandled + + lhs = monoid.reduce(f(x(), y()), {x: X(), y: g(x())}) + rhs = reductor( + bind_dims( + reductor( + bind_dims( + f(unbind_dims(X(), k1), unbind_dims(g(unbind_dims(X(), k1)), k2)), + k2, + ), + axis=0, + ), + k1, + ), + axis=0, + ) + + assert syntactic_eq_alpha(lhs, rhs) diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py index a22928cca..d073827c9 100644 --- a/tests/test_ops_monoid.py +++ b/tests/test_ops_monoid.py @@ -1,16 +1,11 @@ -import functools -import itertools - import pytest from hypothesis import given, settings from hypothesis import strategies as st -from effectful.internals.runtime import interpreter from effectful.ops.monoid import Max, Min, NormalizeIntp, Product, Semilattice, Sum -from effectful.ops.semantics import apply, evaluate, fvsof, handler -from effectful.ops.syntax import _BaseTerm, defdata, syntactic_eq +from effectful.ops.semantics import evaluate, fvsof, handler from effectful.ops.types import NotHandled, Operation -from tests._monoid_helpers import random_interpretation +from tests._monoid_helpers import define_vars, random_interpretation, syntactic_eq_alpha _INT = st.integers(min_value=-100, max_value=100) @@ -38,116 +33,6 @@ ] -def define_vars(*names, typ=int): - if len(names) == 1: - return Operation.define(typ, name=names[0]) - return tuple(Operation.define(typ, name=n) for n in names) - - -@functools.cache -def _canonical_op(idx: int) -> Operation: - """Globally cached canonical Operation, keyed by encounter index. - - Cached so that two independent canonicalize runs return the same - Operation object for the same index — letting ``syntactic_eq`` - compare canonical forms by Operation identity. - """ - return Operation.define(int, name=f"__cv_{idx}") - - -def syntactic_eq_alpha(x, y) -> bool: - """Alpha-equivalence-respecting variant of ``syntactic_eq``. - - Walks each expression bottom-up with :func:`evaluate` and renames - every bound variable to a deterministic canonical Operation. The - canonical names are assigned by a counter that increments in - ``evaluate``'s natural traversal order, so two alpha-equivalent - expressions canonicalize to syntactically identical results. - """ - return syntactic_eq(_canonicalize(x), _canonicalize(y)) - - -def _canonicalize(expr): - counter = itertools.count() - - def _passthrough(op, *args, **kwargs): - return defdata(op, *args, **kwargs) - - def _substitute(arg, renaming): - """Apply a bound-variable renaming using ``evaluate`` for traversal.""" - if not renaming: - return arg - with interpreter({apply: _passthrough, **renaming}): - return evaluate(arg) - - def _bound_var_order(args, kwargs, bound_set): - """Return bound variables in deterministic encounter order.""" - seen: list[Operation] = [] - seen_set: set[Operation] = set() - - def _capture(op, *a, **kw): - if op in bound_set and op not in seen_set: - seen.append(op) - seen_set.add(op) - return defdata(op, *a, **kw) - - # ``evaluate`` walks Terms, lists, tuples, mappings, dataclasses, - # etc. for free; the apply handler captures bound vars used as - # ``x()`` anywhere in the body. - with interpreter({apply: _capture}): - evaluate((args, kwargs)) - - # Binders bypass the apply handler. Pick them up with a small structural - # walk that visits dict keys too. - def _walk_bare(obj): - if isinstance(obj, Operation): - if obj in bound_set and obj not in seen_set: - seen.append(obj) - seen_set.add(obj) - elif isinstance(obj, dict): - for k, v in obj.items(): - _walk_bare(k) - _walk_bare(v) - elif isinstance(obj, list | set | frozenset | tuple): - for v in obj: - _walk_bare(v) - - _walk_bare((args, kwargs)) - return seen - - def _apply_canonical(op, *args, **kwargs): - bindings = op.__fvs_rule__(*args, **kwargs) - all_bound: set[Operation] = set().union( - *bindings.args, *bindings.kwargs.values() - ) - if not all_bound: - return defdata(op, *args, **kwargs) - - order = _bound_var_order(args, kwargs, all_bound) - canonical = {var: _canonical_op(next(counter)) for var in order} - assert all_bound <= set(order) - - new_args = tuple( - _substitute( - arg, {v: canonical[v] for v in bindings.args[i] if v in canonical} - ) - for i, arg in enumerate(args) - ) - new_kwargs = { - k: _substitute( - v, - {var: canonical[var] for var in bindings.kwargs[k] if var in canonical}, - ) - for k, v in kwargs.items() - } - - # avoid the renaming from defdata - return _BaseTerm(op, *new_args, **new_kwargs) - - with interpreter({apply: _apply_canonical}): - return evaluate(expr) - - @pytest.mark.parametrize("monoid", ALL_MONOIDS) @given(a=_INT, b=_INT, c=_INT) @settings(max_examples=50, deadline=None) @@ -357,6 +242,52 @@ def test_plus_zero(monoid): _check_pair(lhs=lhs_left, rhs=monoid.zero, free_vars=[a]) +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_partial_1(monoid): + x, y = define_vars("x", "y") + + lhs = monoid.reduce(x(), {x: []}) + rhs = monoid.identity + + _check_pair(lhs=lhs, rhs=rhs, free_vars=[x, y]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_partial_2(monoid): + x, y = define_vars("x", "y") + Y = define_vars("Y", typ=list[int]) + + lhs = monoid.reduce(x(), {y: Y(), x: []}) + rhs = monoid.identity + + _check_pair(lhs=lhs, rhs=rhs, free_vars=[x, y, Y]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_partial_3(monoid): + x, y, a, b = define_vars("x", "y", "a", "b") + Y = define_vars("Y", typ=list[int]) + + lhs = monoid.reduce(x(), {y: Y(), x: [a(), b()]}) + rhs = monoid.plus(monoid.reduce(a(), {y: Y()}), monoid.reduce(b(), {y: Y()})) + + _check_pair(lhs=lhs, rhs=rhs, free_vars=[x, y, a, b, Y]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_partial_4(monoid): + x, y, a, b = define_vars("x", "y", "a", "b") + + @Operation.define + def f(_x: int) -> list[int]: + raise NotHandled + + lhs = monoid.reduce(x(), {y: f(x()), x: [a(), b()]}) + rhs = monoid.plus(monoid.reduce(a(), {y: f(a())}), monoid.reduce(b(), {y: f(b())})) + + _check_pair(lhs=lhs, rhs=rhs, free_vars=[x, y, a, b, f]) + + @pytest.mark.parametrize("monoid", ALL_MONOIDS) def test_reduce_body_sequence(monoid): x = Operation.define(int, name="x")